├── CITATION.cff
├── LICENSE
├── README.md
├── configs
├── llama_100m.json
├── llama_130m.json
├── llama_1b.json
├── llama_20m.json
├── llama_250m.json
├── llama_350m.json
├── llama_35m.json
├── llama_3b.json
├── llama_40m.json
├── llama_60m.json
├── llama_71m.json
├── llama_7b.json
└── llama_9m.json
├── exp_requirements.txt
├── galore_torch
├── __init__.py
├── adafactor.py
├── adamw.py
├── adamw8bit.py
├── galore_projector.py
└── galore_projector_tensor.py
├── imgs
├── galore_code_box.png
└── subspace_learning.png
├── peft_pretraining
├── args_utils.py
├── dataloader.py
├── modeling_llama.py
└── training_utils.py
├── requirements.txt
├── run_glue.py
├── scripts
├── benchmark_c4
│ ├── llama_130m.sh
│ ├── llama_1b.sh
│ ├── llama_350m.sh
│ ├── llama_60m.sh
│ └── llama_7b.sh
├── single_gpu
│ ├── llama_7b.sh
│ └── llama_7b_checkpointing.sh
└── tensor_test
│ └── neural_operator.py
├── setup.py
└── torchrun_main.py
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: "GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection"
3 | version: 1.0.0
4 | message: "If you use this software, please cite it as below."
5 | authors:
6 | - family-names: "Jiawei"
7 | given-names: "Zhao"
8 | year: 2024
9 | repository-code: "https://arxiv.org/abs/2403.03507"
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GaLore
2 |
3 | This repo contains the pre-release version of GaLore algorithm, proposed by [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507).
4 |
5 | Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows *full-parameter* learning but is more *memory-efficient* than common low-rank adaptation methods, such as LoRA.
6 | As a gradient projection method, GaLore is independent of the choice of optimizers and can be easily plugged into existing ones with only two lines of code, as shown in Algorithm 1 below.
7 |
8 |
9 |

10 |
11 |
12 | ## News
13 |
14 |
15 | - **2024-09-01**: We are working on GaLore 2, which is a more efficient and accessible version of GaLore. Please stay tuned!
16 | - **2024-07-11**: We release Q-GaLore: Quantized GaLore with INT4 Projection. [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)]
17 |
18 | - **2024-07-01**: GaLore is accepted to ICML 2024 as Oral!
19 |
20 | - **2024-04-20**: Please join our Slack workspace [GaLore-Social](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) to discuss with us and the community.
21 |
22 | ## Installation
23 |
24 | ### Install GaLore optimizer
25 | Install from pip:
26 | ```bash
27 | pip install galore-torch
28 | ```
29 |
30 | or if you want to install from source:
31 |
32 | ```bash
33 | git clone git@github.com:jiaweizzhao/GaLore.git
34 | cd GaLore
35 | pip install -e .
36 | ```
37 |
38 | ### Install experiment dependencies
39 |
40 | ```bash
41 | pip install -r exp_requirements.txt
42 | ```
43 |
44 | Our experiment scripts are tested on Python 3.8 with PyTorch 2.1.
45 |
46 | ## Usage
47 |
48 | ### Save optimizer memory using GaLore optimizers
49 |
50 | ```python
51 | from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor
52 | # define param groups as galore_params and non_galore_params
53 | param_groups = [{'params': non_galore_params},
54 | {'params': galore_params, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}]
55 | optimizer = GaLoreAdamW(param_groups, lr=0.01)
56 | ```
57 | ### Save weight gradient memory using per-layer weight updates
58 |
59 | We use `register_post_accumulate_grad_hook` provided by [PyTorch](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html) (`torch>=2.1.0`) to enable per-layer weight updates. An example is shown below:
60 |
61 | ```python
62 | # define an optimizer for each parameter p, and store them in optimizer_dict
63 | for p in model.parameters():
64 | if p.requires_grad:
65 | optimizer_dict[p] = GaLoreAdamW([{'params': p, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}], lr=0.01)
66 |
67 | # define a hook function to update the parameter p during the backward pass
68 | def optimizer_hook(p):
69 | if p.grad is None:
70 | return
71 | optimizer_dict[p].step()
72 | optimizer_dict[p].zero_grad()
73 |
74 | # Register the hook onto every parameter
75 | for p in model.parameters():
76 | if p.requires_grad:
77 | p.register_post_accumulate_grad_hook(optimizer_hook)
78 | ```
79 | More details can be found in [torchrun_main.py](https://github.com/jiaweizzhao/GaLore/blob/a6bc1650984b1c090a4e108d7c0e3109ee7ad844/torchrun_main.py#L334).
80 |
81 | ## Benchmark 1: Pre-Training LLaMA on C4 dataset
82 | `torchrun_main.py` is the main script for training LLaMA models on C4 with GaLore. Our benchmark scripts for various sizes of models are in `scripts/benchmark_c4` folder.
83 | For example, to train a 60m model on C4, do the following:
84 |
85 | ```bash
86 | # LLaMA-60M, GaLore-Adam, 1 A100, 1 Node
87 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \
88 | --model_config configs/llama_60m.json \
89 | --lr 0.01 \
90 | --galore_scale 0.25 \
91 | --rank 128 \
92 | --update_proj_gap 200 \
93 | --batch_size 256 \
94 | --total_batch_size 512 \
95 | --num_training_steps 10000 \
96 | --warmup_steps 1000 \
97 | --weight_decay 0 \
98 | --dtype bfloat16 \
99 | --eval_every 1000 \
100 | --optimizer galore_adamw
101 | ```
102 |
103 | ### Train 7B model with a single GPU with 24GB memory
104 | To train a 7B model with a single GPU such as NVIDIA RTX 4090, all you need to do is to specify `--optimizer=galore_adamw8bit_per_layer`, which enables `GaLoreAdamW8bit` with per-layer weight updates.
105 | With activation checkpointing, you can maintain a batch size of 16 tested on NVIDIA RTX 4090.
106 |
107 | ```bash
108 | # LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing
109 | # bsz=16, 22.8G,
110 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \
111 | --model_config configs/llama_7b.json \
112 | --lr 0.005 \
113 | --galore_scale 0.25 \
114 | --rank 1024 \
115 | --update_proj_gap 500 \
116 | --batch_size 16 \
117 | --total_batch_size 512 \
118 | --activation_checkpointing \
119 | --num_training_steps 150000 \
120 | --warmup_steps 15000 \
121 | --weight_decay 0 \
122 | --grad_clipping 1.0 \
123 | --dtype bfloat16 \
124 | --eval_every 1000 \
125 | --single_gpu \
126 | --optimizer galore_adamw8bit_per_layer
127 | ```
128 |
129 | Currently per-layer weight updates technique is only supported for single GPU training (`--single_gpu`) without using `nn.parallel.DistributedDataParallel`. We are working on supporting multi-GPU training with per-layer weight updates.
130 |
131 | ## Benchmark 2: Fine-Tuning RoBERTa on GLUE tasks
132 | `run_glue.py` is the main script for fine-tuning RoBERTa models on GLUE tasks with GaLore. An example script is shown below:
133 |
134 | ```bash
135 | python run_glue.py \
136 | --model_name_or_path roberta-base \
137 | --task_name mrpc \
138 | --enable_galore \
139 | --lora_all_modules \
140 | --max_length 512 \
141 | --seed=1234 \
142 | --lora_r 4 \
143 | --galore_scale 4 \
144 | --per_device_train_batch_size 16 \
145 | --update_proj_gap 500 \
146 | --learning_rate 3e-5 \
147 | --num_train_epochs 30 \
148 | --output_dir results/ft/roberta_base/mrpc
149 | ```
150 |
151 | ## Citation
152 | ```bibtex
153 | @misc{zhao2024galore,
154 | title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection},
155 | author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian},
156 | year={2024},
157 | eprint={2403.03507},
158 | archivePrefix={arXiv},
159 | primaryClass={cs.LG}
160 | }
161 | ```
--------------------------------------------------------------------------------
/configs/llama_100m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 640,
9 | "intermediate_size": 1708,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 10,
14 | "num_hidden_layers": 12,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32100
20 | }
--------------------------------------------------------------------------------
/configs/llama_130m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 768,
9 | "intermediate_size": 2048,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 12,
14 | "num_hidden_layers": 12,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_1b.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 2048,
9 | "intermediate_size": 5461,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 32,
14 | "num_hidden_layers": 24,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_20m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 256,
9 | "intermediate_size": 688,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 4,
14 | "num_hidden_layers": 4,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_250m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 768,
9 | "intermediate_size": 2560,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 16,
14 | "num_hidden_layers": 24,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_350m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 1024,
9 | "intermediate_size": 2736,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 16,
14 | "num_hidden_layers": 24,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_35m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 384,
9 | "intermediate_size": 1024,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 8,
14 | "num_hidden_layers": 6,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_3b.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 2560,
9 | "intermediate_size": 6848,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 32,
14 | "num_hidden_layers": 32,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_40m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 416,
9 | "intermediate_size": 1024,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 8,
14 | "num_hidden_layers": 8,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_60m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 512,
9 | "intermediate_size": 1376,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 8,
14 | "num_hidden_layers": 8,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_71m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 512,
9 | "intermediate_size": 1368,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 8,
14 | "num_hidden_layers": 12,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 4096,
9 | "intermediate_size": 11008,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 2048,
12 | "model_type": "llama",
13 | "num_attention_heads": 32,
14 | "num_hidden_layers": 32,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/configs/llama_9m.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LLaMAForCausalLM"
4 | ],
5 | "bos_token_id": 0,
6 | "eos_token_id": 1,
7 | "hidden_act": "silu",
8 | "hidden_size": 128,
9 | "intermediate_size": 352,
10 | "initializer_range": 0.02,
11 | "max_sequence_length": 1024,
12 | "model_type": "llama",
13 | "num_attention_heads": 4,
14 | "num_hidden_layers": 4,
15 | "pad_token_id": -1,
16 | "rms_norm_eps": 1e-06,
17 | "transformers_version": "4.28.1",
18 | "use_cache": true,
19 | "vocab_size": 32000
20 | }
--------------------------------------------------------------------------------
/exp_requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers==4.31.0
3 | tokenizers
4 | datasets
5 | peft
6 | wandb
7 | loguru
8 | nvitop
9 | lion-pytorch
10 | matplotlib
11 | bitsandbytes
12 | scipy
13 | scikit-learn
14 | evaluate
--------------------------------------------------------------------------------
/galore_torch/__init__.py:
--------------------------------------------------------------------------------
1 | from .adafactor import Adafactor as GaLoreAdafactor
2 | from .adamw import AdamW as GaLoreAdamW
3 | from .adamw8bit import AdamW8bit as GaLoreAdamW8bit
--------------------------------------------------------------------------------
/galore_torch/adafactor.py:
--------------------------------------------------------------------------------
1 | # copy dependencies from transformers/optimization.py
2 | import math
3 |
4 | import torch
5 | from torch import nn
6 | from torch.optim import Optimizer
7 |
8 |
9 | from transformers.utils.versions import require_version
10 |
11 | from .galore_projector import GaLoreProjector
12 | from .galore_projector_tensor import GaLoreProjectorTensor
13 |
14 |
15 | class Adafactor(Optimizer):
16 | """
17 | AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
18 | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
19 |
20 | Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
21 | this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
22 | `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
23 | `relative_step=False`.
24 |
25 | Arguments:
26 | params (`Iterable[nn.parameter.Parameter]`):
27 | Iterable of parameters to optimize or dictionaries defining parameter groups.
28 | lr (`float`, *optional*):
29 | The external learning rate.
30 | eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
31 | Regularization constants for square gradient and parameter scale respectively
32 | clip_threshold (`float`, *optional*, defaults to 1.0):
33 | Threshold of root mean square of final gradient update
34 | decay_rate (`float`, *optional*, defaults to -0.8):
35 | Coefficient used to compute running averages of square
36 | beta1 (`float`, *optional*):
37 | Coefficient used for computing running averages of gradient
38 | weight_decay (`float`, *optional*, defaults to 0.0):
39 | Weight decay (L2 penalty)
40 | scale_parameter (`bool`, *optional*, defaults to `True`):
41 | If True, learning rate is scaled by root mean square
42 | relative_step (`bool`, *optional*, defaults to `True`):
43 | If True, time-dependent learning rate is computed instead of external learning rate
44 | warmup_init (`bool`, *optional*, defaults to `False`):
45 | Time-dependent learning rate computation depends on whether warm-up initialization is being used
46 |
47 | This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
48 |
49 | Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
50 |
51 | - Training without LR warmup or clip_threshold is not recommended.
52 |
53 | - use scheduled LR warm-up to fixed LR
54 | - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
55 | - Disable relative updates
56 | - Use scale_parameter=False
57 | - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
58 |
59 | Example:
60 |
61 | ```python
62 | Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
63 | ```
64 |
65 | Others reported the following combination to work well:
66 |
67 | ```python
68 | Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
69 | ```
70 |
71 | When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
72 | scheduler as following:
73 |
74 | ```python
75 | from transformers.optimization import Adafactor, AdafactorSchedule
76 |
77 | optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
78 | lr_scheduler = AdafactorSchedule(optimizer)
79 | trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
80 | ```
81 |
82 | Usage:
83 |
84 | ```python
85 | # replace AdamW with Adafactor
86 | optimizer = Adafactor(
87 | model.parameters(),
88 | lr=1e-3,
89 | eps=(1e-30, 1e-3),
90 | clip_threshold=1.0,
91 | decay_rate=-0.8,
92 | beta1=None,
93 | weight_decay=0.0,
94 | relative_step=False,
95 | scale_parameter=False,
96 | warmup_init=False,
97 | )
98 | ```"""
99 |
100 | def __init__(
101 | self,
102 | params,
103 | lr=None,
104 | eps=(1e-30, 1e-3),
105 | clip_threshold=1.0,
106 | decay_rate=-0.8,
107 | beta1=None,
108 | weight_decay=0.0,
109 | scale_parameter=True,
110 | relative_step=True,
111 | warmup_init=False,
112 | ):
113 | require_version("torch>=1.5.0") # add_ with alpha
114 | if lr is not None and relative_step:
115 | raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
116 | if warmup_init and not relative_step:
117 | raise ValueError("`warmup_init=True` requires `relative_step=True`")
118 |
119 | defaults = {
120 | "lr": lr,
121 | "eps": eps,
122 | "clip_threshold": clip_threshold,
123 | "decay_rate": decay_rate,
124 | "beta1": beta1,
125 | "weight_decay": weight_decay,
126 | "scale_parameter": scale_parameter,
127 | "relative_step": relative_step,
128 | "warmup_init": warmup_init,
129 | }
130 | super().__init__(params, defaults)
131 |
132 | @staticmethod
133 | def _get_lr(param_group, param_state):
134 | rel_step_sz = param_group["lr"]
135 | if param_group["relative_step"]:
136 | min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
137 | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
138 | param_scale = 1.0
139 | if param_group["scale_parameter"]:
140 | param_scale = max(param_group["eps"][1], param_state["RMS"])
141 | return param_scale * rel_step_sz
142 |
143 | @staticmethod
144 | def _get_options(param_group, param_shape):
145 | factored = len(param_shape) >= 2
146 | use_first_moment = param_group["beta1"] is not None
147 | return factored, use_first_moment
148 |
149 | @staticmethod
150 | def _rms(tensor):
151 | return tensor.norm(2) / (tensor.numel() ** 0.5)
152 |
153 | @staticmethod
154 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
155 | # copy from fairseq's adafactor implementation:
156 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
157 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
158 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
159 | return torch.mul(r_factor, c_factor)
160 |
161 | @torch.no_grad()
162 | def step(self, closure=None):
163 | """
164 | Performs a single optimization step
165 |
166 | Arguments:
167 | closure (callable, optional): A closure that reevaluates the model
168 | and returns the loss.
169 | """
170 | loss = None
171 | if closure is not None:
172 | loss = closure()
173 |
174 | for group in self.param_groups:
175 | for p in group["params"]:
176 | if p.grad is None:
177 | continue
178 | grad = p.grad
179 | if grad.dtype in {torch.float16, torch.bfloat16}:
180 | grad = grad.float()
181 | if grad.is_sparse:
182 | raise RuntimeError("Adafactor does not support sparse gradients.")
183 |
184 | state = self.state[p]
185 |
186 | if "step" not in state:
187 | state["step"] = 0
188 |
189 | if 'dim' not in group:
190 | group['dim'] = 2
191 |
192 | # GaLore Projection
193 | if "rank" in group:
194 | if "projector" not in state:
195 | if group['dim'] <=2:
196 | state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
197 | else:
198 | state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
199 |
200 | grad = state["projector"].project(grad, state["step"])
201 |
202 | grad_shape = grad.shape
203 |
204 | factored, use_first_moment = self._get_options(group, grad_shape)
205 | # State Initialization
206 | if "RMS" not in state:
207 | state["step"] = 0
208 |
209 | if use_first_moment:
210 | # Exponential moving average of gradient values
211 | state["exp_avg"] = torch.zeros_like(grad)
212 | if factored:
213 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
214 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
215 | else:
216 | state["exp_avg_sq"] = torch.zeros_like(grad)
217 |
218 | state["RMS"] = 0
219 | else:
220 | if use_first_moment:
221 | state["exp_avg"] = state["exp_avg"].to(grad)
222 | if factored:
223 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
224 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
225 | else:
226 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
227 |
228 | p_data_fp32 = p
229 | if p.dtype in {torch.float16, torch.bfloat16}:
230 | p_data_fp32 = p_data_fp32.float()
231 |
232 | state["step"] += 1
233 | state["RMS"] = self._rms(p_data_fp32)
234 | lr = self._get_lr(group, state)
235 |
236 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
237 | update = (grad**2) + group["eps"][0]
238 | if factored:
239 | exp_avg_sq_row = state["exp_avg_sq_row"]
240 | exp_avg_sq_col = state["exp_avg_sq_col"]
241 |
242 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
243 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
244 |
245 | # Approximation of exponential moving average of square of gradient
246 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
247 | update.mul_(grad)
248 | else:
249 | exp_avg_sq = state["exp_avg_sq"]
250 |
251 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
252 | update = exp_avg_sq.rsqrt().mul_(grad)
253 |
254 | update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
255 | update.mul_(lr)
256 |
257 | if use_first_moment:
258 | exp_avg = state["exp_avg"]
259 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
260 | update = exp_avg
261 |
262 | # GaLore Projection Back
263 | if "rank" in group:
264 | update = state["projector"].project_back(update)
265 |
266 | if group["weight_decay"] != 0:
267 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
268 |
269 | p_data_fp32.add_(-update)
270 |
271 | if p.dtype in {torch.float16, torch.bfloat16}:
272 | p.copy_(p_data_fp32)
273 |
274 | return loss
275 |
--------------------------------------------------------------------------------
/galore_torch/adamw.py:
--------------------------------------------------------------------------------
1 | # copy dependencies from transformers/optimization.py
2 | import math
3 | import warnings
4 | from typing import Callable, Iterable, Tuple
5 |
6 | import torch
7 | from torch import nn
8 | from torch.optim import Optimizer
9 |
10 | from transformers.utils.versions import require_version
11 |
12 | from .galore_projector import GaLoreProjector
13 | from .galore_projector_tensor import GaLoreProjectorTensor
14 |
15 |
16 | class AdamW(Optimizer):
17 | """
18 | Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
19 | Regularization](https://arxiv.org/abs/1711.05101).
20 |
21 | Parameters:
22 | params (`Iterable[nn.parameter.Parameter]`):
23 | Iterable of parameters to optimize or dictionaries defining parameter groups.
24 | lr (`float`, *optional*, defaults to 0.001):
25 | The learning rate to use.
26 | betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
27 | Adam's betas parameters (b1, b2).
28 | eps (`float`, *optional*, defaults to 1e-06):
29 | Adam's epsilon for numerical stability.
30 | weight_decay (`float`, *optional*, defaults to 0.0):
31 | Decoupled weight decay to apply.
32 | correct_bias (`bool`, *optional*, defaults to `True`):
33 | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
34 | no_deprecation_warning (`bool`, *optional*, defaults to `False`):
35 | A flag used to disable the deprecation warning (set to `True` to disable the warning).
36 | """
37 |
38 | def __init__(
39 | self,
40 | params: Iterable[nn.parameter.Parameter],
41 | lr: float = 1e-3,
42 | betas: Tuple[float, float] = (0.9, 0.999),
43 | eps: float = 1e-6,
44 | weight_decay: float = 0.0,
45 | correct_bias: bool = True,
46 | no_deprecation_warning: bool = False,
47 | ):
48 | if not no_deprecation_warning:
49 | warnings.warn(
50 | "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
51 | " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
52 | " warning",
53 | FutureWarning,
54 | )
55 | require_version("torch>=1.5.0") # add_ with alpha
56 | if lr < 0.0:
57 | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
58 | if not 0.0 <= betas[0] < 1.0:
59 | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
60 | if not 0.0 <= betas[1] < 1.0:
61 | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
62 | if not 0.0 <= eps:
63 | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
64 | defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
65 | super().__init__(params, defaults)
66 |
67 | @torch.no_grad()
68 | def step(self, closure: Callable = None):
69 | """
70 | Performs a single optimization step.
71 |
72 | Arguments:
73 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
74 | """
75 | loss = None
76 | if closure is not None:
77 | loss = closure()
78 |
79 | for group in self.param_groups:
80 | for p in group["params"]:
81 | if p.grad is None:
82 | continue
83 | grad = p.grad
84 | if grad.is_sparse:
85 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
86 |
87 | state = self.state[p]
88 |
89 | if "step" not in state:
90 | state["step"] = 0
91 |
92 | if 'dim' not in group:
93 | group['dim'] = 2
94 |
95 | # GaLore Projection
96 | if "rank" in group:
97 | if "projector" not in state:
98 | if group['dim'] <=2:
99 | state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
100 | else:
101 | state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
102 | grad = state["projector"].project(grad, state["step"])
103 |
104 | # State initialization
105 | if "exp_avg" not in state:
106 | # Exponential moving average of gradient values
107 | state["exp_avg"] = torch.zeros_like(grad)
108 | # Exponential moving average of squared gradient values
109 | state["exp_avg_sq"] = torch.zeros_like(grad)
110 |
111 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
112 | beta1, beta2 = group["betas"]
113 |
114 | state["step"] += 1
115 |
116 | # Decay the first and second moment running average coefficient
117 | # In-place operations to update the averages at the same time
118 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
119 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
120 | denom = exp_avg_sq.sqrt().add_(group["eps"])
121 |
122 | step_size = group["lr"]
123 | if group["correct_bias"]: # No bias correction for Bert
124 | bias_correction1 = 1.0 - beta1 ** state["step"]
125 | bias_correction2 = 1.0 - beta2 ** state["step"]
126 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
127 |
128 | # compute norm gradient
129 | norm_grad = exp_avg / denom
130 |
131 | # GaLore Projection Back
132 | if "rank" in group:
133 | norm_grad = state["projector"].project_back(norm_grad)
134 |
135 | p.add_(norm_grad, alpha=-step_size)
136 |
137 | # Just adding the square of the weights to the loss function is *not*
138 | # the correct way of using L2 regularization/weight decay with Adam,
139 | # since that will interact with the m and v parameters in strange ways.
140 | #
141 | # Instead we want to decay the weights in a manner that doesn't interact
142 | # with the m/v parameters. This is equivalent to adding the square
143 | # of the weights to the loss with plain (non-momentum) SGD.
144 | # Add weight decay at the end (fixed version)
145 | if group["weight_decay"] > 0.0:
146 | p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
147 |
148 | return loss
149 |
--------------------------------------------------------------------------------
/galore_torch/adamw8bit.py:
--------------------------------------------------------------------------------
1 | from bitsandbytes.optim.optimizer import Optimizer2State
2 |
3 | import torch
4 |
5 | from .galore_projector import GaLoreProjector
6 | from .galore_projector_tensor import GaLoreProjectorTensor
7 |
8 |
9 | class AdamW8bit(Optimizer2State):
10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
11 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
12 |
13 | @torch.no_grad()
14 | def step(self, closure=None):
15 | """Performs a single optimization step.
16 |
17 | Arguments:
18 | closure (callable, optional): A closure that reevaluates the model
19 | and returns the loss.
20 | """
21 | loss = None
22 | if closure is not None:
23 | with torch.enable_grad():
24 | loss = closure()
25 |
26 | overflows = []
27 |
28 | if not self.initialized:
29 | self.check_overrides()
30 | self.to_gpu() # needed for fairseq pure fp16 training
31 | self.initialized = True
32 |
33 | #if self.is_paged: self.page_mng.prefetch_all()
34 | for gindex, group in enumerate(self.param_groups):
35 | for pindex, p in enumerate(group["params"]):
36 | if p.grad is None:
37 | continue
38 | state = self.state[p]
39 |
40 | if "step" not in state:
41 | state["step"] = 0
42 |
43 | if 'dim' not in group:
44 | group['dim'] = 2
45 |
46 | # GaLore Projection
47 | if "rank" in group:
48 | if "projector" not in state:
49 | if group['dim'] <= 2:
50 | state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
51 | else:
52 | state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
53 | if 'weight_decay' in group and group['weight_decay'] > 0:
54 | # ensure that the weight decay is not applied to the norm grad
55 | group['weight_decay_saved'] = group['weight_decay']
56 | group['weight_decay'] = 0
57 |
58 | grad = state["projector"].project(p.grad, state["step"])
59 |
60 | # suboptimal implementation
61 | p.saved_data = p.data.clone()
62 | p.data = grad.clone().to(p.data.dtype).to(p.data.device)
63 | p.data.zero_()
64 | p.grad = grad
65 |
66 | if 'state1' not in state:
67 | self.init_state(group, p, gindex, pindex)
68 |
69 | self.prefetch_state(p)
70 | self.update_step(group, p, gindex, pindex)
71 | torch.cuda.synchronize()
72 |
73 | # GaLore Projection Back
74 | if "rank" in group:
75 | p.data = p.saved_data.add_(state["projector"].project_back(p.data))
76 |
77 | # apply weight decay
78 | if 'weight_decay_saved' in group:
79 | p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay_saved'])
80 | group['weight_decay'] = group['weight_decay_saved']
81 | del group['weight_decay_saved']
82 |
83 | if self.is_paged:
84 | # all paged operation are asynchronous, we need
85 | # to sync to make sure all tensors are in the right state
86 | torch.cuda.synchronize()
87 |
88 |
89 | return loss
--------------------------------------------------------------------------------
/galore_torch/galore_projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class GaLoreProjector:
4 | def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
5 | self.rank = rank
6 | self.verbose = verbose
7 | self.update_proj_gap = update_proj_gap
8 | self.scale = scale
9 | self.ortho_matrix = None
10 | self.proj_type = proj_type
11 |
12 | def project(self, full_rank_grad, iter):
13 | if self.proj_type == 'std':
14 | if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
15 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
16 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
17 | low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t().to(full_rank_grad.device.type))
18 | else:
19 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
20 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
21 | low_rank_grad = torch.matmul(self.ortho_matrix.t().to(full_rank_grad.device.type), full_rank_grad)
22 | elif self.proj_type == 'reverse_std':
23 | if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
24 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
25 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
26 | low_rank_grad = torch.matmul(self.ortho_matrix.t().to(full_rank_grad.device.type),full_rank_grad)
27 | else:
28 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
29 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
30 | low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t().to(full_rank_grad.device.type))
31 | elif self.proj_type == 'right':
32 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
33 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
34 | low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t().to(full_rank_grad.device.type))
35 | elif self.proj_type == 'left':
36 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
37 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
38 | low_rank_grad = torch.matmul(self.ortho_matrix.t().to(full_rank_grad.device.type), full_rank_grad)
39 | elif self.proj_type == 'full':
40 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
41 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
42 | low_rank_grad = torch.matmul(self.ortho_matrix[0].t().to(full_rank_grad.device.type), full_rank_grad) @ self.ortho_matrix[1].t().to(full_rank_grad.device.type)
43 |
44 | return low_rank_grad
45 |
46 | def project_back(self, low_rank_grad):
47 | if self.proj_type == 'std':
48 | if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
49 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix.to(low_rank_grad.device.type))
50 | else:
51 | full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad)
52 | elif self.proj_type == 'reverse_std':
53 | if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
54 | full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad)
55 | else:
56 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix.to(low_rank_grad.device.type))
57 | elif self.proj_type == 'right':
58 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix.to(low_rank_grad.device.type))
59 | elif self.proj_type == 'left':
60 | full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad)
61 | elif self.proj_type == 'full':
62 | full_rank_grad = torch.matmul(self.ortho_matrix[0].to(low_rank_grad.device.type), low_rank_grad) @ self.ortho_matrix[1].to(low_rank_grad.device.type)
63 |
64 |
65 | return full_rank_grad * self.scale
66 |
67 |
68 | # svd decomposition
69 | def get_orthogonal_matrix(self, weights, rank, type):
70 | module_params = weights
71 |
72 | if module_params.data.dtype != torch.float:
73 | float_data = False
74 | original_type = module_params.data.dtype
75 | original_device = module_params.data.device
76 | matrix = module_params.data.float()
77 | else:
78 | float_data = True
79 | matrix = module_params.data
80 |
81 | U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)
82 |
83 | #make the smaller matrix always to be orthogonal matrix
84 | if type=='right':
85 | B = Vh[:rank, :]
86 | if not float_data:
87 | B = B.to(original_device).type(original_type)
88 | return B
89 | elif type=='left':
90 | A = U[:, :rank]
91 | if not float_data:
92 | A = A.to(original_device).type(original_type)
93 | return A
94 | elif type=='full':
95 | A = U[:, :rank]
96 | B = Vh[:rank, :]
97 | if not float_data:
98 | A = A.to(original_device).type(original_type)
99 | B = B.to(original_device).type(original_type)
100 | return [A, B]
101 | else:
102 | raise ValueError('type should be left, right or full')
103 |
--------------------------------------------------------------------------------
/galore_torch/galore_projector_tensor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tensorly.decomposition import tucker
3 | from tensorly import tenalg
4 |
5 | # The GaLoreProjector class in Python implements a projection method using orthogonal matrix
6 | # decomposition for low-rank approximation of gradients for general tensors of dimension >2.
7 | # We use tensor decomposition using tensorly library: https://tensorly.org/stable/index.html
8 | class GaLoreProjectorTensor:
9 | """
10 | A class that represents a projector for the GaLore algorithm.
11 |
12 | Args:
13 | rank (int): The rank of the projector.
14 | verbose (bool, optional): Whether to print verbose output. Defaults to False.
15 | update_proj_gap (int, optional): The number of iterations between updating the orthogonal matrix. Defaults to 200.
16 | scale (float, optional): The scaling factor for the projected gradients. Defaults to 1.0.
17 | """
18 |
19 | def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0):
20 | self.rank = rank
21 | self.verbose = verbose
22 | self.update_proj_gap = update_proj_gap
23 | self.scale = scale
24 | self.ortho_matrix = None
25 | self.transformed_low_rank = None
26 |
27 | def project(self, full_rank_grad, iter):
28 | """
29 | Projects the full-rank gradients onto the low-rank subspace.
30 |
31 | Args:
32 | full_rank_grad (torch.Tensor): The full-rank gradients.
33 | iter (int): The current iteration.
34 |
35 | Returns:
36 | torch.Tensor: The transformed low-rank gradients.
37 | """
38 | if self.ortho_matrix is None and iter % self.update_proj_gap == 0:
39 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank)
40 | self.transformed_low_rank = self.transform(self.ortho_matrix, full_rank_grad)
41 | return self.transformed_low_rank
42 |
43 | def project_back(self, low_rank_grad):
44 | """
45 | Projects the low-rank gradients back to the full-rank space.
46 |
47 | Args:
48 | low_rank_grad (torch.Tensor): The low-rank gradients.
49 |
50 | Returns:
51 | torch.Tensor: The full-rank gradients.
52 | """
53 | full_rank_grad = self.inverse_transform(self.ortho_matrix, self.transformed_low_rank)
54 | return full_rank_grad * self.scale
55 |
56 | # svd decomposition
57 | def get_orthogonal_matrix(self, weights, rank_all):
58 | """
59 | Computes the orthogonal matrix using SVD decomposition.
60 |
61 | Args:
62 | weights (torch.Tensor): The weights to decompose.
63 | rank_all (int): The desired rank of the decomposition.
64 |
65 | Returns:
66 | tuple: A tuple containing the core and factors of the orthogonal matrix.
67 | """
68 | module_params = weights
69 | if module_params.data.dtype != torch.float:
70 | matrix = module_params.data.float()
71 | else:
72 | matrix = module_params.data
73 | tucker_tensor = tucker(matrix, rank=rank_all)
74 | return tucker_tensor
75 |
76 | def transform(self, tensor, x):
77 | """
78 | Transforms the input tensor using the factors of the orthogonal matrix.
79 |
80 | Args:
81 | tensor (tuple): A tuple containing the core and factors of the orthogonal matrix.
82 | x (torch.Tensor): The input tensor.
83 |
84 | Returns:
85 | torch.Tensor: The transformed tensor.
86 | """
87 | _, factors = tensor
88 | return tenalg.multi_mode_dot(x, factors, transpose=True)
89 |
90 | def inverse_transform(self, tensor, x):
91 | """
92 | Inverse transforms the input tensor using the factors of the orthogonal matrix.
93 |
94 | Args:
95 | tensor (tuple): A tuple containing the core and factors of the orthogonal matrix.
96 | x (torch.Tensor): The input tensor.
97 |
98 | Returns:
99 | torch.Tensor: The inverse transformed tensor.
100 | """
101 | _, factors = tensor
102 | return tenalg.multi_mode_dot(x, factors)
103 |
--------------------------------------------------------------------------------
/imgs/galore_code_box.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiaweizzhao/GaLore/2cc66f88cce189e505affbb91042a8e77f5bf4e9/imgs/galore_code_box.png
--------------------------------------------------------------------------------
/imgs/subspace_learning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiaweizzhao/GaLore/2cc66f88cce189e505affbb91042a8e77f5bf4e9/imgs/subspace_learning.png
--------------------------------------------------------------------------------
/peft_pretraining/args_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 |
4 | from loguru import logger
5 |
6 |
7 | def check_args_torchrun_main(args):
8 |
9 | if args.save_dir is None:
10 | # use checkpoints / model name, date and time as save directory
11 | args.save_dir = f"checkpoints/{args.model_config.split('/')[-1].rstrip('.json')}-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
12 |
13 | if args.tags is not None:
14 | args.tags = args.tags.split(",")
15 |
16 | if args.total_batch_size is None:
17 | args.gradient_accumulation = args.gradient_accumulation or 1
18 | args.total_batch_size = args.batch_size * args.gradient_accumulation
19 |
20 | assert args.total_batch_size % args.batch_size == 0, "total_batch_size must be divisible by batch_size"
21 |
22 | if args.max_train_tokens is not None:
23 | args.num_training_steps = args.max_train_tokens // args.total_batch_size
24 | logger.info(f"Training for {args.num_training_steps} update steps")
25 |
26 | if args.continue_from is not None:
27 | assert os.path.exists(args.continue_from), f"--continue_from={args.continue_from} does not exist"
28 |
29 | if args.dtype in ["fp16", "float16"]:
30 | raise NotImplementedError("fp16 is not supported in torchrun_main.py. Use deepspeed_main.py instead (but it seems to have bugs)")
31 |
32 | return args
33 |
--------------------------------------------------------------------------------
/peft_pretraining/dataloader.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import torch
4 | from torch.utils.data import IterableDataset, get_worker_info
5 |
6 |
7 | class PreprocessedIterableDataset(IterableDataset):
8 | def __init__(self, data, tokenizer, batch_size, max_length):
9 | super().__init__()
10 | self.data = data
11 | self.tokenizer = tokenizer
12 | self.batch_size = batch_size
13 | self.max_length = max_length
14 |
15 | def __iter__(self):
16 | worker_info = get_worker_info()
17 | if worker_info is None:
18 | # If no worker_info is provided, we are not using DataLoader workers, so yield all data
19 | iter_data = iter(self.data)
20 | else:
21 | # If using DataLoader workers, yield a subset of the data for this worker
22 | worker_id = worker_info.id
23 | num_workers = worker_info.num_workers
24 | iter_data = itertools.islice(self.data, worker_id, None, num_workers)
25 |
26 | batch = []
27 | for example in iter_data:
28 | tokenized_example = self.tokenizer(
29 | example["text"],
30 | max_length=self.max_length,
31 | truncation=True,
32 | padding="max_length",
33 | return_tensors="pt",
34 | )
35 | batch.append(tokenized_example)
36 |
37 | if len(batch) == self.batch_size:
38 | yield self._format_batch(batch)
39 | batch = []
40 |
41 | if batch:
42 | yield self._format_batch(batch)
43 |
44 | def _format_batch(self, batch):
45 | input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
46 | attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
47 |
48 | return {"input_ids": input_ids, "attention_mask": attention_mask}
49 |
--------------------------------------------------------------------------------
/peft_pretraining/modeling_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ PyTorch LLaMA model."""
21 | import math
22 | from typing import List, Optional, Tuple, Union
23 |
24 | import torch
25 | import torch.utils.checkpoint
26 | from torch import nn
27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28 |
29 | from transformers.activations import ACT2FN
30 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31 | from transformers.modeling_utils import PreTrainedModel
32 | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33 | from transformers.models.llama.configuration_llama import LlamaConfig
34 |
35 |
36 | logger = logging.get_logger(__name__)
37 |
38 | _CONFIG_FOR_DOC = "LlamaConfig"
39 |
40 |
41 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42 | def _make_causal_mask(
43 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44 | ):
45 | """
46 | Make causal mask used for bi-directional self-attention.
47 | """
48 | bsz, tgt_len = input_ids_shape
49 | mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
50 | mask_cond = torch.arange(mask.size(-1), device=device)
51 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52 | mask = mask.to(dtype)
53 |
54 | if past_key_values_length > 0:
55 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57 |
58 |
59 | # Copied from transformers.models.bart.modeling_bart._expand_mask
60 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61 | """
62 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63 | """
64 | bsz, src_len = mask.size()
65 | tgt_len = tgt_len if tgt_len is not None else src_len
66 |
67 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68 |
69 | inverted_mask = 1.0 - expanded_mask
70 |
71 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72 |
73 |
74 | class LlamaRMSNorm(nn.Module):
75 | def __init__(self, hidden_size, eps=1e-6):
76 | """
77 | LlamaRMSNorm is equivalent to T5LayerNorm
78 | """
79 | super().__init__()
80 | self.weight = nn.Parameter(torch.ones(hidden_size))
81 | self.variance_epsilon = eps
82 |
83 | def forward(self, hidden_states):
84 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
85 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86 |
87 | # convert into half-precision if necessary
88 | if self.weight.dtype in [torch.float16, torch.bfloat16]:
89 | hidden_states = hidden_states.to(self.weight.dtype)
90 |
91 | return self.weight * hidden_states
92 |
93 |
94 | class LlamaRotaryEmbedding(torch.nn.Module):
95 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
96 | super().__init__()
97 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
98 | self.register_buffer("inv_freq", inv_freq)
99 |
100 | # Build here to make `torch.jit.trace` work.
101 | self.max_seq_len_cached = max_position_embeddings
102 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
103 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
104 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
105 | emb = torch.cat((freqs, freqs), dim=-1)
106 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
107 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
108 |
109 | def forward(self, x, seq_len=None):
110 | # x: [bs, num_attention_heads, seq_len, head_size]
111 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
112 | if seq_len > self.max_seq_len_cached:
113 | self.max_seq_len_cached = seq_len
114 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
115 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
116 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
117 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
118 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
119 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
120 | return (
121 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
122 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123 | )
124 |
125 |
126 | def rotate_half(x):
127 | """Rotates half the hidden dims of the input."""
128 | x1 = x[..., : x.shape[-1] // 2]
129 | x2 = x[..., x.shape[-1] // 2 :]
130 | return torch.cat((-x2, x1), dim=-1)
131 |
132 |
133 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
134 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
135 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
136 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
137 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
138 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
139 | q_embed = (q * cos) + (rotate_half(q) * sin)
140 | k_embed = (k * cos) + (rotate_half(k) * sin)
141 | return q_embed, k_embed
142 |
143 |
144 | class LlamaMLP(nn.Module):
145 | def __init__(
146 | self,
147 | hidden_size: int,
148 | intermediate_size: int,
149 | hidden_act: str,
150 | ):
151 | super().__init__()
152 | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
153 | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
154 | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
155 | self.act_fn = ACT2FN[hidden_act]
156 |
157 | def forward(self, x):
158 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
159 |
160 |
161 | class LlamaAttention(nn.Module):
162 | """Multi-headed attention from 'Attention Is All You Need' paper"""
163 |
164 | def __init__(self, config: LlamaConfig):
165 | super().__init__()
166 | self.config = config
167 | self.hidden_size = config.hidden_size
168 | self.num_heads = config.num_attention_heads
169 | self.head_dim = self.hidden_size // self.num_heads
170 | self.max_position_embeddings = config.max_position_embeddings
171 |
172 | if (self.head_dim * self.num_heads) != self.hidden_size:
173 | raise ValueError(
174 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
175 | f" and `num_heads`: {self.num_heads})."
176 | )
177 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
178 | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
179 | self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
180 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
181 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
182 |
183 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
184 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
185 |
186 | def forward(
187 | self,
188 | hidden_states: torch.Tensor,
189 | attention_mask: Optional[torch.Tensor] = None,
190 | position_ids: Optional[torch.LongTensor] = None,
191 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
192 | output_attentions: bool = False,
193 | use_cache: bool = False,
194 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
195 | bsz, q_len, _ = hidden_states.size()
196 |
197 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
200 |
201 | kv_seq_len = key_states.shape[-2]
202 | if past_key_value is not None:
203 | kv_seq_len += past_key_value[0].shape[-2]
204 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
205 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
206 | # [bsz, nh, t, hd]
207 |
208 | if past_key_value is not None:
209 | # reuse k, v, self_attention
210 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
211 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
212 |
213 | past_key_value = (key_states, value_states) if use_cache else None
214 |
215 | if attention_mask is not None:
216 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
217 | raise ValueError(
218 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
219 | )
220 |
221 | # WARNING: padding mask is ignored, causal is always applied
222 | attn_output = torch.nn.functional.scaled_dot_product_attention(
223 | query_states, key_states, value_states, dropout_p=0.0, is_causal=True,
224 | )
225 |
226 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
227 | raise ValueError(
228 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
229 | f" {attn_output.size()}"
230 | )
231 |
232 | attn_output = attn_output.transpose(1, 2)
233 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
234 |
235 | attn_output = self.o_proj(attn_output)
236 |
237 | if not output_attentions:
238 | attn_weights = None
239 |
240 | return attn_output, attn_weights, past_key_value
241 |
242 |
243 | class LlamaDecoderLayer(nn.Module):
244 | def __init__(self, config: LlamaConfig):
245 | super().__init__()
246 | self.hidden_size = config.hidden_size
247 | self.self_attn = LlamaAttention(config=config)
248 | self.mlp = LlamaMLP(
249 | hidden_size=self.hidden_size,
250 | intermediate_size=config.intermediate_size,
251 | hidden_act=config.hidden_act,
252 | )
253 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
254 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
255 |
256 | def forward(
257 | self,
258 | hidden_states: torch.Tensor,
259 | attention_mask: Optional[torch.Tensor] = None,
260 | position_ids: Optional[torch.LongTensor] = None,
261 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
262 | output_attentions: Optional[bool] = False,
263 | use_cache: Optional[bool] = False,
264 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
265 | """
266 | Args:
267 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
268 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
269 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
270 | output_attentions (`bool`, *optional*):
271 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
272 | returned tensors for more detail.
273 | use_cache (`bool`, *optional*):
274 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
275 | (see `past_key_values`).
276 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
277 | """
278 |
279 | residual = hidden_states
280 |
281 | hidden_states = self.input_layernorm(hidden_states)
282 |
283 | # Self Attention
284 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
285 | hidden_states=hidden_states,
286 | attention_mask=attention_mask,
287 | position_ids=position_ids,
288 | past_key_value=past_key_value,
289 | output_attentions=output_attentions,
290 | use_cache=use_cache,
291 | )
292 | hidden_states = residual + hidden_states
293 |
294 | # Fully Connected
295 | residual = hidden_states
296 | hidden_states = self.post_attention_layernorm(hidden_states)
297 | hidden_states = self.mlp(hidden_states)
298 | hidden_states = residual + hidden_states
299 |
300 | outputs = (hidden_states,)
301 |
302 | if output_attentions:
303 | outputs += (self_attn_weights,)
304 |
305 | if use_cache:
306 | outputs += (present_key_value,)
307 |
308 | return outputs
309 |
310 |
311 | LLAMA_START_DOCSTRING = r"""
312 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
313 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
314 | etc.)
315 |
316 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
317 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
318 | and behavior.
319 |
320 | Parameters:
321 | config ([`LlamaConfig`]):
322 | Model configuration class with all the parameters of the model. Initializing with a config file does not
323 | load the weights associated with the model, only the configuration. Check out the
324 | [`~PreTrainedModel.from_pretrained`] method to load the model weights.
325 | """
326 |
327 |
328 | @add_start_docstrings(
329 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
330 | LLAMA_START_DOCSTRING,
331 | )
332 | class LlamaPreTrainedModel(PreTrainedModel):
333 | config_class = LlamaConfig
334 | base_model_prefix = "model"
335 | supports_gradient_checkpointing = True
336 | _no_split_modules = ["LlamaDecoderLayer"]
337 | _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
338 |
339 | def _init_weights(self, module):
340 | std = self.config.initializer_range
341 | if isinstance(module, nn.Linear):
342 | module.weight.data.normal_(mean=0.0, std=std)
343 | if module.bias is not None:
344 | module.bias.data.zero_()
345 | elif isinstance(module, nn.Embedding):
346 | module.weight.data.normal_(mean=0.0, std=std)
347 | if module.padding_idx is not None:
348 | module.weight.data[module.padding_idx].zero_()
349 |
350 | def _set_gradient_checkpointing(self, module, value=False):
351 | if isinstance(module, LlamaModel):
352 | module.gradient_checkpointing = value
353 |
354 |
355 | LLAMA_INPUTS_DOCSTRING = r"""
356 | Args:
357 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
358 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
359 | it.
360 |
361 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
362 | [`PreTrainedTokenizer.__call__`] for details.
363 |
364 | [What are input IDs?](../glossary#input-ids)
365 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
366 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
367 |
368 | - 1 for tokens that are **not masked**,
369 | - 0 for tokens that are **masked**.
370 |
371 | [What are attention masks?](../glossary#attention-mask)
372 |
373 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
374 | [`PreTrainedTokenizer.__call__`] for details.
375 |
376 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
377 | `past_key_values`).
378 |
379 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
380 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
381 | information on the default strategy.
382 |
383 | - 1 indicates the head is **not masked**,
384 | - 0 indicates the head is **masked**.
385 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
386 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
387 | config.n_positions - 1]`.
388 |
389 | [What are position IDs?](../glossary#position-ids)
390 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
391 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
392 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
393 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
394 |
395 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
396 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
397 |
398 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
399 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
400 | `decoder_input_ids` of shape `(batch_size, sequence_length)`.
401 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
402 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
403 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
404 | model's internal embedding lookup matrix.
405 | use_cache (`bool`, *optional*):
406 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
407 | `past_key_values`).
408 | output_attentions (`bool`, *optional*):
409 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
410 | tensors for more detail.
411 | output_hidden_states (`bool`, *optional*):
412 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
413 | more detail.
414 | return_dict (`bool`, *optional*):
415 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
416 | """
417 |
418 |
419 | @add_start_docstrings(
420 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
421 | LLAMA_START_DOCSTRING,
422 | )
423 | class LlamaModel(LlamaPreTrainedModel):
424 | """
425 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
426 |
427 | Args:
428 | config: LlamaConfig
429 | """
430 |
431 | def __init__(self, config: LlamaConfig):
432 | super().__init__(config)
433 | self.padding_idx = config.pad_token_id
434 | self.vocab_size = config.vocab_size
435 |
436 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
437 | self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
438 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
439 |
440 | self.gradient_checkpointing = False
441 | # Initialize weights and apply final processing
442 | self.post_init()
443 |
444 | def get_input_embeddings(self):
445 | return self.embed_tokens
446 |
447 | def set_input_embeddings(self, value):
448 | self.embed_tokens = value
449 |
450 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
451 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
452 | # create causal mask
453 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
454 | combined_attention_mask = None
455 | if input_shape[-1] > 1:
456 | combined_attention_mask = _make_causal_mask(
457 | input_shape,
458 | inputs_embeds.dtype,
459 | device=inputs_embeds.device,
460 | past_key_values_length=past_key_values_length,
461 | )
462 |
463 | if attention_mask is not None:
464 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
465 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
466 | inputs_embeds.device
467 | )
468 | combined_attention_mask = (
469 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
470 | )
471 |
472 | return combined_attention_mask
473 |
474 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
475 | def forward(
476 | self,
477 | input_ids: torch.LongTensor = None,
478 | attention_mask: Optional[torch.Tensor] = None,
479 | position_ids: Optional[torch.LongTensor] = None,
480 | past_key_values: Optional[List[torch.FloatTensor]] = None,
481 | inputs_embeds: Optional[torch.FloatTensor] = None,
482 | use_cache: Optional[bool] = None,
483 | output_attentions: Optional[bool] = None,
484 | output_hidden_states: Optional[bool] = None,
485 | return_dict: Optional[bool] = None,
486 | ) -> Union[Tuple, BaseModelOutputWithPast]:
487 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
488 | output_hidden_states = (
489 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
490 | )
491 | use_cache = use_cache if use_cache is not None else self.config.use_cache
492 |
493 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
494 |
495 | # retrieve input_ids and inputs_embeds
496 | if input_ids is not None and inputs_embeds is not None:
497 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
498 | elif input_ids is not None:
499 | batch_size, seq_length = input_ids.shape
500 | elif inputs_embeds is not None:
501 | batch_size, seq_length, _ = inputs_embeds.shape
502 | else:
503 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
504 |
505 | seq_length_with_past = seq_length
506 | past_key_values_length = 0
507 |
508 | if past_key_values is not None:
509 | past_key_values_length = past_key_values[0][0].shape[2]
510 | seq_length_with_past = seq_length_with_past + past_key_values_length
511 |
512 | if position_ids is None:
513 | device = input_ids.device if input_ids is not None else inputs_embeds.device
514 | position_ids = torch.arange(
515 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
516 | )
517 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
518 | else:
519 | position_ids = position_ids.view(-1, seq_length).long()
520 |
521 | if inputs_embeds is None:
522 | inputs_embeds = self.embed_tokens(input_ids)
523 | # embed positions
524 | if attention_mask is None:
525 | attention_mask = torch.ones(
526 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
527 | )
528 | attention_mask = self._prepare_decoder_attention_mask(
529 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
530 | )
531 |
532 | hidden_states = inputs_embeds
533 |
534 | if self.gradient_checkpointing and self.training:
535 | if use_cache:
536 | logger.warning_once(
537 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
538 | )
539 | use_cache = False
540 |
541 | # decoder layers
542 | all_hidden_states = () if output_hidden_states else None
543 | all_self_attns = () if output_attentions else None
544 | next_decoder_cache = () if use_cache else None
545 |
546 | for idx, decoder_layer in enumerate(self.layers):
547 | if output_hidden_states:
548 | all_hidden_states += (hidden_states,)
549 |
550 | past_key_value = past_key_values[idx] if past_key_values is not None else None
551 |
552 | if self.gradient_checkpointing and self.training:
553 |
554 | def create_custom_forward(module):
555 | def custom_forward(*inputs):
556 | # None for past_key_value
557 | return module(*inputs, output_attentions, None)
558 |
559 | return custom_forward
560 |
561 | layer_outputs = torch.utils.checkpoint.checkpoint(
562 | create_custom_forward(decoder_layer),
563 | hidden_states,
564 | attention_mask,
565 | position_ids,
566 | None,
567 | )
568 | else:
569 | layer_outputs = decoder_layer(
570 | hidden_states,
571 | attention_mask=attention_mask,
572 | position_ids=position_ids,
573 | past_key_value=past_key_value,
574 | output_attentions=output_attentions,
575 | use_cache=use_cache,
576 | )
577 |
578 | hidden_states = layer_outputs[0]
579 |
580 | if use_cache:
581 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
582 |
583 | if output_attentions:
584 | all_self_attns += (layer_outputs[1],)
585 |
586 | hidden_states = self.norm(hidden_states)
587 |
588 | # add hidden states from the last decoder layer
589 | if output_hidden_states:
590 | all_hidden_states += (hidden_states,)
591 |
592 | next_cache = next_decoder_cache if use_cache else None
593 | if not return_dict:
594 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
595 | return BaseModelOutputWithPast(
596 | last_hidden_state=hidden_states,
597 | past_key_values=next_cache,
598 | hidden_states=all_hidden_states,
599 | attentions=all_self_attns,
600 | )
601 |
602 |
603 | class LlamaForCausalLM(LlamaPreTrainedModel):
604 | def __init__(self, config):
605 | super().__init__(config)
606 | self.model = LlamaModel(config)
607 |
608 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
609 |
610 | # Initialize weights and apply final processing
611 | self.post_init()
612 |
613 | def get_input_embeddings(self):
614 | return self.model.embed_tokens
615 |
616 | def set_input_embeddings(self, value):
617 | self.model.embed_tokens = value
618 |
619 | def get_output_embeddings(self):
620 | return self.lm_head
621 |
622 | def set_output_embeddings(self, new_embeddings):
623 | self.lm_head = new_embeddings
624 |
625 | def set_decoder(self, decoder):
626 | self.model = decoder
627 |
628 | def get_decoder(self):
629 | return self.model
630 |
631 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
632 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
633 | def forward(
634 | self,
635 | input_ids: torch.LongTensor = None,
636 | attention_mask: Optional[torch.Tensor] = None,
637 | position_ids: Optional[torch.LongTensor] = None,
638 | past_key_values: Optional[List[torch.FloatTensor]] = None,
639 | inputs_embeds: Optional[torch.FloatTensor] = None,
640 | labels: Optional[torch.LongTensor] = None,
641 | use_cache: Optional[bool] = None,
642 | output_attentions: Optional[bool] = None,
643 | output_hidden_states: Optional[bool] = None,
644 | return_dict: Optional[bool] = None,
645 | ) -> Union[Tuple, CausalLMOutputWithPast]:
646 | r"""
647 | Args:
648 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
649 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
650 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
651 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
652 |
653 | Returns:
654 |
655 | Example:
656 |
657 | ```python
658 | >>> from transformers import AutoTokenizer, LlamaForCausalLM
659 |
660 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
661 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
662 |
663 | >>> prompt = "Hey, are you consciours? Can you talk to me?"
664 | >>> inputs = tokenizer(prompt, return_tensors="pt")
665 |
666 | >>> # Generate
667 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
668 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
669 | "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
670 | ```"""
671 |
672 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
673 | output_hidden_states = (
674 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
675 | )
676 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
677 |
678 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
679 | outputs = self.model(
680 | input_ids=input_ids,
681 | attention_mask=attention_mask,
682 | position_ids=position_ids,
683 | past_key_values=past_key_values,
684 | inputs_embeds=inputs_embeds,
685 | use_cache=use_cache,
686 | output_attentions=output_attentions,
687 | output_hidden_states=output_hidden_states,
688 | return_dict=return_dict,
689 | )
690 |
691 | hidden_states = outputs[0]
692 | logits = self.lm_head(hidden_states)
693 |
694 | loss = None
695 | if labels is not None:
696 | # NOTE: big optimization could be done here (?)
697 | # maybe the copy operation that you saw in the debugger was happening here
698 |
699 | # Shift so that tokens < n predict n
700 | shift_logits = logits[..., :-1, :].contiguous()
701 | shift_labels = labels[..., 1:].contiguous()
702 | # Flatten the tokens
703 | loss_fct = CrossEntropyLoss()
704 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
705 | shift_labels = shift_labels.view(-1)
706 | # Enable model parallelism
707 | shift_labels = shift_labels.to(shift_logits.device)
708 | loss = loss_fct(shift_logits, shift_labels)
709 |
710 | if not return_dict:
711 | output = (logits,) + outputs[1:]
712 | return (loss,) + output if loss is not None else output
713 |
714 | return CausalLMOutputWithPast(
715 | loss=loss,
716 | logits=logits,
717 | past_key_values=outputs.past_key_values,
718 | hidden_states=outputs.hidden_states,
719 | attentions=outputs.attentions,
720 | )
721 |
722 | def prepare_inputs_for_generation(
723 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
724 | ):
725 | if past_key_values:
726 | input_ids = input_ids[:, -1:]
727 |
728 | position_ids = kwargs.get("position_ids", None)
729 | if attention_mask is not None and position_ids is None:
730 | # create position_ids on the fly for batch generation
731 | position_ids = attention_mask.long().cumsum(-1) - 1
732 | position_ids.masked_fill_(attention_mask == 0, 1)
733 | if past_key_values:
734 | position_ids = position_ids[:, -1].unsqueeze(-1)
735 |
736 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
737 | if inputs_embeds is not None and past_key_values is None:
738 | model_inputs = {"inputs_embeds": inputs_embeds}
739 | else:
740 | model_inputs = {"input_ids": input_ids}
741 |
742 | model_inputs.update(
743 | {
744 | "position_ids": position_ids,
745 | "past_key_values": past_key_values,
746 | "use_cache": kwargs.get("use_cache"),
747 | "attention_mask": attention_mask,
748 | }
749 | )
750 | return model_inputs
751 |
752 | @staticmethod
753 | def _reorder_cache(past_key_values, beam_idx):
754 | reordered_past = ()
755 | for layer_past in past_key_values:
756 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
757 | return reordered_past
758 |
759 |
760 | @add_start_docstrings(
761 | """
762 | The LLaMa Model transformer with a sequence classification head on top (linear layer).
763 |
764 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
765 | (e.g. GPT-2) do.
766 |
767 | Since it does classification on the last token, it requires to know the position of the last token. If a
768 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
769 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
770 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
771 | each row of the batch).
772 | """,
773 | LLAMA_START_DOCSTRING,
774 | )
775 | class LlamaForSequenceClassification(LlamaPreTrainedModel):
776 | _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
777 |
778 | def __init__(self, config):
779 | super().__init__(config)
780 | self.num_labels = config.num_labels
781 | self.model = LlamaModel(config)
782 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
783 |
784 | # Initialize weights and apply final processing
785 | self.post_init()
786 |
787 | def get_input_embeddings(self):
788 | return self.model.embed_tokens
789 |
790 | def set_input_embeddings(self, value):
791 | self.model.embed_tokens = value
792 |
793 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
794 | def forward(
795 | self,
796 | input_ids: torch.LongTensor = None,
797 | attention_mask: Optional[torch.Tensor] = None,
798 | position_ids: Optional[torch.LongTensor] = None,
799 | past_key_values: Optional[List[torch.FloatTensor]] = None,
800 | inputs_embeds: Optional[torch.FloatTensor] = None,
801 | labels: Optional[torch.LongTensor] = None,
802 | use_cache: Optional[bool] = None,
803 | output_attentions: Optional[bool] = None,
804 | output_hidden_states: Optional[bool] = None,
805 | return_dict: Optional[bool] = None,
806 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
807 | r"""
808 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
809 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
810 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
811 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
812 | """
813 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
814 |
815 | transformer_outputs = self.model(
816 | input_ids,
817 | attention_mask=attention_mask,
818 | position_ids=position_ids,
819 | past_key_values=past_key_values,
820 | inputs_embeds=inputs_embeds,
821 | use_cache=use_cache,
822 | output_attentions=output_attentions,
823 | output_hidden_states=output_hidden_states,
824 | return_dict=return_dict,
825 | )
826 | hidden_states = transformer_outputs[0]
827 | logits = self.score(hidden_states)
828 |
829 | if input_ids is not None:
830 | batch_size = input_ids.shape[0]
831 | else:
832 | batch_size = inputs_embeds.shape[0]
833 |
834 | if self.config.pad_token_id is None and batch_size != 1:
835 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
836 | if self.config.pad_token_id is None:
837 | sequence_lengths = -1
838 | else:
839 | if input_ids is not None:
840 | sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
841 | else:
842 | sequence_lengths = -1
843 |
844 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
845 |
846 | loss = None
847 | if labels is not None:
848 | labels = labels.to(logits.device)
849 | if self.config.problem_type is None:
850 | if self.num_labels == 1:
851 | self.config.problem_type = "regression"
852 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
853 | self.config.problem_type = "single_label_classification"
854 | else:
855 | self.config.problem_type = "multi_label_classification"
856 |
857 | if self.config.problem_type == "regression":
858 | loss_fct = MSELoss()
859 | if self.num_labels == 1:
860 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
861 | else:
862 | loss = loss_fct(pooled_logits, labels)
863 | elif self.config.problem_type == "single_label_classification":
864 | loss_fct = CrossEntropyLoss()
865 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
866 | elif self.config.problem_type == "multi_label_classification":
867 | loss_fct = BCEWithLogitsLoss()
868 | loss = loss_fct(pooled_logits, labels)
869 | if not return_dict:
870 | output = (pooled_logits,) + transformer_outputs[1:]
871 | return ((loss,) + output) if loss is not None else output
872 |
873 | return SequenceClassifierOutputWithPast(
874 | loss=loss,
875 | logits=pooled_logits,
876 | past_key_values=transformer_outputs.past_key_values,
877 | hidden_states=transformer_outputs.hidden_states,
878 | attentions=transformer_outputs.attentions,
879 | )
880 |
--------------------------------------------------------------------------------
/peft_pretraining/training_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | from torch.optim.lr_scheduler import LambdaLR
6 | import transformers
7 |
8 |
9 | def get_scheculer(
10 | optimizer,
11 | *,
12 | scheduler_type,
13 | num_training_steps,
14 | warmup_steps,
15 | min_lr_ratio,
16 | cycle_length=None,
17 | restart_warmup_steps=None,
18 | adjust_step=0,
19 | last_epoch=-1,
20 | ):
21 | if adjust_step != 0 and scheduler_type != "cosine_restarts":
22 | raise ValueError("adjust_step is only supported for cosine_restarts scheduler")
23 |
24 | if scheduler_type == "linear":
25 | return transformers.get_linear_schedule_with_warmup(
26 | optimizer,
27 | num_warmup_steps=warmup_steps,
28 | num_training_steps=num_training_steps,
29 | last_epoch=last_epoch,
30 | )
31 | if scheduler_type == "cosine":
32 | return get_cyclical_cosine_schedule_with_min_lr(
33 | optimizer,
34 | num_warmup_steps=warmup_steps,
35 | num_training_steps=num_training_steps,
36 | cycle_length=cycle_length,
37 | min_lr_ratio=min_lr_ratio,
38 | last_epoch=last_epoch,
39 | )
40 | if scheduler_type == "cosine_restarts":
41 | assert restart_warmup_steps is not None, "restart_warmup_steps must be specified for cosine_restarts scheduler"
42 | return get_cosine_schedule_with_multiple_warmups(
43 | optimizer,
44 | num_training_steps=num_training_steps,
45 | first_warmup_steps=warmup_steps,
46 | restart_warmup_steps=restart_warmup_steps,
47 | restart_every=cycle_length,
48 | min_lr_ratio=min_lr_ratio,
49 | last_epoch=last_epoch,
50 | adjust_step=adjust_step,
51 | )
52 |
53 | raise NotImplementedError(f"Scheduler {scheduler_type} is not implemented")
54 |
55 |
56 | def get_cyclical_cosine_schedule_with_min_lr(optimizer, num_warmup_steps, num_training_steps, cycle_length, min_lr_ratio=0.1, last_epoch=-1):
57 | assert cycle_length is not None or num_training_steps is not None, "You must specify either cycle_length or num_training_steps"
58 |
59 | if cycle_length is None:
60 | cycle_length = num_training_steps
61 |
62 | if num_training_steps % cycle_length != 0:
63 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by cycle_length ({cycle_length})")
64 |
65 | lr_lambda = partial(
66 | _get_cyclical_cosine_schedule_with_min_lr_lambda,
67 | num_warmup_steps=num_warmup_steps,
68 | cycle_length=cycle_length,
69 | min_lr_ratio=min_lr_ratio,
70 | )
71 | return LambdaLR(optimizer, lr_lambda, last_epoch)
72 |
73 |
74 | def get_cosine_schedule_with_multiple_warmups(
75 | optimizer,
76 | *,
77 | num_training_steps,
78 | first_warmup_steps,
79 | restart_warmup_steps,
80 | restart_every,
81 | min_lr_ratio=0.1,
82 | adjust_step=0,
83 | last_epoch=-1,
84 | ):
85 | if restart_every is None:
86 | raise ValueError("restart_every must be specified for cosine_restarts scheduler")
87 |
88 | if num_training_steps % restart_every != 0:
89 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by restart_every ({restart_every})")
90 |
91 | lr_lambda = partial(
92 | _get_cosine_schedule_with_multiple_warmups_lambda,
93 | num_training_steps=num_training_steps,
94 | first_warmup_steps=first_warmup_steps,
95 | restart_warmup_steps=restart_warmup_steps,
96 | restart_every=restart_every,
97 | min_lr_ratio=min_lr_ratio,
98 | adjust_step=adjust_step,
99 | )
100 | return LambdaLR(optimizer, lr_lambda, last_epoch)
101 |
102 |
103 | @torch.no_grad()
104 | def random_pruning(tensor, prune_ratio):
105 | """
106 | Performs random pruning dimensionality reduction.
107 | Only reduces the inner dimensionality, does not affect the shape of the tensor
108 | """
109 | random_pruning_mask = torch.rand_like(tensor) > prune_ratio
110 | tensor = tensor * random_pruning_mask
111 | return tensor
112 |
113 |
114 | @torch.no_grad()
115 | def magnitude_pruning(tensor, prune_ratio):
116 | """
117 | Performs magnitude pruning dimensionality reduction.
118 | Only reduces the inner dimensionality, does not affect the shape of the tensor
119 | """
120 | tensor_magnitude = torch.abs(tensor)
121 | threshold = torch.quantile(tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio).to(dtype=tensor.dtype)
122 |
123 | mask = tensor_magnitude > threshold
124 | tensor = tensor * mask.to(dtype=tensor.dtype)
125 | return tensor
126 |
127 |
128 | def _get_cyclical_cosine_schedule_with_min_lr_lambda(current_step, *, num_warmup_steps, cycle_length, min_lr_ratio):
129 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]"
130 |
131 | # compute where we are in the current cycle
132 | cycle_step = current_step % cycle_length
133 |
134 | if cycle_step < num_warmup_steps:
135 | if current_step != cycle_step:
136 | if cycle_step < 2:
137 | return 1e-7
138 | return float(cycle_step) / float(max(1, num_warmup_steps))
139 |
140 | progress = float(cycle_step - num_warmup_steps) / float(max(1, cycle_length - num_warmup_steps))
141 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
142 |
143 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
144 |
145 |
146 | def _get_cosine_schedule_with_multiple_warmups_lambda(
147 | current_step,
148 | *,
149 | num_training_steps,
150 | first_warmup_steps,
151 | restart_warmup_steps,
152 | restart_every,
153 | min_lr_ratio,
154 | adjust_step,
155 | ):
156 | """
157 | Args:
158 | adjust_step: useful when continuing training from a warmed up checkpoint,
159 | it allows to sync the resets by reducing the number of steps
160 | after the first warmup and before the first reset.
161 | Thus, your ReLoRA resets can be synced with the optimizer resets.
162 | """
163 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]"
164 | assert restart_every > 0, "restart_every must be positive"
165 | assert adjust_step + first_warmup_steps < num_training_steps, "warmup + adjust_step is more than full training steps"
166 | assert adjust_step + first_warmup_steps < restart_every, "the first reset will happen before the warmup is done"
167 |
168 | if current_step < first_warmup_steps:
169 | return float(current_step) / float(max(1, first_warmup_steps))
170 |
171 | _current_step = current_step + adjust_step
172 |
173 | restart_step = _current_step % restart_every
174 | restart_number = _current_step // restart_every
175 |
176 | if restart_step < restart_warmup_steps:
177 | # get expected lr multipler at the end of the warmup
178 | end_of_warmup_progress = (
179 | float(restart_number * restart_every) /
180 | float(max(1, num_training_steps - first_warmup_steps))
181 | )
182 |
183 | _cosine_decay = 0.5 * (1.0 + math.cos(math.pi * end_of_warmup_progress))
184 | warmup_lr_multiplier = min_lr_ratio + (1.0 - min_lr_ratio) * _cosine_decay
185 |
186 | return float(restart_step) / float(max(1, restart_warmup_steps)) * warmup_lr_multiplier
187 |
188 | progress = float(_current_step - first_warmup_steps) / float(max(1, num_training_steps - first_warmup_steps))
189 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
190 |
191 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
192 |
193 |
194 | def collate_fn(batch_list):
195 | batch = {
196 | "input_ids": torch.stack([torch.Tensor(example["input_ids"]).long() for example in batch_list]),
197 | "attention_mask": torch.stack([torch.Tensor(example["attention_mask"]).long() for example in batch_list]),
198 | }
199 | return batch
200 |
201 |
202 | def batch_fn(dataset, batch_size):
203 | batch = []
204 | for example in dataset:
205 | batch.append(example)
206 | if len(batch) == batch_size:
207 | batch = collate_fn(batch)
208 | yield batch
209 | batch = []
210 | if len(batch) > 0:
211 | yield batch
212 |
213 |
214 | def max_train_tokens_to_number(max_train_tokens):
215 | if max_train_tokens.endswith("M"):
216 | return int(max_train_tokens.rstrip("M")) * 1_000_000
217 | elif max_train_tokens.endswith("B"):
218 | return int(max_train_tokens.rstrip("B")) * 1_000_000_000
219 | else:
220 | return int(max_train_tokens)
221 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers
3 | bitsandbytes
4 | tensorly
5 |
--------------------------------------------------------------------------------
/run_glue.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
16 | import argparse
17 | import json
18 | import logging
19 | import math
20 | import os
21 | import random
22 | from pathlib import Path
23 |
24 | import datasets
25 | import evaluate
26 | import torch
27 | from accelerate import Accelerator
28 | from accelerate.logging import get_logger
29 | from accelerate.utils import set_seed
30 | from datasets import load_dataset
31 | from huggingface_hub import Repository, create_repo
32 | from torch.utils.data import DataLoader
33 | from tqdm.auto import tqdm
34 |
35 | import transformers
36 | from transformers import (
37 | AutoConfig,
38 | AutoModelForSequenceClassification,
39 | AutoTokenizer,
40 | DataCollatorWithPadding,
41 | PretrainedConfig,
42 | SchedulerType,
43 | default_data_collator,
44 | get_scheduler,
45 | LlamaForSequenceClassification
46 | )
47 | from transformers.utils import check_min_version, send_example_telemetry
48 | from transformers.utils.versions import require_version
49 |
50 | from galore_torch import GaLoreAdamW
51 |
52 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
53 | # check_min_version("4.38.0.dev0")
54 |
55 | logger = get_logger(__name__)
56 |
57 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
58 |
59 | task_to_keys = {
60 | "cola": ("sentence", None),
61 | "mnli": ("premise", "hypothesis"),
62 | "mrpc": ("sentence1", "sentence2"),
63 | "qnli": ("question", "sentence"),
64 | "qqp": ("question1", "question2"),
65 | "rte": ("sentence1", "sentence2"),
66 | "sst2": ("sentence", None),
67 | "stsb": ("sentence1", "sentence2"),
68 | "wnli": ("sentence1", "sentence2"),
69 | }
70 |
71 |
72 | def parse_args():
73 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
74 |
75 | # LoRA hyperparameters
76 | parser.add_argument("--lora_r", type=int, default=8)
77 | parser.add_argument("--load_pretrained_model", type=str, default=None)
78 |
79 | parser.add_argument(
80 | "--task_name",
81 | type=str,
82 | default=None,
83 | help="The name of the glue task to train on.",
84 | choices=list(task_to_keys.keys()),
85 | )
86 | parser.add_argument(
87 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
88 | )
89 | parser.add_argument(
90 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
91 | )
92 | parser.add_argument(
93 | "--max_length",
94 | type=int,
95 | default=128,
96 | help=(
97 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
98 | " sequences shorter will be padded if `--pad_to_max_length` is passed."
99 | ),
100 | )
101 | parser.add_argument(
102 | "--pad_to_max_length",
103 | action="store_true",
104 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
105 | )
106 | parser.add_argument(
107 | "--model_name_or_path",
108 | type=str,
109 | help="Path to pretrained model or model identifier from huggingface.co/models.",
110 | required=True,
111 | )
112 | parser.add_argument(
113 | "--use_slow_tokenizer",
114 | action="store_true",
115 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
116 | )
117 | parser.add_argument(
118 | "--per_device_train_batch_size",
119 | type=int,
120 | default=8,
121 | help="Batch size (per device) for the training dataloader.",
122 | )
123 | parser.add_argument(
124 | "--per_device_eval_batch_size",
125 | type=int,
126 | default=8,
127 | help="Batch size (per device) for the evaluation dataloader.",
128 | )
129 | parser.add_argument(
130 | "--learning_rate",
131 | type=float,
132 | default=5e-5,
133 | help="Initial learning rate (after the potential warmup period) to use.",
134 | )
135 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
136 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
137 | parser.add_argument(
138 | "--max_train_steps",
139 | type=int,
140 | default=None,
141 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
142 | )
143 | parser.add_argument(
144 | "--gradient_accumulation_steps",
145 | type=int,
146 | default=1,
147 | help="Number of updates steps to accumulate before performing a backward/update pass.",
148 | )
149 | parser.add_argument(
150 | "--lr_scheduler_type",
151 | type=SchedulerType,
152 | default="linear",
153 | help="The scheduler type to use.",
154 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
155 | )
156 | parser.add_argument(
157 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
158 | )
159 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
160 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
161 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
162 | parser.add_argument(
163 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
164 | )
165 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
166 | parser.add_argument(
167 | "--trust_remote_code",
168 | type=bool,
169 | default=False,
170 | help=(
171 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
172 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
173 | "execute code present on the Hub on your local machine."
174 | ),
175 | )
176 | parser.add_argument(
177 | "--checkpointing_steps",
178 | type=str,
179 | default=None,
180 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
181 | )
182 | parser.add_argument(
183 | "--resume_from_checkpoint",
184 | type=str,
185 | default=None,
186 | help="If the training should continue from a checkpoint folder.",
187 | )
188 | parser.add_argument(
189 | "--with_tracking",
190 | action="store_true",
191 | help="Whether to enable experiment trackers for logging.",
192 | )
193 | parser.add_argument(
194 | "--report_to",
195 | type=str,
196 | default="all",
197 | help=(
198 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
199 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
200 | "Only applicable when `--with_tracking` is passed."
201 | ),
202 | )
203 | parser.add_argument(
204 | "--ignore_mismatched_sizes",
205 | action="store_true",
206 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
207 | )
208 |
209 | # support enable_galore
210 | parser.add_argument("--enable_galore", action="store_true", help="Whether or not to use low rank optimizer.")
211 | # update_proj_gap
212 | parser.add_argument("--update_proj_gap", type=int, default=50)
213 | # galore_scale
214 | parser.add_argument("--galore_scale", type=float, default=1.0)
215 | # proj_type
216 | parser.add_argument("--proj_type", type=str, default="std")
217 | # lora_all_modules
218 | parser.add_argument("--lora_all_modules", action="store_true", help="Whether or not to use lora for all modules.")
219 | # eval_llama
220 | parser.add_argument("--eval_llama", action="store_true", help="Whether or not to evaluate llama model.")
221 | # low_rank_method
222 | parser.add_argument("--low_rank_method", type=str, default=None, help="low rank method for wandb sweep")
223 |
224 | args = parser.parse_args()
225 |
226 | # Sanity checks
227 | if args.task_name is None and args.train_file is None and args.validation_file is None:
228 | raise ValueError("Need either a task name or a training/validation file.")
229 | else:
230 | if args.train_file is not None:
231 | extension = args.train_file.split(".")[-1]
232 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
233 | if args.validation_file is not None:
234 | extension = args.validation_file.split(".")[-1]
235 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
236 |
237 | if args.push_to_hub:
238 | assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
239 |
240 | return args
241 |
242 |
243 | def main():
244 | args = parse_args()
245 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
246 | # information sent is the one passed as arguments along with your Python/PyTorch versions.
247 | send_example_telemetry("run_glue_no_trainer", args)
248 |
249 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
250 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
251 | # in the environment
252 | accelerator = (
253 | Accelerator(log_with=args.report_to, project_dir=args.output_dir) if args.with_tracking else Accelerator()
254 | )
255 | # Make one log on every process with the configuration for debugging.
256 | logging.basicConfig(
257 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
258 | datefmt="%m/%d/%Y %H:%M:%S",
259 | level=logging.INFO,
260 | )
261 | logger.info(accelerator.state, main_process_only=False)
262 | if accelerator.is_local_main_process:
263 | datasets.utils.logging.set_verbosity_warning()
264 | transformers.utils.logging.set_verbosity_info()
265 | else:
266 | datasets.utils.logging.set_verbosity_error()
267 | transformers.utils.logging.set_verbosity_error()
268 |
269 | # If passed along, set the training seed now.
270 | if args.seed is not None:
271 | set_seed(args.seed)
272 |
273 | # Handle the repository creation
274 | if accelerator.is_main_process:
275 | if args.push_to_hub:
276 | # Retrieve of infer repo_name
277 | repo_name = args.hub_model_id
278 | if repo_name is None:
279 | repo_name = Path(args.output_dir).absolute().name
280 | # Create repo and retrieve repo_id
281 | repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
282 | # Clone repo locally
283 | repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
284 |
285 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
286 | if "step_*" not in gitignore:
287 | gitignore.write("step_*\n")
288 | if "epoch_*" not in gitignore:
289 | gitignore.write("epoch_*\n")
290 | elif args.output_dir is not None:
291 | os.makedirs(args.output_dir, exist_ok=True)
292 | accelerator.wait_for_everyone()
293 |
294 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
295 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
296 |
297 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
298 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
299 | # label if at least two columns are provided.
300 |
301 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
302 | # single column. You can easily tweak this behavior (see below)
303 |
304 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
305 | # download the dataset.
306 | if args.task_name is not None:
307 | # Downloading and loading a dataset from the hub.
308 | raw_datasets = load_dataset("glue", args.task_name)
309 | else:
310 | # Loading the dataset from local csv or json file.
311 | data_files = {}
312 | if args.train_file is not None:
313 | data_files["train"] = args.train_file
314 | if args.validation_file is not None:
315 | data_files["validation"] = args.validation_file
316 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1]
317 | raw_datasets = load_dataset(extension, data_files=data_files)
318 | # See more about loading any type of standard or custom dataset at
319 | # https://huggingface.co/docs/datasets/loading_datasets.
320 |
321 | # Labels
322 | if args.task_name is not None:
323 | is_regression = args.task_name == "stsb"
324 | if not is_regression:
325 | label_list = raw_datasets["train"].features["label"].names
326 | num_labels = len(label_list)
327 | else:
328 | num_labels = 1
329 | else:
330 | # Trying to have good defaults here, don't hesitate to tweak to your needs.
331 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
332 | if is_regression:
333 | num_labels = 1
334 | else:
335 | # A useful fast method:
336 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
337 | label_list = raw_datasets["train"].unique("label")
338 | label_list.sort() # Let's sort it for determinism
339 | num_labels = len(label_list)
340 |
341 | # Load pretrained model and tokenizer
342 | #
343 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
344 | # download model & vocab.
345 |
346 | if not args.eval_llama:
347 | config = AutoConfig.from_pretrained(
348 | args.model_name_or_path,
349 | num_labels=num_labels,
350 | finetuning_task=args.task_name,
351 | trust_remote_code=args.trust_remote_code,
352 | )
353 | tokenizer = AutoTokenizer.from_pretrained(
354 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code
355 | )
356 | model = AutoModelForSequenceClassification.from_pretrained(
357 | args.model_name_or_path,
358 | from_tf=bool(".ckpt" in args.model_name_or_path),
359 | config=config,
360 | ignore_mismatched_sizes=args.ignore_mismatched_sizes,
361 | trust_remote_code=args.trust_remote_code,
362 | )
363 | else:
364 | config = AutoConfig.from_pretrained(args.model_name_or_path)
365 | setattr(config, 'num_labels', num_labels)
366 | setattr(config, 'finetuning_task', args.task_name)
367 | tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=args.max_length)
368 | tokenizer.padding_side = "left"
369 | model = LlamaForSequenceClassification(
370 | config
371 | )
372 | ## load pretrained model
373 | if args.load_pretrained_model:
374 | logger.info("*" * 40)
375 | logger.info(f"Loading model from {args.load_pretrained_model}")
376 | checkpoint_path = os.path.join(args.load_pretrained_model, "pytorch_model.bin")
377 | checkpoint = torch.load(checkpoint_path, map_location="cpu")
378 | for key in checkpoint.keys():
379 | if key not in model.state_dict().keys():
380 | print(f"key {key} not in model state dict")
381 |
382 | for key in model.state_dict().keys():
383 | if key not in checkpoint.keys():
384 | print(f"key {key} not in checkpoint")
385 | model.load_state_dict(checkpoint, strict=False)
386 | logger.info(f"Model successfully loaded (strict=False policy)")
387 | logger.info("*" * 40)
388 |
389 | # project modules
390 | if not args.lora_all_modules:
391 | target_modules_list = ["q_proj", "v_proj"]
392 | else:
393 | print('Enabling LoRA for all modules')
394 | target_modules_list = ["q_proj", "v_proj", "up_proj", "down_proj", "gate_proj", "k_proj", "o_proj"]
395 |
396 | # other modules for bert-family modules
397 | if 'bert' in args.model_name_or_path:
398 | if not args.lora_all_modules:
399 | target_modules_list = ["query"]
400 | else:
401 | print('Enabling LoRA for all modules')
402 | target_modules_list = ["query", "value", "key", "intermediate.dense", "output.dense"]
403 |
404 | # Preprocessing the datasets
405 | if args.task_name is not None:
406 | sentence1_key, sentence2_key = task_to_keys[args.task_name]
407 | else:
408 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
409 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
410 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
411 | sentence1_key, sentence2_key = "sentence1", "sentence2"
412 | else:
413 | if len(non_label_column_names) >= 2:
414 | sentence1_key, sentence2_key = non_label_column_names[:2]
415 | else:
416 | sentence1_key, sentence2_key = non_label_column_names[0], None
417 |
418 | # Some models have set the order of the labels to use, so let's make sure we do use it.
419 | label_to_id = None
420 | if (
421 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
422 | and args.task_name is not None
423 | and not is_regression
424 | ):
425 | # Some have all caps in their config, some don't.
426 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
427 | if sorted(label_name_to_id.keys()) == sorted(label_list):
428 | logger.info(
429 | f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
430 | "Using it!"
431 | )
432 | label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
433 | else:
434 | logger.warning(
435 | "Your model seems to have been trained with labels, but they don't match the dataset: ",
436 | f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
437 | "\nIgnoring the model labels as a result.",
438 | )
439 | elif args.task_name is None and not is_regression:
440 | label_to_id = {v: i for i, v in enumerate(label_list)}
441 |
442 | if label_to_id is not None:
443 | model.config.label2id = label_to_id
444 | model.config.id2label = {id: label for label, id in config.label2id.items()}
445 | elif args.task_name is not None and not is_regression:
446 | model.config.label2id = {l: i for i, l in enumerate(label_list)}
447 | model.config.id2label = {id: label for label, id in config.label2id.items()}
448 |
449 | padding = "max_length" if args.pad_to_max_length else False
450 |
451 | def preprocess_function(examples):
452 | # Tokenize the texts
453 | texts = (
454 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
455 | )
456 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True)
457 |
458 | if "label" in examples:
459 | if label_to_id is not None:
460 | # Map labels to IDs (not necessary for GLUE tasks)
461 | result["labels"] = [label_to_id[l] for l in examples["label"]]
462 | else:
463 | # In all cases, rename the column to labels because the model will expect that.
464 | result["labels"] = examples["label"]
465 | return result
466 |
467 | with accelerator.main_process_first():
468 | processed_datasets = raw_datasets.map(
469 | preprocess_function,
470 | batched=True,
471 | remove_columns=raw_datasets["train"].column_names,
472 | desc="Running tokenizer on dataset",
473 | )
474 |
475 | train_dataset = processed_datasets["train"]
476 | eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
477 |
478 | # Log a few random samples from the training set:
479 | for index in random.sample(range(len(train_dataset)), 3):
480 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
481 |
482 | # DataLoaders creation:
483 | if args.pad_to_max_length:
484 | # If padding was already done ot max length, we use the default data collator that will just convert everything
485 | # to tensors.
486 | data_collator = default_data_collator
487 | else:
488 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
489 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
490 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
491 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None))
492 |
493 | train_dataloader = DataLoader(
494 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
495 | )
496 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
497 |
498 | # Optimizer
499 | # Split weights in two groups, one with weight decay and the other not.
500 | no_decay = ["bias", "LayerNorm.weight"]
501 | optimizer_grouped_parameters = [
502 | {
503 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
504 | "weight_decay": args.weight_decay,
505 | },
506 | {
507 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
508 | "weight_decay": 0.0,
509 | },
510 | ]
511 |
512 | if not args.enable_galore:
513 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
514 | else:
515 | from torch import nn
516 | # label layers for galore optimizer
517 | # target_modules_list = ["attn", "mlp"]
518 | # target_modules_list = ["q_proj", "v_proj"]
519 | galore_params = []
520 | for module_name, module in model.named_modules():
521 | if not isinstance(module, nn.Linear):
522 | continue
523 |
524 | if not any(target_key in module_name for target_key in target_modules_list):
525 | continue
526 |
527 | print('enable GaLore for weights in module: ', module_name)
528 | galore_params.append(module.weight)
529 |
530 | id_galore_params = [id(p) for p in galore_params]
531 | # make parameters without "rank" to another group
532 | regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
533 | # then call galore_adamw
534 | param_groups = [{'params': regular_params},
535 | {'params': galore_params, 'rank': args.lora_r, 'update_proj_gap': args.update_proj_gap, 'scale': args.galore_scale, 'proj_type': args.proj_type}]
536 | optimizer = GaLoreAdamW(param_groups, lr=args.learning_rate)
537 |
538 |
539 | # Scheduler and math around the number of training steps.
540 | overrode_max_train_steps = False
541 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
542 | if args.max_train_steps is None:
543 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
544 | overrode_max_train_steps = True
545 |
546 | lr_scheduler = get_scheduler(
547 | name=args.lr_scheduler_type,
548 | optimizer=optimizer,
549 | num_warmup_steps=args.num_warmup_steps,
550 | num_training_steps=args.max_train_steps,
551 | )
552 |
553 | # Prepare everything with our `accelerator`.
554 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
555 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
556 | )
557 |
558 | # We need to recalculate our total training steps as the size of the training dataloader may have changed
559 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
560 | if overrode_max_train_steps:
561 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
562 | # Afterwards we recalculate our number of training epochs
563 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
564 |
565 | # Figure out how many steps we should save the Accelerator states
566 | checkpointing_steps = args.checkpointing_steps
567 | if checkpointing_steps is not None and checkpointing_steps.isdigit():
568 | checkpointing_steps = int(checkpointing_steps)
569 |
570 | # We need to initialize the trackers we use, and also store our configuration.
571 | # The trackers initializes automatically on the main process.
572 | if args.with_tracking:
573 | experiment_config = vars(args)
574 | # TensorBoard cannot log Enums, need the raw value
575 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
576 | accelerator.init_trackers("glue_no_trainer", experiment_config)
577 |
578 | # Get the metric function
579 | if args.task_name is not None:
580 | metric = evaluate.load("glue", args.task_name)
581 | else:
582 | metric = evaluate.load("accuracy")
583 |
584 | # Train!
585 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
586 |
587 | logger.info("***** Running training *****")
588 | logger.info(f" Num examples = {len(train_dataset)}")
589 | logger.info(f" Num Epochs = {args.num_train_epochs}")
590 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
591 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
592 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
593 | logger.info(f" Total optimization steps = {args.max_train_steps}")
594 | # Only show the progress bar once on each machine.
595 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
596 | completed_steps = 0
597 | starting_epoch = 0
598 | # Potentially load in the weights and states from a previous save
599 | if args.resume_from_checkpoint:
600 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
601 | checkpoint_path = args.resume_from_checkpoint
602 | path = os.path.basename(args.resume_from_checkpoint)
603 | else:
604 | # Get the most recent checkpoint
605 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
606 | dirs.sort(key=os.path.getctime)
607 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
608 | checkpoint_path = path
609 | path = os.path.basename(checkpoint_path)
610 |
611 | accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
612 | accelerator.load_state(checkpoint_path)
613 | # Extract `epoch_{i}` or `step_{i}`
614 | training_difference = os.path.splitext(path)[0]
615 |
616 | if "epoch" in training_difference:
617 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
618 | resume_step = None
619 | completed_steps = starting_epoch * num_update_steps_per_epoch
620 | else:
621 | # need to multiply `gradient_accumulation_steps` to reflect real steps
622 | resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
623 | starting_epoch = resume_step // len(train_dataloader)
624 | completed_steps = resume_step // args.gradient_accumulation_steps
625 | resume_step -= starting_epoch * len(train_dataloader)
626 |
627 | # update the progress_bar if load from checkpoint
628 | progress_bar.update(completed_steps)
629 |
630 | for epoch in range(starting_epoch, args.num_train_epochs):
631 | model.train()
632 | if args.with_tracking:
633 | total_loss = 0
634 | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
635 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint
636 | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
637 | else:
638 | active_dataloader = train_dataloader
639 | for step, batch in enumerate(active_dataloader):
640 |
641 | outputs = model(**batch)
642 | loss = outputs.loss
643 | # We keep track of the loss at each epoch
644 | if args.with_tracking:
645 | total_loss += loss.detach().float()
646 | loss = loss / args.gradient_accumulation_steps
647 | accelerator.backward(loss)
648 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
649 | optimizer.step()
650 | lr_scheduler.step()
651 | optimizer.zero_grad()
652 | progress_bar.update(1)
653 | completed_steps += 1
654 |
655 | if isinstance(checkpointing_steps, int):
656 | if completed_steps % checkpointing_steps == 0:
657 | output_dir = f"step_{completed_steps}"
658 | if args.output_dir is not None:
659 | output_dir = os.path.join(args.output_dir, output_dir)
660 | accelerator.save_state(output_dir)
661 |
662 | if completed_steps >= args.max_train_steps:
663 | break
664 |
665 | model.eval()
666 | samples_seen = 0
667 | for step, batch in enumerate(eval_dataloader):
668 | with torch.no_grad():
669 | outputs = model(**batch)
670 | predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
671 | predictions, references = accelerator.gather((predictions, batch["labels"]))
672 | # If we are in a multiprocess environment, the last batch has duplicates
673 | if accelerator.num_processes > 1:
674 | if step == len(eval_dataloader) - 1:
675 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
676 | references = references[: len(eval_dataloader.dataset) - samples_seen]
677 | else:
678 | samples_seen += references.shape[0]
679 | metric.add_batch(
680 | predictions=predictions,
681 | references=references,
682 | )
683 |
684 | eval_metric = metric.compute()
685 | logger.info(f"epoch {epoch}: {eval_metric}")
686 |
687 | if args.with_tracking:
688 | accelerator.log(
689 | {
690 | "accuracy" if args.task_name is not None else "glue": eval_metric,
691 | "train_loss": total_loss.item() / len(train_dataloader),
692 | "epoch": epoch,
693 | "step": completed_steps,
694 | },
695 | step=completed_steps,
696 | )
697 |
698 | if args.push_to_hub and epoch < args.num_train_epochs - 1:
699 | accelerator.wait_for_everyone()
700 | unwrapped_model = accelerator.unwrap_model(model)
701 | unwrapped_model.save_pretrained(
702 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
703 | )
704 | if accelerator.is_main_process:
705 | tokenizer.save_pretrained(args.output_dir)
706 | repo.push_to_hub(
707 | commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
708 | )
709 |
710 | if args.checkpointing_steps == "epoch":
711 | output_dir = f"epoch_{epoch}"
712 | if args.output_dir is not None:
713 | output_dir = os.path.join(args.output_dir, output_dir)
714 | accelerator.save_state(output_dir)
715 |
716 | if args.with_tracking:
717 | accelerator.end_training()
718 |
719 | if args.output_dir is not None:
720 | accelerator.wait_for_everyone()
721 | unwrapped_model = accelerator.unwrap_model(model)
722 | unwrapped_model.save_pretrained(
723 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
724 | )
725 | if accelerator.is_main_process:
726 | tokenizer.save_pretrained(args.output_dir)
727 | if args.push_to_hub:
728 | repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
729 |
730 | if args.task_name == "mnli":
731 | # Final evaluation on mismatched validation set
732 | eval_dataset = processed_datasets["validation_mismatched"]
733 | eval_dataloader = DataLoader(
734 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
735 | )
736 | eval_dataloader = accelerator.prepare(eval_dataloader)
737 |
738 | model.eval()
739 | for step, batch in enumerate(eval_dataloader):
740 | outputs = model(**batch)
741 | predictions = outputs.logits.argmax(dim=-1)
742 | metric.add_batch(
743 | predictions=accelerator.gather(predictions),
744 | references=accelerator.gather(batch["labels"]),
745 | )
746 |
747 | eval_metric = metric.compute()
748 | logger.info(f"mnli-mm: {eval_metric}")
749 |
750 | if args.output_dir is not None:
751 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
752 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
753 | json.dump(all_results, f)
754 |
755 |
756 | if __name__ == "__main__":
757 | main()
--------------------------------------------------------------------------------
/scripts/benchmark_c4/llama_130m.sh:
--------------------------------------------------------------------------------
1 | # LLaMA-130M, GaLore-Adam, 1 A100, 1 Node
2 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \
3 | --model_config configs/llama_130m.json \
4 | --lr 0.01 \
5 | --galore_scale 0.25 \
6 | --rank 256 \
7 | --update_proj_gap 200 \
8 | --batch_size 256 \
9 | --total_batch_size 512 \
10 | --num_training_steps 20000 \
11 | --warmup_steps 2000 \
12 | --weight_decay 0 \
13 | --dtype bfloat16 \
14 | --eval_every 1000 \
15 | --optimizer galore_adamw
--------------------------------------------------------------------------------
/scripts/benchmark_c4/llama_1b.sh:
--------------------------------------------------------------------------------
1 | # LLaMA-1B, GaLore-Adam, 8 A100, 1 Node
2 | torchrun --standalone --nproc_per_node 8 torchrun_main.py \
3 | --model_config configs/llama_1b.json \
4 | --lr 0.01 \
5 | --galore_scale 0.25 \
6 | --rank 1024 \
7 | --update_proj_gap 200 \
8 | --batch_size 16 \
9 | --total_batch_size 512 \
10 | --num_training_steps 100000 \
11 | --warmup_steps 10000 \
12 | --weight_decay 0 \
13 | --dtype bfloat16 \
14 | --eval_every 1000 \
15 | --optimizer galore_adamw
--------------------------------------------------------------------------------
/scripts/benchmark_c4/llama_350m.sh:
--------------------------------------------------------------------------------
1 | # LLaMA-350M, GaLore-Adam, 4 A100, 1 Node
2 | torchrun --standalone --nproc_per_node 4 torchrun_main.py \
3 | --model_config configs/llama_350m.json \
4 | --lr 0.01 \
5 | --galore_scale 0.25 \
6 | --rank 256 \
7 | --update_proj_gap 200 \
8 | --batch_size 128 \
9 | --total_batch_size 512 \
10 | --num_training_steps 60000 \
11 | --warmup_steps 6000 \
12 | --weight_decay 0 \
13 | --dtype bfloat16 \
14 | --eval_every 1000 \
15 | --optimizer galore_adamw
--------------------------------------------------------------------------------
/scripts/benchmark_c4/llama_60m.sh:
--------------------------------------------------------------------------------
1 | # LLaMA-60M, GaLore-Adam, 1 A100, 1 Node
2 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \
3 | --model_config configs/llama_60m.json \
4 | --lr 0.01 \
5 | --galore_scale 0.25 \
6 | --rank 128 \
7 | --update_proj_gap 200 \
8 | --batch_size 256 \
9 | --total_batch_size 512 \
10 | --num_training_steps 10000 \
11 | --warmup_steps 1000 \
12 | --weight_decay 0 \
13 | --dtype bfloat16 \
14 | --eval_every 1000 \
15 | --optimizer galore_adamw
--------------------------------------------------------------------------------
/scripts/benchmark_c4/llama_7b.sh:
--------------------------------------------------------------------------------
1 | # LLaMA-7B, GaLore-Adam, 8 A100, 8 Node
2 | torchrun --standalone --nnodes 8 --nproc_per_node 8 torchrun_main.py \
3 | --model_config configs/llama_7b.json \
4 | --lr 0.005 \
5 | --galore_scale 0.25 \
6 | --rank 1024 \
7 | --update_proj_gap 500 \
8 | --batch_size 8 \
9 | --total_batch_size 512 \
10 | --num_training_steps 150000 \
11 | --warmup_steps 15000 \
12 | --weight_decay 0 \
13 | --grad_clipping 1.0 \
14 | --dtype bfloat16 \
15 | --eval_every 1000 \
16 | --optimizer galore_adamw
--------------------------------------------------------------------------------
/scripts/single_gpu/llama_7b.sh:
--------------------------------------------------------------------------------
1 | # LLaMA-7B, 8-bit GaLore-Adam, single GPU
2 | # 22.72G, 0.37s/it
3 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \
4 | --model_config configs/llama_7b.json \
5 | --lr 0.005 \
6 | --galore_scale 0.25 \
7 | --rank 1024 \
8 | --update_proj_gap 500 \
9 | --batch_size 1 \
10 | --total_batch_size 512 \
11 | --num_training_steps 150000 \
12 | --warmup_steps 15000 \
13 | --weight_decay 0 \
14 | --grad_clipping 1.0 \
15 | --dtype bfloat16 \
16 | --eval_every 1000 \
17 | --single_gpu \
18 | --optimizer galore_adamw8bit_per_layer
--------------------------------------------------------------------------------
/scripts/single_gpu/llama_7b_checkpointing.sh:
--------------------------------------------------------------------------------
1 | # LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing
2 | # bsz=16, 22.8G,
3 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \
4 | --model_config configs/llama_7b.json \
5 | --lr 0.005 \
6 | --galore_scale 0.25 \
7 | --rank 1024 \
8 | --update_proj_gap 500 \
9 | --batch_size 16 \
10 | --total_batch_size 512 \
11 | --activation_checkpointing \
12 | --num_training_steps 150000 \
13 | --warmup_steps 15000 \
14 | --weight_decay 0 \
15 | --grad_clipping 1.0 \
16 | --dtype bfloat16 \
17 | --eval_every 1000 \
18 | --single_gpu \
19 | --optimizer galore_adamw8bit_per_layer
--------------------------------------------------------------------------------
/scripts/tensor_test/neural_operator.py:
--------------------------------------------------------------------------------
1 | """
2 | Training a neural operator on Darcy-Flow - Author Robert Joseph
3 | ========================================
4 | In this example, we demonstrate how to use the small Darcy-Flow example we ship with the package on Incremental FNO and Incremental Resolution as well as using Galore tensor decomposition.
5 |
6 | Assuming one installs the neuraloperator library: Instructions can be found here: https://github.com/NeuralOperator/neuraloperator
7 | """
8 |
9 | # %%
10 | #
11 | import torch
12 | import matplotlib.pyplot as plt
13 | import sys
14 | from neuralop.training.callbacks import BasicLoggerCallback
15 | from neuralop.models import FNO
16 | from neuralop import Trainer
17 | from neuralop.datasets import load_darcy_flow_small
18 | from neuralop.utils import count_model_params
19 | from neuralop.training.callbacks import IncrementalCallback
20 | from neuralop.datasets import data_transforms
21 | from neuralop import LpLoss, H1Loss
22 | from neuralop.training import AdamW
23 | from neuralop.utils import count_model_params
24 |
25 |
26 | # %%
27 | # Loading the Darcy flow dataset
28 | train_loader, test_loaders, data_processor = load_darcy_flow_small(
29 | n_train=1000, batch_size=32,
30 | test_resolutions=[16, 32], n_tests=[100, 50],
31 | test_batch_sizes=[32, 32],
32 | positional_encoding=True
33 | )
34 | # %%
35 | # Choose device
36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37 |
38 | # %%
39 | # Set up the incremental FNO model
40 | # We start with 2 modes in each dimension
41 | # We choose to update the modes by the incremental gradient explained algorithm
42 |
43 | starting_modes = (10, 10)
44 | incremental = False
45 |
46 | model = FNO(
47 | max_n_modes=(20, 20),
48 | n_modes=starting_modes,
49 | hidden_channels=64,
50 | in_channels=1,
51 | out_channels=1,
52 | n_layers=4
53 | )
54 | callbacks = [
55 | IncrementalCallback(
56 | incremental_loss_gap=True,
57 | incremental_grad=False,
58 | incremental_grad_eps=0.9999,
59 | incremental_buffer=5,
60 | incremental_max_iter=1,
61 | incremental_grad_max_iter=2,
62 | )
63 | ]
64 | model = model.to(device)
65 | n_params = count_model_params(model)
66 | galore_params = []
67 | galore_params.extend(list(model.fno_blocks.convs.parameters()))
68 | print(galore_params[0].shape, galore_params[1].shape, galore_params[2].shape, galore_params[3].shape)
69 | galore_params.pop(0)
70 | id_galore_params = [id(p) for p in galore_params]
71 | # make parameters without "rank" to another group
72 | regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
73 | # then call galore_adamw
74 | # In this case we have a 5d tensor representing the weights in the spectral layers of the FNO
75 | # A good rule of thumb for tensor decomposition is that we should limit the rank to atmost 0.75, and increase the epochs and tune the lr accordingly compared to the baseline.
76 | # Low rank decomposition takes longer to converge, but it is more memory efficient.
77 | param_groups = [{'params': regular_params},
78 | {'params': galore_params, 'rank': 0.2 , 'update_proj_gap': 10, 'scale': 0.25, 'proj_type': "std", 'dim': 5}]
79 | optimizer = AdamW(param_groups, lr=0.01)
80 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
81 | data_transform = data_transforms.IncrementalDataProcessor(
82 | in_normalizer=None,
83 | out_normalizer=None,
84 | positional_encoding=None,
85 | device=device,
86 | dataset_sublist=[2, 1],
87 | dataset_resolution=16,
88 | dataset_indices=[2, 3],
89 | epoch_gap=10,
90 | verbose=True,
91 | )
92 |
93 | data_transform = data_transform.to(device)
94 | # %%
95 | # Set up the losses
96 | l2loss = LpLoss(d=2, p=2)
97 | h1loss = H1Loss(d=2)
98 | train_loss = h1loss
99 | eval_losses = {"h1": h1loss, "l2": l2loss}
100 | print("\n### OPTIMIZER rank ###\n", i, optimizer)
101 | sys.stdout.flush()
102 |
103 | # Finally pass all of these to the Trainer
104 | trainer = Trainer(
105 | model=model,
106 | n_epochs=100,
107 | data_processor=data_transform,
108 | callbacks=callbacks,
109 | device=device,
110 | verbose=True,
111 | )
112 |
113 | # %%
114 | # Train the model
115 | trainer.train(
116 | train_loader,
117 | test_loaders,
118 | optimizer,
119 | scheduler,
120 | regularizer=False,
121 | training_loss=train_loss,
122 | eval_losses=eval_losses,
123 | )
124 |
125 | # %%
126 | # Plot the prediction, and compare with the ground-truth
127 | # Note that we trained on a very small resolution for
128 | # a very small number of epochs
129 | # In practice, we would train at larger resolution, on many more samples.
130 | #
131 | # However, for practicity, we created a minimal example that
132 | # i) fits in just a few Mb of memory
133 | # ii) can be trained quickly on CPU
134 | #
135 | # In practice we would train a Neural Operator on one or multiple GPUs
136 |
137 | test_samples = test_loaders[32].dataset
138 |
139 | fig = plt.figure(figsize=(7, 7))
140 | for index in range(3):
141 | data = test_samples[index]
142 | # Input x
143 | x = data["x"].to(device)
144 | # Ground-truth
145 | y = data["y"].to(device)
146 | # Model prediction
147 | out = model(x.unsqueeze(0))
148 | ax = fig.add_subplot(3, 3, index * 3 + 1)
149 | x = x.cpu().squeeze().detach().numpy()
150 | y = y.cpu().squeeze().detach().numpy()
151 | ax.imshow(x, cmap="gray")
152 | if index == 0:
153 | ax.set_title("Input x")
154 | plt.xticks([], [])
155 | plt.yticks([], [])
156 |
157 | ax = fig.add_subplot(3, 3, index * 3 + 2)
158 | ax.imshow(y.squeeze())
159 | if index == 0:
160 | ax.set_title("Ground-truth y")
161 | plt.xticks([], [])
162 | plt.yticks([], [])
163 |
164 | ax = fig.add_subplot(3, 3, index * 3 + 3)
165 | ax.imshow(out.cpu().squeeze().detach().numpy())
166 | if index == 0:
167 | ax.set_title("Model prediction")
168 | plt.xticks([], [])
169 | plt.yticks([], [])
170 |
171 | fig.suptitle("Inputs, ground-truth output and prediction.", y=0.98)
172 | plt.tight_layout()
173 | fig.show()
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | with open("requirements.txt") as f:
4 | required = f.read().splitlines()
5 |
6 | setup(
7 | name="galore-torch",
8 | version="1.0",
9 | description="GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection",
10 | url="https://github.com/jiaweizzhao/GaLore",
11 | author="Jiawei Zhao",
12 | author_email="jiawei@caltech.edu",
13 | license="Apache 2.0",
14 | packages=["galore_torch"],
15 | install_requires=required,
16 | )
--------------------------------------------------------------------------------
/torchrun_main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | import random
5 | import argparse
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.data
11 | import torch.distributed as dist
12 |
13 | import transformers
14 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
15 | from transformers import LlamaForCausalLM as HF_LlamaForCausalLM
16 |
17 | import datasets
18 | import datasets.distributed
19 | import wandb
20 |
21 | from tqdm import tqdm
22 | from loguru import logger
23 |
24 | from peft_pretraining import training_utils, args_utils
25 | from peft_pretraining.dataloader import PreprocessedIterableDataset
26 | from peft_pretraining.modeling_llama import LlamaForCausalLM
27 |
28 | import bitsandbytes as bnb
29 | from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor
30 |
31 | transformers.logging.set_verbosity_error()
32 |
33 | def parse_args(args):
34 | parser = argparse.ArgumentParser()
35 |
36 | parser.add_argument("--model_config", type=str, required=True)
37 | parser.add_argument("--use_hf_model", default=False, action="store_true")
38 | parser.add_argument("--continue_from", type=str, default=None)
39 | parser.add_argument("--batch_size", type=int, required=True)
40 | parser.add_argument("--gradient_accumulation", type=int, default=None)
41 | parser.add_argument("--total_batch_size", type=int, default=None)
42 | parser.add_argument("--max_length", type=int, default=256)
43 | parser.add_argument("--optimizer", default="Adam")
44 | parser.add_argument("--lr", type=float, default=1e-4)
45 | parser.add_argument("--scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_restarts"])
46 | parser.add_argument("--min_lr_ratio", type=float, default=0.1)
47 | parser.add_argument("--activation_checkpointing", action="store_true")
48 | parser.add_argument("--weight_decay", type=float, default=0.0)
49 | parser.add_argument("--warmup_steps", type=int, default=1_000)
50 | parser.add_argument("--eval_every", type=int, default=5_000)
51 | parser.add_argument("--num_training_steps", type=int, default=10_000,
52 | help="Number of **update steps** to train for. "
53 | "Notice that gradient accumulation is taken into account.")
54 | parser.add_argument("--max_train_tokens", type=training_utils.max_train_tokens_to_number, default=None,
55 | help="Number of tokens to train on. Overwrites num_training_steps. "
56 | "You can use M and B suffixes, e.g. 100M or 1B.")
57 | parser.add_argument("--save_every", type=int, default=10_000)
58 | parser.add_argument("--save_dir", type=str, default=None)
59 | parser.add_argument("--tags", type=str, default=None)
60 | parser.add_argument("--dtype", type=str, default="bfloat16" if torch.cuda.is_bf16_supported() else "float32")
61 | parser.add_argument("--workers", type=int, default=8)
62 | parser.add_argument("--seed", type=int, default=0)
63 | parser.add_argument("--name", type=str, default="test")
64 | parser.add_argument("--grad_clipping", type=float, default=0.0)
65 | # beta1 for adafactor
66 | parser.add_argument("--beta1", type=float, default=0.0)
67 |
68 | # GaLore parameters
69 | parser.add_argument("--rank", type=int, default=128)
70 | parser.add_argument("--update_proj_gap", type=int, default=50)
71 | parser.add_argument("--galore_scale", type=float, default=1.0)
72 | parser.add_argument("--proj_type", type=str, default="std")
73 |
74 | # disable ddp, single_gpu
75 | parser.add_argument("--single_gpu", default=False, action="store_true")
76 |
77 | args = parser.parse_args(args)
78 |
79 | args = args_utils.check_args_torchrun_main(args)
80 | return args
81 |
82 |
83 | @torch.no_grad()
84 | def evaluate_model(model, preprocess_batched, pad_idx, global_rank, world_size, device, batch_size):
85 | _time = time.time()
86 | val_data = datasets.load_dataset("c4", "en", split="validation", streaming=True) #DGX
87 | val_data = val_data.shuffle(seed=42)
88 | logger.info(f"Loaded validation dataset in {time.time() - _time:.2f} seconds")
89 |
90 | if not args.single_gpu:
91 | val_data = datasets.distributed.split_dataset_by_node(val_data, rank=global_rank, world_size=world_size)
92 |
93 | val_data_mapped = val_data.map(
94 | preprocess_batched,
95 | batched=True,
96 | remove_columns=["text", "timestamp", "url"],
97 | )
98 | val_data_mapped.batch = lambda batch_size: training_utils.batch_fn(val_data_mapped, batch_size)
99 |
100 | target_eval_tokens = 10_000_000
101 | evaluated_on_tokens = 0
102 | total_loss = torch.tensor(0.0).to(device)
103 | total_batches = 1
104 | logger.info(f"Eval set prepared in {time.time() - _time:.2f} seconds")
105 |
106 | for batch in val_data_mapped.batch(batch_size=batch_size):
107 | if evaluated_on_tokens > target_eval_tokens:
108 | break
109 | total_batches += 1
110 |
111 | batch = {k: v.to(device) for k, v in batch.items()}
112 | labels = batch["input_ids"].clone()
113 | labels[labels == pad_idx] = -100
114 | loss = model(**batch, labels=labels).loss
115 | total_loss += loss.detach()
116 |
117 | evaluated_on_tokens += (batch["input_ids"] != pad_idx).sum().item() * world_size
118 |
119 | total_loss = total_loss / total_batches
120 |
121 | # Gather losses across all GPUs
122 | gathered_losses = [torch.zeros_like(total_loss) for _ in range(world_size)]
123 | dist.all_gather(gathered_losses, total_loss)
124 | total_loss = sum([t.item() for t in gathered_losses]) / world_size
125 |
126 | return total_loss, evaluated_on_tokens
127 |
128 |
129 | def main(args):
130 | torch.manual_seed(args.seed)
131 | np.random.seed(args.seed)
132 | random.seed(args.seed)
133 |
134 | assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK"
135 | global_rank = int(os.environ['RANK'])
136 | local_rank = int(os.environ["LOCAL_RANK"])
137 | world_size = int(os.environ["WORLD_SIZE"])
138 | torch.cuda.set_device(local_rank)
139 |
140 | logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}")
141 |
142 | dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size)
143 |
144 | logger.info("Process group initialized")
145 | device = f"cuda:{local_rank}"
146 |
147 | if args.total_batch_size is not None:
148 | if args.gradient_accumulation is None:
149 | assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size"
150 | args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size)
151 | assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0"
152 |
153 | assert args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size, \
154 | "gradient_accumulation * batch_size * world_size must be equal to total_batch_size"
155 |
156 | # turn off logger
157 | if global_rank != 0: logger.remove()
158 |
159 | # initialize wandb without config (it is passed later)
160 | if global_rank == 0:
161 | wandb.init(project="galore-c4")
162 |
163 | logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)")
164 | logger.info("*" * 40)
165 | logger.info(f"Starting training with the arguments")
166 | for k, v in vars(args).items():
167 | logger.info(f"{k:30} {v}")
168 | logger.info("*" * 40)
169 |
170 | data = datasets.load_dataset("allenai/c4", "en", split="train", streaming=True)
171 |
172 | seed_for_shuffle = 42
173 |
174 | logger.info(f"Shuffling data with seed {seed_for_shuffle}")
175 | data: datasets.Dataset = data.shuffle(seed=seed_for_shuffle)
176 | if not args.single_gpu:
177 | data = datasets.distributed.split_dataset_by_node(
178 | data, rank=global_rank, world_size=world_size,
179 | )
180 |
181 | # it doesn't matter which tokenizer we use, because we train from scratch
182 | # T5 tokenizer was trained on C4 and we are also training on C4, so it's a good choice
183 | tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=args.max_length)
184 |
185 | def preprocess_batched(batch):
186 | batch = tokenizer(
187 | batch["text"],
188 | max_length=args.max_length,
189 | truncation=True,
190 | padding="max_length",
191 | return_tensors="pt",
192 | )
193 | return batch
194 |
195 | dataset = PreprocessedIterableDataset(data, tokenizer, batch_size=args.batch_size, max_length=args.max_length)
196 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=args.workers)
197 |
198 | model_config = AutoConfig.from_pretrained(args.model_config)
199 | if args.use_hf_model:
200 | model: HF_LlamaForCausalLM = AutoModelForCausalLM.from_config(model_config)
201 | else:
202 | model = LlamaForCausalLM(model_config)
203 |
204 | if args.activation_checkpointing:
205 | model.gradient_checkpointing_enable()
206 |
207 | global_step = 0
208 | update_step = 0
209 | beginning_step = 0
210 | tokens_seen = 0
211 | tokens_seen_before = 0
212 |
213 | if args.continue_from is not None:
214 | logger.info("*" * 40)
215 | logger.info(f"Loading model from {args.continue_from}")
216 | checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin")
217 | model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True)
218 | logger.info(f"Model successfully loaded (strict=True policy)")
219 |
220 | if os.path.exists(os.path.join(args.continue_from, "training_state.json")):
221 | logger.info(f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}")
222 | with open(os.path.join(args.continue_from, "training_state.json")) as f:
223 | _old_state = json.load(f)
224 | global_step = _old_state["global_step"]
225 | update_step = _old_state["update_step"]
226 | tokens_seen = _old_state["tokens_seen"]
227 | tokens_seen_before = _old_state["tokens_seen_before"]
228 | logger.info(f"global_step : {global_step}")
229 | logger.info(f"update_step : {update_step}")
230 | logger.info(f"tokens_seen : {tokens_seen}")
231 | logger.info(f"tokens_seen_before: {tokens_seen_before}")
232 | logger.info(f"Will train for {args.num_training_steps - update_step} update steps")
233 | else:
234 | logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero")
235 | logger.info("*" * 40)
236 |
237 |
238 | if args.dtype in ["bf16", "bfloat16"]:
239 | model = model.to(device=device, dtype=torch.bfloat16)
240 | else:
241 | model = model.to(device=device)
242 |
243 | n_total_params = sum(p.numel() for p in model.parameters())
244 | trainable_params = [p for p in model.parameters() if p.requires_grad]
245 | # Initialize wandb
246 | run_config = dict(vars(args))
247 | run_config.update({
248 | "max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler
249 | "total_params_M": n_total_params / 1_000_000,
250 | "dataset": 'c4',
251 | "model": model_config.to_dict(),
252 | "world_size": world_size,
253 | "device": str(device),
254 | })
255 |
256 | if global_rank == 0:
257 | wandb.config.update(run_config, allow_val_change=True)
258 | wandb.save(os.path.abspath(__file__), policy="now") # save current script
259 | # fix tqdm visual length to 80 so that the progress bar
260 | # doesn't jump around when changing from external display to laptop
261 | pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80)
262 |
263 | if 'galore' in args.optimizer.lower():
264 | # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
265 | galore_params = []
266 | target_modules_list = ["attn", "mlp"]
267 | for module_name, module in model.named_modules():
268 | if not isinstance(module, nn.Linear):
269 | continue
270 |
271 | if not any(target_key in module_name for target_key in target_modules_list):
272 | continue
273 |
274 | print('enable GaLore for weights in module: ', module_name)
275 | galore_params.append(module.weight)
276 | id_galore_params = [id(p) for p in galore_params]
277 | # make parameters without "rank" to another group
278 | regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
279 | # then call galore_adamw
280 | param_groups = [{'params': regular_params},
281 | {'params': galore_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.galore_scale, 'proj_type': args.proj_type}]
282 |
283 | # print params and trainable params
284 | logger.info(f"\n{model}\n")
285 | logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M")
286 | logger.info(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000:.2f}M")
287 | if 'galore' in args.optimizer.lower():
288 | logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in galore_params) / 1_000_000:.2f}M")
289 | logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps")
290 |
291 | layer_wise_flag = False
292 | if args.optimizer.lower() == "adam":
293 | optimizer = torch.optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
294 | elif args.optimizer.lower() == "galore_adamw":
295 | # redefine way to call galore_adamw
296 | optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
297 | # implement sgd
298 | elif args.optimizer.lower() == "sgd":
299 | optimizer = torch.optim.SGD(trainable_params, lr=args.lr, weight_decay=args.weight_decay, momentum=args.beta1)
300 | # implement adafactor
301 | elif args.optimizer.lower() == "adafactor":
302 | args.beta1 = None if args.beta1 == 0.0 else args.beta1
303 | optimizer = transformers.optimization.Adafactor(
304 | trainable_params,
305 | lr=args.lr,
306 | eps=(1e-30, 1e-3),
307 | clip_threshold=1.0,
308 | decay_rate=-0.8,
309 | beta1=args.beta1,
310 | weight_decay=args.weight_decay,
311 | relative_step=False,
312 | scale_parameter=False,
313 | warmup_init=False,
314 | )
315 | # low-rank adafactor
316 | elif args.optimizer.lower() == "galore_adafactor":
317 | args.beta1 = None if args.beta1 == 0.0 else args.beta1
318 | optimizer = GaLoreAdafactor(
319 | param_groups,
320 | lr=args.lr,
321 | eps=(1e-30, 1e-3),
322 | clip_threshold=1.0,
323 | decay_rate=-0.8,
324 | beta1=args.beta1,
325 | weight_decay=args.weight_decay,
326 | relative_step=False,
327 | scale_parameter=False,
328 | warmup_init=False,
329 | )
330 | # 8-bit Adam
331 | elif args.optimizer.lower() == "adam8bit":
332 | optimizer = bnb.optim.Adam8bit(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
333 | elif args.optimizer.lower() == "galore_adamw8bit":
334 | optimizer = GaLoreAdamW8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay)
335 | elif args.optimizer.lower() == 'galore_adamw8bit_per_layer':
336 | # TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap
337 | optimizer_dict = {}
338 | for p in model.parameters():
339 | if p.requires_grad:
340 | if id(p) in id_galore_params:
341 | optimizer_dict[p] = GaLoreAdamW8bit([{'params': [p], 'rank': args.rank, 'update_proj_gap': args.update_proj_gap * 2, 'scale': args.galore_scale, 'proj_type': args.proj_type}], lr=args.lr, weight_decay=args.weight_decay)
342 | else:
343 | optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=args.lr, weight_decay=args.weight_decay)
344 |
345 | # get scheduler dict
346 | scheduler_dict = {}
347 | for p in model.parameters():
348 | if p.requires_grad:
349 | scheduler_dict[p] = training_utils.get_scheculer(
350 | optimizer=optimizer_dict[p],
351 | scheduler_type=args.scheduler,
352 | num_training_steps=args.num_training_steps * 2,
353 | warmup_steps=args.warmup_steps * 2,
354 | min_lr_ratio=args.min_lr_ratio,
355 | )
356 |
357 | def optimizer_hook(p):
358 | if p.grad is None:
359 | return
360 | optimizer_dict[p].step()
361 | optimizer_dict[p].zero_grad()
362 | scheduler_dict[p].step()
363 |
364 | # Register the hook onto every parameter
365 | for p in model.parameters():
366 | if p.requires_grad:
367 | p.register_post_accumulate_grad_hook(optimizer_hook)
368 |
369 | layer_wise_flag = True
370 |
371 | else:
372 | raise ValueError(f"Optimizer {args.optimizer} not supported")
373 |
374 | if not layer_wise_flag:
375 | scheduler = training_utils.get_scheculer(
376 | optimizer=optimizer,
377 | scheduler_type=args.scheduler,
378 | num_training_steps=args.num_training_steps,
379 | warmup_steps=args.warmup_steps,
380 | min_lr_ratio=args.min_lr_ratio,
381 | )
382 |
383 | if not args.single_gpu:
384 | model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel(
385 | model,
386 | device_ids=[local_rank],
387 | output_device=local_rank,
388 | broadcast_buffers=False,
389 | )
390 |
391 | # global steps and others are defined above
392 | pad_idx = tokenizer.pad_token_id
393 | update_time = time.time()
394 | local_step = 0 # when continue_from is used, local_step != global_step
395 |
396 | # ##############################
397 | # TRAINING LOOP
398 | # we'll never go through all the data, so no need for epochs
399 | # ##############################
400 |
401 | for batch_idx, batch in enumerate(dataloader):
402 |
403 | global_step += 1
404 | local_step += 1
405 |
406 | if update_step > args.num_training_steps:
407 | logger.info(f"Reached max number of update steps (f{args.num_training_steps}). Stopping training.")
408 | print(f"Rank {global_rank} stopping training.")
409 | break
410 |
411 | batch = {k: v.to(device) for k, v in batch.items()}
412 | labels = batch["input_ids"].clone()
413 | labels[labels == pad_idx] = -100
414 | tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size
415 |
416 | loss = model(**batch, labels=labels).loss
417 | scaled_loss = loss / args.gradient_accumulation
418 | scaled_loss.backward()
419 |
420 | if global_step % args.gradient_accumulation != 0:
421 | continue
422 |
423 |
424 | # The below code is only executed during the update step
425 |
426 | # add grad clipping
427 | if args.grad_clipping != 0.0: torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping)
428 |
429 | if global_rank == 0: pbar.update(1)
430 |
431 | if not layer_wise_flag:
432 | optimizer.step()
433 | scheduler.step()
434 | optimizer.zero_grad()
435 |
436 | update_step += 1
437 | update_time = time.time() - update_time
438 |
439 | # save checkpoint by save_every
440 | if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0:
441 | current_model_directory = f"{args.save_dir}/model_{update_step}"
442 | logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
443 | os.makedirs(args.save_dir, exist_ok=True)
444 | model.module.save_pretrained(current_model_directory, max_shard_size='100GB')
445 |
446 | optimizer_checkpoint = {
447 | "optimizer": optimizer.state_dict(),
448 | "scheduler": scheduler.state_dict(),
449 | "update_step": update_step,
450 | "global_step": global_step,
451 | "config": run_config,
452 | "wandb": wandb.run.dir,
453 | "dtype": args.dtype,
454 | }
455 | torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
456 |
457 | training_state_checkpoint = {
458 | "global_step": global_step,
459 | "update_step": update_step,
460 | "tokens_seen": tokens_seen,
461 | "tokens_seen_before": tokens_seen_before,
462 | "update_time": update_time,
463 | }
464 | with open(f"{current_model_directory}/training_state.json", "w") as f:
465 | json.dump(training_state_checkpoint, f, indent=4)
466 |
467 | # save wandb related info
468 | wandb_info = {
469 | "wandb_id": wandb.run.id,
470 | }
471 | with open(f"{args.save_dir}/wandb.json", "w") as f:
472 | json.dump(wandb_info, f, indent=4)
473 |
474 | # evaluation
475 | if update_step % args.eval_every == 0:
476 | logger.info(f"Performing evaluation at step {update_step}")
477 | total_loss, evaluated_on_tokens = evaluate_model(
478 | model, preprocess_batched, pad_idx, global_rank, world_size, device, args.batch_size
479 | )
480 | if global_rank == 0:
481 | wandb.log({
482 | "final_eval_loss": total_loss,
483 | "final_eval_tokens": evaluated_on_tokens,
484 | },
485 | step=global_step,
486 | )
487 | logger.info(f"Eval loss at step {update_step}: {total_loss}")
488 |
489 | if not layer_wise_flag:
490 | lr = optimizer.param_groups[0]["lr"]
491 | else:
492 | lr = list(optimizer_dict.values())[0].param_groups[0]["lr"]
493 | tokens_in_update = tokens_seen - tokens_seen_before
494 | tokens_seen_before = tokens_seen
495 | batches_in_update = args.gradient_accumulation * world_size
496 |
497 | if global_rank == 0:
498 | wandb.log({
499 | "loss": loss.item(),
500 | "lr": lr,
501 | "update_step": update_step,
502 | "tokens_seen": tokens_seen,
503 | "throughput_tokens": tokens_in_update / update_time,
504 | "throughput_examples": args.total_batch_size / update_time,
505 | "throughput_batches": batches_in_update / update_time,
506 | },
507 | step=global_step,
508 | )
509 | update_time = time.time()
510 |
511 | # ##############################
512 | # END of training loop
513 | # ##############################
514 | logger.info("Training finished")
515 | if global_rank == 0: pbar.close()
516 |
517 | current_model_directory = f"{args.save_dir}/model_{update_step}"
518 | if global_rank == 0 and not os.path.exists(current_model_directory):
519 | logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
520 | os.makedirs(args.save_dir, exist_ok=True)
521 | model.module.save_pretrained(current_model_directory)
522 |
523 | optimizer_checkpoint = {
524 | "optimizer": optimizer.state_dict(),
525 | "scheduler": scheduler.state_dict(),
526 | "update_step": update_step,
527 | "global_step": global_step,
528 | "config": run_config,
529 | "wandb": wandb.run.dir,
530 | "dtype": args.dtype,
531 | }
532 | torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
533 |
534 | training_state_checkpoint = {
535 | "global_step": global_step,
536 | "update_step": update_step,
537 | "tokens_seen": tokens_seen,
538 | "tokens_seen_before": tokens_seen_before,
539 | "update_time": update_time,
540 | }
541 | with open(f"{current_model_directory}/training_state.json", "w") as f:
542 | json.dump(training_state_checkpoint, f, indent=4)
543 |
544 | # Final evaluation
545 | logger.info("Running final evaluation")
546 | model.eval()
547 | del loss, optimizer, scheduler
548 | import gc; gc.collect()
549 | torch.cuda.empty_cache()
550 |
551 | total_loss, evaluated_on_tokens = evaluate_model(
552 | model, preprocess_batched, pad_idx, global_rank, world_size, device, args.batch_size
553 | )
554 |
555 | if global_rank == 0:
556 | wandb.log({
557 | "final_eval_loss": total_loss,
558 | "final_eval_tokens": evaluated_on_tokens,
559 | },
560 | step=global_step,
561 | )
562 | logger.info(f"Final eval loss: {total_loss}")
563 |
564 | logger.info("Script finished successfully")
565 | print(f"Rank {global_rank} finished successfully")
566 |
567 |
568 | if __name__ == "__main__":
569 | print("Starting script")
570 | args = parse_args(None)
571 | main(args)
572 |
--------------------------------------------------------------------------------