├── .gitignore ├── assets └── LeanLlama.jpg ├── leanquant ├── __init__.py └── leanquant_utils.py ├── requirements.txt ├── modelutils.py ├── setup.py ├── eval_quantized.py ├── quant.py ├── datautils.py ├── lean_quantizer.py ├── llama.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.log 3 | *.safetensors 4 | *egg-info* 5 | -------------------------------------------------------------------------------- /assets/LeanLlama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanModels/LeanQuant/HEAD/assets/LeanLlama.jpg -------------------------------------------------------------------------------- /leanquant/__init__.py: -------------------------------------------------------------------------------- 1 | from .leanquant_utils import LeanQuantModelForCausalLM, Sub4BitLinear, replace_with_quantizers -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.47.0 2 | numpy==1.26.4 3 | scikit-learn==1.5.2 4 | accelerate==0.34.2 5 | lm-eval==0.4.4 6 | safetensors==0.4.5 7 | tqdm==4.66.5 -------------------------------------------------------------------------------- /modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | DEV = torch.device('cuda:0') 6 | 7 | 8 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 9 | if type(module) in layers: 10 | return {name: module} 11 | res = {} 12 | for name1, child in module.named_children(): 13 | res.update(find_layers( 14 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 15 | )) 16 | return res 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="leanquant", 5 | version="0.1.1", 6 | author="Tianyi Zhang", 7 | author_email="tonyzhang617@gmail.com", 8 | description="The inference kernels for LeanQuant models.", 9 | packages=find_packages(), 10 | install_requires=[ 11 | "transformers>=4.38", 12 | "accelerate", 13 | "safetensors", 14 | "torch", 15 | ], 16 | extras_require={ 17 | "cuda11": [ 18 | "cupy-cuda11x", 19 | ], 20 | "cuda12": [ 21 | "cupy-cuda12x", 22 | ], 23 | }, 24 | ) 25 | -------------------------------------------------------------------------------- /eval_quantized.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from transformers import AutoTokenizer 5 | import lm_eval 6 | 7 | from leanquant import LeanQuantModelForCausalLM 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | 12 | parser.add_argument("--base_model_name_or_path", type=str, default="meta-llama/Llama-2-7b-hf") 13 | parser.add_argument("--leanquant_path", type=str, default="models/llama-2-7b_b3_e3_d0.1.safetensors") 14 | parser.add_argument("--bits", type=int, default=4) 15 | parser.add_argument("--use_bf16", action="store_true") 16 | parser.add_argument("--tasks", nargs='+', type=str, default=["mmlu"]) 17 | parser.add_argument("--eval_batch_size", type=int, default=4) 18 | args = parser.parse_args() 19 | print(args) 20 | 21 | tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path) 22 | model = LeanQuantModelForCausalLM.from_pretrained( 23 | args.base_model_name_or_path, args.leanquant_path, 24 | args.bits, torch_dtype=torch.bfloat16 if args.use_bf16 else torch.float16, 25 | device_map="auto", 26 | ) 27 | 28 | model_eval = lm_eval.models.huggingface.HFLM(model, tokenizer=tokenizer, device=model.device, batch_size=args.eval_batch_size, trust_remote_code=True) 29 | results = lm_eval.simple_evaluate(model=model_eval, tasks=args.tasks) 30 | print(results['results']) 31 | -------------------------------------------------------------------------------- /quant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def quantize(x, scale, zero, maxq): 7 | if maxq < 0: 8 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 9 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 10 | return scale * (q - zero) 11 | 12 | class Quantizer(nn.Module): 13 | 14 | def __init__(self, shape=1): 15 | super(Quantizer, self).__init__() 16 | self.register_buffer('maxq', torch.tensor(0)) 17 | self.register_buffer('scale', torch.zeros(shape)) 18 | self.register_buffer('zero', torch.zeros(shape)) 19 | 20 | def configure( 21 | self, 22 | bits, perchannel=False, sym=True, 23 | mse=False, norm=2.4, grid=100, maxshrink=.8, 24 | trits=False 25 | ): 26 | self.maxq = torch.tensor(2 ** bits - 1) 27 | self.perchannel = perchannel 28 | self.sym = sym 29 | self.mse = mse 30 | self.norm = norm 31 | self.grid = grid 32 | self.maxshrink = maxshrink 33 | if trits: 34 | self.maxq = torch.tensor(-1) 35 | 36 | def find_params(self, x, weight=False): 37 | dev = x.device 38 | self.maxq = self.maxq.to(dev) 39 | 40 | shape = x.shape 41 | if self.perchannel: 42 | if weight: 43 | x = x.flatten(1) 44 | else: 45 | if len(shape) == 4: 46 | x = x.permute([1, 0, 2, 3]) 47 | x = x.flatten(1) 48 | if len(shape) == 3: 49 | x = x.reshape((-1, shape[-1])).t() 50 | if len(shape) == 2: 51 | x = x.t() 52 | else: 53 | x = x.flatten().unsqueeze(0) 54 | 55 | tmp = torch.zeros(x.shape[0], device=dev) 56 | xmin = torch.minimum(x.min(1)[0], tmp) 57 | xmax = torch.maximum(x.max(1)[0], tmp) 58 | 59 | if self.sym: 60 | xmax = torch.maximum(torch.abs(xmin), xmax) 61 | tmp = xmin < 0 62 | if torch.any(tmp): 63 | xmin[tmp] = -xmax[tmp] 64 | tmp = (xmin == 0) & (xmax == 0) 65 | xmin[tmp] = -1 66 | xmax[tmp] = +1 67 | 68 | if self.maxq < 0: 69 | self.scale = xmax 70 | self.zero = xmin 71 | else: 72 | self.scale = (xmax - xmin) / self.maxq 73 | if self.sym: 74 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 75 | else: 76 | self.zero = torch.round(-xmin / self.scale) 77 | 78 | if self.mse: 79 | best = torch.full([x.shape[0]], float('inf'), device=dev) 80 | for i in range(int(self.maxshrink * self.grid)): 81 | p = 1 - i / self.grid 82 | xmin1 = p * xmin 83 | xmax1 = p * xmax 84 | scale1 = (xmax1 - xmin1) / self.maxq 85 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 86 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 87 | q -= x 88 | q.abs_() 89 | q.pow_(self.norm) 90 | err = torch.sum(q, 1) 91 | tmp = err < best 92 | if torch.any(tmp): 93 | best[tmp] = err[tmp] 94 | self.scale[tmp] = scale1[tmp] 95 | self.zero[tmp] = zero1[tmp] 96 | if not self.perchannel: 97 | if weight: 98 | tmp = shape[0] 99 | else: 100 | tmp = shape[1] if len(shape) != 3 else shape[2] 101 | self.scale = self.scale.repeat(tmp) 102 | self.zero = self.zero.repeat(tmp) 103 | 104 | if weight: 105 | shape = [-1] + [1] * (len(shape) - 1) 106 | self.scale = self.scale.reshape(shape) 107 | self.zero = self.zero.reshape(shape) 108 | return 109 | if len(shape) == 4: 110 | self.scale = self.scale.reshape((1, -1, 1, 1)) 111 | self.zero = self.zero.reshape((1, -1, 1, 1)) 112 | if len(shape) == 3: 113 | self.scale = self.scale.reshape((1, 1, -1)) 114 | self.zero = self.zero.reshape((1, 1, -1)) 115 | if len(shape) == 2: 116 | self.scale = self.scale.unsqueeze(0) 117 | self.zero = self.zero.unsqueeze(0) 118 | 119 | def quantize(self, x): 120 | if self.ready(): 121 | return quantize(x, self.scale, self.zero, self.maxq) 122 | return x 123 | 124 | def enabled(self): 125 | return self.maxq > 0 126 | 127 | def ready(self): 128 | return torch.all(self.scale != 0) 129 | -------------------------------------------------------------------------------- /leanquant/leanquant_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Union, Dict 3 | 4 | import torch 5 | from torch import nn 6 | import cupy as cp 7 | 8 | from transformers.modeling_utils import no_init_weights 9 | from transformers import AutoConfig, AutoModelForCausalLM 10 | from accelerate import infer_auto_device_map, dispatch_model 11 | from accelerate.utils import get_balanced_memory 12 | from safetensors import safe_open 13 | 14 | kernel_code = '''typedef unsigned char uint8_t; 15 | typedef unsigned short uint16_t; 16 | 17 | extern "C" 18 | __global__ void gather_sub4bit( 19 | const uint16_t* __restrict__ src, // nrows x 2^bits 20 | const uint8_t* __restrict__ codes, // nrows x (ncols / 2) 21 | uint16_t* __restrict__ dst, // nrows x ncols 22 | int bits, int ncols 23 | ) { 24 | extern __shared__ volatile uint16_t cache[]; // 2^bits 25 | 26 | const int row_id = blockIdx.x; 27 | const int thread_id = threadIdx.x; 28 | const int n_threads = blockDim.x; 29 | 30 | const int n_floats = 1 << bits; 31 | 32 | #pragma unroll 33 | for (int i = thread_id; i < n_floats; i += n_threads) { 34 | cache[i] = src[row_id * n_floats + i]; 35 | } 36 | __syncthreads(); 37 | 38 | for (int i = thread_id; i < ncols / 2; i += n_threads) { 39 | uint8_t code = codes[row_id * ncols / 2 + i]; 40 | dst[row_id * ncols + i * 2] = cache[code >> 4]; 41 | dst[row_id * ncols + i * 2 + 1] = cache[code & 0xf]; 42 | } 43 | }''' 44 | _gather_sub4bit = cp.RawKernel( 45 | kernel_code, 46 | 'gather_sub4bit' 47 | ) 48 | 49 | class Sub4BitLinear(nn.Module): 50 | def __init__(self, orig_weight, bits=4, quant_grid=None, weight_codes=None, dtype=torch.float16): 51 | super().__init__() 52 | 53 | if isinstance(orig_weight, torch.Tensor): 54 | rows, cols = orig_weight.shape 55 | self.quant_grid = nn.Parameter(torch.empty(rows, 2 ** bits, dtype=orig_weight.dtype)) 56 | weight_codes = torch.empty(rows, cols // 2, dtype=torch.uint8) 57 | self.register_buffer('weight_codes', weight_codes) 58 | elif isinstance(quant_grid, torch.Tensor) and isinstance(weight_codes, torch.Tensor): 59 | assert dtype == torch.float16 or dtype == torch.bfloat16 60 | assert weight_codes.dtype == torch.uint8 61 | 62 | quant_grid = quant_grid.squeeze() 63 | weight_codes = weight_codes.squeeze() 64 | 65 | self.quant_grid = nn.Parameter(quant_grid.to(dtype)) 66 | self.register_buffer('weight_codes', weight_codes) 67 | 68 | assert self.quant_grid.shape[0] == self.weight_codes.shape[0] 69 | bits = int(math.log2(self.quant_grid.shape[1])) 70 | assert bits <= 4 71 | assert (2 ** bits) == self.quant_grid.shape[1] 72 | else: 73 | assert False, "This function can be initialized using either an `orig_weight` tensor or a pair of `quant_grid` and `weight_codes` tensors." 74 | 75 | def forward(self, x): 76 | assert x.device.index == self.quant_grid.device.index 77 | 78 | rows, cols_div2 = self.weight_codes.shape 79 | cols = cols_div2 * 2 80 | bits = int(math.log2(self.quant_grid.shape[1])) 81 | 82 | W = torch.empty(rows, cols, dtype=self.quant_grid.dtype, device=x.device) 83 | blocks_per_grid = (rows, ) 84 | threads_per_block = (512, ) 85 | 86 | with cp.cuda.Device(x.device.index): 87 | _gather_sub4bit(grid=blocks_per_grid, block=threads_per_block, shared_mem=2 ** bits * 2, args=[ 88 | self.quant_grid.data_ptr(), self.weight_codes.data_ptr(), W.data_ptr(), bits, cols, 89 | ]) 90 | 91 | return torch.matmul(x, W.t()) 92 | 93 | def replace_with_quantizers(module, quantizers, name=''): 94 | if isinstance(module, Sub4BitLinear): 95 | return 96 | 97 | for attr in dir(module): 98 | tmp = getattr(module, attr) 99 | name1 = name + '.' + attr if name != '' else attr 100 | if name1 in quantizers.keys(): 101 | delattr(module, attr) 102 | setattr(module, attr, Sub4BitLinear( 103 | None, None, 104 | quantizers[name1][0], 105 | quantizers[name1][1], 106 | next(tmp.parameters()).dtype 107 | )) 108 | print(f'{name1}: weights replaced') 109 | del tmp 110 | 111 | for name1, child in module.named_children(): 112 | replace_with_quantizers(child, quantizers, name + '.' + name1 if name != '' else name1) 113 | 114 | def replace_layers(module, bits=4, dtype=torch.float16, name=''): 115 | """Recursively replace all nn.Linear layers with Sub4BitLinear in the model.""" 116 | for attr_name in dir(module): 117 | sub_module = getattr(module, attr_name) 118 | 119 | if isinstance(sub_module, nn.Linear): 120 | delattr(module, attr_name) 121 | setattr(module, attr_name, Sub4BitLinear( 122 | sub_module.weight.data, bits=bits, dtype=dtype, 123 | )) 124 | del sub_module 125 | 126 | for child_name, child in module.named_children(): 127 | replace_layers(child, bits, dtype, f"{name}.{child_name}" if name else child_name) 128 | 129 | 130 | class LeanQuantModelForCausalLM: 131 | @classmethod 132 | def from_pretrained( 133 | cls, base_model_name_or_path: str, quantized_model_path: str, bits=4, torch_dtype=torch.float16, 134 | device: Optional[str] = None, device_map: str = 'auto', 135 | max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, 136 | ): 137 | """Load a pre-trained model and apply quantized layers.""" 138 | with no_init_weights(): 139 | config = AutoConfig.from_pretrained(base_model_name_or_path) 140 | model = AutoModelForCausalLM.from_config(config, torch_dtype=torch_dtype) 141 | 142 | replace_layers(model.model, bits=bits, dtype=torch_dtype) 143 | 144 | state_dict = {} 145 | with safe_open(quantized_model_path, framework='pt', device='cpu') as f: 146 | for k in f.keys(): 147 | state_dict[k] = f.get_tensor(k) 148 | 149 | model.load_state_dict(state_dict) 150 | 151 | if isinstance(device, str): 152 | model = model.to(device) 153 | else: 154 | assert device_map == 'auto', "device_map should be 'auto' if no specific device is provided." 155 | no_split_classes = [type(model.model.layers[0]).__name__] 156 | max_memory = get_balanced_memory(model, max_memory=max_memory, no_split_module_classes=no_split_classes, dtype=torch_dtype) 157 | device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_classes) 158 | model = dispatch_model(model, device_map) 159 | 160 | # Check if any parts of the model are on CPU and warn the user 161 | if any(param.device.type == 'cpu' for param in model.parameters()): 162 | print("Warning: Some model layers are on CPU. For inference, ensure the model is fully loaded onto CUDA-compatible GPUs.") 163 | 164 | return model -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_seed(seed): 6 | np.random.seed(seed) 7 | torch.random.manual_seed(seed) 8 | 9 | 10 | def get_wikitext2(nsamples, seed, seqlen, model): 11 | from datasets import load_dataset 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train', trust_remote_code=True) 13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test', trust_remote_code=True) 14 | 15 | from transformers import AutoTokenizer 16 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 17 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 19 | 20 | import random 21 | random.seed(seed) 22 | trainloader = [] 23 | for _ in range(nsamples): 24 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 25 | j = i + seqlen 26 | inp = trainenc.input_ids[:, i:j] 27 | tar = inp.clone() 28 | tar[:, :-1] = -100 29 | trainloader.append((inp, tar)) 30 | return trainloader, testenc 31 | 32 | def get_ptb(nsamples, seed, seqlen, model): 33 | from datasets import load_dataset 34 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', trust_remote_code=True) 35 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation', trust_remote_code=True) 36 | 37 | from transformers import AutoTokenizer 38 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 39 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') 40 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 41 | 42 | import random 43 | random.seed(seed) 44 | trainloader = [] 45 | for _ in range(nsamples): 46 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 47 | j = i + seqlen 48 | inp = trainenc.input_ids[:, i:j] 49 | tar = inp.clone() 50 | tar[:, :-1] = -100 51 | trainloader.append((inp, tar)) 52 | return trainloader, testenc 53 | 54 | def get_c4(nsamples, seed, seqlen, model): 55 | from datasets import load_dataset 56 | traindata = load_dataset( 57 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', trust_remote_code=True 58 | ) 59 | valdata = load_dataset( 60 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', trust_remote_code=True 61 | ) 62 | 63 | from transformers import AutoTokenizer 64 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 65 | 66 | import random 67 | random.seed(seed) 68 | trainloader = [] 69 | for _ in range(nsamples): 70 | while True: 71 | i = random.randint(0, len(traindata) - 1) 72 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 73 | if trainenc.input_ids.shape[1] >= seqlen: 74 | break 75 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 76 | j = i + seqlen 77 | inp = trainenc.input_ids[:, i:j] 78 | tar = inp.clone() 79 | tar[:, :-1] = -100 80 | trainloader.append((inp, tar)) 81 | 82 | import random 83 | random.seed(0) 84 | valenc = [] 85 | for _ in range(256): 86 | while True: 87 | i = random.randint(0, len(valdata) - 1) 88 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 89 | if tmp.input_ids.shape[1] >= seqlen: 90 | break 91 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 92 | j = i + seqlen 93 | valenc.append(tmp.input_ids[:, i:j]) 94 | valenc = torch.hstack(valenc) 95 | class TokenizerWrapper: 96 | def __init__(self, input_ids): 97 | self.input_ids = input_ids 98 | valenc = TokenizerWrapper(valenc) 99 | 100 | return trainloader, valenc 101 | 102 | def get_ptb_new(nsamples, seed, seqlen, model): 103 | from datasets import load_dataset 104 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', trust_remote_code=True) 105 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test', trust_remote_code=True) 106 | 107 | from transformers import AutoTokenizer 108 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 109 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 110 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 111 | 112 | import random 113 | random.seed(seed) 114 | trainloader = [] 115 | for _ in range(nsamples): 116 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 117 | j = i + seqlen 118 | inp = trainenc.input_ids[:, i:j] 119 | tar = inp.clone() 120 | tar[:, :-1] = -100 121 | trainloader.append((inp, tar)) 122 | return trainloader, testenc 123 | 124 | def get_c4_new(nsamples, seed, seqlen, model): 125 | from datasets import load_dataset 126 | traindata = load_dataset( 127 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', trust_remote_code=True 128 | ) 129 | valdata = load_dataset( 130 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', trust_remote_code=True 131 | ) 132 | 133 | from transformers import AutoTokenizer 134 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 135 | 136 | import random 137 | random.seed(seed) 138 | trainloader = [] 139 | for _ in range(nsamples): 140 | while True: 141 | i = random.randint(0, len(traindata) - 1) 142 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 143 | if trainenc.input_ids.shape[1] >= seqlen: 144 | break 145 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 146 | j = i + seqlen 147 | inp = trainenc.input_ids[:, i:j] 148 | tar = inp.clone() 149 | tar[:, :-1] = -100 150 | trainloader.append((inp, tar)) 151 | 152 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 153 | valenc = valenc.input_ids[:, :(256 * seqlen)] 154 | 155 | class TokenizerWrapper: 156 | def __init__(self, input_ids): 157 | self.input_ids = input_ids 158 | valenc = TokenizerWrapper(valenc) 159 | 160 | return trainloader, valenc 161 | 162 | def get_c4_full(nsamples, seed, seqlen, model): 163 | from datasets import load_dataset 164 | traindata = load_dataset( 165 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', trust_remote_code=True 166 | ) 167 | valdata = load_dataset( 168 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', trust_remote_code=True 169 | ) 170 | 171 | from transformers import AutoTokenizer 172 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 173 | 174 | np.random.seed(seed) 175 | idx_perm = np.random.permutation(np.arange(len(traindata))).tolist() 176 | 177 | trainloader = [] 178 | for i in idx_perm: 179 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 180 | if trainenc.input_ids.shape[1] >= seqlen: 181 | inp = trainenc.input_ids[:, :seqlen] 182 | tar = inp.clone() 183 | tar[:, :-1] = -100 184 | trainloader.append((inp, tar)) 185 | if len(trainloader) >= nsamples: 186 | break 187 | 188 | print(f'Collected {len(trainloader)} calibration samples.') 189 | 190 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 191 | valenc = valenc.input_ids[:, :(256 * seqlen)] 192 | 193 | class TokenizerWrapper: 194 | def __init__(self, input_ids): 195 | self.input_ids = input_ids 196 | valenc = TokenizerWrapper(valenc) 197 | 198 | return trainloader, valenc 199 | 200 | 201 | def get_loaders( 202 | name, nsamples=128, seed=0, seqlen=2048, model='' 203 | ): 204 | if 'wikitext2' in name: 205 | return get_wikitext2(nsamples, seed, seqlen, model) 206 | if 'ptb' in name: 207 | if 'new' in name: 208 | return get_ptb_new(nsamples, seed, seqlen, model) 209 | return get_ptb(nsamples, seed, seqlen, model) 210 | if 'c4' in name: 211 | if 'new' in name: 212 | return get_c4_new(nsamples, seed, seqlen, model) 213 | if 'full' in name: 214 | return get_c4_full(nsamples, seed, seqlen, model) 215 | return get_c4(nsamples, seed, seqlen, model) 216 | -------------------------------------------------------------------------------- /lean_quantizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | 5 | import numpy as np 6 | from sklearn.cluster import KMeans 7 | from multiprocessing import Pool 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn as nn 12 | import transformers 13 | 14 | from quant import * 15 | 16 | 17 | DEBUG = False 18 | 19 | torch.backends.cuda.matmul.allow_tf32 = False 20 | torch.backends.cudnn.allow_tf32 = False 21 | 22 | def kmeans_fit(row_data): 23 | weights_np, sample_weight, n_cluster, random_seed = row_data 24 | kmeans = KMeans( 25 | n_clusters=n_cluster, 26 | init=np.linspace(weights_np.min(), weights_np.max(), num=n_cluster)[:, None] if n_cluster <= 8 else 'k-means++', 27 | n_init='auto', 28 | random_state=random_seed, 29 | max_iter=100, 30 | tol=1e-6, 31 | ).fit(weights_np, sample_weight=sample_weight) 32 | return kmeans.cluster_centers_.reshape(-1) 33 | 34 | pool = Pool(len(os.sched_getaffinity(0))) 35 | 36 | class LeanQuant: 37 | 38 | def __init__(self, layer): 39 | self.layer = layer 40 | self.dev = self.layer.weight.device 41 | W = layer.weight.data.clone() 42 | if isinstance(self.layer, nn.Conv2d): 43 | W = W.flatten(1) 44 | if isinstance(self.layer, transformers.Conv1D): 45 | W = W.t() 46 | self.rows = W.shape[0] 47 | self.columns = W.shape[1] 48 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 49 | self.nsamples = 0 50 | self.lut = None 51 | 52 | def add_batch(self, inp, out): 53 | if DEBUG: 54 | self.inp1 = inp 55 | self.out1 = out 56 | if len(inp.shape) == 2: 57 | inp = inp.unsqueeze(0) 58 | tmp = inp.shape[0] 59 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 60 | if len(inp.shape) == 3: 61 | inp = inp.reshape((-1, inp.shape[-1])) 62 | inp = inp.t() 63 | if isinstance(self.layer, nn.Conv2d): 64 | unfold = nn.Unfold( 65 | self.layer.kernel_size, 66 | dilation=self.layer.dilation, 67 | padding=self.layer.padding, 68 | stride=self.layer.stride 69 | ) 70 | inp = unfold(inp) 71 | inp = inp.permute([1, 0, 2]) 72 | inp = inp.flatten(1) 73 | self.H *= self.nsamples / (self.nsamples + tmp) 74 | self.nsamples += tmp 75 | # inp = inp.float() 76 | inp = math.sqrt(2 / self.nsamples) * inp.float() 77 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 78 | self.H += inp.matmul(inp.t()) 79 | 80 | def fasterquant( 81 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False, args=None, 82 | ): 83 | W = self.layer.weight.data.clone() 84 | if isinstance(self.layer, nn.Conv2d): 85 | W = W.flatten(1) 86 | if isinstance(self.layer, transformers.Conv1D): 87 | W = W.t() 88 | W = W.float() 89 | 90 | tick = time.time() 91 | 92 | if not self.quantizer.ready(): 93 | self.quantizer.find_params(W, weight=True) 94 | 95 | H = self.H 96 | del self.H 97 | dead = torch.diag(H) == 0 98 | H[dead, dead] = 1 99 | W[:, dead] = 0 100 | 101 | if static_groups: 102 | import copy 103 | groups = [] 104 | for i in range(0, self.columns, groupsize): 105 | quantizer = copy.deepcopy(self.quantizer) 106 | quantizer.find_params(W[:, i:(i + groupsize)], weight=True) 107 | groups.append(quantizer) 108 | 109 | if H.shape[0] >= args.offload_threshold: 110 | secondary_device = torch.device('cuda:1') 111 | H = H.to(secondary_device) 112 | 113 | if actorder: 114 | perm_H = torch.argsort(torch.diag(H), descending=True) 115 | perm = perm_H.to(W.device) 116 | W = W[:, perm] 117 | H = H[perm_H][:, perm_H] 118 | invperm = torch.argsort(perm) 119 | 120 | damp = percdamp * torch.mean(torch.diag(H)) 121 | diag = torch.arange(self.columns, device=H.device) 122 | H[diag, diag] += damp 123 | 124 | H = torch.linalg.cholesky(H) 125 | H = torch.cholesky_inverse(H) 126 | H = torch.linalg.cholesky(H, upper=True) 127 | 128 | if H.shape[0] >= args.offload_threshold: 129 | H = H.to(self.dev) 130 | 131 | Losses = torch.zeros_like(W) 132 | Q = torch.zeros_like(W) 133 | Q_codes = Q.to(torch.uint8).cpu() 134 | Hinv = H 135 | torch.cuda.empty_cache() 136 | 137 | if isinstance(args.exponent, float): 138 | kmeans_tasks = [] 139 | W_np = W.cpu().numpy() 140 | Hinv_diagonal_np = (torch.diagonal(Hinv) ** (-args.exponent)).cpu().numpy() 141 | for j in range(W_np.shape[0]): 142 | kmeans_tasks.append((W_np[j, :, None], Hinv_diagonal_np, 2 ** args.wbits, args.kmeans_seed)) 143 | kmeans_results = list(tqdm(pool.imap(kmeans_fit, kmeans_tasks), total=len(kmeans_tasks))) 144 | centroids = torch.from_numpy(np.stack(kmeans_results)).reshape(W.shape[0], 2 ** args.wbits).to(W.device) 145 | else: 146 | centroids = None 147 | 148 | for i1 in range(0, self.columns, blocksize): 149 | i2 = min(i1 + blocksize, self.columns) 150 | count = i2 - i1 151 | 152 | W1 = W[:, i1:i2].clone() 153 | Q1 = torch.zeros_like(W1) 154 | Err1 = torch.zeros_like(W1) 155 | Losses1 = torch.zeros_like(W1) 156 | Hinv1 = Hinv[i1:i2, i1:i2] 157 | 158 | for i in range(count): 159 | w = W1[:, i] 160 | d = Hinv1[i, i] 161 | 162 | if groupsize != -1: 163 | if not static_groups: 164 | if (i1 + i) % groupsize == 0: 165 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) 166 | else: 167 | idx = i1 + i 168 | if actorder: 169 | idx = perm[idx] 170 | self.quantizer = groups[idx // groupsize] 171 | 172 | if isinstance(centroids, torch.Tensor): 173 | codes = torch.argmin((centroids - w[:, None]).abs(), dim=1, keepdim=True) 174 | Q_codes[:, i1+i] = codes.flatten().to(torch.uint8).cpu() 175 | q = torch.gather(centroids, 1, codes).flatten() 176 | else: 177 | q = quantize( 178 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 179 | ).flatten() 180 | Q1[:, i] = q 181 | Losses1[:, i] = (w - q) ** 2 / d ** 2 182 | 183 | err1 = (w - q) / d 184 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 185 | Err1[:, i] = err1 186 | 187 | Q[:, i1:i2] = Q1 188 | Losses[:, i1:i2] = Losses1 / 2 189 | 190 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 191 | 192 | if DEBUG: 193 | self.layer.weight.data[:, :i2] = Q[:, :i2] 194 | self.layer.weight.data[:, i2:] = W[:, i2:] 195 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 196 | print(torch.sum(Losses)) 197 | 198 | torch.cuda.synchronize() 199 | print('time %.2f' % (time.time() - tick)) 200 | print('error', torch.sum(Losses).item()) 201 | 202 | if actorder: 203 | Q = Q[:, invperm] 204 | Q_codes = Q_codes[:, invperm.cpu()] 205 | 206 | if isinstance(args.save_path, str) and isinstance(centroids, torch.Tensor): 207 | nrows, ncols = Q_codes.shape 208 | idx = torch.arange(0, ncols, 2)[None, :].repeat(nrows, 1).to(Q_codes.device) 209 | self.quantized_codes = torch.bitwise_or(torch.bitwise_left_shift(Q_codes.gather(1, idx), 4), Q_codes.gather(1, idx+1)) 210 | self.quant_grid = centroids.cpu() 211 | 212 | if isinstance(self.layer, transformers.Conv1D): 213 | Q = Q.t() 214 | print('norm of difference', torch.norm(self.layer.weight.data - Q).item()) 215 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 216 | if DEBUG: 217 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 218 | 219 | def free(self): 220 | if DEBUG: 221 | self.inp1 = None 222 | self.out1 = None 223 | self.H = None 224 | self.Losses = None 225 | self.Trace = None 226 | torch.cuda.empty_cache() 227 | -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["OMP_NUM_THREADS"] = "1" # this is necessary to parallelize the kmeans 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import AutoModelForCausalLM 8 | from safetensors.torch import save_file 9 | 10 | from lean_quantizer import * 11 | from modelutils import * 12 | from quant import * 13 | from leanquant import replace_with_quantizers 14 | 15 | 16 | def get_llama(model): 17 | import torch 18 | def skip(*args, **kwargs): 19 | pass 20 | torch.nn.init.kaiming_uniform_ = skip 21 | torch.nn.init.uniform_ = skip 22 | torch.nn.init.normal_ = skip 23 | model = AutoModelForCausalLM.from_pretrained(model, torch_dtype='auto') 24 | model.seqlen = 2048 25 | return model 26 | 27 | @torch.no_grad() 28 | def llama_sequential(model, dataloader, dev): 29 | print('Starting ...') 30 | 31 | use_cache = model.config.use_cache 32 | model.config.use_cache = False 33 | layers = model.model.layers 34 | 35 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 36 | model.model.norm = model.model.norm.to(dev) 37 | if hasattr(model.model, 'rotary_emb'): 38 | model.model.rotary_emb = model.model.rotary_emb.to(dev) 39 | layers[0] = layers[0].to(dev) 40 | 41 | dtype = next(iter(model.parameters())).dtype 42 | inps = torch.zeros( 43 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 44 | ) 45 | cache = {'i': 0, 'attention_mask': None} 46 | 47 | class Catcher(nn.Module): 48 | def __init__(self, module): 49 | super().__init__() 50 | self.module = module 51 | def forward(self, inp, **kwargs): 52 | inps[cache['i']] = inp 53 | cache['i'] += 1 54 | cache['attention_mask'] = kwargs['attention_mask'] 55 | cache['position_ids'] = kwargs['position_ids'] 56 | raise ValueError 57 | layers[0] = Catcher(layers[0]) 58 | for batch in dataloader: 59 | try: 60 | model(batch[0].to(dev)) 61 | except ValueError: 62 | pass 63 | layers[0] = layers[0].module 64 | 65 | layers[0] = layers[0].cpu() 66 | model.model.embed_tokens = model.model.embed_tokens.cpu() 67 | model.model.norm = model.model.norm.cpu() 68 | torch.cuda.empty_cache() 69 | 70 | outs = torch.zeros_like(inps) 71 | attention_mask = cache['attention_mask'] 72 | position_ids = cache['position_ids'] 73 | 74 | print('Ready.') 75 | 76 | for i in range(args.n_layers if isinstance(args.n_layers, int) else len(layers)): 77 | layer = layers[i].to(dev) 78 | quantizers = {} 79 | full = find_layers(layer) 80 | 81 | if args.true_sequential: 82 | sequential = [ 83 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], 84 | ['self_attn.o_proj'], 85 | ['mlp.up_proj', 'mlp.gate_proj'], 86 | ['mlp.down_proj'] 87 | ] 88 | else: 89 | sequential = [list(full.keys())] 90 | 91 | for names in sequential: 92 | subset = {n: full[n] for n in names} 93 | 94 | leanquant = {} 95 | for name in subset: 96 | leanquant[name] = LeanQuant(subset[name]) 97 | leanquant[name].quantizer = Quantizer() 98 | leanquant[name].quantizer.configure( 99 | args.wbits, perchannel=True, sym=args.sym, mse=False 100 | ) 101 | 102 | def add_batch(name): 103 | def tmp(_, inp, out): 104 | leanquant[name].add_batch(inp[0].data, out.data) 105 | return tmp 106 | handles = [] 107 | for name in subset: 108 | handles.append(subset[name].register_forward_hook(add_batch(name))) 109 | for j in range(args.nsamples): 110 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 111 | for h in handles: 112 | h.remove() 113 | 114 | for name in subset: 115 | print(i, name) 116 | print('Quantizing ...') 117 | 118 | leanquant[name].fasterquant( 119 | blocksize=args.block_size, percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, static_groups=args.static_groups, args=args, 120 | ) 121 | if isinstance(args.exponent, float): 122 | quantizers[name] = (leanquant[name].quant_grid, leanquant[name].quantized_codes) 123 | leanquant[name].free() 124 | 125 | for j in range(args.nsamples): 126 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 127 | 128 | layer = layer.cpu() 129 | replace_with_quantizers(layer, quantizers) 130 | layers[i] = layer 131 | del layer 132 | del leanquant 133 | torch.cuda.empty_cache() 134 | 135 | inps, outs = outs, inps 136 | 137 | if isinstance(args.save_path, str): 138 | save_file(model.state_dict(), args.save_path) 139 | 140 | model.config.use_cache = use_cache 141 | 142 | return quantizers 143 | 144 | @torch.no_grad() 145 | def llama_eval(model, testenc, dev): 146 | print('Evaluating ...') 147 | 148 | testenc = testenc.input_ids 149 | nsamples = testenc.numel() // model.seqlen 150 | 151 | use_cache = model.config.use_cache 152 | model.config.use_cache = False 153 | layers = model.model.layers 154 | 155 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 156 | model.model.norm = model.model.norm.to(dev) 157 | if hasattr(model.model, 'rotary_emb'): 158 | model.model.rotary_emb = model.model.rotary_emb.to(dev) 159 | layers[0] = layers[0].to(dev) 160 | 161 | dtype = next(iter(model.parameters())).dtype 162 | inps = torch.zeros( 163 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 164 | ) 165 | cache = {'i': 0, 'attention_mask': None} 166 | 167 | class Catcher(nn.Module): 168 | def __init__(self, module): 169 | super().__init__() 170 | self.module = module 171 | def forward(self, inp, **kwargs): 172 | inps[cache['i']] = inp 173 | cache['i'] += 1 174 | cache['attention_mask'] = kwargs['attention_mask'] 175 | cache['position_ids'] = kwargs['position_ids'] 176 | raise ValueError 177 | layers[0] = Catcher(layers[0]) 178 | for i in range(nsamples): 179 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 180 | try: 181 | model(batch) 182 | except ValueError: 183 | pass 184 | layers[0] = layers[0].module 185 | 186 | layers[0] = layers[0].cpu() 187 | model.model.embed_tokens = model.model.embed_tokens.cpu() 188 | torch.cuda.empty_cache() 189 | 190 | outs = torch.zeros_like(inps) 191 | attention_mask = cache['attention_mask'] 192 | position_ids = cache['position_ids'] 193 | 194 | for i in range(len(layers)): 195 | print(i) 196 | layer = layers[i].to(dev) 197 | 198 | if args.nearest: 199 | subset = find_layers(layer) 200 | for name in subset: 201 | quantizer = Quantizer() 202 | quantizer.configure( 203 | args.wbits, perchannel=True, sym=False, mse=False 204 | ) 205 | W = subset[name].weight.data 206 | quantizer.find_params(W, weight=True) 207 | subset[name].weight.data = quantize( 208 | W, quantizer.scale, quantizer.zero, quantizer.maxq 209 | ).to(next(iter(layer.parameters())).dtype) 210 | 211 | for j in range(nsamples): 212 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 213 | layers[i] = layer.cpu() 214 | del layer 215 | torch.cuda.empty_cache() 216 | inps, outs = outs, inps 217 | 218 | if model.model.norm is not None: 219 | model.model.norm = model.model.norm.to(dev) 220 | model.lm_head = model.lm_head.to(dev) 221 | 222 | testenc = testenc.to(dev) 223 | nlls = [] 224 | for i in range(nsamples): 225 | hidden_states = inps[i].unsqueeze(0) 226 | if model.model.norm is not None: 227 | hidden_states = model.model.norm(hidden_states) 228 | lm_logits = model.lm_head(hidden_states) 229 | shift_logits = lm_logits[:, :-1, :].contiguous() 230 | shift_labels = testenc[ 231 | :, (i * model.seqlen):((i + 1) * model.seqlen) 232 | ][:, 1:] 233 | loss_fct = nn.CrossEntropyLoss() 234 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 235 | neg_log_likelihood = loss.float() * model.seqlen 236 | nlls.append(neg_log_likelihood) 237 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 238 | print(ppl.item()) 239 | 240 | model.config.use_cache = use_cache 241 | 242 | 243 | if __name__ == '__main__': 244 | import argparse 245 | from datautils import * 246 | 247 | parser = argparse.ArgumentParser() 248 | 249 | parser.add_argument( 250 | 'model', type=str, 251 | help='LlaMa model to load; pass location of hugginface converted checkpoint.' 252 | ) 253 | parser.add_argument( 254 | 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4', 'ptb-new', 'c4-new', 'c4-full'], 255 | help='Where to extract calibration data from.' 256 | ) 257 | parser.add_argument( 258 | '--seed', 259 | type=int, default=0, help='Seed for sampling the calibration data.' 260 | ) 261 | parser.add_argument( 262 | '--nsamples', type=int, default=128, 263 | help='Number of calibration data samples.' 264 | ) 265 | parser.add_argument( 266 | '--percdamp', type=float, default=.01, 267 | help='Percent of the average Hessian diagonal to use for dampening.' 268 | ) 269 | parser.add_argument( 270 | '--nearest', action='store_true', 271 | help='Whether to run the RTN baseline.' 272 | ) 273 | parser.add_argument( 274 | '--wbits', type=int, default=16, choices=[2, 3, 4, 16], 275 | help='#bits to use for quantization; use 16 for evaluating base model.' 276 | ) 277 | parser.add_argument( 278 | '--groupsize', type=int, default=-1, 279 | help='Groupsize to use for quantization; default uses full row.' 280 | ) 281 | parser.add_argument( 282 | '--sym', action='store_true', 283 | help='Whether to perform symmetric quantization.' 284 | ) 285 | parser.add_argument( 286 | '--new-eval', action='store_true', 287 | help='Whether to use the new PTB and C4 eval.' 288 | ) 289 | parser.add_argument( 290 | '--act-order', action='store_true', 291 | help='Whether to apply the activation order GPTQ heuristic' 292 | ) 293 | parser.add_argument( 294 | '--true-sequential', action='store_true', 295 | help='Whether to run in true sequential model.' 296 | ) 297 | parser.add_argument( 298 | '--static-groups', action='store_true', 299 | help='Whether to use static groups; recommended when using `--actorder` for more efficient inference.' 300 | ) 301 | parser.add_argument( 302 | '--n_layers', type=int, default=None, 303 | ) 304 | parser.add_argument( 305 | '--block_size', type=int, default=128, 306 | ) 307 | parser.add_argument( 308 | '--exponent', type=float, default=None, 309 | ) 310 | parser.add_argument( 311 | '--kmeans_seed', type=int, default=0, 312 | ) 313 | parser.add_argument( 314 | '--offload_threshold', type=int, default=53248, 315 | ) 316 | parser.add_argument( 317 | '--save_path', type=str, default=None, 318 | ) 319 | 320 | args = parser.parse_args() 321 | print(args) 322 | 323 | model = get_llama(args.model) 324 | model.eval() 325 | 326 | dataloader, testloader = get_loaders( 327 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 328 | ) 329 | 330 | if args.wbits < 16 and not args.nearest: 331 | tick = time.time() 332 | quantizers = llama_sequential(model, dataloader, DEV) 333 | print(f'quant_time={time.time() - tick}') 334 | 335 | datasets = ['wikitext2', 'ptb', 'c4'] 336 | if args.new_eval: 337 | datasets = ['wikitext2', 'ptb-new', 'c4-new'] 338 | for dataset in datasets: 339 | dataloader, testloader = get_loaders( 340 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 341 | ) 342 | print(dataset) 343 | llama_eval(model, testloader, DEV) 344 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

LeanQuant

2 |

3 | A lean and mean llama. 4 |

5 |

6 | ICLR 2025 | Accurate and Scalable LLM Quantization with Loss-error-aware Grid 7 |

8 |

9 | 🚀 Quantizes a 70B model on a single 24GB GPU in 4 hours
10 | ⚡ Quantizes a 405B model on two 48GB GPUs in 24 hours 11 |

12 | 13 | --- 14 | 15 | ## 🔍 What is LeanQuant? 16 | 17 | **LeanQuant** is an efficient large language model (LLM) quantization framework that minimizes quality loss while maximizing computational and memory efficiency. It introduces a **loss-error-aware quantization grid** that preserves outliers in the inverse Hessian—achieving superior model quality without extra storage or inference overhead. 18 | 19 | 📄 Read the full paper: [arXiv](https://arxiv.org/abs/2407.10032) 20 | 21 | --- 22 | 23 | ## 🚀 Why LeanQuant? 24 | 25 | ✅ **Scalable Quantization** – Handles ultra-large models with one or two GPUs 26 | 27 | ✅ **Efficient Inference** – Optimized 4-bit CUDA kernels for fast and memory-efficient execution 28 | 29 | ✅ **High Accuracy** – Compares favorably against state-of-the-art quantization methods 30 | 31 | ✅ **Versatile** – Supports non-uniform and affine quantization formats 32 | 33 | ✅ **Minimal Dependencies** – Easy installation and setup 34 | 35 | ✅ **Broad Compatibility** – Works on most CUDA GPUs, and supports multi-GPU distributed inference 36 | 37 | --- 38 | 39 | ## 🛠️ Quick Start 40 | 41 | 1. Make sure you have a Linux environment, a CUDA-enabled GPU, and Python and PyTorch installed. 42 | 2. Install our pip package. 43 | ```bash 44 | # For CUDA 11.x 45 | pip install leanquant[cuda11] 46 | 47 | # For CUDA 12.x 48 | pip install leanquant[cuda12] 49 | ``` 50 | 3. Download a LeanQuant model from the Model Zoo below or from [our HuggingFace page](https://huggingface.co/LeanQuant). Each downloaded model is a `.safetensors` file. For example, download the 4-bit `Llama-3.1-8B-Instruct` from [this link](https://huggingface.co/LeanQuant/Llama-3.1-8B-Instruct-nu-4bit/resolve/main/model.safetensors) or with the command `wget https://huggingface.co/LeanQuant/Llama-3.1-8B-Instruct-nu-4bit/resolve/main/model.safetensors`. 51 | 4. The model can now be loaded for inference using the following script: 52 | ```python 53 | from leanquant import LeanQuantModelForCausalLM 54 | 55 | model = LeanQuantModelForCausalLM.from_pretrained( 56 | "", 57 | "", 58 | bits=, 59 | device_map="auto" 60 | ) 61 | ``` 62 | 63 | **A Complete Example:** The following script shows how to run inference with a 4-bit `Llama-3.1-8B-Instruct` (with the model downloaded to `./model.safetensors`): 64 | ```python 65 | import torch 66 | from leanquant import LeanQuantModelForCausalLM 67 | from transformers import AutoTokenizer 68 | 69 | ### Load model and tokenizer 70 | base_model_name = "meta-llama/Llama-3.1-8B-Instruct" 71 | model = LeanQuantModelForCausalLM.from_pretrained( 72 | base_model_name, 73 | "./model.safetensors", 74 | bits=4, 75 | device_map="auto" 76 | ) 77 | model.eval() 78 | tokenizer = AutoTokenizer.from_pretrained(base_model_name) 79 | 80 | ### Tokenize prompt 81 | prompt = [ 82 | {"role": "system", "content": "You are a helpful assistant, that responds as a pirate."}, 83 | {"role": "user", "content": "What is quantization for deep learning models?"}, 84 | ] 85 | inputs = tokenizer.apply_chat_template( 86 | prompt, 87 | tokenize=True, 88 | add_generation_prompt=True, 89 | return_tensors="pt", 90 | return_dict=True, 91 | ).to(model.device) 92 | 93 | ### Run generation and decode generated tokens 94 | with torch.no_grad(): 95 | output = model.generate(**inputs, do_sample=True, max_new_tokens=256) 96 | 97 | generated_text = tokenizer.decode(output[0], skip_special_tokens=False) 98 | print(generated_text) 99 | ``` 100 | 101 | ## 🦁 Model Zoo 102 | 103 | Explore our collection of pre-quantized models for efficient deployment. 104 | 105 | | Base Model Name | Quantized Bits| Download Link | 106 | |-------------------------------------------------|---------------|----------------------------------------------------------------| 107 | | meta-llama/Meta-Llama-3-8B | 4-bit | [Download](https://huggingface.co/LeanQuant/Meta-Llama-3-8B-nu-4bit/resolve/main/model.safetensors) | 108 | | meta-llama/Meta-Llama-3-8B | 3-bit | [Download](https://huggingface.co/LeanQuant/Meta-Llama-3-8B-nu-3bit/resolve/main/model.safetensors) | 109 | | meta-llama/Meta-Llama-3-8B | 2-bit | [Download](https://huggingface.co/LeanQuant/Meta-Llama-3-8B-nu-2bit/resolve/main/model.safetensors) | 110 | | meta-llama/Llama-2-7b-hf | 4-bit | [Download](https://huggingface.co/LeanQuant/Llama-2-7b-nu-4bit/resolve/main/model.safetensors) | 111 | | meta-llama/Llama-2-7b-hf | 3-bit | [Download](https://huggingface.co/LeanQuant/Llama-2-7b-nu-3bit/resolve/main/model.safetensors) | 112 | | meta-llama/Llama-2-7b-hf | 2-bit | [Download](https://huggingface.co/LeanQuant/Llama-2-7b-nu-2bit/resolve/main/model.safetensors) | 113 | | meta-llama/Llama-2-13b-hf | 4-bit | [Download](https://huggingface.co/LeanQuant/Llama-2-13b-nu-4bit/resolve/main/model.safetensors) | 114 | | meta-llama/Llama-2-13b-hf | 3-bit | [Download](https://huggingface.co/LeanQuant/Llama-2-13b-nu-3bit/resolve/main/model.safetensors) | 115 | | meta-llama/Llama-2-13b-hf | 2-bit | [Download](https://huggingface.co/LeanQuant/Llama-2-13b-nu-2bit/resolve/main/model.safetensors) | 116 | | mistralai/Mistral-7B-v0.1 | 4-bit | [Download](https://huggingface.co/LeanQuant/Mistral-7B-v0.1-nu-4bit/resolve/main/model.safetensors) | 117 | | mistralai/Mistral-7B-v0.1 | 3-bit | [Download](https://huggingface.co/LeanQuant/Mistral-7B-v0.1-nu-3bit/resolve/main/model.safetensors) | 118 | | mistralai/Mistral-7B-v0.1 | 2-bit | [Download](https://huggingface.co/LeanQuant/Mistral-7B-v0.1-nu-2bit/resolve/main/model.safetensors) | 119 | | huggyllama/llama-13b | 4-bit | [Download](https://huggingface.co/LeanQuant/llama-13b-nu-4bit/resolve/main/model.safetensors) | 120 | | huggyllama/llama-13b | 3-bit | [Download](https://huggingface.co/LeanQuant/llama-13b-nu-3bit/resolve/main/model.safetensors) | 121 | | huggyllama/llama-13b | 2-bit | [Download](https://huggingface.co/LeanQuant/llama-13b-nu-2bit/resolve/main/model.safetensors) | 122 | | meta-llama/Meta-Llama-3-8B-Instruct | 4-bit | [Download](https://huggingface.co/LeanQuant/Meta-Llama-3-8B-Instruct-nu-4bit/resolve/main/model.safetensors) | 123 | | meta-llama/Llama-3.1-8B | 4-bit | [Download](https://huggingface.co/LeanQuant/Llama-3.1-8B-nu-4bit/resolve/main/model.safetensors) | 124 | | meta-llama/Llama-3.1-8B-Instruct | 4-bit | [Download](https://huggingface.co/LeanQuant/Llama-3.1-8B-Instruct-nu-4bit/resolve/main/model.safetensors) | 125 | | meta-llama/Llama-3.1-70B | 4-bit | [Download](https://huggingface.co/LeanQuant/Llama-3.1-70B-nu-4bit/resolve/main/model.safetensors) | 126 | | meta-llama/Llama-3.3-70B-Instruct | 4-bit | [Download](https://huggingface.co/LeanQuant/Llama-3.3-70B-Instruct-nu-4bit/resolve/main/model.safetensors) | 127 | 128 | 🚀 More models coming soon! 129 | 130 | ## 📌 How to Quantize and Evaluate a Model 131 | 132 | Follow these steps to quantize and evaluate a large language model. 133 | 134 | ### Requirements 135 | 136 | - At least **one CUDA-enabled GPU** is required for quantization and evaluation. 137 | - A **Linux environment** is recommended. 138 | 139 | ### Setup 140 | 141 | 1. **Clone the Repository** 142 | ```bash 143 | git clone https://github.com/LeanModels/LeanQuant.git 144 | cd LeanQuant 145 | ``` 146 | 147 | 2. **[Optional] Create a Conda Environment** 148 | ```bash 149 | conda create -n leanquant python=3.10 150 | conda activate leanquant 151 | ``` 152 | 153 | 3. **Install Dependencies** 154 | 155 | - Install [PyTorch](https://pytorch.org/get-started/locally/). 156 | - Install [CuPy](https://docs.cupy.dev/en/stable/install.html) based on your CUDA version. 157 | ```bash 158 | # For CUDA 11.x 159 | pip install cupy-cuda11x 160 | 161 | # For CUDA 12.x 162 | pip install cupy-cuda12x 163 | ``` 164 | 4. **Install Additional Requirements** 165 | 166 | ```bash 167 | pip install -r requirements.txt 168 | ``` 169 | 170 | ### Quantizing Models 171 | 172 | We currently support the Llama and Mistral family models (non-VLM, non-MOE). To quantize a model using **LeanQuant**, run the following command: 173 | 174 | ```bash 175 | python llama.py \ 176 | --new-eval \ 177 | --wbits 4 \ 178 | --nsamples 128 \ 179 | --true-sequential --act-order \ 180 | --percdamp 0.1 \ 181 | --exponent 4 \ 182 | --save_path .safetensors 183 | ``` 184 | 185 | **Parameter Explanation:** 186 | 187 | | Parameter | Description | 188 | |--------------------------------|-------------| 189 | | `` | The HuggingFace model name or local model path to quantize. Example: `meta-llama/Llama-3.1-8B-Instruct`. | 190 | | `` | Calibration dataset for quantization. Choices: `wikitext2`, `ptb`, `c4`, `c4-new` (recommended: `c4-new`). | 191 | | `--new-eval` | Enables new evaluation mode for perplexity testing. | 192 | | `--wbits` | Bit-width for quantization. Choices: `4`, `3`, or `2`. | 193 | | `--nsamples` | Number of calibration samples. Recommended: `128`, `256`, or `512`. | 194 | | `--true-sequential` & `--act-order` | Improves quantized model quality. Recommended to enable. | 195 | | `--percdamp` | Dampening applied to the Hessian. Recommended: `0.1` or `0.01`. | 196 | | `--exponent` | Strength parameter for preserving outliers in quantization (`p` in the paper). Recommended: `3` or `4`. | 197 | | `--save_path` | Path and filename to save the quantized model. | 198 | 199 | **Example:** 200 | 201 | To quantize `meta-llama/Llama-3.1-8B-Instruct` to 4-bit precision, run: 202 | 203 | ```bash 204 | python llama.py meta-llama/Llama-3.1-8B-Instruct c4-new \ 205 | --new-eval \ 206 | --wbits 4 \ 207 | --nsamples 128 \ 208 | --true-sequential --act-order \ 209 | --percdamp 0.1 \ 210 | --exponent 4 \ 211 | --save_path Llama-3.1-8B-Instruct-4bit.safetensors 212 | ``` 213 | 214 | ### Quantizing Very Large Models 215 | 216 | LeanQuant enables efficient quantization of Llama-3.1-405B using either two 48GB GPUs or a single 80GB GPU. Use the following command to quantize the 405B model. 217 | ```bash 218 | PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True;CUDA_VISIBLE_DEVICES=0,1 python llama.py meta-llama/Llama-3.1-405B-Instruct \ 219 | c4-new --new-eval \ 220 | --wbits 4 --nsamples 64 \ 221 | --true-sequential --act-order \ 222 | --percdamp 0.1 \ 223 | --exponent 4.0 \ 224 | --offload_threshold 53248 \ 225 | --save_path Llama-3.1-405B-Instruct-4bit.safetensors 226 | ``` 227 | The `--offload_threshold` argument helps prevent out-of-memory errors by offloading Hessian matrix computation to the second GPU. `--offload_threshold 53248` offloads Hessian computation to the second GPU if the matrix dimension is greater than or equal to 53248. If your GPU has enough memory (80GB or above), you can disable Hessian offloading and use a single GPU with `--offload_threshold 1000000`. 228 | 229 | ### Evaluating Quantized Models 230 | 231 | To evaluate a quantized model using **LeanQuant**, run the following command: 232 | 233 | ```bash 234 | python eval_quantized.py \ 235 | --base_model_name_or_path \ 236 | --leanquant_path \ 237 | --bits 4 \ 238 | --tasks mmlu ai2_arc lambada hellaswag winogrande piqa \ 239 | --eval_batch_size 4 240 | ``` 241 | 242 | **Parameter Explanation:** 243 | 244 | | Parameter | Description | 245 | |--------------------------------|-------------| 246 | | `--base_model_name_or_path` | Name or path of the original Hugging Face model that was quantized. | 247 | | `--leanquant_path` | Path to the `.safetensors` file of the quantized model. | 248 | | `--bits` | Bit-width of the quantized model. Choices: `4`, `3`, `2`. | 249 | | `--tasks` | Benchmark tasks from `lm-eval`. | 250 | | `--eval_batch_size` | Batch size for evaluation. Adjust based on available GPU memory. | 251 | 252 | **Example:** 253 | 254 | To evaluate a 4-bit quantized `meta-llama/Llama-3.1-8B-Instruct` on `mmlu` and `hellaswag`, run: 255 | 256 | ```bash 257 | python eval_quantized.py \ 258 | --base_model_name_or_path meta-llama/Llama-3.1-8B-Instruct \ 259 | --leanquant_path Llama-3.1-8B-Instruct-4bit.safetensors \ 260 | --bits 4 \ 261 | --tasks mmlu hellaswag \ 262 | --eval_batch_size 4 263 | ``` 264 | 265 | # Acknowledgements 266 | 267 | This code repository is based on [GPTQ](https://github.com/IST-DASLab/gptq). We thank the authors for their wonderful work. 268 | 269 | # Citation 270 | 271 | If you found our work useful or interesting, please kindly cite us: 272 | ``` 273 | @inproceedings{ 274 | zhang2025leanquant, 275 | title={LeanQuant: Accurate and Scalable Large Language Model Quantization with Loss-error-aware Grid}, 276 | author={Tianyi Zhang and Anshumali Shrivastava}, 277 | booktitle={The Thirteenth International Conference on Learning Representations}, 278 | year={2025}, 279 | url={https://openreview.net/forum?id=ISqx8giekS} 280 | } 281 | ``` --------------------------------------------------------------------------------