├── .gitignore
├── README.md
├── fake_quant
├── README.md
├── data_utils.py
├── eval_utils.py
├── gptaq_utils.py
├── gptq_utils.py
├── hadamard_utils.py
├── main.py
├── model_utils.py
├── monkeypatch.py
├── quant_utils.py
├── requirements.txt
├── rotation_utils.py
├── run_llama.sh
└── utils.py
├── img
└── readme_intro.png
├── spinquant
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── SpinQuant.png
├── eval_utils
│ ├── gptaq_utils.py
│ ├── gptq_utils.py
│ ├── main.py
│ ├── modeling_llama.py
│ └── rotation_utils.py
├── optimize_rotation.py
├── ptq.py
├── requirement.txt
├── scripts
│ ├── 10_optimize_rotation.sh
│ ├── 11_optimize_rotation_fsdp.sh
│ ├── 2_eval_ptq.sh
│ ├── 31_optimize_rotation_executorch.sh
│ └── 32_eval_ptq_executorch.sh
├── train_utils
│ ├── apply_r3_r4.py
│ ├── fsdp_trainer.py
│ ├── main.py
│ ├── modeling_llama_quant.py
│ ├── optimizer.py
│ ├── quant_linear.py
│ └── rtn_utils.py
└── utils
│ ├── convert_to_executorch.py
│ ├── data_utils.py
│ ├── eval_utils.py
│ ├── fuse_norm_utils.py
│ ├── hadamard_utils.py
│ ├── model_utils.py
│ ├── monkeypatch.py
│ ├── process_args.py
│ ├── quant_utils.py
│ └── utils.py
└── vit_quant
├── README.md
├── data_utils.py
├── eval_utils.py
├── gptaq_utils.py
├── gptq_utils.py
├── main.py
├── model_utils.py
├── quant_utils.py
├── run.sh
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled source #
2 | ###################
3 | *.com
4 | *.class
5 | *.dll
6 | *.exe
7 | *.o
8 | *.so
9 |
10 | # Packages #
11 | ############
12 | # it's better to unpack these files and commit the raw source
13 | # git has its own built in compression methods
14 | *.7z
15 | *.dmg
16 | *.gz
17 | *.iso
18 | *.jar
19 | *.rar
20 |
21 | # Logs and databases #
22 | ######################
23 | *.log
24 | *.sql
25 | *.sqlite
26 |
27 | # OS generated files #
28 | ######################
29 | .DS_Store
30 | .DS_Store?
31 | ._*
32 | *.idea*
33 | .Spotlight-V100
34 | .Trashes
35 | *.xml
36 | *.iml
37 | *.pyc
38 | ehthumbs.db
39 | Thumbs.db
40 |
41 | # Ignore ImageNet Pre-trained Models #
42 | ######################################
43 | *ImageNet/checkpoint/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
GPTAQ: Efficient Finetuning-Free Quantization with Asymmetric Calibration [ICML 2025]
4 |
5 |
6 |
7 |
8 |
9 |
10 | The official pytorch implementation of GPTAQ.
11 |
12 | Unlike the previous GPTQ method, which independently calibrates each layer, we always match the quantized layer’s output to the exact output in the full-precision model, resulting in a scheme that we call asymmetric calibration. Such a scheme can effectively reduce the quantization error accumulated in previous layers. We analyze this problem using optimal brain compression to derive a close-formed solution. The new solution explicitly minimizes the quantization error as well as the accumulated asymmetry error. Furthermore, we utilize various techniques to parallelize the solution calculation, including channel parallelization, neuron decomposition, and Cholesky reformulation for matrix fusion. As a result, GPTAQ is easy to implement, simply using 20 more lines of code than GPTQ but improving its performance under low-bit quantization.
13 |
14 | ## Update: Name change to GPTAQ
15 |
16 | We are updating our code to the new name *`GPTAQ`*
17 |
18 | ## Update: GPTQv2 is integrated into GPTQModel
19 |
20 | The GPTQv2 method is integrated into [GPTQModel](https://github.com/ModelCloud/GPTQModel/tree/main) library, with a simple argument to perform.
21 |
22 | You can install GPTQModel:
23 |
24 | ```shell
25 | pip install -v gptqmodel --no-build-isolation
26 | ```
27 |
28 | Quantize LLaMA3.1-8B-Instruct
29 |
30 | ```python
31 | import tempfile
32 |
33 | from datasets import load_dataset
34 | from gptqmodel import GPTQModel, QuantizeConfig
35 | from gptqmodel.quantization import FORMAT
36 | from gptqmodel.utils.eval import EVAL
37 | from logbar import LogBar
38 |
39 | log = LogBar.shared()
40 |
41 | MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
42 | CFG_BITS = 4
43 | CFG_GROUPSIZE = 128
44 | CFG_V2 = True
45 | INPUTS_MAX_LENGTH = 2048 # in tokens
46 | QUANT_SAVE_PATH = f"/your_path/gptq_v2_{CFG_V2}_bit_{CFG_BITS}_gpsize_{CFG_GROUPSIZE}_llama_3.1_8B_Instruct"
47 |
48 | def get_calib_data(tokenizer, rows: int):
49 |
50 | calibration_dataset = load_dataset(
51 | "json",
52 | data_files="/your_path/dataset/c4-train.00000-of-01024.json.gz",
53 | split="train")
54 |
55 | datas = []
56 | for index, sample in enumerate(calibration_dataset):
57 | tokenized = tokenizer(sample["text"])
58 | if len(tokenized.data['input_ids']) <= INPUTS_MAX_LENGTH:
59 | datas.append(tokenized)
60 | if len(datas) >= rows:
61 | break
62 |
63 | return datas
64 |
65 | quant_config = QuantizeConfig(
66 | bits=CFG_BITS,
67 | group_size=CFG_GROUPSIZE,
68 | format=FORMAT.GPTQ,
69 | desc_act=True,
70 | sym=True,
71 | v2=CFG_V2,
72 | )
73 |
74 | log.info(f"QuantConfig: {quant_config}")
75 | log.info(f"Save Path: {QUANT_SAVE_PATH}")
76 |
77 | # load un-quantized native model
78 | model = GPTQModel.load(MODEL_ID, quant_config)
79 |
80 | # load calibration data
81 | calibration_dataset = get_calib_data(tokenizer=model.tokenizer, rows=256)
82 |
83 | model.quantize(calibration_dataset, batch_size=1)
84 |
85 | model.save(QUANT_SAVE_PATH)
86 | log.info(f"Quant Model Saved to: {QUANT_SAVE_PATH}")
87 | ```
88 |
89 | Evaluation on Arc_challenge and GSM8K:
90 |
91 | ```python
92 | # eval
93 | from lm_eval.tasks import TaskManager
94 | from lm_eval.utils import make_table
95 |
96 | with tempfile.TemporaryDirectory() as tmp_dir:
97 | results = GPTQModel.eval(
98 | QUANT_SAVE_PATH,
99 | tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.GSM8K_PLATINUM_COT],
100 | apply_chat_template=True,
101 | random_seed=898,
102 | output_path= tmp_dir,
103 | )
104 |
105 | print(make_table(results))
106 | if "groups" in results:
107 | print(make_table(results, "groups"))
108 | ```
109 |
110 |
111 | Performance comparison (GPTQv2 outperforms GPTQ on GSM8K using 1 fewer bit):
112 |
113 |
114 | v1 ([checkpoints](https://huggingface.co/ModelCloud/GPTQ-v1-Llama-3.1-8B-Instruct
115 | )):
116 |
117 |
118 | | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr|
119 | |------------------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
120 | |arc_challenge| 1|none | 0|acc |↑ |0.5000|± |0.0146|
121 | | | |none | 0|acc_norm|↑ |0.5128|± |0.0146|
122 | |gsm8k_platinum_cot| 3|flexible-extract| 8|exact_match|↑ |0.3995|± |0.0141|
123 | | | |strict-match | 8|exact_match|↑ |0.2548|± |0.0125|
124 |
125 |
126 | v2 ([checkpoints](https://huggingface.co/ModelCloud/GPTQ-v2-Llama-3.1-8B-Instruct)):
127 |
128 | | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr|
129 | |------------------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
130 | |arc_challenge| 1|none | 0|acc |↑ |0.5034|± |0.0146|
131 | | | |none | 0|acc_norm|↑ |0.5068|± |0.0146|
132 | |gsm8k_platinum_cot| 3|flexible-extract| 8|exact_match|↑ |0.7601|± |0.0123|
133 | | | |strict-match | 8|exact_match|↑ |0.5211|± |0.0144|
134 |
135 |
136 |
137 |
138 |
139 | ## Code Structure
140 |
141 | We provide several directories to reproduce the paper results.
142 |
143 | 1. [**fake_quant**](./fake_quant) for reproducing QuaRot+GPTQ/GPTAQ
144 | 2. [**spinquant**](./spinquant) for reproducing SpinQuant+GPTQ/GPTAQ
145 | 3. [**vit_quant**](./vit_quant) for reproducing vision transformer quantization results
146 |
147 | [//]: # (4. **GPTQModel**, a forked version of GPTQModel to support GPTQv2 to deploy weight-only quantization model )
148 |
149 | We recommend use separate envrionments for different experiments to ensure results are matched.
150 |
151 |
152 | ## Acknowledgement
153 |
154 | Our code is built upon several repository:
155 |
156 | [https://github.com/IST-DASLab/gptq](https://github.com/IST-DASLab/gptq)
157 |
158 | [https://github.com/spcl/QuaRot](https://github.com/spcl/QuaRot)
159 |
160 | [https://github.com/facebookresearch/SpinQuant/tree/main](https://github.com/facebookresearch/SpinQuant/tree/main)
161 |
162 |
163 | ## Star Histroy
164 | [](https://star-history.com/#Intelligent-Computing-Lab-Yale/GPTQv2)
165 |
166 | ## Contact
167 |
168 | Yuhang Li (*yuhang.li@yale.edu*)
169 |
170 | ## Citations
171 |
172 | If you find our work useful, please consider giving a star and citation:
173 | ```bibtex
174 | @article{li2025gptqv2,
175 | title={GPTAQ: Efficient Finetuning-Free Quantization for Asymmetric Calibration},
176 | author={Yuhang Li and Ruokai Yin and Donghyun Lee and Shiting Xiao and Priyadarshini Panda},
177 | year={2025},
178 | journal={arXiv preprint arXiv:2504.02692},
179 | }
180 | ```
--------------------------------------------------------------------------------
/fake_quant/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Fake Quantization with QuaRot
4 |
5 | This is code is developed based on [QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs](https://github.com/spcl/QuaRot)
6 |
7 | ## Installation
8 |
9 | We recommend installing the python envrionments with original codebase's requirements:
10 |
11 | ```bash
12 | conda create -n quarot python=3.9
13 | conda activate quarot
14 | pip install -r requirements.txt
15 | ```
16 | Additionally, to apply Hadamard transformation, build [fast-hadamard-transform](https://github.com/Dao-AILab/fast-hadamard-transform) package from source.
17 |
18 |
19 | ## Language Generation and Zero-Shot Evaluations
20 |
21 | Currently, this code supports **LLaMa-2 and LLaMA-3** models (We did not test OPT and LLaMA-1).
22 | You can simply run the `main.py` to reproduce the results in the paper. The important arguments are:
23 |
24 | - `--model`: the model name (or path to the weights)
25 | - `--bsz`: the batch size for PPL evaluation
26 | - `--rotate`: whether we want to rotate the model
27 | - `--lm_eval`: whether we want to run LM-Eval for Zero-Shot tasks
28 | - `--tasks`: the tasks for LM-Eval
29 | - `--cal_dataset`: the calibration dataset for GPTQ quantization
30 | - `--a_bits`: the number of bits for activation quantization
31 | - `--w_bits`: the number of bits for weight quantization
32 | - `--v_bits`: the number of bits for value quantization
33 | - `--k_bits`: the number of bits for key quantization
34 | - `--w_clip`: Whether we want to clip the weights
35 | - `--a_clip_ratio`: The ratio of clipping for activation
36 | - `--k_clip_ratio`: The ratio of clipping for key
37 | - `--v_clip_ratio`: The ratio of clipping for value
38 | - `--w_asym`: Whether we want to use asymmetric quantization for weights
39 | - `--a_asym`: Whether we want to use asymmetric quantization for activation
40 | - `--v_asym`: Whether we want to use asymmetric quantization for value
41 | - `--k_asym`: Whether we want to use asymmetric quantization for key
42 | - `--a_groupsize`: The group size for activation quantization
43 | - `--w_groupsize`: The group size for weight quantization
44 | - `--v_groupsize`: The group size for value quantization
45 | - `--k_groupsize`: The group size for key quantization
46 | - `--use_v2`: Turn on GPTQv2 quantization (recommened)
47 | - `--enable_aq_calibration`: Activation quantization during calibration (recommened)
48 |
49 |
50 | We provide a script `run_llama.sh` to reproduce the results.
51 |
52 |
--------------------------------------------------------------------------------
/fake_quant/data_utils.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import random
3 | import transformers
4 |
5 | def get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode=False):
6 |
7 | if hf_token is None:
8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
9 | else:
10 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
11 |
12 | if eval_mode:
13 | testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
14 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
15 | return testenc
16 | else:
17 | traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
18 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
19 | random.seed(seed)
20 | trainloader = []
21 | for _ in range(nsamples):
22 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
23 | j = i + seqlen
24 | inp = trainenc.input_ids[:, i:j]
25 | tar = inp.clone()
26 | tar[:, :-1] = -100
27 | trainloader.append((inp, tar))
28 | return trainloader
29 |
30 | def get_c4_new(nsamples, seed, seqlen, model, hf_token=None, eval_mode=False):
31 |
32 | if hf_token is None:
33 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
34 | else:
35 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
36 |
37 | if eval_mode:
38 | valdata = datasets.load_dataset(
39 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
40 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
41 | valenc = valenc.input_ids[:, :(256 * seqlen)]
42 | class TokenizerWrapper:
43 | def __init__(self, input_ids):
44 | self.input_ids = input_ids
45 | valenc = TokenizerWrapper(valenc)
46 | return valenc
47 | else:
48 | traindata = datasets.load_dataset(
49 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
50 |
51 | random.seed(seed)
52 | trainloader = []
53 | for _ in range(nsamples):
54 | while True:
55 | i = random.randint(0, len(traindata) - 1)
56 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
57 | if trainenc.input_ids.shape[1] >= seqlen:
58 | break
59 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
60 | j = i + seqlen
61 | inp = trainenc.input_ids[:, i:j]
62 | tar = inp.clone()
63 | tar[:, :-1] = -100
64 | trainloader.append((inp, tar))
65 | return trainloader
66 |
67 |
68 |
69 | def get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode=False):
70 |
71 |
72 | if hf_token is None:
73 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
74 | else:
75 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
76 |
77 | if eval_mode:
78 | testdata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='test')
79 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
80 | return testenc
81 | else:
82 | traindata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='train')
83 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
84 | random.seed(seed)
85 | trainloader = []
86 | for _ in range(nsamples):
87 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
88 | j = i + seqlen
89 | inp = trainenc.input_ids[:, i:j]
90 | tar = inp.clone()
91 | tar[:, :-1] = -100
92 | trainloader.append((inp, tar))
93 | return trainloader
94 |
95 |
96 | def get_loaders(
97 | name, nsamples=128, seed=0, seqlen=2048, model='', hf_token=None, eval_mode=False
98 | ):
99 | if 'wikitext2' in name:
100 | return get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode)
101 | if 'ptb' in name:
102 | return get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode)
103 | if 'c4' in name:
104 | return get_c4_new(nsamples, seed, seqlen, model, hf_token, eval_mode)
--------------------------------------------------------------------------------
/fake_quant/eval_utils.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import model_utils
3 | import quant_utils
4 | import torch
5 | import os
6 | import logging
7 | from tqdm import tqdm
8 |
9 |
10 | @torch.no_grad()
11 | def evaluator(model, testenc, dev, args):
12 |
13 | model.eval()
14 |
15 | if 'opt' in args.model:
16 | opt_type = True
17 | llama_type = False
18 | elif 'meta' in args.model:
19 | llama_type = True
20 | opt_type = False
21 | else:
22 | raise ValueError(f'Unknown model {args.model}')
23 |
24 | use_cache = model.config.use_cache
25 | model.config.use_cache = False
26 |
27 | if opt_type:
28 | layers = model.model.decoder.layers
29 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
30 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
31 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
32 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
33 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
34 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
35 |
36 | elif llama_type:
37 | layers = model.model.layers
38 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
39 |
40 | layers[0] = layers[0].to(dev)
41 |
42 | # Convert the whole text of evaluation dataset into batches of sequences.
43 | input_ids = testenc.input_ids # (1, text_len)
44 | nsamples = input_ids.numel() // model.seqlen # The tail is truncated.
45 | input_ids = input_ids[:, :nsamples * model.seqlen].view(nsamples, model.seqlen).to(dev) # (nsamples, seqlen)
46 |
47 | batch_size = args.bsz
48 | input_ids = [input_ids[i:i + batch_size] for i in range(0, nsamples, batch_size)]
49 | nbatches = len(input_ids)
50 |
51 | dtype = next(iter(model.parameters())).dtype
52 | # The input of the first decoder layer.
53 | inps = torch.zeros(
54 | (nbatches, batch_size, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
55 | )
56 | inps = [0] * nbatches
57 | cache = {'i': 0, 'attention_mask': None}
58 | class Catcher(torch.nn.Module):
59 | def __init__(self, module):
60 | super().__init__()
61 | self.module = module
62 | def forward(self, inp, **kwargs):
63 | inps[cache['i']] = inp
64 | cache['i'] += 1
65 | cache['attention_mask'] = kwargs['attention_mask']
66 | if llama_type:
67 | cache['position_ids'] = kwargs['position_ids']
68 | raise ValueError
69 | layers[0] = Catcher(layers[0])
70 |
71 | for i in range(nbatches):
72 | batch = input_ids[i]
73 | try:
74 | model(batch)
75 | except ValueError:
76 | pass
77 | layers[0] = layers[0].module
78 | layers[0] = layers[0].cpu()
79 |
80 | if opt_type:
81 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
82 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
83 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
84 | model.model.decoder.project_out = model.model.decoder.project_out.cpu()
85 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
86 | model.model.decoder.project_in = model.model.decoder.project_in.cpu()
87 | elif llama_type:
88 | model.model.embed_tokens = model.model.embed_tokens.cpu()
89 | position_ids = cache['position_ids']
90 |
91 | torch.cuda.empty_cache()
92 | outs = [0] * nbatches
93 | attention_mask = cache['attention_mask']
94 |
95 | for i in tqdm(range(len(layers)), desc="(Eval) Layers"):
96 | layer = layers[i].to(dev)
97 |
98 | # Dump the layer input and output
99 | if args.capture_layer_io and args.layer_idx == i:
100 | captured_io = model_utils.capture_layer_io(model_utils.get_model_type(model), layer, inps)
101 | save_path = model_utils.get_layer_io_save_path(args)
102 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
103 | torch.save(captured_io, save_path)
104 | logging.info(f'Dumped layer input and output to: {save_path}')
105 |
106 | for j in range(nbatches):
107 | if opt_type:
108 | outs[j] = layer(inps[j], attention_mask=attention_mask)[0]
109 | elif llama_type:
110 | outs[j] = layer(inps[j], attention_mask=attention_mask, position_ids=position_ids)[0]
111 | layers[i] = layer.cpu()
112 | del layer
113 | torch.cuda.empty_cache()
114 | inps, outs = outs, inps
115 |
116 | if opt_type:
117 | if model.model.decoder.final_layer_norm is not None:
118 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
119 | if model.model.decoder.project_out is not None:
120 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
121 |
122 | elif llama_type:
123 | if model.model.norm is not None:
124 | model.model.norm = model.model.norm.to(dev)
125 |
126 | model.lm_head = model.lm_head.to(dev)
127 | nlls = []
128 | loss_fct = torch.nn.CrossEntropyLoss(reduction = "none")
129 | for i in range(nbatches):
130 | hidden_states = inps[i]
131 | if opt_type:
132 | if model.model.decoder.final_layer_norm is not None:
133 | hidden_states = model.model.decoder.final_layer_norm(hidden_states)
134 | if model.model.decoder.project_out is not None:
135 | hidden_states = model.model.decoder.project_out(hidden_states)
136 | elif llama_type:
137 | if model.model.norm is not None:
138 | hidden_states = model.model.norm(hidden_states)
139 | lm_logits = model.lm_head(hidden_states)
140 | shift_logits = lm_logits[:, :-1, :]
141 | shift_labels = input_ids[i][:, 1:]
142 | loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels)
143 | neg_log_likelihood = loss.float().mean(dim=1)
144 | nlls.append(neg_log_likelihood)
145 | nlls_tensor = torch.cat(nlls)
146 | ppl = torch.exp(nlls_tensor.mean())
147 | model.config.use_cache = use_cache
148 | logging.info(f'\n{args.eval_dataset.upper()} PPL: {ppl.item():.3f}')
149 | return ppl.item()
150 |
--------------------------------------------------------------------------------
/fake_quant/gptaq_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import tqdm
4 | import torch
5 | import torch.nn as nn
6 | import utils
7 | import quant_utils
8 | import model_utils
9 | import logging
10 | import functools
11 |
12 | torch.backends.cuda.matmul.allow_tf32 = False
13 | torch.backends.cudnn.allow_tf32 = False
14 |
15 | class GPTAQ:
16 |
17 | def __init__(self, layer):
18 | self.layer = layer
19 | self.dev = self.layer.weight.device
20 | W = layer.weight.data.clone()
21 | self.rows = W.shape[0]
22 | self.columns = W.shape[1]
23 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
24 | self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev)
25 | self.nsamples = 0
26 | self.fp_inp = []
27 |
28 | def add_batch(self, inp, out):
29 |
30 | if len(inp.shape) == 2:
31 | inp = inp.unsqueeze(0)
32 | tmp = inp.shape[0]
33 | if len(inp.shape) == 3:
34 | inp = inp.reshape((-1, inp.shape[-1]))
35 |
36 | inp = inp.t()
37 |
38 | self.H *= self.nsamples / (self.nsamples + tmp)
39 | self.dXXT *= self.nsamples / (self.nsamples + tmp)
40 | self.nsamples += tmp
41 | inp = math.sqrt(2 / self.nsamples) * inp.float()
42 | self.H += inp.matmul(inp.t())
43 | dX = self.fp_inp[0].float() * math.sqrt(2 / self.nsamples) - inp
44 | self.dXXT += dX.matmul(inp.t())
45 |
46 | del self.fp_inp[0]
47 |
48 | def fasterquant(
49 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False, alpha=0.25
50 | ):
51 | W = self.layer.weight.data.clone()
52 | W = W.float()
53 |
54 | if not self.quantizer.ready():
55 | self.quantizer.find_params(W)
56 |
57 | H = self.H
58 | del self.H
59 | dead = torch.diag(H) == 0
60 | H[dead, dead] = 1
61 | W[:, dead] = 0
62 | self.dXXT[:, dead] = 0
63 |
64 | if static_groups:
65 | import copy
66 | groups = []
67 | for i in range(0, self.columns, groupsize):
68 | quantizer = copy.deepcopy(self.quantizer)
69 | quantizer.find_params(W[:, i:(i + groupsize)])
70 | groups.append(quantizer)
71 |
72 | if actorder:
73 | perm = torch.argsort(torch.diag(H), descending=True)
74 | W = W[:, perm]
75 | H = H[perm][:, perm]
76 | self.dXXT = self.dXXT[perm][:, perm]
77 | invperm = torch.argsort(perm)
78 |
79 | Losses = torch.zeros_like(W)
80 | Q = torch.zeros_like(W)
81 |
82 | damp = percdamp * torch.mean(torch.diag(H))
83 | diag = torch.arange(self.columns, device=self.dev)
84 | H[diag, diag] += damp
85 | Hinv = torch.linalg.cholesky(H)
86 | Hinv = torch.cholesky_inverse(Hinv)
87 | Hinv = torch.linalg.cholesky(Hinv, upper=True)
88 |
89 | # scale it by alpha due to collection of dXXT axnd H
90 | P = alpha * ((self.dXXT @ Hinv.T).triu_(diagonal=1)) @ Hinv
91 | del self.dXXT
92 |
93 | for i1 in range(0, self.columns, blocksize):
94 | i2 = min(i1 + blocksize, self.columns)
95 | count = i2 - i1
96 |
97 | W1 = W[:, i1:i2].clone()
98 | Q1 = torch.zeros_like(W1)
99 | Err1 = torch.zeros_like(W1)
100 | Losses1 = torch.zeros_like(W1)
101 | Hinv1 = Hinv[i1:i2, i1:i2]
102 | P1 = P[i1:i2, i1:i2]
103 |
104 | for i in range(count):
105 | w = W1[:, i]
106 | d = Hinv1[i, i]
107 |
108 | if groupsize != -1:
109 | if not static_groups:
110 | if (i1 + i) % groupsize == 0:
111 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
112 | else:
113 | idx = i1 + i
114 | if actorder:
115 | idx = perm[idx]
116 | self.quantizer = groups[idx // groupsize]
117 |
118 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
119 | Q1[:, i] = q
120 | Losses1[:, i] = (w - q) ** 2 / d ** 2
121 |
122 | err1 = (w - q) / d
123 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0))
124 | Err1[:, i] = err1
125 |
126 | Q[:, i1:i2] = Q1
127 | Losses[:, i1:i2] = Losses1 / 2
128 |
129 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:])
130 |
131 | torch.cuda.synchronize()
132 |
133 | if actorder:
134 | Q = Q[:, invperm]
135 |
136 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
137 | if torch.any(torch.isnan(self.layer.weight.data)):
138 | logging.warning('NaN in weights')
139 | import pprint
140 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point)
141 | raise ValueError('NaN in weights')
142 |
143 | def free(self):
144 | self.H = None
145 | self.Losses = None
146 | self.Trace = None
147 | self.dXXT = None
148 | torch.cuda.empty_cache()
149 | utils.cleanup_memory(verbos=False)
150 |
151 |
152 | @torch.no_grad()
153 | def gptaq_fwrd(model, dataloader, dev, args):
154 | '''
155 | From GPTQ repo
156 | TODO: Make this function general to support both OPT and LLaMA models
157 | '''
158 | logging.info('-----GPTAQ Quantization-----')
159 |
160 | use_cache = model.config.use_cache
161 | model.config.use_cache = False
162 | layers = model.model.layers
163 |
164 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
165 | model.model.norm = model.model.norm.to(dev)
166 | # model.model.rotary_emb = model.model.rotary_emb.to(dev)
167 |
168 | layers[0] = layers[0].to(dev)
169 |
170 | dtype = next(iter(model.parameters())).dtype
171 | inps = torch.zeros(
172 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
173 | )
174 |
175 | cache = {'i': 0, 'attention_mask': None}
176 |
177 | class Catcher(nn.Module):
178 | def __init__(self, module):
179 | super().__init__()
180 | self.module = module
181 |
182 | def forward(self, inp, **kwargs):
183 | inps[cache['i']] = inp
184 | cache['i'] += 1
185 | cache['attention_mask'] = kwargs['attention_mask']
186 | cache['position_ids'] = kwargs['position_ids']
187 | raise ValueError
188 |
189 | layers[0] = Catcher(layers[0])
190 | for batch in dataloader:
191 | try:
192 | model(batch[0].to(dev))
193 | except ValueError:
194 | pass
195 | layers[0] = layers[0].module
196 |
197 | layers[0] = layers[0].cpu()
198 | model.model.embed_tokens = model.model.embed_tokens.cpu()
199 | model.model.norm = model.model.norm.cpu()
200 | torch.cuda.empty_cache()
201 |
202 | outs = torch.zeros_like(inps)
203 |
204 | attention_mask = cache['attention_mask']
205 | position_ids = cache['position_ids']
206 |
207 | quantizers = {}
208 | sequential = [
209 | ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
210 | ['self_attn.o_proj.module'],
211 | ['mlp.up_proj.module', 'mlp.gate_proj.module'],
212 | ['mlp.down_proj.module']
213 | ]
214 |
215 | fp_inputs_cache = model_utils.FPInputsCache(sequential)
216 | fp_inps = inps.clone()
217 |
218 | for i in range(len(layers)):
219 | print(f'\nLayer {i}:', flush=True, end=' ')
220 | layer = layers[i].to(dev)
221 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear])
222 |
223 | bits_config = quant_utils.disable_act_quant(layer)
224 | fp_inputs_cache.add_hook(full)
225 |
226 | for j in range(args.nsamples):
227 | fp_inps[j] = layer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
228 | fp_inputs_cache.clear_hook()
229 | quant_utils.enable_act_quant(layer, bits_config)
230 |
231 | for names in sequential:
232 | subset = {n: full[n] for n in names}
233 |
234 | gptq = {}
235 | for name in subset:
236 | print(f'{name}', end=' ', flush=True)
237 | layer_weight_bits = args.w_bits
238 | layer_weight_sym = not (args.w_asym)
239 | if 'lm_head' in name:
240 | layer_weight_bits = 16
241 | continue
242 | if args.int8_down_proj and 'down_proj' in name:
243 | layer_weight_bits = 8
244 | gptq[name] = GPTAQ(subset[name])
245 | gptq[name].quantizer = quant_utils.WeightQuantizer()
246 | gptq[name].quantizer.configure(
247 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip
248 | )
249 | gptq[name].fp_inp = fp_inputs_cache.fp_cache[name]
250 |
251 | def add_batch(name):
252 | def tmp(_, inp, out):
253 | gptq[name].add_batch(inp[0].data, out.data)
254 |
255 | return tmp
256 |
257 | first_module_name = list(subset.keys())[0]
258 | handle = subset[first_module_name].register_forward_hook(add_batch(first_module_name))
259 |
260 | for j in range(args.nsamples):
261 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
262 | handle.remove()
263 |
264 | # copy H and dXXT
265 | for name in subset:
266 | if name != first_module_name:
267 | gptq[name].H = gptq[first_module_name].H
268 | gptq[name].dXXT = gptq[first_module_name].dXXT
269 |
270 | for name in subset:
271 | layer_w_groupsize = args.w_groupsize
272 | gptq[name].fasterquant(
273 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order,
274 | static_groups=args.static_groups
275 | )
276 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
277 | gptq[name].free()
278 |
279 | for j in range(args.nsamples):
280 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
281 |
282 | fp_inputs_cache.clear_cache()
283 | layers[i] = layer.cpu()
284 | del layer
285 | del gptq
286 | torch.cuda.empty_cache()
287 |
288 | inps, outs = outs, inps
289 |
290 | model.config.use_cache = use_cache
291 | utils.cleanup_memory(verbos=True)
292 | logging.info('-----GPTAQ Quantization Done-----\n')
293 |
294 | return quantizers
295 |
--------------------------------------------------------------------------------
/fake_quant/gptq_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import tqdm
4 | import torch
5 | import torch.nn as nn
6 | import utils
7 | import quant_utils
8 | import logging
9 |
10 | torch.backends.cuda.matmul.allow_tf32 = False
11 | torch.backends.cudnn.allow_tf32 = False
12 |
13 |
14 | class GPTQ:
15 |
16 | def __init__(self, layer):
17 | self.layer = layer
18 | self.dev = self.layer.weight.device
19 | W = layer.weight.data.clone()
20 | self.rows = W.shape[0]
21 | self.columns = W.shape[1]
22 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
23 | self.nsamples = 0
24 |
25 | def add_batch(self, inp, out):
26 |
27 | if len(inp.shape) == 2:
28 | inp = inp.unsqueeze(0)
29 | tmp = inp.shape[0]
30 | if len(inp.shape) == 3:
31 | inp = inp.reshape((-1, inp.shape[-1]))
32 | inp = inp.t()
33 | self.H *= self.nsamples / (self.nsamples + tmp)
34 | self.nsamples += tmp
35 | # inp = inp.float()
36 | inp = math.sqrt(2 / self.nsamples) * inp.float()
37 | # self.H += 2 / self.nsamples * inp.matmul(inp.t())
38 | self.H += inp.matmul(inp.t())
39 |
40 | def fasterquant(
41 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False
42 | ):
43 | W = self.layer.weight.data.clone()
44 | W = W.float()
45 |
46 | tick = time.time()
47 |
48 | if not self.quantizer.ready():
49 | self.quantizer.find_params(W)
50 |
51 | H = self.H
52 | del self.H
53 | dead = torch.diag(H) == 0
54 | H[dead, dead] = 1
55 | W[:, dead] = 0
56 |
57 | if static_groups:
58 | import copy
59 | groups = []
60 | for i in range(0, self.columns, groupsize):
61 | quantizer = copy.deepcopy(self.quantizer)
62 | quantizer.find_params(W[:, i:(i + groupsize)])
63 | groups.append(quantizer)
64 |
65 | if actorder:
66 | perm = torch.argsort(torch.diag(H), descending=True)
67 | W = W[:, perm]
68 | H = H[perm][:, perm]
69 | invperm = torch.argsort(perm)
70 |
71 | Losses = torch.zeros_like(W)
72 | Q = torch.zeros_like(W)
73 |
74 | damp = percdamp * torch.mean(torch.diag(H))
75 | diag = torch.arange(self.columns, device=self.dev)
76 | H[diag, diag] += damp
77 | H = torch.linalg.cholesky(H)
78 | H = torch.cholesky_inverse(H)
79 | H = torch.linalg.cholesky(H, upper=True)
80 | Hinv = H
81 |
82 | for i1 in range(0, self.columns, blocksize):
83 | i2 = min(i1 + blocksize, self.columns)
84 | count = i2 - i1
85 |
86 | W1 = W[:, i1:i2].clone()
87 | Q1 = torch.zeros_like(W1)
88 | Err1 = torch.zeros_like(W1)
89 | Losses1 = torch.zeros_like(W1)
90 | Hinv1 = Hinv[i1:i2, i1:i2]
91 |
92 | for i in range(count):
93 | w = W1[:, i]
94 | d = Hinv1[i, i]
95 |
96 | if groupsize != -1:
97 | if not static_groups:
98 | if (i1 + i) % groupsize == 0:
99 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
100 | else:
101 | idx = i1 + i
102 | if actorder:
103 | idx = perm[idx]
104 | self.quantizer = groups[idx // groupsize]
105 |
106 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
107 | Q1[:, i] = q
108 | Losses1[:, i] = (w - q) ** 2 / d ** 2
109 |
110 | err1 = (w - q) / d
111 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
112 | Err1[:, i] = err1
113 |
114 | Q[:, i1:i2] = Q1
115 | Losses[:, i1:i2] = Losses1 / 2
116 |
117 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
118 |
119 | torch.cuda.synchronize()
120 |
121 | if actorder:
122 | Q = Q[:, invperm]
123 |
124 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
125 | if torch.any(torch.isnan(self.layer.weight.data)):
126 | logging.warning('NaN in weights')
127 | import pprint
128 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point)
129 | raise ValueError('NaN in weights')
130 |
131 | def free(self):
132 | self.H = None
133 | self.Losses = None
134 | self.Trace = None
135 | torch.cuda.empty_cache()
136 | utils.cleanup_memory(verbos=False)
137 |
138 |
139 | @torch.no_grad()
140 | def gptq_fwrd(model, dataloader, dev, args):
141 | '''
142 | From GPTQ repo
143 | TODO: Make this function general to support both OPT and LLaMA models
144 | '''
145 | logging.info('-----GPTQ Quantization-----')
146 |
147 | use_cache = model.config.use_cache
148 | model.config.use_cache = False
149 | layers = model.model.layers
150 |
151 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
152 | model.model.norm = model.model.norm.to(dev)
153 | layers[0] = layers[0].to(dev)
154 |
155 | dtype = next(iter(model.parameters())).dtype
156 | inps = torch.zeros(
157 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
158 | )
159 | cache = {'i': 0, 'attention_mask': None}
160 |
161 | class Catcher(nn.Module):
162 | def __init__(self, module):
163 | super().__init__()
164 | self.module = module
165 | def forward(self, inp, **kwargs):
166 | inps[cache['i']] = inp
167 | cache['i'] += 1
168 | cache['attention_mask'] = kwargs['attention_mask']
169 | cache['position_ids'] = kwargs['position_ids']
170 | raise ValueError
171 |
172 | layers[0] = Catcher(layers[0])
173 | for batch in dataloader:
174 | try:
175 | model(batch[0].to(dev))
176 | except ValueError:
177 | pass
178 | layers[0] = layers[0].module
179 |
180 | layers[0] = layers[0].cpu()
181 | model.model.embed_tokens = model.model.embed_tokens.cpu()
182 | model.model.norm = model.model.norm.cpu()
183 | torch.cuda.empty_cache()
184 |
185 | outs = torch.zeros_like(inps)
186 | attention_mask = cache['attention_mask']
187 | position_ids = cache['position_ids']
188 |
189 | quantizers = {}
190 | sequential = [
191 | ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
192 | ['self_attn.o_proj.module'],
193 | ['mlp.up_proj.module', 'mlp.gate_proj.module'],
194 | ['mlp.down_proj.module']
195 | ]
196 | for i in range(len(layers)):
197 | print(f'\nLayer {i}:', flush=True, end=' ')
198 | layer = layers[i].to(dev)
199 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear])
200 | for names in sequential:
201 | subset = {n: full[n] for n in names}
202 |
203 | gptq = {}
204 | for name in subset:
205 | print(f'{name}', end=' ', flush=True)
206 | layer_weight_bits = args.w_bits
207 | layer_weight_sym = not(args.w_asym)
208 | if 'lm_head' in name:
209 | layer_weight_bits = 16
210 | continue
211 | if args.int8_down_proj and 'down_proj' in name:
212 | layer_weight_bits = 8
213 | gptq[name] = GPTQ(subset[name])
214 | gptq[name].quantizer = quant_utils.WeightQuantizer()
215 | gptq[name].quantizer.configure(
216 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip
217 | )
218 |
219 | def add_batch(name):
220 | def tmp(_, inp, out):
221 | gptq[name].add_batch(inp[0].data, out.data)
222 | return tmp
223 | handles = []
224 | for name in subset:
225 | handles.append(subset[name].register_forward_hook(add_batch(name)))
226 | for j in range(args.nsamples):
227 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
228 | for h in handles:
229 | h.remove()
230 |
231 | for name in subset:
232 | layer_w_groupsize = args.w_groupsize
233 | gptq[name].fasterquant(
234 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False
235 | )
236 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
237 | gptq[name].free()
238 |
239 | for j in range(args.nsamples):
240 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
241 |
242 | layers[i] = layer.cpu()
243 | del layer
244 | del gptq
245 | torch.cuda.empty_cache()
246 |
247 | inps, outs = outs, inps
248 |
249 | model.config.use_cache = use_cache
250 | utils.cleanup_memory(verbos=True)
251 | logging.info('-----GPTQ Quantization Done-----\n')
252 | return quantizers
253 |
254 |
255 |
256 |
257 | @torch.no_grad()
258 | def rtn_fwrd(model, dev, args):
259 | '''
260 | From GPTQ repo
261 | TODO: Make this function general to support both OPT and LLaMA models
262 | '''
263 | assert args.w_groupsize ==-1, "Groupsize not supported in RTN!"
264 | layers = model.model.layers
265 | torch.cuda.empty_cache()
266 |
267 | quantizers = {}
268 |
269 | for i in tqdm.tqdm(range(len(layers)), desc="(RtN Quant.) Layers"):
270 | layer = layers[i].to(dev)
271 |
272 | subset = quant_utils.find_qlayers(layer,
273 | layers=[torch.nn.Linear])
274 |
275 | for name in subset:
276 | layer_weight_bits = args.w_bits
277 | if 'lm_head' in name:
278 | layer_weight_bits = 16
279 | continue
280 | if args.int8_down_proj and 'down_proj' in name:
281 | layer_weight_bits = 8
282 |
283 | quantizer = quant_utils.WeightQuantizer()
284 | quantizer.configure(
285 | layer_weight_bits, perchannel=True, sym=not(args.w_asym), mse=args.w_clip
286 | )
287 | W = subset[name].weight.data
288 | quantizer.find_params(W)
289 | subset[name].weight.data = quantizer.quantize(W).to(
290 | next(iter(layer.parameters())).dtype)
291 | quantizers['model.layers.%d.%s' % (i, name)] = quantizer.cpu()
292 | layers[i] = layer.cpu()
293 | torch.cuda.empty_cache()
294 | del layer
295 |
296 | utils.cleanup_memory(verbos=True)
297 | return quantizers
298 |
--------------------------------------------------------------------------------
/fake_quant/main.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import torch
3 | import model_utils
4 | import data_utils
5 | import transformers
6 | import quant_utils
7 | import rotation_utils
8 | import gptq_utils
9 | import gptaq_utils
10 | import eval_utils
11 | import hadamard_utils
12 |
13 |
14 | def add_aq(model, args):
15 | # Add Input Quantization
16 | if args.a_bits < 16 or args.v_bits < 16:
17 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
18 | down_proj_groupsize = -1
19 | if args.a_groupsize > 0 and "llama" in args.model:
20 | down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize)
21 |
22 | for name in qlayers:
23 | layer_input_bits = args.a_bits
24 | layer_groupsize = args.a_groupsize
25 | layer_a_sym = not(args.a_asym)
26 | layer_a_clip = args.a_clip_ratio
27 |
28 | if 'v_proj' in name and args.v_bits < 16: #Set the v_proj precision
29 | qlayers[name].out_quantizer.configure(bits=args.v_bits,
30 | groupsize=args.v_groupsize,
31 | sym=not(args.v_asym),
32 | clip_ratio=args.v_clip_ratio)
33 |
34 | if 'lm_head' in name: #Skip lm_head quantization
35 | layer_input_bits = 16
36 |
37 | if 'down_proj' in name: #Set the down_proj precision
38 | if args.int8_down_proj:
39 | layer_input_bits = 8
40 | layer_groupsize = down_proj_groupsize
41 |
42 | qlayers[name].quantizer.configure(bits=layer_input_bits,
43 | groupsize=layer_groupsize,
44 | sym=layer_a_sym,
45 | clip_ratio=layer_a_clip)
46 |
47 | if args.k_bits < 16:
48 | if args.k_pre_rope:
49 | raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
50 | else:
51 | rope_function_name = model_utils.get_rope_function_name(model)
52 | layers = model_utils.get_layers(model)
53 | k_quant_config = {'k_bits':args.k_bits, "k_groupsize": args.k_groupsize,
54 | "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio}
55 | for layer in layers:
56 | rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
57 | layer.self_attn,
58 | rope_function_name,
59 | config=model.config,
60 | **k_quant_config)
61 |
62 |
63 | def main():
64 | args = utils.parser_gen()
65 | if args.wandb:
66 | import wandb
67 | wandb.init(project=args.wandb_project, entity=args.wandb_id)
68 | wandb.config.update(args)
69 |
70 | transformers.set_seed(args.seed)
71 | model = model_utils.get_model(args.model, args.hf_token)
72 | model.eval()
73 |
74 | # Rotate the weights
75 | if args.rotate:
76 | rotation_utils.fuse_layer_norms(model)
77 | rotation_utils.rotate_model(model, args)
78 | utils.cleanup_memory(verbos=True)
79 |
80 | quant_utils.add_actquant(model) #Add Activation Wrapper to the model
81 | qlayers = quant_utils.find_qlayers(model)
82 | for name in qlayers:
83 | if 'down_proj' in name:
84 | had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
85 | qlayers[name].online_full_had = True
86 | qlayers[name].had_K = had_K
87 | qlayers[name].K = K
88 | qlayers[name].fp32_had = args.fp32_had
89 | if 'o_proj' in name:
90 | had_K, K = hadamard_utils.get_hadK(model.config.num_attention_heads)
91 | qlayers[name].online_partial_had = True
92 | qlayers[name].had_K = had_K
93 | qlayers[name].K = K
94 | qlayers[name].had_dim = model.config.hidden_size//model.config.num_attention_heads
95 | qlayers[name].fp32_had = args.fp32_had
96 | else:
97 | quant_utils.add_actquant(model) #Add Activation Wrapper to the model as the rest of the code assumes it is present
98 |
99 | if args.enable_aq_calibration:
100 | add_aq(model, args)
101 |
102 | if args.w_bits < 16:
103 | save_dict = {}
104 | if args.load_qmodel_path: # Load Quantized Rotated Model
105 | assert args.rotate, "Model should be rotated to load a quantized model!"
106 | assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
107 | print("Load quantized model from ", args.load_qmodel_path)
108 | save_dict = torch.load(args.load_qmodel_path)
109 | model.load_state_dict(save_dict["model"], strict=False)
110 |
111 | elif not args.w_rtn: # GPTQ Weight Quantization
112 | assert "llama" in args.model, "Only llama is supported for GPTQ!"
113 |
114 | trainloader = data_utils.get_loaders(
115 | args.cal_dataset, nsamples=args.nsamples,
116 | seed=args.seed, model=args.model,
117 | seqlen=model.seqlen, eval_mode=False
118 | )
119 | if args.asym_calibrate:
120 | quantizers = gptaq_utils.gptaq_fwrd(model, trainloader, utils.DEV, args)
121 | save_dict["w_quantizers"] = quantizers
122 | else:
123 | quantizers = gptq_utils.gptq_fwrd(model, trainloader, utils.DEV, args)
124 | save_dict["w_quantizers"] = quantizers
125 | else: # RTN Weight Quantization
126 | quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args)
127 | save_dict["w_quantizers"] = quantizers
128 |
129 | if args.save_qmodel_path:
130 | save_dict["model"] = model.state_dict()
131 | torch.save(save_dict, args.save_qmodel_path)
132 |
133 | if not args.enable_aq_calibration:
134 | add_aq(model, args)
135 |
136 | # Evaluating on dataset
137 | testloader = data_utils.get_loaders(
138 | args.eval_dataset,
139 | seed=args.seed,
140 | model=args.model,
141 | seqlen=model.seqlen,
142 | hf_token=args.hf_token,
143 | eval_mode=True
144 | )
145 |
146 | dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, args)
147 |
148 | if args.wandb:
149 | wandb.log({'ppl/{}'.format(args.eval_dataset.upper()): dataset_ppl})
150 |
151 | if not args.lm_eval:
152 | return
153 | else:
154 | # Import lm_eval utils
155 | import lm_eval
156 | from lm_eval import utils as lm_eval_utils
157 | from lm_eval.api.registry import ALL_TASKS
158 | from lm_eval.models.huggingface import HFLM
159 |
160 | if args.distribute:
161 | utils.distribute_model(model)
162 | else:
163 | model.to(utils.DEV)
164 |
165 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=False, use_auth_token=args.hf_token)
166 | hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.lm_eval_batch_size)
167 |
168 | # commenting out this line as it will include two lambda sub-tasks
169 | # task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS)
170 | task_names = args.tasks
171 | results = lm_eval.simple_evaluate(hflm, tasks=task_names, batch_size=args.lm_eval_batch_size)['results']
172 |
173 | metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results.items()}
174 | metric_vals['acc_avg'] = round(sum(metric_vals.values()) / len(metric_vals.values()), 4)
175 | print(metric_vals)
176 |
177 | if args.wandb:
178 | wandb.log(metric_vals)
179 |
180 |
181 | if __name__ == '__main__':
182 | main()
183 |
--------------------------------------------------------------------------------
/fake_quant/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import typing
3 | import transformers
4 | import utils
5 | import os
6 | import logging
7 | import functools
8 |
9 | OPT_MODEL = transformers.models.opt.modeling_opt.OPTForCausalLM
10 | OPT_LAYER = transformers.models.opt.modeling_opt.OPTDecoderLayer
11 | LLAMA_MODEL = transformers.models.llama.modeling_llama.LlamaForCausalLM
12 | LLAMA_LAYER = transformers.models.llama.modeling_llama.LlamaDecoderLayer
13 |
14 |
15 | def model_type_extractor(model):
16 | if isinstance(model, LLAMA_MODEL):
17 | return LLAMA_MODEL
18 | elif isinstance(model, OPT_MODEL):
19 | return OPT_MODEL
20 | else:
21 | raise ValueError(f'Unknown model type {model}')
22 |
23 | def skip(*args, **kwargs):
24 | # This is a helper function to save time during the initialization!
25 | pass
26 |
27 | def get_rope_function_name(model):
28 | if isinstance(model, LLAMA_MODEL):
29 | return "apply_rotary_pos_emb"
30 | raise NotImplementedError
31 |
32 |
33 | def get_layers(model):
34 | if isinstance(model, OPT_MODEL):
35 | return model.model.decoder.layers
36 | if isinstance(model, LLAMA_MODEL):
37 | return model.model.layers
38 | raise NotImplementedError
39 |
40 |
41 | def get_llama(model_name, hf_token):
42 | torch.nn.init.kaiming_uniform_ = skip
43 | torch.nn.init.uniform_ = skip
44 | torch.nn.init.normal_ = skip
45 | model = transformers.LlamaForCausalLM.from_pretrained(model_name, torch_dtype='auto',
46 | use_auth_token=hf_token,
47 | low_cpu_mem_usage=True)
48 | model.seqlen = 2048
49 | logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen))
50 | return model
51 |
52 |
53 |
54 | def get_opt(model_name):
55 | torch.nn.init.kaiming_uniform_ = skip
56 | torch.nn.init.uniform_ = skip
57 | torch.nn.init.normal_ = skip
58 | model = transformers.OPTForCausalLM.from_pretrained(model_name, torch_dtype='auto',
59 | low_cpu_mem_usage=True)
60 | model.seqlen = model.config.max_position_embeddings
61 | logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen))
62 | return model
63 |
64 |
65 | def get_model(
66 | model_name, hf_token=None
67 | ):
68 | if 'llama' in model_name:
69 | return get_llama(model_name, hf_token)
70 | elif 'opt' in model_name:
71 | return get_opt(model_name)
72 | else:
73 | raise ValueError(f'Unknown model {model_name}')
74 |
75 |
76 | def get_model_type(model):
77 | if isinstance(model, OPT_MODEL):
78 | model_type = OPT_MODEL
79 | elif isinstance(model, LLAMA_MODEL):
80 | model_type = LLAMA_MODEL
81 | else:
82 | raise ValueError(f'Unknown model type {model}')
83 | return model_type
84 |
85 | def get_embeddings(model, model_type) -> list[torch.nn.Module]:
86 | if model_type == LLAMA_MODEL:
87 | return [model.model.embed_tokens]
88 | elif model_type == OPT_MODEL:
89 | return [model.model.decoder.embed_tokens, model.model.decoder.embed_positions]
90 | else:
91 | raise ValueError(f'Unknown model type {model_type}')
92 |
93 |
94 | def get_transformer_layers(model, model_type):
95 | if model_type == LLAMA_MODEL:
96 | return [layer for layer in model.model.layers]
97 | elif model_type == OPT_MODEL:
98 | return [layer for layer in model.model.decoder.layers]
99 | else:
100 | raise ValueError(f'Unknown model type {model_type}')
101 |
102 |
103 | def get_lm_head(model, model_type):
104 | if model_type == LLAMA_MODEL:
105 | return model.lm_head
106 | elif model_type == OPT_MODEL:
107 | return model.lm_head
108 | else:
109 | raise ValueError(f'Unknown model type {model_type}')
110 |
111 | def get_pre_head_layernorm(model, model_type):
112 | if model_type == LLAMA_MODEL:
113 | pre_head_layernorm = model.model.norm
114 | assert isinstance(pre_head_layernorm,
115 | transformers.models.llama.modeling_llama.LlamaRMSNorm)
116 | elif model_type == OPT_MODEL:
117 | pre_head_layernorm = model.model.decoder.final_layer_norm
118 | assert pre_head_layernorm is not None
119 | else:
120 | raise ValueError(f'Unknown model type {model_type}')
121 | return pre_head_layernorm
122 |
123 | def get_mlp_bottleneck_size(model):
124 | model_type = get_model_type(model)
125 | if model_type == LLAMA_MODEL:
126 | return model.config.intermediate_size
127 | elif model_type == OPT_MODEL:
128 | return model.config.ffn_dim
129 | else:
130 | raise ValueError(f'Unknown model type {model_type}')
131 |
132 | def replace_modules(
133 | root: torch.nn.Module,
134 | type_to_replace,
135 | new_module_factory,
136 | replace_layers: bool,
137 | ) -> None:
138 | """Replace modules of given type using the supplied module factory.
139 |
140 | Perform a depth-first search of a module hierarchy starting at root
141 | and replace all instances of type_to_replace with modules created by
142 | new_module_factory. Children of replaced modules are not processed.
143 |
144 | Args:
145 | root: the root of the module hierarchy where modules should be replaced
146 | type_to_replace: a type instances of which will be replaced
147 | new_module_factory: a function that given a module that should be replaced
148 | produces a module to replace it with.
149 | """
150 | for name, module in root.named_children():
151 | new_module = None
152 | if isinstance(module, type_to_replace):
153 | if replace_layers: # layernorm_fusion.replace_layers case where transformer layers are replaced
154 | new_module = new_module_factory(module, int(name))
155 | else: # layernorm_fusion.fuse_modules case where layernorms are fused
156 | new_module = new_module_factory(module)
157 | elif len(list(module.children())) > 0:
158 | replace_modules(module, type_to_replace, new_module_factory, replace_layers)
159 |
160 | if new_module is not None:
161 | setattr(root, name, new_module)
162 |
163 |
164 | class RMSN(torch.nn.Module):
165 | """
166 | This class implements the Root Mean Square Normalization (RMSN) layer.
167 | We use the implementation from LLAMARMSNorm here:
168 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L75
169 | """
170 |
171 | def __init__(self, mean_dim: int, eps=1e-5):
172 | super().__init__()
173 | self.eps = eps
174 | self.mean_dim = mean_dim
175 | self.weight = torch.nn.Parameter(torch.zeros(1))
176 |
177 | def forward(self, x: torch.Tensor) -> torch.Tensor:
178 | input_dtype = x.dtype
179 | if x.dtype == torch.float16:
180 | x = x.to(torch.float32)
181 | variance = x.pow(2).sum(-1, keepdim=True) / self.mean_dim
182 | x = x * torch.rsqrt(variance + self.eps)
183 | return x.to(input_dtype)
184 |
185 |
186 | class FPInputsCache:
187 | """
188 | class for saving the full-precision output in each layer.
189 | """
190 | def __init__(self, sequential):
191 | self.fp_cache = {}
192 | self.names = sequential[0]+sequential[1]+sequential[2]+sequential[3]
193 | for name in self.names:
194 | self.fp_cache[name] = []
195 | self.handles = []
196 |
197 | def cache_fp_input(self, m, inp, out, name):
198 | inp = inp[0].detach()
199 | if len(inp.shape) == 3:
200 | inp = inp.reshape((-1, inp.shape[-1]))
201 | self.fp_cache[name] += [inp.t()]
202 |
203 | def add_hook(self, full):
204 | for name in self.names:
205 | self.handles.append(
206 | full[name].register_forward_hook(
207 | functools.partial(self.cache_fp_input, name=name)
208 | )
209 | )
210 |
211 | def clear_hook(self):
212 | for h in self.handles:
213 | h.remove()
214 | self.handles = []
215 | torch.cuda.empty_cache()
216 |
217 | def clear_cache(self):
218 | for name in self.names:
219 | self.fp_cache[name] = []
220 |
221 |
222 | def get_layer_io_save_path(args):
223 | return os.path.join(args.save_path, 'layer_io', f'{args.layer_idx:03d}.pt')
224 |
225 |
226 | def capture_layer_io(model_type, layer, layer_input):
227 | def hook_factory(module_name, captured_vals, is_input):
228 | def hook(module, input, output):
229 | if is_input:
230 | captured_vals[module_name].append(input[0].detach().cpu())
231 | else:
232 | captured_vals[module_name].append(output.detach().cpu())
233 | return hook
234 |
235 | handles = []
236 |
237 | if model_type == LLAMA_MODEL:
238 | captured_inputs = {
239 | 'k_proj': [], # q_proj, v_proj has the same input as k_proj
240 | 'o_proj': [],
241 | 'gate_proj': [], # up_proj has the same input as gate_proj
242 | 'down_proj': []
243 | }
244 |
245 | captured_outputs = {
246 | 'v_proj': [],
247 | }
248 |
249 | for name in captured_inputs.keys():
250 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None)
251 | handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True)))
252 |
253 | for name in captured_outputs.keys():
254 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None)
255 | handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False)))
256 |
257 | elif model_type == OPT_MODEL:
258 | captured_inputs = {
259 | 'k_proj': [], # q_proj, v_proj has the same input as k_proj
260 | 'out_proj': [],
261 | 'fc1': [],
262 | 'fc2': []
263 | }
264 | captured_outputs = {
265 | 'v_proj': [],
266 | }
267 | for name in captured_inputs.keys():
268 | # In OPT, fc1 and fc2 are directly contained in OPTDecoderLayer
269 | module = getattr(layer.self_attn, name, None) or getattr(layer, name, None)
270 | handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True)))
271 |
272 | for name in captured_outputs.keys():
273 | # In OPT, fc1 and fc2 are directly contained in OPTDecoderLayer
274 | module = getattr(layer.self_attn, name, None) or getattr(layer, name, None)
275 | handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False)))
276 | else:
277 | raise ValueError(f'Unknown model type {model_type}')
278 |
279 | # Process each sequence in the batch one by one to avoid OOM.
280 | for seq_idx in range(layer_input.shape[0]):
281 | # Extract the current sequence across all dimensions.
282 | seq = layer_input[seq_idx:seq_idx + 1].to(utils.DEV)
283 | # Perform a forward pass for the current sequence.
284 | layer(seq)
285 |
286 | # After processing all sequences, concatenate the accumulated inputs for each sub-layer across the batch.
287 | for module_name in captured_inputs:
288 | captured_inputs[module_name] = torch.cat(captured_inputs[module_name], dim=0)
289 | for module_name in captured_outputs:
290 | captured_outputs[module_name] = torch.cat(captured_outputs[module_name], dim=0)
291 |
292 | # Cleanup.
293 | for h in handles:
294 | h.remove()
295 |
296 | return {
297 | 'input': captured_inputs,
298 | 'output': captured_outputs
299 | }
300 |
--------------------------------------------------------------------------------
/fake_quant/monkeypatch.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import types
4 |
5 | def copy_func_with_new_globals(f, globals=None):
6 | """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
7 | if globals is None:
8 | globals = f.__globals__
9 | g = types.FunctionType(f.__code__, globals, name=f.__name__,
10 | argdefs=f.__defaults__, closure=f.__closure__)
11 | g = functools.update_wrapper(g, f)
12 | g.__module__ = f.__module__
13 | g.__kwdefaults__ = copy.copy(f.__kwdefaults__)
14 | return g
15 |
16 | def add_wrapper_after_function_call_in_method(module, method_name, function_name, wrapper_fn):
17 | '''
18 | This function adds a wrapper after the output of a function call in the method named `method_name`.
19 | Only calls directly in the method are affected. Calls by other functions called in the method are not affected.
20 | '''
21 |
22 | original_method = getattr(module, method_name).__func__
23 | method_globals = dict(original_method.__globals__)
24 | wrapper = wrapper_fn(method_globals[function_name])
25 | method_globals[function_name] = wrapper
26 | new_method = copy_func_with_new_globals(original_method, globals=method_globals)
27 | setattr(module, method_name, new_method.__get__(module))
28 | return wrapper
29 |
30 |
--------------------------------------------------------------------------------
/fake_quant/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.38.0
2 | torch==2.2.1
3 | sentencepiece==0.2.0
4 | wandb==0.16.3
5 | huggingface-hub==0.20.3
6 | accelerate==0.27.2
7 | datasets==2.17.1
8 | lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@9b0b15b1ccace3534ffbd13298c569869ce8eaf3
--------------------------------------------------------------------------------
/fake_quant/run_llama.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | export CUDA_VISIBLE_DEVICES=$gpu_id
5 |
6 | python main.py --model meta-llama/Meta-Llama-3-8B \
7 | --w_bits 4 \
8 | --w_groupsize -1 \
9 | --w_clip \
10 | --a_bits 4 \
11 | --v_bits 16 \
12 | --k_bits 16 \
13 | --k_asym \
14 | --v_asym \
15 | --w_asym \
16 | --a_asym \
17 | --a_clip_ratio 0.9 \
18 | --k_clip_ratio 0.95 \
19 | --v_clip_ratio 0.95 \
20 | --asym_calibrate \
21 | --enable_aq_calibration \
22 | --rotate
--------------------------------------------------------------------------------
/img/readme_intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Computing-Lab-Yale/GPTAQ/883afc64f40ac061a59ef8cbe5bd794c01377e37/img/readme_intro.png
--------------------------------------------------------------------------------
/spinquant/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/spinquant/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to SpinQuant
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `main`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints.
13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 |
15 | ## Contributor License Agreement ("CLA")
16 | In order to accept your pull request, we need you to submit a CLA. You only need
17 | to do this once to work on any of Facebook's open source projects.
18 |
19 | Complete your CLA here:
20 |
21 | ## Issues
22 | We use GitHub issues to track public bugs. Please ensure your description is
23 | clear and has sufficient instructions to be able to reproduce the issue.
24 |
25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26 | disclosure of security bugs. In those cases, please go through the process
27 | outlined on that page and do not file a public issue.
28 |
29 | ## License
30 | By contributing to SpinQuant, you agree that your contributions will be licensed
31 | under the LICENSE file in the root directory of this source tree.
32 |
--------------------------------------------------------------------------------
/spinquant/README.md:
--------------------------------------------------------------------------------
1 |
2 | # SpinQuant
3 |
4 |
5 |
6 | This repository contains the code of "[SpinQuant: LLM Quantization with Learned Rotations](https://arxiv.org/pdf/2405.16406)"
7 |
8 |
9 |
10 | ## Run (Following original repo)
11 |
12 |
13 |
14 | ### 1. Requirements:
15 |
16 | * python 3.9, pytorch >= 2.0
17 |
18 | * install pytorch with cuda from https://pytorch.org/get-started/locally/, it is prerequisite for fast-hadamard-transform package.
19 |
20 | * pip install -r requirement.txt
21 |
22 | * git clone https://github.com/Dao-AILab/fast-hadamard-transform.git
23 |
24 | * cd fast-hadamard-transform
25 |
26 | * pip install .
27 |
28 | ### 2. Steps to run:
29 |
30 | **We directly use the pretrained rotation matrix of spinquant to run our experiments**
31 |
32 | You can download the optimized rotation matrices [here](https://drive.google.com/drive/folders/1R2zix4qeXBjcmgnJN1rny93cguJ4rEE8?usp=sharing).
33 |
34 | **Note that to perform GPTQ/GPTQv2, we need to download the W16A4 rotation matrix.**
35 |
36 | After obtaining the optimized_rotation, put the rotation matrix into optimized_rotation_path for evaluation.
37 |
38 | * `bash scripts/2_eval_ptq.sh`
39 |
40 | **For some reasons we don't know, we cannot reproduce exactly the same results with SpinQuant on LLaMA3, if you find out, please let us know**. For example, on LLaMA3-8B, we were unable to reproduce 7.10 perplexity.
41 |
42 |
43 | | Method | Bits | Perplexity |
44 | |-------------------------|------|------------|
45 | | SpinQuant+GPTQ (Paper) | W4A4 | 7.1 |
46 | | SpinQuant+GPTQ (Ours) | W4A4 | 7.26 |
47 | | SpinQuant+GPTQv2 (Ours) | W4A4 | 7.19 |
48 |
49 |
50 |
51 | ### 3. Export to ExecuTorch (The same with original code)
52 |
53 | We also support exporting the quantized model to ExecuTorch, which allows us to utilize the quantization kernels and achieve real-time speedup. For more information on kernel implementation details, please see [ExecuTorch](https://pytorch.org/executorch/stable/index.html), and [ExecuTorch with SpinQuant](https://github.com/pytorch/executorch/tree/main/examples/models/llama#spinquant). We currently support 4-bit weight (set group-size to 256 for 8B model and to 32 for smaller model) and 8-bit dynamic activation quantization.
54 |
55 |
56 | To obtain ExecuTorch-compatible quantized models, you can use the following scripts:
57 |
58 |
59 | * `bash scripts/31_optimize_rotation_executorch.sh $model_name`
60 |
61 | * `bash scripts/32_eval_ptq_executorch.sh $model_name`
62 |
63 |
64 | ### Arguments
65 |
66 |
67 |
68 | - `--input_model`: The model name (or path to the weights)
69 |
70 | - `--output_rotation_path`: The local path we want to store the oprimized rotation matrix
71 |
72 | - `--per_device_train_batch_size`: The batch size for rotation optimization
73 |
74 | - `--per_device_eval_batch_size`: The batch size for PPL evaluation
75 |
76 | - `--a_bits`: The number of bits for activation quantization
77 |
78 | - `--w_bits`: The number of bits for weight quantization
79 |
80 | - `--v_bits`: The number of bits for value quantization
81 |
82 | - `--k_bits`: The number of bits for key quantization
83 |
84 | - `--w_clip`: Whether using the grid search to find best weight clipping range
85 |
86 | - `--w_rtn`: Whether we want to use round-to-nearest quantization. If not having `--w_rtn`, we are using GPTQ quantization.
87 |
88 | - `--w_groupsize`: The group size for group-wise weight quantization.
89 |
90 | - `--rotate`: Whether we want to rotate the model
91 |
92 | - `--optimized_rotation_path`: The checkpoint path of optimized rotation; Use random rotation if path is not given
93 |
94 |
95 |
--------------------------------------------------------------------------------
/spinquant/SpinQuant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Computing-Lab-Yale/GPTAQ/883afc64f40ac061a59ef8cbe5bd794c01377e37/spinquant/SpinQuant.png
--------------------------------------------------------------------------------
/spinquant/eval_utils/gptaq_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import tqdm
4 | import torch
5 | import torch.nn as nn
6 | import logging
7 | import functools
8 |
9 | from utils import quant_utils, utils
10 |
11 |
12 | torch.backends.cuda.matmul.allow_tf32 = False
13 | torch.backends.cudnn.allow_tf32 = False
14 |
15 | class GPTAQ:
16 |
17 | def __init__(self, layer):
18 | self.layer = layer
19 | self.dev = self.layer.weight.device
20 | W = layer.weight.data.clone()
21 | self.rows = W.shape[0]
22 | self.columns = W.shape[1]
23 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
24 | self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev)
25 | self.nsamples = 0
26 | self.fp_inp = []
27 |
28 | def add_batch(self, inp, out):
29 |
30 | if len(inp.shape) == 2:
31 | inp = inp.unsqueeze(0)
32 | tmp = inp.shape[0]
33 | if len(inp.shape) == 3:
34 | inp = inp.reshape((-1, inp.shape[-1]))
35 |
36 | inp = inp.t()
37 |
38 | self.H *= self.nsamples / (self.nsamples + tmp)
39 | self.dXXT *= self.nsamples / (self.nsamples + tmp)
40 | self.nsamples += tmp
41 | inp = math.sqrt(2 / self.nsamples) * inp.float()
42 | self.H += inp.matmul(inp.t())
43 | dX = self.fp_inp[0].float() * math.sqrt(2 / self.nsamples) - inp
44 | self.dXXT += dX.matmul(inp.t())
45 |
46 | del self.fp_inp[0]
47 |
48 | def fasterquant(
49 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False
50 | ):
51 | W = self.layer.weight.data.clone()
52 | W = W.float()
53 |
54 | if not self.quantizer.ready():
55 | self.quantizer.find_params(W)
56 |
57 | H = self.H
58 | del self.H
59 | dead = torch.diag(H) == 0
60 | H[dead, dead] = 1
61 | W[:, dead] = 0
62 | self.dXXT[:, dead] = 0
63 |
64 | if static_groups:
65 | import copy
66 | groups = []
67 | for i in range(0, self.columns, groupsize):
68 | quantizer = copy.deepcopy(self.quantizer)
69 | quantizer.find_params(W[:, i:(i + groupsize)])
70 | groups.append(quantizer)
71 |
72 | if actorder:
73 | perm = torch.argsort(torch.diag(H), descending=True)
74 | W = W[:, perm]
75 | H = H[perm][:, perm]
76 | self.dXXT = self.dXXT[perm][:, perm]
77 | invperm = torch.argsort(perm)
78 |
79 | Losses = torch.zeros_like(W)
80 | Q = torch.zeros_like(W)
81 |
82 | damp = percdamp * torch.mean(torch.diag(H))
83 | diag = torch.arange(self.columns, device=self.dev)
84 | H[diag, diag] += damp
85 | Hinv = torch.linalg.cholesky(H)
86 | Hinv = torch.cholesky_inverse(Hinv)
87 | Hinv = torch.linalg.cholesky(Hinv, upper=True)
88 |
89 | alpha = 0.25
90 | P = alpha * ((self.dXXT @ Hinv.T).triu(diagonal=1)) @ Hinv
91 | del self.dXXT
92 |
93 | for i1 in range(0, self.columns, blocksize):
94 | i2 = min(i1 + blocksize, self.columns)
95 | count = i2 - i1
96 |
97 | W1 = W[:, i1:i2].clone()
98 | Q1 = torch.zeros_like(W1)
99 | Err1 = torch.zeros_like(W1)
100 | Losses1 = torch.zeros_like(W1)
101 | Hinv1 = Hinv[i1:i2, i1:i2]
102 | P1 = P[i1:i2, i1:i2]
103 |
104 | for i in range(count):
105 | w = W1[:, i]
106 | d = Hinv1[i, i]
107 |
108 | if groupsize != -1:
109 | if not static_groups:
110 | if (i1 + i) % groupsize == 0:
111 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
112 | else:
113 | idx = i1 + i
114 | if actorder:
115 | idx = perm[idx]
116 | self.quantizer = groups[idx // groupsize]
117 |
118 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
119 | Q1[:, i] = q
120 | Losses1[:, i] = (w - q) ** 2 / d ** 2
121 |
122 | err1 = (w - q) / d
123 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0))
124 | Err1[:, i] = err1
125 |
126 | Q[:, i1:i2] = Q1
127 | Losses[:, i1:i2] = Losses1 / 2
128 |
129 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:])
130 |
131 | torch.cuda.synchronize()
132 |
133 | if actorder:
134 | Q = Q[:, invperm]
135 |
136 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
137 | if torch.any(torch.isnan(self.layer.weight.data)):
138 | logging.warning('NaN in weights')
139 | import pprint
140 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point)
141 | raise ValueError('NaN in weights')
142 |
143 | def free(self):
144 | self.H = None
145 | self.Losses = None
146 | self.Trace = None
147 | self.dXXT = None
148 | torch.cuda.empty_cache()
149 | utils.cleanup_memory(verbos=False)
150 |
151 |
152 | class FPInputsCache:
153 | """
154 | class for saving the full-precision output in each layer.
155 | """
156 | def __init__(self, sequential):
157 | self.fp_cache = {}
158 | self.names = sequential[0]+sequential[1]+sequential[2]+sequential[3]
159 | print(self.names)
160 | for name in self.names:
161 | self.fp_cache[name] = []
162 | self.handles = []
163 |
164 | def cache_fp_input(self, m, inp, out, name):
165 | inp = inp[0].detach()
166 | if len(inp.shape) == 3:
167 | inp = inp.reshape((-1, inp.shape[-1]))
168 | self.fp_cache[name] += [inp.t()]
169 |
170 | def add_hook(self, full):
171 | for name in self.names:
172 | self.handles.append(
173 | full[name].register_forward_hook(
174 | functools.partial(self.cache_fp_input, name=name)
175 | )
176 | )
177 |
178 | def clear_hook(self):
179 | for h in self.handles:
180 | h.remove()
181 | self.handles = []
182 | torch.cuda.empty_cache()
183 |
184 | def clear_cache(self):
185 | for name in self.names:
186 | self.fp_cache[name] = []
187 |
188 |
189 | @torch.no_grad()
190 | def gptaq_fwrd(model, dataloader, dev, args):
191 | '''
192 | From GPTQ repo
193 | TODO: Make this function general to support both OPT and LLaMA models
194 | '''
195 | print('-----GPTQv2 Quantization-----')
196 |
197 | use_cache = model.config.use_cache
198 | model.config.use_cache = False
199 | layers = model.model.layers
200 |
201 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
202 | model.model.norm = model.model.norm.to(dev)
203 | model.model.rotary_emb = model.model.rotary_emb.to(dev)
204 |
205 | layers[0] = layers[0].to(dev)
206 |
207 | dtype = next(iter(model.parameters())).dtype
208 | inps = torch.zeros(
209 | (args.nsamples, 2048, model.config.hidden_size), dtype=dtype, device=dev
210 | )
211 | cache = {'i': 0, 'attention_mask': None}
212 |
213 | class Catcher(nn.Module):
214 | def __init__(self, module):
215 | super().__init__()
216 | self.module = module
217 |
218 | def forward(self, inp, **kwargs):
219 | inps[cache['i']] = inp
220 | cache['i'] += 1
221 | cache['attention_mask'] = kwargs['attention_mask']
222 | cache['position_ids'] = kwargs['position_ids']
223 | raise ValueError
224 |
225 | layers[0] = Catcher(layers[0])
226 | for batch in dataloader:
227 | try:
228 | model(batch[0].to(dev))
229 | except ValueError:
230 | pass
231 | layers[0] = layers[0].module
232 | layers[0] = layers[0].cpu()
233 |
234 | model.model.embed_tokens = model.model.embed_tokens.cpu()
235 | model.model.norm = model.model.norm.cpu()
236 | model.model.rotary_emb = model.model.rotary_emb.cpu()
237 | torch.cuda.empty_cache()
238 |
239 | outs = torch.zeros_like(inps)
240 |
241 | attention_mask = cache['attention_mask']
242 | position_ids = cache['position_ids']
243 |
244 | quantizers = {}
245 | sequential = [
246 | ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
247 | ['self_attn.o_proj.module'],
248 | ['mlp.up_proj.module', 'mlp.gate_proj.module'],
249 | ['mlp.down_proj.module']
250 | ]
251 |
252 | fp_inputs_cache = FPInputsCache(sequential)
253 | fp_inps = inps.clone()
254 |
255 | for i in range(len(layers)):
256 | print(f'\nLayer {i}:', flush=True, end=' ')
257 | layer = layers[i].to(dev)
258 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear])
259 | bits_config = quant_utils.disable_act_quant(layer)
260 | fp_inputs_cache.add_hook(full)
261 |
262 | for j in range(args.nsamples):
263 | fp_inps[j] = layer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
264 | fp_inputs_cache.clear_hook()
265 | quant_utils.enable_act_quant(layer, bits_config)
266 |
267 | for names in sequential:
268 | subset = {n: full[n] for n in names}
269 |
270 | gptq = {}
271 | for name in subset:
272 | print(f'{name}', end=' ', flush=True)
273 | layer_weight_bits = args.w_bits
274 | layer_weight_sym = not (args.w_asym)
275 | if 'lm_head' in name:
276 | layer_weight_bits = 16
277 | continue
278 | if args.int8_down_proj and 'down_proj' in name:
279 | layer_weight_bits = 8
280 | gptq[name] = GPTAQ(subset[name])
281 | gptq[name].quantizer = quant_utils.WeightQuantizer()
282 | gptq[name].quantizer.configure(
283 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip
284 | )
285 | gptq[name].fp_inp = fp_inputs_cache.fp_cache[name]
286 |
287 | def add_batch(name):
288 | def tmp(_, inp, out):
289 | gptq[name].add_batch(inp[0].data, out.data)
290 |
291 | return tmp
292 |
293 | first_module_name = list(subset.keys())[0]
294 | handle = subset[first_module_name].register_forward_hook(add_batch(first_module_name))
295 |
296 | for j in range(args.nsamples):
297 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
298 | handle.remove()
299 |
300 | # copy H and dXXT
301 | for name in subset:
302 | if name != first_module_name:
303 | gptq[name].H = gptq[first_module_name].H
304 | gptq[name].dXXT = gptq[first_module_name].dXXT
305 |
306 | for name in subset:
307 | layer_w_groupsize = args.w_groupsize
308 | gptq[name].fasterquant(
309 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False
310 | )
311 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
312 | gptq[name].free()
313 |
314 | for j in range(args.nsamples):
315 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
316 |
317 | fp_inputs_cache.clear_cache()
318 | layers[i] = layer.cpu()
319 | del layer
320 | del gptq
321 | torch.cuda.empty_cache()
322 |
323 | inps, outs = outs, inps
324 |
325 | model.config.use_cache = use_cache
326 | utils.cleanup_memory(verbos=True)
327 | print('-----GPTQv2 Quantization Done-----\n')
328 | return quantizers
329 |
--------------------------------------------------------------------------------
/spinquant/eval_utils/main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import torch
12 | import transformers
13 |
14 | from eval_utils import gptq_utils, gptaq_utils, rotation_utils
15 | from utils import data_utils, fuse_norm_utils, hadamard_utils, quant_utils, utils
16 | from utils.convert_to_executorch import (
17 | sanitize_checkpoint_from_spinquant,
18 | write_model_llama,
19 | )
20 |
21 |
22 | def add_input_quantization(model, args):
23 | # Add Input Quantization
24 | if args.a_bits < 16 or args.v_bits < 16:
25 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
26 | down_proj_groupsize = -1
27 | if args.a_groupsize > 0:
28 | down_proj_groupsize = utils.llama_down_proj_groupsize(
29 | model, args.a_groupsize
30 | )
31 |
32 | for name in qlayers:
33 | layer_input_bits = args.a_bits
34 | layer_groupsize = args.a_groupsize
35 | layer_a_sym = not (args.a_asym)
36 | layer_a_clip = args.a_clip_ratio
37 |
38 | num_heads = model.config.num_attention_heads
39 | model_dim = model.config.hidden_size
40 | head_dim = model_dim // num_heads
41 |
42 | if "v_proj" in name and args.v_bits < 16: # Set the v_proj precision
43 | v_groupsize = head_dim
44 | qlayers[name].out_quantizer.configure(
45 | bits=args.v_bits,
46 | groupsize=v_groupsize,
47 | sym=not (args.v_asym),
48 | clip_ratio=args.v_clip_ratio,
49 | )
50 |
51 | if "o_proj" in name:
52 | layer_groupsize = head_dim
53 |
54 | if "lm_head" in name: # Skip lm_head quantization
55 | layer_input_bits = 16
56 |
57 | if "down_proj" in name: # Set the down_proj precision
58 | if args.int8_down_proj:
59 | layer_input_bits = 8
60 | layer_groupsize = down_proj_groupsize
61 |
62 | qlayers[name].quantizer.configure(
63 | bits=layer_input_bits,
64 | groupsize=layer_groupsize,
65 | sym=layer_a_sym,
66 | clip_ratio=layer_a_clip,
67 | )
68 |
69 | if args.k_bits < 16:
70 | if args.k_pre_rope:
71 | raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
72 | else:
73 | rope_function_name = "apply_rotary_pos_emb"
74 | layers = model.model.layers
75 | k_quant_config = {
76 | "k_bits": args.k_bits,
77 | "k_groupsize": args.k_groupsize,
78 | "k_sym": not (args.k_asym),
79 | "k_clip_ratio": args.k_clip_ratio,
80 | }
81 | for layer in layers:
82 | rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
83 | layer.self_attn,
84 | rope_function_name,
85 | config=model.config,
86 | **k_quant_config,
87 | )
88 |
89 |
90 | def ptq_model(args, model, model_args=None):
91 | transformers.set_seed(args.seed)
92 | model.eval()
93 |
94 | # Rotate the weights
95 | if args.rotate:
96 | fuse_norm_utils.fuse_layer_norms(model)
97 | rotation_utils.rotate_model(model, args)
98 | utils.cleanup_memory(verbos=True)
99 |
100 | quant_utils.add_actquant(model) # Add Activation Wrapper to the model
101 | qlayers = quant_utils.find_qlayers(model)
102 | for name in qlayers:
103 | if "down_proj" in name:
104 | had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
105 | qlayers[name].online_full_had = True
106 | qlayers[name].had_K = had_K
107 | qlayers[name].K = K
108 | qlayers[name].fp32_had = args.fp32_had
109 | else:
110 | quant_utils.add_actquant(
111 | model
112 | ) # Add Activation Wrapper to the model as the rest of the code assumes it is present
113 |
114 | if args.enable_aq_calibration:
115 | add_input_quantization(model, args)
116 |
117 | if args.w_bits < 16:
118 | save_dict = {}
119 | if args.load_qmodel_path: # Load Quantized Rotated Model
120 | assert args.rotate, "Model should be rotated to load a quantized model!"
121 | assert (
122 | not args.save_qmodel_path
123 | ), "Cannot save a quantized model if it is already loaded!"
124 | print("Load quantized model from ", args.load_qmodel_path)
125 | save_dict = torch.load(args.load_qmodel_path)
126 | model.load_state_dict(save_dict["model"], strict=False)
127 |
128 | elif not args.w_rtn: # GPTQ Weight Quantization
129 | trainloader = data_utils.get_wikitext2(
130 | nsamples=args.nsamples,
131 | seed=args.seed,
132 | model=model_args.input_model,
133 | seqlen=2048,
134 | eval_mode=False,
135 | )
136 | if args.export_to_et:
137 | # quantize lm_head and embedding with 8bit per-channel quantization with rtn for executorch
138 | quantizers = gptq_utils.rtn_fwrd(
139 | model,
140 | "cuda",
141 | args,
142 | custom_layers=[model.model.embed_tokens, model.lm_head],
143 | )
144 | # quantize other layers with gptq
145 | if args.asym_calibrate:
146 | quantizers = gptaq_utils.gptaq_fwrd(model, trainloader, "cuda", args)
147 | save_dict["w_quantizers"] = quantizers
148 | else:
149 | quantizers = gptq_utils.gptq_fwrd(model, trainloader, "cuda", args)
150 | save_dict["w_quantizers"] = quantizers
151 | else: # RTN Weight Quantization
152 | quantizers = gptq_utils.rtn_fwrd(model, "cuda", args)
153 | save_dict["w_quantizers"] = quantizers
154 |
155 | if args.save_qmodel_path:
156 | save_dict["model"] = model.state_dict()
157 | if args.export_to_et:
158 | save_dict = write_model_llama(
159 | model.state_dict(), model.config, num_shards=1
160 | )[0] # Export num_shards == 1 for executorch
161 | save_dict = sanitize_checkpoint_from_spinquant(
162 | save_dict, group_size=args.w_groupsize
163 | )
164 | torch.save(save_dict, args.save_qmodel_path)
165 |
166 | if not args.enable_aq_calibration:
167 | add_input_quantization(model, args)
168 |
169 | return model
170 |
--------------------------------------------------------------------------------
/spinquant/eval_utils/rotation_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import functools
12 | import math
13 |
14 | import torch
15 | import tqdm
16 |
17 | from utils import monkeypatch, quant_utils, utils
18 | from utils.hadamard_utils import (
19 | apply_exact_had_to_linear,
20 | is_pow2,
21 | random_hadamard_matrix,
22 | )
23 | from utils.utils import HadamardTransform
24 |
25 |
26 | def random_orthogonal_matrix(size, device):
27 | """
28 | Generate a random orthogonal matrix of the specified size.
29 | First, we generate a random matrix with entries from a standard distribution.
30 | Then, we use QR decomposition to obtain an orthogonal matrix.
31 | Finally, we multiply by a diagonal matrix with diag r to adjust the signs.
32 |
33 | Args:
34 | size (int): The size of the matrix (size x size).
35 |
36 | Returns:
37 | torch.Tensor: An orthogonal matrix of the specified size.
38 | """
39 | torch.cuda.empty_cache()
40 | random_matrix = torch.randn(size, size, dtype=torch.float64).to(device)
41 | q, r = torch.linalg.qr(random_matrix)
42 | q *= torch.sign(torch.diag(r)).unsqueeze(0)
43 | return q
44 |
45 |
46 | def get_orthogonal_matrix(size, mode, device="cuda"):
47 | if mode == "random":
48 | return random_orthogonal_matrix(size, device)
49 | elif mode == "hadamard":
50 | return random_hadamard_matrix(size, device)
51 | else:
52 | raise ValueError(f"Unknown mode {mode}")
53 |
54 |
55 | def rotate_embeddings(model, R1: torch.Tensor) -> None:
56 | # Rotate the embeddings.
57 | for W in [model.model.embed_tokens]:
58 | dtype = W.weight.data.dtype
59 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64)
60 | W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
61 |
62 |
63 | def rotate_attention_inputs(layer, R1) -> None:
64 | # Rotate the WQ, WK and WV matrices of the self-attention layer.
65 | for W in [layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj]:
66 | dtype = W.weight.dtype
67 | W_ = W.weight.to(device="cuda", dtype=torch.float64)
68 | W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
69 |
70 |
71 | def rotate_attention_output(layer, R1) -> None:
72 | # Rotate output matrix of the self-attention layer.
73 | W = layer.self_attn.o_proj
74 |
75 | dtype = W.weight.data.dtype
76 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64)
77 | W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype)
78 | if W.bias is not None:
79 | b = W.bias.data.to(device="cuda", dtype=torch.float64)
80 | W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype)
81 |
82 |
83 | def rotate_mlp_input(layer, R1):
84 | # Rotate the MLP input weights.
85 | mlp_inputs = [layer.mlp.up_proj, layer.mlp.gate_proj]
86 | for W in mlp_inputs:
87 | dtype = W.weight.dtype
88 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64)
89 | W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
90 |
91 |
92 | def rotate_mlp_output(layer, R1):
93 | # Rotate the MLP output weights and bias.
94 | W = layer.mlp.down_proj
95 | dtype = W.weight.data.dtype
96 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64)
97 | W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype)
98 | apply_exact_had_to_linear(
99 | W, had_dim=-1, output=False
100 | ) # apply exact (inverse) hadamard on the weights of mlp output
101 | if W.bias is not None:
102 | b = W.bias.data.to(device="cuda", dtype=torch.float64)
103 | W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype)
104 |
105 |
106 | def rotate_head(model, R1: torch.Tensor) -> None:
107 | # Rotate the head.
108 | W = model.lm_head
109 | dtype = W.weight.data.dtype
110 | W_ = W.weight.data.to(device="cuda", dtype=torch.float64)
111 | W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
112 |
113 |
114 | def rotate_ov_proj(layer, head_num, head_dim, R2=None):
115 | v_proj = layer.self_attn.v_proj
116 | o_proj = layer.self_attn.o_proj
117 |
118 | apply_exact_had_to_linear(v_proj, had_dim=head_dim, output=True, R2=R2)
119 | apply_exact_had_to_linear(o_proj, had_dim=head_dim, output=False, R2=R2)
120 |
121 |
122 | @torch.inference_mode()
123 | def rotate_model(model, args):
124 | R1 = get_orthogonal_matrix(model.config.hidden_size, args.rotate_mode)
125 | if args.optimized_rotation_path is not None:
126 | R_cpk = args.optimized_rotation_path
127 | R1 = torch.load(R_cpk)["R1"].cuda().to(torch.float64)
128 | config = model.config
129 | num_heads = config.num_attention_heads
130 | model_dim = config.hidden_size
131 | head_dim = model_dim // num_heads
132 |
133 | rotate_embeddings(model, R1)
134 | rotate_head(model, R1)
135 | utils.cleanup_memory()
136 | layers = [layer for layer in model.model.layers]
137 | for idx, layer in enumerate(tqdm.tqdm(layers, unit="layer", desc="Rotating")):
138 | if args.optimized_rotation_path is not None:
139 | key = f"model.layers.{idx}.self_attn.R2"
140 | R2 = torch.load(R_cpk)[key].cuda().to(torch.float64)
141 | else:
142 | R2 = get_orthogonal_matrix(head_dim, args.rotate_mode)
143 | rotate_attention_inputs(layers[idx], R1)
144 | rotate_attention_output(layers[idx], R1)
145 | rotate_mlp_input(layers[idx], R1)
146 | rotate_mlp_output(layers[idx], R1)
147 | rotate_ov_proj(layers[idx], num_heads, head_dim, R2=R2)
148 |
149 |
150 | class QKRotationWrapper(torch.nn.Module):
151 | def __init__(self, func, config, *args, **kwargs):
152 | super().__init__()
153 | self.config = config
154 | num_heads = config.num_attention_heads
155 | model_dim = config.hidden_size
156 | head_dim = model_dim // num_heads
157 | assert is_pow2(
158 | head_dim
159 | ), f"Only power of 2 head_dim is supported for K-cache Quantization!"
160 | self.func = func
161 | self.k_quantizer = quant_utils.ActQuantizer()
162 | self.k_bits = 16
163 | if kwargs is not None:
164 | assert kwargs["k_groupsize"] in [
165 | -1,
166 | head_dim,
167 | ], f"Only token-wise/{head_dim}g quantization is supported for K-cache"
168 | self.k_bits = kwargs["k_bits"]
169 | self.k_groupsize = kwargs["k_groupsize"]
170 | self.k_sym = kwargs["k_sym"]
171 | self.k_clip_ratio = kwargs["k_clip_ratio"]
172 | self.k_quantizer.configure(
173 | bits=self.k_bits,
174 | groupsize=-1, # we put -1 to be toke-wise quantization and handle head-wise quantization by ourself
175 | sym=self.k_sym,
176 | clip_ratio=self.k_clip_ratio,
177 | )
178 |
179 | def forward(self, *args, **kwargs):
180 | q, k = self.func(*args, **kwargs)
181 | dtype = q.dtype
182 | q = (HadamardTransform.apply(q.float()) / math.sqrt(q.shape[-1])).to(dtype)
183 | k = (HadamardTransform.apply(k.float()) / math.sqrt(k.shape[-1])).to(dtype)
184 | (bsz, num_heads, seq_len, head_dim) = k.shape
185 |
186 | if self.k_groupsize == -1: # token-wise quantization
187 | token_wise_k = k.transpose(1, 2).reshape(-1, num_heads * head_dim)
188 | self.k_quantizer.find_params(token_wise_k)
189 | k = (
190 | self.k_quantizer(token_wise_k)
191 | .reshape((bsz, seq_len, num_heads, head_dim))
192 | .transpose(1, 2)
193 | .to(q)
194 | )
195 | else: # head-wise quantization
196 | per_head_k = k.view(-1, head_dim)
197 | self.k_quantizer.find_params(per_head_k)
198 | k = (
199 | self.k_quantizer(per_head_k)
200 | .reshape((bsz, num_heads, seq_len, head_dim))
201 | .to(q)
202 | )
203 |
204 | self.k_quantizer.free()
205 |
206 | return q, k
207 |
208 |
209 | def add_qk_rotation_wrapper_after_function_call_in_forward(
210 | module,
211 | function_name,
212 | *args,
213 | **kwargs,
214 | ):
215 | """
216 | This function adds a rotation wrapper after the output of a function call in forward.
217 | Only calls directly in the forward function are affected. calls by other functions called in forward are not affected.
218 | """
219 |
220 | attr_name = f"{function_name}_qk_rotation_wrapper"
221 | assert not hasattr(module, attr_name)
222 | wrapper = monkeypatch.add_wrapper_after_function_call_in_method(
223 | module,
224 | "forward",
225 | function_name,
226 | functools.partial(QKRotationWrapper, *args, **kwargs),
227 | )
228 | setattr(module, attr_name, wrapper)
229 |
--------------------------------------------------------------------------------
/spinquant/optimize_rotation.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import datetime
9 | import os
10 | from logging import Logger
11 |
12 | import datasets
13 | import torch
14 | import torch.distributed as dist
15 | from torch import nn
16 | from transformers import LlamaTokenizerFast, Trainer, default_data_collator
17 | import transformers
18 | from train_utils.fsdp_trainer import FSDPTrainer
19 | from train_utils.main import prepare_model
20 | from train_utils.modeling_llama_quant import LlamaForCausalLM as LlamaForCausalLMQuant
21 | from train_utils.optimizer import SGDG
22 | from utils.data_utils import CustomJsonDataset
23 | from utils.hadamard_utils import random_hadamard_matrix
24 | from utils.process_args import process_args_ptq
25 | from utils.utils import get_local_rank, get_logger, pt_fsdp_state_dict
26 |
27 | log: Logger = get_logger("spinquant")
28 |
29 |
30 | class RotateModule(nn.Module):
31 | def __init__(self, R_init):
32 | super(RotateModule, self).__init__()
33 | self.weight = nn.Parameter(R_init.to(torch.float32).to(torch.device("cuda")))
34 |
35 | def forward(self, x, transpose=False):
36 | if transpose:
37 | return x @ self.weight
38 | else:
39 | return self.weight @ x
40 |
41 |
42 | def train() -> None:
43 | dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
44 | model_args, training_args, ptq_args = process_args_ptq()
45 | local_rank = get_local_rank()
46 |
47 | log.info("the rank is {}".format(local_rank))
48 | torch.distributed.barrier()
49 |
50 | config = transformers.AutoConfig.from_pretrained(
51 | model_args.input_model, token=model_args.access_token
52 | )
53 |
54 | # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
55 | process_word_embeddings = False
56 | if config.tie_word_embeddings:
57 | config.tie_word_embeddings = False
58 | process_word_embeddings = True
59 | dtype = torch.bfloat16 if training_args.bf16 else torch.float16
60 | model = LlamaForCausalLMQuant.from_pretrained(
61 | pretrained_model_name_or_path=model_args.input_model,
62 | config=config,
63 | torch_dtype=dtype,
64 | token=model_args.access_token,
65 | )
66 | if process_word_embeddings:
67 | model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
68 |
69 | model = prepare_model(ptq_args, model)
70 | for param in model.parameters():
71 | param.requires_grad = False
72 | R1 = random_hadamard_matrix(model.config.hidden_size, "cuda")
73 | model.R1 = RotateModule(R1)
74 | for i in range(model.config.num_hidden_layers):
75 | # Each head dim = 128 for Llama model
76 | R2 = random_hadamard_matrix(
77 | model.config.hidden_size // model.config.num_attention_heads, "cuda"
78 | )
79 | model.model.layers[i].self_attn.R2 = RotateModule(R2)
80 | if local_rank == 0:
81 | log.info("Model init completed for training {}".format(model))
82 | log.info("Start to load tokenizer...")
83 | tokenizer = LlamaTokenizerFast.from_pretrained(
84 | pretrained_model_name_or_path=model_args.input_model,
85 | cache_dir=training_args.cache_dir,
86 | model_max_length=training_args.model_max_length,
87 | padding_side="right",
88 | use_fast=True,
89 | add_eos_token=False,
90 | add_bos_token=False,
91 | token=model_args.access_token,
92 | )
93 | log.info("Complete tokenizer loading...")
94 | model.config.use_cache = False
95 | calibration_datasets = datasets.load_dataset(
96 | "Salesforce/wikitext", "wikitext-2-raw-v1"
97 | )
98 | train_data = CustomJsonDataset(
99 | calibration_datasets["train"],
100 | tokenizer,
101 | block_size=min(training_args.model_max_length, 2048),
102 | )
103 |
104 | trainable_parameters = [model.R1.weight] + [
105 | model.model.layers[i].self_attn.R2.weight
106 | for i in range(model.config.num_hidden_layers)
107 | ]
108 | model.seqlen = training_args.model_max_length
109 | optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True)
110 | MyTrainer = Trainer
111 | # Use FSDP for 70B rotation training
112 | if training_args.fsdp != "" and training_args.fsdp != []:
113 | MyTrainer = FSDPTrainer
114 |
115 | trainer = MyTrainer(
116 | model=model,
117 | tokenizer=tokenizer,
118 | args=training_args,
119 | train_dataset=train_data,
120 | eval_dataset=None,
121 | data_collator=default_data_collator,
122 | optimizers=(optimizer, None),
123 | )
124 | torch.distributed.barrier()
125 |
126 | trainer.train()
127 | if training_args.fsdp != "" and training_args.fsdp != []:
128 | cpu_state = pt_fsdp_state_dict(trainer.model)
129 | else:
130 | cpu_state = trainer.model.state_dict()
131 |
132 | R_dict = {
133 | key.replace(".weight", ""): value
134 | for key, value in cpu_state.items()
135 | if "R1.weight" in key or "self_attn.R2" in key
136 | }
137 | if local_rank == 0:
138 | os.makedirs(model_args.output_rotation_path, exist_ok=True)
139 | path = os.path.join(model_args.output_rotation_path, "R.bin")
140 | torch.save(
141 | R_dict,
142 | path,
143 | )
144 | dist.barrier()
145 |
146 |
147 | if __name__ == "__main__":
148 | train()
149 |
--------------------------------------------------------------------------------
/spinquant/ptq.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import datetime
9 | from logging import Logger
10 |
11 | import torch
12 | import torch.distributed as dist
13 | from transformers import LlamaTokenizerFast
14 | import transformers
15 | from eval_utils.main import ptq_model
16 | from eval_utils.modeling_llama import LlamaForCausalLM
17 | from utils import data_utils, eval_utils, utils
18 | from utils.process_args import process_args_ptq
19 |
20 | log: Logger = utils.get_logger("spinquant")
21 |
22 |
23 | def train() -> None:
24 | dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
25 | model_args, training_args, ptq_args = process_args_ptq()
26 | local_rank = utils.get_local_rank()
27 |
28 | log.info("the rank is {}".format(local_rank))
29 | torch.distributed.barrier()
30 |
31 | config = transformers.AutoConfig.from_pretrained(
32 | model_args.input_model, token=model_args.access_token
33 | )
34 | # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
35 | process_word_embeddings = False
36 | if config.tie_word_embeddings:
37 | config.tie_word_embeddings = False
38 | process_word_embeddings = True
39 | dtype = torch.bfloat16 if training_args.bf16 else torch.float16
40 | model = LlamaForCausalLM.from_pretrained(
41 | pretrained_model_name_or_path=model_args.input_model,
42 | config=config,
43 | torch_dtype=dtype,
44 | token=model_args.access_token,
45 | )
46 | if process_word_embeddings:
47 | model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
48 | # model.cuda()
49 |
50 | model = ptq_model(ptq_args, model, model_args)
51 | model.seqlen = training_args.model_max_length
52 | if local_rank == 0:
53 | # log.info("Model PTQ completed {}".format(model))
54 | log.info("Start to load tokenizer...")
55 | tokenizer = LlamaTokenizerFast.from_pretrained(
56 | pretrained_model_name_or_path=model_args.input_model,
57 | cache_dir=training_args.cache_dir,
58 | model_max_length=training_args.model_max_length,
59 | padding_side="right",
60 | use_fast=True,
61 | add_eos_token=False,
62 | add_bos_token=False,
63 | token=model_args.access_token,
64 | )
65 | log.info("Complete tokenizer loading...")
66 | model.config.use_cache = False
67 |
68 | testloader = data_utils.get_wikitext2(
69 | seed=ptq_args.seed,
70 | seqlen=2048,
71 | tokenizer=tokenizer,
72 | eval_mode=True,
73 | )
74 |
75 | dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, ptq_args)
76 | log.info("wiki2 ppl is: {}".format(dataset_ppl))
77 |
78 | if not ptq_args.lm_eval:
79 | return
80 | else:
81 | # Import lm_eval utils
82 | import lm_eval
83 | from lm_eval import utils as lm_eval_utils
84 | from lm_eval.api.registry import ALL_TASKS
85 | from lm_eval.models.huggingface import HFLM
86 |
87 | if ptq_args.distribute:
88 | utils.distribute_model(model)
89 | else:
90 | model.to(utils.DEV)
91 |
92 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.input_model, use_fast=False, use_auth_token=None)
93 | hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=ptq_args.lm_eval_batch_size)
94 |
95 | # commenting out this line as it will include two lambda sub-tasks
96 | # task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS)
97 | task_names = ptq_args.tasks
98 | results = lm_eval.simple_evaluate(hflm, tasks=task_names, batch_size=ptq_args.lm_eval_batch_size)['results']
99 |
100 | metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results.items()}
101 | metric_vals['acc_avg'] = round(sum(metric_vals.values()) / len(metric_vals.values()), 4)
102 | print(metric_vals)
103 |
104 |
105 | dist.barrier()
106 |
107 |
108 | if __name__ == "__main__":
109 | train()
110 |
--------------------------------------------------------------------------------
/spinquant/requirement.txt:
--------------------------------------------------------------------------------
1 | transformers==4.44.2
2 | accelerate==0.31.0
3 | datasets==2.20.0
4 | sentencepiece
5 | tensorboardX
6 |
--------------------------------------------------------------------------------
/spinquant/scripts/10_optimize_rotation.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node)
9 | # nproc_per_node indicates the number of GPUs per node to employ.
10 | torchrun --nnodes=1 --nproc_per_node=8 optimize_rotation.py \
11 | --input_model $1 \
12 | --output_rotation_path "your_path" \
13 | --output_dir "your_output_path/" \
14 | --logging_dir "your_log_path/" \
15 | --model_max_length 2048 \
16 | --fp16 False \
17 | --bf16 True \
18 | --log_on_each_node False \
19 | --per_device_train_batch_size 1 \
20 | --logging_steps 1 \
21 | --learning_rate 1.5 \
22 | --weight_decay 0. \
23 | --lr_scheduler_type "cosine" \
24 | --gradient_checkpointing True \
25 | --save_safetensors False \
26 | --max_steps 100 \
27 | --w_bits $2 \
28 | --a_bits $3 \
29 | --k_bits $4 \
30 | --v_bits $4 \
31 | --w_clip \
32 | --a_asym \
33 | --k_asym \
34 | --v_asym \
35 | --k_groupsize 128 \
36 | --v_groupsize 128 \
37 |
--------------------------------------------------------------------------------
/spinquant/scripts/11_optimize_rotation_fsdp.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node)
9 | # nproc_per_node indicates the number of GPUs per node to employ.
10 | torchrun --nnodes=1 --nproc_per_node=8 optimize_rotation.py \
11 | --input_model $1 \
12 | --output_rotation_path "your_path" \
13 | --output_dir "your_output_path/" \
14 | --logging_dir "your_log_path/" \
15 | --model_max_length 2048 \
16 | --fp16 False \
17 | --bf16 True \
18 | --log_on_each_node False \
19 | --per_device_train_batch_size 1 \
20 | --logging_steps 1 \
21 | --learning_rate 1.5 \
22 | --weight_decay 0. \
23 | --lr_scheduler_type "cosine" \
24 | --gradient_checkpointing True \
25 | --max_steps 100 \
26 | --w_bits $2 \
27 | --a_bits $3 \
28 | --k_bits $4 \
29 | --v_bits $4 \
30 | --w_clip \
31 | --a_asym \
32 | --k_asym \
33 | --v_asym \
34 | --k_groupsize 128 \
35 | --v_groupsize 128 \
36 | --fsdp "full_shard auto_wrap" \
37 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer'
38 |
--------------------------------------------------------------------------------
/spinquant/scripts/2_eval_ptq.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node)
9 | # nproc_per_node indicates the number of GPUs per node to employ.
10 |
11 |
12 | torchrun --nnodes=1 --nproc_per_node=1 ptq.py \
13 | --input_model "meta-llama/Meta-Llama-3-8B" \
14 | --do_train False \
15 | --do_eval True \
16 | --per_device_eval_batch_size 16 \
17 | --model_max_length 2048 \
18 | --fp16 False \
19 | --bf16 True \
20 | --save_safetensors False \
21 | --w_bits 4 \
22 | --a_bits 4 \
23 | --k_bits 16 \
24 | --v_bits 16 \
25 | --w_clip \
26 | --a_asym \
27 | --k_asym \
28 | --v_asym \
29 | --k_groupsize 128 \
30 | --v_groupsize 128 \
31 | --rotate \
32 | --optimized_rotation_path "ckpts/8B_W16A4KV16_lr_1.5_seed_0/R.bin" \
33 | --asym_calibrate \
34 | --enable_ap_calibration \
35 |
--------------------------------------------------------------------------------
/spinquant/scripts/31_optimize_rotation_executorch.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node)
9 | # nproc_per_node indicates the number of GPUs per node to employ.
10 | torchrun --nnodes=1 --nproc_per_node=8 optimize_rotation.py \
11 | --input_model $1 \
12 | --output_rotation_path "your_path" \
13 | --output_dir "your_output_path/" \
14 | --logging_dir "your_log_path/" \
15 | --model_max_length 2048 \
16 | --fp16 False \
17 | --bf16 True \
18 | --log_on_each_node False \
19 | --per_device_train_batch_size 1 \
20 | --logging_steps 1 \
21 | --learning_rate 1.5 \
22 | --weight_decay 0. \
23 | --lr_scheduler_type "cosine" \
24 | --gradient_checkpointing True \
25 | --save_safetensors False \
26 | --max_steps 100 \
27 | --w_bits 16 \
28 | --a_bits 8 \
29 | --w_clip \
30 | --a_asym \
31 | --w_groupsize 32
32 |
--------------------------------------------------------------------------------
/spinquant/scripts/32_eval_ptq_executorch.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # nnodes determines the number of GPU nodes to utilize (usually 1 for an 8 GPU node)
9 | # nproc_per_node indicates the number of GPUs per node to employ.
10 | torchrun --nnodes=1 --nproc_per_node=1 ptq.py \
11 | --input_model $1 \
12 | --do_train False \
13 | --do_eval True \
14 | --per_device_eval_batch_size 4 \
15 | --model_max_length 2048 \
16 | --fp16 False \
17 | --bf16 True \
18 | --save_safetensors False \
19 | --w_bits 4 \
20 | --a_bits 8 \
21 | --w_clip \
22 | --w_groupsize 32 \
23 | --a_asym \
24 | --rotate \
25 | --optimized_rotation_path "your_path/R.bin" \
26 | --save_qmodel_path "./your_output_model_path/consolidated.00.pth" \
27 | --export_to_et
28 |
--------------------------------------------------------------------------------
/spinquant/train_utils/apply_r3_r4.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import math
12 |
13 | import torch
14 | import tqdm
15 |
16 | from utils import quant_utils, utils
17 | from utils.hadamard_utils import (
18 | apply_exact_had_to_linear,
19 | is_pow2,
20 | )
21 | from utils.utils import HadamardTransform
22 |
23 |
24 | def R4_rotate_down_proj_weights(layer):
25 | # Rotate the MLP output weights and bias.
26 | W = layer.mlp.down_proj
27 | apply_exact_had_to_linear(
28 | W, had_dim=-1, output=False
29 | ) # apply exact (inverse) hadamard on the weights of mlp output
30 |
31 |
32 | @torch.inference_mode()
33 | def rotate_model(model, args):
34 | config = model.config
35 | num_heads = config.num_attention_heads
36 | model_dim = config.hidden_size
37 | head_dim = model_dim // num_heads
38 |
39 | utils.cleanup_memory()
40 | layers = [layer for layer in model.model.layers]
41 | for idx, layer in enumerate(
42 | tqdm.tqdm(layers, unit="layer", desc="Applying R4 rotation to W_down")
43 | ):
44 | R4_rotate_down_proj_weights(layers[idx])
45 |
46 |
47 | class QKRotationWrapper(torch.nn.Module):
48 | def __init__(self, func, config, *args, **kwargs):
49 | super().__init__()
50 | self.config = config
51 | num_heads = config.num_attention_heads
52 | model_dim = config.hidden_size
53 | head_dim = model_dim // num_heads
54 | assert is_pow2(
55 | head_dim
56 | ), f"Only power of 2 head_dim is supported for K-cache Quantization!"
57 | self.func = func
58 | self.k_quantizer = quant_utils.ActQuantizer()
59 | self.k_bits = 16
60 | if kwargs is not None:
61 | assert kwargs["k_groupsize"] in [
62 | -1,
63 | head_dim,
64 | ], f"Only token-wise/{head_dim}g quantization is supported for K-cache"
65 | self.k_bits = kwargs["k_bits"]
66 | self.k_groupsize = kwargs["k_groupsize"]
67 | self.k_sym = kwargs["k_sym"]
68 | self.k_clip_ratio = kwargs["k_clip_ratio"]
69 | self.k_quantizer.configure(
70 | bits=self.k_bits,
71 | groupsize=-1, # we put -1 to be toke-wise quantization and handle head-wise quantization by ourself
72 | sym=self.k_sym,
73 | clip_ratio=self.k_clip_ratio,
74 | )
75 |
76 | def forward(self, *args, **kwargs):
77 | q, k = self.func(*args, **kwargs)
78 | dtype = q.dtype
79 | q = (HadamardTransform.apply(q.float()) / math.sqrt(q.shape[-1])).to(dtype)
80 | k = (HadamardTransform.apply(k.float()) / math.sqrt(k.shape[-1])).to(dtype)
81 | (bsz, num_heads, seq_len, head_dim) = k.shape
82 |
83 | if self.k_groupsize == -1: # token-wise quantization
84 | token_wise_k = k.transpose(1, 2).reshape(-1, num_heads * head_dim)
85 | self.k_quantizer.find_params(token_wise_k)
86 | k = (
87 | self.k_quantizer(token_wise_k)
88 | .reshape((bsz, seq_len, num_heads, head_dim))
89 | .transpose(1, 2)
90 | .to(q)
91 | )
92 | else: # head-wise quantization
93 | per_head_k = k.view(-1, head_dim)
94 | self.k_quantizer.find_params(per_head_k)
95 | k = (
96 | self.k_quantizer(per_head_k)
97 | .reshape((bsz, num_heads, seq_len, head_dim))
98 | .to(q)
99 | )
100 |
101 | self.k_quantizer.free()
102 |
103 | return q, k
104 |
105 |
106 | def add_qk_rotation_wrapper_after_function_call_in_forward(
107 | module,
108 | function_name,
109 | *args,
110 | **kwargs,
111 | ):
112 | """
113 | This function adds a rotation wrapper after the output of a function call in forward.
114 | Only calls directly in the forward function are affected. calls by other functions called in forward are not affected.
115 | """
116 | import functools
117 |
118 | from utils import monkeypatch
119 |
120 | attr_name = f"{function_name}_qk_rotation_wrapper"
121 | assert not hasattr(module, attr_name)
122 | wrapper = monkeypatch.add_wrapper_after_function_call_in_method(
123 | module,
124 | "forward",
125 | function_name,
126 | functools.partial(QKRotationWrapper, *args, **kwargs),
127 | )
128 | setattr(module, attr_name, wrapper)
129 |
--------------------------------------------------------------------------------
/spinquant/train_utils/main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import transformers
12 |
13 | from train_utils import apply_r3_r4, rtn_utils
14 | from utils import fuse_norm_utils, hadamard_utils, quant_utils, utils
15 |
16 |
17 | def prepare_model(args, model):
18 | transformers.set_seed(args.seed)
19 | model.eval()
20 |
21 | # Rotate the weights
22 | fuse_norm_utils.fuse_layer_norms(model)
23 | apply_r3_r4.rotate_model(model, args)
24 | utils.cleanup_memory(verbos=True)
25 |
26 | quant_utils.add_actquant(model) # Add Activation Wrapper to the model
27 | qlayers = quant_utils.find_qlayers(model)
28 | for name in qlayers:
29 | if "down_proj" in name:
30 | had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
31 | qlayers[name].online_full_had = True
32 | qlayers[name].had_K = had_K
33 | qlayers[name].K = K
34 | qlayers[name].fp32_had = args.fp32_had
35 |
36 | if args.w_bits < 16:
37 | quantizers = rtn_utils.rtn_fwrd(model, "cuda", args)
38 |
39 | # Add Input Quantization
40 | if args.a_bits < 16 or args.v_bits < 16:
41 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
42 | down_proj_groupsize = -1
43 | if args.a_groupsize > 0:
44 | down_proj_groupsize = utils.llama_down_proj_groupsize(
45 | model, args.a_groupsize
46 | )
47 |
48 | for name in qlayers:
49 | layer_input_bits = args.a_bits
50 | layer_groupsize = args.a_groupsize
51 | layer_a_sym = not (args.a_asym)
52 | layer_a_clip = args.a_clip_ratio
53 |
54 | num_heads = model.config.num_attention_heads
55 | model_dim = model.config.hidden_size
56 | head_dim = model_dim // num_heads
57 |
58 | if "v_proj" in name and args.v_bits < 16: # Set the v_proj precision
59 | v_groupsize = head_dim
60 | qlayers[name].out_quantizer.configure(
61 | bits=args.v_bits,
62 | groupsize=v_groupsize,
63 | sym=not (args.v_asym),
64 | clip_ratio=args.v_clip_ratio,
65 | )
66 |
67 | if "o_proj" in name:
68 | layer_groupsize = head_dim
69 |
70 | if "lm_head" in name: # Skip lm_head quantization
71 | layer_input_bits = 16
72 |
73 | if "down_proj" in name: # Set the down_proj precision
74 | if args.int8_down_proj:
75 | layer_input_bits = 8
76 | layer_groupsize = down_proj_groupsize
77 |
78 | qlayers[name].quantizer.configure(
79 | bits=layer_input_bits,
80 | groupsize=layer_groupsize,
81 | sym=layer_a_sym,
82 | clip_ratio=layer_a_clip,
83 | )
84 |
85 | if args.k_bits < 16:
86 | if args.k_pre_rope:
87 | raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
88 | else:
89 | rope_function_name = "apply_rotary_pos_emb"
90 | layers = model.model.layers
91 | k_quant_config = {
92 | "k_bits": args.k_bits,
93 | "k_groupsize": args.k_groupsize,
94 | "k_sym": not (args.k_asym),
95 | "k_clip_ratio": args.k_clip_ratio,
96 | }
97 | for layer in layers:
98 | apply_r3_r4.add_qk_rotation_wrapper_after_function_call_in_forward(
99 | layer.self_attn,
100 | rope_function_name,
101 | config=model.config,
102 | **k_quant_config,
103 | )
104 |
105 | return model
106 |
--------------------------------------------------------------------------------
/spinquant/train_utils/optimizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is originally from: https://github.com/JunLi-Galios/Optimization-on-Stiefel-Manifold-via-Cayley-Transform/blob/master/stiefel_optimizer.py
9 |
10 | import random
11 |
12 | import torch
13 | from torch.optim.optimizer import Optimizer
14 |
15 |
16 | def unit(v, dim: int = 1, eps: float = 1e-8):
17 | vnorm = norm(v, dim)
18 | return v / vnorm.add(eps), vnorm
19 |
20 |
21 | def norm(v, dim: int = 1):
22 | assert len(v.size()) == 2
23 | return v.norm(p=2, dim=dim, keepdim=True)
24 |
25 |
26 | def matrix_norm_one(W):
27 | out = torch.abs(W)
28 | out = torch.sum(out, dim=0)
29 | out = torch.max(out)
30 | return out
31 |
32 |
33 | def Cayley_loop(X, W, tan_vec, t): #
34 | [n, p] = X.size()
35 | Y = X + t * tan_vec
36 | for i in range(5):
37 | Y = X + t * torch.matmul(W, 0.5 * (X + Y))
38 |
39 | return Y.t()
40 |
41 |
42 | def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n
43 | [p, n] = tan_vec.size()
44 | tan_vec.t_()
45 | q, r = torch.linalg.qr(tan_vec)
46 | d = torch.diag(r, 0)
47 | ph = d.sign()
48 | q *= ph.expand_as(q)
49 | q.t_()
50 |
51 | return q
52 |
53 |
54 | episilon = 1e-8
55 |
56 |
57 | class SGDG(Optimizer):
58 | r"""This optimizer updates variables with two different routines
59 | based on the boolean variable 'stiefel'.
60 |
61 | If stiefel is True, the variables will be updated by SGD-G proposed
62 | as decorrelated weight matrix.
63 |
64 | If stiefel is False, the variables will be updated by SGD.
65 | This routine was taken from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py.
66 |
67 | Args:
68 | params (iterable): iterable of parameters to optimize or dicts defining
69 | parameter groups
70 |
71 | -- common parameters
72 | lr (float): learning rate
73 | momentum (float, optional): momentum factor (default: 0)
74 | stiefel (bool, optional): whether to use SGD-G (default: False)
75 |
76 | -- parameters in case stiefel is False
77 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
78 | dampening (float, optional): dampening for momentum (default: 0)
79 | nesterov (bool, optional): enables Nesterov momentum (default: False)
80 |
81 | -- parameters in case stiefel is True
82 | omega (float, optional): orthogonality regularization factor (default: 0)
83 | grad_clip (float, optional): threshold for gradient norm clipping (default: None)
84 | """
85 |
86 | def __init__(
87 | self,
88 | params,
89 | lr,
90 | momentum: int = 0,
91 | dampening: int = 0,
92 | weight_decay: int = 0,
93 | nesterov: bool = False,
94 | stiefel: bool = False,
95 | omega: int = 0,
96 | grad_clip=None,
97 | ) -> None:
98 | defaults = dict(
99 | lr=lr,
100 | momentum=momentum,
101 | dampening=dampening,
102 | weight_decay=weight_decay,
103 | nesterov=nesterov,
104 | stiefel=stiefel,
105 | omega=0,
106 | grad_clip=grad_clip,
107 | )
108 | if nesterov and (momentum <= 0 or dampening != 0):
109 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
110 | super(SGDG, self).__init__(params, defaults)
111 |
112 | def __setstate__(self, state) -> None:
113 | super(SGDG, self).__setstate__(state)
114 | for group in self.param_groups:
115 | group.setdefault("nesterov", False)
116 |
117 | def step(self, closure=None):
118 | """Performs a single optimization step.
119 |
120 | Arguments:
121 | closure (callable, optional): A closure that reevaluates the model
122 | and returns the loss.
123 | """
124 | loss = None
125 | if closure is not None:
126 | loss = closure()
127 |
128 | for group in self.param_groups:
129 | momentum = group["momentum"]
130 | stiefel = group["stiefel"]
131 |
132 | for p in group["params"]:
133 | if p.grad is None:
134 | continue
135 |
136 | unity, _ = unit(p.data.view(p.size()[0], -1))
137 | if stiefel and unity.size()[0] <= unity.size()[1]:
138 | weight_decay = group["weight_decay"]
139 | dampening = group["dampening"]
140 | nesterov = group["nesterov"]
141 |
142 | rand_num = random.randint(1, 101)
143 | if rand_num == 1:
144 | unity = qr_retraction(unity)
145 |
146 | g = p.grad.data.view(p.size()[0], -1)
147 |
148 | lr = group["lr"]
149 |
150 | param_state = self.state[p]
151 | if "momentum_buffer" not in param_state:
152 | param_state["momentum_buffer"] = torch.zeros(g.t().size())
153 | if p.is_cuda:
154 | param_state["momentum_buffer"] = param_state[
155 | "momentum_buffer"
156 | ].cuda()
157 |
158 | V = param_state["momentum_buffer"]
159 | V = momentum * V - g.t()
160 | MX = torch.mm(V, unity)
161 | XMX = torch.mm(unity, MX)
162 | XXMX = torch.mm(unity.t(), XMX)
163 | W_hat = MX - 0.5 * XXMX
164 | W = W_hat - W_hat.t()
165 | t = 0.5 * 2 / (matrix_norm_one(W) + episilon)
166 | alpha = min(t, lr)
167 |
168 | p_new = Cayley_loop(unity.t(), W, V, alpha)
169 | V_new = torch.mm(W, unity.t()) # n-by-p
170 | # check_identity(p_new.t())
171 | p.data.copy_(p_new.view(p.size()))
172 | V.copy_(V_new)
173 |
174 | else:
175 | d_p = p.grad.data
176 | # defined.
177 | try:
178 | if weight_decay != 0:
179 | # defined.
180 | d_p.add_(weight_decay, p.data)
181 | except:
182 | pass
183 | if momentum != 0:
184 | param_state = self.state[p]
185 | if "momentum_buffer" not in param_state:
186 | buf = param_state["momentum_buffer"] = d_p.clone()
187 | else:
188 | buf = param_state["momentum_buffer"]
189 | # always defined.
190 | buf.mul_(momentum).add_(1 - dampening, d_p)
191 | # defined.
192 | if nesterov:
193 | d_p = d_p.add(momentum, buf)
194 | else:
195 | d_p = buf
196 |
197 | p.data.add_(-group["lr"], d_p)
198 |
199 | return loss
200 |
--------------------------------------------------------------------------------
/spinquant/train_utils/quant_linear.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch._tensor import Tensor
11 |
12 |
13 | class QuantizeLinear(nn.Linear):
14 | def forward(
15 | self,
16 | input: Tensor,
17 | R1=None,
18 | R2=None,
19 | transpose=False,
20 | ) -> Tensor:
21 | # quantize weight
22 | if R1 is not None:
23 | dtype = self.weight.dtype
24 | if not transpose:
25 | weight = (self.weight.to(torch.float64) @ R1.to(torch.float64)).to(
26 | dtype
27 | )
28 | else:
29 | weight = (R1.T.to(torch.float64) @ self.weight.to(torch.float64)).to(
30 | dtype
31 | )
32 | if R2 is not None:
33 | # Each head dim = 128 for Llama model
34 | had_dim = R2.shape[0]
35 | dtype = weight.dtype
36 | if transpose:
37 | W_ = weight
38 | init_shape = W_.shape
39 | temp = W_.reshape(-1, init_shape[-1] // had_dim, had_dim)
40 | temp = temp.to(torch.float64) @ R2.to(torch.float64)
41 | weight = temp.reshape(init_shape)
42 | else:
43 | W_ = weight.t()
44 | transposed_shape = W_.shape
45 | temp = W_.reshape(-1, transposed_shape[-1] // had_dim, had_dim)
46 | temp = temp.to(torch.float64) @ R2.to(torch.float64)
47 | weight = temp.reshape(transposed_shape).t()
48 | weight = weight.to(dtype)
49 | else:
50 | weight = self.weight
51 | if hasattr(self, "quantizer"):
52 | dtype = weight.dtype
53 | self.quantizer.find_params(weight.data)
54 | weight = self.quantizer.quantize(weight).to(dtype)
55 |
56 | return nn.functional.linear(input, weight, self.bias)
57 |
--------------------------------------------------------------------------------
/spinquant/train_utils/rtn_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import torch
12 | import tqdm
13 |
14 | from train_utils.quant_linear import QuantizeLinear
15 | from utils import quant_utils, utils
16 |
17 |
18 | @torch.no_grad()
19 | def rtn_fwrd(model, dev, args):
20 | """
21 | From GPTQ repo
22 | """
23 | # assert args.w_groupsize == -1, "Groupsize not supported in RTN!"
24 | layers = model.model.layers
25 | torch.cuda.empty_cache()
26 |
27 | quantizers = {}
28 |
29 | for i in tqdm.tqdm(range(len(layers)), desc="Inserting weight quantizer"):
30 | layer = layers[i].to(dev)
31 |
32 | subset = quant_utils.find_qlayers(
33 | layer, layers=[torch.nn.Linear, QuantizeLinear]
34 | )
35 |
36 | for name in subset:
37 | layer_weight_bits = args.w_bits
38 | if "lm_head" in name:
39 | layer_weight_bits = 16
40 | continue
41 | if args.int8_down_proj and "down_proj" in name:
42 | layer_weight_bits = 8
43 |
44 | quantizer = quant_utils.WeightQuantizer()
45 | quantizer.configure(
46 | layer_weight_bits,
47 | perchannel=True,
48 | sym=not (args.w_asym),
49 | mse=args.w_clip,
50 | weight_groupsize=args.w_groupsize,
51 | )
52 | subset[name].quantizer = quantizer
53 |
54 | quantizers["model.layers.%d.%s" % (i, name)] = quantizer.cpu()
55 | layers[i] = layer.cpu()
56 | torch.cuda.empty_cache()
57 | del layer
58 |
59 | utils.cleanup_memory(verbos=True)
60 | return quantizers
61 |
--------------------------------------------------------------------------------
/spinquant/utils/convert_to_executorch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Adopt from https://fburl.com/code/b4jqkgir
16 |
17 | import math
18 |
19 | from typing import Any, Tuple
20 |
21 | import torch
22 | from torch._tensor import Tensor
23 |
24 |
25 | def compute_intermediate_size(n) -> int:
26 | return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
27 |
28 |
29 | def shard_tensor(tensor: Tensor, dim: int, num_shards: int) -> Tuple[Tensor, ...]:
30 | total_size = tensor.shape[dim]
31 | n_size_per_shard = total_size // num_shards
32 |
33 | ret_tensors = torch.split(tensor, n_size_per_shard, dim)
34 | return ret_tensors
35 |
36 |
37 | def write_model_llama(
38 | hf_state_dict,
39 | config,
40 | num_shards: int,
41 | ):
42 | n_layers = config.num_hidden_layers
43 | n_heads = config.num_attention_heads
44 | dim = config.hidden_size
45 | num_key_value_heads = config.num_key_value_heads
46 | assert n_heads % num_shards == 0
47 | assert dim % (n_heads * 2) == 0
48 |
49 | def un_permute(w, is_query=True):
50 | return (
51 | w.view(
52 | n_heads if is_query else num_key_value_heads,
53 | 2,
54 | dim // n_heads // 2,
55 | dim,
56 | )
57 | .transpose(1, 2)
58 | .reshape(-1, dim)
59 | )
60 |
61 | model_shard_dicts = [{} for _ in range(num_shards)]
62 | for layer_i in range(n_layers):
63 | ## store the same in every shard
64 | for shard_i in range(num_shards):
65 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention_norm.weight"] = (
66 | hf_state_dict[f"model.layers.{layer_i}.input_layernorm.weight"].clone()
67 | )
68 |
69 | for shard_i in range(num_shards):
70 | model_shard_dicts[shard_i][f"layers.{layer_i}.ffn_norm.weight"] = (
71 | hf_state_dict[
72 | f"model.layers.{layer_i}.post_attention_layernorm.weight"
73 | ].clone()
74 | )
75 |
76 | ### int weight
77 | self_attn_q_proj_weight = hf_state_dict[
78 | f"model.layers.{layer_i}.self_attn.q_proj.module.int_weight"
79 | ]
80 | self_attn_q_proj_weight = un_permute(self_attn_q_proj_weight)
81 | list_self_attn_q_proj_weight = shard_tensor(
82 | self_attn_q_proj_weight, 0, num_shards
83 | )
84 | for shard_i in range(num_shards):
85 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wq.weight"] = (
86 | list_self_attn_q_proj_weight[shard_i].clone().to(torch.int8)
87 | )
88 |
89 | ###
90 | self_attn_k_proj_weight = hf_state_dict[
91 | f"model.layers.{layer_i}.self_attn.k_proj.module.int_weight"
92 | ]
93 | self_attn_k_proj_weight = un_permute(self_attn_k_proj_weight, is_query=False)
94 | list_self_attn_k_proj_weight = shard_tensor(
95 | self_attn_k_proj_weight, 0, num_shards
96 | )
97 | for shard_i in range(num_shards):
98 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wk.weight"] = (
99 | list_self_attn_k_proj_weight[shard_i].clone().to(torch.int8)
100 | )
101 |
102 | ###
103 | self_attn_v_proj_weight = hf_state_dict[
104 | f"model.layers.{layer_i}.self_attn.v_proj.module.int_weight"
105 | ]
106 | list_self_attn_v_proj_weight = shard_tensor(
107 | self_attn_v_proj_weight, 0, num_shards
108 | )
109 | for shard_i in range(num_shards):
110 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wv.weight"] = (
111 | list_self_attn_v_proj_weight[shard_i].clone().to(torch.int8)
112 | )
113 |
114 | ###
115 | self_attn_o_proj_weight = hf_state_dict[
116 | f"model.layers.{layer_i}.self_attn.o_proj.module.int_weight"
117 | ]
118 | list_self_attn_o_proj_weight = shard_tensor(
119 | self_attn_o_proj_weight, 1, num_shards
120 | )
121 | for shard_i in range(num_shards):
122 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wo.weight"] = (
123 | list_self_attn_o_proj_weight[shard_i].clone().to(torch.int8)
124 | )
125 |
126 | ###
127 | mlp_gate_proj_weight = hf_state_dict[
128 | f"model.layers.{layer_i}.mlp.gate_proj.module.int_weight"
129 | ]
130 | list_mlp_gate_proj_weight = shard_tensor(mlp_gate_proj_weight, 0, num_shards)
131 | for shard_i in range(num_shards):
132 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w1.weight"] = (
133 | list_mlp_gate_proj_weight[shard_i].clone().to(torch.int8)
134 | )
135 |
136 | ###
137 | mlp_down_proj_weight = hf_state_dict[
138 | f"model.layers.{layer_i}.mlp.down_proj.module.int_weight"
139 | ]
140 | list_mlp_down_proj_weight = shard_tensor(mlp_down_proj_weight, 1, num_shards)
141 | for shard_i in range(num_shards):
142 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w2.weight"] = (
143 | list_mlp_down_proj_weight[shard_i].clone().to(torch.int8)
144 | )
145 |
146 | ###
147 | mlp_up_proj_weight = hf_state_dict[
148 | f"model.layers.{layer_i}.mlp.up_proj.module.int_weight"
149 | ]
150 | list_mlp_up_proj_weight = shard_tensor(mlp_up_proj_weight, 0, num_shards)
151 | for shard_i in range(num_shards):
152 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w3.weight"] = (
153 | list_mlp_up_proj_weight[shard_i].clone().to(torch.int8)
154 | )
155 |
156 | ### scale
157 | self_attn_q_proj_weight = hf_state_dict[
158 | f"model.layers.{layer_i}.self_attn.q_proj.module.scale"
159 | ]
160 | self_attn_q_proj_weight = un_permute(self_attn_q_proj_weight)
161 | list_self_attn_q_proj_weight = shard_tensor(
162 | self_attn_q_proj_weight, 0, num_shards
163 | )
164 | for shard_i in range(num_shards):
165 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wq.scale"] = (
166 | list_self_attn_q_proj_weight[shard_i].clone()
167 | )
168 |
169 | ###
170 | self_attn_k_proj_weight = hf_state_dict[
171 | f"model.layers.{layer_i}.self_attn.k_proj.module.scale"
172 | ]
173 | self_attn_k_proj_weight = un_permute(self_attn_k_proj_weight, is_query=False)
174 | list_self_attn_k_proj_weight = shard_tensor(
175 | self_attn_k_proj_weight, 0, num_shards
176 | )
177 | for shard_i in range(num_shards):
178 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wk.scale"] = (
179 | list_self_attn_k_proj_weight[shard_i].clone()
180 | )
181 |
182 | ###
183 | self_attn_v_proj_weight = hf_state_dict[
184 | f"model.layers.{layer_i}.self_attn.v_proj.module.scale"
185 | ]
186 | list_self_attn_v_proj_weight = shard_tensor(
187 | self_attn_v_proj_weight, 0, num_shards
188 | )
189 | for shard_i in range(num_shards):
190 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wv.scale"] = (
191 | list_self_attn_v_proj_weight[shard_i].clone()
192 | )
193 |
194 | ###
195 | self_attn_o_proj_weight = hf_state_dict[
196 | f"model.layers.{layer_i}.self_attn.o_proj.module.scale"
197 | ]
198 | list_self_attn_o_proj_weight = shard_tensor(
199 | self_attn_o_proj_weight, 1, num_shards
200 | )
201 | for shard_i in range(num_shards):
202 | model_shard_dicts[shard_i][f"layers.{layer_i}.attention.wo.scale"] = (
203 | list_self_attn_o_proj_weight[shard_i].clone()
204 | )
205 |
206 | ###
207 | mlp_gate_proj_weight = hf_state_dict[
208 | f"model.layers.{layer_i}.mlp.gate_proj.module.scale"
209 | ]
210 | list_mlp_gate_proj_weight = shard_tensor(mlp_gate_proj_weight, 0, num_shards)
211 | for shard_i in range(num_shards):
212 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w1.scale"] = (
213 | list_mlp_gate_proj_weight[shard_i].clone()
214 | )
215 |
216 | ###
217 | mlp_down_proj_weight = hf_state_dict[
218 | f"model.layers.{layer_i}.mlp.down_proj.module.scale"
219 | ]
220 | list_mlp_down_proj_weight = shard_tensor(mlp_down_proj_weight, 1, num_shards)
221 | for shard_i in range(num_shards):
222 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w2.scale"] = (
223 | list_mlp_down_proj_weight[shard_i].clone()
224 | )
225 |
226 | ###
227 | mlp_up_proj_weight = hf_state_dict[
228 | f"model.layers.{layer_i}.mlp.up_proj.module.scale"
229 | ]
230 | list_mlp_up_proj_weight = shard_tensor(mlp_up_proj_weight, 0, num_shards)
231 | for shard_i in range(num_shards):
232 | model_shard_dicts[shard_i][f"layers.{layer_i}.feed_forward.w3.scale"] = (
233 | list_mlp_up_proj_weight[shard_i].clone()
234 | )
235 |
236 | ##
237 | for shard_i in range(num_shards):
238 | model_shard_dicts[shard_i]["norm.weight"] = hf_state_dict[
239 | "model.norm.weight"
240 | ].clone()
241 |
242 | list_embed_tokens_weight = shard_tensor(
243 | hf_state_dict["model.embed_tokens.int_weight"], 1, num_shards
244 | )
245 | list_lm_head_weight = shard_tensor(
246 | hf_state_dict["lm_head.module.int_weight"], 0, num_shards
247 | )
248 | for shard_i in range(num_shards):
249 | model_shard_dicts[shard_i]["tok_embeddings.weight"] = (
250 | list_embed_tokens_weight[shard_i].clone().to(torch.int8)
251 | )
252 | model_shard_dicts[shard_i]["output.weight"] = (
253 | list_lm_head_weight[shard_i].clone().to(torch.int8)
254 | )
255 |
256 | list_embed_tokens_weight = shard_tensor(
257 | hf_state_dict["model.embed_tokens.scale"], 1, num_shards
258 | )
259 | list_lm_head_weight = shard_tensor(
260 | hf_state_dict["lm_head.module.scale"], 0, num_shards
261 | )
262 | for shard_i in range(num_shards):
263 | model_shard_dicts[shard_i]["tok_embeddings.scale"] = (
264 | list_embed_tokens_weight[shard_i].clone().to(torch.float)
265 | )
266 | model_shard_dicts[shard_i]["output.scale"] = (
267 | list_lm_head_weight[shard_i].clone().to(torch.float)
268 | )
269 |
270 | return model_shard_dicts
271 |
272 |
273 | def sanitize_checkpoint_from_spinquant(
274 | checkpoint: Any,
275 | group_size: int,
276 | ):
277 | """
278 | Sanitize the SpinQuant checkpoint.
279 | - Renames 'scale' to 'scales'
280 | - Groups scales
281 | - Removes 'o_weight'
282 | - Converts all tensors to contiguous format
283 | """
284 | keys_to_rename = []
285 | keys_to_remove = []
286 | for k, _ in checkpoint.items():
287 | if k.endswith(".scale"):
288 | new_key = k + "s"
289 | keys_to_rename.append((k, new_key))
290 |
291 | for old_key, new_key in keys_to_rename:
292 | old_val = checkpoint.pop(old_key)
293 | checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size]
294 | for k in keys_to_remove:
295 | checkpoint.pop(k)
296 | for k, v in checkpoint.items():
297 | checkpoint[k] = v.contiguous()
298 | return checkpoint
299 |
--------------------------------------------------------------------------------
/spinquant/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import random
12 | from typing import Any, Dict
13 |
14 | import datasets
15 | import torch
16 | import transformers
17 |
18 |
19 | def get_wikitext2(nsamples=128, seed=0, seqlen=2048, model="", tokenizer=None, eval_mode=False):
20 | if tokenizer is None:
21 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
22 |
23 | if eval_mode:
24 | testdata = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[
25 | "test"
26 | ]
27 | testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
28 | return testenc
29 | else:
30 | traindata = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[
31 | "train"
32 | ]
33 | trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
34 | random.seed(seed)
35 | trainloader = []
36 | for _ in range(nsamples):
37 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
38 | j = i + seqlen
39 | inp = trainenc.input_ids[:, i:j]
40 | tar = inp.clone()
41 | tar[:, :-1] = -100
42 | trainloader.append((inp, tar))
43 | return trainloader
44 |
45 |
46 | class CustomJsonDataset(torch.utils.data.IterableDataset):
47 | def __init__(self, dataset, tokenizer, block_size: int = 1024) -> None:
48 | raw_data = dataset
49 | self.tokenizer = tokenizer
50 | self.block_size = block_size
51 | tokenized_datasets = []
52 | for d in raw_data:
53 | tokenized_datasets.append(self.tokenize_function(d))
54 |
55 | grouped_dataset = self.group_texts(tokenized_datasets)
56 | self.input_ids = grouped_dataset["input_ids"]
57 | self.labels = grouped_dataset["labels"]
58 | self.data = [
59 | dict(input_ids=self.input_ids[i], labels=self.labels[i])
60 | for i in range(len(self.input_ids))
61 | ]
62 |
63 | def __len__(self) -> int:
64 | return len(self.data)
65 |
66 | def __getitem__(self, i) -> Dict[str, Any]:
67 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
68 |
69 | def __iter__(self):
70 | return iter(self.data)
71 |
72 | def tokenize_function(self, examples):
73 | return self.tokenizer(examples["text"])
74 |
75 | def group_texts(self, examples):
76 | # Concatenate all texts.
77 | # Initialize an empty dictionary
78 | concatenated_examples = {}
79 |
80 | # Loop through the list of dictionaries
81 | for d in examples:
82 | # Loop through the keys in each dictionary
83 | for key in d.keys():
84 | # If the key is not already a key in the dict_of_lists, create a new list
85 | if key not in concatenated_examples:
86 | concatenated_examples[key] = []
87 | # Append the value to the list associated with the key in dict_of_lists
88 | concatenated_examples[key].extend(d[key])
89 | total_length = len(concatenated_examples["input_ids"])
90 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
91 | # customize this part to your needs.
92 | if total_length >= self.block_size:
93 | total_length = (total_length // self.block_size) * self.block_size
94 | # Split by chunks of max_len.
95 | result = {
96 | k: [
97 | t[i : i + self.block_size]
98 | for i in range(0, total_length, self.block_size)
99 | ]
100 | for k, t in concatenated_examples.items()
101 | }
102 | result["labels"] = result["input_ids"].copy()
103 | return result
104 |
--------------------------------------------------------------------------------
/spinquant/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import logging
12 | import os
13 |
14 | import torch
15 | from tqdm import tqdm
16 |
17 | from utils import model_utils
18 |
19 |
20 | @torch.no_grad()
21 | def evaluator(model, testenc, dev, args):
22 | model.eval()
23 |
24 | use_cache = model.config.use_cache
25 | model.config.use_cache = False
26 |
27 | layers = model.model.layers
28 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
29 | model.model.rotary_emb = model.model.rotary_emb.to(dev)
30 |
31 | layers[0] = layers[0].to(dev)
32 |
33 | # Convert the whole text of evaluation dataset into batches of sequences.
34 | input_ids = testenc.input_ids # (1, text_len)
35 | nsamples = input_ids.numel() // model.seqlen # The tail is truncated.
36 | input_ids = (
37 | input_ids[:, : nsamples * model.seqlen].view(nsamples, model.seqlen).to(dev)
38 | ) # (nsamples, seqlen)
39 |
40 | batch_size = args.bsz
41 | input_ids = [input_ids[i : i + batch_size] for i in range(0, nsamples, batch_size)]
42 | nbatches = len(input_ids)
43 |
44 | dtype = next(iter(model.parameters())).dtype
45 | # The input of the first decoder layer.
46 | inps = torch.zeros(
47 | (nbatches, batch_size, model.seqlen, model.config.hidden_size),
48 | dtype=dtype,
49 | device=dev,
50 | )
51 | inps = [0] * nbatches
52 | cache = {"i": 0, "attention_mask": None}
53 |
54 | class Catcher(torch.nn.Module):
55 | def __init__(self, module):
56 | super().__init__()
57 | self.module = module
58 |
59 | def forward(self, inp, **kwargs):
60 | inps[cache["i"]] = inp
61 | cache["i"] += 1
62 | cache["attention_mask"] = kwargs["attention_mask"]
63 | cache["position_embeddings"] = kwargs["position_embeddings"]
64 | raise ValueError
65 |
66 | layers[0] = Catcher(layers[0])
67 |
68 | for i in range(nbatches):
69 | batch = input_ids[i]
70 | try:
71 | model(batch)
72 | except ValueError:
73 | pass
74 | layers[0] = layers[0].module
75 | layers[0] = layers[0].cpu()
76 |
77 | model.model.embed_tokens = model.model.embed_tokens.cpu()
78 | model.model.rotary_emb = model.model.rotary_emb.cpu()
79 | position_embeddings = cache["position_embeddings"]
80 |
81 | torch.cuda.empty_cache()
82 | outs = [0] * nbatches
83 | attention_mask = cache["attention_mask"]
84 |
85 | for i in tqdm(range(len(layers)), desc="(Eval) Layers"):
86 | layer = layers[i].to(dev)
87 |
88 | # Dump the layer input and output
89 | if args.capture_layer_io and args.layer_idx == i:
90 | captured_io = model_utils.capture_layer_io(layer, inps)
91 | save_path = model_utils.get_layer_io_save_path(args)
92 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
93 | torch.save(captured_io, save_path)
94 | logging.info(f"Dumped layer input and output to: {save_path}")
95 |
96 | for j in range(nbatches):
97 | outs[j] = layer(
98 | inps[j],
99 | attention_mask=attention_mask,
100 | position_embeddings=position_embeddings,
101 | )[0]
102 | layers[i] = layer.cpu()
103 | del layer
104 | torch.cuda.empty_cache()
105 | inps, outs = outs, inps
106 |
107 | if model.model.norm is not None:
108 | model.model.norm = model.model.norm.to(dev)
109 |
110 | model.lm_head = model.lm_head.to(dev)
111 | nlls = []
112 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
113 | for i in range(nbatches):
114 | hidden_states = inps[i]
115 | if model.model.norm is not None:
116 | hidden_states = model.model.norm(hidden_states)
117 | lm_logits = model.lm_head(hidden_states)
118 | shift_logits = lm_logits[:, :-1, :]
119 | shift_labels = input_ids[i][:, 1:]
120 | loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels)
121 | neg_log_likelihood = loss.float().mean(dim=1)
122 | nlls.append(neg_log_likelihood)
123 | nlls_tensor = torch.cat(nlls)
124 | ppl = torch.exp(nlls_tensor.mean())
125 | model.config.use_cache = use_cache
126 | logging.info(f"\n WikiText2 PPL: {ppl.item():.3f}")
127 | return ppl.item()
128 |
--------------------------------------------------------------------------------
/spinquant/utils/fuse_norm_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import typing
12 | import torch
13 |
14 |
15 | def fuse_ln_linear(
16 | layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear]
17 | ) -> None:
18 | """
19 | fuse the linear operations in Layernorm into the adjacent linear blocks.
20 | """
21 | for linear in linear_layers:
22 | linear_dtype = linear.weight.dtype
23 |
24 | # Calculating new weight and bias
25 | W_ = linear.weight.data.double()
26 | linear.weight.data = (W_ * layernorm.weight.double()).to(linear_dtype)
27 |
28 | if hasattr(layernorm, "bias"):
29 | if linear.bias is None:
30 | linear.bias = torch.nn.Parameter(
31 | torch.zeros(linear.out_features, dtype=torch.float64)
32 | )
33 | linear.bias.data = linear.bias.data.double() + torch.matmul(
34 | W_, layernorm.bias.double()
35 | )
36 | linear.bias.data = linear.bias.data.to(linear_dtype)
37 |
38 |
39 | def fuse_layer_norms(model):
40 | kwargs = {"model": model}
41 |
42 | # Embedding fusion
43 | for W in [model.model.embed_tokens]:
44 | W_ = W.weight.data.double()
45 | W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
46 |
47 | layers = [layer for layer in model.model.layers]
48 |
49 | # Fuse the linear operations in Layernorm into the adjacent linear blocks.
50 | for layer in layers:
51 | # fuse the input layernorms into the linear layers
52 | fuse_ln_linear(
53 | layer.post_attention_layernorm, [layer.mlp.up_proj, layer.mlp.gate_proj]
54 | )
55 | fuse_ln_linear(
56 | layer.input_layernorm,
57 | [
58 | layer.self_attn.q_proj,
59 | layer.self_attn.k_proj,
60 | layer.self_attn.v_proj,
61 | ],
62 | )
63 |
64 | W_norm = layer.post_attention_layernorm.weight.data
65 | layer.post_attention_layernorm.weight.data = torch.ones_like(W_norm)
66 | W_norm = layer.input_layernorm.weight.data
67 | layer.input_layernorm.weight.data = torch.ones_like(W_norm)
68 |
69 | fuse_ln_linear(
70 | model.model.norm,
71 | [model.lm_head],
72 | )
73 | W_norm = model.model.norm.weight.data
74 | model.model.norm.weight.data = torch.ones_like(W_norm)
75 |
--------------------------------------------------------------------------------
/spinquant/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import os
12 | import torch
13 |
14 |
15 | def skip(*args, **kwargs):
16 | # This is a helper function to save time during the initialization!
17 | pass
18 |
19 |
20 | def get_layer_io_save_path(args):
21 | return os.path.join(args.save_path, "layer_io", f"{args.layer_idx:03d}.pt")
22 |
23 |
24 | def capture_layer_io(layer, layer_input):
25 | def hook_factory(module_name, captured_vals, is_input):
26 | def hook(module, input, output):
27 | if is_input:
28 | captured_vals[module_name].append(input[0].detach().cpu())
29 | else:
30 | captured_vals[module_name].append(output.detach().cpu())
31 |
32 | return hook
33 |
34 | handles = []
35 |
36 | captured_inputs = {
37 | "k_proj": [], # q_proj, v_proj has the same input as k_proj
38 | "o_proj": [],
39 | "gate_proj": [], # up_proj has the same input as gate_proj
40 | "down_proj": [],
41 | }
42 |
43 | captured_outputs = {
44 | "v_proj": [],
45 | }
46 |
47 | for name in captured_inputs.keys():
48 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None)
49 | handles.append(
50 | module.register_forward_hook(hook_factory(name, captured_inputs, True))
51 | )
52 |
53 | for name in captured_outputs.keys():
54 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None)
55 | handles.append(
56 | module.register_forward_hook(hook_factory(name, captured_outputs, False))
57 | )
58 |
59 | # Process each sequence in the batch one by one to avoid OOM.
60 | for seq_idx in range(layer_input.shape[0]):
61 | # Extract the current sequence across all dimensions.
62 | seq = layer_input[seq_idx : seq_idx + 1].to("cuda")
63 | # Perform a forward pass for the current sequence.
64 | layer(seq)
65 |
66 | # After processing all sequences, concatenate the accumulated inputs for each sub-layer across the batch.
67 | for module_name in captured_inputs:
68 | captured_inputs[module_name] = torch.cat(captured_inputs[module_name], dim=0)
69 | for module_name in captured_outputs:
70 | captured_outputs[module_name] = torch.cat(captured_outputs[module_name], dim=0)
71 |
72 | # Cleanup.
73 | for h in handles:
74 | h.remove()
75 |
76 | return {"input": captured_inputs, "output": captured_outputs}
77 |
--------------------------------------------------------------------------------
/spinquant/utils/monkeypatch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import copy
12 | import functools
13 | import types
14 |
15 |
16 | def copy_func_with_new_globals(f, globals=None):
17 | """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
18 | if globals is None:
19 | globals = f.__globals__
20 | g = types.FunctionType(
21 | f.__code__,
22 | globals,
23 | name=f.__name__,
24 | argdefs=f.__defaults__,
25 | closure=f.__closure__,
26 | )
27 | g = functools.update_wrapper(g, f)
28 | g.__module__ = f.__module__
29 | g.__kwdefaults__ = copy.copy(f.__kwdefaults__)
30 | return g
31 |
32 |
33 | def add_wrapper_after_function_call_in_method(
34 | module,
35 | method_name,
36 | function_name,
37 | wrapper_fn,
38 | ):
39 | """
40 | This function adds a wrapper after the output of a function call in the method named `method_name`.
41 | Only calls directly in the method are affected. Calls by other functions called in the method are not affected.
42 | """
43 |
44 | original_method = getattr(module, method_name).__func__
45 | method_globals = dict(original_method.__globals__)
46 | wrapper = wrapper_fn(method_globals[function_name])
47 | method_globals[function_name] = wrapper
48 | new_method = copy_func_with_new_globals(original_method, globals=method_globals)
49 | setattr(module, method_name, new_method.__get__(module))
50 | return wrapper
51 |
--------------------------------------------------------------------------------
/spinquant/utils/process_args.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | from dataclasses import dataclass, field
12 | from typing import Optional, Tuple
13 |
14 | import argparse
15 | import transformers
16 |
17 |
18 | @dataclass
19 | class ModelArguments:
20 | input_model: Optional[str] = field(
21 | default="test-input", metadata={"help": "Input model"}
22 | )
23 | output_rotation_path: Optional[str] = field(
24 | default="test-output", metadata={"help": "Output rotation checkpoint path"}
25 | )
26 | optimized_rotation_path: Optional[str] = field(
27 | default=None, metadata={"help": "Optimized rotation checkpoint path"}
28 | )
29 | access_token: Optional[str] = field(
30 | default=None,
31 | metadata={"help": "Huggingface access token to access gated repo like Llama"},
32 | )
33 |
34 |
35 | @dataclass
36 | class TrainingArguments(transformers.TrainingArguments):
37 | cache_dir: Optional[str] = field(default=None)
38 | output_dir: Optional[str] = field(default="/tmp/output/")
39 | model_max_length: Optional[int] = field(
40 | default=2048,
41 | metadata={
42 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)"
43 | },
44 | )
45 |
46 |
47 | def parser_gen():
48 | parser = argparse.ArgumentParser()
49 |
50 | parser.add_argument(
51 | "--seed", type=int, default=0, help="Random Seed for HuggingFace and PyTorch"
52 | )
53 |
54 | # Rotation Arguments
55 | parser.add_argument(
56 | "--rotate",
57 | action=argparse.BooleanOptionalAction,
58 | default=False,
59 | help="""Rotate the moodel. This will include online rotation for down-projection and
60 | out-projection. Note that this does not apply rotation to the K/Q and they will be rotated
61 | if we want to quantize the Keys""",
62 | )
63 | parser.add_argument(
64 | "--rotate_mode", type=str, default="hadamard", choices=["hadamard", "random"]
65 | )
66 | parser.add_argument(
67 | "--rotation_seed",
68 | type=int,
69 | default=-1,
70 | help="Random Seed for generating random matrix!!",
71 | )
72 | parser.add_argument(
73 | "--fp32_had",
74 | action=argparse.BooleanOptionalAction,
75 | default=False,
76 | help="Apply Hadamard rotation in FP32 (default: False)",
77 | )
78 |
79 | # Activation Quantization Arguments
80 | parser.add_argument(
81 | "--a_bits",
82 | type=int,
83 | default=16,
84 | help="""Number of bits for inputs of the Linear layers. This will be
85 | for all the linear layers in the model (including down-projection and out-projection)""",
86 | )
87 | parser.add_argument(
88 | "--a_groupsize",
89 | type=int,
90 | default=-1,
91 | help="Groupsize for activation quantization. Note that this should be the same as w_groupsize",
92 | )
93 | parser.add_argument(
94 | "--a_asym",
95 | action=argparse.BooleanOptionalAction,
96 | default=False,
97 | help="ASymmetric Activation quantization (default: False)",
98 | )
99 | parser.add_argument(
100 | "--a_clip_ratio",
101 | type=float,
102 | default=1.0,
103 | help="Clip ratio for activation quantization. new_max = max * clip_ratio",
104 | )
105 |
106 | # Weight Quantization Arguments
107 | parser.add_argument(
108 | "--w_bits",
109 | type=int,
110 | default=16,
111 | help="Number of bits for weights of the Linear layers",
112 | )
113 | parser.add_argument(
114 | "--w_groupsize",
115 | type=int,
116 | default=-1,
117 | help="Groupsize for weight quantization. Note that this should be the same as a_groupsize",
118 | )
119 | parser.add_argument(
120 | "--w_asym",
121 | action=argparse.BooleanOptionalAction,
122 | default=False,
123 | help="ASymmetric weight quantization (default: False)",
124 | )
125 | parser.add_argument(
126 | "--w_rtn",
127 | action=argparse.BooleanOptionalAction,
128 | default=False,
129 | help="Quantize the weights using RtN. If the w_bits < 16 and this flag is not set, we use GPTQ",
130 | )
131 | parser.add_argument(
132 | "--w_clip",
133 | action=argparse.BooleanOptionalAction,
134 | default=False,
135 | help="""Clipping the weight quantization!
136 | We do not support arguments for clipping and we find the best clip ratio during the weight quantization""",
137 | )
138 | parser.add_argument(
139 | "--nsamples",
140 | type=int,
141 | default=128,
142 | help="Number of calibration data samples for GPTQ.",
143 | )
144 | parser.add_argument(
145 | "--percdamp",
146 | type=float,
147 | default=0.01,
148 | help="Percent of the average Hessian diagonal to use for dampening.",
149 | )
150 | parser.add_argument(
151 | "--act_order",
152 | action=argparse.BooleanOptionalAction,
153 | default=False,
154 | help="act-order in GPTQ",
155 | )
156 | parser.add_argument(
157 | "--asym_calibrate",
158 | action=argparse.BooleanOptionalAction,
159 | default=False,
160 | help="Whether to use GPTAQ",
161 | )
162 |
163 |
164 | # General Quantization Arguments
165 | parser.add_argument(
166 | "--int8_down_proj",
167 | action=argparse.BooleanOptionalAction,
168 | default=False,
169 | help="Use INT8 for Down Projection! If this set, both weights and activations of this layer will be in INT8",
170 | )
171 |
172 | # KV-Cache Quantization Arguments
173 | parser.add_argument(
174 | "--v_bits",
175 | type=int,
176 | default=16,
177 | help="""Number of bits for V-cache quantization.
178 | Note that quantizing the V-cache does not need any other rotation""",
179 | )
180 | parser.add_argument("--v_groupsize", type=int, default=-1)
181 | parser.add_argument(
182 | "--v_asym",
183 | action=argparse.BooleanOptionalAction,
184 | default=False,
185 | help="ASymmetric V-cache quantization",
186 | )
187 | parser.add_argument(
188 | "--v_clip_ratio",
189 | type=float,
190 | default=1.0,
191 | help="Clip ratio for v-cache quantization. new_max = max * clip_ratio",
192 | )
193 |
194 | parser.add_argument(
195 | "--k_bits",
196 | type=int,
197 | default=16,
198 | help="""Number of bits for K-cache quantization.
199 | Note that quantizing the K-cache needs another rotation for the keys/queries""",
200 | )
201 | parser.add_argument("--k_groupsize", type=int, default=-1)
202 | parser.add_argument(
203 | "--k_asym",
204 | action=argparse.BooleanOptionalAction,
205 | default=False,
206 | help="ASymmetric K-cache quantization",
207 | )
208 | parser.add_argument(
209 | "--k_pre_rope",
210 | action=argparse.BooleanOptionalAction,
211 | default=False,
212 | help="Pre-RoPE quantization for K-cache (not Supported yet!)",
213 | )
214 | parser.add_argument(
215 | "--k_clip_ratio",
216 | type=float,
217 | default=1.0,
218 | help="Clip ratio for k-cache quantization. new_max = max * clip_ratio",
219 | )
220 |
221 | # Save/Load Quantized Model Arguments
222 | parser.add_argument(
223 | "--load_qmodel_path",
224 | type=str,
225 | default=None,
226 | help="Load the quantized model from the specified path!",
227 | )
228 | parser.add_argument(
229 | "--save_qmodel_path",
230 | type=str,
231 | default=None,
232 | help="Save the quantized model to the specified path!",
233 | )
234 | parser.add_argument(
235 | "--export_to_et",
236 | action=argparse.BooleanOptionalAction,
237 | default=False,
238 | help="Export the quantized model to executorch and save in save_qmodel_path",
239 | )
240 | parser.add_argument(
241 | "--distribute",
242 | action=argparse.BooleanOptionalAction,
243 | default=False,
244 | help="Distribute the model on multiple GPUs for evaluatione",
245 | )
246 |
247 | # Experiments Arguments
248 | parser.add_argument(
249 | "--capture_layer_io",
250 | action=argparse.BooleanOptionalAction,
251 | default=False,
252 | help="Capture the input and output of the specified decoder layer and dump into a file",
253 | )
254 | parser.add_argument(
255 | "--layer_idx", type=int, default=10, help="Which decoder layer to capture"
256 | )
257 | parser.add_argument(
258 | "--enable_aq_calibration",
259 | action=argparse.BooleanOptionalAction,
260 | default=False,
261 | help="Whether to enable activation quantization during GPTQ (default: False)",
262 | )
263 | # LM Eval Arguments
264 | parser.add_argument("--lm_eval", action="store_true", help="Evaluate the model on LM Eval tasks.")
265 | parser.add_argument(
266 | '--tasks',
267 | nargs='+',
268 | default=["piqa", "hellaswag", "arc_easy", "arc_challenge", "winogrande", 'boolq'],
269 | )
270 | parser.add_argument('--lm_eval_batch_size', type=int, default=128, help='Batch size for evaluating with lm eval harness.')
271 |
272 | args, unknown = parser.parse_known_args()
273 |
274 | # assert (
275 | # args.a_groupsize == args.w_groupsize
276 | # ), "a_groupsize should be the same as w_groupsize!"
277 | assert args.k_pre_rope is False, "Pre-RoPE quantization is not supported yet!"
278 |
279 | return args, unknown
280 |
281 |
282 | def process_args_ptq():
283 | ptq_args = None
284 |
285 | ptq_args, unknown_args = parser_gen()
286 |
287 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
288 | model_args, training_args = parser.parse_args_into_dataclasses(args=unknown_args)
289 | if model_args.optimized_rotation_path is not None:
290 | ptq_args.optimized_rotation_path = model_args.optimized_rotation_path
291 | else:
292 | ptq_args.optimized_rotation_path = None
293 | ptq_args.bsz = training_args.per_device_eval_batch_size
294 |
295 | return model_args, training_args, ptq_args
296 |
--------------------------------------------------------------------------------
/spinquant/utils/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9 | # Licensed under Apache License 2.0.
10 |
11 | import logging
12 | import os
13 | import random
14 | from typing import Optional
15 |
16 | import numpy as np
17 | import torch
18 | from fast_hadamard_transform import hadamard_transform
19 | from torch.distributed.fsdp import (
20 | FullStateDictConfig,
21 | )
22 | from torch.distributed.fsdp import (
23 | FullyShardedDataParallel as PT_FSDP,
24 | )
25 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
26 |
27 |
28 | from accelerate import dispatch_model, infer_auto_device_map
29 | from accelerate.utils import get_balanced_memory
30 |
31 | # These flags disable using TensorFloat-32 tensor cores (to avoid numerical issues)
32 | # torch.backends.cuda.matmul.allow_tf32 = False
33 | # torch.backends.cudnn.allow_tf32 = False
34 | DEV = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
35 |
36 |
37 | def pt_fsdp_state_dict(model: torch.nn.Module):
38 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
39 | with PT_FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
40 | return model.state_dict()
41 |
42 |
43 | class HadamardTransform(torch.autograd.Function):
44 | """The unnormalized Hadamard transform (i.e. without dividing by sqrt(2))"""
45 |
46 | @staticmethod
47 | def forward(ctx, u):
48 | return hadamard_transform(u)
49 |
50 | @staticmethod
51 | def backward(ctx, grad):
52 | return hadamard_transform(grad)
53 |
54 |
55 | def llama_down_proj_groupsize(model, groupsize):
56 | assert groupsize > 1, "groupsize should be greater than 1!"
57 |
58 | if model.config.intermediate_size % groupsize == 0:
59 | logging.info(f"(Act.) Groupsiz = Down_proj Groupsize: {groupsize}")
60 | return groupsize
61 |
62 | group_num = int(model.config.hidden_size / groupsize)
63 | assert (
64 | groupsize * group_num == model.config.hidden_size
65 | ), "Invalid groupsize for llama!"
66 |
67 | down_proj_groupsize = model.config.intermediate_size // group_num
68 | assert (
69 | down_proj_groupsize * group_num == model.config.intermediate_size
70 | ), "Invalid groupsize for down_proj!"
71 | logging.info(
72 | f"(Act.) Groupsize: {groupsize}, Down_proj Groupsize: {down_proj_groupsize}"
73 | )
74 | return down_proj_groupsize
75 |
76 |
77 | def set_seed(seed):
78 | np.random.seed(seed)
79 | torch.random.manual_seed(seed)
80 | random.seed(seed)
81 |
82 |
83 | # Dump the log both to console and a log file.
84 | def config_logging(log_file, level=logging.INFO):
85 | class LogFormatter(logging.Formatter):
86 | def format(self, record):
87 | if record.levelno == logging.INFO:
88 | self._style._fmt = "%(message)s"
89 | else:
90 | self._style._fmt = "%(levelname)s: %(message)s"
91 | return super().format(record)
92 |
93 | console_handler = logging.StreamHandler()
94 | console_handler.setFormatter(LogFormatter())
95 |
96 | file_handler = logging.FileHandler(log_file)
97 | file_handler.setFormatter(LogFormatter())
98 |
99 | logging.basicConfig(level=level, handlers=[console_handler, file_handler])
100 |
101 |
102 | def cleanup_memory(verbos=True) -> None:
103 | """Run GC and clear GPU memory."""
104 | import gc
105 | import inspect
106 |
107 | caller_name = ""
108 | try:
109 | caller_name = f" (from {inspect.stack()[1].function})"
110 | except (ValueError, KeyError):
111 | pass
112 |
113 | def total_reserved_mem() -> int:
114 | return sum(
115 | torch.cuda.memory_reserved(device=i)
116 | for i in range(torch.cuda.device_count())
117 | )
118 |
119 | memory_before = total_reserved_mem()
120 |
121 | # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed
122 | gc.collect()
123 |
124 | if torch.cuda.is_available():
125 | torch.cuda.empty_cache()
126 | memory_after = total_reserved_mem()
127 | if verbos:
128 | logging.info(
129 | f"GPU memory{caller_name}: {memory_before / (1024 ** 3):.2f} -> {memory_after / (1024 ** 3):.2f} GB"
130 | f" ({(memory_after - memory_before) / (1024 ** 3):.2f} GB)"
131 | )
132 |
133 |
134 | # Define a utility method for setting the logging parameters of a logger
135 | def get_logger(logger_name: Optional[str]) -> logging.Logger:
136 | # Get the logger with the specified name
137 | logger = logging.getLogger(logger_name)
138 |
139 | # Set the logging level of the logger to INFO
140 | logger.setLevel(logging.INFO)
141 |
142 | # Define a formatter for the log messages
143 | formatter = logging.Formatter(
144 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
145 | )
146 |
147 | # Create a console handler for outputting log messages to the console
148 | console_handler = logging.StreamHandler()
149 | console_handler.setFormatter(formatter)
150 |
151 | # Add the console handler to the logger
152 | logger.addHandler(console_handler)
153 |
154 | return logger
155 |
156 |
157 | def get_local_rank() -> int:
158 | if os.environ.get("LOCAL_RANK"):
159 | return int(os.environ["LOCAL_RANK"])
160 | else:
161 | logging.warning(
162 | "LOCAL_RANK from os.environ is None, fall back to get rank from torch distributed"
163 | )
164 | return torch.distributed.get_rank()
165 |
166 |
167 | def get_global_rank() -> int:
168 | """
169 | Get rank using torch.distributed if available. Otherwise, the RANK env var instead if initialized.
170 | Returns 0 if neither condition is met.
171 | """
172 | if torch.distributed.is_available() and torch.distributed.is_initialized():
173 | return torch.distributed.get_rank()
174 |
175 | environ_rank = os.environ.get("RANK", "")
176 | if environ_rank.isdecimal():
177 | return int(os.environ["RANK"])
178 |
179 | return 0
180 |
181 |
182 | def distribute_model(model) -> None:
183 | """Distribute the model across available GPUs. NB: only implemented for Llama-2."""
184 | no_split_module_classes = ['LlamaDecoderLayer']
185 | max_memory = get_balanced_memory(
186 | model,
187 | no_split_module_classes=no_split_module_classes,
188 | )
189 |
190 | device_map = infer_auto_device_map(
191 | model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
192 | )
193 |
194 | print(device_map)
195 | dispatch_model(
196 | model,
197 | device_map=device_map,
198 | offload_buffers=True,
199 | offload_dir="offload",
200 | state_dict=model.state_dict(),
201 | )
202 |
203 | cleanup_memory()
--------------------------------------------------------------------------------
/vit_quant/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Fake Quantization for Vision Transformers
4 |
5 | This is code is developed based on our [`fake_quant`](../fake_quant)
6 |
7 | ## Installation
8 |
9 | We can use the same envrionment with QuaRot.
10 |
11 | Additionally, install the `timm` package
12 |
13 | ## ImageNet Evaluations
14 |
15 | Currently, this code supports **EVA-02 and DeiT** transformers. The arguments are the same with `fake_quant` experiments.
16 |
17 | **Before you start, kindly modify the imagenet datapath in [data_utils.py](./data_utils.py) **
18 |
19 |
20 | We provide a script `run.sh` to reproduce the results.
21 |
22 |
--------------------------------------------------------------------------------
/vit_quant/data_utils.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import torch
3 | from timm.data import resolve_data_config
4 | from timm.data.transforms_factory import create_transform
5 | from torchvision.datasets import ImageFolder,DatasetFolder
6 |
7 |
8 | def get_validation_loader(dataset_name, model, batch_size, num_workers=8):
9 | if dataset_name == 'imagenet':
10 | data_path = '/gpfs/gibbs/project/panda/shared/imagenet'
11 | else:
12 | raise NotImplementedError
13 | config = resolve_data_config({}, model=model)
14 | test_transform = create_transform(**config)
15 | val_set = ImageFolder(data_path, test_transform)
16 | val_loader = torch.utils.data.DataLoader(val_set,
17 | batch_size=batch_size,
18 | num_workers=num_workers,
19 | shuffle=False)
20 | return val_loader
21 |
22 |
23 | def get_calibration_loader(dataset_name, model, num_data, num_workers=8):
24 | if dataset_name == 'imagenet':
25 | data_path = '/gpfs/gibbs/project/panda/shared/imagenet_2012/train'
26 | else:
27 | raise NotImplementedError
28 | config = resolve_data_config({}, model=model)
29 | train_transform = create_transform(**config, is_training=False)
30 | calib_set = ImageFolder(data_path, train_transform)
31 | calib_loader = torch.utils.data.DataLoader(calib_set,
32 | batch_size=num_data,
33 | num_workers=num_workers,
34 | shuffle=True)
35 |
36 | return calib_loader
37 |
38 |
39 | def get_calibration_data(dataset_name, model, num_data=128):
40 | calib_loader = get_calibration_loader(dataset_name, model, num_data)
41 | calib_loader = iter(calib_loader)
42 | image, label = next(calib_loader)
43 | image = image.cuda()
44 | return image
45 |
--------------------------------------------------------------------------------
/vit_quant/eval_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import data_utils
3 | from tqdm import tqdm
4 |
5 |
6 | def test(args, model):
7 | model.cuda()
8 | test_loader = data_utils.get_validation_loader(
9 | args.eval_dataset, model, args.bsz
10 | )
11 |
12 | pos = 0
13 | tot = 0
14 | i = 0
15 | max_iteration = len(test_loader)
16 | with torch.no_grad():
17 | q = tqdm(test_loader)
18 | for inp, target in q:
19 | i += 1
20 | inp = inp.cuda()
21 | target = target.cuda()
22 | out = model(inp)
23 | pos_num = torch.sum(out.argmax(1) == target).item()
24 | pos += pos_num
25 | tot += inp.size(0)
26 | q.set_postfix({"acc": pos / tot})
27 | if i >= max_iteration:
28 | break
29 | print('ImageNet accuracy: {}%'.format(100 * pos / tot))
--------------------------------------------------------------------------------
/vit_quant/gptaq_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import tqdm
4 | import torch
5 | import torch.nn as nn
6 | import utils
7 | import quant_utils
8 | import timm
9 | import logging
10 | import functools
11 |
12 | torch.backends.cuda.matmul.allow_tf32 = False
13 | torch.backends.cudnn.allow_tf32 = False
14 |
15 |
16 | class GPTAQ:
17 |
18 | def __init__(self, layer, cls_token=0):
19 | self.layer = layer
20 | self.dev = self.layer.weight.device
21 | W = layer.weight.data.clone()
22 | self.rows = W.shape[0]
23 | self.columns = W.shape[1]
24 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
25 | self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev)
26 | self.nsamples = 0
27 | self.fp_inp = []
28 | self.cls_token = cls_token
29 |
30 | def add_batch(self, inp, out):
31 | if len(inp.shape) == 2:
32 | inp = inp.unsqueeze(0)
33 | tmp = inp.shape[0]
34 | if len(inp.shape) == 3:
35 | inp = inp[:, self.cls_token:, :]
36 | inp = inp.reshape((-1, inp.shape[-1]))
37 |
38 | inp = inp.t()
39 |
40 | self.H *= self.nsamples / (self.nsamples + tmp)
41 | self.dXXT *= self.nsamples / (self.nsamples + tmp)
42 | self.nsamples += tmp
43 | # inp = inp.float()
44 | inp = math.sqrt(2 / self.nsamples) * inp.float()
45 | # self.H += 2 / self.nsamples * inp.matmul(inp.t())
46 | self.H += inp.matmul(inp.t())
47 | dX = self.fp_inp[0].float().to(self.dev) * math.sqrt(2 / self.nsamples) - inp
48 | self.dXXT += dX.matmul(inp.t())
49 |
50 | del self.fp_inp[0]
51 |
52 | def fasterquant(
53 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False
54 | ):
55 | W = self.layer.weight.data.clone()
56 | W = W.float()
57 |
58 | if not self.quantizer.ready():
59 | self.quantizer.find_params(W)
60 |
61 | H = self.H
62 | del self.H
63 | dead = torch.diag(H) == 0
64 | H[dead, dead] = 1
65 | W[:, dead] = 0
66 | self.dXXT[:, dead] = 0
67 |
68 | if static_groups:
69 | import copy
70 | groups = []
71 | for i in range(0, self.columns, groupsize):
72 | quantizer = copy.deepcopy(self.quantizer)
73 | quantizer.find_params(W[:, i:(i + groupsize)])
74 | groups.append(quantizer)
75 |
76 | if actorder:
77 | perm = torch.argsort(torch.diag(H), descending=True)
78 | W = W[:, perm]
79 | H = H[perm][:, perm]
80 | self.dXXT = self.dXXT[perm][:, perm]
81 | invperm = torch.argsort(perm)
82 |
83 | Losses = torch.zeros_like(W)
84 | Q = torch.zeros_like(W)
85 |
86 | damp = percdamp * torch.mean(torch.diag(H))
87 | diag = torch.arange(self.columns, device=self.dev)
88 | H[diag, diag] += damp
89 | Hinv = torch.linalg.cholesky(H)
90 | Hinv = torch.cholesky_inverse(Hinv)
91 | Hinv = torch.linalg.cholesky(Hinv, upper=True)
92 |
93 | alpha = 0.25
94 | P = alpha * ((self.dXXT @ Hinv.T).triu(diagonal=1)) @ Hinv
95 | del self.dXXT
96 |
97 | for i1 in range(0, self.columns, blocksize):
98 | i2 = min(i1 + blocksize, self.columns)
99 | count = i2 - i1
100 |
101 | W1 = W[:, i1:i2].clone()
102 | Q1 = torch.zeros_like(W1)
103 | Err1 = torch.zeros_like(W1)
104 | Losses1 = torch.zeros_like(W1)
105 | Hinv1 = Hinv[i1:i2, i1:i2]
106 | P1 = P[i1:i2, i1:i2]
107 |
108 | for i in range(count):
109 | w = W1[:, i]
110 | d = Hinv1[i, i]
111 |
112 | if groupsize != -1:
113 | if not static_groups:
114 | if (i1 + i) % groupsize == 0:
115 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
116 | else:
117 | idx = i1 + i
118 | if actorder:
119 | idx = perm[idx]
120 | self.quantizer = groups[idx // groupsize]
121 |
122 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
123 | Q1[:, i] = q
124 | Losses1[:, i] = (w - q) ** 2 / d ** 2
125 |
126 | err1 = (w - q) / d
127 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0))
128 | Err1[:, i] = err1
129 |
130 | Q[:, i1:i2] = Q1
131 | Losses[:, i1:i2] = Losses1 / 2
132 |
133 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:])
134 |
135 | torch.cuda.synchronize()
136 |
137 | if actorder:
138 | Q = Q[:, invperm]
139 |
140 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
141 | if torch.any(torch.isnan(self.layer.weight.data)):
142 | logging.warning('NaN in weights')
143 | import pprint
144 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point)
145 | raise ValueError('NaN in weights')
146 |
147 | def free(self):
148 | self.H = None
149 | self.Losses = None
150 | self.Trace = None
151 | torch.cuda.empty_cache()
152 | utils.cleanup_memory(verbos=False)
153 |
154 |
155 | class FPInputsCache:
156 | """
157 | class for saving the full-precision output in each layer.
158 | """
159 | def __init__(self, sequential, cls_token=0):
160 | self.fp_cache = {}
161 | self.names = sequential[0]+sequential[1]+sequential[2]+sequential[3]
162 | print(self.names)
163 | for name in self.names:
164 | self.fp_cache[name] = []
165 | self.handles = []
166 | self.cls_token = cls_token
167 |
168 | def cache_fp_input(self, m, inp, out, name):
169 | inp = inp[0].detach().cpu()
170 | if len(inp.shape) == 3:
171 | inp = inp[:, self.cls_token:, :]
172 | inp = inp.reshape((-1, inp.shape[-1]))
173 | self.fp_cache[name] += [inp.t()]
174 |
175 | def add_hook(self, full):
176 | for name in self.names:
177 | self.handles.append(
178 | full[name].register_forward_hook(
179 | functools.partial(self.cache_fp_input, name=name)
180 | )
181 | )
182 |
183 | def clear_hook(self):
184 | for h in self.handles:
185 | h.remove()
186 | self.handles = []
187 | torch.cuda.empty_cache()
188 |
189 | def clear_cache(self):
190 | for name in self.names:
191 | self.fp_cache[name] = []
192 |
193 |
194 | @torch.no_grad()
195 | def gptaq_fwrd(model, calib_data, dev, args):
196 | '''
197 | From GPTQ repo
198 | TODO: Make this function general to support both OPT and LLaMA models
199 | '''
200 | print('-----GPTQv2 Quantization-----')
201 |
202 | layers = model.blocks
203 |
204 | model = model.cuda()
205 |
206 | class Catcher(nn.Module):
207 | def __init__(self, module):
208 | super().__init__()
209 | self.module = module
210 | self.fwd_kwargs = {}
211 | self.inps = None
212 |
213 | def forward(self, inp, **kwargs):
214 | self.inps = inp.data.clone()
215 | self.fwd_kwargs = kwargs
216 | print('data collected')
217 | raise ValueError
218 |
219 | layers[0] = Catcher(layers[0])
220 | try:
221 | model(calib_data.to(dev))
222 | except ValueError:
223 | pass
224 |
225 | inps = layers[0].inps
226 | fwd_kwargs = layers[0].fwd_kwargs
227 |
228 | layers[0] = layers[0].module
229 | layers[0] = layers[0].cpu()
230 | model = model.cpu()
231 |
232 | del calib_data
233 | torch.cuda.empty_cache()
234 |
235 | outs = torch.zeros_like(inps)
236 |
237 | quantizers = {}
238 | if isinstance(model, timm.models.Eva):
239 | sequential = [
240 | ['attn.q_proj.module', 'attn.k_proj.module', 'attn.v_proj.module'],
241 | ['attn.proj.module'],
242 | ['mlp.fc1_g.module', 'mlp.fc1_x.module'],
243 | ['mlp.fc2.module']
244 | ]
245 | elif isinstance(model, timm.models.VisionTransformer):
246 | sequential = [
247 | ['attn.qkv.module'],
248 | ['attn.proj.module'],
249 | ['mlp.fc1.module'],
250 | ['mlp.fc2.module']
251 | ]
252 | else:
253 | raise NotImplementedError
254 |
255 | fp_inputs_cache = FPInputsCache(sequential, cls_token=0)
256 | fp_inps = inps.clone()
257 |
258 | for i in range(len(layers)):
259 | print(f'\nLayer {i}:', flush=True, end=' ')
260 | layer = layers[i].to(dev)
261 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear])
262 |
263 | bits_config = quant_utils.disable_act_quant(layer)
264 | fp_inputs_cache.add_hook(full)
265 | fp_inps = layer(fp_inps, **fwd_kwargs)
266 | fp_inputs_cache.clear_hook()
267 | quant_utils.enable_act_quant(layer, bits_config)
268 |
269 | for names in sequential:
270 | subset = {n: full[n] for n in names}
271 |
272 | gptq = {}
273 | for name in subset:
274 | print(f'{name}', end=' ', flush=True)
275 | layer_weight_bits = args.w_bits
276 | layer_weight_sym = not (args.w_asym)
277 | if 'head' in name:
278 | layer_weight_bits = 16
279 | continue
280 | gptq[name] = GPTAQ(subset[name], cls_token=0)
281 | gptq[name].quantizer = quant_utils.WeightQuantizer()
282 | gptq[name].quantizer.configure(
283 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip
284 | )
285 | gptq[name].fp_inp = fp_inputs_cache.fp_cache[name]
286 |
287 | def add_batch(name):
288 | def tmp(_, inp, out):
289 | gptq[name].add_batch(inp[0].data, out.data)
290 |
291 | return tmp
292 |
293 | handles = []
294 | for name in subset:
295 | handles.append(subset[name].register_forward_hook(add_batch(name)))
296 | outs = layer(inps, **fwd_kwargs) # forward calibration data
297 | for h in handles:
298 | h.remove()
299 |
300 | for name in subset:
301 | layer_w_groupsize = args.w_groupsize
302 | gptq[name].fasterquant(
303 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False
304 | )
305 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
306 | gptq[name].free()
307 |
308 | outs = layer(inps, **fwd_kwargs) # forward calibration data
309 |
310 | fp_inputs_cache.clear_cache()
311 | layers[i] = layer.cpu()
312 | del layer
313 | del gptq
314 | torch.cuda.empty_cache()
315 |
316 | inps, outs = outs, inps
317 |
318 | utils.cleanup_memory(verbos=True)
319 | print('-----GPTQv2 Quantization Done-----\n')
320 | return quantizers
321 |
--------------------------------------------------------------------------------
/vit_quant/gptq_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import timm.models
5 | import tqdm
6 | import torch
7 | import torch.nn as nn
8 | import utils
9 | import quant_utils
10 | import logging
11 |
12 | torch.backends.cuda.matmul.allow_tf32 = False
13 | torch.backends.cudnn.allow_tf32 = False
14 |
15 |
16 | class GPTQ:
17 |
18 | def __init__(self, layer, cls_token=0):
19 | self.layer = layer
20 | self.dev = self.layer.weight.device
21 | W = layer.weight.data.clone()
22 | self.rows = W.shape[0]
23 | self.columns = W.shape[1]
24 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
25 | self.nsamples = 0
26 | self.cls_token = cls_token
27 |
28 | def add_batch(self, inp, out):
29 |
30 | if len(inp.shape) == 2:
31 | inp = inp.unsqueeze(0)
32 | tmp = inp.shape[0]
33 | if len(inp.shape) == 3:
34 | inp = inp[:, self.cls_token:, :]
35 | inp = inp.reshape((-1, inp.shape[-1]))
36 | inp = inp.t()
37 | self.H *= self.nsamples / (self.nsamples + tmp)
38 | self.nsamples += tmp
39 | # inp = inp.float()
40 | inp = math.sqrt(2 / self.nsamples) * inp.float()
41 | # self.H += 2 / self.nsamples * inp.matmul(inp.t())
42 | self.H += inp.matmul(inp.t())
43 |
44 | def fasterquant(
45 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False
46 | ):
47 | W = self.layer.weight.data.clone()
48 | W = W.float()
49 |
50 | tick = time.time()
51 |
52 | if not self.quantizer.ready():
53 | self.quantizer.find_params(W)
54 |
55 | H = self.H
56 | del self.H
57 | dead = torch.diag(H) == 0
58 | H[dead, dead] = 1
59 | W[:, dead] = 0
60 |
61 | if static_groups:
62 | import copy
63 | groups = []
64 | for i in range(0, self.columns, groupsize):
65 | quantizer = copy.deepcopy(self.quantizer)
66 | quantizer.find_params(W[:, i:(i + groupsize)])
67 | groups.append(quantizer)
68 |
69 | if actorder:
70 | perm = torch.argsort(torch.diag(H), descending=True)
71 | W = W[:, perm]
72 | H = H[perm][:, perm]
73 | invperm = torch.argsort(perm)
74 |
75 | Losses = torch.zeros_like(W)
76 | Q = torch.zeros_like(W)
77 |
78 | damp = percdamp * torch.mean(torch.diag(H))
79 | diag = torch.arange(self.columns, device=self.dev)
80 | H[diag, diag] += damp
81 | H = torch.linalg.cholesky(H)
82 | H = torch.cholesky_inverse(H)
83 | H = torch.linalg.cholesky(H, upper=True)
84 | Hinv = H
85 |
86 | for i1 in range(0, self.columns, blocksize):
87 | i2 = min(i1 + blocksize, self.columns)
88 | count = i2 - i1
89 |
90 | W1 = W[:, i1:i2].clone()
91 | Q1 = torch.zeros_like(W1)
92 | Err1 = torch.zeros_like(W1)
93 | Losses1 = torch.zeros_like(W1)
94 | Hinv1 = Hinv[i1:i2, i1:i2]
95 |
96 | for i in range(count):
97 | w = W1[:, i]
98 | d = Hinv1[i, i]
99 |
100 | if groupsize != -1:
101 | if not static_groups:
102 | if (i1 + i) % groupsize == 0:
103 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
104 | else:
105 | idx = i1 + i
106 | if actorder:
107 | idx = perm[idx]
108 | self.quantizer = groups[idx // groupsize]
109 |
110 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
111 | Q1[:, i] = q
112 | Losses1[:, i] = (w - q) ** 2 / d ** 2
113 |
114 | err1 = (w - q) / d
115 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
116 | Err1[:, i] = err1
117 |
118 | Q[:, i1:i2] = Q1
119 | Losses[:, i1:i2] = Losses1 / 2
120 |
121 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
122 |
123 | torch.cuda.synchronize()
124 |
125 | if actorder:
126 | Q = Q[:, invperm]
127 |
128 | # print('change ratio: {}'.format(Q.norm()/self.layer.weight.data.norm()))
129 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
130 | if torch.any(torch.isnan(self.layer.weight.data)):
131 | logging.warning('NaN in weights')
132 | import pprint
133 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point)
134 | raise ValueError('NaN in weights')
135 |
136 | def free(self):
137 | self.H = None
138 | self.Losses = None
139 | self.Trace = None
140 | torch.cuda.empty_cache()
141 | utils.cleanup_memory(verbos=False)
142 |
143 |
144 | @torch.no_grad()
145 | def gptq_fwrd(model, calib_data, dev, args):
146 | '''
147 | From GPTQ repo
148 | TODO: Make this function general to support both OPT and LLaMA models
149 | '''
150 | logging.info('-----GPTQ Quantization-----')
151 |
152 | layers = model.blocks
153 |
154 | model = model.cuda()
155 |
156 | class Catcher(nn.Module):
157 | def __init__(self, module):
158 | super().__init__()
159 | self.module = module
160 | self.fwd_kwargs = {}
161 | self.inps = None
162 |
163 | def forward(self, inp, **kwargs):
164 | self.inps = inp.data.clone()
165 | self.fwd_kwargs = kwargs
166 | print('data collected')
167 | raise ValueError
168 |
169 | layers[0] = Catcher(layers[0])
170 | try:
171 | model(calib_data.to(dev))
172 | except ValueError:
173 | pass
174 |
175 | inps = layers[0].inps
176 | fwd_kwargs = layers[0].fwd_kwargs
177 |
178 | layers[0] = layers[0].module
179 | layers[0] = layers[0].cpu()
180 | model = model.cpu()
181 |
182 | del calib_data
183 | torch.cuda.empty_cache()
184 |
185 | outs = torch.zeros_like(inps)
186 |
187 | quantizers = {}
188 | if isinstance(model, timm.models.Eva):
189 | sequential = [
190 | ['attn.q_proj.module', 'attn.k_proj.module', 'attn.v_proj.module'],
191 | ['attn.proj.module'],
192 | ['mlp.fc1_g.module', 'mlp.fc1_x.module'],
193 | ['mlp.fc2.module']
194 | ]
195 | elif isinstance(model, timm.models.VisionTransformer):
196 | sequential = [
197 | ['attn.qkv.module'],
198 | ['attn.proj.module'],
199 | ['mlp.fc1.module'],
200 | ['mlp.fc2.module']
201 | ]
202 | else:
203 | raise NotImplementedError
204 |
205 | for i in range(len(layers)):
206 | print(f'\nLayer {i}:', flush=True, end=' ')
207 | layer = layers[i].to(dev)
208 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear])
209 | for names in sequential:
210 | subset = {n: full[n] for n in names}
211 |
212 | gptq = {}
213 | for name in subset:
214 | print(f'{name}', end=' ', flush=True)
215 | layer_weight_bits = args.w_bits
216 | layer_weight_sym = not (args.w_asym)
217 | if 'head' in name:
218 | layer_weight_bits = 16
219 | continue
220 | gptq[name] = GPTQ(subset[name], cls_token=0)
221 | gptq[name].quantizer = quant_utils.WeightQuantizer()
222 | gptq[name].quantizer.configure(
223 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip
224 | )
225 |
226 | def add_batch(name):
227 | def tmp(_, inp, out):
228 | gptq[name].add_batch(inp[0].data, out.data)
229 |
230 | return tmp
231 |
232 | handles = []
233 | for name in subset:
234 | handles.append(subset[name].register_forward_hook(add_batch(name)))
235 | outs = layer(inps, **fwd_kwargs) # forward calibration data
236 | for h in handles:
237 | h.remove()
238 |
239 | for name in subset:
240 | layer_w_groupsize = args.w_groupsize
241 | gptq[name].fasterquant(
242 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False
243 | )
244 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
245 | gptq[name].free()
246 |
247 | outs = layer(inps, **fwd_kwargs)
248 |
249 | layers[i] = layer.cpu()
250 | del layer
251 | del gptq
252 | torch.cuda.empty_cache()
253 |
254 | inps, outs = outs, inps
255 |
256 | utils.cleanup_memory(verbos=True)
257 | logging.info('-----GPTQ Quantization Done-----\n')
258 | return quantizers
259 |
260 |
261 | @torch.no_grad()
262 | def rtn_fwrd(model, dev, args):
263 | '''
264 | From GPTQ repo
265 | TODO: Make this function general to support both OPT and LLaMA models
266 | '''
267 | assert args.w_groupsize == -1, "Groupsize not supported in RTN!"
268 | layers = model.blocks
269 | torch.cuda.empty_cache()
270 |
271 | quantizers = {}
272 |
273 | for i in tqdm.tqdm(range(len(layers)), desc="(RtN Quant.) Layers"):
274 | layer = layers[i].to(dev)
275 |
276 | subset = quant_utils.find_qlayers(layer,
277 | layers=[torch.nn.Linear])
278 |
279 | for name in subset:
280 | layer_weight_bits = args.w_bits
281 | if 'head' in name:
282 | layer_weight_bits = 16
283 | continue
284 |
285 | quantizer = quant_utils.WeightQuantizer()
286 | quantizer.configure(
287 | layer_weight_bits, perchannel=True, sym=not (args.w_asym), mse=args.w_clip
288 | )
289 | W = subset[name].weight.data
290 | quantizer.find_params(W)
291 | subset[name].weight.data = quantizer.quantize(W).to(
292 | next(iter(layer.parameters())).dtype)
293 | quantizers['model.layers.%d.%s' % (i, name)] = quantizer.cpu()
294 | layers[i] = layer.cpu()
295 | torch.cuda.empty_cache()
296 | del layer
297 |
298 | utils.cleanup_memory(verbos=True)
299 | return quantizers
300 |
--------------------------------------------------------------------------------
/vit_quant/main.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import torch
3 | import model_utils
4 | import data_utils
5 | import eval_utils
6 | import quant_utils
7 | import gptq_utils
8 | import gptaq_utils
9 |
10 |
11 | def configure_act_quantizer(model, args):
12 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
13 | down_proj_groupsize = -1
14 | if args.a_groupsize > 0 and "llama" in args.model:
15 | down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize)
16 |
17 | for name in qlayers:
18 | layer_input_bits = args.a_bits
19 | layer_groupsize = args.a_groupsize
20 | layer_a_sym = not (args.a_asym)
21 | layer_a_clip = args.a_clip_ratio
22 |
23 | if 'head' in name: # Skip lm_head quantization
24 | layer_input_bits = 16
25 |
26 | if 'down_proj' in name: # Set the down_proj precision
27 | if args.int8_down_proj:
28 | layer_input_bits = 8
29 | layer_groupsize = down_proj_groupsize
30 |
31 | qlayers[name].quantizer.configure(bits=layer_input_bits,
32 | groupsize=layer_groupsize,
33 | sym=layer_a_sym,
34 | clip_ratio=layer_a_clip)
35 |
36 |
37 | def main():
38 | args = utils.parser_gen()
39 | if args.wandb:
40 | import wandb
41 | wandb.init(project=args.wandb_project, entity=args.wandb_id)
42 | wandb.config.update(args)
43 |
44 | torch.manual_seed(args.seed)
45 | model = model_utils.get_model(args.model)
46 | if args.eval_fp:
47 | eval_utils.test(args, model)
48 | quant_utils.add_actquant(model) # Add Activation Wrapper to the model as the rest of the code assumes it is present
49 |
50 | # Add Input Quantization
51 | if args.a_bits < 16 and args.enable_aq_calibration:
52 | configure_act_quantizer(model, args)
53 |
54 | if args.w_bits < 16:
55 | save_dict = {}
56 | if args.load_qmodel_path: # Load Quantized Rotated Model
57 | assert args.rotate, "Model should be rotated to load a quantized model!"
58 | assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
59 | print("Load quantized model from ", args.load_qmodel_path)
60 | save_dict = torch.load(args.load_qmodel_path)
61 | model.load_state_dict(save_dict["model"])
62 |
63 | elif not args.w_rtn: # GPTQ Weight Quantization
64 |
65 | calib_data = data_utils.get_calibration_data(
66 | args.cal_dataset, num_data=args.nsamples, model=model,
67 | )
68 |
69 | if args.asym_calibrate:
70 | quantizers = gptaq_utils.gptaq_fwrd(model, calib_data, utils.DEV, args)
71 | save_dict["w_quantizers"] = quantizers
72 | else:
73 | quantizers = gptq_utils.gptq_fwrd(model, calib_data, utils.DEV, args)
74 | save_dict["w_quantizers"] = quantizers
75 | else: # RTN Weight Quantization
76 | quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args)
77 | save_dict["w_quantizers"] = quantizers
78 |
79 | if args.save_qmodel_path:
80 | save_dict["model"] = model.state_dict()
81 | torch.save(save_dict, args.save_qmodel_path)
82 |
83 | if args.a_bits < 16 and not args.enable_aq_calibration:
84 | configure_act_quantizer(model, args)
85 |
86 | # Evaluating on dataset
87 | eval_utils.test(args, model)
88 |
89 |
90 | if __name__ == '__main__':
91 | main()
92 |
--------------------------------------------------------------------------------
/vit_quant/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import typing
3 | # import utils
4 | import timm
5 | import os
6 | import logging
7 |
8 |
9 | def get_model(name, ):
10 | """
11 | Get a vision transformer model.
12 | This will replace matrix multiplication operations with matmul modules in the model.
13 |
14 | Currently support almost all models in timm.models.transformers, including:
15 | - vit_tiny/small/base/large_patch16/patch32_224/384,
16 | - deit_tiny/small/base(_distilled)_patch16_224,
17 | - deit_base(_distilled)_patch16_384,
18 | - swin_tiny/small/base/large_patch4_window7_224,
19 | - swin_base/large_patch4_window12_384
20 |
21 | These models are finetuned on imagenet-1k and should use ViTImageNetLoaderGenerator
22 | for calibration and testing.
23 | """
24 | net = timm.create_model(name, pretrained=True)
25 |
26 | net.cuda()
27 | net.eval()
28 | return net
29 |
30 |
31 |
32 | def stem_layer_forward(model, calib_data, dev):
33 | if isinstance(model, timm.models.VisionTransformer):
34 | model.patch_embed = model.patch_embed.to(dev)
35 | model.pos_drop = model.pos_drop.to(dev)
36 |
37 | inps = model.patch_embed(calib_data)
38 | inps = model._pos_embed(inps)
39 | inps = model.patch_drop(inps)
40 | inps = model.norm_pre(inps)
41 |
42 | dtype = next(iter(model.parameters())).dtype
43 |
44 | model.patch_embed = model.patch_embed.cpu()
45 | model.pos_drop = model.pos_drop.cpu()
46 | kwargs = {}
47 | elif isinstance(model, timm.models.Eva):
48 | model.patch_embed = model.patch_embed.to(dev)
49 | model._pos_embed = model._pos_embed.to(dev)
50 |
51 | inps = model.patch_embed(calib_data)
52 | inps, rot_pos_embed = model._pos_embed(inps)
53 | inps = model.norm_pre(inps)
54 |
55 | dtype = next(iter(model.parameters())).dtype
56 |
57 | model.patch_embed = model.patch_embed.cpu()
58 | model.pos_drop = model.pos_drop.cpu()
59 |
60 | return model, inps, kwargs
--------------------------------------------------------------------------------
/vit_quant/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | export CUDA_VISIBLE_DEVICES=$gpu_id
5 |
6 |
7 | python main.py --model eva02_large_patch14_448.mim_m38m_ft_in22k_in1k \
8 | --w_bits 4 \
9 | --w_groupsize -1 \
10 | --w_clip \
11 | --a_bits 4 \
12 | --nsamples 128 \
13 | --a_asym \
14 | --w_asym \
15 | --percdamp 0.1 \
16 | --act_order \
17 | --bsz 256 \
18 | --asym_calibrate \
19 | --enable_aq_calibration \
20 |
--------------------------------------------------------------------------------
/vit_quant/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pprint
3 | import torch
4 | import random
5 | import numpy as np
6 | import os
7 | from datetime import datetime
8 | import logging
9 |
10 |
11 | supported_models = [
12 | 'deit_base_patch16_224',
13 | 'deit_small_patch16_224',
14 | 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k',
15 | ]
16 | supported_datasets = ['imagenet']
17 |
18 | # These flags disable using TensorFloat-32 tensor cores (to avoid numerical issues)
19 | torch.backends.cuda.matmul.allow_tf32 = False
20 | torch.backends.cudnn.allow_tf32 = False
21 | DEV = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
22 |
23 |
24 | def llama_down_proj_groupsize(model, groupsize):
25 | assert groupsize > 1, 'groupsize should be greater than 1!'
26 |
27 | if model.config.intermediate_size % groupsize == 0:
28 | logging.info(f'(Act.) Groupsiz = Down_proj Groupsize: {groupsize}')
29 | return groupsize
30 |
31 | group_num = int(model.config.hidden_size / groupsize)
32 | assert groupsize * group_num == model.config.hidden_size, 'Invalid groupsize for llama!'
33 |
34 | down_proj_groupsize = model.config.intermediate_size // group_num
35 | assert down_proj_groupsize * group_num == model.config.intermediate_size, 'Invalid groupsize for down_proj!'
36 | logging.info(f'(Act.) Groupsize: {groupsize}, Down_proj Groupsize: {down_proj_groupsize}')
37 | return down_proj_groupsize
38 |
39 |
40 | def set_seed(seed):
41 | np.random.seed(seed)
42 | torch.random.manual_seed(seed)
43 | random.seed(seed)
44 |
45 |
46 | # Dump the log both to console and a log file.
47 | def config_logging(log_file, level=logging.INFO):
48 | class LogFormatter(logging.Formatter):
49 | def format(self, record):
50 | if record.levelno == logging.INFO:
51 | self._style._fmt = "%(message)s"
52 | else:
53 | self._style._fmt = "%(levelname)s: %(message)s"
54 | return super().format(record)
55 |
56 | console_handler = logging.StreamHandler()
57 | console_handler.setFormatter(LogFormatter())
58 |
59 | file_handler = logging.FileHandler(log_file)
60 | file_handler.setFormatter(LogFormatter())
61 |
62 | logging.basicConfig(level=level, handlers=[console_handler, file_handler])
63 |
64 |
65 | def parser_gen():
66 | parser = argparse.ArgumentParser()
67 |
68 | # General Arguments
69 | parser.add_argument('--model', type=str, default='deit_base_patch16_224',
70 | help='Model to load;', choices=supported_models)
71 | parser.add_argument('--seed', type=int, default=0, help='Random Seed for HuggingFace and PyTorch')
72 | parser.add_argument('--eval_dataset', type=str, default='imagenet',
73 | help='Dataset for Evaluation (default: imagenet)', choices=supported_datasets, )
74 | parser.add_argument('--bsz', type=int, default=256,
75 | help='Batch-size for PPL evaluation (default:256)')
76 | parser.add_argument('--eval_fp', action=argparse.BooleanOptionalAction, default=False,
77 | help='whether to evaluate the FP16 model (default: False)')
78 |
79 | # Activation Quantization Arguments
80 | parser.add_argument('--a_bits', type=int, default=16,
81 | help='''Number of bits for inputs of the Linear layers. This will be
82 | for all the linear layers in the model (including down-projection and out-projection)''')
83 | parser.add_argument('--a_groupsize', type=int, default=-1,
84 | help='Groupsize for activation quantization. Note that this should be the same as w_groupsize')
85 | parser.add_argument('--a_asym', action=argparse.BooleanOptionalAction, default=False,
86 | help='ASymmetric Activation quantization (default: False)')
87 | parser.add_argument('--a_clip_ratio', type=float, default=1.0,
88 | help='Clip ratio for activation quantization. new_max = max * clip_ratio')
89 | parser.add_argument('--enable_aq_calibration', action=argparse.BooleanOptionalAction, default=False,
90 | help='Activation quantization during GPTQ (default: False)')
91 |
92 | # Weight Quantization Arguments
93 | parser.add_argument('--w_bits', type=int, default=16,
94 | help='Number of bits for weights of the Linear layers')
95 | parser.add_argument('--w_groupsize', type=int, default=-1,
96 | help='Groupsize for weight quantization. Note that this should be the same as a_groupsize')
97 | parser.add_argument('--w_asym', action=argparse.BooleanOptionalAction, default=False,
98 | help='ASymmetric weight quantization (default: False)')
99 | parser.add_argument('--w_rtn', action=argparse.BooleanOptionalAction, default=False,
100 | help='Quantize the weights using RtN. If the w_bits < 16 and this flag is not set, we use GPTQ')
101 | parser.add_argument('--w_clip', action=argparse.BooleanOptionalAction, default=False,
102 | help='''Clipping the weight quantization!
103 | We do not support arguments for clipping and we find the best clip ratio during the weight quantization''')
104 | parser.add_argument('--nsamples', type=int, default=128,
105 | help='Number of calibration data samples for GPTQ.')
106 | parser.add_argument('--cal_dataset', type=str, default='imagenet',
107 | help='calibration data samples for GPTQ.', choices=supported_datasets)
108 | parser.add_argument('--percdamp', type=float, default=.01,
109 | help='Percent of the average Hessian diagonal to use for dampening.')
110 | parser.add_argument('--act_order', action=argparse.BooleanOptionalAction, default=False,
111 | help='act-order in GPTQ')
112 | parser.add_argument('--asym_calibrate', action=argparse.BooleanOptionalAction, default=False,
113 | help='whether to use GPTAQ')
114 |
115 | # Save/Load Quantized Model Arguments
116 | parser.add_argument('--load_qmodel_path', type=str, default=None,
117 | help='Load the quantized model from the specified path!')
118 | parser.add_argument('--save_qmodel_path', type=str, default=None,
119 | help='Save the quantized model to the specified path!')
120 |
121 | # WandB Arguments
122 | parser.add_argument('--wandb', action=argparse.BooleanOptionalAction, default=False)
123 | parser.add_argument('--wandb_id', type=str, default=None)
124 | parser.add_argument('--wandb_project', type=str, default=None)
125 |
126 | # Experiments Arguments
127 | parser.add_argument('--save_name', type=str, default=None, help='The path to save experiment data, '
128 | 'including quantized models, dumped layer inputs, etc. The data will be saved in experiments/[model]/save_name. Default: [datetime].')
129 | parser.add_argument('--capture_layer_io', action=argparse.BooleanOptionalAction, default=False,
130 | help='Capture the input and output of the specified decoder layer and dump into a file')
131 | parser.add_argument('--layer_idx', type=int, default=10, help='Which decoder layer to capture')
132 |
133 | parser.add_argument(
134 | "--distribute",
135 | action="store_true",
136 | help="Distribute the model on multiple GPUs for evaluation.",
137 | )
138 |
139 | args = parser.parse_args()
140 |
141 | if args.save_name is None:
142 | args.save_name = datetime.now().strftime("%Y%m%d_%H%M%S")
143 | setattr(args, 'save_path',
144 | os.path.join(os.path.dirname(os.path.abspath(__file__)), 'experiments', args.model, args.save_name))
145 | os.makedirs(args.save_path, exist_ok=True)
146 |
147 | config_logging(os.path.join(args.save_path, f'{args.save_name}.log'))
148 |
149 | assert args.a_groupsize == args.w_groupsize, 'a_groupsize should be the same as w_groupsize!'
150 |
151 | if args.wandb:
152 | assert args.wandb_id is not None and args.wandb_project is not None, 'WandB ID/project is not provided!'
153 |
154 | logging.info('Arguments: ')
155 | logging.info(pprint.pformat(vars(args)))
156 | logging.info('--' * 30)
157 | return args
158 |
159 |
160 | def cleanup_memory(verbos=True) -> None:
161 | """Run GC and clear GPU memory."""
162 | import gc
163 | import inspect
164 | caller_name = ''
165 | try:
166 | caller_name = f' (from {inspect.stack()[1].function})'
167 | except (ValueError, KeyError):
168 | pass
169 |
170 | def total_reserved_mem() -> int:
171 | return sum(torch.cuda.memory_reserved(device=i) for i in range(torch.cuda.device_count()))
172 |
173 | memory_before = total_reserved_mem()
174 |
175 | # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed
176 | gc.collect()
177 |
178 | if torch.cuda.is_available():
179 | torch.cuda.empty_cache()
180 | memory_after = total_reserved_mem()
181 | if verbos:
182 | logging.info(
183 | f"GPU memory{caller_name}: {memory_before / (1024 ** 3):.2f} -> {memory_after / (1024 ** 3):.2f} GB"
184 | f" ({(memory_after - memory_before) / (1024 ** 3):.2f} GB)"
185 | )
186 |
187 |
188 | def distribute_model(model) -> None:
189 | """Distribute the model across available GPUs. NB: only implemented for Llama-2."""
190 | no_split_module_classes = ['LlamaDecoderLayer']
191 | max_memory = get_balanced_memory(
192 | model,
193 | no_split_module_classes=no_split_module_classes,
194 | )
195 |
196 | device_map = infer_auto_device_map(
197 | model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
198 | )
199 |
200 | dispatch_model(
201 | model,
202 | device_map=device_map,
203 | offload_buffers=True,
204 | offload_dir="offload",
205 | state_dict=model.state_dict(),
206 | )
207 |
208 | cleanup_memory()
--------------------------------------------------------------------------------