├── .gitignore ├── LICENSE ├── README.md ├── finetune.py ├── notebooks ├── 01_check_llama.ipynb └── 02_ai3_llama.ipynb ├── pefty_llama ├── configuration.py ├── modeling.py ├── modeling_peft.py └── peft │ ├── __init__.py │ ├── adapter.py │ ├── bitfit.py │ ├── configuration.py │ ├── ia3.py │ ├── lora.py │ ├── prefix_adapter.py │ ├── prefix_tuning.py │ └── prompt_tuning.py ├── requirements.txt ├── setup.py └── tokenize_dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | model_checkpoints 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # My PEFTy LLaMa 2 | 3 |
4 | my_pefty_llama 5 |
6 | 7 | Minimal implementations of multiple PEFT methods for LLaMA fine-tuning. 8 | 9 | # Supported methods 10 | 11 | | Method | Status | Paper | 12 | | --- | --- | --- | 13 | | (IA)3 | ✅ | [arxiv.org/abs/2205.05638](https://arxiv.org/abs/2205.05638) | 14 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | from dataclasses import dataclass, field 5 | import tqdm.auto as tqdm 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import Dataset 12 | 13 | import datasets 14 | import transformers 15 | from transformers import ( 16 | HfArgumentParser, 17 | Trainer, 18 | TrainingArguments, 19 | ) 20 | from pefty_llama.peft import PeftConfig 21 | from pefty_llama.modeling_peft import create_model, set_peft_requires_grad 22 | 23 | 24 | @dataclass 25 | class FinetuneArguments: 26 | dataset_path: str = field() 27 | hf_path: str = field() 28 | model_name: str = field(default="7b") 29 | use_8bit: bool = field(default=False) 30 | 31 | 32 | class CastOutputToFloat(nn.Sequential): 33 | def forward(self, x): return super().forward(x).to(torch.float32) 34 | 35 | 36 | def only_tunable_params(model): 37 | requires_grad = {k: v.requires_grad for k, v in model.named_parameters()} 38 | return { 39 | k: v 40 | for k, v in model.state_dict().items() 41 | if k in requires_grad and requires_grad[k] 42 | } 43 | 44 | 45 | class ModifiedTrainer(Trainer): 46 | 47 | def compute_loss(self, model, inputs, return_outputs=False): 48 | batch_size = inputs["input_ids"].shape[0] 49 | 50 | labels = inputs["input_ids"] 51 | input_ids = torch.cat([ 52 | torch.ones(batch_size, 1).long().to(labels.device), 53 | inputs["input_ids"][:, :-1], 54 | ], dim=1) 55 | 56 | # logits will be 1 block shorter than input_ids, since we're dropping off the first block 57 | logits = model(input_ids=input_ids) 58 | 59 | loss_fct = nn.CrossEntropyLoss(ignore_index=-100) 60 | loss = loss_fct(logits.reshape( 61 | -1, logits.size(-1)), labels.reshape(-1) 62 | ) 63 | if return_outputs: 64 | return loss, logits 65 | else: 66 | return loss 67 | 68 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 69 | # If we are executing this function, we are the process zero, so we don't check for that. 70 | output_dir = output_dir if output_dir is not None else self.args.output_dir 71 | os.makedirs(output_dir, exist_ok=True) 72 | torch.save( 73 | only_tunable_params(self.model), 74 | os.path.join(output_dir, f"checkpoint.p"), 75 | ) 76 | 77 | # Good practice: save your training arguments together with the trained model 78 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 79 | 80 | def _final_ops_before_train(self): 81 | pass 82 | 83 | 84 | def data_collator(features: list) -> dict: 85 | return { 86 | "input_ids": torch.stack([torch.LongTensor(f["input_ids"]) for f in features]), 87 | } 88 | 89 | 90 | def save_tunable_parameters(model, path): 91 | saved_params = { 92 | k: v.to("cpu") 93 | for k, v in model.named_parameters() 94 | if v.requires_grad 95 | } 96 | torch.save(saved_params, path) 97 | 98 | 99 | def main(): 100 | finetune_args, peft_config, training_args = HfArgumentParser(( 101 | FinetuneArguments, 102 | PeftConfig, 103 | TrainingArguments, 104 | )).parse_args_into_dataclasses() 105 | 106 | print("Setup Data") 107 | training_args.remove_unused_columns = False 108 | dataset = datasets.load_from_disk(finetune_args.dataset_path) 109 | 110 | print("Setup Model") 111 | model = create_model( 112 | model_name=finetune_args.model_name, 113 | peft_config=peft_config, 114 | hf_path=finetune_args.hf_path, 115 | use_8bit=finetune_args.use_8bit, 116 | ) 117 | set_peft_requires_grad(model) 118 | if finetune_args.use_8bit: 119 | model.lm_head = CastOutputToFloat(model.lm_head) 120 | if training_args.gradient_checkpointing: 121 | print("Enabling gradient checkpointing") 122 | model.gradient_checkpointing_enable() 123 | model.enable_input_require_grads() 124 | 125 | print("Train") 126 | trainer = ModifiedTrainer( 127 | model=model, 128 | train_dataset=dataset, 129 | args=training_args, 130 | data_collator=data_collator 131 | ) 132 | trainer.train() 133 | save_tunable_parameters(model, os.path.join(training_args.output_dir, "params.p")) 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /notebooks/01_check_llama.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.environ['BITSANDBYTES_NOWELCOME'] = '1'" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "bin /mnt/shared_home/vlialin/miniconda3/envs/pefty_llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so\n" 23 | ] 24 | }, 25 | { 26 | "name": "stderr", 27 | "output_type": "stream", 28 | "text": [ 29 | "/mnt/shared_home/vlialin/miniconda3/envs/pefty_llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 30 | " from .autonotebook import tqdm as notebook_tqdm\n", 31 | "100%|██████████| 33/33 [00:07<00:00, 4.34it/s]\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "from pefty_llama.modeling import create_model\n", 37 | "\n", 38 | "model = create_model(\"7b\", hf_path=\"../model_checkpoints/llama-7b-hf\")" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stderr", 48 | "output_type": "stream", 49 | "text": [ 50 | "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", 51 | "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", 52 | "The class this function is called from is 'LlamaTokenizer'.\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "from transformers import LlamaTokenizer\n", 58 | "\n", 59 | "tokenizer = LlamaTokenizer.from_pretrained(\"decapoda-research/llama-7b-hf\")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "input_ids = tokenizer(\"Hello world!\", return_tensors=\"pt\").input_ids\n", 69 | "input_ids = input_ids.to(\"cuda\")\n", 70 | "\n", 71 | "output = model.generate(input_ids, generation_length=50)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 5, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "' ⁇ Hello world!, I am a student of the University of _____________. I am currently enrolled in the _____________ program. I am writing to you to request a letter of recommendation.\\nI am currently enrolled in the _____________ program at'" 83 | ] 84 | }, 85 | "execution_count": 5, 86 | "metadata": {}, 87 | "output_type": "execute_result" 88 | } 89 | ], 90 | "source": [ 91 | "tokenizer.decode(output[0])" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [] 100 | } 101 | ], 102 | "metadata": { 103 | "kernelspec": { 104 | "display_name": "pefty_llama", 105 | "language": "python", 106 | "name": "pefty_llama" 107 | }, 108 | "language_info": { 109 | "codemirror_mode": { 110 | "name": "ipython", 111 | "version": 3 112 | }, 113 | "file_extension": ".py", 114 | "mimetype": "text/x-python", 115 | "name": "python", 116 | "nbconvert_exporter": "python", 117 | "pygments_lexer": "ipython3", 118 | "version": "3.10.10" 119 | }, 120 | "orig_nbformat": 4 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 2 124 | } 125 | -------------------------------------------------------------------------------- /notebooks/02_ai3_llama.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.environ['BITSANDBYTES_NOWELCOME'] = '1'\n", 11 | "\n", 12 | "import torch\n", 13 | "from transformers import LlamaTokenizer\n", 14 | "from pefty_llama.modeling import create_model" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stderr", 24 | "output_type": "stream", 25 | "text": [ 26 | "/mnt/shared_home/vlialin/miniconda3/envs/pefty_llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 27 | " from .autonotebook import tqdm as notebook_tqdm\n" 28 | ] 29 | }, 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "bin /mnt/shared_home/vlialin/miniconda3/envs/pefty_llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so\n" 35 | ] 36 | }, 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", 42 | "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", 43 | "The class this function is called from is 'LlamaTokenizer'.\n", 44 | "100%|██████████| 33/33 [00:07<00:00, 4.38it/s]\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "tokenizer = LlamaTokenizer.from_pretrained(\"decapoda-research/llama-7b-hf\")\n", 50 | "model = create_model(\"7b\", hf_path=\"../model_checkpoints/llama-7b-hf\")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Total trainable parameters: 6,738,415,616\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 68 | "print(f\"Total trainable parameters: {total_trainable_params:,}\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | " ⁇ 42 is the answer first of its own kind.\n", 81 | "The 2\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "input_ids = tokenizer(\"42 is the answer\", return_tensors=\"pt\").input_ids\n", 87 | "input_ids = input_ids.to(\"cuda\")\n", 88 | "out1 = model.generate(input_ids, generation_length=10)\n", 89 | "print(tokenizer.decode(out1[0]))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "from pefty_llama.peft.ia3 import IA3\n", 99 | "model = IA3(model).to(\"cuda\")" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 7, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | " ⁇ 42 is the answer first of its own kind.\n", 112 | "The 2\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "out2 = model.generate(input_ids, generation_length=10)\n", 118 | "print(tokenizer.decode(out2[0]))\n", 119 | "assert torch.all(out1 == out2), \"At initialization, the model should produce the same output as the original model.\"" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 8, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "Total trainable parameters: 614,400\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 137 | "print(f\"Total trainable parameters: {total_trainable_params:,}\")" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 5, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "IA3(\n", 149 | " (base_model): LLaMAModel(\n", 150 | " (model): LLaMAInnerModel(\n", 151 | " (embed_tokens): Embedding(32000, 4096)\n", 152 | " (layers): ModuleList(\n", 153 | " (0-31): 32 x LLaMALayer(\n", 154 | " (self_attn): IA3Attention(\n", 155 | " (q_proj): NoInitLinear(in_features=4096, out_features=4096, bias=False)\n", 156 | " (k_proj): NoInitLinear(in_features=4096, out_features=4096, bias=False)\n", 157 | " (v_proj): NoInitLinear(in_features=4096, out_features=4096, bias=False)\n", 158 | " (o_proj): NoInitLinear(in_features=4096, out_features=4096, bias=False)\n", 159 | " (rotary_emb): RotaryEmbedding()\n", 160 | " )\n", 161 | " (mlp): IA3MLP(\n", 162 | " (gate_proj): NoInitLinear(in_features=4096, out_features=11008, bias=False)\n", 163 | " (up_proj): NoInitLinear(in_features=4096, out_features=11008, bias=False)\n", 164 | " (down_proj): NoInitLinear(in_features=11008, out_features=4096, bias=False)\n", 165 | " )\n", 166 | " (input_layernorm): RMSNorm()\n", 167 | " (post_attention_layernorm): RMSNorm()\n", 168 | " )\n", 169 | " )\n", 170 | " (norm): RMSNorm()\n", 171 | " )\n", 172 | " (lm_head): NoInitLinear(in_features=4096, out_features=32000, bias=False)\n", 173 | " )\n", 174 | ")" 175 | ] 176 | }, 177 | "execution_count": 5, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | } 181 | ], 182 | "source": [ 183 | "model" 184 | ] 185 | } 186 | ], 187 | "metadata": { 188 | "kernelspec": { 189 | "display_name": "pefty_llama", 190 | "language": "python", 191 | "name": "pefty_llama" 192 | }, 193 | "language_info": { 194 | "codemirror_mode": { 195 | "name": "ipython", 196 | "version": 3 197 | }, 198 | "file_extension": ".py", 199 | "mimetype": "text/x-python", 200 | "name": "python", 201 | "nbconvert_exporter": "python", 202 | "pygments_lexer": "ipython3", 203 | "version": "3.10.10" 204 | }, 205 | "orig_nbformat": 4 206 | }, 207 | "nbformat": 4, 208 | "nbformat_minor": 2 209 | } 210 | -------------------------------------------------------------------------------- /pefty_llama/configuration.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import dataclasses 3 | import torch 4 | 5 | 6 | @dataclasses.dataclass 7 | class LLaMAConfig: 8 | dim: int 9 | n_layers: int 10 | n_heads: int 11 | vocab_size: int = 32000 12 | max_seq_length: int = 2048 13 | dtype: Any = torch.float16 14 | pad_token_id: int = 0 15 | bos_token_id: int = 1 16 | eos_token_id: int = 2 17 | use_8bit: bool = False 18 | gradient_checkpointing: bool = False 19 | 20 | @property 21 | def head_dim(self): 22 | return self.dim // self.n_heads 23 | 24 | def to_dict(self): 25 | return dataclasses.asdict(self) 26 | 27 | 28 | LLAMA_7B_CONFIG = LLaMAConfig( 29 | dim=4096, 30 | n_layers=32, 31 | n_heads=32, 32 | ) 33 | DEBUG_CONFIG = LLaMAConfig( 34 | dim=64, 35 | n_layers=3, 36 | n_heads=4, 37 | ) 38 | 39 | LLAMA_CONFIG_DICT = { 40 | "7b": LLAMA_7B_CONFIG, 41 | "debug": DEBUG_CONFIG, 42 | } 43 | -------------------------------------------------------------------------------- /pefty_llama/modeling.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/zphang/minimal-llama/blob/c37e481136f118a16f77f50cdf5e867ed5dafbf9/minimal_llama/pref/llama_simple2.py 2 | 3 | import os 4 | import json 5 | import math 6 | import dataclasses 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | import bitsandbytes as bnb 13 | import tqdm.auto as tqdm 14 | 15 | from accelerate import init_empty_weights 16 | from transformers.utils.bitsandbytes import set_module_8bit_tensor_to_device 17 | from transformers import ( 18 | LlamaConfig as HF_LlamaConfig, 19 | LlamaForCausalLM as HF_Llama, 20 | ) 21 | 22 | 23 | @dataclasses.dataclass 24 | class LLaMAConfig: 25 | dim: int 26 | n_layers: int 27 | n_heads: int 28 | vocab_size: int = 32000 29 | max_seq_length: int = 2048 30 | dtype = torch.float16 31 | pad_token_id: int = 0 32 | bos_token_id: int = 1 33 | eos_token_id: int = 2 34 | use_8bit: bool = False 35 | 36 | @property 37 | def head_dim(self): 38 | return self.dim // self.n_heads 39 | 40 | 41 | LLAMA_7B_CONFIG = LLaMAConfig( 42 | dim=4096, 43 | n_layers=32, 44 | n_heads=32, 45 | ) 46 | 47 | LLAMA_CONFIG_DICT = { 48 | "7b": LLAMA_7B_CONFIG, 49 | } 50 | 51 | 52 | class LLaMAModel(nn.Module): 53 | def __init__(self, config: LLaMAConfig): 54 | super().__init__() 55 | self.config = config 56 | self.model = LLaMAInnerModel(config) 57 | self.lm_head = NoInitLinear(config.dim, config.vocab_size, bias=False, dtype=config.dtype) 58 | 59 | @classmethod 60 | def from_pretrained(cls, model_name_or_path, use_8bit=False): 61 | """Load model from a huggingface model name or path.""" 62 | hf_config = HF_LlamaConfig.from_pretrained(model_name_or_path) 63 | 64 | config = LLaMAConfig( 65 | vocab_size=hf_config.vocab_size, 66 | dim=hf_config.hidden_size, 67 | n_layers=hf_config.num_hidden_layers, 68 | n_heads=hf_config.num_attention_heads, 69 | max_seq_length=hf_config.max_position_embeddings, 70 | dtype=hf_config.dtype, 71 | pad_token_id=hf_config.pad_token_id, 72 | bos_token_id=hf_config.bos_token_id, 73 | eos_token_id=hf_config.eos_token_id, 74 | use_8bit=use_8bit, 75 | ) 76 | 77 | raise NotImplementedError() 78 | model = cls(config) 79 | 80 | # Load weights from huggingface model to the disk if needed 81 | if os.path.isdir(model_name_or_path): 82 | hf_model_path = model_name_or_path 83 | else: 84 | hf_model_path = hf_config.cache_dir 85 | hf_model = HF_LLaMA.from_pretrained(hf_model_path, config=hf_config) 86 | hf_model.save_pretrained(hf_model_path) 87 | 88 | return model 89 | 90 | def forward(self, 91 | input_ids): 92 | """Forward pass (with full decode sequence, intended for training or loss-scoring) 93 | 94 | :param input_ids: [batch_size, seq_len] 95 | :return: logits [batch_size, seq_len] 96 | """ 97 | # 1) Create masks 98 | # decoder mask 99 | # [batch_size, num_heads=1, q_len=seq_len, kv_len=seq_len] 100 | attention_mask = create_attention_mask(input_ids=input_ids, dtype=self.config.dtype) 101 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids) 102 | cos, sin = self.get_cos_sin(rope_embed_ids) 103 | 104 | # 2) Forward pass 105 | # [batch_size, seq_len, hidden_dim] 106 | model_out = self.model( 107 | input_ids, 108 | attention_mask=attention_mask, 109 | cos=cos, sin=sin, 110 | ) 111 | # [batch_size, seq_len, vocab_size] 112 | logits = self.lm_head(model_out["hidden_states"]) 113 | return logits 114 | 115 | def init_kv_cache(self, input_ids): 116 | # noinspection GrazieInspection 117 | """Initialize KV cache for decoding. 118 | 119 | A KV cache consists of a list of dicts (one per layer): 120 | dict( 121 | key = [batch_size, num_heads, kv_seq_len=0, head_dim] 122 | value = [batch_size, num_heads, kv_seq_len=0, head_dim] 123 | ) 124 | 125 | :param input_ids: [batch_size, dec_seq_len] 126 | :return: 0-length kv_cache 127 | """ 128 | kv_cache = [] 129 | batch_size = input_ids.shape[0] 130 | num_heads = self.config.n_heads 131 | head_dim = self.config.head_dim 132 | for layer in self.model.layers: 133 | device = layer.input_layernorm.weight.device 134 | kv_cache.append({ 135 | "key": torch.zeros([batch_size, num_heads, 0, head_dim]).to(device=device, dtype=self.config.dtype), 136 | "value": torch.zeros([batch_size, num_heads, 0, head_dim]).to(device=device, dtype=self.config.dtype), 137 | }) 138 | return kv_cache 139 | 140 | def generate(self, input_ids, generation_length: 20): 141 | """Generate tokens with efficient caching of KV. 142 | 143 | TODO: Add stopping conditions 144 | TODO: Add sampling capabilities 145 | 146 | :param input_ids: [batch_size, enc_seq_len] 147 | :param generation_length: int 148 | :return: [batch_size, generation_length] 149 | """ 150 | original_input_ids = input_ids 151 | batch_size, seq_len = input_ids.shape 152 | # noinspection PyUnresolvedReferences 153 | num_valid_tokens = (input_ids != self.config.pad_token_id).long().sum(dim=1) 154 | 155 | # 1) Setup 156 | if input_ids is None: 157 | # [batch_size, dec_seq_len=1] 158 | input_ids = torch.LongTensor( 159 | [[self.config.pad_token_id]] * batch_size 160 | ).to(self.lm_head.weights.device) 161 | # See: init_kv_cache. list[dict] 162 | kv_cache = self.init_kv_cache(input_ids) 163 | generated_token_ids_list = [original_input_ids] 164 | total_seq_len = seq_len 165 | 166 | # 2) First encoding 167 | # [batch_size=1, num_heads=1, q_len=1, kv_len=1] 168 | attention_mask = create_attention_mask(input_ids=input_ids, dtype=self.config.dtype) 169 | # dict( 170 | # hidden_states = [batch_size, dec_seq_len=decode_step+1, hidden_dim] 171 | # kv_cache = list[dict( 172 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 173 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 174 | # )] 175 | # ) 176 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids) 177 | cos, sin = self.get_cos_sin(rope_embed_ids) 178 | model_out = self.model( 179 | input_ids=input_ids, 180 | attention_mask=attention_mask, 181 | cos=cos, sin=sin, 182 | kv_cache=kv_cache, 183 | ) 184 | logits = self.lm_head(model_out["hidden_states"]) 185 | kv_cache = model_out["kv_cache"] 186 | generated_token_ids = logits.argmax(-1)[ 187 | torch.arange(batch_size, dtype=torch.long, device=input_ids.device), 188 | num_valid_tokens-1, 189 | ][:, None] 190 | generated_token_ids_list.append(generated_token_ids) 191 | input_ids = generated_token_ids 192 | 193 | # 2.1 shift KV cache 194 | for layer_kv_cache in kv_cache: 195 | for i in range(batch_size): 196 | layer_kv_cache["key"] = shift_kv_cache_right( 197 | layer_kv_cache["key"], num_valid_tokens=num_valid_tokens) 198 | layer_kv_cache["value"] = shift_kv_cache_right( 199 | layer_kv_cache["value"], num_valid_tokens=num_valid_tokens) 200 | 201 | # 3) Subsequent steps 202 | for decode_step in range(generation_length-1): 203 | num_valid_tokens += 1 204 | total_seq_len += 1 205 | # [batch_size=1, num_heads=1, q_len=1, kv_len=1] 206 | attention_mask = convert_mask_to_soft_mask(create_generation_attention_mask( 207 | batch_size=batch_size, 208 | seq_len=total_seq_len, 209 | num_valid_tokens=num_valid_tokens, 210 | device=input_ids.device, 211 | ), dtype=self.config.dtype) 212 | # dict( 213 | # hidden_states = [batch_size, dec_seq_len=decode_step+1, hidden_dim] 214 | # kv_cache = list[dict( 215 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 216 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 217 | # )] 218 | # ) 219 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids) + num_valid_tokens 220 | cos, sin = self.get_cos_sin(rope_embed_ids) 221 | model_out = self.model( 222 | input_ids=input_ids, 223 | attention_mask=attention_mask, 224 | kv_cache=kv_cache, 225 | cos=cos, sin=sin, 226 | ) 227 | # [batch_size, dec_seq_len=1, vocab_size] 228 | logits = self.lm_head(model_out["hidden_states"]) 229 | kv_cache = model_out["kv_cache"] 230 | # [batch_size, dec_seq_len=1] 231 | generated_token_ids = logits.argmax(-1)[:, -1:] 232 | generated_token_ids_list.append(generated_token_ids) 233 | input_ids = generated_token_ids 234 | return torch.cat(generated_token_ids_list, dim=1) 235 | 236 | def get_cos_sin(self, rope_embed_ids): 237 | cos = F.embedding( 238 | rope_embed_ids, 239 | self.model.layers[0].self_attn.rotary_emb.cos_cached[0, 0] 240 | ).to(self.config.dtype) 241 | sin = F.embedding( 242 | rope_embed_ids, 243 | self.model.layers[0].self_attn.rotary_emb.sin_cached[0, 0] 244 | ).to(self.config.dtype) 245 | cos, sin = cos[:, None, :, :], sin[:, None, :, :] 246 | return cos, sin 247 | 248 | 249 | class LLaMAInnerModel(nn.Module): 250 | def __init__(self, config: LLaMAConfig): 251 | super().__init__() 252 | self.config = config 253 | self.embed_tokens = nn.Embedding(config.vocab_size, config.dim, dtype=config.dtype) 254 | self.layers = nn.ModuleList([ 255 | LLaMALayer(config=config) 256 | for _ in range(config.n_layers) 257 | ]) 258 | self.norm = RMSNorm(dim=config.dim) 259 | 260 | def forward(self, 261 | input_ids, 262 | attention_mask, 263 | cos, sin, 264 | kv_cache=None): 265 | """ 266 | :param input_ids: [batch_size, seq_len] 267 | :param attention_mask: [batch_size=1, num_heads=1, seq_len, seq_len] 268 | :param kv_cache: See init_kv_cache. 269 | We use the presence of kv_cache to determine if we're generating 270 | :param cos: 271 | :param sin: 272 | """ 273 | hidden_states = self.embed_tokens(input_ids) 274 | 275 | new_kv_cache = [] 276 | for layer_i, layer in enumerate(self.layers): 277 | if kv_cache: 278 | # dict( 279 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 280 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 281 | # ) 282 | layer_kv_cache = kv_cache[layer_i] 283 | else: 284 | layer_kv_cache = None 285 | 286 | layer_out = layer( 287 | hidden_states=hidden_states, 288 | attention_mask=attention_mask, 289 | kv_cache=layer_kv_cache, 290 | cos=cos, sin=sin, 291 | ) 292 | hidden_states = layer_out["hidden_states"] 293 | if kv_cache: 294 | new_kv_cache.append(layer_out["kv_cache"]) 295 | hidden_states = self.norm(hidden_states) 296 | output = { 297 | "hidden_states": hidden_states 298 | } 299 | if kv_cache: 300 | output["kv_cache"] = new_kv_cache 301 | return output 302 | 303 | 304 | class LLaMALayer(nn.Module): 305 | def __init__(self, config: LLaMAConfig): 306 | super().__init__() 307 | self.config = config 308 | self.self_attn = Attention(config=config) 309 | self.mlp = MLP(config=config) 310 | self.input_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype) 311 | self.post_attention_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype) 312 | 313 | def forward( 314 | self, 315 | hidden_states, 316 | attention_mask, 317 | cos, sin, 318 | kv_cache=None, 319 | ): 320 | # 1) Self-attention 321 | # [batch_size, seq_len, hidden_dim] 322 | normed_hidden_states = self.input_layernorm(hidden_states) 323 | # dict( 324 | # attn_output = [batch_size, seq_len, hidden_dim] 325 | # kv_cache = dict( 326 | # key = [batch_size, num_heads, kv_seq_len, head_dim] 327 | # value = [batch_size, num_heads, kv_seq_len, head_dim] 328 | # ) 329 | # ) 330 | check_nan(normed_hidden_states) 331 | raw_self_attn_output = self.self_attn( 332 | hidden_states=normed_hidden_states, 333 | attention_mask=attention_mask, 334 | kv_cache=kv_cache, 335 | cos=cos, sin=sin, 336 | ) 337 | # [batch_size, seq_len, hidden_dim] 338 | hidden_states = hidden_states + raw_self_attn_output["attn_output"] 339 | check_nan(hidden_states) 340 | # 2) FFN 341 | # [batch_size, seq_len, hidden_dim] 342 | hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states)) 343 | check_nan(hidden_states) 344 | if kv_cache: 345 | return { 346 | "hidden_states": hidden_states, 347 | "kv_cache": raw_self_attn_output["kv_cache"], 348 | } 349 | 350 | return {"hidden_states": hidden_states} 351 | 352 | 353 | class MLP(nn.Module): 354 | def __init__( 355 | self, 356 | config: LLaMAConfig, 357 | multiple_of: int = 256, 358 | ): 359 | super().__init__() 360 | dim = config.dim 361 | hidden_dim = 4 * dim 362 | hidden_dim = int(2 * hidden_dim / 3) 363 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 364 | 365 | if config.use_8bit: 366 | self.gate_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False) 367 | self.up_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False) 368 | self.down_proj = NoInit8bitLinear(hidden_dim, dim, bias=False, threshold=6.0, has_fp16_weights=False) 369 | else: 370 | self.gate_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype) 371 | self.up_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype) 372 | self.down_proj = NoInitLinear(hidden_dim, dim, bias=False, dtype=config.dtype) 373 | 374 | def forward(self, x): 375 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 376 | 377 | 378 | class RMSNorm(torch.nn.Module): 379 | def __init__(self, dim: int, eps: float = 1e-6, dtype=torch.float16): 380 | super().__init__() 381 | self.eps = eps 382 | self.weight = nn.Parameter(torch.ones(dim, dtype=dtype)) 383 | 384 | def _norm(self, x): 385 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 386 | 387 | def forward(self, x): 388 | output = self._norm(x.float()).type_as(x) 389 | return output * self.weight 390 | 391 | 392 | class Attention(nn.Module): 393 | def __init__(self, config: LLaMAConfig): 394 | super().__init__() 395 | self.config = config 396 | self.n_heads = config.n_heads 397 | self.head_dim = config.dim // config.n_heads 398 | 399 | if config.use_8bit: 400 | self.q_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 401 | self.k_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 402 | self.v_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 403 | self.o_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 404 | else: 405 | self.q_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 406 | self.k_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 407 | self.v_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 408 | self.o_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 409 | self.rotary_emb = RotaryEmbedding(dim=self.head_dim) 410 | 411 | def forward(self, hidden_states, attention_mask, cos, sin, kv_cache=None): 412 | """ 413 | precomputed_kv_hidden_states is for init (pre-compute KV activations, e.g. for added prefixes) 414 | kv_cache is for generation (cached past KV) 415 | """ 416 | batch_size, q_seq_len, hidden_dim = hidden_states.size() 417 | 418 | # (batch_size, num_heads, q_seq_len, head_dim) 419 | query_states = self.q_proj(hidden_states).view( 420 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 421 | key_states = self.k_proj(hidden_states).view( 422 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 423 | value_states = self.v_proj(hidden_states).view( 424 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 425 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos=cos, sin=sin) 426 | if kv_cache: 427 | key_states = torch.cat([kv_cache["key"], key_states], dim=2) 428 | value_states = torch.cat([kv_cache["value"], value_states], dim=2) 429 | 430 | attn_output = torch.nn.functional.scaled_dot_product_attention( 431 | query=query_states, 432 | key=key_states, 433 | value=value_states, 434 | attn_mask=attention_mask, 435 | ) 436 | # (batch_size, q_seq_len, hidden_dim) 437 | attn_output = attn_output.transpose(1, 2).contiguous().view( 438 | batch_size, q_seq_len, hidden_dim, 439 | ) 440 | attn_output = self.o_proj(attn_output) 441 | check_nan(attn_output) 442 | if kv_cache: 443 | new_kv_cache = {"key": key_states, "value": value_states} 444 | return {"attn_output": attn_output, "kv_cache": new_kv_cache} 445 | 446 | return {"attn_output": attn_output} 447 | 448 | 449 | class RotaryEmbedding(torch.nn.Module): 450 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 451 | super().__init__() 452 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device=device) / dim)) 453 | self.register_buffer("inv_freq", inv_freq) 454 | 455 | # Build here to make `torch.jit.trace` work. 456 | self.max_seq_len_cached = max_position_embeddings 457 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device).to(self.inv_freq.dtype) 458 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 459 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 460 | emb = torch.cat((freqs, freqs), dim=-1) 461 | self.cos_cached = emb.cos()[None, None, :, :] 462 | self.sin_cached = emb.sin()[None, None, :, :] 463 | 464 | def forward(self, x, seq_len=None): 465 | # x: [bs, num_attention_heads, seq_len, head_size] 466 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 467 | if seq_len > self.max_seq_len_cached: 468 | self.max_seq_len_cached = seq_len 469 | t = torch.arange(self.max_seq_len_cached, device=x.device).to(self.inv_freq.dtype) 470 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 471 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 472 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 473 | self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype) 474 | self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype) 475 | return ( 476 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), 477 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), 478 | ) 479 | 480 | 481 | def rotate_half(x): 482 | """Rotates half the hidden dims of the input.""" 483 | x1 = x[..., : x.shape[-1] // 2] 484 | x2 = x[..., x.shape[-1] // 2:] 485 | return torch.cat((-x2, x1), dim=-1) 486 | 487 | 488 | def apply_rotary_pos_emb(q, k, cos, sin): 489 | q_embed = (q * cos) + (rotate_half(q) * sin) 490 | k_embed = (k * cos) + (rotate_half(k) * sin) 491 | return q_embed, k_embed 492 | 493 | 494 | def create_attention_mask(input_ids, 495 | dtype=torch.float32, 496 | return_soft_mask=True): 497 | """Create mask for decoder attention. 498 | 499 | Decoder masks have two use-cases: 500 | 501 | 1) Training, where we see the full decoder sequence. In that case, 502 | we want a causal mask. 503 | 504 | 2) Generation, where we only see one token at once. In that case, 505 | it doesn't really matter what we give, we can just give a 1. 506 | (i.e. seq_len = 1) 507 | 508 | Note that in both cases we do not care about which decoder_input_ids 509 | are valid, and also we can always simply broadcast over the batch size 510 | and heads. 511 | 512 | :param input_ids: [batch_size, seq_len] 513 | :param dtype: dtype 514 | :param return_soft_mask: whether to return mask or logits-mask 515 | :return: float [batch_size=1, num_heads=1, q_len=seq_len, kv_len=seq_len] 516 | """ 517 | batch_size, seq_length = input_ids.shape 518 | # [seq_len] 519 | seq_ids = torch.arange(seq_length, device=input_ids.device) 520 | # [seq_len, seq_len] 521 | causal_mask = seq_ids[None, :].repeat(seq_length, 1) <= seq_ids[:, None] 522 | # [batch_size=1, num_heads=1, seq_len, seq_len] 523 | causal_mask = causal_mask[None, None, :, :] 524 | if return_soft_mask: 525 | return convert_mask_to_soft_mask(causal_mask, dtype=dtype) 526 | else: 527 | return causal_mask 528 | 529 | 530 | def convert_mask_to_soft_mask(mask, dtype): 531 | """Convert binary mask to mask that can be added to logits. 532 | 533 | (i.e. 0 for attention, large negative for masked) 534 | """ 535 | mask = mask.to(dtype=dtype) 536 | mask = (1.0 - mask) * torch.finfo(dtype).min 537 | return mask 538 | 539 | 540 | class NoInitLinear(nn.Linear): 541 | def reset_parameters(self) -> None: 542 | pass 543 | 544 | 545 | class NoInit8bitLinear(bnb.nn.Linear8bitLt): 546 | def reset_parameters(self) -> None: 547 | pass 548 | 549 | 550 | def get_linear_class(use_8bit=False): 551 | if use_8bit: 552 | return NoInit8bitLinear 553 | else: 554 | return NoInitLinear 555 | 556 | 557 | class NoInitEmbedding(nn.Embedding): 558 | def reset_parameters(self) -> None: 559 | pass 560 | 561 | 562 | def check_nan(x): 563 | if torch.isnan(x).any(): 564 | import pdb 565 | pdb.set_trace() 566 | 567 | 568 | def create_model(model_name, hf_path, use_8bit=False, device=None): 569 | config = LLAMA_CONFIG_DICT[model_name] 570 | 571 | with open(os.path.join(hf_path, "pytorch_model.bin.index.json")) as f: 572 | weight_map = json.load(f)["weight_map"] 573 | 574 | filename_list = sorted(list(set(weight_map.values()))) 575 | if device is None: 576 | # TODO: Local rank 577 | device = torch.device("cuda:0") 578 | if use_8bit: 579 | config = dataclasses.replace(config, use_8bit=True) 580 | with init_empty_weights(): 581 | model = LLaMAModel(config=config) 582 | state_keys = set(model.state_dict()) 583 | filename_list = sorted(list(set(weight_map.values()))) 584 | for filename in tqdm.tqdm(filename_list): 585 | loaded = torch.load(os.path.join(hf_path, filename), map_location="cpu") 586 | for k, v in loaded.items(): 587 | set_module_8bit_tensor_to_device(model, tensor_name=k, device=device, value=v) 588 | state_keys.remove(k) 589 | assert not state_keys 590 | else: 591 | # noinspection PyUnresolvedReferences 592 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 593 | model = LLaMAModel(config=config).cuda() 594 | torch.set_default_tensor_type(torch.FloatTensor) 595 | state_keys = set(model.state_dict()) 596 | for filename in tqdm.tqdm(filename_list): 597 | loaded = torch.load(os.path.join(hf_path, filename), map_location="cpu") 598 | model.load_state_dict(loaded, strict=False) 599 | for k in loaded: 600 | state_keys.remove(k) 601 | return model 602 | 603 | 604 | def shift_kv_cache_right(layer_cache, num_valid_tokens): 605 | """ 606 | :param layer_cache: left-aligned kv cache element, [batch_size, num_heads, seq_len, dim] 607 | :param num_valid_tokens: [batch_size] 608 | :return: 609 | """ 610 | batch_size = layer_cache.shape[0] 611 | # noinspection PyUnresolvedReferences 612 | return torch.stack([ 613 | torch.cat([ 614 | layer_cache[i, :, num_valid_tokens[i]:, :], 615 | layer_cache[i, :, :num_valid_tokens[i], :], 616 | ], dim=1) 617 | for i in range(batch_size) 618 | ], dim=0) 619 | 620 | 621 | def create_generation_attention_mask(batch_size, seq_len, num_valid_tokens, device): 622 | """ 623 | :param batch_size: int 624 | :param seq_len: int 625 | :param num_valid_tokens: [batch_size] 626 | :param device: 627 | :return: 628 | """ 629 | # For right-aligned, based on num_valid_tokens 630 | # noinspection PyTypeChecker 631 | attn_mask = torch.zeros([batch_size, 1, 1, seq_len], dtype=bool) 632 | for i in range(batch_size): 633 | valid = num_valid_tokens[i] 634 | # noinspection PyTypeChecker 635 | # attn_mask[i, 0, -valid:, -valid:] = torch.tril(torch.ones([valid, valid], dtype=bool)) 636 | attn_mask[i, 0, 0, -valid:] = True 637 | return attn_mask.to(device=device) 638 | 639 | 640 | def create_casual_attention_mask(seq_len, device): 641 | # noinspection PyTypeChecker 642 | attn_mask = torch.tril(torch.ones([seq_len, seq_len], dtype=bool))[None, None, :, :] 643 | return attn_mask.to(device=device) 644 | 645 | 646 | def create_rope_embed_ids(input_ids): 647 | pad_token_id = 0 648 | max_position = 2047 # These will not actually be used, as they are masked out by the attention mask 649 | x = (input_ids != pad_token_id).cumsum(-1) - 1 650 | x[input_ids == pad_token_id] = max_position 651 | return x 652 | -------------------------------------------------------------------------------- /pefty_llama/modeling_peft.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/zphang/minimal-llama/blob/c37e481136f118a16f77f50cdf5e867ed5dafbf9/minimal_llama/pref/llama_simple2.py 2 | 3 | import os 4 | import json 5 | import math 6 | import dataclasses 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | import bitsandbytes as bnb 13 | import tqdm.auto as tqdm 14 | 15 | from accelerate import init_empty_weights 16 | from transformers.utils.bitsandbytes import set_module_8bit_tensor_to_device 17 | from transformers import ( 18 | LlamaConfig as HF_LlamaConfig, 19 | LlamaForCausalLM as HF_Llama, 20 | ) 21 | import pefty_llama.peft as peft 22 | from pefty_llama.configuration import LLaMAConfig, LLAMA_CONFIG_DICT 23 | 24 | 25 | class LLaMAModel(nn.Module): 26 | def __init__(self, config: LLaMAConfig, peft_config: peft.PeftConfig): 27 | super().__init__() 28 | self.config = config 29 | self.peft_config = peft_config 30 | self.model = LLaMAInnerModel(config=config, peft_config=peft_config) 31 | self.lm_head = NoInitLinear(config.dim, config.vocab_size, bias=False, dtype=config.dtype) 32 | 33 | if self.peft_config.peft_mode == peft.PEFT_PREFIX: 34 | self.peft_prefixes = peft.SoftPrefixes(config=config, peft_config=peft_config) 35 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_embedding: 36 | self.peft_lora_lm_head = peft.LoRA(config=config, peft_config=peft_config, 37 | output_dim=config.vocab_size) 38 | 39 | def forward(self, 40 | input_ids): 41 | """Forward pass (with full decode sequence, intended for training or loss-scoring) 42 | 43 | :param input_ids: [batch_size, seq_len] 44 | :return: logits [batch_size, seq_len] 45 | """ 46 | # 1) Create masks 47 | # decoder mask 48 | # [batch_size, num_heads=1, q_len=seq_len, kv_len=seq_len] 49 | attention_mask = create_attention_mask(input_ids=input_ids, dtype=self.config.dtype) 50 | input_ids_for_rope = input_ids 51 | if self.peft_config.peft_mode == peft.PEFT_PREFIX: 52 | attention_mask = torch.cat([ 53 | zeros_like([1, 1, input_ids.shape[1], self.peft_config.num_prefix_tokens], tensor=attention_mask), 54 | attention_mask, 55 | ], dim=3) 56 | 57 | if self.peft_config.peft_mode in peft.PEFT_PROMPT: 58 | input_ids_for_rope = torch.cat([ 59 | torch.ones([input_ids.shape[0], self.peft_config.num_prefix_tokens], 60 | dtype=input_ids.dtype, device=input_ids.device), 61 | input_ids, 62 | ], dim=1) 63 | # Easier to just remake the attention mask 64 | attention_mask = create_attention_mask(input_ids=input_ids_for_rope, dtype=self.config.dtype) 65 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids_for_rope) 66 | cos, sin = self.get_cos_sin(rope_embed_ids) 67 | 68 | if self.peft_config.peft_mode == peft.PEFT_PREFIX: 69 | kv_cache = self.peft_prefixes(batch_size=input_ids.shape[0]) 70 | else: 71 | kv_cache = None 72 | 73 | # 2) Forward pass 74 | # [batch_size, seq_len, hidden_dim] 75 | model_out = self.model( 76 | input_ids, 77 | attention_mask=attention_mask, 78 | cos=cos, sin=sin, 79 | kv_cache=kv_cache, 80 | ) 81 | # [batch_size, seq_len, vocab_size] 82 | logits = self.lm_head(model_out["hidden_states"]) 83 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_embedding: 84 | logits += self.peft_lora_lm_head(model_out["hidden_states"]) 85 | return logits 86 | 87 | def init_kv_cache(self, input_ids): 88 | # noinspection GrazieInspection 89 | """Initialize KV cache for decoding. 90 | 91 | A KV cache consists of a list of dicts (one per layer): 92 | dict( 93 | key = [batch_size, num_heads, kv_seq_len=0, head_dim] 94 | value = [batch_size, num_heads, kv_seq_len=0, head_dim] 95 | ) 96 | 97 | :param input_ids: [batch_size, dec_seq_len] 98 | :return: 0-length kv_cache 99 | """ 100 | kv_cache = [] 101 | batch_size = input_ids.shape[0] 102 | num_heads = self.config.n_heads 103 | head_dim = self.config.head_dim 104 | for layer in self.model.layers: 105 | device = layer.input_layernorm.weight.device 106 | kv_cache.append({ 107 | "key": torch.zeros([batch_size, num_heads, 0, head_dim]).to(device=device, dtype=self.config.dtype), 108 | "value": torch.zeros([batch_size, num_heads, 0, head_dim]).to(device=device, dtype=self.config.dtype), 109 | }) 110 | return kv_cache 111 | 112 | def generate(self, input_ids, generation_length: int = 20, 113 | return_output_only=True): 114 | """Generate tokens with efficient caching of KV. 115 | 116 | TODO: Add stopping conditions 117 | TODO: Add sampling capabilities 118 | 119 | :param input_ids: [batch_size, enc_seq_len] 120 | :param generation_length: int 121 | :param return_output_only: 122 | :return: [batch_size, generation_length] 123 | """ 124 | original_input_ids = input_ids 125 | batch_size, seq_len = input_ids.shape 126 | # noinspection PyUnresolvedReferences 127 | num_valid_tokens = (input_ids != self.config.pad_token_id).long().sum(dim=1) 128 | 129 | # 1) Setup 130 | if input_ids is None: 131 | # [batch_size, dec_seq_len=1] 132 | input_ids = torch.LongTensor( 133 | [[self.config.pad_token_id]] * batch_size 134 | ).to(self.lm_head.weights.device) 135 | # See: init_kv_cache. list[dict] 136 | if self.peft_config.peft_mode == peft.PEFT_PREFIX: 137 | kv_cache = self.peft_prefixes(batch_size=input_ids.shape[0]) 138 | num_valid_kv_cache = num_valid_tokens + self.peft_config.num_prefix_tokens 139 | else: 140 | kv_cache = self.init_kv_cache(input_ids) 141 | num_valid_kv_cache = num_valid_tokens 142 | generated_token_ids_list = [original_input_ids] 143 | total_seq_len = seq_len 144 | 145 | # 2) First encoding 146 | # [batch_size=1, num_heads=1, q_len=1, kv_len=1] 147 | attention_mask = create_attention_mask(input_ids=input_ids, dtype=self.config.dtype) 148 | input_ids_for_rope = input_ids 149 | # dict( 150 | # hidden_states = [batch_size, dec_seq_len=decode_step+1, hidden_dim] 151 | # kv_cache = list[dict( 152 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 153 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 154 | # )] 155 | # ) 156 | if self.peft_config.peft_mode in (peft.PEFT_PREFIX, peft.PEFT_PROMPT): 157 | num_prefix_tokens = self.peft_config.num_prefix_tokens 158 | total_seq_len += num_prefix_tokens 159 | # [batch_size, num_heads=1, q_len=seq_len, kv_len=num_prefix_tokens + dec_seq_len] 160 | attention_mask = torch.cat([ 161 | zeros_like([1, 1, input_ids.shape[1], num_prefix_tokens], tensor=attention_mask), 162 | attention_mask, 163 | ], dim=3) 164 | 165 | if self.peft_config.peft_mode in peft.PEFT_PROMPT: 166 | input_ids_for_rope = torch.cat([ 167 | torch.ones([input_ids.shape[0], self.peft_config.num_prefix_tokens], 168 | dtype=input_ids.dtype, device=input_ids.device), 169 | input_ids, 170 | ], dim=1) 171 | # Easier to just remake the attention mask 172 | attention_mask = create_attention_mask(input_ids=input_ids_for_rope, dtype=self.config.dtype) 173 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids_for_rope) 174 | cos, sin = self.get_cos_sin(rope_embed_ids) 175 | model_out = self.model( 176 | input_ids=input_ids, 177 | attention_mask=attention_mask, 178 | cos=cos, sin=sin, 179 | kv_cache=kv_cache, 180 | ) 181 | logits = self.lm_head(model_out["hidden_states"]) 182 | kv_cache = model_out["kv_cache"] 183 | generated_token_ids = logits.argmax(-1)[ 184 | torch.arange(batch_size, dtype=torch.long, device=input_ids.device), 185 | num_valid_tokens-1, 186 | ][:, None] 187 | generated_token_ids_list.append(generated_token_ids) 188 | input_ids = generated_token_ids 189 | 190 | # 3) Subsequent steps 191 | for decode_step in range(generation_length-1): 192 | num_valid_tokens += 1 193 | total_seq_len += 1 194 | # [batch_size=1, num_heads=1, q_len=1, kv_len=1] 195 | attention_mask = convert_mask_to_soft_mask(create_generation_attention_mask( 196 | batch_size=batch_size, 197 | seq_len=total_seq_len, 198 | num_valid_tokens=num_valid_tokens, 199 | device=input_ids.device, 200 | ), dtype=self.config.dtype) 201 | # dict( 202 | # hidden_states = [batch_size, dec_seq_len=decode_step+1, hidden_dim] 203 | # kv_cache = list[dict( 204 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 205 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 206 | # )] 207 | # ) 208 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids) + num_valid_tokens[:, None] 209 | cos, sin = self.get_cos_sin(rope_embed_ids) 210 | model_out = self.model( 211 | input_ids=input_ids, 212 | attention_mask=attention_mask, 213 | kv_cache=kv_cache, 214 | cos=cos, sin=sin, 215 | ) 216 | # [batch_size, dec_seq_len=1, vocab_size] 217 | logits = self.lm_head(model_out["hidden_states"]) 218 | kv_cache = model_out["kv_cache"] 219 | # [batch_size, dec_seq_len=1] 220 | generated_token_ids = logits.argmax(-1)[:, -1:] 221 | generated_token_ids_list.append(generated_token_ids) 222 | input_ids = generated_token_ids 223 | output = torch.cat(generated_token_ids_list, dim=1) 224 | if return_output_only: 225 | output = output[:, seq_len:] 226 | return output 227 | 228 | def get_cos_sin(self, rope_embed_ids): 229 | cos = F.embedding( 230 | rope_embed_ids, 231 | self.model.layers[0].self_attn.rotary_emb.cos_cached[0, 0].to(rope_embed_ids.device) 232 | ).to(self.config.dtype) 233 | sin = F.embedding( 234 | rope_embed_ids, 235 | self.model.layers[0].self_attn.rotary_emb.sin_cached[0, 0].to(rope_embed_ids.device) 236 | ).to(self.config.dtype) 237 | cos, sin = cos[:, None, :, :], sin[:, None, :, :] 238 | return cos, sin 239 | 240 | def gradient_checkpointing_enable(self): 241 | self.config.gradient_checkpointing = True 242 | 243 | def enable_input_require_grads(self): 244 | def make_inputs_require_grads(module, input, output): 245 | output.requires_grad_(True) 246 | self.model.embed_tokens.register_forward_hook(make_inputs_require_grads) 247 | 248 | 249 | class LLaMAInnerModel(nn.Module): 250 | def __init__(self, config: LLaMAConfig, peft_config: peft.PeftConfig): 251 | super().__init__() 252 | self.config = config 253 | self.peft_config = peft_config 254 | self.embed_tokens = nn.Embedding(config.vocab_size, config.dim, dtype=config.dtype) 255 | self.layers = nn.ModuleList([ 256 | LLaMALayer(config=config, peft_config=peft_config) 257 | for _ in range(config.n_layers) 258 | ]) 259 | self.norm = RMSNorm(dim=config.dim) 260 | 261 | if self.peft_config.peft_mode == peft.PEFT_PROMPT: 262 | self.peft_prompt = peft.AddSoftPrompt(config=config, peft_config=peft_config) 263 | 264 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_embedding: 265 | self.peft_lora_embed = peft.LoRAEmbed(config=config, peft_config=peft_config) 266 | 267 | def forward(self, 268 | input_ids, 269 | attention_mask, 270 | cos, sin, 271 | kv_cache=None): 272 | """ 273 | :param input_ids: [batch_size, seq_len] 274 | :param attention_mask: [batch_size=1, num_heads=1, seq_len, seq_len] 275 | :param cos: for RoPE 276 | :param sin: for RoPE 277 | :param kv_cache: See init_kv_cache. 278 | """ 279 | hidden_states = self.embed_tokens(input_ids).to(self.config.dtype) 280 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_embedding: 281 | hidden_states += self.peft_lora_embed(input_ids).to(self.config.dtype) 282 | 283 | if self.peft_config.peft_mode == peft.PEFT_PROMPT: 284 | if kv_cache is None or kv_cache[0]["key"].shape[2] == 0: 285 | # Only add prompt if kv_cache is None (full forward pass) or if kv_cache is empty (first decode step) 286 | hidden_states = self.peft_prompt(hidden_states) 287 | 288 | new_kv_cache = [] 289 | for layer_i, layer in enumerate(self.layers): 290 | if kv_cache: 291 | # dict( 292 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 293 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim] 294 | # ) 295 | layer_kv_cache = kv_cache[layer_i] 296 | else: 297 | layer_kv_cache = None 298 | 299 | if self.config.gradient_checkpointing: 300 | layer_out = torch.utils.checkpoint.checkpoint( 301 | layer, 302 | hidden_states, 303 | attention_mask, 304 | cos, sin, 305 | layer_kv_cache, 306 | ) 307 | else: 308 | layer_out = layer( 309 | hidden_states=hidden_states, 310 | attention_mask=attention_mask, 311 | cos=cos, sin=sin, 312 | kv_cache=layer_kv_cache, 313 | ) 314 | hidden_states, out_layer_kv_cache = layer_out 315 | if kv_cache: 316 | new_kv_cache.append(out_layer_kv_cache) 317 | hidden_states = self.norm(hidden_states) 318 | output = { 319 | "hidden_states": hidden_states 320 | } 321 | if kv_cache: 322 | output["kv_cache"] = new_kv_cache 323 | return output 324 | 325 | 326 | class LLaMALayer(nn.Module): 327 | def __init__(self, config: LLaMAConfig, peft_config: peft.PeftConfig): 328 | super().__init__() 329 | self.config = config 330 | self.peft_config = peft_config 331 | self.self_attn = Attention(config=config, peft_config=peft_config) 332 | self.mlp = MLP(config=config, peft_config=peft_config) 333 | self.input_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype) 334 | self.post_attention_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype) 335 | 336 | if self.peft_config.peft_mode == peft.PEFT_ADAPTER: 337 | if self.peft_config.adapter_version == "houlsby": 338 | self.peft_adapter_attn = peft.Adapter(config=config, peft_config=peft_config) 339 | self.peft_adapter_mlp = peft.Adapter(config=config, peft_config=peft_config) 340 | 341 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 342 | self.peft_input_layernorm_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config) 343 | self.peft_post_attention_layernorm_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config) 344 | 345 | def forward( 346 | self, 347 | hidden_states, 348 | attention_mask, 349 | cos, sin, 350 | kv_cache=None, 351 | ): 352 | # 1) Self-attention 353 | # [batch_size, seq_len, hidden_dim] 354 | normed_hidden_states = self.input_layernorm(hidden_states).to(self.config.dtype) 355 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 356 | normed_hidden_states = self.peft_input_layernorm_bias(normed_hidden_states) 357 | # dict( 358 | # attn_output = [batch_size, seq_len, hidden_dim] 359 | # kv_cache = dict( 360 | # key = [batch_size, num_heads, kv_seq_len, head_dim] 361 | # value = [batch_size, num_heads, kv_seq_len, head_dim] 362 | # ) 363 | # ) 364 | check_nan(normed_hidden_states) 365 | raw_self_attn_output = self.self_attn( 366 | hidden_states=normed_hidden_states, 367 | attention_mask=attention_mask, 368 | kv_cache=kv_cache, 369 | cos=cos, sin=sin, 370 | ) 371 | # [batch_size, seq_len, hidden_dim] 372 | attn_out = raw_self_attn_output["attn_output"] 373 | if self.peft_config.peft_mode == peft.PEFT_ADAPTER \ 374 | and self.peft_config.adapter_version == peft.ADAPTER_VERSION_HOULSBY: 375 | attn_out = self.peft_adapter_attn(attn_out) 376 | 377 | # [batch_size, seq_len, hidden_dim] 378 | hidden_states = hidden_states + attn_out 379 | check_nan(hidden_states) 380 | # 2) FFN 381 | # [batch_size, seq_len, hidden_dim] 382 | post_normed_hidden_states = self.post_attention_layernorm(hidden_states) 383 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 384 | post_normed_hidden_states = self.peft_post_attention_layernorm_bias(post_normed_hidden_states) 385 | 386 | mlp_out = self.mlp(post_normed_hidden_states) 387 | if self.peft_config.peft_mode == peft.PEFT_ADAPTER: 388 | mlp_out = self.peft_adapter_mlp(mlp_out) 389 | 390 | hidden_states = hidden_states + mlp_out 391 | check_nan(hidden_states) 392 | # if kv_cache: 393 | # return { 394 | # "hidden_states": hidden_states, 395 | # "kv_cache": raw_self_attn_output["kv_cache"], 396 | # } 397 | # 398 | # return {"hidden_states": hidden_states} 399 | if kv_cache: 400 | return hidden_states, raw_self_attn_output["kv_cache"] 401 | else: 402 | return hidden_states, None 403 | 404 | 405 | class MLP(nn.Module): 406 | def __init__( 407 | self, 408 | config: LLaMAConfig, 409 | peft_config: peft.PeftConfig, 410 | multiple_of: int = 256, 411 | ): 412 | super().__init__() 413 | self.config = config 414 | self.peft_config = peft_config 415 | dim = config.dim 416 | hidden_dim = 4 * dim 417 | hidden_dim = int(2 * hidden_dim / 3) 418 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 419 | 420 | if config.use_8bit: 421 | self.gate_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False) 422 | self.up_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False) 423 | self.down_proj = NoInit8bitLinear(hidden_dim, dim, bias=False, threshold=6.0, has_fp16_weights=False) 424 | else: 425 | self.gate_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype) 426 | self.up_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype) 427 | self.down_proj = NoInitLinear(hidden_dim, dim, bias=False, dtype=config.dtype) 428 | 429 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_mlp: 430 | self.gate_proj_lora = peft.LoRA(config=config, peft_config=peft_config, 431 | input_dim=dim, output_dim=hidden_dim) 432 | self.up_proj_lora = peft.LoRA(config=config, peft_config=peft_config, 433 | input_dim=dim, output_dim=hidden_dim) 434 | self.down_proj_lora = peft.LoRA(config=config, peft_config=peft_config, 435 | input_dim=dim, output_dim=hidden_dim) 436 | if self.peft_config.peft_mode == peft.PEFT_IA3: 437 | self.peft_ia3 = peft.IA3ForMLP(config, peft_config=peft_config) 438 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 439 | self.peft_gate_proj_bias = peft.BitFitAddBias(dim=hidden_dim, peft_config=peft_config) 440 | self.peft_up_proj_bias = peft.BitFitAddBias(dim=hidden_dim, peft_config=peft_config) 441 | self.peft_down_proj_bias = peft.BitFitAddBias(dim=dim, peft_config=peft_config) 442 | 443 | def forward(self, x): 444 | gate_proj = self.gate_proj(x) 445 | up_proj = self.up_proj(x) 446 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_mlp: 447 | gate_proj += self.gate_proj_lora(x) 448 | up_proj += self.up_proj_lora(x) 449 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 450 | gate_proj = self.peft_gate_proj_bias(gate_proj) 451 | up_proj = self.peft_gate_proj_bias(up_proj) 452 | 453 | intermediate_state = F.silu(gate_proj) * up_proj 454 | if self.peft_config.peft_mode == peft.PEFT_IA3: 455 | intermediate_state = self.peft_ia3(intermediate_state) 456 | 457 | down_proj = self.down_proj(intermediate_state) 458 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_mlp: 459 | down_proj = self.down_proj_lora(x) 460 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 461 | down_proj = self.peft_down_proj_bias(down_proj) 462 | 463 | return down_proj 464 | 465 | 466 | class RMSNorm(torch.nn.Module): 467 | def __init__(self, dim: int, eps: float = 1e-6, dtype=torch.float16): 468 | super().__init__() 469 | self.eps = eps 470 | self.weight = nn.Parameter(torch.ones(dim, dtype=dtype)) 471 | 472 | def _norm(self, x): 473 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 474 | 475 | def forward(self, x): 476 | output = self._norm(x.float()).type_as(x) 477 | return output * self.weight 478 | 479 | 480 | class Attention(nn.Module): 481 | def __init__(self, config: LLaMAConfig, peft_config: peft.PeftConfig): 482 | super().__init__() 483 | self.config = config 484 | self.peft_config = peft_config 485 | self.n_heads = config.n_heads 486 | self.head_dim = config.dim // config.n_heads 487 | 488 | if config.use_8bit: 489 | self.q_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 490 | self.k_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 491 | self.v_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 492 | self.o_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 493 | else: 494 | self.q_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 495 | self.k_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 496 | self.v_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 497 | self.o_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 498 | self.rotary_emb = RotaryEmbedding(dim=self.head_dim) 499 | 500 | if self.peft_config.peft_mode == peft.PEFT_LORA: 501 | self.peft_q_proj_lora = peft.LoRA(config=config, peft_config=peft_config) 502 | self.peft_v_proj_lora = peft.LoRA(config=config, peft_config=peft_config) 503 | if self.peft_config.peft_mode == peft.PEFT_IA3: 504 | self.peft_ia3 = peft.IA3ForAttn(config, peft_config=peft_config) 505 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 506 | self.peft_q_proj_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config) 507 | self.peft_k_proj_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config) 508 | self.peft_v_proj_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config) 509 | self.peft_o_proj_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config) 510 | if self.peft_config.peft_mode == peft.PEFT_PREFIX_ADAPTER: 511 | self.peft_prefix_adapter = peft.PrefixAdapter(config=config, peft_config=peft_config) 512 | 513 | def forward(self, hidden_states, attention_mask, cos, sin, kv_cache=None): 514 | """ 515 | precomputed_kv_hidden_states is for init (pre-compute KV activations, e.g. for added prefixes) 516 | kv_cache is for generation (cached past KV) 517 | """ 518 | batch_size, q_seq_len, hidden_dim = hidden_states.size() 519 | 520 | # (batch_size, num_heads, q_seq_len, head_dim) 521 | query_states = self.q_proj(hidden_states) 522 | key_states = self.k_proj(hidden_states) 523 | value_states = self.v_proj(hidden_states) 524 | 525 | if self.peft_config.peft_mode == peft.PEFT_LORA: 526 | query_states += self.peft_q_proj_lora(hidden_states) 527 | value_states += self.peft_v_proj_lora(hidden_states) 528 | if self.peft_config.peft_mode == peft.PEFT_IA3: 529 | key_states, value_states = self.peft_ia3(key_states, value_states) 530 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 531 | query_states = self.peft_q_proj_bias(query_states) 532 | key_states = self.peft_k_proj_bias(key_states) 533 | value_states = self.peft_v_proj_bias(value_states) 534 | 535 | query_states = query_states.view( 536 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 537 | key_states = key_states.view( 538 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 539 | value_states = value_states.view( 540 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 541 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos=cos, sin=sin) 542 | 543 | if kv_cache: 544 | key_states = torch.cat([kv_cache["key"], key_states], dim=2) 545 | value_states = torch.cat([kv_cache["value"], value_states], dim=2) 546 | 547 | attn_output = torch.nn.functional.scaled_dot_product_attention( 548 | query=query_states, 549 | key=key_states, 550 | value=value_states, 551 | attn_mask=attention_mask, 552 | ) 553 | 554 | if self.peft_config.peft_mode == peft.PEFT_PREFIX_ADAPTER: 555 | attn_output = attn_output + self.peft_prefix_adapter(query_states=query_states) 556 | 557 | # (batch_size, q_seq_len, hidden_dim) 558 | attn_output = attn_output.transpose(1, 2).contiguous().view( 559 | batch_size, q_seq_len, hidden_dim, 560 | ) 561 | attn_output = self.o_proj(attn_output) 562 | if self.peft_config.peft_mode == peft.PEFT_BITFIT: 563 | attn_output = self.peft_o_proj_bias(attn_output) 564 | 565 | check_nan(attn_output) 566 | if kv_cache: 567 | new_kv_cache = {"key": key_states, "value": value_states} 568 | return {"attn_output": attn_output, "kv_cache": new_kv_cache} 569 | else: 570 | return {"attn_output": attn_output} 571 | 572 | 573 | class RotaryEmbedding(torch.nn.Module): 574 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 575 | super().__init__() 576 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device=device) / dim)) 577 | self.register_buffer("inv_freq", inv_freq) 578 | 579 | # Build here to make `torch.jit.trace` work. 580 | self.max_seq_len_cached = max_position_embeddings 581 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device).to(self.inv_freq.dtype) 582 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 583 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 584 | emb = torch.cat((freqs, freqs), dim=-1) 585 | self.cos_cached = emb.cos()[None, None, :, :] 586 | self.sin_cached = emb.sin()[None, None, :, :] 587 | 588 | def forward(self, x, seq_len=None): 589 | # x: [bs, num_attention_heads, seq_len, head_size] 590 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 591 | if seq_len > self.max_seq_len_cached: 592 | self.max_seq_len_cached = seq_len 593 | t = torch.arange(self.max_seq_len_cached, device=x.device).to(self.inv_freq.dtype) 594 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 595 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 596 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 597 | self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype) 598 | self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype) 599 | return ( 600 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), 601 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), 602 | ) 603 | 604 | 605 | def rotate_half(x): 606 | """Rotates half the hidden dims of the input.""" 607 | x1 = x[..., : x.shape[-1] // 2] 608 | x2 = x[..., x.shape[-1] // 2:] 609 | return torch.cat((-x2, x1), dim=-1) 610 | 611 | 612 | def apply_rotary_pos_emb(q, k, cos, sin): 613 | q_embed = (q * cos) + (rotate_half(q) * sin) 614 | k_embed = (k * cos) + (rotate_half(k) * sin) 615 | return q_embed, k_embed 616 | 617 | 618 | def create_attention_mask(input_ids, 619 | dtype=torch.float32, 620 | return_soft_mask=True): 621 | """Create mask for decoder attention. 622 | 623 | Decoder masks have two use-cases: 624 | 625 | 1) Training, where we see the full decoder sequence. In that case, 626 | we want a causal mask. 627 | 628 | 2) Generation, where we only see one token at once. In that case, 629 | it doesn't really matter what we give, we can just give a 1. 630 | (i.e. seq_len = 1) 631 | 632 | Note that in both cases we do not care about which decoder_input_ids 633 | are valid, and also we can always simply broadcast over the batch size 634 | and heads. 635 | 636 | :param input_ids: [batch_size, seq_len] 637 | :param dtype: dtype 638 | :param return_soft_mask: whether to return mask or logits-mask 639 | :return: float [batch_size=1, num_heads=1, q_len=seq_len, kv_len=seq_len] 640 | """ 641 | batch_size, seq_length = input_ids.shape 642 | # [seq_len] 643 | seq_ids = torch.arange(seq_length, device=input_ids.device) 644 | # [seq_len, seq_len] 645 | causal_mask = seq_ids[None, :].repeat(seq_length, 1) <= seq_ids[:, None] 646 | # [batch_size=1, num_heads=1, seq_len, seq_len] 647 | causal_mask = causal_mask[None, None, :, :] 648 | if return_soft_mask: 649 | return convert_mask_to_soft_mask(causal_mask, dtype=dtype) 650 | else: 651 | return causal_mask 652 | 653 | 654 | def convert_mask_to_soft_mask(mask, dtype): 655 | """Convert binary mask to mask that can be added to logits. 656 | 657 | (i.e. 0 for attention, large negative for masked) 658 | """ 659 | mask = mask.to(dtype=dtype) 660 | mask = (1.0 - mask) * torch.finfo(dtype).min 661 | return mask 662 | 663 | 664 | class NoInitLinear(nn.Linear): 665 | def reset_parameters(self) -> None: 666 | pass 667 | 668 | 669 | class NoInit8bitLinear(bnb.nn.Linear8bitLt): 670 | def reset_parameters(self) -> None: 671 | pass 672 | 673 | 674 | def get_linear_class(use_8bit=False): 675 | if use_8bit: 676 | return NoInit8bitLinear 677 | else: 678 | return NoInitLinear 679 | 680 | 681 | class NoInitEmbedding(nn.Embedding): 682 | def reset_parameters(self) -> None: 683 | pass 684 | 685 | 686 | def check_nan(x): 687 | # if torch.isnan(x).any(): 688 | # import pdb 689 | # pdb.set_trace() 690 | pass 691 | 692 | 693 | def create_model(model_name, hf_path, peft_config: peft.PeftConfig, use_8bit=False, device=None): 694 | config = LLAMA_CONFIG_DICT[model_name] 695 | 696 | with open(os.path.join(hf_path, "pytorch_model.bin.index.json")) as f: 697 | weight_map = json.load(f)["weight_map"] 698 | 699 | filename_list = sorted(list(set(weight_map.values()))) 700 | if device is None: 701 | # TODO: Local rank 702 | device = torch.device("cuda:0") 703 | if use_8bit: 704 | config = dataclasses.replace(config, use_8bit=True) 705 | with init_empty_weights(): 706 | model = LLaMAModel(config=config, peft_config=peft_config) 707 | if model_name == "debug": 708 | return model 709 | state_keys = set(model.state_dict()) 710 | filename_list = sorted(list(set(weight_map.values()))) 711 | for filename in tqdm.tqdm(filename_list): 712 | loaded = torch.load(os.path.join(hf_path, filename), map_location="cpu") 713 | for k, v in loaded.items(): 714 | set_module_8bit_tensor_to_device(model, tensor_name=k, device=device, value=v) 715 | state_keys.remove(k) 716 | assert not state_keys 717 | else: 718 | # noinspection PyUnresolvedReferences 719 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 720 | model = LLaMAModel(config=config, peft_config=peft_config).cuda() 721 | torch.set_default_tensor_type(torch.FloatTensor) 722 | if model_name == "debug": 723 | return model 724 | state_keys = set(model.state_dict()) 725 | for filename in tqdm.tqdm(filename_list): 726 | loaded = torch.load(os.path.join(hf_path, filename), map_location="cpu") 727 | model.load_state_dict(loaded, strict=False) 728 | for k in loaded: 729 | state_keys.remove(k) 730 | return model 731 | 732 | 733 | def set_peft_requires_grad(model: LLaMAModel): 734 | for p in model.parameters(): 735 | p.requires_grad_(False) 736 | if model.peft_config.peft_mode == peft.PEFT_PREFIX: 737 | _set_requires_grad_if_str_in_name(model, substr="peft_prefix") 738 | elif model.peft_config.peft_mode == peft.PEFT_PROMPT: 739 | _set_requires_grad_if_str_in_name(model, substr="peft_prompt") 740 | elif model.peft_config.peft_mode == peft.PEFT_ADAPTER: 741 | _set_requires_grad_if_str_in_name(model, substr="peft_adapter_") 742 | elif model.peft_config.peft_mode == peft.PEFT_PREFIX_ADAPTER: 743 | _set_requires_grad_if_str_in_name(model, substr="peft_prefix_adapter") 744 | elif model.peft_config.peft_mode == peft.PEFT_LORA: 745 | _set_requires_grad_if_str_in_name(model, substr="_lora") 746 | elif model.peft_config.peft_mode == peft.PEFT_IA3: 747 | _set_requires_grad_if_str_in_name(model, substr="ia3") 748 | elif model.peft_config.peft_mode == peft.PEFT_BITFIT: 749 | _set_requires_grad_if_str_in_name(model, substr="bias") 750 | elif model.peft_config.peft_mode == peft.NO_PEFT: 751 | pass 752 | else: 753 | raise KeyError(model.peft_config.peft_mode) 754 | 755 | 756 | def _set_requires_grad_if_str_in_name(model, substr): 757 | for n, p in model.named_parameters(): 758 | if substr in n: 759 | p.requires_grad_(True) 760 | print(f"Tuning: {n}") 761 | 762 | 763 | def create_generation_attention_mask(batch_size, seq_len, num_valid_tokens, device): 764 | """ 765 | :param batch_size: int 766 | :param seq_len: int 767 | :param num_valid_tokens: [batch_size] 768 | :param device: 769 | :return: 770 | """ 771 | # For right-aligned, based on num_valid_tokens 772 | # noinspection PyTypeChecker 773 | attn_mask = torch.zeros([batch_size, 1, 1, seq_len], dtype=bool) 774 | for i in range(batch_size): 775 | valid = num_valid_tokens[i] 776 | # noinspection PyTypeChecker 777 | # attn_mask[i, 0, -valid:, -valid:] = torch.tril(torch.ones([valid, valid], dtype=bool)) 778 | attn_mask[i, 0, 0, -valid:] = True 779 | return attn_mask.to(device=device) 780 | 781 | 782 | def create_casual_attention_mask(seq_len, device): 783 | # noinspection PyTypeChecker 784 | attn_mask = torch.tril(torch.ones([seq_len, seq_len], dtype=bool))[None, None, :, :] 785 | return attn_mask.to(device=device) 786 | 787 | 788 | def create_rope_embed_ids(input_ids): 789 | pad_token_id = 0 790 | max_position = 2047 # These will not actually be used, as they are masked out by the attention mask 791 | x = (input_ids != pad_token_id).cumsum(-1) - 1 792 | x[input_ids == pad_token_id] = max_position 793 | return x 794 | 795 | 796 | def zeros_like(shape, tensor): 797 | return torch.zeros(shape).type_as(tensor).to(tensor.device) 798 | -------------------------------------------------------------------------------- /pefty_llama/peft/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration import * 2 | from .ia3 import IA3ForAttn, IA3ForMLP 3 | from .bitfit import BitFitAddBias 4 | from .lora import LoRA, LoRAEmbed 5 | from .prefix_tuning import SoftPrefixes 6 | from .prompt_tuning import AddSoftPrompt 7 | from .adapter import Adapter 8 | from .prefix_adapter import PrefixAdapter 9 | -------------------------------------------------------------------------------- /pefty_llama/peft/adapter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from pefty_llama.configuration import LLaMAConfig 4 | from .configuration import PeftConfig 5 | 6 | 7 | class Adapter(nn.Module): 8 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig): 9 | super().__init__() 10 | self.config = config 11 | self.peft_config = peft_config 12 | self.down_proj = nn.Linear( 13 | config.dim, peft_config.adapter_hidden_size, bias=False, 14 | dtype=peft_config.peft_dtype, 15 | ) 16 | self.up_proj = nn.Linear( 17 | peft_config.adapter_hidden_size, config.dim, bias=False, 18 | dtype=peft_config.peft_dtype, 19 | ) 20 | 21 | def forward(self, hidden_states): 22 | hidden_states = hidden_states.to(self.peft_config.peft_dtype) 23 | out = self.up_proj(F.gelu(self.down_proj(hidden_states))) + hidden_states 24 | return out.to(self.config.dtype) 25 | -------------------------------------------------------------------------------- /pefty_llama/peft/bitfit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .configuration import PeftConfig 4 | 5 | 6 | class BitFitAddBias(nn.Module): 7 | def __init__(self, dim: int, peft_config: PeftConfig): 8 | super().__init__() 9 | self.peft_config = peft_config 10 | self.bias = nn.Parameter(torch.zeros(dim, dtype=peft_config.peft_dtype)) 11 | 12 | def forward(self, hidden_state): 13 | input_dtype = hidden_state.dtype 14 | return (hidden_state.to(self.peft_config.peft_dtype) + self.bias).to(input_dtype) 15 | -------------------------------------------------------------------------------- /pefty_llama/peft/configuration.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | 6 | PEFT_PREFIX = "prefix" 7 | PEFT_PROMPT = "prompt" 8 | PEFT_ADAPTER = "adapter" 9 | PEFT_PREFIX_ADAPTER = "prefix_adapter" 10 | PEFT_LORA = "lora" 11 | PEFT_IA3 = "ia3" 12 | PEFT_BITFIT = "bitfit" 13 | NO_PEFT = "nothing" 14 | 15 | ADAPTER_VERSION_HOULSBY = "houlsby" 16 | ADAPTER_VERSION_PFEIFFER = "pfeiffer" 17 | 18 | 19 | @dataclass 20 | class PeftConfig: 21 | peft_mode: str = field() 22 | peft_dtype: Any = field(default=torch.float32) 23 | 24 | # Used by prompt, prefix, prefix_adapter 25 | num_prefix_tokens: int = field(default=16) 26 | 27 | # Prefix 28 | prefix_use_mlp: bool = field(default=True) 29 | prefix_mlp_intermediate_size: int = field(default=None) 30 | 31 | # LoRA 32 | lora_rank: int = field(default=8) 33 | lora_alpha: int = field(default=16) 34 | lora_mlp: bool = field(default=False) 35 | lora_embedding: bool = field(default=False) 36 | 37 | # Adapter 38 | adapter_hidden_size: int = field(default=64) 39 | adapter_version: str = field(default=ADAPTER_VERSION_PFEIFFER) # houlsby, pfeiffer 40 | 41 | def check(self): 42 | assert self.peft_mode in ( 43 | PEFT_PREFIX, PEFT_PREFIX_ADAPTER, PEFT_PROMPT, PEFT_ADAPTER, 44 | PEFT_IA3, PEFT_BITFIT, 45 | NO_PEFT, 46 | ) 47 | -------------------------------------------------------------------------------- /pefty_llama/peft/ia3.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from pefty_llama.modeling import LLaMAModel, NoInitLinear, NoInit8bitLinear, RotaryEmbedding, apply_rotary_pos_emb, check_nan 9 | from pefty_llama.configuration import LLaMAConfig 10 | from .configuration import PeftConfig 11 | 12 | 13 | class IA3Attention(nn.Module): 14 | def __init__(self, config: LLaMAConfig): 15 | super().__init__() 16 | self.config = config 17 | self.n_heads = config.n_heads 18 | self.head_dim = config.dim // config.n_heads 19 | 20 | if config.use_8bit: 21 | self.q_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 22 | self.k_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 23 | self.v_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 24 | self.o_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False) 25 | else: 26 | self.q_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 27 | self.k_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 28 | self.v_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 29 | self.o_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype) 30 | self.rotary_emb = RotaryEmbedding(dim=self.head_dim) 31 | 32 | # IA3-specific parameters: 33 | self.peft_l_k = nn.Parameter(torch.ones(1, self.n_heads, 1, self.head_dim, dtype=config.dtype)) 34 | self.peft_l_v = nn.Parameter(torch.ones(1, self.n_heads, 1, self.head_dim, dtype=config.dtype)) 35 | 36 | def forward(self, hidden_states, attention_mask, cos, sin, kv_cache=None): 37 | """ 38 | precomputed_kv_hidden_states is for init (pre-compute KV activations, e.g. for added prefixes) 39 | kv_cache is for generation (cached past KV) 40 | """ 41 | batch_size, q_seq_len, hidden_dim = hidden_states.size() 42 | 43 | # (batch_size, num_heads, q_seq_len, head_dim) 44 | query_states = self.q_proj(hidden_states).view( 45 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 46 | key_states = self.k_proj(hidden_states).view( 47 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 48 | value_states = self.v_proj(hidden_states).view( 49 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2) 50 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos=cos, sin=sin) 51 | if kv_cache: 52 | key_states = torch.cat([kv_cache["key"], key_states], dim=2) 53 | value_states = torch.cat([kv_cache["value"], value_states], dim=2) 54 | 55 | # IA3-specific: 56 | query_states = query_states * self.peft_l_k 57 | value_states = value_states * self.peft_l_v 58 | # end of IA3-specific 59 | 60 | scores = torch.matmul( 61 | query_states, key_states.transpose(3, 2).type_as(query_states) / math.sqrt(self.head_dim) 62 | ) 63 | scores += attention_mask 64 | 65 | # (batch_size, num_heads, q_seq_len, kv_seq_len) 66 | attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores) 67 | # (batch_size, num_heads, q_seq_len, head_dim) 68 | attn_output = torch.matmul(attn_weights, value_states.type_as(query_states)) 69 | # (batch_size, q_seq_len, hidden_dim) 70 | attn_output = attn_output.transpose(1, 2).contiguous().view( 71 | batch_size, q_seq_len, hidden_dim, 72 | ) 73 | attn_output = self.o_proj(attn_output) 74 | check_nan(attn_output) 75 | if kv_cache: 76 | new_kv_cache = {"key": key_states, "value": value_states} 77 | return {"attn_output": attn_output, "kv_cache": new_kv_cache} 78 | else: 79 | return {"attn_output": attn_output} 80 | 81 | 82 | class IA3MLP(nn.Module): 83 | def __init__( 84 | self, 85 | config: LLaMAConfig, 86 | multiple_of: int = 256, 87 | ): 88 | super().__init__() 89 | dim = config.dim 90 | hidden_dim = 4 * dim 91 | hidden_dim = int(2 * hidden_dim / 3) 92 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 93 | 94 | if config.use_8bit: 95 | self.gate_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False) 96 | self.up_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False) 97 | self.down_proj = NoInit8bitLinear(hidden_dim, dim, bias=False, threshold=6.0, has_fp16_weights=False) 98 | else: 99 | self.gate_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype) 100 | self.up_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype) 101 | self.down_proj = NoInitLinear(hidden_dim, dim, bias=False, dtype=config.dtype) 102 | 103 | # IA3-specific parameters: 104 | self.peft_l_ffn = nn.Parameter(torch.ones(1, 1, hidden_dim, dtype=config.dtype)) 105 | 106 | def forward(self, x): 107 | h = F.silu(self.gate_proj(x)) * self.up_proj(x) 108 | # IA3-specific: 109 | h = h * self.peft_l_ffn 110 | # end of IA3-specific 111 | return self.down_proj(h) 112 | 113 | 114 | class IA3(nn.Module): 115 | def __init__(self, model: LLaMAModel): 116 | super().__init__() 117 | self.base_model = model 118 | model_config = model.config 119 | 120 | for layer in self.base_model.model.layers: 121 | # you also need to copy the parameters of the layer to the new layer 122 | patched_attn = IA3Attention(model_config) 123 | current_attn = layer.self_attn 124 | patched_attn.q_proj.weight = current_attn.q_proj.weight 125 | patched_attn.k_proj.weight = current_attn.k_proj.weight 126 | patched_attn.v_proj.weight = current_attn.v_proj.weight 127 | patched_attn.o_proj.weight = current_attn.o_proj.weight 128 | patched_attn.rotary_emb = current_attn.rotary_emb 129 | 130 | layer.self_attn = patched_attn 131 | del current_attn 132 | 133 | patched_mlp = IA3MLP(model_config) 134 | current_mlp = layer.mlp 135 | patched_mlp.gate_proj.weight = current_mlp.gate_proj.weight 136 | patched_mlp.up_proj.weight = current_mlp.up_proj.weight 137 | patched_mlp.down_proj.weight = current_mlp.down_proj.weight 138 | 139 | layer.mlp = patched_mlp 140 | del current_mlp 141 | 142 | # cleanup memory freed by deleting the old layers 143 | if torch.cuda.is_available(): 144 | torch.cuda.empty_cache() 145 | gc.collect() 146 | 147 | for name, param in self.base_model.named_parameters(): 148 | if "peft_" in name: continue 149 | param.requires_grad = False 150 | 151 | # monkey patch the methods 152 | self.forward = self.base_model.forward 153 | self.generate = self.base_model.generate 154 | 155 | 156 | class IA3ForAttn(nn.Module): 157 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig): 158 | super().__init__() 159 | self.config = config 160 | self.peft_config = peft_config 161 | self.n_heads = config.n_heads 162 | self.head_dim = config.dim // config.n_heads 163 | 164 | self.peft_l_k = nn.Parameter(torch.ones(config.dim, dtype=peft_config.peft_dtype)) 165 | self.peft_l_v = nn.Parameter(torch.ones(config.dim, dtype=peft_config.peft_dtype)) 166 | 167 | def forward(self, key_states, value_states): 168 | return ( 169 | (key_states.to(self.peft_config.peft_dtype) * self.peft_l_k).to(self.config.dtype), 170 | (value_states.to(self.peft_config.peft_dtype) * self.peft_l_v).to(self.config.dtype), 171 | ) 172 | 173 | 174 | class IA3ForMLP(nn.Module): 175 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig): 176 | super().__init__() 177 | self.config = config 178 | self.peft_config = peft_config 179 | multiple_of = 256 180 | intermediate_dim = 4 * config.dim 181 | intermediate_dim = int(2 * intermediate_dim / 3) 182 | intermediate_dim = multiple_of * ((intermediate_dim + multiple_of - 1) // multiple_of) 183 | 184 | self.peft_l_ffn = nn.Parameter(torch.ones(1, 1, intermediate_dim, dtype=peft_config.peft_dtype)) 185 | 186 | def forward(self, intermediate_state): 187 | return (intermediate_state.to(self.peft_config.peft_dtype) * self.peft_l_ffn).to(self.config.dtype) 188 | -------------------------------------------------------------------------------- /pefty_llama/peft/lora.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from pefty_llama.configuration import LLaMAConfig 6 | from .configuration import PeftConfig 7 | 8 | 9 | class LoRA(nn.Module): 10 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig, 11 | input_dim: Optional[int] = None, 12 | output_dim: Optional[int] = None, 13 | ): 14 | super().__init__() 15 | self.config = config 16 | self.peft_config = peft_config 17 | 18 | if input_dim is None: 19 | input_dim = self.config.dim 20 | if output_dim is None: 21 | output_dim = self.config.dim 22 | self.lora_down = nn.Parameter(torch.randn(input_dim, peft_config.lora_rank, dtype=peft_config.peft_dtype)) 23 | self.lora_up = nn.Parameter(torch.zeros(peft_config.lora_rank, output_dim, dtype=peft_config.peft_dtype)) 24 | self.rank = peft_config.lora_rank 25 | self.scaling = peft_config.lora_alpha / peft_config.lora_rank 26 | 27 | def forward(self, hidden_states): 28 | hidden_states = hidden_states.to(self.peft_config.peft_dtype) 29 | lora_out = torch.einsum("ij,bsi->bsj", (self.lora_down @ self.lora_up), hidden_states) / self.rank 30 | return (hidden_states + self.scaling * lora_out).to(self.config.dtype) 31 | 32 | 33 | class LoRAEmbed(nn.Module): 34 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig): 35 | super().__init__() 36 | self.config = config 37 | self.peft_config = peft_config 38 | 39 | self.lora_down = nn.Parameter(torch.randn(config.vocab_size, peft_config.lora_rank, dtype=peft_config.peft_dtype)) 40 | self.lora_up = nn.Parameter(torch.zeros(peft_config.lora_rank, config.dim, dtype=peft_config.peft_dtype)) 41 | self.rank = peft_config.lora_rank 42 | self.scaling = peft_config.lora_alpha / peft_config.lora_rank 43 | 44 | def forward(self, input_ids): 45 | embedding_matrix = self.lora_down @ self.lora_up 46 | return F.embedding(input_ids, embedding_matrix) 47 | -------------------------------------------------------------------------------- /pefty_llama/peft/prefix_adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pefty_llama.configuration import LLaMAConfig 5 | from .configuration import PeftConfig 6 | 7 | 8 | class PrefixAdapter(nn.Module): 9 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig): 10 | super().__init__() 11 | self.config = config 12 | self.peft_config = peft_config 13 | # "batch_size"=1, num_heads, num_prefix_tokens, head_dim 14 | self.prefix_k = nn.Parameter(torch.randn( 15 | 1, config.n_heads, peft_config.num_prefix_tokens, config.head_dim, dtype=peft_config.peft_dtype)) 16 | self.prefix_v = nn.Parameter(torch.randn( 17 | 1, config.n_heads, peft_config.num_prefix_tokens, config.head_dim, dtype=peft_config.peft_dtype)) 18 | self.gate = nn.Parameter(torch.zeros(1, config.n_heads, 1, 1)) 19 | 20 | def forward(self, query_states): 21 | batch_size, num_heads, q_seq_len, head_dim = query_states.shape 22 | # "batch_size"=1, num_heads, num_prefix_tokens, head_dim 23 | prefix_k = self.prefix_k.expand(batch_size, -1, -1, -1) 24 | prefix_v = self.prefix_v.expand(batch_size, -1, -1, -1) 25 | attn_output = torch.nn.functional.scaled_dot_product_attention( 26 | query=query_states.to(self.peft_config.peft_dtype), 27 | key=prefix_k, 28 | value=prefix_v, 29 | ) 30 | return (F.tanh(self.gate) * attn_output).to(self.config.dtype) 31 | -------------------------------------------------------------------------------- /pefty_llama/peft/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pefty_llama.configuration import LLaMAConfig 4 | from .configuration import PeftConfig 5 | 6 | 7 | class SoftPrefixes(nn.Module): 8 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig): 9 | super().__init__() 10 | self.config = config 11 | self.peft_config = peft_config 12 | if self.peft_config.prefix_use_mlp: 13 | if self.peft_config.prefix_mlp_intermediate_size is not None: 14 | intermediate_size = self.peft_config.prefix_mlp_intermediate_size 15 | else: 16 | intermediate_size = self.config.dim 17 | 18 | self.initial = nn.Parameter( 19 | torch.randn(peft_config.num_prefix_tokens, config.dim, dtype=peft_config.peft_dtype) 20 | ) 21 | self.mlp = torch.nn.Sequential( 22 | torch.nn.Linear(config.dim, intermediate_size, dtype=peft_config.peft_dtype), 23 | torch.nn.Tanh(), 24 | torch.nn.Linear(intermediate_size, config.n_layers * 2 * config.dim, dtype=peft_config.peft_dtype), 25 | ) 26 | else: 27 | self.soft_prompt = nn.Parameter(torch.randn( 28 | peft_config.num_prefix_tokens, config.n_layers * 2 * config.dim, 29 | dtype=peft_config.peft_dtype 30 | )) 31 | 32 | def forward(self, batch_size): 33 | if self.peft_config.prefix_use_mlp: 34 | out = self.mlp(self.initial) 35 | else: 36 | out = self.embedding 37 | # layers, k/v, num_prefix_tokens, num_heads, head_dim 38 | out = out.view(self.peft_config.num_prefix_tokens, self.config.n_layers, 2, 39 | self.config.n_heads, self.config.head_dim).to(self.config.dtype) 40 | return [ 41 | { 42 | "key": out[:, layer, 0, :, :].permute(1, 0, 2).unsqueeze(0).expand(batch_size, -1, -1, -1), 43 | "value": out[:, layer, 1, :, :].permute(1, 0, 2).unsqueeze(0).expand(batch_size, -1, -1, -1), 44 | } 45 | for layer in range(self.config.n_layers) 46 | ] 47 | -------------------------------------------------------------------------------- /pefty_llama/peft/prompt_tuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pefty_llama.configuration import LLaMAConfig 4 | from .configuration import PeftConfig 5 | 6 | 7 | class AddSoftPrompt(nn.Module): 8 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig): 9 | super().__init__() 10 | self.peft_config = peft_config 11 | self.soft_prompt = nn.Parameter( 12 | torch.randn(peft_config.num_prefix_tokens, config.dim, dtype=peft_config.peft_dtype) 13 | ) 14 | 15 | def forward(self, hidden_states): 16 | batch_size, seq_len, dim = hidden_states.shape 17 | soft_prompt = self.soft_prompt.unsqueeze(0).expand(batch_size, -1, -1).to(self.config.dtype) 18 | return torch.cat([soft_prompt, hidden_states], dim=1) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | tqdm 3 | transformers 4 | accelerate 5 | bitsandbytes 6 | sentencepiece 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | with open("requirements.txt", "r") as f: 7 | requires = f.read().splitlines() 8 | 9 | setuptools.setup( 10 | name="pefty_llama", 11 | version="0.0.1", 12 | author="Vlad Lialin", 13 | author_email="vlad.lialin@gmail.com", 14 | description="Minimal implementations of multiple PEFT methods for LLaMA fine-tuning", 15 | url="https://github.com/Guitaricet/my_pefty_llama", 16 | packages=setuptools.find_packages(), 17 | requires=requires, 18 | ) 19 | -------------------------------------------------------------------------------- /tokenize_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import tqdm.auto as tqdm 6 | 7 | import datasets 8 | import transformers 9 | 10 | 11 | def read_jsonl(path): 12 | # Manually open because .splitlines is different from iterating over lines 13 | with open(path, "r") as f: 14 | for line in f: 15 | yield json.loads(line) 16 | 17 | 18 | def read_lm_dataformat(path): 19 | import lm_dataformat 20 | reader = lm_dataformat.Reader(path) 21 | yield from reader.stream_data() 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--tokenizer_path", type=str) 27 | parser.add_argument("--data_path", type=str) 28 | parser.add_argument("--data_format", type=str, default="jsonl") 29 | parser.add_argument("--save_path", type=str) 30 | parser.add_argument("--max_seq_length", type=int, default=2048) 31 | parser.add_argument("--shard_size", type=int, default=100000) 32 | args = parser.parse_args() 33 | os.makedirs(args.save_path, exist_ok=True) 34 | 35 | tokenizer = transformers.LlamaTokenizer.from_pretrained(args.tokenizer_path) 36 | 37 | all_tokenized = [] 38 | if args.data_format == "jsonl": 39 | reader = read_jsonl(args.data_path) 40 | elif args.data_format == "lm_dataformat": 41 | reader = read_lm_dataformat(args.data_path) 42 | else: 43 | raise KeyError(args.data_format) 44 | 45 | total = 0 46 | shards = 0 47 | for elem in tqdm.tqdm(reader): 48 | text = elem["text"] if args.data_format == "jsonl" else elem 49 | tokenized = tokenizer.encode(text) 50 | num_chunks = len(tokenized) // args.max_seq_length 51 | for j in range(num_chunks): 52 | chunk = tokenized[ 53 | j * args.max_seq_length: (j + 1) * args.max_seq_length 54 | ] 55 | all_tokenized.append(chunk) 56 | total += 1 57 | if len(all_tokenized) == args.shard_size: 58 | ds = datasets.Dataset.from_dict({"input_ids": all_tokenized}) 59 | ds.save_to_disk(os.path.join(args.save_path, "shard_{:05d}".format(shards))) 60 | all_tokenized = [] 61 | shards += 1 62 | 63 | if len(all_tokenized) > 0: 64 | ds = datasets.Dataset.from_dict({"input_ids": all_tokenized}) 65 | ds.save_to_disk(os.path.join(args.save_path, "shard_{:05d}".format(shards))) 66 | 67 | print(f"Generated {total} samples in {shards} shards.") 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | --------------------------------------------------------------------------------