├── image └── qalora.png ├── requirements.txt ├── merge.py ├── LICENSE ├── environment.yaml ├── README.md ├── peft_utils.py └── qalora.py /image/qalora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuhuixu1993/qa-lora/HEAD/image/qalora.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bert-score==0.3.13 2 | evaluate==0.4.0 3 | rouge-score==0.1.2 4 | scikit-learn==1.2.2 5 | sentencepiece==0.1.99 6 | wandb==0.15.2 7 | transformers==4.31.0 8 | peft==0.4.0 9 | accelerate==0.21.0 -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | model_path = 'path of the quantized model' 3 | lora_path = 'path of the saved LoRA adapters' 4 | merged_path = 'target path of the merged model' 5 | scale = 16 /64 6 | group_size = 32 7 | 8 | model = torch.load(model_path, map_location='cpu') 9 | lora = torch.load(lora_path, map_location='cpu') 10 | tmp_keys = [key[17:-14] for key in lora.keys() if 'lora_A' in key] 11 | for tmp_key in tmp_keys: 12 | model[tmp_key+'.qzeros'] -= (lora['base_model.model.'+tmp_key+'.lora_B.weight'] @ lora['base_model.model.'+tmp_key+'.lora_A.weight']).t() * scale / group_size /model[tmp_key+'.scales'] 13 | 14 | torch.save(model, merged_path) 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 YuhuiXu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: alpaca 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.05.30=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.9=h7f8727e_0 15 | - pip=23.1.2=py38h06a4308_0 16 | - python=3.8.16=h955ad1f_4 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=67.8.0=py38h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.38.4=py38h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==1.2.0 26 | - accelerate==0.21.0.dev0 27 | - addict==2.4.0 28 | - aiohttp==3.8.4 29 | - aiosignal==1.3.1 30 | - appdirs==1.4.4 31 | - async-timeout==4.0.2 32 | - attrs==22.2.0 33 | - auto-gptq==0.3.0.dev0 34 | - bert-score==0.3.13 35 | - bitsandbytes==0.39.0 36 | - certifi==2022.9.24 37 | - charset-normalizer==3.0.1 38 | - click==8.1.3 39 | - cmake==3.26.3 40 | - colorama==0.4.5 41 | - contourpy==1.0.6 42 | - cycler==0.11.0 43 | - datasets==2.12.0 44 | - dill==0.3.5.1 45 | - evaluate==0.4.0 46 | - fonttools==4.39.4 47 | - frozenlist==1.3.3 48 | - fsspec==2022.11.0 49 | - gitdb==4.0.10 50 | - gitpython==3.1.31 51 | - huggingface-hub==0.14.1 52 | - importlib-resources==5.12.0 53 | - jinja2==3.1.2 54 | - joblib==1.2.0 55 | - lazy-import==0.2.2 56 | - lit==16.0.5 57 | - lxml==4.8.0 58 | - markupsafe==2.1.2 59 | - matplotlib==3.7.1 60 | - mpmath==1.2.1 61 | - multidict==6.0.2 62 | - multiprocess==0.70.12.2 63 | - networkx==2.8.8 64 | - ninja==1.11.1 65 | - nltk==3.8.1 66 | - numpy==1.24.2 67 | - packaging==23.1 68 | - pandas==2.0.0 69 | - pathlib2==2.3.7.post1 70 | - pathtools==0.1.2 71 | - peft==0.4.0.dev0 72 | - pillow==9.3.0 73 | - protobuf==3.20.2 74 | - psutil==5.9.4 75 | - pyarrow==10.0.1 76 | - pyparsing==3.0.9 77 | - python-dateutil==2.8.2 78 | - pytz==2022.6 79 | - pyyaml==6.0 80 | - requests==2.31.0 81 | - responses==0.18.0 82 | - rouge==1.0.0 83 | - rouge-score==0.1.2 84 | - safetensors==0.3.1 85 | - scikit-learn==1.2.2 86 | - scipy==1.10.1 87 | - sentencepiece==0.1.99 88 | - sentry-sdk==1.24.0 89 | - setproctitle==1.3.2 90 | - six==1.16.0 91 | - smmap==5.0.0 92 | - sympy==1.11.1 93 | - terminaltables==3.1.10 94 | - threadpoolctl==3.1.0 95 | - tokenizers==0.13.3 96 | - torch==2.0.0+cu117 97 | - tqdm==4.63.1 98 | - transformers==4.31.0.dev0 99 | - triton==2.0.0 100 | - typing-extensions==4.6.2 101 | - tzdata==2022.7 102 | - urllib3==1.26.16 103 | - wandb==0.15.2 104 | - xformers==0.0.20 105 | - xxhash==3.0.0 106 | - yapf==0.32.0 107 | - yarl==1.8.1 108 | - zipp==3.15.0 109 | 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QA-LoRA 2 | 3 | QA-LoRA has been accepted by ICLR 2024! 4 | 5 | This repository provides the official PyTorch implementation of [QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models](https://arxiv.org/pdf/2309.14717.pdf). 6 | 7 |
8 | 9 |
10 | 11 | QA-LoRA is easily implemented with a few lines of code, and it equips the original LoRA with two-fold abilities: (i) during fine-tuning, the LLM's weights are quantized (e.g., into INT4) to reduce time and memory usage; (ii) after fine-tuning, the LLM and auxiliary weights are naturally integrated into a quantized model without loss of accuracy. 12 | 13 | ## Todo list 14 | Fix the conflict with the newest Auto-gptq version. 15 | 16 | ## Installation 17 | ```bash 18 | conda create -n qalora python=3.8 19 | conda activate qalora 20 | conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia 21 | git clone -b v0.3.0 https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ 22 | pip install . 23 | cd .. 24 | pip install bitsandbytes 25 | pip install -r requirements.txt 26 | pip install protobuf==3.20.* 27 | ``` 28 | Change the `peft_utils.py` in your own auto-gptq path(python path/auto_gptq/utils/peft_utils.py) with the new one. 29 | For the users of [GPTQLORA](https://github.com/qwopqwop200/gptqlora), you only need to change the `peft_utils.py` file. 30 | 31 | 32 | ## Quantization 33 | We use [GPTQ](https://github.com/qwopqwop200/GPTQ-for-LLaMa) for quantization. 34 | bits=4, group-size=32, act-order=False 35 | If you change the group-size, you need to change the group_size in `peft_utils.py` and `merge.py` accordingly. 36 | 37 | ## Training 38 | ```bash 39 | python qalora.py --model_path 40 | ``` 41 | 42 | The file structure of the model checkpoint is as follows: 43 | ``` 44 | config.json llama7b-4bit-32g.bin special_tokens_map.json tokenizer_config.json 45 | generation_config.json quantize_config.json tokenizer.model 46 | ``` 47 | 48 | ## Merge 49 | Note that our trained LoRA modules can be perfectly merged into the quantized model. We offer a simple merged script in this repo. 50 | 51 | ## Notice 52 | ### About the implementations 53 | There are two kinds of implementations of the dimention reduction(x from D_in to D_in//L). Both are mathematical equivalent. 54 | #### The first one(this repo) 55 | Adopt avgpooling operation. But the weights of adapters will be divided by D_in//L during merge(refer to `merge.py`). 56 | ```bash 57 | adapter_result = (lora_B(lora_A(lora_dropout(self.qa_pool(x)))) * scale).type_as(result) 58 | model[tmp_key+'.qzeros'] -= (lora['base_model.model.'+tmp_key+'.lora_B.weight'] @ lora['base_model.model.'+tmp_key+'.lora_A.weight']).t() * scale / group_size / model[tmp_key+'.scales'] 59 | ``` 60 | #### The second one 61 | Utilize sum operation. The adapters do not need to be divided during merge) 62 | 63 | ```bash 64 | adapter_result = (lora_B(lora_A(lora_dropout(self.qa_pool(x) * group_size))) * scale).type_as(result) 65 | model[tmp_key+'.qzeros'] -= (lora['base_model.model.'+tmp_key+'.lora_B.weight'] @ lora['base_model.model.'+tmp_key+'.lora_A.weight']).t() * scale / model[tmp_key+'.scales'] 66 | ``` 67 | 68 | ### About the quantization 69 | 70 | Some GPTQ implementation such as [GPTQ-for-llama](https://github.com/qwopqwop200/GPTQ-for-LLaMa) further compress the zeros into qzeros. You need to decode the qzeros first and restore fp16 format zeros. 71 | ## Acknowledgements 72 | Our code is based on [QLoRA](https://github.com/artidoro/qlora), [GPTQLORA](https://github.com/qwopqwop200/gptqlora), [Auto-GPTQ](https://github.com/PanQiWei/AutoGPTQ/tree/main) 73 | -------------------------------------------------------------------------------- /peft_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import re 3 | from contextlib import contextmanager 4 | from dataclasses import asdict 5 | from enum import Enum 6 | from typing import List, Optional 7 | 8 | import torch 9 | from peft import get_peft_model, PeftConfig, PeftModel, PeftType 10 | from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING 11 | from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel, Embedding 12 | from peft.tuners.adalora import AdaLoraConfig, AdaLoraLayer, AdaLoraModel 13 | from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING 14 | from peft.utils.other import _get_submodules 15 | 16 | from ..modeling._base import BaseGPTQForCausalLM 17 | 18 | 19 | group_size = 32 # quantization group_size 20 | 21 | 22 | class GPTQLoraConfig(LoraConfig): 23 | injected_fused_attention: bool = False 24 | injected_fused_mlp: bool = False 25 | 26 | 27 | class GPTQLoraLinear(torch.nn.Linear, LoraLayer): 28 | def __init__( 29 | self, 30 | adapter_name: str, 31 | linear_module: torch.nn.Linear, 32 | r: int = 0, 33 | lora_alpha: int = 1, 34 | lora_dropout: float = 0.0, 35 | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) 36 | **kwargs, 37 | ): 38 | init_lora_weights = kwargs.pop("init_lora_weights", True) 39 | 40 | torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features) 41 | LoraLayer.__init__(self, linear_module.in_features//group_size, linear_module.out_features) 42 | 43 | self.linear_module = linear_module 44 | 45 | self.weight.requires_grad = False 46 | self.weight = self.linear_module.weight 47 | self.bias = self.linear_module.bias 48 | self.fan_in_fan_out = fan_in_fan_out 49 | if fan_in_fan_out: 50 | self.weight.data = self.weight.data.T 51 | 52 | self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) 53 | self.active_adapter = adapter_name 54 | self.qa_pool = torch.nn.AvgPool1d(group_size) # using pooling layer to conduct sum operation 55 | 56 | def reset_lora_parameters(self, adapter_name): 57 | if adapter_name in self.lora_A.keys(): 58 | torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight) 59 | torch.nn.init.zeros_(self.lora_B[adapter_name].weight) 60 | 61 | def merge(self): 62 | raise NotImplementedError("gptq model not support merge lora adapter") 63 | 64 | def unmerge(self): 65 | raise NotImplementedError("gptq model not support unmerge lora adapter") 66 | 67 | def forward(self, x: torch.Tensor): 68 | previous_dtype = x.dtype 69 | if self.active_adapter not in self.lora_A.keys(): 70 | return self.linear_module(x) 71 | if self.disable_adapters: 72 | if self.r[self.active_adapter] > 0 and self.merged: 73 | self.unmerge() 74 | result = self.linear_module(x) 75 | elif self.r[self.active_adapter] > 0 and not self.merged: 76 | result = self.linear_module(x) 77 | 78 | lora_B = self.lora_B[self.active_adapter] 79 | lora_A = self.lora_A[self.active_adapter] 80 | lora_dropout = self.lora_dropout[self.active_adapter] 81 | scale = self.scaling[self.active_adapter] 82 | 83 | x = x.type_as(lora_A.weight.data) 84 | adapter_result = (lora_B(lora_A(lora_dropout(self.qa_pool(x)))) * scale).type_as(result) 85 | result += adapter_result 86 | else: 87 | result = self.linear_module(x) 88 | 89 | result = result.to(previous_dtype) 90 | 91 | return result 92 | 93 | 94 | class GPTQLoraModel(LoraModel): 95 | def _find_and_replace(self, adapter_name): 96 | lora_config = self.peft_config[adapter_name] 97 | is_target_modules_in_base_model = False 98 | kwargs = { 99 | "r": lora_config.r, 100 | "lora_alpha": lora_config.lora_alpha, 101 | "lora_dropout": lora_config.lora_dropout, 102 | "fan_in_fan_out": lora_config.fan_in_fan_out, 103 | "init_lora_weights": lora_config.init_lora_weights, 104 | } 105 | key_list = [key for key, _ in self.model.named_modules()] 106 | for key in key_list: 107 | if isinstance(lora_config.target_modules, str): 108 | target_module_found = re.fullmatch(lora_config.target_modules, key) 109 | else: 110 | target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) 111 | if target_module_found: 112 | if not is_target_modules_in_base_model: 113 | is_target_modules_in_base_model = True 114 | parent, target, target_name = _get_submodules(self.model, key) 115 | bias = False 116 | if hasattr(target, "bias"): 117 | bias = target.bias is not None 118 | 119 | if isinstance(target, LoraLayer): 120 | target.update_layer( 121 | adapter_name, 122 | lora_config.r, 123 | lora_config.lora_alpha, 124 | lora_config.lora_dropout, 125 | lora_config.init_lora_weights, 126 | ) 127 | else: 128 | if isinstance(target, torch.nn.Embedding): 129 | embedding_kwargs = kwargs.copy() 130 | embedding_kwargs.pop("fan_in_fan_out", None) 131 | in_features, out_features = target.num_embeddings, target.embedding_dim 132 | new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) 133 | else: 134 | if isinstance(target, torch.nn.Linear): 135 | if kwargs["fan_in_fan_out"]: 136 | warnings.warn( 137 | "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " 138 | "Setting fan_in_fan_out to False." 139 | ) 140 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False 141 | else: 142 | raise ValueError( 143 | f"Target module {target} is not supported. " 144 | f"Currently, only `torch.nn.Linear` and its subclasses are supported." 145 | ) 146 | new_module = GPTQLoraLinear(adapter_name, target, **kwargs) 147 | 148 | self._replace_module(parent, target_name, new_module, target) 149 | if not is_target_modules_in_base_model: 150 | raise ValueError( 151 | f"Target modules {lora_config.target_modules} not found in the base model. " 152 | f"Please check the target modules and try again." 153 | ) 154 | 155 | def _replace_module(self, parent_module, child_name, new_module, old_module): 156 | setattr(parent_module, child_name, new_module) 157 | if not isinstance(new_module, GPTQLoraLinear): 158 | new_module.weight = old_module.weight 159 | if hasattr(old_module, "bias"): 160 | if old_module.bias is not None: 161 | new_module.bias = old_module.bias 162 | 163 | if getattr(old_module, "state", None) is not None: 164 | new_module.state = old_module.state 165 | new_module.to(old_module.weight.device) 166 | 167 | # dispatch to correct device 168 | for name, module in new_module.named_modules(): 169 | if "lora_" in name: 170 | module.to(old_module.weight.device) 171 | 172 | def merge_adapter(self): 173 | raise NotImplementedError("gptq model not support merge ada lora adapter") 174 | 175 | def unmerge_adapter(self): 176 | raise NotImplementedError("gptq model not support unmerge ada lora adapter") 177 | 178 | def merge_and_unload(self): 179 | raise NotImplementedError("gptq model not support merge and unload") 180 | 181 | 182 | class GPTQAdaLoraConfig(AdaLoraConfig): 183 | injected_fused_attention: bool = False 184 | injected_fused_mlp: bool = False 185 | 186 | 187 | class GPTQSVDLinear(torch.nn.Linear, AdaLoraLayer): 188 | def __init__( 189 | self, 190 | adapter_name: str, 191 | linear_module: torch.nn.Linear, 192 | r: int = 0, 193 | lora_alpha: int = 1, 194 | lora_dropout: float = 0.0, 195 | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) 196 | **kwargs, 197 | ): 198 | init_lora_weights = kwargs.pop("init_lora_weights", True) 199 | 200 | torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features) 201 | AdaLoraLayer.__init__(self, linear_module.in_features, linear_module.out_features) 202 | 203 | self.linear_module = linear_module 204 | 205 | self.weight.requires_grad = False 206 | self.weight = self.linear_module.weight 207 | self.bias = self.linear_module.bias 208 | self.fan_in_fan_out = fan_in_fan_out 209 | if fan_in_fan_out: 210 | self.weight.data = self.weight.data.T 211 | 212 | self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) 213 | self.active_adapter = adapter_name 214 | 215 | def merge(self): 216 | raise NotImplementedError("gptq model not support merge lora adapter") 217 | 218 | def unmerge(self): 219 | raise NotImplementedError("gptq model not support unmerge lora adapter") 220 | 221 | def forward(self, x: torch.Tensor): 222 | if self.active_adapter not in self.lora_A.keys(): 223 | return self.linear_module(x) 224 | if self.disable_adapters: 225 | if self.r[self.active_adapter] > 0 and self.merged: 226 | self.unmerge() 227 | result = self.linear_module(x) 228 | elif self.r[self.active_adapter] > 0 and not self.merged: 229 | result = self.linear_module(x) 230 | result += ( 231 | ( 232 | self.lora_dropout[self.active_adapter](x) 233 | @ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T 234 | @ self.lora_B[self.active_adapter].T 235 | ) 236 | * self.scaling[self.active_adapter] 237 | / (self.ranknum[self.active_adapter] + 1e-5) 238 | ) 239 | else: 240 | result = self.linear_module(x) 241 | return result 242 | 243 | 244 | class GPTQAdaLoraModel(AdaLoraModel): 245 | def _find_and_replace(self, adapter_name): 246 | lora_config = self.peft_config[adapter_name] 247 | is_target_modules_in_base_model = False 248 | kwargs = { 249 | "r": lora_config.init_r, 250 | "lora_alpha": lora_config.lora_alpha, 251 | "lora_dropout": lora_config.lora_dropout, 252 | "fan_in_fan_out": lora_config.fan_in_fan_out, 253 | "init_lora_weights": lora_config.init_lora_weights, 254 | } 255 | key_list = [key for key, _ in self.model.named_modules()] 256 | for key in key_list: 257 | if isinstance(lora_config.target_modules, str): 258 | target_module_found = re.fullmatch(lora_config.target_modules, key) 259 | else: 260 | target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) 261 | if target_module_found: 262 | if not is_target_modules_in_base_model: 263 | is_target_modules_in_base_model = True 264 | parent, target, target_name = _get_submodules(self.model, key) 265 | bias = target.bias is not None 266 | if isinstance(target, LoraLayer): 267 | target.update_layer( 268 | adapter_name, 269 | lora_config.init_r, 270 | lora_config.lora_alpha, 271 | lora_config.lora_dropout, 272 | lora_config.init_lora_weights, 273 | ) 274 | else: 275 | if isinstance(target, torch.nn.Linear): 276 | in_features, out_features = target.in_features, target.out_features 277 | if kwargs["fan_in_fan_out"]: 278 | warnings.warn( 279 | "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " 280 | "Setting fan_in_fan_out to False." 281 | ) 282 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False 283 | else: 284 | raise ValueError( 285 | f"Target module {target} is not supported. " 286 | f"Currently, only `torch.nn.Linear` and its subclasses are supported." 287 | ) 288 | new_module = GPTQSVDLinear(adapter_name, target, **kwargs) 289 | 290 | self._replace_module(parent, target_name, new_module, target) 291 | if not is_target_modules_in_base_model: 292 | raise ValueError( 293 | f"Target modules {lora_config.target_modules} not found in the base model. " 294 | f"Please check the target modules and try again." 295 | ) 296 | 297 | def _replace_module(self, parent_module, child_name, new_module, old_module): 298 | setattr(parent_module, child_name, new_module) 299 | 300 | # dispatch to correct device 301 | for name, module in new_module.named_modules(): 302 | if "lora_" in name: 303 | module.to(old_module.weight.device) 304 | 305 | def merge_adapter(self): 306 | raise NotImplementedError("gptq model not support merge ada lora adapter") 307 | 308 | def unmerge_adapter(self): 309 | raise NotImplementedError("gptq model not support unmerge ada lora adapter") 310 | 311 | def merge_and_unload(self): 312 | raise NotImplementedError("gptq model not support merge and unload") 313 | 314 | 315 | def find_all_linear_names(model: BaseGPTQForCausalLM, ignore: Optional[List[str]] = None, ignore_lm_head: bool = True): 316 | if not ignore: 317 | ignore = [] 318 | lm_head_name = model.lm_head_name 319 | if ignore_lm_head and lm_head_name not in ignore: 320 | ignore.append(lm_head_name) 321 | results = set() 322 | for n, m in model.named_modules(): 323 | if isinstance(m, torch.nn.Linear): 324 | res = n.split('.')[-1] 325 | if res not in ignore: 326 | results.add(res) 327 | return list(results) 328 | 329 | 330 | @contextmanager 331 | def hijack_peft_mappings(): 332 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig 333 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel 334 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig 335 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel 336 | 337 | try: 338 | yield 339 | except: 340 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig 341 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel 342 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig 343 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel 344 | raise 345 | finally: 346 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig 347 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel 348 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig 349 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel 350 | 351 | 352 | def get_gptq_peft_model( 353 | model: BaseGPTQForCausalLM, 354 | peft_config: PeftConfig = None, 355 | model_id: str = None, 356 | adapter_name: str = "default", 357 | auto_find_all_linears: bool = True, 358 | train_mode: bool = False 359 | ): 360 | if train_mode and not model.trainable: 361 | model.enable_trainable_mode() 362 | if train_mode and not peft_config: 363 | raise ValueError("peft_config not specified when in train mode.") 364 | if not train_mode and not model_id: 365 | raise ValueError("model_id(where to load adapters) not specified when in inference mode.") 366 | 367 | if model.fused_attn_module_type is not None and not model.injected_fused_attention: 368 | peft_types = [PeftType.LORA.value, PeftType.ADALORA.value] 369 | warnings.warn( 370 | f"You can just ignore this warning if the peft type you use isn't in {peft_types}.\n" 371 | f"{model.__class__.__name__} supports injecting fused attention but not enables this time. " 372 | "If you are training adapters, you must also disable fused attention injection when loading quantized " 373 | "base model at inference time, otherwise adapters may not be added to base model properly. " 374 | "If you are loading adapters to do inference, you can reference to adapter's config file to check " 375 | "whether the adapters are trained using base model that not enable fused attention injection." 376 | ) 377 | if model.injected_fused_mlp: 378 | raise NotImplementedError("GPTQ model that enables fused mlp injection is not supported to integrate with peft.") 379 | 380 | if train_mode: 381 | peft_type = peft_config.peft_type 382 | if not isinstance(peft_type, str): 383 | peft_type = peft_type.value 384 | if peft_type in [PeftType.LORA.value, PeftType.ADALORA.value]: 385 | if auto_find_all_linears: 386 | peft_config.target_modules = find_all_linear_names(model, ignore_lm_head=True) 387 | if peft_type == PeftType.LORA.value and not isinstance(peft_config, GPTQLoraConfig): 388 | peft_config = GPTQLoraConfig(**peft_config.to_dict()) 389 | if peft_type == PeftType.ADALORA.value and not isinstance(peft_config, GPTQAdaLoraConfig): 390 | peft_config = GPTQAdaLoraConfig(**peft_config.to_dict()) 391 | peft_config.injected_fused_attention = model.injected_fused_attention 392 | peft_config.injected_fused_mlp = model.injected_fused_mlp 393 | if peft_type == PeftType.ADAPTION_PROMPT.value: 394 | if peft_config.adapter_layers > model.config.num_hidden_layers: 395 | warnings.warn( 396 | f"model has only {model.config.num_hidden_layers} layers " 397 | f"but adapter_layers is set to {peft_config.adapter_layers}, " 398 | f"will reset value to {model.config.num_hidden_layers}." 399 | ) 400 | peft_config.adapter_layers = model.config.num_hidden_layers 401 | if model.injected_fused_attention: 402 | raise NotImplementedError( 403 | "model with fused attention injected isn't supported to use ADAPTION_PROMPT peft type yet." 404 | ) 405 | 406 | with hijack_peft_mappings(): 407 | try: 408 | if train_mode: 409 | peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name) 410 | else: 411 | peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name) 412 | except: 413 | raise NotImplementedError( 414 | f"{model.__class__.__name__} not support {peft_config.peft_type.value} peft type yet." 415 | ) 416 | 417 | return peft_model 418 | 419 | 420 | __all__ = [ 421 | "GPTQLoraConfig", 422 | "GPTQLoraModel", 423 | "GPTQAdaLoraConfig", 424 | "GPTQAdaLoraModel", 425 | "find_all_linear_names", 426 | "get_gptq_peft_model" 427 | ] 428 | -------------------------------------------------------------------------------- /qalora.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | from collections import defaultdict 5 | import copy 6 | import json 7 | import os 8 | from os.path import exists, join, isdir 9 | from dataclasses import dataclass, field 10 | import sys 11 | from typing import Optional, Dict, Sequence 12 | import numpy as np 13 | from tqdm import tqdm 14 | import logging 15 | 16 | import torch 17 | import transformers 18 | from torch.nn.utils.rnn import pad_sequence 19 | import argparse 20 | from transformers import ( 21 | AutoTokenizer, 22 | AutoModelForCausalLM, 23 | set_seed, 24 | Seq2SeqTrainer, 25 | LlamaTokenizerFast 26 | ) 27 | from datasets import load_dataset 28 | import evaluate 29 | 30 | from peft import ( 31 | LoraConfig, 32 | get_peft_model_state_dict, 33 | set_peft_model_state_dict, 34 | PeftModel 35 | ) 36 | from peft.tuners.lora import LoraLayer 37 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 38 | from auto_gptq.utils.peft_utils import get_gptq_peft_model, GPTQLoraConfig 39 | from auto_gptq import AutoGPTQForCausalLM 40 | from auto_gptq.nn_modules.qlinear import GeneralQuantLinear 41 | 42 | torch.backends.cuda.matmul.allow_tf32 = True 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | IGNORE_INDEX = -100 47 | DEFAULT_PAD_TOKEN = "[PAD]" 48 | 49 | import os 50 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 51 | 52 | def prepare_model_for_int8_training(model, use_gradient_checkpointing=True): 53 | r""" 54 | This method wraps the entire protocol for preparing a model before running a training. This includes: 55 | 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm 56 | head to fp32 57 | 58 | Args: 59 | model, (`transformers.PreTrainedModel`): 60 | The loaded model from `transformers` 61 | """ 62 | for name, param in model.named_parameters(): 63 | # freeze base model's layers 64 | param.requires_grad = False 65 | 66 | if use_gradient_checkpointing: 67 | # For backward compatibility 68 | if hasattr(model, "enable_input_require_grads"): 69 | model.enable_input_require_grads() 70 | else: 71 | 72 | def make_inputs_require_grad(module, input, output): 73 | output.requires_grad_(True) 74 | 75 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 76 | 77 | # enable gradient checkpointing for memory efficiency 78 | model.gradient_checkpointing_enable() 79 | 80 | model.lm_head = model.lm_head.float() 81 | for _, param in model.named_parameters(): 82 | if param.dtype == torch.float16: 83 | param = param.float() 84 | 85 | return model 86 | 87 | @dataclass 88 | class ModelArguments: 89 | model_path: Optional[str] = field( 90 | default="./llama-7b/" 91 | ) 92 | trust_remote_code: Optional[bool] = field( 93 | default=False, 94 | metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} 95 | ) 96 | 97 | @dataclass 98 | class DataArguments: 99 | eval_dataset_size: int = field( 100 | default=1024, metadata={"help": "Size of validation dataset."} 101 | ) 102 | max_train_samples: Optional[int] = field( 103 | default=None, 104 | metadata={ 105 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 106 | "value if set." 107 | }, 108 | ) 109 | max_eval_samples: Optional[int] = field( 110 | default=None, 111 | metadata={ 112 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 113 | "value if set." 114 | }, 115 | ) 116 | source_max_len: int = field( 117 | default=1024, 118 | metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."}, 119 | ) 120 | target_max_len: int = field( 121 | default=256, 122 | metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, 123 | ) 124 | dataset: str = field( 125 | default='alpaca', 126 | metadata={"help": "Which dataset to finetune on. See datamodule for options."} 127 | ) 128 | 129 | @dataclass 130 | class TrainingArguments(transformers.Seq2SeqTrainingArguments): 131 | cache_dir: Optional[str] = field( 132 | default=None 133 | ) 134 | train_on_source: Optional[bool] = field( 135 | default=False, 136 | metadata={"help": "Whether to train on the input in addition to the target text."} 137 | ) 138 | mmlu_split: Optional[str] = field( 139 | default='eval', 140 | metadata={"help": "The MMLU split to run on"} 141 | ) 142 | mmlu_dataset: Optional[str] = field( 143 | default='mmlu-fs', 144 | metadata={"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."} 145 | ) 146 | do_mmlu_eval: Optional[bool] = field( 147 | default=False, 148 | metadata={"help": "Whether to run the MMLU evaluation."} 149 | ) 150 | max_mmlu_samples: Optional[int] = field( 151 | default=None, 152 | metadata={"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."} 153 | ) 154 | mmlu_source_max_len: int = field( 155 | default=2048, 156 | metadata={"help": "Maximum source sequence length for mmlu."} 157 | ) 158 | full_finetune: bool = field( 159 | default=False, 160 | metadata={"help": "Finetune the entire model without adapters."} 161 | ) 162 | adam8bit: bool = field( 163 | default=False, 164 | metadata={"help": "Use 8-bit adam."} 165 | ) 166 | lora_r: int = field( 167 | default=64, 168 | metadata={"help": "Lora R dimension."} 169 | ) 170 | lora_alpha: float = field( 171 | default=16, 172 | metadata={"help": " Lora alpha."} 173 | ) 174 | lora_dropout: float = field( 175 | default=0.0, 176 | metadata={"help":"Lora dropout."} 177 | ) 178 | max_memory_MB: int = field( 179 | default=24000, 180 | metadata={"help": "Free memory per gpu."} 181 | ) 182 | report_to: str = field( 183 | default='none', 184 | metadata={"help": "To use wandb or something else for reporting."} 185 | ) 186 | output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) 187 | optim: str = field(default='paged_adamw_32bit', metadata={"help": 'The optimizer to be used'}) 188 | per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'}) 189 | gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'}) 190 | max_steps: int = field(default=10000, metadata={"help": 'How many optimizer update steps to take'}) 191 | weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) # use lora dropout instead for regularization if needed 192 | learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'}) 193 | remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'}) 194 | max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'}) 195 | gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'}) 196 | do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'}) 197 | lr_scheduler_type: str = field(default='constant', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'}) 198 | warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'}) 199 | logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) 200 | group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) 201 | save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'}) 202 | save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) 203 | save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) 204 | 205 | @dataclass 206 | class GenerationArguments: 207 | # For more hyperparameters check: 208 | # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig 209 | # Length arguments 210 | max_new_tokens: Optional[int] = field( 211 | default=256, 212 | metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops" 213 | "if predict_with_generate is set."} 214 | ) 215 | min_new_tokens : Optional[int] = field( 216 | default=None, 217 | metadata={"help": "Minimum number of new tokens to generate."} 218 | ) 219 | 220 | # Generation strategy 221 | do_sample: Optional[bool] = field(default=False) 222 | num_beams: Optional[int] = field(default=1) 223 | num_beam_groups: Optional[int] = field(default=1) 224 | penalty_alpha: Optional[float] = field(default=None) 225 | use_cache: Optional[bool] = field(default=True) 226 | 227 | # Hyperparameters for logit manipulation 228 | temperature: Optional[float] = field(default=1.0) 229 | top_k: Optional[int] = field(default=50) 230 | top_p: Optional[float] = field(default=1.0) 231 | typical_p: Optional[float] = field(default=1.0) 232 | diversity_penalty: Optional[float] = field(default=0.0) 233 | repetition_penalty: Optional[float] = field(default=1.0) 234 | length_penalty: Optional[float] = field(default=1.0) 235 | no_repeat_ngram_size: Optional[int] = field(default=0) 236 | 237 | def find_all_linear_names(args, model): 238 | cls = GeneralQuantLinear if not(args.full_finetune) else torch.nn.Linear 239 | lora_module_names = set() 240 | for name, module in model.named_modules(): 241 | if isinstance(module, cls): 242 | names = name.split('.') 243 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 244 | 245 | 246 | if 'lm_head' in lora_module_names: # needed for 16-bit 247 | lora_module_names.remove('lm_head') 248 | return list(lora_module_names) 249 | 250 | 251 | class SavePeftModelCallback(transformers.TrainerCallback): 252 | def save_model(self, args, state, kwargs): 253 | print('Saving PEFT checkpoint...') 254 | if state.best_model_checkpoint is not None: 255 | checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model") 256 | else: 257 | checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") 258 | 259 | peft_model_path = os.path.join(checkpoint_folder, "adapter_model") 260 | kwargs["model"].save_pretrained(peft_model_path) 261 | 262 | pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") 263 | if os.path.exists(pytorch_model_path): 264 | os.remove(pytorch_model_path) 265 | 266 | def on_save(self, args, state, control, **kwargs): 267 | self.save_model(args, state, kwargs) 268 | return control 269 | 270 | def on_train_end(self, args, state, control, **kwargs): 271 | def touch(fname, times=None): 272 | with open(fname, 'a'): 273 | os.utime(fname, times) 274 | 275 | touch(join(args.output_dir, 'completed')) 276 | self.save_model(args, state, kwargs) 277 | 278 | def get_accelerate_model(args, checkpoint_dir): 279 | 280 | n_gpus = torch.cuda.device_count() 281 | max_memory = f'{args.max_memory_MB}MB' 282 | max_memory = {i: max_memory for i in range(n_gpus)} 283 | 284 | if args.full_finetune: assert args.bits in [16, 32] 285 | 286 | print(f'loading base model {args.model_path}...') 287 | model = AutoGPTQForCausalLM.from_quantized( 288 | args.model_path, 289 | device_map='auto', 290 | max_memory=max_memory, 291 | trust_remote_code=args.trust_remote_code, 292 | inject_fused_attention = False, 293 | inject_fused_mlp = False, 294 | use_triton=True, 295 | warmup_triton=False, 296 | trainable=True 297 | ) 298 | model.model.quantize_config = model.quantize_config 299 | model.train() 300 | 301 | setattr(model, 'model_parallel', True) 302 | setattr(model, 'is_parallelizable', True) 303 | #modules = find_all_linear_names(args, model) 304 | 305 | model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) 306 | 307 | if not args.full_finetune: 308 | model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing) 309 | if args.gradient_checkpointing: 310 | model.gradient_checkpointing_enable() 311 | 312 | config = GPTQLoraConfig( 313 | r=args.lora_r, 314 | lora_alpha=args.lora_alpha, 315 | #target_modules=modules, 316 | lora_dropout=args.lora_dropout, 317 | bias="none", 318 | task_type="CAUSAL_LM", 319 | ) 320 | if not args.full_finetune: 321 | if checkpoint_dir is not None: 322 | print("Loading adapters from checkpoint.") 323 | model = PeftModel.from_pretrained(model, join(checkpoint_dir, 'adapter_model')) 324 | for name, p in model.named_parameters(): 325 | if 'lora' in name: 326 | print(name, p.sum()) 327 | else: 328 | print(f'adding LoRA modules...') 329 | model = get_gptq_peft_model(model, config, auto_find_all_linears=True, train_mode=True) 330 | 331 | if args.gradient_checkpointing: 332 | if hasattr(model, "enable_input_require_grads"): 333 | model.enable_input_require_grads() 334 | else: 335 | def make_inputs_require_grad(module, input, output): 336 | output.requires_grad_(True) 337 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 338 | 339 | 340 | for name, module in model.named_modules(): 341 | if isinstance(module, LoraLayer): 342 | if args.bf16: 343 | module = module.to(torch.bfloat16) 344 | if 'norm' in name: 345 | module = module.to(torch.float32) 346 | if 'lm_head' in name or 'embed_tokens' in name: 347 | if hasattr(module, 'weight'): 348 | if args.bf16 and module.weight.dtype == torch.float32: 349 | module = module.to(torch.bfloat16) 350 | return model 351 | 352 | def print_trainable_parameters(args, model): 353 | """ 354 | Prints the number of trainable parameters in the model. 355 | """ 356 | trainable_params = 0 357 | all_param = 0 358 | for _, param in model.named_parameters(): 359 | all_param += param.numel() 360 | if param.requires_grad: 361 | trainable_params += param.numel() 362 | try: 363 | trainable_params /= (32//model.quantize_config.bits) 364 | except: 365 | pass 366 | print(f"trainable params: {trainable_params} || all params: {all_param} || trainable: {100 * trainable_params / all_param}") 367 | 368 | def smart_tokenizer_and_embedding_resize( 369 | special_tokens_dict: Dict, 370 | tokenizer: transformers.PreTrainedTokenizer, 371 | model: transformers.PreTrainedModel, 372 | ): 373 | """Resize tokenizer and embedding. 374 | 375 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 376 | """ 377 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 378 | model.resize_token_embeddings(len(tokenizer)) 379 | 380 | if num_new_tokens > 0: 381 | input_embeddings = model.get_input_embeddings().weight.data 382 | output_embeddings = model.get_output_embeddings().weight.data 383 | 384 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 385 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 386 | 387 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 388 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 389 | 390 | @dataclass 391 | class DataCollatorForCausalLM(object): 392 | tokenizer: transformers.PreTrainedTokenizer 393 | source_max_len: int 394 | target_max_len: int 395 | train_on_source: bool 396 | predict_with_generate: bool 397 | 398 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 399 | # Extract elements 400 | sources = [example['input'] for example in instances] 401 | targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances] 402 | # Tokenize 403 | tokenized_sources_with_prompt = self.tokenizer( 404 | sources, 405 | max_length=self.source_max_len, 406 | truncation=True, 407 | ) 408 | tokenized_targets = self.tokenizer( 409 | targets, 410 | max_length=self.target_max_len, 411 | truncation=True, 412 | add_special_tokens=False, 413 | ) 414 | # Build the input and labels for causal LM 415 | input_ids = [] 416 | labels = [] 417 | for tokenized_source, tokenized_target in zip( 418 | tokenized_sources_with_prompt['input_ids'], 419 | tokenized_targets['input_ids'] 420 | ): 421 | if not self.predict_with_generate: 422 | input_ids.append(torch.tensor(tokenized_source + tokenized_target)) 423 | if not self.train_on_source: 424 | labels.append( 425 | torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target)) 426 | ) 427 | else: 428 | labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target))) 429 | else: 430 | input_ids.append(torch.tensor(tokenized_source)) 431 | # Apply padding 432 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 433 | labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None 434 | data_dict = { 435 | 'input_ids': input_ids, 436 | 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id), 437 | } 438 | if labels is not None: 439 | data_dict['labels'] = labels 440 | return data_dict 441 | 442 | def extract_unnatural_instructions_data(examples, extract_reformulations=False): 443 | out = { 444 | 'input': [], 445 | 'output': [], 446 | } 447 | for example_instances in examples['instances']: 448 | for instance in example_instances: 449 | out['input'].append(instance['instruction_with_input']) 450 | out['output'].append(instance['output']) 451 | if extract_reformulations: 452 | for example_reformulations in examples['reformulations']: 453 | if example_reformulations is not None: 454 | for instance in example_reformulations: 455 | out['input'].append(instance['instruction_with_input']) 456 | out['output'].append(instance['output']) 457 | return out 458 | 459 | PROMPT_DICT = { 460 | "prompt_input": ( 461 | "Below is an instruction that describes a task, paired with an input that provides further context. " 462 | "Write a response that appropriately completes the request.\n\n" 463 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: " 464 | ), 465 | "prompt_no_input": ( 466 | "Below is an instruction that describes a task. " 467 | "Write a response that appropriately completes the request.\n\n" 468 | "### Instruction:\n{instruction}\n\n### Response: " 469 | ), 470 | } 471 | 472 | def extract_alpaca_dataset(example): 473 | if example.get("input", "") != "": 474 | prompt_format = PROMPT_DICT["prompt_input"] 475 | else: 476 | prompt_format = PROMPT_DICT["prompt_no_input"] 477 | return {'input': prompt_format.format(**example)} 478 | 479 | def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict: 480 | """ 481 | Make dataset and collator for supervised fine-tuning. 482 | Datasets are expected to have the following columns: { `input`, `output` } 483 | 484 | Available datasets to be selected with `dataset` argument: 485 | - alpaca, 52002 examples 486 | - alpaca cleaned, 51942 examples 487 | - chip2 (OIG), 210289 examples 488 | - self-instruct, 82612 examples 489 | - hh-rlhf (Anthropic), 160800 examples 490 | - longform, 23.7k examples 491 | 492 | Coming soon: 493 | - unnatural instructions core, 66010 examples 494 | - unnatural instructions full, 240670 examples 495 | - alpaca-gpt4, 52002 examples 496 | - unnatural-instructions-gpt4, 9000 examples 497 | - oa-rlhf (OpenAssistant) primary message tree only, 9209 examples 498 | - oa-rlhf-assistant (OpenAssistant) all assistant replies with ranking 499 | - supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used) 500 | - flan (FLAN v2), up to 20M examples available 501 | 502 | Not Available: 503 | - vicuna, not released at the moment. 504 | """ 505 | # Load dataset. 506 | # Alpaca 507 | if args.dataset == 'alpaca': 508 | dataset = load_dataset("tatsu-lab/alpaca") 509 | dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction']) 510 | # Alpaca clean 511 | elif args.dataset == 'alpaca-clean': 512 | dataset = load_dataset("yahma/alpaca-cleaned") 513 | dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction']) 514 | # Chip2 515 | elif args.dataset == 'chip2': 516 | dataset = load_dataset("laion/OIG", data_files='unified_chip2.jsonl') 517 | dataset = dataset.map(lambda x: { 518 | 'input': x['text'].split('\n: ')[0].replace(': ', ''), 519 | 'output': x['text'].split('\n: ')[1], 520 | }, remove_columns=['text', 'metadata']) 521 | # Self Instruct 522 | elif args.dataset == 'self-instruct': 523 | dataset = load_dataset("yizhongw/self_instruct", name='self_instruct') 524 | for old, new in [["prompt", "input"], ["completion", "output"]]: 525 | dataset = dataset.rename_column(old, new) 526 | # Anthropic rlhf 527 | elif args.dataset == 'hh-rlhf': 528 | dataset = load_dataset("Anthropic/hh-rlhf") 529 | dataset = dataset.map(lambda x: { 530 | 'input': '', 531 | 'output': x['chosen'] 532 | }, remove_columns=['chosen', 'rejected']) 533 | # LongForm 534 | elif args.dataset == 'longform': 535 | dataset = load_dataset("akoksal/LongForm") 536 | elif args.dataset == 'vicuna': 537 | raise NotImplementedError("Vicuna data was not released.") 538 | else: 539 | raise NotImplementedError(f"Dataset {args.dataset} not implemented yet.") 540 | 541 | # Split train/eval, reduce size 542 | if args.do_eval or args.do_predict: 543 | if 'eval' in dataset: 544 | eval_dataset = dataset['eval'] 545 | else: 546 | print('Splitting train dataset in train and validation according to `eval_dataset_size`') 547 | dataset = dataset["train"].train_test_split( 548 | test_size=args.eval_dataset_size, shuffle=True, seed=42 549 | ) 550 | eval_dataset = dataset['test'] 551 | if args.max_eval_samples is not None and len(eval_dataset) > args.max_eval_samples: 552 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 553 | if args.group_by_length: 554 | eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) 555 | if args.do_train: 556 | train_dataset = dataset['train'] 557 | if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples: 558 | train_dataset = train_dataset.select(range(args.max_train_samples)) 559 | if args.group_by_length: 560 | train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) 561 | 562 | data_collator = DataCollatorForCausalLM( 563 | tokenizer=tokenizer, 564 | source_max_len=args.source_max_len, 565 | target_max_len=args.target_max_len, 566 | train_on_source=args.train_on_source, 567 | predict_with_generate=args.predict_with_generate, 568 | ) 569 | return dict( 570 | train_dataset=train_dataset if args.do_train else None, 571 | eval_dataset=eval_dataset if args.do_eval else None, 572 | predict_dataset=eval_dataset if args.do_predict else None, 573 | data_collator=data_collator 574 | ) 575 | 576 | def get_last_checkpoint(checkpoint_dir): 577 | if isdir(checkpoint_dir): 578 | is_completed = exists(join(checkpoint_dir, 'completed')) 579 | if is_completed: return None, True # already finished 580 | max_step = 0 581 | for filename in os.listdir(checkpoint_dir): 582 | if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'): 583 | max_step = max(max_step, int(filename.replace('checkpoint-', ''))) 584 | if max_step == 0: return None, is_completed # training started, but no checkpoint 585 | checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}') 586 | print(f"Found a previous checkpoint at: {checkpoint_dir}") 587 | return checkpoint_dir, is_completed # checkpoint found! 588 | return None, False # first training 589 | 590 | def train(): 591 | hfparser = transformers.HfArgumentParser(( 592 | ModelArguments, DataArguments, TrainingArguments, GenerationArguments 593 | )) 594 | model_args, data_args, training_args, generation_args, extra_args = \ 595 | hfparser.parse_args_into_dataclasses(return_remaining_strings=True) 596 | training_args.generation_config = transformers.GenerationConfig(**vars(generation_args)) 597 | args = argparse.Namespace( 598 | **vars(model_args), **vars(data_args), **vars(training_args) 599 | ) 600 | 601 | 602 | checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir) 603 | if completed_training: 604 | print('Detected that training was already completed!') 605 | 606 | model = get_accelerate_model(args, checkpoint_dir) 607 | training_args.skip_loading_checkpoint_weights=True 608 | 609 | resume_from_checkpoint = checkpoint_dir 610 | if resume_from_checkpoint: 611 | # Check the available weights and load them 612 | checkpoint_name = os.path.join( 613 | checkpoint_dir, "pytorch_model.bin" 614 | ) # Full checkpoint 615 | if not os.path.exists(checkpoint_name): 616 | checkpoint_path = os.path.join( 617 | checkpoint_dir, "adapter_model" 618 | ) 619 | 620 | checkpoint_name = os.path.join( 621 | checkpoint_path, "adapter_model.bin" 622 | ) # only LoRA model - LoRA config above has to fit 623 | resume_from_checkpoint = ( 624 | False # So the trainer won't try loading its state 625 | ) 626 | # The two files above have a different name depending on how they were saved, but are actually the same. 627 | if os.path.exists(checkpoint_name): 628 | print(f"Restarting from {checkpoint_name}") 629 | adapters_weights = torch.load(checkpoint_name) 630 | set_peft_model_state_dict(model, adapters_weights) 631 | else: 632 | print(f"Checkpoint {checkpoint_name} not found") 633 | 634 | model.config.use_cache = False 635 | print_trainable_parameters(args, model) 636 | print('loaded model') 637 | set_seed(args.seed) 638 | 639 | # Tokenizer 640 | tokenizer = AutoTokenizer.from_pretrained( 641 | args.model_path, 642 | cache_dir=args.cache_dir, 643 | padding_side="right", 644 | use_fast=True, 645 | ) 646 | 647 | if tokenizer.pad_token is None: 648 | smart_tokenizer_and_embedding_resize( 649 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 650 | tokenizer=tokenizer, 651 | model=model, 652 | ) 653 | 654 | if isinstance(tokenizer, LlamaTokenizerFast): 655 | # LLaMA tokenizer may not have correct special tokens set. 656 | # Check and add them if missing to prevent them from being parsed into different tokens. 657 | # Note that these are present in the vocabulary. 658 | # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. 659 | tokenizer.add_special_tokens( 660 | { 661 | "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), 662 | "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), 663 | "unk_token": tokenizer.convert_ids_to_tokens(model.config.pad_token_id), 664 | } 665 | ) 666 | 667 | data_module = make_data_module(tokenizer=tokenizer, args=args) 668 | trainer = Seq2SeqTrainer( 669 | model=model, 670 | tokenizer=tokenizer, 671 | args=training_args, 672 | **{k:v for k,v in data_module.items() if k != 'predict_dataset'}, 673 | ) 674 | 675 | # Callbacks 676 | if not args.full_finetune: 677 | trainer.add_callback(SavePeftModelCallback) 678 | if args.do_mmlu_eval: 679 | if args.mmlu_dataset == 'mmlu-zs': 680 | mmlu_dataset = load_dataset("json", data_files={ 681 | 'eval': 'data/mmlu/zero_shot_mmlu_val.json', 682 | 'test': 'data/mmlu/zero_shot_mmlu_test.json', 683 | }) 684 | mmlu_dataset = mmlu_dataset.remove_columns('subject') 685 | # MMLU Five-shot (Eval/Test only) 686 | elif args.mmlu_dataset == 'mmlu' or args.mmlu_dataset == 'mmlu-fs': 687 | mmlu_dataset = load_dataset("json", data_files={ 688 | 'eval': 'data/mmlu/five_shot_mmlu_val.json', 689 | 'test': 'data/mmlu/five_shot_mmlu_test.json', 690 | }) 691 | # mmlu_dataset = mmlu_dataset.remove_columns('subject') 692 | mmlu_dataset = mmlu_dataset[args.mmlu_split] 693 | if args.max_mmlu_samples is not None: 694 | mmlu_dataset = mmlu_dataset.select(range(args.max_mmlu_samples)) 695 | abcd_idx = [ 696 | tokenizer("A", add_special_tokens=False).input_ids[0], 697 | tokenizer("B", add_special_tokens=False).input_ids[0], 698 | tokenizer("C", add_special_tokens=False).input_ids[0], 699 | tokenizer("D", add_special_tokens=False).input_ids[0], 700 | ] 701 | accuracy = evaluate.load("accuracy") 702 | 703 | class MMLUEvalCallback(transformers.TrainerCallback): 704 | def on_evaluate(self, args, state, control, model, **kwargs): 705 | data_loader = trainer.get_eval_dataloader(mmlu_dataset) 706 | source_max_len = trainer.data_collator.source_max_len 707 | trainer.data_collator.source_max_len = args.mmlu_source_max_len 708 | trainer.model.eval() 709 | preds, refs = [], [] 710 | loss_mmlu = 0 711 | for batch in tqdm(data_loader, total=len(data_loader)): 712 | (loss, logits, labels) = trainer.prediction_step(trainer.model,batch,prediction_loss_only=False,) 713 | # There are two tokens, the output, and eos token. 714 | for i, logit in enumerate(logits): 715 | label_non_zero_id = (batch['labels'][i] != -100).nonzero()[0][0] 716 | logit_abcd = logit[label_non_zero_id-1][abcd_idx] 717 | preds.append(torch.argmax(logit_abcd).item()) 718 | labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:,0] 719 | for label in labels.tolist(): 720 | if label in abcd_idx: 721 | refs += [abcd_idx.index(label)] 722 | 723 | loss_mmlu += loss.item() 724 | # Extract results by subject. 725 | results = {'mmlu_loss':loss_mmlu/len(data_loader)} 726 | subject = mmlu_dataset['subject'] 727 | subjects = {s:{'refs':[], 'preds':[]} for s in set(subject)} 728 | for s,p,r in zip(subject, preds, refs): 729 | subjects[s]['preds'].append(p) 730 | subjects[s]['refs'].append(r) 731 | subject_scores = [] 732 | for subject in subjects: 733 | subject_score = accuracy.compute( 734 | references=subjects[subject]['refs'], 735 | predictions=subjects[subject]['preds'] 736 | )['accuracy'] 737 | results[f'mmlu_{args.mmlu_split}_accuracy_{subject}'] = subject_score 738 | subject_scores.append(subject_score) 739 | results[f'mmlu_{args.mmlu_split}_accuracy'] = np.mean(subject_scores) 740 | trainer.log(results) 741 | trainer.data_collator.source_max_len = source_max_len 742 | 743 | trainer.add_callback(MMLUEvalCallback) 744 | 745 | # Verifying the datatypes. 746 | dtypes = {} 747 | for _, p in model.named_parameters(): 748 | dtype = p.dtype 749 | if dtype not in dtypes: dtypes[dtype] = 0 750 | dtypes[dtype] += p.numel() 751 | total = 0 752 | for k, v in dtypes.items(): total+= v 753 | for k, v in dtypes.items(): 754 | print(k, v, v/total) 755 | 756 | all_metrics = {"run_name": args.run_name} 757 | # Training 758 | if args.do_train: 759 | train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint) 760 | metrics = train_result.metrics 761 | trainer.log_metrics("train", metrics) 762 | trainer.save_metrics("train", metrics) 763 | trainer.save_state() 764 | all_metrics.update(metrics) 765 | # Evaluation 766 | if args.do_eval: 767 | logger.info("*** Evaluate ***") 768 | metrics = trainer.evaluate(metric_key_prefix="eval") 769 | trainer.log_metrics("eval", metrics) 770 | trainer.save_metrics("eval", metrics) 771 | all_metrics.update(metrics) 772 | # Prediction 773 | if args.do_predict: 774 | logger.info("*** Predict ***") 775 | prediction_output = trainer.predict(test_dataset=data_module['predict_dataset'],metric_key_prefix="predict") 776 | prediction_metrics = prediction_output.metrics 777 | predictions = prediction_output.predictions 778 | predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) 779 | predictions = tokenizer.batch_decode( 780 | predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 781 | ) 782 | with open(os.path.join(args.output_dir, 'predictions.jsonl'), 'w') as fout: 783 | for i, example in enumerate(data_module['predict_dataset']): 784 | example['prediction_with_input'] = predictions[i].strip() 785 | example['prediction'] = predictions[i].replace(example['input'], '').strip() 786 | fout.write(json.dumps(example) + '\n') 787 | print(prediction_metrics) 788 | trainer.log_metrics("predict", prediction_metrics) 789 | trainer.save_metrics("predict", prediction_metrics) 790 | all_metrics.update(prediction_metrics) 791 | 792 | if (args.do_train or args.do_eval or args.do_predict): 793 | with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout: 794 | fout.write(json.dumps(all_metrics)) 795 | 796 | if __name__ == "__main__": 797 | train() 798 | --------------------------------------------------------------------------------