├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── convert_llama_weights_to_hf.py ├── datautils.py ├── fused_attn.py ├── gptq.py ├── llama.py ├── llama_inference.py ├── llama_inference_dmapauto.py ├── llama_inference_offload.py ├── modelutils.py ├── opt.py ├── quant.py ├── quant_cuda.cpp ├── quant_cuda_kernel.cu ├── requirements.txt ├── santacoder.py ├── santacoder_inference.py ├── scripts ├── santacoder-16bit.sh ├── santacoder-32bit.sh ├── santacoder-4bit.sh ├── santacoder-8bit.sh ├── starcoder-16bit.sh ├── starcoder-32bit.sh ├── starcoder-4bit.sh ├── starcoder-8bit.sh ├── starcoderbase-16bit.sh ├── starcoderbase-32bit.sh ├── starcoderbase-4bit.sh └── starcoderbase-8bit.sh ├── setup_cuda.py ├── share_tensors_across_processes.py └── test_kernel.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pycqa/isort 3 | rev: 5.12.0 4 | hooks: 5 | - id: isort 6 | name: isort (python) 7 | - repo: https://github.com/psf/black 8 | rev: 23.3.0 9 | hooks: 10 | - id: black 11 | args: [--line-length=119,--target-version=py39,--force-exclude=instructions/] 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPTQ-for-SantaCoder-and-StarCoder 2 | Quantization of [SantaCoder](https://arxiv.org/abs/2301.03988) using [GPTQ](https://arxiv.org/abs/2210.17323) 3 | 4 | GPTQ is SOTA one-shot weight quantization method 5 | 6 | **This code is based on [GPTQ](https://github.com/IST-DASLab/gptq)** 7 | 8 | Changed to support new features proposed by [GPTQ](https://github.com/IST-DASLab/gptq#new-features). 9 | 10 | * Slightly adjusted preprocessing of C4 and PTB for more realistic evaluations (used in our updated results); can be activated via the flag --new-eval. 11 | * two new tricks:--act-order (quantizing columns in order of decreasing activation size) and --true-sequential (performing sequential quantization even within a single Transformer block). Those fix GPTQ's strangely bad performance on the 7B model (from 7.15 to 6.09 Wiki2 PPL) and lead to slight improvements on most models/settings in general. 12 | 13 | **It supports act-order, but it's very slow.** 14 | 15 | ## Result 16 | | [SantaCoder](https://arxiv.org/abs/2301.03988) | Bits | group-size | memory(MiB) | wikitext2 | ptb | c4 | stack | checkpoint size(MB) | 17 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ---------- | ---------- | ---------- | ------------------- | 18 | | FP32 | 32 | - | 4344.722 | 24.927 | 38.574 | 27.779 | 2.619 | 4394 | 19 | | BF16 | 16 | - | 2173.680 | 24.960 | 38.597 | 27.794 | 2.621 | 2195 | 20 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 8 | -1 | 1396.548 | 24.936 | 38.592 | 27.785 | 2.619 | 1411 | 21 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | -1 | 911.384 | 26.581 | 40.717 | 29.232 | 2.658 | 913 | 22 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | -1 | - | 11761.473 | 7273.338 | 9124.941 | 2485.844 | 789 | 23 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 2 | -1 | - | 67976.797 | 68994.484 | 73294.438 | 45370.488 | 649 | 24 | 25 | ## Result 26 | | StarCoder | Bits | group-size | memory(MiB) | wikitext2 | ptb | c4 | stack | checkpoint size(MB) | 27 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ---------- | ---------- | ---------- | ------------------- | 28 | | FP32 | 32 | - | | 10.801 | 16.425 | 13.402 | 1.738 | 59195 | 29 | | BF16 | 16 | - | | 10.807 | 16.424 | 13.408 | 1.739 | 29597 | 30 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 8 | 128 | | 10.805 | 15.453 | 13.408 | 1.739 | 16163 | 31 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | | 10.989 | 16.839 | 13.676 | 1.757 | 8877 | 32 | 33 | ## Result 34 | | StarCoderBase | Bits | group-size | memory(MiB) | wikitext2 | ptb | c4 | stack | checkpoint size(MB) | 35 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ---------- | ---------- | ---------- | ------------------- | 36 | | FP32 | 32 | - | | 10.172 | 15.756 | 12.736 | 1.692 | 59195 | 37 | | BF16 | 16 | - | | 10.173 | 15.765 | 12.745 | 1.692 | 29597 | 38 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 8 | 128 | | 10.174 | 15.767 | 12.739 | 1.692 | 16163 | 39 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | | 10.387 | 16.056 | 13.005 | 1.708 | 8877 | 40 | 41 | Quantization requires a large amount of CPU memory. However, the memory required can be reduced by using swap memory. 42 | 43 | Depending on the GPUs/drivers, there may be a difference in performance, which decreases as the model size increases.(https://github.com/IST-DASLab/gptq/issues/1) 44 | 45 | According to [GPTQ paper](https://arxiv.org/abs/2210.17323), As the size of the model increases, the difference in performance between FP16 and GPTQ decreases. 46 | 47 | ## Installation 48 | If you don't have [conda](https://docs.conda.io/en/latest/miniconda.html), install it first. 49 | ```shell 50 | conda create --name gptq python=3.9 -y 51 | conda activate gptq 52 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia 53 | # Or, if you're having trouble with conda, use pip with python3.9: 54 | # pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 55 | 56 | pip install -r requirements.txt 57 | python setup_cuda.py install 58 | ``` 59 | 60 | All experiments were run on a single NVIDIA RTX3090. 61 | 62 | # Language Generation 63 | ## SantaCoder 64 | Visit [mayank31398/santacoder-GPTQ-4bit-128g](https://huggingface.co/mayank31398/santacoder-GPTQ-4bit-128g) for the 4-bit weights. 65 | Visit [mayank31398/santacoder-GPTQ-8bit-128g](https://huggingface.co/mayank31398/santacoder-GPTQ-8bit-128g) for the 8-bit weights. 66 | ```shell 67 | # 4-bit 68 | git clone https://huggingface.co/mayank31398/santacoder-GPTQ-4bit-128g 69 | # 8-bit 70 | git clone https://huggingface.co/mayank31398/santacoder-GPTQ-8bit-128g 71 | ``` 72 | Alternatively, you can also use the [scripts](scripts/) to get the quantized models and save them to disk. 73 | 74 | For generation use: 75 | ```shell 76 | # fp32 77 | python -m santacoder_inference bigcode/gpt_bigcode-santacoder --wbits 32 78 | # bf16 79 | python -m santacoder_inference bigcode/gpt_bigcode-santacoder --wbits 16 80 | 81 | # GPTQ int8 82 | python -m santacoder_inference bigcode/gpt_bigcode-santacoder --wbits 8 --load santacoder-GPTQ-8bit-128g/model.pt 83 | # GPTQ int4 84 | python -m santacoder_inference bigcode/gpt_bigcode-santacoder --wbits 4 --load santacoder-GPTQ-4bit-128g/model.pt 85 | ``` 86 | 87 | ## StarCoder 88 | Visit [mayank31398/starcoder-GPTQ-4bit-128g](https://huggingface.co/mayank31398/starcoder-GPTQ-4bit-128g) for the 4-bit weights. 89 | Visit [mayank31398/starcoder-GPTQ-8bit-128g](https://huggingface.co/mayank31398/starcoder-GPTQ-8bit-128g) for the 8-bit weights. 90 | ```shell 91 | # 4-bit 92 | git clone https://huggingface.co/mayank31398/starcoder-GPTQ-4bit-128g 93 | # 8-bit 94 | git clone https://huggingface.co/mayank31398/starcoder-GPTQ-8bit-128g 95 | ``` 96 | Alternatively, you can also use the [scripts](scripts/) to get the quantized models and save them to disk. 97 | 98 | For generation use: 99 | ```shell 100 | # fp32 101 | python -m santacoder_inference bigcode/starcoder --wbits 32 102 | # bf16 103 | python -m santacoder_inference bigcode/starcoder --wbits 16 104 | 105 | # GPTQ int8 106 | python -m santacoder_inference bigcode/starcoder --wbits 8 --groupsize 128 --load starcoder-GPTQ-8bit-128g/model.pt 107 | # GPTQ int4 108 | python -m santacoder_inference bigcode/starcoder --wbits 4 --groupsize 128 --load starcoder-GPTQ-4bit-128g/model.pt 109 | ``` 110 | 111 | ## StarCoderBase 112 | Visit [mayank31398/starcoderbase-GPTQ-4bit-128g](https://huggingface.co/mayank31398/starcoderbase-GPTQ-4bit-128g) for the 4-bit weights. 113 | Visit [mayank31398/starcoderbase-GPTQ-8bit-128g](https://huggingface.co/mayank31398/starcoderbase-GPTQ-8bit-128g) for the 8-bit weights. 114 | ```shell 115 | # 4-bit 116 | git clone https://huggingface.co/mayank31398/starcoderbase-GPTQ-4bit-128g 117 | # 8-bit 118 | git clone https://huggingface.co/mayank31398/starcoderbase-GPTQ-8bit-128g 119 | ``` 120 | Alternatively, you can also use the [scripts](scripts/) to get the quantized models and save them to disk. 121 | 122 | For generation use: 123 | ```shell 124 | # fp32 125 | python -m santacoder_inference bigcode/starcoderbase --wbits 32 126 | # bf16 127 | python -m santacoder_inference bigcode/starcoderbase --wbits 16 128 | 129 | # GPTQ int8 130 | python -m santacoder_inference bigcode/starcoderbase --wbits 8 --groupsize 128 --load starcoderbase-GPTQ-8bit-128g/model.pt 131 | # GPTQ int4 132 | python -m santacoder_inference bigcode/starcoderbase --wbits 4 --groupsize 128 --load starcoderbase-GPTQ-4bit-128g/model.pt 133 | ``` 134 | 135 | # Acknowledgements 136 | This code is based on [GPTQ](https://github.com/IST-DASLab/gptq) 137 | 138 | Triton GPTQ kernel code is based on [GPTQ-triton](https://github.com/fpgaminer/GPTQ-triton) 139 | -------------------------------------------------------------------------------- /convert_llama_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from transformers.models.llama.conv̦ert_llama_weights_to_hf import write_model, write_tokenizer 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--input_dir", 11 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 12 | ) 13 | parser.add_argument( 14 | "--model_size", 15 | choices=["7B", "13B", "30B", "65B", "tokenizer_only"], 16 | ) 17 | parser.add_argument( 18 | "--output_dir", 19 | help="Location to write HF model and tokenizer", 20 | ) 21 | args = parser.parse_args() 22 | if args.model_size != "tokenizer_only": 23 | write_model( 24 | model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()), 25 | input_base_path=os.path.join(args.input_dir, args.model_size), 26 | model_size=args.model_size, 27 | ) 28 | write_tokenizer( 29 | tokenizer_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()), 30 | input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"), 31 | ) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from datasets import Dataset, load_dataset 8 | from transformers import AutoTokenizer 9 | 10 | 11 | def set_seed(seed): 12 | np.random.seed(seed) 13 | torch.random.manual_seed(seed) 14 | 15 | 16 | def get_wikitext2(nsamples, seed, seqlen, model): 17 | traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") 18 | testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 19 | 20 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 21 | trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") 22 | testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") 23 | 24 | random.seed(seed) 25 | trainloader = [] 26 | for _ in range(nsamples): 27 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 28 | j = i + seqlen 29 | inp = trainenc.input_ids[:, i:j] 30 | tar = inp.clone() 31 | tar[:, :-1] = -100 32 | trainloader.append((inp, tar)) 33 | return trainloader, testenc 34 | 35 | 36 | def get_ptb(nsamples, seed, seqlen, model): 37 | traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") 38 | valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") 39 | 40 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 41 | trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") 42 | testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") 43 | 44 | random.seed(seed) 45 | trainloader = [] 46 | for _ in range(nsamples): 47 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 48 | j = i + seqlen 49 | inp = trainenc.input_ids[:, i:j] 50 | tar = inp.clone() 51 | tar[:, :-1] = -100 52 | trainloader.append((inp, tar)) 53 | return trainloader, testenc 54 | 55 | 56 | def get_c4(nsamples, seed, seqlen, model): 57 | traindata = load_dataset( 58 | "allenai/c4", 59 | "allenai--c4", 60 | data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, 61 | split="train", 62 | use_auth_token=False, 63 | ) 64 | valdata = load_dataset( 65 | "allenai/c4", 66 | "allenai--c4", 67 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, 68 | split="validation", 69 | use_auth_token=False, 70 | ) 71 | 72 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 73 | 74 | random.seed(seed) 75 | trainloader = [] 76 | for _ in range(nsamples): 77 | while True: 78 | i = random.randint(0, len(traindata) - 1) 79 | trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") 80 | if trainenc.input_ids.shape[1] >= seqlen + 1: 81 | break 82 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 83 | j = i + seqlen 84 | inp = trainenc.input_ids[:, i:j] 85 | tar = inp.clone() 86 | tar[:, :-1] = -100 87 | trainloader.append((inp, tar)) 88 | 89 | random.seed(0) 90 | valenc = [] 91 | for _ in range(256): 92 | while True: 93 | i = random.randint(0, len(valdata) - 1) 94 | tmp = tokenizer(valdata[i]["text"], return_tensors="pt") 95 | if tmp.input_ids.shape[1] >= seqlen: 96 | break 97 | if tmp.input_ids.shape[1] - seqlen - 1 < 0: 98 | continue 99 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 100 | j = i + seqlen 101 | valenc.append(tmp.input_ids[:, i:j]) 102 | valenc = torch.hstack(valenc) 103 | 104 | class TokenizerWrapper: 105 | def __init__(self, input_ids): 106 | self.input_ids = input_ids 107 | 108 | valenc = TokenizerWrapper(valenc) 109 | 110 | return trainloader, valenc 111 | 112 | 113 | def get_ptb_new(nsamples, seed, seqlen, model): 114 | traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") 115 | testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") 116 | 117 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 118 | trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") 119 | testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") 120 | 121 | random.seed(seed) 122 | trainloader = [] 123 | for _ in range(nsamples): 124 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 125 | j = i + seqlen 126 | inp = trainenc.input_ids[:, i:j] 127 | tar = inp.clone() 128 | tar[:, :-1] = -100 129 | trainloader.append((inp, tar)) 130 | return trainloader, testenc 131 | 132 | 133 | def get_c4_new(nsamples, seed, seqlen, model): 134 | traindata = load_dataset( 135 | "allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train" 136 | ) 137 | valdata = load_dataset( 138 | "allenai/c4", 139 | "allenai--c4", 140 | data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, 141 | split="validation", 142 | ) 143 | 144 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 145 | 146 | random.seed(seed) 147 | trainloader = [] 148 | for _ in range(nsamples): 149 | while True: 150 | i = random.randint(0, len(traindata) - 1) 151 | trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") 152 | if trainenc.input_ids.shape[1] >= seqlen: 153 | break 154 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 155 | j = i + seqlen 156 | inp = trainenc.input_ids[:, i:j] 157 | tar = inp.clone() 158 | tar[:, :-1] = -100 159 | trainloader.append((inp, tar)) 160 | 161 | valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") 162 | valenc = valenc.input_ids[:, : (256 * seqlen)] 163 | 164 | class TokenizerWrapper: 165 | def __init__(self, input_ids): 166 | self.input_ids = input_ids 167 | 168 | valenc = TokenizerWrapper(valenc) 169 | 170 | return trainloader, valenc 171 | 172 | 173 | def get_stack(nsamples, seed, seqlen, model): 174 | languages = ["c++", "java", "javascript", "python"][:1] 175 | 176 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 177 | 178 | seed = 0 179 | size = 1000 180 | test_samples = 150 181 | 182 | trainloader = [] 183 | testloader = [] 184 | 185 | for language in languages: 186 | thestack = load_dataset( 187 | "bigcode/the-stack", 188 | split="train", 189 | streaming=True, 190 | data_files=[f"data/{language}/*"], 191 | ) 192 | # print(f"subset {language} loaded") 193 | ds = thestack.shuffle(seed=seed) 194 | 195 | # 10k subset of random samples from ds 196 | small_ds = list(ds.take(size)) 197 | # convert to Datasets 198 | small_ds = Dataset.from_pandas(pd.DataFrame(data=small_ds)) 199 | 200 | for i in range(len(small_ds)): 201 | trainenc = tokenizer(small_ds[i]["content"], return_tensors="pt")["input_ids"] 202 | if trainenc.shape[1] < seqlen + 1: 203 | continue 204 | 205 | i = random.randint(0, trainenc.shape[1] - seqlen - 1) 206 | j = i + seqlen 207 | inp = trainenc[:, i:j] 208 | 209 | if len(trainloader) < nsamples: 210 | tar = inp.clone() 211 | tar[:, :-1] = -100 212 | trainloader.append((inp, tar)) 213 | elif len(testloader) < test_samples: 214 | testloader.append(inp) 215 | else: 216 | break 217 | 218 | return trainloader, torch.cat(testloader, dim=-1) 219 | 220 | 221 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=""): 222 | if "wikitext2" in name: 223 | return get_wikitext2(nsamples, seed, seqlen, model) 224 | elif "ptb" in name: 225 | if "new" in name: 226 | return get_ptb_new(nsamples, seed, seqlen, model) 227 | return get_ptb(nsamples, seed, seqlen, model) 228 | elif "c4" in name: 229 | if "new" in name: 230 | return get_c4_new(nsamples, seed, seqlen, model) 231 | return get_c4(nsamples, seed, seqlen, model) 232 | elif "stack" in name: 233 | return get_stack(nsamples, seed, seqlen, model) 234 | -------------------------------------------------------------------------------- /fused_attn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.cuda.amp import custom_bwd, custom_fwd 5 | from torch.nn import functional as F 6 | from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb 7 | 8 | from quant import * 9 | 10 | 11 | class QuantLlamaAttention(nn.Module): 12 | """Multi-headed attention from 'Attention Is All You Need' paper""" 13 | 14 | def __init__( 15 | self, 16 | hidden_size, 17 | num_heads, 18 | qkv_proj, 19 | o_proj, 20 | rotary_emb, 21 | ): 22 | super().__init__() 23 | self.hidden_size = hidden_size 24 | self.num_heads = num_heads 25 | self.head_dim = hidden_size // num_heads 26 | 27 | if (self.head_dim * num_heads) != self.hidden_size: 28 | raise ValueError( 29 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 30 | f" and `num_heads`: {num_heads})." 31 | ) 32 | self.qkv_proj = qkv_proj 33 | self.o_proj = o_proj 34 | self.rotary_emb = rotary_emb 35 | 36 | def _shape(self, tensor, seq_len, bsz): 37 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 38 | 39 | def forward( 40 | self, 41 | hidden_states, 42 | past_key_value=None, 43 | attention_mask=None, 44 | position_ids=None, 45 | output_attentions=False, 46 | use_cache=False, 47 | ): 48 | """Input shape: Batch x Time x Channel""" 49 | 50 | bsz, q_len, _ = hidden_states.size() 51 | 52 | qkv_states = self.qkv_proj(hidden_states) 53 | query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) 54 | 55 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 56 | key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 57 | value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 58 | 59 | kv_seq_len = key_states.shape[-2] 60 | if past_key_value is not None: 61 | kv_seq_len += past_key_value[0].shape[-2] 62 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 63 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 64 | # [bsz, nh, t, hd] 65 | 66 | is_causal = past_key_value is None 67 | if past_key_value is not None: 68 | # reuse k, v, self_attention 69 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 70 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 71 | 72 | if use_cache: 73 | # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor 74 | # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. 75 | query_states = query_states.contiguous() 76 | key_states = key_states.contiguous() 77 | value_states = value_states.contiguous() 78 | 79 | past_key_value = (key_states, value_states) if use_cache else None 80 | 81 | with torch.backends.cuda.sdp_kernel(enable_math=False): 82 | attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) 83 | 84 | attn_output = attn_output.transpose(1, 2) 85 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 86 | 87 | attn_output = self.o_proj(attn_output) 88 | 89 | if not output_attentions: 90 | attn_weights = None 91 | 92 | return attn_output, attn_weights, past_key_value 93 | 94 | 95 | def make_quant_attn(model): 96 | """ 97 | Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. 98 | """ 99 | for name, m in model.named_modules(): 100 | if not isinstance(m, LlamaAttention): 101 | continue 102 | 103 | q_proj = m.q_proj 104 | k_proj = m.k_proj 105 | v_proj = m.v_proj 106 | 107 | qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) 108 | qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) 109 | scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) 110 | g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) 111 | bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None 112 | 113 | qkv_layer = QuantLinear( 114 | q_proj.bits, 115 | q_proj.groupsize, 116 | q_proj.infeatures, 117 | q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, 118 | True if q_proj.bias is not None else False, 119 | ) 120 | qkv_layer.qweight = qweights 121 | qkv_layer.qzeros = qzeros 122 | qkv_layer.scales = scales 123 | qkv_layer.g_idx = g_idx 124 | qkv_layer.bias = bias 125 | 126 | attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) 127 | 128 | if "." in name: 129 | parent_name = name.rsplit(".", 1)[0] 130 | child_name = name[len(parent_name) + 1 :] 131 | parent = model.get_submodule(parent_name) 132 | else: 133 | parent_name = "" 134 | parent = model 135 | child_name = name 136 | 137 | # print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") 138 | 139 | setattr(parent, child_name, attn) 140 | -------------------------------------------------------------------------------- /gptq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from quant import * 9 | 10 | DEBUG = False 11 | 12 | torch.backends.cuda.matmul.allow_tf32 = False 13 | torch.backends.cudnn.allow_tf32 = False 14 | 15 | 16 | class GPTQ: 17 | def __init__(self, layer): 18 | self.layer = layer 19 | self.dev = self.layer.weight.device 20 | W = layer.weight.data.clone() 21 | if isinstance(self.layer, nn.Conv2d): 22 | W = W.flatten(1) 23 | if isinstance(self.layer, transformers.Conv1D): 24 | W = W.t() 25 | self.rows = W.shape[0] 26 | self.columns = W.shape[1] 27 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 28 | self.nsamples = 0 29 | 30 | def add_batch(self, inp, out): 31 | if DEBUG: 32 | self.inp1 = inp 33 | self.out1 = out 34 | if len(inp.shape) == 2: 35 | inp = inp.unsqueeze(0) 36 | tmp = inp.shape[0] 37 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 38 | if len(inp.shape) == 3: 39 | inp = inp.reshape((-1, inp.shape[-1])) 40 | inp = inp.t() 41 | if isinstance(self.layer, nn.Conv2d): 42 | unfold = nn.Unfold( 43 | self.layer.kernel_size, 44 | dilation=self.layer.dilation, 45 | padding=self.layer.padding, 46 | stride=self.layer.stride, 47 | ) 48 | inp = unfold(inp) 49 | inp = inp.permute([1, 0, 2]) 50 | inp = inp.flatten(1) 51 | self.H *= self.nsamples / (self.nsamples + tmp) 52 | self.nsamples += tmp 53 | # inp = inp.float() 54 | inp = math.sqrt(2 / self.nsamples) * inp.float() 55 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 56 | self.H += inp.matmul(inp.t()) 57 | 58 | def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False): 59 | W = self.layer.weight.data.clone() 60 | if isinstance(self.layer, nn.Conv2d): 61 | W = W.flatten(1) 62 | if isinstance(self.layer, transformers.Conv1D): 63 | W = W.t() 64 | W = W.float() 65 | 66 | tick = time.time() 67 | 68 | if not self.quantizer.ready(): 69 | self.quantizer.find_params(W, weight=True) 70 | 71 | H = self.H 72 | del self.H 73 | dead = torch.diag(H) == 0 74 | H[dead, dead] = 1 75 | W[:, dead] = 0 76 | 77 | if actorder: 78 | perm = torch.argsort(torch.diag(H), descending=True) 79 | W = W[:, perm] 80 | H = H[perm][:, perm] 81 | 82 | Losses = torch.zeros_like(W) 83 | Q = torch.zeros_like(W) 84 | 85 | damp = percdamp * torch.mean(torch.diag(H)) 86 | diag = torch.arange(self.columns, device=self.dev) 87 | H[diag, diag] += damp 88 | H = torch.linalg.cholesky(H) 89 | H = torch.cholesky_inverse(H) 90 | H = torch.linalg.cholesky(H, upper=True) 91 | Hinv = H 92 | 93 | g_idx = [] 94 | scale = [] 95 | zero = [] 96 | now_idx = 1 97 | 98 | for i1 in range(0, self.columns, blocksize): 99 | i2 = min(i1 + blocksize, self.columns) 100 | count = i2 - i1 101 | 102 | W1 = W[:, i1:i2].clone() 103 | Q1 = torch.zeros_like(W1) 104 | Err1 = torch.zeros_like(W1) 105 | Losses1 = torch.zeros_like(W1) 106 | Hinv1 = Hinv[i1:i2, i1:i2] 107 | 108 | for i in range(count): 109 | w = W1[:, i] 110 | d = Hinv1[i, i] 111 | 112 | if groupsize != -1: 113 | if (i1 + i) % groupsize == 0: 114 | self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + groupsize)], weight=True) 115 | 116 | if ((i1 + i) // groupsize) - now_idx == -1: 117 | scale.append(self.quantizer.scale) 118 | zero.append(self.quantizer.zero) 119 | now_idx += 1 120 | 121 | q = quantize(w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq).flatten() 122 | Q1[:, i] = q 123 | Losses1[:, i] = (w - q) ** 2 / d**2 124 | 125 | err1 = (w - q) / d 126 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 127 | Err1[:, i] = err1 128 | 129 | Q[:, i1:i2] = Q1 130 | Losses[:, i1:i2] = Losses1 / 2 131 | 132 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 133 | 134 | if DEBUG: 135 | self.layer.weight.data[:, :i2] = Q[:, :i2] 136 | self.layer.weight.data[:, i2:] = W[:, i2:] 137 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 138 | print(torch.sum(Losses)) 139 | 140 | torch.cuda.synchronize() 141 | print("time %.2f" % (time.time() - tick)) 142 | print("error", torch.sum(Losses).item()) 143 | 144 | groupsize = groupsize if groupsize != -1 else self.columns 145 | g_idx = [i // groupsize for i in range(self.columns)] 146 | g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) 147 | if actorder: 148 | invperm = torch.argsort(perm) 149 | Q = Q[:, invperm] 150 | g_idx = g_idx[invperm] 151 | 152 | if isinstance(self.layer, transformers.Conv1D): 153 | Q = Q.t() 154 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 155 | if DEBUG: 156 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 157 | 158 | if scale == []: 159 | scale.append(self.quantizer.scale) 160 | zero.append(self.quantizer.zero) 161 | scale = torch.cat(scale, dim=1) 162 | zero = torch.cat(zero, dim=1) 163 | return scale, zero, g_idx 164 | 165 | def free(self): 166 | if DEBUG: 167 | self.inp1 = None 168 | self.out1 = None 169 | self.H = None 170 | self.Losses = None 171 | self.Trace = None 172 | torch.cuda.empty_cache() 173 | -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from fused_attn import * 7 | from gptq import * 8 | from modelutils import * 9 | from quant import * 10 | 11 | 12 | def get_llama(model): 13 | import torch 14 | 15 | def skip(*args, **kwargs): 16 | pass 17 | 18 | torch.nn.init.kaiming_uniform_ = skip 19 | torch.nn.init.uniform_ = skip 20 | torch.nn.init.normal_ = skip 21 | from transformers import LlamaForCausalLM 22 | 23 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype="auto") 24 | model.seqlen = 2048 25 | return model 26 | 27 | 28 | @torch.no_grad() 29 | def llama_sequential(model, dataloader, dev): 30 | print("Starting ...") 31 | 32 | use_cache = model.config.use_cache 33 | model.config.use_cache = False 34 | layers = model.model.layers 35 | 36 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 37 | model.model.norm = model.model.norm.to(dev) 38 | layers[0] = layers[0].to(dev) 39 | 40 | dtype = next(iter(model.parameters())).dtype 41 | inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 42 | cache = {"i": 0, "attention_mask": None} 43 | 44 | class Catcher(nn.Module): 45 | def __init__(self, module): 46 | super().__init__() 47 | self.module = module 48 | 49 | def forward(self, inp, **kwargs): 50 | inps[cache["i"]] = inp 51 | cache["i"] += 1 52 | cache["attention_mask"] = kwargs["attention_mask"] 53 | cache["position_ids"] = kwargs["position_ids"] 54 | raise ValueError 55 | 56 | layers[0] = Catcher(layers[0]) 57 | for batch in dataloader: 58 | try: 59 | model(batch[0].to(dev)) 60 | except ValueError: 61 | pass 62 | layers[0] = layers[0].module 63 | 64 | layers[0] = layers[0].cpu() 65 | model.model.embed_tokens = model.model.embed_tokens.cpu() 66 | model.model.norm = model.model.norm.cpu() 67 | torch.cuda.empty_cache() 68 | 69 | outs = torch.zeros_like(inps) 70 | attention_mask = cache["attention_mask"] 71 | position_ids = cache["position_ids"] 72 | 73 | print("Ready.") 74 | 75 | quantizers = {} 76 | for i in range(len(layers)): 77 | layer = layers[i].to(dev) 78 | full = find_layers(layer) 79 | if args.true_sequential: 80 | sequential = [ 81 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 82 | ["self_attn.o_proj"], 83 | ["mlp.up_proj", "mlp.gate_proj"], 84 | ["mlp.down_proj"], 85 | ] 86 | else: 87 | sequential = [list(full.keys())] 88 | 89 | for names in sequential: 90 | subset = {n: full[n] for n in names} 91 | gptq = {} 92 | for name in subset: 93 | gptq[name] = GPTQ(subset[name]) 94 | gptq[name].quantizer = Quantizer() 95 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 96 | 97 | def add_batch(name): 98 | def tmp(_, inp, out): 99 | gptq[name].add_batch(inp[0].data, out.data) 100 | 101 | return tmp 102 | 103 | handles = [] 104 | for name in subset: 105 | handles.append(subset[name].register_forward_hook(add_batch(name))) 106 | for j in range(args.nsamples): 107 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 108 | for h in handles: 109 | h.remove() 110 | 111 | for name in subset: 112 | print(f"Quantizing {name} in layer {i+1}/{len(layers)}...") 113 | scale, zero, g_idx = gptq[name].fasterquant( 114 | percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order 115 | ) 116 | quantizers["model.layers.%d.%s" % (i, name)] = ( 117 | gptq[name].quantizer.cpu(), 118 | scale.cpu(), 119 | zero.cpu(), 120 | g_idx.cpu(), 121 | ) 122 | gptq[name].free() 123 | 124 | for j in range(args.nsamples): 125 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 126 | 127 | layers[i] = layer.cpu() 128 | del layer 129 | del gptq 130 | torch.cuda.empty_cache() 131 | 132 | inps, outs = outs, inps 133 | 134 | model.config.use_cache = use_cache 135 | 136 | return quantizers 137 | 138 | 139 | @torch.no_grad() 140 | def llama_eval(model, testenc, dev): 141 | print("Evaluating ...") 142 | 143 | testenc = testenc.input_ids 144 | nsamples = testenc.numel() // model.seqlen 145 | 146 | use_cache = model.config.use_cache 147 | model.config.use_cache = False 148 | layers = model.model.layers 149 | 150 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 151 | layers[0] = layers[0].to(dev) 152 | 153 | dtype = next(iter(model.parameters())).dtype 154 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 155 | cache = {"i": 0, "attention_mask": None} 156 | 157 | class Catcher(nn.Module): 158 | def __init__(self, module): 159 | super().__init__() 160 | self.module = module 161 | 162 | def forward(self, inp, **kwargs): 163 | inps[cache["i"]] = inp 164 | cache["i"] += 1 165 | cache["attention_mask"] = kwargs["attention_mask"] 166 | cache["position_ids"] = kwargs["position_ids"] 167 | raise ValueError 168 | 169 | layers[0] = Catcher(layers[0]) 170 | for i in range(nsamples): 171 | batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev) 172 | try: 173 | model(batch) 174 | except ValueError: 175 | pass 176 | layers[0] = layers[0].module 177 | 178 | layers[0] = layers[0].cpu() 179 | model.model.embed_tokens = model.model.embed_tokens.cpu() 180 | torch.cuda.empty_cache() 181 | 182 | outs = torch.zeros_like(inps) 183 | attention_mask = cache["attention_mask"] 184 | position_ids = cache["position_ids"] 185 | 186 | for i in range(len(layers)): 187 | print(i) 188 | layer = layers[i].to(dev) 189 | 190 | if args.nearest: 191 | subset = find_layers(layer) 192 | for name in subset: 193 | quantizer = Quantizer() 194 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 195 | W = subset[name].weight.data 196 | quantizer.find_params(W, weight=True) 197 | subset[name].weight.data = quantize(W, quantizer.scale, quantizer.zero, quantizer.maxq).to( 198 | next(iter(layer.parameters())).dtype 199 | ) 200 | 201 | for j in range(nsamples): 202 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 203 | layers[i] = layer.cpu() 204 | del layer 205 | torch.cuda.empty_cache() 206 | inps, outs = outs, inps 207 | 208 | if model.model.norm is not None: 209 | model.model.norm = model.model.norm.to(dev) 210 | model.lm_head = model.lm_head.to(dev) 211 | testenc = testenc.to(dev) 212 | nlls = [] 213 | for i in range(nsamples): 214 | hidden_states = inps[i].unsqueeze(0) 215 | if model.model.norm is not None: 216 | hidden_states = model.model.norm(hidden_states) 217 | lm_logits = model.lm_head(hidden_states) 218 | shift_logits = lm_logits[:, :-1, :].contiguous() 219 | shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] 220 | loss_fct = nn.CrossEntropyLoss() 221 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 222 | neg_log_likelihood = loss.float() * model.seqlen 223 | nlls.append(neg_log_likelihood) 224 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 225 | print(ppl.item()) 226 | 227 | model.config.use_cache = use_cache 228 | 229 | 230 | # TODO: perform packing on GPU 231 | def llama_pack(model, quantizers, wbits, groupsize): 232 | layers = find_layers(model) 233 | layers = {n: layers[n] for n in quantizers} 234 | make_quant(model, quantizers, wbits, groupsize) 235 | qlayers = find_layers(model, [QuantLinear]) 236 | print("Packing ...") 237 | for name in qlayers: 238 | print(name) 239 | quantizers[name], scale, zero, g_idx = quantizers[name] 240 | qlayers[name].pack(layers[name], scale, zero, g_idx) 241 | print("Done.") 242 | return model 243 | 244 | 245 | def load_quant(model, checkpoint, wbits, groupsize=-1): 246 | from transformers import LlamaConfig, LlamaForCausalLM 247 | 248 | config = LlamaConfig.from_pretrained(model) 249 | 250 | def noop(*args, **kwargs): 251 | pass 252 | 253 | torch.nn.init.kaiming_uniform_ = noop 254 | torch.nn.init.uniform_ = noop 255 | torch.nn.init.normal_ = noop 256 | 257 | torch.set_default_dtype(torch.half) 258 | transformers.modeling_utils._init_weights = False 259 | torch.set_default_dtype(torch.half) 260 | model = LlamaForCausalLM(config) 261 | torch.set_default_dtype(torch.float) 262 | model = model.eval() 263 | layers = find_layers(model) 264 | for name in ["lm_head"]: 265 | if name in layers: 266 | del layers[name] 267 | make_quant(model, layers, wbits, groupsize) 268 | 269 | del layers 270 | 271 | print("Loading model ...") 272 | if checkpoint.endswith(".safetensors"): 273 | from safetensors.torch import load_file as safe_load 274 | 275 | model.load_state_dict(safe_load(checkpoint), strict=False) 276 | else: 277 | model.load_state_dict(torch.load(checkpoint), strict=False) 278 | 279 | make_quant_attn(model) 280 | model.seqlen = 2048 281 | print("Done.") 282 | 283 | return model 284 | 285 | 286 | def llama_multigpu(model, gpus): 287 | model.model.embed_tokens = model.model.embed_tokens.to(gpus[0]) 288 | if hasattr(model.model, "norm") and model.model.norm: 289 | model.model.norm = model.model.norm.to(gpus[-1]) 290 | import copy 291 | 292 | model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1]) 293 | 294 | cache = {"mask": None} 295 | 296 | class MoveModule(nn.Module): 297 | def __init__(self, module): 298 | super().__init__() 299 | self.module = module 300 | self.dev = next(iter(self.module.parameters())).device 301 | 302 | def forward(self, *inp, **kwargs): 303 | inp = list(inp) 304 | if inp[0].device != self.dev: 305 | inp[0] = inp[0].to(self.dev) 306 | if cache["mask"] is None or cache["mask"].device != self.dev: 307 | cache["mask"] = kwargs["attention_mask"].to(self.dev) 308 | kwargs["attention_mask"] = cache["mask"] 309 | tmp = self.module(*inp, **kwargs) 310 | return tmp 311 | 312 | layers = model.model.layers 313 | pergpu = math.ceil(len(layers) / len(gpus)) 314 | for i in range(len(layers)): 315 | layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) 316 | 317 | model.gpus = gpus 318 | 319 | 320 | def benchmark(model, input_ids, check=False): 321 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, "gpus") else DEV) 322 | torch.cuda.synchronize() 323 | 324 | cache = {"past": None} 325 | 326 | def clear_past(i): 327 | def tmp(layer, inp, out): 328 | if cache["past"]: 329 | cache["past"][i] = None 330 | 331 | return tmp 332 | 333 | for i, layer in enumerate(model.model.layers): 334 | layer.register_forward_hook(clear_past(i)) 335 | 336 | print("Benchmarking ...") 337 | 338 | if check: 339 | loss = nn.CrossEntropyLoss() 340 | tot = 0.0 341 | 342 | def sync(): 343 | if hasattr(model, "gpus"): 344 | for gpu in model.gpus: 345 | torch.cuda.synchronize(gpu) 346 | else: 347 | torch.cuda.synchronize() 348 | 349 | max_memory = 0 350 | with torch.no_grad(): 351 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV) 352 | times = [] 353 | for i in range(input_ids.numel()): 354 | tick = time.time() 355 | out = model( 356 | input_ids[:, i : i + 1], 357 | past_key_values=cache["past"], 358 | attention_mask=attention_mask[:, : (i + 1)].reshape((1, -1)), 359 | ) 360 | sync() 361 | times.append(time.time() - tick) 362 | print(i, times[-1]) 363 | max_memory = max(max_memory, torch.cuda.memory_allocated() / 1024 / 1024) 364 | if check and i != input_ids.numel() - 1: 365 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() 366 | cache["past"] = list(out.past_key_values) 367 | del out 368 | sync() 369 | import numpy as np 370 | 371 | print("Median:", np.median(times)) 372 | if check: 373 | print("PPL:", torch.exp(tot / (input_ids.numel() - 1)).item()) 374 | print("max memory(MiB):", max_memory) 375 | 376 | 377 | if __name__ == "__main__": 378 | import argparse 379 | 380 | from datautils import * 381 | 382 | parser = argparse.ArgumentParser() 383 | 384 | parser.add_argument("model", type=str, help="llama model to load") 385 | parser.add_argument( 386 | "dataset", type=str, choices=["wikitext2", "ptb", "c4"], help="Where to extract calibration data from." 387 | ) 388 | parser.add_argument("--seed", type=int, default=0, help="Seed for sampling the calibration data.") 389 | parser.add_argument("--nsamples", type=int, default=128, help="Number of calibration data samples.") 390 | parser.add_argument( 391 | "--percdamp", type=float, default=0.01, help="Percent of the average Hessian diagonal to use for dampening." 392 | ) 393 | parser.add_argument("--nearest", action="store_true", help="Whether to run the RTN baseline.") 394 | parser.add_argument( 395 | "--wbits", 396 | type=int, 397 | default=16, 398 | choices=[2, 3, 4, 8, 16], 399 | help="#bits to use for quantization; use 16 for evaluating base model.", 400 | ) 401 | parser.add_argument("--trits", action="store_true", help="Whether to use trits for quantization.") 402 | parser.add_argument( 403 | "--groupsize", type=int, default=-1, help="Groupsize to use for quantization; default uses full row." 404 | ) 405 | parser.add_argument("--eval", action="store_true", help="evaluate quantized model.") 406 | parser.add_argument("--save", type=str, default="", help="Save quantized checkpoint under this name.") 407 | parser.add_argument( 408 | "--save_safetensors", type=str, default="", help="Save quantized `.safetensors` checkpoint under this name." 409 | ) 410 | parser.add_argument("--load", type=str, default="", help="Load quantized model.") 411 | parser.add_argument("--benchmark", type=int, default=0, help="Number of tokens to use for benchmarking.") 412 | parser.add_argument( 413 | "--check", action="store_true", help="Whether to compute perplexity during benchmarking for verification." 414 | ) 415 | parser.add_argument("--sym", action="store_true", help="Whether to perform symmetric quantization.") 416 | parser.add_argument( 417 | "--act-order", action="store_true", help="Whether to apply the activation order GPTQ heuristic" 418 | ) 419 | parser.add_argument("--true-sequential", action="store_true", help="Whether to run in true sequential model.") 420 | parser.add_argument("--new-eval", action="store_true", help="Whether to use the new PTB and C4 eval") 421 | 422 | args = parser.parse_args() 423 | 424 | if type(args.load) is not str: 425 | args.load = args.load.as_posix() 426 | 427 | if args.load: 428 | model = load_quant(args.model, args.load, args.wbits, args.groupsize) 429 | else: 430 | model = get_llama(args.model) 431 | model.eval() 432 | 433 | dataloader, testloader = get_loaders( 434 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 435 | ) 436 | 437 | if not args.load and args.wbits < 16 and not args.nearest: 438 | tick = time.time() 439 | quantizers = llama_sequential(model, dataloader, DEV) 440 | print(time.time() - tick) 441 | 442 | if args.benchmark: 443 | gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] 444 | if len(gpus) > 1: 445 | llama_multigpu(model, gpus) 446 | else: 447 | model = model.to(DEV) 448 | if args.benchmark: 449 | input_ids = next(iter(dataloader))[0][:, : args.benchmark] 450 | benchmark(model, input_ids, check=args.check) 451 | 452 | if args.load: 453 | exit() 454 | 455 | if args.eval: 456 | datasets = ["wikitext2", "ptb", "c4"] 457 | if args.new_eval: 458 | datasets = ["wikitext2", "ptb-new", "c4-new"] 459 | for dataset in datasets: 460 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) 461 | print(dataset) 462 | llama_eval(model, testloader, DEV) 463 | 464 | if args.save: 465 | llama_pack(model, quantizers, args.wbits, args.groupsize) 466 | torch.save(model.state_dict(), args.save) 467 | 468 | if args.save_safetensors: 469 | llama_pack(model, quantizers, args.wbits, args.groupsize) 470 | from safetensors.torch import save_file as safe_save 471 | 472 | state_dict = model.state_dict() 473 | state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} 474 | safe_save(state_dict, args.save_safetensors) 475 | -------------------------------------------------------------------------------- /llama_inference.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import AutoTokenizer 6 | 7 | from gptq import * 8 | from modelutils import * 9 | from quant import * 10 | 11 | DEV = torch.device("cuda:0") 12 | 13 | 14 | def get_llama(model): 15 | import torch 16 | 17 | def skip(*args, **kwargs): 18 | pass 19 | 20 | torch.nn.init.kaiming_uniform_ = skip 21 | torch.nn.init.uniform_ = skip 22 | torch.nn.init.normal_ = skip 23 | from transformers import LlamaForCausalLM 24 | 25 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype="auto") 26 | model.seqlen = 2048 27 | return model 28 | 29 | 30 | def load_quant(model, checkpoint, wbits, groupsize, device): 31 | from transformers import LlamaConfig, LlamaForCausalLM 32 | 33 | config = LlamaConfig.from_pretrained(model) 34 | 35 | def noop(*args, **kwargs): 36 | pass 37 | 38 | torch.nn.init.kaiming_uniform_ = noop 39 | torch.nn.init.uniform_ = noop 40 | torch.nn.init.normal_ = noop 41 | 42 | torch.set_default_dtype(torch.half) 43 | transformers.modeling_utils._init_weights = False 44 | torch.set_default_dtype(torch.half) 45 | model = LlamaForCausalLM(config) 46 | torch.set_default_dtype(torch.float) 47 | model = model.eval() 48 | layers = find_layers(model) 49 | for name in ["lm_head"]: 50 | if name in layers: 51 | del layers[name] 52 | make_quant(model, layers, wbits, groupsize) 53 | 54 | print("Loading model ...") 55 | if checkpoint.endswith(".safetensors"): 56 | from safetensors.torch import load_file as safe_load 57 | 58 | if device == -1: 59 | device = "cpu" 60 | model.load_state_dict(safe_load(checkpoint, device)) 61 | else: 62 | model.load_state_dict(torch.load(checkpoint)) 63 | model.seqlen = 2048 64 | print("Done.") 65 | 66 | return model 67 | 68 | 69 | if __name__ == "__main__": 70 | import argparse 71 | 72 | from datautils import * 73 | 74 | parser = argparse.ArgumentParser() 75 | 76 | parser.add_argument("model", type=str, help="llama model to load") 77 | parser.add_argument( 78 | "--wbits", 79 | type=int, 80 | default=16, 81 | choices=[2, 3, 4, 8, 16], 82 | help="#bits to use for quantization; use 16 for evaluating base model.", 83 | ) 84 | parser.add_argument( 85 | "--groupsize", type=int, default=-1, help="Groupsize to use for quantization; default uses full row." 86 | ) 87 | parser.add_argument("--load", type=str, default="", help="Load quantized model.") 88 | 89 | parser.add_argument("--text", type=str, help="input text") 90 | 91 | parser.add_argument( 92 | "--min_length", type=int, default=10, help="The minimum length of the sequence to be generated." 93 | ) 94 | 95 | parser.add_argument( 96 | "--max_length", type=int, default=50, help="The maximum length of the sequence to be generated." 97 | ) 98 | 99 | parser.add_argument( 100 | "--top_p", 101 | type=float, 102 | default=0.95, 103 | help="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.", 104 | ) 105 | 106 | parser.add_argument( 107 | "--temperature", type=float, default=0.8, help="The value used to module the next token probabilities." 108 | ) 109 | 110 | parser.add_argument( 111 | "--device", 112 | type=int, 113 | default=-1, 114 | help='The device used to load the model when using safetensors. Default device is "cpu" or specify, 0,1,2,3,... for GPU device.', 115 | ) 116 | 117 | args = parser.parse_args() 118 | 119 | if type(args.load) is not str: 120 | args.load = args.load.as_posix() 121 | 122 | if args.load: 123 | model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.device) 124 | else: 125 | model = get_llama(args.model) 126 | model.eval() 127 | 128 | model.to(DEV) 129 | tokenizer = AutoTokenizer.from_pretrained(args.model) 130 | input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV) 131 | 132 | with torch.no_grad(): 133 | generated_ids = model.generate( 134 | input_ids, 135 | do_sample=True, 136 | min_length=args.min_length, 137 | max_length=args.max_length, 138 | top_p=args.top_p, 139 | temperature=args.temperature, 140 | ) 141 | print(tokenizer.decode([el.item() for el in generated_ids[0]])) 142 | -------------------------------------------------------------------------------- /llama_inference_dmapauto.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import accelerate 4 | import torch 5 | import torch.nn as nn 6 | from transformers import AutoTokenizer 7 | 8 | import share_tensors_across_processes 9 | from gptq import * 10 | from modelutils import * 11 | from quant import * 12 | 13 | 14 | def get_llama(model): 15 | import torch 16 | 17 | def skip(*args, **kwargs): 18 | pass 19 | 20 | torch.nn.init.kaiming_uniform_ = skip 21 | torch.nn.init.uniform_ = skip 22 | torch.nn.init.normal_ = skip 23 | from transformers import LlamaForCausalLM 24 | 25 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16, device_map="auto") 26 | model.seqlen = 2048 27 | return model 28 | 29 | 30 | def load_quant(model, checkpoint, wbits, groupsize, device_map): 31 | from transformers import LlamaConfig, LlamaForCausalLM 32 | 33 | config = LlamaConfig.from_pretrained(model) 34 | 35 | def noop(*args, **kwargs): 36 | pass 37 | 38 | torch.nn.init.kaiming_uniform_ = noop 39 | torch.nn.init.uniform_ = noop 40 | torch.nn.init.normal_ = noop 41 | 42 | torch.set_default_dtype(torch.half) 43 | transformers.modeling_utils._init_weights = False 44 | torch.set_default_dtype(torch.half) 45 | with accelerate.init_empty_weights(): 46 | model = LlamaForCausalLM(config) 47 | torch.set_default_dtype(torch.float) 48 | model = model.eval() 49 | layers = find_layers(model) 50 | for name in ["lm_head"]: 51 | if name in layers: 52 | del layers[name] 53 | make_quant(model, layers, wbits, groupsize) 54 | 55 | print("Loading model ...") 56 | model = accelerate.load_checkpoint_and_dispatch( 57 | model, checkpoint, device_map=device_map, no_split_module_classes=["LlamaDecoderLayer"] 58 | ) 59 | model.seqlen = 2048 60 | print("Done.") 61 | 62 | return model 63 | 64 | 65 | if __name__ == "__main__": 66 | import argparse 67 | 68 | from datautils import * 69 | 70 | parser = argparse.ArgumentParser() 71 | 72 | parser.add_argument("model", type=str, help="llama model to load") 73 | parser.add_argument( 74 | "--wbits", 75 | type=int, 76 | default=16, 77 | choices=[2, 3, 4, 8, 16], 78 | help="#bits to use for quantization; use 16 for evaluating base model.", 79 | ) 80 | parser.add_argument( 81 | "--groupsize", type=int, default=-1, help="Groupsize to use for quantization; default uses full row." 82 | ) 83 | parser.add_argument("--load", type=str, default="", help="Load quantized model.") 84 | 85 | parser.add_argument("--text", type=str, help="input text") 86 | 87 | parser.add_argument( 88 | "--min_length", type=int, default=10, help="The minimum length of the sequence to be generated." 89 | ) 90 | 91 | parser.add_argument( 92 | "--max_length", type=int, default=50, help="The maximum length of the sequence to be generated." 93 | ) 94 | 95 | parser.add_argument( 96 | "--top_p", 97 | type=float, 98 | default=0.95, 99 | help="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.", 100 | ) 101 | 102 | parser.add_argument( 103 | "--temperature", type=float, default=0.8, help="The value used to module the next token probabilities." 104 | ) 105 | 106 | parser.add_argument( 107 | "--do_sample", action="store_true", help="Perform multinomial sampling (slow) to produce varied output." 108 | ) 109 | 110 | parser.add_argument( 111 | "--enable_eos_token", 112 | action="store_true", 113 | help="Check for the completion token every forward pass: https://github.com/huggingface/transformers/pull/22875", 114 | ) 115 | 116 | parser.add_argument( 117 | "--device_map", 118 | type=str, 119 | default="auto", 120 | help='The device_map used to load the model when using accelerate. Default is "auto".', 121 | ) 122 | 123 | parser.add_argument( 124 | "--keep_alive", action="store_true", help="Keep the process alive so others can share the loaded model." 125 | ) 126 | 127 | args = parser.parse_args() 128 | 129 | if type(args.load) is not str: 130 | args.load = args.load.as_posix() 131 | 132 | if args.load: 133 | model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.device_map) 134 | else: 135 | model = get_llama(args.model) 136 | model.eval() 137 | 138 | if args.text is not None: 139 | DEV = next(model.parameters()).device 140 | tokenizer = AutoTokenizer.from_pretrained(args.model) 141 | input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV) 142 | 143 | if not args.enable_eos_token: 144 | model.config.eos_token_id = None 145 | with torch.no_grad(): 146 | generated_ids = model.generate( 147 | input_ids, 148 | do_sample=args.do_sample, 149 | min_length=args.min_length, 150 | max_length=args.max_length, 151 | top_p=args.top_p, 152 | temperature=args.temperature, 153 | ) 154 | print(tokenizer.decode([el.item() for el in generated_ids[0]])) 155 | 156 | if args.keep_alive: 157 | print("Keeping process alive to reference model memory.") 158 | print("Further processes should launch faster.") 159 | while True: 160 | time.sleep(60) 161 | -------------------------------------------------------------------------------- /llama_inference_offload.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import AutoTokenizer 6 | 7 | from gptq import * 8 | from modelutils import * 9 | from quant import * 10 | 11 | DEV = torch.device("cuda:0") 12 | import copy 13 | import time 14 | from typing import List, Optional, Tuple, Union 15 | 16 | from transformers.modeling_outputs import BaseModelOutputWithPast 17 | from transformers.models.llama.modeling_llama import LlamaConfig, LlamaModel 18 | 19 | 20 | class Offload_LlamaModel(LlamaModel): 21 | def __init__(self, config: LlamaConfig): 22 | super().__init__(config) 23 | 24 | def forward( 25 | self, 26 | input_ids: torch.LongTensor = None, 27 | attention_mask: Optional[torch.Tensor] = None, 28 | position_ids: Optional[torch.LongTensor] = None, 29 | past_key_values: Optional[List[torch.FloatTensor]] = None, 30 | inputs_embeds: Optional[torch.FloatTensor] = None, 31 | use_cache: Optional[bool] = None, 32 | output_attentions: Optional[bool] = None, 33 | output_hidden_states: Optional[bool] = None, 34 | return_dict: Optional[bool] = None, 35 | ) -> Union[Tuple, BaseModelOutputWithPast]: 36 | r""" 37 | Args: 38 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 39 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 40 | provide it. 41 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 42 | [`PreTrainedTokenizer.__call__`] for details. 43 | [What are input IDs?](../glossary#input-ids) 44 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 45 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 46 | - 1 for tokens that are **not masked**, 47 | - 0 for tokens that are **masked**. 48 | [What are attention masks?](../glossary#attention-mask) 49 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 50 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range 51 | `[0, config.n_positions - 1]`. 52 | [What are position IDs?](../glossary#position-ids) 53 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 54 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 55 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 56 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 57 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 58 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 59 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 60 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 61 | use_cache (`bool`, *optional*): 62 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 63 | (see `past_key_values`). 64 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 65 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 66 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 67 | than the model's internal embedding lookup matrix. 68 | output_attentions (`bool`, *optional*): 69 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 70 | returned tensors for more detail. 71 | output_hidden_states (`bool`, *optional*): 72 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 73 | for more detail. 74 | return_dict (`bool`, *optional*): 75 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 76 | """ 77 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 78 | output_hidden_states = ( 79 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 80 | ) 81 | use_cache = use_cache if use_cache is not None else self.config.use_cache 82 | 83 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 84 | 85 | # retrieve input_ids and inputs_embeds 86 | if input_ids is not None and inputs_embeds is not None: 87 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 88 | elif input_ids is not None: 89 | batch_size, seq_length = input_ids.shape 90 | elif inputs_embeds is not None: 91 | batch_size, seq_length, _ = inputs_embeds.shape 92 | else: 93 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 94 | seq_length_with_past = seq_length 95 | past_key_values_length = 0 96 | if past_key_values is not None: 97 | past_key_values_length = past_key_values[0][0].shape[2] 98 | seq_length_with_past = seq_length_with_past + past_key_values_length 99 | 100 | if position_ids is None: 101 | device = input_ids.device if input_ids is not None else inputs_embeds.device 102 | position_ids = torch.arange( 103 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 104 | ) 105 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 106 | else: 107 | position_ids = position_ids.view(-1, seq_length).long() 108 | 109 | if inputs_embeds is None: 110 | inputs_embeds = self.embed_tokens(input_ids) 111 | 112 | # embed positions 113 | if attention_mask is None: 114 | attention_mask = torch.ones( 115 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device 116 | ) 117 | attention_mask = self._prepare_decoder_attention_mask( 118 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 119 | ) 120 | 121 | hidden_states = inputs_embeds 122 | 123 | if self.gradient_checkpointing and self.training: 124 | if use_cache: 125 | logger.warning_once( 126 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 127 | ) 128 | use_cache = False 129 | 130 | # decoder layers 131 | all_hidden_states = () if output_hidden_states else None 132 | all_self_attns = () if output_attentions else None 133 | next_decoder_cache = () if use_cache else None 134 | 135 | for idx in range(len(self.layers)): 136 | if idx <= (self.preload - 1): 137 | decoder_layer = self.layers[idx] 138 | else: 139 | decoder_layer = self.layers[idx].to(DEV) 140 | 141 | if output_hidden_states: 142 | all_hidden_states += (hidden_states,) 143 | 144 | past_key_value = past_key_values[idx] if past_key_values is not None else None 145 | 146 | if self.gradient_checkpointing and self.training: 147 | 148 | def create_custom_forward(module): 149 | def custom_forward(*inputs): 150 | # None for past_key_value 151 | return module(*inputs, output_attentions, None) 152 | 153 | return custom_forward 154 | 155 | layer_outputs = torch.utils.checkpoint.checkpoint( 156 | create_custom_forward(decoder_layer), 157 | hidden_states, 158 | attention_mask, 159 | position_ids, 160 | None, 161 | ) 162 | else: 163 | layer_outputs = decoder_layer( 164 | hidden_states, 165 | attention_mask=attention_mask, 166 | position_ids=position_ids, 167 | past_key_value=past_key_value, 168 | output_attentions=output_attentions, 169 | use_cache=use_cache, 170 | ) 171 | 172 | hidden_states = layer_outputs[0] 173 | 174 | if idx > (self.preload - 1): 175 | self.layers[idx] = decoder_layer.cpu() 176 | del decoder_layer 177 | torch.cuda.empty_cache() 178 | 179 | if use_cache: 180 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 181 | 182 | if output_attentions: 183 | all_self_attns += (layer_outputs[1],) 184 | 185 | hidden_states = self.norm(hidden_states) 186 | 187 | # add hidden states from the last decoder layer 188 | if output_hidden_states: 189 | all_hidden_states += (hidden_states,) 190 | 191 | next_cache = next_decoder_cache if use_cache else None 192 | if not return_dict: 193 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 194 | return BaseModelOutputWithPast( 195 | last_hidden_state=hidden_states, 196 | past_key_values=next_cache, 197 | hidden_states=all_hidden_states, 198 | attentions=all_self_attns, 199 | ) 200 | 201 | 202 | def load_quant(model, checkpoint, wbits, groupsize, pre_layer): 203 | transformers.models.llama.modeling_llama.LlamaModel = Offload_LlamaModel 204 | from transformers import LlamaConfig, LlamaForCausalLM 205 | 206 | config = LlamaConfig.from_pretrained(model) 207 | 208 | def noop(*args, **kwargs): 209 | pass 210 | 211 | torch.nn.init.kaiming_uniform_ = noop 212 | torch.nn.init.uniform_ = noop 213 | torch.nn.init.normal_ = noop 214 | 215 | torch.set_default_dtype(torch.half) 216 | transformers.modeling_utils._init_weights = False 217 | torch.set_default_dtype(torch.half) 218 | model = LlamaForCausalLM(config) 219 | torch.set_default_dtype(torch.float) 220 | model = model.eval() 221 | layers = find_layers(model) 222 | for name in ["lm_head"]: 223 | if name in layers: 224 | del layers[name] 225 | make_quant(model, layers, wbits, groupsize) 226 | 227 | print("Loading model ...") 228 | if checkpoint.endswith(".safetensors"): 229 | from safetensors.torch import load_file as safe_load 230 | 231 | model.load_state_dict(safe_load(checkpoint)) 232 | else: 233 | model.load_state_dict(torch.load(checkpoint)) 234 | model.seqlen = 2048 235 | 236 | for i in range(pre_layer): 237 | model.model.layers[i].to(DEV) 238 | model.model.embed_tokens.to(DEV) 239 | model.model.norm.to(DEV) 240 | model.lm_head.to(DEV) 241 | model.model.preload = pre_layer 242 | print("Done.") 243 | return model 244 | 245 | 246 | if __name__ == "__main__": 247 | import argparse 248 | 249 | from datautils import * 250 | 251 | parser = argparse.ArgumentParser() 252 | 253 | parser.add_argument("model", type=str, help="llama model to load") 254 | parser.add_argument("--wbits", type=int, default=4, choices=[2, 3, 4, 8], help="#bits to use for quantization") 255 | parser.add_argument( 256 | "--groupsize", type=int, default=-1, help="Groupsize to use for quantization; default uses full row." 257 | ) 258 | parser.add_argument("--load", type=str, default="", help="Load quantized model.") 259 | parser.add_argument("--text", type=str, help="input text") 260 | 261 | parser.add_argument( 262 | "--min_length", type=int, default=10, help="The minimum length of the sequence to be generated." 263 | ) 264 | 265 | parser.add_argument( 266 | "--max_length", type=int, default=50, help="The maximum length of the sequence to be generated." 267 | ) 268 | 269 | parser.add_argument( 270 | "--top_p", 271 | type=float, 272 | default=0.95, 273 | help="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.", 274 | ) 275 | 276 | parser.add_argument( 277 | "--temperature", type=float, default=0.8, help="The value used to module the next token probabilities." 278 | ) 279 | 280 | parser.add_argument("--pre_layer", type=int, default=50, help="The number of layers to preload") 281 | 282 | args = parser.parse_args() 283 | 284 | if type(args.load) is not str: 285 | args.load = args.load.as_posix() 286 | 287 | model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.pre_layer) 288 | 289 | tokenizer = AutoTokenizer.from_pretrained(args.model) 290 | input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV) 291 | 292 | with torch.no_grad(): 293 | generated_ids = model.generate( 294 | input_ids, 295 | do_sample=True, 296 | min_length=args.min_length, 297 | max_length=args.max_length, 298 | top_p=args.top_p, 299 | temperature=args.temperature, 300 | ) 301 | print(tokenizer.decode([el.item() for el in generated_ids[0]])) 302 | -------------------------------------------------------------------------------- /modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | DEV = torch.device("cuda:0") 5 | 6 | 7 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): 8 | if type(module) in layers: 9 | return {name: module} 10 | res = {} 11 | for name1, child in module.named_children(): 12 | res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) 13 | return res 14 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import transformers 6 | 7 | from gptq import * 8 | from modelutils import * 9 | from quant import * 10 | 11 | 12 | def get_opt(model): 13 | import torch 14 | 15 | def skip(*args, **kwargs): 16 | pass 17 | 18 | torch.nn.init.kaiming_uniform_ = skip 19 | torch.nn.init.uniform_ = skip 20 | torch.nn.init.normal_ = skip 21 | from transformers import OPTForCausalLM 22 | 23 | model = OPTForCausalLM.from_pretrained(model, torch_dtype="auto") 24 | model.seqlen = model.config.max_position_embeddings 25 | return model 26 | 27 | 28 | @torch.no_grad() 29 | def opt_sequential(model, dataloader, dev): 30 | print("Starting ...") 31 | 32 | use_cache = model.config.use_cache 33 | model.config.use_cache = False 34 | layers = model.model.decoder.layers 35 | 36 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 37 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 38 | if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: 39 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 40 | if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: 41 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 42 | layers[0] = layers[0].to(dev) 43 | 44 | dtype = next(iter(model.parameters())).dtype 45 | inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 46 | cache = {"i": 0, "attention_mask": None} 47 | 48 | class Catcher(nn.Module): 49 | def __init__(self, module): 50 | super().__init__() 51 | self.module = module 52 | 53 | def forward(self, inp, **kwargs): 54 | inps[cache["i"]] = inp 55 | cache["i"] += 1 56 | cache["attention_mask"] = kwargs["attention_mask"] 57 | raise ValueError 58 | 59 | layers[0] = Catcher(layers[0]) 60 | for batch in dataloader: 61 | try: 62 | model(batch[0].to(dev)) 63 | except ValueError: 64 | pass 65 | layers[0] = layers[0].module 66 | 67 | layers[0] = layers[0].cpu() 68 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 69 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 70 | if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: 71 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 72 | if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: 73 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 74 | torch.cuda.empty_cache() 75 | 76 | outs = torch.zeros_like(inps) 77 | attention_mask = cache["attention_mask"] 78 | 79 | print("Ready.") 80 | 81 | quantizers = {} 82 | for i in range(len(layers)): 83 | layer = layers[i].to(dev) 84 | 85 | subset = find_layers(layer) 86 | gptq = {} 87 | for name in subset: 88 | gptq[name] = GPTQ(subset[name]) 89 | gptq[name].quantizer = Quantizer() 90 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits) 91 | 92 | def add_batch(name): 93 | def tmp(_, inp, out): 94 | gptq[name].add_batch(inp[0].data, out.data) 95 | 96 | return tmp 97 | 98 | handles = [] 99 | for name in subset: 100 | handles.append(subset[name].register_forward_hook(add_batch(name))) 101 | 102 | for j in range(args.nsamples): 103 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 104 | 105 | for h in handles: 106 | h.remove() 107 | 108 | for name in subset: 109 | print(f"Quantizing {name} in layer {i+1}/{len(layers)}...") 110 | scale, zero, g_idx = gptq[name].fasterquant( 111 | percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order 112 | ) 113 | quantizers["model.decoder.layers.%d.%s" % (i, name)] = ( 114 | gptq[name].quantizer.cpu(), 115 | scale.cpu(), 116 | zero.cpu(), 117 | g_idx.cpu(), 118 | ) 119 | gptq[name].free() 120 | 121 | for j in range(args.nsamples): 122 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 123 | 124 | layers[i] = layer.cpu() 125 | del layer 126 | del gptq 127 | torch.cuda.empty_cache() 128 | 129 | inps, outs = outs, inps 130 | 131 | model.config.use_cache = use_cache 132 | 133 | return quantizers 134 | 135 | 136 | @torch.no_grad() 137 | def opt_eval(model, testenc, dev): 138 | print("Evaluating ...") 139 | 140 | testenc = testenc.input_ids 141 | nsamples = testenc.numel() // model.seqlen 142 | 143 | use_cache = model.config.use_cache 144 | model.config.use_cache = False 145 | layers = model.model.decoder.layers 146 | 147 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 148 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 149 | if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: 150 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 151 | if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: 152 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 153 | layers[0] = layers[0].to(dev) 154 | 155 | dtype = next(iter(model.parameters())).dtype 156 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 157 | cache = {"i": 0, "attention_mask": None} 158 | 159 | class Catcher(nn.Module): 160 | def __init__(self, module): 161 | super().__init__() 162 | self.module = module 163 | 164 | def forward(self, inp, **kwargs): 165 | inps[cache["i"]] = inp 166 | cache["i"] += 1 167 | cache["attention_mask"] = kwargs["attention_mask"] 168 | raise ValueError 169 | 170 | layers[0] = Catcher(layers[0]) 171 | for i in range(nsamples): 172 | batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev) 173 | try: 174 | model(batch) 175 | except ValueError: 176 | pass 177 | layers[0] = layers[0].module 178 | 179 | layers[0] = layers[0].cpu() 180 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 181 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 182 | if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: 183 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 184 | if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: 185 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 186 | torch.cuda.empty_cache() 187 | 188 | outs = torch.zeros_like(inps) 189 | attention_mask = cache["attention_mask"] 190 | 191 | for i in range(len(layers)): 192 | print(i) 193 | layer = layers[i].to(dev) 194 | 195 | if args.nearest: 196 | subset = find_layers(layer) 197 | for name in subset: 198 | quantizer = Quantizer() 199 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 200 | W = subset[name].weight.data 201 | quantizer.find_params(W, weight=True) 202 | subset[name].weight.data = quantize(W, quantizer.scale, quantizer.zero, quantizer.maxq).to( 203 | next(iter(layer.parameters())).dtype 204 | ) 205 | 206 | for j in range(nsamples): 207 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 208 | layers[i] = layer.cpu() 209 | del layer 210 | torch.cuda.empty_cache() 211 | inps, outs = outs, inps 212 | 213 | if model.model.decoder.final_layer_norm is not None: 214 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 215 | if model.model.decoder.project_out is not None: 216 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 217 | model.lm_head = model.lm_head.to(dev) 218 | 219 | testenc = testenc.to(dev) 220 | nlls = [] 221 | for i in range(nsamples): 222 | hidden_states = inps[i].unsqueeze(0) 223 | if model.model.decoder.final_layer_norm is not None: 224 | hidden_states = model.model.decoder.final_layer_norm(hidden_states) 225 | if model.model.decoder.project_out is not None: 226 | hidden_states = model.model.decoder.project_out(hidden_states) 227 | lm_logits = model.lm_head(hidden_states) 228 | shift_logits = lm_logits[:, :-1, :].contiguous() 229 | shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] 230 | loss_fct = nn.CrossEntropyLoss() 231 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 232 | neg_log_likelihood = loss.float() * model.seqlen 233 | nlls.append(neg_log_likelihood) 234 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 235 | print(ppl.item()) 236 | 237 | model.config.use_cache = use_cache 238 | 239 | 240 | # TODO: perform packing on GPU 241 | def opt_pack(model, quantizers, wbits, groupsize): 242 | layers = find_layers(model) 243 | layers = {n: layers[n] for n in quantizers} 244 | make_quant(model, quantizers, wbits, groupsize) 245 | qlayers = find_layers(model, [QuantLinear]) 246 | print("Packing ...") 247 | for name in qlayers: 248 | print(name) 249 | quantizers[name], scale, zero, g_idx = quantizers[name] 250 | qlayers[name].pack(layers[name], scale, zero, g_idx) 251 | print("Done.") 252 | return model 253 | 254 | 255 | def load_quant(model, checkpoint, wbits, groupsize): 256 | from transformers import OPTConfig, OPTForCausalLM 257 | 258 | config = OPTConfig.from_pretrained(model) 259 | 260 | def noop(*args, **kwargs): 261 | pass 262 | 263 | torch.nn.init.kaiming_uniform_ = noop 264 | torch.nn.init.uniform_ = noop 265 | torch.nn.init.normal_ = noop 266 | 267 | torch.set_default_dtype(torch.half) 268 | transformers.modeling_utils._init_weights = False 269 | torch.set_default_dtype(torch.half) 270 | model = OPTForCausalLM(config) 271 | torch.set_default_dtype(torch.float) 272 | model = model.eval() 273 | layers = find_layers(model) 274 | for name in ["model.decoder.project_out", "model.decoder.project_in", "lm_head"]: 275 | if name in layers: 276 | del layers[name] 277 | make_quant(model, layers, wbits, groupsize) 278 | 279 | del layers 280 | 281 | print("Loading model ...") 282 | if checkpoint.endswith(".safetensors"): 283 | from safetensors.torch import load_file as safe_load 284 | 285 | model.load_state_dict(safe_load(checkpoint)) 286 | else: 287 | model.load_state_dict(torch.load(checkpoint)) 288 | model.seqlen = model.config.max_position_embeddings 289 | print("Done.") 290 | return model 291 | 292 | 293 | def opt_multigpu(model, gpus): 294 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0]) 295 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0]) 296 | if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: 297 | model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0]) 298 | if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: 299 | model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1]) 300 | if hasattr(model.model.decoder, "final_layer_norm") and model.model.decoder.final_layer_norm: 301 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1]) 302 | import copy 303 | 304 | model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1]) 305 | 306 | cache = {"mask": None} 307 | 308 | class MoveModule(nn.Module): 309 | def __init__(self, module): 310 | super().__init__() 311 | self.module = module 312 | self.dev = next(iter(self.module.parameters())).device 313 | 314 | def forward(self, *inp, **kwargs): 315 | inp = list(inp) 316 | if inp[0].device != self.dev: 317 | inp[0] = inp[0].to(self.dev) 318 | if cache["mask"] is None or cache["mask"].device != self.dev: 319 | cache["mask"] = kwargs["attention_mask"].to(self.dev) 320 | kwargs["attention_mask"] = cache["mask"] 321 | tmp = self.module(*inp, **kwargs) 322 | return tmp 323 | 324 | layers = model.model.decoder.layers 325 | pergpu = math.ceil(len(layers) / len(gpus)) 326 | for i in range(len(layers)): 327 | layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) 328 | 329 | model.gpus = gpus 330 | 331 | 332 | def benchmark(model, input_ids, check=False): 333 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, "gpus") else DEV) 334 | torch.cuda.synchronize() 335 | 336 | cache = {"past": None} 337 | 338 | def clear_past(i): 339 | def tmp(layer, inp, out): 340 | if cache["past"]: 341 | cache["past"][i] = None 342 | 343 | return tmp 344 | 345 | for i, layer in enumerate(model.model.decoder.layers): 346 | layer.register_forward_hook(clear_past(i)) 347 | 348 | print("Benchmarking ...") 349 | 350 | if check: 351 | loss = nn.CrossEntropyLoss() 352 | tot = 0.0 353 | 354 | def sync(): 355 | if hasattr(model, "gpus"): 356 | for gpu in model.gpus: 357 | torch.cuda.synchronize(gpu) 358 | else: 359 | torch.cuda.synchronize() 360 | 361 | with torch.no_grad(): 362 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV) 363 | times = [] 364 | for i in range(input_ids.numel()): 365 | tick = time.time() 366 | out = model( 367 | input_ids[:, i].reshape(-1), 368 | past_key_values=cache["past"], 369 | attention_mask=attention_mask[:, : (i + 1)].reshape((1, -1)), 370 | ) 371 | sync() 372 | times.append(time.time() - tick) 373 | print(i, times[-1]) 374 | if check and i != input_ids.numel() - 1: 375 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() 376 | cache["past"] = list(out.past_key_values) 377 | del out 378 | sync() 379 | import numpy as np 380 | 381 | print("Median:", np.median(times)) 382 | if check: 383 | print("PPL:", torch.exp(tot / (input_ids.numel() - 1)).item()) 384 | 385 | 386 | if __name__ == "__main__": 387 | import argparse 388 | 389 | from datautils import * 390 | 391 | parser = argparse.ArgumentParser() 392 | 393 | parser.add_argument("model", type=str, help="OPT model to load; pass `facebook/opt-X`.") 394 | parser.add_argument( 395 | "dataset", type=str, choices=["wikitext2", "ptb", "c4"], help="Where to extract calibration data from." 396 | ) 397 | parser.add_argument("--seed", type=int, default=0, help="Seed for sampling the calibration data.") 398 | parser.add_argument("--nsamples", type=int, default=128, help="Number of calibration data samples.") 399 | parser.add_argument( 400 | "--percdamp", type=float, default=0.01, help="Percent of the average Hessian diagonal to use for dampening." 401 | ) 402 | parser.add_argument("--nearest", action="store_true", help="Whether to run the RTN baseline.") 403 | parser.add_argument( 404 | "--wbits", 405 | type=int, 406 | default=16, 407 | choices=[2, 3, 4, 8, 16], 408 | help="#bits to use for quantization; use 16 for evaluating base model.", 409 | ) 410 | parser.add_argument("--trits", action="store_true", help="Whether to use trits for quantization.") 411 | parser.add_argument( 412 | "--groupsize", type=int, default=-1, help="Groupsize to use for quantization; default uses full row." 413 | ) 414 | parser.add_argument("--eval", action="store_true", help="evaluate quantized model.") 415 | parser.add_argument("--save", type=str, default="", help="Save quantized checkpoint under this name.") 416 | parser.add_argument( 417 | "--save_safetensors", type=str, default="", help="Save quantized `.safetensors` checkpoint under this name." 418 | ) 419 | parser.add_argument("--load", type=str, default="", help="Load quantized model.") 420 | parser.add_argument("--benchmark", type=int, default=0, help="Number of tokens to use for benchmarking.") 421 | parser.add_argument( 422 | "--check", action="store_true", help="Whether to compute perplexity during benchmarking for verification." 423 | ) 424 | parser.add_argument("--sym", action="store_true", help="Whether to perform symmetric quantization.") 425 | parser.add_argument( 426 | "--act-order", action="store_true", help="Whether to apply the activation order GPTQ heuristic" 427 | ) 428 | parser.add_argument("--new-eval", action="store_true", help="Whether to use the new PTB and C4 eval") 429 | 430 | args = parser.parse_args() 431 | 432 | if type(args.load) is not str: 433 | args.load = args.load.as_posix() 434 | 435 | if args.load: 436 | model = load_quant(args.model, args.load, args.wbits, args.groupsize) 437 | else: 438 | model = get_opt(args.model) 439 | model.eval() 440 | 441 | dataloader, testloader = get_loaders( 442 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 443 | ) 444 | 445 | if not args.load and args.wbits < 16 and not args.nearest: 446 | tick = time.time() 447 | quantizers = opt_sequential(model, dataloader, DEV) 448 | print(time.time() - tick) 449 | 450 | if args.benchmark: 451 | gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] 452 | if len(gpus) > 1: 453 | opt_multigpu(model, gpus) 454 | else: 455 | model = model.to(DEV) 456 | if args.benchmark: 457 | input_ids = next(iter(dataloader))[0][:, : args.benchmark] 458 | benchmark(model, input_ids, check=args.check) 459 | 460 | if args.load: 461 | exit() 462 | 463 | if args.eval: 464 | datasets = ["wikitext2", "ptb", "c4"] 465 | if args.new_eval: 466 | datasets = ["wikitext2", "ptb-new", "c4-new"] 467 | for dataset in datasets: 468 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) 469 | print(dataset) 470 | opt_eval(model, testloader, DEV) 471 | 472 | if args.save: 473 | opt_pack(model, quantizers, args.wbits, args.groupsize) 474 | torch.save(model.state_dict(), args.save) 475 | 476 | if args.save_safetensors: 477 | opt_pack(model, quantizers, args.wbits, args.groupsize) 478 | from safetensors.torch import save_file as safe_save 479 | 480 | state_dict = model.state_dict() 481 | state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} 482 | safe_save(state_dict, args.save_safetensors) 483 | -------------------------------------------------------------------------------- /quant.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def quantize(x, scale, zero, maxq): 9 | if maxq < 0: 10 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 11 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 12 | return scale * (q - zero) 13 | 14 | 15 | class Quantizer(nn.Module): 16 | def __init__(self, shape=1): 17 | super(Quantizer, self).__init__() 18 | self.register_buffer("maxq", torch.tensor(0)) 19 | self.register_buffer("scale", torch.zeros(shape)) 20 | self.register_buffer("zero", torch.zeros(shape)) 21 | 22 | def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8, trits=False): 23 | self.maxq = torch.tensor(2**bits - 1) 24 | self.perchannel = perchannel 25 | self.sym = sym 26 | self.mse = mse 27 | self.norm = norm 28 | self.grid = grid 29 | self.maxshrink = maxshrink 30 | if trits: 31 | self.maxq = torch.tensor(-1) 32 | 33 | def find_params(self, x, weight=False): 34 | dev = x.device 35 | self.maxq = self.maxq.to(dev) 36 | 37 | shape = x.shape 38 | if self.perchannel: 39 | if weight: 40 | x = x.flatten(1) 41 | else: 42 | if len(shape) == 4: 43 | x = x.permute([1, 0, 2, 3]) 44 | x = x.flatten(1) 45 | if len(shape) == 3: 46 | x = x.reshape((-1, shape[-1])).t() 47 | if len(shape) == 2: 48 | x = x.t() 49 | else: 50 | x = x.flatten().unsqueeze(0) 51 | 52 | tmp = torch.zeros(x.shape[0], device=dev) 53 | xmin = torch.minimum(x.min(1)[0], tmp) 54 | xmax = torch.maximum(x.max(1)[0], tmp) 55 | 56 | if self.sym: 57 | xmax = torch.maximum(torch.abs(xmin), xmax) 58 | tmp = xmin < 0 59 | if torch.any(tmp): 60 | xmin[tmp] = -xmax[tmp] 61 | tmp = (xmin == 0) & (xmax == 0) 62 | xmin[tmp] = -1 63 | xmax[tmp] = +1 64 | 65 | if self.maxq < 0: 66 | self.scale = xmax 67 | self.zero = xmin 68 | else: 69 | self.scale = (xmax - xmin) / self.maxq 70 | if self.sym: 71 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 72 | else: 73 | self.zero = torch.round(-xmin / self.scale) 74 | 75 | if self.mse: 76 | best = torch.full([x.shape[0]], float("inf"), device=dev) 77 | for i in range(int(self.maxshrink * self.grid)): 78 | p = 1 - i / self.grid 79 | xmin1 = p * xmin 80 | xmax1 = p * xmax 81 | scale1 = (xmax1 - xmin1) / self.maxq 82 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 83 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 84 | q -= x 85 | q.abs_() 86 | q.pow_(self.norm) 87 | err = torch.sum(q, 1) 88 | tmp = err < best 89 | if torch.any(tmp): 90 | best[tmp] = err[tmp] 91 | self.scale[tmp] = scale1[tmp] 92 | self.zero[tmp] = zero1[tmp] 93 | if not self.perchannel: 94 | if weight: 95 | tmp = shape[0] 96 | else: 97 | tmp = shape[1] if len(shape) != 3 else shape[2] 98 | self.scale = self.scale.repeat(tmp) 99 | self.zero = self.zero.repeat(tmp) 100 | 101 | if weight: 102 | shape = [-1] + [1] * (len(shape) - 1) 103 | self.scale = self.scale.reshape(shape) 104 | self.zero = self.zero.reshape(shape) 105 | return 106 | if len(shape) == 4: 107 | self.scale = self.scale.reshape((1, -1, 1, 1)) 108 | self.zero = self.zero.reshape((1, -1, 1, 1)) 109 | if len(shape) == 3: 110 | self.scale = self.scale.reshape((1, 1, -1)) 111 | self.zero = self.zero.reshape((1, 1, -1)) 112 | if len(shape) == 2: 113 | self.scale = self.scale.unsqueeze(0) 114 | self.zero = self.zero.unsqueeze(0) 115 | 116 | def quantize(self, x): 117 | if self.ready(): 118 | return quantize(x, self.scale, self.zero, self.maxq) 119 | return x 120 | 121 | def enabled(self): 122 | return self.maxq > 0 123 | 124 | def ready(self): 125 | return torch.all(self.scale != 0) 126 | 127 | 128 | try: 129 | import quant_cuda 130 | 131 | is_cuda = True 132 | except: 133 | print("CUDA extension not installed.") 134 | is_cuda = False 135 | 136 | 137 | def make_quant(module, names, bits, groupsize, name=""): 138 | if isinstance(module, QuantLinear): 139 | return 140 | for attr in dir(module): 141 | tmp = getattr(module, attr) 142 | name1 = name + "." + attr if name != "" else attr 143 | if name1 in names: 144 | delattr(module, attr) 145 | setattr( 146 | module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None) 147 | ) 148 | for name1, child in module.named_children(): 149 | make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1) 150 | 151 | 152 | class QuantLinear(nn.Module): 153 | def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda): 154 | super().__init__() 155 | if bits not in [2, 3, 4, 8]: 156 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 157 | self.infeatures = infeatures 158 | self.outfeatures = outfeatures 159 | self.bits = bits 160 | self.groupsize = groupsize if groupsize != -1 else infeatures 161 | self.maxq = 2**self.bits - 1 162 | 163 | self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) 164 | self.register_buffer( 165 | "qzeros", 166 | torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32), 167 | ) 168 | self.register_buffer( 169 | "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16) 170 | ) 171 | self.register_buffer( 172 | "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32) 173 | ) 174 | if bias: 175 | self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) 176 | else: 177 | self.bias = None 178 | 179 | # is performed by unpacking the weights and using torch.matmul 180 | if self.bits in [2, 4, 8]: 181 | self.register_buffer( 182 | "wf", torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0), persistent=False 183 | ) 184 | elif self.bits == 3: 185 | self.register_buffer( 186 | "wf", 187 | torch.tensor( 188 | [ 189 | [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], 190 | [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], 191 | [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0], 192 | ], 193 | dtype=torch.int32, 194 | ).reshape(1, 3, 12), 195 | persistent=False, 196 | ) 197 | 198 | self.kernel_switch_threshold = kernel_switch_threshold 199 | self.is_cuda = is_cuda 200 | 201 | def pack(self, linear, scales, zeros, g_idx=None): 202 | self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx 203 | 204 | scales = scales.t().contiguous() 205 | zeros = zeros.t().contiguous() 206 | scale_zeros = zeros * scales 207 | self.scales = scales.clone().half() 208 | if linear.bias is not None: 209 | self.bias = linear.bias.clone().half() 210 | 211 | intweight = [] 212 | for idx in range(self.infeatures): 213 | intweight.append( 214 | torch.round( 215 | (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]] 216 | ).to(torch.int)[:, None] 217 | ) 218 | intweight = torch.cat(intweight, dim=1) 219 | intweight = intweight.t().contiguous() 220 | intweight = intweight.numpy().astype(np.uint32) 221 | qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) 222 | i = 0 223 | row = 0 224 | while row < qweight.shape[0]: 225 | if self.bits in [2, 4, 8]: 226 | for j in range(i, i + (32 // self.bits)): 227 | qweight[row] |= intweight[j] << (self.bits * (j - i)) 228 | i += 32 // self.bits 229 | row += 1 230 | elif self.bits == 3: 231 | for j in range(i, i + 10): 232 | qweight[row] |= intweight[j] << (3 * (j - i)) 233 | i += 10 234 | qweight[row] |= intweight[i] << 30 235 | row += 1 236 | qweight[row] |= (intweight[i] >> 2) & 1 237 | i += 1 238 | for j in range(i, i + 10): 239 | qweight[row] |= intweight[j] << (3 * (j - i) + 1) 240 | i += 10 241 | qweight[row] |= intweight[i] << 31 242 | row += 1 243 | qweight[row] |= (intweight[i] >> 1) & 0x3 244 | i += 1 245 | for j in range(i, i + 10): 246 | qweight[row] |= intweight[j] << (3 * (j - i) + 2) 247 | i += 10 248 | row += 1 249 | else: 250 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 251 | 252 | qweight = qweight.astype(np.int32) 253 | self.qweight = torch.from_numpy(qweight) 254 | 255 | zeros -= 1 256 | zeros = zeros.numpy().astype(np.uint32) 257 | qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) 258 | i = 0 259 | col = 0 260 | while col < qzeros.shape[1]: 261 | if self.bits in [2, 4, 8]: 262 | for j in range(i, i + (32 // self.bits)): 263 | qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) 264 | i += 32 // self.bits 265 | col += 1 266 | elif self.bits == 3: 267 | for j in range(i, i + 10): 268 | qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) 269 | i += 10 270 | qzeros[:, col] |= zeros[:, i] << 30 271 | col += 1 272 | qzeros[:, col] |= (zeros[:, i] >> 2) & 1 273 | i += 1 274 | for j in range(i, i + 10): 275 | qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) 276 | i += 10 277 | qzeros[:, col] |= zeros[:, i] << 31 278 | col += 1 279 | qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 280 | i += 1 281 | for j in range(i, i + 10): 282 | qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) 283 | i += 10 284 | col += 1 285 | else: 286 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 287 | 288 | qzeros = qzeros.astype(np.int32) 289 | self.qzeros = torch.from_numpy(qzeros) 290 | 291 | def forward(self, x): 292 | out_shape = x.shape[:-1] + (self.outfeatures,) 293 | x = x.reshape(-1, x.shape[-1]) 294 | if self.is_cuda is True and ( 295 | self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold 296 | ): 297 | out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) 298 | if self.bits == 2: 299 | quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) 300 | elif self.bits == 3: 301 | quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) 302 | elif self.bits == 4: 303 | quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) 304 | elif self.bits == 8: 305 | quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) 306 | out = out.half() 307 | else: 308 | if self.bits in [2, 4, 8]: 309 | zeros = torch.bitwise_right_shift( 310 | torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0) 311 | ).to(torch.int16 if self.bits == 8 else torch.int8) 312 | torch.bitwise_and(zeros, (2**self.bits) - 1, out=zeros) 313 | 314 | zeros = zeros + 1 315 | zeros = zeros.reshape(self.scales.shape) 316 | 317 | weight = torch.bitwise_right_shift( 318 | torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1) 319 | ).to(torch.int16 if self.bits == 8 else torch.int8) 320 | torch.bitwise_and(weight, (2**self.bits) - 1, out=weight) 321 | elif self.bits == 3: 322 | zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand( 323 | -1, -1, -1, 12 324 | ) 325 | zeros = zeros >> self.wf.unsqueeze(0) 326 | zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4) 327 | zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6) 328 | zeros = zeros & 0x7 329 | zeros = torch.cat([zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], dim=2) 330 | 331 | zeros = zeros + 1 332 | zeros = zeros.reshape(self.scales.shape) 333 | 334 | weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand( 335 | -1, -1, 12, -1 336 | ) 337 | weight = (weight >> self.wf.unsqueeze(-1)) & 0x7 338 | weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4) 339 | weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6) 340 | weight = weight & 0x7 341 | weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1) 342 | 343 | weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) 344 | num_itr = self.g_idx.shape[0] // x.shape[-1] 345 | if num_itr == 1: 346 | weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]) 347 | else: 348 | num_dim = self.g_idx.shape[0] // num_itr 349 | weights = [] 350 | for i in range(num_itr): 351 | scale_i = self.scales[:, i * num_dim : (i + 1) * num_dim] 352 | weight_i = weight[:, i * num_dim : (i + 1) * num_dim] 353 | zeros_i = zeros[:, i * num_dim : (i + 1) * num_dim] 354 | g_idx_i = self.g_idx[i * num_dim : (i + 1) * num_dim] 355 | weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) 356 | weights = torch.cat(weights, dim=1) 357 | out = torch.matmul(x.half(), weights) 358 | out = out.reshape(out_shape) 359 | out = out + self.bias if self.bias is not None else out 360 | return out 361 | -------------------------------------------------------------------------------- /quant_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void vecquant2matmul_cuda( 6 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 7 | torch::Tensor scales, torch::Tensor zeros, 8 | torch::Tensor g_idx 9 | ); 10 | 11 | void vecquant2matmul( 12 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 13 | torch::Tensor scales, torch::Tensor zeros, 14 | torch::Tensor g_idx 15 | ) { 16 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 17 | vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx); 18 | } 19 | 20 | void vecquant3matmul_cuda( 21 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 22 | torch::Tensor scales, torch::Tensor zeros, 23 | torch::Tensor g_idx 24 | ); 25 | 26 | void vecquant3matmul( 27 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 28 | torch::Tensor scales, torch::Tensor zeros, 29 | torch::Tensor g_idx 30 | ) { 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 32 | vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx); 33 | } 34 | 35 | void vecquant4matmul_cuda( 36 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 37 | torch::Tensor scales, torch::Tensor zeros, 38 | torch::Tensor g_idx 39 | ); 40 | 41 | void vecquant4matmul( 42 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 43 | torch::Tensor scales, torch::Tensor zeros, 44 | torch::Tensor g_idx 45 | ) { 46 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 47 | vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); 48 | } 49 | 50 | void vecquant8matmul_cuda( 51 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 52 | torch::Tensor scales, torch::Tensor zeros, 53 | torch::Tensor g_idx 54 | ); 55 | 56 | void vecquant8matmul( 57 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 58 | torch::Tensor scales, torch::Tensor zeros, 59 | torch::Tensor g_idx 60 | ) { 61 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 62 | vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); 63 | } 64 | 65 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 66 | m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); 67 | m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); 68 | m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); 69 | m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); 70 | } 71 | -------------------------------------------------------------------------------- /quant_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // atomicAdd for double-precision floating-point numbers on hardware with 8 | // compute capability < 6.0 from: 9 | // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions 10 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 11 | __device__ double atomicAdd( 12 | double* address, 13 | double val 14 | ) { 15 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 16 | unsigned long long int old = *address_as_ull, assumed; 17 | 18 | do { 19 | assumed = old; 20 | old = atomicCAS( 21 | address_as_ull, 22 | assumed, 23 | __double_as_longlong(val + __longlong_as_double(assumed)) 24 | ); 25 | 26 | // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) 27 | } while (assumed != old); 28 | 29 | return __longlong_as_double(old); 30 | } 31 | #endif 32 | 33 | template 34 | __global__ void VecQuant2MatMulKernel( 35 | const scalar_t* __restrict__ vec, 36 | const int* __restrict__ mat, 37 | scalar_t* __restrict__ mul, 38 | const scalar_t* __restrict__ scales, 39 | const int* __restrict__ zeros, 40 | const int* __restrict__ g_idx, 41 | int batch, 42 | int vec_height, 43 | int height, 44 | int width, 45 | int zero_width 46 | ); 47 | 48 | template 49 | __global__ void VecQuant3MatMulKernel( 50 | const scalar_t* __restrict__ vec, 51 | const int* __restrict__ mat, 52 | scalar_t* __restrict__ mul, 53 | const scalar_t* __restrict__ scales, 54 | const int* __restrict__ zeros, 55 | const int* __restrict__ g_idx, 56 | int batch, 57 | int vec_height, 58 | int height, 59 | int width, 60 | int zero_width 61 | ); 62 | 63 | template 64 | __global__ void VecQuant4MatMulKernel( 65 | const scalar_t* __restrict__ vec, 66 | const int* __restrict__ mat, 67 | scalar_t* __restrict__ mul, 68 | const scalar_t* __restrict__ scales, 69 | const int* __restrict__ zeros, 70 | const int* __restrict__ g_idx, 71 | int batch, 72 | int vec_height, 73 | int height, 74 | int width, 75 | int zero_width 76 | ); 77 | 78 | template 79 | __global__ void VecQuant8MatMulKernel( 80 | const scalar_t* __restrict__ vec, 81 | const int* __restrict__ mat, 82 | scalar_t* __restrict__ mul, 83 | const scalar_t* __restrict__ scales, 84 | const int* __restrict__ zeros, 85 | const int* __restrict__ g_idx, 86 | int batch, 87 | int vec_height, 88 | int height, 89 | int width, 90 | int zero_width 91 | ); 92 | 93 | const int BLOCKWIDTH = 256; 94 | const int BLOCKHEIGHT2 = 16; 95 | const int BLOCKHEIGHT3 = 24; 96 | const int BLOCKHEIGHT4 = 32; 97 | const int BLOCKHEIGHT8 = 64; 98 | 99 | __device__ inline unsigned int as_unsigned(int i) { 100 | return *reinterpret_cast(&i); 101 | } 102 | 103 | __device__ inline int as_int(int i) { 104 | return *reinterpret_cast(&i); 105 | } 106 | 107 | 108 | void vecquant2matmul_cuda( 109 | torch::Tensor vec, 110 | torch::Tensor mat, 111 | torch::Tensor mul, 112 | torch::Tensor scales, 113 | torch::Tensor zeros, 114 | torch::Tensor g_idx 115 | ) { 116 | int batch = vec.size(0); 117 | int vec_height = vec.size(1); 118 | int height = mat.size(0); 119 | int width = mat.size(1); 120 | int zero_width = zeros.size(1); 121 | 122 | dim3 blocks( 123 | (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, 124 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 125 | ); 126 | dim3 threads(BLOCKWIDTH); 127 | 128 | AT_DISPATCH_FLOATING_TYPES( 129 | vec.type(), "vecquant2matmul_cuda", ([&] { 130 | VecQuant2MatMulKernel<<>>( 131 | vec.data(), mat.data(), mul.data(), 132 | scales.data(), zeros.data(), g_idx.data(), 133 | batch, vec_height, height, width, zero_width 134 | ); 135 | }) 136 | ); 137 | } 138 | 139 | template 140 | __global__ void VecQuant2MatMulKernel( 141 | const scalar_t* __restrict__ vec, 142 | const int* __restrict__ mat, 143 | scalar_t* __restrict__ mul, 144 | const scalar_t* __restrict__ scales, 145 | const int* __restrict__ zeros, 146 | const int* __restrict__ g_idx, 147 | int batch, 148 | int vec_height, 149 | int height, 150 | int width, 151 | int zero_width 152 | ) { 153 | int h = BLOCKHEIGHT2 * blockIdx.x; 154 | int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; 155 | 156 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 157 | int i = width * h + w; 158 | int g_h = h * 16; 159 | int k; 160 | unsigned int g; 161 | scalar_t w_tmp; 162 | 163 | int z_w = w / 16; 164 | int z_mod = (w % 16) * 2; 165 | 166 | float weight[BLOCKWIDTH]; 167 | 168 | for (k = 0; k < BLOCKWIDTH; ++k){ 169 | int k_w = (k / 16); 170 | int k_bit = (k % 16) * 2; 171 | 172 | g = as_int(g_idx[g_h + k]); 173 | scalar_t scale = scales[g * width + w]; 174 | scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); 175 | 176 | w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); 177 | 178 | weight[k] = scale * (w_tmp - zero); 179 | } 180 | 181 | scalar_t res; 182 | for (int b = 0; b < batch; ++b){ 183 | res = 0; 184 | 185 | blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; 186 | __syncthreads(); 187 | for (k = 0; k < BLOCKWIDTH; ++k){ 188 | res += weight[k] * blockvec[k]; 189 | } 190 | atomicAdd(&mul[b * width + w], res); 191 | __syncthreads(); 192 | } 193 | } 194 | 195 | void vecquant3matmul_cuda( 196 | torch::Tensor vec, 197 | torch::Tensor mat, 198 | torch::Tensor mul, 199 | torch::Tensor scales, 200 | torch::Tensor zeros, 201 | torch::Tensor g_idx 202 | ) { 203 | int batch = vec.size(0); 204 | int vec_height = vec.size(1); 205 | int height = mat.size(0); 206 | int width = mat.size(1); 207 | int zero_width = zeros.size(1); 208 | 209 | dim3 blocks( 210 | (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, 211 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 212 | ); 213 | dim3 threads(BLOCKWIDTH); 214 | 215 | AT_DISPATCH_FLOATING_TYPES( 216 | vec.type(), "vecquant3matmul_cuda", ([&] { 217 | VecQuant3MatMulKernel<<>>( 218 | vec.data(), mat.data(), mul.data(), 219 | scales.data(), zeros.data(), g_idx.data(), 220 | batch, vec_height, height, width, zero_width 221 | ); 222 | }) 223 | ); 224 | } 225 | 226 | template 227 | __global__ void VecQuant3MatMulKernel( 228 | const scalar_t* __restrict__ vec, 229 | const int* __restrict__ mat, 230 | scalar_t* __restrict__ mul, 231 | const scalar_t* __restrict__ scales, 232 | const int* __restrict__ zeros, 233 | const int* __restrict__ g_idx, 234 | int batch, 235 | int vec_height, 236 | int height, 237 | int width, 238 | int zero_width 239 | ) { 240 | int h = BLOCKHEIGHT3 * blockIdx.x; 241 | int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; 242 | 243 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 244 | int i = width * h + w; 245 | int g_h = (h / 3) * 32; 246 | int k; 247 | unsigned int g; 248 | scalar_t w_tmp; 249 | 250 | int z_w = (w / 32) * 3; 251 | int z_mod = w % 32; 252 | int z_bit; 253 | unsigned int z_tmp; 254 | if (z_mod != 10){ 255 | if (z_mod != 21){ 256 | z_bit = z_mod; 257 | if (z_bit > 21){ 258 | z_bit -= 22; 259 | z_bit *= 3; 260 | z_bit += 2; 261 | z_w += 2; 262 | } else if (z_bit > 10){ 263 | z_bit -= 11; 264 | z_bit *= 3; 265 | z_bit += 1; 266 | z_w += 1; 267 | } else { 268 | z_bit *= 3; 269 | } 270 | } else { 271 | z_w += 1; 272 | } 273 | } 274 | 275 | float weight[BLOCKWIDTH]; 276 | 277 | for (k = 0; k < BLOCKWIDTH; ++k){ 278 | int k_w = (k / 32) * 3; 279 | int k_mod = k % 32; 280 | int k_bit; 281 | 282 | if (k_mod != 10){ 283 | if (k_mod != 21){ 284 | k_bit = k_mod; 285 | if (k_bit > 21){ 286 | k_bit -= 22; 287 | k_bit *= 3; 288 | k_bit += 2; 289 | k_w += 2; 290 | } else if (k_bit > 10){ 291 | k_bit -= 11; 292 | k_bit *= 3; 293 | k_bit += 1; 294 | k_w += 1; 295 | } else { 296 | k_bit *= 3; 297 | } 298 | } else { 299 | k_w += 1; 300 | } 301 | } 302 | 303 | g = as_int(g_idx[g_h + k]); 304 | scalar_t scale = scales[g * width + w]; 305 | scalar_t zero; 306 | if (z_mod == 10) { 307 | z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); 308 | zero = scalar_t((z_tmp) + 1); 309 | } else if (z_mod == 21){ 310 | z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); 311 | zero = scalar_t((z_tmp) + 1); 312 | } else { 313 | zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); 314 | } 315 | 316 | if (k_mod == 10) { 317 | w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); 318 | } else if (k_mod == 21){ 319 | w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6); 320 | } else { 321 | w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7); 322 | } 323 | weight[k] = scale * (w_tmp - zero); 324 | } 325 | 326 | scalar_t res; 327 | for (int b = 0; b < batch; ++b){ 328 | res = 0; 329 | 330 | blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; 331 | __syncthreads(); 332 | for (k = 0; k < BLOCKWIDTH; ++k){ 333 | res += weight[k] * blockvec[k]; 334 | } 335 | atomicAdd(&mul[b * width + w], res); 336 | __syncthreads(); 337 | } 338 | } 339 | 340 | void vecquant4matmul_cuda( 341 | torch::Tensor vec, 342 | torch::Tensor mat, 343 | torch::Tensor mul, 344 | torch::Tensor scales, 345 | torch::Tensor zeros, 346 | torch::Tensor g_idx 347 | ) { 348 | int batch = vec.size(0); 349 | int vec_height = vec.size(1); 350 | int height = mat.size(0); 351 | int width = mat.size(1); 352 | int zero_width = zeros.size(1); 353 | 354 | dim3 blocks( 355 | (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, 356 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 357 | ); 358 | dim3 threads(BLOCKWIDTH); 359 | 360 | AT_DISPATCH_FLOATING_TYPES( 361 | vec.type(), "vecquant4matmul_cuda", ([&] { 362 | VecQuant4MatMulKernel<<>>( 363 | vec.data(), mat.data(), mul.data(), 364 | scales.data(), zeros.data(), g_idx.data(), 365 | batch, vec_height, height, width, zero_width 366 | ); 367 | }) 368 | ); 369 | } 370 | 371 | template 372 | __global__ void VecQuant4MatMulKernel( 373 | const scalar_t* __restrict__ vec, 374 | const int* __restrict__ mat, 375 | scalar_t* __restrict__ mul, 376 | const scalar_t* __restrict__ scales, 377 | const int* __restrict__ zeros, 378 | const int* __restrict__ g_idx, 379 | int batch, 380 | int vec_height, 381 | int height, 382 | int width, 383 | int zero_width 384 | ) { 385 | int h = BLOCKHEIGHT4 * blockIdx.x; 386 | int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; 387 | 388 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 389 | int i = width * h + w; 390 | int g_h = h * 8; 391 | int k; 392 | unsigned int g; 393 | scalar_t w_tmp; 394 | 395 | 396 | int z_w = w / 8; 397 | int z_mod = (w % 8) * 4; 398 | 399 | float weight[BLOCKWIDTH]; 400 | 401 | for (k = 0; k < BLOCKWIDTH; ++k){ 402 | int k_w = (k / 8); 403 | int k_bit = (k % 8) * 4; 404 | 405 | g = as_int(g_idx[g_h + k]); 406 | scalar_t scale = scales[g * width + w]; 407 | scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); 408 | 409 | w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); 410 | 411 | weight[k] = scale * (w_tmp - zero); 412 | } 413 | 414 | scalar_t res; 415 | for (int b = 0; b < batch; ++b){ 416 | res = 0; 417 | 418 | blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; 419 | __syncthreads(); 420 | for (k = 0; k < BLOCKWIDTH; ++k){ 421 | res += weight[k] * blockvec[k]; 422 | } 423 | atomicAdd(&mul[b * width + w], res); 424 | __syncthreads(); 425 | } 426 | } 427 | 428 | void vecquant8matmul_cuda( 429 | torch::Tensor vec, 430 | torch::Tensor mat, 431 | torch::Tensor mul, 432 | torch::Tensor scales, 433 | torch::Tensor zeros, 434 | torch::Tensor g_idx 435 | ) { 436 | int batch = vec.size(0); 437 | int vec_height = vec.size(1); 438 | int height = mat.size(0); 439 | int width = mat.size(1); 440 | int zero_width = zeros.size(1); 441 | 442 | dim3 blocks( 443 | (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, 444 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 445 | ); 446 | dim3 threads(BLOCKWIDTH); 447 | 448 | AT_DISPATCH_FLOATING_TYPES( 449 | vec.type(), "vecquant8matmul_cuda", ([&] { 450 | VecQuant8MatMulKernel<<>>( 451 | vec.data(), mat.data(), mul.data(), 452 | scales.data(), zeros.data(), g_idx.data(), 453 | batch, vec_height, height, width, zero_width 454 | ); 455 | }) 456 | ); 457 | } 458 | 459 | template 460 | __global__ void VecQuant8MatMulKernel( 461 | const scalar_t* __restrict__ vec, 462 | const int* __restrict__ mat, 463 | scalar_t* __restrict__ mul, 464 | const scalar_t* __restrict__ scales, 465 | const int* __restrict__ zeros, 466 | const int* __restrict__ g_idx, 467 | int batch, 468 | int vec_height, 469 | int height, 470 | int width, 471 | int zero_width 472 | ) { 473 | int h = BLOCKHEIGHT8 * blockIdx.x; 474 | int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; 475 | 476 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 477 | int i = width * h + w; 478 | int g_h = h * 4; 479 | int k; 480 | unsigned int g; 481 | scalar_t w_tmp; 482 | 483 | int z_w = w / 4; 484 | int z_mod = (w % 4) * 8; 485 | 486 | float weight[BLOCKWIDTH]; 487 | 488 | for (k = 0; k < BLOCKWIDTH; ++k){ 489 | int k_w = (k / 4); 490 | int k_bit = (k % 4) * 8; 491 | 492 | g = as_int(g_idx[g_h + k]); 493 | scalar_t scale = scales[g * width + w]; 494 | scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); 495 | 496 | w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); 497 | 498 | weight[k] = scale * (w_tmp - zero); 499 | } 500 | 501 | scalar_t res; 502 | for (int b = 0; b < batch; ++b){ 503 | res = 0; 504 | 505 | blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; 506 | __syncthreads(); 507 | for (k = 0; k < BLOCKWIDTH; ++k){ 508 | res += weight[k] * blockvec[k]; 509 | } 510 | atomicAdd(&mul[b * width + w], res); 511 | __syncthreads(); 512 | } 513 | } 514 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | safetensors==0.3.0 2 | datasets==2.10.1 3 | sentencepiece 4 | git+https://github.com/huggingface/transformers 5 | accelerate==0.17.1 6 | -------------------------------------------------------------------------------- /santacoder.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | from transformers import AutoConfig, AutoModelForCausalLM 7 | 8 | from gptq import * 9 | from modelutils import * 10 | from quant import * 11 | 12 | torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None 13 | torch.nn.init.uniform_ = lambda *args, **kwargs: None 14 | torch.nn.init.normal_ = lambda *args, **kwargs: None 15 | 16 | 17 | def get_santacoder(model, wbits): 18 | if wbits == 16: 19 | torch_dtype = torch.bfloat16 20 | else: 21 | torch_dtype = torch.float32 22 | 23 | model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch_dtype) 24 | model.seqlen = 2048 25 | return model 26 | 27 | 28 | def setup(nsamples, model, batch_iterator, dev): 29 | model.config.use_cache = False 30 | layers = model.transformer.h 31 | 32 | model.transformer.wte = model.transformer.wte.to(dev) 33 | model.transformer.wpe = model.transformer.wpe.to(dev) 34 | model.transformer.ln_f = model.transformer.ln_f.to(dev) 35 | layers[0] = layers[0].to(dev) 36 | 37 | dtype = next(iter(model.parameters())).dtype 38 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 39 | cache = {"i": 0, "attention_mask": None} 40 | 41 | class Catcher(nn.Module): 42 | def __init__(self, module): 43 | super().__init__() 44 | self.module = module 45 | 46 | def forward(self, inp, **kwargs): 47 | inps[cache["i"]] = inp 48 | cache["i"] += 1 49 | cache["attention_mask"] = kwargs["attention_mask"] 50 | raise ValueError 51 | 52 | layers[0] = Catcher(layers[0]) 53 | 54 | for batch in batch_iterator: 55 | try: 56 | model(batch.to(dev)) 57 | except ValueError: 58 | pass 59 | 60 | layers[0] = layers[0].module 61 | 62 | model.transformer.wte = model.transformer.wte.cpu() 63 | model.transformer.wpe = model.transformer.wpe.cpu() 64 | model.transformer.ln_f = model.transformer.ln_f.cpu() 65 | layers[0] = layers[0].cpu() 66 | 67 | torch.cuda.empty_cache() 68 | 69 | outs = torch.zeros_like(inps) 70 | attention_mask = cache["attention_mask"].to(dev) 71 | 72 | return layers, inps, outs, attention_mask 73 | 74 | 75 | @torch.no_grad() 76 | def santacoder_sequential(model, dataloader, dev, level): 77 | def get_batch_iterator(data, nsamples): 78 | for batch in data: 79 | yield batch[0] 80 | 81 | use_cache = model.config.use_cache 82 | 83 | print("Starting ...") 84 | layers, inps, outs, attention_mask = setup( 85 | args.nsamples, model, get_batch_iterator(dataloader, args.nsamples), dev 86 | ) 87 | print("Ready.") 88 | 89 | quantizers = {} 90 | for i in tqdm(range(len(layers))): 91 | print(f"layer {i}") 92 | 93 | layer = layers[i].to(dev) 94 | full = find_layers(layer) 95 | if args.true_sequential: 96 | sequential = [ 97 | ["attn.c_attn", "attn.c_proj"], 98 | ["mlp.c_fc"], 99 | ["mlp.c_proj"], 100 | ] 101 | if level >= 0: 102 | sequential = sequential[:level] 103 | else: 104 | sequential = sequential[level:] 105 | print("quantization target =", sequential) 106 | else: 107 | sequential = [list(full.keys())] 108 | 109 | for names in sequential: 110 | subset = {n: full[n] for n in names} 111 | gptq = {} 112 | for name in subset: 113 | gptq[name] = GPTQ(subset[name]) 114 | gptq[name].quantizer = Quantizer() 115 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 116 | 117 | def add_batch(name): 118 | def tmp(_, inp, out): 119 | gptq[name].add_batch(inp[0].data, out.data) 120 | 121 | return tmp 122 | 123 | handles = [] 124 | for name in subset: 125 | handles.append(subset[name].register_forward_hook(add_batch(name))) 126 | for j in range(args.nsamples): 127 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 128 | for h in handles: 129 | h.remove() 130 | 131 | for name in subset: 132 | print(f"Quantizing {name} in layer {i+1}/{len(layers)}...") 133 | scale, zero, g_idx = gptq[name].fasterquant( 134 | percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order 135 | ) 136 | quantizers["transformer.h.%d.%s" % (i, name)] = ( 137 | gptq[name].quantizer.cpu(), 138 | scale.cpu(), 139 | zero.cpu(), 140 | g_idx.cpu(), 141 | ) 142 | gptq[name].free() 143 | 144 | for j in range(args.nsamples): 145 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 146 | 147 | layers[i] = layer.cpu() 148 | del layer 149 | del gptq 150 | torch.cuda.empty_cache() 151 | 152 | inps, outs = outs, inps 153 | 154 | model.config.use_cache = use_cache 155 | 156 | return quantizers 157 | 158 | 159 | @torch.no_grad() 160 | def santacoder_eval(model, testenc, dev, dataset_name): 161 | def get_batch_iterator(data, nsamples): 162 | for i in range(nsamples): 163 | yield data[:, (i * model.seqlen) : ((i + 1) * model.seqlen)] 164 | 165 | print("Evaluating ...") 166 | 167 | if dataset_name != "stack": 168 | testenc = testenc.input_ids 169 | nsamples = testenc.numel() // model.seqlen 170 | 171 | use_cache = model.config.use_cache 172 | 173 | layers, inps, outs, attention_mask = setup(nsamples, model, get_batch_iterator(testenc, nsamples), dev) 174 | 175 | for i in range(len(layers)): 176 | layer = layers[i].to(dev) 177 | 178 | if args.nearest: 179 | subset = find_layers(layer) 180 | for name in subset: 181 | quantizer = Quantizer() 182 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 183 | W = subset[name].weight.data 184 | quantizer.find_params(W, weight=True) 185 | subset[name].weight.data = quantize(W, quantizer.scale, quantizer.zero, quantizer.maxq).to( 186 | next(iter(layer.parameters())).dtype 187 | ) 188 | 189 | for j in range(nsamples): 190 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 191 | layers[i] = layer.cpu() 192 | del layer 193 | torch.cuda.empty_cache() 194 | inps, outs = outs, inps 195 | 196 | if model.transformer.ln_f is not None: 197 | model.transformer.ln_f = model.transformer.ln_f.to(dev) 198 | model.lm_head = model.lm_head.to(dev) 199 | 200 | testenc = testenc.to(dev) 201 | nlls = [] 202 | for i in range(nsamples): 203 | hidden_states = inps[i].unsqueeze(0) 204 | if model.transformer.ln_f is not None: 205 | hidden_states = model.transformer.ln_f(hidden_states) 206 | lm_logits = model.lm_head(hidden_states) 207 | shift_logits = lm_logits[:, :-1, :].contiguous() 208 | shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] 209 | loss_fct = nn.CrossEntropyLoss() 210 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 211 | neg_log_likelihood = loss.float() * model.seqlen 212 | nlls.append(neg_log_likelihood) 213 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 214 | print(ppl.item()) 215 | 216 | model.config.use_cache = use_cache 217 | 218 | 219 | # TODO: perform packing on GPU 220 | def santacoder_pack(model, quantizers, wbits, groupsize): 221 | layers = find_layers(model) 222 | layers = {n: layers[n] for n in quantizers} 223 | make_quant(model, quantizers, wbits, groupsize) 224 | qlayers = find_layers(model, [QuantLinear]) 225 | print("Packing ...") 226 | for name in qlayers: 227 | print(name) 228 | quantizers[name], scale, zero, g_idx = quantizers[name] 229 | qlayers[name].pack(layers[name], scale, zero, g_idx) 230 | print("Done.") 231 | return model 232 | 233 | 234 | def load_quant(model, checkpoint, wbits, groupsize=-1): 235 | config = AutoConfig.from_pretrained(model) 236 | 237 | torch.set_default_dtype(torch.half) 238 | transformers.modeling_utils._init_weights = False 239 | torch.set_default_dtype(torch.half) 240 | model = AutoModelForCausalLM.from_config(config) 241 | torch.set_default_dtype(torch.float) 242 | model = model.eval() 243 | layers = find_layers(model) 244 | for name in ["lm_head"]: 245 | if name in layers: 246 | del layers[name] 247 | make_quant(model, layers, wbits, groupsize) 248 | 249 | del layers 250 | 251 | print("Loading model ...") 252 | if checkpoint.endswith(".safetensors"): 253 | from safetensors.torch import load_file as safe_load 254 | 255 | model.load_state_dict(safe_load(checkpoint), strict=False) 256 | else: 257 | model.load_state_dict(torch.load(checkpoint), strict=False) 258 | model.seqlen = 2048 259 | print("Done.") 260 | 261 | return model 262 | 263 | 264 | def benchmark(model, input_ids, check=False): 265 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, "gpus") else DEV) 266 | torch.cuda.synchronize() 267 | 268 | cache = {"past": None} 269 | 270 | def clear_past(i): 271 | def tmp(layer, inp, out): 272 | if cache["past"]: 273 | cache["past"][i] = None 274 | 275 | return tmp 276 | 277 | for i, layer in enumerate(model.transformer.h): 278 | layer.register_forward_hook(clear_past(i)) 279 | 280 | print("Benchmarking ...") 281 | 282 | if check: 283 | loss = nn.CrossEntropyLoss() 284 | tot = 0.0 285 | 286 | def sync(): 287 | if hasattr(model, "gpus"): 288 | for gpu in model.gpus: 289 | torch.cuda.synchronize(gpu) 290 | else: 291 | torch.cuda.synchronize() 292 | 293 | max_memory = 0 294 | with torch.no_grad(): 295 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV) 296 | times = [] 297 | for i in range(input_ids.numel()): 298 | tick = time.time() 299 | out = model( 300 | input_ids[:, i : i + 1], 301 | past_key_values=cache["past"], 302 | attention_mask=attention_mask[:, : (i + 1)].reshape((1, -1)), 303 | ) 304 | sync() 305 | times.append(time.time() - tick) 306 | print(i, times[-1]) 307 | max_memory = max(max_memory, torch.cuda.memory_allocated() / 1024 / 1024) 308 | if check and i != input_ids.numel() - 1: 309 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() 310 | cache["past"] = list(out.past_key_values) 311 | del out 312 | sync() 313 | import numpy as np 314 | 315 | print("Median:", np.median(times)) 316 | if check: 317 | print("PPL:", torch.exp(tot / (input_ids.numel() - 1)).item()) 318 | print("max memory(MiB):", max_memory) 319 | 320 | 321 | if __name__ == "__main__": 322 | import argparse 323 | 324 | from datautils import * 325 | 326 | parser = argparse.ArgumentParser() 327 | 328 | parser.add_argument("model", type=str, help="model to load") 329 | parser.add_argument( 330 | "dataset", 331 | type=str, 332 | choices=["wikitext2", "ptb", "c4", "stack"], 333 | help="Where to extract calibration data from.", 334 | ) 335 | parser.add_argument("--seed", type=int, default=0, help="Seed for sampling the calibration data.") 336 | parser.add_argument("--nsamples", type=int, default=128, help="Number of calibration data samples.") 337 | parser.add_argument( 338 | "--percdamp", type=float, default=0.01, help="Percent of the average Hessian diagonal to use for dampening." 339 | ) 340 | parser.add_argument("--nearest", action="store_true", help="Whether to run the RTN baseline.") 341 | parser.add_argument( 342 | "--wbits", 343 | type=int, 344 | default=32, 345 | choices=[2, 3, 4, 8, 16, 32], 346 | help="#bits to use for quantization; use 16 for evaluating base model.", 347 | ) 348 | parser.add_argument("--trits", action="store_true", help="Whether to use trits for quantization.") 349 | parser.add_argument( 350 | "--groupsize", type=int, default=-1, help="Groupsize to use for quantization; default uses full row." 351 | ) 352 | parser.add_argument("--eval", action="store_true", help="evaluate quantized model.") 353 | parser.add_argument("--save", type=str, default="", help="Save quantized checkpoint under this name.") 354 | parser.add_argument( 355 | "--save_safetensors", type=str, default="", help="Save quantized `.safetensors` checkpoint under this name." 356 | ) 357 | parser.add_argument("--load", type=str, default="", help="Load quantized model.") 358 | parser.add_argument("--benchmark", type=int, default=0, help="Number of tokens to use for benchmarking.") 359 | parser.add_argument( 360 | "--check", action="store_true", help="Whether to compute perplexity during benchmarking for verification." 361 | ) 362 | parser.add_argument("--sym", action="store_true", help="Whether to perform symmetric quantization.") 363 | parser.add_argument( 364 | "--act-order", action="store_true", help="Whether to apply the activation order GPTQ heuristic" 365 | ) 366 | parser.add_argument("--true-sequential", action="store_true", help="Whether to run in true sequential model.") 367 | parser.add_argument( 368 | "--optimization-level", 369 | type=int, 370 | choices=[-2, -1, 1, 2, 3], 371 | help="Whether to run in true sequential model.", 372 | ) 373 | parser.add_argument("--new-eval", action="store_true", help="Whether to use the new PTB and C4 eval") 374 | 375 | args = parser.parse_args() 376 | 377 | if type(args.load) is not str: 378 | args.load = args.load.as_posix() 379 | 380 | if args.load: 381 | model = load_quant(args.model, args.load, args.wbits, args.groupsize) 382 | else: 383 | model = get_santacoder(args.model, args.wbits) 384 | model.eval() 385 | 386 | dataloader, testloader = get_loaders( 387 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 388 | ) 389 | 390 | if not args.load and args.wbits < 16 and not args.nearest: 391 | tick = time.time() 392 | quantizers = santacoder_sequential(model, dataloader, DEV, args.optimization_level) 393 | print(time.time() - tick) 394 | santacoder_pack(model, quantizers, args.wbits, args.groupsize) 395 | 396 | if args.benchmark: 397 | model = model.to(DEV) 398 | if args.benchmark: 399 | input_ids = next(iter(dataloader))[0][:, : args.benchmark] 400 | benchmark(model, input_ids, check=args.check) 401 | 402 | if args.eval: 403 | datasets = ["wikitext2", "ptb", "c4", "stack"] 404 | if args.new_eval: 405 | datasets = ["wikitext2", "ptb-new", "c4-new"] 406 | for dataset in datasets: 407 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) 408 | print(dataset) 409 | santacoder_eval(model, testloader, DEV, dataset) 410 | 411 | if args.save: 412 | torch.save(model.state_dict(), args.save) 413 | 414 | if args.save_safetensors: 415 | from safetensors.torch import save_file as safe_save 416 | 417 | safe_save(model.state_dict(), args.save_safetensors) 418 | -------------------------------------------------------------------------------- /santacoder_inference.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | 4 | import termcolor 5 | import torch 6 | import transformers 7 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 8 | 9 | from gptq import * 10 | from modelutils import * 11 | from quant import * 12 | 13 | 14 | def disable_torch_init(): 15 | def noop(*args, **kwargs): 16 | pass 17 | 18 | torch.nn.init.kaiming_uniform_ = noop 19 | torch.nn.init.uniform_ = noop 20 | torch.nn.init.normal_ = noop 21 | transformers.modeling_utils._init_weights = False 22 | 23 | 24 | def get_santacoder(model, checkpoint, wbits, groupsize): 25 | if wbits == 16: 26 | torch_dtype = torch.bfloat16 27 | else: 28 | torch_dtype = torch.float32 29 | 30 | if checkpoint is None: 31 | # Load full model with weights 32 | model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch_dtype) 33 | else: 34 | # Load only models without weights 35 | config = AutoConfig.from_pretrained(model) 36 | model = AutoModelForCausalLM.from_config(config, torch_dtype=torch_dtype) 37 | model = model.eval() 38 | 39 | if wbits < 16: 40 | layers = find_layers(model) 41 | for name in ["lm_head"]: 42 | if name in layers: 43 | del layers[name] 44 | make_quant(model, layers, wbits, groupsize) 45 | 46 | model.load_state_dict(torch.load(checkpoint)) 47 | 48 | model.seqlen = 2048 49 | model = model.cuda() 50 | return model 51 | 52 | 53 | def simple_generation_test(tokenizer, model, prompt): 54 | batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) 55 | batch = {k: v.cuda() for k, v in batch.items()} 56 | 57 | for _ in range(2): 58 | print("generating...") 59 | t1 = time.time() 60 | generated = model.generate(batch["input_ids"], do_sample=False, min_new_tokens=100, max_new_tokens=100) 61 | t2 = time.time() 62 | print(termcolor.colored(tokenizer.decode(generated[0]), "yellow")) 63 | print("generated in %0.2fms" % ((t2 - t1) * 1000)) 64 | 65 | print("prompt tokens", len(batch["input_ids"][0])) 66 | print("all tokens", len(generated[0])) 67 | 68 | generated_tokens = len(generated[0]) - len(batch["input_ids"][0]) 69 | print("%0.1fms per token" % (((t2 - t1) * 1000) / generated_tokens)) 70 | 71 | 72 | def main(): 73 | parser = ArgumentParser() 74 | parser.add_argument("model", type=str, help="model to load, such as bigcode/gpt_bigcode-santacoder") 75 | parser.add_argument("--load", type=str, help="load a quantized checkpoint, use normal model if not specified") 76 | parser.add_argument("--wbits", type=int, default=16, help="bits in quantization checkpoint") 77 | parser.add_argument( 78 | "--groupsize", type=int, default=-1, help="Groupsize to use for quantization; default uses full row." 79 | ) 80 | parser.add_argument("--prompt", type=str, default="pygame example\n\n```", help="prompt the model") 81 | args = parser.parse_args() 82 | 83 | disable_torch_init() 84 | 85 | t1 = time.time() 86 | model = get_santacoder(args.model, args.load, args.wbits, args.groupsize) 87 | t2 = time.time() 88 | print("model load time %0.1fms" % ((t2 - t1) * 1000)) 89 | 90 | tokenizer = AutoTokenizer.from_pretrained(args.model) 91 | 92 | simple_generation_test(tokenizer, model, args.prompt) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /scripts/santacoder-16bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | bits=16 4 | 5 | # evaluate perplexity of bf16 6 | python santacoder.py bigcode/gpt_bigcode-santacoder $dataset --nsamples $nsamples --eval --wbits $bits 7 | -------------------------------------------------------------------------------- /scripts/santacoder-32bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | bits=32 4 | 5 | # evaluate perplexity of fp32 6 | python santacoder.py bigcode/gpt_bigcode-santacoder $dataset --nsamples $nsamples --eval --wbits $bits 7 | -------------------------------------------------------------------------------- /scripts/santacoder-4bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | opt_level=3 4 | groupsize=-1 5 | bits=4 6 | 7 | mkdir -p models/$bits-bit 8 | 9 | # remove --eval if you dont want to evaluate perplexity of the model 10 | python santacoder.py bigcode/gpt_bigcode-santacoder $dataset --nsamples $nsamples --eval --wbits $bits --act-order --groupsize $groupsize --optimization-level $opt_level --true-sequential --save models/$bits-bit/model.pt 11 | -------------------------------------------------------------------------------- /scripts/santacoder-8bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | opt_level=3 4 | groupsize=-1 5 | bits=8 6 | 7 | mkdir -p models/$bits-bit 8 | 9 | # remove --eval if you dont want to evaluate perplexity of the model 10 | python santacoder.py bigcode/gpt_bigcode-santacoder $dataset --nsamples $nsamples --eval --wbits $bits --act-order --groupsize $groupsize --optimization-level $opt_level --true-sequential --save models/$bits-bit/model.pt 11 | -------------------------------------------------------------------------------- /scripts/starcoder-16bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | bits=16 4 | 5 | # evaluate perplexity of bf16 6 | python santacoder.py bigcode/starcoder $dataset --nsamples $nsamples --eval --wbits $bits 7 | -------------------------------------------------------------------------------- /scripts/starcoder-32bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | bits=32 4 | 5 | # evaluate perplexity of fp32 6 | python santacoder.py bigcode/starcoder $dataset --nsamples $nsamples --eval --wbits $bits 7 | -------------------------------------------------------------------------------- /scripts/starcoder-4bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=150 3 | opt_level=3 4 | groupsize=128 5 | bits=4 6 | 7 | mkdir -p models/$bits-bit 8 | 9 | # remove --eval if you dont want to evaluate perplexity of the model 10 | python santacoder.py bigcode/starcoder $dataset --nsamples $nsamples --eval --wbits $bits --act-order --groupsize $groupsize --optimization-level $opt_level --true-sequential --save models/$bits-bit/model.pt 11 | -------------------------------------------------------------------------------- /scripts/starcoder-8bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | opt_level=3 4 | groupsize=128 5 | bits=8 6 | 7 | mkdir -p models/$bits-bit 8 | 9 | # remove --eval if you dont want to evaluate perplexity of the model 10 | python santacoder.py bigcode/starcoder $dataset --nsamples $nsamples --eval --wbits $bits --act-order --groupsize $groupsize --optimization-level $opt_level --true-sequential --save models/$bits-bit/model.pt 11 | -------------------------------------------------------------------------------- /scripts/starcoderbase-16bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | bits=16 4 | 5 | # evaluate perplexity of bf16 6 | python santacoder.py bigcode/starcoderbase $dataset --nsamples $nsamples --eval --wbits $bits 7 | -------------------------------------------------------------------------------- /scripts/starcoderbase-32bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | bits=32 4 | 5 | # evaluate perplexity of fp32 6 | python santacoder.py bigcode/starcoderbase $dataset --nsamples $nsamples --eval --wbits $bits 7 | -------------------------------------------------------------------------------- /scripts/starcoderbase-4bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=150 3 | opt_level=3 4 | groupsize=128 5 | bits=4 6 | 7 | mkdir -p models/$bits-bit 8 | 9 | # remove --eval if you dont want to evaluate perplexity of the model 10 | python santacoder.py bigcode/starcoderbase $dataset --nsamples $nsamples --eval --wbits $bits --act-order --groupsize $groupsize --optimization-level $opt_level --true-sequential --save models/$bits-bit/model.pt 11 | -------------------------------------------------------------------------------- /scripts/starcoderbase-8bit.sh: -------------------------------------------------------------------------------- 1 | dataset=stack 2 | nsamples=128 3 | opt_level=3 4 | groupsize=128 5 | bits=8 6 | 7 | mkdir -p models/$bits-bit 8 | 9 | # remove --eval if you dont want to evaluate perplexity of the model 10 | python santacoder.py bigcode/starcoderbase $dataset --nsamples $nsamples --eval --wbits $bits --act-order --groupsize $groupsize --optimization-level $opt_level --true-sequential --save models/$bits-bit/model.pt 11 | -------------------------------------------------------------------------------- /setup_cuda.py: -------------------------------------------------------------------------------- 1 | from setuptools import Extension, setup 2 | from torch.utils import cpp_extension 3 | 4 | setup( 5 | name="quant_cuda", 6 | ext_modules=[cpp_extension.CUDAExtension("quant_cuda", ["quant_cuda.cpp", "quant_cuda_kernel.cu"])], 7 | cmdclass={"build_ext": cpp_extension.BuildExtension}, 8 | ) 9 | -------------------------------------------------------------------------------- /share_tensors_across_processes.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import pickle 4 | 5 | import accelerate 6 | import torch 7 | import tqdm 8 | 9 | accelerate_load_checkpoint_and_dispatch = accelerate.load_checkpoint_and_dispatch 10 | 11 | 12 | def load_checkpoint_shared_and_dispatch( 13 | model, checkpoint, device_map="auto", max_memory=None, no_split_module_classes=None 14 | ): # **kwparams): 15 | try: 16 | with open(f"{checkpoint}.shared", "rb") as shared: 17 | shared = pickle.load(shared) 18 | with tqdm.tqdm( 19 | shared["state_dict"].items(), 20 | desc="Importing shared tensors ...", 21 | unit="w", 22 | leave=False, 23 | ) as shared_tensors: 24 | os.kill(shared["pid"], 0) # raises if pid does not exist 25 | device_map = shared["device_map"] 26 | # offload_dir = shared['offload_folder'] 27 | # offload_buffers = shared['offload_buyffers'] 28 | # preload_module_classes = shared['preload_module_classes'] 29 | state_dict = { 30 | name: rebuild_tensor(*tensor_params) for name, (rebuild_tensor, tensor_params) in shared_tensors 31 | } 32 | model.load_state_dict(state_dict, strict=False) 33 | for param_name, param in state_dict.items(): 34 | module_name = param_name 35 | 36 | while len(module_name) > 0 and module_name not in device_map: 37 | module_name = ".".join(module_name.split(".")[:-1]) 38 | param_device = device_map[module_name] 39 | 40 | # if param_device == 'disk': 41 | # 42 | accelerate.utils.modeling.set_module_tensor_to_device( 43 | model, param_name, param_device, value=param 44 | ) # , **kwparams) 45 | except (FileNotFoundError, EOFError, KeyError, ProcessLookupError, RuntimeError): 46 | if device_map != "sequential": 47 | max_memory = accelerate.utils.get_balanced_memory( 48 | model, 49 | max_memory=max_memory, 50 | no_split_module_classes=no_split_module_classes, 51 | low_zero=(device_map == "balanced_low_0"), 52 | # **kwparams, 53 | ) 54 | if isinstance(device_map, str): 55 | device_map = accelerate.infer_auto_device_map( 56 | model, max_memory=max_memory, no_split_module_classes=no_split_module_classes # , **kwparams 57 | ) 58 | # if not kwparams.get('offload_state_dict') and device_map is not None and 'disk' in device_map.values(): 59 | # offload_state_dict = True 60 | accelerate.load_checkpoint_in_model(model, checkpoint, device_map=device_map) # , **kwparams) 61 | state_dict = model.state_dict() 62 | with open(f"{checkpoint}.shared", "wb") as shared, tqdm.tqdm( 63 | state_dict.items(), 64 | desc="Exporting shared tensors ...", 65 | unit="w", 66 | ) as state_dict_items: 67 | pickle.dump( 68 | { 69 | "pid": os.getpid(), 70 | "device_map": device_map, 71 | "state_dict": { 72 | name: torch.multiprocessing.reductions.reduce_tensor(tensor.share_memory_()) 73 | for name, tensor in state_dict_items 74 | }, 75 | }, 76 | shared, 77 | ) 78 | 79 | del state_dict 80 | gc.collect() 81 | 82 | if device_map is not None: 83 | model = accelerate.big_modeling.dispatch_model(model, device_map=device_map) # , **kwparams) 84 | return model 85 | 86 | 87 | accelerate.load_checkpoint_and_dispatch = load_checkpoint_shared_and_dispatch 88 | -------------------------------------------------------------------------------- /test_kernel.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import quant_cuda 4 | import torch 5 | import torch.nn as nn 6 | 7 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 8 | 9 | torch.backends.cuda.matmul.allow_tf32 = False 10 | torch.backends.cudnn.allow_tf32 = False 11 | 12 | print("Benchmarking LLaMa-7B FC2 matvec ...") 13 | 14 | DEV = torch.device("cuda:0") 15 | 16 | B = 4 17 | L = 512 18 | M = 4096 19 | N = 11008 20 | 21 | DTYPE = torch.half 22 | mat = torch.randn((M, N), device=DEV, dtype=DTYPE) 23 | vec = torch.randn((B, M), device=DEV, dtype=DTYPE) 24 | mul = torch.zeros((B, N), device=DEV, dtype=DTYPE) 25 | 26 | COUNT = 1000 27 | import time 28 | 29 | tick = time.time() 30 | for _ in range(COUNT): 31 | torch.matmul(vec, mat, out=mul) 32 | torch.cuda.synchronize() 33 | print("FP16:", (time.time() - tick) / COUNT) 34 | 35 | DTYPE = torch.float 36 | mat = mat.to(DTYPE) 37 | vec = vec.to(DTYPE) 38 | mul = mul.to(DTYPE) 39 | 40 | mat = torch.randint(-1000000000, 1000000000, (M // 32 * 2, N), device=DEV, dtype=torch.int) 41 | scales = torch.randn(N, device=DEV, dtype=DTYPE) 42 | zeros = torch.randint(-1000000000, 1000000000, (1, N // 32 * 2), device=DEV, dtype=torch.int32) 43 | g_idx = torch.zeros(M, device=DEV, dtype=torch.int32) 44 | COUNT = 1000 45 | import time 46 | 47 | vec = vec.float() 48 | tick = time.time() 49 | for _ in range(COUNT): 50 | quant_cuda.vecquant2matmul(vec, mat, mul, scales, zeros, g_idx) 51 | torch.cuda.synchronize() 52 | print("2bit:", (time.time() - tick) / COUNT) 53 | 54 | mat = torch.randint(-1000000000, 1000000000, (M // 32 * 3, N), device=DEV, dtype=torch.int) 55 | scales = torch.randn(N, device=DEV, dtype=DTYPE) 56 | zeros = torch.randint(-1000000000, 1000000000, (1, N // 32 * 3), device=DEV, dtype=torch.int32) 57 | 58 | vec = vec.float() 59 | tick = time.time() 60 | for _ in range(COUNT): 61 | quant_cuda.vecquant3matmul(vec, mat, mul, scales, zeros, g_idx) 62 | torch.cuda.synchronize() 63 | print("3bit:", (time.time() - tick) / COUNT) 64 | 65 | mat = torch.randint(-1000000000, 1000000000, (M // 32 * 4, N), device=DEV, dtype=torch.int) 66 | scales = torch.randn(N, device=DEV, dtype=DTYPE) 67 | zeros = torch.randint(-1000000000, 1000000000, (1, N // 32 * 4), device=DEV, dtype=torch.int32) 68 | 69 | vec = vec.float() 70 | tick = time.time() 71 | for _ in range(COUNT): 72 | quant_cuda.vecquant4matmul(vec, mat, mul, scales, zeros, g_idx) 73 | torch.cuda.synchronize() 74 | print("4bit:", (time.time() - tick) / COUNT) 75 | 76 | mat = torch.randint(-1000000000, 1000000000, (M // 32 * 8, N), device=DEV, dtype=torch.int) 77 | scales = torch.randn(N, device=DEV, dtype=DTYPE) 78 | zeros = torch.randint(-1000000000, 1000000000, (1, N // 32 * 8), device=DEV, dtype=torch.int32) 79 | 80 | vec = vec.float() 81 | tick = time.time() 82 | for _ in range(COUNT): 83 | quant_cuda.vecquant8matmul(vec, mat, mul, scales, zeros, g_idx) 84 | torch.cuda.synchronize() 85 | print("8bit:", (time.time() - tick) / COUNT) 86 | print("Verifiying kernel correctness ...") 87 | 88 | M = 4096 89 | N = 11008 90 | 91 | from quant import * 92 | 93 | layer = nn.Linear(M, N) 94 | vec = torch.randn(B, L, M).to(DEV).half() 95 | 96 | quantizer = Quantizer() 97 | quantizer.configure(2, perchannel=True, sym=False, mse=False) 98 | quantizer.find_params(layer.weight.data, weight=True) 99 | layer.weight.data = quantize(layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq) 100 | 101 | qlayer = QuantLinear( 102 | 2, -1, layer.in_features, layer.out_features, layer.bias is not None, kernel_switch_threshold=False, is_cuda=True 103 | ) 104 | qlayer.pack(layer, quantizer.scale, quantizer.zero) 105 | 106 | qlayer = qlayer.to(DEV) 107 | layer = layer.to(DEV).half() 108 | 109 | with torch.no_grad(): 110 | print("2bit Simu:", layer(vec)) 111 | print("2bit Kern:", qlayer(vec)) 112 | 113 | layer = nn.Linear(M, N) 114 | vec = torch.randn(B, L, M).to(DEV).half() 115 | 116 | quantizer = Quantizer() 117 | quantizer.configure(3, perchannel=True, sym=False, mse=False) 118 | quantizer.find_params(layer.weight.data, weight=True) 119 | layer.weight.data = quantize(layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq) 120 | 121 | qlayer = QuantLinear( 122 | 3, -1, layer.in_features, layer.out_features, layer.bias is not None, kernel_switch_threshold=False, is_cuda=True 123 | ) 124 | qlayer.pack(layer, quantizer.scale, quantizer.zero) 125 | 126 | qlayer = qlayer.to(DEV) 127 | layer = layer.to(DEV).half() 128 | 129 | with torch.no_grad(): 130 | print("3bit Simu:", layer(vec)) 131 | print("3bit Kern:", qlayer(vec)) 132 | 133 | layer = nn.Linear(M, N) 134 | vec = torch.randn(B, L, M).to(DEV).half() 135 | 136 | quantizer = Quantizer() 137 | quantizer.configure(4, perchannel=True, sym=False, mse=False) 138 | quantizer.find_params(layer.weight.data, weight=True) 139 | layer.weight.data = quantize(layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq) 140 | 141 | qlayer = QuantLinear( 142 | 4, -1, layer.in_features, layer.out_features, layer.bias is not None, kernel_switch_threshold=False, is_cuda=True 143 | ) 144 | qlayer.pack(layer, quantizer.scale, quantizer.zero) 145 | 146 | qlayer = qlayer.to(DEV) 147 | layer = layer.to(DEV).half() 148 | 149 | with torch.no_grad(): 150 | print("4bit Simu:", layer(vec)) 151 | print("4bit Kern:", qlayer(vec)) 152 | 153 | layer = nn.Linear(M, N) 154 | vec = torch.randn(B, L, M).to(DEV).half() 155 | 156 | quantizer = Quantizer() 157 | quantizer.configure(8, perchannel=True, sym=False, mse=False) 158 | quantizer.find_params(layer.weight.data, weight=True) 159 | layer.weight.data = quantize(layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq) 160 | 161 | qlayer = QuantLinear( 162 | 8, -1, layer.in_features, layer.out_features, layer.bias is not None, kernel_switch_threshold=False, is_cuda=True 163 | ) 164 | qlayer.pack(layer, quantizer.scale, quantizer.zero) 165 | 166 | qlayer = qlayer.to(DEV) 167 | layer = layer.to(DEV).half() 168 | 169 | with torch.no_grad(): 170 | print("8bit Simu:", layer(vec)) 171 | print("8bit Kern:", qlayer(vec)) 172 | --------------------------------------------------------------------------------