├── .gitignore ├── assets └── LeanLlama.jpg ├── leanquant ├── __init__.py └── leanquant_utils.py ├── requirements.txt ├── modelutils.py ├── setup.py ├── eval_quantized.py ├── quant.py ├── datautils.py ├── lean_quantizer.py ├── llama.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.log 3 | *.safetensors 4 | *egg-info* 5 | -------------------------------------------------------------------------------- /assets/LeanLlama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanModels/LeanQuant/HEAD/assets/LeanLlama.jpg -------------------------------------------------------------------------------- /leanquant/__init__.py: -------------------------------------------------------------------------------- 1 | from .leanquant_utils import LeanQuantModelForCausalLM, Sub4BitLinear, replace_with_quantizers -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.47.0 2 | numpy==1.26.4 3 | scikit-learn==1.5.2 4 | accelerate==0.34.2 5 | lm-eval==0.4.4 6 | safetensors==0.4.5 7 | tqdm==4.66.5 -------------------------------------------------------------------------------- /modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | DEV = torch.device('cuda:0') 6 | 7 | 8 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 9 | if type(module) in layers: 10 | return {name: module} 11 | res = {} 12 | for name1, child in module.named_children(): 13 | res.update(find_layers( 14 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 15 | )) 16 | return res 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="leanquant", 5 | version="0.1.1", 6 | author="Tianyi Zhang", 7 | author_email="tonyzhang617@gmail.com", 8 | description="The inference kernels for LeanQuant models.", 9 | packages=find_packages(), 10 | install_requires=[ 11 | "transformers>=4.38", 12 | "accelerate", 13 | "safetensors", 14 | "torch", 15 | ], 16 | extras_require={ 17 | "cuda11": [ 18 | "cupy-cuda11x", 19 | ], 20 | "cuda12": [ 21 | "cupy-cuda12x", 22 | ], 23 | }, 24 | ) 25 | -------------------------------------------------------------------------------- /eval_quantized.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from transformers import AutoTokenizer 5 | import lm_eval 6 | 7 | from leanquant import LeanQuantModelForCausalLM 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | 12 | parser.add_argument("--base_model_name_or_path", type=str, default="meta-llama/Llama-2-7b-hf") 13 | parser.add_argument("--leanquant_path", type=str, default="models/llama-2-7b_b3_e3_d0.1.safetensors") 14 | parser.add_argument("--bits", type=int, default=4) 15 | parser.add_argument("--use_bf16", action="store_true") 16 | parser.add_argument("--tasks", nargs='+', type=str, default=["mmlu"]) 17 | parser.add_argument("--eval_batch_size", type=int, default=4) 18 | args = parser.parse_args() 19 | print(args) 20 | 21 | tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path) 22 | model = LeanQuantModelForCausalLM.from_pretrained( 23 | args.base_model_name_or_path, args.leanquant_path, 24 | args.bits, torch_dtype=torch.bfloat16 if args.use_bf16 else torch.float16, 25 | device_map="auto", 26 | ) 27 | 28 | model_eval = lm_eval.models.huggingface.HFLM(model, tokenizer=tokenizer, device=model.device, batch_size=args.eval_batch_size, trust_remote_code=True) 29 | results = lm_eval.simple_evaluate(model=model_eval, tasks=args.tasks) 30 | print(results['results']) 31 | -------------------------------------------------------------------------------- /quant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def quantize(x, scale, zero, maxq): 7 | if maxq < 0: 8 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 9 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 10 | return scale * (q - zero) 11 | 12 | class Quantizer(nn.Module): 13 | 14 | def __init__(self, shape=1): 15 | super(Quantizer, self).__init__() 16 | self.register_buffer('maxq', torch.tensor(0)) 17 | self.register_buffer('scale', torch.zeros(shape)) 18 | self.register_buffer('zero', torch.zeros(shape)) 19 | 20 | def configure( 21 | self, 22 | bits, perchannel=False, sym=True, 23 | mse=False, norm=2.4, grid=100, maxshrink=.8, 24 | trits=False 25 | ): 26 | self.maxq = torch.tensor(2 ** bits - 1) 27 | self.perchannel = perchannel 28 | self.sym = sym 29 | self.mse = mse 30 | self.norm = norm 31 | self.grid = grid 32 | self.maxshrink = maxshrink 33 | if trits: 34 | self.maxq = torch.tensor(-1) 35 | 36 | def find_params(self, x, weight=False): 37 | dev = x.device 38 | self.maxq = self.maxq.to(dev) 39 | 40 | shape = x.shape 41 | if self.perchannel: 42 | if weight: 43 | x = x.flatten(1) 44 | else: 45 | if len(shape) == 4: 46 | x = x.permute([1, 0, 2, 3]) 47 | x = x.flatten(1) 48 | if len(shape) == 3: 49 | x = x.reshape((-1, shape[-1])).t() 50 | if len(shape) == 2: 51 | x = x.t() 52 | else: 53 | x = x.flatten().unsqueeze(0) 54 | 55 | tmp = torch.zeros(x.shape[0], device=dev) 56 | xmin = torch.minimum(x.min(1)[0], tmp) 57 | xmax = torch.maximum(x.max(1)[0], tmp) 58 | 59 | if self.sym: 60 | xmax = torch.maximum(torch.abs(xmin), xmax) 61 | tmp = xmin < 0 62 | if torch.any(tmp): 63 | xmin[tmp] = -xmax[tmp] 64 | tmp = (xmin == 0) & (xmax == 0) 65 | xmin[tmp] = -1 66 | xmax[tmp] = +1 67 | 68 | if self.maxq < 0: 69 | self.scale = xmax 70 | self.zero = xmin 71 | else: 72 | self.scale = (xmax - xmin) / self.maxq 73 | if self.sym: 74 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 75 | else: 76 | self.zero = torch.round(-xmin / self.scale) 77 | 78 | if self.mse: 79 | best = torch.full([x.shape[0]], float('inf'), device=dev) 80 | for i in range(int(self.maxshrink * self.grid)): 81 | p = 1 - i / self.grid 82 | xmin1 = p * xmin 83 | xmax1 = p * xmax 84 | scale1 = (xmax1 - xmin1) / self.maxq 85 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 86 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 87 | q -= x 88 | q.abs_() 89 | q.pow_(self.norm) 90 | err = torch.sum(q, 1) 91 | tmp = err < best 92 | if torch.any(tmp): 93 | best[tmp] = err[tmp] 94 | self.scale[tmp] = scale1[tmp] 95 | self.zero[tmp] = zero1[tmp] 96 | if not self.perchannel: 97 | if weight: 98 | tmp = shape[0] 99 | else: 100 | tmp = shape[1] if len(shape) != 3 else shape[2] 101 | self.scale = self.scale.repeat(tmp) 102 | self.zero = self.zero.repeat(tmp) 103 | 104 | if weight: 105 | shape = [-1] + [1] * (len(shape) - 1) 106 | self.scale = self.scale.reshape(shape) 107 | self.zero = self.zero.reshape(shape) 108 | return 109 | if len(shape) == 4: 110 | self.scale = self.scale.reshape((1, -1, 1, 1)) 111 | self.zero = self.zero.reshape((1, -1, 1, 1)) 112 | if len(shape) == 3: 113 | self.scale = self.scale.reshape((1, 1, -1)) 114 | self.zero = self.zero.reshape((1, 1, -1)) 115 | if len(shape) == 2: 116 | self.scale = self.scale.unsqueeze(0) 117 | self.zero = self.zero.unsqueeze(0) 118 | 119 | def quantize(self, x): 120 | if self.ready(): 121 | return quantize(x, self.scale, self.zero, self.maxq) 122 | return x 123 | 124 | def enabled(self): 125 | return self.maxq > 0 126 | 127 | def ready(self): 128 | return torch.all(self.scale != 0) 129 | -------------------------------------------------------------------------------- /leanquant/leanquant_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Union, Dict 3 | 4 | import torch 5 | from torch import nn 6 | import cupy as cp 7 | 8 | from transformers.modeling_utils import no_init_weights 9 | from transformers import AutoConfig, AutoModelForCausalLM 10 | from accelerate import infer_auto_device_map, dispatch_model 11 | from accelerate.utils import get_balanced_memory 12 | from safetensors import safe_open 13 | 14 | kernel_code = '''typedef unsigned char uint8_t; 15 | typedef unsigned short uint16_t; 16 | 17 | extern "C" 18 | __global__ void gather_sub4bit( 19 | const uint16_t* __restrict__ src, // nrows x 2^bits 20 | const uint8_t* __restrict__ codes, // nrows x (ncols / 2) 21 | uint16_t* __restrict__ dst, // nrows x ncols 22 | int bits, int ncols 23 | ) { 24 | extern __shared__ volatile uint16_t cache[]; // 2^bits 25 | 26 | const int row_id = blockIdx.x; 27 | const int thread_id = threadIdx.x; 28 | const int n_threads = blockDim.x; 29 | 30 | const int n_floats = 1 << bits; 31 | 32 | #pragma unroll 33 | for (int i = thread_id; i < n_floats; i += n_threads) { 34 | cache[i] = src[row_id * n_floats + i]; 35 | } 36 | __syncthreads(); 37 | 38 | for (int i = thread_id; i < ncols / 2; i += n_threads) { 39 | uint8_t code = codes[row_id * ncols / 2 + i]; 40 | dst[row_id * ncols + i * 2] = cache[code >> 4]; 41 | dst[row_id * ncols + i * 2 + 1] = cache[code & 0xf]; 42 | } 43 | }''' 44 | _gather_sub4bit = cp.RawKernel( 45 | kernel_code, 46 | 'gather_sub4bit' 47 | ) 48 | 49 | class Sub4BitLinear(nn.Module): 50 | def __init__(self, orig_weight, bits=4, quant_grid=None, weight_codes=None, dtype=torch.float16): 51 | super().__init__() 52 | 53 | if isinstance(orig_weight, torch.Tensor): 54 | rows, cols = orig_weight.shape 55 | self.quant_grid = nn.Parameter(torch.empty(rows, 2 ** bits, dtype=orig_weight.dtype)) 56 | weight_codes = torch.empty(rows, cols // 2, dtype=torch.uint8) 57 | self.register_buffer('weight_codes', weight_codes) 58 | elif isinstance(quant_grid, torch.Tensor) and isinstance(weight_codes, torch.Tensor): 59 | assert dtype == torch.float16 or dtype == torch.bfloat16 60 | assert weight_codes.dtype == torch.uint8 61 | 62 | quant_grid = quant_grid.squeeze() 63 | weight_codes = weight_codes.squeeze() 64 | 65 | self.quant_grid = nn.Parameter(quant_grid.to(dtype)) 66 | self.register_buffer('weight_codes', weight_codes) 67 | 68 | assert self.quant_grid.shape[0] == self.weight_codes.shape[0] 69 | bits = int(math.log2(self.quant_grid.shape[1])) 70 | assert bits <= 4 71 | assert (2 ** bits) == self.quant_grid.shape[1] 72 | else: 73 | assert False, "This function can be initialized using either an `orig_weight` tensor or a pair of `quant_grid` and `weight_codes` tensors." 74 | 75 | def forward(self, x): 76 | assert x.device.index == self.quant_grid.device.index 77 | 78 | rows, cols_div2 = self.weight_codes.shape 79 | cols = cols_div2 * 2 80 | bits = int(math.log2(self.quant_grid.shape[1])) 81 | 82 | W = torch.empty(rows, cols, dtype=self.quant_grid.dtype, device=x.device) 83 | blocks_per_grid = (rows, ) 84 | threads_per_block = (512, ) 85 | 86 | with cp.cuda.Device(x.device.index): 87 | _gather_sub4bit(grid=blocks_per_grid, block=threads_per_block, shared_mem=2 ** bits * 2, args=[ 88 | self.quant_grid.data_ptr(), self.weight_codes.data_ptr(), W.data_ptr(), bits, cols, 89 | ]) 90 | 91 | return torch.matmul(x, W.t()) 92 | 93 | def replace_with_quantizers(module, quantizers, name=''): 94 | if isinstance(module, Sub4BitLinear): 95 | return 96 | 97 | for attr in dir(module): 98 | tmp = getattr(module, attr) 99 | name1 = name + '.' + attr if name != '' else attr 100 | if name1 in quantizers.keys(): 101 | delattr(module, attr) 102 | setattr(module, attr, Sub4BitLinear( 103 | None, None, 104 | quantizers[name1][0], 105 | quantizers[name1][1], 106 | next(tmp.parameters()).dtype 107 | )) 108 | print(f'{name1}: weights replaced') 109 | del tmp 110 | 111 | for name1, child in module.named_children(): 112 | replace_with_quantizers(child, quantizers, name + '.' + name1 if name != '' else name1) 113 | 114 | def replace_layers(module, bits=4, dtype=torch.float16, name=''): 115 | """Recursively replace all nn.Linear layers with Sub4BitLinear in the model.""" 116 | for attr_name in dir(module): 117 | sub_module = getattr(module, attr_name) 118 | 119 | if isinstance(sub_module, nn.Linear): 120 | delattr(module, attr_name) 121 | setattr(module, attr_name, Sub4BitLinear( 122 | sub_module.weight.data, bits=bits, dtype=dtype, 123 | )) 124 | del sub_module 125 | 126 | for child_name, child in module.named_children(): 127 | replace_layers(child, bits, dtype, f"{name}.{child_name}" if name else child_name) 128 | 129 | 130 | class LeanQuantModelForCausalLM: 131 | @classmethod 132 | def from_pretrained( 133 | cls, base_model_name_or_path: str, quantized_model_path: str, bits=4, torch_dtype=torch.float16, 134 | device: Optional[str] = None, device_map: str = 'auto', 135 | max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, 136 | ): 137 | """Load a pre-trained model and apply quantized layers.""" 138 | with no_init_weights(): 139 | config = AutoConfig.from_pretrained(base_model_name_or_path) 140 | model = AutoModelForCausalLM.from_config(config, torch_dtype=torch_dtype) 141 | 142 | replace_layers(model.model, bits=bits, dtype=torch_dtype) 143 | 144 | state_dict = {} 145 | with safe_open(quantized_model_path, framework='pt', device='cpu') as f: 146 | for k in f.keys(): 147 | state_dict[k] = f.get_tensor(k) 148 | 149 | model.load_state_dict(state_dict) 150 | 151 | if isinstance(device, str): 152 | model = model.to(device) 153 | else: 154 | assert device_map == 'auto', "device_map should be 'auto' if no specific device is provided." 155 | no_split_classes = [type(model.model.layers[0]).__name__] 156 | max_memory = get_balanced_memory(model, max_memory=max_memory, no_split_module_classes=no_split_classes, dtype=torch_dtype) 157 | device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_classes) 158 | model = dispatch_model(model, device_map) 159 | 160 | # Check if any parts of the model are on CPU and warn the user 161 | if any(param.device.type == 'cpu' for param in model.parameters()): 162 | print("Warning: Some model layers are on CPU. For inference, ensure the model is fully loaded onto CUDA-compatible GPUs.") 163 | 164 | return model -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_seed(seed): 6 | np.random.seed(seed) 7 | torch.random.manual_seed(seed) 8 | 9 | 10 | def get_wikitext2(nsamples, seed, seqlen, model): 11 | from datasets import load_dataset 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train', trust_remote_code=True) 13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test', trust_remote_code=True) 14 | 15 | from transformers import AutoTokenizer 16 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 17 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 19 | 20 | import random 21 | random.seed(seed) 22 | trainloader = [] 23 | for _ in range(nsamples): 24 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 25 | j = i + seqlen 26 | inp = trainenc.input_ids[:, i:j] 27 | tar = inp.clone() 28 | tar[:, :-1] = -100 29 | trainloader.append((inp, tar)) 30 | return trainloader, testenc 31 | 32 | def get_ptb(nsamples, seed, seqlen, model): 33 | from datasets import load_dataset 34 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', trust_remote_code=True) 35 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation', trust_remote_code=True) 36 | 37 | from transformers import AutoTokenizer 38 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 39 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') 40 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 41 | 42 | import random 43 | random.seed(seed) 44 | trainloader = [] 45 | for _ in range(nsamples): 46 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 47 | j = i + seqlen 48 | inp = trainenc.input_ids[:, i:j] 49 | tar = inp.clone() 50 | tar[:, :-1] = -100 51 | trainloader.append((inp, tar)) 52 | return trainloader, testenc 53 | 54 | def get_c4(nsamples, seed, seqlen, model): 55 | from datasets import load_dataset 56 | traindata = load_dataset( 57 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', trust_remote_code=True 58 | ) 59 | valdata = load_dataset( 60 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', trust_remote_code=True 61 | ) 62 | 63 | from transformers import AutoTokenizer 64 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 65 | 66 | import random 67 | random.seed(seed) 68 | trainloader = [] 69 | for _ in range(nsamples): 70 | while True: 71 | i = random.randint(0, len(traindata) - 1) 72 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 73 | if trainenc.input_ids.shape[1] >= seqlen: 74 | break 75 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 76 | j = i + seqlen 77 | inp = trainenc.input_ids[:, i:j] 78 | tar = inp.clone() 79 | tar[:, :-1] = -100 80 | trainloader.append((inp, tar)) 81 | 82 | import random 83 | random.seed(0) 84 | valenc = [] 85 | for _ in range(256): 86 | while True: 87 | i = random.randint(0, len(valdata) - 1) 88 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 89 | if tmp.input_ids.shape[1] >= seqlen: 90 | break 91 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 92 | j = i + seqlen 93 | valenc.append(tmp.input_ids[:, i:j]) 94 | valenc = torch.hstack(valenc) 95 | class TokenizerWrapper: 96 | def __init__(self, input_ids): 97 | self.input_ids = input_ids 98 | valenc = TokenizerWrapper(valenc) 99 | 100 | return trainloader, valenc 101 | 102 | def get_ptb_new(nsamples, seed, seqlen, model): 103 | from datasets import load_dataset 104 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', trust_remote_code=True) 105 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test', trust_remote_code=True) 106 | 107 | from transformers import AutoTokenizer 108 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 109 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 110 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 111 | 112 | import random 113 | random.seed(seed) 114 | trainloader = [] 115 | for _ in range(nsamples): 116 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 117 | j = i + seqlen 118 | inp = trainenc.input_ids[:, i:j] 119 | tar = inp.clone() 120 | tar[:, :-1] = -100 121 | trainloader.append((inp, tar)) 122 | return trainloader, testenc 123 | 124 | def get_c4_new(nsamples, seed, seqlen, model): 125 | from datasets import load_dataset 126 | traindata = load_dataset( 127 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', trust_remote_code=True 128 | ) 129 | valdata = load_dataset( 130 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', trust_remote_code=True 131 | ) 132 | 133 | from transformers import AutoTokenizer 134 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 135 | 136 | import random 137 | random.seed(seed) 138 | trainloader = [] 139 | for _ in range(nsamples): 140 | while True: 141 | i = random.randint(0, len(traindata) - 1) 142 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 143 | if trainenc.input_ids.shape[1] >= seqlen: 144 | break 145 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 146 | j = i + seqlen 147 | inp = trainenc.input_ids[:, i:j] 148 | tar = inp.clone() 149 | tar[:, :-1] = -100 150 | trainloader.append((inp, tar)) 151 | 152 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 153 | valenc = valenc.input_ids[:, :(256 * seqlen)] 154 | 155 | class TokenizerWrapper: 156 | def __init__(self, input_ids): 157 | self.input_ids = input_ids 158 | valenc = TokenizerWrapper(valenc) 159 | 160 | return trainloader, valenc 161 | 162 | def get_c4_full(nsamples, seed, seqlen, model): 163 | from datasets import load_dataset 164 | traindata = load_dataset( 165 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', trust_remote_code=True 166 | ) 167 | valdata = load_dataset( 168 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', trust_remote_code=True 169 | ) 170 | 171 | from transformers import AutoTokenizer 172 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 173 | 174 | np.random.seed(seed) 175 | idx_perm = np.random.permutation(np.arange(len(traindata))).tolist() 176 | 177 | trainloader = [] 178 | for i in idx_perm: 179 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 180 | if trainenc.input_ids.shape[1] >= seqlen: 181 | inp = trainenc.input_ids[:, :seqlen] 182 | tar = inp.clone() 183 | tar[:, :-1] = -100 184 | trainloader.append((inp, tar)) 185 | if len(trainloader) >= nsamples: 186 | break 187 | 188 | print(f'Collected {len(trainloader)} calibration samples.') 189 | 190 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 191 | valenc = valenc.input_ids[:, :(256 * seqlen)] 192 | 193 | class TokenizerWrapper: 194 | def __init__(self, input_ids): 195 | self.input_ids = input_ids 196 | valenc = TokenizerWrapper(valenc) 197 | 198 | return trainloader, valenc 199 | 200 | 201 | def get_loaders( 202 | name, nsamples=128, seed=0, seqlen=2048, model='' 203 | ): 204 | if 'wikitext2' in name: 205 | return get_wikitext2(nsamples, seed, seqlen, model) 206 | if 'ptb' in name: 207 | if 'new' in name: 208 | return get_ptb_new(nsamples, seed, seqlen, model) 209 | return get_ptb(nsamples, seed, seqlen, model) 210 | if 'c4' in name: 211 | if 'new' in name: 212 | return get_c4_new(nsamples, seed, seqlen, model) 213 | if 'full' in name: 214 | return get_c4_full(nsamples, seed, seqlen, model) 215 | return get_c4(nsamples, seed, seqlen, model) 216 | -------------------------------------------------------------------------------- /lean_quantizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | 5 | import numpy as np 6 | from sklearn.cluster import KMeans 7 | from multiprocessing import Pool 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn as nn 12 | import transformers 13 | 14 | from quant import * 15 | 16 | 17 | DEBUG = False 18 | 19 | torch.backends.cuda.matmul.allow_tf32 = False 20 | torch.backends.cudnn.allow_tf32 = False 21 | 22 | def kmeans_fit(row_data): 23 | weights_np, sample_weight, n_cluster, random_seed = row_data 24 | kmeans = KMeans( 25 | n_clusters=n_cluster, 26 | init=np.linspace(weights_np.min(), weights_np.max(), num=n_cluster)[:, None] if n_cluster <= 8 else 'k-means++', 27 | n_init='auto', 28 | random_state=random_seed, 29 | max_iter=100, 30 | tol=1e-6, 31 | ).fit(weights_np, sample_weight=sample_weight) 32 | return kmeans.cluster_centers_.reshape(-1) 33 | 34 | pool = Pool(len(os.sched_getaffinity(0))) 35 | 36 | class LeanQuant: 37 | 38 | def __init__(self, layer): 39 | self.layer = layer 40 | self.dev = self.layer.weight.device 41 | W = layer.weight.data.clone() 42 | if isinstance(self.layer, nn.Conv2d): 43 | W = W.flatten(1) 44 | if isinstance(self.layer, transformers.Conv1D): 45 | W = W.t() 46 | self.rows = W.shape[0] 47 | self.columns = W.shape[1] 48 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 49 | self.nsamples = 0 50 | self.lut = None 51 | 52 | def add_batch(self, inp, out): 53 | if DEBUG: 54 | self.inp1 = inp 55 | self.out1 = out 56 | if len(inp.shape) == 2: 57 | inp = inp.unsqueeze(0) 58 | tmp = inp.shape[0] 59 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 60 | if len(inp.shape) == 3: 61 | inp = inp.reshape((-1, inp.shape[-1])) 62 | inp = inp.t() 63 | if isinstance(self.layer, nn.Conv2d): 64 | unfold = nn.Unfold( 65 | self.layer.kernel_size, 66 | dilation=self.layer.dilation, 67 | padding=self.layer.padding, 68 | stride=self.layer.stride 69 | ) 70 | inp = unfold(inp) 71 | inp = inp.permute([1, 0, 2]) 72 | inp = inp.flatten(1) 73 | self.H *= self.nsamples / (self.nsamples + tmp) 74 | self.nsamples += tmp 75 | # inp = inp.float() 76 | inp = math.sqrt(2 / self.nsamples) * inp.float() 77 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 78 | self.H += inp.matmul(inp.t()) 79 | 80 | def fasterquant( 81 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False, args=None, 82 | ): 83 | W = self.layer.weight.data.clone() 84 | if isinstance(self.layer, nn.Conv2d): 85 | W = W.flatten(1) 86 | if isinstance(self.layer, transformers.Conv1D): 87 | W = W.t() 88 | W = W.float() 89 | 90 | tick = time.time() 91 | 92 | if not self.quantizer.ready(): 93 | self.quantizer.find_params(W, weight=True) 94 | 95 | H = self.H 96 | del self.H 97 | dead = torch.diag(H) == 0 98 | H[dead, dead] = 1 99 | W[:, dead] = 0 100 | 101 | if static_groups: 102 | import copy 103 | groups = [] 104 | for i in range(0, self.columns, groupsize): 105 | quantizer = copy.deepcopy(self.quantizer) 106 | quantizer.find_params(W[:, i:(i + groupsize)], weight=True) 107 | groups.append(quantizer) 108 | 109 | if H.shape[0] >= args.offload_threshold: 110 | secondary_device = torch.device('cuda:1') 111 | H = H.to(secondary_device) 112 | 113 | if actorder: 114 | perm_H = torch.argsort(torch.diag(H), descending=True) 115 | perm = perm_H.to(W.device) 116 | W = W[:, perm] 117 | H = H[perm_H][:, perm_H] 118 | invperm = torch.argsort(perm) 119 | 120 | damp = percdamp * torch.mean(torch.diag(H)) 121 | diag = torch.arange(self.columns, device=H.device) 122 | H[diag, diag] += damp 123 | 124 | H = torch.linalg.cholesky(H) 125 | H = torch.cholesky_inverse(H) 126 | H = torch.linalg.cholesky(H, upper=True) 127 | 128 | if H.shape[0] >= args.offload_threshold: 129 | H = H.to(self.dev) 130 | 131 | Losses = torch.zeros_like(W) 132 | Q = torch.zeros_like(W) 133 | Q_codes = Q.to(torch.uint8).cpu() 134 | Hinv = H 135 | torch.cuda.empty_cache() 136 | 137 | if isinstance(args.exponent, float): 138 | kmeans_tasks = [] 139 | W_np = W.cpu().numpy() 140 | Hinv_diagonal_np = (torch.diagonal(Hinv) ** (-args.exponent)).cpu().numpy() 141 | for j in range(W_np.shape[0]): 142 | kmeans_tasks.append((W_np[j, :, None], Hinv_diagonal_np, 2 ** args.wbits, args.kmeans_seed)) 143 | kmeans_results = list(tqdm(pool.imap(kmeans_fit, kmeans_tasks), total=len(kmeans_tasks))) 144 | centroids = torch.from_numpy(np.stack(kmeans_results)).reshape(W.shape[0], 2 ** args.wbits).to(W.device) 145 | else: 146 | centroids = None 147 | 148 | for i1 in range(0, self.columns, blocksize): 149 | i2 = min(i1 + blocksize, self.columns) 150 | count = i2 - i1 151 | 152 | W1 = W[:, i1:i2].clone() 153 | Q1 = torch.zeros_like(W1) 154 | Err1 = torch.zeros_like(W1) 155 | Losses1 = torch.zeros_like(W1) 156 | Hinv1 = Hinv[i1:i2, i1:i2] 157 | 158 | for i in range(count): 159 | w = W1[:, i] 160 | d = Hinv1[i, i] 161 | 162 | if groupsize != -1: 163 | if not static_groups: 164 | if (i1 + i) % groupsize == 0: 165 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) 166 | else: 167 | idx = i1 + i 168 | if actorder: 169 | idx = perm[idx] 170 | self.quantizer = groups[idx // groupsize] 171 | 172 | if isinstance(centroids, torch.Tensor): 173 | codes = torch.argmin((centroids - w[:, None]).abs(), dim=1, keepdim=True) 174 | Q_codes[:, i1+i] = codes.flatten().to(torch.uint8).cpu() 175 | q = torch.gather(centroids, 1, codes).flatten() 176 | else: 177 | q = quantize( 178 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 179 | ).flatten() 180 | Q1[:, i] = q 181 | Losses1[:, i] = (w - q) ** 2 / d ** 2 182 | 183 | err1 = (w - q) / d 184 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 185 | Err1[:, i] = err1 186 | 187 | Q[:, i1:i2] = Q1 188 | Losses[:, i1:i2] = Losses1 / 2 189 | 190 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 191 | 192 | if DEBUG: 193 | self.layer.weight.data[:, :i2] = Q[:, :i2] 194 | self.layer.weight.data[:, i2:] = W[:, i2:] 195 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 196 | print(torch.sum(Losses)) 197 | 198 | torch.cuda.synchronize() 199 | print('time %.2f' % (time.time() - tick)) 200 | print('error', torch.sum(Losses).item()) 201 | 202 | if actorder: 203 | Q = Q[:, invperm] 204 | Q_codes = Q_codes[:, invperm.cpu()] 205 | 206 | if isinstance(args.save_path, str) and isinstance(centroids, torch.Tensor): 207 | nrows, ncols = Q_codes.shape 208 | idx = torch.arange(0, ncols, 2)[None, :].repeat(nrows, 1).to(Q_codes.device) 209 | self.quantized_codes = torch.bitwise_or(torch.bitwise_left_shift(Q_codes.gather(1, idx), 4), Q_codes.gather(1, idx+1)) 210 | self.quant_grid = centroids.cpu() 211 | 212 | if isinstance(self.layer, transformers.Conv1D): 213 | Q = Q.t() 214 | print('norm of difference', torch.norm(self.layer.weight.data - Q).item()) 215 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 216 | if DEBUG: 217 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 218 | 219 | def free(self): 220 | if DEBUG: 221 | self.inp1 = None 222 | self.out1 = None 223 | self.H = None 224 | self.Losses = None 225 | self.Trace = None 226 | torch.cuda.empty_cache() 227 | -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["OMP_NUM_THREADS"] = "1" # this is necessary to parallelize the kmeans 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import AutoModelForCausalLM 8 | from safetensors.torch import save_file 9 | 10 | from lean_quantizer import * 11 | from modelutils import * 12 | from quant import * 13 | from leanquant import replace_with_quantizers 14 | 15 | 16 | def get_llama(model): 17 | import torch 18 | def skip(*args, **kwargs): 19 | pass 20 | torch.nn.init.kaiming_uniform_ = skip 21 | torch.nn.init.uniform_ = skip 22 | torch.nn.init.normal_ = skip 23 | model = AutoModelForCausalLM.from_pretrained(model, torch_dtype='auto') 24 | model.seqlen = 2048 25 | return model 26 | 27 | @torch.no_grad() 28 | def llama_sequential(model, dataloader, dev): 29 | print('Starting ...') 30 | 31 | use_cache = model.config.use_cache 32 | model.config.use_cache = False 33 | layers = model.model.layers 34 | 35 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 36 | model.model.norm = model.model.norm.to(dev) 37 | if hasattr(model.model, 'rotary_emb'): 38 | model.model.rotary_emb = model.model.rotary_emb.to(dev) 39 | layers[0] = layers[0].to(dev) 40 | 41 | dtype = next(iter(model.parameters())).dtype 42 | inps = torch.zeros( 43 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 44 | ) 45 | cache = {'i': 0, 'attention_mask': None} 46 | 47 | class Catcher(nn.Module): 48 | def __init__(self, module): 49 | super().__init__() 50 | self.module = module 51 | def forward(self, inp, **kwargs): 52 | inps[cache['i']] = inp 53 | cache['i'] += 1 54 | cache['attention_mask'] = kwargs['attention_mask'] 55 | cache['position_ids'] = kwargs['position_ids'] 56 | raise ValueError 57 | layers[0] = Catcher(layers[0]) 58 | for batch in dataloader: 59 | try: 60 | model(batch[0].to(dev)) 61 | except ValueError: 62 | pass 63 | layers[0] = layers[0].module 64 | 65 | layers[0] = layers[0].cpu() 66 | model.model.embed_tokens = model.model.embed_tokens.cpu() 67 | model.model.norm = model.model.norm.cpu() 68 | torch.cuda.empty_cache() 69 | 70 | outs = torch.zeros_like(inps) 71 | attention_mask = cache['attention_mask'] 72 | position_ids = cache['position_ids'] 73 | 74 | print('Ready.') 75 | 76 | for i in range(args.n_layers if isinstance(args.n_layers, int) else len(layers)): 77 | layer = layers[i].to(dev) 78 | quantizers = {} 79 | full = find_layers(layer) 80 | 81 | if args.true_sequential: 82 | sequential = [ 83 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], 84 | ['self_attn.o_proj'], 85 | ['mlp.up_proj', 'mlp.gate_proj'], 86 | ['mlp.down_proj'] 87 | ] 88 | else: 89 | sequential = [list(full.keys())] 90 | 91 | for names in sequential: 92 | subset = {n: full[n] for n in names} 93 | 94 | leanquant = {} 95 | for name in subset: 96 | leanquant[name] = LeanQuant(subset[name]) 97 | leanquant[name].quantizer = Quantizer() 98 | leanquant[name].quantizer.configure( 99 | args.wbits, perchannel=True, sym=args.sym, mse=False 100 | ) 101 | 102 | def add_batch(name): 103 | def tmp(_, inp, out): 104 | leanquant[name].add_batch(inp[0].data, out.data) 105 | return tmp 106 | handles = [] 107 | for name in subset: 108 | handles.append(subset[name].register_forward_hook(add_batch(name))) 109 | for j in range(args.nsamples): 110 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 111 | for h in handles: 112 | h.remove() 113 | 114 | for name in subset: 115 | print(i, name) 116 | print('Quantizing ...') 117 | 118 | leanquant[name].fasterquant( 119 | blocksize=args.block_size, percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, static_groups=args.static_groups, args=args, 120 | ) 121 | if isinstance(args.exponent, float): 122 | quantizers[name] = (leanquant[name].quant_grid, leanquant[name].quantized_codes) 123 | leanquant[name].free() 124 | 125 | for j in range(args.nsamples): 126 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 127 | 128 | layer = layer.cpu() 129 | replace_with_quantizers(layer, quantizers) 130 | layers[i] = layer 131 | del layer 132 | del leanquant 133 | torch.cuda.empty_cache() 134 | 135 | inps, outs = outs, inps 136 | 137 | if isinstance(args.save_path, str): 138 | save_file(model.state_dict(), args.save_path) 139 | 140 | model.config.use_cache = use_cache 141 | 142 | return quantizers 143 | 144 | @torch.no_grad() 145 | def llama_eval(model, testenc, dev): 146 | print('Evaluating ...') 147 | 148 | testenc = testenc.input_ids 149 | nsamples = testenc.numel() // model.seqlen 150 | 151 | use_cache = model.config.use_cache 152 | model.config.use_cache = False 153 | layers = model.model.layers 154 | 155 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 156 | model.model.norm = model.model.norm.to(dev) 157 | if hasattr(model.model, 'rotary_emb'): 158 | model.model.rotary_emb = model.model.rotary_emb.to(dev) 159 | layers[0] = layers[0].to(dev) 160 | 161 | dtype = next(iter(model.parameters())).dtype 162 | inps = torch.zeros( 163 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 164 | ) 165 | cache = {'i': 0, 'attention_mask': None} 166 | 167 | class Catcher(nn.Module): 168 | def __init__(self, module): 169 | super().__init__() 170 | self.module = module 171 | def forward(self, inp, **kwargs): 172 | inps[cache['i']] = inp 173 | cache['i'] += 1 174 | cache['attention_mask'] = kwargs['attention_mask'] 175 | cache['position_ids'] = kwargs['position_ids'] 176 | raise ValueError 177 | layers[0] = Catcher(layers[0]) 178 | for i in range(nsamples): 179 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 180 | try: 181 | model(batch) 182 | except ValueError: 183 | pass 184 | layers[0] = layers[0].module 185 | 186 | layers[0] = layers[0].cpu() 187 | model.model.embed_tokens = model.model.embed_tokens.cpu() 188 | torch.cuda.empty_cache() 189 | 190 | outs = torch.zeros_like(inps) 191 | attention_mask = cache['attention_mask'] 192 | position_ids = cache['position_ids'] 193 | 194 | for i in range(len(layers)): 195 | print(i) 196 | layer = layers[i].to(dev) 197 | 198 | if args.nearest: 199 | subset = find_layers(layer) 200 | for name in subset: 201 | quantizer = Quantizer() 202 | quantizer.configure( 203 | args.wbits, perchannel=True, sym=False, mse=False 204 | ) 205 | W = subset[name].weight.data 206 | quantizer.find_params(W, weight=True) 207 | subset[name].weight.data = quantize( 208 | W, quantizer.scale, quantizer.zero, quantizer.maxq 209 | ).to(next(iter(layer.parameters())).dtype) 210 | 211 | for j in range(nsamples): 212 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 213 | layers[i] = layer.cpu() 214 | del layer 215 | torch.cuda.empty_cache() 216 | inps, outs = outs, inps 217 | 218 | if model.model.norm is not None: 219 | model.model.norm = model.model.norm.to(dev) 220 | model.lm_head = model.lm_head.to(dev) 221 | 222 | testenc = testenc.to(dev) 223 | nlls = [] 224 | for i in range(nsamples): 225 | hidden_states = inps[i].unsqueeze(0) 226 | if model.model.norm is not None: 227 | hidden_states = model.model.norm(hidden_states) 228 | lm_logits = model.lm_head(hidden_states) 229 | shift_logits = lm_logits[:, :-1, :].contiguous() 230 | shift_labels = testenc[ 231 | :, (i * model.seqlen):((i + 1) * model.seqlen) 232 | ][:, 1:] 233 | loss_fct = nn.CrossEntropyLoss() 234 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 235 | neg_log_likelihood = loss.float() * model.seqlen 236 | nlls.append(neg_log_likelihood) 237 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 238 | print(ppl.item()) 239 | 240 | model.config.use_cache = use_cache 241 | 242 | 243 | if __name__ == '__main__': 244 | import argparse 245 | from datautils import * 246 | 247 | parser = argparse.ArgumentParser() 248 | 249 | parser.add_argument( 250 | 'model', type=str, 251 | help='LlaMa model to load; pass location of hugginface converted checkpoint.' 252 | ) 253 | parser.add_argument( 254 | 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4', 'ptb-new', 'c4-new', 'c4-full'], 255 | help='Where to extract calibration data from.' 256 | ) 257 | parser.add_argument( 258 | '--seed', 259 | type=int, default=0, help='Seed for sampling the calibration data.' 260 | ) 261 | parser.add_argument( 262 | '--nsamples', type=int, default=128, 263 | help='Number of calibration data samples.' 264 | ) 265 | parser.add_argument( 266 | '--percdamp', type=float, default=.01, 267 | help='Percent of the average Hessian diagonal to use for dampening.' 268 | ) 269 | parser.add_argument( 270 | '--nearest', action='store_true', 271 | help='Whether to run the RTN baseline.' 272 | ) 273 | parser.add_argument( 274 | '--wbits', type=int, default=16, choices=[2, 3, 4, 16], 275 | help='#bits to use for quantization; use 16 for evaluating base model.' 276 | ) 277 | parser.add_argument( 278 | '--groupsize', type=int, default=-1, 279 | help='Groupsize to use for quantization; default uses full row.' 280 | ) 281 | parser.add_argument( 282 | '--sym', action='store_true', 283 | help='Whether to perform symmetric quantization.' 284 | ) 285 | parser.add_argument( 286 | '--new-eval', action='store_true', 287 | help='Whether to use the new PTB and C4 eval.' 288 | ) 289 | parser.add_argument( 290 | '--act-order', action='store_true', 291 | help='Whether to apply the activation order GPTQ heuristic' 292 | ) 293 | parser.add_argument( 294 | '--true-sequential', action='store_true', 295 | help='Whether to run in true sequential model.' 296 | ) 297 | parser.add_argument( 298 | '--static-groups', action='store_true', 299 | help='Whether to use static groups; recommended when using `--actorder` for more efficient inference.' 300 | ) 301 | parser.add_argument( 302 | '--n_layers', type=int, default=None, 303 | ) 304 | parser.add_argument( 305 | '--block_size', type=int, default=128, 306 | ) 307 | parser.add_argument( 308 | '--exponent', type=float, default=None, 309 | ) 310 | parser.add_argument( 311 | '--kmeans_seed', type=int, default=0, 312 | ) 313 | parser.add_argument( 314 | '--offload_threshold', type=int, default=53248, 315 | ) 316 | parser.add_argument( 317 | '--save_path', type=str, default=None, 318 | ) 319 | 320 | args = parser.parse_args() 321 | print(args) 322 | 323 | model = get_llama(args.model) 324 | model.eval() 325 | 326 | dataloader, testloader = get_loaders( 327 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 328 | ) 329 | 330 | if args.wbits < 16 and not args.nearest: 331 | tick = time.time() 332 | quantizers = llama_sequential(model, dataloader, DEV) 333 | print(f'quant_time={time.time() - tick}') 334 | 335 | datasets = ['wikitext2', 'ptb', 'c4'] 336 | if args.new_eval: 337 | datasets = ['wikitext2', 'ptb-new', 'c4-new'] 338 | for dataset in datasets: 339 | dataloader, testloader = get_loaders( 340 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 341 | ) 342 | print(dataset) 343 | llama_eval(model, testloader, DEV) 344 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
3 |
4 |
6 | ICLR 2025 | Accurate and Scalable LLM Quantization with Loss-error-aware Grid 7 |
8 |
9 | 🚀 Quantizes a 70B model on a single 24GB GPU in 4 hours
10 | ⚡ Quantizes a 405B model on two 48GB GPUs in 24 hours
11 |