├── .gitignore ├── README.md ├── fake_quant ├── README.md ├── data_utils.py ├── eval_utils.py ├── gptaq_utils.py ├── gptq_utils.py ├── hadamard_utils.py ├── main.py ├── model_utils.py ├── monkeypatch.py ├── quant_utils.py ├── requirements.txt ├── rotation_utils.py ├── run_llama.sh └── utils.py ├── img └── readme_intro.png ├── spinquant ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SpinQuant.png ├── eval_utils │ ├── gptaq_utils.py │ ├── gptq_utils.py │ ├── main.py │ ├── modeling_llama.py │ └── rotation_utils.py ├── optimize_rotation.py ├── ptq.py ├── requirement.txt ├── scripts │ ├── 10_optimize_rotation.sh │ ├── 11_optimize_rotation_fsdp.sh │ ├── 2_eval_ptq.sh │ ├── 31_optimize_rotation_executorch.sh │ └── 32_eval_ptq_executorch.sh ├── train_utils │ ├── apply_r3_r4.py │ ├── fsdp_trainer.py │ ├── main.py │ ├── modeling_llama_quant.py │ ├── optimizer.py │ ├── quant_linear.py │ └── rtn_utils.py └── utils │ ├── convert_to_executorch.py │ ├── data_utils.py │ ├── eval_utils.py │ ├── fuse_norm_utils.py │ ├── hadamard_utils.py │ ├── model_utils.py │ ├── monkeypatch.py │ ├── process_args.py │ ├── quant_utils.py │ └── utils.py └── vit_quant ├── README.md ├── data_utils.py ├── eval_utils.py ├── gptaq_utils.py ├── gptq_utils.py ├── main.py ├── model_utils.py ├── quant_utils.py ├── run.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.jar 19 | *.rar 20 | 21 | # Logs and databases # 22 | ###################### 23 | *.log 24 | *.sql 25 | *.sqlite 26 | 27 | # OS generated files # 28 | ###################### 29 | .DS_Store 30 | .DS_Store? 31 | ._* 32 | *.idea* 33 | .Spotlight-V100 34 | .Trashes 35 | *.xml 36 | *.iml 37 | *.pyc 38 | ehthumbs.db 39 | Thumbs.db 40 | 41 | # Ignore ImageNet Pre-trained Models # 42 | ###################################### 43 | *ImageNet/checkpoint/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 |

GPTAQ: Efficient Finetuning-Free Quantization with Asymmetric Calibration [ICML 2025]

4 |

5 | 6 |

7 | 8 |

9 | 10 | The official pytorch implementation of GPTAQ. 11 | 12 | Unlike the previous GPTQ method, which independently calibrates each layer, we always match the quantized layer’s output to the exact output in the full-precision model, resulting in a scheme that we call asymmetric calibration. Such a scheme can effectively reduce the quantization error accumulated in previous layers. We analyze this problem using optimal brain compression to derive a close-formed solution. The new solution explicitly minimizes the quantization error as well as the accumulated asymmetry error. Furthermore, we utilize various techniques to parallelize the solution calculation, including channel parallelization, neuron decomposition, and Cholesky reformulation for matrix fusion. As a result, GPTAQ is easy to implement, simply using 20 more lines of code than GPTQ but improving its performance under low-bit quantization. 13 | 14 | ## Update: Name change to GPTAQ 15 | 16 | We are updating our code to the new name *`GPTAQ`* 17 | 18 | ## Update: GPTQv2 is integrated into GPTQModel 19 | 20 | The GPTQv2 method is integrated into [GPTQModel](https://github.com/ModelCloud/GPTQModel/tree/main) library, with a simple argument to perform. 21 | 22 | You can install GPTQModel: 23 | 24 | ```shell 25 | pip install -v gptqmodel --no-build-isolation 26 | ``` 27 | 28 | Quantize LLaMA3.1-8B-Instruct 29 | 30 | ```python 31 | import tempfile 32 | 33 | from datasets import load_dataset 34 | from gptqmodel import GPTQModel, QuantizeConfig 35 | from gptqmodel.quantization import FORMAT 36 | from gptqmodel.utils.eval import EVAL 37 | from logbar import LogBar 38 | 39 | log = LogBar.shared() 40 | 41 | MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" 42 | CFG_BITS = 4 43 | CFG_GROUPSIZE = 128 44 | CFG_V2 = True 45 | INPUTS_MAX_LENGTH = 2048 # in tokens 46 | QUANT_SAVE_PATH = f"/your_path/gptq_v2_{CFG_V2}_bit_{CFG_BITS}_gpsize_{CFG_GROUPSIZE}_llama_3.1_8B_Instruct" 47 | 48 | def get_calib_data(tokenizer, rows: int): 49 | 50 | calibration_dataset = load_dataset( 51 | "json", 52 | data_files="/your_path/dataset/c4-train.00000-of-01024.json.gz", 53 | split="train") 54 | 55 | datas = [] 56 | for index, sample in enumerate(calibration_dataset): 57 | tokenized = tokenizer(sample["text"]) 58 | if len(tokenized.data['input_ids']) <= INPUTS_MAX_LENGTH: 59 | datas.append(tokenized) 60 | if len(datas) >= rows: 61 | break 62 | 63 | return datas 64 | 65 | quant_config = QuantizeConfig( 66 | bits=CFG_BITS, 67 | group_size=CFG_GROUPSIZE, 68 | format=FORMAT.GPTQ, 69 | desc_act=True, 70 | sym=True, 71 | v2=CFG_V2, 72 | ) 73 | 74 | log.info(f"QuantConfig: {quant_config}") 75 | log.info(f"Save Path: {QUANT_SAVE_PATH}") 76 | 77 | # load un-quantized native model 78 | model = GPTQModel.load(MODEL_ID, quant_config) 79 | 80 | # load calibration data 81 | calibration_dataset = get_calib_data(tokenizer=model.tokenizer, rows=256) 82 | 83 | model.quantize(calibration_dataset, batch_size=1) 84 | 85 | model.save(QUANT_SAVE_PATH) 86 | log.info(f"Quant Model Saved to: {QUANT_SAVE_PATH}") 87 | ``` 88 | 89 | Evaluation on Arc_challenge and GSM8K: 90 | 91 | ```python 92 | # eval 93 | from lm_eval.tasks import TaskManager 94 | from lm_eval.utils import make_table 95 | 96 | with tempfile.TemporaryDirectory() as tmp_dir: 97 | results = GPTQModel.eval( 98 | QUANT_SAVE_PATH, 99 | tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.GSM8K_PLATINUM_COT], 100 | apply_chat_template=True, 101 | random_seed=898, 102 | output_path= tmp_dir, 103 | ) 104 | 105 | print(make_table(results)) 106 | if "groups" in results: 107 | print(make_table(results, "groups")) 108 | ``` 109 | 110 | 111 | Performance comparison (GPTQv2 outperforms GPTQ on GSM8K using 1 fewer bit): 112 | 113 | 114 | v1 ([checkpoints](https://huggingface.co/ModelCloud/GPTQ-v1-Llama-3.1-8B-Instruct 115 | )): 116 | 117 | 118 | | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| 119 | |------------------|------:|----------------|-----:|-----------|---|-----:|---|-----:| 120 | |arc_challenge| 1|none | 0|acc |↑ |0.5000|± |0.0146| 121 | | | |none | 0|acc_norm|↑ |0.5128|± |0.0146| 122 | |gsm8k_platinum_cot| 3|flexible-extract| 8|exact_match|↑ |0.3995|± |0.0141| 123 | | | |strict-match | 8|exact_match|↑ |0.2548|± |0.0125| 124 | 125 | 126 | v2 ([checkpoints](https://huggingface.co/ModelCloud/GPTQ-v2-Llama-3.1-8B-Instruct)): 127 | 128 | | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| 129 | |------------------|------:|----------------|-----:|-----------|---|-----:|---|-----:| 130 | |arc_challenge| 1|none | 0|acc |↑ |0.5034|± |0.0146| 131 | | | |none | 0|acc_norm|↑ |0.5068|± |0.0146| 132 | |gsm8k_platinum_cot| 3|flexible-extract| 8|exact_match|↑ |0.7601|± |0.0123| 133 | | | |strict-match | 8|exact_match|↑ |0.5211|± |0.0144| 134 | 135 | 136 | 137 | 138 | 139 | ## Code Structure 140 | 141 | We provide several directories to reproduce the paper results. 142 | 143 | 1. [**fake_quant**](./fake_quant) for reproducing QuaRot+GPTQ/GPTAQ 144 | 2. [**spinquant**](./spinquant) for reproducing SpinQuant+GPTQ/GPTAQ 145 | 3. [**vit_quant**](./vit_quant) for reproducing vision transformer quantization results 146 | 147 | [//]: # (4. **GPTQModel**, a forked version of GPTQModel to support GPTQv2 to deploy weight-only quantization model ) 148 | 149 | We recommend use separate envrionments for different experiments to ensure results are matched. 150 | 151 | 152 | ## Acknowledgement 153 | 154 | Our code is built upon several repository: 155 | 156 | [https://github.com/IST-DASLab/gptq](https://github.com/IST-DASLab/gptq) 157 | 158 | [https://github.com/spcl/QuaRot](https://github.com/spcl/QuaRot) 159 | 160 | [https://github.com/facebookresearch/SpinQuant/tree/main](https://github.com/facebookresearch/SpinQuant/tree/main) 161 | 162 | 163 | ## Star Histroy 164 | [![Star History Chart](https://api.star-history.com/svg?repos=Intelligent-Computing-Lab-Yale/GPTQv2&type=Date)](https://star-history.com/#Intelligent-Computing-Lab-Yale/GPTQv2) 165 | 166 | ## Contact 167 | 168 | Yuhang Li (*yuhang.li@yale.edu*) 169 | 170 | ## Citations 171 | 172 | If you find our work useful, please consider giving a star and citation: 173 | ```bibtex 174 | @article{li2025gptqv2, 175 | title={GPTAQ: Efficient Finetuning-Free Quantization for Asymmetric Calibration}, 176 | author={Yuhang Li and Ruokai Yin and Donghyun Lee and Shiting Xiao and Priyadarshini Panda}, 177 | year={2025}, 178 | journal={arXiv preprint arXiv:2504.02692}, 179 | } 180 | ``` -------------------------------------------------------------------------------- /fake_quant/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Fake Quantization with QuaRot 4 | 5 | This is code is developed based on [QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs](https://github.com/spcl/QuaRot) 6 | 7 | ## Installation 8 | 9 | We recommend installing the python envrionments with original codebase's requirements: 10 | 11 | ```bash 12 | conda create -n quarot python=3.9 13 | conda activate quarot 14 | pip install -r requirements.txt 15 | ``` 16 | Additionally, to apply Hadamard transformation, build [fast-hadamard-transform](https://github.com/Dao-AILab/fast-hadamard-transform) package from source. 17 | 18 | 19 | ## Language Generation and Zero-Shot Evaluations 20 | 21 | Currently, this code supports **LLaMa-2 and LLaMA-3** models (We did not test OPT and LLaMA-1). 22 | You can simply run the `main.py` to reproduce the results in the paper. The important arguments are: 23 | 24 | - `--model`: the model name (or path to the weights) 25 | - `--bsz`: the batch size for PPL evaluation 26 | - `--rotate`: whether we want to rotate the model 27 | - `--lm_eval`: whether we want to run LM-Eval for Zero-Shot tasks 28 | - `--tasks`: the tasks for LM-Eval 29 | - `--cal_dataset`: the calibration dataset for GPTQ quantization 30 | - `--a_bits`: the number of bits for activation quantization 31 | - `--w_bits`: the number of bits for weight quantization 32 | - `--v_bits`: the number of bits for value quantization 33 | - `--k_bits`: the number of bits for key quantization 34 | - `--w_clip`: Whether we want to clip the weights 35 | - `--a_clip_ratio`: The ratio of clipping for activation 36 | - `--k_clip_ratio`: The ratio of clipping for key 37 | - `--v_clip_ratio`: The ratio of clipping for value 38 | - `--w_asym`: Whether we want to use asymmetric quantization for weights 39 | - `--a_asym`: Whether we want to use asymmetric quantization for activation 40 | - `--v_asym`: Whether we want to use asymmetric quantization for value 41 | - `--k_asym`: Whether we want to use asymmetric quantization for key 42 | - `--a_groupsize`: The group size for activation quantization 43 | - `--w_groupsize`: The group size for weight quantization 44 | - `--v_groupsize`: The group size for value quantization 45 | - `--k_groupsize`: The group size for key quantization 46 | - `--use_v2`: Turn on GPTQv2 quantization (recommened) 47 | - `--enable_aq_calibration`: Activation quantization during calibration (recommened) 48 | 49 | 50 | We provide a script `run_llama.sh` to reproduce the results. 51 | 52 | -------------------------------------------------------------------------------- /fake_quant/data_utils.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import random 3 | import transformers 4 | 5 | def get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode=False): 6 | 7 | if hf_token is None: 8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 9 | else: 10 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 11 | 12 | if eval_mode: 13 | testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 15 | return testenc 16 | else: 17 | traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 18 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 19 | random.seed(seed) 20 | trainloader = [] 21 | for _ in range(nsamples): 22 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 23 | j = i + seqlen 24 | inp = trainenc.input_ids[:, i:j] 25 | tar = inp.clone() 26 | tar[:, :-1] = -100 27 | trainloader.append((inp, tar)) 28 | return trainloader 29 | 30 | def get_c4_new(nsamples, seed, seqlen, model, hf_token=None, eval_mode=False): 31 | 32 | if hf_token is None: 33 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 34 | else: 35 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 36 | 37 | if eval_mode: 38 | valdata = datasets.load_dataset( 39 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 40 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 41 | valenc = valenc.input_ids[:, :(256 * seqlen)] 42 | class TokenizerWrapper: 43 | def __init__(self, input_ids): 44 | self.input_ids = input_ids 45 | valenc = TokenizerWrapper(valenc) 46 | return valenc 47 | else: 48 | traindata = datasets.load_dataset( 49 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 50 | 51 | random.seed(seed) 52 | trainloader = [] 53 | for _ in range(nsamples): 54 | while True: 55 | i = random.randint(0, len(traindata) - 1) 56 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 57 | if trainenc.input_ids.shape[1] >= seqlen: 58 | break 59 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 60 | j = i + seqlen 61 | inp = trainenc.input_ids[:, i:j] 62 | tar = inp.clone() 63 | tar[:, :-1] = -100 64 | trainloader.append((inp, tar)) 65 | return trainloader 66 | 67 | 68 | 69 | def get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode=False): 70 | 71 | 72 | if hf_token is None: 73 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 74 | else: 75 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 76 | 77 | if eval_mode: 78 | testdata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='test') 79 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 80 | return testenc 81 | else: 82 | traindata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='train') 83 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 84 | random.seed(seed) 85 | trainloader = [] 86 | for _ in range(nsamples): 87 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 88 | j = i + seqlen 89 | inp = trainenc.input_ids[:, i:j] 90 | tar = inp.clone() 91 | tar[:, :-1] = -100 92 | trainloader.append((inp, tar)) 93 | return trainloader 94 | 95 | 96 | def get_loaders( 97 | name, nsamples=128, seed=0, seqlen=2048, model='', hf_token=None, eval_mode=False 98 | ): 99 | if 'wikitext2' in name: 100 | return get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode) 101 | if 'ptb' in name: 102 | return get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode) 103 | if 'c4' in name: 104 | return get_c4_new(nsamples, seed, seqlen, model, hf_token, eval_mode) -------------------------------------------------------------------------------- /fake_quant/eval_utils.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import model_utils 3 | import quant_utils 4 | import torch 5 | import os 6 | import logging 7 | from tqdm import tqdm 8 | 9 | 10 | @torch.no_grad() 11 | def evaluator(model, testenc, dev, args): 12 | 13 | model.eval() 14 | 15 | if 'opt' in args.model: 16 | opt_type = True 17 | llama_type = False 18 | elif 'meta' in args.model: 19 | llama_type = True 20 | opt_type = False 21 | else: 22 | raise ValueError(f'Unknown model {args.model}') 23 | 24 | use_cache = model.config.use_cache 25 | model.config.use_cache = False 26 | 27 | if opt_type: 28 | layers = model.model.decoder.layers 29 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 30 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 31 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 32 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 33 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 34 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 35 | 36 | elif llama_type: 37 | layers = model.model.layers 38 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 39 | 40 | layers[0] = layers[0].to(dev) 41 | 42 | # Convert the whole text of evaluation dataset into batches of sequences. 43 | input_ids = testenc.input_ids # (1, text_len) 44 | nsamples = input_ids.numel() // model.seqlen # The tail is truncated. 45 | input_ids = input_ids[:, :nsamples * model.seqlen].view(nsamples, model.seqlen).to(dev) # (nsamples, seqlen) 46 | 47 | batch_size = args.bsz 48 | input_ids = [input_ids[i:i + batch_size] for i in range(0, nsamples, batch_size)] 49 | nbatches = len(input_ids) 50 | 51 | dtype = next(iter(model.parameters())).dtype 52 | # The input of the first decoder layer. 53 | inps = torch.zeros( 54 | (nbatches, batch_size, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 55 | ) 56 | inps = [0] * nbatches 57 | cache = {'i': 0, 'attention_mask': None} 58 | class Catcher(torch.nn.Module): 59 | def __init__(self, module): 60 | super().__init__() 61 | self.module = module 62 | def forward(self, inp, **kwargs): 63 | inps[cache['i']] = inp 64 | cache['i'] += 1 65 | cache['attention_mask'] = kwargs['attention_mask'] 66 | if llama_type: 67 | cache['position_ids'] = kwargs['position_ids'] 68 | raise ValueError 69 | layers[0] = Catcher(layers[0]) 70 | 71 | for i in range(nbatches): 72 | batch = input_ids[i] 73 | try: 74 | model(batch) 75 | except ValueError: 76 | pass 77 | layers[0] = layers[0].module 78 | layers[0] = layers[0].cpu() 79 | 80 | if opt_type: 81 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 82 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 83 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 84 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 85 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 86 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 87 | elif llama_type: 88 | model.model.embed_tokens = model.model.embed_tokens.cpu() 89 | position_ids = cache['position_ids'] 90 | 91 | torch.cuda.empty_cache() 92 | outs = [0] * nbatches 93 | attention_mask = cache['attention_mask'] 94 | 95 | for i in tqdm(range(len(layers)), desc="(Eval) Layers"): 96 | layer = layers[i].to(dev) 97 | 98 | # Dump the layer input and output 99 | if args.capture_layer_io and args.layer_idx == i: 100 | captured_io = model_utils.capture_layer_io(model_utils.get_model_type(model), layer, inps) 101 | save_path = model_utils.get_layer_io_save_path(args) 102 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 103 | torch.save(captured_io, save_path) 104 | logging.info(f'Dumped layer input and output to: {save_path}') 105 | 106 | for j in range(nbatches): 107 | if opt_type: 108 | outs[j] = layer(inps[j], attention_mask=attention_mask)[0] 109 | elif llama_type: 110 | outs[j] = layer(inps[j], attention_mask=attention_mask, position_ids=position_ids)[0] 111 | layers[i] = layer.cpu() 112 | del layer 113 | torch.cuda.empty_cache() 114 | inps, outs = outs, inps 115 | 116 | if opt_type: 117 | if model.model.decoder.final_layer_norm is not None: 118 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 119 | if model.model.decoder.project_out is not None: 120 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 121 | 122 | elif llama_type: 123 | if model.model.norm is not None: 124 | model.model.norm = model.model.norm.to(dev) 125 | 126 | model.lm_head = model.lm_head.to(dev) 127 | nlls = [] 128 | loss_fct = torch.nn.CrossEntropyLoss(reduction = "none") 129 | for i in range(nbatches): 130 | hidden_states = inps[i] 131 | if opt_type: 132 | if model.model.decoder.final_layer_norm is not None: 133 | hidden_states = model.model.decoder.final_layer_norm(hidden_states) 134 | if model.model.decoder.project_out is not None: 135 | hidden_states = model.model.decoder.project_out(hidden_states) 136 | elif llama_type: 137 | if model.model.norm is not None: 138 | hidden_states = model.model.norm(hidden_states) 139 | lm_logits = model.lm_head(hidden_states) 140 | shift_logits = lm_logits[:, :-1, :] 141 | shift_labels = input_ids[i][:, 1:] 142 | loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels) 143 | neg_log_likelihood = loss.float().mean(dim=1) 144 | nlls.append(neg_log_likelihood) 145 | nlls_tensor = torch.cat(nlls) 146 | ppl = torch.exp(nlls_tensor.mean()) 147 | model.config.use_cache = use_cache 148 | logging.info(f'\n{args.eval_dataset.upper()} PPL: {ppl.item():.3f}') 149 | return ppl.item() 150 | -------------------------------------------------------------------------------- /fake_quant/gptaq_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import utils 7 | import quant_utils 8 | import model_utils 9 | import logging 10 | import functools 11 | 12 | torch.backends.cuda.matmul.allow_tf32 = False 13 | torch.backends.cudnn.allow_tf32 = False 14 | 15 | class GPTAQ: 16 | 17 | def __init__(self, layer): 18 | self.layer = layer 19 | self.dev = self.layer.weight.device 20 | W = layer.weight.data.clone() 21 | self.rows = W.shape[0] 22 | self.columns = W.shape[1] 23 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 24 | self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev) 25 | self.nsamples = 0 26 | self.fp_inp = [] 27 | 28 | def add_batch(self, inp, out): 29 | 30 | if len(inp.shape) == 2: 31 | inp = inp.unsqueeze(0) 32 | tmp = inp.shape[0] 33 | if len(inp.shape) == 3: 34 | inp = inp.reshape((-1, inp.shape[-1])) 35 | 36 | inp = inp.t() 37 | 38 | self.H *= self.nsamples / (self.nsamples + tmp) 39 | self.dXXT *= self.nsamples / (self.nsamples + tmp) 40 | self.nsamples += tmp 41 | inp = math.sqrt(2 / self.nsamples) * inp.float() 42 | self.H += inp.matmul(inp.t()) 43 | dX = self.fp_inp[0].float() * math.sqrt(2 / self.nsamples) - inp 44 | self.dXXT += dX.matmul(inp.t()) 45 | 46 | del self.fp_inp[0] 47 | 48 | def fasterquant( 49 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False, alpha=0.25 50 | ): 51 | W = self.layer.weight.data.clone() 52 | W = W.float() 53 | 54 | if not self.quantizer.ready(): 55 | self.quantizer.find_params(W) 56 | 57 | H = self.H 58 | del self.H 59 | dead = torch.diag(H) == 0 60 | H[dead, dead] = 1 61 | W[:, dead] = 0 62 | self.dXXT[:, dead] = 0 63 | 64 | if static_groups: 65 | import copy 66 | groups = [] 67 | for i in range(0, self.columns, groupsize): 68 | quantizer = copy.deepcopy(self.quantizer) 69 | quantizer.find_params(W[:, i:(i + groupsize)]) 70 | groups.append(quantizer) 71 | 72 | if actorder: 73 | perm = torch.argsort(torch.diag(H), descending=True) 74 | W = W[:, perm] 75 | H = H[perm][:, perm] 76 | self.dXXT = self.dXXT[perm][:, perm] 77 | invperm = torch.argsort(perm) 78 | 79 | Losses = torch.zeros_like(W) 80 | Q = torch.zeros_like(W) 81 | 82 | damp = percdamp * torch.mean(torch.diag(H)) 83 | diag = torch.arange(self.columns, device=self.dev) 84 | H[diag, diag] += damp 85 | Hinv = torch.linalg.cholesky(H) 86 | Hinv = torch.cholesky_inverse(Hinv) 87 | Hinv = torch.linalg.cholesky(Hinv, upper=True) 88 | 89 | # scale it by alpha due to collection of dXXT axnd H 90 | P = alpha * ((self.dXXT @ Hinv.T).triu_(diagonal=1)) @ Hinv 91 | del self.dXXT 92 | 93 | for i1 in range(0, self.columns, blocksize): 94 | i2 = min(i1 + blocksize, self.columns) 95 | count = i2 - i1 96 | 97 | W1 = W[:, i1:i2].clone() 98 | Q1 = torch.zeros_like(W1) 99 | Err1 = torch.zeros_like(W1) 100 | Losses1 = torch.zeros_like(W1) 101 | Hinv1 = Hinv[i1:i2, i1:i2] 102 | P1 = P[i1:i2, i1:i2] 103 | 104 | for i in range(count): 105 | w = W1[:, i] 106 | d = Hinv1[i, i] 107 | 108 | if groupsize != -1: 109 | if not static_groups: 110 | if (i1 + i) % groupsize == 0: 111 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)]) 112 | else: 113 | idx = i1 + i 114 | if actorder: 115 | idx = perm[idx] 116 | self.quantizer = groups[idx // groupsize] 117 | 118 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten() 119 | Q1[:, i] = q 120 | Losses1[:, i] = (w - q) ** 2 / d ** 2 121 | 122 | err1 = (w - q) / d 123 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0)) 124 | Err1[:, i] = err1 125 | 126 | Q[:, i1:i2] = Q1 127 | Losses[:, i1:i2] = Losses1 / 2 128 | 129 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:]) 130 | 131 | torch.cuda.synchronize() 132 | 133 | if actorder: 134 | Q = Q[:, invperm] 135 | 136 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 137 | if torch.any(torch.isnan(self.layer.weight.data)): 138 | logging.warning('NaN in weights') 139 | import pprint 140 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point) 141 | raise ValueError('NaN in weights') 142 | 143 | def free(self): 144 | self.H = None 145 | self.Losses = None 146 | self.Trace = None 147 | self.dXXT = None 148 | torch.cuda.empty_cache() 149 | utils.cleanup_memory(verbos=False) 150 | 151 | 152 | @torch.no_grad() 153 | def gptaq_fwrd(model, dataloader, dev, args): 154 | ''' 155 | From GPTQ repo 156 | TODO: Make this function general to support both OPT and LLaMA models 157 | ''' 158 | logging.info('-----GPTAQ Quantization-----') 159 | 160 | use_cache = model.config.use_cache 161 | model.config.use_cache = False 162 | layers = model.model.layers 163 | 164 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 165 | model.model.norm = model.model.norm.to(dev) 166 | # model.model.rotary_emb = model.model.rotary_emb.to(dev) 167 | 168 | layers[0] = layers[0].to(dev) 169 | 170 | dtype = next(iter(model.parameters())).dtype 171 | inps = torch.zeros( 172 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 173 | ) 174 | 175 | cache = {'i': 0, 'attention_mask': None} 176 | 177 | class Catcher(nn.Module): 178 | def __init__(self, module): 179 | super().__init__() 180 | self.module = module 181 | 182 | def forward(self, inp, **kwargs): 183 | inps[cache['i']] = inp 184 | cache['i'] += 1 185 | cache['attention_mask'] = kwargs['attention_mask'] 186 | cache['position_ids'] = kwargs['position_ids'] 187 | raise ValueError 188 | 189 | layers[0] = Catcher(layers[0]) 190 | for batch in dataloader: 191 | try: 192 | model(batch[0].to(dev)) 193 | except ValueError: 194 | pass 195 | layers[0] = layers[0].module 196 | 197 | layers[0] = layers[0].cpu() 198 | model.model.embed_tokens = model.model.embed_tokens.cpu() 199 | model.model.norm = model.model.norm.cpu() 200 | torch.cuda.empty_cache() 201 | 202 | outs = torch.zeros_like(inps) 203 | 204 | attention_mask = cache['attention_mask'] 205 | position_ids = cache['position_ids'] 206 | 207 | quantizers = {} 208 | sequential = [ 209 | ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'], 210 | ['self_attn.o_proj.module'], 211 | ['mlp.up_proj.module', 'mlp.gate_proj.module'], 212 | ['mlp.down_proj.module'] 213 | ] 214 | 215 | fp_inputs_cache = model_utils.FPInputsCache(sequential) 216 | fp_inps = inps.clone() 217 | 218 | for i in range(len(layers)): 219 | print(f'\nLayer {i}:', flush=True, end=' ') 220 | layer = layers[i].to(dev) 221 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear]) 222 | 223 | bits_config = quant_utils.disable_act_quant(layer) 224 | fp_inputs_cache.add_hook(full) 225 | 226 | for j in range(args.nsamples): 227 | fp_inps[j] = layer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 228 | fp_inputs_cache.clear_hook() 229 | quant_utils.enable_act_quant(layer, bits_config) 230 | 231 | for names in sequential: 232 | subset = {n: full[n] for n in names} 233 | 234 | gptq = {} 235 | for name in subset: 236 | print(f'{name}', end=' ', flush=True) 237 | layer_weight_bits = args.w_bits 238 | layer_weight_sym = not (args.w_asym) 239 | if 'lm_head' in name: 240 | layer_weight_bits = 16 241 | continue 242 | if args.int8_down_proj and 'down_proj' in name: 243 | layer_weight_bits = 8 244 | gptq[name] = GPTAQ(subset[name]) 245 | gptq[name].quantizer = quant_utils.WeightQuantizer() 246 | gptq[name].quantizer.configure( 247 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip 248 | ) 249 | gptq[name].fp_inp = fp_inputs_cache.fp_cache[name] 250 | 251 | def add_batch(name): 252 | def tmp(_, inp, out): 253 | gptq[name].add_batch(inp[0].data, out.data) 254 | 255 | return tmp 256 | 257 | first_module_name = list(subset.keys())[0] 258 | handle = subset[first_module_name].register_forward_hook(add_batch(first_module_name)) 259 | 260 | for j in range(args.nsamples): 261 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 262 | handle.remove() 263 | 264 | # copy H and dXXT 265 | for name in subset: 266 | if name != first_module_name: 267 | gptq[name].H = gptq[first_module_name].H 268 | gptq[name].dXXT = gptq[first_module_name].dXXT 269 | 270 | for name in subset: 271 | layer_w_groupsize = args.w_groupsize 272 | gptq[name].fasterquant( 273 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, 274 | static_groups=args.static_groups 275 | ) 276 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer 277 | gptq[name].free() 278 | 279 | for j in range(args.nsamples): 280 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 281 | 282 | fp_inputs_cache.clear_cache() 283 | layers[i] = layer.cpu() 284 | del layer 285 | del gptq 286 | torch.cuda.empty_cache() 287 | 288 | inps, outs = outs, inps 289 | 290 | model.config.use_cache = use_cache 291 | utils.cleanup_memory(verbos=True) 292 | logging.info('-----GPTAQ Quantization Done-----\n') 293 | 294 | return quantizers 295 | -------------------------------------------------------------------------------- /fake_quant/gptq_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import utils 7 | import quant_utils 8 | import logging 9 | 10 | torch.backends.cuda.matmul.allow_tf32 = False 11 | torch.backends.cudnn.allow_tf32 = False 12 | 13 | 14 | class GPTQ: 15 | 16 | def __init__(self, layer): 17 | self.layer = layer 18 | self.dev = self.layer.weight.device 19 | W = layer.weight.data.clone() 20 | self.rows = W.shape[0] 21 | self.columns = W.shape[1] 22 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 23 | self.nsamples = 0 24 | 25 | def add_batch(self, inp, out): 26 | 27 | if len(inp.shape) == 2: 28 | inp = inp.unsqueeze(0) 29 | tmp = inp.shape[0] 30 | if len(inp.shape) == 3: 31 | inp = inp.reshape((-1, inp.shape[-1])) 32 | inp = inp.t() 33 | self.H *= self.nsamples / (self.nsamples + tmp) 34 | self.nsamples += tmp 35 | # inp = inp.float() 36 | inp = math.sqrt(2 / self.nsamples) * inp.float() 37 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 38 | self.H += inp.matmul(inp.t()) 39 | 40 | def fasterquant( 41 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False 42 | ): 43 | W = self.layer.weight.data.clone() 44 | W = W.float() 45 | 46 | tick = time.time() 47 | 48 | if not self.quantizer.ready(): 49 | self.quantizer.find_params(W) 50 | 51 | H = self.H 52 | del self.H 53 | dead = torch.diag(H) == 0 54 | H[dead, dead] = 1 55 | W[:, dead] = 0 56 | 57 | if static_groups: 58 | import copy 59 | groups = [] 60 | for i in range(0, self.columns, groupsize): 61 | quantizer = copy.deepcopy(self.quantizer) 62 | quantizer.find_params(W[:, i:(i + groupsize)]) 63 | groups.append(quantizer) 64 | 65 | if actorder: 66 | perm = torch.argsort(torch.diag(H), descending=True) 67 | W = W[:, perm] 68 | H = H[perm][:, perm] 69 | invperm = torch.argsort(perm) 70 | 71 | Losses = torch.zeros_like(W) 72 | Q = torch.zeros_like(W) 73 | 74 | damp = percdamp * torch.mean(torch.diag(H)) 75 | diag = torch.arange(self.columns, device=self.dev) 76 | H[diag, diag] += damp 77 | H = torch.linalg.cholesky(H) 78 | H = torch.cholesky_inverse(H) 79 | H = torch.linalg.cholesky(H, upper=True) 80 | Hinv = H 81 | 82 | for i1 in range(0, self.columns, blocksize): 83 | i2 = min(i1 + blocksize, self.columns) 84 | count = i2 - i1 85 | 86 | W1 = W[:, i1:i2].clone() 87 | Q1 = torch.zeros_like(W1) 88 | Err1 = torch.zeros_like(W1) 89 | Losses1 = torch.zeros_like(W1) 90 | Hinv1 = Hinv[i1:i2, i1:i2] 91 | 92 | for i in range(count): 93 | w = W1[:, i] 94 | d = Hinv1[i, i] 95 | 96 | if groupsize != -1: 97 | if not static_groups: 98 | if (i1 + i) % groupsize == 0: 99 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)]) 100 | else: 101 | idx = i1 + i 102 | if actorder: 103 | idx = perm[idx] 104 | self.quantizer = groups[idx // groupsize] 105 | 106 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten() 107 | Q1[:, i] = q 108 | Losses1[:, i] = (w - q) ** 2 / d ** 2 109 | 110 | err1 = (w - q) / d 111 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 112 | Err1[:, i] = err1 113 | 114 | Q[:, i1:i2] = Q1 115 | Losses[:, i1:i2] = Losses1 / 2 116 | 117 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 118 | 119 | torch.cuda.synchronize() 120 | 121 | if actorder: 122 | Q = Q[:, invperm] 123 | 124 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 125 | if torch.any(torch.isnan(self.layer.weight.data)): 126 | logging.warning('NaN in weights') 127 | import pprint 128 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point) 129 | raise ValueError('NaN in weights') 130 | 131 | def free(self): 132 | self.H = None 133 | self.Losses = None 134 | self.Trace = None 135 | torch.cuda.empty_cache() 136 | utils.cleanup_memory(verbos=False) 137 | 138 | 139 | @torch.no_grad() 140 | def gptq_fwrd(model, dataloader, dev, args): 141 | ''' 142 | From GPTQ repo 143 | TODO: Make this function general to support both OPT and LLaMA models 144 | ''' 145 | logging.info('-----GPTQ Quantization-----') 146 | 147 | use_cache = model.config.use_cache 148 | model.config.use_cache = False 149 | layers = model.model.layers 150 | 151 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 152 | model.model.norm = model.model.norm.to(dev) 153 | layers[0] = layers[0].to(dev) 154 | 155 | dtype = next(iter(model.parameters())).dtype 156 | inps = torch.zeros( 157 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 158 | ) 159 | cache = {'i': 0, 'attention_mask': None} 160 | 161 | class Catcher(nn.Module): 162 | def __init__(self, module): 163 | super().__init__() 164 | self.module = module 165 | def forward(self, inp, **kwargs): 166 | inps[cache['i']] = inp 167 | cache['i'] += 1 168 | cache['attention_mask'] = kwargs['attention_mask'] 169 | cache['position_ids'] = kwargs['position_ids'] 170 | raise ValueError 171 | 172 | layers[0] = Catcher(layers[0]) 173 | for batch in dataloader: 174 | try: 175 | model(batch[0].to(dev)) 176 | except ValueError: 177 | pass 178 | layers[0] = layers[0].module 179 | 180 | layers[0] = layers[0].cpu() 181 | model.model.embed_tokens = model.model.embed_tokens.cpu() 182 | model.model.norm = model.model.norm.cpu() 183 | torch.cuda.empty_cache() 184 | 185 | outs = torch.zeros_like(inps) 186 | attention_mask = cache['attention_mask'] 187 | position_ids = cache['position_ids'] 188 | 189 | quantizers = {} 190 | sequential = [ 191 | ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'], 192 | ['self_attn.o_proj.module'], 193 | ['mlp.up_proj.module', 'mlp.gate_proj.module'], 194 | ['mlp.down_proj.module'] 195 | ] 196 | for i in range(len(layers)): 197 | print(f'\nLayer {i}:', flush=True, end=' ') 198 | layer = layers[i].to(dev) 199 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear]) 200 | for names in sequential: 201 | subset = {n: full[n] for n in names} 202 | 203 | gptq = {} 204 | for name in subset: 205 | print(f'{name}', end=' ', flush=True) 206 | layer_weight_bits = args.w_bits 207 | layer_weight_sym = not(args.w_asym) 208 | if 'lm_head' in name: 209 | layer_weight_bits = 16 210 | continue 211 | if args.int8_down_proj and 'down_proj' in name: 212 | layer_weight_bits = 8 213 | gptq[name] = GPTQ(subset[name]) 214 | gptq[name].quantizer = quant_utils.WeightQuantizer() 215 | gptq[name].quantizer.configure( 216 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip 217 | ) 218 | 219 | def add_batch(name): 220 | def tmp(_, inp, out): 221 | gptq[name].add_batch(inp[0].data, out.data) 222 | return tmp 223 | handles = [] 224 | for name in subset: 225 | handles.append(subset[name].register_forward_hook(add_batch(name))) 226 | for j in range(args.nsamples): 227 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 228 | for h in handles: 229 | h.remove() 230 | 231 | for name in subset: 232 | layer_w_groupsize = args.w_groupsize 233 | gptq[name].fasterquant( 234 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False 235 | ) 236 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer 237 | gptq[name].free() 238 | 239 | for j in range(args.nsamples): 240 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 241 | 242 | layers[i] = layer.cpu() 243 | del layer 244 | del gptq 245 | torch.cuda.empty_cache() 246 | 247 | inps, outs = outs, inps 248 | 249 | model.config.use_cache = use_cache 250 | utils.cleanup_memory(verbos=True) 251 | logging.info('-----GPTQ Quantization Done-----\n') 252 | return quantizers 253 | 254 | 255 | 256 | 257 | @torch.no_grad() 258 | def rtn_fwrd(model, dev, args): 259 | ''' 260 | From GPTQ repo 261 | TODO: Make this function general to support both OPT and LLaMA models 262 | ''' 263 | assert args.w_groupsize ==-1, "Groupsize not supported in RTN!" 264 | layers = model.model.layers 265 | torch.cuda.empty_cache() 266 | 267 | quantizers = {} 268 | 269 | for i in tqdm.tqdm(range(len(layers)), desc="(RtN Quant.) Layers"): 270 | layer = layers[i].to(dev) 271 | 272 | subset = quant_utils.find_qlayers(layer, 273 | layers=[torch.nn.Linear]) 274 | 275 | for name in subset: 276 | layer_weight_bits = args.w_bits 277 | if 'lm_head' in name: 278 | layer_weight_bits = 16 279 | continue 280 | if args.int8_down_proj and 'down_proj' in name: 281 | layer_weight_bits = 8 282 | 283 | quantizer = quant_utils.WeightQuantizer() 284 | quantizer.configure( 285 | layer_weight_bits, perchannel=True, sym=not(args.w_asym), mse=args.w_clip 286 | ) 287 | W = subset[name].weight.data 288 | quantizer.find_params(W) 289 | subset[name].weight.data = quantizer.quantize(W).to( 290 | next(iter(layer.parameters())).dtype) 291 | quantizers['model.layers.%d.%s' % (i, name)] = quantizer.cpu() 292 | layers[i] = layer.cpu() 293 | torch.cuda.empty_cache() 294 | del layer 295 | 296 | utils.cleanup_memory(verbos=True) 297 | return quantizers 298 | -------------------------------------------------------------------------------- /fake_quant/main.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import torch 3 | import model_utils 4 | import data_utils 5 | import transformers 6 | import quant_utils 7 | import rotation_utils 8 | import gptq_utils 9 | import gptaq_utils 10 | import eval_utils 11 | import hadamard_utils 12 | 13 | 14 | def add_aq(model, args): 15 | # Add Input Quantization 16 | if args.a_bits < 16 or args.v_bits < 16: 17 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper]) 18 | down_proj_groupsize = -1 19 | if args.a_groupsize > 0 and "llama" in args.model: 20 | down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize) 21 | 22 | for name in qlayers: 23 | layer_input_bits = args.a_bits 24 | layer_groupsize = args.a_groupsize 25 | layer_a_sym = not(args.a_asym) 26 | layer_a_clip = args.a_clip_ratio 27 | 28 | if 'v_proj' in name and args.v_bits < 16: #Set the v_proj precision 29 | qlayers[name].out_quantizer.configure(bits=args.v_bits, 30 | groupsize=args.v_groupsize, 31 | sym=not(args.v_asym), 32 | clip_ratio=args.v_clip_ratio) 33 | 34 | if 'lm_head' in name: #Skip lm_head quantization 35 | layer_input_bits = 16 36 | 37 | if 'down_proj' in name: #Set the down_proj precision 38 | if args.int8_down_proj: 39 | layer_input_bits = 8 40 | layer_groupsize = down_proj_groupsize 41 | 42 | qlayers[name].quantizer.configure(bits=layer_input_bits, 43 | groupsize=layer_groupsize, 44 | sym=layer_a_sym, 45 | clip_ratio=layer_a_clip) 46 | 47 | if args.k_bits < 16: 48 | if args.k_pre_rope: 49 | raise NotImplementedError("Pre-RoPE quantization is not supported yet!") 50 | else: 51 | rope_function_name = model_utils.get_rope_function_name(model) 52 | layers = model_utils.get_layers(model) 53 | k_quant_config = {'k_bits':args.k_bits, "k_groupsize": args.k_groupsize, 54 | "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio} 55 | for layer in layers: 56 | rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward( 57 | layer.self_attn, 58 | rope_function_name, 59 | config=model.config, 60 | **k_quant_config) 61 | 62 | 63 | def main(): 64 | args = utils.parser_gen() 65 | if args.wandb: 66 | import wandb 67 | wandb.init(project=args.wandb_project, entity=args.wandb_id) 68 | wandb.config.update(args) 69 | 70 | transformers.set_seed(args.seed) 71 | model = model_utils.get_model(args.model, args.hf_token) 72 | model.eval() 73 | 74 | # Rotate the weights 75 | if args.rotate: 76 | rotation_utils.fuse_layer_norms(model) 77 | rotation_utils.rotate_model(model, args) 78 | utils.cleanup_memory(verbos=True) 79 | 80 | quant_utils.add_actquant(model) #Add Activation Wrapper to the model 81 | qlayers = quant_utils.find_qlayers(model) 82 | for name in qlayers: 83 | if 'down_proj' in name: 84 | had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size) 85 | qlayers[name].online_full_had = True 86 | qlayers[name].had_K = had_K 87 | qlayers[name].K = K 88 | qlayers[name].fp32_had = args.fp32_had 89 | if 'o_proj' in name: 90 | had_K, K = hadamard_utils.get_hadK(model.config.num_attention_heads) 91 | qlayers[name].online_partial_had = True 92 | qlayers[name].had_K = had_K 93 | qlayers[name].K = K 94 | qlayers[name].had_dim = model.config.hidden_size//model.config.num_attention_heads 95 | qlayers[name].fp32_had = args.fp32_had 96 | else: 97 | quant_utils.add_actquant(model) #Add Activation Wrapper to the model as the rest of the code assumes it is present 98 | 99 | if args.enable_aq_calibration: 100 | add_aq(model, args) 101 | 102 | if args.w_bits < 16: 103 | save_dict = {} 104 | if args.load_qmodel_path: # Load Quantized Rotated Model 105 | assert args.rotate, "Model should be rotated to load a quantized model!" 106 | assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!" 107 | print("Load quantized model from ", args.load_qmodel_path) 108 | save_dict = torch.load(args.load_qmodel_path) 109 | model.load_state_dict(save_dict["model"], strict=False) 110 | 111 | elif not args.w_rtn: # GPTQ Weight Quantization 112 | assert "llama" in args.model, "Only llama is supported for GPTQ!" 113 | 114 | trainloader = data_utils.get_loaders( 115 | args.cal_dataset, nsamples=args.nsamples, 116 | seed=args.seed, model=args.model, 117 | seqlen=model.seqlen, eval_mode=False 118 | ) 119 | if args.asym_calibrate: 120 | quantizers = gptaq_utils.gptaq_fwrd(model, trainloader, utils.DEV, args) 121 | save_dict["w_quantizers"] = quantizers 122 | else: 123 | quantizers = gptq_utils.gptq_fwrd(model, trainloader, utils.DEV, args) 124 | save_dict["w_quantizers"] = quantizers 125 | else: # RTN Weight Quantization 126 | quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args) 127 | save_dict["w_quantizers"] = quantizers 128 | 129 | if args.save_qmodel_path: 130 | save_dict["model"] = model.state_dict() 131 | torch.save(save_dict, args.save_qmodel_path) 132 | 133 | if not args.enable_aq_calibration: 134 | add_aq(model, args) 135 | 136 | # Evaluating on dataset 137 | testloader = data_utils.get_loaders( 138 | args.eval_dataset, 139 | seed=args.seed, 140 | model=args.model, 141 | seqlen=model.seqlen, 142 | hf_token=args.hf_token, 143 | eval_mode=True 144 | ) 145 | 146 | dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, args) 147 | 148 | if args.wandb: 149 | wandb.log({'ppl/{}'.format(args.eval_dataset.upper()): dataset_ppl}) 150 | 151 | if not args.lm_eval: 152 | return 153 | else: 154 | # Import lm_eval utils 155 | import lm_eval 156 | from lm_eval import utils as lm_eval_utils 157 | from lm_eval.api.registry import ALL_TASKS 158 | from lm_eval.models.huggingface import HFLM 159 | 160 | if args.distribute: 161 | utils.distribute_model(model) 162 | else: 163 | model.to(utils.DEV) 164 | 165 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=False, use_auth_token=args.hf_token) 166 | hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.lm_eval_batch_size) 167 | 168 | # commenting out this line as it will include two lambda sub-tasks 169 | # task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS) 170 | task_names = args.tasks 171 | results = lm_eval.simple_evaluate(hflm, tasks=task_names, batch_size=args.lm_eval_batch_size)['results'] 172 | 173 | metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results.items()} 174 | metric_vals['acc_avg'] = round(sum(metric_vals.values()) / len(metric_vals.values()), 4) 175 | print(metric_vals) 176 | 177 | if args.wandb: 178 | wandb.log(metric_vals) 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /fake_quant/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import typing 3 | import transformers 4 | import utils 5 | import os 6 | import logging 7 | import functools 8 | 9 | OPT_MODEL = transformers.models.opt.modeling_opt.OPTForCausalLM 10 | OPT_LAYER = transformers.models.opt.modeling_opt.OPTDecoderLayer 11 | LLAMA_MODEL = transformers.models.llama.modeling_llama.LlamaForCausalLM 12 | LLAMA_LAYER = transformers.models.llama.modeling_llama.LlamaDecoderLayer 13 | 14 | 15 | def model_type_extractor(model): 16 | if isinstance(model, LLAMA_MODEL): 17 | return LLAMA_MODEL 18 | elif isinstance(model, OPT_MODEL): 19 | return OPT_MODEL 20 | else: 21 | raise ValueError(f'Unknown model type {model}') 22 | 23 | def skip(*args, **kwargs): 24 | # This is a helper function to save time during the initialization! 25 | pass 26 | 27 | def get_rope_function_name(model): 28 | if isinstance(model, LLAMA_MODEL): 29 | return "apply_rotary_pos_emb" 30 | raise NotImplementedError 31 | 32 | 33 | def get_layers(model): 34 | if isinstance(model, OPT_MODEL): 35 | return model.model.decoder.layers 36 | if isinstance(model, LLAMA_MODEL): 37 | return model.model.layers 38 | raise NotImplementedError 39 | 40 | 41 | def get_llama(model_name, hf_token): 42 | torch.nn.init.kaiming_uniform_ = skip 43 | torch.nn.init.uniform_ = skip 44 | torch.nn.init.normal_ = skip 45 | model = transformers.LlamaForCausalLM.from_pretrained(model_name, torch_dtype='auto', 46 | use_auth_token=hf_token, 47 | low_cpu_mem_usage=True) 48 | model.seqlen = 2048 49 | logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen)) 50 | return model 51 | 52 | 53 | 54 | def get_opt(model_name): 55 | torch.nn.init.kaiming_uniform_ = skip 56 | torch.nn.init.uniform_ = skip 57 | torch.nn.init.normal_ = skip 58 | model = transformers.OPTForCausalLM.from_pretrained(model_name, torch_dtype='auto', 59 | low_cpu_mem_usage=True) 60 | model.seqlen = model.config.max_position_embeddings 61 | logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen)) 62 | return model 63 | 64 | 65 | def get_model( 66 | model_name, hf_token=None 67 | ): 68 | if 'llama' in model_name: 69 | return get_llama(model_name, hf_token) 70 | elif 'opt' in model_name: 71 | return get_opt(model_name) 72 | else: 73 | raise ValueError(f'Unknown model {model_name}') 74 | 75 | 76 | def get_model_type(model): 77 | if isinstance(model, OPT_MODEL): 78 | model_type = OPT_MODEL 79 | elif isinstance(model, LLAMA_MODEL): 80 | model_type = LLAMA_MODEL 81 | else: 82 | raise ValueError(f'Unknown model type {model}') 83 | return model_type 84 | 85 | def get_embeddings(model, model_type) -> list[torch.nn.Module]: 86 | if model_type == LLAMA_MODEL: 87 | return [model.model.embed_tokens] 88 | elif model_type == OPT_MODEL: 89 | return [model.model.decoder.embed_tokens, model.model.decoder.embed_positions] 90 | else: 91 | raise ValueError(f'Unknown model type {model_type}') 92 | 93 | 94 | def get_transformer_layers(model, model_type): 95 | if model_type == LLAMA_MODEL: 96 | return [layer for layer in model.model.layers] 97 | elif model_type == OPT_MODEL: 98 | return [layer for layer in model.model.decoder.layers] 99 | else: 100 | raise ValueError(f'Unknown model type {model_type}') 101 | 102 | 103 | def get_lm_head(model, model_type): 104 | if model_type == LLAMA_MODEL: 105 | return model.lm_head 106 | elif model_type == OPT_MODEL: 107 | return model.lm_head 108 | else: 109 | raise ValueError(f'Unknown model type {model_type}') 110 | 111 | def get_pre_head_layernorm(model, model_type): 112 | if model_type == LLAMA_MODEL: 113 | pre_head_layernorm = model.model.norm 114 | assert isinstance(pre_head_layernorm, 115 | transformers.models.llama.modeling_llama.LlamaRMSNorm) 116 | elif model_type == OPT_MODEL: 117 | pre_head_layernorm = model.model.decoder.final_layer_norm 118 | assert pre_head_layernorm is not None 119 | else: 120 | raise ValueError(f'Unknown model type {model_type}') 121 | return pre_head_layernorm 122 | 123 | def get_mlp_bottleneck_size(model): 124 | model_type = get_model_type(model) 125 | if model_type == LLAMA_MODEL: 126 | return model.config.intermediate_size 127 | elif model_type == OPT_MODEL: 128 | return model.config.ffn_dim 129 | else: 130 | raise ValueError(f'Unknown model type {model_type}') 131 | 132 | def replace_modules( 133 | root: torch.nn.Module, 134 | type_to_replace, 135 | new_module_factory, 136 | replace_layers: bool, 137 | ) -> None: 138 | """Replace modules of given type using the supplied module factory. 139 | 140 | Perform a depth-first search of a module hierarchy starting at root 141 | and replace all instances of type_to_replace with modules created by 142 | new_module_factory. Children of replaced modules are not processed. 143 | 144 | Args: 145 | root: the root of the module hierarchy where modules should be replaced 146 | type_to_replace: a type instances of which will be replaced 147 | new_module_factory: a function that given a module that should be replaced 148 | produces a module to replace it with. 149 | """ 150 | for name, module in root.named_children(): 151 | new_module = None 152 | if isinstance(module, type_to_replace): 153 | if replace_layers: # layernorm_fusion.replace_layers case where transformer layers are replaced 154 | new_module = new_module_factory(module, int(name)) 155 | else: # layernorm_fusion.fuse_modules case where layernorms are fused 156 | new_module = new_module_factory(module) 157 | elif len(list(module.children())) > 0: 158 | replace_modules(module, type_to_replace, new_module_factory, replace_layers) 159 | 160 | if new_module is not None: 161 | setattr(root, name, new_module) 162 | 163 | 164 | class RMSN(torch.nn.Module): 165 | """ 166 | This class implements the Root Mean Square Normalization (RMSN) layer. 167 | We use the implementation from LLAMARMSNorm here: 168 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L75 169 | """ 170 | 171 | def __init__(self, mean_dim: int, eps=1e-5): 172 | super().__init__() 173 | self.eps = eps 174 | self.mean_dim = mean_dim 175 | self.weight = torch.nn.Parameter(torch.zeros(1)) 176 | 177 | def forward(self, x: torch.Tensor) -> torch.Tensor: 178 | input_dtype = x.dtype 179 | if x.dtype == torch.float16: 180 | x = x.to(torch.float32) 181 | variance = x.pow(2).sum(-1, keepdim=True) / self.mean_dim 182 | x = x * torch.rsqrt(variance + self.eps) 183 | return x.to(input_dtype) 184 | 185 | 186 | class FPInputsCache: 187 | """ 188 | class for saving the full-precision output in each layer. 189 | """ 190 | def __init__(self, sequential): 191 | self.fp_cache = {} 192 | self.names = sequential[0]+sequential[1]+sequential[2]+sequential[3] 193 | for name in self.names: 194 | self.fp_cache[name] = [] 195 | self.handles = [] 196 | 197 | def cache_fp_input(self, m, inp, out, name): 198 | inp = inp[0].detach() 199 | if len(inp.shape) == 3: 200 | inp = inp.reshape((-1, inp.shape[-1])) 201 | self.fp_cache[name] += [inp.t()] 202 | 203 | def add_hook(self, full): 204 | for name in self.names: 205 | self.handles.append( 206 | full[name].register_forward_hook( 207 | functools.partial(self.cache_fp_input, name=name) 208 | ) 209 | ) 210 | 211 | def clear_hook(self): 212 | for h in self.handles: 213 | h.remove() 214 | self.handles = [] 215 | torch.cuda.empty_cache() 216 | 217 | def clear_cache(self): 218 | for name in self.names: 219 | self.fp_cache[name] = [] 220 | 221 | 222 | def get_layer_io_save_path(args): 223 | return os.path.join(args.save_path, 'layer_io', f'{args.layer_idx:03d}.pt') 224 | 225 | 226 | def capture_layer_io(model_type, layer, layer_input): 227 | def hook_factory(module_name, captured_vals, is_input): 228 | def hook(module, input, output): 229 | if is_input: 230 | captured_vals[module_name].append(input[0].detach().cpu()) 231 | else: 232 | captured_vals[module_name].append(output.detach().cpu()) 233 | return hook 234 | 235 | handles = [] 236 | 237 | if model_type == LLAMA_MODEL: 238 | captured_inputs = { 239 | 'k_proj': [], # q_proj, v_proj has the same input as k_proj 240 | 'o_proj': [], 241 | 'gate_proj': [], # up_proj has the same input as gate_proj 242 | 'down_proj': [] 243 | } 244 | 245 | captured_outputs = { 246 | 'v_proj': [], 247 | } 248 | 249 | for name in captured_inputs.keys(): 250 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None) 251 | handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True))) 252 | 253 | for name in captured_outputs.keys(): 254 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None) 255 | handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False))) 256 | 257 | elif model_type == OPT_MODEL: 258 | captured_inputs = { 259 | 'k_proj': [], # q_proj, v_proj has the same input as k_proj 260 | 'out_proj': [], 261 | 'fc1': [], 262 | 'fc2': [] 263 | } 264 | captured_outputs = { 265 | 'v_proj': [], 266 | } 267 | for name in captured_inputs.keys(): 268 | # In OPT, fc1 and fc2 are directly contained in OPTDecoderLayer 269 | module = getattr(layer.self_attn, name, None) or getattr(layer, name, None) 270 | handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True))) 271 | 272 | for name in captured_outputs.keys(): 273 | # In OPT, fc1 and fc2 are directly contained in OPTDecoderLayer 274 | module = getattr(layer.self_attn, name, None) or getattr(layer, name, None) 275 | handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False))) 276 | else: 277 | raise ValueError(f'Unknown model type {model_type}') 278 | 279 | # Process each sequence in the batch one by one to avoid OOM. 280 | for seq_idx in range(layer_input.shape[0]): 281 | # Extract the current sequence across all dimensions. 282 | seq = layer_input[seq_idx:seq_idx + 1].to(utils.DEV) 283 | # Perform a forward pass for the current sequence. 284 | layer(seq) 285 | 286 | # After processing all sequences, concatenate the accumulated inputs for each sub-layer across the batch. 287 | for module_name in captured_inputs: 288 | captured_inputs[module_name] = torch.cat(captured_inputs[module_name], dim=0) 289 | for module_name in captured_outputs: 290 | captured_outputs[module_name] = torch.cat(captured_outputs[module_name], dim=0) 291 | 292 | # Cleanup. 293 | for h in handles: 294 | h.remove() 295 | 296 | return { 297 | 'input': captured_inputs, 298 | 'output': captured_outputs 299 | } 300 | -------------------------------------------------------------------------------- /fake_quant/monkeypatch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import types 4 | 5 | def copy_func_with_new_globals(f, globals=None): 6 | """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" 7 | if globals is None: 8 | globals = f.__globals__ 9 | g = types.FunctionType(f.__code__, globals, name=f.__name__, 10 | argdefs=f.__defaults__, closure=f.__closure__) 11 | g = functools.update_wrapper(g, f) 12 | g.__module__ = f.__module__ 13 | g.__kwdefaults__ = copy.copy(f.__kwdefaults__) 14 | return g 15 | 16 | def add_wrapper_after_function_call_in_method(module, method_name, function_name, wrapper_fn): 17 | ''' 18 | This function adds a wrapper after the output of a function call in the method named `method_name`. 19 | Only calls directly in the method are affected. Calls by other functions called in the method are not affected. 20 | ''' 21 | 22 | original_method = getattr(module, method_name).__func__ 23 | method_globals = dict(original_method.__globals__) 24 | wrapper = wrapper_fn(method_globals[function_name]) 25 | method_globals[function_name] = wrapper 26 | new_method = copy_func_with_new_globals(original_method, globals=method_globals) 27 | setattr(module, method_name, new_method.__get__(module)) 28 | return wrapper 29 | 30 | -------------------------------------------------------------------------------- /fake_quant/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.38.0 2 | torch==2.2.1 3 | sentencepiece==0.2.0 4 | wandb==0.16.3 5 | huggingface-hub==0.20.3 6 | accelerate==0.27.2 7 | datasets==2.17.1 8 | lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@9b0b15b1ccace3534ffbd13298c569869ce8eaf3 -------------------------------------------------------------------------------- /fake_quant/run_llama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | export CUDA_VISIBLE_DEVICES=$gpu_id 5 | 6 | python main.py --model meta-llama/Meta-Llama-3-8B \ 7 | --w_bits 4 \ 8 | --w_groupsize -1 \ 9 | --w_clip \ 10 | --a_bits 4 \ 11 | --v_bits 16 \ 12 | --k_bits 16 \ 13 | --k_asym \ 14 | --v_asym \ 15 | --w_asym \ 16 | --a_asym \ 17 | --a_clip_ratio 0.9 \ 18 | --k_clip_ratio 0.95 \ 19 | --v_clip_ratio 0.95 \ 20 | --asym_calibrate \ 21 | --enable_aq_calibration \ 22 | --rotate -------------------------------------------------------------------------------- /img/readme_intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Computing-Lab-Yale/GPTAQ/883afc64f40ac061a59ef8cbe5bd794c01377e37/img/readme_intro.png -------------------------------------------------------------------------------- /spinquant/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /spinquant/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SpinQuant 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to SpinQuant, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /spinquant/README.md: -------------------------------------------------------------------------------- 1 | 2 | # SpinQuant 3 | 4 | 5 | 6 | This repository contains the code of "[SpinQuant: LLM Quantization with Learned Rotations](https://arxiv.org/pdf/2405.16406)" 7 | 8 | 9 | 10 | ## Run (Following original repo) 11 | 12 | 13 | 14 | ### 1. Requirements: 15 | 16 | * python 3.9, pytorch >= 2.0 17 | 18 | * install pytorch with cuda from https://pytorch.org/get-started/locally/, it is prerequisite for fast-hadamard-transform package. 19 | 20 | * pip install -r requirement.txt 21 | 22 | * git clone https://github.com/Dao-AILab/fast-hadamard-transform.git 23 | 24 | * cd fast-hadamard-transform 25 | 26 | * pip install . 27 | 28 | ### 2. Steps to run: 29 | 30 | **We directly use the pretrained rotation matrix of spinquant to run our experiments** 31 | 32 | You can download the optimized rotation matrices [here](https://drive.google.com/drive/folders/1R2zix4qeXBjcmgnJN1rny93cguJ4rEE8?usp=sharing). 33 | 34 | **Note that to perform GPTQ/GPTQv2, we need to download the W16A4 rotation matrix.** 35 | 36 | After obtaining the optimized_rotation, put the rotation matrix into optimized_rotation_path for evaluation. 37 | 38 | * `bash scripts/2_eval_ptq.sh` 39 | 40 | **For some reasons we don't know, we cannot reproduce exactly the same results with SpinQuant on LLaMA3, if you find out, please let us know**. For example, on LLaMA3-8B, we were unable to reproduce 7.10 perplexity. 41 | 42 | 43 | | Method | Bits | Perplexity | 44 | |-------------------------|------|------------| 45 | | SpinQuant+GPTQ (Paper) | W4A4 | 7.1 | 46 | | SpinQuant+GPTQ (Ours) | W4A4 | 7.26 | 47 | | SpinQuant+GPTQv2 (Ours) | W4A4 | 7.19 | 48 | 49 | 50 | 51 | ### 3. Export to ExecuTorch (The same with original code) 52 | 53 | We also support exporting the quantized model to ExecuTorch, which allows us to utilize the quantization kernels and achieve real-time speedup. For more information on kernel implementation details, please see [ExecuTorch](https://pytorch.org/executorch/stable/index.html), and [ExecuTorch with SpinQuant](https://github.com/pytorch/executorch/tree/main/examples/models/llama#spinquant). We currently support 4-bit weight (set group-size to 256 for 8B model and to 32 for smaller model) and 8-bit dynamic activation quantization. 54 | 55 | 56 | To obtain ExecuTorch-compatible quantized models, you can use the following scripts: 57 | 58 | 59 | * `bash scripts/31_optimize_rotation_executorch.sh $model_name` 60 | 61 | * `bash scripts/32_eval_ptq_executorch.sh $model_name` 62 | 63 | 64 | ### Arguments 65 | 66 | 67 | 68 | - `--input_model`: The model name (or path to the weights) 69 | 70 | - `--output_rotation_path`: The local path we want to store the oprimized rotation matrix 71 | 72 | - `--per_device_train_batch_size`: The batch size for rotation optimization 73 | 74 | - `--per_device_eval_batch_size`: The batch size for PPL evaluation 75 | 76 | - `--a_bits`: The number of bits for activation quantization 77 | 78 | - `--w_bits`: The number of bits for weight quantization 79 | 80 | - `--v_bits`: The number of bits for value quantization 81 | 82 | - `--k_bits`: The number of bits for key quantization 83 | 84 | - `--w_clip`: Whether using the grid search to find best weight clipping range 85 | 86 | - `--w_rtn`: Whether we want to use round-to-nearest quantization. If not having `--w_rtn`, we are using GPTQ quantization. 87 | 88 | - `--w_groupsize`: The group size for group-wise weight quantization. 89 | 90 | - `--rotate`: Whether we want to rotate the model 91 | 92 | - `--optimized_rotation_path`: The checkpoint path of optimized rotation; Use random rotation if path is not given 93 | 94 | 95 | -------------------------------------------------------------------------------- /spinquant/SpinQuant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Computing-Lab-Yale/GPTAQ/883afc64f40ac061a59ef8cbe5bd794c01377e37/spinquant/SpinQuant.png -------------------------------------------------------------------------------- /spinquant/eval_utils/gptaq_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import logging 7 | import functools 8 | 9 | from utils import quant_utils, utils 10 | 11 | 12 | torch.backends.cuda.matmul.allow_tf32 = False 13 | torch.backends.cudnn.allow_tf32 = False 14 | 15 | class GPTAQ: 16 | 17 | def __init__(self, layer): 18 | self.layer = layer 19 | self.dev = self.layer.weight.device 20 | W = layer.weight.data.clone() 21 | self.rows = W.shape[0] 22 | self.columns = W.shape[1] 23 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 24 | self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev) 25 | self.nsamples = 0 26 | self.fp_inp = [] 27 | 28 | def add_batch(self, inp, out): 29 | 30 | if len(inp.shape) == 2: 31 | inp = inp.unsqueeze(0) 32 | tmp = inp.shape[0] 33 | if len(inp.shape) == 3: 34 | inp = inp.reshape((-1, inp.shape[-1])) 35 | 36 | inp = inp.t() 37 | 38 | self.H *= self.nsamples / (self.nsamples + tmp) 39 | self.dXXT *= self.nsamples / (self.nsamples + tmp) 40 | self.nsamples += tmp 41 | inp = math.sqrt(2 / self.nsamples) * inp.float() 42 | self.H += inp.matmul(inp.t()) 43 | dX = self.fp_inp[0].float() * math.sqrt(2 / self.nsamples) - inp 44 | self.dXXT += dX.matmul(inp.t()) 45 | 46 | del self.fp_inp[0] 47 | 48 | def fasterquant( 49 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False 50 | ): 51 | W = self.layer.weight.data.clone() 52 | W = W.float() 53 | 54 | if not self.quantizer.ready(): 55 | self.quantizer.find_params(W) 56 | 57 | H = self.H 58 | del self.H 59 | dead = torch.diag(H) == 0 60 | H[dead, dead] = 1 61 | W[:, dead] = 0 62 | self.dXXT[:, dead] = 0 63 | 64 | if static_groups: 65 | import copy 66 | groups = [] 67 | for i in range(0, self.columns, groupsize): 68 | quantizer = copy.deepcopy(self.quantizer) 69 | quantizer.find_params(W[:, i:(i + groupsize)]) 70 | groups.append(quantizer) 71 | 72 | if actorder: 73 | perm = torch.argsort(torch.diag(H), descending=True) 74 | W = W[:, perm] 75 | H = H[perm][:, perm] 76 | self.dXXT = self.dXXT[perm][:, perm] 77 | invperm = torch.argsort(perm) 78 | 79 | Losses = torch.zeros_like(W) 80 | Q = torch.zeros_like(W) 81 | 82 | damp = percdamp * torch.mean(torch.diag(H)) 83 | diag = torch.arange(self.columns, device=self.dev) 84 | H[diag, diag] += damp 85 | Hinv = torch.linalg.cholesky(H) 86 | Hinv = torch.cholesky_inverse(Hinv) 87 | Hinv = torch.linalg.cholesky(Hinv, upper=True) 88 | 89 | alpha = 0.25 90 | P = alpha * ((self.dXXT @ Hinv.T).triu(diagonal=1)) @ Hinv 91 | del self.dXXT 92 | 93 | for i1 in range(0, self.columns, blocksize): 94 | i2 = min(i1 + blocksize, self.columns) 95 | count = i2 - i1 96 | 97 | W1 = W[:, i1:i2].clone() 98 | Q1 = torch.zeros_like(W1) 99 | Err1 = torch.zeros_like(W1) 100 | Losses1 = torch.zeros_like(W1) 101 | Hinv1 = Hinv[i1:i2, i1:i2] 102 | P1 = P[i1:i2, i1:i2] 103 | 104 | for i in range(count): 105 | w = W1[:, i] 106 | d = Hinv1[i, i] 107 | 108 | if groupsize != -1: 109 | if not static_groups: 110 | if (i1 + i) % groupsize == 0: 111 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)]) 112 | else: 113 | idx = i1 + i 114 | if actorder: 115 | idx = perm[idx] 116 | self.quantizer = groups[idx // groupsize] 117 | 118 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten() 119 | Q1[:, i] = q 120 | Losses1[:, i] = (w - q) ** 2 / d ** 2 121 | 122 | err1 = (w - q) / d 123 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0)) 124 | Err1[:, i] = err1 125 | 126 | Q[:, i1:i2] = Q1 127 | Losses[:, i1:i2] = Losses1 / 2 128 | 129 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:]) 130 | 131 | torch.cuda.synchronize() 132 | 133 | if actorder: 134 | Q = Q[:, invperm] 135 | 136 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 137 | if torch.any(torch.isnan(self.layer.weight.data)): 138 | logging.warning('NaN in weights') 139 | import pprint 140 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point) 141 | raise ValueError('NaN in weights') 142 | 143 | def free(self): 144 | self.H = None 145 | self.Losses = None 146 | self.Trace = None 147 | self.dXXT = None 148 | torch.cuda.empty_cache() 149 | utils.cleanup_memory(verbos=False) 150 | 151 | 152 | class FPInputsCache: 153 | """ 154 | class for saving the full-precision output in each layer. 155 | """ 156 | def __init__(self, sequential): 157 | self.fp_cache = {} 158 | self.names = sequential[0]+sequential[1]+sequential[2]+sequential[3] 159 | print(self.names) 160 | for name in self.names: 161 | self.fp_cache[name] = [] 162 | self.handles = [] 163 | 164 | def cache_fp_input(self, m, inp, out, name): 165 | inp = inp[0].detach() 166 | if len(inp.shape) == 3: 167 | inp = inp.reshape((-1, inp.shape[-1])) 168 | self.fp_cache[name] += [inp.t()] 169 | 170 | def add_hook(self, full): 171 | for name in self.names: 172 | self.handles.append( 173 | full[name].register_forward_hook( 174 | functools.partial(self.cache_fp_input, name=name) 175 | ) 176 | ) 177 | 178 | def clear_hook(self): 179 | for h in self.handles: 180 | h.remove() 181 | self.handles = [] 182 | torch.cuda.empty_cache() 183 | 184 | def clear_cache(self): 185 | for name in self.names: 186 | self.fp_cache[name] = [] 187 | 188 | 189 | @torch.no_grad() 190 | def gptaq_fwrd(model, dataloader, dev, args): 191 | ''' 192 | From GPTQ repo 193 | TODO: Make this function general to support both OPT and LLaMA models 194 | ''' 195 | print('-----GPTQv2 Quantization-----') 196 | 197 | use_cache = model.config.use_cache 198 | model.config.use_cache = False 199 | layers = model.model.layers 200 | 201 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 202 | model.model.norm = model.model.norm.to(dev) 203 | model.model.rotary_emb = model.model.rotary_emb.to(dev) 204 | 205 | layers[0] = layers[0].to(dev) 206 | 207 | dtype = next(iter(model.parameters())).dtype 208 | inps = torch.zeros( 209 | (args.nsamples, 2048, model.config.hidden_size), dtype=dtype, device=dev 210 | ) 211 | cache = {'i': 0, 'attention_mask': None} 212 | 213 | class Catcher(nn.Module): 214 | def __init__(self, module): 215 | super().__init__() 216 | self.module = module 217 | 218 | def forward(self, inp, **kwargs): 219 | inps[cache['i']] = inp 220 | cache['i'] += 1 221 | cache['attention_mask'] = kwargs['attention_mask'] 222 | cache['position_ids'] = kwargs['position_ids'] 223 | raise ValueError 224 | 225 | layers[0] = Catcher(layers[0]) 226 | for batch in dataloader: 227 | try: 228 | model(batch[0].to(dev)) 229 | except ValueError: 230 | pass 231 | layers[0] = layers[0].module 232 | layers[0] = layers[0].cpu() 233 | 234 | model.model.embed_tokens = model.model.embed_tokens.cpu() 235 | model.model.norm = model.model.norm.cpu() 236 | model.model.rotary_emb = model.model.rotary_emb.cpu() 237 | torch.cuda.empty_cache() 238 | 239 | outs = torch.zeros_like(inps) 240 | 241 | attention_mask = cache['attention_mask'] 242 | position_ids = cache['position_ids'] 243 | 244 | quantizers = {} 245 | sequential = [ 246 | ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'], 247 | ['self_attn.o_proj.module'], 248 | ['mlp.up_proj.module', 'mlp.gate_proj.module'], 249 | ['mlp.down_proj.module'] 250 | ] 251 | 252 | fp_inputs_cache = FPInputsCache(sequential) 253 | fp_inps = inps.clone() 254 | 255 | for i in range(len(layers)): 256 | print(f'\nLayer {i}:', flush=True, end=' ') 257 | layer = layers[i].to(dev) 258 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear]) 259 | bits_config = quant_utils.disable_act_quant(layer) 260 | fp_inputs_cache.add_hook(full) 261 | 262 | for j in range(args.nsamples): 263 | fp_inps[j] = layer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 264 | fp_inputs_cache.clear_hook() 265 | quant_utils.enable_act_quant(layer, bits_config) 266 | 267 | for names in sequential: 268 | subset = {n: full[n] for n in names} 269 | 270 | gptq = {} 271 | for name in subset: 272 | print(f'{name}', end=' ', flush=True) 273 | layer_weight_bits = args.w_bits 274 | layer_weight_sym = not (args.w_asym) 275 | if 'lm_head' in name: 276 | layer_weight_bits = 16 277 | continue 278 | if args.int8_down_proj and 'down_proj' in name: 279 | layer_weight_bits = 8 280 | gptq[name] = GPTAQ(subset[name]) 281 | gptq[name].quantizer = quant_utils.WeightQuantizer() 282 | gptq[name].quantizer.configure( 283 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip 284 | ) 285 | gptq[name].fp_inp = fp_inputs_cache.fp_cache[name] 286 | 287 | def add_batch(name): 288 | def tmp(_, inp, out): 289 | gptq[name].add_batch(inp[0].data, out.data) 290 | 291 | return tmp 292 | 293 | first_module_name = list(subset.keys())[0] 294 | handle = subset[first_module_name].register_forward_hook(add_batch(first_module_name)) 295 | 296 | for j in range(args.nsamples): 297 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 298 | handle.remove() 299 | 300 | # copy H and dXXT 301 | for name in subset: 302 | if name != first_module_name: 303 | gptq[name].H = gptq[first_module_name].H 304 | gptq[name].dXXT = gptq[first_module_name].dXXT 305 | 306 | for name in subset: 307 | layer_w_groupsize = args.w_groupsize 308 | gptq[name].fasterquant( 309 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False 310 | ) 311 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer 312 | gptq[name].free() 313 | 314 | for j in range(args.nsamples): 315 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 316 | 317 | fp_inputs_cache.clear_cache() 318 | layers[i] = layer.cpu() 319 | del layer 320 | del gptq 321 | torch.cuda.empty_cache() 322 | 323 | inps, outs = outs, inps 324 | 325 | model.config.use_cache = use_cache 326 | utils.cleanup_memory(verbos=True) 327 | print('-----GPTQv2 Quantization Done-----\n') 328 | return quantizers 329 | -------------------------------------------------------------------------------- /spinquant/eval_utils/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import torch 12 | import transformers 13 | 14 | from eval_utils import gptq_utils, gptaq_utils, rotation_utils 15 | from utils import data_utils, fuse_norm_utils, hadamard_utils, quant_utils, utils 16 | from utils.convert_to_executorch import ( 17 | sanitize_checkpoint_from_spinquant, 18 | write_model_llama, 19 | ) 20 | 21 | 22 | def add_input_quantization(model, args): 23 | # Add Input Quantization 24 | if args.a_bits < 16 or args.v_bits < 16: 25 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper]) 26 | down_proj_groupsize = -1 27 | if args.a_groupsize > 0: 28 | down_proj_groupsize = utils.llama_down_proj_groupsize( 29 | model, args.a_groupsize 30 | ) 31 | 32 | for name in qlayers: 33 | layer_input_bits = args.a_bits 34 | layer_groupsize = args.a_groupsize 35 | layer_a_sym = not (args.a_asym) 36 | layer_a_clip = args.a_clip_ratio 37 | 38 | num_heads = model.config.num_attention_heads 39 | model_dim = model.config.hidden_size 40 | head_dim = model_dim // num_heads 41 | 42 | if "v_proj" in name and args.v_bits < 16: # Set the v_proj precision 43 | v_groupsize = head_dim 44 | qlayers[name].out_quantizer.configure( 45 | bits=args.v_bits, 46 | groupsize=v_groupsize, 47 | sym=not (args.v_asym), 48 | clip_ratio=args.v_clip_ratio, 49 | ) 50 | 51 | if "o_proj" in name: 52 | layer_groupsize = head_dim 53 | 54 | if "lm_head" in name: # Skip lm_head quantization 55 | layer_input_bits = 16 56 | 57 | if "down_proj" in name: # Set the down_proj precision 58 | if args.int8_down_proj: 59 | layer_input_bits = 8 60 | layer_groupsize = down_proj_groupsize 61 | 62 | qlayers[name].quantizer.configure( 63 | bits=layer_input_bits, 64 | groupsize=layer_groupsize, 65 | sym=layer_a_sym, 66 | clip_ratio=layer_a_clip, 67 | ) 68 | 69 | if args.k_bits < 16: 70 | if args.k_pre_rope: 71 | raise NotImplementedError("Pre-RoPE quantization is not supported yet!") 72 | else: 73 | rope_function_name = "apply_rotary_pos_emb" 74 | layers = model.model.layers 75 | k_quant_config = { 76 | "k_bits": args.k_bits, 77 | "k_groupsize": args.k_groupsize, 78 | "k_sym": not (args.k_asym), 79 | "k_clip_ratio": args.k_clip_ratio, 80 | } 81 | for layer in layers: 82 | rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward( 83 | layer.self_attn, 84 | rope_function_name, 85 | config=model.config, 86 | **k_quant_config, 87 | ) 88 | 89 | 90 | def ptq_model(args, model, model_args=None): 91 | transformers.set_seed(args.seed) 92 | model.eval() 93 | 94 | # Rotate the weights 95 | if args.rotate: 96 | fuse_norm_utils.fuse_layer_norms(model) 97 | rotation_utils.rotate_model(model, args) 98 | utils.cleanup_memory(verbos=True) 99 | 100 | quant_utils.add_actquant(model) # Add Activation Wrapper to the model 101 | qlayers = quant_utils.find_qlayers(model) 102 | for name in qlayers: 103 | if "down_proj" in name: 104 | had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size) 105 | qlayers[name].online_full_had = True 106 | qlayers[name].had_K = had_K 107 | qlayers[name].K = K 108 | qlayers[name].fp32_had = args.fp32_had 109 | else: 110 | quant_utils.add_actquant( 111 | model 112 | ) # Add Activation Wrapper to the model as the rest of the code assumes it is present 113 | 114 | if args.enable_aq_calibration: 115 | add_input_quantization(model, args) 116 | 117 | if args.w_bits < 16: 118 | save_dict = {} 119 | if args.load_qmodel_path: # Load Quantized Rotated Model 120 | assert args.rotate, "Model should be rotated to load a quantized model!" 121 | assert ( 122 | not args.save_qmodel_path 123 | ), "Cannot save a quantized model if it is already loaded!" 124 | print("Load quantized model from ", args.load_qmodel_path) 125 | save_dict = torch.load(args.load_qmodel_path) 126 | model.load_state_dict(save_dict["model"], strict=False) 127 | 128 | elif not args.w_rtn: # GPTQ Weight Quantization 129 | trainloader = data_utils.get_wikitext2( 130 | nsamples=args.nsamples, 131 | seed=args.seed, 132 | model=model_args.input_model, 133 | seqlen=2048, 134 | eval_mode=False, 135 | ) 136 | if args.export_to_et: 137 | # quantize lm_head and embedding with 8bit per-channel quantization with rtn for executorch 138 | quantizers = gptq_utils.rtn_fwrd( 139 | model, 140 | "cuda", 141 | args, 142 | custom_layers=[model.model.embed_tokens, model.lm_head], 143 | ) 144 | # quantize other layers with gptq 145 | if args.asym_calibrate: 146 | quantizers = gptaq_utils.gptaq_fwrd(model, trainloader, "cuda", args) 147 | save_dict["w_quantizers"] = quantizers 148 | else: 149 | quantizers = gptq_utils.gptq_fwrd(model, trainloader, "cuda", args) 150 | save_dict["w_quantizers"] = quantizers 151 | else: # RTN Weight Quantization 152 | quantizers = gptq_utils.rtn_fwrd(model, "cuda", args) 153 | save_dict["w_quantizers"] = quantizers 154 | 155 | if args.save_qmodel_path: 156 | save_dict["model"] = model.state_dict() 157 | if args.export_to_et: 158 | save_dict = write_model_llama( 159 | model.state_dict(), model.config, num_shards=1 160 | )[0] # Export num_shards == 1 for executorch 161 | save_dict = sanitize_checkpoint_from_spinquant( 162 | save_dict, group_size=args.w_groupsize 163 | ) 164 | torch.save(save_dict, args.save_qmodel_path) 165 | 166 | if not args.enable_aq_calibration: 167 | add_input_quantization(model, args) 168 | 169 | return model 170 | -------------------------------------------------------------------------------- /spinquant/eval_utils/rotation_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import functools 12 | import math 13 | 14 | import torch 15 | import tqdm 16 | 17 | from utils import monkeypatch, quant_utils, utils 18 | from utils.hadamard_utils import ( 19 | apply_exact_had_to_linear, 20 | is_pow2, 21 | random_hadamard_matrix, 22 | ) 23 | from utils.utils import HadamardTransform 24 | 25 | 26 | def random_orthogonal_matrix(size, device): 27 | """ 28 | Generate a random orthogonal matrix of the specified size. 29 | First, we generate a random matrix with entries from a standard distribution. 30 | Then, we use QR decomposition to obtain an orthogonal matrix. 31 | Finally, we multiply by a diagonal matrix with diag r to adjust the signs. 32 | 33 | Args: 34 | size (int): The size of the matrix (size x size). 35 | 36 | Returns: 37 | torch.Tensor: An orthogonal matrix of the specified size. 38 | """ 39 | torch.cuda.empty_cache() 40 | random_matrix = torch.randn(size, size, dtype=torch.float64).to(device) 41 | q, r = torch.linalg.qr(random_matrix) 42 | q *= torch.sign(torch.diag(r)).unsqueeze(0) 43 | return q 44 | 45 | 46 | def get_orthogonal_matrix(size, mode, device="cuda"): 47 | if mode == "random": 48 | return random_orthogonal_matrix(size, device) 49 | elif mode == "hadamard": 50 | return random_hadamard_matrix(size, device) 51 | else: 52 | raise ValueError(f"Unknown mode {mode}") 53 | 54 | 55 | def rotate_embeddings(model, R1: torch.Tensor) -> None: 56 | # Rotate the embeddings. 57 | for W in [model.model.embed_tokens]: 58 | dtype = W.weight.data.dtype 59 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64) 60 | W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) 61 | 62 | 63 | def rotate_attention_inputs(layer, R1) -> None: 64 | # Rotate the WQ, WK and WV matrices of the self-attention layer. 65 | for W in [layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj]: 66 | dtype = W.weight.dtype 67 | W_ = W.weight.to(device="cuda", dtype=torch.float64) 68 | W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) 69 | 70 | 71 | def rotate_attention_output(layer, R1) -> None: 72 | # Rotate output matrix of the self-attention layer. 73 | W = layer.self_attn.o_proj 74 | 75 | dtype = W.weight.data.dtype 76 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64) 77 | W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype) 78 | if W.bias is not None: 79 | b = W.bias.data.to(device="cuda", dtype=torch.float64) 80 | W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype) 81 | 82 | 83 | def rotate_mlp_input(layer, R1): 84 | # Rotate the MLP input weights. 85 | mlp_inputs = [layer.mlp.up_proj, layer.mlp.gate_proj] 86 | for W in mlp_inputs: 87 | dtype = W.weight.dtype 88 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64) 89 | W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) 90 | 91 | 92 | def rotate_mlp_output(layer, R1): 93 | # Rotate the MLP output weights and bias. 94 | W = layer.mlp.down_proj 95 | dtype = W.weight.data.dtype 96 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64) 97 | W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype) 98 | apply_exact_had_to_linear( 99 | W, had_dim=-1, output=False 100 | ) # apply exact (inverse) hadamard on the weights of mlp output 101 | if W.bias is not None: 102 | b = W.bias.data.to(device="cuda", dtype=torch.float64) 103 | W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype) 104 | 105 | 106 | def rotate_head(model, R1: torch.Tensor) -> None: 107 | # Rotate the head. 108 | W = model.lm_head 109 | dtype = W.weight.data.dtype 110 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64) 111 | W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) 112 | 113 | 114 | def rotate_ov_proj(layer, head_num, head_dim, R2=None): 115 | v_proj = layer.self_attn.v_proj 116 | o_proj = layer.self_attn.o_proj 117 | 118 | apply_exact_had_to_linear(v_proj, had_dim=head_dim, output=True, R2=R2) 119 | apply_exact_had_to_linear(o_proj, had_dim=head_dim, output=False, R2=R2) 120 | 121 | 122 | @torch.inference_mode() 123 | def rotate_model(model, args): 124 | R1 = get_orthogonal_matrix(model.config.hidden_size, args.rotate_mode) 125 | if args.optimized_rotation_path is not None: 126 | R_cpk = args.optimized_rotation_path 127 | R1 = torch.load(R_cpk)["R1"].cuda().to(torch.float64) 128 | config = model.config 129 | num_heads = config.num_attention_heads 130 | model_dim = config.hidden_size 131 | head_dim = model_dim // num_heads 132 | 133 | rotate_embeddings(model, R1) 134 | rotate_head(model, R1) 135 | utils.cleanup_memory() 136 | layers = [layer for layer in model.model.layers] 137 | for idx, layer in enumerate(tqdm.tqdm(layers, unit="layer", desc="Rotating")): 138 | if args.optimized_rotation_path is not None: 139 | key = f"model.layers.{idx}.self_attn.R2" 140 | R2 = torch.load(R_cpk)[key].cuda().to(torch.float64) 141 | else: 142 | R2 = get_orthogonal_matrix(head_dim, args.rotate_mode) 143 | rotate_attention_inputs(layers[idx], R1) 144 | rotate_attention_output(layers[idx], R1) 145 | rotate_mlp_input(layers[idx], R1) 146 | rotate_mlp_output(layers[idx], R1) 147 | rotate_ov_proj(layers[idx], num_heads, head_dim, R2=R2) 148 | 149 | 150 | class QKRotationWrapper(torch.nn.Module): 151 | def __init__(self, func, config, *args, **kwargs): 152 | super().__init__() 153 | self.config = config 154 | num_heads = config.num_attention_heads 155 | model_dim = config.hidden_size 156 | head_dim = model_dim // num_heads 157 | assert is_pow2( 158 | head_dim 159 | ), f"Only power of 2 head_dim is supported for K-cache Quantization!" 160 | self.func = func 161 | self.k_quantizer = quant_utils.ActQuantizer() 162 | self.k_bits = 16 163 | if kwargs is not None: 164 | assert kwargs["k_groupsize"] in [ 165 | -1, 166 | head_dim, 167 | ], f"Only token-wise/{head_dim}g quantization is supported for K-cache" 168 | self.k_bits = kwargs["k_bits"] 169 | self.k_groupsize = kwargs["k_groupsize"] 170 | self.k_sym = kwargs["k_sym"] 171 | self.k_clip_ratio = kwargs["k_clip_ratio"] 172 | self.k_quantizer.configure( 173 | bits=self.k_bits, 174 | groupsize=-1, # we put -1 to be toke-wise quantization and handle head-wise quantization by ourself 175 | sym=self.k_sym, 176 | clip_ratio=self.k_clip_ratio, 177 | ) 178 | 179 | def forward(self, *args, **kwargs): 180 | q, k = self.func(*args, **kwargs) 181 | dtype = q.dtype 182 | q = (HadamardTransform.apply(q.float()) / math.sqrt(q.shape[-1])).to(dtype) 183 | k = (HadamardTransform.apply(k.float()) / math.sqrt(k.shape[-1])).to(dtype) 184 | (bsz, num_heads, seq_len, head_dim) = k.shape 185 | 186 | if self.k_groupsize == -1: # token-wise quantization 187 | token_wise_k = k.transpose(1, 2).reshape(-1, num_heads * head_dim) 188 | self.k_quantizer.find_params(token_wise_k) 189 | k = ( 190 | self.k_quantizer(token_wise_k) 191 | .reshape((bsz, seq_len, num_heads, head_dim)) 192 | .transpose(1, 2) 193 | .to(q) 194 | ) 195 | else: # head-wise quantization 196 | per_head_k = k.view(-1, head_dim) 197 | self.k_quantizer.find_params(per_head_k) 198 | k = ( 199 | self.k_quantizer(per_head_k) 200 | .reshape((bsz, num_heads, seq_len, head_dim)) 201 | .to(q) 202 | ) 203 | 204 | self.k_quantizer.free() 205 | 206 | return q, k 207 | 208 | 209 | def add_qk_rotation_wrapper_after_function_call_in_forward( 210 | module, 211 | function_name, 212 | *args, 213 | **kwargs, 214 | ): 215 | """ 216 | This function adds a rotation wrapper after the output of a function call in forward. 217 | Only calls directly in the forward function are affected. calls by other functions called in forward are not affected. 218 | """ 219 | 220 | attr_name = f"{function_name}_qk_rotation_wrapper" 221 | assert not hasattr(module, attr_name) 222 | wrapper = monkeypatch.add_wrapper_after_function_call_in_method( 223 | module, 224 | "forward", 225 | function_name, 226 | functools.partial(QKRotationWrapper, *args, **kwargs), 227 | ) 228 | setattr(module, attr_name, wrapper) 229 | -------------------------------------------------------------------------------- /spinquant/optimize_rotation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import datetime 9 | import os 10 | from logging import Logger 11 | 12 | import datasets 13 | import torch 14 | import torch.distributed as dist 15 | from torch import nn 16 | from transformers import LlamaTokenizerFast, Trainer, default_data_collator 17 | import transformers 18 | from train_utils.fsdp_trainer import FSDPTrainer 19 | from train_utils.main import prepare_model 20 | from train_utils.modeling_llama_quant import LlamaForCausalLM as LlamaForCausalLMQuant 21 | from train_utils.optimizer import SGDG 22 | from utils.data_utils import CustomJsonDataset 23 | from utils.hadamard_utils import random_hadamard_matrix 24 | from utils.process_args import process_args_ptq 25 | from utils.utils import get_local_rank, get_logger, pt_fsdp_state_dict 26 | 27 | log: Logger = get_logger("spinquant") 28 | 29 | 30 | class RotateModule(nn.Module): 31 | def __init__(self, R_init): 32 | super(RotateModule, self).__init__() 33 | self.weight = nn.Parameter(R_init.to(torch.float32).to(torch.device("cuda"))) 34 | 35 | def forward(self, x, transpose=False): 36 | if transpose: 37 | return x @ self.weight 38 | else: 39 | return self.weight @ x 40 | 41 | 42 | def train() -> None: 43 | dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8)) 44 | model_args, training_args, ptq_args = process_args_ptq() 45 | local_rank = get_local_rank() 46 | 47 | log.info("the rank is {}".format(local_rank)) 48 | torch.distributed.barrier() 49 | 50 | config = transformers.AutoConfig.from_pretrained( 51 | model_args.input_model, token=model_args.access_token 52 | ) 53 | 54 | # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens 55 | process_word_embeddings = False 56 | if config.tie_word_embeddings: 57 | config.tie_word_embeddings = False 58 | process_word_embeddings = True 59 | dtype = torch.bfloat16 if training_args.bf16 else torch.float16 60 | model = LlamaForCausalLMQuant.from_pretrained( 61 | pretrained_model_name_or_path=model_args.input_model, 62 | config=config, 63 | torch_dtype=dtype, 64 | token=model_args.access_token, 65 | ) 66 | if process_word_embeddings: 67 | model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone() 68 | 69 | model = prepare_model(ptq_args, model) 70 | for param in model.parameters(): 71 | param.requires_grad = False 72 | R1 = random_hadamard_matrix(model.config.hidden_size, "cuda") 73 | model.R1 = RotateModule(R1) 74 | for i in range(model.config.num_hidden_layers): 75 | # Each head dim = 128 for Llama model 76 | R2 = random_hadamard_matrix( 77 | model.config.hidden_size // model.config.num_attention_heads, "cuda" 78 | ) 79 | model.model.layers[i].self_attn.R2 = RotateModule(R2) 80 | if local_rank == 0: 81 | log.info("Model init completed for training {}".format(model)) 82 | log.info("Start to load tokenizer...") 83 | tokenizer = LlamaTokenizerFast.from_pretrained( 84 | pretrained_model_name_or_path=model_args.input_model, 85 | cache_dir=training_args.cache_dir, 86 | model_max_length=training_args.model_max_length, 87 | padding_side="right", 88 | use_fast=True, 89 | add_eos_token=False, 90 | add_bos_token=False, 91 | token=model_args.access_token, 92 | ) 93 | log.info("Complete tokenizer loading...") 94 | model.config.use_cache = False 95 | calibration_datasets = datasets.load_dataset( 96 | "Salesforce/wikitext", "wikitext-2-raw-v1" 97 | ) 98 | train_data = CustomJsonDataset( 99 | calibration_datasets["train"], 100 | tokenizer, 101 | block_size=min(training_args.model_max_length, 2048), 102 | ) 103 | 104 | trainable_parameters = [model.R1.weight] + [ 105 | model.model.layers[i].self_attn.R2.weight 106 | for i in range(model.config.num_hidden_layers) 107 | ] 108 | model.seqlen = training_args.model_max_length 109 | optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True) 110 | MyTrainer = Trainer 111 | # Use FSDP for 70B rotation training 112 | if training_args.fsdp != "" and training_args.fsdp != []: 113 | MyTrainer = FSDPTrainer 114 | 115 | trainer = MyTrainer( 116 | model=model, 117 | tokenizer=tokenizer, 118 | args=training_args, 119 | train_dataset=train_data, 120 | eval_dataset=None, 121 | data_collator=default_data_collator, 122 | optimizers=(optimizer, None), 123 | ) 124 | torch.distributed.barrier() 125 | 126 | trainer.train() 127 | if training_args.fsdp != "" and training_args.fsdp != []: 128 | cpu_state = pt_fsdp_state_dict(trainer.model) 129 | else: 130 | cpu_state = trainer.model.state_dict() 131 | 132 | R_dict = { 133 | key.replace(".weight", ""): value 134 | for key, value in cpu_state.items() 135 | if "R1.weight" in key or "self_attn.R2" in key 136 | } 137 | if local_rank == 0: 138 | os.makedirs(model_args.output_rotation_path, exist_ok=True) 139 | path = os.path.join(model_args.output_rotation_path, "R.bin") 140 | torch.save( 141 | R_dict, 142 | path, 143 | ) 144 | dist.barrier() 145 | 146 | 147 | if __name__ == "__main__": 148 | train() 149 | -------------------------------------------------------------------------------- /spinquant/ptq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import datetime 9 | from logging import Logger 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from transformers import LlamaTokenizerFast 14 | import transformers 15 | from eval_utils.main import ptq_model 16 | from eval_utils.modeling_llama import LlamaForCausalLM 17 | from utils import data_utils, eval_utils, utils 18 | from utils.process_args import process_args_ptq 19 | 20 | log: Logger = utils.get_logger("spinquant") 21 | 22 | 23 | def train() -> None: 24 | dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8)) 25 | model_args, training_args, ptq_args = process_args_ptq() 26 | local_rank = utils.get_local_rank() 27 | 28 | log.info("the rank is {}".format(local_rank)) 29 | torch.distributed.barrier() 30 | 31 | config = transformers.AutoConfig.from_pretrained( 32 | model_args.input_model, token=model_args.access_token 33 | ) 34 | # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens 35 | process_word_embeddings = False 36 | if config.tie_word_embeddings: 37 | config.tie_word_embeddings = False 38 | process_word_embeddings = True 39 | dtype = torch.bfloat16 if training_args.bf16 else torch.float16 40 | model = LlamaForCausalLM.from_pretrained( 41 | pretrained_model_name_or_path=model_args.input_model, 42 | config=config, 43 | torch_dtype=dtype, 44 | token=model_args.access_token, 45 | ) 46 | if process_word_embeddings: 47 | model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone() 48 | # model.cuda() 49 | 50 | model = ptq_model(ptq_args, model, model_args) 51 | model.seqlen = training_args.model_max_length 52 | if local_rank == 0: 53 | # log.info("Model PTQ completed {}".format(model)) 54 | log.info("Start to load tokenizer...") 55 | tokenizer = LlamaTokenizerFast.from_pretrained( 56 | pretrained_model_name_or_path=model_args.input_model, 57 | cache_dir=training_args.cache_dir, 58 | model_max_length=training_args.model_max_length, 59 | padding_side="right", 60 | use_fast=True, 61 | add_eos_token=False, 62 | add_bos_token=False, 63 | token=model_args.access_token, 64 | ) 65 | log.info("Complete tokenizer loading...") 66 | model.config.use_cache = False 67 | 68 | testloader = data_utils.get_wikitext2( 69 | seed=ptq_args.seed, 70 | seqlen=2048, 71 | tokenizer=tokenizer, 72 | eval_mode=True, 73 | ) 74 | 75 | dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, ptq_args) 76 | log.info("wiki2 ppl is: {}".format(dataset_ppl)) 77 | 78 | if not ptq_args.lm_eval: 79 | return 80 | else: 81 | # Import lm_eval utils 82 | import lm_eval 83 | from lm_eval import utils as lm_eval_utils 84 | from lm_eval.api.registry import ALL_TASKS 85 | from lm_eval.models.huggingface import HFLM 86 | 87 | if ptq_args.distribute: 88 | utils.distribute_model(model) 89 | else: 90 | model.to(utils.DEV) 91 | 92 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.input_model, use_fast=False, use_auth_token=None) 93 | hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=ptq_args.lm_eval_batch_size) 94 | 95 | # commenting out this line as it will include two lambda sub-tasks 96 | # task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS) 97 | task_names = ptq_args.tasks 98 | results = lm_eval.simple_evaluate(hflm, tasks=task_names, batch_size=ptq_args.lm_eval_batch_size)['results'] 99 | 100 | metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results.items()} 101 | metric_vals['acc_avg'] = round(sum(metric_vals.values()) / len(metric_vals.values()), 4) 102 | print(metric_vals) 103 | 104 | 105 | dist.barrier() 106 | 107 | 108 | if __name__ == "__main__": 109 | train() 110 | -------------------------------------------------------------------------------- /spinquant/requirement.txt: -------------------------------------------------------------------------------- 1 | transformers==4.44.2 2 | accelerate==0.31.0 3 | datasets==2.20.0 4 | sentencepiece 5 | tensorboardX 6 | -------------------------------------------------------------------------------- /spinquant/scripts/10_optimize_rotation.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node) 9 | # nproc_per_node indicates the number of GPUs per node to employ. 10 | torchrun --nnodes=1 --nproc_per_node=8 optimize_rotation.py \ 11 | --input_model $1 \ 12 | --output_rotation_path "your_path" \ 13 | --output_dir "your_output_path/" \ 14 | --logging_dir "your_log_path/" \ 15 | --model_max_length 2048 \ 16 | --fp16 False \ 17 | --bf16 True \ 18 | --log_on_each_node False \ 19 | --per_device_train_batch_size 1 \ 20 | --logging_steps 1 \ 21 | --learning_rate 1.5 \ 22 | --weight_decay 0. \ 23 | --lr_scheduler_type "cosine" \ 24 | --gradient_checkpointing True \ 25 | --save_safetensors False \ 26 | --max_steps 100 \ 27 | --w_bits $2 \ 28 | --a_bits $3 \ 29 | --k_bits $4 \ 30 | --v_bits $4 \ 31 | --w_clip \ 32 | --a_asym \ 33 | --k_asym \ 34 | --v_asym \ 35 | --k_groupsize 128 \ 36 | --v_groupsize 128 \ 37 | -------------------------------------------------------------------------------- /spinquant/scripts/11_optimize_rotation_fsdp.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node) 9 | # nproc_per_node indicates the number of GPUs per node to employ. 10 | torchrun --nnodes=1 --nproc_per_node=8 optimize_rotation.py \ 11 | --input_model $1 \ 12 | --output_rotation_path "your_path" \ 13 | --output_dir "your_output_path/" \ 14 | --logging_dir "your_log_path/" \ 15 | --model_max_length 2048 \ 16 | --fp16 False \ 17 | --bf16 True \ 18 | --log_on_each_node False \ 19 | --per_device_train_batch_size 1 \ 20 | --logging_steps 1 \ 21 | --learning_rate 1.5 \ 22 | --weight_decay 0. \ 23 | --lr_scheduler_type "cosine" \ 24 | --gradient_checkpointing True \ 25 | --max_steps 100 \ 26 | --w_bits $2 \ 27 | --a_bits $3 \ 28 | --k_bits $4 \ 29 | --v_bits $4 \ 30 | --w_clip \ 31 | --a_asym \ 32 | --k_asym \ 33 | --v_asym \ 34 | --k_groupsize 128 \ 35 | --v_groupsize 128 \ 36 | --fsdp "full_shard auto_wrap" \ 37 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' 38 | -------------------------------------------------------------------------------- /spinquant/scripts/2_eval_ptq.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node) 9 | # nproc_per_node indicates the number of GPUs per node to employ. 10 | 11 | 12 | torchrun --nnodes=1 --nproc_per_node=1 ptq.py \ 13 | --input_model "meta-llama/Meta-Llama-3-8B" \ 14 | --do_train False \ 15 | --do_eval True \ 16 | --per_device_eval_batch_size 16 \ 17 | --model_max_length 2048 \ 18 | --fp16 False \ 19 | --bf16 True \ 20 | --save_safetensors False \ 21 | --w_bits 4 \ 22 | --a_bits 4 \ 23 | --k_bits 16 \ 24 | --v_bits 16 \ 25 | --w_clip \ 26 | --a_asym \ 27 | --k_asym \ 28 | --v_asym \ 29 | --k_groupsize 128 \ 30 | --v_groupsize 128 \ 31 | --rotate \ 32 | --optimized_rotation_path "ckpts/8B_W16A4KV16_lr_1.5_seed_0/R.bin" \ 33 | --asym_calibrate \ 34 | --enable_ap_calibration \ 35 | -------------------------------------------------------------------------------- /spinquant/scripts/31_optimize_rotation_executorch.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node) 9 | # nproc_per_node indicates the number of GPUs per node to employ. 10 | torchrun --nnodes=1 --nproc_per_node=8 optimize_rotation.py \ 11 | --input_model $1 \ 12 | --output_rotation_path "your_path" \ 13 | --output_dir "your_output_path/" \ 14 | --logging_dir "your_log_path/" \ 15 | --model_max_length 2048 \ 16 | --fp16 False \ 17 | --bf16 True \ 18 | --log_on_each_node False \ 19 | --per_device_train_batch_size 1 \ 20 | --logging_steps 1 \ 21 | --learning_rate 1.5 \ 22 | --weight_decay 0. \ 23 | --lr_scheduler_type "cosine" \ 24 | --gradient_checkpointing True \ 25 | --save_safetensors False \ 26 | --max_steps 100 \ 27 | --w_bits 16 \ 28 | --a_bits 8 \ 29 | --w_clip \ 30 | --a_asym \ 31 | --w_groupsize 32 32 | -------------------------------------------------------------------------------- /spinquant/scripts/32_eval_ptq_executorch.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node) 9 | # nproc_per_node indicates the number of GPUs per node to employ. 10 | torchrun --nnodes=1 --nproc_per_node=1 ptq.py \ 11 | --input_model $1 \ 12 | --do_train False \ 13 | --do_eval True \ 14 | --per_device_eval_batch_size 4 \ 15 | --model_max_length 2048 \ 16 | --fp16 False \ 17 | --bf16 True \ 18 | --save_safetensors False \ 19 | --w_bits 4 \ 20 | --a_bits 8 \ 21 | --w_clip \ 22 | --w_groupsize 32 \ 23 | --a_asym \ 24 | --rotate \ 25 | --optimized_rotation_path "your_path/R.bin" \ 26 | --save_qmodel_path "./your_output_model_path/consolidated.00.pth" \ 27 | --export_to_et 28 | -------------------------------------------------------------------------------- /spinquant/train_utils/apply_r3_r4.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import math 12 | 13 | import torch 14 | import tqdm 15 | 16 | from utils import quant_utils, utils 17 | from utils.hadamard_utils import ( 18 | apply_exact_had_to_linear, 19 | is_pow2, 20 | ) 21 | from utils.utils import HadamardTransform 22 | 23 | 24 | def R4_rotate_down_proj_weights(layer): 25 | # Rotate the MLP output weights and bias. 26 | W = layer.mlp.down_proj 27 | apply_exact_had_to_linear( 28 | W, had_dim=-1, output=False 29 | ) # apply exact (inverse) hadamard on the weights of mlp output 30 | 31 | 32 | @torch.inference_mode() 33 | def rotate_model(model, args): 34 | config = model.config 35 | num_heads = config.num_attention_heads 36 | model_dim = config.hidden_size 37 | head_dim = model_dim // num_heads 38 | 39 | utils.cleanup_memory() 40 | layers = [layer for layer in model.model.layers] 41 | for idx, layer in enumerate( 42 | tqdm.tqdm(layers, unit="layer", desc="Applying R4 rotation to W_down") 43 | ): 44 | R4_rotate_down_proj_weights(layers[idx]) 45 | 46 | 47 | class QKRotationWrapper(torch.nn.Module): 48 | def __init__(self, func, config, *args, **kwargs): 49 | super().__init__() 50 | self.config = config 51 | num_heads = config.num_attention_heads 52 | model_dim = config.hidden_size 53 | head_dim = model_dim // num_heads 54 | assert is_pow2( 55 | head_dim 56 | ), f"Only power of 2 head_dim is supported for K-cache Quantization!" 57 | self.func = func 58 | self.k_quantizer = quant_utils.ActQuantizer() 59 | self.k_bits = 16 60 | if kwargs is not None: 61 | assert kwargs["k_groupsize"] in [ 62 | -1, 63 | head_dim, 64 | ], f"Only token-wise/{head_dim}g quantization is supported for K-cache" 65 | self.k_bits = kwargs["k_bits"] 66 | self.k_groupsize = kwargs["k_groupsize"] 67 | self.k_sym = kwargs["k_sym"] 68 | self.k_clip_ratio = kwargs["k_clip_ratio"] 69 | self.k_quantizer.configure( 70 | bits=self.k_bits, 71 | groupsize=-1, # we put -1 to be toke-wise quantization and handle head-wise quantization by ourself 72 | sym=self.k_sym, 73 | clip_ratio=self.k_clip_ratio, 74 | ) 75 | 76 | def forward(self, *args, **kwargs): 77 | q, k = self.func(*args, **kwargs) 78 | dtype = q.dtype 79 | q = (HadamardTransform.apply(q.float()) / math.sqrt(q.shape[-1])).to(dtype) 80 | k = (HadamardTransform.apply(k.float()) / math.sqrt(k.shape[-1])).to(dtype) 81 | (bsz, num_heads, seq_len, head_dim) = k.shape 82 | 83 | if self.k_groupsize == -1: # token-wise quantization 84 | token_wise_k = k.transpose(1, 2).reshape(-1, num_heads * head_dim) 85 | self.k_quantizer.find_params(token_wise_k) 86 | k = ( 87 | self.k_quantizer(token_wise_k) 88 | .reshape((bsz, seq_len, num_heads, head_dim)) 89 | .transpose(1, 2) 90 | .to(q) 91 | ) 92 | else: # head-wise quantization 93 | per_head_k = k.view(-1, head_dim) 94 | self.k_quantizer.find_params(per_head_k) 95 | k = ( 96 | self.k_quantizer(per_head_k) 97 | .reshape((bsz, num_heads, seq_len, head_dim)) 98 | .to(q) 99 | ) 100 | 101 | self.k_quantizer.free() 102 | 103 | return q, k 104 | 105 | 106 | def add_qk_rotation_wrapper_after_function_call_in_forward( 107 | module, 108 | function_name, 109 | *args, 110 | **kwargs, 111 | ): 112 | """ 113 | This function adds a rotation wrapper after the output of a function call in forward. 114 | Only calls directly in the forward function are affected. calls by other functions called in forward are not affected. 115 | """ 116 | import functools 117 | 118 | from utils import monkeypatch 119 | 120 | attr_name = f"{function_name}_qk_rotation_wrapper" 121 | assert not hasattr(module, attr_name) 122 | wrapper = monkeypatch.add_wrapper_after_function_call_in_method( 123 | module, 124 | "forward", 125 | function_name, 126 | functools.partial(QKRotationWrapper, *args, **kwargs), 127 | ) 128 | setattr(module, attr_name, wrapper) 129 | -------------------------------------------------------------------------------- /spinquant/train_utils/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import transformers 12 | 13 | from train_utils import apply_r3_r4, rtn_utils 14 | from utils import fuse_norm_utils, hadamard_utils, quant_utils, utils 15 | 16 | 17 | def prepare_model(args, model): 18 | transformers.set_seed(args.seed) 19 | model.eval() 20 | 21 | # Rotate the weights 22 | fuse_norm_utils.fuse_layer_norms(model) 23 | apply_r3_r4.rotate_model(model, args) 24 | utils.cleanup_memory(verbos=True) 25 | 26 | quant_utils.add_actquant(model) # Add Activation Wrapper to the model 27 | qlayers = quant_utils.find_qlayers(model) 28 | for name in qlayers: 29 | if "down_proj" in name: 30 | had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size) 31 | qlayers[name].online_full_had = True 32 | qlayers[name].had_K = had_K 33 | qlayers[name].K = K 34 | qlayers[name].fp32_had = args.fp32_had 35 | 36 | if args.w_bits < 16: 37 | quantizers = rtn_utils.rtn_fwrd(model, "cuda", args) 38 | 39 | # Add Input Quantization 40 | if args.a_bits < 16 or args.v_bits < 16: 41 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper]) 42 | down_proj_groupsize = -1 43 | if args.a_groupsize > 0: 44 | down_proj_groupsize = utils.llama_down_proj_groupsize( 45 | model, args.a_groupsize 46 | ) 47 | 48 | for name in qlayers: 49 | layer_input_bits = args.a_bits 50 | layer_groupsize = args.a_groupsize 51 | layer_a_sym = not (args.a_asym) 52 | layer_a_clip = args.a_clip_ratio 53 | 54 | num_heads = model.config.num_attention_heads 55 | model_dim = model.config.hidden_size 56 | head_dim = model_dim // num_heads 57 | 58 | if "v_proj" in name and args.v_bits < 16: # Set the v_proj precision 59 | v_groupsize = head_dim 60 | qlayers[name].out_quantizer.configure( 61 | bits=args.v_bits, 62 | groupsize=v_groupsize, 63 | sym=not (args.v_asym), 64 | clip_ratio=args.v_clip_ratio, 65 | ) 66 | 67 | if "o_proj" in name: 68 | layer_groupsize = head_dim 69 | 70 | if "lm_head" in name: # Skip lm_head quantization 71 | layer_input_bits = 16 72 | 73 | if "down_proj" in name: # Set the down_proj precision 74 | if args.int8_down_proj: 75 | layer_input_bits = 8 76 | layer_groupsize = down_proj_groupsize 77 | 78 | qlayers[name].quantizer.configure( 79 | bits=layer_input_bits, 80 | groupsize=layer_groupsize, 81 | sym=layer_a_sym, 82 | clip_ratio=layer_a_clip, 83 | ) 84 | 85 | if args.k_bits < 16: 86 | if args.k_pre_rope: 87 | raise NotImplementedError("Pre-RoPE quantization is not supported yet!") 88 | else: 89 | rope_function_name = "apply_rotary_pos_emb" 90 | layers = model.model.layers 91 | k_quant_config = { 92 | "k_bits": args.k_bits, 93 | "k_groupsize": args.k_groupsize, 94 | "k_sym": not (args.k_asym), 95 | "k_clip_ratio": args.k_clip_ratio, 96 | } 97 | for layer in layers: 98 | apply_r3_r4.add_qk_rotation_wrapper_after_function_call_in_forward( 99 | layer.self_attn, 100 | rope_function_name, 101 | config=model.config, 102 | **k_quant_config, 103 | ) 104 | 105 | return model 106 | -------------------------------------------------------------------------------- /spinquant/train_utils/optimizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is originally from: https://github.com/JunLi-Galios/Optimization-on-Stiefel-Manifold-via-Cayley-Transform/blob/master/stiefel_optimizer.py 9 | 10 | import random 11 | 12 | import torch 13 | from torch.optim.optimizer import Optimizer 14 | 15 | 16 | def unit(v, dim: int = 1, eps: float = 1e-8): 17 | vnorm = norm(v, dim) 18 | return v / vnorm.add(eps), vnorm 19 | 20 | 21 | def norm(v, dim: int = 1): 22 | assert len(v.size()) == 2 23 | return v.norm(p=2, dim=dim, keepdim=True) 24 | 25 | 26 | def matrix_norm_one(W): 27 | out = torch.abs(W) 28 | out = torch.sum(out, dim=0) 29 | out = torch.max(out) 30 | return out 31 | 32 | 33 | def Cayley_loop(X, W, tan_vec, t): # 34 | [n, p] = X.size() 35 | Y = X + t * tan_vec 36 | for i in range(5): 37 | Y = X + t * torch.matmul(W, 0.5 * (X + Y)) 38 | 39 | return Y.t() 40 | 41 | 42 | def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n 43 | [p, n] = tan_vec.size() 44 | tan_vec.t_() 45 | q, r = torch.linalg.qr(tan_vec) 46 | d = torch.diag(r, 0) 47 | ph = d.sign() 48 | q *= ph.expand_as(q) 49 | q.t_() 50 | 51 | return q 52 | 53 | 54 | episilon = 1e-8 55 | 56 | 57 | class SGDG(Optimizer): 58 | r"""This optimizer updates variables with two different routines 59 | based on the boolean variable 'stiefel'. 60 | 61 | If stiefel is True, the variables will be updated by SGD-G proposed 62 | as decorrelated weight matrix. 63 | 64 | If stiefel is False, the variables will be updated by SGD. 65 | This routine was taken from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py. 66 | 67 | Args: 68 | params (iterable): iterable of parameters to optimize or dicts defining 69 | parameter groups 70 | 71 | -- common parameters 72 | lr (float): learning rate 73 | momentum (float, optional): momentum factor (default: 0) 74 | stiefel (bool, optional): whether to use SGD-G (default: False) 75 | 76 | -- parameters in case stiefel is False 77 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 78 | dampening (float, optional): dampening for momentum (default: 0) 79 | nesterov (bool, optional): enables Nesterov momentum (default: False) 80 | 81 | -- parameters in case stiefel is True 82 | omega (float, optional): orthogonality regularization factor (default: 0) 83 | grad_clip (float, optional): threshold for gradient norm clipping (default: None) 84 | """ 85 | 86 | def __init__( 87 | self, 88 | params, 89 | lr, 90 | momentum: int = 0, 91 | dampening: int = 0, 92 | weight_decay: int = 0, 93 | nesterov: bool = False, 94 | stiefel: bool = False, 95 | omega: int = 0, 96 | grad_clip=None, 97 | ) -> None: 98 | defaults = dict( 99 | lr=lr, 100 | momentum=momentum, 101 | dampening=dampening, 102 | weight_decay=weight_decay, 103 | nesterov=nesterov, 104 | stiefel=stiefel, 105 | omega=0, 106 | grad_clip=grad_clip, 107 | ) 108 | if nesterov and (momentum <= 0 or dampening != 0): 109 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 110 | super(SGDG, self).__init__(params, defaults) 111 | 112 | def __setstate__(self, state) -> None: 113 | super(SGDG, self).__setstate__(state) 114 | for group in self.param_groups: 115 | group.setdefault("nesterov", False) 116 | 117 | def step(self, closure=None): 118 | """Performs a single optimization step. 119 | 120 | Arguments: 121 | closure (callable, optional): A closure that reevaluates the model 122 | and returns the loss. 123 | """ 124 | loss = None 125 | if closure is not None: 126 | loss = closure() 127 | 128 | for group in self.param_groups: 129 | momentum = group["momentum"] 130 | stiefel = group["stiefel"] 131 | 132 | for p in group["params"]: 133 | if p.grad is None: 134 | continue 135 | 136 | unity, _ = unit(p.data.view(p.size()[0], -1)) 137 | if stiefel and unity.size()[0] <= unity.size()[1]: 138 | weight_decay = group["weight_decay"] 139 | dampening = group["dampening"] 140 | nesterov = group["nesterov"] 141 | 142 | rand_num = random.randint(1, 101) 143 | if rand_num == 1: 144 | unity = qr_retraction(unity) 145 | 146 | g = p.grad.data.view(p.size()[0], -1) 147 | 148 | lr = group["lr"] 149 | 150 | param_state = self.state[p] 151 | if "momentum_buffer" not in param_state: 152 | param_state["momentum_buffer"] = torch.zeros(g.t().size()) 153 | if p.is_cuda: 154 | param_state["momentum_buffer"] = param_state[ 155 | "momentum_buffer" 156 | ].cuda() 157 | 158 | V = param_state["momentum_buffer"] 159 | V = momentum * V - g.t() 160 | MX = torch.mm(V, unity) 161 | XMX = torch.mm(unity, MX) 162 | XXMX = torch.mm(unity.t(), XMX) 163 | W_hat = MX - 0.5 * XXMX 164 | W = W_hat - W_hat.t() 165 | t = 0.5 * 2 / (matrix_norm_one(W) + episilon) 166 | alpha = min(t, lr) 167 | 168 | p_new = Cayley_loop(unity.t(), W, V, alpha) 169 | V_new = torch.mm(W, unity.t()) # n-by-p 170 | # check_identity(p_new.t()) 171 | p.data.copy_(p_new.view(p.size())) 172 | V.copy_(V_new) 173 | 174 | else: 175 | d_p = p.grad.data 176 | # defined. 177 | try: 178 | if weight_decay != 0: 179 | # defined. 180 | d_p.add_(weight_decay, p.data) 181 | except: 182 | pass 183 | if momentum != 0: 184 | param_state = self.state[p] 185 | if "momentum_buffer" not in param_state: 186 | buf = param_state["momentum_buffer"] = d_p.clone() 187 | else: 188 | buf = param_state["momentum_buffer"] 189 | # always defined. 190 | buf.mul_(momentum).add_(1 - dampening, d_p) 191 | # defined. 192 | if nesterov: 193 | d_p = d_p.add(momentum, buf) 194 | else: 195 | d_p = buf 196 | 197 | p.data.add_(-group["lr"], d_p) 198 | 199 | return loss 200 | -------------------------------------------------------------------------------- /spinquant/train_utils/quant_linear.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch._tensor import Tensor 11 | 12 | 13 | class QuantizeLinear(nn.Linear): 14 | def forward( 15 | self, 16 | input: Tensor, 17 | R1=None, 18 | R2=None, 19 | transpose=False, 20 | ) -> Tensor: 21 | # quantize weight 22 | if R1 is not None: 23 | dtype = self.weight.dtype 24 | if not transpose: 25 | weight = (self.weight.to(torch.float64) @ R1.to(torch.float64)).to( 26 | dtype 27 | ) 28 | else: 29 | weight = (R1.T.to(torch.float64) @ self.weight.to(torch.float64)).to( 30 | dtype 31 | ) 32 | if R2 is not None: 33 | # Each head dim = 128 for Llama model 34 | had_dim = R2.shape[0] 35 | dtype = weight.dtype 36 | if transpose: 37 | W_ = weight 38 | init_shape = W_.shape 39 | temp = W_.reshape(-1, init_shape[-1] // had_dim, had_dim) 40 | temp = temp.to(torch.float64) @ R2.to(torch.float64) 41 | weight = temp.reshape(init_shape) 42 | else: 43 | W_ = weight.t() 44 | transposed_shape = W_.shape 45 | temp = W_.reshape(-1, transposed_shape[-1] // had_dim, had_dim) 46 | temp = temp.to(torch.float64) @ R2.to(torch.float64) 47 | weight = temp.reshape(transposed_shape).t() 48 | weight = weight.to(dtype) 49 | else: 50 | weight = self.weight 51 | if hasattr(self, "quantizer"): 52 | dtype = weight.dtype 53 | self.quantizer.find_params(weight.data) 54 | weight = self.quantizer.quantize(weight).to(dtype) 55 | 56 | return nn.functional.linear(input, weight, self.bias) 57 | -------------------------------------------------------------------------------- /spinquant/train_utils/rtn_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import torch 12 | import tqdm 13 | 14 | from train_utils.quant_linear import QuantizeLinear 15 | from utils import quant_utils, utils 16 | 17 | 18 | @torch.no_grad() 19 | def rtn_fwrd(model, dev, args): 20 | """ 21 | From GPTQ repo 22 | """ 23 | # assert args.w_groupsize == -1, "Groupsize not supported in RTN!" 24 | layers = model.model.layers 25 | torch.cuda.empty_cache() 26 | 27 | quantizers = {} 28 | 29 | for i in tqdm.tqdm(range(len(layers)), desc="Inserting weight quantizer"): 30 | layer = layers[i].to(dev) 31 | 32 | subset = quant_utils.find_qlayers( 33 | layer, layers=[torch.nn.Linear, QuantizeLinear] 34 | ) 35 | 36 | for name in subset: 37 | layer_weight_bits = args.w_bits 38 | if "lm_head" in name: 39 | layer_weight_bits = 16 40 | continue 41 | if args.int8_down_proj and "down_proj" in name: 42 | layer_weight_bits = 8 43 | 44 | quantizer = quant_utils.WeightQuantizer() 45 | quantizer.configure( 46 | layer_weight_bits, 47 | perchannel=True, 48 | sym=not (args.w_asym), 49 | mse=args.w_clip, 50 | weight_groupsize=args.w_groupsize, 51 | ) 52 | subset[name].quantizer = quantizer 53 | 54 | quantizers["model.layers.%d.%s" % (i, name)] = quantizer.cpu() 55 | layers[i] = layer.cpu() 56 | torch.cuda.empty_cache() 57 | del layer 58 | 59 | utils.cleanup_memory(verbos=True) 60 | return quantizers 61 | -------------------------------------------------------------------------------- /spinquant/utils/convert_to_executorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Adopt from https://fburl.com/code/b4jqkgir 16 | 17 | import math 18 | 19 | from typing import Any, Tuple 20 | 21 | import torch 22 | from torch._tensor import Tensor 23 | 24 | 25 | def compute_intermediate_size(n) -> int: 26 | return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 27 | 28 | 29 | def shard_tensor(tensor: Tensor, dim: int, num_shards: int) -> Tuple[Tensor, ...]: 30 | total_size = tensor.shape[dim] 31 | n_size_per_shard = total_size // num_shards 32 | 33 | ret_tensors = torch.split(tensor, n_size_per_shard, dim) 34 | return ret_tensors 35 | 36 | 37 | def write_model_llama( 38 | hf_state_dict, 39 | config, 40 | num_shards: int, 41 | ): 42 | n_layers = config.num_hidden_layers 43 | n_heads = config.num_attention_heads 44 | dim = config.hidden_size 45 | num_key_value_heads = config.num_key_value_heads 46 | assert n_heads % num_shards == 0 47 | assert dim % (n_heads * 2) == 0 48 | 49 | def un_permute(w, is_query=True): 50 | return ( 51 | w.view( 52 | n_heads if is_query else num_key_value_heads, 53 | 2, 54 | dim // n_heads // 2, 55 | dim, 56 | ) 57 | .transpose(1, 2) 58 | .reshape(-1, dim) 59 | ) 60 | 61 | model_shard_dicts = [{} for _ in range(num_shards)] 62 | for layer_i in range(n_layers): 63 | ## store the same in every shard 64 | for shard_i in range(num_shards): 65 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention_norm.weight"] = ( 66 | hf_state_dict[f"model.layers.{layer_i}.input_layernorm.weight"].clone() 67 | ) 68 | 69 | for shard_i in range(num_shards): 70 | model_shard_dicts[shard_i][f"layers.{layer_i}.ffn_norm.weight"] = ( 71 | hf_state_dict[ 72 | f"model.layers.{layer_i}.post_attention_layernorm.weight" 73 | ].clone() 74 | ) 75 | 76 | ### int weight 77 | self_attn_q_proj_weight = hf_state_dict[ 78 | f"model.layers.{layer_i}.self_attn.q_proj.module.int_weight" 79 | ] 80 | self_attn_q_proj_weight = un_permute(self_attn_q_proj_weight) 81 | list_self_attn_q_proj_weight = shard_tensor( 82 | self_attn_q_proj_weight, 0, num_shards 83 | ) 84 | for shard_i in range(num_shards): 85 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wq.weight"] = ( 86 | list_self_attn_q_proj_weight[shard_i].clone().to(torch.int8) 87 | ) 88 | 89 | ### 90 | self_attn_k_proj_weight = hf_state_dict[ 91 | f"model.layers.{layer_i}.self_attn.k_proj.module.int_weight" 92 | ] 93 | self_attn_k_proj_weight = un_permute(self_attn_k_proj_weight, is_query=False) 94 | list_self_attn_k_proj_weight = shard_tensor( 95 | self_attn_k_proj_weight, 0, num_shards 96 | ) 97 | for shard_i in range(num_shards): 98 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wk.weight"] = ( 99 | list_self_attn_k_proj_weight[shard_i].clone().to(torch.int8) 100 | ) 101 | 102 | ### 103 | self_attn_v_proj_weight = hf_state_dict[ 104 | f"model.layers.{layer_i}.self_attn.v_proj.module.int_weight" 105 | ] 106 | list_self_attn_v_proj_weight = shard_tensor( 107 | self_attn_v_proj_weight, 0, num_shards 108 | ) 109 | for shard_i in range(num_shards): 110 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wv.weight"] = ( 111 | list_self_attn_v_proj_weight[shard_i].clone().to(torch.int8) 112 | ) 113 | 114 | ### 115 | self_attn_o_proj_weight = hf_state_dict[ 116 | f"model.layers.{layer_i}.self_attn.o_proj.module.int_weight" 117 | ] 118 | list_self_attn_o_proj_weight = shard_tensor( 119 | self_attn_o_proj_weight, 1, num_shards 120 | ) 121 | for shard_i in range(num_shards): 122 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wo.weight"] = ( 123 | list_self_attn_o_proj_weight[shard_i].clone().to(torch.int8) 124 | ) 125 | 126 | ### 127 | mlp_gate_proj_weight = hf_state_dict[ 128 | f"model.layers.{layer_i}.mlp.gate_proj.module.int_weight" 129 | ] 130 | list_mlp_gate_proj_weight = shard_tensor(mlp_gate_proj_weight, 0, num_shards) 131 | for shard_i in range(num_shards): 132 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w1.weight"] = ( 133 | list_mlp_gate_proj_weight[shard_i].clone().to(torch.int8) 134 | ) 135 | 136 | ### 137 | mlp_down_proj_weight = hf_state_dict[ 138 | f"model.layers.{layer_i}.mlp.down_proj.module.int_weight" 139 | ] 140 | list_mlp_down_proj_weight = shard_tensor(mlp_down_proj_weight, 1, num_shards) 141 | for shard_i in range(num_shards): 142 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w2.weight"] = ( 143 | list_mlp_down_proj_weight[shard_i].clone().to(torch.int8) 144 | ) 145 | 146 | ### 147 | mlp_up_proj_weight = hf_state_dict[ 148 | f"model.layers.{layer_i}.mlp.up_proj.module.int_weight" 149 | ] 150 | list_mlp_up_proj_weight = shard_tensor(mlp_up_proj_weight, 0, num_shards) 151 | for shard_i in range(num_shards): 152 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w3.weight"] = ( 153 | list_mlp_up_proj_weight[shard_i].clone().to(torch.int8) 154 | ) 155 | 156 | ### scale 157 | self_attn_q_proj_weight = hf_state_dict[ 158 | f"model.layers.{layer_i}.self_attn.q_proj.module.scale" 159 | ] 160 | self_attn_q_proj_weight = un_permute(self_attn_q_proj_weight) 161 | list_self_attn_q_proj_weight = shard_tensor( 162 | self_attn_q_proj_weight, 0, num_shards 163 | ) 164 | for shard_i in range(num_shards): 165 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wq.scale"] = ( 166 | list_self_attn_q_proj_weight[shard_i].clone() 167 | ) 168 | 169 | ### 170 | self_attn_k_proj_weight = hf_state_dict[ 171 | f"model.layers.{layer_i}.self_attn.k_proj.module.scale" 172 | ] 173 | self_attn_k_proj_weight = un_permute(self_attn_k_proj_weight, is_query=False) 174 | list_self_attn_k_proj_weight = shard_tensor( 175 | self_attn_k_proj_weight, 0, num_shards 176 | ) 177 | for shard_i in range(num_shards): 178 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wk.scale"] = ( 179 | list_self_attn_k_proj_weight[shard_i].clone() 180 | ) 181 | 182 | ### 183 | self_attn_v_proj_weight = hf_state_dict[ 184 | f"model.layers.{layer_i}.self_attn.v_proj.module.scale" 185 | ] 186 | list_self_attn_v_proj_weight = shard_tensor( 187 | self_attn_v_proj_weight, 0, num_shards 188 | ) 189 | for shard_i in range(num_shards): 190 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wv.scale"] = ( 191 | list_self_attn_v_proj_weight[shard_i].clone() 192 | ) 193 | 194 | ### 195 | self_attn_o_proj_weight = hf_state_dict[ 196 | f"model.layers.{layer_i}.self_attn.o_proj.module.scale" 197 | ] 198 | list_self_attn_o_proj_weight = shard_tensor( 199 | self_attn_o_proj_weight, 1, num_shards 200 | ) 201 | for shard_i in range(num_shards): 202 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wo.scale"] = ( 203 | list_self_attn_o_proj_weight[shard_i].clone() 204 | ) 205 | 206 | ### 207 | mlp_gate_proj_weight = hf_state_dict[ 208 | f"model.layers.{layer_i}.mlp.gate_proj.module.scale" 209 | ] 210 | list_mlp_gate_proj_weight = shard_tensor(mlp_gate_proj_weight, 0, num_shards) 211 | for shard_i in range(num_shards): 212 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w1.scale"] = ( 213 | list_mlp_gate_proj_weight[shard_i].clone() 214 | ) 215 | 216 | ### 217 | mlp_down_proj_weight = hf_state_dict[ 218 | f"model.layers.{layer_i}.mlp.down_proj.module.scale" 219 | ] 220 | list_mlp_down_proj_weight = shard_tensor(mlp_down_proj_weight, 1, num_shards) 221 | for shard_i in range(num_shards): 222 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w2.scale"] = ( 223 | list_mlp_down_proj_weight[shard_i].clone() 224 | ) 225 | 226 | ### 227 | mlp_up_proj_weight = hf_state_dict[ 228 | f"model.layers.{layer_i}.mlp.up_proj.module.scale" 229 | ] 230 | list_mlp_up_proj_weight = shard_tensor(mlp_up_proj_weight, 0, num_shards) 231 | for shard_i in range(num_shards): 232 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w3.scale"] = ( 233 | list_mlp_up_proj_weight[shard_i].clone() 234 | ) 235 | 236 | ## 237 | for shard_i in range(num_shards): 238 | model_shard_dicts[shard_i]["norm.weight"] = hf_state_dict[ 239 | "model.norm.weight" 240 | ].clone() 241 | 242 | list_embed_tokens_weight = shard_tensor( 243 | hf_state_dict["model.embed_tokens.int_weight"], 1, num_shards 244 | ) 245 | list_lm_head_weight = shard_tensor( 246 | hf_state_dict["lm_head.module.int_weight"], 0, num_shards 247 | ) 248 | for shard_i in range(num_shards): 249 | model_shard_dicts[shard_i]["tok_embeddings.weight"] = ( 250 | list_embed_tokens_weight[shard_i].clone().to(torch.int8) 251 | ) 252 | model_shard_dicts[shard_i]["output.weight"] = ( 253 | list_lm_head_weight[shard_i].clone().to(torch.int8) 254 | ) 255 | 256 | list_embed_tokens_weight = shard_tensor( 257 | hf_state_dict["model.embed_tokens.scale"], 1, num_shards 258 | ) 259 | list_lm_head_weight = shard_tensor( 260 | hf_state_dict["lm_head.module.scale"], 0, num_shards 261 | ) 262 | for shard_i in range(num_shards): 263 | model_shard_dicts[shard_i]["tok_embeddings.scale"] = ( 264 | list_embed_tokens_weight[shard_i].clone().to(torch.float) 265 | ) 266 | model_shard_dicts[shard_i]["output.scale"] = ( 267 | list_lm_head_weight[shard_i].clone().to(torch.float) 268 | ) 269 | 270 | return model_shard_dicts 271 | 272 | 273 | def sanitize_checkpoint_from_spinquant( 274 | checkpoint: Any, 275 | group_size: int, 276 | ): 277 | """ 278 | Sanitize the SpinQuant checkpoint. 279 | - Renames 'scale' to 'scales' 280 | - Groups scales 281 | - Removes 'o_weight' 282 | - Converts all tensors to contiguous format 283 | """ 284 | keys_to_rename = [] 285 | keys_to_remove = [] 286 | for k, _ in checkpoint.items(): 287 | if k.endswith(".scale"): 288 | new_key = k + "s" 289 | keys_to_rename.append((k, new_key)) 290 | 291 | for old_key, new_key in keys_to_rename: 292 | old_val = checkpoint.pop(old_key) 293 | checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size] 294 | for k in keys_to_remove: 295 | checkpoint.pop(k) 296 | for k, v in checkpoint.items(): 297 | checkpoint[k] = v.contiguous() 298 | return checkpoint 299 | -------------------------------------------------------------------------------- /spinquant/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import random 12 | from typing import Any, Dict 13 | 14 | import datasets 15 | import torch 16 | import transformers 17 | 18 | 19 | def get_wikitext2(nsamples=128, seed=0, seqlen=2048, model="", tokenizer=None, eval_mode=False): 20 | if tokenizer is None: 21 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 22 | 23 | if eval_mode: 24 | testdata = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[ 25 | "test" 26 | ] 27 | testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") 28 | return testenc 29 | else: 30 | traindata = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[ 31 | "train" 32 | ] 33 | trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") 34 | random.seed(seed) 35 | trainloader = [] 36 | for _ in range(nsamples): 37 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 38 | j = i + seqlen 39 | inp = trainenc.input_ids[:, i:j] 40 | tar = inp.clone() 41 | tar[:, :-1] = -100 42 | trainloader.append((inp, tar)) 43 | return trainloader 44 | 45 | 46 | class CustomJsonDataset(torch.utils.data.IterableDataset): 47 | def __init__(self, dataset, tokenizer, block_size: int = 1024) -> None: 48 | raw_data = dataset 49 | self.tokenizer = tokenizer 50 | self.block_size = block_size 51 | tokenized_datasets = [] 52 | for d in raw_data: 53 | tokenized_datasets.append(self.tokenize_function(d)) 54 | 55 | grouped_dataset = self.group_texts(tokenized_datasets) 56 | self.input_ids = grouped_dataset["input_ids"] 57 | self.labels = grouped_dataset["labels"] 58 | self.data = [ 59 | dict(input_ids=self.input_ids[i], labels=self.labels[i]) 60 | for i in range(len(self.input_ids)) 61 | ] 62 | 63 | def __len__(self) -> int: 64 | return len(self.data) 65 | 66 | def __getitem__(self, i) -> Dict[str, Any]: 67 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 68 | 69 | def __iter__(self): 70 | return iter(self.data) 71 | 72 | def tokenize_function(self, examples): 73 | return self.tokenizer(examples["text"]) 74 | 75 | def group_texts(self, examples): 76 | # Concatenate all texts. 77 | # Initialize an empty dictionary 78 | concatenated_examples = {} 79 | 80 | # Loop through the list of dictionaries 81 | for d in examples: 82 | # Loop through the keys in each dictionary 83 | for key in d.keys(): 84 | # If the key is not already a key in the dict_of_lists, create a new list 85 | if key not in concatenated_examples: 86 | concatenated_examples[key] = [] 87 | # Append the value to the list associated with the key in dict_of_lists 88 | concatenated_examples[key].extend(d[key]) 89 | total_length = len(concatenated_examples["input_ids"]) 90 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 91 | # customize this part to your needs. 92 | if total_length >= self.block_size: 93 | total_length = (total_length // self.block_size) * self.block_size 94 | # Split by chunks of max_len. 95 | result = { 96 | k: [ 97 | t[i : i + self.block_size] 98 | for i in range(0, total_length, self.block_size) 99 | ] 100 | for k, t in concatenated_examples.items() 101 | } 102 | result["labels"] = result["input_ids"].copy() 103 | return result 104 | -------------------------------------------------------------------------------- /spinquant/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import logging 12 | import os 13 | 14 | import torch 15 | from tqdm import tqdm 16 | 17 | from utils import model_utils 18 | 19 | 20 | @torch.no_grad() 21 | def evaluator(model, testenc, dev, args): 22 | model.eval() 23 | 24 | use_cache = model.config.use_cache 25 | model.config.use_cache = False 26 | 27 | layers = model.model.layers 28 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 29 | model.model.rotary_emb = model.model.rotary_emb.to(dev) 30 | 31 | layers[0] = layers[0].to(dev) 32 | 33 | # Convert the whole text of evaluation dataset into batches of sequences. 34 | input_ids = testenc.input_ids # (1, text_len) 35 | nsamples = input_ids.numel() // model.seqlen # The tail is truncated. 36 | input_ids = ( 37 | input_ids[:, : nsamples * model.seqlen].view(nsamples, model.seqlen).to(dev) 38 | ) # (nsamples, seqlen) 39 | 40 | batch_size = args.bsz 41 | input_ids = [input_ids[i : i + batch_size] for i in range(0, nsamples, batch_size)] 42 | nbatches = len(input_ids) 43 | 44 | dtype = next(iter(model.parameters())).dtype 45 | # The input of the first decoder layer. 46 | inps = torch.zeros( 47 | (nbatches, batch_size, model.seqlen, model.config.hidden_size), 48 | dtype=dtype, 49 | device=dev, 50 | ) 51 | inps = [0] * nbatches 52 | cache = {"i": 0, "attention_mask": None} 53 | 54 | class Catcher(torch.nn.Module): 55 | def __init__(self, module): 56 | super().__init__() 57 | self.module = module 58 | 59 | def forward(self, inp, **kwargs): 60 | inps[cache["i"]] = inp 61 | cache["i"] += 1 62 | cache["attention_mask"] = kwargs["attention_mask"] 63 | cache["position_embeddings"] = kwargs["position_embeddings"] 64 | raise ValueError 65 | 66 | layers[0] = Catcher(layers[0]) 67 | 68 | for i in range(nbatches): 69 | batch = input_ids[i] 70 | try: 71 | model(batch) 72 | except ValueError: 73 | pass 74 | layers[0] = layers[0].module 75 | layers[0] = layers[0].cpu() 76 | 77 | model.model.embed_tokens = model.model.embed_tokens.cpu() 78 | model.model.rotary_emb = model.model.rotary_emb.cpu() 79 | position_embeddings = cache["position_embeddings"] 80 | 81 | torch.cuda.empty_cache() 82 | outs = [0] * nbatches 83 | attention_mask = cache["attention_mask"] 84 | 85 | for i in tqdm(range(len(layers)), desc="(Eval) Layers"): 86 | layer = layers[i].to(dev) 87 | 88 | # Dump the layer input and output 89 | if args.capture_layer_io and args.layer_idx == i: 90 | captured_io = model_utils.capture_layer_io(layer, inps) 91 | save_path = model_utils.get_layer_io_save_path(args) 92 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 93 | torch.save(captured_io, save_path) 94 | logging.info(f"Dumped layer input and output to: {save_path}") 95 | 96 | for j in range(nbatches): 97 | outs[j] = layer( 98 | inps[j], 99 | attention_mask=attention_mask, 100 | position_embeddings=position_embeddings, 101 | )[0] 102 | layers[i] = layer.cpu() 103 | del layer 104 | torch.cuda.empty_cache() 105 | inps, outs = outs, inps 106 | 107 | if model.model.norm is not None: 108 | model.model.norm = model.model.norm.to(dev) 109 | 110 | model.lm_head = model.lm_head.to(dev) 111 | nlls = [] 112 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none") 113 | for i in range(nbatches): 114 | hidden_states = inps[i] 115 | if model.model.norm is not None: 116 | hidden_states = model.model.norm(hidden_states) 117 | lm_logits = model.lm_head(hidden_states) 118 | shift_logits = lm_logits[:, :-1, :] 119 | shift_labels = input_ids[i][:, 1:] 120 | loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels) 121 | neg_log_likelihood = loss.float().mean(dim=1) 122 | nlls.append(neg_log_likelihood) 123 | nlls_tensor = torch.cat(nlls) 124 | ppl = torch.exp(nlls_tensor.mean()) 125 | model.config.use_cache = use_cache 126 | logging.info(f"\n WikiText2 PPL: {ppl.item():.3f}") 127 | return ppl.item() 128 | -------------------------------------------------------------------------------- /spinquant/utils/fuse_norm_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import typing 12 | import torch 13 | 14 | 15 | def fuse_ln_linear( 16 | layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear] 17 | ) -> None: 18 | """ 19 | fuse the linear operations in Layernorm into the adjacent linear blocks. 20 | """ 21 | for linear in linear_layers: 22 | linear_dtype = linear.weight.dtype 23 | 24 | # Calculating new weight and bias 25 | W_ = linear.weight.data.double() 26 | linear.weight.data = (W_ * layernorm.weight.double()).to(linear_dtype) 27 | 28 | if hasattr(layernorm, "bias"): 29 | if linear.bias is None: 30 | linear.bias = torch.nn.Parameter( 31 | torch.zeros(linear.out_features, dtype=torch.float64) 32 | ) 33 | linear.bias.data = linear.bias.data.double() + torch.matmul( 34 | W_, layernorm.bias.double() 35 | ) 36 | linear.bias.data = linear.bias.data.to(linear_dtype) 37 | 38 | 39 | def fuse_layer_norms(model): 40 | kwargs = {"model": model} 41 | 42 | # Embedding fusion 43 | for W in [model.model.embed_tokens]: 44 | W_ = W.weight.data.double() 45 | W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) 46 | 47 | layers = [layer for layer in model.model.layers] 48 | 49 | # Fuse the linear operations in Layernorm into the adjacent linear blocks. 50 | for layer in layers: 51 | # fuse the input layernorms into the linear layers 52 | fuse_ln_linear( 53 | layer.post_attention_layernorm, [layer.mlp.up_proj, layer.mlp.gate_proj] 54 | ) 55 | fuse_ln_linear( 56 | layer.input_layernorm, 57 | [ 58 | layer.self_attn.q_proj, 59 | layer.self_attn.k_proj, 60 | layer.self_attn.v_proj, 61 | ], 62 | ) 63 | 64 | W_norm = layer.post_attention_layernorm.weight.data 65 | layer.post_attention_layernorm.weight.data = torch.ones_like(W_norm) 66 | W_norm = layer.input_layernorm.weight.data 67 | layer.input_layernorm.weight.data = torch.ones_like(W_norm) 68 | 69 | fuse_ln_linear( 70 | model.model.norm, 71 | [model.lm_head], 72 | ) 73 | W_norm = model.model.norm.weight.data 74 | model.model.norm.weight.data = torch.ones_like(W_norm) 75 | -------------------------------------------------------------------------------- /spinquant/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import os 12 | import torch 13 | 14 | 15 | def skip(*args, **kwargs): 16 | # This is a helper function to save time during the initialization! 17 | pass 18 | 19 | 20 | def get_layer_io_save_path(args): 21 | return os.path.join(args.save_path, "layer_io", f"{args.layer_idx:03d}.pt") 22 | 23 | 24 | def capture_layer_io(layer, layer_input): 25 | def hook_factory(module_name, captured_vals, is_input): 26 | def hook(module, input, output): 27 | if is_input: 28 | captured_vals[module_name].append(input[0].detach().cpu()) 29 | else: 30 | captured_vals[module_name].append(output.detach().cpu()) 31 | 32 | return hook 33 | 34 | handles = [] 35 | 36 | captured_inputs = { 37 | "k_proj": [], # q_proj, v_proj has the same input as k_proj 38 | "o_proj": [], 39 | "gate_proj": [], # up_proj has the same input as gate_proj 40 | "down_proj": [], 41 | } 42 | 43 | captured_outputs = { 44 | "v_proj": [], 45 | } 46 | 47 | for name in captured_inputs.keys(): 48 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None) 49 | handles.append( 50 | module.register_forward_hook(hook_factory(name, captured_inputs, True)) 51 | ) 52 | 53 | for name in captured_outputs.keys(): 54 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None) 55 | handles.append( 56 | module.register_forward_hook(hook_factory(name, captured_outputs, False)) 57 | ) 58 | 59 | # Process each sequence in the batch one by one to avoid OOM. 60 | for seq_idx in range(layer_input.shape[0]): 61 | # Extract the current sequence across all dimensions. 62 | seq = layer_input[seq_idx : seq_idx + 1].to("cuda") 63 | # Perform a forward pass for the current sequence. 64 | layer(seq) 65 | 66 | # After processing all sequences, concatenate the accumulated inputs for each sub-layer across the batch. 67 | for module_name in captured_inputs: 68 | captured_inputs[module_name] = torch.cat(captured_inputs[module_name], dim=0) 69 | for module_name in captured_outputs: 70 | captured_outputs[module_name] = torch.cat(captured_outputs[module_name], dim=0) 71 | 72 | # Cleanup. 73 | for h in handles: 74 | h.remove() 75 | 76 | return {"input": captured_inputs, "output": captured_outputs} 77 | -------------------------------------------------------------------------------- /spinquant/utils/monkeypatch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import copy 12 | import functools 13 | import types 14 | 15 | 16 | def copy_func_with_new_globals(f, globals=None): 17 | """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" 18 | if globals is None: 19 | globals = f.__globals__ 20 | g = types.FunctionType( 21 | f.__code__, 22 | globals, 23 | name=f.__name__, 24 | argdefs=f.__defaults__, 25 | closure=f.__closure__, 26 | ) 27 | g = functools.update_wrapper(g, f) 28 | g.__module__ = f.__module__ 29 | g.__kwdefaults__ = copy.copy(f.__kwdefaults__) 30 | return g 31 | 32 | 33 | def add_wrapper_after_function_call_in_method( 34 | module, 35 | method_name, 36 | function_name, 37 | wrapper_fn, 38 | ): 39 | """ 40 | This function adds a wrapper after the output of a function call in the method named `method_name`. 41 | Only calls directly in the method are affected. Calls by other functions called in the method are not affected. 42 | """ 43 | 44 | original_method = getattr(module, method_name).__func__ 45 | method_globals = dict(original_method.__globals__) 46 | wrapper = wrapper_fn(method_globals[function_name]) 47 | method_globals[function_name] = wrapper 48 | new_method = copy_func_with_new_globals(original_method, globals=method_globals) 49 | setattr(module, method_name, new_method.__get__(module)) 50 | return wrapper 51 | -------------------------------------------------------------------------------- /spinquant/utils/process_args.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | from dataclasses import dataclass, field 12 | from typing import Optional, Tuple 13 | 14 | import argparse 15 | import transformers 16 | 17 | 18 | @dataclass 19 | class ModelArguments: 20 | input_model: Optional[str] = field( 21 | default="test-input", metadata={"help": "Input model"} 22 | ) 23 | output_rotation_path: Optional[str] = field( 24 | default="test-output", metadata={"help": "Output rotation checkpoint path"} 25 | ) 26 | optimized_rotation_path: Optional[str] = field( 27 | default=None, metadata={"help": "Optimized rotation checkpoint path"} 28 | ) 29 | access_token: Optional[str] = field( 30 | default=None, 31 | metadata={"help": "Huggingface access token to access gated repo like Llama"}, 32 | ) 33 | 34 | 35 | @dataclass 36 | class TrainingArguments(transformers.TrainingArguments): 37 | cache_dir: Optional[str] = field(default=None) 38 | output_dir: Optional[str] = field(default="/tmp/output/") 39 | model_max_length: Optional[int] = field( 40 | default=2048, 41 | metadata={ 42 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)" 43 | }, 44 | ) 45 | 46 | 47 | def parser_gen(): 48 | parser = argparse.ArgumentParser() 49 | 50 | parser.add_argument( 51 | "--seed", type=int, default=0, help="Random Seed for HuggingFace and PyTorch" 52 | ) 53 | 54 | # Rotation Arguments 55 | parser.add_argument( 56 | "--rotate", 57 | action=argparse.BooleanOptionalAction, 58 | default=False, 59 | help="""Rotate the moodel. This will include online rotation for down-projection and 60 | out-projection. Note that this does not apply rotation to the K/Q and they will be rotated 61 | if we want to quantize the Keys""", 62 | ) 63 | parser.add_argument( 64 | "--rotate_mode", type=str, default="hadamard", choices=["hadamard", "random"] 65 | ) 66 | parser.add_argument( 67 | "--rotation_seed", 68 | type=int, 69 | default=-1, 70 | help="Random Seed for generating random matrix!!", 71 | ) 72 | parser.add_argument( 73 | "--fp32_had", 74 | action=argparse.BooleanOptionalAction, 75 | default=False, 76 | help="Apply Hadamard rotation in FP32 (default: False)", 77 | ) 78 | 79 | # Activation Quantization Arguments 80 | parser.add_argument( 81 | "--a_bits", 82 | type=int, 83 | default=16, 84 | help="""Number of bits for inputs of the Linear layers. This will be 85 | for all the linear layers in the model (including down-projection and out-projection)""", 86 | ) 87 | parser.add_argument( 88 | "--a_groupsize", 89 | type=int, 90 | default=-1, 91 | help="Groupsize for activation quantization. Note that this should be the same as w_groupsize", 92 | ) 93 | parser.add_argument( 94 | "--a_asym", 95 | action=argparse.BooleanOptionalAction, 96 | default=False, 97 | help="ASymmetric Activation quantization (default: False)", 98 | ) 99 | parser.add_argument( 100 | "--a_clip_ratio", 101 | type=float, 102 | default=1.0, 103 | help="Clip ratio for activation quantization. new_max = max * clip_ratio", 104 | ) 105 | 106 | # Weight Quantization Arguments 107 | parser.add_argument( 108 | "--w_bits", 109 | type=int, 110 | default=16, 111 | help="Number of bits for weights of the Linear layers", 112 | ) 113 | parser.add_argument( 114 | "--w_groupsize", 115 | type=int, 116 | default=-1, 117 | help="Groupsize for weight quantization. Note that this should be the same as a_groupsize", 118 | ) 119 | parser.add_argument( 120 | "--w_asym", 121 | action=argparse.BooleanOptionalAction, 122 | default=False, 123 | help="ASymmetric weight quantization (default: False)", 124 | ) 125 | parser.add_argument( 126 | "--w_rtn", 127 | action=argparse.BooleanOptionalAction, 128 | default=False, 129 | help="Quantize the weights using RtN. If the w_bits < 16 and this flag is not set, we use GPTQ", 130 | ) 131 | parser.add_argument( 132 | "--w_clip", 133 | action=argparse.BooleanOptionalAction, 134 | default=False, 135 | help="""Clipping the weight quantization! 136 | We do not support arguments for clipping and we find the best clip ratio during the weight quantization""", 137 | ) 138 | parser.add_argument( 139 | "--nsamples", 140 | type=int, 141 | default=128, 142 | help="Number of calibration data samples for GPTQ.", 143 | ) 144 | parser.add_argument( 145 | "--percdamp", 146 | type=float, 147 | default=0.01, 148 | help="Percent of the average Hessian diagonal to use for dampening.", 149 | ) 150 | parser.add_argument( 151 | "--act_order", 152 | action=argparse.BooleanOptionalAction, 153 | default=False, 154 | help="act-order in GPTQ", 155 | ) 156 | parser.add_argument( 157 | "--asym_calibrate", 158 | action=argparse.BooleanOptionalAction, 159 | default=False, 160 | help="Whether to use GPTAQ", 161 | ) 162 | 163 | 164 | # General Quantization Arguments 165 | parser.add_argument( 166 | "--int8_down_proj", 167 | action=argparse.BooleanOptionalAction, 168 | default=False, 169 | help="Use INT8 for Down Projection! If this set, both weights and activations of this layer will be in INT8", 170 | ) 171 | 172 | # KV-Cache Quantization Arguments 173 | parser.add_argument( 174 | "--v_bits", 175 | type=int, 176 | default=16, 177 | help="""Number of bits for V-cache quantization. 178 | Note that quantizing the V-cache does not need any other rotation""", 179 | ) 180 | parser.add_argument("--v_groupsize", type=int, default=-1) 181 | parser.add_argument( 182 | "--v_asym", 183 | action=argparse.BooleanOptionalAction, 184 | default=False, 185 | help="ASymmetric V-cache quantization", 186 | ) 187 | parser.add_argument( 188 | "--v_clip_ratio", 189 | type=float, 190 | default=1.0, 191 | help="Clip ratio for v-cache quantization. new_max = max * clip_ratio", 192 | ) 193 | 194 | parser.add_argument( 195 | "--k_bits", 196 | type=int, 197 | default=16, 198 | help="""Number of bits for K-cache quantization. 199 | Note that quantizing the K-cache needs another rotation for the keys/queries""", 200 | ) 201 | parser.add_argument("--k_groupsize", type=int, default=-1) 202 | parser.add_argument( 203 | "--k_asym", 204 | action=argparse.BooleanOptionalAction, 205 | default=False, 206 | help="ASymmetric K-cache quantization", 207 | ) 208 | parser.add_argument( 209 | "--k_pre_rope", 210 | action=argparse.BooleanOptionalAction, 211 | default=False, 212 | help="Pre-RoPE quantization for K-cache (not Supported yet!)", 213 | ) 214 | parser.add_argument( 215 | "--k_clip_ratio", 216 | type=float, 217 | default=1.0, 218 | help="Clip ratio for k-cache quantization. new_max = max * clip_ratio", 219 | ) 220 | 221 | # Save/Load Quantized Model Arguments 222 | parser.add_argument( 223 | "--load_qmodel_path", 224 | type=str, 225 | default=None, 226 | help="Load the quantized model from the specified path!", 227 | ) 228 | parser.add_argument( 229 | "--save_qmodel_path", 230 | type=str, 231 | default=None, 232 | help="Save the quantized model to the specified path!", 233 | ) 234 | parser.add_argument( 235 | "--export_to_et", 236 | action=argparse.BooleanOptionalAction, 237 | default=False, 238 | help="Export the quantized model to executorch and save in save_qmodel_path", 239 | ) 240 | parser.add_argument( 241 | "--distribute", 242 | action=argparse.BooleanOptionalAction, 243 | default=False, 244 | help="Distribute the model on multiple GPUs for evaluatione", 245 | ) 246 | 247 | # Experiments Arguments 248 | parser.add_argument( 249 | "--capture_layer_io", 250 | action=argparse.BooleanOptionalAction, 251 | default=False, 252 | help="Capture the input and output of the specified decoder layer and dump into a file", 253 | ) 254 | parser.add_argument( 255 | "--layer_idx", type=int, default=10, help="Which decoder layer to capture" 256 | ) 257 | parser.add_argument( 258 | "--enable_aq_calibration", 259 | action=argparse.BooleanOptionalAction, 260 | default=False, 261 | help="Whether to enable activation quantization during GPTQ (default: False)", 262 | ) 263 | # LM Eval Arguments 264 | parser.add_argument("--lm_eval", action="store_true", help="Evaluate the model on LM Eval tasks.") 265 | parser.add_argument( 266 | '--tasks', 267 | nargs='+', 268 | default=["piqa", "hellaswag", "arc_easy", "arc_challenge", "winogrande", 'boolq'], 269 | ) 270 | parser.add_argument('--lm_eval_batch_size', type=int, default=128, help='Batch size for evaluating with lm eval harness.') 271 | 272 | args, unknown = parser.parse_known_args() 273 | 274 | # assert ( 275 | # args.a_groupsize == args.w_groupsize 276 | # ), "a_groupsize should be the same as w_groupsize!" 277 | assert args.k_pre_rope is False, "Pre-RoPE quantization is not supported yet!" 278 | 279 | return args, unknown 280 | 281 | 282 | def process_args_ptq(): 283 | ptq_args = None 284 | 285 | ptq_args, unknown_args = parser_gen() 286 | 287 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) 288 | model_args, training_args = parser.parse_args_into_dataclasses(args=unknown_args) 289 | if model_args.optimized_rotation_path is not None: 290 | ptq_args.optimized_rotation_path = model_args.optimized_rotation_path 291 | else: 292 | ptq_args.optimized_rotation_path = None 293 | ptq_args.bsz = training_args.per_device_eval_batch_size 294 | 295 | return model_args, training_args, ptq_args 296 | -------------------------------------------------------------------------------- /spinquant/utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). 9 | # Licensed under Apache License 2.0. 10 | 11 | import logging 12 | import os 13 | import random 14 | from typing import Optional 15 | 16 | import numpy as np 17 | import torch 18 | from fast_hadamard_transform import hadamard_transform 19 | from torch.distributed.fsdp import ( 20 | FullStateDictConfig, 21 | ) 22 | from torch.distributed.fsdp import ( 23 | FullyShardedDataParallel as PT_FSDP, 24 | ) 25 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 26 | 27 | 28 | from accelerate import dispatch_model, infer_auto_device_map 29 | from accelerate.utils import get_balanced_memory 30 | 31 | # These flags disable using TensorFloat-32 tensor cores (to avoid numerical issues) 32 | # torch.backends.cuda.matmul.allow_tf32 = False 33 | # torch.backends.cudnn.allow_tf32 = False 34 | DEV = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 35 | 36 | 37 | def pt_fsdp_state_dict(model: torch.nn.Module): 38 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 39 | with PT_FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): 40 | return model.state_dict() 41 | 42 | 43 | class HadamardTransform(torch.autograd.Function): 44 | """The unnormalized Hadamard transform (i.e. without dividing by sqrt(2))""" 45 | 46 | @staticmethod 47 | def forward(ctx, u): 48 | return hadamard_transform(u) 49 | 50 | @staticmethod 51 | def backward(ctx, grad): 52 | return hadamard_transform(grad) 53 | 54 | 55 | def llama_down_proj_groupsize(model, groupsize): 56 | assert groupsize > 1, "groupsize should be greater than 1!" 57 | 58 | if model.config.intermediate_size % groupsize == 0: 59 | logging.info(f"(Act.) Groupsiz = Down_proj Groupsize: {groupsize}") 60 | return groupsize 61 | 62 | group_num = int(model.config.hidden_size / groupsize) 63 | assert ( 64 | groupsize * group_num == model.config.hidden_size 65 | ), "Invalid groupsize for llama!" 66 | 67 | down_proj_groupsize = model.config.intermediate_size // group_num 68 | assert ( 69 | down_proj_groupsize * group_num == model.config.intermediate_size 70 | ), "Invalid groupsize for down_proj!" 71 | logging.info( 72 | f"(Act.) Groupsize: {groupsize}, Down_proj Groupsize: {down_proj_groupsize}" 73 | ) 74 | return down_proj_groupsize 75 | 76 | 77 | def set_seed(seed): 78 | np.random.seed(seed) 79 | torch.random.manual_seed(seed) 80 | random.seed(seed) 81 | 82 | 83 | # Dump the log both to console and a log file. 84 | def config_logging(log_file, level=logging.INFO): 85 | class LogFormatter(logging.Formatter): 86 | def format(self, record): 87 | if record.levelno == logging.INFO: 88 | self._style._fmt = "%(message)s" 89 | else: 90 | self._style._fmt = "%(levelname)s: %(message)s" 91 | return super().format(record) 92 | 93 | console_handler = logging.StreamHandler() 94 | console_handler.setFormatter(LogFormatter()) 95 | 96 | file_handler = logging.FileHandler(log_file) 97 | file_handler.setFormatter(LogFormatter()) 98 | 99 | logging.basicConfig(level=level, handlers=[console_handler, file_handler]) 100 | 101 | 102 | def cleanup_memory(verbos=True) -> None: 103 | """Run GC and clear GPU memory.""" 104 | import gc 105 | import inspect 106 | 107 | caller_name = "" 108 | try: 109 | caller_name = f" (from {inspect.stack()[1].function})" 110 | except (ValueError, KeyError): 111 | pass 112 | 113 | def total_reserved_mem() -> int: 114 | return sum( 115 | torch.cuda.memory_reserved(device=i) 116 | for i in range(torch.cuda.device_count()) 117 | ) 118 | 119 | memory_before = total_reserved_mem() 120 | 121 | # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed 122 | gc.collect() 123 | 124 | if torch.cuda.is_available(): 125 | torch.cuda.empty_cache() 126 | memory_after = total_reserved_mem() 127 | if verbos: 128 | logging.info( 129 | f"GPU memory{caller_name}: {memory_before / (1024 ** 3):.2f} -> {memory_after / (1024 ** 3):.2f} GB" 130 | f" ({(memory_after - memory_before) / (1024 ** 3):.2f} GB)" 131 | ) 132 | 133 | 134 | # Define a utility method for setting the logging parameters of a logger 135 | def get_logger(logger_name: Optional[str]) -> logging.Logger: 136 | # Get the logger with the specified name 137 | logger = logging.getLogger(logger_name) 138 | 139 | # Set the logging level of the logger to INFO 140 | logger.setLevel(logging.INFO) 141 | 142 | # Define a formatter for the log messages 143 | formatter = logging.Formatter( 144 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 145 | ) 146 | 147 | # Create a console handler for outputting log messages to the console 148 | console_handler = logging.StreamHandler() 149 | console_handler.setFormatter(formatter) 150 | 151 | # Add the console handler to the logger 152 | logger.addHandler(console_handler) 153 | 154 | return logger 155 | 156 | 157 | def get_local_rank() -> int: 158 | if os.environ.get("LOCAL_RANK"): 159 | return int(os.environ["LOCAL_RANK"]) 160 | else: 161 | logging.warning( 162 | "LOCAL_RANK from os.environ is None, fall back to get rank from torch distributed" 163 | ) 164 | return torch.distributed.get_rank() 165 | 166 | 167 | def get_global_rank() -> int: 168 | """ 169 | Get rank using torch.distributed if available. Otherwise, the RANK env var instead if initialized. 170 | Returns 0 if neither condition is met. 171 | """ 172 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 173 | return torch.distributed.get_rank() 174 | 175 | environ_rank = os.environ.get("RANK", "") 176 | if environ_rank.isdecimal(): 177 | return int(os.environ["RANK"]) 178 | 179 | return 0 180 | 181 | 182 | def distribute_model(model) -> None: 183 | """Distribute the model across available GPUs. NB: only implemented for Llama-2.""" 184 | no_split_module_classes = ['LlamaDecoderLayer'] 185 | max_memory = get_balanced_memory( 186 | model, 187 | no_split_module_classes=no_split_module_classes, 188 | ) 189 | 190 | device_map = infer_auto_device_map( 191 | model, max_memory=max_memory, no_split_module_classes=no_split_module_classes 192 | ) 193 | 194 | print(device_map) 195 | dispatch_model( 196 | model, 197 | device_map=device_map, 198 | offload_buffers=True, 199 | offload_dir="offload", 200 | state_dict=model.state_dict(), 201 | ) 202 | 203 | cleanup_memory() -------------------------------------------------------------------------------- /vit_quant/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Fake Quantization for Vision Transformers 4 | 5 | This is code is developed based on our [`fake_quant`](../fake_quant) 6 | 7 | ## Installation 8 | 9 | We can use the same envrionment with QuaRot. 10 | 11 | Additionally, install the `timm` package 12 | 13 | ## ImageNet Evaluations 14 | 15 | Currently, this code supports **EVA-02 and DeiT** transformers. The arguments are the same with `fake_quant` experiments. 16 | 17 | **Before you start, kindly modify the imagenet datapath in [data_utils.py](./data_utils.py) ** 18 | 19 | 20 | We provide a script `run.sh` to reproduce the results. 21 | 22 | -------------------------------------------------------------------------------- /vit_quant/data_utils.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | from timm.data import resolve_data_config 4 | from timm.data.transforms_factory import create_transform 5 | from torchvision.datasets import ImageFolder,DatasetFolder 6 | 7 | 8 | def get_validation_loader(dataset_name, model, batch_size, num_workers=8): 9 | if dataset_name == 'imagenet': 10 | data_path = '/gpfs/gibbs/project/panda/shared/imagenet' 11 | else: 12 | raise NotImplementedError 13 | config = resolve_data_config({}, model=model) 14 | test_transform = create_transform(**config) 15 | val_set = ImageFolder(data_path, test_transform) 16 | val_loader = torch.utils.data.DataLoader(val_set, 17 | batch_size=batch_size, 18 | num_workers=num_workers, 19 | shuffle=False) 20 | return val_loader 21 | 22 | 23 | def get_calibration_loader(dataset_name, model, num_data, num_workers=8): 24 | if dataset_name == 'imagenet': 25 | data_path = '/gpfs/gibbs/project/panda/shared/imagenet_2012/train' 26 | else: 27 | raise NotImplementedError 28 | config = resolve_data_config({}, model=model) 29 | train_transform = create_transform(**config, is_training=False) 30 | calib_set = ImageFolder(data_path, train_transform) 31 | calib_loader = torch.utils.data.DataLoader(calib_set, 32 | batch_size=num_data, 33 | num_workers=num_workers, 34 | shuffle=True) 35 | 36 | return calib_loader 37 | 38 | 39 | def get_calibration_data(dataset_name, model, num_data=128): 40 | calib_loader = get_calibration_loader(dataset_name, model, num_data) 41 | calib_loader = iter(calib_loader) 42 | image, label = next(calib_loader) 43 | image = image.cuda() 44 | return image 45 | -------------------------------------------------------------------------------- /vit_quant/eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import data_utils 3 | from tqdm import tqdm 4 | 5 | 6 | def test(args, model): 7 | model.cuda() 8 | test_loader = data_utils.get_validation_loader( 9 | args.eval_dataset, model, args.bsz 10 | ) 11 | 12 | pos = 0 13 | tot = 0 14 | i = 0 15 | max_iteration = len(test_loader) 16 | with torch.no_grad(): 17 | q = tqdm(test_loader) 18 | for inp, target in q: 19 | i += 1 20 | inp = inp.cuda() 21 | target = target.cuda() 22 | out = model(inp) 23 | pos_num = torch.sum(out.argmax(1) == target).item() 24 | pos += pos_num 25 | tot += inp.size(0) 26 | q.set_postfix({"acc": pos / tot}) 27 | if i >= max_iteration: 28 | break 29 | print('ImageNet accuracy: {}%'.format(100 * pos / tot)) -------------------------------------------------------------------------------- /vit_quant/gptaq_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import utils 7 | import quant_utils 8 | import timm 9 | import logging 10 | import functools 11 | 12 | torch.backends.cuda.matmul.allow_tf32 = False 13 | torch.backends.cudnn.allow_tf32 = False 14 | 15 | 16 | class GPTAQ: 17 | 18 | def __init__(self, layer, cls_token=0): 19 | self.layer = layer 20 | self.dev = self.layer.weight.device 21 | W = layer.weight.data.clone() 22 | self.rows = W.shape[0] 23 | self.columns = W.shape[1] 24 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 25 | self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev) 26 | self.nsamples = 0 27 | self.fp_inp = [] 28 | self.cls_token = cls_token 29 | 30 | def add_batch(self, inp, out): 31 | if len(inp.shape) == 2: 32 | inp = inp.unsqueeze(0) 33 | tmp = inp.shape[0] 34 | if len(inp.shape) == 3: 35 | inp = inp[:, self.cls_token:, :] 36 | inp = inp.reshape((-1, inp.shape[-1])) 37 | 38 | inp = inp.t() 39 | 40 | self.H *= self.nsamples / (self.nsamples + tmp) 41 | self.dXXT *= self.nsamples / (self.nsamples + tmp) 42 | self.nsamples += tmp 43 | # inp = inp.float() 44 | inp = math.sqrt(2 / self.nsamples) * inp.float() 45 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 46 | self.H += inp.matmul(inp.t()) 47 | dX = self.fp_inp[0].float().to(self.dev) * math.sqrt(2 / self.nsamples) - inp 48 | self.dXXT += dX.matmul(inp.t()) 49 | 50 | del self.fp_inp[0] 51 | 52 | def fasterquant( 53 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False 54 | ): 55 | W = self.layer.weight.data.clone() 56 | W = W.float() 57 | 58 | if not self.quantizer.ready(): 59 | self.quantizer.find_params(W) 60 | 61 | H = self.H 62 | del self.H 63 | dead = torch.diag(H) == 0 64 | H[dead, dead] = 1 65 | W[:, dead] = 0 66 | self.dXXT[:, dead] = 0 67 | 68 | if static_groups: 69 | import copy 70 | groups = [] 71 | for i in range(0, self.columns, groupsize): 72 | quantizer = copy.deepcopy(self.quantizer) 73 | quantizer.find_params(W[:, i:(i + groupsize)]) 74 | groups.append(quantizer) 75 | 76 | if actorder: 77 | perm = torch.argsort(torch.diag(H), descending=True) 78 | W = W[:, perm] 79 | H = H[perm][:, perm] 80 | self.dXXT = self.dXXT[perm][:, perm] 81 | invperm = torch.argsort(perm) 82 | 83 | Losses = torch.zeros_like(W) 84 | Q = torch.zeros_like(W) 85 | 86 | damp = percdamp * torch.mean(torch.diag(H)) 87 | diag = torch.arange(self.columns, device=self.dev) 88 | H[diag, diag] += damp 89 | Hinv = torch.linalg.cholesky(H) 90 | Hinv = torch.cholesky_inverse(Hinv) 91 | Hinv = torch.linalg.cholesky(Hinv, upper=True) 92 | 93 | alpha = 0.25 94 | P = alpha * ((self.dXXT @ Hinv.T).triu(diagonal=1)) @ Hinv 95 | del self.dXXT 96 | 97 | for i1 in range(0, self.columns, blocksize): 98 | i2 = min(i1 + blocksize, self.columns) 99 | count = i2 - i1 100 | 101 | W1 = W[:, i1:i2].clone() 102 | Q1 = torch.zeros_like(W1) 103 | Err1 = torch.zeros_like(W1) 104 | Losses1 = torch.zeros_like(W1) 105 | Hinv1 = Hinv[i1:i2, i1:i2] 106 | P1 = P[i1:i2, i1:i2] 107 | 108 | for i in range(count): 109 | w = W1[:, i] 110 | d = Hinv1[i, i] 111 | 112 | if groupsize != -1: 113 | if not static_groups: 114 | if (i1 + i) % groupsize == 0: 115 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)]) 116 | else: 117 | idx = i1 + i 118 | if actorder: 119 | idx = perm[idx] 120 | self.quantizer = groups[idx // groupsize] 121 | 122 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten() 123 | Q1[:, i] = q 124 | Losses1[:, i] = (w - q) ** 2 / d ** 2 125 | 126 | err1 = (w - q) / d 127 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0)) 128 | Err1[:, i] = err1 129 | 130 | Q[:, i1:i2] = Q1 131 | Losses[:, i1:i2] = Losses1 / 2 132 | 133 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:]) 134 | 135 | torch.cuda.synchronize() 136 | 137 | if actorder: 138 | Q = Q[:, invperm] 139 | 140 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 141 | if torch.any(torch.isnan(self.layer.weight.data)): 142 | logging.warning('NaN in weights') 143 | import pprint 144 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point) 145 | raise ValueError('NaN in weights') 146 | 147 | def free(self): 148 | self.H = None 149 | self.Losses = None 150 | self.Trace = None 151 | torch.cuda.empty_cache() 152 | utils.cleanup_memory(verbos=False) 153 | 154 | 155 | class FPInputsCache: 156 | """ 157 | class for saving the full-precision output in each layer. 158 | """ 159 | def __init__(self, sequential, cls_token=0): 160 | self.fp_cache = {} 161 | self.names = sequential[0]+sequential[1]+sequential[2]+sequential[3] 162 | print(self.names) 163 | for name in self.names: 164 | self.fp_cache[name] = [] 165 | self.handles = [] 166 | self.cls_token = cls_token 167 | 168 | def cache_fp_input(self, m, inp, out, name): 169 | inp = inp[0].detach().cpu() 170 | if len(inp.shape) == 3: 171 | inp = inp[:, self.cls_token:, :] 172 | inp = inp.reshape((-1, inp.shape[-1])) 173 | self.fp_cache[name] += [inp.t()] 174 | 175 | def add_hook(self, full): 176 | for name in self.names: 177 | self.handles.append( 178 | full[name].register_forward_hook( 179 | functools.partial(self.cache_fp_input, name=name) 180 | ) 181 | ) 182 | 183 | def clear_hook(self): 184 | for h in self.handles: 185 | h.remove() 186 | self.handles = [] 187 | torch.cuda.empty_cache() 188 | 189 | def clear_cache(self): 190 | for name in self.names: 191 | self.fp_cache[name] = [] 192 | 193 | 194 | @torch.no_grad() 195 | def gptaq_fwrd(model, calib_data, dev, args): 196 | ''' 197 | From GPTQ repo 198 | TODO: Make this function general to support both OPT and LLaMA models 199 | ''' 200 | print('-----GPTQv2 Quantization-----') 201 | 202 | layers = model.blocks 203 | 204 | model = model.cuda() 205 | 206 | class Catcher(nn.Module): 207 | def __init__(self, module): 208 | super().__init__() 209 | self.module = module 210 | self.fwd_kwargs = {} 211 | self.inps = None 212 | 213 | def forward(self, inp, **kwargs): 214 | self.inps = inp.data.clone() 215 | self.fwd_kwargs = kwargs 216 | print('data collected') 217 | raise ValueError 218 | 219 | layers[0] = Catcher(layers[0]) 220 | try: 221 | model(calib_data.to(dev)) 222 | except ValueError: 223 | pass 224 | 225 | inps = layers[0].inps 226 | fwd_kwargs = layers[0].fwd_kwargs 227 | 228 | layers[0] = layers[0].module 229 | layers[0] = layers[0].cpu() 230 | model = model.cpu() 231 | 232 | del calib_data 233 | torch.cuda.empty_cache() 234 | 235 | outs = torch.zeros_like(inps) 236 | 237 | quantizers = {} 238 | if isinstance(model, timm.models.Eva): 239 | sequential = [ 240 | ['attn.q_proj.module', 'attn.k_proj.module', 'attn.v_proj.module'], 241 | ['attn.proj.module'], 242 | ['mlp.fc1_g.module', 'mlp.fc1_x.module'], 243 | ['mlp.fc2.module'] 244 | ] 245 | elif isinstance(model, timm.models.VisionTransformer): 246 | sequential = [ 247 | ['attn.qkv.module'], 248 | ['attn.proj.module'], 249 | ['mlp.fc1.module'], 250 | ['mlp.fc2.module'] 251 | ] 252 | else: 253 | raise NotImplementedError 254 | 255 | fp_inputs_cache = FPInputsCache(sequential, cls_token=0) 256 | fp_inps = inps.clone() 257 | 258 | for i in range(len(layers)): 259 | print(f'\nLayer {i}:', flush=True, end=' ') 260 | layer = layers[i].to(dev) 261 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear]) 262 | 263 | bits_config = quant_utils.disable_act_quant(layer) 264 | fp_inputs_cache.add_hook(full) 265 | fp_inps = layer(fp_inps, **fwd_kwargs) 266 | fp_inputs_cache.clear_hook() 267 | quant_utils.enable_act_quant(layer, bits_config) 268 | 269 | for names in sequential: 270 | subset = {n: full[n] for n in names} 271 | 272 | gptq = {} 273 | for name in subset: 274 | print(f'{name}', end=' ', flush=True) 275 | layer_weight_bits = args.w_bits 276 | layer_weight_sym = not (args.w_asym) 277 | if 'head' in name: 278 | layer_weight_bits = 16 279 | continue 280 | gptq[name] = GPTAQ(subset[name], cls_token=0) 281 | gptq[name].quantizer = quant_utils.WeightQuantizer() 282 | gptq[name].quantizer.configure( 283 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip 284 | ) 285 | gptq[name].fp_inp = fp_inputs_cache.fp_cache[name] 286 | 287 | def add_batch(name): 288 | def tmp(_, inp, out): 289 | gptq[name].add_batch(inp[0].data, out.data) 290 | 291 | return tmp 292 | 293 | handles = [] 294 | for name in subset: 295 | handles.append(subset[name].register_forward_hook(add_batch(name))) 296 | outs = layer(inps, **fwd_kwargs) # forward calibration data 297 | for h in handles: 298 | h.remove() 299 | 300 | for name in subset: 301 | layer_w_groupsize = args.w_groupsize 302 | gptq[name].fasterquant( 303 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False 304 | ) 305 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer 306 | gptq[name].free() 307 | 308 | outs = layer(inps, **fwd_kwargs) # forward calibration data 309 | 310 | fp_inputs_cache.clear_cache() 311 | layers[i] = layer.cpu() 312 | del layer 313 | del gptq 314 | torch.cuda.empty_cache() 315 | 316 | inps, outs = outs, inps 317 | 318 | utils.cleanup_memory(verbos=True) 319 | print('-----GPTQv2 Quantization Done-----\n') 320 | return quantizers 321 | -------------------------------------------------------------------------------- /vit_quant/gptq_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import timm.models 5 | import tqdm 6 | import torch 7 | import torch.nn as nn 8 | import utils 9 | import quant_utils 10 | import logging 11 | 12 | torch.backends.cuda.matmul.allow_tf32 = False 13 | torch.backends.cudnn.allow_tf32 = False 14 | 15 | 16 | class GPTQ: 17 | 18 | def __init__(self, layer, cls_token=0): 19 | self.layer = layer 20 | self.dev = self.layer.weight.device 21 | W = layer.weight.data.clone() 22 | self.rows = W.shape[0] 23 | self.columns = W.shape[1] 24 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 25 | self.nsamples = 0 26 | self.cls_token = cls_token 27 | 28 | def add_batch(self, inp, out): 29 | 30 | if len(inp.shape) == 2: 31 | inp = inp.unsqueeze(0) 32 | tmp = inp.shape[0] 33 | if len(inp.shape) == 3: 34 | inp = inp[:, self.cls_token:, :] 35 | inp = inp.reshape((-1, inp.shape[-1])) 36 | inp = inp.t() 37 | self.H *= self.nsamples / (self.nsamples + tmp) 38 | self.nsamples += tmp 39 | # inp = inp.float() 40 | inp = math.sqrt(2 / self.nsamples) * inp.float() 41 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 42 | self.H += inp.matmul(inp.t()) 43 | 44 | def fasterquant( 45 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False 46 | ): 47 | W = self.layer.weight.data.clone() 48 | W = W.float() 49 | 50 | tick = time.time() 51 | 52 | if not self.quantizer.ready(): 53 | self.quantizer.find_params(W) 54 | 55 | H = self.H 56 | del self.H 57 | dead = torch.diag(H) == 0 58 | H[dead, dead] = 1 59 | W[:, dead] = 0 60 | 61 | if static_groups: 62 | import copy 63 | groups = [] 64 | for i in range(0, self.columns, groupsize): 65 | quantizer = copy.deepcopy(self.quantizer) 66 | quantizer.find_params(W[:, i:(i + groupsize)]) 67 | groups.append(quantizer) 68 | 69 | if actorder: 70 | perm = torch.argsort(torch.diag(H), descending=True) 71 | W = W[:, perm] 72 | H = H[perm][:, perm] 73 | invperm = torch.argsort(perm) 74 | 75 | Losses = torch.zeros_like(W) 76 | Q = torch.zeros_like(W) 77 | 78 | damp = percdamp * torch.mean(torch.diag(H)) 79 | diag = torch.arange(self.columns, device=self.dev) 80 | H[diag, diag] += damp 81 | H = torch.linalg.cholesky(H) 82 | H = torch.cholesky_inverse(H) 83 | H = torch.linalg.cholesky(H, upper=True) 84 | Hinv = H 85 | 86 | for i1 in range(0, self.columns, blocksize): 87 | i2 = min(i1 + blocksize, self.columns) 88 | count = i2 - i1 89 | 90 | W1 = W[:, i1:i2].clone() 91 | Q1 = torch.zeros_like(W1) 92 | Err1 = torch.zeros_like(W1) 93 | Losses1 = torch.zeros_like(W1) 94 | Hinv1 = Hinv[i1:i2, i1:i2] 95 | 96 | for i in range(count): 97 | w = W1[:, i] 98 | d = Hinv1[i, i] 99 | 100 | if groupsize != -1: 101 | if not static_groups: 102 | if (i1 + i) % groupsize == 0: 103 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)]) 104 | else: 105 | idx = i1 + i 106 | if actorder: 107 | idx = perm[idx] 108 | self.quantizer = groups[idx // groupsize] 109 | 110 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten() 111 | Q1[:, i] = q 112 | Losses1[:, i] = (w - q) ** 2 / d ** 2 113 | 114 | err1 = (w - q) / d 115 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 116 | Err1[:, i] = err1 117 | 118 | Q[:, i1:i2] = Q1 119 | Losses[:, i1:i2] = Losses1 / 2 120 | 121 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 122 | 123 | torch.cuda.synchronize() 124 | 125 | if actorder: 126 | Q = Q[:, invperm] 127 | 128 | # print('change ratio: {}'.format(Q.norm()/self.layer.weight.data.norm())) 129 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 130 | if torch.any(torch.isnan(self.layer.weight.data)): 131 | logging.warning('NaN in weights') 132 | import pprint 133 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point) 134 | raise ValueError('NaN in weights') 135 | 136 | def free(self): 137 | self.H = None 138 | self.Losses = None 139 | self.Trace = None 140 | torch.cuda.empty_cache() 141 | utils.cleanup_memory(verbos=False) 142 | 143 | 144 | @torch.no_grad() 145 | def gptq_fwrd(model, calib_data, dev, args): 146 | ''' 147 | From GPTQ repo 148 | TODO: Make this function general to support both OPT and LLaMA models 149 | ''' 150 | logging.info('-----GPTQ Quantization-----') 151 | 152 | layers = model.blocks 153 | 154 | model = model.cuda() 155 | 156 | class Catcher(nn.Module): 157 | def __init__(self, module): 158 | super().__init__() 159 | self.module = module 160 | self.fwd_kwargs = {} 161 | self.inps = None 162 | 163 | def forward(self, inp, **kwargs): 164 | self.inps = inp.data.clone() 165 | self.fwd_kwargs = kwargs 166 | print('data collected') 167 | raise ValueError 168 | 169 | layers[0] = Catcher(layers[0]) 170 | try: 171 | model(calib_data.to(dev)) 172 | except ValueError: 173 | pass 174 | 175 | inps = layers[0].inps 176 | fwd_kwargs = layers[0].fwd_kwargs 177 | 178 | layers[0] = layers[0].module 179 | layers[0] = layers[0].cpu() 180 | model = model.cpu() 181 | 182 | del calib_data 183 | torch.cuda.empty_cache() 184 | 185 | outs = torch.zeros_like(inps) 186 | 187 | quantizers = {} 188 | if isinstance(model, timm.models.Eva): 189 | sequential = [ 190 | ['attn.q_proj.module', 'attn.k_proj.module', 'attn.v_proj.module'], 191 | ['attn.proj.module'], 192 | ['mlp.fc1_g.module', 'mlp.fc1_x.module'], 193 | ['mlp.fc2.module'] 194 | ] 195 | elif isinstance(model, timm.models.VisionTransformer): 196 | sequential = [ 197 | ['attn.qkv.module'], 198 | ['attn.proj.module'], 199 | ['mlp.fc1.module'], 200 | ['mlp.fc2.module'] 201 | ] 202 | else: 203 | raise NotImplementedError 204 | 205 | for i in range(len(layers)): 206 | print(f'\nLayer {i}:', flush=True, end=' ') 207 | layer = layers[i].to(dev) 208 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear]) 209 | for names in sequential: 210 | subset = {n: full[n] for n in names} 211 | 212 | gptq = {} 213 | for name in subset: 214 | print(f'{name}', end=' ', flush=True) 215 | layer_weight_bits = args.w_bits 216 | layer_weight_sym = not (args.w_asym) 217 | if 'head' in name: 218 | layer_weight_bits = 16 219 | continue 220 | gptq[name] = GPTQ(subset[name], cls_token=0) 221 | gptq[name].quantizer = quant_utils.WeightQuantizer() 222 | gptq[name].quantizer.configure( 223 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip 224 | ) 225 | 226 | def add_batch(name): 227 | def tmp(_, inp, out): 228 | gptq[name].add_batch(inp[0].data, out.data) 229 | 230 | return tmp 231 | 232 | handles = [] 233 | for name in subset: 234 | handles.append(subset[name].register_forward_hook(add_batch(name))) 235 | outs = layer(inps, **fwd_kwargs) # forward calibration data 236 | for h in handles: 237 | h.remove() 238 | 239 | for name in subset: 240 | layer_w_groupsize = args.w_groupsize 241 | gptq[name].fasterquant( 242 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False 243 | ) 244 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer 245 | gptq[name].free() 246 | 247 | outs = layer(inps, **fwd_kwargs) 248 | 249 | layers[i] = layer.cpu() 250 | del layer 251 | del gptq 252 | torch.cuda.empty_cache() 253 | 254 | inps, outs = outs, inps 255 | 256 | utils.cleanup_memory(verbos=True) 257 | logging.info('-----GPTQ Quantization Done-----\n') 258 | return quantizers 259 | 260 | 261 | @torch.no_grad() 262 | def rtn_fwrd(model, dev, args): 263 | ''' 264 | From GPTQ repo 265 | TODO: Make this function general to support both OPT and LLaMA models 266 | ''' 267 | assert args.w_groupsize == -1, "Groupsize not supported in RTN!" 268 | layers = model.blocks 269 | torch.cuda.empty_cache() 270 | 271 | quantizers = {} 272 | 273 | for i in tqdm.tqdm(range(len(layers)), desc="(RtN Quant.) Layers"): 274 | layer = layers[i].to(dev) 275 | 276 | subset = quant_utils.find_qlayers(layer, 277 | layers=[torch.nn.Linear]) 278 | 279 | for name in subset: 280 | layer_weight_bits = args.w_bits 281 | if 'head' in name: 282 | layer_weight_bits = 16 283 | continue 284 | 285 | quantizer = quant_utils.WeightQuantizer() 286 | quantizer.configure( 287 | layer_weight_bits, perchannel=True, sym=not (args.w_asym), mse=args.w_clip 288 | ) 289 | W = subset[name].weight.data 290 | quantizer.find_params(W) 291 | subset[name].weight.data = quantizer.quantize(W).to( 292 | next(iter(layer.parameters())).dtype) 293 | quantizers['model.layers.%d.%s' % (i, name)] = quantizer.cpu() 294 | layers[i] = layer.cpu() 295 | torch.cuda.empty_cache() 296 | del layer 297 | 298 | utils.cleanup_memory(verbos=True) 299 | return quantizers 300 | -------------------------------------------------------------------------------- /vit_quant/main.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import torch 3 | import model_utils 4 | import data_utils 5 | import eval_utils 6 | import quant_utils 7 | import gptq_utils 8 | import gptaq_utils 9 | 10 | 11 | def configure_act_quantizer(model, args): 12 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper]) 13 | down_proj_groupsize = -1 14 | if args.a_groupsize > 0 and "llama" in args.model: 15 | down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize) 16 | 17 | for name in qlayers: 18 | layer_input_bits = args.a_bits 19 | layer_groupsize = args.a_groupsize 20 | layer_a_sym = not (args.a_asym) 21 | layer_a_clip = args.a_clip_ratio 22 | 23 | if 'head' in name: # Skip lm_head quantization 24 | layer_input_bits = 16 25 | 26 | if 'down_proj' in name: # Set the down_proj precision 27 | if args.int8_down_proj: 28 | layer_input_bits = 8 29 | layer_groupsize = down_proj_groupsize 30 | 31 | qlayers[name].quantizer.configure(bits=layer_input_bits, 32 | groupsize=layer_groupsize, 33 | sym=layer_a_sym, 34 | clip_ratio=layer_a_clip) 35 | 36 | 37 | def main(): 38 | args = utils.parser_gen() 39 | if args.wandb: 40 | import wandb 41 | wandb.init(project=args.wandb_project, entity=args.wandb_id) 42 | wandb.config.update(args) 43 | 44 | torch.manual_seed(args.seed) 45 | model = model_utils.get_model(args.model) 46 | if args.eval_fp: 47 | eval_utils.test(args, model) 48 | quant_utils.add_actquant(model) # Add Activation Wrapper to the model as the rest of the code assumes it is present 49 | 50 | # Add Input Quantization 51 | if args.a_bits < 16 and args.enable_aq_calibration: 52 | configure_act_quantizer(model, args) 53 | 54 | if args.w_bits < 16: 55 | save_dict = {} 56 | if args.load_qmodel_path: # Load Quantized Rotated Model 57 | assert args.rotate, "Model should be rotated to load a quantized model!" 58 | assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!" 59 | print("Load quantized model from ", args.load_qmodel_path) 60 | save_dict = torch.load(args.load_qmodel_path) 61 | model.load_state_dict(save_dict["model"]) 62 | 63 | elif not args.w_rtn: # GPTQ Weight Quantization 64 | 65 | calib_data = data_utils.get_calibration_data( 66 | args.cal_dataset, num_data=args.nsamples, model=model, 67 | ) 68 | 69 | if args.asym_calibrate: 70 | quantizers = gptaq_utils.gptaq_fwrd(model, calib_data, utils.DEV, args) 71 | save_dict["w_quantizers"] = quantizers 72 | else: 73 | quantizers = gptq_utils.gptq_fwrd(model, calib_data, utils.DEV, args) 74 | save_dict["w_quantizers"] = quantizers 75 | else: # RTN Weight Quantization 76 | quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args) 77 | save_dict["w_quantizers"] = quantizers 78 | 79 | if args.save_qmodel_path: 80 | save_dict["model"] = model.state_dict() 81 | torch.save(save_dict, args.save_qmodel_path) 82 | 83 | if args.a_bits < 16 and not args.enable_aq_calibration: 84 | configure_act_quantizer(model, args) 85 | 86 | # Evaluating on dataset 87 | eval_utils.test(args, model) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /vit_quant/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import typing 3 | # import utils 4 | import timm 5 | import os 6 | import logging 7 | 8 | 9 | def get_model(name, ): 10 | """ 11 | Get a vision transformer model. 12 | This will replace matrix multiplication operations with matmul modules in the model. 13 | 14 | Currently support almost all models in timm.models.transformers, including: 15 | - vit_tiny/small/base/large_patch16/patch32_224/384, 16 | - deit_tiny/small/base(_distilled)_patch16_224, 17 | - deit_base(_distilled)_patch16_384, 18 | - swin_tiny/small/base/large_patch4_window7_224, 19 | - swin_base/large_patch4_window12_384 20 | 21 | These models are finetuned on imagenet-1k and should use ViTImageNetLoaderGenerator 22 | for calibration and testing. 23 | """ 24 | net = timm.create_model(name, pretrained=True) 25 | 26 | net.cuda() 27 | net.eval() 28 | return net 29 | 30 | 31 | 32 | def stem_layer_forward(model, calib_data, dev): 33 | if isinstance(model, timm.models.VisionTransformer): 34 | model.patch_embed = model.patch_embed.to(dev) 35 | model.pos_drop = model.pos_drop.to(dev) 36 | 37 | inps = model.patch_embed(calib_data) 38 | inps = model._pos_embed(inps) 39 | inps = model.patch_drop(inps) 40 | inps = model.norm_pre(inps) 41 | 42 | dtype = next(iter(model.parameters())).dtype 43 | 44 | model.patch_embed = model.patch_embed.cpu() 45 | model.pos_drop = model.pos_drop.cpu() 46 | kwargs = {} 47 | elif isinstance(model, timm.models.Eva): 48 | model.patch_embed = model.patch_embed.to(dev) 49 | model._pos_embed = model._pos_embed.to(dev) 50 | 51 | inps = model.patch_embed(calib_data) 52 | inps, rot_pos_embed = model._pos_embed(inps) 53 | inps = model.norm_pre(inps) 54 | 55 | dtype = next(iter(model.parameters())).dtype 56 | 57 | model.patch_embed = model.patch_embed.cpu() 58 | model.pos_drop = model.pos_drop.cpu() 59 | 60 | return model, inps, kwargs -------------------------------------------------------------------------------- /vit_quant/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | export CUDA_VISIBLE_DEVICES=$gpu_id 5 | 6 | 7 | python main.py --model eva02_large_patch14_448.mim_m38m_ft_in22k_in1k \ 8 | --w_bits 4 \ 9 | --w_groupsize -1 \ 10 | --w_clip \ 11 | --a_bits 4 \ 12 | --nsamples 128 \ 13 | --a_asym \ 14 | --w_asym \ 15 | --percdamp 0.1 \ 16 | --act_order \ 17 | --bsz 256 \ 18 | --asym_calibrate \ 19 | --enable_aq_calibration \ 20 | -------------------------------------------------------------------------------- /vit_quant/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | import torch 4 | import random 5 | import numpy as np 6 | import os 7 | from datetime import datetime 8 | import logging 9 | 10 | 11 | supported_models = [ 12 | 'deit_base_patch16_224', 13 | 'deit_small_patch16_224', 14 | 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', 15 | ] 16 | supported_datasets = ['imagenet'] 17 | 18 | # These flags disable using TensorFloat-32 tensor cores (to avoid numerical issues) 19 | torch.backends.cuda.matmul.allow_tf32 = False 20 | torch.backends.cudnn.allow_tf32 = False 21 | DEV = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 22 | 23 | 24 | def llama_down_proj_groupsize(model, groupsize): 25 | assert groupsize > 1, 'groupsize should be greater than 1!' 26 | 27 | if model.config.intermediate_size % groupsize == 0: 28 | logging.info(f'(Act.) Groupsiz = Down_proj Groupsize: {groupsize}') 29 | return groupsize 30 | 31 | group_num = int(model.config.hidden_size / groupsize) 32 | assert groupsize * group_num == model.config.hidden_size, 'Invalid groupsize for llama!' 33 | 34 | down_proj_groupsize = model.config.intermediate_size // group_num 35 | assert down_proj_groupsize * group_num == model.config.intermediate_size, 'Invalid groupsize for down_proj!' 36 | logging.info(f'(Act.) Groupsize: {groupsize}, Down_proj Groupsize: {down_proj_groupsize}') 37 | return down_proj_groupsize 38 | 39 | 40 | def set_seed(seed): 41 | np.random.seed(seed) 42 | torch.random.manual_seed(seed) 43 | random.seed(seed) 44 | 45 | 46 | # Dump the log both to console and a log file. 47 | def config_logging(log_file, level=logging.INFO): 48 | class LogFormatter(logging.Formatter): 49 | def format(self, record): 50 | if record.levelno == logging.INFO: 51 | self._style._fmt = "%(message)s" 52 | else: 53 | self._style._fmt = "%(levelname)s: %(message)s" 54 | return super().format(record) 55 | 56 | console_handler = logging.StreamHandler() 57 | console_handler.setFormatter(LogFormatter()) 58 | 59 | file_handler = logging.FileHandler(log_file) 60 | file_handler.setFormatter(LogFormatter()) 61 | 62 | logging.basicConfig(level=level, handlers=[console_handler, file_handler]) 63 | 64 | 65 | def parser_gen(): 66 | parser = argparse.ArgumentParser() 67 | 68 | # General Arguments 69 | parser.add_argument('--model', type=str, default='deit_base_patch16_224', 70 | help='Model to load;', choices=supported_models) 71 | parser.add_argument('--seed', type=int, default=0, help='Random Seed for HuggingFace and PyTorch') 72 | parser.add_argument('--eval_dataset', type=str, default='imagenet', 73 | help='Dataset for Evaluation (default: imagenet)', choices=supported_datasets, ) 74 | parser.add_argument('--bsz', type=int, default=256, 75 | help='Batch-size for PPL evaluation (default:256)') 76 | parser.add_argument('--eval_fp', action=argparse.BooleanOptionalAction, default=False, 77 | help='whether to evaluate the FP16 model (default: False)') 78 | 79 | # Activation Quantization Arguments 80 | parser.add_argument('--a_bits', type=int, default=16, 81 | help='''Number of bits for inputs of the Linear layers. This will be 82 | for all the linear layers in the model (including down-projection and out-projection)''') 83 | parser.add_argument('--a_groupsize', type=int, default=-1, 84 | help='Groupsize for activation quantization. Note that this should be the same as w_groupsize') 85 | parser.add_argument('--a_asym', action=argparse.BooleanOptionalAction, default=False, 86 | help='ASymmetric Activation quantization (default: False)') 87 | parser.add_argument('--a_clip_ratio', type=float, default=1.0, 88 | help='Clip ratio for activation quantization. new_max = max * clip_ratio') 89 | parser.add_argument('--enable_aq_calibration', action=argparse.BooleanOptionalAction, default=False, 90 | help='Activation quantization during GPTQ (default: False)') 91 | 92 | # Weight Quantization Arguments 93 | parser.add_argument('--w_bits', type=int, default=16, 94 | help='Number of bits for weights of the Linear layers') 95 | parser.add_argument('--w_groupsize', type=int, default=-1, 96 | help='Groupsize for weight quantization. Note that this should be the same as a_groupsize') 97 | parser.add_argument('--w_asym', action=argparse.BooleanOptionalAction, default=False, 98 | help='ASymmetric weight quantization (default: False)') 99 | parser.add_argument('--w_rtn', action=argparse.BooleanOptionalAction, default=False, 100 | help='Quantize the weights using RtN. If the w_bits < 16 and this flag is not set, we use GPTQ') 101 | parser.add_argument('--w_clip', action=argparse.BooleanOptionalAction, default=False, 102 | help='''Clipping the weight quantization! 103 | We do not support arguments for clipping and we find the best clip ratio during the weight quantization''') 104 | parser.add_argument('--nsamples', type=int, default=128, 105 | help='Number of calibration data samples for GPTQ.') 106 | parser.add_argument('--cal_dataset', type=str, default='imagenet', 107 | help='calibration data samples for GPTQ.', choices=supported_datasets) 108 | parser.add_argument('--percdamp', type=float, default=.01, 109 | help='Percent of the average Hessian diagonal to use for dampening.') 110 | parser.add_argument('--act_order', action=argparse.BooleanOptionalAction, default=False, 111 | help='act-order in GPTQ') 112 | parser.add_argument('--asym_calibrate', action=argparse.BooleanOptionalAction, default=False, 113 | help='whether to use GPTAQ') 114 | 115 | # Save/Load Quantized Model Arguments 116 | parser.add_argument('--load_qmodel_path', type=str, default=None, 117 | help='Load the quantized model from the specified path!') 118 | parser.add_argument('--save_qmodel_path', type=str, default=None, 119 | help='Save the quantized model to the specified path!') 120 | 121 | # WandB Arguments 122 | parser.add_argument('--wandb', action=argparse.BooleanOptionalAction, default=False) 123 | parser.add_argument('--wandb_id', type=str, default=None) 124 | parser.add_argument('--wandb_project', type=str, default=None) 125 | 126 | # Experiments Arguments 127 | parser.add_argument('--save_name', type=str, default=None, help='The path to save experiment data, ' 128 | 'including quantized models, dumped layer inputs, etc. The data will be saved in experiments/[model]/save_name. Default: [datetime].') 129 | parser.add_argument('--capture_layer_io', action=argparse.BooleanOptionalAction, default=False, 130 | help='Capture the input and output of the specified decoder layer and dump into a file') 131 | parser.add_argument('--layer_idx', type=int, default=10, help='Which decoder layer to capture') 132 | 133 | parser.add_argument( 134 | "--distribute", 135 | action="store_true", 136 | help="Distribute the model on multiple GPUs for evaluation.", 137 | ) 138 | 139 | args = parser.parse_args() 140 | 141 | if args.save_name is None: 142 | args.save_name = datetime.now().strftime("%Y%m%d_%H%M%S") 143 | setattr(args, 'save_path', 144 | os.path.join(os.path.dirname(os.path.abspath(__file__)), 'experiments', args.model, args.save_name)) 145 | os.makedirs(args.save_path, exist_ok=True) 146 | 147 | config_logging(os.path.join(args.save_path, f'{args.save_name}.log')) 148 | 149 | assert args.a_groupsize == args.w_groupsize, 'a_groupsize should be the same as w_groupsize!' 150 | 151 | if args.wandb: 152 | assert args.wandb_id is not None and args.wandb_project is not None, 'WandB ID/project is not provided!' 153 | 154 | logging.info('Arguments: ') 155 | logging.info(pprint.pformat(vars(args))) 156 | logging.info('--' * 30) 157 | return args 158 | 159 | 160 | def cleanup_memory(verbos=True) -> None: 161 | """Run GC and clear GPU memory.""" 162 | import gc 163 | import inspect 164 | caller_name = '' 165 | try: 166 | caller_name = f' (from {inspect.stack()[1].function})' 167 | except (ValueError, KeyError): 168 | pass 169 | 170 | def total_reserved_mem() -> int: 171 | return sum(torch.cuda.memory_reserved(device=i) for i in range(torch.cuda.device_count())) 172 | 173 | memory_before = total_reserved_mem() 174 | 175 | # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed 176 | gc.collect() 177 | 178 | if torch.cuda.is_available(): 179 | torch.cuda.empty_cache() 180 | memory_after = total_reserved_mem() 181 | if verbos: 182 | logging.info( 183 | f"GPU memory{caller_name}: {memory_before / (1024 ** 3):.2f} -> {memory_after / (1024 ** 3):.2f} GB" 184 | f" ({(memory_after - memory_before) / (1024 ** 3):.2f} GB)" 185 | ) 186 | 187 | 188 | def distribute_model(model) -> None: 189 | """Distribute the model across available GPUs. NB: only implemented for Llama-2.""" 190 | no_split_module_classes = ['LlamaDecoderLayer'] 191 | max_memory = get_balanced_memory( 192 | model, 193 | no_split_module_classes=no_split_module_classes, 194 | ) 195 | 196 | device_map = infer_auto_device_map( 197 | model, max_memory=max_memory, no_split_module_classes=no_split_module_classes 198 | ) 199 | 200 | dispatch_model( 201 | model, 202 | device_map=device_map, 203 | offload_buffers=True, 204 | offload_dir="offload", 205 | state_dict=model.state_dict(), 206 | ) 207 | 208 | cleanup_memory() --------------------------------------------------------------------------------