├── CITATION.cff ├── LICENSE ├── README.md ├── configs ├── llama_100m.json ├── llama_130m.json ├── llama_1b.json ├── llama_20m.json ├── llama_250m.json ├── llama_350m.json ├── llama_35m.json ├── llama_3b.json ├── llama_40m.json ├── llama_60m.json ├── llama_71m.json ├── llama_7b.json └── llama_9m.json ├── exp_requirements.txt ├── galore_torch ├── __init__.py ├── adafactor.py ├── adamw.py ├── adamw8bit.py ├── galore_projector.py └── galore_projector_tensor.py ├── imgs ├── galore_code_box.png └── subspace_learning.png ├── peft_pretraining ├── args_utils.py ├── dataloader.py ├── modeling_llama.py └── training_utils.py ├── requirements.txt ├── run_glue.py ├── scripts ├── benchmark_c4 │ ├── llama_130m.sh │ ├── llama_1b.sh │ ├── llama_350m.sh │ ├── llama_60m.sh │ └── llama_7b.sh ├── single_gpu │ ├── llama_7b.sh │ └── llama_7b_checkpointing.sh └── tensor_test │ └── neural_operator.py ├── setup.py └── torchrun_main.py /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: "GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection" 3 | version: 1.0.0 4 | message: "If you use this software, please cite it as below." 5 | authors: 6 | - family-names: "Jiawei" 7 | given-names: "Zhao" 8 | year: 2024 9 | repository-code: "https://arxiv.org/abs/2403.03507" 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GaLore 2 | 3 | This repo contains the pre-release version of GaLore algorithm, proposed by [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507). 4 | 5 | Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows *full-parameter* learning but is more *memory-efficient* than common low-rank adaptation methods, such as LoRA. 6 | As a gradient projection method, GaLore is independent of the choice of optimizers and can be easily plugged into existing ones with only two lines of code, as shown in Algorithm 1 below. 7 | 8 |
9 | Image 2 10 |
11 | 12 | ## News 13 | 14 | 15 | - **2024-09-01**: We are working on GaLore 2, which is a more efficient and accessible version of GaLore. Please stay tuned! 16 | - **2024-07-11**: We release Q-GaLore: Quantized GaLore with INT4 Projection. [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)] 17 | 18 | - **2024-07-01**: GaLore is accepted to ICML 2024 as Oral! 19 | 20 | - **2024-04-20**: Please join our Slack workspace [GaLore-Social](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) to discuss with us and the community. 21 | 22 | ## Installation 23 | 24 | ### Install GaLore optimizer 25 | Install from pip: 26 | ```bash 27 | pip install galore-torch 28 | ``` 29 | 30 | or if you want to install from source: 31 | 32 | ```bash 33 | git clone git@github.com:jiaweizzhao/GaLore.git 34 | cd GaLore 35 | pip install -e . 36 | ``` 37 | 38 | ### Install experiment dependencies 39 | 40 | ```bash 41 | pip install -r exp_requirements.txt 42 | ``` 43 | 44 | Our experiment scripts are tested on Python 3.8 with PyTorch 2.1. 45 | 46 | ## Usage 47 | 48 | ### Save optimizer memory using GaLore optimizers 49 | 50 | ```python 51 | from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor 52 | # define param groups as galore_params and non_galore_params 53 | param_groups = [{'params': non_galore_params}, 54 | {'params': galore_params, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}] 55 | optimizer = GaLoreAdamW(param_groups, lr=0.01) 56 | ``` 57 | ### Save weight gradient memory using per-layer weight updates 58 | 59 | We use `register_post_accumulate_grad_hook` provided by [PyTorch](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html) (`torch>=2.1.0`) to enable per-layer weight updates. An example is shown below: 60 | 61 | ```python 62 | # define an optimizer for each parameter p, and store them in optimizer_dict 63 | for p in model.parameters(): 64 | if p.requires_grad: 65 | optimizer_dict[p] = GaLoreAdamW([{'params': p, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}], lr=0.01) 66 | 67 | # define a hook function to update the parameter p during the backward pass 68 | def optimizer_hook(p): 69 | if p.grad is None: 70 | return 71 | optimizer_dict[p].step() 72 | optimizer_dict[p].zero_grad() 73 | 74 | # Register the hook onto every parameter 75 | for p in model.parameters(): 76 | if p.requires_grad: 77 | p.register_post_accumulate_grad_hook(optimizer_hook) 78 | ``` 79 | More details can be found in [torchrun_main.py](https://github.com/jiaweizzhao/GaLore/blob/a6bc1650984b1c090a4e108d7c0e3109ee7ad844/torchrun_main.py#L334). 80 | 81 | ## Benchmark 1: Pre-Training LLaMA on C4 dataset 82 | `torchrun_main.py` is the main script for training LLaMA models on C4 with GaLore. Our benchmark scripts for various sizes of models are in `scripts/benchmark_c4` folder. 83 | For example, to train a 60m model on C4, do the following: 84 | 85 | ```bash 86 | # LLaMA-60M, GaLore-Adam, 1 A100, 1 Node 87 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \ 88 | --model_config configs/llama_60m.json \ 89 | --lr 0.01 \ 90 | --galore_scale 0.25 \ 91 | --rank 128 \ 92 | --update_proj_gap 200 \ 93 | --batch_size 256 \ 94 | --total_batch_size 512 \ 95 | --num_training_steps 10000 \ 96 | --warmup_steps 1000 \ 97 | --weight_decay 0 \ 98 | --dtype bfloat16 \ 99 | --eval_every 1000 \ 100 | --optimizer galore_adamw 101 | ``` 102 | 103 | ### Train 7B model with a single GPU with 24GB memory 104 | To train a 7B model with a single GPU such as NVIDIA RTX 4090, all you need to do is to specify `--optimizer=galore_adamw8bit_per_layer`, which enables `GaLoreAdamW8bit` with per-layer weight updates. 105 | With activation checkpointing, you can maintain a batch size of 16 tested on NVIDIA RTX 4090. 106 | 107 | ```bash 108 | # LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing 109 | # bsz=16, 22.8G, 110 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \ 111 | --model_config configs/llama_7b.json \ 112 | --lr 0.005 \ 113 | --galore_scale 0.25 \ 114 | --rank 1024 \ 115 | --update_proj_gap 500 \ 116 | --batch_size 16 \ 117 | --total_batch_size 512 \ 118 | --activation_checkpointing \ 119 | --num_training_steps 150000 \ 120 | --warmup_steps 15000 \ 121 | --weight_decay 0 \ 122 | --grad_clipping 1.0 \ 123 | --dtype bfloat16 \ 124 | --eval_every 1000 \ 125 | --single_gpu \ 126 | --optimizer galore_adamw8bit_per_layer 127 | ``` 128 | 129 | Currently per-layer weight updates technique is only supported for single GPU training (`--single_gpu`) without using `nn.parallel.DistributedDataParallel`. We are working on supporting multi-GPU training with per-layer weight updates. 130 | 131 | ## Benchmark 2: Fine-Tuning RoBERTa on GLUE tasks 132 | `run_glue.py` is the main script for fine-tuning RoBERTa models on GLUE tasks with GaLore. An example script is shown below: 133 | 134 | ```bash 135 | python run_glue.py \ 136 | --model_name_or_path roberta-base \ 137 | --task_name mrpc \ 138 | --enable_galore \ 139 | --lora_all_modules \ 140 | --max_length 512 \ 141 | --seed=1234 \ 142 | --lora_r 4 \ 143 | --galore_scale 4 \ 144 | --per_device_train_batch_size 16 \ 145 | --update_proj_gap 500 \ 146 | --learning_rate 3e-5 \ 147 | --num_train_epochs 30 \ 148 | --output_dir results/ft/roberta_base/mrpc 149 | ``` 150 | 151 | ## Citation 152 | ```bibtex 153 | @misc{zhao2024galore, 154 | title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection}, 155 | author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian}, 156 | year={2024}, 157 | eprint={2403.03507}, 158 | archivePrefix={arXiv}, 159 | primaryClass={cs.LG} 160 | } 161 | ``` -------------------------------------------------------------------------------- /configs/llama_100m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 640, 9 | "intermediate_size": 1708, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 10, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_130m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 768, 9 | "intermediate_size": 2048, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 12, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_1b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 2048, 9 | "intermediate_size": 5461, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_20m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 256, 9 | "intermediate_size": 688, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 4, 14 | "num_hidden_layers": 4, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_250m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 768, 9 | "intermediate_size": 2560, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_350m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 1024, 9 | "intermediate_size": 2736, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_35m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 384, 9 | "intermediate_size": 1024, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 6, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_3b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 2560, 9 | "intermediate_size": 6848, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_40m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 416, 9 | "intermediate_size": 1024, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 8, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_60m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 512, 9 | "intermediate_size": 1376, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 8, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_71m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 512, 9 | "intermediate_size": 1368, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 4096, 9 | "intermediate_size": 11008, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 2048, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_9m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 128, 9 | "intermediate_size": 352, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 4, 14 | "num_hidden_layers": 4, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /exp_requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers==4.31.0 3 | tokenizers 4 | datasets 5 | peft 6 | wandb 7 | loguru 8 | nvitop 9 | lion-pytorch 10 | matplotlib 11 | bitsandbytes 12 | scipy 13 | scikit-learn 14 | evaluate -------------------------------------------------------------------------------- /galore_torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .adafactor import Adafactor as GaLoreAdafactor 2 | from .adamw import AdamW as GaLoreAdamW 3 | from .adamw8bit import AdamW8bit as GaLoreAdamW8bit -------------------------------------------------------------------------------- /galore_torch/adafactor.py: -------------------------------------------------------------------------------- 1 | # copy dependencies from transformers/optimization.py 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | from torch.optim import Optimizer 7 | 8 | 9 | from transformers.utils.versions import require_version 10 | 11 | from .galore_projector import GaLoreProjector 12 | from .galore_projector_tensor import GaLoreProjectorTensor 13 | 14 | 15 | class Adafactor(Optimizer): 16 | """ 17 | AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: 18 | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py 19 | 20 | Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that 21 | this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and 22 | `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and 23 | `relative_step=False`. 24 | 25 | Arguments: 26 | params (`Iterable[nn.parameter.Parameter]`): 27 | Iterable of parameters to optimize or dictionaries defining parameter groups. 28 | lr (`float`, *optional*): 29 | The external learning rate. 30 | eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): 31 | Regularization constants for square gradient and parameter scale respectively 32 | clip_threshold (`float`, *optional*, defaults to 1.0): 33 | Threshold of root mean square of final gradient update 34 | decay_rate (`float`, *optional*, defaults to -0.8): 35 | Coefficient used to compute running averages of square 36 | beta1 (`float`, *optional*): 37 | Coefficient used for computing running averages of gradient 38 | weight_decay (`float`, *optional*, defaults to 0.0): 39 | Weight decay (L2 penalty) 40 | scale_parameter (`bool`, *optional*, defaults to `True`): 41 | If True, learning rate is scaled by root mean square 42 | relative_step (`bool`, *optional*, defaults to `True`): 43 | If True, time-dependent learning rate is computed instead of external learning rate 44 | warmup_init (`bool`, *optional*, defaults to `False`): 45 | Time-dependent learning rate computation depends on whether warm-up initialization is being used 46 | 47 | This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. 48 | 49 | Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): 50 | 51 | - Training without LR warmup or clip_threshold is not recommended. 52 | 53 | - use scheduled LR warm-up to fixed LR 54 | - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) 55 | - Disable relative updates 56 | - Use scale_parameter=False 57 | - Additional optimizer operations like gradient clipping should not be used alongside Adafactor 58 | 59 | Example: 60 | 61 | ```python 62 | Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) 63 | ``` 64 | 65 | Others reported the following combination to work well: 66 | 67 | ```python 68 | Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) 69 | ``` 70 | 71 | When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] 72 | scheduler as following: 73 | 74 | ```python 75 | from transformers.optimization import Adafactor, AdafactorSchedule 76 | 77 | optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) 78 | lr_scheduler = AdafactorSchedule(optimizer) 79 | trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) 80 | ``` 81 | 82 | Usage: 83 | 84 | ```python 85 | # replace AdamW with Adafactor 86 | optimizer = Adafactor( 87 | model.parameters(), 88 | lr=1e-3, 89 | eps=(1e-30, 1e-3), 90 | clip_threshold=1.0, 91 | decay_rate=-0.8, 92 | beta1=None, 93 | weight_decay=0.0, 94 | relative_step=False, 95 | scale_parameter=False, 96 | warmup_init=False, 97 | ) 98 | ```""" 99 | 100 | def __init__( 101 | self, 102 | params, 103 | lr=None, 104 | eps=(1e-30, 1e-3), 105 | clip_threshold=1.0, 106 | decay_rate=-0.8, 107 | beta1=None, 108 | weight_decay=0.0, 109 | scale_parameter=True, 110 | relative_step=True, 111 | warmup_init=False, 112 | ): 113 | require_version("torch>=1.5.0") # add_ with alpha 114 | if lr is not None and relative_step: 115 | raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") 116 | if warmup_init and not relative_step: 117 | raise ValueError("`warmup_init=True` requires `relative_step=True`") 118 | 119 | defaults = { 120 | "lr": lr, 121 | "eps": eps, 122 | "clip_threshold": clip_threshold, 123 | "decay_rate": decay_rate, 124 | "beta1": beta1, 125 | "weight_decay": weight_decay, 126 | "scale_parameter": scale_parameter, 127 | "relative_step": relative_step, 128 | "warmup_init": warmup_init, 129 | } 130 | super().__init__(params, defaults) 131 | 132 | @staticmethod 133 | def _get_lr(param_group, param_state): 134 | rel_step_sz = param_group["lr"] 135 | if param_group["relative_step"]: 136 | min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 137 | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) 138 | param_scale = 1.0 139 | if param_group["scale_parameter"]: 140 | param_scale = max(param_group["eps"][1], param_state["RMS"]) 141 | return param_scale * rel_step_sz 142 | 143 | @staticmethod 144 | def _get_options(param_group, param_shape): 145 | factored = len(param_shape) >= 2 146 | use_first_moment = param_group["beta1"] is not None 147 | return factored, use_first_moment 148 | 149 | @staticmethod 150 | def _rms(tensor): 151 | return tensor.norm(2) / (tensor.numel() ** 0.5) 152 | 153 | @staticmethod 154 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 155 | # copy from fairseq's adafactor implementation: 156 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 157 | r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) 158 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() 159 | return torch.mul(r_factor, c_factor) 160 | 161 | @torch.no_grad() 162 | def step(self, closure=None): 163 | """ 164 | Performs a single optimization step 165 | 166 | Arguments: 167 | closure (callable, optional): A closure that reevaluates the model 168 | and returns the loss. 169 | """ 170 | loss = None 171 | if closure is not None: 172 | loss = closure() 173 | 174 | for group in self.param_groups: 175 | for p in group["params"]: 176 | if p.grad is None: 177 | continue 178 | grad = p.grad 179 | if grad.dtype in {torch.float16, torch.bfloat16}: 180 | grad = grad.float() 181 | if grad.is_sparse: 182 | raise RuntimeError("Adafactor does not support sparse gradients.") 183 | 184 | state = self.state[p] 185 | 186 | if "step" not in state: 187 | state["step"] = 0 188 | 189 | if 'dim' not in group: 190 | group['dim'] = 2 191 | 192 | # GaLore Projection 193 | if "rank" in group: 194 | if "projector" not in state: 195 | if group['dim'] <=2: 196 | state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) 197 | else: 198 | state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) 199 | 200 | grad = state["projector"].project(grad, state["step"]) 201 | 202 | grad_shape = grad.shape 203 | 204 | factored, use_first_moment = self._get_options(group, grad_shape) 205 | # State Initialization 206 | if "RMS" not in state: 207 | state["step"] = 0 208 | 209 | if use_first_moment: 210 | # Exponential moving average of gradient values 211 | state["exp_avg"] = torch.zeros_like(grad) 212 | if factored: 213 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) 214 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) 215 | else: 216 | state["exp_avg_sq"] = torch.zeros_like(grad) 217 | 218 | state["RMS"] = 0 219 | else: 220 | if use_first_moment: 221 | state["exp_avg"] = state["exp_avg"].to(grad) 222 | if factored: 223 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) 224 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) 225 | else: 226 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) 227 | 228 | p_data_fp32 = p 229 | if p.dtype in {torch.float16, torch.bfloat16}: 230 | p_data_fp32 = p_data_fp32.float() 231 | 232 | state["step"] += 1 233 | state["RMS"] = self._rms(p_data_fp32) 234 | lr = self._get_lr(group, state) 235 | 236 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) 237 | update = (grad**2) + group["eps"][0] 238 | if factored: 239 | exp_avg_sq_row = state["exp_avg_sq_row"] 240 | exp_avg_sq_col = state["exp_avg_sq_col"] 241 | 242 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) 243 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) 244 | 245 | # Approximation of exponential moving average of square of gradient 246 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 247 | update.mul_(grad) 248 | else: 249 | exp_avg_sq = state["exp_avg_sq"] 250 | 251 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) 252 | update = exp_avg_sq.rsqrt().mul_(grad) 253 | 254 | update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) 255 | update.mul_(lr) 256 | 257 | if use_first_moment: 258 | exp_avg = state["exp_avg"] 259 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) 260 | update = exp_avg 261 | 262 | # GaLore Projection Back 263 | if "rank" in group: 264 | update = state["projector"].project_back(update) 265 | 266 | if group["weight_decay"] != 0: 267 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) 268 | 269 | p_data_fp32.add_(-update) 270 | 271 | if p.dtype in {torch.float16, torch.bfloat16}: 272 | p.copy_(p_data_fp32) 273 | 274 | return loss 275 | -------------------------------------------------------------------------------- /galore_torch/adamw.py: -------------------------------------------------------------------------------- 1 | # copy dependencies from transformers/optimization.py 2 | import math 3 | import warnings 4 | from typing import Callable, Iterable, Tuple 5 | 6 | import torch 7 | from torch import nn 8 | from torch.optim import Optimizer 9 | 10 | from transformers.utils.versions import require_version 11 | 12 | from .galore_projector import GaLoreProjector 13 | from .galore_projector_tensor import GaLoreProjectorTensor 14 | 15 | 16 | class AdamW(Optimizer): 17 | """ 18 | Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay 19 | Regularization](https://arxiv.org/abs/1711.05101). 20 | 21 | Parameters: 22 | params (`Iterable[nn.parameter.Parameter]`): 23 | Iterable of parameters to optimize or dictionaries defining parameter groups. 24 | lr (`float`, *optional*, defaults to 0.001): 25 | The learning rate to use. 26 | betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): 27 | Adam's betas parameters (b1, b2). 28 | eps (`float`, *optional*, defaults to 1e-06): 29 | Adam's epsilon for numerical stability. 30 | weight_decay (`float`, *optional*, defaults to 0.0): 31 | Decoupled weight decay to apply. 32 | correct_bias (`bool`, *optional*, defaults to `True`): 33 | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). 34 | no_deprecation_warning (`bool`, *optional*, defaults to `False`): 35 | A flag used to disable the deprecation warning (set to `True` to disable the warning). 36 | """ 37 | 38 | def __init__( 39 | self, 40 | params: Iterable[nn.parameter.Parameter], 41 | lr: float = 1e-3, 42 | betas: Tuple[float, float] = (0.9, 0.999), 43 | eps: float = 1e-6, 44 | weight_decay: float = 0.0, 45 | correct_bias: bool = True, 46 | no_deprecation_warning: bool = False, 47 | ): 48 | if not no_deprecation_warning: 49 | warnings.warn( 50 | "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" 51 | " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" 52 | " warning", 53 | FutureWarning, 54 | ) 55 | require_version("torch>=1.5.0") # add_ with alpha 56 | if lr < 0.0: 57 | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") 58 | if not 0.0 <= betas[0] < 1.0: 59 | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") 60 | if not 0.0 <= betas[1] < 1.0: 61 | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") 62 | if not 0.0 <= eps: 63 | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") 64 | defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} 65 | super().__init__(params, defaults) 66 | 67 | @torch.no_grad() 68 | def step(self, closure: Callable = None): 69 | """ 70 | Performs a single optimization step. 71 | 72 | Arguments: 73 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. 74 | """ 75 | loss = None 76 | if closure is not None: 77 | loss = closure() 78 | 79 | for group in self.param_groups: 80 | for p in group["params"]: 81 | if p.grad is None: 82 | continue 83 | grad = p.grad 84 | if grad.is_sparse: 85 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 86 | 87 | state = self.state[p] 88 | 89 | if "step" not in state: 90 | state["step"] = 0 91 | 92 | if 'dim' not in group: 93 | group['dim'] = 2 94 | 95 | # GaLore Projection 96 | if "rank" in group: 97 | if "projector" not in state: 98 | if group['dim'] <=2: 99 | state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) 100 | else: 101 | state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) 102 | grad = state["projector"].project(grad, state["step"]) 103 | 104 | # State initialization 105 | if "exp_avg" not in state: 106 | # Exponential moving average of gradient values 107 | state["exp_avg"] = torch.zeros_like(grad) 108 | # Exponential moving average of squared gradient values 109 | state["exp_avg_sq"] = torch.zeros_like(grad) 110 | 111 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 112 | beta1, beta2 = group["betas"] 113 | 114 | state["step"] += 1 115 | 116 | # Decay the first and second moment running average coefficient 117 | # In-place operations to update the averages at the same time 118 | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) 119 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 120 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 121 | 122 | step_size = group["lr"] 123 | if group["correct_bias"]: # No bias correction for Bert 124 | bias_correction1 = 1.0 - beta1 ** state["step"] 125 | bias_correction2 = 1.0 - beta2 ** state["step"] 126 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 127 | 128 | # compute norm gradient 129 | norm_grad = exp_avg / denom 130 | 131 | # GaLore Projection Back 132 | if "rank" in group: 133 | norm_grad = state["projector"].project_back(norm_grad) 134 | 135 | p.add_(norm_grad, alpha=-step_size) 136 | 137 | # Just adding the square of the weights to the loss function is *not* 138 | # the correct way of using L2 regularization/weight decay with Adam, 139 | # since that will interact with the m and v parameters in strange ways. 140 | # 141 | # Instead we want to decay the weights in a manner that doesn't interact 142 | # with the m/v parameters. This is equivalent to adding the square 143 | # of the weights to the loss with plain (non-momentum) SGD. 144 | # Add weight decay at the end (fixed version) 145 | if group["weight_decay"] > 0.0: 146 | p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) 147 | 148 | return loss 149 | -------------------------------------------------------------------------------- /galore_torch/adamw8bit.py: -------------------------------------------------------------------------------- 1 | from bitsandbytes.optim.optimizer import Optimizer2State 2 | 3 | import torch 4 | 5 | from .galore_projector import GaLoreProjector 6 | from .galore_projector_tensor import GaLoreProjectorTensor 7 | 8 | 9 | class AdamW8bit(Optimizer2State): 10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 11 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) 12 | 13 | @torch.no_grad() 14 | def step(self, closure=None): 15 | """Performs a single optimization step. 16 | 17 | Arguments: 18 | closure (callable, optional): A closure that reevaluates the model 19 | and returns the loss. 20 | """ 21 | loss = None 22 | if closure is not None: 23 | with torch.enable_grad(): 24 | loss = closure() 25 | 26 | overflows = [] 27 | 28 | if not self.initialized: 29 | self.check_overrides() 30 | self.to_gpu() # needed for fairseq pure fp16 training 31 | self.initialized = True 32 | 33 | #if self.is_paged: self.page_mng.prefetch_all() 34 | for gindex, group in enumerate(self.param_groups): 35 | for pindex, p in enumerate(group["params"]): 36 | if p.grad is None: 37 | continue 38 | state = self.state[p] 39 | 40 | if "step" not in state: 41 | state["step"] = 0 42 | 43 | if 'dim' not in group: 44 | group['dim'] = 2 45 | 46 | # GaLore Projection 47 | if "rank" in group: 48 | if "projector" not in state: 49 | if group['dim'] <= 2: 50 | state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) 51 | else: 52 | state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) 53 | if 'weight_decay' in group and group['weight_decay'] > 0: 54 | # ensure that the weight decay is not applied to the norm grad 55 | group['weight_decay_saved'] = group['weight_decay'] 56 | group['weight_decay'] = 0 57 | 58 | grad = state["projector"].project(p.grad, state["step"]) 59 | 60 | # suboptimal implementation 61 | p.saved_data = p.data.clone() 62 | p.data = grad.clone().to(p.data.dtype).to(p.data.device) 63 | p.data.zero_() 64 | p.grad = grad 65 | 66 | if 'state1' not in state: 67 | self.init_state(group, p, gindex, pindex) 68 | 69 | self.prefetch_state(p) 70 | self.update_step(group, p, gindex, pindex) 71 | torch.cuda.synchronize() 72 | 73 | # GaLore Projection Back 74 | if "rank" in group: 75 | p.data = p.saved_data.add_(state["projector"].project_back(p.data)) 76 | 77 | # apply weight decay 78 | if 'weight_decay_saved' in group: 79 | p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay_saved']) 80 | group['weight_decay'] = group['weight_decay_saved'] 81 | del group['weight_decay_saved'] 82 | 83 | if self.is_paged: 84 | # all paged operation are asynchronous, we need 85 | # to sync to make sure all tensors are in the right state 86 | torch.cuda.synchronize() 87 | 88 | 89 | return loss -------------------------------------------------------------------------------- /galore_torch/galore_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class GaLoreProjector: 4 | def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'): 5 | self.rank = rank 6 | self.verbose = verbose 7 | self.update_proj_gap = update_proj_gap 8 | self.scale = scale 9 | self.ortho_matrix = None 10 | self.proj_type = proj_type 11 | 12 | def project(self, full_rank_grad, iter): 13 | if self.proj_type == 'std': 14 | if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: 15 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 16 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') 17 | low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t().to(full_rank_grad.device.type)) 18 | else: 19 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 20 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') 21 | low_rank_grad = torch.matmul(self.ortho_matrix.t().to(full_rank_grad.device.type), full_rank_grad) 22 | elif self.proj_type == 'reverse_std': 23 | if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: 24 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 25 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') 26 | low_rank_grad = torch.matmul(self.ortho_matrix.t().to(full_rank_grad.device.type),full_rank_grad) 27 | else: 28 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 29 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') 30 | low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t().to(full_rank_grad.device.type)) 31 | elif self.proj_type == 'right': 32 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 33 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') 34 | low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t().to(full_rank_grad.device.type)) 35 | elif self.proj_type == 'left': 36 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 37 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') 38 | low_rank_grad = torch.matmul(self.ortho_matrix.t().to(full_rank_grad.device.type), full_rank_grad) 39 | elif self.proj_type == 'full': 40 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 41 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full') 42 | low_rank_grad = torch.matmul(self.ortho_matrix[0].t().to(full_rank_grad.device.type), full_rank_grad) @ self.ortho_matrix[1].t().to(full_rank_grad.device.type) 43 | 44 | return low_rank_grad 45 | 46 | def project_back(self, low_rank_grad): 47 | if self.proj_type == 'std': 48 | if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: 49 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix.to(low_rank_grad.device.type)) 50 | else: 51 | full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad) 52 | elif self.proj_type == 'reverse_std': 53 | if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std 54 | full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad) 55 | else: 56 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix.to(low_rank_grad.device.type)) 57 | elif self.proj_type == 'right': 58 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix.to(low_rank_grad.device.type)) 59 | elif self.proj_type == 'left': 60 | full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad) 61 | elif self.proj_type == 'full': 62 | full_rank_grad = torch.matmul(self.ortho_matrix[0].to(low_rank_grad.device.type), low_rank_grad) @ self.ortho_matrix[1].to(low_rank_grad.device.type) 63 | 64 | 65 | return full_rank_grad * self.scale 66 | 67 | 68 | # svd decomposition 69 | def get_orthogonal_matrix(self, weights, rank, type): 70 | module_params = weights 71 | 72 | if module_params.data.dtype != torch.float: 73 | float_data = False 74 | original_type = module_params.data.dtype 75 | original_device = module_params.data.device 76 | matrix = module_params.data.float() 77 | else: 78 | float_data = True 79 | matrix = module_params.data 80 | 81 | U, s, Vh = torch.linalg.svd(matrix, full_matrices = False) 82 | 83 | #make the smaller matrix always to be orthogonal matrix 84 | if type=='right': 85 | B = Vh[:rank, :] 86 | if not float_data: 87 | B = B.to(original_device).type(original_type) 88 | return B 89 | elif type=='left': 90 | A = U[:, :rank] 91 | if not float_data: 92 | A = A.to(original_device).type(original_type) 93 | return A 94 | elif type=='full': 95 | A = U[:, :rank] 96 | B = Vh[:rank, :] 97 | if not float_data: 98 | A = A.to(original_device).type(original_type) 99 | B = B.to(original_device).type(original_type) 100 | return [A, B] 101 | else: 102 | raise ValueError('type should be left, right or full') 103 | -------------------------------------------------------------------------------- /galore_torch/galore_projector_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tensorly.decomposition import tucker 3 | from tensorly import tenalg 4 | 5 | # The GaLoreProjector class in Python implements a projection method using orthogonal matrix 6 | # decomposition for low-rank approximation of gradients for general tensors of dimension >2. 7 | # We use tensor decomposition using tensorly library: https://tensorly.org/stable/index.html 8 | class GaLoreProjectorTensor: 9 | """ 10 | A class that represents a projector for the GaLore algorithm. 11 | 12 | Args: 13 | rank (int): The rank of the projector. 14 | verbose (bool, optional): Whether to print verbose output. Defaults to False. 15 | update_proj_gap (int, optional): The number of iterations between updating the orthogonal matrix. Defaults to 200. 16 | scale (float, optional): The scaling factor for the projected gradients. Defaults to 1.0. 17 | """ 18 | 19 | def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0): 20 | self.rank = rank 21 | self.verbose = verbose 22 | self.update_proj_gap = update_proj_gap 23 | self.scale = scale 24 | self.ortho_matrix = None 25 | self.transformed_low_rank = None 26 | 27 | def project(self, full_rank_grad, iter): 28 | """ 29 | Projects the full-rank gradients onto the low-rank subspace. 30 | 31 | Args: 32 | full_rank_grad (torch.Tensor): The full-rank gradients. 33 | iter (int): The current iteration. 34 | 35 | Returns: 36 | torch.Tensor: The transformed low-rank gradients. 37 | """ 38 | if self.ortho_matrix is None and iter % self.update_proj_gap == 0: 39 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank) 40 | self.transformed_low_rank = self.transform(self.ortho_matrix, full_rank_grad) 41 | return self.transformed_low_rank 42 | 43 | def project_back(self, low_rank_grad): 44 | """ 45 | Projects the low-rank gradients back to the full-rank space. 46 | 47 | Args: 48 | low_rank_grad (torch.Tensor): The low-rank gradients. 49 | 50 | Returns: 51 | torch.Tensor: The full-rank gradients. 52 | """ 53 | full_rank_grad = self.inverse_transform(self.ortho_matrix, self.transformed_low_rank) 54 | return full_rank_grad * self.scale 55 | 56 | # svd decomposition 57 | def get_orthogonal_matrix(self, weights, rank_all): 58 | """ 59 | Computes the orthogonal matrix using SVD decomposition. 60 | 61 | Args: 62 | weights (torch.Tensor): The weights to decompose. 63 | rank_all (int): The desired rank of the decomposition. 64 | 65 | Returns: 66 | tuple: A tuple containing the core and factors of the orthogonal matrix. 67 | """ 68 | module_params = weights 69 | if module_params.data.dtype != torch.float: 70 | matrix = module_params.data.float() 71 | else: 72 | matrix = module_params.data 73 | tucker_tensor = tucker(matrix, rank=rank_all) 74 | return tucker_tensor 75 | 76 | def transform(self, tensor, x): 77 | """ 78 | Transforms the input tensor using the factors of the orthogonal matrix. 79 | 80 | Args: 81 | tensor (tuple): A tuple containing the core and factors of the orthogonal matrix. 82 | x (torch.Tensor): The input tensor. 83 | 84 | Returns: 85 | torch.Tensor: The transformed tensor. 86 | """ 87 | _, factors = tensor 88 | return tenalg.multi_mode_dot(x, factors, transpose=True) 89 | 90 | def inverse_transform(self, tensor, x): 91 | """ 92 | Inverse transforms the input tensor using the factors of the orthogonal matrix. 93 | 94 | Args: 95 | tensor (tuple): A tuple containing the core and factors of the orthogonal matrix. 96 | x (torch.Tensor): The input tensor. 97 | 98 | Returns: 99 | torch.Tensor: The inverse transformed tensor. 100 | """ 101 | _, factors = tensor 102 | return tenalg.multi_mode_dot(x, factors) 103 | -------------------------------------------------------------------------------- /imgs/galore_code_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaweizzhao/GaLore/2cc66f88cce189e505affbb91042a8e77f5bf4e9/imgs/galore_code_box.png -------------------------------------------------------------------------------- /imgs/subspace_learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaweizzhao/GaLore/2cc66f88cce189e505affbb91042a8e77f5bf4e9/imgs/subspace_learning.png -------------------------------------------------------------------------------- /peft_pretraining/args_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | from loguru import logger 5 | 6 | 7 | def check_args_torchrun_main(args): 8 | 9 | if args.save_dir is None: 10 | # use checkpoints / model name, date and time as save directory 11 | args.save_dir = f"checkpoints/{args.model_config.split('/')[-1].rstrip('.json')}-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" 12 | 13 | if args.tags is not None: 14 | args.tags = args.tags.split(",") 15 | 16 | if args.total_batch_size is None: 17 | args.gradient_accumulation = args.gradient_accumulation or 1 18 | args.total_batch_size = args.batch_size * args.gradient_accumulation 19 | 20 | assert args.total_batch_size % args.batch_size == 0, "total_batch_size must be divisible by batch_size" 21 | 22 | if args.max_train_tokens is not None: 23 | args.num_training_steps = args.max_train_tokens // args.total_batch_size 24 | logger.info(f"Training for {args.num_training_steps} update steps") 25 | 26 | if args.continue_from is not None: 27 | assert os.path.exists(args.continue_from), f"--continue_from={args.continue_from} does not exist" 28 | 29 | if args.dtype in ["fp16", "float16"]: 30 | raise NotImplementedError("fp16 is not supported in torchrun_main.py. Use deepspeed_main.py instead (but it seems to have bugs)") 31 | 32 | return args 33 | -------------------------------------------------------------------------------- /peft_pretraining/dataloader.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | from torch.utils.data import IterableDataset, get_worker_info 5 | 6 | 7 | class PreprocessedIterableDataset(IterableDataset): 8 | def __init__(self, data, tokenizer, batch_size, max_length): 9 | super().__init__() 10 | self.data = data 11 | self.tokenizer = tokenizer 12 | self.batch_size = batch_size 13 | self.max_length = max_length 14 | 15 | def __iter__(self): 16 | worker_info = get_worker_info() 17 | if worker_info is None: 18 | # If no worker_info is provided, we are not using DataLoader workers, so yield all data 19 | iter_data = iter(self.data) 20 | else: 21 | # If using DataLoader workers, yield a subset of the data for this worker 22 | worker_id = worker_info.id 23 | num_workers = worker_info.num_workers 24 | iter_data = itertools.islice(self.data, worker_id, None, num_workers) 25 | 26 | batch = [] 27 | for example in iter_data: 28 | tokenized_example = self.tokenizer( 29 | example["text"], 30 | max_length=self.max_length, 31 | truncation=True, 32 | padding="max_length", 33 | return_tensors="pt", 34 | ) 35 | batch.append(tokenized_example) 36 | 37 | if len(batch) == self.batch_size: 38 | yield self._format_batch(batch) 39 | batch = [] 40 | 41 | if batch: 42 | yield self._format_batch(batch) 43 | 44 | def _format_batch(self, batch): 45 | input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) 46 | attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) 47 | 48 | return {"input_ids": input_ids, "attention_mask": attention_mask} 49 | -------------------------------------------------------------------------------- /peft_pretraining/modeling_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | import math 22 | from typing import List, Optional, Tuple, Union 23 | 24 | import torch 25 | import torch.utils.checkpoint 26 | from torch import nn 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 28 | 29 | from transformers.activations import ACT2FN 30 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 31 | from transformers.modeling_utils import PreTrainedModel 32 | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings 33 | from transformers.models.llama.configuration_llama import LlamaConfig 34 | 35 | 36 | logger = logging.get_logger(__name__) 37 | 38 | _CONFIG_FOR_DOC = "LlamaConfig" 39 | 40 | 41 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 42 | def _make_causal_mask( 43 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 44 | ): 45 | """ 46 | Make causal mask used for bi-directional self-attention. 47 | """ 48 | bsz, tgt_len = input_ids_shape 49 | mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) 50 | mask_cond = torch.arange(mask.size(-1), device=device) 51 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 52 | mask = mask.to(dtype) 53 | 54 | if past_key_values_length > 0: 55 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 56 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 57 | 58 | 59 | # Copied from transformers.models.bart.modeling_bart._expand_mask 60 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 61 | """ 62 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 63 | """ 64 | bsz, src_len = mask.size() 65 | tgt_len = tgt_len if tgt_len is not None else src_len 66 | 67 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 68 | 69 | inverted_mask = 1.0 - expanded_mask 70 | 71 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 72 | 73 | 74 | class LlamaRMSNorm(nn.Module): 75 | def __init__(self, hidden_size, eps=1e-6): 76 | """ 77 | LlamaRMSNorm is equivalent to T5LayerNorm 78 | """ 79 | super().__init__() 80 | self.weight = nn.Parameter(torch.ones(hidden_size)) 81 | self.variance_epsilon = eps 82 | 83 | def forward(self, hidden_states): 84 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 85 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 86 | 87 | # convert into half-precision if necessary 88 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 89 | hidden_states = hidden_states.to(self.weight.dtype) 90 | 91 | return self.weight * hidden_states 92 | 93 | 94 | class LlamaRotaryEmbedding(torch.nn.Module): 95 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 96 | super().__init__() 97 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 98 | self.register_buffer("inv_freq", inv_freq) 99 | 100 | # Build here to make `torch.jit.trace` work. 101 | self.max_seq_len_cached = max_position_embeddings 102 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 103 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 104 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 105 | emb = torch.cat((freqs, freqs), dim=-1) 106 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 107 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 108 | 109 | def forward(self, x, seq_len=None): 110 | # x: [bs, num_attention_heads, seq_len, head_size] 111 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 112 | if seq_len > self.max_seq_len_cached: 113 | self.max_seq_len_cached = seq_len 114 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 115 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 116 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 117 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 118 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 119 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 120 | return ( 121 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 122 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 123 | ) 124 | 125 | 126 | def rotate_half(x): 127 | """Rotates half the hidden dims of the input.""" 128 | x1 = x[..., : x.shape[-1] // 2] 129 | x2 = x[..., x.shape[-1] // 2 :] 130 | return torch.cat((-x2, x1), dim=-1) 131 | 132 | 133 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 134 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 135 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 136 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 137 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 138 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 139 | q_embed = (q * cos) + (rotate_half(q) * sin) 140 | k_embed = (k * cos) + (rotate_half(k) * sin) 141 | return q_embed, k_embed 142 | 143 | 144 | class LlamaMLP(nn.Module): 145 | def __init__( 146 | self, 147 | hidden_size: int, 148 | intermediate_size: int, 149 | hidden_act: str, 150 | ): 151 | super().__init__() 152 | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 153 | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) 154 | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 155 | self.act_fn = ACT2FN[hidden_act] 156 | 157 | def forward(self, x): 158 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 159 | 160 | 161 | class LlamaAttention(nn.Module): 162 | """Multi-headed attention from 'Attention Is All You Need' paper""" 163 | 164 | def __init__(self, config: LlamaConfig): 165 | super().__init__() 166 | self.config = config 167 | self.hidden_size = config.hidden_size 168 | self.num_heads = config.num_attention_heads 169 | self.head_dim = self.hidden_size // self.num_heads 170 | self.max_position_embeddings = config.max_position_embeddings 171 | 172 | if (self.head_dim * self.num_heads) != self.hidden_size: 173 | raise ValueError( 174 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 175 | f" and `num_heads`: {self.num_heads})." 176 | ) 177 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 178 | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 179 | self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 180 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 181 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) 182 | 183 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 184 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 185 | 186 | def forward( 187 | self, 188 | hidden_states: torch.Tensor, 189 | attention_mask: Optional[torch.Tensor] = None, 190 | position_ids: Optional[torch.LongTensor] = None, 191 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 192 | output_attentions: bool = False, 193 | use_cache: bool = False, 194 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 195 | bsz, q_len, _ = hidden_states.size() 196 | 197 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 198 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 199 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 200 | 201 | kv_seq_len = key_states.shape[-2] 202 | if past_key_value is not None: 203 | kv_seq_len += past_key_value[0].shape[-2] 204 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 205 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 206 | # [bsz, nh, t, hd] 207 | 208 | if past_key_value is not None: 209 | # reuse k, v, self_attention 210 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 211 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 212 | 213 | past_key_value = (key_states, value_states) if use_cache else None 214 | 215 | if attention_mask is not None: 216 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 217 | raise ValueError( 218 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 219 | ) 220 | 221 | # WARNING: padding mask is ignored, causal is always applied 222 | attn_output = torch.nn.functional.scaled_dot_product_attention( 223 | query_states, key_states, value_states, dropout_p=0.0, is_causal=True, 224 | ) 225 | 226 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 227 | raise ValueError( 228 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 229 | f" {attn_output.size()}" 230 | ) 231 | 232 | attn_output = attn_output.transpose(1, 2) 233 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 234 | 235 | attn_output = self.o_proj(attn_output) 236 | 237 | if not output_attentions: 238 | attn_weights = None 239 | 240 | return attn_output, attn_weights, past_key_value 241 | 242 | 243 | class LlamaDecoderLayer(nn.Module): 244 | def __init__(self, config: LlamaConfig): 245 | super().__init__() 246 | self.hidden_size = config.hidden_size 247 | self.self_attn = LlamaAttention(config=config) 248 | self.mlp = LlamaMLP( 249 | hidden_size=self.hidden_size, 250 | intermediate_size=config.intermediate_size, 251 | hidden_act=config.hidden_act, 252 | ) 253 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 254 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 255 | 256 | def forward( 257 | self, 258 | hidden_states: torch.Tensor, 259 | attention_mask: Optional[torch.Tensor] = None, 260 | position_ids: Optional[torch.LongTensor] = None, 261 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 262 | output_attentions: Optional[bool] = False, 263 | use_cache: Optional[bool] = False, 264 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 265 | """ 266 | Args: 267 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 268 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 269 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 270 | output_attentions (`bool`, *optional*): 271 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 272 | returned tensors for more detail. 273 | use_cache (`bool`, *optional*): 274 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 275 | (see `past_key_values`). 276 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 277 | """ 278 | 279 | residual = hidden_states 280 | 281 | hidden_states = self.input_layernorm(hidden_states) 282 | 283 | # Self Attention 284 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 285 | hidden_states=hidden_states, 286 | attention_mask=attention_mask, 287 | position_ids=position_ids, 288 | past_key_value=past_key_value, 289 | output_attentions=output_attentions, 290 | use_cache=use_cache, 291 | ) 292 | hidden_states = residual + hidden_states 293 | 294 | # Fully Connected 295 | residual = hidden_states 296 | hidden_states = self.post_attention_layernorm(hidden_states) 297 | hidden_states = self.mlp(hidden_states) 298 | hidden_states = residual + hidden_states 299 | 300 | outputs = (hidden_states,) 301 | 302 | if output_attentions: 303 | outputs += (self_attn_weights,) 304 | 305 | if use_cache: 306 | outputs += (present_key_value,) 307 | 308 | return outputs 309 | 310 | 311 | LLAMA_START_DOCSTRING = r""" 312 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 313 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 314 | etc.) 315 | 316 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 317 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 318 | and behavior. 319 | 320 | Parameters: 321 | config ([`LlamaConfig`]): 322 | Model configuration class with all the parameters of the model. Initializing with a config file does not 323 | load the weights associated with the model, only the configuration. Check out the 324 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 325 | """ 326 | 327 | 328 | @add_start_docstrings( 329 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 330 | LLAMA_START_DOCSTRING, 331 | ) 332 | class LlamaPreTrainedModel(PreTrainedModel): 333 | config_class = LlamaConfig 334 | base_model_prefix = "model" 335 | supports_gradient_checkpointing = True 336 | _no_split_modules = ["LlamaDecoderLayer"] 337 | _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] 338 | 339 | def _init_weights(self, module): 340 | std = self.config.initializer_range 341 | if isinstance(module, nn.Linear): 342 | module.weight.data.normal_(mean=0.0, std=std) 343 | if module.bias is not None: 344 | module.bias.data.zero_() 345 | elif isinstance(module, nn.Embedding): 346 | module.weight.data.normal_(mean=0.0, std=std) 347 | if module.padding_idx is not None: 348 | module.weight.data[module.padding_idx].zero_() 349 | 350 | def _set_gradient_checkpointing(self, module, value=False): 351 | if isinstance(module, LlamaModel): 352 | module.gradient_checkpointing = value 353 | 354 | 355 | LLAMA_INPUTS_DOCSTRING = r""" 356 | Args: 357 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 358 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 359 | it. 360 | 361 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 362 | [`PreTrainedTokenizer.__call__`] for details. 363 | 364 | [What are input IDs?](../glossary#input-ids) 365 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 366 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 367 | 368 | - 1 for tokens that are **not masked**, 369 | - 0 for tokens that are **masked**. 370 | 371 | [What are attention masks?](../glossary#attention-mask) 372 | 373 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 374 | [`PreTrainedTokenizer.__call__`] for details. 375 | 376 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 377 | `past_key_values`). 378 | 379 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 380 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 381 | information on the default strategy. 382 | 383 | - 1 indicates the head is **not masked**, 384 | - 0 indicates the head is **masked**. 385 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 386 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 387 | config.n_positions - 1]`. 388 | 389 | [What are position IDs?](../glossary#position-ids) 390 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 391 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 392 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 393 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 394 | 395 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 396 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 397 | 398 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 399 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 400 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 401 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 402 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 403 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 404 | model's internal embedding lookup matrix. 405 | use_cache (`bool`, *optional*): 406 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 407 | `past_key_values`). 408 | output_attentions (`bool`, *optional*): 409 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 410 | tensors for more detail. 411 | output_hidden_states (`bool`, *optional*): 412 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 413 | more detail. 414 | return_dict (`bool`, *optional*): 415 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 416 | """ 417 | 418 | 419 | @add_start_docstrings( 420 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 421 | LLAMA_START_DOCSTRING, 422 | ) 423 | class LlamaModel(LlamaPreTrainedModel): 424 | """ 425 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 426 | 427 | Args: 428 | config: LlamaConfig 429 | """ 430 | 431 | def __init__(self, config: LlamaConfig): 432 | super().__init__(config) 433 | self.padding_idx = config.pad_token_id 434 | self.vocab_size = config.vocab_size 435 | 436 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 437 | self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 438 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 439 | 440 | self.gradient_checkpointing = False 441 | # Initialize weights and apply final processing 442 | self.post_init() 443 | 444 | def get_input_embeddings(self): 445 | return self.embed_tokens 446 | 447 | def set_input_embeddings(self, value): 448 | self.embed_tokens = value 449 | 450 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 451 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 452 | # create causal mask 453 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 454 | combined_attention_mask = None 455 | if input_shape[-1] > 1: 456 | combined_attention_mask = _make_causal_mask( 457 | input_shape, 458 | inputs_embeds.dtype, 459 | device=inputs_embeds.device, 460 | past_key_values_length=past_key_values_length, 461 | ) 462 | 463 | if attention_mask is not None: 464 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 465 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 466 | inputs_embeds.device 467 | ) 468 | combined_attention_mask = ( 469 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 470 | ) 471 | 472 | return combined_attention_mask 473 | 474 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 475 | def forward( 476 | self, 477 | input_ids: torch.LongTensor = None, 478 | attention_mask: Optional[torch.Tensor] = None, 479 | position_ids: Optional[torch.LongTensor] = None, 480 | past_key_values: Optional[List[torch.FloatTensor]] = None, 481 | inputs_embeds: Optional[torch.FloatTensor] = None, 482 | use_cache: Optional[bool] = None, 483 | output_attentions: Optional[bool] = None, 484 | output_hidden_states: Optional[bool] = None, 485 | return_dict: Optional[bool] = None, 486 | ) -> Union[Tuple, BaseModelOutputWithPast]: 487 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 488 | output_hidden_states = ( 489 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 490 | ) 491 | use_cache = use_cache if use_cache is not None else self.config.use_cache 492 | 493 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 494 | 495 | # retrieve input_ids and inputs_embeds 496 | if input_ids is not None and inputs_embeds is not None: 497 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 498 | elif input_ids is not None: 499 | batch_size, seq_length = input_ids.shape 500 | elif inputs_embeds is not None: 501 | batch_size, seq_length, _ = inputs_embeds.shape 502 | else: 503 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 504 | 505 | seq_length_with_past = seq_length 506 | past_key_values_length = 0 507 | 508 | if past_key_values is not None: 509 | past_key_values_length = past_key_values[0][0].shape[2] 510 | seq_length_with_past = seq_length_with_past + past_key_values_length 511 | 512 | if position_ids is None: 513 | device = input_ids.device if input_ids is not None else inputs_embeds.device 514 | position_ids = torch.arange( 515 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 516 | ) 517 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 518 | else: 519 | position_ids = position_ids.view(-1, seq_length).long() 520 | 521 | if inputs_embeds is None: 522 | inputs_embeds = self.embed_tokens(input_ids) 523 | # embed positions 524 | if attention_mask is None: 525 | attention_mask = torch.ones( 526 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device 527 | ) 528 | attention_mask = self._prepare_decoder_attention_mask( 529 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 530 | ) 531 | 532 | hidden_states = inputs_embeds 533 | 534 | if self.gradient_checkpointing and self.training: 535 | if use_cache: 536 | logger.warning_once( 537 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 538 | ) 539 | use_cache = False 540 | 541 | # decoder layers 542 | all_hidden_states = () if output_hidden_states else None 543 | all_self_attns = () if output_attentions else None 544 | next_decoder_cache = () if use_cache else None 545 | 546 | for idx, decoder_layer in enumerate(self.layers): 547 | if output_hidden_states: 548 | all_hidden_states += (hidden_states,) 549 | 550 | past_key_value = past_key_values[idx] if past_key_values is not None else None 551 | 552 | if self.gradient_checkpointing and self.training: 553 | 554 | def create_custom_forward(module): 555 | def custom_forward(*inputs): 556 | # None for past_key_value 557 | return module(*inputs, output_attentions, None) 558 | 559 | return custom_forward 560 | 561 | layer_outputs = torch.utils.checkpoint.checkpoint( 562 | create_custom_forward(decoder_layer), 563 | hidden_states, 564 | attention_mask, 565 | position_ids, 566 | None, 567 | ) 568 | else: 569 | layer_outputs = decoder_layer( 570 | hidden_states, 571 | attention_mask=attention_mask, 572 | position_ids=position_ids, 573 | past_key_value=past_key_value, 574 | output_attentions=output_attentions, 575 | use_cache=use_cache, 576 | ) 577 | 578 | hidden_states = layer_outputs[0] 579 | 580 | if use_cache: 581 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 582 | 583 | if output_attentions: 584 | all_self_attns += (layer_outputs[1],) 585 | 586 | hidden_states = self.norm(hidden_states) 587 | 588 | # add hidden states from the last decoder layer 589 | if output_hidden_states: 590 | all_hidden_states += (hidden_states,) 591 | 592 | next_cache = next_decoder_cache if use_cache else None 593 | if not return_dict: 594 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 595 | return BaseModelOutputWithPast( 596 | last_hidden_state=hidden_states, 597 | past_key_values=next_cache, 598 | hidden_states=all_hidden_states, 599 | attentions=all_self_attns, 600 | ) 601 | 602 | 603 | class LlamaForCausalLM(LlamaPreTrainedModel): 604 | def __init__(self, config): 605 | super().__init__(config) 606 | self.model = LlamaModel(config) 607 | 608 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 609 | 610 | # Initialize weights and apply final processing 611 | self.post_init() 612 | 613 | def get_input_embeddings(self): 614 | return self.model.embed_tokens 615 | 616 | def set_input_embeddings(self, value): 617 | self.model.embed_tokens = value 618 | 619 | def get_output_embeddings(self): 620 | return self.lm_head 621 | 622 | def set_output_embeddings(self, new_embeddings): 623 | self.lm_head = new_embeddings 624 | 625 | def set_decoder(self, decoder): 626 | self.model = decoder 627 | 628 | def get_decoder(self): 629 | return self.model 630 | 631 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 632 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 633 | def forward( 634 | self, 635 | input_ids: torch.LongTensor = None, 636 | attention_mask: Optional[torch.Tensor] = None, 637 | position_ids: Optional[torch.LongTensor] = None, 638 | past_key_values: Optional[List[torch.FloatTensor]] = None, 639 | inputs_embeds: Optional[torch.FloatTensor] = None, 640 | labels: Optional[torch.LongTensor] = None, 641 | use_cache: Optional[bool] = None, 642 | output_attentions: Optional[bool] = None, 643 | output_hidden_states: Optional[bool] = None, 644 | return_dict: Optional[bool] = None, 645 | ) -> Union[Tuple, CausalLMOutputWithPast]: 646 | r""" 647 | Args: 648 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 649 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 650 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 651 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 652 | 653 | Returns: 654 | 655 | Example: 656 | 657 | ```python 658 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 659 | 660 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 661 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 662 | 663 | >>> prompt = "Hey, are you consciours? Can you talk to me?" 664 | >>> inputs = tokenizer(prompt, return_tensors="pt") 665 | 666 | >>> # Generate 667 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 668 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 669 | "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." 670 | ```""" 671 | 672 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 673 | output_hidden_states = ( 674 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 675 | ) 676 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 677 | 678 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 679 | outputs = self.model( 680 | input_ids=input_ids, 681 | attention_mask=attention_mask, 682 | position_ids=position_ids, 683 | past_key_values=past_key_values, 684 | inputs_embeds=inputs_embeds, 685 | use_cache=use_cache, 686 | output_attentions=output_attentions, 687 | output_hidden_states=output_hidden_states, 688 | return_dict=return_dict, 689 | ) 690 | 691 | hidden_states = outputs[0] 692 | logits = self.lm_head(hidden_states) 693 | 694 | loss = None 695 | if labels is not None: 696 | # NOTE: big optimization could be done here (?) 697 | # maybe the copy operation that you saw in the debugger was happening here 698 | 699 | # Shift so that tokens < n predict n 700 | shift_logits = logits[..., :-1, :].contiguous() 701 | shift_labels = labels[..., 1:].contiguous() 702 | # Flatten the tokens 703 | loss_fct = CrossEntropyLoss() 704 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 705 | shift_labels = shift_labels.view(-1) 706 | # Enable model parallelism 707 | shift_labels = shift_labels.to(shift_logits.device) 708 | loss = loss_fct(shift_logits, shift_labels) 709 | 710 | if not return_dict: 711 | output = (logits,) + outputs[1:] 712 | return (loss,) + output if loss is not None else output 713 | 714 | return CausalLMOutputWithPast( 715 | loss=loss, 716 | logits=logits, 717 | past_key_values=outputs.past_key_values, 718 | hidden_states=outputs.hidden_states, 719 | attentions=outputs.attentions, 720 | ) 721 | 722 | def prepare_inputs_for_generation( 723 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 724 | ): 725 | if past_key_values: 726 | input_ids = input_ids[:, -1:] 727 | 728 | position_ids = kwargs.get("position_ids", None) 729 | if attention_mask is not None and position_ids is None: 730 | # create position_ids on the fly for batch generation 731 | position_ids = attention_mask.long().cumsum(-1) - 1 732 | position_ids.masked_fill_(attention_mask == 0, 1) 733 | if past_key_values: 734 | position_ids = position_ids[:, -1].unsqueeze(-1) 735 | 736 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 737 | if inputs_embeds is not None and past_key_values is None: 738 | model_inputs = {"inputs_embeds": inputs_embeds} 739 | else: 740 | model_inputs = {"input_ids": input_ids} 741 | 742 | model_inputs.update( 743 | { 744 | "position_ids": position_ids, 745 | "past_key_values": past_key_values, 746 | "use_cache": kwargs.get("use_cache"), 747 | "attention_mask": attention_mask, 748 | } 749 | ) 750 | return model_inputs 751 | 752 | @staticmethod 753 | def _reorder_cache(past_key_values, beam_idx): 754 | reordered_past = () 755 | for layer_past in past_key_values: 756 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 757 | return reordered_past 758 | 759 | 760 | @add_start_docstrings( 761 | """ 762 | The LLaMa Model transformer with a sequence classification head on top (linear layer). 763 | 764 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models 765 | (e.g. GPT-2) do. 766 | 767 | Since it does classification on the last token, it requires to know the position of the last token. If a 768 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 769 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 770 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 771 | each row of the batch). 772 | """, 773 | LLAMA_START_DOCSTRING, 774 | ) 775 | class LlamaForSequenceClassification(LlamaPreTrainedModel): 776 | _keys_to_ignore_on_load_missing = [r"lm_head.weight"] 777 | 778 | def __init__(self, config): 779 | super().__init__(config) 780 | self.num_labels = config.num_labels 781 | self.model = LlamaModel(config) 782 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 783 | 784 | # Initialize weights and apply final processing 785 | self.post_init() 786 | 787 | def get_input_embeddings(self): 788 | return self.model.embed_tokens 789 | 790 | def set_input_embeddings(self, value): 791 | self.model.embed_tokens = value 792 | 793 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 794 | def forward( 795 | self, 796 | input_ids: torch.LongTensor = None, 797 | attention_mask: Optional[torch.Tensor] = None, 798 | position_ids: Optional[torch.LongTensor] = None, 799 | past_key_values: Optional[List[torch.FloatTensor]] = None, 800 | inputs_embeds: Optional[torch.FloatTensor] = None, 801 | labels: Optional[torch.LongTensor] = None, 802 | use_cache: Optional[bool] = None, 803 | output_attentions: Optional[bool] = None, 804 | output_hidden_states: Optional[bool] = None, 805 | return_dict: Optional[bool] = None, 806 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 807 | r""" 808 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 809 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 810 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 811 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 812 | """ 813 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 814 | 815 | transformer_outputs = self.model( 816 | input_ids, 817 | attention_mask=attention_mask, 818 | position_ids=position_ids, 819 | past_key_values=past_key_values, 820 | inputs_embeds=inputs_embeds, 821 | use_cache=use_cache, 822 | output_attentions=output_attentions, 823 | output_hidden_states=output_hidden_states, 824 | return_dict=return_dict, 825 | ) 826 | hidden_states = transformer_outputs[0] 827 | logits = self.score(hidden_states) 828 | 829 | if input_ids is not None: 830 | batch_size = input_ids.shape[0] 831 | else: 832 | batch_size = inputs_embeds.shape[0] 833 | 834 | if self.config.pad_token_id is None and batch_size != 1: 835 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 836 | if self.config.pad_token_id is None: 837 | sequence_lengths = -1 838 | else: 839 | if input_ids is not None: 840 | sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) 841 | else: 842 | sequence_lengths = -1 843 | 844 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 845 | 846 | loss = None 847 | if labels is not None: 848 | labels = labels.to(logits.device) 849 | if self.config.problem_type is None: 850 | if self.num_labels == 1: 851 | self.config.problem_type = "regression" 852 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 853 | self.config.problem_type = "single_label_classification" 854 | else: 855 | self.config.problem_type = "multi_label_classification" 856 | 857 | if self.config.problem_type == "regression": 858 | loss_fct = MSELoss() 859 | if self.num_labels == 1: 860 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 861 | else: 862 | loss = loss_fct(pooled_logits, labels) 863 | elif self.config.problem_type == "single_label_classification": 864 | loss_fct = CrossEntropyLoss() 865 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 866 | elif self.config.problem_type == "multi_label_classification": 867 | loss_fct = BCEWithLogitsLoss() 868 | loss = loss_fct(pooled_logits, labels) 869 | if not return_dict: 870 | output = (pooled_logits,) + transformer_outputs[1:] 871 | return ((loss,) + output) if loss is not None else output 872 | 873 | return SequenceClassifierOutputWithPast( 874 | loss=loss, 875 | logits=pooled_logits, 876 | past_key_values=transformer_outputs.past_key_values, 877 | hidden_states=transformer_outputs.hidden_states, 878 | attentions=transformer_outputs.attentions, 879 | ) 880 | -------------------------------------------------------------------------------- /peft_pretraining/training_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | from torch.optim.lr_scheduler import LambdaLR 6 | import transformers 7 | 8 | 9 | def get_scheculer( 10 | optimizer, 11 | *, 12 | scheduler_type, 13 | num_training_steps, 14 | warmup_steps, 15 | min_lr_ratio, 16 | cycle_length=None, 17 | restart_warmup_steps=None, 18 | adjust_step=0, 19 | last_epoch=-1, 20 | ): 21 | if adjust_step != 0 and scheduler_type != "cosine_restarts": 22 | raise ValueError("adjust_step is only supported for cosine_restarts scheduler") 23 | 24 | if scheduler_type == "linear": 25 | return transformers.get_linear_schedule_with_warmup( 26 | optimizer, 27 | num_warmup_steps=warmup_steps, 28 | num_training_steps=num_training_steps, 29 | last_epoch=last_epoch, 30 | ) 31 | if scheduler_type == "cosine": 32 | return get_cyclical_cosine_schedule_with_min_lr( 33 | optimizer, 34 | num_warmup_steps=warmup_steps, 35 | num_training_steps=num_training_steps, 36 | cycle_length=cycle_length, 37 | min_lr_ratio=min_lr_ratio, 38 | last_epoch=last_epoch, 39 | ) 40 | if scheduler_type == "cosine_restarts": 41 | assert restart_warmup_steps is not None, "restart_warmup_steps must be specified for cosine_restarts scheduler" 42 | return get_cosine_schedule_with_multiple_warmups( 43 | optimizer, 44 | num_training_steps=num_training_steps, 45 | first_warmup_steps=warmup_steps, 46 | restart_warmup_steps=restart_warmup_steps, 47 | restart_every=cycle_length, 48 | min_lr_ratio=min_lr_ratio, 49 | last_epoch=last_epoch, 50 | adjust_step=adjust_step, 51 | ) 52 | 53 | raise NotImplementedError(f"Scheduler {scheduler_type} is not implemented") 54 | 55 | 56 | def get_cyclical_cosine_schedule_with_min_lr(optimizer, num_warmup_steps, num_training_steps, cycle_length, min_lr_ratio=0.1, last_epoch=-1): 57 | assert cycle_length is not None or num_training_steps is not None, "You must specify either cycle_length or num_training_steps" 58 | 59 | if cycle_length is None: 60 | cycle_length = num_training_steps 61 | 62 | if num_training_steps % cycle_length != 0: 63 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by cycle_length ({cycle_length})") 64 | 65 | lr_lambda = partial( 66 | _get_cyclical_cosine_schedule_with_min_lr_lambda, 67 | num_warmup_steps=num_warmup_steps, 68 | cycle_length=cycle_length, 69 | min_lr_ratio=min_lr_ratio, 70 | ) 71 | return LambdaLR(optimizer, lr_lambda, last_epoch) 72 | 73 | 74 | def get_cosine_schedule_with_multiple_warmups( 75 | optimizer, 76 | *, 77 | num_training_steps, 78 | first_warmup_steps, 79 | restart_warmup_steps, 80 | restart_every, 81 | min_lr_ratio=0.1, 82 | adjust_step=0, 83 | last_epoch=-1, 84 | ): 85 | if restart_every is None: 86 | raise ValueError("restart_every must be specified for cosine_restarts scheduler") 87 | 88 | if num_training_steps % restart_every != 0: 89 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by restart_every ({restart_every})") 90 | 91 | lr_lambda = partial( 92 | _get_cosine_schedule_with_multiple_warmups_lambda, 93 | num_training_steps=num_training_steps, 94 | first_warmup_steps=first_warmup_steps, 95 | restart_warmup_steps=restart_warmup_steps, 96 | restart_every=restart_every, 97 | min_lr_ratio=min_lr_ratio, 98 | adjust_step=adjust_step, 99 | ) 100 | return LambdaLR(optimizer, lr_lambda, last_epoch) 101 | 102 | 103 | @torch.no_grad() 104 | def random_pruning(tensor, prune_ratio): 105 | """ 106 | Performs random pruning dimensionality reduction. 107 | Only reduces the inner dimensionality, does not affect the shape of the tensor 108 | """ 109 | random_pruning_mask = torch.rand_like(tensor) > prune_ratio 110 | tensor = tensor * random_pruning_mask 111 | return tensor 112 | 113 | 114 | @torch.no_grad() 115 | def magnitude_pruning(tensor, prune_ratio): 116 | """ 117 | Performs magnitude pruning dimensionality reduction. 118 | Only reduces the inner dimensionality, does not affect the shape of the tensor 119 | """ 120 | tensor_magnitude = torch.abs(tensor) 121 | threshold = torch.quantile(tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio).to(dtype=tensor.dtype) 122 | 123 | mask = tensor_magnitude > threshold 124 | tensor = tensor * mask.to(dtype=tensor.dtype) 125 | return tensor 126 | 127 | 128 | def _get_cyclical_cosine_schedule_with_min_lr_lambda(current_step, *, num_warmup_steps, cycle_length, min_lr_ratio): 129 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]" 130 | 131 | # compute where we are in the current cycle 132 | cycle_step = current_step % cycle_length 133 | 134 | if cycle_step < num_warmup_steps: 135 | if current_step != cycle_step: 136 | if cycle_step < 2: 137 | return 1e-7 138 | return float(cycle_step) / float(max(1, num_warmup_steps)) 139 | 140 | progress = float(cycle_step - num_warmup_steps) / float(max(1, cycle_length - num_warmup_steps)) 141 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) 142 | 143 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay 144 | 145 | 146 | def _get_cosine_schedule_with_multiple_warmups_lambda( 147 | current_step, 148 | *, 149 | num_training_steps, 150 | first_warmup_steps, 151 | restart_warmup_steps, 152 | restart_every, 153 | min_lr_ratio, 154 | adjust_step, 155 | ): 156 | """ 157 | Args: 158 | adjust_step: useful when continuing training from a warmed up checkpoint, 159 | it allows to sync the resets by reducing the number of steps 160 | after the first warmup and before the first reset. 161 | Thus, your ReLoRA resets can be synced with the optimizer resets. 162 | """ 163 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]" 164 | assert restart_every > 0, "restart_every must be positive" 165 | assert adjust_step + first_warmup_steps < num_training_steps, "warmup + adjust_step is more than full training steps" 166 | assert adjust_step + first_warmup_steps < restart_every, "the first reset will happen before the warmup is done" 167 | 168 | if current_step < first_warmup_steps: 169 | return float(current_step) / float(max(1, first_warmup_steps)) 170 | 171 | _current_step = current_step + adjust_step 172 | 173 | restart_step = _current_step % restart_every 174 | restart_number = _current_step // restart_every 175 | 176 | if restart_step < restart_warmup_steps: 177 | # get expected lr multipler at the end of the warmup 178 | end_of_warmup_progress = ( 179 | float(restart_number * restart_every) / 180 | float(max(1, num_training_steps - first_warmup_steps)) 181 | ) 182 | 183 | _cosine_decay = 0.5 * (1.0 + math.cos(math.pi * end_of_warmup_progress)) 184 | warmup_lr_multiplier = min_lr_ratio + (1.0 - min_lr_ratio) * _cosine_decay 185 | 186 | return float(restart_step) / float(max(1, restart_warmup_steps)) * warmup_lr_multiplier 187 | 188 | progress = float(_current_step - first_warmup_steps) / float(max(1, num_training_steps - first_warmup_steps)) 189 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) 190 | 191 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay 192 | 193 | 194 | def collate_fn(batch_list): 195 | batch = { 196 | "input_ids": torch.stack([torch.Tensor(example["input_ids"]).long() for example in batch_list]), 197 | "attention_mask": torch.stack([torch.Tensor(example["attention_mask"]).long() for example in batch_list]), 198 | } 199 | return batch 200 | 201 | 202 | def batch_fn(dataset, batch_size): 203 | batch = [] 204 | for example in dataset: 205 | batch.append(example) 206 | if len(batch) == batch_size: 207 | batch = collate_fn(batch) 208 | yield batch 209 | batch = [] 210 | if len(batch) > 0: 211 | yield batch 212 | 213 | 214 | def max_train_tokens_to_number(max_train_tokens): 215 | if max_train_tokens.endswith("M"): 216 | return int(max_train_tokens.rstrip("M")) * 1_000_000 217 | elif max_train_tokens.endswith("B"): 218 | return int(max_train_tokens.rstrip("B")) * 1_000_000_000 219 | else: 220 | return int(max_train_tokens) 221 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | bitsandbytes 4 | tensorly 5 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Finetuning a 🤗 Transformers model for sequence classification on GLUE.""" 16 | import argparse 17 | import json 18 | import logging 19 | import math 20 | import os 21 | import random 22 | from pathlib import Path 23 | 24 | import datasets 25 | import evaluate 26 | import torch 27 | from accelerate import Accelerator 28 | from accelerate.logging import get_logger 29 | from accelerate.utils import set_seed 30 | from datasets import load_dataset 31 | from huggingface_hub import Repository, create_repo 32 | from torch.utils.data import DataLoader 33 | from tqdm.auto import tqdm 34 | 35 | import transformers 36 | from transformers import ( 37 | AutoConfig, 38 | AutoModelForSequenceClassification, 39 | AutoTokenizer, 40 | DataCollatorWithPadding, 41 | PretrainedConfig, 42 | SchedulerType, 43 | default_data_collator, 44 | get_scheduler, 45 | LlamaForSequenceClassification 46 | ) 47 | from transformers.utils import check_min_version, send_example_telemetry 48 | from transformers.utils.versions import require_version 49 | 50 | from galore_torch import GaLoreAdamW 51 | 52 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 53 | # check_min_version("4.38.0.dev0") 54 | 55 | logger = get_logger(__name__) 56 | 57 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 58 | 59 | task_to_keys = { 60 | "cola": ("sentence", None), 61 | "mnli": ("premise", "hypothesis"), 62 | "mrpc": ("sentence1", "sentence2"), 63 | "qnli": ("question", "sentence"), 64 | "qqp": ("question1", "question2"), 65 | "rte": ("sentence1", "sentence2"), 66 | "sst2": ("sentence", None), 67 | "stsb": ("sentence1", "sentence2"), 68 | "wnli": ("sentence1", "sentence2"), 69 | } 70 | 71 | 72 | def parse_args(): 73 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") 74 | 75 | # LoRA hyperparameters 76 | parser.add_argument("--lora_r", type=int, default=8) 77 | parser.add_argument("--load_pretrained_model", type=str, default=None) 78 | 79 | parser.add_argument( 80 | "--task_name", 81 | type=str, 82 | default=None, 83 | help="The name of the glue task to train on.", 84 | choices=list(task_to_keys.keys()), 85 | ) 86 | parser.add_argument( 87 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data." 88 | ) 89 | parser.add_argument( 90 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." 91 | ) 92 | parser.add_argument( 93 | "--max_length", 94 | type=int, 95 | default=128, 96 | help=( 97 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 98 | " sequences shorter will be padded if `--pad_to_max_length` is passed." 99 | ), 100 | ) 101 | parser.add_argument( 102 | "--pad_to_max_length", 103 | action="store_true", 104 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 105 | ) 106 | parser.add_argument( 107 | "--model_name_or_path", 108 | type=str, 109 | help="Path to pretrained model or model identifier from huggingface.co/models.", 110 | required=True, 111 | ) 112 | parser.add_argument( 113 | "--use_slow_tokenizer", 114 | action="store_true", 115 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 116 | ) 117 | parser.add_argument( 118 | "--per_device_train_batch_size", 119 | type=int, 120 | default=8, 121 | help="Batch size (per device) for the training dataloader.", 122 | ) 123 | parser.add_argument( 124 | "--per_device_eval_batch_size", 125 | type=int, 126 | default=8, 127 | help="Batch size (per device) for the evaluation dataloader.", 128 | ) 129 | parser.add_argument( 130 | "--learning_rate", 131 | type=float, 132 | default=5e-5, 133 | help="Initial learning rate (after the potential warmup period) to use.", 134 | ) 135 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 136 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 137 | parser.add_argument( 138 | "--max_train_steps", 139 | type=int, 140 | default=None, 141 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 142 | ) 143 | parser.add_argument( 144 | "--gradient_accumulation_steps", 145 | type=int, 146 | default=1, 147 | help="Number of updates steps to accumulate before performing a backward/update pass.", 148 | ) 149 | parser.add_argument( 150 | "--lr_scheduler_type", 151 | type=SchedulerType, 152 | default="linear", 153 | help="The scheduler type to use.", 154 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 155 | ) 156 | parser.add_argument( 157 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 158 | ) 159 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 160 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 161 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 162 | parser.add_argument( 163 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." 164 | ) 165 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") 166 | parser.add_argument( 167 | "--trust_remote_code", 168 | type=bool, 169 | default=False, 170 | help=( 171 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" 172 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will " 173 | "execute code present on the Hub on your local machine." 174 | ), 175 | ) 176 | parser.add_argument( 177 | "--checkpointing_steps", 178 | type=str, 179 | default=None, 180 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 181 | ) 182 | parser.add_argument( 183 | "--resume_from_checkpoint", 184 | type=str, 185 | default=None, 186 | help="If the training should continue from a checkpoint folder.", 187 | ) 188 | parser.add_argument( 189 | "--with_tracking", 190 | action="store_true", 191 | help="Whether to enable experiment trackers for logging.", 192 | ) 193 | parser.add_argument( 194 | "--report_to", 195 | type=str, 196 | default="all", 197 | help=( 198 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 199 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. ' 200 | "Only applicable when `--with_tracking` is passed." 201 | ), 202 | ) 203 | parser.add_argument( 204 | "--ignore_mismatched_sizes", 205 | action="store_true", 206 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.", 207 | ) 208 | 209 | # support enable_galore 210 | parser.add_argument("--enable_galore", action="store_true", help="Whether or not to use low rank optimizer.") 211 | # update_proj_gap 212 | parser.add_argument("--update_proj_gap", type=int, default=50) 213 | # galore_scale 214 | parser.add_argument("--galore_scale", type=float, default=1.0) 215 | # proj_type 216 | parser.add_argument("--proj_type", type=str, default="std") 217 | # lora_all_modules 218 | parser.add_argument("--lora_all_modules", action="store_true", help="Whether or not to use lora for all modules.") 219 | # eval_llama 220 | parser.add_argument("--eval_llama", action="store_true", help="Whether or not to evaluate llama model.") 221 | # low_rank_method 222 | parser.add_argument("--low_rank_method", type=str, default=None, help="low rank method for wandb sweep") 223 | 224 | args = parser.parse_args() 225 | 226 | # Sanity checks 227 | if args.task_name is None and args.train_file is None and args.validation_file is None: 228 | raise ValueError("Need either a task name or a training/validation file.") 229 | else: 230 | if args.train_file is not None: 231 | extension = args.train_file.split(".")[-1] 232 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 233 | if args.validation_file is not None: 234 | extension = args.validation_file.split(".")[-1] 235 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 236 | 237 | if args.push_to_hub: 238 | assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." 239 | 240 | return args 241 | 242 | 243 | def main(): 244 | args = parse_args() 245 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 246 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 247 | send_example_telemetry("run_glue_no_trainer", args) 248 | 249 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 250 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers 251 | # in the environment 252 | accelerator = ( 253 | Accelerator(log_with=args.report_to, project_dir=args.output_dir) if args.with_tracking else Accelerator() 254 | ) 255 | # Make one log on every process with the configuration for debugging. 256 | logging.basicConfig( 257 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 258 | datefmt="%m/%d/%Y %H:%M:%S", 259 | level=logging.INFO, 260 | ) 261 | logger.info(accelerator.state, main_process_only=False) 262 | if accelerator.is_local_main_process: 263 | datasets.utils.logging.set_verbosity_warning() 264 | transformers.utils.logging.set_verbosity_info() 265 | else: 266 | datasets.utils.logging.set_verbosity_error() 267 | transformers.utils.logging.set_verbosity_error() 268 | 269 | # If passed along, set the training seed now. 270 | if args.seed is not None: 271 | set_seed(args.seed) 272 | 273 | # Handle the repository creation 274 | if accelerator.is_main_process: 275 | if args.push_to_hub: 276 | # Retrieve of infer repo_name 277 | repo_name = args.hub_model_id 278 | if repo_name is None: 279 | repo_name = Path(args.output_dir).absolute().name 280 | # Create repo and retrieve repo_id 281 | repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id 282 | # Clone repo locally 283 | repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token) 284 | 285 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 286 | if "step_*" not in gitignore: 287 | gitignore.write("step_*\n") 288 | if "epoch_*" not in gitignore: 289 | gitignore.write("epoch_*\n") 290 | elif args.output_dir is not None: 291 | os.makedirs(args.output_dir, exist_ok=True) 292 | accelerator.wait_for_everyone() 293 | 294 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 295 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 296 | 297 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 298 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 299 | # label if at least two columns are provided. 300 | 301 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 302 | # single column. You can easily tweak this behavior (see below) 303 | 304 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 305 | # download the dataset. 306 | if args.task_name is not None: 307 | # Downloading and loading a dataset from the hub. 308 | raw_datasets = load_dataset("glue", args.task_name) 309 | else: 310 | # Loading the dataset from local csv or json file. 311 | data_files = {} 312 | if args.train_file is not None: 313 | data_files["train"] = args.train_file 314 | if args.validation_file is not None: 315 | data_files["validation"] = args.validation_file 316 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1] 317 | raw_datasets = load_dataset(extension, data_files=data_files) 318 | # See more about loading any type of standard or custom dataset at 319 | # https://huggingface.co/docs/datasets/loading_datasets. 320 | 321 | # Labels 322 | if args.task_name is not None: 323 | is_regression = args.task_name == "stsb" 324 | if not is_regression: 325 | label_list = raw_datasets["train"].features["label"].names 326 | num_labels = len(label_list) 327 | else: 328 | num_labels = 1 329 | else: 330 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 331 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 332 | if is_regression: 333 | num_labels = 1 334 | else: 335 | # A useful fast method: 336 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 337 | label_list = raw_datasets["train"].unique("label") 338 | label_list.sort() # Let's sort it for determinism 339 | num_labels = len(label_list) 340 | 341 | # Load pretrained model and tokenizer 342 | # 343 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 344 | # download model & vocab. 345 | 346 | if not args.eval_llama: 347 | config = AutoConfig.from_pretrained( 348 | args.model_name_or_path, 349 | num_labels=num_labels, 350 | finetuning_task=args.task_name, 351 | trust_remote_code=args.trust_remote_code, 352 | ) 353 | tokenizer = AutoTokenizer.from_pretrained( 354 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code 355 | ) 356 | model = AutoModelForSequenceClassification.from_pretrained( 357 | args.model_name_or_path, 358 | from_tf=bool(".ckpt" in args.model_name_or_path), 359 | config=config, 360 | ignore_mismatched_sizes=args.ignore_mismatched_sizes, 361 | trust_remote_code=args.trust_remote_code, 362 | ) 363 | else: 364 | config = AutoConfig.from_pretrained(args.model_name_or_path) 365 | setattr(config, 'num_labels', num_labels) 366 | setattr(config, 'finetuning_task', args.task_name) 367 | tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=args.max_length) 368 | tokenizer.padding_side = "left" 369 | model = LlamaForSequenceClassification( 370 | config 371 | ) 372 | ## load pretrained model 373 | if args.load_pretrained_model: 374 | logger.info("*" * 40) 375 | logger.info(f"Loading model from {args.load_pretrained_model}") 376 | checkpoint_path = os.path.join(args.load_pretrained_model, "pytorch_model.bin") 377 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 378 | for key in checkpoint.keys(): 379 | if key not in model.state_dict().keys(): 380 | print(f"key {key} not in model state dict") 381 | 382 | for key in model.state_dict().keys(): 383 | if key not in checkpoint.keys(): 384 | print(f"key {key} not in checkpoint") 385 | model.load_state_dict(checkpoint, strict=False) 386 | logger.info(f"Model successfully loaded (strict=False policy)") 387 | logger.info("*" * 40) 388 | 389 | # project modules 390 | if not args.lora_all_modules: 391 | target_modules_list = ["q_proj", "v_proj"] 392 | else: 393 | print('Enabling LoRA for all modules') 394 | target_modules_list = ["q_proj", "v_proj", "up_proj", "down_proj", "gate_proj", "k_proj", "o_proj"] 395 | 396 | # other modules for bert-family modules 397 | if 'bert' in args.model_name_or_path: 398 | if not args.lora_all_modules: 399 | target_modules_list = ["query"] 400 | else: 401 | print('Enabling LoRA for all modules') 402 | target_modules_list = ["query", "value", "key", "intermediate.dense", "output.dense"] 403 | 404 | # Preprocessing the datasets 405 | if args.task_name is not None: 406 | sentence1_key, sentence2_key = task_to_keys[args.task_name] 407 | else: 408 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 409 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 410 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 411 | sentence1_key, sentence2_key = "sentence1", "sentence2" 412 | else: 413 | if len(non_label_column_names) >= 2: 414 | sentence1_key, sentence2_key = non_label_column_names[:2] 415 | else: 416 | sentence1_key, sentence2_key = non_label_column_names[0], None 417 | 418 | # Some models have set the order of the labels to use, so let's make sure we do use it. 419 | label_to_id = None 420 | if ( 421 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 422 | and args.task_name is not None 423 | and not is_regression 424 | ): 425 | # Some have all caps in their config, some don't. 426 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 427 | if sorted(label_name_to_id.keys()) == sorted(label_list): 428 | logger.info( 429 | f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " 430 | "Using it!" 431 | ) 432 | label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} 433 | else: 434 | logger.warning( 435 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 436 | f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}." 437 | "\nIgnoring the model labels as a result.", 438 | ) 439 | elif args.task_name is None and not is_regression: 440 | label_to_id = {v: i for i, v in enumerate(label_list)} 441 | 442 | if label_to_id is not None: 443 | model.config.label2id = label_to_id 444 | model.config.id2label = {id: label for label, id in config.label2id.items()} 445 | elif args.task_name is not None and not is_regression: 446 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 447 | model.config.id2label = {id: label for label, id in config.label2id.items()} 448 | 449 | padding = "max_length" if args.pad_to_max_length else False 450 | 451 | def preprocess_function(examples): 452 | # Tokenize the texts 453 | texts = ( 454 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 455 | ) 456 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True) 457 | 458 | if "label" in examples: 459 | if label_to_id is not None: 460 | # Map labels to IDs (not necessary for GLUE tasks) 461 | result["labels"] = [label_to_id[l] for l in examples["label"]] 462 | else: 463 | # In all cases, rename the column to labels because the model will expect that. 464 | result["labels"] = examples["label"] 465 | return result 466 | 467 | with accelerator.main_process_first(): 468 | processed_datasets = raw_datasets.map( 469 | preprocess_function, 470 | batched=True, 471 | remove_columns=raw_datasets["train"].column_names, 472 | desc="Running tokenizer on dataset", 473 | ) 474 | 475 | train_dataset = processed_datasets["train"] 476 | eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] 477 | 478 | # Log a few random samples from the training set: 479 | for index in random.sample(range(len(train_dataset)), 3): 480 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 481 | 482 | # DataLoaders creation: 483 | if args.pad_to_max_length: 484 | # If padding was already done ot max length, we use the default data collator that will just convert everything 485 | # to tensors. 486 | data_collator = default_data_collator 487 | else: 488 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of 489 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple 490 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). 491 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) 492 | 493 | train_dataloader = DataLoader( 494 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 495 | ) 496 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) 497 | 498 | # Optimizer 499 | # Split weights in two groups, one with weight decay and the other not. 500 | no_decay = ["bias", "LayerNorm.weight"] 501 | optimizer_grouped_parameters = [ 502 | { 503 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 504 | "weight_decay": args.weight_decay, 505 | }, 506 | { 507 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 508 | "weight_decay": 0.0, 509 | }, 510 | ] 511 | 512 | if not args.enable_galore: 513 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 514 | else: 515 | from torch import nn 516 | # label layers for galore optimizer 517 | # target_modules_list = ["attn", "mlp"] 518 | # target_modules_list = ["q_proj", "v_proj"] 519 | galore_params = [] 520 | for module_name, module in model.named_modules(): 521 | if not isinstance(module, nn.Linear): 522 | continue 523 | 524 | if not any(target_key in module_name for target_key in target_modules_list): 525 | continue 526 | 527 | print('enable GaLore for weights in module: ', module_name) 528 | galore_params.append(module.weight) 529 | 530 | id_galore_params = [id(p) for p in galore_params] 531 | # make parameters without "rank" to another group 532 | regular_params = [p for p in model.parameters() if id(p) not in id_galore_params] 533 | # then call galore_adamw 534 | param_groups = [{'params': regular_params}, 535 | {'params': galore_params, 'rank': args.lora_r, 'update_proj_gap': args.update_proj_gap, 'scale': args.galore_scale, 'proj_type': args.proj_type}] 536 | optimizer = GaLoreAdamW(param_groups, lr=args.learning_rate) 537 | 538 | 539 | # Scheduler and math around the number of training steps. 540 | overrode_max_train_steps = False 541 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 542 | if args.max_train_steps is None: 543 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 544 | overrode_max_train_steps = True 545 | 546 | lr_scheduler = get_scheduler( 547 | name=args.lr_scheduler_type, 548 | optimizer=optimizer, 549 | num_warmup_steps=args.num_warmup_steps, 550 | num_training_steps=args.max_train_steps, 551 | ) 552 | 553 | # Prepare everything with our `accelerator`. 554 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 555 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 556 | ) 557 | 558 | # We need to recalculate our total training steps as the size of the training dataloader may have changed 559 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 560 | if overrode_max_train_steps: 561 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 562 | # Afterwards we recalculate our number of training epochs 563 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 564 | 565 | # Figure out how many steps we should save the Accelerator states 566 | checkpointing_steps = args.checkpointing_steps 567 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 568 | checkpointing_steps = int(checkpointing_steps) 569 | 570 | # We need to initialize the trackers we use, and also store our configuration. 571 | # The trackers initializes automatically on the main process. 572 | if args.with_tracking: 573 | experiment_config = vars(args) 574 | # TensorBoard cannot log Enums, need the raw value 575 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 576 | accelerator.init_trackers("glue_no_trainer", experiment_config) 577 | 578 | # Get the metric function 579 | if args.task_name is not None: 580 | metric = evaluate.load("glue", args.task_name) 581 | else: 582 | metric = evaluate.load("accuracy") 583 | 584 | # Train! 585 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 586 | 587 | logger.info("***** Running training *****") 588 | logger.info(f" Num examples = {len(train_dataset)}") 589 | logger.info(f" Num Epochs = {args.num_train_epochs}") 590 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 591 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 592 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 593 | logger.info(f" Total optimization steps = {args.max_train_steps}") 594 | # Only show the progress bar once on each machine. 595 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 596 | completed_steps = 0 597 | starting_epoch = 0 598 | # Potentially load in the weights and states from a previous save 599 | if args.resume_from_checkpoint: 600 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 601 | checkpoint_path = args.resume_from_checkpoint 602 | path = os.path.basename(args.resume_from_checkpoint) 603 | else: 604 | # Get the most recent checkpoint 605 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 606 | dirs.sort(key=os.path.getctime) 607 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 608 | checkpoint_path = path 609 | path = os.path.basename(checkpoint_path) 610 | 611 | accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") 612 | accelerator.load_state(checkpoint_path) 613 | # Extract `epoch_{i}` or `step_{i}` 614 | training_difference = os.path.splitext(path)[0] 615 | 616 | if "epoch" in training_difference: 617 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 618 | resume_step = None 619 | completed_steps = starting_epoch * num_update_steps_per_epoch 620 | else: 621 | # need to multiply `gradient_accumulation_steps` to reflect real steps 622 | resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps 623 | starting_epoch = resume_step // len(train_dataloader) 624 | completed_steps = resume_step // args.gradient_accumulation_steps 625 | resume_step -= starting_epoch * len(train_dataloader) 626 | 627 | # update the progress_bar if load from checkpoint 628 | progress_bar.update(completed_steps) 629 | 630 | for epoch in range(starting_epoch, args.num_train_epochs): 631 | model.train() 632 | if args.with_tracking: 633 | total_loss = 0 634 | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: 635 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint 636 | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) 637 | else: 638 | active_dataloader = train_dataloader 639 | for step, batch in enumerate(active_dataloader): 640 | 641 | outputs = model(**batch) 642 | loss = outputs.loss 643 | # We keep track of the loss at each epoch 644 | if args.with_tracking: 645 | total_loss += loss.detach().float() 646 | loss = loss / args.gradient_accumulation_steps 647 | accelerator.backward(loss) 648 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 649 | optimizer.step() 650 | lr_scheduler.step() 651 | optimizer.zero_grad() 652 | progress_bar.update(1) 653 | completed_steps += 1 654 | 655 | if isinstance(checkpointing_steps, int): 656 | if completed_steps % checkpointing_steps == 0: 657 | output_dir = f"step_{completed_steps}" 658 | if args.output_dir is not None: 659 | output_dir = os.path.join(args.output_dir, output_dir) 660 | accelerator.save_state(output_dir) 661 | 662 | if completed_steps >= args.max_train_steps: 663 | break 664 | 665 | model.eval() 666 | samples_seen = 0 667 | for step, batch in enumerate(eval_dataloader): 668 | with torch.no_grad(): 669 | outputs = model(**batch) 670 | predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() 671 | predictions, references = accelerator.gather((predictions, batch["labels"])) 672 | # If we are in a multiprocess environment, the last batch has duplicates 673 | if accelerator.num_processes > 1: 674 | if step == len(eval_dataloader) - 1: 675 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] 676 | references = references[: len(eval_dataloader.dataset) - samples_seen] 677 | else: 678 | samples_seen += references.shape[0] 679 | metric.add_batch( 680 | predictions=predictions, 681 | references=references, 682 | ) 683 | 684 | eval_metric = metric.compute() 685 | logger.info(f"epoch {epoch}: {eval_metric}") 686 | 687 | if args.with_tracking: 688 | accelerator.log( 689 | { 690 | "accuracy" if args.task_name is not None else "glue": eval_metric, 691 | "train_loss": total_loss.item() / len(train_dataloader), 692 | "epoch": epoch, 693 | "step": completed_steps, 694 | }, 695 | step=completed_steps, 696 | ) 697 | 698 | if args.push_to_hub and epoch < args.num_train_epochs - 1: 699 | accelerator.wait_for_everyone() 700 | unwrapped_model = accelerator.unwrap_model(model) 701 | unwrapped_model.save_pretrained( 702 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 703 | ) 704 | if accelerator.is_main_process: 705 | tokenizer.save_pretrained(args.output_dir) 706 | repo.push_to_hub( 707 | commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True 708 | ) 709 | 710 | if args.checkpointing_steps == "epoch": 711 | output_dir = f"epoch_{epoch}" 712 | if args.output_dir is not None: 713 | output_dir = os.path.join(args.output_dir, output_dir) 714 | accelerator.save_state(output_dir) 715 | 716 | if args.with_tracking: 717 | accelerator.end_training() 718 | 719 | if args.output_dir is not None: 720 | accelerator.wait_for_everyone() 721 | unwrapped_model = accelerator.unwrap_model(model) 722 | unwrapped_model.save_pretrained( 723 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 724 | ) 725 | if accelerator.is_main_process: 726 | tokenizer.save_pretrained(args.output_dir) 727 | if args.push_to_hub: 728 | repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) 729 | 730 | if args.task_name == "mnli": 731 | # Final evaluation on mismatched validation set 732 | eval_dataset = processed_datasets["validation_mismatched"] 733 | eval_dataloader = DataLoader( 734 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 735 | ) 736 | eval_dataloader = accelerator.prepare(eval_dataloader) 737 | 738 | model.eval() 739 | for step, batch in enumerate(eval_dataloader): 740 | outputs = model(**batch) 741 | predictions = outputs.logits.argmax(dim=-1) 742 | metric.add_batch( 743 | predictions=accelerator.gather(predictions), 744 | references=accelerator.gather(batch["labels"]), 745 | ) 746 | 747 | eval_metric = metric.compute() 748 | logger.info(f"mnli-mm: {eval_metric}") 749 | 750 | if args.output_dir is not None: 751 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()} 752 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 753 | json.dump(all_results, f) 754 | 755 | 756 | if __name__ == "__main__": 757 | main() -------------------------------------------------------------------------------- /scripts/benchmark_c4/llama_130m.sh: -------------------------------------------------------------------------------- 1 | # LLaMA-130M, GaLore-Adam, 1 A100, 1 Node 2 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \ 3 | --model_config configs/llama_130m.json \ 4 | --lr 0.01 \ 5 | --galore_scale 0.25 \ 6 | --rank 256 \ 7 | --update_proj_gap 200 \ 8 | --batch_size 256 \ 9 | --total_batch_size 512 \ 10 | --num_training_steps 20000 \ 11 | --warmup_steps 2000 \ 12 | --weight_decay 0 \ 13 | --dtype bfloat16 \ 14 | --eval_every 1000 \ 15 | --optimizer galore_adamw -------------------------------------------------------------------------------- /scripts/benchmark_c4/llama_1b.sh: -------------------------------------------------------------------------------- 1 | # LLaMA-1B, GaLore-Adam, 8 A100, 1 Node 2 | torchrun --standalone --nproc_per_node 8 torchrun_main.py \ 3 | --model_config configs/llama_1b.json \ 4 | --lr 0.01 \ 5 | --galore_scale 0.25 \ 6 | --rank 1024 \ 7 | --update_proj_gap 200 \ 8 | --batch_size 16 \ 9 | --total_batch_size 512 \ 10 | --num_training_steps 100000 \ 11 | --warmup_steps 10000 \ 12 | --weight_decay 0 \ 13 | --dtype bfloat16 \ 14 | --eval_every 1000 \ 15 | --optimizer galore_adamw -------------------------------------------------------------------------------- /scripts/benchmark_c4/llama_350m.sh: -------------------------------------------------------------------------------- 1 | # LLaMA-350M, GaLore-Adam, 4 A100, 1 Node 2 | torchrun --standalone --nproc_per_node 4 torchrun_main.py \ 3 | --model_config configs/llama_350m.json \ 4 | --lr 0.01 \ 5 | --galore_scale 0.25 \ 6 | --rank 256 \ 7 | --update_proj_gap 200 \ 8 | --batch_size 128 \ 9 | --total_batch_size 512 \ 10 | --num_training_steps 60000 \ 11 | --warmup_steps 6000 \ 12 | --weight_decay 0 \ 13 | --dtype bfloat16 \ 14 | --eval_every 1000 \ 15 | --optimizer galore_adamw -------------------------------------------------------------------------------- /scripts/benchmark_c4/llama_60m.sh: -------------------------------------------------------------------------------- 1 | # LLaMA-60M, GaLore-Adam, 1 A100, 1 Node 2 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \ 3 | --model_config configs/llama_60m.json \ 4 | --lr 0.01 \ 5 | --galore_scale 0.25 \ 6 | --rank 128 \ 7 | --update_proj_gap 200 \ 8 | --batch_size 256 \ 9 | --total_batch_size 512 \ 10 | --num_training_steps 10000 \ 11 | --warmup_steps 1000 \ 12 | --weight_decay 0 \ 13 | --dtype bfloat16 \ 14 | --eval_every 1000 \ 15 | --optimizer galore_adamw -------------------------------------------------------------------------------- /scripts/benchmark_c4/llama_7b.sh: -------------------------------------------------------------------------------- 1 | # LLaMA-7B, GaLore-Adam, 8 A100, 8 Node 2 | torchrun --standalone --nnodes 8 --nproc_per_node 8 torchrun_main.py \ 3 | --model_config configs/llama_7b.json \ 4 | --lr 0.005 \ 5 | --galore_scale 0.25 \ 6 | --rank 1024 \ 7 | --update_proj_gap 500 \ 8 | --batch_size 8 \ 9 | --total_batch_size 512 \ 10 | --num_training_steps 150000 \ 11 | --warmup_steps 15000 \ 12 | --weight_decay 0 \ 13 | --grad_clipping 1.0 \ 14 | --dtype bfloat16 \ 15 | --eval_every 1000 \ 16 | --optimizer galore_adamw -------------------------------------------------------------------------------- /scripts/single_gpu/llama_7b.sh: -------------------------------------------------------------------------------- 1 | # LLaMA-7B, 8-bit GaLore-Adam, single GPU 2 | # 22.72G, 0.37s/it 3 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \ 4 | --model_config configs/llama_7b.json \ 5 | --lr 0.005 \ 6 | --galore_scale 0.25 \ 7 | --rank 1024 \ 8 | --update_proj_gap 500 \ 9 | --batch_size 1 \ 10 | --total_batch_size 512 \ 11 | --num_training_steps 150000 \ 12 | --warmup_steps 15000 \ 13 | --weight_decay 0 \ 14 | --grad_clipping 1.0 \ 15 | --dtype bfloat16 \ 16 | --eval_every 1000 \ 17 | --single_gpu \ 18 | --optimizer galore_adamw8bit_per_layer -------------------------------------------------------------------------------- /scripts/single_gpu/llama_7b_checkpointing.sh: -------------------------------------------------------------------------------- 1 | # LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing 2 | # bsz=16, 22.8G, 3 | torchrun --standalone --nproc_per_node 1 torchrun_main.py \ 4 | --model_config configs/llama_7b.json \ 5 | --lr 0.005 \ 6 | --galore_scale 0.25 \ 7 | --rank 1024 \ 8 | --update_proj_gap 500 \ 9 | --batch_size 16 \ 10 | --total_batch_size 512 \ 11 | --activation_checkpointing \ 12 | --num_training_steps 150000 \ 13 | --warmup_steps 15000 \ 14 | --weight_decay 0 \ 15 | --grad_clipping 1.0 \ 16 | --dtype bfloat16 \ 17 | --eval_every 1000 \ 18 | --single_gpu \ 19 | --optimizer galore_adamw8bit_per_layer -------------------------------------------------------------------------------- /scripts/tensor_test/neural_operator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training a neural operator on Darcy-Flow - Author Robert Joseph 3 | ======================================== 4 | In this example, we demonstrate how to use the small Darcy-Flow example we ship with the package on Incremental FNO and Incremental Resolution as well as using Galore tensor decomposition. 5 | 6 | Assuming one installs the neuraloperator library: Instructions can be found here: https://github.com/NeuralOperator/neuraloperator 7 | """ 8 | 9 | # %% 10 | # 11 | import torch 12 | import matplotlib.pyplot as plt 13 | import sys 14 | from neuralop.training.callbacks import BasicLoggerCallback 15 | from neuralop.models import FNO 16 | from neuralop import Trainer 17 | from neuralop.datasets import load_darcy_flow_small 18 | from neuralop.utils import count_model_params 19 | from neuralop.training.callbacks import IncrementalCallback 20 | from neuralop.datasets import data_transforms 21 | from neuralop import LpLoss, H1Loss 22 | from neuralop.training import AdamW 23 | from neuralop.utils import count_model_params 24 | 25 | 26 | # %% 27 | # Loading the Darcy flow dataset 28 | train_loader, test_loaders, data_processor = load_darcy_flow_small( 29 | n_train=1000, batch_size=32, 30 | test_resolutions=[16, 32], n_tests=[100, 50], 31 | test_batch_sizes=[32, 32], 32 | positional_encoding=True 33 | ) 34 | # %% 35 | # Choose device 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | # %% 39 | # Set up the incremental FNO model 40 | # We start with 2 modes in each dimension 41 | # We choose to update the modes by the incremental gradient explained algorithm 42 | 43 | starting_modes = (10, 10) 44 | incremental = False 45 | 46 | model = FNO( 47 | max_n_modes=(20, 20), 48 | n_modes=starting_modes, 49 | hidden_channels=64, 50 | in_channels=1, 51 | out_channels=1, 52 | n_layers=4 53 | ) 54 | callbacks = [ 55 | IncrementalCallback( 56 | incremental_loss_gap=True, 57 | incremental_grad=False, 58 | incremental_grad_eps=0.9999, 59 | incremental_buffer=5, 60 | incremental_max_iter=1, 61 | incremental_grad_max_iter=2, 62 | ) 63 | ] 64 | model = model.to(device) 65 | n_params = count_model_params(model) 66 | galore_params = [] 67 | galore_params.extend(list(model.fno_blocks.convs.parameters())) 68 | print(galore_params[0].shape, galore_params[1].shape, galore_params[2].shape, galore_params[3].shape) 69 | galore_params.pop(0) 70 | id_galore_params = [id(p) for p in galore_params] 71 | # make parameters without "rank" to another group 72 | regular_params = [p for p in model.parameters() if id(p) not in id_galore_params] 73 | # then call galore_adamw 74 | # In this case we have a 5d tensor representing the weights in the spectral layers of the FNO 75 | # A good rule of thumb for tensor decomposition is that we should limit the rank to atmost 0.75, and increase the epochs and tune the lr accordingly compared to the baseline. 76 | # Low rank decomposition takes longer to converge, but it is more memory efficient. 77 | param_groups = [{'params': regular_params}, 78 | {'params': galore_params, 'rank': 0.2 , 'update_proj_gap': 10, 'scale': 0.25, 'proj_type': "std", 'dim': 5}] 79 | optimizer = AdamW(param_groups, lr=0.01) 80 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30) 81 | data_transform = data_transforms.IncrementalDataProcessor( 82 | in_normalizer=None, 83 | out_normalizer=None, 84 | positional_encoding=None, 85 | device=device, 86 | dataset_sublist=[2, 1], 87 | dataset_resolution=16, 88 | dataset_indices=[2, 3], 89 | epoch_gap=10, 90 | verbose=True, 91 | ) 92 | 93 | data_transform = data_transform.to(device) 94 | # %% 95 | # Set up the losses 96 | l2loss = LpLoss(d=2, p=2) 97 | h1loss = H1Loss(d=2) 98 | train_loss = h1loss 99 | eval_losses = {"h1": h1loss, "l2": l2loss} 100 | print("\n### OPTIMIZER rank ###\n", i, optimizer) 101 | sys.stdout.flush() 102 | 103 | # Finally pass all of these to the Trainer 104 | trainer = Trainer( 105 | model=model, 106 | n_epochs=100, 107 | data_processor=data_transform, 108 | callbacks=callbacks, 109 | device=device, 110 | verbose=True, 111 | ) 112 | 113 | # %% 114 | # Train the model 115 | trainer.train( 116 | train_loader, 117 | test_loaders, 118 | optimizer, 119 | scheduler, 120 | regularizer=False, 121 | training_loss=train_loss, 122 | eval_losses=eval_losses, 123 | ) 124 | 125 | # %% 126 | # Plot the prediction, and compare with the ground-truth 127 | # Note that we trained on a very small resolution for 128 | # a very small number of epochs 129 | # In practice, we would train at larger resolution, on many more samples. 130 | # 131 | # However, for practicity, we created a minimal example that 132 | # i) fits in just a few Mb of memory 133 | # ii) can be trained quickly on CPU 134 | # 135 | # In practice we would train a Neural Operator on one or multiple GPUs 136 | 137 | test_samples = test_loaders[32].dataset 138 | 139 | fig = plt.figure(figsize=(7, 7)) 140 | for index in range(3): 141 | data = test_samples[index] 142 | # Input x 143 | x = data["x"].to(device) 144 | # Ground-truth 145 | y = data["y"].to(device) 146 | # Model prediction 147 | out = model(x.unsqueeze(0)) 148 | ax = fig.add_subplot(3, 3, index * 3 + 1) 149 | x = x.cpu().squeeze().detach().numpy() 150 | y = y.cpu().squeeze().detach().numpy() 151 | ax.imshow(x, cmap="gray") 152 | if index == 0: 153 | ax.set_title("Input x") 154 | plt.xticks([], []) 155 | plt.yticks([], []) 156 | 157 | ax = fig.add_subplot(3, 3, index * 3 + 2) 158 | ax.imshow(y.squeeze()) 159 | if index == 0: 160 | ax.set_title("Ground-truth y") 161 | plt.xticks([], []) 162 | plt.yticks([], []) 163 | 164 | ax = fig.add_subplot(3, 3, index * 3 + 3) 165 | ax.imshow(out.cpu().squeeze().detach().numpy()) 166 | if index == 0: 167 | ax.set_title("Model prediction") 168 | plt.xticks([], []) 169 | plt.yticks([], []) 170 | 171 | fig.suptitle("Inputs, ground-truth output and prediction.", y=0.98) 172 | plt.tight_layout() 173 | fig.show() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("requirements.txt") as f: 4 | required = f.read().splitlines() 5 | 6 | setup( 7 | name="galore-torch", 8 | version="1.0", 9 | description="GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection", 10 | url="https://github.com/jiaweizzhao/GaLore", 11 | author="Jiawei Zhao", 12 | author_email="jiawei@caltech.edu", 13 | license="Apache 2.0", 14 | packages=["galore_torch"], 15 | install_requires=required, 16 | ) -------------------------------------------------------------------------------- /torchrun_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import random 5 | import argparse 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.data 11 | import torch.distributed as dist 12 | 13 | import transformers 14 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM 15 | from transformers import LlamaForCausalLM as HF_LlamaForCausalLM 16 | 17 | import datasets 18 | import datasets.distributed 19 | import wandb 20 | 21 | from tqdm import tqdm 22 | from loguru import logger 23 | 24 | from peft_pretraining import training_utils, args_utils 25 | from peft_pretraining.dataloader import PreprocessedIterableDataset 26 | from peft_pretraining.modeling_llama import LlamaForCausalLM 27 | 28 | import bitsandbytes as bnb 29 | from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor 30 | 31 | transformers.logging.set_verbosity_error() 32 | 33 | def parse_args(args): 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument("--model_config", type=str, required=True) 37 | parser.add_argument("--use_hf_model", default=False, action="store_true") 38 | parser.add_argument("--continue_from", type=str, default=None) 39 | parser.add_argument("--batch_size", type=int, required=True) 40 | parser.add_argument("--gradient_accumulation", type=int, default=None) 41 | parser.add_argument("--total_batch_size", type=int, default=None) 42 | parser.add_argument("--max_length", type=int, default=256) 43 | parser.add_argument("--optimizer", default="Adam") 44 | parser.add_argument("--lr", type=float, default=1e-4) 45 | parser.add_argument("--scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_restarts"]) 46 | parser.add_argument("--min_lr_ratio", type=float, default=0.1) 47 | parser.add_argument("--activation_checkpointing", action="store_true") 48 | parser.add_argument("--weight_decay", type=float, default=0.0) 49 | parser.add_argument("--warmup_steps", type=int, default=1_000) 50 | parser.add_argument("--eval_every", type=int, default=5_000) 51 | parser.add_argument("--num_training_steps", type=int, default=10_000, 52 | help="Number of **update steps** to train for. " 53 | "Notice that gradient accumulation is taken into account.") 54 | parser.add_argument("--max_train_tokens", type=training_utils.max_train_tokens_to_number, default=None, 55 | help="Number of tokens to train on. Overwrites num_training_steps. " 56 | "You can use M and B suffixes, e.g. 100M or 1B.") 57 | parser.add_argument("--save_every", type=int, default=10_000) 58 | parser.add_argument("--save_dir", type=str, default=None) 59 | parser.add_argument("--tags", type=str, default=None) 60 | parser.add_argument("--dtype", type=str, default="bfloat16" if torch.cuda.is_bf16_supported() else "float32") 61 | parser.add_argument("--workers", type=int, default=8) 62 | parser.add_argument("--seed", type=int, default=0) 63 | parser.add_argument("--name", type=str, default="test") 64 | parser.add_argument("--grad_clipping", type=float, default=0.0) 65 | # beta1 for adafactor 66 | parser.add_argument("--beta1", type=float, default=0.0) 67 | 68 | # GaLore parameters 69 | parser.add_argument("--rank", type=int, default=128) 70 | parser.add_argument("--update_proj_gap", type=int, default=50) 71 | parser.add_argument("--galore_scale", type=float, default=1.0) 72 | parser.add_argument("--proj_type", type=str, default="std") 73 | 74 | # disable ddp, single_gpu 75 | parser.add_argument("--single_gpu", default=False, action="store_true") 76 | 77 | args = parser.parse_args(args) 78 | 79 | args = args_utils.check_args_torchrun_main(args) 80 | return args 81 | 82 | 83 | @torch.no_grad() 84 | def evaluate_model(model, preprocess_batched, pad_idx, global_rank, world_size, device, batch_size): 85 | _time = time.time() 86 | val_data = datasets.load_dataset("c4", "en", split="validation", streaming=True) #DGX 87 | val_data = val_data.shuffle(seed=42) 88 | logger.info(f"Loaded validation dataset in {time.time() - _time:.2f} seconds") 89 | 90 | if not args.single_gpu: 91 | val_data = datasets.distributed.split_dataset_by_node(val_data, rank=global_rank, world_size=world_size) 92 | 93 | val_data_mapped = val_data.map( 94 | preprocess_batched, 95 | batched=True, 96 | remove_columns=["text", "timestamp", "url"], 97 | ) 98 | val_data_mapped.batch = lambda batch_size: training_utils.batch_fn(val_data_mapped, batch_size) 99 | 100 | target_eval_tokens = 10_000_000 101 | evaluated_on_tokens = 0 102 | total_loss = torch.tensor(0.0).to(device) 103 | total_batches = 1 104 | logger.info(f"Eval set prepared in {time.time() - _time:.2f} seconds") 105 | 106 | for batch in val_data_mapped.batch(batch_size=batch_size): 107 | if evaluated_on_tokens > target_eval_tokens: 108 | break 109 | total_batches += 1 110 | 111 | batch = {k: v.to(device) for k, v in batch.items()} 112 | labels = batch["input_ids"].clone() 113 | labels[labels == pad_idx] = -100 114 | loss = model(**batch, labels=labels).loss 115 | total_loss += loss.detach() 116 | 117 | evaluated_on_tokens += (batch["input_ids"] != pad_idx).sum().item() * world_size 118 | 119 | total_loss = total_loss / total_batches 120 | 121 | # Gather losses across all GPUs 122 | gathered_losses = [torch.zeros_like(total_loss) for _ in range(world_size)] 123 | dist.all_gather(gathered_losses, total_loss) 124 | total_loss = sum([t.item() for t in gathered_losses]) / world_size 125 | 126 | return total_loss, evaluated_on_tokens 127 | 128 | 129 | def main(args): 130 | torch.manual_seed(args.seed) 131 | np.random.seed(args.seed) 132 | random.seed(args.seed) 133 | 134 | assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK" 135 | global_rank = int(os.environ['RANK']) 136 | local_rank = int(os.environ["LOCAL_RANK"]) 137 | world_size = int(os.environ["WORLD_SIZE"]) 138 | torch.cuda.set_device(local_rank) 139 | 140 | logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}") 141 | 142 | dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size) 143 | 144 | logger.info("Process group initialized") 145 | device = f"cuda:{local_rank}" 146 | 147 | if args.total_batch_size is not None: 148 | if args.gradient_accumulation is None: 149 | assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size" 150 | args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size) 151 | assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0" 152 | 153 | assert args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size, \ 154 | "gradient_accumulation * batch_size * world_size must be equal to total_batch_size" 155 | 156 | # turn off logger 157 | if global_rank != 0: logger.remove() 158 | 159 | # initialize wandb without config (it is passed later) 160 | if global_rank == 0: 161 | wandb.init(project="galore-c4") 162 | 163 | logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)") 164 | logger.info("*" * 40) 165 | logger.info(f"Starting training with the arguments") 166 | for k, v in vars(args).items(): 167 | logger.info(f"{k:30} {v}") 168 | logger.info("*" * 40) 169 | 170 | data = datasets.load_dataset("allenai/c4", "en", split="train", streaming=True) 171 | 172 | seed_for_shuffle = 42 173 | 174 | logger.info(f"Shuffling data with seed {seed_for_shuffle}") 175 | data: datasets.Dataset = data.shuffle(seed=seed_for_shuffle) 176 | if not args.single_gpu: 177 | data = datasets.distributed.split_dataset_by_node( 178 | data, rank=global_rank, world_size=world_size, 179 | ) 180 | 181 | # it doesn't matter which tokenizer we use, because we train from scratch 182 | # T5 tokenizer was trained on C4 and we are also training on C4, so it's a good choice 183 | tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=args.max_length) 184 | 185 | def preprocess_batched(batch): 186 | batch = tokenizer( 187 | batch["text"], 188 | max_length=args.max_length, 189 | truncation=True, 190 | padding="max_length", 191 | return_tensors="pt", 192 | ) 193 | return batch 194 | 195 | dataset = PreprocessedIterableDataset(data, tokenizer, batch_size=args.batch_size, max_length=args.max_length) 196 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=args.workers) 197 | 198 | model_config = AutoConfig.from_pretrained(args.model_config) 199 | if args.use_hf_model: 200 | model: HF_LlamaForCausalLM = AutoModelForCausalLM.from_config(model_config) 201 | else: 202 | model = LlamaForCausalLM(model_config) 203 | 204 | if args.activation_checkpointing: 205 | model.gradient_checkpointing_enable() 206 | 207 | global_step = 0 208 | update_step = 0 209 | beginning_step = 0 210 | tokens_seen = 0 211 | tokens_seen_before = 0 212 | 213 | if args.continue_from is not None: 214 | logger.info("*" * 40) 215 | logger.info(f"Loading model from {args.continue_from}") 216 | checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin") 217 | model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True) 218 | logger.info(f"Model successfully loaded (strict=True policy)") 219 | 220 | if os.path.exists(os.path.join(args.continue_from, "training_state.json")): 221 | logger.info(f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}") 222 | with open(os.path.join(args.continue_from, "training_state.json")) as f: 223 | _old_state = json.load(f) 224 | global_step = _old_state["global_step"] 225 | update_step = _old_state["update_step"] 226 | tokens_seen = _old_state["tokens_seen"] 227 | tokens_seen_before = _old_state["tokens_seen_before"] 228 | logger.info(f"global_step : {global_step}") 229 | logger.info(f"update_step : {update_step}") 230 | logger.info(f"tokens_seen : {tokens_seen}") 231 | logger.info(f"tokens_seen_before: {tokens_seen_before}") 232 | logger.info(f"Will train for {args.num_training_steps - update_step} update steps") 233 | else: 234 | logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero") 235 | logger.info("*" * 40) 236 | 237 | 238 | if args.dtype in ["bf16", "bfloat16"]: 239 | model = model.to(device=device, dtype=torch.bfloat16) 240 | else: 241 | model = model.to(device=device) 242 | 243 | n_total_params = sum(p.numel() for p in model.parameters()) 244 | trainable_params = [p for p in model.parameters() if p.requires_grad] 245 | # Initialize wandb 246 | run_config = dict(vars(args)) 247 | run_config.update({ 248 | "max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler 249 | "total_params_M": n_total_params / 1_000_000, 250 | "dataset": 'c4', 251 | "model": model_config.to_dict(), 252 | "world_size": world_size, 253 | "device": str(device), 254 | }) 255 | 256 | if global_rank == 0: 257 | wandb.config.update(run_config, allow_val_change=True) 258 | wandb.save(os.path.abspath(__file__), policy="now") # save current script 259 | # fix tqdm visual length to 80 so that the progress bar 260 | # doesn't jump around when changing from external display to laptop 261 | pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80) 262 | 263 | if 'galore' in args.optimizer.lower(): 264 | # make parameters with "rank" to a single group, if param_name has "mlp" or "attn" 265 | galore_params = [] 266 | target_modules_list = ["attn", "mlp"] 267 | for module_name, module in model.named_modules(): 268 | if not isinstance(module, nn.Linear): 269 | continue 270 | 271 | if not any(target_key in module_name for target_key in target_modules_list): 272 | continue 273 | 274 | print('enable GaLore for weights in module: ', module_name) 275 | galore_params.append(module.weight) 276 | id_galore_params = [id(p) for p in galore_params] 277 | # make parameters without "rank" to another group 278 | regular_params = [p for p in model.parameters() if id(p) not in id_galore_params] 279 | # then call galore_adamw 280 | param_groups = [{'params': regular_params}, 281 | {'params': galore_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.galore_scale, 'proj_type': args.proj_type}] 282 | 283 | # print params and trainable params 284 | logger.info(f"\n{model}\n") 285 | logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M") 286 | logger.info(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000:.2f}M") 287 | if 'galore' in args.optimizer.lower(): 288 | logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in galore_params) / 1_000_000:.2f}M") 289 | logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps") 290 | 291 | layer_wise_flag = False 292 | if args.optimizer.lower() == "adam": 293 | optimizer = torch.optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay) 294 | elif args.optimizer.lower() == "galore_adamw": 295 | # redefine way to call galore_adamw 296 | optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) 297 | # implement sgd 298 | elif args.optimizer.lower() == "sgd": 299 | optimizer = torch.optim.SGD(trainable_params, lr=args.lr, weight_decay=args.weight_decay, momentum=args.beta1) 300 | # implement adafactor 301 | elif args.optimizer.lower() == "adafactor": 302 | args.beta1 = None if args.beta1 == 0.0 else args.beta1 303 | optimizer = transformers.optimization.Adafactor( 304 | trainable_params, 305 | lr=args.lr, 306 | eps=(1e-30, 1e-3), 307 | clip_threshold=1.0, 308 | decay_rate=-0.8, 309 | beta1=args.beta1, 310 | weight_decay=args.weight_decay, 311 | relative_step=False, 312 | scale_parameter=False, 313 | warmup_init=False, 314 | ) 315 | # low-rank adafactor 316 | elif args.optimizer.lower() == "galore_adafactor": 317 | args.beta1 = None if args.beta1 == 0.0 else args.beta1 318 | optimizer = GaLoreAdafactor( 319 | param_groups, 320 | lr=args.lr, 321 | eps=(1e-30, 1e-3), 322 | clip_threshold=1.0, 323 | decay_rate=-0.8, 324 | beta1=args.beta1, 325 | weight_decay=args.weight_decay, 326 | relative_step=False, 327 | scale_parameter=False, 328 | warmup_init=False, 329 | ) 330 | # 8-bit Adam 331 | elif args.optimizer.lower() == "adam8bit": 332 | optimizer = bnb.optim.Adam8bit(trainable_params, lr=args.lr, weight_decay=args.weight_decay) 333 | elif args.optimizer.lower() == "galore_adamw8bit": 334 | optimizer = GaLoreAdamW8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay) 335 | elif args.optimizer.lower() == 'galore_adamw8bit_per_layer': 336 | # TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap 337 | optimizer_dict = {} 338 | for p in model.parameters(): 339 | if p.requires_grad: 340 | if id(p) in id_galore_params: 341 | optimizer_dict[p] = GaLoreAdamW8bit([{'params': [p], 'rank': args.rank, 'update_proj_gap': args.update_proj_gap * 2, 'scale': args.galore_scale, 'proj_type': args.proj_type}], lr=args.lr, weight_decay=args.weight_decay) 342 | else: 343 | optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=args.lr, weight_decay=args.weight_decay) 344 | 345 | # get scheduler dict 346 | scheduler_dict = {} 347 | for p in model.parameters(): 348 | if p.requires_grad: 349 | scheduler_dict[p] = training_utils.get_scheculer( 350 | optimizer=optimizer_dict[p], 351 | scheduler_type=args.scheduler, 352 | num_training_steps=args.num_training_steps * 2, 353 | warmup_steps=args.warmup_steps * 2, 354 | min_lr_ratio=args.min_lr_ratio, 355 | ) 356 | 357 | def optimizer_hook(p): 358 | if p.grad is None: 359 | return 360 | optimizer_dict[p].step() 361 | optimizer_dict[p].zero_grad() 362 | scheduler_dict[p].step() 363 | 364 | # Register the hook onto every parameter 365 | for p in model.parameters(): 366 | if p.requires_grad: 367 | p.register_post_accumulate_grad_hook(optimizer_hook) 368 | 369 | layer_wise_flag = True 370 | 371 | else: 372 | raise ValueError(f"Optimizer {args.optimizer} not supported") 373 | 374 | if not layer_wise_flag: 375 | scheduler = training_utils.get_scheculer( 376 | optimizer=optimizer, 377 | scheduler_type=args.scheduler, 378 | num_training_steps=args.num_training_steps, 379 | warmup_steps=args.warmup_steps, 380 | min_lr_ratio=args.min_lr_ratio, 381 | ) 382 | 383 | if not args.single_gpu: 384 | model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel( 385 | model, 386 | device_ids=[local_rank], 387 | output_device=local_rank, 388 | broadcast_buffers=False, 389 | ) 390 | 391 | # global steps and others are defined above 392 | pad_idx = tokenizer.pad_token_id 393 | update_time = time.time() 394 | local_step = 0 # when continue_from is used, local_step != global_step 395 | 396 | # ############################## 397 | # TRAINING LOOP 398 | # we'll never go through all the data, so no need for epochs 399 | # ############################## 400 | 401 | for batch_idx, batch in enumerate(dataloader): 402 | 403 | global_step += 1 404 | local_step += 1 405 | 406 | if update_step > args.num_training_steps: 407 | logger.info(f"Reached max number of update steps (f{args.num_training_steps}). Stopping training.") 408 | print(f"Rank {global_rank} stopping training.") 409 | break 410 | 411 | batch = {k: v.to(device) for k, v in batch.items()} 412 | labels = batch["input_ids"].clone() 413 | labels[labels == pad_idx] = -100 414 | tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size 415 | 416 | loss = model(**batch, labels=labels).loss 417 | scaled_loss = loss / args.gradient_accumulation 418 | scaled_loss.backward() 419 | 420 | if global_step % args.gradient_accumulation != 0: 421 | continue 422 | 423 | 424 | # The below code is only executed during the update step 425 | 426 | # add grad clipping 427 | if args.grad_clipping != 0.0: torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping) 428 | 429 | if global_rank == 0: pbar.update(1) 430 | 431 | if not layer_wise_flag: 432 | optimizer.step() 433 | scheduler.step() 434 | optimizer.zero_grad() 435 | 436 | update_step += 1 437 | update_time = time.time() - update_time 438 | 439 | # save checkpoint by save_every 440 | if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0: 441 | current_model_directory = f"{args.save_dir}/model_{update_step}" 442 | logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}") 443 | os.makedirs(args.save_dir, exist_ok=True) 444 | model.module.save_pretrained(current_model_directory, max_shard_size='100GB') 445 | 446 | optimizer_checkpoint = { 447 | "optimizer": optimizer.state_dict(), 448 | "scheduler": scheduler.state_dict(), 449 | "update_step": update_step, 450 | "global_step": global_step, 451 | "config": run_config, 452 | "wandb": wandb.run.dir, 453 | "dtype": args.dtype, 454 | } 455 | torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt") 456 | 457 | training_state_checkpoint = { 458 | "global_step": global_step, 459 | "update_step": update_step, 460 | "tokens_seen": tokens_seen, 461 | "tokens_seen_before": tokens_seen_before, 462 | "update_time": update_time, 463 | } 464 | with open(f"{current_model_directory}/training_state.json", "w") as f: 465 | json.dump(training_state_checkpoint, f, indent=4) 466 | 467 | # save wandb related info 468 | wandb_info = { 469 | "wandb_id": wandb.run.id, 470 | } 471 | with open(f"{args.save_dir}/wandb.json", "w") as f: 472 | json.dump(wandb_info, f, indent=4) 473 | 474 | # evaluation 475 | if update_step % args.eval_every == 0: 476 | logger.info(f"Performing evaluation at step {update_step}") 477 | total_loss, evaluated_on_tokens = evaluate_model( 478 | model, preprocess_batched, pad_idx, global_rank, world_size, device, args.batch_size 479 | ) 480 | if global_rank == 0: 481 | wandb.log({ 482 | "final_eval_loss": total_loss, 483 | "final_eval_tokens": evaluated_on_tokens, 484 | }, 485 | step=global_step, 486 | ) 487 | logger.info(f"Eval loss at step {update_step}: {total_loss}") 488 | 489 | if not layer_wise_flag: 490 | lr = optimizer.param_groups[0]["lr"] 491 | else: 492 | lr = list(optimizer_dict.values())[0].param_groups[0]["lr"] 493 | tokens_in_update = tokens_seen - tokens_seen_before 494 | tokens_seen_before = tokens_seen 495 | batches_in_update = args.gradient_accumulation * world_size 496 | 497 | if global_rank == 0: 498 | wandb.log({ 499 | "loss": loss.item(), 500 | "lr": lr, 501 | "update_step": update_step, 502 | "tokens_seen": tokens_seen, 503 | "throughput_tokens": tokens_in_update / update_time, 504 | "throughput_examples": args.total_batch_size / update_time, 505 | "throughput_batches": batches_in_update / update_time, 506 | }, 507 | step=global_step, 508 | ) 509 | update_time = time.time() 510 | 511 | # ############################## 512 | # END of training loop 513 | # ############################## 514 | logger.info("Training finished") 515 | if global_rank == 0: pbar.close() 516 | 517 | current_model_directory = f"{args.save_dir}/model_{update_step}" 518 | if global_rank == 0 and not os.path.exists(current_model_directory): 519 | logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}") 520 | os.makedirs(args.save_dir, exist_ok=True) 521 | model.module.save_pretrained(current_model_directory) 522 | 523 | optimizer_checkpoint = { 524 | "optimizer": optimizer.state_dict(), 525 | "scheduler": scheduler.state_dict(), 526 | "update_step": update_step, 527 | "global_step": global_step, 528 | "config": run_config, 529 | "wandb": wandb.run.dir, 530 | "dtype": args.dtype, 531 | } 532 | torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt") 533 | 534 | training_state_checkpoint = { 535 | "global_step": global_step, 536 | "update_step": update_step, 537 | "tokens_seen": tokens_seen, 538 | "tokens_seen_before": tokens_seen_before, 539 | "update_time": update_time, 540 | } 541 | with open(f"{current_model_directory}/training_state.json", "w") as f: 542 | json.dump(training_state_checkpoint, f, indent=4) 543 | 544 | # Final evaluation 545 | logger.info("Running final evaluation") 546 | model.eval() 547 | del loss, optimizer, scheduler 548 | import gc; gc.collect() 549 | torch.cuda.empty_cache() 550 | 551 | total_loss, evaluated_on_tokens = evaluate_model( 552 | model, preprocess_batched, pad_idx, global_rank, world_size, device, args.batch_size 553 | ) 554 | 555 | if global_rank == 0: 556 | wandb.log({ 557 | "final_eval_loss": total_loss, 558 | "final_eval_tokens": evaluated_on_tokens, 559 | }, 560 | step=global_step, 561 | ) 562 | logger.info(f"Final eval loss: {total_loss}") 563 | 564 | logger.info("Script finished successfully") 565 | print(f"Rank {global_rank} finished successfully") 566 | 567 | 568 | if __name__ == "__main__": 569 | print("Starting script") 570 | args = parse_args(None) 571 | main(args) 572 | --------------------------------------------------------------------------------