├── image
└── qalora.png
├── requirements.txt
├── merge.py
├── LICENSE
├── environment.yaml
├── README.md
├── peft_utils.py
└── qalora.py
/image/qalora.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuhuixu1993/qa-lora/HEAD/image/qalora.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bert-score==0.3.13
2 | evaluate==0.4.0
3 | rouge-score==0.1.2
4 | scikit-learn==1.2.2
5 | sentencepiece==0.1.99
6 | wandb==0.15.2
7 | transformers==4.31.0
8 | peft==0.4.0
9 | accelerate==0.21.0
--------------------------------------------------------------------------------
/merge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | model_path = 'path of the quantized model'
3 | lora_path = 'path of the saved LoRA adapters'
4 | merged_path = 'target path of the merged model'
5 | scale = 16 /64
6 | group_size = 32
7 |
8 | model = torch.load(model_path, map_location='cpu')
9 | lora = torch.load(lora_path, map_location='cpu')
10 | tmp_keys = [key[17:-14] for key in lora.keys() if 'lora_A' in key]
11 | for tmp_key in tmp_keys:
12 | model[tmp_key+'.qzeros'] -= (lora['base_model.model.'+tmp_key+'.lora_B.weight'] @ lora['base_model.model.'+tmp_key+'.lora_A.weight']).t() * scale / group_size /model[tmp_key+'.scales']
13 |
14 | torch.save(model, merged_path)
15 |
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 YuhuiXu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: alpaca
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.05.30=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.4.4=h6a678d5_0
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=3.0.9=h7f8727e_0
15 | - pip=23.1.2=py38h06a4308_0
16 | - python=3.8.16=h955ad1f_4
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=67.8.0=py38h06a4308_0
19 | - sqlite=3.41.2=h5eee18b_0
20 | - tk=8.6.12=h1ccaba5_0
21 | - wheel=0.38.4=py38h06a4308_0
22 | - xz=5.4.2=h5eee18b_0
23 | - zlib=1.2.13=h5eee18b_0
24 | - pip:
25 | - absl-py==1.2.0
26 | - accelerate==0.21.0.dev0
27 | - addict==2.4.0
28 | - aiohttp==3.8.4
29 | - aiosignal==1.3.1
30 | - appdirs==1.4.4
31 | - async-timeout==4.0.2
32 | - attrs==22.2.0
33 | - auto-gptq==0.3.0.dev0
34 | - bert-score==0.3.13
35 | - bitsandbytes==0.39.0
36 | - certifi==2022.9.24
37 | - charset-normalizer==3.0.1
38 | - click==8.1.3
39 | - cmake==3.26.3
40 | - colorama==0.4.5
41 | - contourpy==1.0.6
42 | - cycler==0.11.0
43 | - datasets==2.12.0
44 | - dill==0.3.5.1
45 | - evaluate==0.4.0
46 | - fonttools==4.39.4
47 | - frozenlist==1.3.3
48 | - fsspec==2022.11.0
49 | - gitdb==4.0.10
50 | - gitpython==3.1.31
51 | - huggingface-hub==0.14.1
52 | - importlib-resources==5.12.0
53 | - jinja2==3.1.2
54 | - joblib==1.2.0
55 | - lazy-import==0.2.2
56 | - lit==16.0.5
57 | - lxml==4.8.0
58 | - markupsafe==2.1.2
59 | - matplotlib==3.7.1
60 | - mpmath==1.2.1
61 | - multidict==6.0.2
62 | - multiprocess==0.70.12.2
63 | - networkx==2.8.8
64 | - ninja==1.11.1
65 | - nltk==3.8.1
66 | - numpy==1.24.2
67 | - packaging==23.1
68 | - pandas==2.0.0
69 | - pathlib2==2.3.7.post1
70 | - pathtools==0.1.2
71 | - peft==0.4.0.dev0
72 | - pillow==9.3.0
73 | - protobuf==3.20.2
74 | - psutil==5.9.4
75 | - pyarrow==10.0.1
76 | - pyparsing==3.0.9
77 | - python-dateutil==2.8.2
78 | - pytz==2022.6
79 | - pyyaml==6.0
80 | - requests==2.31.0
81 | - responses==0.18.0
82 | - rouge==1.0.0
83 | - rouge-score==0.1.2
84 | - safetensors==0.3.1
85 | - scikit-learn==1.2.2
86 | - scipy==1.10.1
87 | - sentencepiece==0.1.99
88 | - sentry-sdk==1.24.0
89 | - setproctitle==1.3.2
90 | - six==1.16.0
91 | - smmap==5.0.0
92 | - sympy==1.11.1
93 | - terminaltables==3.1.10
94 | - threadpoolctl==3.1.0
95 | - tokenizers==0.13.3
96 | - torch==2.0.0+cu117
97 | - tqdm==4.63.1
98 | - transformers==4.31.0.dev0
99 | - triton==2.0.0
100 | - typing-extensions==4.6.2
101 | - tzdata==2022.7
102 | - urllib3==1.26.16
103 | - wandb==0.15.2
104 | - xformers==0.0.20
105 | - xxhash==3.0.0
106 | - yapf==0.32.0
107 | - yarl==1.8.1
108 | - zipp==3.15.0
109 |
110 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # QA-LoRA
2 |
3 | QA-LoRA has been accepted by ICLR 2024!
4 |
5 | This repository provides the official PyTorch implementation of [QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models](https://arxiv.org/pdf/2309.14717.pdf).
6 |
7 |
8 |

9 |
10 |
11 | QA-LoRA is easily implemented with a few lines of code, and it equips the original LoRA with two-fold abilities: (i) during fine-tuning, the LLM's weights are quantized (e.g., into INT4) to reduce time and memory usage; (ii) after fine-tuning, the LLM and auxiliary weights are naturally integrated into a quantized model without loss of accuracy.
12 |
13 | ## Todo list
14 | Fix the conflict with the newest Auto-gptq version.
15 |
16 | ## Installation
17 | ```bash
18 | conda create -n qalora python=3.8
19 | conda activate qalora
20 | conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia
21 | git clone -b v0.3.0 https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ
22 | pip install .
23 | cd ..
24 | pip install bitsandbytes
25 | pip install -r requirements.txt
26 | pip install protobuf==3.20.*
27 | ```
28 | Change the `peft_utils.py` in your own auto-gptq path(python path/auto_gptq/utils/peft_utils.py) with the new one.
29 | For the users of [GPTQLORA](https://github.com/qwopqwop200/gptqlora), you only need to change the `peft_utils.py` file.
30 |
31 |
32 | ## Quantization
33 | We use [GPTQ](https://github.com/qwopqwop200/GPTQ-for-LLaMa) for quantization.
34 | bits=4, group-size=32, act-order=False
35 | If you change the group-size, you need to change the group_size in `peft_utils.py` and `merge.py` accordingly.
36 |
37 | ## Training
38 | ```bash
39 | python qalora.py --model_path
40 | ```
41 |
42 | The file structure of the model checkpoint is as follows:
43 | ```
44 | config.json llama7b-4bit-32g.bin special_tokens_map.json tokenizer_config.json
45 | generation_config.json quantize_config.json tokenizer.model
46 | ```
47 |
48 | ## Merge
49 | Note that our trained LoRA modules can be perfectly merged into the quantized model. We offer a simple merged script in this repo.
50 |
51 | ## Notice
52 | ### About the implementations
53 | There are two kinds of implementations of the dimention reduction(x from D_in to D_in//L). Both are mathematical equivalent.
54 | #### The first one(this repo)
55 | Adopt avgpooling operation. But the weights of adapters will be divided by D_in//L during merge(refer to `merge.py`).
56 | ```bash
57 | adapter_result = (lora_B(lora_A(lora_dropout(self.qa_pool(x)))) * scale).type_as(result)
58 | model[tmp_key+'.qzeros'] -= (lora['base_model.model.'+tmp_key+'.lora_B.weight'] @ lora['base_model.model.'+tmp_key+'.lora_A.weight']).t() * scale / group_size / model[tmp_key+'.scales']
59 | ```
60 | #### The second one
61 | Utilize sum operation. The adapters do not need to be divided during merge)
62 |
63 | ```bash
64 | adapter_result = (lora_B(lora_A(lora_dropout(self.qa_pool(x) * group_size))) * scale).type_as(result)
65 | model[tmp_key+'.qzeros'] -= (lora['base_model.model.'+tmp_key+'.lora_B.weight'] @ lora['base_model.model.'+tmp_key+'.lora_A.weight']).t() * scale / model[tmp_key+'.scales']
66 | ```
67 |
68 | ### About the quantization
69 |
70 | Some GPTQ implementation such as [GPTQ-for-llama](https://github.com/qwopqwop200/GPTQ-for-LLaMa) further compress the zeros into qzeros. You need to decode the qzeros first and restore fp16 format zeros.
71 | ## Acknowledgements
72 | Our code is based on [QLoRA](https://github.com/artidoro/qlora), [GPTQLORA](https://github.com/qwopqwop200/gptqlora), [Auto-GPTQ](https://github.com/PanQiWei/AutoGPTQ/tree/main)
73 |
--------------------------------------------------------------------------------
/peft_utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import re
3 | from contextlib import contextmanager
4 | from dataclasses import asdict
5 | from enum import Enum
6 | from typing import List, Optional
7 |
8 | import torch
9 | from peft import get_peft_model, PeftConfig, PeftModel, PeftType
10 | from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
11 | from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel, Embedding
12 | from peft.tuners.adalora import AdaLoraConfig, AdaLoraLayer, AdaLoraModel
13 | from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
14 | from peft.utils.other import _get_submodules
15 |
16 | from ..modeling._base import BaseGPTQForCausalLM
17 |
18 |
19 | group_size = 32 # quantization group_size
20 |
21 |
22 | class GPTQLoraConfig(LoraConfig):
23 | injected_fused_attention: bool = False
24 | injected_fused_mlp: bool = False
25 |
26 |
27 | class GPTQLoraLinear(torch.nn.Linear, LoraLayer):
28 | def __init__(
29 | self,
30 | adapter_name: str,
31 | linear_module: torch.nn.Linear,
32 | r: int = 0,
33 | lora_alpha: int = 1,
34 | lora_dropout: float = 0.0,
35 | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
36 | **kwargs,
37 | ):
38 | init_lora_weights = kwargs.pop("init_lora_weights", True)
39 |
40 | torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features)
41 | LoraLayer.__init__(self, linear_module.in_features//group_size, linear_module.out_features)
42 |
43 | self.linear_module = linear_module
44 |
45 | self.weight.requires_grad = False
46 | self.weight = self.linear_module.weight
47 | self.bias = self.linear_module.bias
48 | self.fan_in_fan_out = fan_in_fan_out
49 | if fan_in_fan_out:
50 | self.weight.data = self.weight.data.T
51 |
52 | self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
53 | self.active_adapter = adapter_name
54 | self.qa_pool = torch.nn.AvgPool1d(group_size) # using pooling layer to conduct sum operation
55 |
56 | def reset_lora_parameters(self, adapter_name):
57 | if adapter_name in self.lora_A.keys():
58 | torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
59 | torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
60 |
61 | def merge(self):
62 | raise NotImplementedError("gptq model not support merge lora adapter")
63 |
64 | def unmerge(self):
65 | raise NotImplementedError("gptq model not support unmerge lora adapter")
66 |
67 | def forward(self, x: torch.Tensor):
68 | previous_dtype = x.dtype
69 | if self.active_adapter not in self.lora_A.keys():
70 | return self.linear_module(x)
71 | if self.disable_adapters:
72 | if self.r[self.active_adapter] > 0 and self.merged:
73 | self.unmerge()
74 | result = self.linear_module(x)
75 | elif self.r[self.active_adapter] > 0 and not self.merged:
76 | result = self.linear_module(x)
77 |
78 | lora_B = self.lora_B[self.active_adapter]
79 | lora_A = self.lora_A[self.active_adapter]
80 | lora_dropout = self.lora_dropout[self.active_adapter]
81 | scale = self.scaling[self.active_adapter]
82 |
83 | x = x.type_as(lora_A.weight.data)
84 | adapter_result = (lora_B(lora_A(lora_dropout(self.qa_pool(x)))) * scale).type_as(result)
85 | result += adapter_result
86 | else:
87 | result = self.linear_module(x)
88 |
89 | result = result.to(previous_dtype)
90 |
91 | return result
92 |
93 |
94 | class GPTQLoraModel(LoraModel):
95 | def _find_and_replace(self, adapter_name):
96 | lora_config = self.peft_config[adapter_name]
97 | is_target_modules_in_base_model = False
98 | kwargs = {
99 | "r": lora_config.r,
100 | "lora_alpha": lora_config.lora_alpha,
101 | "lora_dropout": lora_config.lora_dropout,
102 | "fan_in_fan_out": lora_config.fan_in_fan_out,
103 | "init_lora_weights": lora_config.init_lora_weights,
104 | }
105 | key_list = [key for key, _ in self.model.named_modules()]
106 | for key in key_list:
107 | if isinstance(lora_config.target_modules, str):
108 | target_module_found = re.fullmatch(lora_config.target_modules, key)
109 | else:
110 | target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
111 | if target_module_found:
112 | if not is_target_modules_in_base_model:
113 | is_target_modules_in_base_model = True
114 | parent, target, target_name = _get_submodules(self.model, key)
115 | bias = False
116 | if hasattr(target, "bias"):
117 | bias = target.bias is not None
118 |
119 | if isinstance(target, LoraLayer):
120 | target.update_layer(
121 | adapter_name,
122 | lora_config.r,
123 | lora_config.lora_alpha,
124 | lora_config.lora_dropout,
125 | lora_config.init_lora_weights,
126 | )
127 | else:
128 | if isinstance(target, torch.nn.Embedding):
129 | embedding_kwargs = kwargs.copy()
130 | embedding_kwargs.pop("fan_in_fan_out", None)
131 | in_features, out_features = target.num_embeddings, target.embedding_dim
132 | new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
133 | else:
134 | if isinstance(target, torch.nn.Linear):
135 | if kwargs["fan_in_fan_out"]:
136 | warnings.warn(
137 | "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
138 | "Setting fan_in_fan_out to False."
139 | )
140 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
141 | else:
142 | raise ValueError(
143 | f"Target module {target} is not supported. "
144 | f"Currently, only `torch.nn.Linear` and its subclasses are supported."
145 | )
146 | new_module = GPTQLoraLinear(adapter_name, target, **kwargs)
147 |
148 | self._replace_module(parent, target_name, new_module, target)
149 | if not is_target_modules_in_base_model:
150 | raise ValueError(
151 | f"Target modules {lora_config.target_modules} not found in the base model. "
152 | f"Please check the target modules and try again."
153 | )
154 |
155 | def _replace_module(self, parent_module, child_name, new_module, old_module):
156 | setattr(parent_module, child_name, new_module)
157 | if not isinstance(new_module, GPTQLoraLinear):
158 | new_module.weight = old_module.weight
159 | if hasattr(old_module, "bias"):
160 | if old_module.bias is not None:
161 | new_module.bias = old_module.bias
162 |
163 | if getattr(old_module, "state", None) is not None:
164 | new_module.state = old_module.state
165 | new_module.to(old_module.weight.device)
166 |
167 | # dispatch to correct device
168 | for name, module in new_module.named_modules():
169 | if "lora_" in name:
170 | module.to(old_module.weight.device)
171 |
172 | def merge_adapter(self):
173 | raise NotImplementedError("gptq model not support merge ada lora adapter")
174 |
175 | def unmerge_adapter(self):
176 | raise NotImplementedError("gptq model not support unmerge ada lora adapter")
177 |
178 | def merge_and_unload(self):
179 | raise NotImplementedError("gptq model not support merge and unload")
180 |
181 |
182 | class GPTQAdaLoraConfig(AdaLoraConfig):
183 | injected_fused_attention: bool = False
184 | injected_fused_mlp: bool = False
185 |
186 |
187 | class GPTQSVDLinear(torch.nn.Linear, AdaLoraLayer):
188 | def __init__(
189 | self,
190 | adapter_name: str,
191 | linear_module: torch.nn.Linear,
192 | r: int = 0,
193 | lora_alpha: int = 1,
194 | lora_dropout: float = 0.0,
195 | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
196 | **kwargs,
197 | ):
198 | init_lora_weights = kwargs.pop("init_lora_weights", True)
199 |
200 | torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features)
201 | AdaLoraLayer.__init__(self, linear_module.in_features, linear_module.out_features)
202 |
203 | self.linear_module = linear_module
204 |
205 | self.weight.requires_grad = False
206 | self.weight = self.linear_module.weight
207 | self.bias = self.linear_module.bias
208 | self.fan_in_fan_out = fan_in_fan_out
209 | if fan_in_fan_out:
210 | self.weight.data = self.weight.data.T
211 |
212 | self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
213 | self.active_adapter = adapter_name
214 |
215 | def merge(self):
216 | raise NotImplementedError("gptq model not support merge lora adapter")
217 |
218 | def unmerge(self):
219 | raise NotImplementedError("gptq model not support unmerge lora adapter")
220 |
221 | def forward(self, x: torch.Tensor):
222 | if self.active_adapter not in self.lora_A.keys():
223 | return self.linear_module(x)
224 | if self.disable_adapters:
225 | if self.r[self.active_adapter] > 0 and self.merged:
226 | self.unmerge()
227 | result = self.linear_module(x)
228 | elif self.r[self.active_adapter] > 0 and not self.merged:
229 | result = self.linear_module(x)
230 | result += (
231 | (
232 | self.lora_dropout[self.active_adapter](x)
233 | @ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
234 | @ self.lora_B[self.active_adapter].T
235 | )
236 | * self.scaling[self.active_adapter]
237 | / (self.ranknum[self.active_adapter] + 1e-5)
238 | )
239 | else:
240 | result = self.linear_module(x)
241 | return result
242 |
243 |
244 | class GPTQAdaLoraModel(AdaLoraModel):
245 | def _find_and_replace(self, adapter_name):
246 | lora_config = self.peft_config[adapter_name]
247 | is_target_modules_in_base_model = False
248 | kwargs = {
249 | "r": lora_config.init_r,
250 | "lora_alpha": lora_config.lora_alpha,
251 | "lora_dropout": lora_config.lora_dropout,
252 | "fan_in_fan_out": lora_config.fan_in_fan_out,
253 | "init_lora_weights": lora_config.init_lora_weights,
254 | }
255 | key_list = [key for key, _ in self.model.named_modules()]
256 | for key in key_list:
257 | if isinstance(lora_config.target_modules, str):
258 | target_module_found = re.fullmatch(lora_config.target_modules, key)
259 | else:
260 | target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
261 | if target_module_found:
262 | if not is_target_modules_in_base_model:
263 | is_target_modules_in_base_model = True
264 | parent, target, target_name = _get_submodules(self.model, key)
265 | bias = target.bias is not None
266 | if isinstance(target, LoraLayer):
267 | target.update_layer(
268 | adapter_name,
269 | lora_config.init_r,
270 | lora_config.lora_alpha,
271 | lora_config.lora_dropout,
272 | lora_config.init_lora_weights,
273 | )
274 | else:
275 | if isinstance(target, torch.nn.Linear):
276 | in_features, out_features = target.in_features, target.out_features
277 | if kwargs["fan_in_fan_out"]:
278 | warnings.warn(
279 | "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
280 | "Setting fan_in_fan_out to False."
281 | )
282 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
283 | else:
284 | raise ValueError(
285 | f"Target module {target} is not supported. "
286 | f"Currently, only `torch.nn.Linear` and its subclasses are supported."
287 | )
288 | new_module = GPTQSVDLinear(adapter_name, target, **kwargs)
289 |
290 | self._replace_module(parent, target_name, new_module, target)
291 | if not is_target_modules_in_base_model:
292 | raise ValueError(
293 | f"Target modules {lora_config.target_modules} not found in the base model. "
294 | f"Please check the target modules and try again."
295 | )
296 |
297 | def _replace_module(self, parent_module, child_name, new_module, old_module):
298 | setattr(parent_module, child_name, new_module)
299 |
300 | # dispatch to correct device
301 | for name, module in new_module.named_modules():
302 | if "lora_" in name:
303 | module.to(old_module.weight.device)
304 |
305 | def merge_adapter(self):
306 | raise NotImplementedError("gptq model not support merge ada lora adapter")
307 |
308 | def unmerge_adapter(self):
309 | raise NotImplementedError("gptq model not support unmerge ada lora adapter")
310 |
311 | def merge_and_unload(self):
312 | raise NotImplementedError("gptq model not support merge and unload")
313 |
314 |
315 | def find_all_linear_names(model: BaseGPTQForCausalLM, ignore: Optional[List[str]] = None, ignore_lm_head: bool = True):
316 | if not ignore:
317 | ignore = []
318 | lm_head_name = model.lm_head_name
319 | if ignore_lm_head and lm_head_name not in ignore:
320 | ignore.append(lm_head_name)
321 | results = set()
322 | for n, m in model.named_modules():
323 | if isinstance(m, torch.nn.Linear):
324 | res = n.split('.')[-1]
325 | if res not in ignore:
326 | results.add(res)
327 | return list(results)
328 |
329 |
330 | @contextmanager
331 | def hijack_peft_mappings():
332 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
333 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
334 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
335 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
336 |
337 | try:
338 | yield
339 | except:
340 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
341 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
342 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
343 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
344 | raise
345 | finally:
346 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
347 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
348 | PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.ADALORA] = GPTQAdaLoraConfig
349 | PEFT_TYPE_TO_MODEL_MAPPING[PeftType.ADALORA] = GPTQAdaLoraModel
350 |
351 |
352 | def get_gptq_peft_model(
353 | model: BaseGPTQForCausalLM,
354 | peft_config: PeftConfig = None,
355 | model_id: str = None,
356 | adapter_name: str = "default",
357 | auto_find_all_linears: bool = True,
358 | train_mode: bool = False
359 | ):
360 | if train_mode and not model.trainable:
361 | model.enable_trainable_mode()
362 | if train_mode and not peft_config:
363 | raise ValueError("peft_config not specified when in train mode.")
364 | if not train_mode and not model_id:
365 | raise ValueError("model_id(where to load adapters) not specified when in inference mode.")
366 |
367 | if model.fused_attn_module_type is not None and not model.injected_fused_attention:
368 | peft_types = [PeftType.LORA.value, PeftType.ADALORA.value]
369 | warnings.warn(
370 | f"You can just ignore this warning if the peft type you use isn't in {peft_types}.\n"
371 | f"{model.__class__.__name__} supports injecting fused attention but not enables this time. "
372 | "If you are training adapters, you must also disable fused attention injection when loading quantized "
373 | "base model at inference time, otherwise adapters may not be added to base model properly. "
374 | "If you are loading adapters to do inference, you can reference to adapter's config file to check "
375 | "whether the adapters are trained using base model that not enable fused attention injection."
376 | )
377 | if model.injected_fused_mlp:
378 | raise NotImplementedError("GPTQ model that enables fused mlp injection is not supported to integrate with peft.")
379 |
380 | if train_mode:
381 | peft_type = peft_config.peft_type
382 | if not isinstance(peft_type, str):
383 | peft_type = peft_type.value
384 | if peft_type in [PeftType.LORA.value, PeftType.ADALORA.value]:
385 | if auto_find_all_linears:
386 | peft_config.target_modules = find_all_linear_names(model, ignore_lm_head=True)
387 | if peft_type == PeftType.LORA.value and not isinstance(peft_config, GPTQLoraConfig):
388 | peft_config = GPTQLoraConfig(**peft_config.to_dict())
389 | if peft_type == PeftType.ADALORA.value and not isinstance(peft_config, GPTQAdaLoraConfig):
390 | peft_config = GPTQAdaLoraConfig(**peft_config.to_dict())
391 | peft_config.injected_fused_attention = model.injected_fused_attention
392 | peft_config.injected_fused_mlp = model.injected_fused_mlp
393 | if peft_type == PeftType.ADAPTION_PROMPT.value:
394 | if peft_config.adapter_layers > model.config.num_hidden_layers:
395 | warnings.warn(
396 | f"model has only {model.config.num_hidden_layers} layers "
397 | f"but adapter_layers is set to {peft_config.adapter_layers}, "
398 | f"will reset value to {model.config.num_hidden_layers}."
399 | )
400 | peft_config.adapter_layers = model.config.num_hidden_layers
401 | if model.injected_fused_attention:
402 | raise NotImplementedError(
403 | "model with fused attention injected isn't supported to use ADAPTION_PROMPT peft type yet."
404 | )
405 |
406 | with hijack_peft_mappings():
407 | try:
408 | if train_mode:
409 | peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
410 | else:
411 | peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
412 | except:
413 | raise NotImplementedError(
414 | f"{model.__class__.__name__} not support {peft_config.peft_type.value} peft type yet."
415 | )
416 |
417 | return peft_model
418 |
419 |
420 | __all__ = [
421 | "GPTQLoraConfig",
422 | "GPTQLoraModel",
423 | "GPTQAdaLoraConfig",
424 | "GPTQAdaLoraModel",
425 | "find_all_linear_names",
426 | "get_gptq_peft_model"
427 | ]
428 |
--------------------------------------------------------------------------------
/qalora.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the MIT license found in the
2 | # LICENSE file in the root directory of this source tree.
3 |
4 | from collections import defaultdict
5 | import copy
6 | import json
7 | import os
8 | from os.path import exists, join, isdir
9 | from dataclasses import dataclass, field
10 | import sys
11 | from typing import Optional, Dict, Sequence
12 | import numpy as np
13 | from tqdm import tqdm
14 | import logging
15 |
16 | import torch
17 | import transformers
18 | from torch.nn.utils.rnn import pad_sequence
19 | import argparse
20 | from transformers import (
21 | AutoTokenizer,
22 | AutoModelForCausalLM,
23 | set_seed,
24 | Seq2SeqTrainer,
25 | LlamaTokenizerFast
26 | )
27 | from datasets import load_dataset
28 | import evaluate
29 |
30 | from peft import (
31 | LoraConfig,
32 | get_peft_model_state_dict,
33 | set_peft_model_state_dict,
34 | PeftModel
35 | )
36 | from peft.tuners.lora import LoraLayer
37 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
38 | from auto_gptq.utils.peft_utils import get_gptq_peft_model, GPTQLoraConfig
39 | from auto_gptq import AutoGPTQForCausalLM
40 | from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
41 |
42 | torch.backends.cuda.matmul.allow_tf32 = True
43 |
44 | logger = logging.getLogger(__name__)
45 |
46 | IGNORE_INDEX = -100
47 | DEFAULT_PAD_TOKEN = "[PAD]"
48 |
49 | import os
50 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
51 |
52 | def prepare_model_for_int8_training(model, use_gradient_checkpointing=True):
53 | r"""
54 | This method wraps the entire protocol for preparing a model before running a training. This includes:
55 | 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
56 | head to fp32
57 |
58 | Args:
59 | model, (`transformers.PreTrainedModel`):
60 | The loaded model from `transformers`
61 | """
62 | for name, param in model.named_parameters():
63 | # freeze base model's layers
64 | param.requires_grad = False
65 |
66 | if use_gradient_checkpointing:
67 | # For backward compatibility
68 | if hasattr(model, "enable_input_require_grads"):
69 | model.enable_input_require_grads()
70 | else:
71 |
72 | def make_inputs_require_grad(module, input, output):
73 | output.requires_grad_(True)
74 |
75 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
76 |
77 | # enable gradient checkpointing for memory efficiency
78 | model.gradient_checkpointing_enable()
79 |
80 | model.lm_head = model.lm_head.float()
81 | for _, param in model.named_parameters():
82 | if param.dtype == torch.float16:
83 | param = param.float()
84 |
85 | return model
86 |
87 | @dataclass
88 | class ModelArguments:
89 | model_path: Optional[str] = field(
90 | default="./llama-7b/"
91 | )
92 | trust_remote_code: Optional[bool] = field(
93 | default=False,
94 | metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
95 | )
96 |
97 | @dataclass
98 | class DataArguments:
99 | eval_dataset_size: int = field(
100 | default=1024, metadata={"help": "Size of validation dataset."}
101 | )
102 | max_train_samples: Optional[int] = field(
103 | default=None,
104 | metadata={
105 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
106 | "value if set."
107 | },
108 | )
109 | max_eval_samples: Optional[int] = field(
110 | default=None,
111 | metadata={
112 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
113 | "value if set."
114 | },
115 | )
116 | source_max_len: int = field(
117 | default=1024,
118 | metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."},
119 | )
120 | target_max_len: int = field(
121 | default=256,
122 | metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
123 | )
124 | dataset: str = field(
125 | default='alpaca',
126 | metadata={"help": "Which dataset to finetune on. See datamodule for options."}
127 | )
128 |
129 | @dataclass
130 | class TrainingArguments(transformers.Seq2SeqTrainingArguments):
131 | cache_dir: Optional[str] = field(
132 | default=None
133 | )
134 | train_on_source: Optional[bool] = field(
135 | default=False,
136 | metadata={"help": "Whether to train on the input in addition to the target text."}
137 | )
138 | mmlu_split: Optional[str] = field(
139 | default='eval',
140 | metadata={"help": "The MMLU split to run on"}
141 | )
142 | mmlu_dataset: Optional[str] = field(
143 | default='mmlu-fs',
144 | metadata={"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."}
145 | )
146 | do_mmlu_eval: Optional[bool] = field(
147 | default=False,
148 | metadata={"help": "Whether to run the MMLU evaluation."}
149 | )
150 | max_mmlu_samples: Optional[int] = field(
151 | default=None,
152 | metadata={"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."}
153 | )
154 | mmlu_source_max_len: int = field(
155 | default=2048,
156 | metadata={"help": "Maximum source sequence length for mmlu."}
157 | )
158 | full_finetune: bool = field(
159 | default=False,
160 | metadata={"help": "Finetune the entire model without adapters."}
161 | )
162 | adam8bit: bool = field(
163 | default=False,
164 | metadata={"help": "Use 8-bit adam."}
165 | )
166 | lora_r: int = field(
167 | default=64,
168 | metadata={"help": "Lora R dimension."}
169 | )
170 | lora_alpha: float = field(
171 | default=16,
172 | metadata={"help": " Lora alpha."}
173 | )
174 | lora_dropout: float = field(
175 | default=0.0,
176 | metadata={"help":"Lora dropout."}
177 | )
178 | max_memory_MB: int = field(
179 | default=24000,
180 | metadata={"help": "Free memory per gpu."}
181 | )
182 | report_to: str = field(
183 | default='none',
184 | metadata={"help": "To use wandb or something else for reporting."}
185 | )
186 | output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
187 | optim: str = field(default='paged_adamw_32bit', metadata={"help": 'The optimizer to be used'})
188 | per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
189 | gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
190 | max_steps: int = field(default=10000, metadata={"help": 'How many optimizer update steps to take'})
191 | weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) # use lora dropout instead for regularization if needed
192 | learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'})
193 | remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'})
194 | max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'})
195 | gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'})
196 | do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'})
197 | lr_scheduler_type: str = field(default='constant', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
198 | warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'})
199 | logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
200 | group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
201 | save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'})
202 | save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
203 | save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
204 |
205 | @dataclass
206 | class GenerationArguments:
207 | # For more hyperparameters check:
208 | # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
209 | # Length arguments
210 | max_new_tokens: Optional[int] = field(
211 | default=256,
212 | metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops"
213 | "if predict_with_generate is set."}
214 | )
215 | min_new_tokens : Optional[int] = field(
216 | default=None,
217 | metadata={"help": "Minimum number of new tokens to generate."}
218 | )
219 |
220 | # Generation strategy
221 | do_sample: Optional[bool] = field(default=False)
222 | num_beams: Optional[int] = field(default=1)
223 | num_beam_groups: Optional[int] = field(default=1)
224 | penalty_alpha: Optional[float] = field(default=None)
225 | use_cache: Optional[bool] = field(default=True)
226 |
227 | # Hyperparameters for logit manipulation
228 | temperature: Optional[float] = field(default=1.0)
229 | top_k: Optional[int] = field(default=50)
230 | top_p: Optional[float] = field(default=1.0)
231 | typical_p: Optional[float] = field(default=1.0)
232 | diversity_penalty: Optional[float] = field(default=0.0)
233 | repetition_penalty: Optional[float] = field(default=1.0)
234 | length_penalty: Optional[float] = field(default=1.0)
235 | no_repeat_ngram_size: Optional[int] = field(default=0)
236 |
237 | def find_all_linear_names(args, model):
238 | cls = GeneralQuantLinear if not(args.full_finetune) else torch.nn.Linear
239 | lora_module_names = set()
240 | for name, module in model.named_modules():
241 | if isinstance(module, cls):
242 | names = name.split('.')
243 | lora_module_names.add(names[0] if len(names) == 1 else names[-1])
244 |
245 |
246 | if 'lm_head' in lora_module_names: # needed for 16-bit
247 | lora_module_names.remove('lm_head')
248 | return list(lora_module_names)
249 |
250 |
251 | class SavePeftModelCallback(transformers.TrainerCallback):
252 | def save_model(self, args, state, kwargs):
253 | print('Saving PEFT checkpoint...')
254 | if state.best_model_checkpoint is not None:
255 | checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model")
256 | else:
257 | checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
258 |
259 | peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
260 | kwargs["model"].save_pretrained(peft_model_path)
261 |
262 | pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
263 | if os.path.exists(pytorch_model_path):
264 | os.remove(pytorch_model_path)
265 |
266 | def on_save(self, args, state, control, **kwargs):
267 | self.save_model(args, state, kwargs)
268 | return control
269 |
270 | def on_train_end(self, args, state, control, **kwargs):
271 | def touch(fname, times=None):
272 | with open(fname, 'a'):
273 | os.utime(fname, times)
274 |
275 | touch(join(args.output_dir, 'completed'))
276 | self.save_model(args, state, kwargs)
277 |
278 | def get_accelerate_model(args, checkpoint_dir):
279 |
280 | n_gpus = torch.cuda.device_count()
281 | max_memory = f'{args.max_memory_MB}MB'
282 | max_memory = {i: max_memory for i in range(n_gpus)}
283 |
284 | if args.full_finetune: assert args.bits in [16, 32]
285 |
286 | print(f'loading base model {args.model_path}...')
287 | model = AutoGPTQForCausalLM.from_quantized(
288 | args.model_path,
289 | device_map='auto',
290 | max_memory=max_memory,
291 | trust_remote_code=args.trust_remote_code,
292 | inject_fused_attention = False,
293 | inject_fused_mlp = False,
294 | use_triton=True,
295 | warmup_triton=False,
296 | trainable=True
297 | )
298 | model.model.quantize_config = model.quantize_config
299 | model.train()
300 |
301 | setattr(model, 'model_parallel', True)
302 | setattr(model, 'is_parallelizable', True)
303 | #modules = find_all_linear_names(args, model)
304 |
305 | model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
306 |
307 | if not args.full_finetune:
308 | model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
309 | if args.gradient_checkpointing:
310 | model.gradient_checkpointing_enable()
311 |
312 | config = GPTQLoraConfig(
313 | r=args.lora_r,
314 | lora_alpha=args.lora_alpha,
315 | #target_modules=modules,
316 | lora_dropout=args.lora_dropout,
317 | bias="none",
318 | task_type="CAUSAL_LM",
319 | )
320 | if not args.full_finetune:
321 | if checkpoint_dir is not None:
322 | print("Loading adapters from checkpoint.")
323 | model = PeftModel.from_pretrained(model, join(checkpoint_dir, 'adapter_model'))
324 | for name, p in model.named_parameters():
325 | if 'lora' in name:
326 | print(name, p.sum())
327 | else:
328 | print(f'adding LoRA modules...')
329 | model = get_gptq_peft_model(model, config, auto_find_all_linears=True, train_mode=True)
330 |
331 | if args.gradient_checkpointing:
332 | if hasattr(model, "enable_input_require_grads"):
333 | model.enable_input_require_grads()
334 | else:
335 | def make_inputs_require_grad(module, input, output):
336 | output.requires_grad_(True)
337 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
338 |
339 |
340 | for name, module in model.named_modules():
341 | if isinstance(module, LoraLayer):
342 | if args.bf16:
343 | module = module.to(torch.bfloat16)
344 | if 'norm' in name:
345 | module = module.to(torch.float32)
346 | if 'lm_head' in name or 'embed_tokens' in name:
347 | if hasattr(module, 'weight'):
348 | if args.bf16 and module.weight.dtype == torch.float32:
349 | module = module.to(torch.bfloat16)
350 | return model
351 |
352 | def print_trainable_parameters(args, model):
353 | """
354 | Prints the number of trainable parameters in the model.
355 | """
356 | trainable_params = 0
357 | all_param = 0
358 | for _, param in model.named_parameters():
359 | all_param += param.numel()
360 | if param.requires_grad:
361 | trainable_params += param.numel()
362 | try:
363 | trainable_params /= (32//model.quantize_config.bits)
364 | except:
365 | pass
366 | print(f"trainable params: {trainable_params} || all params: {all_param} || trainable: {100 * trainable_params / all_param}")
367 |
368 | def smart_tokenizer_and_embedding_resize(
369 | special_tokens_dict: Dict,
370 | tokenizer: transformers.PreTrainedTokenizer,
371 | model: transformers.PreTrainedModel,
372 | ):
373 | """Resize tokenizer and embedding.
374 |
375 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
376 | """
377 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
378 | model.resize_token_embeddings(len(tokenizer))
379 |
380 | if num_new_tokens > 0:
381 | input_embeddings = model.get_input_embeddings().weight.data
382 | output_embeddings = model.get_output_embeddings().weight.data
383 |
384 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
385 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
386 |
387 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
388 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
389 |
390 | @dataclass
391 | class DataCollatorForCausalLM(object):
392 | tokenizer: transformers.PreTrainedTokenizer
393 | source_max_len: int
394 | target_max_len: int
395 | train_on_source: bool
396 | predict_with_generate: bool
397 |
398 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
399 | # Extract elements
400 | sources = [example['input'] for example in instances]
401 | targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances]
402 | # Tokenize
403 | tokenized_sources_with_prompt = self.tokenizer(
404 | sources,
405 | max_length=self.source_max_len,
406 | truncation=True,
407 | )
408 | tokenized_targets = self.tokenizer(
409 | targets,
410 | max_length=self.target_max_len,
411 | truncation=True,
412 | add_special_tokens=False,
413 | )
414 | # Build the input and labels for causal LM
415 | input_ids = []
416 | labels = []
417 | for tokenized_source, tokenized_target in zip(
418 | tokenized_sources_with_prompt['input_ids'],
419 | tokenized_targets['input_ids']
420 | ):
421 | if not self.predict_with_generate:
422 | input_ids.append(torch.tensor(tokenized_source + tokenized_target))
423 | if not self.train_on_source:
424 | labels.append(
425 | torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
426 | )
427 | else:
428 | labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)))
429 | else:
430 | input_ids.append(torch.tensor(tokenized_source))
431 | # Apply padding
432 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
433 | labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None
434 | data_dict = {
435 | 'input_ids': input_ids,
436 | 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id),
437 | }
438 | if labels is not None:
439 | data_dict['labels'] = labels
440 | return data_dict
441 |
442 | def extract_unnatural_instructions_data(examples, extract_reformulations=False):
443 | out = {
444 | 'input': [],
445 | 'output': [],
446 | }
447 | for example_instances in examples['instances']:
448 | for instance in example_instances:
449 | out['input'].append(instance['instruction_with_input'])
450 | out['output'].append(instance['output'])
451 | if extract_reformulations:
452 | for example_reformulations in examples['reformulations']:
453 | if example_reformulations is not None:
454 | for instance in example_reformulations:
455 | out['input'].append(instance['instruction_with_input'])
456 | out['output'].append(instance['output'])
457 | return out
458 |
459 | PROMPT_DICT = {
460 | "prompt_input": (
461 | "Below is an instruction that describes a task, paired with an input that provides further context. "
462 | "Write a response that appropriately completes the request.\n\n"
463 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
464 | ),
465 | "prompt_no_input": (
466 | "Below is an instruction that describes a task. "
467 | "Write a response that appropriately completes the request.\n\n"
468 | "### Instruction:\n{instruction}\n\n### Response: "
469 | ),
470 | }
471 |
472 | def extract_alpaca_dataset(example):
473 | if example.get("input", "") != "":
474 | prompt_format = PROMPT_DICT["prompt_input"]
475 | else:
476 | prompt_format = PROMPT_DICT["prompt_no_input"]
477 | return {'input': prompt_format.format(**example)}
478 |
479 | def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict:
480 | """
481 | Make dataset and collator for supervised fine-tuning.
482 | Datasets are expected to have the following columns: { `input`, `output` }
483 |
484 | Available datasets to be selected with `dataset` argument:
485 | - alpaca, 52002 examples
486 | - alpaca cleaned, 51942 examples
487 | - chip2 (OIG), 210289 examples
488 | - self-instruct, 82612 examples
489 | - hh-rlhf (Anthropic), 160800 examples
490 | - longform, 23.7k examples
491 |
492 | Coming soon:
493 | - unnatural instructions core, 66010 examples
494 | - unnatural instructions full, 240670 examples
495 | - alpaca-gpt4, 52002 examples
496 | - unnatural-instructions-gpt4, 9000 examples
497 | - oa-rlhf (OpenAssistant) primary message tree only, 9209 examples
498 | - oa-rlhf-assistant (OpenAssistant) all assistant replies with ranking
499 | - supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used)
500 | - flan (FLAN v2), up to 20M examples available
501 |
502 | Not Available:
503 | - vicuna, not released at the moment.
504 | """
505 | # Load dataset.
506 | # Alpaca
507 | if args.dataset == 'alpaca':
508 | dataset = load_dataset("tatsu-lab/alpaca")
509 | dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction'])
510 | # Alpaca clean
511 | elif args.dataset == 'alpaca-clean':
512 | dataset = load_dataset("yahma/alpaca-cleaned")
513 | dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction'])
514 | # Chip2
515 | elif args.dataset == 'chip2':
516 | dataset = load_dataset("laion/OIG", data_files='unified_chip2.jsonl')
517 | dataset = dataset.map(lambda x: {
518 | 'input': x['text'].split('\n: ')[0].replace(': ', ''),
519 | 'output': x['text'].split('\n: ')[1],
520 | }, remove_columns=['text', 'metadata'])
521 | # Self Instruct
522 | elif args.dataset == 'self-instruct':
523 | dataset = load_dataset("yizhongw/self_instruct", name='self_instruct')
524 | for old, new in [["prompt", "input"], ["completion", "output"]]:
525 | dataset = dataset.rename_column(old, new)
526 | # Anthropic rlhf
527 | elif args.dataset == 'hh-rlhf':
528 | dataset = load_dataset("Anthropic/hh-rlhf")
529 | dataset = dataset.map(lambda x: {
530 | 'input': '',
531 | 'output': x['chosen']
532 | }, remove_columns=['chosen', 'rejected'])
533 | # LongForm
534 | elif args.dataset == 'longform':
535 | dataset = load_dataset("akoksal/LongForm")
536 | elif args.dataset == 'vicuna':
537 | raise NotImplementedError("Vicuna data was not released.")
538 | else:
539 | raise NotImplementedError(f"Dataset {args.dataset} not implemented yet.")
540 |
541 | # Split train/eval, reduce size
542 | if args.do_eval or args.do_predict:
543 | if 'eval' in dataset:
544 | eval_dataset = dataset['eval']
545 | else:
546 | print('Splitting train dataset in train and validation according to `eval_dataset_size`')
547 | dataset = dataset["train"].train_test_split(
548 | test_size=args.eval_dataset_size, shuffle=True, seed=42
549 | )
550 | eval_dataset = dataset['test']
551 | if args.max_eval_samples is not None and len(eval_dataset) > args.max_eval_samples:
552 | eval_dataset = eval_dataset.select(range(args.max_eval_samples))
553 | if args.group_by_length:
554 | eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
555 | if args.do_train:
556 | train_dataset = dataset['train']
557 | if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples:
558 | train_dataset = train_dataset.select(range(args.max_train_samples))
559 | if args.group_by_length:
560 | train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
561 |
562 | data_collator = DataCollatorForCausalLM(
563 | tokenizer=tokenizer,
564 | source_max_len=args.source_max_len,
565 | target_max_len=args.target_max_len,
566 | train_on_source=args.train_on_source,
567 | predict_with_generate=args.predict_with_generate,
568 | )
569 | return dict(
570 | train_dataset=train_dataset if args.do_train else None,
571 | eval_dataset=eval_dataset if args.do_eval else None,
572 | predict_dataset=eval_dataset if args.do_predict else None,
573 | data_collator=data_collator
574 | )
575 |
576 | def get_last_checkpoint(checkpoint_dir):
577 | if isdir(checkpoint_dir):
578 | is_completed = exists(join(checkpoint_dir, 'completed'))
579 | if is_completed: return None, True # already finished
580 | max_step = 0
581 | for filename in os.listdir(checkpoint_dir):
582 | if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'):
583 | max_step = max(max_step, int(filename.replace('checkpoint-', '')))
584 | if max_step == 0: return None, is_completed # training started, but no checkpoint
585 | checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}')
586 | print(f"Found a previous checkpoint at: {checkpoint_dir}")
587 | return checkpoint_dir, is_completed # checkpoint found!
588 | return None, False # first training
589 |
590 | def train():
591 | hfparser = transformers.HfArgumentParser((
592 | ModelArguments, DataArguments, TrainingArguments, GenerationArguments
593 | ))
594 | model_args, data_args, training_args, generation_args, extra_args = \
595 | hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
596 | training_args.generation_config = transformers.GenerationConfig(**vars(generation_args))
597 | args = argparse.Namespace(
598 | **vars(model_args), **vars(data_args), **vars(training_args)
599 | )
600 |
601 |
602 | checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
603 | if completed_training:
604 | print('Detected that training was already completed!')
605 |
606 | model = get_accelerate_model(args, checkpoint_dir)
607 | training_args.skip_loading_checkpoint_weights=True
608 |
609 | resume_from_checkpoint = checkpoint_dir
610 | if resume_from_checkpoint:
611 | # Check the available weights and load them
612 | checkpoint_name = os.path.join(
613 | checkpoint_dir, "pytorch_model.bin"
614 | ) # Full checkpoint
615 | if not os.path.exists(checkpoint_name):
616 | checkpoint_path = os.path.join(
617 | checkpoint_dir, "adapter_model"
618 | )
619 |
620 | checkpoint_name = os.path.join(
621 | checkpoint_path, "adapter_model.bin"
622 | ) # only LoRA model - LoRA config above has to fit
623 | resume_from_checkpoint = (
624 | False # So the trainer won't try loading its state
625 | )
626 | # The two files above have a different name depending on how they were saved, but are actually the same.
627 | if os.path.exists(checkpoint_name):
628 | print(f"Restarting from {checkpoint_name}")
629 | adapters_weights = torch.load(checkpoint_name)
630 | set_peft_model_state_dict(model, adapters_weights)
631 | else:
632 | print(f"Checkpoint {checkpoint_name} not found")
633 |
634 | model.config.use_cache = False
635 | print_trainable_parameters(args, model)
636 | print('loaded model')
637 | set_seed(args.seed)
638 |
639 | # Tokenizer
640 | tokenizer = AutoTokenizer.from_pretrained(
641 | args.model_path,
642 | cache_dir=args.cache_dir,
643 | padding_side="right",
644 | use_fast=True,
645 | )
646 |
647 | if tokenizer.pad_token is None:
648 | smart_tokenizer_and_embedding_resize(
649 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
650 | tokenizer=tokenizer,
651 | model=model,
652 | )
653 |
654 | if isinstance(tokenizer, LlamaTokenizerFast):
655 | # LLaMA tokenizer may not have correct special tokens set.
656 | # Check and add them if missing to prevent them from being parsed into different tokens.
657 | # Note that these are present in the vocabulary.
658 | # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token.
659 | tokenizer.add_special_tokens(
660 | {
661 | "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
662 | "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
663 | "unk_token": tokenizer.convert_ids_to_tokens(model.config.pad_token_id),
664 | }
665 | )
666 |
667 | data_module = make_data_module(tokenizer=tokenizer, args=args)
668 | trainer = Seq2SeqTrainer(
669 | model=model,
670 | tokenizer=tokenizer,
671 | args=training_args,
672 | **{k:v for k,v in data_module.items() if k != 'predict_dataset'},
673 | )
674 |
675 | # Callbacks
676 | if not args.full_finetune:
677 | trainer.add_callback(SavePeftModelCallback)
678 | if args.do_mmlu_eval:
679 | if args.mmlu_dataset == 'mmlu-zs':
680 | mmlu_dataset = load_dataset("json", data_files={
681 | 'eval': 'data/mmlu/zero_shot_mmlu_val.json',
682 | 'test': 'data/mmlu/zero_shot_mmlu_test.json',
683 | })
684 | mmlu_dataset = mmlu_dataset.remove_columns('subject')
685 | # MMLU Five-shot (Eval/Test only)
686 | elif args.mmlu_dataset == 'mmlu' or args.mmlu_dataset == 'mmlu-fs':
687 | mmlu_dataset = load_dataset("json", data_files={
688 | 'eval': 'data/mmlu/five_shot_mmlu_val.json',
689 | 'test': 'data/mmlu/five_shot_mmlu_test.json',
690 | })
691 | # mmlu_dataset = mmlu_dataset.remove_columns('subject')
692 | mmlu_dataset = mmlu_dataset[args.mmlu_split]
693 | if args.max_mmlu_samples is not None:
694 | mmlu_dataset = mmlu_dataset.select(range(args.max_mmlu_samples))
695 | abcd_idx = [
696 | tokenizer("A", add_special_tokens=False).input_ids[0],
697 | tokenizer("B", add_special_tokens=False).input_ids[0],
698 | tokenizer("C", add_special_tokens=False).input_ids[0],
699 | tokenizer("D", add_special_tokens=False).input_ids[0],
700 | ]
701 | accuracy = evaluate.load("accuracy")
702 |
703 | class MMLUEvalCallback(transformers.TrainerCallback):
704 | def on_evaluate(self, args, state, control, model, **kwargs):
705 | data_loader = trainer.get_eval_dataloader(mmlu_dataset)
706 | source_max_len = trainer.data_collator.source_max_len
707 | trainer.data_collator.source_max_len = args.mmlu_source_max_len
708 | trainer.model.eval()
709 | preds, refs = [], []
710 | loss_mmlu = 0
711 | for batch in tqdm(data_loader, total=len(data_loader)):
712 | (loss, logits, labels) = trainer.prediction_step(trainer.model,batch,prediction_loss_only=False,)
713 | # There are two tokens, the output, and eos token.
714 | for i, logit in enumerate(logits):
715 | label_non_zero_id = (batch['labels'][i] != -100).nonzero()[0][0]
716 | logit_abcd = logit[label_non_zero_id-1][abcd_idx]
717 | preds.append(torch.argmax(logit_abcd).item())
718 | labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:,0]
719 | for label in labels.tolist():
720 | if label in abcd_idx:
721 | refs += [abcd_idx.index(label)]
722 |
723 | loss_mmlu += loss.item()
724 | # Extract results by subject.
725 | results = {'mmlu_loss':loss_mmlu/len(data_loader)}
726 | subject = mmlu_dataset['subject']
727 | subjects = {s:{'refs':[], 'preds':[]} for s in set(subject)}
728 | for s,p,r in zip(subject, preds, refs):
729 | subjects[s]['preds'].append(p)
730 | subjects[s]['refs'].append(r)
731 | subject_scores = []
732 | for subject in subjects:
733 | subject_score = accuracy.compute(
734 | references=subjects[subject]['refs'],
735 | predictions=subjects[subject]['preds']
736 | )['accuracy']
737 | results[f'mmlu_{args.mmlu_split}_accuracy_{subject}'] = subject_score
738 | subject_scores.append(subject_score)
739 | results[f'mmlu_{args.mmlu_split}_accuracy'] = np.mean(subject_scores)
740 | trainer.log(results)
741 | trainer.data_collator.source_max_len = source_max_len
742 |
743 | trainer.add_callback(MMLUEvalCallback)
744 |
745 | # Verifying the datatypes.
746 | dtypes = {}
747 | for _, p in model.named_parameters():
748 | dtype = p.dtype
749 | if dtype not in dtypes: dtypes[dtype] = 0
750 | dtypes[dtype] += p.numel()
751 | total = 0
752 | for k, v in dtypes.items(): total+= v
753 | for k, v in dtypes.items():
754 | print(k, v, v/total)
755 |
756 | all_metrics = {"run_name": args.run_name}
757 | # Training
758 | if args.do_train:
759 | train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
760 | metrics = train_result.metrics
761 | trainer.log_metrics("train", metrics)
762 | trainer.save_metrics("train", metrics)
763 | trainer.save_state()
764 | all_metrics.update(metrics)
765 | # Evaluation
766 | if args.do_eval:
767 | logger.info("*** Evaluate ***")
768 | metrics = trainer.evaluate(metric_key_prefix="eval")
769 | trainer.log_metrics("eval", metrics)
770 | trainer.save_metrics("eval", metrics)
771 | all_metrics.update(metrics)
772 | # Prediction
773 | if args.do_predict:
774 | logger.info("*** Predict ***")
775 | prediction_output = trainer.predict(test_dataset=data_module['predict_dataset'],metric_key_prefix="predict")
776 | prediction_metrics = prediction_output.metrics
777 | predictions = prediction_output.predictions
778 | predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
779 | predictions = tokenizer.batch_decode(
780 | predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
781 | )
782 | with open(os.path.join(args.output_dir, 'predictions.jsonl'), 'w') as fout:
783 | for i, example in enumerate(data_module['predict_dataset']):
784 | example['prediction_with_input'] = predictions[i].strip()
785 | example['prediction'] = predictions[i].replace(example['input'], '').strip()
786 | fout.write(json.dumps(example) + '\n')
787 | print(prediction_metrics)
788 | trainer.log_metrics("predict", prediction_metrics)
789 | trainer.save_metrics("predict", prediction_metrics)
790 | all_metrics.update(prediction_metrics)
791 |
792 | if (args.do_train or args.do_eval or args.do_predict):
793 | with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout:
794 | fout.write(json.dumps(all_metrics))
795 |
796 | if __name__ == "__main__":
797 | train()
798 |
--------------------------------------------------------------------------------