├── README.md ├── calibration ├── llama │ ├── calibration.py │ ├── generate_bias.sh │ ├── generate_scale_factor.sh │ └── run_calibration.py └── opt │ ├── calibration.py │ ├── generate_bias.sh │ ├── generate_scale_factor.sh │ └── run_calibration.py ├── figures └── tender-flow.png ├── lm-eval.patch ├── models ├── modeling_llama.py ├── modeling_llama_orig.py ├── modeling_llama_tender.py ├── modeling_opt.py ├── modeling_opt_orig.py └── modeling_opt_tender.py ├── requirements.txt ├── scripts ├── datautils.py ├── llama.py ├── opt.py ├── table_2.py ├── table_3.py ├── table_7.py └── utils.py └── setup.sh /README.md: -------------------------------------------------------------------------------- 1 | # Tender: Accelerating Large Language Models via Tensor Decomposition and Runtime Requantization 2 | [[paper](https://arxiv.org/abs/2406.12930)] 3 | 4 | ![tender-flow](figures/tender-flow.png) 5 | 6 | ## Overview 7 | 8 | This repository contains the code for the ISCA'24 paper `Tender: Accelerating Large Language Models via Tensor Decomposition and Runtime Requantization`. Some of the code snippets are referenced from SmoothQuant (ICML'23) and GPTQ (ICLR'23) GitHub repositories. 9 | 10 | ## Abstract 11 | 12 | Large language models (LLMs) demonstrate outstanding performance in various tasks in machine learning and have thus become one of the most important workloads in today's computing landscape. However, deploying LLM inference poses challenges due to the high compute and memory requirements stemming from the enormous model size and the difficulty of running it in the integer pipelines. In this paper, we present Tender, an algorithm-hardware co-design solution that enables efficient deployment of LLM inference at low precision. Based on our analysis of outlier values in LLMs, we propose a decomposed quantization technique in which the scale factors of decomposed matrices are powers of two apart. The proposed scheme allows us to avoid explicit requantization (i.e., dequantization/quantization) when accumulating the partial sums from the decomposed matrices, with a minimal extension to the commodity tensor compute hardware. Our evaluation shows that Tender achieves higher accuracy and inference performance compared to the state-of-the-art methods while also being significantly less intrusive to the existing accelerators. 13 | 14 | ## Directory Structure 15 | 16 | - calibration: Calibration scripts for generating scale factor, channel bias, and group index. 17 | - opt: Calibration script for OPT. 18 | - llama: Calibration script for Llama-2 / LLaMA 19 | - models: Tender implementation. 20 | - scripts: Scripts for running the perplexity and accuracy evaluation. 21 | 22 | ## Setup 23 | 24 | ### Prerequisite 25 | 26 | Fetch llama-2 from [here](https://llama.meta.com/llama-downloads). You may also need to convert the model to huggingface format using the `convert_llama_weight_to_hf.py` in `transformers/src/transformers/models/llama`. 27 | 28 | ```sh 29 | conda create -n tender python=3.9 30 | conda activate tender 31 | conda install ninja 32 | pip install -r requirements.txt 33 | git clone -b v4.35-release https://github.com/huggingface/transformers.git 34 | cd transformers 35 | pip install -e . 36 | cd .. && bash setup.sh 37 | ``` 38 | 39 | NOTE: `setup.sh` renames the original model code from `modeling_xx.py` to `modeling_xx_orig.py` (e.g., `modeling_opt.py` -> `modeling_opt_orig.py`) in the transformers library. 40 | 41 | ### Calibration 42 | 43 | Tender requires offline calibration to determine the scale factors, biases, and channel indices. To calibrate the models, run the following command: 44 | 45 | ```sh 46 | cd calibration/opt 47 | bash generate_bias.sh 48 | bash generate_scale_factor.sh 49 | 50 | export LLAMA2_PATH=/path/to/llama-2 51 | cd calibration/llama 52 | bash generate_bias.sh 53 | bash generate_scale_factor.sh 54 | ``` 55 | 56 | The above code generates channel bias, scale factor, and channel group index for each row chunk. 57 | 58 | ## Perplexity Evaluation 59 | 60 | To reproduce Tables 2 and 3, run the following command: 61 | 62 | ```sh 63 | export LLAMA2_PATH=/path/to/llama-2 64 | cd scripts 65 | python table_2.py 66 | python table_3.py 67 | ``` 68 | 69 | ## Accuracy Evaluation 70 | 71 | To reproduce Table 7, you need to get the source code of lm-evaluation-harness and MX emulation library: 72 | 73 | ```sh 74 | git clone https://github.com/microsoft/microxcaling.git 75 | cd microxcaling 76 | git checkout 94741 && pip install --no-deps . 77 | cd .. 78 | git clone https://github.com/EleutherAI/lm-evaluation-harness.git 79 | cd lm-evaluation-harness 80 | git checkout 1736d && pip install -e . 81 | # apply the patch for supporting Tender 82 | cp ../lm-eval.patch . && git apply lm-eval.patch 83 | ``` 84 | 85 | After setting up the libraries, run the following command: 86 | 87 | ```sh 88 | cd scripts 89 | python table_7.py 90 | ``` 91 | 92 | The results might be slightly different depending on the GPU. 93 | 94 | ## Citation 95 | 96 | If you find Tender useful and relevant to your research, please kindly cite our paper. 97 | 98 | ```bibtex 99 | @inproceedings{lee-isca24, 100 | title={Tender: Accelerating Large Language Models via Tensor Decomposition and Runtime Requantization}, 101 | author={Lee, Jungi and Lee, Wonbeom and Sim, Jaewoong}, 102 | booktitle={Proceedings of the 51st Annual International Symposium on Computer Architecture}, 103 | year={2024} 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /calibration/llama/calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from datasets import load_dataset 5 | import functools 6 | from collections import defaultdict 7 | 8 | from functools import partial 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | def forward_by_layer(model, inputs, num_samples, seqlen): 13 | 14 | dev = torch.device("cuda:0") 15 | 16 | inputs = inputs.input_ids 17 | 18 | use_cache = model.config.use_cache 19 | model.config.use_cache = False 20 | layers = model.model.layers 21 | 22 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 23 | layers[0] = layers[0].to(dev) 24 | 25 | dtype = next(iter(model.parameters())).dtype 26 | inps = torch.zeros((num_samples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 27 | cache = {'i': 0, 'attention_mask': None} 28 | 29 | class Catcher(nn.Module): 30 | def __init__(self, module): 31 | super().__init__() 32 | self.module = module 33 | 34 | def forward(self, inp, **kwargs): 35 | inps[cache['i']] = inp # 1, seqlen, dim 36 | cache['i'] += 1 37 | cache['attention_mask'] = kwargs['attention_mask'] 38 | cache['position_ids'] = kwargs['position_ids'] 39 | raise ValueError 40 | 41 | layers[0] = Catcher(layers[0]) 42 | for i in range(num_samples): 43 | batch = inputs[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 44 | try: 45 | model(batch) 46 | except ValueError: 47 | pass 48 | layers[0] = layers[0].module 49 | 50 | layers[0] = layers[0].cpu() 51 | model.model.embed_tokens = model.model.embed_tokens.cpu() 52 | torch.cuda.empty_cache() 53 | 54 | outs = torch.zeros_like(inps) 55 | attention_mask = cache['attention_mask'] 56 | position_ids = cache['position_ids'] 57 | 58 | for i in range(len(layers)): 59 | layer = layers[i].to(dev) 60 | 61 | for j in range(num_samples): 62 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 63 | layers[i] = layer.cpu() 64 | del layer 65 | torch.cuda.empty_cache() 66 | inps, outs = outs, inps 67 | print(i,end=' ',flush=True) 68 | print() 69 | 70 | if model.model.norm is not None: 71 | model.model.norm = model.model.norm.to(dev) 72 | model.lm_head = model.lm_head.to(dev) 73 | 74 | nlls = [] 75 | inputs = inputs.to(dev) 76 | for i in range(num_samples): 77 | hidden_states = inps[i].unsqueeze(0) 78 | if model.model.norm is not None: 79 | hidden_states = model.model.norm(hidden_states) 80 | lm_logits = model.lm_head(hidden_states) 81 | shift_logits = lm_logits[:, :-1, :].contiguous() 82 | shift_labels = inputs[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:] 83 | loss_fct = nn.CrossEntropyLoss() 84 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 85 | neg_log_likelihood = loss.float() * model.seqlen 86 | nlls.append(neg_log_likelihood) 87 | 88 | model.config.use_cache = use_cache 89 | 90 | return torch.stack(nlls).sum() 91 | 92 | def select_best_scheme(scale_factors, model, inputs, quant_mha = False): 93 | nll_sum = [] 94 | for i, scale_factor in enumerate(scale_factors): 95 | for layer in model.model.layers: 96 | attn = layer.self_attn 97 | mlp = layer.mlp 98 | prefix = "model.layers." + str(attn.layer_idx) 99 | 100 | name = prefix + ".self_attn" + "h_tmax" 101 | attn.h_tmax = scale_factor[name] 102 | name = prefix + ".self_attn" + "h_cmax" 103 | attn.h_group_index = scale_factor[name] 104 | name = prefix + ".self_attn" + "o_tmax" 105 | attn.o_tmax = scale_factor[name] 106 | name = prefix + ".self_attn" + "o_cmax" 107 | attn.o_group_index = scale_factor[name] 108 | 109 | if quant_mha: 110 | name = prefix + ".self_attn" + "q_tmax" 111 | attn.q_tmax = scale_factor[name] 112 | name = prefix + ".self_attn" + "q_cmax" 113 | attn.q_group_index = scale_factor[name] 114 | name = prefix + ".self_attn" + "s_tmax" 115 | attn.s_tmax = scale_factor[name] 116 | name = prefix + ".self_attn" + "s_cmax" 117 | attn.s_group_index = scale_factor[name] 118 | 119 | name = prefix + ".self_attn" + "k_scale" 120 | attn.k_scale = scale_factor[name] 121 | name = prefix + ".self_attn" + "v_scale" 122 | attn.v_scale = scale_factor[name] 123 | 124 | name = prefix + "fc1_tmax" 125 | mlp.fc1_tmax = scale_factor[name] 126 | name = prefix + "fc1_cmax" 127 | mlp.fc1_group_index = scale_factor[name] 128 | name = prefix + "fc2_tmax" 129 | mlp.fc2_tmax = scale_factor[name] 130 | name = prefix + "fc2_cmax" 131 | mlp.fc2_group_index = scale_factor[name] 132 | 133 | nll = forward_by_layer(model, inputs, 1, model.seqlen) 134 | ppl = torch.exp(nll / (1 * model.seqlen)).item() 135 | print("index %d ppl %f"%(i, ppl)) 136 | nll_sum.append(nll.item()) 137 | 138 | idx = np.argmin(np.array(nll_sum)) 139 | if idx==0: 140 | scheme = "rdn" 141 | elif idx==1: 142 | scheme = "rup" 143 | 144 | print("scheme %s selected"%(scheme),flush=True) 145 | 146 | return scale_factors[idx] 147 | 148 | def get_scale_factor(model, tokenizer, dataset_path, num_samples=512, seq_len=512, quant_mha = False): 149 | model.eval() 150 | model.seqlen = seq_len 151 | scale_factor = {} 152 | 153 | def stat_tensor(attn, name): 154 | h_tmax = attn.h_tmax_cal # chunks 155 | o_tmax = attn.o_tmax_cal 156 | 157 | h_cmax = attn.h_cmax_cal # chunks, hidden_dim 158 | o_cmax = attn.o_cmax_cal 159 | 160 | tmaxes = [h_tmax, o_tmax] 161 | cmaxes = [h_cmax, o_cmax] 162 | names = ["h", "o"] 163 | 164 | if quant_mha: 165 | k_scale = attn.k_scale_cal 166 | v_scale = attn.v_scale_cal 167 | 168 | q_tmax = attn.q_tmax_cal # b*h, chunks 169 | s_tmax = attn.s_tmax_cal # b*h, chunks 170 | q_cmax = attn.q_cmax_cal # b*h, chunks, head_dim 171 | s_cmax = attn.s_cmax_cal # b*h, chunks, head_dim 172 | tmaxes.extend([q_tmax, s_tmax]) 173 | cmaxes.extend([q_cmax, s_cmax]) 174 | names.extend(["q", "s"]) 175 | 176 | if name in scale_factor: 177 | for i in range(len(names)): 178 | old_tmax = scale_factor[name + names[i] + "_tmax"] 179 | new_tmax = tmaxes[i] 180 | old_cmax = scale_factor[name + names[i] + "_cmax"] 181 | new_cmax = cmaxes[i] 182 | 183 | scale_factor[name + names[i] + "_tmax"] = torch.where(old_tmax > new_tmax, old_tmax, new_tmax) 184 | scale_factor[name + names[i] + "_cmax"] = torch.where(old_cmax > new_cmax, old_cmax, new_cmax) 185 | if quant_mha: 186 | old_k_scale = scale_factor[name + "k_scale"] 187 | old_v_scale = scale_factor[name + "v_scale"] 188 | scale_factor[name + "k_scale"] = torch.where(old_k_scale > k_scale, old_k_scale, k_scale) 189 | scale_factor[name + "v_scale"] = torch.where(old_v_scale > v_scale, old_v_scale, v_scale) 190 | else: 191 | for i in range(len(names)): 192 | scale_factor[name + names[i] + "_tmax"] = tmaxes[i] 193 | scale_factor[name + names[i] + "_cmax"] = cmaxes[i] 194 | if quant_mha: 195 | scale_factor[name + "k_scale"] = k_scale 196 | scale_factor[name + "v_scale"] = v_scale 197 | scale_factor[name] = True 198 | 199 | def decoder_layer_stat_tensor(decoder, name): 200 | fc1_tmax = decoder.mlp.fc1_tmax_cal 201 | fc2_tmax = decoder.mlp.fc2_tmax_cal 202 | 203 | fc1_cmax = decoder.mlp.fc1_cmax_cal # chunks, hidden_dim 204 | fc2_cmax = decoder.mlp.fc2_cmax_cal 205 | 206 | tmaxes = [fc1_tmax, fc2_tmax] 207 | cmaxes = [fc1_cmax, fc2_cmax] 208 | names = ["fc1", "fc2"] 209 | 210 | if name in scale_factor: 211 | for i in range(len(names)): 212 | old_tmax = scale_factor[name + names[i] + "_tmax"] 213 | new_tmax = tmaxes[i] 214 | old_cmax = scale_factor[name + names[i] + "_cmax"] 215 | new_cmax = cmaxes[i] 216 | 217 | scale_factor[name + names[i] + "_tmax"] = torch.where(old_tmax > new_tmax, old_tmax, new_tmax) 218 | scale_factor[name + names[i] + "_cmax"] = torch.where(old_cmax > new_cmax, old_cmax, new_cmax) 219 | else: 220 | for i in range(len(names)): 221 | scale_factor[name + names[i] + "_tmax"] = tmaxes[i] 222 | scale_factor[name + names[i] + "_cmax"] = cmaxes[i] 223 | scale_factor[name] = True 224 | 225 | def stat_input_hook(m, hidden_states, output_attentions, name): 226 | stat_tensor(m, name) 227 | def decoder_layer_stat_input_hook(m, hidden_states, output_attentions, name): 228 | decoder_layer_stat_tensor(m, name) 229 | 230 | hooks = [] 231 | for name, m in model.named_modules(): 232 | if name.endswith('self_attn'): 233 | hooks.append( 234 | m.register_forward_hook( 235 | functools.partial(stat_input_hook, name=name)) 236 | ) 237 | if name.endswith('layers'): 238 | layer_index = 0 239 | for layer in m: 240 | hooks.append( 241 | layer.register_forward_hook( 242 | functools.partial(decoder_layer_stat_input_hook, name=name+"."+str(layer_index))) 243 | ) 244 | layer_index += 1 245 | 246 | dataset = load_dataset("json", data_files=dataset_path, split="train") 247 | dataset = dataset.shuffle(seed=42) 248 | 249 | inputs = tokenizer("\n\n".join(dataset['text'][:1000]), return_tensors='pt') 250 | 251 | forward_by_layer(model, inputs, num_samples, seq_len) 252 | 253 | for h in hooks: 254 | h.remove() 255 | 256 | decomp_factor = model.model.layers[0].self_attn.decomp_factor 257 | 258 | # Static calibration: Chooses between round up and round down 259 | # Runtime: Round up 260 | import copy 261 | scale_factor_rdn = copy.deepcopy(scale_factor) 262 | scale_factor_rup = copy.deepcopy(scale_factor) 263 | 264 | for name in scale_factor: 265 | if "tmax" in name: 266 | tmax = scale_factor[name] # chunks 267 | cmax = scale_factor[name.replace("tmax", "cmax")] # chunks, hidden_dim 268 | 269 | thresholds = [] 270 | for i in range(decomp_factor): 271 | thresh = (tmax / (2**(decomp_factor-1-i))).unsqueeze(-1) # chunks, 1 272 | thresholds.append(thresh) 273 | 274 | group_index_rdn = torch.zeros_like(cmax) # chunks, hidden_dim 275 | group_index_rup = torch.zeros_like(cmax) 276 | 277 | for i in range(decomp_factor): 278 | if i == 0: 279 | mask = cmax <= thresholds[i] 280 | group_index_rup = torch.where(mask, i, group_index_rup) 281 | group_index_rdn = torch.where(mask, i, group_index_rdn) 282 | else: 283 | mask = torch.logical_and((thresholds[i-1] < cmax), (cmax <= thresholds[i])) 284 | group_index_rup = torch.where(mask, i, group_index_rup) 285 | 286 | group_index_rdn = torch.where(mask, i-1, group_index_rdn) 287 | 288 | scale_factor_rdn[name.replace("tmax", "cmax")] = group_index_rdn 289 | scale_factor_rup[name.replace("tmax", "cmax")] = group_index_rup 290 | 291 | scale_factors = [scale_factor_rdn, scale_factor_rup] 292 | scale_factor = select_best_scheme(scale_factors, model, inputs, quant_mha) 293 | 294 | return scale_factor 295 | 296 | def get_bias(model, tokenizer, dataset_path, num_samples = 512, seq_len = 512, quant_mha = False): 297 | model.eval() 298 | model.seqlen = seq_len 299 | device = next(model.parameters()).device 300 | bias = {} 301 | 302 | def stat_tensor(attn, name): 303 | h_ch_bias = attn.h_ch_bias_cal # chunks, 1, hidden_dim 304 | 305 | biases = [h_ch_bias] 306 | bias_names = ["h_ch_bias"] 307 | 308 | if quant_mha: 309 | q_ch_bias = attn.q_ch_bias_cal # heads, chunks, 1, head_dim 310 | k_ch_bias = attn.k_ch_bias_cal # heads, 1, head_dim 311 | biases.extend([q_ch_bias, k_ch_bias]) 312 | bias_names.extend(['q_ch_bias', 'k_ch_bias']) 313 | 314 | if name in bias: 315 | for i in range(len(biases)): 316 | bias[name + bias_names[i]] = bias[name + bias_names[i]] + biases[i] 317 | else: 318 | for i in range(len(biases)): 319 | bias[name + bias_names[i]] = biases[i] 320 | bias[name] = True 321 | 322 | def decoder_layer_stat_tensor(decoder, name): 323 | h_ch_bias = decoder.mlp.h_ch_bias_cal 324 | 325 | biases = [h_ch_bias] 326 | bias_names = ["h_ch_bias"] 327 | 328 | if name in bias: 329 | for i in range(len(biases)): 330 | bias[name + bias_names[i]] = bias[name + bias_names[i]] + biases[i] 331 | else: 332 | for i in range(len(biases)): 333 | bias[name + bias_names[i]] = biases[i] 334 | bias[name] = True 335 | 336 | def stat_input_hook(m, hidden_states, output_attentions, name): 337 | stat_tensor(m, name) 338 | 339 | def decoder_layer_stat_input_hook(m, hidden_states, output_attentions, name): 340 | decoder_layer_stat_tensor(m, name) 341 | 342 | hooks = [] 343 | for name, m in model.named_modules(): 344 | if name.endswith('self_attn'): 345 | hooks.append( 346 | m.register_forward_hook( 347 | functools.partial(stat_input_hook, name=name)) 348 | ) 349 | if name.endswith('layers'): 350 | layer_index = 0 351 | for layer in m: 352 | hooks.append( 353 | layer.register_forward_hook( 354 | functools.partial(decoder_layer_stat_input_hook, name=name+"."+str(layer_index))) 355 | ) 356 | layer_index += 1 357 | 358 | dataset = load_dataset("json", data_files=dataset_path, split="train") 359 | dataset = dataset.shuffle(seed=42) 360 | 361 | inputs = tokenizer("\n\n".join(dataset['text'][:1000]), return_tensors='pt') 362 | forward_by_layer(model, inputs, num_samples, seq_len) 363 | 364 | for h in hooks: 365 | h.remove() 366 | 367 | for name in bias: 368 | bias[name] = bias[name] / num_samples 369 | 370 | return bias 371 | -------------------------------------------------------------------------------- /calibration/llama/generate_bias.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHUNK_SIZE=256 4 | 5 | # Llama-2 - Linear 6 | for nsamples in 128;do 7 | for seqlen in 2048;do 8 | for size in 7b 13b 70b;do 9 | for bits in 4 8;do 10 | case ${size} in 11 | 7b) 12 | if [ ${bits} = 4 ]; then 13 | decomp=14 14 | elif [ ${bits} = 8 ]; then 15 | decomp=8 16 | fi 17 | ;; 18 | 13b) 19 | if [ ${bits} = 4 ]; then 20 | decomp=16 21 | elif [ ${bits} = 8 ]; then 22 | decomp=14 23 | fi 24 | ;; 25 | 70b) 26 | if [ ${bits} = 4 ]; then 27 | decomp=20 28 | elif [ ${bits} = 8 ]; then 29 | decomp=16 30 | fi 31 | ;; 32 | esac 33 | echo calibrating llama-2-${size} ${bits}bit chunk size of ${CHUNK_SIZE} 34 | echo linear only 35 | python run_calibration.py \ 36 | --model-name "${LLAMA2_PATH}/llama-2-${size}" \ 37 | --target "bias" \ 38 | --output-path "llama-2-bias/${seqlen}_${size}_${nsamples}_${bits}bit_${decomp}decomp.pt" \ 39 | --dataset-path "../../data/val.jsonl.zst" \ 40 | --num-samples ${nsamples} \ 41 | --seq-len ${seqlen} \ 42 | --q_bits ${bits} \ 43 | --decomp_factor ${decomp} \ 44 | --chunk_size ${CHUNK_SIZE} 45 | done 46 | done 47 | done 48 | done 49 | 50 | # Llama-1 - Linear 51 | for nsamples in 128;do 52 | for seqlen in 2048;do 53 | for size in 7b 13b;do 54 | for bits in 4 8;do 55 | decomp=14 56 | echo calibrating llama-1-${size} ${bits}bit chunk size of ${CHUNK_SIZE} 57 | echo linear only 58 | 59 | case ${size} in 60 | 7b) 61 | model_name="baffo32/decapoda-research-llama-7B-hf" 62 | ;; 63 | 13b) 64 | model_name="JG22/decapoda-research-llama-13b" 65 | ;; 66 | esac 67 | python run_calibration.py \ 68 | --model-name ${model_name} \ 69 | --target "bias" \ 70 | --output-path "llama-1-bias/${seqlen}_${size}_${nsamples}_${bits}bit_${decomp}decomp.pt" \ 71 | --dataset-path "../../data/val.jsonl.zst" \ 72 | --num-samples ${nsamples} \ 73 | --seq-len ${seqlen} \ 74 | --q_bits ${bits} \ 75 | --decomp_factor ${decomp} \ 76 | --chunk_size ${CHUNK_SIZE} 77 | done 78 | done 79 | done 80 | done 81 | -------------------------------------------------------------------------------- /calibration/llama/generate_scale_factor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHUNK_SIZE=256 4 | 5 | # Llama-2 - Linear 6 | for nsamples in 128;do 7 | for seqlen in 2048;do 8 | for size in 7b 13b 70b;do 9 | for bits in 4 8;do 10 | case ${size} in 11 | 7b) 12 | if [ ${bits} = 4 ]; then 13 | decomp=14 14 | elif [ ${bits} = 8 ]; then 15 | decomp=8 16 | fi 17 | ;; 18 | 13b) 19 | if [ ${bits} = 4 ]; then 20 | decomp=16 21 | elif [ ${bits} = 8 ]; then 22 | decomp=14 23 | fi 24 | ;; 25 | 70b) 26 | if [ ${bits} = 4 ]; then 27 | decomp=20 28 | elif [ ${bits} = 8 ]; then 29 | decomp=16 30 | fi 31 | ;; 32 | esac 33 | echo calibrating llama-2-${size} ${bits}bit chunk size of ${CHUNK_SIZE} 34 | echo linear only 35 | python run_calibration.py \ 36 | --model-name "${LLAMA2_PATH}/llama-2-${size}" \ 37 | --target "scale" \ 38 | --output-path "llama-2-scale/${seqlen}_${size}_${nsamples}_${bits}bit_${decomp}decomp.pt" \ 39 | --dataset-path "../../data/val.jsonl.zst" \ 40 | --num-samples ${nsamples} \ 41 | --seq-len ${seqlen} \ 42 | --q_bits ${bits} \ 43 | --decomp_factor ${decomp} \ 44 | --chunk_size ${CHUNK_SIZE} 45 | done 46 | done 47 | done 48 | done 49 | 50 | # Llama-1 - Linear 51 | for nsamples in 128;do 52 | for seqlen in 2048;do 53 | for size in 7b 13b;do 54 | for bits in 4 8;do 55 | decomp=14 56 | echo calibrating llama-1-${size} ${bits}bit chunk size of ${CHUNK_SIZE} 57 | echo linear only 58 | 59 | case ${size} in 60 | 7b) 61 | model_name="baffo32/decapoda-research-llama-7B-hf" 62 | ;; 63 | 13b) 64 | model_name="JG22/decapoda-research-llama-13b" 65 | ;; 66 | esac 67 | python run_calibration.py \ 68 | --model-name ${model_name} \ 69 | --target "scale" \ 70 | --output-path "llama-1-scale/${seqlen}_${size}_${nsamples}_${bits}bit_${decomp}decomp.pt" \ 71 | --dataset-path "../../data/val.jsonl.zst" \ 72 | --num-samples ${nsamples} \ 73 | --seq-len ${seqlen} \ 74 | --q_bits ${bits} \ 75 | --decomp_factor ${decomp} \ 76 | --chunk_size ${CHUNK_SIZE} 77 | done 78 | done 79 | done 80 | done 81 | -------------------------------------------------------------------------------- /calibration/llama/run_calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | import argparse 5 | from calibration import get_scale_factor, get_bias 6 | 7 | def build_model_and_tokenizer(model_name, seq_len): 8 | if os.path.exists('../../models/modeling_llama.py'): 9 | os.system('rm ../../models/modeling_llama.py') 10 | cwd = os.getcwd() 11 | os.chdir('../../models/') 12 | os.system('ln -s modeling_llama_tender.py modeling_llama.py') 13 | os.chdir(cwd) 14 | 15 | from transformers import ( 16 | LlamaForCausalLM, 17 | LlamaTokenizer, 18 | ) 19 | 20 | kwargs = {"torch_dtype": torch.float16, "device_map": "cpu"} 21 | model = LlamaForCausalLM.from_pretrained(model_name, **kwargs) 22 | tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=seq_len) 23 | return model, tokenizer 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--model-name', type=str, 28 | required=True, help='model name') 29 | parser.add_argument('--target', type=str, choices=['scale', 'bias'], required=True, 30 | help='Calibrate scale factor or bias') 31 | parser.add_argument('--output-path', type=str, default='biases/llama-7b.pt', 32 | help='where to save the result') 33 | parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', 34 | help='location of the calibration dataset, we use the validation set of the Pile dataset') 35 | parser.add_argument('--num-samples', type=int, default=512) 36 | parser.add_argument('--seq-len', type=int, default=512) 37 | parser.add_argument('--q_bits', type=int, default=8, 38 | help='Number of bits for quantization') 39 | parser.add_argument('--decomp_factor', type=int, default=8, 40 | help='Number of column groups') 41 | parser.add_argument('--chunk_size', type=int, default=256, 42 | help='Size of row chunk') 43 | parser.add_argument('--quant_mha', action='store_true', 44 | help='Whether to quantize multi-head-attention') 45 | args = parser.parse_args() 46 | return args 47 | 48 | @torch.no_grad() 49 | def main(): 50 | args = parse_args() 51 | model, tokenizer = build_model_and_tokenizer(args.model_name, args.seq_len) 52 | 53 | for layer in model.model.layers: 54 | layer.self_attn.quant_mha = args.quant_mha 55 | 56 | layer.self_attn.q_bits = args.q_bits 57 | layer.mlp.q_bits = args.q_bits 58 | 59 | layer.self_attn.decomp_factor = args.decomp_factor 60 | layer.mlp.decomp_factor = args.decomp_factor 61 | 62 | layer.self_attn.chunk_size = args.chunk_size 63 | layer.mlp.chunk_size = args.chunk_size 64 | 65 | if args.target == 'scale': 66 | result = get_scale_factor(model, tokenizer, args.dataset_path, 67 | args.num_samples, args.seq_len, args.quant_mha) 68 | else: 69 | result = get_bias(model, tokenizer, args.dataset_path, 70 | args.num_samples, args.seq_len, args.quant_mha) 71 | 72 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 73 | torch.save(result, args.output_path) 74 | 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /calibration/opt/calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from datasets import load_dataset 5 | import functools 6 | from collections import defaultdict 7 | 8 | from functools import partial 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | def forward_by_layer(model, inputs, num_samples, seqlen): 13 | 14 | dev = torch.device("cuda:0") 15 | 16 | inputs = inputs.input_ids 17 | 18 | use_cache = model.config.use_cache 19 | model.config.use_cache = False 20 | layers = model.model.decoder.layers 21 | 22 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 23 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 24 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 25 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 26 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 27 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 28 | layers[0] = layers[0].to(dev) 29 | 30 | dtype = next(iter(model.parameters())).dtype 31 | inps = torch.zeros((num_samples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 32 | #inps = [0 for _ in range(num_samples)] 33 | cache = {'i': 0, 'attention_mask': None} 34 | 35 | class Catcher(nn.Module): 36 | def __init__(self, module): 37 | super().__init__() 38 | self.module = module 39 | 40 | def forward(self, inp, **kwargs): 41 | inps[cache['i']] = inp # 1, seqlen, dim 42 | cache['i'] += 1 43 | #if kwargs['attention_mask'].shape[-1] == seqlen: 44 | cache['attention_mask'] = kwargs['attention_mask'] 45 | raise ValueError 46 | 47 | # Layer 0 -> Attention mask, position embedding 48 | layers[0] = Catcher(layers[0]) 49 | for i in range(num_samples): 50 | batch = inputs[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 51 | try: 52 | model(batch) 53 | except ValueError: 54 | pass 55 | layers[0] = layers[0].module 56 | 57 | layers[0] = layers[0].cpu() 58 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 59 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 60 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 61 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 62 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 63 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 64 | torch.cuda.empty_cache() 65 | 66 | outs = torch.zeros_like(inps) 67 | attention_mask = cache['attention_mask'] 68 | 69 | # Entire layer 70 | for i in range(len(layers)): 71 | layer = layers[i].to(dev) 72 | 73 | for j in range(num_samples): 74 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 75 | 76 | layers[i] = layer.cpu() 77 | del layer 78 | torch.cuda.empty_cache() 79 | inps, outs = outs, inps 80 | print(i, end=' ', flush=True) 81 | print() 82 | 83 | if model.model.decoder.final_layer_norm is not None: 84 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 85 | if model.model.decoder.project_out is not None: 86 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 87 | model.lm_head = model.lm_head.to(dev) 88 | 89 | # Final layer norm and lm_head 90 | nlls = [] 91 | inputs = inputs.to(dev) 92 | for i in range(num_samples): 93 | hidden_states = inps[i].unsqueeze(0) 94 | if model.model.decoder.final_layer_norm is not None: 95 | hidden_states = model.model.decoder.final_layer_norm(hidden_states) 96 | if model.model.decoder.project_out is not None: 97 | hidden_states = model.model.decoder.project_out(hidden_states) 98 | lm_logits = model.lm_head(hidden_states) 99 | shift_logits = lm_logits[:, :-1, :].contiguous() 100 | shift_labels = inputs[ 101 | :, (i * model.seqlen):((i + 1) * model.seqlen) 102 | ][:, 1:] 103 | loss_fct = nn.CrossEntropyLoss() 104 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 105 | neg_log_likelihood = loss.float() * model.seqlen 106 | nlls.append(neg_log_likelihood) 107 | 108 | model.config.use_cache = use_cache 109 | 110 | return torch.stack(nlls).sum() 111 | 112 | def select_best_scheme(scale_factors, model, inputs, quant_mha = False): 113 | nll_sum = [] 114 | for i, scale_factor in enumerate(scale_factors): 115 | for layer in model.model.decoder.layers: 116 | attn = layer.self_attn 117 | prefix = "model.decoder.layers." + str(attn.layer_idx) 118 | 119 | name = prefix + ".self_attn" + "h_tmax" 120 | attn.h_tmax = scale_factor[name] 121 | name = prefix + ".self_attn" + "h_cmax" 122 | attn.h_group_index = scale_factor[name] 123 | name = prefix + ".self_attn" + "o_tmax" 124 | attn.o_tmax = scale_factor[name] 125 | name = prefix + ".self_attn" + "o_cmax" 126 | attn.o_group_index = scale_factor[name] 127 | 128 | if quant_mha: 129 | name = prefix + ".self_attn" + "q_tmax" 130 | attn.q_tmax = scale_factor[name] 131 | name = prefix + ".self_attn" + "q_cmax" 132 | attn.q_group_index = scale_factor[name] 133 | name = prefix + ".self_attn" + "s_tmax" 134 | attn.s_tmax = scale_factor[name] 135 | name = prefix + ".self_attn" + "s_cmax" 136 | attn.s_group_index = scale_factor[name] 137 | 138 | name = prefix + ".self_attn" + "k_scale" 139 | attn.k_scale = scale_factor[name] 140 | name = prefix + ".self_attn" + "v_scale" 141 | attn.v_scale = scale_factor[name] 142 | 143 | name = prefix + "fc1_tmax" 144 | layer.fc1_tmax = scale_factor[name] 145 | name = prefix + "fc1_cmax" 146 | layer.fc1_group_index = scale_factor[name] 147 | name = prefix + "fc2_tmax" 148 | layer.fc2_tmax = scale_factor[name] 149 | name = prefix + "fc2_cmax" 150 | layer.fc2_group_index = scale_factor[name] 151 | 152 | nll = forward_by_layer(model, inputs, 1, model.seqlen) 153 | ppl = torch.exp(nll / (1 * model.seqlen)).item() 154 | print("index %d ppl %f"%(i, ppl)) 155 | nll_sum.append(nll.item()) 156 | 157 | idx = np.argmin(np.array(nll_sum)) 158 | if idx==0: 159 | scheme = "rdn" 160 | elif idx==1: 161 | scheme = "rup" 162 | 163 | print("scheme %s selected"%(scheme),flush=True) 164 | 165 | return scale_factors[idx] 166 | 167 | def get_scale_factor(model, tokenizer, dataset_path, num_samples=512, seq_len=512, quant_mha = False): 168 | model.eval() 169 | model.seqlen = seq_len 170 | scale_factor = {} 171 | 172 | def stat_tensor(attn, name): 173 | h_tmax = attn.h_tmax_cal # chunks 174 | o_tmax = attn.o_tmax_cal 175 | 176 | h_cmax = attn.h_cmax_cal # chunks, hidden_dim 177 | o_cmax = attn.o_cmax_cal 178 | 179 | tmaxes = [h_tmax, o_tmax] 180 | cmaxes = [h_cmax, o_cmax] 181 | names = ["h", "o"] 182 | 183 | if quant_mha: 184 | k_scale = attn.k_scale_cal 185 | v_scale = attn.v_scale_cal 186 | 187 | q_tmax = attn.q_tmax_cal # b*h, chunks 188 | s_tmax = attn.s_tmax_cal # b*h, chunks 189 | q_cmax = attn.q_cmax_cal # b*h, chunks, head_dim 190 | s_cmax = attn.s_cmax_cal # b*h, chunks, head_dim 191 | tmaxes.extend([q_tmax, s_tmax]) 192 | cmaxes.extend([q_cmax, s_cmax]) 193 | names.extend(["q", "s"]) 194 | 195 | if name in scale_factor: 196 | for i in range(len(names)): 197 | old_tmax = scale_factor[name + names[i] + "_tmax"] 198 | new_tmax = tmaxes[i] 199 | old_cmax = scale_factor[name + names[i] + "_cmax"] 200 | new_cmax = cmaxes[i] 201 | 202 | scale_factor[name + names[i] + "_tmax"] = torch.where(old_tmax > new_tmax, old_tmax, new_tmax) 203 | scale_factor[name + names[i] + "_cmax"] = torch.where(old_cmax > new_cmax, old_cmax, new_cmax) 204 | if quant_mha: 205 | old_k_scale = scale_factor[name + "k_scale"] 206 | old_v_scale = scale_factor[name + "v_scale"] 207 | scale_factor[name + "k_scale"] = torch.where(old_k_scale > k_scale, old_k_scale, k_scale) 208 | scale_factor[name + "v_scale"] = torch.where(old_v_scale > v_scale, old_v_scale, v_scale) 209 | else: 210 | for i in range(len(names)): 211 | scale_factor[name + names[i] + "_tmax"] = tmaxes[i] 212 | scale_factor[name + names[i] + "_cmax"] = cmaxes[i] 213 | if quant_mha: 214 | scale_factor[name + "k_scale"] = k_scale 215 | scale_factor[name + "v_scale"] = v_scale 216 | scale_factor[name] = True 217 | 218 | def decoder_layer_stat_tensor(decoder, name): 219 | fc1_tmax = decoder.fc1_tmax_cal 220 | fc2_tmax = decoder.fc2_tmax_cal 221 | 222 | fc1_cmax = decoder.fc1_cmax_cal # chunks, hidden_dim 223 | fc2_cmax = decoder.fc2_cmax_cal 224 | 225 | tmaxes = [fc1_tmax, fc2_tmax] 226 | cmaxes = [fc1_cmax, fc2_cmax] 227 | names = ["fc1", "fc2"] 228 | 229 | if name in scale_factor: 230 | for i in range(len(names)): 231 | old_tmax = scale_factor[name + names[i] + "_tmax"] 232 | new_tmax = tmaxes[i] 233 | old_cmax = scale_factor[name + names[i] + "_cmax"] 234 | new_cmax = cmaxes[i] 235 | 236 | scale_factor[name + names[i] + "_tmax"] = torch.where(old_tmax > new_tmax, old_tmax, new_tmax) 237 | scale_factor[name + names[i] + "_cmax"] = torch.where(old_cmax > new_cmax, old_cmax, new_cmax) 238 | else: 239 | for i in range(len(names)): 240 | scale_factor[name + names[i] + "_tmax"] = tmaxes[i] 241 | scale_factor[name + names[i] + "_cmax"] = cmaxes[i] 242 | scale_factor[name] = True 243 | 244 | def stat_input_hook(m, hidden_states, output_attentions, name): 245 | stat_tensor(m, name) 246 | def decoder_layer_stat_input_hook(m, hidden_states, output_attentions, name): 247 | decoder_layer_stat_tensor(m, name) 248 | 249 | hooks = [] 250 | for name, m in model.named_modules(): 251 | if name.endswith('self_attn'): 252 | hooks.append( 253 | m.register_forward_hook( 254 | functools.partial(stat_input_hook, name=name)) 255 | ) 256 | if name.endswith('layers'): 257 | layer_index = 0 258 | for layer in m: 259 | hooks.append( 260 | layer.register_forward_hook( 261 | functools.partial(decoder_layer_stat_input_hook, name=name+"."+str(layer_index))) 262 | ) 263 | layer_index += 1 264 | 265 | dataset = load_dataset("json", data_files=dataset_path, split="train") 266 | dataset = dataset.shuffle(seed=42) 267 | 268 | inputs = tokenizer("\n\n".join(dataset['text'][:1000]), return_tensors='pt') 269 | 270 | forward_by_layer(model, inputs, num_samples, seq_len) 271 | 272 | for h in hooks: 273 | h.remove() 274 | 275 | decomp_factor = model.model.decoder.layers[0].decomp_factor 276 | 277 | # Static calibration: Chooses between round up and round down 278 | # Runtime: Round up 279 | import copy 280 | scale_factor_rdn = copy.deepcopy(scale_factor) 281 | scale_factor_rup = copy.deepcopy(scale_factor) 282 | 283 | for name in scale_factor: 284 | if "tmax" in name: 285 | tmax = scale_factor[name] # chunks 286 | cmax = scale_factor[name.replace("tmax", "cmax")] # chunks, hidden_dim 287 | 288 | thresholds = [] 289 | for i in range(decomp_factor): 290 | thresh = (tmax / (2**(decomp_factor-1-i))).unsqueeze(-1) # chunks, 1 291 | thresholds.append(thresh) 292 | 293 | group_index_rdn = torch.zeros_like(cmax) # chunks, hidden_dim 294 | group_index_rup = torch.zeros_like(cmax) 295 | 296 | for i in range(decomp_factor): 297 | if i == 0: 298 | mask = cmax <= thresholds[i] 299 | group_index_rup = torch.where(mask, i, group_index_rup) 300 | group_index_rdn = torch.where(mask, i, group_index_rdn) 301 | else: 302 | mask = torch.logical_and((thresholds[i-1] < cmax), (cmax <= thresholds[i])) 303 | group_index_rup = torch.where(mask, i, group_index_rup) 304 | 305 | group_index_rdn = torch.where(mask, i-1, group_index_rdn) 306 | 307 | scale_factor_rdn[name.replace("tmax", "cmax")] = group_index_rdn 308 | scale_factor_rup[name.replace("tmax", "cmax")] = group_index_rup 309 | 310 | scale_factors = [scale_factor_rdn, scale_factor_rup] 311 | scale_factor = select_best_scheme(scale_factors, model, inputs, quant_mha) 312 | 313 | return scale_factor 314 | 315 | def get_bias(model, tokenizer, dataset_path, num_samples = 512, seq_len = 512, quant_mha = False): 316 | model.eval() 317 | model.seqlen = seq_len 318 | device = next(model.parameters()).device 319 | bias = {} 320 | 321 | def stat_tensor(attn, name): 322 | h_ch_bias = attn.h_ch_bias_cal # chunks, 1, hidden_dim 323 | o_ch_bias= attn.o_ch_bias_cal 324 | 325 | biases = [h_ch_bias, o_ch_bias] 326 | bias_names = ["h_ch_bias", "o_ch_bias"] 327 | 328 | if quant_mha: 329 | q_ch_bias = attn.q_ch_bias_cal # heads, chunks, 1, head_dim 330 | k_ch_bias = attn.k_ch_bias_cal # heads, 1, head_dim 331 | biases.extend([q_ch_bias, k_ch_bias]) 332 | bias_names.extend(['q_ch_bias', 'k_ch_bias']) 333 | 334 | if name in bias: 335 | for i in range(len(biases)): 336 | bias[name + bias_names[i]] = bias[name + bias_names[i]] + biases[i] 337 | else: 338 | for i in range(len(biases)): 339 | bias[name + bias_names[i]] = biases[i] 340 | bias[name] = True 341 | 342 | def decoder_layer_stat_tensor(decoder, name): 343 | h_ch_bias = decoder.h_ch_bias_cal 344 | 345 | biases = [h_ch_bias] 346 | bias_names = ["h_ch_bias"] 347 | 348 | if name in bias: 349 | for i in range(len(biases)): 350 | bias[name + bias_names[i]] = bias[name + bias_names[i]] + biases[i] 351 | else: 352 | for i in range(len(biases)): 353 | bias[name + bias_names[i]] = biases[i] 354 | bias[name] = True 355 | 356 | def stat_input_hook(m, hidden_states, output_attentions, name): 357 | stat_tensor(m, name) 358 | 359 | def decoder_layer_stat_input_hook(m, hidden_states, output_attentions, name): 360 | decoder_layer_stat_tensor(m, name) 361 | 362 | hooks = [] 363 | for name, m in model.named_modules(): 364 | if name.endswith('self_attn'): 365 | hooks.append( 366 | m.register_forward_hook( 367 | functools.partial(stat_input_hook, name=name)) 368 | ) 369 | if name.endswith('layers'): 370 | layer_index = 0 371 | for layer in m: 372 | hooks.append( 373 | layer.register_forward_hook( 374 | functools.partial(decoder_layer_stat_input_hook, name=name+"."+str(layer_index))) 375 | ) 376 | layer_index += 1 377 | 378 | dataset = load_dataset("json", data_files=dataset_path, split="train") 379 | dataset = dataset.shuffle(seed=42) 380 | 381 | inputs = tokenizer("\n\n".join(dataset['text'][:1000]), return_tensors='pt') 382 | forward_by_layer(model, inputs, num_samples, seq_len) 383 | 384 | for h in hooks: 385 | h.remove() 386 | 387 | for name in bias: 388 | bias[name] = bias[name] / num_samples 389 | 390 | return bias 391 | -------------------------------------------------------------------------------- /calibration/opt/generate_bias.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHUNK_SIZE=256 4 | 5 | # OPT - Linear 6 | for nsamples in 128;do 7 | for seqlen in 2048;do 8 | for size in 6.7b 13b 66b;do 9 | for bits in 4 8;do 10 | case ${size} in 11 | 6.7b) 12 | if [ ${bits} = 4 ]; then 13 | decomp=8 14 | elif [ ${bits} = 8 ]; then 15 | decomp=4 16 | fi 17 | ;; 18 | 13b) 19 | if [ ${bits} = 4 ]; then 20 | decomp=8 21 | elif [ ${bits} = 8 ]; then 22 | decomp=4 23 | fi 24 | ;; 25 | 66b) 26 | if [ ${bits} = 4 ]; then 27 | decomp=10 28 | elif [ ${bits} = 8 ]; then 29 | decomp=8 30 | fi 31 | ;; 32 | esac 33 | echo calibrating opt-${size} ${bits}bit chunk size of ${CHUNK_SIZE} 34 | echo linear only 35 | python run_calibration.py \ 36 | --model-name "facebook/opt-${size}" \ 37 | --target "bias" \ 38 | --output-path "bias/${seqlen}_${size}_${nsamples}_${bits}bit_${decomp}decomp.pt" \ 39 | --dataset-path "../../data/val.jsonl.zst" \ 40 | --num-samples ${nsamples} \ 41 | --seq-len ${seqlen} \ 42 | --q_bits ${bits} \ 43 | --decomp_factor ${decomp} \ 44 | --chunk_size ${CHUNK_SIZE} 45 | done 46 | done 47 | done 48 | done 49 | 50 | 51 | # OPT - Linear & MHA 52 | for nsamples in 128;do 53 | for seqlen in 2048;do 54 | for size in 6.7b;do 55 | for bits in 4 8;do 56 | case ${size} in 57 | 6.7b) 58 | if [ ${bits} = 4 ]; then 59 | decomp=8 60 | elif [ ${bits} = 8 ]; then 61 | decomp=4 62 | fi 63 | ;; 64 | 13b) 65 | if [ ${bits} = 4 ]; then 66 | decomp=8 67 | elif [ ${bits} = 8 ]; then 68 | decomp=4 69 | fi 70 | ;; 71 | 66b) 72 | if [ ${bits} = 4 ]; then 73 | decomp=10 74 | elif [ ${bits} = 8 ]; then 75 | decomp=8 76 | fi 77 | ;; 78 | esac 79 | echo calibrating opt-${size} ${bits}bit chunk size of ${CHUNK_SIZE} 80 | echo linear + mha 81 | python run_calibration.py \ 82 | --model-name "facebook/opt-${size}" \ 83 | --target "bias" \ 84 | --output-path "bias/${seqlen}_${size}_${nsamples}_${bits}bit_${decomp}decomp_mha.pt" \ 85 | --dataset-path "../../data/val.jsonl.zst" \ 86 | --num-samples ${nsamples} \ 87 | --seq-len ${seqlen} \ 88 | --q_bits ${bits} \ 89 | --decomp_factor ${decomp} \ 90 | --chunk_size ${CHUNK_SIZE} \ 91 | --quant_mha 92 | done 93 | done 94 | done 95 | done 96 | 97 | -------------------------------------------------------------------------------- /calibration/opt/generate_scale_factor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CHUNK_SIZE=256 3 | 4 | # OPT - Linear 5 | for nsamples in 128;do 6 | for seqlen in 2048;do 7 | for size in 6.7b 13b 66b;do 8 | for bits in 4 8;do 9 | case ${size} in 10 | 6.7b) 11 | if [ ${bits} = 4 ]; then 12 | decomp=8 13 | elif [ ${bits} = 8 ]; then 14 | decomp=4 15 | fi 16 | ;; 17 | 13b) 18 | if [ ${bits} = 4 ]; then 19 | decomp=8 20 | elif [ ${bits} = 8 ]; then 21 | decomp=4 22 | fi 23 | ;; 24 | 66b) 25 | if [ ${bits} = 4 ]; then 26 | decomp=10 27 | elif [ ${bits} = 8 ]; then 28 | decomp=8 29 | fi 30 | ;; 31 | esac 32 | echo calibrating opt-${size} ${bits}bit chunk size of ${CHUNK_SIZE} 33 | echo linear only 34 | python run_calibration.py \ 35 | --model-name "facebook/opt-${size}" \ 36 | --target "scale" \ 37 | --output-path "scale/${seqlen}_${size}_${nsamples}_${bits}bit_${decomp}decomp.pt" \ 38 | --dataset-path "../../data/val.jsonl.zst" \ 39 | --num-samples ${nsamples} \ 40 | --seq-len ${seqlen} \ 41 | --q_bits ${bits} \ 42 | --decomp_factor ${decomp} \ 43 | --chunk_size ${CHUNK_SIZE} 44 | done 45 | done 46 | done 47 | done 48 | 49 | # OPT - Linear & MHA 50 | for nsamples in 128;do 51 | for seqlen in 2048;do 52 | for size in 6.7b;do 53 | for bits in 4 8;do 54 | case ${size} in 55 | 6.7b) 56 | if [ ${bits} = 4 ]; then 57 | decomp=8 58 | elif [ ${bits} = 8 ]; then 59 | decomp=4 60 | fi 61 | ;; 62 | 13b) 63 | if [ ${bits} = 4 ]; then 64 | decomp=8 65 | elif [ ${bits} = 8 ]; then 66 | decomp=4 67 | fi 68 | ;; 69 | 66b) 70 | if [ ${bits} = 4 ]; then 71 | decomp=10 72 | elif [ ${bits} = 8 ]; then 73 | decomp=8 74 | fi 75 | ;; 76 | esac 77 | echo calibrating opt-${size} ${bits}bit chunk size of ${CHUNK_SIZE} 78 | echo linear + mha 79 | python run_calibration.py \ 80 | --model-name "facebook/opt-${size}" \ 81 | --target "scale" \ 82 | --output-path "scale/${seqlen}_${size}_${nsamples}_${bits}bit_${decomp}decomp_mha.pt" \ 83 | --dataset-path "../../data/val.jsonl.zst" \ 84 | --num-samples ${nsamples} \ 85 | --seq-len ${seqlen} \ 86 | --q_bits ${bits} \ 87 | --decomp_factor ${decomp} \ 88 | --chunk_size ${CHUNK_SIZE} \ 89 | --quant_mha 90 | done 91 | done 92 | done 93 | done 94 | 95 | -------------------------------------------------------------------------------- /calibration/opt/run_calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | ) 8 | import argparse 9 | from calibration import get_scale_factor, get_bias 10 | 11 | 12 | def build_model_and_tokenizer(model_name, seq_len): 13 | tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=seq_len) 14 | if os.path.exists('../../models/modeling_opt.py'): 15 | os.system('rm ../../models/modeling_opt.py') 16 | cwd = os.getcwd() 17 | os.chdir('../../models/') 18 | os.system('ln -s modeling_opt_tender.py modeling_opt.py') 19 | os.chdir(cwd) 20 | 21 | kwargs = {"torch_dtype": torch.float16, "device_map": "cpu"} 22 | model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) 23 | return model, tokenizer 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--model-name', type=str, 28 | default='facebook/opt-1.3b', help='model name') 29 | parser.add_argument('--target', type=str, choices=['scale', 'bias'], required=True, 30 | help='Calibrate scale factor or bias') 31 | parser.add_argument('--output-path', type=str, default='scales/opt-1.3b.pt', 32 | help='where to save the result') 33 | parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', 34 | help='location of the calibration dataset, we use the validation set of the Pile dataset') 35 | parser.add_argument('--num-samples', type=int, default=512) 36 | parser.add_argument('--seq-len', type=int, default=512) 37 | parser.add_argument('--q_bits', type=int, default=8, 38 | help='Number of bits for quantization') 39 | parser.add_argument('--decomp_factor', type=int, default=8, 40 | help='Number of groups for classification') 41 | parser.add_argument('--chunk_size', type=int, default=256, 42 | help='Size of row chunk') 43 | parser.add_argument('--quant_mha', action='store_true', 44 | help='Whether to quantize multi-head-attention') 45 | args = parser.parse_args() 46 | return args 47 | 48 | @torch.no_grad() 49 | def main(): 50 | args = parse_args() 51 | model, tokenizer = build_model_and_tokenizer(args.model_name, args.seq_len) 52 | 53 | for layer in model.model.decoder.layers: 54 | layer.self_attn.quant_mha = args.quant_mha 55 | 56 | layer.self_attn.q_bits = args.q_bits 57 | layer.q_bits = args.q_bits 58 | 59 | layer.self_attn.decomp_factor = args.decomp_factor 60 | layer.decomp_factor = args.decomp_factor 61 | 62 | layer.self_attn.chunk_size = args.chunk_size 63 | layer.chunk_size = args.chunk_size 64 | 65 | if args.target == 'scale': 66 | result = get_scale_factor(model, tokenizer, args.dataset_path, 67 | args.num_samples, args.seq_len, args.quant_mha) 68 | 69 | else: 70 | result = get_bias(model, tokenizer, args.dataset_path, 71 | args.num_samples, args.seq_len, args.quant_mha) 72 | 73 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 74 | torch.save(result, args.output_path) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /figures/tender-flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snu-comparch/Tender/ed7f5d0e81c05b46f80bd896a83011b32869bd80/figures/tender-flow.png -------------------------------------------------------------------------------- /lm-eval.patch: -------------------------------------------------------------------------------- 1 | diff --git a/lm_eval/models/gpt2.py b/lm_eval/models/gpt2.py 2 | index 77fa415e..5ddef5a4 100644 3 | --- a/lm_eval/models/gpt2.py 4 | +++ b/lm_eval/models/gpt2.py 5 | @@ -15,6 +15,60 @@ def _get_dtype( 6 | _torch_dtype = dtype 7 | return _torch_dtype 8 | 9 | +def set_params_tender_opt(model): 10 | + q_bits = 4 11 | + decomp_factor = 8 12 | + chunk_size = 16 13 | + for layer in model.model.decoder.layers: 14 | + attn = layer.self_attn 15 | + 16 | + attn.quant_mha = True 17 | + 18 | + attn.q_bits = q_bits 19 | + layer.q_bits = q_bits 20 | + 21 | + attn.decomp_factor = decomp_factor 22 | + layer.decomp_factor = decomp_factor 23 | + 24 | + attn.chunk_size = chunk_size 25 | + layer.chunk_size = chunk_size 26 | + 27 | + attn.quant_out_bf16 = True 28 | + layer.quant_out_bf16 = True 29 | + 30 | + model.quant_lm_head = True 31 | + model.lm_head_tender.q_bits = q_bits 32 | + model.lm_head_tender.decomp_factor = decomp_factor 33 | + model.lm_head_tender.chunk_size = chunk_size 34 | + model.lm_head_tender.quant_out_bf16 = True 35 | + 36 | +def set_params_tender_llama(model): 37 | + q_bits = 4 38 | + decomp_factor = 14 39 | + chunk_size = 16 40 | + for layer in model.model.layers: 41 | + attn = layer.self_attn 42 | + mlp = layer.mlp 43 | + 44 | + attn.quant_mha = True 45 | + 46 | + attn.q_bits = q_bits 47 | + mlp.q_bits = q_bits 48 | + 49 | + attn.decomp_factor = decomp_factor 50 | + mlp.decomp_factor = decomp_factor 51 | + 52 | + attn.chunk_size = chunk_size 53 | + mlp.chunk_size = chunk_size 54 | + 55 | + attn.quant_out_bf16 = True 56 | + layer.quant_out_bf16 = True 57 | + 58 | + model.quant_lm_head = True 59 | + model.lm_head_tender.q_bits = q_bits 60 | + model.lm_head_tender.decomp_factor = decomp_factor 61 | + model.lm_head_tender.chunk_size = chunk_size 62 | + model.lm_head_tender.quant_out_bf16 = True 63 | 64 | class HFLM(BaseLM): 65 | 66 | @@ -34,6 +88,7 @@ class HFLM(BaseLM): 67 | load_in_8bit: Optional[bool] = False, 68 | trust_remote_code: Optional[bool] = False, 69 | dtype: Optional[Union[str, torch.dtype]]="auto", 70 | + scheme="base", 71 | ): 72 | super().__init__() 73 | 74 | @@ -90,11 +145,25 @@ class HFLM(BaseLM): 75 | torch_dtype=_get_dtype(dtype), 76 | trust_remote_code=trust_remote_code, 77 | ).to(self.device) 78 | - self.tokenizer = transformers.AutoTokenizer.from_pretrained( 79 | - tokenizer if tokenizer else pretrained, 80 | - revision=revision, 81 | - trust_remote_code=trust_remote_code, 82 | - ) 83 | + if scheme=="tender": 84 | + if 'opt' in pretrained: 85 | + set_params_tender_opt(self.model) 86 | + elif 'llama' in pretrained: 87 | + set_params_tender_llama(self.model) 88 | + 89 | + try: 90 | + self.tokenizer = transformers.AutoTokenizer.from_pretrained( 91 | + tokenizer if tokenizer else pretrained, 92 | + revision=revision, 93 | + trust_remote_code=trust_remote_code, 94 | + ) 95 | + except: 96 | + from transformers import LlamaTokenizer 97 | + self.tokenizer = LlamaTokenizer.from_pretrained( 98 | + tokenizer if tokenizer else pretrained, 99 | + revision=revision, 100 | + trust_remote_code=trust_remote_code, 101 | + ) 102 | 103 | else: 104 | raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel') 105 | -------------------------------------------------------------------------------- /models/modeling_llama_orig.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | import math 22 | from typing import List, Optional, Tuple, Union 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | from torch import nn 28 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 29 | 30 | from ...activations import ACT2FN 31 | from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 32 | from ...modeling_utils import PreTrainedModel 33 | from ...pytorch_utils import ALL_LAYERNORM_LAYERS 34 | from ...utils import ( 35 | add_start_docstrings, 36 | add_start_docstrings_to_model_forward, 37 | is_flash_attn_available, 38 | logging, 39 | replace_return_docstrings, 40 | ) 41 | from .configuration_llama import LlamaConfig 42 | 43 | 44 | if is_flash_attn_available(): 45 | from flash_attn import flash_attn_func, flash_attn_varlen_func 46 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 47 | 48 | 49 | logger = logging.get_logger(__name__) 50 | 51 | _CONFIG_FOR_DOC = "LlamaConfig" 52 | 53 | 54 | def _get_unpad_data(padding_mask): 55 | seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) 56 | indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() 57 | max_seqlen_in_batch = seqlens_in_batch.max().item() 58 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 59 | return ( 60 | indices, 61 | cu_seqlens, 62 | max_seqlen_in_batch, 63 | ) 64 | 65 | 66 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 67 | def _make_causal_mask( 68 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 69 | ): 70 | """ 71 | Make causal mask used for bi-directional self-attention. 72 | """ 73 | bsz, tgt_len = input_ids_shape 74 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 75 | mask_cond = torch.arange(mask.size(-1), device=device) 76 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 77 | mask = mask.to(dtype) 78 | 79 | if past_key_values_length > 0: 80 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 81 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 82 | 83 | 84 | # Copied from transformers.models.bart.modeling_bart._expand_mask 85 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 86 | """ 87 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 88 | """ 89 | bsz, src_len = mask.size() 90 | tgt_len = tgt_len if tgt_len is not None else src_len 91 | 92 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 93 | 94 | inverted_mask = 1.0 - expanded_mask 95 | 96 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 97 | 98 | 99 | class LlamaRMSNorm(nn.Module): 100 | def __init__(self, hidden_size, eps=1e-6): 101 | """ 102 | LlamaRMSNorm is equivalent to T5LayerNorm 103 | """ 104 | super().__init__() 105 | self.weight = nn.Parameter(torch.ones(hidden_size)) 106 | self.variance_epsilon = eps 107 | 108 | def forward(self, hidden_states): 109 | input_dtype = hidden_states.dtype 110 | hidden_states = hidden_states.to(torch.float32) 111 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 112 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 113 | return self.weight * hidden_states.to(input_dtype) 114 | 115 | 116 | ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) 117 | 118 | 119 | class LlamaRotaryEmbedding(nn.Module): 120 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 121 | super().__init__() 122 | 123 | self.dim = dim 124 | self.max_position_embeddings = max_position_embeddings 125 | self.base = base 126 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 127 | self.register_buffer("inv_freq", inv_freq, persistent=False) 128 | 129 | # Build here to make `torch.jit.trace` work. 130 | self._set_cos_sin_cache( 131 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 132 | ) 133 | 134 | def _set_cos_sin_cache(self, seq_len, device, dtype): 135 | self.max_seq_len_cached = seq_len 136 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 137 | 138 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 139 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 140 | emb = torch.cat((freqs, freqs), dim=-1) 141 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 142 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 143 | 144 | def forward(self, x, seq_len=None): 145 | # x: [bs, num_attention_heads, seq_len, head_size] 146 | if seq_len > self.max_seq_len_cached: 147 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 148 | 149 | return ( 150 | self.cos_cached[:seq_len].to(dtype=x.dtype), 151 | self.sin_cached[:seq_len].to(dtype=x.dtype), 152 | ) 153 | 154 | 155 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 156 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 157 | 158 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 159 | self.scaling_factor = scaling_factor 160 | super().__init__(dim, max_position_embeddings, base, device) 161 | 162 | def _set_cos_sin_cache(self, seq_len, device, dtype): 163 | self.max_seq_len_cached = seq_len 164 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 165 | t = t / self.scaling_factor 166 | 167 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 168 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 169 | emb = torch.cat((freqs, freqs), dim=-1) 170 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 171 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 172 | 173 | 174 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 175 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 176 | 177 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 178 | self.scaling_factor = scaling_factor 179 | super().__init__(dim, max_position_embeddings, base, device) 180 | 181 | def _set_cos_sin_cache(self, seq_len, device, dtype): 182 | self.max_seq_len_cached = seq_len 183 | 184 | if seq_len > self.max_position_embeddings: 185 | base = self.base * ( 186 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 187 | ) ** (self.dim / (self.dim - 2)) 188 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 189 | self.register_buffer("inv_freq", inv_freq, persistent=False) 190 | 191 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 192 | 193 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 194 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 195 | emb = torch.cat((freqs, freqs), dim=-1) 196 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 197 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 198 | 199 | 200 | def rotate_half(x): 201 | """Rotates half the hidden dims of the input.""" 202 | x1 = x[..., : x.shape[-1] // 2] 203 | x2 = x[..., x.shape[-1] // 2 :] 204 | return torch.cat((-x2, x1), dim=-1) 205 | 206 | 207 | # Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb 208 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 209 | cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] 210 | sin = sin[position_ids].unsqueeze(1) 211 | q_embed = (q * cos) + (rotate_half(q) * sin) 212 | k_embed = (k * cos) + (rotate_half(k) * sin) 213 | return q_embed, k_embed 214 | 215 | 216 | class LlamaMLP(nn.Module): 217 | def __init__(self, config): 218 | super().__init__() 219 | self.config = config 220 | self.hidden_size = config.hidden_size 221 | self.intermediate_size = config.intermediate_size 222 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 223 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 224 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 225 | self.act_fn = ACT2FN[config.hidden_act] 226 | 227 | def forward(self, x): 228 | if self.config.pretraining_tp > 1: 229 | slice = self.intermediate_size // self.config.pretraining_tp 230 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 231 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 232 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 233 | 234 | gate_proj = torch.cat( 235 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 236 | ) 237 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) 238 | 239 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 240 | down_proj = [ 241 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 242 | ] 243 | down_proj = sum(down_proj) 244 | else: 245 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 246 | 247 | return down_proj 248 | 249 | 250 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 251 | """ 252 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 253 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 254 | """ 255 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 256 | if n_rep == 1: 257 | return hidden_states 258 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 259 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 260 | 261 | 262 | class LlamaAttention(nn.Module): 263 | """Multi-headed attention from 'Attention Is All You Need' paper""" 264 | 265 | def __init__(self, config: LlamaConfig): 266 | super().__init__() 267 | self.config = config 268 | self.hidden_size = config.hidden_size 269 | self.num_heads = config.num_attention_heads 270 | self.head_dim = self.hidden_size // self.num_heads 271 | self.num_key_value_heads = config.num_key_value_heads 272 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 273 | self.max_position_embeddings = config.max_position_embeddings 274 | self.rope_theta = config.rope_theta 275 | 276 | if (self.head_dim * self.num_heads) != self.hidden_size: 277 | raise ValueError( 278 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 279 | f" and `num_heads`: {self.num_heads})." 280 | ) 281 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) 282 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 283 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 284 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) 285 | self._init_rope() 286 | 287 | def _init_rope(self): 288 | if self.config.rope_scaling is None: 289 | self.rotary_emb = LlamaRotaryEmbedding( 290 | self.head_dim, 291 | max_position_embeddings=self.max_position_embeddings, 292 | base=self.rope_theta, 293 | ) 294 | else: 295 | scaling_type = self.config.rope_scaling["type"] 296 | scaling_factor = self.config.rope_scaling["factor"] 297 | if scaling_type == "linear": 298 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding( 299 | self.head_dim, 300 | max_position_embeddings=self.max_position_embeddings, 301 | scaling_factor=scaling_factor, 302 | base=self.rope_theta, 303 | ) 304 | elif scaling_type == "dynamic": 305 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( 306 | self.head_dim, 307 | max_position_embeddings=self.max_position_embeddings, 308 | scaling_factor=scaling_factor, 309 | base=self.rope_theta, 310 | ) 311 | else: 312 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 313 | 314 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 315 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 316 | 317 | def forward( 318 | self, 319 | hidden_states: torch.Tensor, 320 | attention_mask: Optional[torch.Tensor] = None, 321 | position_ids: Optional[torch.LongTensor] = None, 322 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 323 | output_attentions: bool = False, 324 | use_cache: bool = False, 325 | padding_mask: Optional[torch.LongTensor] = None, 326 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 327 | bsz, q_len, _ = hidden_states.size() 328 | 329 | if self.config.pretraining_tp > 1: 330 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 331 | query_slices = self.q_proj.weight.split( 332 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 333 | ) 334 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 335 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 336 | 337 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 338 | query_states = torch.cat(query_states, dim=-1) 339 | 340 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 341 | key_states = torch.cat(key_states, dim=-1) 342 | 343 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 344 | value_states = torch.cat(value_states, dim=-1) 345 | 346 | else: 347 | query_states = self.q_proj(hidden_states) 348 | key_states = self.k_proj(hidden_states) 349 | value_states = self.v_proj(hidden_states) 350 | 351 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 352 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 353 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 354 | 355 | kv_seq_len = key_states.shape[-2] 356 | if past_key_value is not None: 357 | kv_seq_len += past_key_value[0].shape[-2] 358 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 359 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 360 | 361 | if past_key_value is not None: 362 | # reuse k, v, self_attention 363 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 364 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 365 | 366 | past_key_value = (key_states, value_states) if use_cache else None 367 | 368 | key_states = repeat_kv(key_states, self.num_key_value_groups) 369 | value_states = repeat_kv(value_states, self.num_key_value_groups) 370 | 371 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 372 | 373 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 374 | raise ValueError( 375 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 376 | f" {attn_weights.size()}" 377 | ) 378 | 379 | if attention_mask is not None: 380 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 381 | raise ValueError( 382 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 383 | ) 384 | attn_weights = attn_weights + attention_mask 385 | 386 | # upcast attention to fp32 387 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 388 | attn_output = torch.matmul(attn_weights, value_states) 389 | 390 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 391 | raise ValueError( 392 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 393 | f" {attn_output.size()}" 394 | ) 395 | 396 | attn_output = attn_output.transpose(1, 2).contiguous() 397 | 398 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 399 | 400 | if self.config.pretraining_tp > 1: 401 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 402 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 403 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) 404 | else: 405 | attn_output = self.o_proj(attn_output) 406 | 407 | if not output_attentions: 408 | attn_weights = None 409 | 410 | return attn_output, attn_weights, past_key_value 411 | 412 | 413 | class LlamaFlashAttention2(LlamaAttention): 414 | """ 415 | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays 416 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 417 | flash attention and deal with padding tokens in case the input contains any of them. 418 | """ 419 | 420 | def forward( 421 | self, 422 | hidden_states: torch.Tensor, 423 | attention_mask: Optional[torch.Tensor] = None, 424 | position_ids: Optional[torch.LongTensor] = None, 425 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 426 | output_attentions: bool = False, 427 | use_cache: bool = False, 428 | padding_mask: Optional[torch.LongTensor] = None, 429 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 430 | # LlamaFlashAttention2 attention does not support output_attentions 431 | output_attentions = False 432 | 433 | bsz, q_len, _ = hidden_states.size() 434 | 435 | query_states = self.q_proj(hidden_states) 436 | key_states = self.k_proj(hidden_states) 437 | value_states = self.v_proj(hidden_states) 438 | 439 | # Flash attention requires the input to have the shape 440 | # batch_size x seq_length x head_dime x hidden_dim 441 | # therefore we just need to keep the original shape 442 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 443 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 444 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 445 | 446 | kv_seq_len = key_states.shape[-2] 447 | if past_key_value is not None: 448 | kv_seq_len += past_key_value[0].shape[-2] 449 | 450 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 451 | 452 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 453 | 454 | if past_key_value is not None: 455 | # reuse k, v, self_attention 456 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 457 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 458 | 459 | past_key_value = (key_states, value_states) if use_cache else None 460 | 461 | query_states = query_states.transpose(1, 2) 462 | key_states = key_states.transpose(1, 2) 463 | value_states = value_states.transpose(1, 2) 464 | 465 | # TODO: llama does not have dropout in the config?? 466 | # It is recommended to use dropout with FA according to the docs 467 | # when training. 468 | dropout_rate = 0.0 # if not self.training else self.attn_dropout 469 | 470 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 471 | # therefore the input hidden states gets silently casted in float32. Hence, we need 472 | # cast them back in float16 just to be sure everything works as expected. 473 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 474 | # in fp32. (LlamaRMSNorm handles it correctly) 475 | input_dtype = query_states.dtype 476 | if input_dtype == torch.float32: 477 | logger.warning_once( 478 | "The input hidden states seems to be silently casted in float32, this might be related to" 479 | " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 480 | " float16." 481 | ) 482 | 483 | query_states = query_states.to(torch.float16) 484 | key_states = key_states.to(torch.float16) 485 | value_states = value_states.to(torch.float16) 486 | 487 | attn_output = self._flash_attention_forward( 488 | query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate 489 | ) 490 | 491 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 492 | attn_output = self.o_proj(attn_output) 493 | 494 | if not output_attentions: 495 | attn_weights = None 496 | 497 | return attn_output, attn_weights, past_key_value 498 | 499 | def _flash_attention_forward( 500 | self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None 501 | ): 502 | """ 503 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 504 | first unpad the input, then computes the attention scores and pad the final attention scores. 505 | 506 | Args: 507 | query_states (`torch.Tensor`): 508 | Input query states to be passed to Flash Attention API 509 | key_states (`torch.Tensor`): 510 | Input key states to be passed to Flash Attention API 511 | value_states (`torch.Tensor`): 512 | Input value states to be passed to Flash Attention API 513 | padding_mask (`torch.Tensor`): 514 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 515 | position of padding tokens and 1 for the position of non-padding tokens. 516 | dropout (`int`, *optional*): 517 | Attention dropout 518 | softmax_scale (`float`, *optional*): 519 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 520 | """ 521 | # Contains at least one padding token in the sequence 522 | if padding_mask is not None: 523 | batch_size = query_states.shape[0] 524 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 525 | query_states, key_states, value_states, padding_mask, query_length 526 | ) 527 | 528 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 529 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 530 | 531 | attn_output_unpad = flash_attn_varlen_func( 532 | query_states, 533 | key_states, 534 | value_states, 535 | cu_seqlens_q=cu_seqlens_q, 536 | cu_seqlens_k=cu_seqlens_k, 537 | max_seqlen_q=max_seqlen_in_batch_q, 538 | max_seqlen_k=max_seqlen_in_batch_k, 539 | dropout_p=dropout, 540 | softmax_scale=softmax_scale, 541 | causal=True, 542 | ) 543 | 544 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 545 | else: 546 | attn_output = flash_attn_func( 547 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True 548 | ) 549 | 550 | return attn_output 551 | 552 | def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): 553 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) 554 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 555 | 556 | key_layer = index_first_axis( 557 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 558 | ) 559 | value_layer = index_first_axis( 560 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 561 | ) 562 | if query_length == kv_seq_len: 563 | query_layer = index_first_axis( 564 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k 565 | ) 566 | cu_seqlens_q = cu_seqlens_k 567 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 568 | indices_q = indices_k 569 | elif query_length == 1: 570 | max_seqlen_in_batch_q = 1 571 | cu_seqlens_q = torch.arange( 572 | batch_size + 1, dtype=torch.int32, device=query_layer.device 573 | ) # There is a memcpy here, that is very bad. 574 | indices_q = cu_seqlens_q[:-1] 575 | query_layer = query_layer.squeeze(1) 576 | else: 577 | # The -q_len: slice assumes left padding. 578 | padding_mask = padding_mask[:, -query_length:] 579 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) 580 | 581 | return ( 582 | query_layer, 583 | key_layer, 584 | value_layer, 585 | indices_q, 586 | (cu_seqlens_q, cu_seqlens_k), 587 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 588 | ) 589 | 590 | 591 | class LlamaDecoderLayer(nn.Module): 592 | def __init__(self, config: LlamaConfig): 593 | super().__init__() 594 | self.hidden_size = config.hidden_size 595 | self.self_attn = ( 596 | LlamaAttention(config=config) 597 | if not getattr(config, "_flash_attn_2_enabled", False) 598 | else LlamaFlashAttention2(config=config) 599 | ) 600 | self.mlp = LlamaMLP(config) 601 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 602 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 603 | 604 | def forward( 605 | self, 606 | hidden_states: torch.Tensor, 607 | attention_mask: Optional[torch.Tensor] = None, 608 | position_ids: Optional[torch.LongTensor] = None, 609 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 610 | output_attentions: Optional[bool] = False, 611 | use_cache: Optional[bool] = False, 612 | padding_mask: Optional[torch.LongTensor] = None, 613 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 614 | """ 615 | Args: 616 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 617 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 618 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 619 | output_attentions (`bool`, *optional*): 620 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 621 | returned tensors for more detail. 622 | use_cache (`bool`, *optional*): 623 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 624 | (see `past_key_values`). 625 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 626 | """ 627 | 628 | residual = hidden_states 629 | 630 | hidden_states = self.input_layernorm(hidden_states) 631 | 632 | # Self Attention 633 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 634 | hidden_states=hidden_states, 635 | attention_mask=attention_mask, 636 | position_ids=position_ids, 637 | past_key_value=past_key_value, 638 | output_attentions=output_attentions, 639 | use_cache=use_cache, 640 | padding_mask=padding_mask, 641 | ) 642 | hidden_states = residual + hidden_states 643 | 644 | # Fully Connected 645 | residual = hidden_states 646 | hidden_states = self.post_attention_layernorm(hidden_states) 647 | hidden_states = self.mlp(hidden_states) 648 | hidden_states = residual + hidden_states 649 | 650 | outputs = (hidden_states,) 651 | 652 | if output_attentions: 653 | outputs += (self_attn_weights,) 654 | 655 | if use_cache: 656 | outputs += (present_key_value,) 657 | 658 | return outputs 659 | 660 | 661 | LLAMA_START_DOCSTRING = r""" 662 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 663 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 664 | etc.) 665 | 666 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 667 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 668 | and behavior. 669 | 670 | Parameters: 671 | config ([`LlamaConfig`]): 672 | Model configuration class with all the parameters of the model. Initializing with a config file does not 673 | load the weights associated with the model, only the configuration. Check out the 674 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 675 | """ 676 | 677 | 678 | @add_start_docstrings( 679 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 680 | LLAMA_START_DOCSTRING, 681 | ) 682 | class LlamaPreTrainedModel(PreTrainedModel): 683 | config_class = LlamaConfig 684 | base_model_prefix = "model" 685 | supports_gradient_checkpointing = True 686 | _no_split_modules = ["LlamaDecoderLayer"] 687 | _skip_keys_device_placement = "past_key_values" 688 | _supports_flash_attn_2 = True 689 | 690 | def _init_weights(self, module): 691 | std = self.config.initializer_range 692 | if isinstance(module, nn.Linear): 693 | module.weight.data.normal_(mean=0.0, std=std) 694 | if module.bias is not None: 695 | module.bias.data.zero_() 696 | elif isinstance(module, nn.Embedding): 697 | module.weight.data.normal_(mean=0.0, std=std) 698 | if module.padding_idx is not None: 699 | module.weight.data[module.padding_idx].zero_() 700 | 701 | def _set_gradient_checkpointing(self, module, value=False): 702 | if isinstance(module, LlamaModel): 703 | module.gradient_checkpointing = value 704 | 705 | 706 | LLAMA_INPUTS_DOCSTRING = r""" 707 | Args: 708 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 709 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 710 | it. 711 | 712 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 713 | [`PreTrainedTokenizer.__call__`] for details. 714 | 715 | [What are input IDs?](../glossary#input-ids) 716 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 717 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 718 | 719 | - 1 for tokens that are **not masked**, 720 | - 0 for tokens that are **masked**. 721 | 722 | [What are attention masks?](../glossary#attention-mask) 723 | 724 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 725 | [`PreTrainedTokenizer.__call__`] for details. 726 | 727 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 728 | `past_key_values`). 729 | 730 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 731 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 732 | information on the default strategy. 733 | 734 | - 1 indicates the head is **not masked**, 735 | - 0 indicates the head is **masked**. 736 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 737 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 738 | config.n_positions - 1]`. 739 | 740 | [What are position IDs?](../glossary#position-ids) 741 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 742 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 743 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 744 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 745 | 746 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 747 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 748 | 749 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 750 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 751 | of shape `(batch_size, sequence_length)`. 752 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 753 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 754 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 755 | model's internal embedding lookup matrix. 756 | use_cache (`bool`, *optional*): 757 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 758 | `past_key_values`). 759 | output_attentions (`bool`, *optional*): 760 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 761 | tensors for more detail. 762 | output_hidden_states (`bool`, *optional*): 763 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 764 | more detail. 765 | return_dict (`bool`, *optional*): 766 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 767 | """ 768 | 769 | 770 | @add_start_docstrings( 771 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 772 | LLAMA_START_DOCSTRING, 773 | ) 774 | class LlamaModel(LlamaPreTrainedModel): 775 | """ 776 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 777 | 778 | Args: 779 | config: LlamaConfig 780 | """ 781 | 782 | def __init__(self, config: LlamaConfig): 783 | super().__init__(config) 784 | self.padding_idx = config.pad_token_id 785 | self.vocab_size = config.vocab_size 786 | 787 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 788 | self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 789 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 790 | 791 | self.gradient_checkpointing = False 792 | # Initialize weights and apply final processing 793 | self.post_init() 794 | 795 | def get_input_embeddings(self): 796 | return self.embed_tokens 797 | 798 | def set_input_embeddings(self, value): 799 | self.embed_tokens = value 800 | 801 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 802 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 803 | # create causal mask 804 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 805 | combined_attention_mask = None 806 | if input_shape[-1] > 1: 807 | combined_attention_mask = _make_causal_mask( 808 | input_shape, 809 | inputs_embeds.dtype, 810 | device=inputs_embeds.device, 811 | past_key_values_length=past_key_values_length, 812 | ) 813 | 814 | if attention_mask is not None: 815 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 816 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 817 | inputs_embeds.device 818 | ) 819 | combined_attention_mask = ( 820 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 821 | ) 822 | 823 | return combined_attention_mask 824 | 825 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 826 | def forward( 827 | self, 828 | input_ids: torch.LongTensor = None, 829 | attention_mask: Optional[torch.Tensor] = None, 830 | position_ids: Optional[torch.LongTensor] = None, 831 | past_key_values: Optional[List[torch.FloatTensor]] = None, 832 | inputs_embeds: Optional[torch.FloatTensor] = None, 833 | use_cache: Optional[bool] = None, 834 | output_attentions: Optional[bool] = None, 835 | output_hidden_states: Optional[bool] = None, 836 | return_dict: Optional[bool] = None, 837 | ) -> Union[Tuple, BaseModelOutputWithPast]: 838 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 839 | output_hidden_states = ( 840 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 841 | ) 842 | use_cache = use_cache if use_cache is not None else self.config.use_cache 843 | 844 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 845 | 846 | # retrieve input_ids and inputs_embeds 847 | if input_ids is not None and inputs_embeds is not None: 848 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 849 | elif input_ids is not None: 850 | batch_size, seq_length = input_ids.shape 851 | elif inputs_embeds is not None: 852 | batch_size, seq_length, _ = inputs_embeds.shape 853 | else: 854 | raise ValueError("You have to specify either input_ids or inputs_embeds") 855 | 856 | seq_length_with_past = seq_length 857 | past_key_values_length = 0 858 | 859 | if past_key_values is not None: 860 | past_key_values_length = past_key_values[0][0].shape[2] 861 | seq_length_with_past = seq_length_with_past + past_key_values_length 862 | 863 | if position_ids is None: 864 | device = input_ids.device if input_ids is not None else inputs_embeds.device 865 | position_ids = torch.arange( 866 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 867 | ) 868 | position_ids = position_ids.unsqueeze(0) 869 | 870 | if inputs_embeds is None: 871 | inputs_embeds = self.embed_tokens(input_ids) 872 | # embed positions 873 | if attention_mask is None: 874 | attention_mask = torch.ones( 875 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device 876 | ) 877 | padding_mask = None 878 | else: 879 | if 0 in attention_mask: 880 | padding_mask = attention_mask 881 | else: 882 | padding_mask = None 883 | 884 | attention_mask = self._prepare_decoder_attention_mask( 885 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 886 | ) 887 | 888 | hidden_states = inputs_embeds 889 | 890 | if self.gradient_checkpointing and self.training: 891 | if use_cache: 892 | logger.warning_once( 893 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 894 | ) 895 | use_cache = False 896 | 897 | # decoder layers 898 | all_hidden_states = () if output_hidden_states else None 899 | all_self_attns = () if output_attentions else None 900 | next_decoder_cache = () if use_cache else None 901 | 902 | for idx, decoder_layer in enumerate(self.layers): 903 | if output_hidden_states: 904 | all_hidden_states += (hidden_states,) 905 | 906 | past_key_value = past_key_values[idx] if past_key_values is not None else None 907 | 908 | if self.gradient_checkpointing and self.training: 909 | 910 | def create_custom_forward(module): 911 | def custom_forward(*inputs): 912 | # None for past_key_value 913 | return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) 914 | 915 | return custom_forward 916 | 917 | layer_outputs = torch.utils.checkpoint.checkpoint( 918 | create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids 919 | ) 920 | else: 921 | layer_outputs = decoder_layer( 922 | hidden_states, 923 | attention_mask=attention_mask, 924 | position_ids=position_ids, 925 | past_key_value=past_key_value, 926 | output_attentions=output_attentions, 927 | use_cache=use_cache, 928 | padding_mask=padding_mask, 929 | ) 930 | 931 | hidden_states = layer_outputs[0] 932 | 933 | if use_cache: 934 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 935 | 936 | if output_attentions: 937 | all_self_attns += (layer_outputs[1],) 938 | 939 | hidden_states = self.norm(hidden_states) 940 | 941 | # add hidden states from the last decoder layer 942 | if output_hidden_states: 943 | all_hidden_states += (hidden_states,) 944 | 945 | next_cache = next_decoder_cache if use_cache else None 946 | if not return_dict: 947 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 948 | return BaseModelOutputWithPast( 949 | last_hidden_state=hidden_states, 950 | past_key_values=next_cache, 951 | hidden_states=all_hidden_states, 952 | attentions=all_self_attns, 953 | ) 954 | 955 | 956 | class LlamaForCausalLM(LlamaPreTrainedModel): 957 | _tied_weights_keys = ["lm_head.weight"] 958 | 959 | def __init__(self, config): 960 | super().__init__(config) 961 | self.model = LlamaModel(config) 962 | self.vocab_size = config.vocab_size 963 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 964 | 965 | # Initialize weights and apply final processing 966 | self.post_init() 967 | 968 | def get_input_embeddings(self): 969 | return self.model.embed_tokens 970 | 971 | def set_input_embeddings(self, value): 972 | self.model.embed_tokens = value 973 | 974 | def get_output_embeddings(self): 975 | return self.lm_head 976 | 977 | def set_output_embeddings(self, new_embeddings): 978 | self.lm_head = new_embeddings 979 | 980 | def set_decoder(self, decoder): 981 | self.model = decoder 982 | 983 | def get_decoder(self): 984 | return self.model 985 | 986 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 987 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 988 | def forward( 989 | self, 990 | input_ids: torch.LongTensor = None, 991 | attention_mask: Optional[torch.Tensor] = None, 992 | position_ids: Optional[torch.LongTensor] = None, 993 | past_key_values: Optional[List[torch.FloatTensor]] = None, 994 | inputs_embeds: Optional[torch.FloatTensor] = None, 995 | labels: Optional[torch.LongTensor] = None, 996 | use_cache: Optional[bool] = None, 997 | output_attentions: Optional[bool] = None, 998 | output_hidden_states: Optional[bool] = None, 999 | return_dict: Optional[bool] = None, 1000 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1001 | r""" 1002 | Args: 1003 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1004 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1005 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1006 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1007 | 1008 | Returns: 1009 | 1010 | Example: 1011 | 1012 | ```python 1013 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 1014 | 1015 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1016 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1017 | 1018 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1019 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1020 | 1021 | >>> # Generate 1022 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1023 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1024 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1025 | ```""" 1026 | 1027 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1028 | output_hidden_states = ( 1029 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1030 | ) 1031 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1032 | 1033 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1034 | outputs = self.model( 1035 | input_ids=input_ids, 1036 | attention_mask=attention_mask, 1037 | position_ids=position_ids, 1038 | past_key_values=past_key_values, 1039 | inputs_embeds=inputs_embeds, 1040 | use_cache=use_cache, 1041 | output_attentions=output_attentions, 1042 | output_hidden_states=output_hidden_states, 1043 | return_dict=return_dict, 1044 | ) 1045 | 1046 | hidden_states = outputs[0] 1047 | if self.config.pretraining_tp > 1: 1048 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 1049 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 1050 | logits = torch.cat(logits, dim=-1) 1051 | else: 1052 | logits = self.lm_head(hidden_states) 1053 | logits = logits.float() 1054 | 1055 | loss = None 1056 | if labels is not None: 1057 | # Shift so that tokens < n predict n 1058 | shift_logits = logits[..., :-1, :].contiguous() 1059 | shift_labels = labels[..., 1:].contiguous() 1060 | # Flatten the tokens 1061 | loss_fct = CrossEntropyLoss() 1062 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1063 | shift_labels = shift_labels.view(-1) 1064 | # Enable model parallelism 1065 | shift_labels = shift_labels.to(shift_logits.device) 1066 | loss = loss_fct(shift_logits, shift_labels) 1067 | 1068 | if not return_dict: 1069 | output = (logits,) + outputs[1:] 1070 | return (loss,) + output if loss is not None else output 1071 | 1072 | return CausalLMOutputWithPast( 1073 | loss=loss, 1074 | logits=logits, 1075 | past_key_values=outputs.past_key_values, 1076 | hidden_states=outputs.hidden_states, 1077 | attentions=outputs.attentions, 1078 | ) 1079 | 1080 | def prepare_inputs_for_generation( 1081 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1082 | ): 1083 | if past_key_values: 1084 | input_ids = input_ids[:, -1:] 1085 | 1086 | position_ids = kwargs.get("position_ids", None) 1087 | if attention_mask is not None and position_ids is None: 1088 | # create position_ids on the fly for batch generation 1089 | position_ids = attention_mask.long().cumsum(-1) - 1 1090 | position_ids.masked_fill_(attention_mask == 0, 1) 1091 | if past_key_values: 1092 | position_ids = position_ids[:, -1].unsqueeze(-1) 1093 | 1094 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1095 | if inputs_embeds is not None and past_key_values is None: 1096 | model_inputs = {"inputs_embeds": inputs_embeds} 1097 | else: 1098 | model_inputs = {"input_ids": input_ids} 1099 | 1100 | model_inputs.update( 1101 | { 1102 | "position_ids": position_ids, 1103 | "past_key_values": past_key_values, 1104 | "use_cache": kwargs.get("use_cache"), 1105 | "attention_mask": attention_mask, 1106 | } 1107 | ) 1108 | return model_inputs 1109 | 1110 | @staticmethod 1111 | def _reorder_cache(past_key_values, beam_idx): 1112 | reordered_past = () 1113 | for layer_past in past_key_values: 1114 | reordered_past += ( 1115 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1116 | ) 1117 | return reordered_past 1118 | 1119 | 1120 | @add_start_docstrings( 1121 | """ 1122 | The LLaMa Model transformer with a sequence classification head on top (linear layer). 1123 | 1124 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1125 | (e.g. GPT-2) do. 1126 | 1127 | Since it does classification on the last token, it requires to know the position of the last token. If a 1128 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1129 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1130 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1131 | each row of the batch). 1132 | """, 1133 | LLAMA_START_DOCSTRING, 1134 | ) 1135 | class LlamaForSequenceClassification(LlamaPreTrainedModel): 1136 | def __init__(self, config): 1137 | super().__init__(config) 1138 | self.num_labels = config.num_labels 1139 | self.model = LlamaModel(config) 1140 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1141 | 1142 | # Initialize weights and apply final processing 1143 | self.post_init() 1144 | 1145 | def get_input_embeddings(self): 1146 | return self.model.embed_tokens 1147 | 1148 | def set_input_embeddings(self, value): 1149 | self.model.embed_tokens = value 1150 | 1151 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1152 | def forward( 1153 | self, 1154 | input_ids: torch.LongTensor = None, 1155 | attention_mask: Optional[torch.Tensor] = None, 1156 | position_ids: Optional[torch.LongTensor] = None, 1157 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1158 | inputs_embeds: Optional[torch.FloatTensor] = None, 1159 | labels: Optional[torch.LongTensor] = None, 1160 | use_cache: Optional[bool] = None, 1161 | output_attentions: Optional[bool] = None, 1162 | output_hidden_states: Optional[bool] = None, 1163 | return_dict: Optional[bool] = None, 1164 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1165 | r""" 1166 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1167 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1168 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1169 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1170 | """ 1171 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1172 | 1173 | transformer_outputs = self.model( 1174 | input_ids, 1175 | attention_mask=attention_mask, 1176 | position_ids=position_ids, 1177 | past_key_values=past_key_values, 1178 | inputs_embeds=inputs_embeds, 1179 | use_cache=use_cache, 1180 | output_attentions=output_attentions, 1181 | output_hidden_states=output_hidden_states, 1182 | return_dict=return_dict, 1183 | ) 1184 | hidden_states = transformer_outputs[0] 1185 | logits = self.score(hidden_states) 1186 | 1187 | if input_ids is not None: 1188 | batch_size = input_ids.shape[0] 1189 | else: 1190 | batch_size = inputs_embeds.shape[0] 1191 | 1192 | if self.config.pad_token_id is None and batch_size != 1: 1193 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1194 | if self.config.pad_token_id is None: 1195 | sequence_lengths = -1 1196 | else: 1197 | if input_ids is not None: 1198 | sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( 1199 | logits.device 1200 | ) 1201 | else: 1202 | sequence_lengths = -1 1203 | 1204 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1205 | 1206 | loss = None 1207 | if labels is not None: 1208 | labels = labels.to(logits.device) 1209 | if self.config.problem_type is None: 1210 | if self.num_labels == 1: 1211 | self.config.problem_type = "regression" 1212 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1213 | self.config.problem_type = "single_label_classification" 1214 | else: 1215 | self.config.problem_type = "multi_label_classification" 1216 | 1217 | if self.config.problem_type == "regression": 1218 | loss_fct = MSELoss() 1219 | if self.num_labels == 1: 1220 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1221 | else: 1222 | loss = loss_fct(pooled_logits, labels) 1223 | elif self.config.problem_type == "single_label_classification": 1224 | loss_fct = CrossEntropyLoss() 1225 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1226 | elif self.config.problem_type == "multi_label_classification": 1227 | loss_fct = BCEWithLogitsLoss() 1228 | loss = loss_fct(pooled_logits, labels) 1229 | if not return_dict: 1230 | output = (pooled_logits,) + transformer_outputs[1:] 1231 | return ((loss,) + output) if loss is not None else output 1232 | 1233 | return SequenceClassifierOutputWithPast( 1234 | loss=loss, 1235 | logits=pooled_logits, 1236 | past_key_values=transformer_outputs.past_key_values, 1237 | hidden_states=transformer_outputs.hidden_states, 1238 | attentions=transformer_outputs.attentions, 1239 | ) 1240 | -------------------------------------------------------------------------------- /models/modeling_opt_orig.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch OPT model.""" 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.utils.checkpoint 20 | from torch import nn 21 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 22 | 23 | from ...activations import ACT2FN 24 | from ...modeling_outputs import ( 25 | BaseModelOutputWithPast, 26 | CausalLMOutputWithPast, 27 | QuestionAnsweringModelOutput, 28 | SequenceClassifierOutputWithPast, 29 | ) 30 | from ...modeling_utils import PreTrainedModel 31 | from ...utils import ( 32 | add_code_sample_docstrings, 33 | add_start_docstrings, 34 | add_start_docstrings_to_model_forward, 35 | logging, 36 | replace_return_docstrings, 37 | ) 38 | from .configuration_opt import OPTConfig 39 | 40 | 41 | logger = logging.get_logger(__name__) 42 | 43 | _CHECKPOINT_FOR_DOC = "facebook/opt-350m" 44 | _CONFIG_FOR_DOC = "OPTConfig" 45 | 46 | # Base model docstring 47 | _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] 48 | 49 | # SequenceClassification docstring 50 | _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" 51 | _SEQ_CLASS_EXPECTED_LOSS = 1.71 52 | _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" 53 | 54 | OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 55 | "facebook/opt-125m", 56 | "facebook/opt-350m", 57 | "facebook/opt-1.3b", 58 | "facebook/opt-2.7b", 59 | "facebook/opt-6.7b", 60 | "facebook/opt-13b", 61 | "facebook/opt-30b", 62 | # See all OPT models at https://huggingface.co/models?filter=opt 63 | ] 64 | 65 | 66 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 67 | def _make_causal_mask( 68 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 69 | ): 70 | """ 71 | Make causal mask used for bi-directional self-attention. 72 | """ 73 | bsz, tgt_len = input_ids_shape 74 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 75 | mask_cond = torch.arange(mask.size(-1), device=device) 76 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 77 | mask = mask.to(dtype) 78 | 79 | if past_key_values_length > 0: 80 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 81 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 82 | 83 | 84 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 85 | """ 86 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 87 | """ 88 | bsz, src_len = mask.size() 89 | tgt_len = tgt_len if tgt_len is not None else src_len 90 | 91 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 92 | 93 | inverted_mask = 1.0 - expanded_mask 94 | 95 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 96 | 97 | 98 | class OPTLearnedPositionalEmbedding(nn.Embedding): 99 | """ 100 | This module learns positional embeddings up to a fixed maximum size. 101 | """ 102 | 103 | def __init__(self, num_embeddings: int, embedding_dim: int): 104 | # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 105 | # and adjust num_embeddings appropriately. Other models don't have this hack 106 | self.offset = 2 107 | super().__init__(num_embeddings + self.offset, embedding_dim) 108 | 109 | def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): 110 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 111 | attention_mask = attention_mask.long() 112 | 113 | # create positions depending on attention_mask 114 | positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 115 | 116 | # cut positions if `past_key_values_length` is > 0 117 | positions = positions[:, past_key_values_length:] 118 | 119 | return super().forward(positions + self.offset) 120 | 121 | 122 | class OPTAttention(nn.Module): 123 | """Multi-headed attention from 'Attention Is All You Need' paper""" 124 | 125 | def __init__( 126 | self, 127 | embed_dim: int, 128 | num_heads: int, 129 | dropout: float = 0.0, 130 | is_decoder: bool = False, 131 | bias: bool = True, 132 | ): 133 | super().__init__() 134 | self.embed_dim = embed_dim 135 | self.num_heads = num_heads 136 | self.dropout = dropout 137 | self.head_dim = embed_dim // num_heads 138 | 139 | if (self.head_dim * num_heads) != self.embed_dim: 140 | raise ValueError( 141 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 142 | f" and `num_heads`: {num_heads})." 143 | ) 144 | self.scaling = self.head_dim**-0.5 145 | self.is_decoder = is_decoder 146 | 147 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 148 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 149 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 150 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 151 | 152 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 153 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 154 | 155 | def forward( 156 | self, 157 | hidden_states: torch.Tensor, 158 | key_value_states: Optional[torch.Tensor] = None, 159 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 160 | attention_mask: Optional[torch.Tensor] = None, 161 | layer_head_mask: Optional[torch.Tensor] = None, 162 | output_attentions: bool = False, 163 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 164 | """Input shape: Batch x Time x Channel""" 165 | 166 | # if key_value_states are provided this layer is used as a cross-attention layer 167 | # for the decoder 168 | is_cross_attention = key_value_states is not None 169 | 170 | bsz, tgt_len, _ = hidden_states.size() 171 | 172 | # get query proj 173 | query_states = self.q_proj(hidden_states) * self.scaling 174 | # get key, value proj 175 | if is_cross_attention and past_key_value is not None: 176 | # reuse k,v, cross_attentions 177 | key_states = past_key_value[0] 178 | value_states = past_key_value[1] 179 | elif is_cross_attention: 180 | # cross_attentions 181 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 182 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 183 | elif past_key_value is not None: 184 | # reuse k, v, self_attention 185 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 186 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 187 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 188 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 189 | else: 190 | # self_attention 191 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 192 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 193 | 194 | if self.is_decoder: 195 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 196 | # Further calls to cross_attention layer can then reuse all cross-attention 197 | # key/value_states (first "if" case) 198 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 199 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 200 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 201 | # if encoder bi-directional self-attention `past_key_value` is always `None` 202 | past_key_value = (key_states, value_states) 203 | 204 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 205 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 206 | key_states = key_states.view(*proj_shape) 207 | value_states = value_states.view(*proj_shape) 208 | 209 | src_len = key_states.size(1) 210 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 211 | 212 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 213 | raise ValueError( 214 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 215 | f" {attn_weights.size()}" 216 | ) 217 | 218 | if attention_mask is not None: 219 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 220 | raise ValueError( 221 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 222 | ) 223 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 224 | attn_weights = torch.max( 225 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) 226 | ) 227 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 228 | 229 | # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 230 | if attn_weights.dtype == torch.float16: 231 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) 232 | else: 233 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 234 | 235 | if layer_head_mask is not None: 236 | if layer_head_mask.size() != (self.num_heads,): 237 | raise ValueError( 238 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" 239 | f" {layer_head_mask.size()}" 240 | ) 241 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 242 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 243 | 244 | if output_attentions: 245 | # this operation is a bit awkward, but it's required to 246 | # make sure that attn_weights keeps its gradient. 247 | # In order to do so, attn_weights have to be reshaped 248 | # twice and have to be reused in the following 249 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 250 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 251 | else: 252 | attn_weights_reshaped = None 253 | 254 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 255 | 256 | attn_output = torch.bmm(attn_probs, value_states) 257 | 258 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 259 | raise ValueError( 260 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 261 | f" {attn_output.size()}" 262 | ) 263 | 264 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 265 | attn_output = attn_output.transpose(1, 2) 266 | 267 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 268 | # partitioned aross GPUs when using tensor-parallelism. 269 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 270 | 271 | attn_output = self.out_proj(attn_output) 272 | 273 | return attn_output, attn_weights_reshaped, past_key_value 274 | 275 | 276 | class OPTDecoderLayer(nn.Module): 277 | def __init__(self, config: OPTConfig): 278 | super().__init__() 279 | self.embed_dim = config.hidden_size 280 | self.self_attn = OPTAttention( 281 | embed_dim=self.embed_dim, 282 | num_heads=config.num_attention_heads, 283 | dropout=config.attention_dropout, 284 | is_decoder=True, 285 | bias=config.enable_bias, 286 | ) 287 | self.do_layer_norm_before = config.do_layer_norm_before 288 | self.dropout = config.dropout 289 | self.activation_fn = ACT2FN[config.activation_function] 290 | 291 | self.self_attn_layer_norm = nn.LayerNorm( 292 | self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine 293 | ) 294 | self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) 295 | self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) 296 | self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) 297 | 298 | def forward( 299 | self, 300 | hidden_states: torch.Tensor, 301 | attention_mask: Optional[torch.Tensor] = None, 302 | layer_head_mask: Optional[torch.Tensor] = None, 303 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 304 | output_attentions: Optional[bool] = False, 305 | use_cache: Optional[bool] = False, 306 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 307 | """ 308 | Args: 309 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 310 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 311 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 312 | layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size 313 | `(encoder_attention_heads,)`. 314 | output_attentions (`bool`, *optional*): 315 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 316 | returned tensors for more detail. 317 | use_cache (`bool`, *optional*): 318 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 319 | (see `past_key_values`). 320 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 321 | """ 322 | 323 | residual = hidden_states 324 | 325 | # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention 326 | if self.do_layer_norm_before: 327 | hidden_states = self.self_attn_layer_norm(hidden_states) 328 | 329 | # Self Attention 330 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 331 | hidden_states=hidden_states, 332 | past_key_value=past_key_value, 333 | attention_mask=attention_mask, 334 | layer_head_mask=layer_head_mask, 335 | output_attentions=output_attentions, 336 | ) 337 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 338 | hidden_states = residual + hidden_states 339 | 340 | # 350m applies layer norm AFTER attention 341 | if not self.do_layer_norm_before: 342 | hidden_states = self.self_attn_layer_norm(hidden_states) 343 | 344 | # Fully Connected 345 | hidden_states_shape = hidden_states.shape 346 | hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) 347 | residual = hidden_states 348 | 349 | # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention 350 | if self.do_layer_norm_before: 351 | hidden_states = self.final_layer_norm(hidden_states) 352 | 353 | hidden_states = self.fc1(hidden_states) 354 | hidden_states = self.activation_fn(hidden_states) 355 | 356 | hidden_states = self.fc2(hidden_states) 357 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 358 | 359 | hidden_states = (residual + hidden_states).view(hidden_states_shape) 360 | 361 | # 350m applies layer norm AFTER attention 362 | if not self.do_layer_norm_before: 363 | hidden_states = self.final_layer_norm(hidden_states) 364 | 365 | outputs = (hidden_states,) 366 | 367 | if output_attentions: 368 | outputs += (self_attn_weights,) 369 | 370 | if use_cache: 371 | outputs += (present_key_value,) 372 | 373 | return outputs 374 | 375 | 376 | OPT_START_DOCSTRING = r""" 377 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 378 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 379 | etc.) 380 | 381 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 382 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 383 | and behavior. 384 | 385 | Parameters: 386 | config ([`OPTConfig`]): 387 | Model configuration class with all the parameters of the model. Initializing with a config file does not 388 | load the weights associated with the model, only the configuration. Check out the 389 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 390 | """ 391 | 392 | 393 | @add_start_docstrings( 394 | "The bare OPT Model outputting raw hidden-states without any specific head on top.", 395 | OPT_START_DOCSTRING, 396 | ) 397 | class OPTPreTrainedModel(PreTrainedModel): 398 | config_class = OPTConfig 399 | base_model_prefix = "model" 400 | supports_gradient_checkpointing = True 401 | _no_split_modules = ["OPTDecoderLayer"] 402 | 403 | def _init_weights(self, module): 404 | std = self.config.init_std 405 | if isinstance(module, nn.Linear): 406 | module.weight.data.normal_(mean=0.0, std=std) 407 | if module.bias is not None: 408 | module.bias.data.zero_() 409 | elif isinstance(module, nn.Embedding): 410 | module.weight.data.normal_(mean=0.0, std=std) 411 | if module.padding_idx is not None: 412 | module.weight.data[module.padding_idx].zero_() 413 | 414 | def _set_gradient_checkpointing(self, module, value=False): 415 | if isinstance(module, (OPTDecoder)): 416 | module.gradient_checkpointing = value 417 | 418 | 419 | OPT_INPUTS_DOCSTRING = r""" 420 | Args: 421 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 422 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 423 | it. 424 | 425 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 426 | [`PreTrainedTokenizer.__call__`] for details. 427 | 428 | [What are input IDs?](../glossary#input-ids) 429 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 430 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 431 | 432 | - 1 for tokens that are **not masked**, 433 | - 0 for tokens that are **masked**. 434 | 435 | [What are attention masks?](../glossary#attention-mask) 436 | 437 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 438 | [`PreTrainedTokenizer.__call__`] for details. 439 | 440 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 441 | `past_key_values`). 442 | 443 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 444 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 445 | information on the default strategy. 446 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 447 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: 448 | 449 | - 1 indicates the head is **not masked**, 450 | - 0 indicates the head is **masked**. 451 | 452 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 453 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 454 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 455 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 456 | 457 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 458 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 459 | 460 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 461 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 462 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 463 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 464 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 465 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 466 | model's internal embedding lookup matrix. 467 | use_cache (`bool`, *optional*): 468 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 469 | `past_key_values`). 470 | output_attentions (`bool`, *optional*): 471 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 472 | tensors for more detail. 473 | output_hidden_states (`bool`, *optional*): 474 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 475 | more detail. 476 | return_dict (`bool`, *optional*): 477 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 478 | """ 479 | 480 | 481 | class OPTDecoder(OPTPreTrainedModel): 482 | """ 483 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] 484 | 485 | Args: 486 | config: OPTConfig 487 | """ 488 | 489 | def __init__(self, config: OPTConfig): 490 | super().__init__(config) 491 | self.dropout = config.dropout 492 | self.layerdrop = config.layerdrop 493 | self.padding_idx = config.pad_token_id 494 | self.max_target_positions = config.max_position_embeddings 495 | self.vocab_size = config.vocab_size 496 | 497 | self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) 498 | self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) 499 | 500 | if config.word_embed_proj_dim != config.hidden_size: 501 | self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) 502 | else: 503 | self.project_out = None 504 | 505 | if config.word_embed_proj_dim != config.hidden_size: 506 | self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) 507 | else: 508 | self.project_in = None 509 | 510 | # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility 511 | # with checkpoints that have been fine-tuned before transformers v4.20.1 512 | # see https://github.com/facebookresearch/metaseq/pull/164 513 | if config.do_layer_norm_before and not config._remove_final_layer_norm: 514 | self.final_layer_norm = nn.LayerNorm( 515 | config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine 516 | ) 517 | else: 518 | self.final_layer_norm = None 519 | 520 | self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 521 | 522 | self.gradient_checkpointing = False 523 | # Initialize weights and apply final processing 524 | self.post_init() 525 | 526 | def get_input_embeddings(self): 527 | return self.embed_tokens 528 | 529 | def set_input_embeddings(self, value): 530 | self.embed_tokens = value 531 | 532 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 533 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 534 | # create causal mask 535 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 536 | combined_attention_mask = None 537 | if input_shape[-1] > 1: 538 | combined_attention_mask = _make_causal_mask( 539 | input_shape, 540 | inputs_embeds.dtype, 541 | device=inputs_embeds.device, 542 | past_key_values_length=past_key_values_length, 543 | ) 544 | 545 | if attention_mask is not None: 546 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 547 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 548 | inputs_embeds.device 549 | ) 550 | combined_attention_mask = ( 551 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 552 | ) 553 | 554 | return combined_attention_mask 555 | 556 | def forward( 557 | self, 558 | input_ids: torch.LongTensor = None, 559 | attention_mask: Optional[torch.Tensor] = None, 560 | head_mask: Optional[torch.Tensor] = None, 561 | past_key_values: Optional[List[torch.FloatTensor]] = None, 562 | inputs_embeds: Optional[torch.FloatTensor] = None, 563 | use_cache: Optional[bool] = None, 564 | output_attentions: Optional[bool] = None, 565 | output_hidden_states: Optional[bool] = None, 566 | return_dict: Optional[bool] = None, 567 | ) -> Union[Tuple, BaseModelOutputWithPast]: 568 | r""" 569 | Args: 570 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 571 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 572 | provide it. 573 | 574 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 575 | [`PreTrainedTokenizer.__call__`] for details. 576 | 577 | [What are input IDs?](../glossary#input-ids) 578 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 579 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 580 | 581 | - 1 for tokens that are **not masked**, 582 | - 0 for tokens that are **masked**. 583 | 584 | [What are attention masks?](../glossary#attention-mask) 585 | head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): 586 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 587 | 588 | - 1 indicates the head is **not masked**, 589 | - 0 indicates the head is **masked**. 590 | 591 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 592 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 593 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 594 | 595 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 596 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 597 | 598 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 599 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 600 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 601 | 602 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 603 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 604 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 605 | than the model's internal embedding lookup matrix. 606 | output_attentions (`bool`, *optional*): 607 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 608 | returned tensors for more detail. 609 | output_hidden_states (`bool`, *optional*): 610 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 611 | for more detail. 612 | return_dict (`bool`, *optional*): 613 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 614 | """ 615 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 616 | output_hidden_states = ( 617 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 618 | ) 619 | use_cache = use_cache if use_cache is not None else self.config.use_cache 620 | 621 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 622 | 623 | # retrieve input_ids and inputs_embeds 624 | if input_ids is not None and inputs_embeds is not None: 625 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 626 | elif input_ids is not None: 627 | input_shape = input_ids.size() 628 | input_ids = input_ids.view(-1, input_shape[-1]) 629 | elif inputs_embeds is not None: 630 | input_shape = inputs_embeds.size()[:-1] 631 | else: 632 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 633 | 634 | if inputs_embeds is None: 635 | inputs_embeds = self.embed_tokens(input_ids) 636 | 637 | batch_size, seq_length = input_shape 638 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 639 | # required mask seq length can be calculated via length of past 640 | mask_seq_length = past_key_values_length + seq_length 641 | 642 | # embed positions 643 | if attention_mask is None: 644 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) 645 | elif attention_mask.shape[1] != mask_seq_length: 646 | raise ValueError( 647 | f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " 648 | f"{mask_seq_length} (sum of the lengths of current and past inputs)" 649 | ) 650 | causal_attention_mask = self._prepare_decoder_attention_mask( 651 | attention_mask, input_shape, inputs_embeds, past_key_values_length 652 | ) 653 | pos_embeds = self.embed_positions(attention_mask, past_key_values_length) 654 | 655 | if self.project_in is not None: 656 | inputs_embeds = self.project_in(inputs_embeds) 657 | 658 | hidden_states = inputs_embeds + pos_embeds 659 | 660 | if self.gradient_checkpointing and self.training: 661 | if use_cache: 662 | logger.warning_once( 663 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 664 | ) 665 | use_cache = False 666 | 667 | # decoder layers 668 | all_hidden_states = () if output_hidden_states else None 669 | all_self_attns = () if output_attentions else None 670 | next_decoder_cache = () if use_cache else None 671 | 672 | # check if head_mask has a correct number of layers specified if desired 673 | for attn_mask, mask_name in zip([head_mask], ["head_mask"]): 674 | if attn_mask is not None: 675 | if attn_mask.size()[0] != (len(self.layers)): 676 | raise ValueError( 677 | f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" 678 | f" {head_mask.size()[0]}." 679 | ) 680 | 681 | for idx, decoder_layer in enumerate(self.layers): 682 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 683 | if output_hidden_states: 684 | all_hidden_states += (hidden_states,) 685 | 686 | if self.training: 687 | dropout_probability = torch.rand([]) 688 | if dropout_probability < self.layerdrop: 689 | continue 690 | 691 | past_key_value = past_key_values[idx] if past_key_values is not None else None 692 | 693 | if self.gradient_checkpointing and self.training: 694 | 695 | def create_custom_forward(module): 696 | def custom_forward(*inputs): 697 | # None for past_key_value 698 | return module(*inputs, output_attentions, None) 699 | 700 | return custom_forward 701 | 702 | layer_outputs = torch.utils.checkpoint.checkpoint( 703 | create_custom_forward(decoder_layer), 704 | hidden_states, 705 | causal_attention_mask, 706 | head_mask[idx] if head_mask is not None else None, 707 | None, 708 | ) 709 | else: 710 | layer_outputs = decoder_layer( 711 | hidden_states, 712 | attention_mask=causal_attention_mask, 713 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 714 | past_key_value=past_key_value, 715 | output_attentions=output_attentions, 716 | use_cache=use_cache, 717 | ) 718 | 719 | hidden_states = layer_outputs[0] 720 | 721 | if use_cache: 722 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 723 | 724 | if output_attentions: 725 | all_self_attns += (layer_outputs[1],) 726 | 727 | if self.final_layer_norm is not None: 728 | hidden_states = self.final_layer_norm(hidden_states) 729 | 730 | if self.project_out is not None: 731 | hidden_states = self.project_out(hidden_states) 732 | 733 | # add hidden states from the last decoder layer 734 | if output_hidden_states: 735 | all_hidden_states += (hidden_states,) 736 | 737 | next_cache = next_decoder_cache if use_cache else None 738 | if not return_dict: 739 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 740 | return BaseModelOutputWithPast( 741 | last_hidden_state=hidden_states, 742 | past_key_values=next_cache, 743 | hidden_states=all_hidden_states, 744 | attentions=all_self_attns, 745 | ) 746 | 747 | 748 | @add_start_docstrings( 749 | "The bare OPT Model outputting raw hidden-states without any specific head on top.", 750 | OPT_START_DOCSTRING, 751 | ) 752 | class OPTModel(OPTPreTrainedModel): 753 | def __init__(self, config: OPTConfig): 754 | super().__init__(config) 755 | self.decoder = OPTDecoder(config) 756 | # Initialize weights and apply final processing 757 | self.post_init() 758 | 759 | def get_input_embeddings(self): 760 | return self.decoder.embed_tokens 761 | 762 | def set_input_embeddings(self, value): 763 | self.decoder.embed_tokens = value 764 | 765 | def get_decoder(self): 766 | return self.decoder 767 | 768 | @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) 769 | @add_code_sample_docstrings( 770 | checkpoint=_CHECKPOINT_FOR_DOC, 771 | output_type=BaseModelOutputWithPast, 772 | config_class=_CONFIG_FOR_DOC, 773 | expected_output=_EXPECTED_OUTPUT_SHAPE, 774 | ) 775 | def forward( 776 | self, 777 | input_ids: torch.LongTensor = None, 778 | attention_mask: Optional[torch.Tensor] = None, 779 | head_mask: Optional[torch.Tensor] = None, 780 | past_key_values: Optional[List[torch.FloatTensor]] = None, 781 | inputs_embeds: Optional[torch.FloatTensor] = None, 782 | use_cache: Optional[bool] = None, 783 | output_attentions: Optional[bool] = None, 784 | output_hidden_states: Optional[bool] = None, 785 | return_dict: Optional[bool] = None, 786 | ) -> Union[Tuple, BaseModelOutputWithPast]: 787 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 788 | output_hidden_states = ( 789 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 790 | ) 791 | use_cache = use_cache if use_cache is not None else self.config.use_cache 792 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 793 | 794 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 795 | decoder_outputs = self.decoder( 796 | input_ids=input_ids, 797 | attention_mask=attention_mask, 798 | head_mask=head_mask, 799 | past_key_values=past_key_values, 800 | inputs_embeds=inputs_embeds, 801 | use_cache=use_cache, 802 | output_attentions=output_attentions, 803 | output_hidden_states=output_hidden_states, 804 | return_dict=return_dict, 805 | ) 806 | 807 | if not return_dict: 808 | return decoder_outputs 809 | 810 | return BaseModelOutputWithPast( 811 | last_hidden_state=decoder_outputs.last_hidden_state, 812 | past_key_values=decoder_outputs.past_key_values, 813 | hidden_states=decoder_outputs.hidden_states, 814 | attentions=decoder_outputs.attentions, 815 | ) 816 | 817 | 818 | class OPTForCausalLM(OPTPreTrainedModel): 819 | _tied_weights_keys = ["lm_head.weight"] 820 | 821 | def __init__(self, config): 822 | super().__init__(config) 823 | self.model = OPTModel(config) 824 | 825 | # the lm_head weight is automatically tied to the embed tokens weight 826 | self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) 827 | 828 | # Initialize weights and apply final processing 829 | self.post_init() 830 | 831 | def get_input_embeddings(self): 832 | return self.model.decoder.embed_tokens 833 | 834 | def set_input_embeddings(self, value): 835 | self.model.decoder.embed_tokens = value 836 | 837 | def get_output_embeddings(self): 838 | return self.lm_head 839 | 840 | def set_output_embeddings(self, new_embeddings): 841 | self.lm_head = new_embeddings 842 | 843 | def set_decoder(self, decoder): 844 | self.model.decoder = decoder 845 | 846 | def get_decoder(self): 847 | return self.model.decoder 848 | 849 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 850 | def forward( 851 | self, 852 | input_ids: torch.LongTensor = None, 853 | attention_mask: Optional[torch.Tensor] = None, 854 | head_mask: Optional[torch.Tensor] = None, 855 | past_key_values: Optional[List[torch.FloatTensor]] = None, 856 | inputs_embeds: Optional[torch.FloatTensor] = None, 857 | labels: Optional[torch.LongTensor] = None, 858 | use_cache: Optional[bool] = None, 859 | output_attentions: Optional[bool] = None, 860 | output_hidden_states: Optional[bool] = None, 861 | return_dict: Optional[bool] = None, 862 | ) -> Union[Tuple, CausalLMOutputWithPast]: 863 | r""" 864 | Args: 865 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 866 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 867 | provide it. 868 | 869 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 870 | [`PreTrainedTokenizer.__call__`] for details. 871 | 872 | [What are input IDs?](../glossary#input-ids) 873 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 874 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 875 | 876 | - 1 for tokens that are **not masked**, 877 | - 0 for tokens that are **masked**. 878 | 879 | [What are attention masks?](../glossary#attention-mask) 880 | head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): 881 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 882 | 883 | - 1 indicates the head is **not masked**, 884 | - 0 indicates the head is **masked**. 885 | 886 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 887 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 888 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 889 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional 890 | tensors are only required when the model is used as a decoder in a Sequence to Sequence model. 891 | 892 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 893 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 894 | 895 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 896 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 897 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 898 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 899 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 900 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 901 | than the model's internal embedding lookup matrix. 902 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 903 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 904 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 905 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 906 | use_cache (`bool`, *optional*): 907 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 908 | (see `past_key_values`). 909 | output_attentions (`bool`, *optional*): 910 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 911 | returned tensors for more detail. 912 | output_hidden_states (`bool`, *optional*): 913 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 914 | for more detail. 915 | return_dict (`bool`, *optional*): 916 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 917 | 918 | Returns: 919 | 920 | Example: 921 | 922 | ```python 923 | >>> from transformers import AutoTokenizer, OPTForCausalLM 924 | 925 | >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") 926 | >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") 927 | 928 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 929 | >>> inputs = tokenizer(prompt, return_tensors="pt") 930 | 931 | >>> # Generate 932 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 933 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 934 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." 935 | ```""" 936 | 937 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 938 | output_hidden_states = ( 939 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 940 | ) 941 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 942 | 943 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 944 | outputs = self.model.decoder( 945 | input_ids=input_ids, 946 | attention_mask=attention_mask, 947 | head_mask=head_mask, 948 | past_key_values=past_key_values, 949 | inputs_embeds=inputs_embeds, 950 | use_cache=use_cache, 951 | output_attentions=output_attentions, 952 | output_hidden_states=output_hidden_states, 953 | return_dict=return_dict, 954 | ) 955 | 956 | logits = self.lm_head(outputs[0]).contiguous() 957 | 958 | loss = None 959 | if labels is not None: 960 | # move labels to correct device to enable model parallelism 961 | labels = labels.to(logits.device) 962 | # Shift so that tokens < n predict n 963 | shift_logits = logits[..., :-1, :].contiguous() 964 | shift_labels = labels[..., 1:].contiguous() 965 | # Flatten the tokens 966 | loss_fct = CrossEntropyLoss() 967 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) 968 | 969 | if not return_dict: 970 | output = (logits,) + outputs[1:] 971 | return (loss,) + output if loss is not None else output 972 | 973 | return CausalLMOutputWithPast( 974 | loss=loss, 975 | logits=logits, 976 | past_key_values=outputs.past_key_values, 977 | hidden_states=outputs.hidden_states, 978 | attentions=outputs.attentions, 979 | ) 980 | 981 | def prepare_inputs_for_generation( 982 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 983 | ): 984 | if past_key_values: 985 | input_ids = input_ids[:, -1:] 986 | 987 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 988 | if inputs_embeds is not None and past_key_values is None: 989 | model_inputs = {"inputs_embeds": inputs_embeds} 990 | else: 991 | model_inputs = {"input_ids": input_ids} 992 | 993 | model_inputs.update( 994 | { 995 | "past_key_values": past_key_values, 996 | "use_cache": kwargs.get("use_cache"), 997 | "attention_mask": attention_mask, 998 | } 999 | ) 1000 | return model_inputs 1001 | 1002 | @staticmethod 1003 | def _reorder_cache(past_key_values, beam_idx): 1004 | reordered_past = () 1005 | for layer_past in past_key_values: 1006 | reordered_past += ( 1007 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1008 | ) 1009 | return reordered_past 1010 | 1011 | 1012 | @add_start_docstrings( 1013 | """ 1014 | The OPT Model transformer with a sequence classification head on top (linear layer). 1015 | 1016 | [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1017 | (e.g. GPT-2) do. 1018 | 1019 | Since it does classification on the last token, it requires to know the position of the last token. If a 1020 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1021 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1022 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1023 | each row of the batch). 1024 | """, 1025 | OPT_START_DOCSTRING, 1026 | ) 1027 | class OPTForSequenceClassification(OPTPreTrainedModel): 1028 | def __init__(self, config: OPTConfig): 1029 | super().__init__(config) 1030 | self.num_labels = config.num_labels 1031 | self.model = OPTModel(config) 1032 | self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) 1033 | 1034 | # Initialize weights and apply final processing 1035 | self.post_init() 1036 | 1037 | @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) 1038 | @add_code_sample_docstrings( 1039 | checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, 1040 | output_type=SequenceClassifierOutputWithPast, 1041 | config_class=_CONFIG_FOR_DOC, 1042 | expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, 1043 | expected_loss=_SEQ_CLASS_EXPECTED_LOSS, 1044 | ) 1045 | def forward( 1046 | self, 1047 | input_ids: Optional[torch.LongTensor] = None, 1048 | attention_mask: Optional[torch.FloatTensor] = None, 1049 | head_mask: Optional[torch.FloatTensor] = None, 1050 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1051 | inputs_embeds: Optional[torch.FloatTensor] = None, 1052 | labels: Optional[torch.LongTensor] = None, 1053 | use_cache: Optional[bool] = None, 1054 | output_attentions: Optional[bool] = None, 1055 | output_hidden_states: Optional[bool] = None, 1056 | return_dict: Optional[bool] = None, 1057 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1058 | r""" 1059 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1060 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1061 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1062 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1063 | """ 1064 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1065 | 1066 | transformer_outputs = self.model( 1067 | input_ids, 1068 | past_key_values=past_key_values, 1069 | attention_mask=attention_mask, 1070 | head_mask=head_mask, 1071 | inputs_embeds=inputs_embeds, 1072 | use_cache=use_cache, 1073 | output_attentions=output_attentions, 1074 | output_hidden_states=output_hidden_states, 1075 | return_dict=return_dict, 1076 | ) 1077 | hidden_states = transformer_outputs[0] 1078 | logits = self.score(hidden_states) 1079 | 1080 | if input_ids is not None: 1081 | batch_size, sequence_length = input_ids.shape[:2] 1082 | else: 1083 | batch_size, sequence_length = inputs_embeds.shape[:2] 1084 | 1085 | if self.config.pad_token_id is None: 1086 | sequence_lengths = -1 1087 | else: 1088 | if input_ids is not None: 1089 | sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( 1090 | logits.device 1091 | ) 1092 | else: 1093 | sequence_lengths = -1 1094 | logger.warning( 1095 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 1096 | "unexpected if using padding tokens in conjunction with `inputs_embeds.`" 1097 | ) 1098 | 1099 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1100 | 1101 | loss = None 1102 | if labels is not None: 1103 | if self.config.problem_type is None: 1104 | if self.num_labels == 1: 1105 | self.config.problem_type = "regression" 1106 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1107 | self.config.problem_type = "single_label_classification" 1108 | else: 1109 | self.config.problem_type = "multi_label_classification" 1110 | 1111 | if self.config.problem_type == "regression": 1112 | loss_fct = MSELoss() 1113 | if self.num_labels == 1: 1114 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1115 | else: 1116 | loss = loss_fct(pooled_logits, labels) 1117 | elif self.config.problem_type == "single_label_classification": 1118 | loss_fct = CrossEntropyLoss() 1119 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1120 | elif self.config.problem_type == "multi_label_classification": 1121 | loss_fct = BCEWithLogitsLoss() 1122 | loss = loss_fct(pooled_logits, labels) 1123 | if not return_dict: 1124 | output = (pooled_logits,) + transformer_outputs[1:] 1125 | return ((loss,) + output) if loss is not None else output 1126 | 1127 | return SequenceClassifierOutputWithPast( 1128 | loss=loss, 1129 | logits=pooled_logits, 1130 | past_key_values=transformer_outputs.past_key_values, 1131 | hidden_states=transformer_outputs.hidden_states, 1132 | attentions=transformer_outputs.attentions, 1133 | ) 1134 | 1135 | def get_input_embeddings(self): 1136 | return self.model.decoder.embed_tokens 1137 | 1138 | def set_input_embeddings(self, value): 1139 | self.model.decoder.embed_tokens = value 1140 | 1141 | 1142 | @add_start_docstrings( 1143 | """ 1144 | The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD 1145 | (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). 1146 | """, 1147 | OPT_START_DOCSTRING, 1148 | ) 1149 | class OPTForQuestionAnswering(OPTPreTrainedModel): 1150 | def __init__(self, config: OPTConfig): 1151 | super().__init__(config) 1152 | self.model = OPTModel(config) 1153 | self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) 1154 | 1155 | # Initialize weights and apply final processing 1156 | self.post_init() 1157 | 1158 | @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) 1159 | @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) 1160 | def forward( 1161 | self, 1162 | input_ids: Optional[torch.LongTensor] = None, 1163 | attention_mask: Optional[torch.FloatTensor] = None, 1164 | head_mask: Optional[torch.FloatTensor] = None, 1165 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 1166 | inputs_embeds: Optional[torch.FloatTensor] = None, 1167 | start_positions: Optional[torch.LongTensor] = None, 1168 | end_positions: Optional[torch.LongTensor] = None, 1169 | use_cache: Optional[bool] = None, 1170 | output_attentions: Optional[bool] = None, 1171 | output_hidden_states: Optional[bool] = None, 1172 | return_dict: Optional[bool] = None, 1173 | ) -> Union[Tuple, QuestionAnsweringModelOutput]: 1174 | r""" 1175 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1176 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1177 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1178 | are not taken into account for computing the loss. 1179 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1180 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1181 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1182 | are not taken into account for computing the loss. 1183 | 1184 | Returns: 1185 | 1186 | Example: 1187 | 1188 | ```python 1189 | >>> from transformers import AutoTokenizer, OPTForQuestionAnswering 1190 | >>> import torch 1191 | 1192 | >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT 1193 | >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") 1194 | 1195 | >>> # note: we are loading a OPTForQuestionAnswering from the hub here, 1196 | >>> # so the head will be randomly initialized, hence the predictions will be random 1197 | >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") 1198 | 1199 | >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" 1200 | 1201 | >>> inputs = tokenizer(question, text, return_tensors="pt") 1202 | >>> with torch.no_grad(): 1203 | ... outputs = model(**inputs) 1204 | 1205 | >>> answer_start_index = outputs.start_logits.argmax() 1206 | >>> answer_end_index = outputs.end_logits.argmax() 1207 | 1208 | >>> answer_offset = len(tokenizer(question)[0]) 1209 | 1210 | >>> predict_answer_tokens = inputs.input_ids[ 1211 | ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 1212 | ... ] 1213 | >>> predicted = tokenizer.decode(predict_answer_tokens) 1214 | >>> predicted 1215 | ' a nice puppet' 1216 | ```""" 1217 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1218 | 1219 | transformer_outputs = self.model( 1220 | input_ids, 1221 | past_key_values=past_key_values, 1222 | attention_mask=attention_mask, 1223 | head_mask=head_mask, 1224 | inputs_embeds=inputs_embeds, 1225 | use_cache=use_cache, 1226 | output_attentions=output_attentions, 1227 | output_hidden_states=output_hidden_states, 1228 | return_dict=return_dict, 1229 | ) 1230 | hidden_states = transformer_outputs[0] 1231 | 1232 | logits = self.qa_outputs(hidden_states) 1233 | start_logits, end_logits = logits.split(1, dim=-1) 1234 | start_logits = start_logits.squeeze(-1).contiguous() 1235 | end_logits = end_logits.squeeze(-1).contiguous() 1236 | 1237 | total_loss = None 1238 | if start_positions is not None and end_positions is not None: 1239 | # If we are on multi-GPU, split add a dimension 1240 | if len(start_positions.size()) > 1: 1241 | start_positions = start_positions.squeeze(-1) 1242 | if len(end_positions.size()) > 1: 1243 | end_positions = end_positions.squeeze(-1) 1244 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1245 | ignored_index = start_logits.size(1) 1246 | start_positions = start_positions.clamp(0, ignored_index) 1247 | end_positions = end_positions.clamp(0, ignored_index) 1248 | 1249 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1250 | start_loss = loss_fct(start_logits, start_positions) 1251 | end_loss = loss_fct(end_logits, end_positions) 1252 | total_loss = (start_loss + end_loss) / 2 1253 | 1254 | if not return_dict: 1255 | output = (start_logits, end_logits) + transformer_outputs[2:] 1256 | return ((total_loss,) + output) if total_loss is not None else output 1257 | 1258 | return QuestionAnsweringModelOutput( 1259 | loss=total_loss, 1260 | start_logits=start_logits, 1261 | end_logits=end_logits, 1262 | hidden_states=transformer_outputs.hidden_states, 1263 | attentions=transformer_outputs.attentions, 1264 | ) 1265 | 1266 | def get_input_embeddings(self): 1267 | return self.model.decoder.embed_tokens 1268 | 1269 | def set_input_embeddings(self, value): 1270 | self.model.decoder.embed_tokens = value 1271 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision==0.15.2 3 | torchaudio==2.0.2 4 | datasets==2.14.7 5 | huggingface-hub==0.17.3 6 | accelerate>=0.12.0 7 | sentencepiece 8 | zstandard 9 | -------------------------------------------------------------------------------- /scripts/datautils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from datasets import load_dataset 5 | from transformers import GPT2Tokenizer, LlamaTokenizer 6 | 7 | 8 | def set_seed(seed): 9 | np.random.seed(seed) 10 | torch.random.manual_seed(seed) 11 | random.seed(seed) 12 | 13 | def get_wikitext2(seqlen, model): 14 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 15 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 16 | 17 | if 'llama' in model: 18 | tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False) 19 | else: 20 | tokenizer = GPT2Tokenizer.from_pretrained(model, use_fast=False) 21 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 22 | 23 | return testenc 24 | 25 | def get_ptb(seqlen, model): 26 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 27 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') 28 | 29 | if 'llama' in model: 30 | tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False) 31 | else: 32 | tokenizer = GPT2Tokenizer.from_pretrained(model, use_fast=False) 33 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 34 | 35 | return testenc 36 | 37 | def get_loaders( 38 | name, seed=0, seqlen=2048, model='', it=0, c4_num=0 39 | ): 40 | set_seed(seed) 41 | if 'wikitext2' in name: 42 | return get_wikitext2(seqlen, model) 43 | if 'ptb' in name: 44 | return get_ptb(seqlen, model) 45 | -------------------------------------------------------------------------------- /scripts/llama.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | 5 | import argparse 6 | from datautils import * 7 | 8 | 9 | def get_llama_base(model, seqlen): 10 | 11 | def skip(*args, **kwargs): 12 | pass 13 | 14 | torch.nn.init.kaiming_uniform_ = skip 15 | torch.nn.init.uniform_ = skip 16 | torch.nn.init.normal_ = skip 17 | from transformers import LlamaForCausalLM 18 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16, device_map='cpu') 19 | 20 | model.seqlen = seqlen 21 | return model 22 | 23 | 24 | @torch.no_grad() 25 | def llama_eval(model, testenc, eval_samples): 26 | print('Evaluating ', end='') 27 | 28 | dev = torch.device('cuda:0') 29 | testenc = testenc.input_ids 30 | if eval_samples: 31 | nsamples = eval_samples 32 | else: 33 | nsamples = min(1000, testenc.numel() // model.seqlen) 34 | print("nsamples: ", nsamples) 35 | 36 | use_cache = model.config.use_cache 37 | model.config.use_cache = False 38 | layers = model.model.layers 39 | 40 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 41 | layers[0] = layers[0].to(dev) 42 | 43 | dtype = next(iter(model.parameters())).dtype 44 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 45 | cache = {'i': 0, 'attention_mask': None} 46 | 47 | class Catcher(nn.Module): 48 | 49 | def __init__(self, module): 50 | super().__init__() 51 | self.module = module 52 | 53 | def forward(self, inp, **kwargs): 54 | inps[cache['i']] = inp 55 | cache['i'] += 1 56 | cache['attention_mask'] = kwargs['attention_mask'] 57 | cache['position_ids'] = kwargs['position_ids'] 58 | raise ValueError 59 | 60 | layers[0] = Catcher(layers[0]) 61 | for i in range(nsamples): 62 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 63 | try: 64 | model(batch) 65 | except ValueError: 66 | pass 67 | layers[0] = layers[0].module 68 | 69 | layers[0] = layers[0].cpu() 70 | model.model.embed_tokens = model.model.embed_tokens.cpu() 71 | torch.cuda.empty_cache() 72 | 73 | outs = torch.zeros_like(inps) 74 | attention_mask = cache['attention_mask'] 75 | position_ids = cache['position_ids'] 76 | 77 | for i in range(len(layers)): 78 | layer = layers[i].to(dev) 79 | 80 | for j in range(nsamples): 81 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 82 | layers[i] = layer.cpu() 83 | del layer 84 | torch.cuda.empty_cache() 85 | inps, outs = outs, inps 86 | print(i, end=' ', flush=True) 87 | print() 88 | 89 | if model.model.norm is not None: 90 | model.model.norm = model.model.norm.to(dev) 91 | model.lm_head = model.lm_head.to(dev) 92 | 93 | testenc = testenc.to(dev) 94 | nlls = [] 95 | for i in range(nsamples): 96 | hidden_states = inps[i].unsqueeze(0) 97 | if model.model.norm is not None: 98 | hidden_states = model.model.norm(hidden_states) 99 | lm_logits = model.lm_head(hidden_states) 100 | shift_logits = lm_logits[:, :-1, :].contiguous() 101 | shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:] 102 | loss_fct = nn.CrossEntropyLoss() 103 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 104 | neg_log_likelihood = loss.float() * model.seqlen 105 | nlls.append(neg_log_likelihood) 106 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 107 | print(ppl.item()) 108 | print("perplexity: %f"%ppl.item()) 109 | 110 | model.config.use_cache = use_cache 111 | 112 | if __name__ == '__main__': 113 | 114 | parser = argparse.ArgumentParser() 115 | 116 | parser.add_argument( 117 | '--model', type=str, 118 | help='Llama model to load' 119 | ) 120 | parser.add_argument( 121 | '--eval_dataset', type=str, 122 | help='evaluation dataset' 123 | ) 124 | parser.add_argument( 125 | '--seq_len', type=int, 126 | help='model sequence length' 127 | ) 128 | parser.add_argument( 129 | '--eval_samples', type=int, default=0, 130 | help='number of sample evaluation dataset' 131 | ) 132 | parser.add_argument('--seed', type=int, default=0, 133 | help='Random seed for data load' 134 | ) 135 | parser.add_argument( 136 | '--q_bits', type=int, default=0, 137 | help='Number of bits for quantization' 138 | ) 139 | parser.add_argument( 140 | '--decomp_factor', type=int, default=0, 141 | help='Number of groups for classification' 142 | ) 143 | parser.add_argument( 144 | '--chunk_size', type=int, default=256, 145 | help='Size of row chunk' 146 | ) 147 | parser.add_argument( 148 | '--quant_mha', action='store_true', 149 | help='Whether to quantize multi-head-attention' 150 | ) 151 | parser.add_argument( 152 | '--scale_factor', type=str, default="", 153 | help='path to scale factor learned from calibration data.' 154 | ) 155 | parser.add_argument( 156 | '--bias', type=str, default="", 157 | help='path to bias learned from calibration data.' 158 | ) 159 | 160 | args = parser.parse_args() 161 | 162 | model = get_llama_base(args.model, args.seq_len) 163 | model.eval() 164 | 165 | if args.scale_factor: 166 | scale_factor = torch.load(args.scale_factor) 167 | for layer in model.model.layers: 168 | attn = layer.self_attn 169 | mlp = layer.mlp 170 | prefix = "model.layers." + str(attn.layer_idx) 171 | 172 | name = prefix + ".self_attn" + "h_tmax" 173 | attn.h_tmax = scale_factor[name] 174 | name = prefix + ".self_attn" + "h_cmax" 175 | attn.h_group_index = scale_factor[name] 176 | name = prefix + ".self_attn" + "o_tmax" 177 | attn.o_tmax = scale_factor[name] 178 | name = prefix + ".self_attn" + "o_cmax" 179 | attn.o_group_index = scale_factor[name] 180 | 181 | if args.quant_mha: 182 | name = prefix + ".self_attn" + "q_tmax" 183 | attn.q_tmax = scale_factor[name] 184 | name = prefix + ".self_attn" + "q_cmax" 185 | attn.q_group_index = scale_factor[name] 186 | name = prefix + ".self_attn" + "s_tmax" 187 | attn.s_tmax = scale_factor[name] 188 | name = prefix + ".self_attn" + "s_cmax" 189 | attn.s_group_index = scale_factor[name] 190 | 191 | name = prefix + ".self_attn" + "k_scale" 192 | attn.k_scale = scale_factor[name] 193 | name = prefix + ".self_attn" + "v_scale" 194 | attn.v_scale = scale_factor[name] 195 | 196 | name = prefix + "fc1_tmax" 197 | mlp.fc1_tmax = scale_factor[name] 198 | name = prefix + "fc1_cmax" 199 | mlp.fc1_group_index = scale_factor[name] 200 | name = prefix + "fc2_tmax" 201 | mlp.fc2_tmax = scale_factor[name] 202 | name = prefix + "fc2_cmax" 203 | mlp.fc2_group_index = scale_factor[name] 204 | 205 | if args.bias: 206 | bias = torch.load(args.bias) 207 | for layer in model.model.layers: 208 | attn = layer.self_attn 209 | mlp = layer.mlp 210 | prefix = "model.layers." + str(attn.layer_idx) 211 | 212 | name = prefix + ".self_attn" + "h_ch_bias" 213 | attn.h_ch_bias = bias[name] 214 | 215 | if args.quant_mha: 216 | name = prefix + ".self_attn" + "q_ch_bias" 217 | attn.q_ch_bias = bias[name] 218 | name = prefix + ".self_attn" + "k_zero_bias" 219 | attn.k_ch_bias = bias[name] 220 | 221 | name = prefix + "h_ch_bias" 222 | mlp.h_ch_bias = bias[name] 223 | 224 | for layer in model.model.layers: 225 | layer.self_attn.quant_mha = args.quant_mha 226 | 227 | layer.self_attn.q_bits = args.q_bits 228 | layer.mlp.q_bits = args.q_bits 229 | 230 | layer.self_attn.decomp_factor = args.decomp_factor 231 | layer.mlp.decomp_factor = args.decomp_factor 232 | 233 | layer.self_attn.chunk_size = args.chunk_size 234 | layer.mlp.chunk_size = args.chunk_size 235 | 236 | testloader= get_loaders( 237 | args.eval_dataset, seed = args.seed, model = args.model, seqlen = model.seqlen 238 | ) 239 | 240 | llama_eval(model, testloader, args.eval_samples) 241 | 242 | -------------------------------------------------------------------------------- /scripts/opt.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | 5 | import argparse 6 | from datautils import * 7 | 8 | 9 | def get_opt_base(model, seq_len): 10 | def skip(*args, **kwargs): 11 | pass 12 | torch.nn.init.kaiming_uniform_ = skip 13 | torch.nn.init.uniform_ = skip 14 | torch.nn.init.normal_ = skip 15 | from transformers import OPTForCausalLM 16 | model = OPTForCausalLM.from_pretrained(model, torch_dtype=torch.float16, device_map='cpu') 17 | 18 | model.seqlen = seq_len 19 | if model.config.max_position_embeddings < seq_len: 20 | print(f"Warning: Given seqlen {model.seqlen} is larger than max length {model.config.max_position_embeddings}") 21 | 22 | return model 23 | 24 | @torch.no_grad() 25 | def opt_eval(model, testenc, eval_samples): 26 | print('Evaluating ', end='') 27 | 28 | dev = torch.device('cuda:0') 29 | testenc = testenc.input_ids 30 | if eval_samples: 31 | nsamples = eval_samples 32 | else: 33 | nsamples = min(1000, testenc.numel() // model.seqlen) 34 | print("nsamples: ", nsamples) 35 | 36 | use_cache = model.config.use_cache 37 | model.config.use_cache = False 38 | layers = model.model.decoder.layers 39 | 40 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 41 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 42 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 43 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 44 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 45 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 46 | layers[0] = layers[0].to(dev) 47 | 48 | dtype = next(iter(model.parameters())).dtype 49 | inps = torch.zeros( 50 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 51 | ) 52 | cache = {'i': 0, 'attention_mask': None} 53 | 54 | class Catcher(nn.Module): 55 | def __init__(self, module): 56 | super().__init__() 57 | self.module = module 58 | def forward(self, inp, **kwargs): 59 | inps[cache['i']] = inp 60 | cache['i'] += 1 61 | cache['attention_mask'] = kwargs['attention_mask'] 62 | raise ValueError 63 | layers[0] = Catcher(layers[0]) 64 | for i in range(nsamples): 65 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 66 | try: 67 | model(batch) 68 | except ValueError: 69 | pass 70 | layers[0] = layers[0].module 71 | 72 | layers[0] = layers[0].cpu() 73 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 74 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 75 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 76 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 77 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 78 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 79 | torch.cuda.empty_cache() 80 | 81 | outs = torch.zeros_like(inps) 82 | attention_mask = cache['attention_mask'] 83 | 84 | for i in range(len(layers)): 85 | layer = layers[i].to(dev) 86 | 87 | for j in range(nsamples): 88 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 89 | layers[i] = layer.cpu() 90 | del layer 91 | torch.cuda.empty_cache() 92 | inps, outs = outs, inps 93 | print(i, end=' ',flush=True) 94 | print() 95 | 96 | if model.model.decoder.final_layer_norm is not None: 97 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 98 | if model.model.decoder.project_out is not None: 99 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 100 | model.lm_head = model.lm_head.to(dev) 101 | 102 | testenc = testenc.to(dev) 103 | nlls = [] 104 | for i in range(nsamples): 105 | hidden_states = inps[i].unsqueeze(0) 106 | if model.model.decoder.final_layer_norm is not None: 107 | hidden_states = model.model.decoder.final_layer_norm(hidden_states) 108 | if model.model.decoder.project_out is not None: 109 | hidden_states = model.model.decoder.project_out(hidden_states) 110 | lm_logits = model.lm_head(hidden_states) 111 | shift_logits = lm_logits[:, :-1, :].contiguous() 112 | shift_labels = testenc[ 113 | :, (i * model.seqlen):((i + 1) * model.seqlen) 114 | ][:, 1:] 115 | loss_fct = nn.CrossEntropyLoss() 116 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 117 | neg_log_likelihood = loss.float() * model.seqlen 118 | nlls.append(neg_log_likelihood) 119 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 120 | print("Perplexity: ", ppl.item(), flush=True) 121 | 122 | model.config.use_cache = use_cache 123 | 124 | 125 | if __name__ == '__main__': 126 | 127 | parser = argparse.ArgumentParser() 128 | 129 | parser.add_argument( 130 | '--model', type=str, 131 | help='OPT model to load; pass `facebook/opt-X`.' 132 | ) 133 | parser.add_argument( 134 | '--eval_dataset', type=str, 135 | help='evaluation dataset' 136 | ) 137 | parser.add_argument( 138 | '--seq_len', type=int, 139 | help='model sequence length' 140 | ) 141 | parser.add_argument( 142 | '--eval_samples', type=int, default=0, 143 | help='number of sample evaluation dataset' 144 | ) 145 | parser.add_argument('--seed', type=int, default=0, 146 | help='Random seed for data load' 147 | ) 148 | parser.add_argument( 149 | '--q_bits', type=int, default=0, 150 | help='Number of bits for quantization' 151 | ) 152 | parser.add_argument( 153 | '--decomp_factor', type=int, default=0, 154 | help='Number of channel groups' 155 | ) 156 | parser.add_argument( 157 | '--chunk_size', type=int, default=256, 158 | help='Size of row chunk' 159 | ) 160 | parser.add_argument( 161 | '--quant_mha', action='store_true', 162 | help='Whether to quantize multi-head-attention' 163 | ) 164 | parser.add_argument( 165 | '--scale_factor', type=str, default="", 166 | help='path to scale factor learned from calibration data.' 167 | ) 168 | parser.add_argument( 169 | '--bias', type=str, default="", 170 | help='path to bias learned from calibration data.' 171 | ) 172 | 173 | args = parser.parse_args() 174 | model = get_opt_base(args.model, args.seq_len) 175 | 176 | if args.scale_factor: 177 | scale_factor = torch.load(args.scale_factor) 178 | for layer in model.model.decoder.layers: 179 | attn = layer.self_attn 180 | prefix = "model.decoder.layers." + str(attn.layer_idx) 181 | 182 | name = prefix + ".self_attn" + "h_tmax" 183 | attn.h_tmax = scale_factor[name] 184 | name = prefix + ".self_attn" + "h_cmax" 185 | attn.h_group_index = scale_factor[name] 186 | name = prefix + ".self_attn" + "o_tmax" 187 | attn.o_tmax = scale_factor[name] 188 | name = prefix + ".self_attn" + "o_cmax" 189 | attn.o_group_index = scale_factor[name] 190 | 191 | if args.quant_mha: 192 | name = prefix + ".self_attn" + "q_tmax" 193 | attn.q_tmax = scale_factor[name] 194 | name = prefix + ".self_attn" + "q_cmax" 195 | attn.q_group_index = scale_factor[name] 196 | name = prefix + ".self_attn" + "s_tmax" 197 | attn.s_tmax = scale_factor[name] 198 | name = prefix + ".self_attn" + "s_cmax" 199 | attn.s_group_index = scale_factor[name] 200 | 201 | name = prefix + ".self_attn" + "k_scale" 202 | attn.k_scale = scale_factor[name] 203 | name = prefix + ".self_attn" + "v_scale" 204 | attn.v_scale = scale_factor[name] 205 | 206 | name = prefix + "fc1_tmax" 207 | layer.fc1_tmax = scale_factor[name] 208 | name = prefix + "fc1_cmax" 209 | layer.fc1_group_index = scale_factor[name] 210 | name = prefix + "fc2_tmax" 211 | layer.fc2_tmax = scale_factor[name] 212 | name = prefix + "fc2_cmax" 213 | layer.fc2_group_index = scale_factor[name] 214 | 215 | if args.bias: 216 | bias = torch.load(args.bias) 217 | for layer in model.model.decoder.layers: 218 | attn = layer.self_attn 219 | prefix = "model.decoder.layers." + str(attn.layer_idx) 220 | 221 | name = prefix + ".self_attn" + "h_ch_bias" 222 | attn.h_ch_bias = bias[name] 223 | name = prefix + ".self_attn" + "o_ch_bias" 224 | attn.o_ch_bias = bias[name] 225 | 226 | if args.quant_mha: 227 | name = prefix + ".self_attn" + "q_ch_bias" 228 | attn.q_ch_bias = bias[name] 229 | name = prefix + ".self_attn" + "k_ch_bias" 230 | attn.k_ch_bias = bias[name] 231 | 232 | name = prefix + "h_ch_bias" 233 | layer.h_ch_bias = bias[name] 234 | 235 | model.eval() 236 | 237 | for layer in model.model.decoder.layers: 238 | layer.self_attn.quant_mha = args.quant_mha 239 | 240 | layer.self_attn.q_bits = args.q_bits 241 | layer.q_bits = args.q_bits 242 | 243 | layer.self_attn.decomp_factor = args.decomp_factor 244 | layer.decomp_factor = args.decomp_factor 245 | 246 | layer.self_attn.chunk_size = args.chunk_size 247 | layer.chunk_size = args.chunk_size 248 | 249 | testloader = get_loaders( 250 | args.eval_dataset, seed = args.seed, seqlen = model.seqlen, model = args.model 251 | ) 252 | 253 | opt_eval(model, testloader, args.eval_samples) 254 | -------------------------------------------------------------------------------- /scripts/table_2.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import os 3 | 4 | # If you set EVAL_SAMPLES 0, evaluate over full dataset. (no sampling) 5 | EVAL_SAMPLES = 0 6 | llama_2_dir = os.environ['LLAMA2_PATH'] 7 | 8 | print('='*10 + ' OPT baseline ' + '='*10) 9 | set_symlink_opt('modeling_opt_orig.py') 10 | for SIZE in ['6.7b', '13b', '66b']: 11 | for SEQLEN in [2048]: 12 | for DATASET in ["wikitext2", 'ptb']: 13 | cmd = "CUDA_VISIBLE_DEVICES=0 python opt.py " 14 | cmd += "--model facebook/opt-%s "%(SIZE) 15 | cmd += "--eval_dataset %s "%(DATASET) 16 | cmd += "--seq_len %d "%(SEQLEN) 17 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 18 | print(cmd) 19 | os.system(cmd) 20 | print("-------------------------------------------") 21 | 22 | print('='*10 + ' OPT Tender-INT4 ' + '='*10) 23 | set_symlink_opt('modeling_opt_tender.py') 24 | for SIZE in ['6.7b', '13b', '66b']: 25 | for SEQLEN in [2048]: 26 | for DATASET in ["wikitext2", 'ptb']: 27 | for BITS in [4, 8]: 28 | DECOMP = opt_decomp_params(SIZE, BITS) 29 | cmd = "CUDA_VISIBLE_DEVICES=0 python opt.py " 30 | cmd += "--model facebook/opt-%s "%(SIZE) 31 | cmd += "--eval_dataset %s "%(DATASET) 32 | cmd += "--seq_len %d "%(SEQLEN) 33 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 34 | cmd += "--q_bits %d "%(BITS) 35 | cmd += "--decomp_factor %d "%(DECOMP) 36 | cmd += "--chunk_size %d "%(256) 37 | cmd += "--scale_factor %s "%(f"../calibration/opt/scale/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp.pt") 38 | cmd += "--bias %s "%(f"../calibration/opt/bias/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp.pt") 39 | print(cmd) 40 | os.system(cmd) 41 | print("-------------------------------------------") 42 | 43 | print('='*10 + ' Llama-2 baseline ' + '='*10) 44 | set_symlink_llama('modeling_llama_orig.py') 45 | for SIZE in ['7b', '13b', '70b']: 46 | for SEQLEN in [2048]: 47 | for DATASET in ["wikitext2", "ptb"]: 48 | cmd = "CUDA_VISIBLE_DEVICES=0 python llama.py " 49 | cmd += "--model %s/llama-2-%s "%(llama_2_dir, SIZE) 50 | cmd += "--eval_dataset %s "%(DATASET) 51 | cmd += "--seq_len %d "%(SEQLEN) 52 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 53 | print(cmd) 54 | os.system(cmd) 55 | print("-------------------------------------------") 56 | 57 | print('='*10 + ' Llama-2 Tender-INT4 ' + '='*10) 58 | set_symlink_llama('modeling_llama_tender.py') 59 | for SIZE in ['7b', '13b', '70b']: 60 | for SEQLEN in [2048]: 61 | for DATASET in ["wikitext2", 'ptb']: 62 | for BITS in [4, 8]: 63 | DECOMP = llama2_decomp_params(SIZE, BITS) 64 | cmd = "CUDA_VISIBLE_DEVICES=0 python llama.py " 65 | cmd += "--model %s/llama-2-%s "%(llama_2_dir, SIZE) 66 | cmd += "--eval_dataset %s "%(DATASET) 67 | cmd += "--seq_len %d "%(SEQLEN) 68 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 69 | cmd += "--q_bits %d "%(BITS) 70 | cmd += "--decomp_factor %d "%(DECOMP) 71 | cmd += "--chunk_size %d "%(256) 72 | cmd += "--scale_factor %s "%(f"../calibration/llama/llama-2-scale/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp.pt") 73 | cmd += "--bias %s "%(f"../calibration/llama/llama-2-bias/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp.pt") 74 | print(cmd) 75 | os.system(cmd) 76 | print("-------------------------------------------") 77 | 78 | print('='*10 + ' Llama-1 baseline ' + '='*10) 79 | set_symlink_llama('modeling_llama_orig.py') 80 | model_name = {"7b": "baffo32/decapoda-research-llama-7B-hf", 81 | "13b": "JG22/decapoda-research-llama-13b"} 82 | for SIZE in ['7b', '13b']: 83 | for SEQLEN in [2048]: 84 | for DATASET in ["wikitext2", "ptb"]: 85 | cmd = "CUDA_VISIBLE_DEVICES=0 python llama.py " 86 | cmd += "--model %s "%(model_name[SIZE]) 87 | cmd += "--eval_dataset %s "%(DATASET) 88 | cmd += "--seq_len %d "%(SEQLEN) 89 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 90 | print(cmd) 91 | os.system(cmd) 92 | print("-------------------------------------------") 93 | 94 | print('='*10 + ' Llama-1 Tender-INT4 ' + '='*10) 95 | set_symlink_llama('modeling_llama_tender.py') 96 | for SIZE in ['7b', '13b']: 97 | for SEQLEN in [2048]: 98 | for DATASET in ["wikitext2", 'ptb']: 99 | for BITS in [4, 8]: 100 | DECOMP = 14 101 | cmd = "CUDA_VISIBLE_DEVICES=0 python llama.py " 102 | cmd += "--model %s "%(model_name[SIZE]) 103 | cmd += "--eval_dataset %s "%(DATASET) 104 | cmd += "--seq_len %d "%(SEQLEN) 105 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 106 | cmd += "--q_bits %d "%(BITS) 107 | cmd += "--decomp_factor %d "%(DECOMP) 108 | cmd += "--chunk_size %d "%(256) 109 | cmd += "--scale_factor %s "%(f"../calibration/llama/llama-1-scale/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp.pt") 110 | cmd += "--bias %s "%(f"../calibration/llama/llama-1-bias/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp.pt") 111 | print(cmd) 112 | os.system(cmd) 113 | print("-------------------------------------------") 114 | -------------------------------------------------------------------------------- /scripts/table_3.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | # If you set EVAL_SAMPLES 0, evaluate over full dataset. (no sampling) 4 | EVAL_SAMPLES = 0 5 | 6 | print('='*10 + ' OPT baseline ' + '='*10) 7 | set_symlink_opt('modeling_opt_orig.py') 8 | for SIZE in ['6.7b']: 9 | for SEQLEN in [2048, 256, 32]: 10 | for DATASET in ["wikitext2", 'ptb']: 11 | cmd = "CUDA_VISIBLE_DEVICES=0 python opt.py " 12 | cmd += "--model facebook/opt-%s "%(SIZE) 13 | cmd += "--eval_dataset %s "%(DATASET) 14 | cmd += "--seq_len %d "%(SEQLEN) 15 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 16 | print(cmd) 17 | os.system(cmd) 18 | print("-------------------------------------------") 19 | 20 | print('='*10 + ' OPT Tender-INT4 ' + '='*10) 21 | set_symlink_opt('modeling_opt_tender.py') 22 | for SIZE in ['6.7b']: 23 | for SEQLEN in [2048, 256, 32]: 24 | for DATASET in ["wikitext2", 'ptb']: 25 | for BITS in [4, 8]: 26 | DECOMP = opt_decomp_params(SIZE, BITS) 27 | cmd = "CUDA_VISIBLE_DEVICES=0 python opt.py " 28 | cmd += "--model facebook/opt-%s "%(SIZE) 29 | cmd += "--eval_dataset %s "%(DATASET) 30 | cmd += "--seq_len %d "%(SEQLEN) 31 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 32 | cmd += "--q_bits %d "%(BITS) 33 | cmd += "--decomp_factor %d "%(DECOMP) 34 | cmd += "--chunk_size %d "%(256) 35 | cmd += "--scale_factor %s "%(f"../calibration/opt/scale/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp.pt") 36 | cmd += "--bias %s "%(f"../calibration/opt/bias/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp.pt") 37 | print(cmd) 38 | os.system(cmd) 39 | print("-------------------------------------------") 40 | 41 | print('='*10 + ' OPT Tender-INT4 (all)' + '='*10) 42 | for SIZE in ['6.7b']: 43 | for SEQLEN in [2048, 256, 32]: 44 | for DATASET in ["wikitext2", 'ptb']: 45 | for BITS in [4, 8]: 46 | DECOMP = opt_decomp_params(SIZE, BITS) 47 | cmd = "CUDA_VISIBLE_DEVICES=0 python opt.py " 48 | cmd += "--model facebook/opt-%s "%(SIZE) 49 | cmd += "--eval_dataset %s "%(DATASET) 50 | cmd += "--seq_len %d "%(SEQLEN) 51 | cmd += "--eval_samples %d "%(EVAL_SAMPLES) 52 | cmd += "--q_bits %d "%(BITS) 53 | cmd += "--decomp_factor %d "%(DECOMP) 54 | cmd += "--chunk_size %d "%(256) 55 | cmd += "--quant_mha " 56 | cmd += "--scale_factor %s "%(f"../calibration/opt/scale/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp_mha.pt") 57 | cmd += "--bias %s "%(f"../calibration/opt/bias/2048_{SIZE}_128_{BITS}bit_{DECOMP}decomp_mha.pt") 58 | print(cmd) 59 | os.system(cmd) 60 | print("-------------------------------------------") 61 | -------------------------------------------------------------------------------- /scripts/table_7.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import os 3 | 4 | tasks=[ 5 | "hellaswag" , 6 | "lambada_openai", 7 | "arc_challenge", 8 | "arc_easy", 9 | "wic", 10 | "anli_r2", 11 | "winogrande", 12 | "hendrycksTest-college_computer_science", 13 | "hendrycksTest-international_law", 14 | "hendrycksTest-jurisprudence" 15 | ] 16 | task_str = ",".join(tasks) 17 | 18 | os.chdir("../lm-evaluation-harness") 19 | 20 | ## OPT 21 | # Baseline 22 | print('='*10 + ' OPT Baseline ' + '='*10) 23 | set_symlink_opt('modeling_opt_orig.py') 24 | cmd = ("python main.py " 25 | "--model hf-causal " 26 | "--model_args pretrained=facebook/opt-6.7b,dtype=float32,scheme=base " 27 | f"--tasks {task_str} " 28 | "--num_fewshot 0 " 29 | "--no_cache " 30 | "--device cuda:0 ") 31 | print(cmd) 32 | os.system(cmd) 33 | 34 | # Tender 35 | print('='*10 + ' OPT Tender-INT4 ' + '='*10) 36 | set_symlink_opt('modeling_opt_tender.py') 37 | cmd = ("python main.py " 38 | "--model hf-causal " 39 | "--model_args pretrained=facebook/opt-6.7b,dtype=float32,scheme=tender " 40 | f"--tasks {task_str} " 41 | "--num_fewshot 0 " 42 | "--no_cache " 43 | "--device cuda:0 ") 44 | print(cmd) 45 | os.system(cmd) 46 | 47 | ## Llama 48 | # Baseline 49 | print('='*10 + ' LLaMA Baseline ' + '='*10) 50 | set_symlink_llama('modeling_llama_orig.py') 51 | cmd = ("python main.py " 52 | "--model hf-causal " 53 | "--model_args pretrained=baffo32/decapoda-research-llama-7B-hf,dtype=float32,scheme=base " 54 | f"--tasks {task_str} " 55 | "--num_fewshot 0 " 56 | "--no_cache " 57 | "--device cuda:0 ") 58 | print(cmd) 59 | os.system(cmd) 60 | 61 | # Tender 62 | print('='*10 + ' LLaMA Tender-INT4 ' + '='*10) 63 | set_symlink_llama('modeling_llama_tender.py') 64 | cmd = ("python main.py " 65 | "--model hf-causal " 66 | "--model_args pretrained=baffo32/decapoda-research-llama-7B-hf,dtype=float32,scheme=tender " 67 | f"--tasks {task_str} " 68 | "--num_fewshot 0 " 69 | "--no_cache " 70 | "--device cuda:0 ") 71 | print(cmd) 72 | os.system(cmd) 73 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def opt_decomp_params(size, bits): 4 | if size == "6.7b": 5 | if bits == 4: 6 | decomp = 8 7 | elif bits == 8: 8 | decomp = 4 9 | elif size == "13b": 10 | if bits == 4: 11 | decomp = 8 12 | elif bits == 8: 13 | decomp = 4 14 | elif size == "66b": 15 | if bits == 4: 16 | decomp = 10 17 | elif bits == 8: 18 | decomp = 8 19 | else: 20 | raise ValueError 21 | return decomp 22 | 23 | def llama2_decomp_params(size, bits): 24 | if size == "7b": 25 | if bits == 4: 26 | decomp = 14 27 | elif bits == 8: 28 | decomp = 8 29 | elif size == "13b": 30 | if bits == 4: 31 | decomp = 16 32 | elif bits == 8: 33 | decomp = 14 34 | elif size == "70b": 35 | if bits == 4: 36 | decomp = 20 37 | elif bits == 8: 38 | decomp = 16 39 | else: 40 | raise ValueError 41 | return decomp 42 | 43 | 44 | def set_symlink_opt(name): 45 | if not os.path.exists(f'../models/{name}'): 46 | print(f'no such file in ../models/{name}') 47 | exit(1) 48 | 49 | if os.path.exists(f'../models/modeling_opt.py'): 50 | os.system(f'rm ../models/modeling_opt.py') 51 | 52 | os.system(f'ln -s ../models/{name} ../models/modeling_opt.py') 53 | 54 | 55 | def set_symlink_llama(name): 56 | if not os.path.exists(f'../models/{name}'): 57 | print(f'no such file in ../models/{name}') 58 | exit(1) 59 | 60 | if os.path.exists(f'../models/modeling_llama.py'): 61 | os.system(f'rm ../models/modeling_llama.py') 62 | 63 | os.system(f'ln -s ../models/{name} ../models/modeling_llama.py') 64 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # dataset 4 | wget https://huggingface.co/datasets/mit-han-lab/pile-val-backup/resolve/main/val.jsonl.zst 5 | mkdir -p ./data 6 | mv val.jsonl.zst ./data 7 | 8 | # model setup 9 | 10 | CWD=${PWD} 11 | cd transformers/src/transformers/models 12 | 13 | for model in llama opt;do 14 | mv ${model}/modeling_${model}.py ${model}/modeling_${model}_orig.py 15 | ln -s ${CWD}/models/modeling_${model}.py ${model}/modeling_${model}.py 16 | done 17 | --------------------------------------------------------------------------------