├── .gitignore
├── LICENSE
├── README.md
├── finetune.py
├── notebooks
├── 01_check_llama.ipynb
└── 02_ai3_llama.ipynb
├── pefty_llama
├── configuration.py
├── modeling.py
├── modeling_peft.py
└── peft
│ ├── __init__.py
│ ├── adapter.py
│ ├── bitfit.py
│ ├── configuration.py
│ ├── ia3.py
│ ├── lora.py
│ ├── prefix_adapter.py
│ ├── prefix_tuning.py
│ └── prompt_tuning.py
├── requirements.txt
├── setup.py
└── tokenize_dataset.py
/.gitignore:
--------------------------------------------------------------------------------
1 | model_checkpoints
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # My PEFTy LLaMa
2 |
3 |
4 |

5 |
6 |
7 | Minimal implementations of multiple PEFT methods for LLaMA fine-tuning.
8 |
9 | # Supported methods
10 |
11 | | Method | Status | Paper |
12 | | --- | --- | --- |
13 | | (IA)3 | ✅ | [arxiv.org/abs/2205.05638](https://arxiv.org/abs/2205.05638) |
14 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import math
4 | from dataclasses import dataclass, field
5 | import tqdm.auto as tqdm
6 | from typing import Optional
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.utils.data import Dataset
12 |
13 | import datasets
14 | import transformers
15 | from transformers import (
16 | HfArgumentParser,
17 | Trainer,
18 | TrainingArguments,
19 | )
20 | from pefty_llama.peft import PeftConfig
21 | from pefty_llama.modeling_peft import create_model, set_peft_requires_grad
22 |
23 |
24 | @dataclass
25 | class FinetuneArguments:
26 | dataset_path: str = field()
27 | hf_path: str = field()
28 | model_name: str = field(default="7b")
29 | use_8bit: bool = field(default=False)
30 |
31 |
32 | class CastOutputToFloat(nn.Sequential):
33 | def forward(self, x): return super().forward(x).to(torch.float32)
34 |
35 |
36 | def only_tunable_params(model):
37 | requires_grad = {k: v.requires_grad for k, v in model.named_parameters()}
38 | return {
39 | k: v
40 | for k, v in model.state_dict().items()
41 | if k in requires_grad and requires_grad[k]
42 | }
43 |
44 |
45 | class ModifiedTrainer(Trainer):
46 |
47 | def compute_loss(self, model, inputs, return_outputs=False):
48 | batch_size = inputs["input_ids"].shape[0]
49 |
50 | labels = inputs["input_ids"]
51 | input_ids = torch.cat([
52 | torch.ones(batch_size, 1).long().to(labels.device),
53 | inputs["input_ids"][:, :-1],
54 | ], dim=1)
55 |
56 | # logits will be 1 block shorter than input_ids, since we're dropping off the first block
57 | logits = model(input_ids=input_ids)
58 |
59 | loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
60 | loss = loss_fct(logits.reshape(
61 | -1, logits.size(-1)), labels.reshape(-1)
62 | )
63 | if return_outputs:
64 | return loss, logits
65 | else:
66 | return loss
67 |
68 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
69 | # If we are executing this function, we are the process zero, so we don't check for that.
70 | output_dir = output_dir if output_dir is not None else self.args.output_dir
71 | os.makedirs(output_dir, exist_ok=True)
72 | torch.save(
73 | only_tunable_params(self.model),
74 | os.path.join(output_dir, f"checkpoint.p"),
75 | )
76 |
77 | # Good practice: save your training arguments together with the trained model
78 | torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
79 |
80 | def _final_ops_before_train(self):
81 | pass
82 |
83 |
84 | def data_collator(features: list) -> dict:
85 | return {
86 | "input_ids": torch.stack([torch.LongTensor(f["input_ids"]) for f in features]),
87 | }
88 |
89 |
90 | def save_tunable_parameters(model, path):
91 | saved_params = {
92 | k: v.to("cpu")
93 | for k, v in model.named_parameters()
94 | if v.requires_grad
95 | }
96 | torch.save(saved_params, path)
97 |
98 |
99 | def main():
100 | finetune_args, peft_config, training_args = HfArgumentParser((
101 | FinetuneArguments,
102 | PeftConfig,
103 | TrainingArguments,
104 | )).parse_args_into_dataclasses()
105 |
106 | print("Setup Data")
107 | training_args.remove_unused_columns = False
108 | dataset = datasets.load_from_disk(finetune_args.dataset_path)
109 |
110 | print("Setup Model")
111 | model = create_model(
112 | model_name=finetune_args.model_name,
113 | peft_config=peft_config,
114 | hf_path=finetune_args.hf_path,
115 | use_8bit=finetune_args.use_8bit,
116 | )
117 | set_peft_requires_grad(model)
118 | if finetune_args.use_8bit:
119 | model.lm_head = CastOutputToFloat(model.lm_head)
120 | if training_args.gradient_checkpointing:
121 | print("Enabling gradient checkpointing")
122 | model.gradient_checkpointing_enable()
123 | model.enable_input_require_grads()
124 |
125 | print("Train")
126 | trainer = ModifiedTrainer(
127 | model=model,
128 | train_dataset=dataset,
129 | args=training_args,
130 | data_collator=data_collator
131 | )
132 | trainer.train()
133 | save_tunable_parameters(model, os.path.join(training_args.output_dir, "params.p"))
134 |
135 |
136 | if __name__ == "__main__":
137 | main()
138 |
--------------------------------------------------------------------------------
/notebooks/01_check_llama.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.environ['BITSANDBYTES_NOWELCOME'] = '1'"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [
18 | {
19 | "name": "stdout",
20 | "output_type": "stream",
21 | "text": [
22 | "bin /mnt/shared_home/vlialin/miniconda3/envs/pefty_llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so\n"
23 | ]
24 | },
25 | {
26 | "name": "stderr",
27 | "output_type": "stream",
28 | "text": [
29 | "/mnt/shared_home/vlialin/miniconda3/envs/pefty_llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
30 | " from .autonotebook import tqdm as notebook_tqdm\n",
31 | "100%|██████████| 33/33 [00:07<00:00, 4.34it/s]\n"
32 | ]
33 | }
34 | ],
35 | "source": [
36 | "from pefty_llama.modeling import create_model\n",
37 | "\n",
38 | "model = create_model(\"7b\", hf_path=\"../model_checkpoints/llama-7b-hf\")"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 3,
44 | "metadata": {},
45 | "outputs": [
46 | {
47 | "name": "stderr",
48 | "output_type": "stream",
49 | "text": [
50 | "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
51 | "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n",
52 | "The class this function is called from is 'LlamaTokenizer'.\n"
53 | ]
54 | }
55 | ],
56 | "source": [
57 | "from transformers import LlamaTokenizer\n",
58 | "\n",
59 | "tokenizer = LlamaTokenizer.from_pretrained(\"decapoda-research/llama-7b-hf\")"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 4,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "input_ids = tokenizer(\"Hello world!\", return_tensors=\"pt\").input_ids\n",
69 | "input_ids = input_ids.to(\"cuda\")\n",
70 | "\n",
71 | "output = model.generate(input_ids, generation_length=50)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 5,
77 | "metadata": {},
78 | "outputs": [
79 | {
80 | "data": {
81 | "text/plain": [
82 | "' ⁇ Hello world!, I am a student of the University of _____________. I am currently enrolled in the _____________ program. I am writing to you to request a letter of recommendation.\\nI am currently enrolled in the _____________ program at'"
83 | ]
84 | },
85 | "execution_count": 5,
86 | "metadata": {},
87 | "output_type": "execute_result"
88 | }
89 | ],
90 | "source": [
91 | "tokenizer.decode(output[0])"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": []
100 | }
101 | ],
102 | "metadata": {
103 | "kernelspec": {
104 | "display_name": "pefty_llama",
105 | "language": "python",
106 | "name": "pefty_llama"
107 | },
108 | "language_info": {
109 | "codemirror_mode": {
110 | "name": "ipython",
111 | "version": 3
112 | },
113 | "file_extension": ".py",
114 | "mimetype": "text/x-python",
115 | "name": "python",
116 | "nbconvert_exporter": "python",
117 | "pygments_lexer": "ipython3",
118 | "version": "3.10.10"
119 | },
120 | "orig_nbformat": 4
121 | },
122 | "nbformat": 4,
123 | "nbformat_minor": 2
124 | }
125 |
--------------------------------------------------------------------------------
/notebooks/02_ai3_llama.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 6,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.environ['BITSANDBYTES_NOWELCOME'] = '1'\n",
11 | "\n",
12 | "import torch\n",
13 | "from transformers import LlamaTokenizer\n",
14 | "from pefty_llama.modeling import create_model"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [
22 | {
23 | "name": "stderr",
24 | "output_type": "stream",
25 | "text": [
26 | "/mnt/shared_home/vlialin/miniconda3/envs/pefty_llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
27 | " from .autonotebook import tqdm as notebook_tqdm\n"
28 | ]
29 | },
30 | {
31 | "name": "stdout",
32 | "output_type": "stream",
33 | "text": [
34 | "bin /mnt/shared_home/vlialin/miniconda3/envs/pefty_llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so\n"
35 | ]
36 | },
37 | {
38 | "name": "stderr",
39 | "output_type": "stream",
40 | "text": [
41 | "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
42 | "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n",
43 | "The class this function is called from is 'LlamaTokenizer'.\n",
44 | "100%|██████████| 33/33 [00:07<00:00, 4.38it/s]\n"
45 | ]
46 | }
47 | ],
48 | "source": [
49 | "tokenizer = LlamaTokenizer.from_pretrained(\"decapoda-research/llama-7b-hf\")\n",
50 | "model = create_model(\"7b\", hf_path=\"../model_checkpoints/llama-7b-hf\")"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [
58 | {
59 | "name": "stdout",
60 | "output_type": "stream",
61 | "text": [
62 | "Total trainable parameters: 6,738,415,616\n"
63 | ]
64 | }
65 | ],
66 | "source": [
67 | "total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
68 | "print(f\"Total trainable parameters: {total_trainable_params:,}\")"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 3,
74 | "metadata": {},
75 | "outputs": [
76 | {
77 | "name": "stdout",
78 | "output_type": "stream",
79 | "text": [
80 | " ⁇ 42 is the answer first of its own kind.\n",
81 | "The 2\n"
82 | ]
83 | }
84 | ],
85 | "source": [
86 | "input_ids = tokenizer(\"42 is the answer\", return_tensors=\"pt\").input_ids\n",
87 | "input_ids = input_ids.to(\"cuda\")\n",
88 | "out1 = model.generate(input_ids, generation_length=10)\n",
89 | "print(tokenizer.decode(out1[0]))"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 4,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "from pefty_llama.peft.ia3 import IA3\n",
99 | "model = IA3(model).to(\"cuda\")"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 7,
105 | "metadata": {},
106 | "outputs": [
107 | {
108 | "name": "stdout",
109 | "output_type": "stream",
110 | "text": [
111 | " ⁇ 42 is the answer first of its own kind.\n",
112 | "The 2\n"
113 | ]
114 | }
115 | ],
116 | "source": [
117 | "out2 = model.generate(input_ids, generation_length=10)\n",
118 | "print(tokenizer.decode(out2[0]))\n",
119 | "assert torch.all(out1 == out2), \"At initialization, the model should produce the same output as the original model.\""
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 8,
125 | "metadata": {},
126 | "outputs": [
127 | {
128 | "name": "stdout",
129 | "output_type": "stream",
130 | "text": [
131 | "Total trainable parameters: 614,400\n"
132 | ]
133 | }
134 | ],
135 | "source": [
136 | "total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
137 | "print(f\"Total trainable parameters: {total_trainable_params:,}\")"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 5,
143 | "metadata": {},
144 | "outputs": [
145 | {
146 | "data": {
147 | "text/plain": [
148 | "IA3(\n",
149 | " (base_model): LLaMAModel(\n",
150 | " (model): LLaMAInnerModel(\n",
151 | " (embed_tokens): Embedding(32000, 4096)\n",
152 | " (layers): ModuleList(\n",
153 | " (0-31): 32 x LLaMALayer(\n",
154 | " (self_attn): IA3Attention(\n",
155 | " (q_proj): NoInitLinear(in_features=4096, out_features=4096, bias=False)\n",
156 | " (k_proj): NoInitLinear(in_features=4096, out_features=4096, bias=False)\n",
157 | " (v_proj): NoInitLinear(in_features=4096, out_features=4096, bias=False)\n",
158 | " (o_proj): NoInitLinear(in_features=4096, out_features=4096, bias=False)\n",
159 | " (rotary_emb): RotaryEmbedding()\n",
160 | " )\n",
161 | " (mlp): IA3MLP(\n",
162 | " (gate_proj): NoInitLinear(in_features=4096, out_features=11008, bias=False)\n",
163 | " (up_proj): NoInitLinear(in_features=4096, out_features=11008, bias=False)\n",
164 | " (down_proj): NoInitLinear(in_features=11008, out_features=4096, bias=False)\n",
165 | " )\n",
166 | " (input_layernorm): RMSNorm()\n",
167 | " (post_attention_layernorm): RMSNorm()\n",
168 | " )\n",
169 | " )\n",
170 | " (norm): RMSNorm()\n",
171 | " )\n",
172 | " (lm_head): NoInitLinear(in_features=4096, out_features=32000, bias=False)\n",
173 | " )\n",
174 | ")"
175 | ]
176 | },
177 | "execution_count": 5,
178 | "metadata": {},
179 | "output_type": "execute_result"
180 | }
181 | ],
182 | "source": [
183 | "model"
184 | ]
185 | }
186 | ],
187 | "metadata": {
188 | "kernelspec": {
189 | "display_name": "pefty_llama",
190 | "language": "python",
191 | "name": "pefty_llama"
192 | },
193 | "language_info": {
194 | "codemirror_mode": {
195 | "name": "ipython",
196 | "version": 3
197 | },
198 | "file_extension": ".py",
199 | "mimetype": "text/x-python",
200 | "name": "python",
201 | "nbconvert_exporter": "python",
202 | "pygments_lexer": "ipython3",
203 | "version": "3.10.10"
204 | },
205 | "orig_nbformat": 4
206 | },
207 | "nbformat": 4,
208 | "nbformat_minor": 2
209 | }
210 |
--------------------------------------------------------------------------------
/pefty_llama/configuration.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import dataclasses
3 | import torch
4 |
5 |
6 | @dataclasses.dataclass
7 | class LLaMAConfig:
8 | dim: int
9 | n_layers: int
10 | n_heads: int
11 | vocab_size: int = 32000
12 | max_seq_length: int = 2048
13 | dtype: Any = torch.float16
14 | pad_token_id: int = 0
15 | bos_token_id: int = 1
16 | eos_token_id: int = 2
17 | use_8bit: bool = False
18 | gradient_checkpointing: bool = False
19 |
20 | @property
21 | def head_dim(self):
22 | return self.dim // self.n_heads
23 |
24 | def to_dict(self):
25 | return dataclasses.asdict(self)
26 |
27 |
28 | LLAMA_7B_CONFIG = LLaMAConfig(
29 | dim=4096,
30 | n_layers=32,
31 | n_heads=32,
32 | )
33 | DEBUG_CONFIG = LLaMAConfig(
34 | dim=64,
35 | n_layers=3,
36 | n_heads=4,
37 | )
38 |
39 | LLAMA_CONFIG_DICT = {
40 | "7b": LLAMA_7B_CONFIG,
41 | "debug": DEBUG_CONFIG,
42 | }
43 |
--------------------------------------------------------------------------------
/pefty_llama/modeling.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/zphang/minimal-llama/blob/c37e481136f118a16f77f50cdf5e867ed5dafbf9/minimal_llama/pref/llama_simple2.py
2 |
3 | import os
4 | import json
5 | import math
6 | import dataclasses
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | import bitsandbytes as bnb
13 | import tqdm.auto as tqdm
14 |
15 | from accelerate import init_empty_weights
16 | from transformers.utils.bitsandbytes import set_module_8bit_tensor_to_device
17 | from transformers import (
18 | LlamaConfig as HF_LlamaConfig,
19 | LlamaForCausalLM as HF_Llama,
20 | )
21 |
22 |
23 | @dataclasses.dataclass
24 | class LLaMAConfig:
25 | dim: int
26 | n_layers: int
27 | n_heads: int
28 | vocab_size: int = 32000
29 | max_seq_length: int = 2048
30 | dtype = torch.float16
31 | pad_token_id: int = 0
32 | bos_token_id: int = 1
33 | eos_token_id: int = 2
34 | use_8bit: bool = False
35 |
36 | @property
37 | def head_dim(self):
38 | return self.dim // self.n_heads
39 |
40 |
41 | LLAMA_7B_CONFIG = LLaMAConfig(
42 | dim=4096,
43 | n_layers=32,
44 | n_heads=32,
45 | )
46 |
47 | LLAMA_CONFIG_DICT = {
48 | "7b": LLAMA_7B_CONFIG,
49 | }
50 |
51 |
52 | class LLaMAModel(nn.Module):
53 | def __init__(self, config: LLaMAConfig):
54 | super().__init__()
55 | self.config = config
56 | self.model = LLaMAInnerModel(config)
57 | self.lm_head = NoInitLinear(config.dim, config.vocab_size, bias=False, dtype=config.dtype)
58 |
59 | @classmethod
60 | def from_pretrained(cls, model_name_or_path, use_8bit=False):
61 | """Load model from a huggingface model name or path."""
62 | hf_config = HF_LlamaConfig.from_pretrained(model_name_or_path)
63 |
64 | config = LLaMAConfig(
65 | vocab_size=hf_config.vocab_size,
66 | dim=hf_config.hidden_size,
67 | n_layers=hf_config.num_hidden_layers,
68 | n_heads=hf_config.num_attention_heads,
69 | max_seq_length=hf_config.max_position_embeddings,
70 | dtype=hf_config.dtype,
71 | pad_token_id=hf_config.pad_token_id,
72 | bos_token_id=hf_config.bos_token_id,
73 | eos_token_id=hf_config.eos_token_id,
74 | use_8bit=use_8bit,
75 | )
76 |
77 | raise NotImplementedError()
78 | model = cls(config)
79 |
80 | # Load weights from huggingface model to the disk if needed
81 | if os.path.isdir(model_name_or_path):
82 | hf_model_path = model_name_or_path
83 | else:
84 | hf_model_path = hf_config.cache_dir
85 | hf_model = HF_LLaMA.from_pretrained(hf_model_path, config=hf_config)
86 | hf_model.save_pretrained(hf_model_path)
87 |
88 | return model
89 |
90 | def forward(self,
91 | input_ids):
92 | """Forward pass (with full decode sequence, intended for training or loss-scoring)
93 |
94 | :param input_ids: [batch_size, seq_len]
95 | :return: logits [batch_size, seq_len]
96 | """
97 | # 1) Create masks
98 | # decoder mask
99 | # [batch_size, num_heads=1, q_len=seq_len, kv_len=seq_len]
100 | attention_mask = create_attention_mask(input_ids=input_ids, dtype=self.config.dtype)
101 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids)
102 | cos, sin = self.get_cos_sin(rope_embed_ids)
103 |
104 | # 2) Forward pass
105 | # [batch_size, seq_len, hidden_dim]
106 | model_out = self.model(
107 | input_ids,
108 | attention_mask=attention_mask,
109 | cos=cos, sin=sin,
110 | )
111 | # [batch_size, seq_len, vocab_size]
112 | logits = self.lm_head(model_out["hidden_states"])
113 | return logits
114 |
115 | def init_kv_cache(self, input_ids):
116 | # noinspection GrazieInspection
117 | """Initialize KV cache for decoding.
118 |
119 | A KV cache consists of a list of dicts (one per layer):
120 | dict(
121 | key = [batch_size, num_heads, kv_seq_len=0, head_dim]
122 | value = [batch_size, num_heads, kv_seq_len=0, head_dim]
123 | )
124 |
125 | :param input_ids: [batch_size, dec_seq_len]
126 | :return: 0-length kv_cache
127 | """
128 | kv_cache = []
129 | batch_size = input_ids.shape[0]
130 | num_heads = self.config.n_heads
131 | head_dim = self.config.head_dim
132 | for layer in self.model.layers:
133 | device = layer.input_layernorm.weight.device
134 | kv_cache.append({
135 | "key": torch.zeros([batch_size, num_heads, 0, head_dim]).to(device=device, dtype=self.config.dtype),
136 | "value": torch.zeros([batch_size, num_heads, 0, head_dim]).to(device=device, dtype=self.config.dtype),
137 | })
138 | return kv_cache
139 |
140 | def generate(self, input_ids, generation_length: 20):
141 | """Generate tokens with efficient caching of KV.
142 |
143 | TODO: Add stopping conditions
144 | TODO: Add sampling capabilities
145 |
146 | :param input_ids: [batch_size, enc_seq_len]
147 | :param generation_length: int
148 | :return: [batch_size, generation_length]
149 | """
150 | original_input_ids = input_ids
151 | batch_size, seq_len = input_ids.shape
152 | # noinspection PyUnresolvedReferences
153 | num_valid_tokens = (input_ids != self.config.pad_token_id).long().sum(dim=1)
154 |
155 | # 1) Setup
156 | if input_ids is None:
157 | # [batch_size, dec_seq_len=1]
158 | input_ids = torch.LongTensor(
159 | [[self.config.pad_token_id]] * batch_size
160 | ).to(self.lm_head.weights.device)
161 | # See: init_kv_cache. list[dict]
162 | kv_cache = self.init_kv_cache(input_ids)
163 | generated_token_ids_list = [original_input_ids]
164 | total_seq_len = seq_len
165 |
166 | # 2) First encoding
167 | # [batch_size=1, num_heads=1, q_len=1, kv_len=1]
168 | attention_mask = create_attention_mask(input_ids=input_ids, dtype=self.config.dtype)
169 | # dict(
170 | # hidden_states = [batch_size, dec_seq_len=decode_step+1, hidden_dim]
171 | # kv_cache = list[dict(
172 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
173 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
174 | # )]
175 | # )
176 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids)
177 | cos, sin = self.get_cos_sin(rope_embed_ids)
178 | model_out = self.model(
179 | input_ids=input_ids,
180 | attention_mask=attention_mask,
181 | cos=cos, sin=sin,
182 | kv_cache=kv_cache,
183 | )
184 | logits = self.lm_head(model_out["hidden_states"])
185 | kv_cache = model_out["kv_cache"]
186 | generated_token_ids = logits.argmax(-1)[
187 | torch.arange(batch_size, dtype=torch.long, device=input_ids.device),
188 | num_valid_tokens-1,
189 | ][:, None]
190 | generated_token_ids_list.append(generated_token_ids)
191 | input_ids = generated_token_ids
192 |
193 | # 2.1 shift KV cache
194 | for layer_kv_cache in kv_cache:
195 | for i in range(batch_size):
196 | layer_kv_cache["key"] = shift_kv_cache_right(
197 | layer_kv_cache["key"], num_valid_tokens=num_valid_tokens)
198 | layer_kv_cache["value"] = shift_kv_cache_right(
199 | layer_kv_cache["value"], num_valid_tokens=num_valid_tokens)
200 |
201 | # 3) Subsequent steps
202 | for decode_step in range(generation_length-1):
203 | num_valid_tokens += 1
204 | total_seq_len += 1
205 | # [batch_size=1, num_heads=1, q_len=1, kv_len=1]
206 | attention_mask = convert_mask_to_soft_mask(create_generation_attention_mask(
207 | batch_size=batch_size,
208 | seq_len=total_seq_len,
209 | num_valid_tokens=num_valid_tokens,
210 | device=input_ids.device,
211 | ), dtype=self.config.dtype)
212 | # dict(
213 | # hidden_states = [batch_size, dec_seq_len=decode_step+1, hidden_dim]
214 | # kv_cache = list[dict(
215 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
216 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
217 | # )]
218 | # )
219 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids) + num_valid_tokens
220 | cos, sin = self.get_cos_sin(rope_embed_ids)
221 | model_out = self.model(
222 | input_ids=input_ids,
223 | attention_mask=attention_mask,
224 | kv_cache=kv_cache,
225 | cos=cos, sin=sin,
226 | )
227 | # [batch_size, dec_seq_len=1, vocab_size]
228 | logits = self.lm_head(model_out["hidden_states"])
229 | kv_cache = model_out["kv_cache"]
230 | # [batch_size, dec_seq_len=1]
231 | generated_token_ids = logits.argmax(-1)[:, -1:]
232 | generated_token_ids_list.append(generated_token_ids)
233 | input_ids = generated_token_ids
234 | return torch.cat(generated_token_ids_list, dim=1)
235 |
236 | def get_cos_sin(self, rope_embed_ids):
237 | cos = F.embedding(
238 | rope_embed_ids,
239 | self.model.layers[0].self_attn.rotary_emb.cos_cached[0, 0]
240 | ).to(self.config.dtype)
241 | sin = F.embedding(
242 | rope_embed_ids,
243 | self.model.layers[0].self_attn.rotary_emb.sin_cached[0, 0]
244 | ).to(self.config.dtype)
245 | cos, sin = cos[:, None, :, :], sin[:, None, :, :]
246 | return cos, sin
247 |
248 |
249 | class LLaMAInnerModel(nn.Module):
250 | def __init__(self, config: LLaMAConfig):
251 | super().__init__()
252 | self.config = config
253 | self.embed_tokens = nn.Embedding(config.vocab_size, config.dim, dtype=config.dtype)
254 | self.layers = nn.ModuleList([
255 | LLaMALayer(config=config)
256 | for _ in range(config.n_layers)
257 | ])
258 | self.norm = RMSNorm(dim=config.dim)
259 |
260 | def forward(self,
261 | input_ids,
262 | attention_mask,
263 | cos, sin,
264 | kv_cache=None):
265 | """
266 | :param input_ids: [batch_size, seq_len]
267 | :param attention_mask: [batch_size=1, num_heads=1, seq_len, seq_len]
268 | :param kv_cache: See init_kv_cache.
269 | We use the presence of kv_cache to determine if we're generating
270 | :param cos:
271 | :param sin:
272 | """
273 | hidden_states = self.embed_tokens(input_ids)
274 |
275 | new_kv_cache = []
276 | for layer_i, layer in enumerate(self.layers):
277 | if kv_cache:
278 | # dict(
279 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
280 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
281 | # )
282 | layer_kv_cache = kv_cache[layer_i]
283 | else:
284 | layer_kv_cache = None
285 |
286 | layer_out = layer(
287 | hidden_states=hidden_states,
288 | attention_mask=attention_mask,
289 | kv_cache=layer_kv_cache,
290 | cos=cos, sin=sin,
291 | )
292 | hidden_states = layer_out["hidden_states"]
293 | if kv_cache:
294 | new_kv_cache.append(layer_out["kv_cache"])
295 | hidden_states = self.norm(hidden_states)
296 | output = {
297 | "hidden_states": hidden_states
298 | }
299 | if kv_cache:
300 | output["kv_cache"] = new_kv_cache
301 | return output
302 |
303 |
304 | class LLaMALayer(nn.Module):
305 | def __init__(self, config: LLaMAConfig):
306 | super().__init__()
307 | self.config = config
308 | self.self_attn = Attention(config=config)
309 | self.mlp = MLP(config=config)
310 | self.input_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype)
311 | self.post_attention_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype)
312 |
313 | def forward(
314 | self,
315 | hidden_states,
316 | attention_mask,
317 | cos, sin,
318 | kv_cache=None,
319 | ):
320 | # 1) Self-attention
321 | # [batch_size, seq_len, hidden_dim]
322 | normed_hidden_states = self.input_layernorm(hidden_states)
323 | # dict(
324 | # attn_output = [batch_size, seq_len, hidden_dim]
325 | # kv_cache = dict(
326 | # key = [batch_size, num_heads, kv_seq_len, head_dim]
327 | # value = [batch_size, num_heads, kv_seq_len, head_dim]
328 | # )
329 | # )
330 | check_nan(normed_hidden_states)
331 | raw_self_attn_output = self.self_attn(
332 | hidden_states=normed_hidden_states,
333 | attention_mask=attention_mask,
334 | kv_cache=kv_cache,
335 | cos=cos, sin=sin,
336 | )
337 | # [batch_size, seq_len, hidden_dim]
338 | hidden_states = hidden_states + raw_self_attn_output["attn_output"]
339 | check_nan(hidden_states)
340 | # 2) FFN
341 | # [batch_size, seq_len, hidden_dim]
342 | hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
343 | check_nan(hidden_states)
344 | if kv_cache:
345 | return {
346 | "hidden_states": hidden_states,
347 | "kv_cache": raw_self_attn_output["kv_cache"],
348 | }
349 |
350 | return {"hidden_states": hidden_states}
351 |
352 |
353 | class MLP(nn.Module):
354 | def __init__(
355 | self,
356 | config: LLaMAConfig,
357 | multiple_of: int = 256,
358 | ):
359 | super().__init__()
360 | dim = config.dim
361 | hidden_dim = 4 * dim
362 | hidden_dim = int(2 * hidden_dim / 3)
363 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
364 |
365 | if config.use_8bit:
366 | self.gate_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False)
367 | self.up_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False)
368 | self.down_proj = NoInit8bitLinear(hidden_dim, dim, bias=False, threshold=6.0, has_fp16_weights=False)
369 | else:
370 | self.gate_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype)
371 | self.up_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype)
372 | self.down_proj = NoInitLinear(hidden_dim, dim, bias=False, dtype=config.dtype)
373 |
374 | def forward(self, x):
375 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
376 |
377 |
378 | class RMSNorm(torch.nn.Module):
379 | def __init__(self, dim: int, eps: float = 1e-6, dtype=torch.float16):
380 | super().__init__()
381 | self.eps = eps
382 | self.weight = nn.Parameter(torch.ones(dim, dtype=dtype))
383 |
384 | def _norm(self, x):
385 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
386 |
387 | def forward(self, x):
388 | output = self._norm(x.float()).type_as(x)
389 | return output * self.weight
390 |
391 |
392 | class Attention(nn.Module):
393 | def __init__(self, config: LLaMAConfig):
394 | super().__init__()
395 | self.config = config
396 | self.n_heads = config.n_heads
397 | self.head_dim = config.dim // config.n_heads
398 |
399 | if config.use_8bit:
400 | self.q_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
401 | self.k_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
402 | self.v_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
403 | self.o_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
404 | else:
405 | self.q_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
406 | self.k_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
407 | self.v_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
408 | self.o_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
409 | self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
410 |
411 | def forward(self, hidden_states, attention_mask, cos, sin, kv_cache=None):
412 | """
413 | precomputed_kv_hidden_states is for init (pre-compute KV activations, e.g. for added prefixes)
414 | kv_cache is for generation (cached past KV)
415 | """
416 | batch_size, q_seq_len, hidden_dim = hidden_states.size()
417 |
418 | # (batch_size, num_heads, q_seq_len, head_dim)
419 | query_states = self.q_proj(hidden_states).view(
420 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
421 | key_states = self.k_proj(hidden_states).view(
422 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
423 | value_states = self.v_proj(hidden_states).view(
424 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
425 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos=cos, sin=sin)
426 | if kv_cache:
427 | key_states = torch.cat([kv_cache["key"], key_states], dim=2)
428 | value_states = torch.cat([kv_cache["value"], value_states], dim=2)
429 |
430 | attn_output = torch.nn.functional.scaled_dot_product_attention(
431 | query=query_states,
432 | key=key_states,
433 | value=value_states,
434 | attn_mask=attention_mask,
435 | )
436 | # (batch_size, q_seq_len, hidden_dim)
437 | attn_output = attn_output.transpose(1, 2).contiguous().view(
438 | batch_size, q_seq_len, hidden_dim,
439 | )
440 | attn_output = self.o_proj(attn_output)
441 | check_nan(attn_output)
442 | if kv_cache:
443 | new_kv_cache = {"key": key_states, "value": value_states}
444 | return {"attn_output": attn_output, "kv_cache": new_kv_cache}
445 |
446 | return {"attn_output": attn_output}
447 |
448 |
449 | class RotaryEmbedding(torch.nn.Module):
450 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
451 | super().__init__()
452 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device=device) / dim))
453 | self.register_buffer("inv_freq", inv_freq)
454 |
455 | # Build here to make `torch.jit.trace` work.
456 | self.max_seq_len_cached = max_position_embeddings
457 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device).to(self.inv_freq.dtype)
458 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
459 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
460 | emb = torch.cat((freqs, freqs), dim=-1)
461 | self.cos_cached = emb.cos()[None, None, :, :]
462 | self.sin_cached = emb.sin()[None, None, :, :]
463 |
464 | def forward(self, x, seq_len=None):
465 | # x: [bs, num_attention_heads, seq_len, head_size]
466 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
467 | if seq_len > self.max_seq_len_cached:
468 | self.max_seq_len_cached = seq_len
469 | t = torch.arange(self.max_seq_len_cached, device=x.device).to(self.inv_freq.dtype)
470 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
471 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
472 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
473 | self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype)
474 | self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype)
475 | return (
476 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device),
477 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device),
478 | )
479 |
480 |
481 | def rotate_half(x):
482 | """Rotates half the hidden dims of the input."""
483 | x1 = x[..., : x.shape[-1] // 2]
484 | x2 = x[..., x.shape[-1] // 2:]
485 | return torch.cat((-x2, x1), dim=-1)
486 |
487 |
488 | def apply_rotary_pos_emb(q, k, cos, sin):
489 | q_embed = (q * cos) + (rotate_half(q) * sin)
490 | k_embed = (k * cos) + (rotate_half(k) * sin)
491 | return q_embed, k_embed
492 |
493 |
494 | def create_attention_mask(input_ids,
495 | dtype=torch.float32,
496 | return_soft_mask=True):
497 | """Create mask for decoder attention.
498 |
499 | Decoder masks have two use-cases:
500 |
501 | 1) Training, where we see the full decoder sequence. In that case,
502 | we want a causal mask.
503 |
504 | 2) Generation, where we only see one token at once. In that case,
505 | it doesn't really matter what we give, we can just give a 1.
506 | (i.e. seq_len = 1)
507 |
508 | Note that in both cases we do not care about which decoder_input_ids
509 | are valid, and also we can always simply broadcast over the batch size
510 | and heads.
511 |
512 | :param input_ids: [batch_size, seq_len]
513 | :param dtype: dtype
514 | :param return_soft_mask: whether to return mask or logits-mask
515 | :return: float [batch_size=1, num_heads=1, q_len=seq_len, kv_len=seq_len]
516 | """
517 | batch_size, seq_length = input_ids.shape
518 | # [seq_len]
519 | seq_ids = torch.arange(seq_length, device=input_ids.device)
520 | # [seq_len, seq_len]
521 | causal_mask = seq_ids[None, :].repeat(seq_length, 1) <= seq_ids[:, None]
522 | # [batch_size=1, num_heads=1, seq_len, seq_len]
523 | causal_mask = causal_mask[None, None, :, :]
524 | if return_soft_mask:
525 | return convert_mask_to_soft_mask(causal_mask, dtype=dtype)
526 | else:
527 | return causal_mask
528 |
529 |
530 | def convert_mask_to_soft_mask(mask, dtype):
531 | """Convert binary mask to mask that can be added to logits.
532 |
533 | (i.e. 0 for attention, large negative for masked)
534 | """
535 | mask = mask.to(dtype=dtype)
536 | mask = (1.0 - mask) * torch.finfo(dtype).min
537 | return mask
538 |
539 |
540 | class NoInitLinear(nn.Linear):
541 | def reset_parameters(self) -> None:
542 | pass
543 |
544 |
545 | class NoInit8bitLinear(bnb.nn.Linear8bitLt):
546 | def reset_parameters(self) -> None:
547 | pass
548 |
549 |
550 | def get_linear_class(use_8bit=False):
551 | if use_8bit:
552 | return NoInit8bitLinear
553 | else:
554 | return NoInitLinear
555 |
556 |
557 | class NoInitEmbedding(nn.Embedding):
558 | def reset_parameters(self) -> None:
559 | pass
560 |
561 |
562 | def check_nan(x):
563 | if torch.isnan(x).any():
564 | import pdb
565 | pdb.set_trace()
566 |
567 |
568 | def create_model(model_name, hf_path, use_8bit=False, device=None):
569 | config = LLAMA_CONFIG_DICT[model_name]
570 |
571 | with open(os.path.join(hf_path, "pytorch_model.bin.index.json")) as f:
572 | weight_map = json.load(f)["weight_map"]
573 |
574 | filename_list = sorted(list(set(weight_map.values())))
575 | if device is None:
576 | # TODO: Local rank
577 | device = torch.device("cuda:0")
578 | if use_8bit:
579 | config = dataclasses.replace(config, use_8bit=True)
580 | with init_empty_weights():
581 | model = LLaMAModel(config=config)
582 | state_keys = set(model.state_dict())
583 | filename_list = sorted(list(set(weight_map.values())))
584 | for filename in tqdm.tqdm(filename_list):
585 | loaded = torch.load(os.path.join(hf_path, filename), map_location="cpu")
586 | for k, v in loaded.items():
587 | set_module_8bit_tensor_to_device(model, tensor_name=k, device=device, value=v)
588 | state_keys.remove(k)
589 | assert not state_keys
590 | else:
591 | # noinspection PyUnresolvedReferences
592 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
593 | model = LLaMAModel(config=config).cuda()
594 | torch.set_default_tensor_type(torch.FloatTensor)
595 | state_keys = set(model.state_dict())
596 | for filename in tqdm.tqdm(filename_list):
597 | loaded = torch.load(os.path.join(hf_path, filename), map_location="cpu")
598 | model.load_state_dict(loaded, strict=False)
599 | for k in loaded:
600 | state_keys.remove(k)
601 | return model
602 |
603 |
604 | def shift_kv_cache_right(layer_cache, num_valid_tokens):
605 | """
606 | :param layer_cache: left-aligned kv cache element, [batch_size, num_heads, seq_len, dim]
607 | :param num_valid_tokens: [batch_size]
608 | :return:
609 | """
610 | batch_size = layer_cache.shape[0]
611 | # noinspection PyUnresolvedReferences
612 | return torch.stack([
613 | torch.cat([
614 | layer_cache[i, :, num_valid_tokens[i]:, :],
615 | layer_cache[i, :, :num_valid_tokens[i], :],
616 | ], dim=1)
617 | for i in range(batch_size)
618 | ], dim=0)
619 |
620 |
621 | def create_generation_attention_mask(batch_size, seq_len, num_valid_tokens, device):
622 | """
623 | :param batch_size: int
624 | :param seq_len: int
625 | :param num_valid_tokens: [batch_size]
626 | :param device:
627 | :return:
628 | """
629 | # For right-aligned, based on num_valid_tokens
630 | # noinspection PyTypeChecker
631 | attn_mask = torch.zeros([batch_size, 1, 1, seq_len], dtype=bool)
632 | for i in range(batch_size):
633 | valid = num_valid_tokens[i]
634 | # noinspection PyTypeChecker
635 | # attn_mask[i, 0, -valid:, -valid:] = torch.tril(torch.ones([valid, valid], dtype=bool))
636 | attn_mask[i, 0, 0, -valid:] = True
637 | return attn_mask.to(device=device)
638 |
639 |
640 | def create_casual_attention_mask(seq_len, device):
641 | # noinspection PyTypeChecker
642 | attn_mask = torch.tril(torch.ones([seq_len, seq_len], dtype=bool))[None, None, :, :]
643 | return attn_mask.to(device=device)
644 |
645 |
646 | def create_rope_embed_ids(input_ids):
647 | pad_token_id = 0
648 | max_position = 2047 # These will not actually be used, as they are masked out by the attention mask
649 | x = (input_ids != pad_token_id).cumsum(-1) - 1
650 | x[input_ids == pad_token_id] = max_position
651 | return x
652 |
--------------------------------------------------------------------------------
/pefty_llama/modeling_peft.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/zphang/minimal-llama/blob/c37e481136f118a16f77f50cdf5e867ed5dafbf9/minimal_llama/pref/llama_simple2.py
2 |
3 | import os
4 | import json
5 | import math
6 | import dataclasses
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | import bitsandbytes as bnb
13 | import tqdm.auto as tqdm
14 |
15 | from accelerate import init_empty_weights
16 | from transformers.utils.bitsandbytes import set_module_8bit_tensor_to_device
17 | from transformers import (
18 | LlamaConfig as HF_LlamaConfig,
19 | LlamaForCausalLM as HF_Llama,
20 | )
21 | import pefty_llama.peft as peft
22 | from pefty_llama.configuration import LLaMAConfig, LLAMA_CONFIG_DICT
23 |
24 |
25 | class LLaMAModel(nn.Module):
26 | def __init__(self, config: LLaMAConfig, peft_config: peft.PeftConfig):
27 | super().__init__()
28 | self.config = config
29 | self.peft_config = peft_config
30 | self.model = LLaMAInnerModel(config=config, peft_config=peft_config)
31 | self.lm_head = NoInitLinear(config.dim, config.vocab_size, bias=False, dtype=config.dtype)
32 |
33 | if self.peft_config.peft_mode == peft.PEFT_PREFIX:
34 | self.peft_prefixes = peft.SoftPrefixes(config=config, peft_config=peft_config)
35 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_embedding:
36 | self.peft_lora_lm_head = peft.LoRA(config=config, peft_config=peft_config,
37 | output_dim=config.vocab_size)
38 |
39 | def forward(self,
40 | input_ids):
41 | """Forward pass (with full decode sequence, intended for training or loss-scoring)
42 |
43 | :param input_ids: [batch_size, seq_len]
44 | :return: logits [batch_size, seq_len]
45 | """
46 | # 1) Create masks
47 | # decoder mask
48 | # [batch_size, num_heads=1, q_len=seq_len, kv_len=seq_len]
49 | attention_mask = create_attention_mask(input_ids=input_ids, dtype=self.config.dtype)
50 | input_ids_for_rope = input_ids
51 | if self.peft_config.peft_mode == peft.PEFT_PREFIX:
52 | attention_mask = torch.cat([
53 | zeros_like([1, 1, input_ids.shape[1], self.peft_config.num_prefix_tokens], tensor=attention_mask),
54 | attention_mask,
55 | ], dim=3)
56 |
57 | if self.peft_config.peft_mode in peft.PEFT_PROMPT:
58 | input_ids_for_rope = torch.cat([
59 | torch.ones([input_ids.shape[0], self.peft_config.num_prefix_tokens],
60 | dtype=input_ids.dtype, device=input_ids.device),
61 | input_ids,
62 | ], dim=1)
63 | # Easier to just remake the attention mask
64 | attention_mask = create_attention_mask(input_ids=input_ids_for_rope, dtype=self.config.dtype)
65 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids_for_rope)
66 | cos, sin = self.get_cos_sin(rope_embed_ids)
67 |
68 | if self.peft_config.peft_mode == peft.PEFT_PREFIX:
69 | kv_cache = self.peft_prefixes(batch_size=input_ids.shape[0])
70 | else:
71 | kv_cache = None
72 |
73 | # 2) Forward pass
74 | # [batch_size, seq_len, hidden_dim]
75 | model_out = self.model(
76 | input_ids,
77 | attention_mask=attention_mask,
78 | cos=cos, sin=sin,
79 | kv_cache=kv_cache,
80 | )
81 | # [batch_size, seq_len, vocab_size]
82 | logits = self.lm_head(model_out["hidden_states"])
83 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_embedding:
84 | logits += self.peft_lora_lm_head(model_out["hidden_states"])
85 | return logits
86 |
87 | def init_kv_cache(self, input_ids):
88 | # noinspection GrazieInspection
89 | """Initialize KV cache for decoding.
90 |
91 | A KV cache consists of a list of dicts (one per layer):
92 | dict(
93 | key = [batch_size, num_heads, kv_seq_len=0, head_dim]
94 | value = [batch_size, num_heads, kv_seq_len=0, head_dim]
95 | )
96 |
97 | :param input_ids: [batch_size, dec_seq_len]
98 | :return: 0-length kv_cache
99 | """
100 | kv_cache = []
101 | batch_size = input_ids.shape[0]
102 | num_heads = self.config.n_heads
103 | head_dim = self.config.head_dim
104 | for layer in self.model.layers:
105 | device = layer.input_layernorm.weight.device
106 | kv_cache.append({
107 | "key": torch.zeros([batch_size, num_heads, 0, head_dim]).to(device=device, dtype=self.config.dtype),
108 | "value": torch.zeros([batch_size, num_heads, 0, head_dim]).to(device=device, dtype=self.config.dtype),
109 | })
110 | return kv_cache
111 |
112 | def generate(self, input_ids, generation_length: int = 20,
113 | return_output_only=True):
114 | """Generate tokens with efficient caching of KV.
115 |
116 | TODO: Add stopping conditions
117 | TODO: Add sampling capabilities
118 |
119 | :param input_ids: [batch_size, enc_seq_len]
120 | :param generation_length: int
121 | :param return_output_only:
122 | :return: [batch_size, generation_length]
123 | """
124 | original_input_ids = input_ids
125 | batch_size, seq_len = input_ids.shape
126 | # noinspection PyUnresolvedReferences
127 | num_valid_tokens = (input_ids != self.config.pad_token_id).long().sum(dim=1)
128 |
129 | # 1) Setup
130 | if input_ids is None:
131 | # [batch_size, dec_seq_len=1]
132 | input_ids = torch.LongTensor(
133 | [[self.config.pad_token_id]] * batch_size
134 | ).to(self.lm_head.weights.device)
135 | # See: init_kv_cache. list[dict]
136 | if self.peft_config.peft_mode == peft.PEFT_PREFIX:
137 | kv_cache = self.peft_prefixes(batch_size=input_ids.shape[0])
138 | num_valid_kv_cache = num_valid_tokens + self.peft_config.num_prefix_tokens
139 | else:
140 | kv_cache = self.init_kv_cache(input_ids)
141 | num_valid_kv_cache = num_valid_tokens
142 | generated_token_ids_list = [original_input_ids]
143 | total_seq_len = seq_len
144 |
145 | # 2) First encoding
146 | # [batch_size=1, num_heads=1, q_len=1, kv_len=1]
147 | attention_mask = create_attention_mask(input_ids=input_ids, dtype=self.config.dtype)
148 | input_ids_for_rope = input_ids
149 | # dict(
150 | # hidden_states = [batch_size, dec_seq_len=decode_step+1, hidden_dim]
151 | # kv_cache = list[dict(
152 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
153 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
154 | # )]
155 | # )
156 | if self.peft_config.peft_mode in (peft.PEFT_PREFIX, peft.PEFT_PROMPT):
157 | num_prefix_tokens = self.peft_config.num_prefix_tokens
158 | total_seq_len += num_prefix_tokens
159 | # [batch_size, num_heads=1, q_len=seq_len, kv_len=num_prefix_tokens + dec_seq_len]
160 | attention_mask = torch.cat([
161 | zeros_like([1, 1, input_ids.shape[1], num_prefix_tokens], tensor=attention_mask),
162 | attention_mask,
163 | ], dim=3)
164 |
165 | if self.peft_config.peft_mode in peft.PEFT_PROMPT:
166 | input_ids_for_rope = torch.cat([
167 | torch.ones([input_ids.shape[0], self.peft_config.num_prefix_tokens],
168 | dtype=input_ids.dtype, device=input_ids.device),
169 | input_ids,
170 | ], dim=1)
171 | # Easier to just remake the attention mask
172 | attention_mask = create_attention_mask(input_ids=input_ids_for_rope, dtype=self.config.dtype)
173 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids_for_rope)
174 | cos, sin = self.get_cos_sin(rope_embed_ids)
175 | model_out = self.model(
176 | input_ids=input_ids,
177 | attention_mask=attention_mask,
178 | cos=cos, sin=sin,
179 | kv_cache=kv_cache,
180 | )
181 | logits = self.lm_head(model_out["hidden_states"])
182 | kv_cache = model_out["kv_cache"]
183 | generated_token_ids = logits.argmax(-1)[
184 | torch.arange(batch_size, dtype=torch.long, device=input_ids.device),
185 | num_valid_tokens-1,
186 | ][:, None]
187 | generated_token_ids_list.append(generated_token_ids)
188 | input_ids = generated_token_ids
189 |
190 | # 3) Subsequent steps
191 | for decode_step in range(generation_length-1):
192 | num_valid_tokens += 1
193 | total_seq_len += 1
194 | # [batch_size=1, num_heads=1, q_len=1, kv_len=1]
195 | attention_mask = convert_mask_to_soft_mask(create_generation_attention_mask(
196 | batch_size=batch_size,
197 | seq_len=total_seq_len,
198 | num_valid_tokens=num_valid_tokens,
199 | device=input_ids.device,
200 | ), dtype=self.config.dtype)
201 | # dict(
202 | # hidden_states = [batch_size, dec_seq_len=decode_step+1, hidden_dim]
203 | # kv_cache = list[dict(
204 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
205 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
206 | # )]
207 | # )
208 | rope_embed_ids = create_rope_embed_ids(input_ids=input_ids) + num_valid_tokens[:, None]
209 | cos, sin = self.get_cos_sin(rope_embed_ids)
210 | model_out = self.model(
211 | input_ids=input_ids,
212 | attention_mask=attention_mask,
213 | kv_cache=kv_cache,
214 | cos=cos, sin=sin,
215 | )
216 | # [batch_size, dec_seq_len=1, vocab_size]
217 | logits = self.lm_head(model_out["hidden_states"])
218 | kv_cache = model_out["kv_cache"]
219 | # [batch_size, dec_seq_len=1]
220 | generated_token_ids = logits.argmax(-1)[:, -1:]
221 | generated_token_ids_list.append(generated_token_ids)
222 | input_ids = generated_token_ids
223 | output = torch.cat(generated_token_ids_list, dim=1)
224 | if return_output_only:
225 | output = output[:, seq_len:]
226 | return output
227 |
228 | def get_cos_sin(self, rope_embed_ids):
229 | cos = F.embedding(
230 | rope_embed_ids,
231 | self.model.layers[0].self_attn.rotary_emb.cos_cached[0, 0].to(rope_embed_ids.device)
232 | ).to(self.config.dtype)
233 | sin = F.embedding(
234 | rope_embed_ids,
235 | self.model.layers[0].self_attn.rotary_emb.sin_cached[0, 0].to(rope_embed_ids.device)
236 | ).to(self.config.dtype)
237 | cos, sin = cos[:, None, :, :], sin[:, None, :, :]
238 | return cos, sin
239 |
240 | def gradient_checkpointing_enable(self):
241 | self.config.gradient_checkpointing = True
242 |
243 | def enable_input_require_grads(self):
244 | def make_inputs_require_grads(module, input, output):
245 | output.requires_grad_(True)
246 | self.model.embed_tokens.register_forward_hook(make_inputs_require_grads)
247 |
248 |
249 | class LLaMAInnerModel(nn.Module):
250 | def __init__(self, config: LLaMAConfig, peft_config: peft.PeftConfig):
251 | super().__init__()
252 | self.config = config
253 | self.peft_config = peft_config
254 | self.embed_tokens = nn.Embedding(config.vocab_size, config.dim, dtype=config.dtype)
255 | self.layers = nn.ModuleList([
256 | LLaMALayer(config=config, peft_config=peft_config)
257 | for _ in range(config.n_layers)
258 | ])
259 | self.norm = RMSNorm(dim=config.dim)
260 |
261 | if self.peft_config.peft_mode == peft.PEFT_PROMPT:
262 | self.peft_prompt = peft.AddSoftPrompt(config=config, peft_config=peft_config)
263 |
264 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_embedding:
265 | self.peft_lora_embed = peft.LoRAEmbed(config=config, peft_config=peft_config)
266 |
267 | def forward(self,
268 | input_ids,
269 | attention_mask,
270 | cos, sin,
271 | kv_cache=None):
272 | """
273 | :param input_ids: [batch_size, seq_len]
274 | :param attention_mask: [batch_size=1, num_heads=1, seq_len, seq_len]
275 | :param cos: for RoPE
276 | :param sin: for RoPE
277 | :param kv_cache: See init_kv_cache.
278 | """
279 | hidden_states = self.embed_tokens(input_ids).to(self.config.dtype)
280 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_embedding:
281 | hidden_states += self.peft_lora_embed(input_ids).to(self.config.dtype)
282 |
283 | if self.peft_config.peft_mode == peft.PEFT_PROMPT:
284 | if kv_cache is None or kv_cache[0]["key"].shape[2] == 0:
285 | # Only add prompt if kv_cache is None (full forward pass) or if kv_cache is empty (first decode step)
286 | hidden_states = self.peft_prompt(hidden_states)
287 |
288 | new_kv_cache = []
289 | for layer_i, layer in enumerate(self.layers):
290 | if kv_cache:
291 | # dict(
292 | # key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
293 | # value = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
294 | # )
295 | layer_kv_cache = kv_cache[layer_i]
296 | else:
297 | layer_kv_cache = None
298 |
299 | if self.config.gradient_checkpointing:
300 | layer_out = torch.utils.checkpoint.checkpoint(
301 | layer,
302 | hidden_states,
303 | attention_mask,
304 | cos, sin,
305 | layer_kv_cache,
306 | )
307 | else:
308 | layer_out = layer(
309 | hidden_states=hidden_states,
310 | attention_mask=attention_mask,
311 | cos=cos, sin=sin,
312 | kv_cache=layer_kv_cache,
313 | )
314 | hidden_states, out_layer_kv_cache = layer_out
315 | if kv_cache:
316 | new_kv_cache.append(out_layer_kv_cache)
317 | hidden_states = self.norm(hidden_states)
318 | output = {
319 | "hidden_states": hidden_states
320 | }
321 | if kv_cache:
322 | output["kv_cache"] = new_kv_cache
323 | return output
324 |
325 |
326 | class LLaMALayer(nn.Module):
327 | def __init__(self, config: LLaMAConfig, peft_config: peft.PeftConfig):
328 | super().__init__()
329 | self.config = config
330 | self.peft_config = peft_config
331 | self.self_attn = Attention(config=config, peft_config=peft_config)
332 | self.mlp = MLP(config=config, peft_config=peft_config)
333 | self.input_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype)
334 | self.post_attention_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype)
335 |
336 | if self.peft_config.peft_mode == peft.PEFT_ADAPTER:
337 | if self.peft_config.adapter_version == "houlsby":
338 | self.peft_adapter_attn = peft.Adapter(config=config, peft_config=peft_config)
339 | self.peft_adapter_mlp = peft.Adapter(config=config, peft_config=peft_config)
340 |
341 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
342 | self.peft_input_layernorm_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config)
343 | self.peft_post_attention_layernorm_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config)
344 |
345 | def forward(
346 | self,
347 | hidden_states,
348 | attention_mask,
349 | cos, sin,
350 | kv_cache=None,
351 | ):
352 | # 1) Self-attention
353 | # [batch_size, seq_len, hidden_dim]
354 | normed_hidden_states = self.input_layernorm(hidden_states).to(self.config.dtype)
355 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
356 | normed_hidden_states = self.peft_input_layernorm_bias(normed_hidden_states)
357 | # dict(
358 | # attn_output = [batch_size, seq_len, hidden_dim]
359 | # kv_cache = dict(
360 | # key = [batch_size, num_heads, kv_seq_len, head_dim]
361 | # value = [batch_size, num_heads, kv_seq_len, head_dim]
362 | # )
363 | # )
364 | check_nan(normed_hidden_states)
365 | raw_self_attn_output = self.self_attn(
366 | hidden_states=normed_hidden_states,
367 | attention_mask=attention_mask,
368 | kv_cache=kv_cache,
369 | cos=cos, sin=sin,
370 | )
371 | # [batch_size, seq_len, hidden_dim]
372 | attn_out = raw_self_attn_output["attn_output"]
373 | if self.peft_config.peft_mode == peft.PEFT_ADAPTER \
374 | and self.peft_config.adapter_version == peft.ADAPTER_VERSION_HOULSBY:
375 | attn_out = self.peft_adapter_attn(attn_out)
376 |
377 | # [batch_size, seq_len, hidden_dim]
378 | hidden_states = hidden_states + attn_out
379 | check_nan(hidden_states)
380 | # 2) FFN
381 | # [batch_size, seq_len, hidden_dim]
382 | post_normed_hidden_states = self.post_attention_layernorm(hidden_states)
383 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
384 | post_normed_hidden_states = self.peft_post_attention_layernorm_bias(post_normed_hidden_states)
385 |
386 | mlp_out = self.mlp(post_normed_hidden_states)
387 | if self.peft_config.peft_mode == peft.PEFT_ADAPTER:
388 | mlp_out = self.peft_adapter_mlp(mlp_out)
389 |
390 | hidden_states = hidden_states + mlp_out
391 | check_nan(hidden_states)
392 | # if kv_cache:
393 | # return {
394 | # "hidden_states": hidden_states,
395 | # "kv_cache": raw_self_attn_output["kv_cache"],
396 | # }
397 | #
398 | # return {"hidden_states": hidden_states}
399 | if kv_cache:
400 | return hidden_states, raw_self_attn_output["kv_cache"]
401 | else:
402 | return hidden_states, None
403 |
404 |
405 | class MLP(nn.Module):
406 | def __init__(
407 | self,
408 | config: LLaMAConfig,
409 | peft_config: peft.PeftConfig,
410 | multiple_of: int = 256,
411 | ):
412 | super().__init__()
413 | self.config = config
414 | self.peft_config = peft_config
415 | dim = config.dim
416 | hidden_dim = 4 * dim
417 | hidden_dim = int(2 * hidden_dim / 3)
418 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
419 |
420 | if config.use_8bit:
421 | self.gate_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False)
422 | self.up_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False)
423 | self.down_proj = NoInit8bitLinear(hidden_dim, dim, bias=False, threshold=6.0, has_fp16_weights=False)
424 | else:
425 | self.gate_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype)
426 | self.up_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype)
427 | self.down_proj = NoInitLinear(hidden_dim, dim, bias=False, dtype=config.dtype)
428 |
429 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_mlp:
430 | self.gate_proj_lora = peft.LoRA(config=config, peft_config=peft_config,
431 | input_dim=dim, output_dim=hidden_dim)
432 | self.up_proj_lora = peft.LoRA(config=config, peft_config=peft_config,
433 | input_dim=dim, output_dim=hidden_dim)
434 | self.down_proj_lora = peft.LoRA(config=config, peft_config=peft_config,
435 | input_dim=dim, output_dim=hidden_dim)
436 | if self.peft_config.peft_mode == peft.PEFT_IA3:
437 | self.peft_ia3 = peft.IA3ForMLP(config, peft_config=peft_config)
438 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
439 | self.peft_gate_proj_bias = peft.BitFitAddBias(dim=hidden_dim, peft_config=peft_config)
440 | self.peft_up_proj_bias = peft.BitFitAddBias(dim=hidden_dim, peft_config=peft_config)
441 | self.peft_down_proj_bias = peft.BitFitAddBias(dim=dim, peft_config=peft_config)
442 |
443 | def forward(self, x):
444 | gate_proj = self.gate_proj(x)
445 | up_proj = self.up_proj(x)
446 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_mlp:
447 | gate_proj += self.gate_proj_lora(x)
448 | up_proj += self.up_proj_lora(x)
449 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
450 | gate_proj = self.peft_gate_proj_bias(gate_proj)
451 | up_proj = self.peft_gate_proj_bias(up_proj)
452 |
453 | intermediate_state = F.silu(gate_proj) * up_proj
454 | if self.peft_config.peft_mode == peft.PEFT_IA3:
455 | intermediate_state = self.peft_ia3(intermediate_state)
456 |
457 | down_proj = self.down_proj(intermediate_state)
458 | if self.peft_config.peft_mode == peft.PEFT_LORA and self.peft_config.lora_mlp:
459 | down_proj = self.down_proj_lora(x)
460 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
461 | down_proj = self.peft_down_proj_bias(down_proj)
462 |
463 | return down_proj
464 |
465 |
466 | class RMSNorm(torch.nn.Module):
467 | def __init__(self, dim: int, eps: float = 1e-6, dtype=torch.float16):
468 | super().__init__()
469 | self.eps = eps
470 | self.weight = nn.Parameter(torch.ones(dim, dtype=dtype))
471 |
472 | def _norm(self, x):
473 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
474 |
475 | def forward(self, x):
476 | output = self._norm(x.float()).type_as(x)
477 | return output * self.weight
478 |
479 |
480 | class Attention(nn.Module):
481 | def __init__(self, config: LLaMAConfig, peft_config: peft.PeftConfig):
482 | super().__init__()
483 | self.config = config
484 | self.peft_config = peft_config
485 | self.n_heads = config.n_heads
486 | self.head_dim = config.dim // config.n_heads
487 |
488 | if config.use_8bit:
489 | self.q_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
490 | self.k_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
491 | self.v_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
492 | self.o_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
493 | else:
494 | self.q_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
495 | self.k_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
496 | self.v_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
497 | self.o_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
498 | self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
499 |
500 | if self.peft_config.peft_mode == peft.PEFT_LORA:
501 | self.peft_q_proj_lora = peft.LoRA(config=config, peft_config=peft_config)
502 | self.peft_v_proj_lora = peft.LoRA(config=config, peft_config=peft_config)
503 | if self.peft_config.peft_mode == peft.PEFT_IA3:
504 | self.peft_ia3 = peft.IA3ForAttn(config, peft_config=peft_config)
505 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
506 | self.peft_q_proj_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config)
507 | self.peft_k_proj_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config)
508 | self.peft_v_proj_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config)
509 | self.peft_o_proj_bias = peft.BitFitAddBias(dim=config.dim, peft_config=peft_config)
510 | if self.peft_config.peft_mode == peft.PEFT_PREFIX_ADAPTER:
511 | self.peft_prefix_adapter = peft.PrefixAdapter(config=config, peft_config=peft_config)
512 |
513 | def forward(self, hidden_states, attention_mask, cos, sin, kv_cache=None):
514 | """
515 | precomputed_kv_hidden_states is for init (pre-compute KV activations, e.g. for added prefixes)
516 | kv_cache is for generation (cached past KV)
517 | """
518 | batch_size, q_seq_len, hidden_dim = hidden_states.size()
519 |
520 | # (batch_size, num_heads, q_seq_len, head_dim)
521 | query_states = self.q_proj(hidden_states)
522 | key_states = self.k_proj(hidden_states)
523 | value_states = self.v_proj(hidden_states)
524 |
525 | if self.peft_config.peft_mode == peft.PEFT_LORA:
526 | query_states += self.peft_q_proj_lora(hidden_states)
527 | value_states += self.peft_v_proj_lora(hidden_states)
528 | if self.peft_config.peft_mode == peft.PEFT_IA3:
529 | key_states, value_states = self.peft_ia3(key_states, value_states)
530 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
531 | query_states = self.peft_q_proj_bias(query_states)
532 | key_states = self.peft_k_proj_bias(key_states)
533 | value_states = self.peft_v_proj_bias(value_states)
534 |
535 | query_states = query_states.view(
536 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
537 | key_states = key_states.view(
538 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
539 | value_states = value_states.view(
540 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
541 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos=cos, sin=sin)
542 |
543 | if kv_cache:
544 | key_states = torch.cat([kv_cache["key"], key_states], dim=2)
545 | value_states = torch.cat([kv_cache["value"], value_states], dim=2)
546 |
547 | attn_output = torch.nn.functional.scaled_dot_product_attention(
548 | query=query_states,
549 | key=key_states,
550 | value=value_states,
551 | attn_mask=attention_mask,
552 | )
553 |
554 | if self.peft_config.peft_mode == peft.PEFT_PREFIX_ADAPTER:
555 | attn_output = attn_output + self.peft_prefix_adapter(query_states=query_states)
556 |
557 | # (batch_size, q_seq_len, hidden_dim)
558 | attn_output = attn_output.transpose(1, 2).contiguous().view(
559 | batch_size, q_seq_len, hidden_dim,
560 | )
561 | attn_output = self.o_proj(attn_output)
562 | if self.peft_config.peft_mode == peft.PEFT_BITFIT:
563 | attn_output = self.peft_o_proj_bias(attn_output)
564 |
565 | check_nan(attn_output)
566 | if kv_cache:
567 | new_kv_cache = {"key": key_states, "value": value_states}
568 | return {"attn_output": attn_output, "kv_cache": new_kv_cache}
569 | else:
570 | return {"attn_output": attn_output}
571 |
572 |
573 | class RotaryEmbedding(torch.nn.Module):
574 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
575 | super().__init__()
576 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device=device) / dim))
577 | self.register_buffer("inv_freq", inv_freq)
578 |
579 | # Build here to make `torch.jit.trace` work.
580 | self.max_seq_len_cached = max_position_embeddings
581 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device).to(self.inv_freq.dtype)
582 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
583 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
584 | emb = torch.cat((freqs, freqs), dim=-1)
585 | self.cos_cached = emb.cos()[None, None, :, :]
586 | self.sin_cached = emb.sin()[None, None, :, :]
587 |
588 | def forward(self, x, seq_len=None):
589 | # x: [bs, num_attention_heads, seq_len, head_size]
590 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
591 | if seq_len > self.max_seq_len_cached:
592 | self.max_seq_len_cached = seq_len
593 | t = torch.arange(self.max_seq_len_cached, device=x.device).to(self.inv_freq.dtype)
594 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
595 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
596 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
597 | self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype)
598 | self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype)
599 | return (
600 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device),
601 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device),
602 | )
603 |
604 |
605 | def rotate_half(x):
606 | """Rotates half the hidden dims of the input."""
607 | x1 = x[..., : x.shape[-1] // 2]
608 | x2 = x[..., x.shape[-1] // 2:]
609 | return torch.cat((-x2, x1), dim=-1)
610 |
611 |
612 | def apply_rotary_pos_emb(q, k, cos, sin):
613 | q_embed = (q * cos) + (rotate_half(q) * sin)
614 | k_embed = (k * cos) + (rotate_half(k) * sin)
615 | return q_embed, k_embed
616 |
617 |
618 | def create_attention_mask(input_ids,
619 | dtype=torch.float32,
620 | return_soft_mask=True):
621 | """Create mask for decoder attention.
622 |
623 | Decoder masks have two use-cases:
624 |
625 | 1) Training, where we see the full decoder sequence. In that case,
626 | we want a causal mask.
627 |
628 | 2) Generation, where we only see one token at once. In that case,
629 | it doesn't really matter what we give, we can just give a 1.
630 | (i.e. seq_len = 1)
631 |
632 | Note that in both cases we do not care about which decoder_input_ids
633 | are valid, and also we can always simply broadcast over the batch size
634 | and heads.
635 |
636 | :param input_ids: [batch_size, seq_len]
637 | :param dtype: dtype
638 | :param return_soft_mask: whether to return mask or logits-mask
639 | :return: float [batch_size=1, num_heads=1, q_len=seq_len, kv_len=seq_len]
640 | """
641 | batch_size, seq_length = input_ids.shape
642 | # [seq_len]
643 | seq_ids = torch.arange(seq_length, device=input_ids.device)
644 | # [seq_len, seq_len]
645 | causal_mask = seq_ids[None, :].repeat(seq_length, 1) <= seq_ids[:, None]
646 | # [batch_size=1, num_heads=1, seq_len, seq_len]
647 | causal_mask = causal_mask[None, None, :, :]
648 | if return_soft_mask:
649 | return convert_mask_to_soft_mask(causal_mask, dtype=dtype)
650 | else:
651 | return causal_mask
652 |
653 |
654 | def convert_mask_to_soft_mask(mask, dtype):
655 | """Convert binary mask to mask that can be added to logits.
656 |
657 | (i.e. 0 for attention, large negative for masked)
658 | """
659 | mask = mask.to(dtype=dtype)
660 | mask = (1.0 - mask) * torch.finfo(dtype).min
661 | return mask
662 |
663 |
664 | class NoInitLinear(nn.Linear):
665 | def reset_parameters(self) -> None:
666 | pass
667 |
668 |
669 | class NoInit8bitLinear(bnb.nn.Linear8bitLt):
670 | def reset_parameters(self) -> None:
671 | pass
672 |
673 |
674 | def get_linear_class(use_8bit=False):
675 | if use_8bit:
676 | return NoInit8bitLinear
677 | else:
678 | return NoInitLinear
679 |
680 |
681 | class NoInitEmbedding(nn.Embedding):
682 | def reset_parameters(self) -> None:
683 | pass
684 |
685 |
686 | def check_nan(x):
687 | # if torch.isnan(x).any():
688 | # import pdb
689 | # pdb.set_trace()
690 | pass
691 |
692 |
693 | def create_model(model_name, hf_path, peft_config: peft.PeftConfig, use_8bit=False, device=None):
694 | config = LLAMA_CONFIG_DICT[model_name]
695 |
696 | with open(os.path.join(hf_path, "pytorch_model.bin.index.json")) as f:
697 | weight_map = json.load(f)["weight_map"]
698 |
699 | filename_list = sorted(list(set(weight_map.values())))
700 | if device is None:
701 | # TODO: Local rank
702 | device = torch.device("cuda:0")
703 | if use_8bit:
704 | config = dataclasses.replace(config, use_8bit=True)
705 | with init_empty_weights():
706 | model = LLaMAModel(config=config, peft_config=peft_config)
707 | if model_name == "debug":
708 | return model
709 | state_keys = set(model.state_dict())
710 | filename_list = sorted(list(set(weight_map.values())))
711 | for filename in tqdm.tqdm(filename_list):
712 | loaded = torch.load(os.path.join(hf_path, filename), map_location="cpu")
713 | for k, v in loaded.items():
714 | set_module_8bit_tensor_to_device(model, tensor_name=k, device=device, value=v)
715 | state_keys.remove(k)
716 | assert not state_keys
717 | else:
718 | # noinspection PyUnresolvedReferences
719 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
720 | model = LLaMAModel(config=config, peft_config=peft_config).cuda()
721 | torch.set_default_tensor_type(torch.FloatTensor)
722 | if model_name == "debug":
723 | return model
724 | state_keys = set(model.state_dict())
725 | for filename in tqdm.tqdm(filename_list):
726 | loaded = torch.load(os.path.join(hf_path, filename), map_location="cpu")
727 | model.load_state_dict(loaded, strict=False)
728 | for k in loaded:
729 | state_keys.remove(k)
730 | return model
731 |
732 |
733 | def set_peft_requires_grad(model: LLaMAModel):
734 | for p in model.parameters():
735 | p.requires_grad_(False)
736 | if model.peft_config.peft_mode == peft.PEFT_PREFIX:
737 | _set_requires_grad_if_str_in_name(model, substr="peft_prefix")
738 | elif model.peft_config.peft_mode == peft.PEFT_PROMPT:
739 | _set_requires_grad_if_str_in_name(model, substr="peft_prompt")
740 | elif model.peft_config.peft_mode == peft.PEFT_ADAPTER:
741 | _set_requires_grad_if_str_in_name(model, substr="peft_adapter_")
742 | elif model.peft_config.peft_mode == peft.PEFT_PREFIX_ADAPTER:
743 | _set_requires_grad_if_str_in_name(model, substr="peft_prefix_adapter")
744 | elif model.peft_config.peft_mode == peft.PEFT_LORA:
745 | _set_requires_grad_if_str_in_name(model, substr="_lora")
746 | elif model.peft_config.peft_mode == peft.PEFT_IA3:
747 | _set_requires_grad_if_str_in_name(model, substr="ia3")
748 | elif model.peft_config.peft_mode == peft.PEFT_BITFIT:
749 | _set_requires_grad_if_str_in_name(model, substr="bias")
750 | elif model.peft_config.peft_mode == peft.NO_PEFT:
751 | pass
752 | else:
753 | raise KeyError(model.peft_config.peft_mode)
754 |
755 |
756 | def _set_requires_grad_if_str_in_name(model, substr):
757 | for n, p in model.named_parameters():
758 | if substr in n:
759 | p.requires_grad_(True)
760 | print(f"Tuning: {n}")
761 |
762 |
763 | def create_generation_attention_mask(batch_size, seq_len, num_valid_tokens, device):
764 | """
765 | :param batch_size: int
766 | :param seq_len: int
767 | :param num_valid_tokens: [batch_size]
768 | :param device:
769 | :return:
770 | """
771 | # For right-aligned, based on num_valid_tokens
772 | # noinspection PyTypeChecker
773 | attn_mask = torch.zeros([batch_size, 1, 1, seq_len], dtype=bool)
774 | for i in range(batch_size):
775 | valid = num_valid_tokens[i]
776 | # noinspection PyTypeChecker
777 | # attn_mask[i, 0, -valid:, -valid:] = torch.tril(torch.ones([valid, valid], dtype=bool))
778 | attn_mask[i, 0, 0, -valid:] = True
779 | return attn_mask.to(device=device)
780 |
781 |
782 | def create_casual_attention_mask(seq_len, device):
783 | # noinspection PyTypeChecker
784 | attn_mask = torch.tril(torch.ones([seq_len, seq_len], dtype=bool))[None, None, :, :]
785 | return attn_mask.to(device=device)
786 |
787 |
788 | def create_rope_embed_ids(input_ids):
789 | pad_token_id = 0
790 | max_position = 2047 # These will not actually be used, as they are masked out by the attention mask
791 | x = (input_ids != pad_token_id).cumsum(-1) - 1
792 | x[input_ids == pad_token_id] = max_position
793 | return x
794 |
795 |
796 | def zeros_like(shape, tensor):
797 | return torch.zeros(shape).type_as(tensor).to(tensor.device)
798 |
--------------------------------------------------------------------------------
/pefty_llama/peft/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration import *
2 | from .ia3 import IA3ForAttn, IA3ForMLP
3 | from .bitfit import BitFitAddBias
4 | from .lora import LoRA, LoRAEmbed
5 | from .prefix_tuning import SoftPrefixes
6 | from .prompt_tuning import AddSoftPrompt
7 | from .adapter import Adapter
8 | from .prefix_adapter import PrefixAdapter
9 |
--------------------------------------------------------------------------------
/pefty_llama/peft/adapter.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from pefty_llama.configuration import LLaMAConfig
4 | from .configuration import PeftConfig
5 |
6 |
7 | class Adapter(nn.Module):
8 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig):
9 | super().__init__()
10 | self.config = config
11 | self.peft_config = peft_config
12 | self.down_proj = nn.Linear(
13 | config.dim, peft_config.adapter_hidden_size, bias=False,
14 | dtype=peft_config.peft_dtype,
15 | )
16 | self.up_proj = nn.Linear(
17 | peft_config.adapter_hidden_size, config.dim, bias=False,
18 | dtype=peft_config.peft_dtype,
19 | )
20 |
21 | def forward(self, hidden_states):
22 | hidden_states = hidden_states.to(self.peft_config.peft_dtype)
23 | out = self.up_proj(F.gelu(self.down_proj(hidden_states))) + hidden_states
24 | return out.to(self.config.dtype)
25 |
--------------------------------------------------------------------------------
/pefty_llama/peft/bitfit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .configuration import PeftConfig
4 |
5 |
6 | class BitFitAddBias(nn.Module):
7 | def __init__(self, dim: int, peft_config: PeftConfig):
8 | super().__init__()
9 | self.peft_config = peft_config
10 | self.bias = nn.Parameter(torch.zeros(dim, dtype=peft_config.peft_dtype))
11 |
12 | def forward(self, hidden_state):
13 | input_dtype = hidden_state.dtype
14 | return (hidden_state.to(self.peft_config.peft_dtype) + self.bias).to(input_dtype)
15 |
--------------------------------------------------------------------------------
/pefty_llama/peft/configuration.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from dataclasses import dataclass, field
3 |
4 | import torch
5 |
6 | PEFT_PREFIX = "prefix"
7 | PEFT_PROMPT = "prompt"
8 | PEFT_ADAPTER = "adapter"
9 | PEFT_PREFIX_ADAPTER = "prefix_adapter"
10 | PEFT_LORA = "lora"
11 | PEFT_IA3 = "ia3"
12 | PEFT_BITFIT = "bitfit"
13 | NO_PEFT = "nothing"
14 |
15 | ADAPTER_VERSION_HOULSBY = "houlsby"
16 | ADAPTER_VERSION_PFEIFFER = "pfeiffer"
17 |
18 |
19 | @dataclass
20 | class PeftConfig:
21 | peft_mode: str = field()
22 | peft_dtype: Any = field(default=torch.float32)
23 |
24 | # Used by prompt, prefix, prefix_adapter
25 | num_prefix_tokens: int = field(default=16)
26 |
27 | # Prefix
28 | prefix_use_mlp: bool = field(default=True)
29 | prefix_mlp_intermediate_size: int = field(default=None)
30 |
31 | # LoRA
32 | lora_rank: int = field(default=8)
33 | lora_alpha: int = field(default=16)
34 | lora_mlp: bool = field(default=False)
35 | lora_embedding: bool = field(default=False)
36 |
37 | # Adapter
38 | adapter_hidden_size: int = field(default=64)
39 | adapter_version: str = field(default=ADAPTER_VERSION_PFEIFFER) # houlsby, pfeiffer
40 |
41 | def check(self):
42 | assert self.peft_mode in (
43 | PEFT_PREFIX, PEFT_PREFIX_ADAPTER, PEFT_PROMPT, PEFT_ADAPTER,
44 | PEFT_IA3, PEFT_BITFIT,
45 | NO_PEFT,
46 | )
47 |
--------------------------------------------------------------------------------
/pefty_llama/peft/ia3.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from pefty_llama.modeling import LLaMAModel, NoInitLinear, NoInit8bitLinear, RotaryEmbedding, apply_rotary_pos_emb, check_nan
9 | from pefty_llama.configuration import LLaMAConfig
10 | from .configuration import PeftConfig
11 |
12 |
13 | class IA3Attention(nn.Module):
14 | def __init__(self, config: LLaMAConfig):
15 | super().__init__()
16 | self.config = config
17 | self.n_heads = config.n_heads
18 | self.head_dim = config.dim // config.n_heads
19 |
20 | if config.use_8bit:
21 | self.q_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
22 | self.k_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
23 | self.v_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
24 | self.o_proj = NoInit8bitLinear(config.dim, config.dim, bias=False, threshold=6.0, has_fp16_weights=False)
25 | else:
26 | self.q_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
27 | self.k_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
28 | self.v_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
29 | self.o_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
30 | self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
31 |
32 | # IA3-specific parameters:
33 | self.peft_l_k = nn.Parameter(torch.ones(1, self.n_heads, 1, self.head_dim, dtype=config.dtype))
34 | self.peft_l_v = nn.Parameter(torch.ones(1, self.n_heads, 1, self.head_dim, dtype=config.dtype))
35 |
36 | def forward(self, hidden_states, attention_mask, cos, sin, kv_cache=None):
37 | """
38 | precomputed_kv_hidden_states is for init (pre-compute KV activations, e.g. for added prefixes)
39 | kv_cache is for generation (cached past KV)
40 | """
41 | batch_size, q_seq_len, hidden_dim = hidden_states.size()
42 |
43 | # (batch_size, num_heads, q_seq_len, head_dim)
44 | query_states = self.q_proj(hidden_states).view(
45 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
46 | key_states = self.k_proj(hidden_states).view(
47 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
48 | value_states = self.v_proj(hidden_states).view(
49 | batch_size, q_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
50 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos=cos, sin=sin)
51 | if kv_cache:
52 | key_states = torch.cat([kv_cache["key"], key_states], dim=2)
53 | value_states = torch.cat([kv_cache["value"], value_states], dim=2)
54 |
55 | # IA3-specific:
56 | query_states = query_states * self.peft_l_k
57 | value_states = value_states * self.peft_l_v
58 | # end of IA3-specific
59 |
60 | scores = torch.matmul(
61 | query_states, key_states.transpose(3, 2).type_as(query_states) / math.sqrt(self.head_dim)
62 | )
63 | scores += attention_mask
64 |
65 | # (batch_size, num_heads, q_seq_len, kv_seq_len)
66 | attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores)
67 | # (batch_size, num_heads, q_seq_len, head_dim)
68 | attn_output = torch.matmul(attn_weights, value_states.type_as(query_states))
69 | # (batch_size, q_seq_len, hidden_dim)
70 | attn_output = attn_output.transpose(1, 2).contiguous().view(
71 | batch_size, q_seq_len, hidden_dim,
72 | )
73 | attn_output = self.o_proj(attn_output)
74 | check_nan(attn_output)
75 | if kv_cache:
76 | new_kv_cache = {"key": key_states, "value": value_states}
77 | return {"attn_output": attn_output, "kv_cache": new_kv_cache}
78 | else:
79 | return {"attn_output": attn_output}
80 |
81 |
82 | class IA3MLP(nn.Module):
83 | def __init__(
84 | self,
85 | config: LLaMAConfig,
86 | multiple_of: int = 256,
87 | ):
88 | super().__init__()
89 | dim = config.dim
90 | hidden_dim = 4 * dim
91 | hidden_dim = int(2 * hidden_dim / 3)
92 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
93 |
94 | if config.use_8bit:
95 | self.gate_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False)
96 | self.up_proj = NoInit8bitLinear(dim, hidden_dim, bias=False, threshold=6.0, has_fp16_weights=False)
97 | self.down_proj = NoInit8bitLinear(hidden_dim, dim, bias=False, threshold=6.0, has_fp16_weights=False)
98 | else:
99 | self.gate_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype)
100 | self.up_proj = NoInitLinear(dim, hidden_dim, bias=False, dtype=config.dtype)
101 | self.down_proj = NoInitLinear(hidden_dim, dim, bias=False, dtype=config.dtype)
102 |
103 | # IA3-specific parameters:
104 | self.peft_l_ffn = nn.Parameter(torch.ones(1, 1, hidden_dim, dtype=config.dtype))
105 |
106 | def forward(self, x):
107 | h = F.silu(self.gate_proj(x)) * self.up_proj(x)
108 | # IA3-specific:
109 | h = h * self.peft_l_ffn
110 | # end of IA3-specific
111 | return self.down_proj(h)
112 |
113 |
114 | class IA3(nn.Module):
115 | def __init__(self, model: LLaMAModel):
116 | super().__init__()
117 | self.base_model = model
118 | model_config = model.config
119 |
120 | for layer in self.base_model.model.layers:
121 | # you also need to copy the parameters of the layer to the new layer
122 | patched_attn = IA3Attention(model_config)
123 | current_attn = layer.self_attn
124 | patched_attn.q_proj.weight = current_attn.q_proj.weight
125 | patched_attn.k_proj.weight = current_attn.k_proj.weight
126 | patched_attn.v_proj.weight = current_attn.v_proj.weight
127 | patched_attn.o_proj.weight = current_attn.o_proj.weight
128 | patched_attn.rotary_emb = current_attn.rotary_emb
129 |
130 | layer.self_attn = patched_attn
131 | del current_attn
132 |
133 | patched_mlp = IA3MLP(model_config)
134 | current_mlp = layer.mlp
135 | patched_mlp.gate_proj.weight = current_mlp.gate_proj.weight
136 | patched_mlp.up_proj.weight = current_mlp.up_proj.weight
137 | patched_mlp.down_proj.weight = current_mlp.down_proj.weight
138 |
139 | layer.mlp = patched_mlp
140 | del current_mlp
141 |
142 | # cleanup memory freed by deleting the old layers
143 | if torch.cuda.is_available():
144 | torch.cuda.empty_cache()
145 | gc.collect()
146 |
147 | for name, param in self.base_model.named_parameters():
148 | if "peft_" in name: continue
149 | param.requires_grad = False
150 |
151 | # monkey patch the methods
152 | self.forward = self.base_model.forward
153 | self.generate = self.base_model.generate
154 |
155 |
156 | class IA3ForAttn(nn.Module):
157 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig):
158 | super().__init__()
159 | self.config = config
160 | self.peft_config = peft_config
161 | self.n_heads = config.n_heads
162 | self.head_dim = config.dim // config.n_heads
163 |
164 | self.peft_l_k = nn.Parameter(torch.ones(config.dim, dtype=peft_config.peft_dtype))
165 | self.peft_l_v = nn.Parameter(torch.ones(config.dim, dtype=peft_config.peft_dtype))
166 |
167 | def forward(self, key_states, value_states):
168 | return (
169 | (key_states.to(self.peft_config.peft_dtype) * self.peft_l_k).to(self.config.dtype),
170 | (value_states.to(self.peft_config.peft_dtype) * self.peft_l_v).to(self.config.dtype),
171 | )
172 |
173 |
174 | class IA3ForMLP(nn.Module):
175 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig):
176 | super().__init__()
177 | self.config = config
178 | self.peft_config = peft_config
179 | multiple_of = 256
180 | intermediate_dim = 4 * config.dim
181 | intermediate_dim = int(2 * intermediate_dim / 3)
182 | intermediate_dim = multiple_of * ((intermediate_dim + multiple_of - 1) // multiple_of)
183 |
184 | self.peft_l_ffn = nn.Parameter(torch.ones(1, 1, intermediate_dim, dtype=peft_config.peft_dtype))
185 |
186 | def forward(self, intermediate_state):
187 | return (intermediate_state.to(self.peft_config.peft_dtype) * self.peft_l_ffn).to(self.config.dtype)
188 |
--------------------------------------------------------------------------------
/pefty_llama/peft/lora.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from pefty_llama.configuration import LLaMAConfig
6 | from .configuration import PeftConfig
7 |
8 |
9 | class LoRA(nn.Module):
10 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig,
11 | input_dim: Optional[int] = None,
12 | output_dim: Optional[int] = None,
13 | ):
14 | super().__init__()
15 | self.config = config
16 | self.peft_config = peft_config
17 |
18 | if input_dim is None:
19 | input_dim = self.config.dim
20 | if output_dim is None:
21 | output_dim = self.config.dim
22 | self.lora_down = nn.Parameter(torch.randn(input_dim, peft_config.lora_rank, dtype=peft_config.peft_dtype))
23 | self.lora_up = nn.Parameter(torch.zeros(peft_config.lora_rank, output_dim, dtype=peft_config.peft_dtype))
24 | self.rank = peft_config.lora_rank
25 | self.scaling = peft_config.lora_alpha / peft_config.lora_rank
26 |
27 | def forward(self, hidden_states):
28 | hidden_states = hidden_states.to(self.peft_config.peft_dtype)
29 | lora_out = torch.einsum("ij,bsi->bsj", (self.lora_down @ self.lora_up), hidden_states) / self.rank
30 | return (hidden_states + self.scaling * lora_out).to(self.config.dtype)
31 |
32 |
33 | class LoRAEmbed(nn.Module):
34 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig):
35 | super().__init__()
36 | self.config = config
37 | self.peft_config = peft_config
38 |
39 | self.lora_down = nn.Parameter(torch.randn(config.vocab_size, peft_config.lora_rank, dtype=peft_config.peft_dtype))
40 | self.lora_up = nn.Parameter(torch.zeros(peft_config.lora_rank, config.dim, dtype=peft_config.peft_dtype))
41 | self.rank = peft_config.lora_rank
42 | self.scaling = peft_config.lora_alpha / peft_config.lora_rank
43 |
44 | def forward(self, input_ids):
45 | embedding_matrix = self.lora_down @ self.lora_up
46 | return F.embedding(input_ids, embedding_matrix)
47 |
--------------------------------------------------------------------------------
/pefty_llama/peft/prefix_adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from pefty_llama.configuration import LLaMAConfig
5 | from .configuration import PeftConfig
6 |
7 |
8 | class PrefixAdapter(nn.Module):
9 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig):
10 | super().__init__()
11 | self.config = config
12 | self.peft_config = peft_config
13 | # "batch_size"=1, num_heads, num_prefix_tokens, head_dim
14 | self.prefix_k = nn.Parameter(torch.randn(
15 | 1, config.n_heads, peft_config.num_prefix_tokens, config.head_dim, dtype=peft_config.peft_dtype))
16 | self.prefix_v = nn.Parameter(torch.randn(
17 | 1, config.n_heads, peft_config.num_prefix_tokens, config.head_dim, dtype=peft_config.peft_dtype))
18 | self.gate = nn.Parameter(torch.zeros(1, config.n_heads, 1, 1))
19 |
20 | def forward(self, query_states):
21 | batch_size, num_heads, q_seq_len, head_dim = query_states.shape
22 | # "batch_size"=1, num_heads, num_prefix_tokens, head_dim
23 | prefix_k = self.prefix_k.expand(batch_size, -1, -1, -1)
24 | prefix_v = self.prefix_v.expand(batch_size, -1, -1, -1)
25 | attn_output = torch.nn.functional.scaled_dot_product_attention(
26 | query=query_states.to(self.peft_config.peft_dtype),
27 | key=prefix_k,
28 | value=prefix_v,
29 | )
30 | return (F.tanh(self.gate) * attn_output).to(self.config.dtype)
31 |
--------------------------------------------------------------------------------
/pefty_llama/peft/prefix_tuning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pefty_llama.configuration import LLaMAConfig
4 | from .configuration import PeftConfig
5 |
6 |
7 | class SoftPrefixes(nn.Module):
8 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig):
9 | super().__init__()
10 | self.config = config
11 | self.peft_config = peft_config
12 | if self.peft_config.prefix_use_mlp:
13 | if self.peft_config.prefix_mlp_intermediate_size is not None:
14 | intermediate_size = self.peft_config.prefix_mlp_intermediate_size
15 | else:
16 | intermediate_size = self.config.dim
17 |
18 | self.initial = nn.Parameter(
19 | torch.randn(peft_config.num_prefix_tokens, config.dim, dtype=peft_config.peft_dtype)
20 | )
21 | self.mlp = torch.nn.Sequential(
22 | torch.nn.Linear(config.dim, intermediate_size, dtype=peft_config.peft_dtype),
23 | torch.nn.Tanh(),
24 | torch.nn.Linear(intermediate_size, config.n_layers * 2 * config.dim, dtype=peft_config.peft_dtype),
25 | )
26 | else:
27 | self.soft_prompt = nn.Parameter(torch.randn(
28 | peft_config.num_prefix_tokens, config.n_layers * 2 * config.dim,
29 | dtype=peft_config.peft_dtype
30 | ))
31 |
32 | def forward(self, batch_size):
33 | if self.peft_config.prefix_use_mlp:
34 | out = self.mlp(self.initial)
35 | else:
36 | out = self.embedding
37 | # layers, k/v, num_prefix_tokens, num_heads, head_dim
38 | out = out.view(self.peft_config.num_prefix_tokens, self.config.n_layers, 2,
39 | self.config.n_heads, self.config.head_dim).to(self.config.dtype)
40 | return [
41 | {
42 | "key": out[:, layer, 0, :, :].permute(1, 0, 2).unsqueeze(0).expand(batch_size, -1, -1, -1),
43 | "value": out[:, layer, 1, :, :].permute(1, 0, 2).unsqueeze(0).expand(batch_size, -1, -1, -1),
44 | }
45 | for layer in range(self.config.n_layers)
46 | ]
47 |
--------------------------------------------------------------------------------
/pefty_llama/peft/prompt_tuning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pefty_llama.configuration import LLaMAConfig
4 | from .configuration import PeftConfig
5 |
6 |
7 | class AddSoftPrompt(nn.Module):
8 | def __init__(self, config: LLaMAConfig, peft_config: PeftConfig):
9 | super().__init__()
10 | self.peft_config = peft_config
11 | self.soft_prompt = nn.Parameter(
12 | torch.randn(peft_config.num_prefix_tokens, config.dim, dtype=peft_config.peft_dtype)
13 | )
14 |
15 | def forward(self, hidden_states):
16 | batch_size, seq_len, dim = hidden_states.shape
17 | soft_prompt = self.soft_prompt.unsqueeze(0).expand(batch_size, -1, -1).to(self.config.dtype)
18 | return torch.cat([soft_prompt, hidden_states], dim=1)
19 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | tqdm
3 | transformers
4 | accelerate
5 | bitsandbytes
6 | sentencepiece
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as f:
4 | long_description = f.read()
5 |
6 | with open("requirements.txt", "r") as f:
7 | requires = f.read().splitlines()
8 |
9 | setuptools.setup(
10 | name="pefty_llama",
11 | version="0.0.1",
12 | author="Vlad Lialin",
13 | author_email="vlad.lialin@gmail.com",
14 | description="Minimal implementations of multiple PEFT methods for LLaMA fine-tuning",
15 | url="https://github.com/Guitaricet/my_pefty_llama",
16 | packages=setuptools.find_packages(),
17 | requires=requires,
18 | )
19 |
--------------------------------------------------------------------------------
/tokenize_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import tqdm.auto as tqdm
6 |
7 | import datasets
8 | import transformers
9 |
10 |
11 | def read_jsonl(path):
12 | # Manually open because .splitlines is different from iterating over lines
13 | with open(path, "r") as f:
14 | for line in f:
15 | yield json.loads(line)
16 |
17 |
18 | def read_lm_dataformat(path):
19 | import lm_dataformat
20 | reader = lm_dataformat.Reader(path)
21 | yield from reader.stream_data()
22 |
23 |
24 | def main():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--tokenizer_path", type=str)
27 | parser.add_argument("--data_path", type=str)
28 | parser.add_argument("--data_format", type=str, default="jsonl")
29 | parser.add_argument("--save_path", type=str)
30 | parser.add_argument("--max_seq_length", type=int, default=2048)
31 | parser.add_argument("--shard_size", type=int, default=100000)
32 | args = parser.parse_args()
33 | os.makedirs(args.save_path, exist_ok=True)
34 |
35 | tokenizer = transformers.LlamaTokenizer.from_pretrained(args.tokenizer_path)
36 |
37 | all_tokenized = []
38 | if args.data_format == "jsonl":
39 | reader = read_jsonl(args.data_path)
40 | elif args.data_format == "lm_dataformat":
41 | reader = read_lm_dataformat(args.data_path)
42 | else:
43 | raise KeyError(args.data_format)
44 |
45 | total = 0
46 | shards = 0
47 | for elem in tqdm.tqdm(reader):
48 | text = elem["text"] if args.data_format == "jsonl" else elem
49 | tokenized = tokenizer.encode(text)
50 | num_chunks = len(tokenized) // args.max_seq_length
51 | for j in range(num_chunks):
52 | chunk = tokenized[
53 | j * args.max_seq_length: (j + 1) * args.max_seq_length
54 | ]
55 | all_tokenized.append(chunk)
56 | total += 1
57 | if len(all_tokenized) == args.shard_size:
58 | ds = datasets.Dataset.from_dict({"input_ids": all_tokenized})
59 | ds.save_to_disk(os.path.join(args.save_path, "shard_{:05d}".format(shards)))
60 | all_tokenized = []
61 | shards += 1
62 |
63 | if len(all_tokenized) > 0:
64 | ds = datasets.Dataset.from_dict({"input_ids": all_tokenized})
65 | ds.save_to_disk(os.path.join(args.save_path, "shard_{:05d}".format(shards)))
66 |
67 | print(f"Generated {total} samples in {shards} shards.")
68 |
69 |
70 | if __name__ == "__main__":
71 | main()
72 |
--------------------------------------------------------------------------------