├── .gitignore ├── TritonBench.png ├── setup.py ├── pyproject.toml ├── CHANGELOG.md ├── src └── gptq_triton │ ├── utils.py │ ├── __init__.py │ ├── fused_attention.py │ ├── custom_autotune.py │ ├── quant_linear.py │ └── fused_mlp.py ├── setup.cfg ├── datautils.py ├── generate.py ├── ppl.py ├── benchmark_generate.py ├── README.md ├── gptq.py ├── LICENSE ├── quantize.py ├── Verify.ipynb └── original_quant.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.json 3 | *.jsonl 4 | /build 5 | /dist 6 | *.egg-info -------------------------------------------------------------------------------- /TritonBench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fpgaminer/GPTQ-triton/HEAD/TritonBench.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import setuptools 3 | 4 | if __name__ == "__main__": 5 | setuptools.setup() -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 48", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [0.0.4] - 2023-04-29 4 | 5 | - Fixed an interaction between the fused QKV projection and the key-value cache that caused excessive memory usage. 6 | 7 | 8 | ## [0.0.3] - 2023-04-28 9 | 10 | - Disabled cache in `ppl.py`; isn't used and saves memory. 11 | - Added more benchmarks to README. 12 | - Fixed bug in `generate.py`; generated sequence length was not calculated correctly. 13 | 14 | 15 | ## [0.0.3] - 2023-04-19 16 | 17 | - Added support for groupsize. 18 | - Note: fuse_mlp is not recommended for groupsize != -1. It is now disabled automatically during loading if the model has grouping, unless fuse_mlp is explictly set to True. This is a result of the current kernel implementation being slower than the naive implementation for groupsize != -1. 19 | - Added a warning if `act_order` and `groupsize` are used together. They are not compatible. -------------------------------------------------------------------------------- /src/gptq_triton/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import triton 4 | 5 | 6 | def matmul4_kernel_config_pruner(configs, nargs): 7 | """ 8 | The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. 9 | """ 10 | m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16) 11 | n = max(2 ** int(math.ceil(math.log2(nargs['N']))), 16) 12 | k = max(2 ** int(math.ceil(math.log2(nargs['K']))), 16) 13 | 14 | used = set() 15 | for config in configs: 16 | block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) 17 | block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) 18 | block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) 19 | group_size_m = config.kwargs['GROUP_SIZE_M'] 20 | 21 | if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: 22 | continue 23 | 24 | used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) 25 | yield triton.Config({'BLOCK_SIZE_M': block_size_m, 'BLOCK_SIZE_N': block_size_n, 'BLOCK_SIZE_K': block_size_k, 'GROUP_SIZE_M': group_size_m}, num_stages=config.num_stages, num_warps=config.num_warps) -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = gptq_triton 3 | version = 0.0.3 4 | author = fpgaminer 5 | author_email = fpgaminer@bitcoin-mining.com 6 | description = Fast GPTQ kernels written in Triton 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown; charset=UTF-8 9 | url = https://github.com/fpgaminer/GPTQ-triton 10 | keywords = gptq, triton, torch, cuda, gpu, quantization, quantize, quantized, inference, deep learning, machine learning 11 | license = Apache License 2.0 12 | license_file = LICENSE 13 | classifiers = 14 | Development Status :: 3 - Alpha 15 | License :: OSI Approved :: Apache Software License 16 | Intended Audience :: Developers 17 | Programming Language :: Python :: 3 18 | Programming Language :: Python :: 3.6 19 | Programming Language :: Python :: 3.7 20 | Programming Language :: Python :: 3.8 21 | Programming Language :: Python :: 3.9 22 | Programming Language :: Python :: 3.10 23 | Topic :: Scientific/Engineering :: Artificial Intelligence 24 | Topic :: Software Development :: Libraries :: Python Modules 25 | 26 | [options] 27 | zip_safe = False 28 | include_package_data = False 29 | package_dir = 30 | = src 31 | packages = find: 32 | python_requires = >=3.6 33 | install_requires = 34 | triton >= 2.0.0 35 | torch >= 2.0.0 36 | transformers 37 | 38 | [options.packages.find] 39 | where = src -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from datasets import load_dataset 4 | 5 | 6 | def get_dataset(dataset_name: str, tokenizer, nsamples: int, seed: int, seqlen: int): 7 | if dataset_name == "wikitext-2": 8 | return get_wikitext2(nsamples, seed, seqlen, tokenizer) 9 | elif dataset_name == 'ptb': 10 | return get_ptb(nsamples, seed, seqlen, tokenizer, jointext='\n\n') 11 | elif dataset_name == 'ptb-new': 12 | return get_ptb(nsamples, seed, seqlen, tokenizer, jointext=' ') 13 | elif dataset_name == 'c4': 14 | return get_c4(nsamples, seed, seqlen, tokenizer) 15 | else: 16 | raise ValueError(f"Unknown dataset {dataset_name}") 17 | 18 | 19 | def get_wikitext2(nsamples: int, seed: int, seqlen: int, tokenizer, jointext: str = '\n\n'): 20 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 21 | 22 | trainenc = tokenizer(jointext.join(traindata['text']), return_tensors='pt') 23 | 24 | rng = random.Random(seed) 25 | trainloader = (rng.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) for _ in range(nsamples)) 26 | trainloader = [trainenc.input_ids[:, i:i+seqlen] for i in trainloader] 27 | 28 | return trainloader 29 | 30 | 31 | def get_ptb(nsamples: int, seed: int, seqlen: int, tokenizer, jointext: str): 32 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 33 | 34 | trainenc = tokenizer(jointext.join(traindata['sentence']), return_tensors='pt') 35 | 36 | rng = random.Random(seed) 37 | trainloader = (rng.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) for _ in range(nsamples)) 38 | trainloader = [trainenc.input_ids[:, i:i+seqlen] for i in trainloader] 39 | 40 | return trainloader 41 | 42 | 43 | def get_c4(nsamples: int, seed: int, seqlen: int, tokenizer): 44 | # WARNING: Many of the files in the allenai/c4 repo are marked as "Unsafe" by HuggingFace, possibly containing a virus. This particular file is not, and I doubt it's an issue, but worth noting. 45 | traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 46 | 47 | rng = random.Random(seed) 48 | 49 | trainloader = [] 50 | for _ in range(nsamples): 51 | while True: 52 | i = rng.randint(0, len(traindata) - 1) 53 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 54 | if trainenc.input_ids.shape[1] >= seqlen: 55 | break 56 | 57 | i = rng.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 58 | inp = trainenc.input_ids[:, i:i + seqlen] 59 | trainloader.append(inp) 60 | 61 | return trainloader -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Example of how to use the quantized model to generate text. 4 | """ 5 | import argparse 6 | import time 7 | 8 | import torch 9 | from gptq_triton import load_quant 10 | from transformers import AutoTokenizer, LlamaForCausalLM 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model', type=str, help='Path to model, either a HuggingFace model or a quantized model') 15 | parser.add_argument('--quant', action='store_true', help='Whether the model is quantized') 16 | parser.add_argument('--prompt', type=str, default='The quick brown fox', help='Prompt to use for generation') 17 | parser.add_argument('--max-length', type=int, default=2048, help='Maximum length of generated text') 18 | parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for generation') 19 | parser.add_argument('--top-k', type=int, default=0, help='Top-k for generation') 20 | parser.add_argument('--top-p', type=float, default=0.0, help='Top-p for generation') 21 | parser.add_argument('--repetition-penalty', type=float, default=1.0, help='Repetition penalty for generation') 22 | 23 | 24 | def main(): 25 | args = parser.parse_args() 26 | 27 | if not args.quant: 28 | model = get_llama(args.model) 29 | model.eval() 30 | model.to('cuda') 31 | else: 32 | model = load_quant(args.model) 33 | model.eval() 34 | model.to('cuda') 35 | 36 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 37 | 38 | encoded_prompt = tokenizer.encode(args.prompt, add_special_tokens=False, return_tensors='pt').to('cuda') 39 | 40 | start_time = time.time() 41 | output_sequences = model.generate( 42 | input_ids=encoded_prompt, 43 | max_length=args.max_length + len(encoded_prompt[0]), 44 | temperature=args.temperature, 45 | top_k=args.top_k, 46 | top_p=args.top_p, 47 | repetition_penalty=args.repetition_penalty, 48 | do_sample=True, 49 | num_return_sequences=1, 50 | ) 51 | end_time = time.time() 52 | 53 | if len(output_sequences.shape) > 2: 54 | output_sequences.squeeze_() 55 | 56 | total_tokens_generated = 0 57 | 58 | for generated_sequence in output_sequences: 59 | generated_sequence = generated_sequence.tolist() 60 | total_tokens_generated += len(generated_sequence) - len(encoded_prompt[0]) 61 | 62 | text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) 63 | 64 | total_sequence = ( 65 | args.prompt + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)):] 66 | ) 67 | 68 | print(total_sequence) 69 | 70 | print() 71 | print(f'Generation took {end_time - start_time:.2f} seconds') 72 | print(f'Total tokens generated: {total_tokens_generated}') 73 | print(f'Average generation speed: {total_tokens_generated / (end_time - start_time):.2f} tokens per second') 74 | 75 | 76 | def get_llama(model: str): 77 | """ 78 | Load a pretrained Llama model 79 | """ 80 | def skip(*args, **kwargs): 81 | pass 82 | # NOTE: This is a nasty hack, but it speeds up model building by a huge amount 83 | old_inits = (torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_) 84 | torch.nn.init.kaiming_uniform_ = skip 85 | torch.nn.init.uniform_ = skip 86 | torch.nn.init.normal_ = skip 87 | 88 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') 89 | model.seqlen = 2048 90 | 91 | # Restore the old initializers 92 | torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_ = old_inits 93 | 94 | return model 95 | 96 | 97 | if __name__ == '__main__': 98 | main() -------------------------------------------------------------------------------- /src/gptq_triton/__init__.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import torch 7 | import transformers 8 | from transformers import LlamaConfig, LlamaForCausalLM 9 | 10 | from . import fused_mlp, quant_linear 11 | from .fused_attention import QuantLlamaAttention, make_quant_attn 12 | from .fused_mlp import QuantLlamaMLP, make_fused_mlp 13 | from .quant_linear import QuantLinear, make_quant, triton_matmul4 14 | 15 | 16 | def load_quant(checkpoint: str, warmup_autotune: bool = True, device: Optional[str] = 'cuda', fuse_mlp: Optional[bool] = None): 17 | """ 18 | Load a quantized model from a checkpoint. 19 | Args: 20 | checkpoint: Path to the checkpoint directory. 21 | warmup_autotune: If True, run a warmup autotune pass. Otherwise autotune will run during forward passes. 22 | device: Device to run the model on; needed if warmup_autotune is True. 23 | fuse_mlp: If True, replace the MLP layers with fused versions. If None, will apply fuse_mlp if the model's groupsize is -1, otherwise fuse_mlp will be disabled (it's slower when using grouping). 24 | Returns: 25 | The loaded model. 26 | """ 27 | quant_config = json.load(open(Path(checkpoint) / 'quant_config.json')) 28 | wbits = quant_config['wbits'] 29 | groupsize = quant_config['groupsize'] 30 | 31 | # Load the model config 32 | config = LlamaConfig.from_pretrained(checkpoint) 33 | def noop(*args, **kwargs): 34 | pass 35 | # NOTE: This is a nasty hack, but it speeds up creation of the model by a huge amount. 36 | old_init = (torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_) 37 | torch.nn.init.kaiming_uniform_ = noop 38 | torch.nn.init.uniform_ = noop 39 | torch.nn.init.normal_ = noop 40 | 41 | # Build the model 42 | # TODO: Is this needed? 43 | torch.set_default_dtype(torch.half) 44 | old_init_weights = transformers.modeling_utils._init_weights 45 | transformers.modeling_utils._init_weights = False 46 | torch.set_default_dtype(torch.half) 47 | model = LlamaForCausalLM(config) 48 | torch.set_default_dtype(torch.float) 49 | 50 | # Restore the original init functions 51 | (torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_) = old_init 52 | transformers.modeling_utils._init_weights = old_init_weights 53 | 54 | # Swap out linear layers for quantized ones 55 | make_quant(model, wbits, groupsize) 56 | 57 | # Load the quantized checkpoint 58 | print('Loading model ...') 59 | if (Path(checkpoint) / 'model.safetensors').exists(): 60 | from safetensors.torch import load_file as safe_load 61 | model.load_state_dict(safe_load(Path(checkpoint) / 'model.safetensors')) 62 | elif (Path(checkpoint) / 'model.pt').exists(): 63 | model.load_state_dict(torch.load(Path(checkpoint) / 'model.pt'), strict=False) 64 | else: 65 | raise FileNotFoundError(f"Could not find model checkpoint at {checkpoint}; please ensure that the path is correct and contains a `model.pt` or `model.safetensors` file.") 66 | 67 | # Go through all the QuantLinear layers and if their bias is all zeros, set it to None 68 | for name, m in model.named_modules(): 69 | if isinstance(m, QuantLinear): 70 | if m.bias is not None and (m.bias == 0).all(): 71 | m.bias = None 72 | #print(f"Removed bias from {name}") 73 | 74 | make_quant_attn(model) 75 | 76 | if fuse_mlp == True or (fuse_mlp is None and groupsize == -1): 77 | make_fused_mlp(model) 78 | 79 | # Move the model to the correct device 80 | if device is not None: 81 | model = model.to(device) 82 | 83 | # Warm up the autotune cache 84 | if warmup_autotune: 85 | if device is None: 86 | raise ValueError("You must specify a device when warmup_autotune is True.") 87 | 88 | autotune_warmup(model) 89 | 90 | model.seqlen = 2048 91 | print('Done.') 92 | 93 | return model 94 | 95 | 96 | def autotune_warmup(model): 97 | """ 98 | The Triton kernels autotune themselves for specific input sizes. But this takes time. 99 | This function collects information on all possible input sizes for the different kernels 100 | and then runs them through the autotuner. 101 | The intended use is to run this on startup so the autotuner doesn't have to run during 102 | actual inference. 103 | """ 104 | from tqdm import tqdm 105 | 106 | warmups = itertools.chain(quant_linear.autotune_warmup(model), fused_mlp.autotune_warmup(model)) 107 | warmups = list(warmups) 108 | 109 | print('Warming up autotune cache ...') 110 | with torch.no_grad(): 111 | for m in tqdm(range(0, 12)): 112 | m = 2 ** m 113 | for func in warmups: 114 | func(m) -------------------------------------------------------------------------------- /ppl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | 4 | import torch 5 | import torch.nn as nn 6 | from datasets import load_dataset 7 | from gptq_triton import load_quant 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, LlamaForCausalLM 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--model', type=str, help='Path to model, either a HuggingFace model or a quantized model') 15 | parser.add_argument('--quant', action='store_true', help='Whether the model is quantized') 16 | parser.add_argument('--stride', type=int, default=512, help='Stride for calculating perplexity') 17 | parser.add_argument('--context-length', type=int, default=2048, help='Length of context to use') 18 | 19 | 20 | def main(): 21 | args = parser.parse_args() 22 | 23 | if not args.quant: 24 | model = get_llama(args.model) 25 | model.eval() 26 | model.to('cuda') 27 | else: 28 | model = load_quant(args.model) 29 | model.eval() 30 | model.to('cuda') 31 | 32 | # NOTE: Setting use_fast=False for now, as the alternative was an order of magnitude slower on a recent `transformers` commit 33 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 34 | context_length = model.seqlen if args.context_length is None else args.context_length 35 | 36 | for dataset in ['wikitext-2', 'ptb', 'c4']: 37 | ppl = calculate_perplexity(model, tokenizer, dataset, max_length=context_length, stride=args.stride) 38 | print(f"{dataset} perplexity: {ppl}") 39 | 40 | 41 | def get_llama(model: str): 42 | """ 43 | Load a pretrained Llama model 44 | """ 45 | def skip(*args, **kwargs): 46 | pass 47 | # NOTE: This is a nasty hack, but it speeds up model building by a huge amount 48 | old_inits = (torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_) 49 | torch.nn.init.kaiming_uniform_ = skip 50 | torch.nn.init.uniform_ = skip 51 | torch.nn.init.normal_ = skip 52 | 53 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') 54 | model.seqlen = 2048 55 | 56 | # Restore the old initializers 57 | torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_ = old_inits 58 | 59 | return model 60 | 61 | 62 | def get_dataset(dataset_name: str, tokenizer) -> torch.Tensor: 63 | if dataset_name == "wikitext-2": 64 | test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 65 | encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt").input_ids 66 | elif dataset_name == 'ptb': 67 | test = load_dataset("ptb_text_only", 'penn_treebank', split="validation") 68 | encodings = tokenizer("\n\n".join(test["sentence"]), return_tensors="pt").input_ids 69 | elif dataset_name == 'c4': 70 | # WARNING: Many of the files in the allenai/c4 repo are marked as "Unsafe" by HuggingFace, possibly containing a virus. This particular file is not, and I doubt it's an issue, but worth noting. 71 | test = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 72 | encodings = [tokenizer(x, return_tensors="pt").input_ids for x in test['text'][:1000]] 73 | encodings = torch.cat(encodings, dim=1) 74 | else: 75 | raise ValueError(f"Unknown dataset {dataset_name}") 76 | 77 | return encodings 78 | 79 | 80 | def calculate_perplexity(model, tokenizer, dataset: str, max_length: int, stride: int = 512) -> float: 81 | print("Loading dataset...") 82 | encodings = get_dataset(dataset, tokenizer) 83 | seq_len = encodings.size(1) 84 | 85 | print("Calculating perplexity...") 86 | print(f"Sequence length: {seq_len}") 87 | print(f"Max length: {max_length}") 88 | print(f"Stride: {stride}") 89 | 90 | nlls = [] 91 | prev_end_loc = 0 92 | 93 | for begin_loc in (pbar := tqdm(range(0, seq_len - 1, stride))): 94 | end_loc = min(seq_len - 1, begin_loc + max_length) 95 | trg_len = end_loc - prev_end_loc # How many tokens we want to predict 96 | input_ids = encodings[:, begin_loc:end_loc+1].to('cuda') # +1 for the labels 97 | 98 | with torch.no_grad(): 99 | # Ask the model for logits 100 | # NOTE: Instead of calling HF's model wrapper, we call the model directly to hopefully cut down on some memory overhead 101 | outputs = model.model(input_ids[:, :-1], use_cache=False) 102 | logits = model.lm_head(outputs[0][..., -trg_len:, :]) 103 | 104 | # The last trg_len tokens are the labels 105 | labels = input_ids[:, -trg_len:].contiguous() 106 | 107 | # Compute the NLL for this batch using flattened logits and labels 108 | loss_fct = nn.CrossEntropyLoss() 109 | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 110 | 111 | nlls.append(loss.to('cpu').to(torch.float32)) 112 | ppl = torch.exp(torch.stack(nlls).mean()) 113 | pbar.set_description(f"Perplexity: {ppl:.2f}") 114 | 115 | prev_end_loc = end_loc 116 | if end_loc == (seq_len - 1): 117 | break 118 | 119 | ppl = torch.exp(torch.stack(nlls).mean()) 120 | 121 | return ppl 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /src/gptq_triton/fused_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from .quant_linear import QuantLinear 7 | from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, LlamaConfig 8 | 9 | 10 | def make_quant_attn(model): 11 | """ 12 | Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. 13 | """ 14 | for name, m in model.named_modules(): 15 | if not isinstance(m, LlamaAttention): 16 | continue 17 | 18 | q_proj = m.q_proj 19 | k_proj = m.k_proj 20 | v_proj = m.v_proj 21 | 22 | qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) 23 | qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) 24 | scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) 25 | 26 | qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, bias=False) 27 | qkv_layer.qweight = qweights 28 | qkv_layer.qzeros = qzeros 29 | qkv_layer.scales = scales 30 | qkv_layer.bias = None 31 | 32 | attn = QuantLlamaAttention(m.config, qkv_layer, m.o_proj, m.rotary_emb) 33 | 34 | if '.' in name: 35 | parent_name = name.rsplit('.', 1)[0] 36 | child_name = name[len(parent_name) + 1:] 37 | parent = model.get_submodule(parent_name) 38 | else: 39 | parent_name = '' 40 | parent = model 41 | child_name = name 42 | 43 | #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") 44 | 45 | setattr(parent, child_name, attn) 46 | 47 | 48 | class QuantLlamaAttention(nn.Module): 49 | """ 50 | Modified version of LlamaAttention that fuses the q, k, v projections. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | config: LlamaConfig, 56 | qkv_proj, 57 | o_proj, 58 | rotary_emb, 59 | ): 60 | super().__init__() 61 | self.config = config 62 | self.hidden_size = config.hidden_size 63 | self.num_heads = config.num_attention_heads 64 | self.head_dim = self.hidden_size // self.num_heads 65 | self.max_position_embeddings = config.max_position_embeddings 66 | 67 | if (self.head_dim * self.num_heads) != self.hidden_size: 68 | raise ValueError( 69 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 70 | f" and `num_heads`: {self.num_heads})." 71 | ) 72 | self.qkv_proj = qkv_proj 73 | self.o_proj = o_proj 74 | self.rotary_emb = rotary_emb 75 | 76 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 77 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 78 | 79 | def forward( 80 | self, 81 | hidden_states: torch.Tensor, 82 | attention_mask: Optional[torch.Tensor] = None, 83 | position_ids: Optional[torch.LongTensor] = None, 84 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 85 | output_attentions: bool = False, 86 | use_cache: bool = False, 87 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 88 | """Input shape: Batch x Time x Channel""" 89 | bsz, q_len, _ = hidden_states.size() 90 | 91 | qkv_states = self.qkv_proj(hidden_states) 92 | query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) 93 | 94 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 95 | key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 96 | value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 97 | 98 | kv_seq_len = key_states.shape[-2] 99 | if past_key_value is not None: 100 | kv_seq_len += past_key_value[0].shape[-2] 101 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 102 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 103 | # [bsz, nh, t, hd] 104 | 105 | if past_key_value is not None: 106 | # reuse k, v, self_attention 107 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 108 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 109 | 110 | if use_cache: 111 | # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor 112 | # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. 113 | query_states = query_states.contiguous() 114 | key_states = key_states.contiguous() 115 | value_states = value_states.contiguous() 116 | 117 | past_key_value = (key_states, value_states) if use_cache else None 118 | 119 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 120 | 121 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 122 | raise ValueError( 123 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 124 | f" {attn_weights.size()}" 125 | ) 126 | 127 | if attention_mask is not None: 128 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 129 | raise ValueError( 130 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 131 | ) 132 | attn_weights = attn_weights + attention_mask 133 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) 134 | 135 | # upcast attention to fp32 136 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 137 | attn_output = torch.matmul(attn_weights, value_states) 138 | 139 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 140 | raise ValueError( 141 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 142 | f" {attn_output.size()}" 143 | ) 144 | 145 | attn_output = attn_output.transpose(1, 2) 146 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 147 | 148 | attn_output = self.o_proj(attn_output) 149 | 150 | if not output_attentions: 151 | attn_weights = None 152 | 153 | return attn_output, attn_weights, past_key_value -------------------------------------------------------------------------------- /benchmark_generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Benchmarks the generation speed of a model. While Benchmark.ipynb provides nice detailed performance data, it measures the kernels in isolation. 4 | This script measures "real world" performance by running the whole model in generation mode. 5 | It tests a grid of prompt lengths and generation lengths, and saves the timing results to `results.json`. 6 | """ 7 | import argparse 8 | import itertools 9 | import json 10 | import os 11 | import random 12 | import time 13 | 14 | import original_quant 15 | import torch 16 | import transformers 17 | from gptq_triton import load_quant 18 | from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--model', type=str, help='Path to model, either a HuggingFace model or a quantized model') 23 | parser.add_argument('--quant', action='store_true', help='Whether the model is quantized') 24 | parser.add_argument('--cuda', type=str, help='Whether to use the old CUDA kernel and format; this must be set to the path to the CUDA quantized model, and --model must be set to a HF model') 25 | parser.add_argument('--average', type=int, default=10, help='Number of times to run each test to get an average') 26 | 27 | 28 | def main(): 29 | args = parser.parse_args() 30 | 31 | if args.cuda: 32 | model = load_cuda_quant(args.model, args.cuda, 4, -1) 33 | model.eval() 34 | model.to('cuda') 35 | elif not args.quant: 36 | model = get_llama(args.model) 37 | model.eval() 38 | model.to('cuda') 39 | else: 40 | model = load_quant(args.model) 41 | model.eval() 42 | model.to('cuda') 43 | 44 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 45 | 46 | prompt_lengths = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] 47 | max_lengths = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] 48 | 49 | lengths = set(itertools.product(prompt_lengths, max_lengths)) 50 | 51 | # Remove lengths that we've already tested 52 | if os.path.exists('results.jsonl'): 53 | with open('results.jsonl', 'r') as f: 54 | for line in f: 55 | line = json.loads(line) 56 | key = (line['prompt_length'], line['max_length']) 57 | if key in lengths: 58 | lengths.remove(key) 59 | 60 | # Shuffle the lengths so that we don't always test in the same order and get caching effects 61 | lengths = list(lengths) 62 | random.shuffle(lengths) 63 | 64 | # TODO: For some reason the first run is always slow, so we run it once before the benchmark to warm things up 65 | encoded_prompt = tokenizer.encode("TODO", add_special_tokens=False, return_tensors='pt').to('cuda') 66 | _ = model.generate( 67 | input_ids=encoded_prompt, 68 | max_length=8, 69 | do_sample=True, 70 | num_return_sequences=1, 71 | suppress_tokens=[model.generation_config.eos_token_id], 72 | ) 73 | 74 | # Run the remaining benchmarks 75 | with open('results.jsonl', 'a') as f: 76 | for prompt_length, max_length in lengths: 77 | print(f'Prompt length: {prompt_length}, max length: {max_length}') 78 | 79 | results = [] 80 | 81 | for _ in range(args.average): 82 | # Generate a long random string 83 | # We do this every time to avoid caching effects 84 | prompt = ''.join(random.choice('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,;:!?') for _ in range(2048 * 10)) 85 | 86 | # Encode and crop down 87 | encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt') 88 | encoded_prompt = encoded_prompt[:, :prompt_length] 89 | encoded_prompt = encoded_prompt.to('cuda') 90 | 91 | start_time = time.time() 92 | _ = model.generate( 93 | input_ids=encoded_prompt, 94 | max_length=max_length + prompt_length, 95 | do_sample=True, 96 | num_return_sequences=1, 97 | suppress_tokens=[model.generation_config.eos_token_id], # This prevents the sampler from ending early; it must generate max_length tokens 98 | ) 99 | end_time = time.time() 100 | 101 | gen_time = end_time - start_time 102 | speed = max_length / gen_time 103 | 104 | results.append((gen_time, speed)) 105 | 106 | # Compute the average 107 | avg_time = sum(t for t, _ in results) / len(results) 108 | avg_speed = (max_length * len(results)) / sum(t for t, _ in results) 109 | 110 | print(f'Average generation time: {avg_time:.2f} seconds') 111 | print(f'Average generation speed: {avg_speed:.2f} tokens per second') 112 | print() 113 | 114 | f.write(json.dumps({ 115 | 'prompt_length': prompt_length, 116 | 'max_length': max_length, 117 | 'average_time': avg_time, 118 | 'average_speed': avg_speed, 119 | 'runs': results, 120 | })) 121 | f.write("\n") 122 | f.flush() 123 | 124 | 125 | def get_llama(model: str): 126 | """ 127 | Load a pretrained Llama model 128 | """ 129 | def skip(*args, **kwargs): 130 | pass 131 | # NOTE: This is a nasty hack, but it speeds up model building by a huge amount 132 | old_inits = (torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_) 133 | torch.nn.init.kaiming_uniform_ = skip 134 | torch.nn.init.uniform_ = skip 135 | torch.nn.init.normal_ = skip 136 | 137 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') 138 | model.seqlen = 2048 139 | 140 | # Restore the old initializers 141 | torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_ = old_inits 142 | 143 | return model 144 | 145 | 146 | def load_cuda_quant(model, checkpoint, wbits, groupsize): 147 | """ 148 | Load a quantized model using the old CUDA kernel 149 | """ 150 | config = LlamaConfig.from_pretrained(model) 151 | def noop(*args, **kwargs): 152 | pass 153 | original_inits = (torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_) 154 | torch.nn.init.kaiming_uniform_ = noop 155 | torch.nn.init.uniform_ = noop 156 | torch.nn.init.normal_ = noop 157 | 158 | torch.set_default_dtype(torch.half) 159 | original_init_weights = transformers.modeling_utils._init_weights 160 | transformers.modeling_utils._init_weights = False 161 | torch.set_default_dtype(torch.half) 162 | model = LlamaForCausalLM(config) 163 | torch.set_default_dtype(torch.float) 164 | 165 | transformers.modeling_utils._init_weights = original_init_weights 166 | torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_ = original_inits 167 | 168 | model = model.eval() 169 | layers = original_quant.find_layers(model) 170 | for name in ['lm_head']: 171 | if name in layers: 172 | del layers[name] 173 | original_quant.make_quant(model, layers, wbits, groupsize, faster=False) 174 | 175 | del layers 176 | 177 | print('Loading model ...') 178 | if checkpoint.endswith('.safetensors'): 179 | from safetensors.torch import load_file as safe_load 180 | model.load_state_dict(safe_load(checkpoint)) 181 | else: 182 | model.load_state_dict(torch.load(checkpoint)) 183 | model.seqlen = 2048 184 | print('Done.') 185 | 186 | return model 187 | 188 | 189 | if __name__ == '__main__': 190 | main() -------------------------------------------------------------------------------- /src/gptq_triton/custom_autotune.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. 3 | """ 4 | import builtins 5 | import math 6 | import time 7 | from typing import Dict, List, Optional 8 | 9 | import triton 10 | 11 | 12 | class Autotuner(triton.KernelInterface): 13 | def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: Optional[List[str]] = None): 14 | ''' 15 | :param prune_configs_by: a dict of functions that are used to prune configs, fields: 16 | 'perf_model': performance model used to predicate running time with different configs, returns running time 17 | 'top_k': number of configs to bench 18 | 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. 19 | 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results, and which ones 20 | ''' 21 | if not configs: 22 | self.configs = [triton.Config({}, num_warps=4, num_stages=2)] 23 | else: 24 | self.configs = configs 25 | self.key_idx = [arg_names.index(k) for k in key] 26 | self.nearest_power_of_two = set(nearest_power_of_two) if nearest_power_of_two is not None else set() 27 | self.nearest_power_of_two = [i for i in range(len(self.key_idx)) if key[i] in self.nearest_power_of_two] 28 | self.cache = {} 29 | # hook to reset all required tensor to zeros before relaunching a kernel 30 | self.hook = lambda args: 0 31 | if reset_to_zero is not None: 32 | self.reset_idx = [arg_names.index(k) for k in reset_to_zero] 33 | 34 | def _hook(args): 35 | for i in self.reset_idx: 36 | args[i].zero_() 37 | self.hook = _hook 38 | self.arg_names = arg_names 39 | # prune configs 40 | if prune_configs_by: 41 | perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] 42 | if 'early_config_prune' in prune_configs_by: 43 | early_config_prune = prune_configs_by['early_config_prune'] 44 | else: 45 | perf_model, top_k, early_config_prune = None, None, None 46 | self.perf_model, self.configs_top_k = perf_model, top_k 47 | self.early_config_prune = early_config_prune 48 | self.fn = fn 49 | self.record_detailed_timings = False 50 | self.detailed_timings = {} 51 | 52 | def _bench(self, *args, config, **meta): 53 | # check for conflicts, i.e. meta-parameters both provided 54 | # as kwargs and by the autotuner 55 | conflicts = meta.keys() & config.kwargs.keys() 56 | if conflicts: 57 | raise ValueError( 58 | f"Conflicting meta-parameters: {', '.join(conflicts)}." 59 | " Make sure that you don't re-define auto-tuned symbols." 60 | ) 61 | # augment meta-parameters with tunable ones 62 | current = dict(meta, **config.kwargs) 63 | 64 | def kernel_call(): 65 | if config.pre_hook: 66 | config.pre_hook(self.nargs) 67 | self.hook(args) 68 | self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) 69 | try: 70 | # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses 71 | # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default 72 | return triton.testing.do_bench(kernel_call, rep=40) 73 | except triton.OutOfResources: 74 | return float('inf') 75 | 76 | def run(self, *args, **kwargs): 77 | self.nargs = dict(zip(self.arg_names, args)) 78 | if len(self.configs) > 1: 79 | key = list(args[i] for i in self.key_idx) 80 | 81 | # This reduces the amount of autotuning by rounding the keys to the nearest power of two 82 | # In my testing this gives decent results, and greatly reduces the amount of tuning required 83 | for i in self.nearest_power_of_two: 84 | key[i] = 2 ** int(math.log2(key[i]) + 0.5) 85 | key = tuple(key) 86 | 87 | if key not in self.cache: 88 | # prune configs 89 | pruned_configs = self.prune_configs(kwargs) 90 | bench_start = time.time() 91 | timings = {config: self._bench(*args, config=config, **kwargs) 92 | for config in pruned_configs} 93 | bench_end = time.time() 94 | self.bench_time = bench_end - bench_start 95 | self.cache[key] = builtins.min(timings, key=timings.get) 96 | self.hook(args) 97 | self.configs_timings = timings 98 | 99 | if self.record_detailed_timings: 100 | self.detailed_timings[key] = timings 101 | config = self.cache[key] 102 | else: 103 | config = self.configs[0] 104 | self.best_config = config 105 | if config.pre_hook is not None: 106 | config.pre_hook(self.nargs) 107 | return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) 108 | 109 | def prune_configs(self, kwargs): 110 | pruned_configs = self.configs 111 | if self.early_config_prune: 112 | pruned_configs = self.early_config_prune(self.configs, self.nargs) 113 | if self.perf_model: 114 | top_k = self.configs_top_k 115 | if isinstance(top_k, float) and top_k <= 1.0: 116 | top_k = int(len(self.configs) * top_k) 117 | if len(pruned_configs) > top_k: 118 | est_timing = { 119 | config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, 120 | num_warps=config.num_warps) 121 | for config in pruned_configs 122 | } 123 | pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] 124 | return pruned_configs 125 | 126 | def warmup(self, *args, **kwargs): 127 | self.nargs = dict(zip(self.arg_names, args)) 128 | for config in self.prune_configs(kwargs): 129 | self.fn.warmup( 130 | *args, 131 | num_warps=config.num_warps, 132 | num_stages=config.num_stages, 133 | **kwargs, 134 | **config.kwargs, 135 | ) 136 | self.nargs = None 137 | 138 | 139 | def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): 140 | """ 141 | Decorator for auto-tuning a :code:`triton.jit`'d function. 142 | .. highlight:: python 143 | .. code-block:: python 144 | @triton.autotune(configs=[ 145 | triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), 146 | triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), 147 | ], 148 | key=['x_size'] # the two above configs will be evaluated anytime 149 | # the value of x_size changes 150 | ) 151 | @triton.jit 152 | def kernel(x_ptr, x_size, **META): 153 | BLOCK_SIZE = META['BLOCK_SIZE'] 154 | :note: When all the configurations are evaluated, the kernel will run multiple time. 155 | This means that whatever value the kernel updates will be updated multiple times. 156 | To avoid this undesired behavior, you can use the `reset_to_zero` argument, which 157 | reset the value of the provided tensor to `zero` before running any configuration. 158 | :param configs: a list of :code:`triton.Config` objects 159 | :type configs: list[triton.Config] 160 | :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. 161 | :type key: list[str] 162 | :param prune_configs_by: a dict of functions that are used to prune configs, fields: 163 | 'perf_model': performance model used to predicate running time with different configs, returns running time 164 | 'top_k': number of configs to bench 165 | 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. 166 | :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. 167 | :type reset_to_zero: list[str] 168 | """ 169 | def decorator(fn): 170 | return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) 171 | 172 | return decorator -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPTQ-triton 2 | 3 | This is my attempt at implementing a Triton kernel for GPTQ inference. This code is based on the [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa) codebase, which is itself based on the [GPTQ](https://github.com/IST-DASLab/gptq) codebase. 4 | 5 | ``` 6 | @article{frantar-gptq, 7 | title={{GPTQ}: Accurate Post-training Compression for Generative Pretrained Transformers}, 8 | author={Elias Frantar and Saleh Ashkboos and Torsten Hoefler and Dan Alistarh}, 9 | year={2022}, 10 | journal={arXiv preprint arXiv:2210.17323} 11 | } 12 | ``` 13 | 14 | ## Installation 15 | 16 | `pip install .` 17 | 18 | 19 | ## Motivation 20 | 21 | As of today (2023-03-27) the CUDA kernels in the aforementioned codebases do not scale well with context length, running up to 10x slower when the context is large versus the equivilent FP16 model. To solve this I'm implementing the inference kernel in Triton, which should allow for much better scaling. 22 | 23 | The implementation is based around the matmul tutorial from the Triton documentation. The main difference is decoding the quantized weights before performing each sub-block of the matrix multiplication. 24 | 25 | Fusing of the FF layers and QKV matrix are also applied. 26 | 27 | 28 | ## Performance 29 | 30 | This benchmark was run on a 3090 using the `benchmark_generate.py` script. 31 | 32 | ![Triton benchmark graph](TritonBench.png) 33 | 34 | 35 | ## Accuracy (PPL) 36 | 37 | The following results were obtained using the `ppl.py` script with a stride of 512 and a context length of 2048. 38 | For the 4bit CUDA results, a custom version of `ppl.py` was used, as the current script is dedicated to the Triton kernel convensions. 39 | it/s numbers are from a 3090. 40 | 41 | 42 | | [LLaMA-7B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | it/s | Wikitext2 | PTB | C4 | 43 | | -------------------------------------------------- | ---- | ---------- | ----------- | ---- | --------- | ----- | ---- | 44 | | FP16 | 16 | - | 17373 | 1.64 | 5.04 | 7.85 | 6.99 | 45 | | GPTQ CUDA | 4 | -1 | 8805 | 0.11 | 5.44 | 8.24 | - | 46 | | GPTQ Triton | 4 | -1 | 6323 | 1.70 | 5.44 | 8.24 | 7.48 | 47 | 48 | 49 | | [LLaMA-13B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | it/s | Wikitext2 | PTB | C4 | 50 | | -------------------------------------------------- | ---- | ---------- | ----------- | ---- | --------- | ----- | ---- | 51 | | FP16 | 16 | - | 31633 | - | 4.52 | 7.19 | 6.66 | 52 | | GPTQ Triton | 4 | -1 | 10325 | 0.95 | 4.74 | 7.49 | 7.00 | 53 | | GPTQ Triton | 4 | 128 | 9547 | 0.92 | 4.67 | 7.38 | 6.99 | 54 | 55 | 56 | | [LLaMA-30B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | it/s | Wikitext2 | PTB | C4 | 57 | | -------------------------------------------------- | ---- | ---------- | ----------- | ---- | --------- | ----- | ---- | 58 | | FP16 | 16 | - | 72491 | - | 3.61 | 6.50 | 6.07 | 59 | | GPTQ Triton | 4 | -1 | 19989 | 0.40 | 3.89 | 6.71 | 6.44 | 60 | | GPTQ Triton | 4 | 128 | 20055 | 0.39 | 3.79 | 6.70 | 6.34 | 61 | | GPTQ Triton | 4 | 512 | 19547 | 0.39 | 3.78 | 6.62 | 6.30 | 62 | 63 | 64 | ## Requirements 65 | 66 | See `setup.cfg`, but note that a nightly `transformers` is preferred right now. v4.28.1 might work. Known working `transformers` commit is `28f26c107b4a1c5c7e32ed4d9575622da0627a40`. 67 | 68 | 69 | ## Quantizing a model 70 | 71 | The `quantize.py` script is used to quantize a HuggingFace model. Example usage: 72 | 73 | `./quantize.py --model --dataset c4 --wbits 4 --groupsize -1 --act-order --true-sequential --save ` 74 | 75 | Arguments: 76 | 77 | * `--model`: Path to a HF FP16 model 78 | * `--dataset`: Dataset to use for calibration. Can be `wikitext-2`, `ptb`, `ptb-new` or `c4`. 79 | * `--seed`: Seed for sampling the calibration data. 80 | * `--nsamples`: Number of calibration data samples. 81 | * `--percdamp`: Percent of the average Hessian diagonal to use for dampening (default 0.01). 82 | * `--wbits`: Number of bits to use for quantization. 83 | * `--groupsize`: Groupsize to use for quantization; default (-1) uses full row. 84 | * `--save`: Save quantized result to this folder. 85 | * `--safetensors`: Save using the safetensors format. 86 | * `--act-order`: Use activation order quantization. 87 | * `--true-sequential`: Use true sequential quantization. 88 | 89 | **NOTE:** The Triton kernel is currently only implemented for 4-bits. 90 | 91 | ### Explanation of `groupsize` 92 | 93 | The GPTQ quantization algorithm gets applied to `nn.Linear`, `nn.Conv2d`, and `transformers.Conv1d` layers. (NOTE: `quantize.py` currently only supports LLaMA like models, and thus only `nn.Linear` layers are quantized, and `lm_head` is skipped.) Each matrix is quantized into a quantized weight matrix, quantized zeros, and float16 scale (bias is not quantized). During matmul, the weights are decoded using the formula `w = (w - z - 1) * s`. 94 | 95 | Scales and zeros are per-outfeature, so when there is no grouping, scales and zeros would be `1xOutfeatures`. That means that each row of the matrix (i.e. along the infeatures dimension) is quantized using the same scalar scale and zero. When grouping is used, each row is split into `groupsize` values, and each group is quantized using its own scalar scale and zero. This means that the scales and zeros are `(Infeatures//groupsize)xOutfeatures`. 96 | 97 | Groupsize provides a tradeoff. Lower groupsizes offer more granularity to the quantization and thus less loss of accuracy, but decrease the memory savings offered by quantization. 98 | 99 | ### Explanation of `nsamples` and `dataset` 100 | nsamples and dataset effect the calibration data. This input data is fed through the network during quantization to calibrate the algorithm. See the GPTQ paper for more detail. 101 | 102 | ### Explanation of `true-sequential` 103 | Models are quantized sequentially, one "layer" at a time. For example, LLaMA 7B has 32 layers, starting after the input embedding, following by the head. Each of these layers has many different Linear modules that can be quantized. Without the `true-sequential` flag, these Linear modules will be quantized in an arbitrary order. With the `true-sequential` flag, the Linear modules will be quantized in the order they would be encountered during a forward pass. This can provide an accuracy boost. 104 | 105 | ### Explanation of `act-order` 106 | 107 | I don't know. Looking at the code (in `gptq.py`), it seems to re-order the matrix before quantization based on the `argsort` of the estimated `H`. The order in which the matrix columns of the matrix are quantized might have an impact on final accuracy. `act-order` was introduced by the GPTQ authors to improve accuracy when quantizing "small" models like LLaMA 7B. 108 | 109 | 110 | ## Files 111 | 112 | * `benchmark_generate.py` - A script for benchmarking generation speed at different prompt lengths and generation lengths. 113 | 114 | * `Benchmark.ipynb` - A notebook for benchmarking the Triton kernel against the CUDA kernel and FP16. 115 | 116 | * `quantize.py` - A script for quantizing a model. 117 | 118 | * `generate.py` - An example script for generating text from a model. Example usage: `./generate.py --model --quant --prompt "Write a story about a duck: Once upon a time there was a duck" --temperature 0.6 --top-p 0.6 --repetition-penalty 1.1` 119 | 120 | * `ppl.py` - A script for calculating the perplexity of a model against wikitext2, PTB, and C4. This is useful for verifying correctness of the Triton kernel, comparing it to the CUDA kernel and the original FP16 model. 121 | 122 | * `Verify.ipynb` - A notebook for verifying the correctness of the Triton kernel. 123 | -------------------------------------------------------------------------------- /gptq.py: -------------------------------------------------------------------------------- 1 | # Copied from: https://github.com/IST-DASLab/gptq 2 | import math 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import transformers 8 | 9 | 10 | DEBUG = False 11 | 12 | torch.backends.cuda.matmul.allow_tf32 = False 13 | torch.backends.cudnn.allow_tf32 = False 14 | 15 | 16 | class GPTQ: 17 | def __init__(self, layer): 18 | self.layer = layer 19 | self.dev = self.layer.weight.device 20 | W = layer.weight.data.clone() 21 | if isinstance(self.layer, nn.Conv2d): 22 | W = W.flatten(1) 23 | if isinstance(self.layer, transformers.Conv1D): 24 | W = W.t() 25 | self.rows = W.shape[0] 26 | self.columns = W.shape[1] 27 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 28 | self.nsamples = 0 29 | 30 | def add_batch(self, inp, out): 31 | if DEBUG: 32 | self.inp1 = inp 33 | self.out1 = out 34 | if len(inp.shape) == 2: 35 | inp = inp.unsqueeze(0) 36 | tmp = inp.shape[0] 37 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 38 | if len(inp.shape) == 3: 39 | inp = inp.reshape((-1, inp.shape[-1])) 40 | inp = inp.t() 41 | if isinstance(self.layer, nn.Conv2d): 42 | unfold = nn.Unfold( 43 | self.layer.kernel_size, 44 | dilation=self.layer.dilation, 45 | padding=self.layer.padding, 46 | stride=self.layer.stride 47 | ) 48 | inp = unfold(inp) 49 | inp = inp.permute([1, 0, 2]) 50 | inp = inp.flatten(1) 51 | self.H *= self.nsamples / (self.nsamples + tmp) 52 | self.nsamples += tmp 53 | # inp = inp.float() 54 | inp = math.sqrt(2 / self.nsamples) * inp.float() 55 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 56 | self.H += inp.matmul(inp.t()) 57 | 58 | def fasterquant( 59 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False 60 | ): 61 | W = self.layer.weight.data.clone() 62 | if isinstance(self.layer, nn.Conv2d): 63 | W = W.flatten(1) 64 | if isinstance(self.layer, transformers.Conv1D): 65 | W = W.t() 66 | W = W.float() 67 | 68 | tick = time.time() 69 | 70 | if not self.quantizer.ready(): 71 | self.quantizer.find_params(W, weight=True) 72 | 73 | H = self.H 74 | del self.H 75 | dead = torch.diag(H) == 0 76 | H[dead, dead] = 1 77 | W[:, dead] = 0 78 | 79 | if actorder: 80 | perm = torch.argsort(torch.diag(H), descending=True) 81 | W = W[:, perm] 82 | H = H[perm][:, perm] 83 | 84 | Losses = torch.zeros_like(W) 85 | Q = torch.zeros_like(W) 86 | 87 | damp = percdamp * torch.mean(torch.diag(H)) 88 | diag = torch.arange(self.columns, device=self.dev) 89 | H[diag, diag] += damp 90 | H = torch.linalg.cholesky(H) 91 | H = torch.cholesky_inverse(H) 92 | H = torch.linalg.cholesky(H, upper=True) 93 | Hinv = H 94 | 95 | scale = [] 96 | zero = [] 97 | now_idx = 1 98 | 99 | for i1 in range(0, self.columns, blocksize): 100 | i2 = min(i1 + blocksize, self.columns) 101 | count = i2 - i1 102 | 103 | W1 = W[:, i1:i2].clone() 104 | Q1 = torch.zeros_like(W1) 105 | Err1 = torch.zeros_like(W1) 106 | Losses1 = torch.zeros_like(W1) 107 | Hinv1 = Hinv[i1:i2, i1:i2] 108 | 109 | for i in range(count): 110 | w = W1[:, i] 111 | d = Hinv1[i, i] 112 | 113 | if groupsize != -1: 114 | if (i1 + i) % groupsize == 0: 115 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) 116 | 117 | if ((i1 + i) // groupsize) - now_idx == -1: 118 | scale.append(self.quantizer.scale) 119 | zero.append(self.quantizer.zero) 120 | now_idx += 1 121 | 122 | q = quantize( 123 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 124 | ).flatten() 125 | Q1[:, i] = q 126 | Losses1[:, i] = (w - q) ** 2 / d ** 2 127 | 128 | err1 = (w - q) / d 129 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 130 | Err1[:, i] = err1 131 | 132 | Q[:, i1:i2] = Q1 133 | Losses[:, i1:i2] = Losses1 / 2 134 | 135 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 136 | 137 | if DEBUG: 138 | self.layer.weight.data[:, :i2] = Q[:, :i2] 139 | self.layer.weight.data[:, i2:] = W[:, i2:] 140 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 141 | print(torch.sum(Losses)) 142 | 143 | torch.cuda.synchronize() 144 | print('time %.2f' % (time.time() - tick)) 145 | print('error', torch.sum(Losses).item()) 146 | 147 | if actorder: 148 | invperm = torch.argsort(perm) 149 | Q = Q[:, invperm] 150 | 151 | if isinstance(self.layer, transformers.Conv1D): 152 | Q = Q.t() 153 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 154 | if DEBUG: 155 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 156 | 157 | if scale == []: 158 | scale.append(self.quantizer.scale) 159 | zero.append(self.quantizer.zero) 160 | scale = torch.cat(scale,dim=1) 161 | zero = torch.cat(zero,dim=1) 162 | return scale,zero 163 | 164 | def free(self): 165 | if DEBUG: 166 | self.inp1 = None 167 | self.out1 = None 168 | self.H = None 169 | self.Losses = None 170 | self.Trace = None 171 | torch.cuda.empty_cache() 172 | 173 | 174 | def quantize(x, scale, zero, maxq): 175 | if maxq < 0: 176 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 177 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 178 | return scale * (q - zero) 179 | 180 | 181 | class Quantizer(nn.Module): 182 | def __init__(self, shape=1): 183 | super(Quantizer, self).__init__() 184 | self.register_buffer('maxq', torch.tensor(0)) 185 | self.register_buffer('scale', torch.zeros(shape)) 186 | self.register_buffer('zero', torch.zeros(shape)) 187 | 188 | def configure( 189 | self, 190 | bits, perchannel=False, sym=True, 191 | mse=False, norm=2.4, grid=100, maxshrink=.8, 192 | trits=False 193 | ): 194 | 195 | self.maxq = torch.tensor(2 ** bits - 1) 196 | self.perchannel = perchannel 197 | self.sym = sym 198 | self.mse = mse 199 | self.norm = norm 200 | self.grid = grid 201 | self.maxshrink = maxshrink 202 | if trits: 203 | self.maxq = torch.tensor(-1) 204 | 205 | def find_params(self, x, weight=False): 206 | dev = x.device 207 | self.maxq = self.maxq.to(dev) 208 | 209 | shape = x.shape 210 | if self.perchannel: 211 | if weight: 212 | x = x.flatten(1) 213 | else: 214 | if len(shape) == 4: 215 | x = x.permute([1, 0, 2, 3]) 216 | x = x.flatten(1) 217 | if len(shape) == 3: 218 | x = x.reshape((-1, shape[-1])).t() 219 | if len(shape) == 2: 220 | x = x.t() 221 | else: 222 | x = x.flatten().unsqueeze(0) 223 | 224 | tmp = torch.zeros(x.shape[0], device=dev) 225 | xmin = torch.minimum(x.min(1)[0], tmp) 226 | xmax = torch.maximum(x.max(1)[0], tmp) 227 | 228 | if self.sym: 229 | xmax = torch.maximum(torch.abs(xmin), xmax) 230 | tmp = xmin < 0 231 | if torch.any(tmp): 232 | xmin[tmp] = -xmax[tmp] 233 | tmp = (xmin == 0) & (xmax == 0) 234 | xmin[tmp] = -1 235 | xmax[tmp] = +1 236 | 237 | if self.maxq < 0: 238 | self.scale = xmax 239 | self.zero = xmin 240 | else: 241 | self.scale = (xmax - xmin) / self.maxq 242 | if self.sym: 243 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 244 | else: 245 | self.zero = torch.round(-xmin / self.scale) 246 | 247 | if self.mse: 248 | best = torch.full([x.shape[0]], float('inf'), device=dev) 249 | for i in range(int(self.maxshrink * self.grid)): 250 | p = 1 - i / self.grid 251 | xmin1 = p * xmin 252 | xmax1 = p * xmax 253 | scale1 = (xmax1 - xmin1) / self.maxq 254 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 255 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 256 | q -= x 257 | q.abs_() 258 | q.pow_(self.norm) 259 | err = torch.sum(q, 1) 260 | tmp = err < best 261 | if torch.any(tmp): 262 | best[tmp] = err[tmp] 263 | self.scale[tmp] = scale1[tmp] 264 | self.zero[tmp] = zero1[tmp] 265 | if not self.perchannel: 266 | if weight: 267 | tmp = shape[0] 268 | else: 269 | tmp = shape[1] if len(shape) != 3 else shape[2] 270 | self.scale = self.scale.repeat(tmp) 271 | self.zero = self.zero.repeat(tmp) 272 | 273 | if weight: 274 | shape = [-1] + [1] * (len(shape) - 1) 275 | self.scale = self.scale.reshape(shape) 276 | self.zero = self.zero.reshape(shape) 277 | return 278 | if len(shape) == 4: 279 | self.scale = self.scale.reshape((1, -1, 1, 1)) 280 | self.zero = self.zero.reshape((1, -1, 1, 1)) 281 | if len(shape) == 3: 282 | self.scale = self.scale.reshape((1, 1, -1)) 283 | self.zero = self.zero.reshape((1, 1, -1)) 284 | if len(shape) == 2: 285 | self.scale = self.scale.unsqueeze(0) 286 | self.zero = self.zero.unsqueeze(0) 287 | 288 | def quantize(self, x): 289 | if self.ready(): 290 | return quantize(x, self.scale, self.zero, self.maxq) 291 | return x 292 | 293 | def enabled(self): 294 | return self.maxq > 0 295 | 296 | def ready(self): 297 | return torch.all(self.scale != 0) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /quantize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Based on https://github.com/IST-DASLab/gptq 3 | # Quantize a model using the GPTQ algorithm. 4 | import argparse 5 | import json 6 | from pathlib import Path 7 | import shutil 8 | import time 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | from datautils import get_dataset 14 | from gptq import GPTQ, Quantizer 15 | import gptq 16 | from gptq_triton import QuantLinear, quant_linear 17 | from tqdm import tqdm 18 | from transformers import AutoTokenizer, LlamaForCausalLM 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument('--model', type=str, required=True, help='llama model to load') 24 | parser.add_argument('--dataset', type=str, choices=['wikitext-2', 'ptb', 'ptb-new', 'c4'], required=True, help='Where to extract calibration data from.') 25 | parser.add_argument('--seed',type=int, default=0, help='Seed for sampling the calibration data.') 26 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.') 27 | parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') 28 | parser.add_argument('--wbits', type=int, required=True, choices=[2, 4, 8], help='#bits to use for quantization.') 29 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') 30 | parser.add_argument('--save', type=str, required=True, help='Save quantized result to this folder.') 31 | parser.add_argument('--safetensors', action='store_true', help='Whether to save tensors in safetensors format.') 32 | parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') 33 | parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') 34 | parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.') 35 | 36 | 37 | def main(): 38 | args = parser.parse_args() 39 | args.save = Path(args.save) 40 | 41 | if args.act_order and args.groupsize != -1: 42 | raise ValueError('Cannot use act_order and groupsize together') 43 | 44 | print('Loading model...') 45 | model = get_llama(args.model) 46 | model.eval() 47 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 48 | 49 | print('Loading data...') 50 | dataloader = get_dataset(args.dataset, tokenizer, nsamples=args.nsamples, seed=args.seed, seqlen=model.seqlen) 51 | 52 | print('Quantizing...') 53 | tick = time.time() 54 | quantizers = llama_sequential(model, dataloader, device='cuda', wbits=args.wbits, nsamples=args.nsamples, true_sequential=args.true_sequential, sym=args.sym, percdamp=args.percdamp, groupsize=args.groupsize, act_order=args.act_order) 55 | print(f"Total time: {time.time() - tick:.2f}s") 56 | 57 | print('Packing...') 58 | llama_pack(model, quantizers, args.wbits, args.groupsize) 59 | 60 | print('Saving...') 61 | args.save.mkdir(parents=True, exist_ok=True) 62 | 63 | # Save the model 64 | if args.safetensors: 65 | from safetensors.torch import save_file as safe_save 66 | safe_save(model.state_dict(), args.save / 'model.safetensors') 67 | else: 68 | torch.save(model.state_dict(), args.save / 'model.pt') 69 | 70 | # Write quant_config.json 71 | with open(args.save / 'quant_config.json', 'w') as f: 72 | f.write(json.dumps({ 73 | 'wbits': args.wbits, 74 | 'groupsize': args.groupsize, 75 | })) 76 | 77 | # Copy the config 78 | for file in ['config.json', 'generation_config.json', 'special_tokens_map.json', 'tokenizer_config.json', 'tokenizer.model']: 79 | shutil.copy(args.model + '/' + file, args.save / file) 80 | 81 | print('Done.') 82 | 83 | 84 | def get_llama(model): 85 | def skip(*args, **kwargs): 86 | pass 87 | 88 | # NOTE: This is a nasty hack, but it speeds up model building by a huge amount 89 | torch.nn.init.kaiming_uniform_ = skip 90 | torch.nn.init.uniform_ = skip 91 | torch.nn.init.normal_ = skip 92 | 93 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') 94 | model.seqlen = 2048 95 | 96 | return model 97 | 98 | 99 | @torch.no_grad() 100 | def llama_sequential(model, dataloader, device, wbits: int, nsamples: int, true_sequential: bool, sym: bool, percdamp: float, groupsize: int, act_order: bool): 101 | # Disable caching 102 | use_cache = model.config.use_cache 103 | model.config.use_cache = False 104 | 105 | # Prepare 106 | layers = model.model.layers 107 | dtype = next(iter(model.parameters())).dtype 108 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device) 109 | outs = torch.zeros_like(inps) 110 | 111 | # Move the first layer to GPU 112 | model.model.embed_tokens = model.model.embed_tokens.to(device) 113 | model.model.norm = model.model.norm.to(device) 114 | layers[0] = layers[0].to(device) 115 | 116 | # Create a dummy layer that catches the input and attention mask, and then bails 117 | # This allows us to capture all the inputs to the first layer for the calibration data 118 | cache = {'i': 0, 'attention_mask': None, 'position_ids': None} 119 | class Catcher(nn.Module): 120 | def __init__(self, module): 121 | super().__init__() 122 | self.module = module 123 | 124 | def forward(self, inp, **kwargs): 125 | inps[cache['i']] = inp 126 | cache['i'] += 1 127 | if cache['attention_mask'] is not None: 128 | assert torch.all(cache['attention_mask'] == kwargs['attention_mask']) 129 | cache['attention_mask'] = kwargs['attention_mask'] 130 | if cache['position_ids'] is not None: 131 | assert torch.all(cache['position_ids'] == kwargs['position_ids']) 132 | cache['position_ids'] = kwargs['position_ids'] 133 | raise ValueError 134 | 135 | layers[0] = Catcher(layers[0]) 136 | for batch in dataloader: 137 | try: 138 | model(batch.to(device)) 139 | except ValueError: 140 | pass 141 | layers[0] = layers[0].module 142 | 143 | # Move things back to the CPU (but not the first layer, since we'll just move it back to GPU immediately below) 144 | model.model.embed_tokens = model.model.embed_tokens.cpu() 145 | model.model.norm = model.model.norm.cpu() 146 | torch.cuda.empty_cache() 147 | 148 | attention_mask = cache['attention_mask'] 149 | position_ids = cache['position_ids'] 150 | quantizers = {} 151 | 152 | # Layers are quantized in order, and only one layer lives on the GPU at a time to save memory 153 | # Otherwise quantizing large models would be impossible (NOTE for future readers: are you enjoying your 1TB VRAM?) 154 | for i, layer in tqdm(enumerate(layers), total=len(layers)): 155 | layer = layer.to(device) 156 | full = {name: m for name, m in layer.named_modules() if isinstance(m, nn.Linear)} 157 | 158 | if true_sequential: 159 | sequential = [ 160 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], 161 | ['self_attn.o_proj'], 162 | ['mlp.up_proj', 'mlp.gate_proj'], 163 | ['mlp.down_proj'] 164 | ] 165 | else: 166 | sequential = [list(full.keys())] 167 | 168 | # For each subset of linear layers 169 | for names in sequential: 170 | subset = {n: full[n] for n in names} 171 | gptq = {} 172 | 173 | # Prepare a quantizer for each linear layer 174 | for name in subset: 175 | gptq[name] = GPTQ(subset[name]) 176 | gptq[name].quantizer = Quantizer() 177 | gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False) 178 | 179 | # Feed data to the quantizer, and save outs 180 | def add_batch(name): 181 | def tmp(_, inp, out): 182 | gptq[name].add_batch(inp[0].data, out.data) 183 | return tmp 184 | 185 | handles = [] 186 | for name in subset: 187 | handles.append(subset[name].register_forward_hook(add_batch(name))) 188 | for j in range(nsamples): 189 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] # TODO: Saving outs doesn't seem needed here? 190 | for h in handles: 191 | h.remove() 192 | 193 | # With the data collected, quantize the layers 194 | for name in subset: 195 | print(i, name) 196 | scale, zero = gptq[name].fasterquant(percdamp=percdamp, groupsize=groupsize, actorder=act_order) 197 | quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer, scale, zero) 198 | gptq[name].free() 199 | 200 | # Save outputs of the layer after quantization, so we can feed them into the next layer 201 | for j in range(nsamples): 202 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 203 | 204 | # Move the layer back to the CPU, and free up memory 205 | layers[i] = layer.cpu() 206 | del layer 207 | del gptq 208 | torch.cuda.empty_cache() 209 | 210 | # Swap buffers 211 | inps, outs = outs, inps 212 | 213 | # Restore settings 214 | model.config.use_cache = use_cache 215 | 216 | return quantizers 217 | 218 | 219 | def llama_pack(model, quantizers, wbits: int, groupsize: int): 220 | # Find all the quantized layers 221 | layers = {name: m for name, m in model.named_modules() if isinstance(m, nn.Linear)} 222 | layers = {n: layers[n] for n in quantizers} 223 | 224 | # Replace all applicable instances of Linear with QuantLinear in the model 225 | quant_linear.make_quant(model, wbits, groupsize) 226 | 227 | for name, m in tqdm(model.named_modules(), total=len(list(model.named_modules()))): 228 | if not isinstance(m, QuantLinear): 229 | continue 230 | 231 | quantizer, scale, zero = quantizers[name] 232 | quantizer, scale, zero = quantizer.cpu(), scale.cpu(), zero.cpu() 233 | pack_linear(m, layers[name].weight.data, scale, zero, m.bias) 234 | 235 | 236 | def pack_linear(quant, weights: torch.FloatTensor, scales: torch.FloatTensor, zeros, bias: Optional[torch.FloatTensor]): 237 | """ 238 | Packs the quantized weights, scales, and zero points into a QuantLinear layer 239 | """ 240 | scales = scales.t().contiguous() 241 | zeros = zeros.t().contiguous() 242 | scale_zeros = zeros * scales 243 | 244 | quant.scales = scales.clone().to(torch.float16) 245 | 246 | if quant.bias is not None: 247 | quant.bias = bias.clone().to(torch.float16) 248 | 249 | # Round weights to nearest integer based on scale and zero point 250 | # Each weight will be one int, but should not exceed quant.bits 251 | intweight = [] 252 | for idx in range(quant.infeatures): 253 | g_idx = idx // quant.groupsize 254 | # TODO: This is oddly complex. The `gptq.quantize` function does `return scale * (q - zero)`, so shouldn't 255 | # this just be `q = torch.round((weights[:,idx] / scales[g_idx]) + zero[g_idx])`? 256 | q = torch.round((weights[:,idx] + scale_zeros[g_idx]) / scales[g_idx]).to(torch.int32) 257 | intweight.append(q[:,None]) 258 | intweight = torch.cat(intweight,dim=1) 259 | intweight = intweight.t().contiguous() 260 | 261 | # Now pack the weights into uint32's 262 | #qweight = torch.zeros((intweight.shape[0] // 32 * quant.bits, intweight.shape[1]), dtype=torch.int32) 263 | quant.qweight.zero_() 264 | i = 0 265 | row = 0 266 | while row < quant.qweight.shape[0]: 267 | if quant.bits in [2,4,8]: 268 | for j in range(i, i + (32 // quant.bits)): 269 | quant.qweight[row] |= intweight[j] << (quant.bits * (j - i)) 270 | i += 32 // quant.bits 271 | row += 1 272 | else: 273 | raise NotImplementedError("Only 2,4,8 bits are supported.") 274 | 275 | # Subtract 1 from the zero point 276 | zeros = zeros - 1 277 | 278 | # Pack the zero points into uint32's 279 | zeros = zeros.to(torch.int32) 280 | #qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32) 281 | quant.qzeros.zero_() 282 | i = 0 283 | col = 0 284 | while col < quant.qzeros.shape[1]: 285 | if quant.bits in [2,4,8]: 286 | for j in range(i, i + (32 // quant.bits)): 287 | quant.qzeros[:, col] |= zeros[:, j] << (quant.bits * (j - i)) 288 | i += 32 // quant.bits 289 | col += 1 290 | else: 291 | raise NotImplementedError("Only 2,4,8 bits are supported.") 292 | 293 | 294 | @torch.no_grad() 295 | def dumbquant(layer, bits: int, groupsize: int = -1, perchannel: bool = True, sym: bool = False, mse: bool = False): 296 | """ 297 | Used to generate test data by performing a dumb quantization on the weights of a layer. 298 | Layer is modified in place. 299 | """ 300 | assert isinstance(layer, nn.Linear) 301 | quantizer = Quantizer() 302 | quantizer.configure(bits, perchannel=perchannel, sym=sym, mse=mse) 303 | 304 | W = layer.weight.data.clone() 305 | W = W.float() 306 | 307 | quantizer.find_params(W, weight=True) # TODO: Is this needed? 308 | 309 | groupsize = W.shape[1] if groupsize == -1 else groupsize 310 | scale = [] 311 | zero = [] 312 | 313 | for i in range(0, W.shape[1]): 314 | w = W[:, i] 315 | 316 | if i % groupsize == 0: 317 | quantizer.find_params(W[:, (i):(i + groupsize)], weight=True) 318 | scale.append(quantizer.scale) 319 | zero.append(quantizer.zero) 320 | 321 | q = gptq.quantize( 322 | w.unsqueeze(1), quantizer.scale, quantizer.zero, quantizer.maxq 323 | ).flatten() 324 | layer.weight.data[:, i] = q.to(layer.weight.data.dtype) 325 | 326 | scale = torch.cat(scale, dim=1) 327 | zero = torch.cat(zero, dim=1) 328 | 329 | return scale, zero 330 | 331 | 332 | if __name__ == '__main__': 333 | main() -------------------------------------------------------------------------------- /src/gptq_triton/quant_linear.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | from typing import Optional 4 | 5 | from . import custom_autotune 6 | import torch 7 | import torch.nn as nn 8 | import triton 9 | import triton.language as tl 10 | from .utils import matmul4_kernel_config_pruner 11 | 12 | 13 | def make_quant(model, bits, groupsize): 14 | """ 15 | Replace all linear layers in a model with quantized ones. 16 | Except for the lm_head, which is not quantized. 17 | """ 18 | for name, m in model.named_modules(): 19 | if not isinstance(m, nn.Linear): 20 | continue 21 | 22 | if name == 'lm_head': 23 | continue 24 | 25 | # Replace the linear layer with a quantized one 26 | qlayer = QuantLinear(bits, groupsize, m.in_features, m.out_features, m.bias is not None) 27 | parent_name = name.rsplit('.', 1)[0] 28 | parent = model.get_submodule(parent_name) 29 | 30 | #print(f"Replacing {name} with quant; parent: {parent_name}, child's name: {name[len(parent_name) + 1:]}") 31 | 32 | setattr(parent, name[len(parent_name) + 1:], qlayer) 33 | 34 | 35 | def autotune_warmup(model): 36 | # Find all the QuantLinear layers 37 | modules = (m for m in model.modules() if isinstance(m, QuantLinear)) 38 | kn_values = {(m.infeatures, m.outfeatures): (m.qweight, m.scales, m.qzeros, m.groupsize) for m in modules} 39 | 40 | print(f'QuantLinear Warmup: Found {len(kn_values)} unique KN values.') 41 | 42 | def func(m, k, qweight, scales, qzeros, groupsize): 43 | a = torch.randn(1, m, k, dtype=torch.float16, device='cuda') 44 | triton_matmul4(groupsize, a, qweight, scales, qzeros) 45 | 46 | return (functools.partial(func, k=k, qweight=qweight, scales=scales, qzeros=qzeros, groupsize=groupsize) for (k, n), (qweight, scales, qzeros, groupsize) in kn_values.items()) 47 | 48 | 49 | class QuantLinear(nn.Module): 50 | def __init__(self, bits: int, groupsize: int, infeatures: int, outfeatures: int, bias: bool): 51 | super().__init__() 52 | 53 | if bits not in [4]: 54 | raise NotImplementedError("Only 4 bits are supported.") 55 | 56 | groupsize = infeatures if groupsize == -1 else groupsize 57 | 58 | self.infeatures = infeatures 59 | self.outfeatures = outfeatures 60 | self.bits = bits 61 | self.groupsize = groupsize 62 | 63 | features_per_int = 32 // bits 64 | 65 | assert outfeatures % features_per_int == 0, "outfeatures must be a multiple of features_per_int" 66 | 67 | self.register_buffer('qweight', torch.empty((infeatures // features_per_int, outfeatures), dtype=torch.int32)) 68 | self.register_buffer('qzeros', torch.empty((math.ceil(infeatures / groupsize), outfeatures // features_per_int), dtype=torch.int32)) 69 | self.register_buffer('scales', torch.empty((math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16)) 70 | if bias: 71 | self.register_buffer('bias', torch.empty(outfeatures, dtype=torch.float16)) 72 | else: 73 | self.register_parameter('bias', None) 74 | 75 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 76 | y = triton_matmul4(self.groupsize, x, self.qweight, self.scales, self.qzeros, self.bias) 77 | return y 78 | 79 | 80 | # This Triton kernel is adapted from the Triton matmul example 81 | # It unpacks the quantized weights and then performs the matmul like usual 82 | # It operates in FP16 mode 83 | @custom_autotune.autotune( 84 | configs=[ 85 | # These weren't useful, at least on a 3090 86 | #triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), 87 | #triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), 88 | #triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), 89 | #triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), 90 | #triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), 91 | 92 | #triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 93 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 94 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 95 | #triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 96 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 97 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 98 | 99 | # These provided a benefit on a 3090 100 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 101 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), 102 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), 103 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), 104 | 105 | # From PyTorch Inductor 106 | #triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), 107 | #triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), 108 | #triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), 109 | #triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), 110 | 111 | #triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), 112 | #triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=8), 113 | #triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=8), 114 | 115 | #triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=2), 116 | ], 117 | key=['M', 'N', 'K', 'NO_GROUPS'], 118 | nearest_power_of_two=['M', 'N', 'K'], 119 | prune_configs_by={ 120 | 'early_config_prune': matmul4_kernel_config_pruner, 121 | 'perf_model': None, 122 | 'top_k': None, 123 | }, 124 | ) 125 | @triton.jit 126 | def matmul4_kernel( 127 | a_ptr, b_ptr, c_ptr, 128 | scales_ptr, zeros_ptr, 129 | M, N, K, 130 | stride_am, stride_ak, 131 | stride_bk, stride_bn, 132 | stride_cm, stride_cn, 133 | stride_scales_g, stride_scales_n, 134 | stride_zeros_g, stride_zeros_n, 135 | groupsize, NO_GROUPS: tl.constexpr, 136 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 137 | GROUP_SIZE_M: tl.constexpr, 138 | ): 139 | """ 140 | Compute the matrix multiplication C = A x B. 141 | A is of shape (M, K) float16 142 | B is of shape (K//8, N) int32 143 | C is of shape (M, N) float16 144 | scales is of shape (G, N) float16 145 | zeros is of shape (G, N//8) int32 146 | groupsize is an int specifying the size of groups for scales and zeros. 147 | G is K // groupsize. 148 | Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient. 149 | 150 | WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. 151 | WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. 152 | WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. 153 | """ 154 | pid = tl.program_id(axis=0) 155 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 156 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 157 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 158 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 159 | group_id = pid // num_pid_in_group 160 | first_pid_m = group_id * GROUP_SIZE_M 161 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 162 | pid_m = first_pid_m + (pid % group_size_m) 163 | pid_n = (pid % num_pid_in_group) // group_size_m 164 | 165 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 166 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 167 | offs_k = tl.arange(0, BLOCK_SIZE_K) 168 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 169 | a_mask = (offs_am[:, None] < M) 170 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 171 | b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 172 | scales_ptrs = scales_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,) 173 | # zeros_ptrs is set up such that it repeats elements along the N axis 8 times 174 | zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) # (BLOCK_SIZE_N,) 175 | 176 | # shifter is used to extract the 4 bits of each element in the 32-bit word from B and zeros 177 | shifter = (offs_k % 8) * 4 178 | zeros_shifter = (offs_bn % 8) * 4 179 | 180 | # If G == 1, scales and zeros are the same for all K, so we can load them once 181 | if NO_GROUPS: 182 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 183 | scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) 184 | zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 185 | 186 | # Unpack zeros 187 | zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 188 | zeros = (zeros + 1) * scales # (BLOCK_SIZE_N,) float16 189 | 190 | # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) 191 | # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension 192 | # So this loop is along the infeatures dimension (K) 193 | # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel 194 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 195 | for k in range(0, num_pid_k): 196 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 197 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 198 | 199 | if not NO_GROUPS: 200 | g_id = k // (groupsize // BLOCK_SIZE_K) 201 | ptr = scales_ptrs + g_id * stride_scales_g 202 | scales = tl.load(ptr) # (BLOCK_SIZE_N,) 203 | ptr = zeros_ptrs + g_id * stride_zeros_g # (BLOCK_SIZE_N,) 204 | zeros = tl.load(ptr) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 205 | 206 | # Unpack zeros 207 | zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 208 | zeros = (zeros + 1) * scales # (BLOCK_SIZE_N,) float16 209 | 210 | # Now we need to unpack b (which is 4-bit values) into 32-bit values 211 | b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values 212 | b = b * scales[None, :] - zeros[None, :] # Scale and shift 213 | 214 | accumulator += tl.dot(a, b) 215 | a_ptrs += BLOCK_SIZE_K * stride_ak 216 | b_ptrs += (BLOCK_SIZE_K // 8) * stride_bk 217 | 218 | c = accumulator.to(tl.float16) 219 | 220 | # Store the result 221 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 222 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 223 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 224 | c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 225 | tl.store(c_ptrs, accumulator, mask=c_mask) 226 | 227 | 228 | def triton_matmul4(groupsize: int, a: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, bias: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: 229 | """ 230 | Compute the matrix multiplication C = A x B + bias. 231 | Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. 232 | 233 | A is of shape (..., K) float16 234 | qweight is of shape (K//8, N) int32 235 | scales is of shape (G, N) float16 236 | qzeros is of shape (G, N//8) int32 237 | bias is of shape (1, N) float16 238 | 239 | groupsize is the number of infeatures in each group. 240 | G = K // groupsize 241 | 242 | Returns C of shape (..., N) float16 243 | """ 244 | assert a.shape[-1] == (qweight.shape[0] * 8), "A must be a multiple of 8 in the last dimension" 245 | assert a.is_contiguous(), "A must be contiguous" 246 | 247 | # Flatten a into (-1, K) 248 | x = a.view(-1, a.shape[-1]) 249 | 250 | M, K = x.shape 251 | N = qweight.shape[1] 252 | # This is based on the possible BLOCK_SIZE_Ks 253 | assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" 254 | # This is based on the possible BLOCK_SIZE_Ns 255 | assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" 256 | # This is based on the possible BLOCK_SIZE_Ks 257 | assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" 258 | 259 | c = torch.empty((M, N), device='cuda', dtype=torch.float16) 260 | 261 | grid = lambda META: ( 262 | triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), 263 | ) 264 | matmul4_kernel[grid]( 265 | x, qweight, c, 266 | scales, qzeros, 267 | M, N, K, 268 | x.stride(0), x.stride(1), 269 | qweight.stride(0), qweight.stride(1), 270 | c.stride(0), c.stride(1), 271 | scales.stride(0), scales.stride(1), 272 | qzeros.stride(0), qzeros.stride(1), 273 | groupsize, groupsize == K, 274 | ) 275 | 276 | # Reshape c 277 | c = c.view(a.shape[:-1] + (N,)) # (..., N) 278 | 279 | # Add bias 280 | if bias is not None: 281 | c = c + bias 282 | 283 | return c -------------------------------------------------------------------------------- /src/gptq_triton/fused_mlp.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | import triton 6 | import triton.language as tl 7 | from transformers.models.llama.modeling_llama import LlamaMLP 8 | 9 | from . import custom_autotune 10 | from .utils import matmul4_kernel_config_pruner 11 | 12 | 13 | def make_fused_mlp(m, parent_name=''): 14 | """ 15 | Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. 16 | """ 17 | if isinstance(m, LlamaMLP): 18 | return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj) 19 | 20 | for name, child in m.named_children(): 21 | child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") 22 | 23 | if isinstance(child, QuantLlamaMLP): 24 | setattr(m, name, child) 25 | #print(f"Replacing {name} with fused_mlp; parent: {parent_name}") 26 | 27 | return m 28 | 29 | 30 | def autotune_warmup(model): 31 | # Find all the QuantLlamaMLP layers 32 | modules = (m for m in model.modules() if isinstance(m, QuantLlamaMLP)) 33 | k_values = {m.infeatures: { 34 | 'gate_proj_qweight': m.gate_proj_qweight, 35 | 'gate_proj_scales': m.gate_proj_scales, 36 | 'gate_proj_qzeros': m.gate_proj_qzeros, 37 | 'up_proj_qweight': m.up_proj_qweight, 38 | 'up_proj_scales': m.up_proj_scales, 39 | 'up_proj_qzeros': m.up_proj_qzeros, 40 | 'groupsize': m.groupsize, 41 | } for m in modules} 42 | 43 | print(f'FusedMLP Warmup: Found {len(k_values)} unique K values.') 44 | 45 | def func(m, k, gate_proj_qweight, gate_proj_scales, gate_proj_qzeros, up_proj_qweight, up_proj_scales, up_proj_qzeros, groupsize): 46 | a = torch.randn(1, m, k, dtype=torch.float16, device='cuda') 47 | triton_llama_mlp_4(groupsize, a, gate_proj_qweight, gate_proj_scales, gate_proj_qzeros, up_proj_qweight, up_proj_scales, up_proj_qzeros) 48 | 49 | return (functools.partial(func, k=k, **v) for k, v in k_values.items()) 50 | 51 | 52 | class QuantLlamaMLP(nn.Module): 53 | def __init__( 54 | self, 55 | gate_proj, 56 | down_proj, 57 | up_proj, 58 | ): 59 | super().__init__() 60 | 61 | assert gate_proj.groupsize == up_proj.groupsize 62 | 63 | # Only save the quantized weights, not the QuantLinear modules 64 | # This prevents the QuantLinear autotuning warmup from considering these modules 65 | self.register_buffer('gate_proj_qweight', gate_proj.qweight) 66 | self.register_buffer('gate_proj_scales', gate_proj.scales) 67 | self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) 68 | self.register_buffer('up_proj_qweight', up_proj.qweight) 69 | self.register_buffer('up_proj_scales', up_proj.scales) 70 | self.register_buffer('up_proj_qzeros', up_proj.qzeros) 71 | self.groupsize = gate_proj.groupsize 72 | 73 | self.infeatures = gate_proj.infeatures 74 | self.outfeatures = down_proj.outfeatures 75 | 76 | self.down_proj = down_proj 77 | 78 | def forward(self, x): 79 | return self.down_proj(triton_llama_mlp_4(self.groupsize, x, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros)) 80 | 81 | 82 | # This Triton kernel fuses the gate_proj, up_proj, activation, and multiplication of LlamaMLP 83 | # It operates on quantized weights 84 | @custom_autotune.autotune( 85 | configs=[ 86 | triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 87 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 88 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 89 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 90 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), 91 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), # 3090 92 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), # 3090 93 | 94 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), # 3090 95 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), # 3090 96 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), # 3090 97 | 98 | # This configuration provides a benefit to groupsize=128, but groupsize isn't recommended for fused mlp right now 99 | #triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), # 3090 100 | ], 101 | key=['M', 'N', 'K', 'NO_GROUPS'], 102 | nearest_power_of_two=['M', 'N', 'K'], 103 | prune_configs_by={ 104 | 'early_config_prune': matmul4_kernel_config_pruner, 105 | 'perf_model': None, 106 | 'top_k': None, 107 | }, 108 | ) 109 | @triton.jit 110 | def llama_mlp_fused_4_kernel( 111 | a_ptr, c_ptr, 112 | b1_ptr, scales1_ptr, zeros1_ptr, 113 | b2_ptr, scales2_ptr, zeros2_ptr, 114 | M, N, K, 115 | stride_am, stride_ak, 116 | stride_bk, stride_bn, 117 | stride_cm, stride_cn, 118 | stride_scales_g, stride_scales_n, 119 | stride_zeros_g, stride_zeros_n, 120 | groupsize, NO_GROUPS: tl.constexpr, 121 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 122 | GROUP_SIZE_M: tl.constexpr, 123 | ): 124 | """ 125 | Computes: C = silu(A * B1) * (A * B2) 126 | A is of shape (M, K) float16 127 | B is of shape (K//8, N) int32 128 | C is of shape (M, N) float16 129 | scales is of shape (G, N) float16 130 | zeros is of shape (G, N//8) int32 131 | groupsize is an int specifying the size of groups for scales and zeros. 132 | G is K // groupsize. 133 | Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient. 134 | 135 | WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. 136 | WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. 137 | WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. 138 | """ 139 | pid = tl.program_id(axis=0) 140 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 141 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 142 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 143 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 144 | group_id = pid // num_pid_in_group 145 | first_pid_m = group_id * GROUP_SIZE_M 146 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 147 | pid_m = first_pid_m + (pid % group_size_m) 148 | pid_n = (pid % num_pid_in_group) // group_size_m 149 | 150 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 151 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 152 | offs_k = tl.arange(0, BLOCK_SIZE_K) 153 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 154 | a_mask = (offs_am[:, None] < M) 155 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 156 | b1_ptrs = b1_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 157 | b2_ptrs = b2_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 158 | scales1_ptrs = scales1_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,) 159 | scales2_ptrs = scales2_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,) 160 | # zeros_ptrs is set up such that it repeats elements along the N axis 8 times 161 | zeros1_ptrs = zeros1_ptr + (offs_bn // 8) * stride_zeros_n # (BLOCK_SIZE_N,) 162 | zeros2_ptrs = zeros2_ptr + (offs_bn // 8) * stride_zeros_n # (BLOCK_SIZE_N,) 163 | 164 | # shifter is used to extract the 4 bits of each element in the 32-bit word from B and zeros 165 | shifter = (offs_k % 8) * 4 166 | zeros_shifter = (offs_bn % 8) * 4 167 | 168 | # If G == 1, scales and zeros are the same for all K, so we can load them once 169 | if NO_GROUPS: 170 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 171 | scales1 = tl.load(scales1_ptrs) # (BLOCK_SIZE_N,) 172 | scales2 = tl.load(scales2_ptrs) # (BLOCK_SIZE_N,) 173 | zeros1 = tl.load(zeros1_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 174 | zeros2 = tl.load(zeros2_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 175 | 176 | # Unpack zeros 177 | zeros1 = (zeros1 >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 178 | zeros1 = (zeros1 + 1) * scales1 # (BLOCK_SIZE_N,) float16 179 | zeros2 = (zeros2 >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 180 | zeros2 = (zeros2 + 1) * scales2 # (BLOCK_SIZE_N,) float16 181 | 182 | # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) 183 | # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension 184 | # So this loop is along the infeatures dimension (K) 185 | # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel 186 | accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 187 | accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 188 | for k in range(0, num_pid_k): 189 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 190 | b = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 191 | 192 | if not NO_GROUPS: 193 | g_id = k // (groupsize // BLOCK_SIZE_K) 194 | scales1 = tl.load(scales1_ptrs + g_id * stride_scales_g) # (BLOCK_SIZE_N,) 195 | scales2 = tl.load(scales2_ptrs + g_id * stride_scales_g) # (BLOCK_SIZE_N,) 196 | zeros1 = tl.load(zeros1_ptrs + g_id * stride_zeros_g) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 197 | zeros2 = tl.load(zeros2_ptrs + g_id * stride_zeros_g) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 198 | 199 | # Unpack zeros 200 | zeros1 = (zeros1 >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 201 | zeros2 = (zeros2 >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 202 | zeros1 = (zeros1 + 1) * scales1 # (BLOCK_SIZE_N,) float16 203 | zeros2 = (zeros2 + 1) * scales2 # (BLOCK_SIZE_N,) float16 204 | 205 | # Now we need to unpack b (which is 4-bit values) into 32-bit values 206 | b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values 207 | b = b * scales1[None, :] - zeros1[None, :] # Scale and shift 208 | 209 | accumulator1 += tl.dot(a, b) 210 | 211 | b = tl.load(b2_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 212 | b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values 213 | b = b * scales2[None, :] - zeros2[None, :] # Scale and shift 214 | 215 | accumulator2 += tl.dot(a, b) 216 | 217 | a_ptrs += BLOCK_SIZE_K * stride_ak 218 | b1_ptrs += (BLOCK_SIZE_K // 8) * stride_bk 219 | b2_ptrs += (BLOCK_SIZE_K // 8) * stride_bk 220 | 221 | # Apply activation to accumulator1 222 | accumulator1 = silu(accumulator1) 223 | 224 | # Multiply accumulator1 and accumulator2 225 | c = accumulator1 * accumulator2 226 | #c = c.to(tl.float16) # Seems like Triton does this conversion automatically if c_ptrs is float16 227 | 228 | # Store the result 229 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 230 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 231 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 232 | c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 233 | tl.store(c_ptrs, c, mask=c_mask) 234 | 235 | 236 | @triton.jit 237 | def silu(x): 238 | return x * tl.sigmoid(x) 239 | 240 | 241 | def triton_llama_mlp_4( 242 | groupsize: int, 243 | a: torch.FloatTensor, 244 | gate_qweight: torch.IntTensor, 245 | gate_scales: torch.FloatTensor, 246 | gate_qzeros: torch.IntTensor, 247 | up_qweight: torch.IntTensor, 248 | up_scales: torch.FloatTensor, 249 | up_qzeros: torch.IntTensor, 250 | ) -> torch.FloatTensor: 251 | """ 252 | Computes: silu(gate(a)) * up(a) 253 | Where gate and up are quantized using GPTQ and groupsize = -1 into 4-bit values. 254 | 255 | A is of shape (..., K) float16 256 | *_qweight is of shape (K//8, N) int32 257 | *_scales is of shape (G, N) float16 258 | *_qzeros is of shape (G, N//8) int32 259 | 260 | groupsize is the number of infeatures in each group. 261 | G = K // groupsize 262 | 263 | Returns C of shape (..., N) float16 264 | """ 265 | assert gate_qweight.shape == up_qweight.shape and gate_scales.shape == up_scales.shape and gate_qzeros.shape == up_qzeros.shape, "All weights must have the same shape" 266 | assert a.shape[-1] == (gate_qweight.shape[0] * 8), "A must be a multiple of 8 in the last dimension" 267 | assert a.is_contiguous(), "A must be contiguous" 268 | 269 | # Flatten a into (-1, K) 270 | x = a.view(-1, a.shape[-1]) 271 | 272 | M, K = x.shape 273 | N = gate_qweight.shape[1] 274 | # This is based on the possible BLOCK_SIZE_Ks 275 | assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" 276 | # This is based on the possible BLOCK_SIZE_Ns 277 | assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" 278 | # This is based on the possible BLOCK_SIZE_Ks 279 | assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" 280 | 281 | c = torch.empty((M, N), device='cuda', dtype=torch.float16) 282 | 283 | grid = lambda META: ( 284 | triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), 285 | ) 286 | llama_mlp_fused_4_kernel[grid]( 287 | x, c, 288 | gate_qweight, gate_scales, gate_qzeros, 289 | up_qweight, up_scales, up_qzeros, 290 | M, N, K, 291 | x.stride(0), x.stride(1), 292 | gate_qweight.stride(0), gate_qweight.stride(1), 293 | c.stride(0), c.stride(1), 294 | gate_scales.stride(0), gate_scales.stride(1), 295 | gate_qzeros.stride(0), gate_qzeros.stride(1), 296 | groupsize, groupsize == K, 297 | ) 298 | 299 | # Reshape c 300 | c = c.view(a.shape[:-1] + (N,)) # (..., N) 301 | 302 | return c -------------------------------------------------------------------------------- /Verify.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Verify Correctness of GPTQ-triton\n", 9 | "\n", 10 | "This notebook verifies the correctness of the Triton kernels and other modifications." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", 21 | "import itertools\n", 22 | "\n", 23 | "import original_quant\n", 24 | "import gptq_triton\n", 25 | "import torch\n", 26 | "import torch.nn as nn\n", 27 | "from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, LlamaConfig\n", 28 | "import gptq\n", 29 | "from quantize import dumbquant, pack_linear" 30 | ] 31 | }, 32 | { 33 | "attachments": {}, 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Verify QuantLinear" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "groupsize | M | N | K | cuda - ref | triton - ref | triton - cuda |\n", 50 | " -1 | 1 | 4096 | 4096 | 0.000977 | 0.001953 | 0.001953 | \n", 51 | " -1 | 1 | 4096 | 11008 | 0.000977 | 0.001953 | 0.001953 | \n", 52 | " -1 | 1 | 11008 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 53 | " -1 | 1 | 11008 | 11008 | 0.000977 | 0.001953 | 0.000977 | \n", 54 | " -1 | 8 | 4096 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 55 | " -1 | 8 | 4096 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 56 | " -1 | 8 | 11008 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 57 | " -1 | 8 | 11008 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 58 | " -1 | 100 | 4096 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 59 | " -1 | 100 | 4096 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 60 | " -1 | 100 | 11008 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 61 | " -1 | 100 | 11008 | 11008 | 0.001953 | 0.002930 | 0.001953 | \n", 62 | " -1 | 256 | 4096 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 63 | " -1 | 256 | 4096 | 11008 | 0.001953 | 0.002930 | 0.002930 | \n", 64 | " -1 | 256 | 11008 | 4096 | 0.001953 | 0.003906 | 0.001953 | \n", 65 | " -1 | 256 | 11008 | 11008 | 0.001953 | 0.003906 | 0.002930 | \n", 66 | " -1 | 2048 | 4096 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 67 | " -1 | 2048 | 4096 | 11008 | 0.001953 | 0.002930 | 0.002930 | \n", 68 | " -1 | 2048 | 11008 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 69 | " -1 | 2048 | 11008 | 11008 | 0.001953 | 0.003906 | 0.001953 | \n", 70 | " 128 | 1 | 4096 | 4096 | 0.000977 | 0.000977 | 0.000977 | \n", 71 | " 128 | 1 | 4096 | 11008 | 0.000977 | 0.000977 | 0.000977 | \n", 72 | " 128 | 1 | 11008 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 73 | " 128 | 1 | 11008 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 74 | " 128 | 8 | 4096 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 75 | " 128 | 8 | 4096 | 11008 | 0.000977 | 0.001953 | 0.001953 | \n", 76 | " 128 | 8 | 11008 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 77 | " 128 | 8 | 11008 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 78 | " 128 | 100 | 4096 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 79 | " 128 | 100 | 4096 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 80 | " 128 | 100 | 11008 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 81 | " 128 | 100 | 11008 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 82 | " 128 | 256 | 4096 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 83 | " 128 | 256 | 4096 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 84 | " 128 | 256 | 11008 | 4096 | 0.003906 | 0.003906 | 0.001953 | \n", 85 | " 128 | 256 | 11008 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 86 | " 128 | 2048 | 4096 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 87 | " 128 | 2048 | 4096 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n", 88 | " 128 | 2048 | 11008 | 4096 | 0.001953 | 0.001953 | 0.001953 | \n", 89 | " 128 | 2048 | 11008 | 11008 | 0.001953 | 0.001953 | 0.001953 | \n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "# QuantLinear is compared against a reference and the CUDA kernel at various values of M, N and K\n", 95 | "# The reference is an FP16 simulation of the quantized weights\n", 96 | "torch.manual_seed(0)\n", 97 | "print(\"groupsize | M | N | K | cuda - ref | triton - ref | triton - cuda |\")\n", 98 | "\n", 99 | "for (groupsize, M, N, K) in itertools.product([-1, 128], [1, 8, 100, 256, 2048], [4096, 11008], [4096, 11008]):\n", 100 | "\tM = M # B * seq_len\n", 101 | "\tK = K # Input dimension\n", 102 | "\tN = N # Output dimension\n", 103 | "\n", 104 | "\tlayer = nn.Linear(K, N, bias=False) # Llama doesn't use bias\n", 105 | "\tvec = torch.randn(1, M, K, device='cuda', dtype=torch.float16)\n", 106 | "\n", 107 | "\tscales, zeros = dumbquant(layer, 4, groupsize=groupsize)\n", 108 | "\n", 109 | "\tcudalayer = original_quant.QuantLinear(4, groupsize, layer.in_features, layer.out_features)\n", 110 | "\tcudalayer.pack(layer, scales.clone(), zeros.clone())\n", 111 | "\n", 112 | "\ttritonlayer = gptq_triton.QuantLinear(4, groupsize, layer.in_features, layer.out_features, bias=False)\n", 113 | "\tpack_linear(tritonlayer, layer.weight.data, scales, zeros, None)\n", 114 | "\n", 115 | "\tlayer = layer.half()\n", 116 | "\n", 117 | "\tlayer = layer.to('cuda')\n", 118 | "\tcudalayer = cudalayer.to('cuda')\n", 119 | "\ttritonlayer = tritonlayer.to('cuda')\n", 120 | "\n", 121 | "\tref = layer(vec)\n", 122 | "\tcuda_out = cudalayer(vec)\n", 123 | "\ttriton_out = tritonlayer(vec)\n", 124 | "\n", 125 | "\t# Print results\n", 126 | "\tprint(f' {groupsize:5d}', end=' | ')\n", 127 | "\tprint(f'{M:5d}', end=' | ')\n", 128 | "\tprint(f'{N:5d}', end=' | ')\n", 129 | "\tprint(f'{K:5d}', end=' | ')\n", 130 | "\tprint(f' {(cuda_out - ref).abs().max():.6f}', end=' | ')\n", 131 | "\tprint(f' {(triton_out - ref).abs().max():.6f}', end=' | ')\n", 132 | "\tprint(f' {(triton_out - cuda_out).abs().max():.6f}', end=' | ')\n", 133 | "\n", 134 | "\tif (triton_out - ref).abs().max() > 0.004 or (triton_out - cuda_out).abs().max() > 0.004:\n", 135 | "\t\tprint(\" !!! WARNING: Error is too large !!! \")\n", 136 | "\telse:\n", 137 | "\t\tprint()" 138 | ] 139 | }, 140 | { 141 | "attachments": {}, 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "## Verify QKV Fusion" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 3, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "Max diff: 0.0\n", 158 | "Max diff: 0.0\n", 159 | "Max diff: 0.0\n", 160 | "Max diff: 0.0\n", 161 | "Max diff: 0.0\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "# Comparison to ensure that the QKV fusion is correct\n", 167 | "class TestModel(nn.Module):\n", 168 | "\tdef __init__(self):\n", 169 | "\t\tsuper().__init__()\n", 170 | "\t\tself.attn = LlamaAttention(LlamaConfig(hidden_size=4096, num_attention_heads=32))\n", 171 | "\n", 172 | "\tdef forward(self, x):\n", 173 | "\t\treturn self.attn(x,)\n", 174 | "\n", 175 | "model = TestModel()\n", 176 | "\n", 177 | "# Quantize the model\n", 178 | "for name, m in model.named_modules():\n", 179 | "\tif not isinstance(m, nn.Linear):\n", 180 | "\t\tcontinue\n", 181 | "\n", 182 | "\tscales, zeros = dumbquant(m, 4, groupsize=-1)\n", 183 | "\ttriton_layer = gptq_triton.QuantLinear(4, -1, m.in_features, m.out_features, bias=False)\n", 184 | "\tpack_linear(triton_layer, m.weight.data, scales, zeros, None)\n", 185 | "\n", 186 | "\t# Replace in model\n", 187 | "\tparent_name = name.rsplit('.', 1)[0]\n", 188 | "\tparent = model.get_submodule(parent_name)\n", 189 | "\n", 190 | "\tsetattr(parent, name[len(parent_name) + 1:], triton_layer)\n", 191 | "\n", 192 | "# Save the original attention layer\n", 193 | "original_attn = model.attn\n", 194 | "\n", 195 | "# Fuse\n", 196 | "gptq_triton.make_quant_attn(model)\n", 197 | "fused_attn = model.attn\n", 198 | "\n", 199 | "# Move to CUDA\n", 200 | "original_attn.to('cuda')\n", 201 | "fused_attn.to('cuda')\n", 202 | "\n", 203 | "# Compare\n", 204 | "for M in [1, 8, 100, 256, 2048]:\n", 205 | "\tx = torch.randn(1, M, 4096, device='cuda', dtype=torch.float16)\n", 206 | "\tposition_ids = torch.arange(0, M, dtype=torch.long, device='cuda')\n", 207 | "\tposition_ids = position_ids.unsqueeze(0).view(-1, M)\n", 208 | "\n", 209 | "\toriginal_out = original_attn(x, position_ids=position_ids)[0]\n", 210 | "\tfused_out = fused_attn(x, position_ids=position_ids)[0]\n", 211 | "\n", 212 | "\tdiff = (original_out - fused_out).abs().max()\n", 213 | "\tprint(f\"Max diff: {diff}\")\n", 214 | "\n", 215 | "\t# Assertions\n", 216 | "\tassert isinstance(fused_attn, gptq_triton.QuantLlamaAttention)\n", 217 | "\tassert diff == 0" 218 | ] 219 | }, 220 | { 221 | "attachments": {}, 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "## Verify Fused MLP" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 4, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "Max diff: 0.000244140625\n", 238 | "Max diff: 0.000244140625\n", 239 | "Max diff: 0.000244140625\n", 240 | "Max diff: 0.000244140625\n", 241 | "Max diff: 0.000244140625\n", 242 | "Max diff: 0.000244140625\n", 243 | "Max diff: 0.0003662109375\n", 244 | "Max diff: 0.00025177001953125\n", 245 | "Max diff: 0.00048828125\n", 246 | "Max diff: 0.000274658203125\n" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "layer = LlamaMLP(4096, 11008, 'silu')\n", 252 | "layer = layer.half()\n", 253 | "layer_g128 = LlamaMLP(4096, 11008, 'silu')\n", 254 | "layer_g128 = layer_g128.half()\n", 255 | "layer_g128.load_state_dict(layer.state_dict())\n", 256 | "\n", 257 | "# Quantize\n", 258 | "for name, m in layer.named_modules():\n", 259 | "\tif not isinstance(m, nn.Linear):\n", 260 | "\t\tcontinue\n", 261 | "\n", 262 | "\tscales, zeros = dumbquant(m, 4, groupsize=-1)\n", 263 | "\ttriton_layer = gptq_triton.QuantLinear(4, -1, m.in_features, m.out_features, bias=False)\n", 264 | "\tpack_linear(triton_layer, m.weight.data, scales, zeros, None)\n", 265 | "\n", 266 | "\tsetattr(layer, name, triton_layer)\n", 267 | "\n", 268 | "for name, m in layer_g128.named_modules():\n", 269 | "\tif not isinstance(m, nn.Linear):\n", 270 | "\t\tcontinue\n", 271 | "\n", 272 | "\tscales, zeros = dumbquant(m, 4, groupsize=128)\n", 273 | "\ttriton_layer = gptq_triton.QuantLinear(4, 128, m.in_features, m.out_features, bias=False)\n", 274 | "\tpack_linear(triton_layer, m.weight.data, scales, zeros, None)\n", 275 | "\n", 276 | "\tsetattr(layer_g128, name, triton_layer)\n", 277 | "\n", 278 | "# Fuse\n", 279 | "fused_layer = gptq_triton.make_fused_mlp(layer)\n", 280 | "fused_layer_g128 = gptq_triton.make_fused_mlp(layer_g128)\n", 281 | "assert isinstance(fused_layer, gptq_triton.QuantLlamaMLP) and isinstance(fused_layer_g128, gptq_triton.QuantLlamaMLP)\n", 282 | "\n", 283 | "# Move to CUDA\n", 284 | "layer.to('cuda')\n", 285 | "layer_g128.to('cuda')\n", 286 | "fused_layer.to('cuda')\n", 287 | "fused_layer_g128.to('cuda')\n", 288 | "\n", 289 | "# Compare\n", 290 | "for M in [1, 8, 100, 256, 2048]:\n", 291 | "\tx = torch.randn(1, M, 4096, device='cuda', dtype=torch.float16)\n", 292 | "\n", 293 | "\toriginal_out = layer(x)\n", 294 | "\tfused_out = fused_layer(x)\n", 295 | "\n", 296 | "\tdiff = (original_out - fused_out).abs().max()\n", 297 | "\tprint(f\"Max diff: {diff}\")\n", 298 | "\n", 299 | "\t# There is a small difference because the fused MLP performs some calculations in float32, while the original MLP performs them in float16\n", 300 | "\tassert diff < 1e-3\n", 301 | "\n", 302 | "\toriginal_out = layer_g128(x)\n", 303 | "\tfused_out = fused_layer_g128(x)\n", 304 | "\tdiff = (original_out - fused_out).abs().max()\n", 305 | "\tprint(f\"Max diff: {diff}\")\n", 306 | "\tassert diff < 1e-3" 307 | ] 308 | } 309 | ], 310 | "metadata": { 311 | "kernelspec": { 312 | "display_name": "llama", 313 | "language": "python", 314 | "name": "python3" 315 | }, 316 | "language_info": { 317 | "codemirror_mode": { 318 | "name": "ipython", 319 | "version": 3 320 | }, 321 | "file_extension": ".py", 322 | "mimetype": "text/x-python", 323 | "name": "python", 324 | "nbconvert_exporter": "python", 325 | "pygments_lexer": "ipython3", 326 | "version": "3.10.10" 327 | }, 328 | "orig_nbformat": 4 329 | }, 330 | "nbformat": 4, 331 | "nbformat_minor": 2 332 | } 333 | -------------------------------------------------------------------------------- /original_quant.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa 2 | # Used for comparison with the Triton kernel 3 | # Some minor modifications were made to ease use 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import math 8 | from transformers import LlamaConfig, LlamaForCausalLM 9 | import transformers 10 | 11 | 12 | def quantize(x, scale, zero, maxq): 13 | if maxq < 0: 14 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 15 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 16 | return scale * (q - zero) 17 | 18 | class Quantizer(nn.Module): 19 | 20 | def __init__(self, shape=1): 21 | super(Quantizer, self).__init__() 22 | self.register_buffer('maxq', torch.tensor(0)) 23 | self.register_buffer('scale', torch.zeros(shape)) 24 | self.register_buffer('zero', torch.zeros(shape)) 25 | 26 | def configure( 27 | self, 28 | bits, perchannel=False, sym=True, 29 | mse=False, norm=2.4, grid=100, maxshrink=.8, 30 | trits=False 31 | ): 32 | 33 | self.maxq = torch.tensor(2 ** bits - 1) 34 | self.perchannel = perchannel 35 | self.sym = sym 36 | self.mse = mse 37 | self.norm = norm 38 | self.grid = grid 39 | self.maxshrink = maxshrink 40 | if trits: 41 | self.maxq = torch.tensor(-1) 42 | 43 | def find_params(self, x, weight=False): 44 | dev = x.device 45 | self.maxq = self.maxq.to(dev) 46 | 47 | shape = x.shape 48 | if self.perchannel: 49 | if weight: 50 | x = x.flatten(1) 51 | else: 52 | if len(shape) == 4: 53 | x = x.permute([1, 0, 2, 3]) 54 | x = x.flatten(1) 55 | if len(shape) == 3: 56 | x = x.reshape((-1, shape[-1])).t() 57 | if len(shape) == 2: 58 | x = x.t() 59 | else: 60 | x = x.flatten().unsqueeze(0) 61 | 62 | tmp = torch.zeros(x.shape[0], device=dev) 63 | xmin = torch.minimum(x.min(1)[0], tmp) 64 | xmax = torch.maximum(x.max(1)[0], tmp) 65 | 66 | if self.sym: 67 | xmax = torch.maximum(torch.abs(xmin), xmax) 68 | tmp = xmin < 0 69 | if torch.any(tmp): 70 | xmin[tmp] = -xmax[tmp] 71 | tmp = (xmin == 0) & (xmax == 0) 72 | xmin[tmp] = -1 73 | xmax[tmp] = +1 74 | 75 | if self.maxq < 0: 76 | self.scale = xmax 77 | self.zero = xmin 78 | else: 79 | self.scale = (xmax - xmin) / self.maxq 80 | if self.sym: 81 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 82 | else: 83 | self.zero = torch.round(-xmin / self.scale) 84 | 85 | if self.mse: 86 | best = torch.full([x.shape[0]], float('inf'), device=dev) 87 | for i in range(int(self.maxshrink * self.grid)): 88 | p = 1 - i / self.grid 89 | xmin1 = p * xmin 90 | xmax1 = p * xmax 91 | scale1 = (xmax1 - xmin1) / self.maxq 92 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 93 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 94 | q -= x 95 | q.abs_() 96 | q.pow_(self.norm) 97 | err = torch.sum(q, 1) 98 | tmp = err < best 99 | if torch.any(tmp): 100 | best[tmp] = err[tmp] 101 | self.scale[tmp] = scale1[tmp] 102 | self.zero[tmp] = zero1[tmp] 103 | if not self.perchannel: 104 | if weight: 105 | tmp = shape[0] 106 | else: 107 | tmp = shape[1] if len(shape) != 3 else shape[2] 108 | self.scale = self.scale.repeat(tmp) 109 | self.zero = self.zero.repeat(tmp) 110 | 111 | if weight: 112 | shape = [-1] + [1] * (len(shape) - 1) 113 | self.scale = self.scale.reshape(shape) 114 | self.zero = self.zero.reshape(shape) 115 | return 116 | if len(shape) == 4: 117 | self.scale = self.scale.reshape((1, -1, 1, 1)) 118 | self.zero = self.zero.reshape((1, -1, 1, 1)) 119 | if len(shape) == 3: 120 | self.scale = self.scale.reshape((1, 1, -1)) 121 | self.zero = self.zero.reshape((1, 1, -1)) 122 | if len(shape) == 2: 123 | self.scale = self.scale.unsqueeze(0) 124 | self.zero = self.zero.unsqueeze(0) 125 | 126 | def quantize(self, x): 127 | if self.ready(): 128 | return quantize(x, self.scale, self.zero, self.maxq) 129 | return x 130 | 131 | def enabled(self): 132 | return self.maxq > 0 133 | 134 | def ready(self): 135 | return torch.all(self.scale != 0) 136 | 137 | 138 | try: 139 | import quant_cuda 140 | except: 141 | print('CUDA extension not installed.') 142 | 143 | # Assumes layer is perfectly divisible into 256 * 256 blocks 144 | class QuantLinear(nn.Module): 145 | def __init__(self, bits, groupsize, infeatures, outfeatures, faster=False): 146 | super().__init__() 147 | if bits not in [2,3,4,8]: 148 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 149 | self.infeatures = infeatures 150 | self.outfeatures = outfeatures 151 | self.bits = bits 152 | if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2,int(math.log2(groupsize)))): 153 | raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)") 154 | groupsize = groupsize if groupsize != -1 else infeatures 155 | self.groupsize = groupsize 156 | self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures/groupsize),outfeatures // 256 * (bits * 8)), dtype=torch.int)) 157 | self.register_buffer('scales', torch.zeros((math.ceil(infeatures/groupsize),outfeatures))) 158 | self.register_buffer('bias', torch.zeros(outfeatures)) 159 | self.register_buffer( 160 | 'qweight', torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int) 161 | ) 162 | self.half_indim = self.infeatures // 2 163 | self._initialized_quant_state = False 164 | self.faster = faster 165 | 166 | def pack(self, linear, scales, zeros): 167 | scales = scales.t().contiguous() 168 | zeros = zeros.t().contiguous() 169 | scale_zeros = zeros * scales 170 | self.scales = scales.clone() 171 | if linear.bias is not None: 172 | self.bias = linear.bias.clone() 173 | 174 | intweight = [] 175 | for idx in range(self.infeatures): 176 | g_idx = idx // self.groupsize 177 | intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,None]) 178 | intweight = torch.cat(intweight,dim=1) 179 | intweight = intweight.t().contiguous() 180 | intweight = intweight.numpy().astype(np.uint32) 181 | qweight = np.zeros( 182 | (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 183 | ) 184 | i = 0 185 | row = 0 186 | while row < qweight.shape[0]: 187 | if self.bits in [2,4,8]: 188 | for j in range(i, i + (32//self.bits)): 189 | qweight[row] |= intweight[j] << (self.bits * (j - i)) 190 | i += 32//self.bits 191 | row += 1 192 | elif self.bits == 3: 193 | for j in range(i, i + 10): 194 | qweight[row] |= intweight[j] << (3 * (j - i)) 195 | i += 10 196 | qweight[row] |= intweight[i] << 30 197 | row += 1 198 | qweight[row] |= (intweight[i] >> 2) & 1 199 | i += 1 200 | for j in range(i, i + 10): 201 | qweight[row] |= intweight[j] << (3 * (j - i) + 1) 202 | i += 10 203 | qweight[row] |= intweight[i] << 31 204 | row += 1 205 | qweight[row] |= (intweight[i] >> 1) & 0x3 206 | i += 1 207 | for j in range(i, i + 10): 208 | qweight[row] |= intweight[j] << (3 * (j - i) + 2) 209 | i += 10 210 | row += 1 211 | else: 212 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 213 | 214 | qweight = qweight.astype(np.int32) 215 | self.qweight = torch.from_numpy(qweight) 216 | 217 | zeros -= 1; 218 | zeros = zeros.numpy().astype(np.uint32) 219 | qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32) 220 | i = 0 221 | col = 0 222 | while col < qzeros.shape[1]: 223 | if self.bits in [2,4,8]: 224 | for j in range(i, i + (32//self.bits)): 225 | qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) 226 | i += 32//self.bits 227 | col += 1 228 | elif self.bits == 3: 229 | for j in range(i, i + 10): 230 | qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) 231 | i += 10 232 | qzeros[:, col] |= zeros[:, i] << 30 233 | col += 1 234 | qzeros[:, col] |= (zeros[:, i] >> 2) & 1 235 | i += 1 236 | for j in range(i, i + 10): 237 | qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) 238 | i += 10 239 | qzeros[:, col] |= zeros[:, i] << 31 240 | col += 1 241 | qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 242 | i += 1 243 | for j in range(i, i + 10): 244 | qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) 245 | i += 10 246 | col += 1 247 | else: 248 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 249 | 250 | qzeros = qzeros.astype(np.int32) 251 | self.qzeros = torch.from_numpy(qzeros) 252 | 253 | def forward(self, x): 254 | if not self._initialized_quant_state: 255 | # Do we even have a bias? Check for at least one non-zero element. 256 | if self.bias is not None and bool(torch.any(self.bias != 0)): 257 | # Then make sure it's the right type. 258 | self.bias.data = self.bias.data.to(torch.float32) 259 | else: 260 | self.bias = None 261 | 262 | outshape = list(x.shape) 263 | outshape[-1] = self.outfeatures 264 | x = x.reshape(-1, x.shape[-1]) 265 | if self.bias is None: 266 | y = torch.zeros(x.shape[0], outshape[-1], dtype=torch.float32, device=x.device) 267 | else: 268 | y = self.bias.clone().repeat(x.shape[0], 1) 269 | 270 | output_dtype = x.dtype 271 | if self.faster: 272 | x = x.half() 273 | if self.bits == 3: 274 | quant_cuda.vecquant3matmul_faster(x, self.qweight, y, self.scales, self.qzeros, self.groupsize, self.half_indim) 275 | else: 276 | raise NotImplementedError("Only 3 bits are supported.") 277 | else: 278 | x = x.float() 279 | if self.bits == 2: 280 | quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) 281 | elif self.bits == 3: 282 | quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) 283 | elif self.bits == 4: 284 | quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) 285 | elif self.bits == 8: 286 | quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) 287 | else: 288 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 289 | y = y.to(output_dtype) 290 | return y.reshape(outshape) 291 | 292 | def make_quant(module, names, bits, groupsize, faster=False, name=''): 293 | if isinstance(module, QuantLinear): 294 | return 295 | for attr in dir(module): 296 | tmp = getattr(module, attr) 297 | name1 = name + '.' + attr if name != '' else attr 298 | if name1 in names: 299 | delattr(module, attr) 300 | setattr( 301 | module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, faster=faster) 302 | ) 303 | for name1, child in module.named_children(): 304 | make_quant(child, names, bits, groupsize, faster, name + '.' + name1 if name != '' else name1) 305 | 306 | 307 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 308 | if type(module) in layers: 309 | return {name: module} 310 | res = {} 311 | for name1, child in module.named_children(): 312 | res.update(find_layers( 313 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 314 | )) 315 | return res 316 | 317 | 318 | def load_cuda_quant(model, checkpoint, wbits, groupsize): 319 | """ 320 | Load a quantized model using the old CUDA kernel 321 | """ 322 | config = LlamaConfig.from_pretrained(model) 323 | def noop(*args, **kwargs): 324 | pass 325 | original_inits = (torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_) 326 | torch.nn.init.kaiming_uniform_ = noop 327 | torch.nn.init.uniform_ = noop 328 | torch.nn.init.normal_ = noop 329 | 330 | torch.set_default_dtype(torch.half) 331 | original_init_weights = transformers.modeling_utils._init_weights 332 | transformers.modeling_utils._init_weights = False 333 | torch.set_default_dtype(torch.half) 334 | model = LlamaForCausalLM(config) 335 | torch.set_default_dtype(torch.float) 336 | 337 | transformers.modeling_utils._init_weights = original_init_weights 338 | torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_ = original_inits 339 | 340 | model = model.eval() 341 | layers = find_layers(model) 342 | for name in ['lm_head']: 343 | if name in layers: 344 | del layers[name] 345 | make_quant(model, layers, wbits, groupsize, faster=False) 346 | 347 | del layers 348 | 349 | print('Loading model ...') 350 | if checkpoint.endswith('.safetensors'): 351 | from safetensors.torch import load_file as safe_load 352 | model.load_state_dict(safe_load(checkpoint)) 353 | else: 354 | model.load_state_dict(torch.load(checkpoint)) 355 | model.seqlen = 2048 356 | print('Done.') 357 | 358 | return model --------------------------------------------------------------------------------