├── .gitignore ├── requirements.txt ├── images ├── request.png └── result.png ├── src ├── model │ ├── __init__.py │ ├── fused_base.py │ ├── utils.py │ ├── QuantLinear.py │ ├── custom_autotune.py │ ├── fused_llama_attn.py │ ├── fused_llama_mlp.py │ ├── kernels.py │ └── LlamaGPTQ.py ├── __init__.py ├── globals.py ├── utils.py ├── speculative_sampling.py └── kvcache_model.py ├── make_quant.sh ├── start_server.sh ├── scripts ├── quantize.py └── serving.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | triton==2.1.0 2 | auto_gptq==0.7.0 3 | transformers==4.37.2 -------------------------------------------------------------------------------- /images/request.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjchen-thu/codebear/HEAD/images/request.png -------------------------------------------------------------------------------- /images/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjchen-thu/codebear/HEAD/images/result.png -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.LlamaGPTQ import LlamaGPTQ 2 | 3 | __all__ = ["LlamaGPTQ", ] -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from src.speculative_sampling import speculative_sampling 2 | from src.model import LlamaGPTQ 3 | 4 | __all__ = ["speculative_sampling", "LlamaGPTQ"] -------------------------------------------------------------------------------- /make_quant.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | f_value="" 4 | q_value="" 5 | 6 | while getopts ":f:q:" opt; do 7 | case ${opt} in 8 | f ) 9 | f_value=$OPTARG 10 | ;; 11 | q ) 12 | q_value=$OPTARG 13 | ;; 14 | \? ) 15 | echo "Invalid option: $OPTARG" 1>&2 16 | exit 1 17 | ;; 18 | : ) 19 | echo "Invalid option: $OPTARG requires an argument" 1>&2 20 | exit 1 21 | ;; 22 | esac 23 | done 24 | shift $((OPTIND -1)) 25 | 26 | if [ -z "$f_value" ] || [ -z "$q_value" ]; then 27 | echo "Both -f and -q parameters are required." 28 | exit 1 29 | fi 30 | 31 | python scripts/quantize.py -f "$f_value" -q "$q_value" -------------------------------------------------------------------------------- /start_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | s_value="" 5 | l_value="" 6 | t_value="" 7 | 8 | 9 | while getopts ":s:l:t:" opt; do 10 | case ${opt} in 11 | s ) 12 | s_value=$OPTARG 13 | ;; 14 | l ) 15 | l_value=$OPTARG 16 | ;; 17 | t ) 18 | t_value=$OPTARG 19 | ;; 20 | \? ) 21 | echo "Invalid option: $OPTARG" 1>&2 22 | exit 1 23 | ;; 24 | : ) 25 | echo "Invalid option: $OPTARG requires an argument" 1>&2 26 | exit 1 27 | ;; 28 | esac 29 | done 30 | shift $((OPTIND -1)) 31 | 32 | 33 | if [ -z "$s_value" ] || [ -z "$l_value" ] || [ -z "$t_value" ]; then 34 | echo "Parameters -s, -l, and -t are required." 35 | exit 1 36 | fi 37 | 38 | 39 | python scripts/serving.py -s "$s_value" -l "$l_value" -t "$t_value" 40 | -------------------------------------------------------------------------------- /src/globals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # reference: https://github.com/feifeibear/LLMSpeculativeSampling 4 | class Singleton(type): 5 | _instances = {} 6 | 7 | def __call__(cls, *args, **kwargs): 8 | if cls not in cls._instances: 9 | cls._instances[cls] = super().__call__(*args, **kwargs) 10 | return cls._instances[cls] 11 | 12 | class Decoder(metaclass=Singleton): 13 | def __init__(self): 14 | self.tokenizer = None 15 | 16 | def set_tokenizer(self, tokenizer): 17 | self.tokenizer = tokenizer 18 | 19 | def encode(self, s: str, return_tensors='pt') -> torch.Tensor: 20 | return self.tokenizer.encode(s, return_tensors=return_tensors) 21 | 22 | def decode(self, t: torch.Tensor) -> str: 23 | return self.tokenizer.decode(t[0], skip_special_tokens=True) 24 | -------------------------------------------------------------------------------- /src/model/fused_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from abc import abstractmethod 4 | from logging import getLogger 5 | 6 | 7 | logger = getLogger(__name__) 8 | 9 | class TritonModuleMixin: 10 | @classmethod 11 | def warmup(cls, model, transpose=False, seqlen=2048): 12 | pass 13 | 14 | class FusedBaseModule(nn.Module, TritonModuleMixin): 15 | @classmethod 16 | @abstractmethod 17 | def inject_to_model(cls, *args, **kwargs): 18 | raise NotImplementedError() 19 | 20 | 21 | class FusedBaseAttentionModule(FusedBaseModule): 22 | @classmethod 23 | @abstractmethod 24 | def inject_to_model( 25 | cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, trainable=False, **kwargs 26 | ): 27 | raise NotImplementedError() 28 | 29 | @classmethod 30 | def warmup(cls, model, transpose=False, seqlen=2048): 31 | pass 32 | 33 | 34 | class FusedBaseMLPModule(FusedBaseModule): 35 | @classmethod 36 | @abstractmethod 37 | def inject_to_model(cls, model, use_triton=False, **kwargs): 38 | raise NotImplementedError() 39 | -------------------------------------------------------------------------------- /scripts/quantize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | 4 | from transformers import AutoTokenizer 5 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='serving arg pars') 9 | parser.add_argument('--float_model_path', '-f', type=str, default="/raid/chenhj/CodeLlama-7b-Python-hf") 10 | parser.add_argument('--quant_model_path', '-q', type=str, default="/raid/chenhj/test") 11 | 12 | parser.add_argument('--bits', '-b', type=int, default=4, help='bits num(recommended 4bit)') 13 | parser.add_argument('--group_size', '-g', type=int, default=128, help='group size(recommended 128)') 14 | 15 | args = parser.parse_args() 16 | 17 | logging.basicConfig( 18 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" 19 | ) 20 | 21 | tokenizer = AutoTokenizer.from_pretrained(args.float_model_path, use_fast=True) 22 | examples = [ 23 | tokenizer( 24 | "AI, or Artificial Intelligence, can perform a wide range of tasks, often mimicking human behaviors and capabilities but at a scale or speed that humans cannot match. " 25 | ) 26 | ] 27 | 28 | quantize_config = BaseQuantizeConfig( 29 | bits=args.bits, 30 | group_size=args.group_size, 31 | desc_act=False, 32 | ) 33 | 34 | model = AutoGPTQForCausalLM.from_pretrained(args.float_model_path, quantize_config) 35 | 36 | model.quantize(examples) 37 | model.save_quantized(args.quant_model_path) 38 | model.save_quantized(args.quant_model_path, use_safetensors=True) 39 | 40 | # load quantized model to the first GPU 41 | model = AutoGPTQForCausalLM.from_quantized(args.quant_model_path, device="cuda:0") 42 | 43 | # inference with model.generate 44 | print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0])) 45 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | # copy from https://github.com/LeeSinLiang/microGPT/blob/ed40cf9780dbeb180adfe94c227d4aa97e69250e/gpt.py 5 | def top_k_top_p_filter(logits: torch.Tensor, top_k: int = 0, top_p: float = 0.0): 6 | """ 7 | 8 | Args: 9 | logits (torch.Tensorpe_): 2D tensor with shape (batch, vocab) 10 | top_k (int, optional): top_k. Defaults to 0. 11 | top_p (float, optional): top_p. Defaults to 0.0. 12 | 13 | Returns: 14 | torch.Tensor: a renormalized logits 15 | """ 16 | if top_k > 0: 17 | filter = torch.topk(logits, min(top_k, logits.size(-1)))[0] 18 | logits[logits < filter[:, [-1]]] = float('-inf') 19 | if top_p > 0.0: 20 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 21 | cumulative_probs = torch.cumsum( 22 | F.softmax(sorted_logits, dim=-1), dim=-1) 23 | filter = cumulative_probs > top_p 24 | filter[..., 1:] = filter[..., :-1].clone() 25 | filter[..., 0] = 0 26 | indices_to_remove = filter.scatter(1, sorted_indices, filter) 27 | logits[indices_to_remove] = float('-inf') 28 | return logits 29 | 30 | 31 | def norm_logits(logits : torch.Tensor, temperature : float, top_k : float, top_p : float) -> torch.Tensor: 32 | """ 33 | 34 | Args: 35 | logits (torch.Tensor): shape (1, vocab) 36 | temperature (float): temperature 37 | top_k (float): top_k 38 | top_p (float): top_p 39 | 40 | Returns: 41 | torch.Tensor: next token with shape as (batch, 1) 42 | """ 43 | assert logits.dim() == 2 44 | logits = logits / temperature 45 | logits = top_k_top_p_filter(logits, top_k=top_k, top_p=top_p) 46 | probs = F.softmax(logits, dim=1) 47 | return probs 48 | 49 | 50 | def sample(probs : torch.Tensor, num_samples: int = 1): 51 | idx_next = torch.multinomial(probs, num_samples=num_samples) 52 | if (idx_next.item() == 0): 53 | raise RuntimeError 54 | return idx_next 55 | 56 | 57 | def max_fn(x): 58 | """ 59 | norm(max (x, 0)) 60 | """ 61 | x_max = torch.where(x > 0, x, torch.zeros_like(x)) 62 | x_max_sum = torch.sum(x_max, dim=1, keepdim=True) 63 | return x_max / x_max_sum 64 | -------------------------------------------------------------------------------- /scripts/serving.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request 2 | from transformers import AutoTokenizer 3 | import torch 4 | import logging 5 | import os 6 | import sys 7 | import json 8 | import time 9 | import argparse 10 | 11 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 12 | from src import speculative_sampling, LlamaGPTQ 13 | 14 | app = Flask(__name__) 15 | GLOBAL_SERVER = None 16 | 17 | # reference: https://github.com/feifeibear/LLMSpeculativeSampling 18 | class Server: 19 | def __init__(self, small_model, large_model, tokenizer_model, max_tokens, top_k, top_p) -> None: 20 | self._device = "cuda:0" 21 | self._small_model = LlamaGPTQ.from_quantized(small_model, device="cuda:0", use_triton=True, warmup_triton=False) 22 | self._large_model = LlamaGPTQ.from_quantized(large_model, device="cuda:0", use_triton=True, warmup_triton=False, inject_fused_attention=False) 23 | self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_model) 24 | 25 | self.num_tokens = max_tokens 26 | self.top_k = top_k 27 | self.top_p = top_p 28 | 29 | def process_request(self, request : str) -> torch.Tensor: 30 | input_str = request['prompt'] 31 | logging.info(f"receive request {input_str}") 32 | input_ids = self._tokenizer.encode(input_str, return_tensors='pt').to(self._device) 33 | output = speculative_sampling(input_ids, 34 | self._small_model, 35 | self._large_model, self.num_tokens, 36 | top_k = self.top_k, 37 | top_p = self.top_p) 38 | generated_text = self._tokenizer.decode(output[0], skip_special_tokens=True) 39 | return generated_text 40 | 41 | # Set up a route to listen for inference requests 42 | @app.route('/codebear', methods=['POST']) 43 | def codebear(): 44 | # Check the content type of the request 45 | # if request.headers['Content-Type'] != 'application/json': 46 | # return jsonify({'error': 'Invalid content type'}) 47 | 48 | # Get the request data 49 | request_data = request.get_json() 50 | print(request_data) 51 | 52 | # Perform inference 53 | start_time = time.time() 54 | result = GLOBAL_SERVER.process_request(request_data) 55 | end_time = time.time() 56 | 57 | lines = result.splitlines() 58 | indented_lines = [' ' + line for line in lines] 59 | # indented_lines[0] = ' '+indented_lines[0] 60 | indented_text = '\n'.join(indented_lines) 61 | 62 | max_len = 200 63 | 64 | output_data = { 65 | "response": indented_text, 66 | "number": max_len, 67 | "time": "{:.2f}".format(end_time - start_time), 68 | "tokensps": "{:.2f}".format(max_len / (end_time - start_time)), 69 | } 70 | output_json = json.dumps(output_data) 71 | 72 | print(indented_text) 73 | 74 | return output_json, 200 75 | 76 | # Return the inference results 77 | # return jsonify(result) 78 | 79 | if __name__ == '__main__': 80 | parser = argparse.ArgumentParser(description='serving arg pars') 81 | parser.add_argument('--small_model', '-s', type=str, default="/raid/chenhj/CodeLlama-7b-4bit") 82 | parser.add_argument('--large_model', '-l', type=str, default="/raid/chenhj/CodeLlama-34b-4bit") 83 | parser.add_argument('--tokenizer_model', '-t', type=str, default="/raid/chenhj/CodeLlama-7b-Python-hf") 84 | parser.add_argument('--max_tokens', '-M', type=int, default=200, help='Max tokens generated') 85 | parser.add_argument('--top_k', '-k', type=int, default=10, help='top_k') 86 | parser.add_argument('--top_p', '-p', type=float, default=0.9, help='top_p') 87 | args = parser.parse_args() 88 | 89 | GLOBAL_SERVER = Server( 90 | small_model=args.small_model, 91 | large_model=args.large_model, 92 | tokenizer_model=args.tokenizer_model, 93 | max_tokens=args.max_tokens, 94 | top_k=args.top_k, 95 | top_p=args.top_p 96 | ) 97 | # Start the Flask service 98 | app.run(host='0.0.0.0', port=5000) 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Codebear 2 | This repository combines **GPTQ 4-bit quantization** and **Speculative Decoding** to accelerate Large Language Models' (LLM) inference for code completion tasks in **personal usage scenarios** (where GPU resources are limited yet there's a pursuit for better performance and faster speed with larger models). 3 | 4 | [GPTQ](https://arxiv.org/abs/2210.17323) is a one-shot weight quantization method based on approximate second-order information, that is both highly-accurate and highly efficient. And [Speculative Decoding](https://arxiv.org/abs/2302.01318) is a innovative sampling strategy by using a small approximation model to propose sequences of tokens that will later be checked by a larger model. 5 | 6 | By combining these two techniques, one can even deploy multiple LLMs in a single GPU with limited HBM memory usage. While benefiting from the improved performance brought by larger models, it also helps to accelerate inference speed to some extent. 7 | 8 | The flowing figures are tested in a single V100(32GB) by deploying [CodeLlama-34B](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) and [CodeLlama-7B](https://huggingface.co/codellama/CodeLlama-7b-Python-hf) models, with triton-based QuantLinear backend. 9 | 10 | | | 3 prefill + 200th decoding | 11 | | :----: | :----: | 12 | | Memory Usage(GB) | 27.7 | 13 | 14 | 15 | | 3 prefill + 200th decoding| CodeLlama 7B(FP16) |CodeLlama 7B(4Bit) |CodeLlama 34B(4Bit) |Speculative 7B+34B(4Bit)| 16 | | :----: | :----: |:----: |:----: |:----: | 17 | | Inference Speed(Tokens/sec) | 14.3 | 34.1 | 7.9 | 9.4 | 18 | 19 | ![alt text](images/result.png) 20 | 21 | 22 | ## Update 23 | | | content | 24 | | :----: | :----: | 25 | | 2024-03-10| fused mlp triton kernel| 26 | 27 | ## Acknowledgement 28 | 29 | - Special thanks to [feifeibear](https://github.com/feifeibear) for releasing the implemention of speculative decoding with both Google's and Deepmind's versions([LLMSpeculativeSampling](https://github.com/feifeibear/LLMSpeculativeSampling)). 30 | - Special thanks to [AutoGPTQ team](https://github.com/AutoGPTQ/) for implementing GPTQ algorithm and open source the code. 31 | 32 | ## Quick Tour 33 | ### Requirements 34 | ``` 35 | triton==2.1.0 36 | auto_gptq==0.7.0 37 | transformers==4.37.2 38 | ``` 39 | 40 | ### Step1: Quantize 41 | Download the float model from official([CodeLlama-7B](https://huggingface.co/codellama/) and [CodeLlama-34B](https://huggingface.co/codellama/CodeLlama-34b-Python-hf)), then quantize them. 42 | ```bash 43 | #quntize the 7b model 44 | ./make_quant.sh -f /PATH/TO/7B/FLOAT/MODEL -q /PATH/TO/7B/QUANT/MODEL 45 | #quntize the 34b model 46 | ./make_quant.sh -f /PATH/TO/34B/FLOAT/MODEL -q /PATH/TO/34B/QUANT/MODEL 47 | ``` 48 | Or you can just download the 4bit quantized model from my Huggingface([CodeLlama-7B-4bit](https://huggingface.co/guaguabear/codebear-7b-4bit) and [CodeLlama-34B-4bit](https://huggingface.co/guaguabear/codebear-34b-4bit)) 49 | 50 | The basic config of quantization is set to bits = 4, group_num = 128 (can be changed in ./scripts/quantize.py). 51 | 52 | ### Step2: Serving 53 | Start serving 54 | ```bash 55 | ./start_server.sh -s /PATH/TO/7B/QUANT/MODEL -l /PATH/TO/34B/QUANT/MODEL -t /PATH/TO/7B/FLOAT/MODEL 56 | ``` 57 | Default sampling params are set to max_tokens = 200, top_k = 10, top_p = 0.9 (can be changed in ./scripts/serving.py). 58 | 59 | Send request (**the model is specially trained for code completion with python**) 60 | ``` 61 | curl -X POST -H "Content-Type: application/json" -d '{"prompt": "def quicksort("}' http://127.0.0.1:5000/codebear 62 | ``` 63 | ![alt text](images/request.png) 64 | 65 | 66 | 67 | ## Future Plans 68 | 69 | | | Progress | 70 | | :----: | :----: | 71 | | fused_flash_attn_MHA triton implemention| todo | 72 | | fused_flash_attn_GQA triton implemention| todo | 73 | | INT8 KV cache| todo | 74 | 75 | 76 | ## References 77 | ``` 78 | @article{frantar-gptq, 79 | title={{GPTQ}: Accurate Post-training Compression for Generative Pretrained Transformers}, 80 | author={Elias Frantar and Saleh Ashkboos and Torsten Hoefler and Dan Alistarh}, 81 | year={2022}, 82 | journal={arXiv preprint arXiv:2210.17323} 83 | } 84 | 85 | @inproceedings{leviathan2023fast, 86 | title={Fast inference from transformers via speculative decoding}, 87 | author={Leviathan, Yaniv and Kalman, Matan and Matias, Yossi}, 88 | booktitle={International Conference on Machine Learning}, 89 | pages={19274--19286}, 90 | year={2023}, 91 | organization={PMLR} 92 | } 93 | 94 | ``` -------------------------------------------------------------------------------- /src/speculative_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.kvcache_model import KVCacheModel 4 | from src.utils import sample, max_fn 5 | from src.globals import Decoder 6 | 7 | # reference: https://github.com/feifeibear/LLMSpeculativeSampling 8 | @torch.no_grad() 9 | def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, 10 | max_len : int , gamma : int = 4, 11 | temperature : float = 1, top_k : int = 0, top_p : float = 0, verbose : bool = False, random_seed : int = None) -> torch.Tensor: 12 | """ 13 | Google version Speculative Sampling. 14 | https://arxiv.org/pdf/2211.17192.pdf 15 | 16 | Adapted with KV Cache Optimization. 17 | 18 | Args: 19 | x (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now. 20 | approx_model (torch.nn.Module): approx model, the small one 21 | target_model (torch.nn.Module): target model, the large one 22 | max_len (int): the max overall generated tokens number. 23 | gamma (int): $\gamma$, the token number small model guesses. 24 | temperature (float, optional): Defaults to 1. 25 | top_k (int, optional): Defaults to 0. 26 | top_p (float, optional): Defaults to 0. 27 | 28 | Returns: 29 | torch.Tensor: generated tokens (batch, target_seqlen) 30 | """ 31 | seq_len = prefix.shape[1] 32 | T = seq_len + max_len 33 | 34 | assert prefix.shape[0] == 1, "input batch size must be 1" 35 | 36 | assert approx_model.device == target_model.device 37 | 38 | device = target_model.device 39 | 40 | approx_model_cache = KVCacheModel(approx_model, temperature, top_k, top_p) 41 | target_model_cache = KVCacheModel(target_model, temperature, top_k, top_p) 42 | 43 | resample_count = 0 44 | target_sample_count = 0 45 | accepted_count = 0 46 | 47 | while prefix.shape[1] < T: 48 | prefix_len = prefix.shape[1] 49 | 50 | x = approx_model_cache.generate(prefix, gamma) 51 | _ = target_model_cache.generate(x, 1) 52 | 53 | n = prefix_len + gamma - 1 54 | 55 | 56 | for i in range(gamma): 57 | if random_seed: 58 | torch.manual_seed(random_seed) 59 | r = torch.rand(1, device = device) 60 | j = x[:, prefix_len + i] 61 | 62 | if r > (target_model_cache._prob_history[:, prefix_len + i - 1, j]) / (approx_model_cache._prob_history[:, prefix_len + i - 1, j]): 63 | # reject 64 | n = prefix_len + i - 1 65 | break 66 | 67 | if verbose: 68 | print(f"approx guess accepted {j[0]}: \033[31m{Decoder().decode(torch.tensor([j]))}\033[0m") 69 | 70 | accepted_count += 1 71 | 72 | # print(f"n : {n}, i : {i}, prefix_len + gamma - 1: {prefix_len + gamma - 1}") 73 | assert n >= prefix_len - 1, f"n {n}, prefix_len {prefix_len}" 74 | prefix = x[:, :n + 1] 75 | 76 | approx_model_cache.rollback(n+1) 77 | 78 | assert approx_model_cache._prob_history.shape[-2] <= n + 1, f"approx_model prob list shape {approx_model_cache._prob_history.shape}, n {n}" 79 | 80 | if n < prefix_len + gamma - 1: 81 | # reject someone, sample from the pos n 82 | t = sample(max_fn(target_model_cache._prob_history[:, n, :] - approx_model_cache._prob_history[:, n, :])) 83 | if verbose: 84 | print(f"target resamples at position {n}: \033[34m{Decoder().decode(t)}\033[0m") 85 | resample_count += 1 86 | target_model_cache.rollback(n+1) 87 | else: 88 | # all approx model decoding accepted 89 | assert n == target_model_cache._prob_history.shape[1] - 1 90 | t = sample(target_model_cache._prob_history[:, -1, :]) 91 | if verbose: 92 | print(f"target samples {n}: \033[35m{Decoder().decode(t)}\033[0m") 93 | target_sample_count += 1 94 | target_model_cache.rollback(n+2) 95 | 96 | 97 | prefix = torch.cat((prefix, t), dim=1) 98 | 99 | if verbose: 100 | print(f"generated tokens numbers {prefix.shape[-1] - seq_len}, accepted_count {accepted_count}, target_sample_count {target_sample_count}, resample_count {resample_count}") 101 | return prefix 102 | -------------------------------------------------------------------------------- /src/kvcache_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.utils import norm_logits, sample 4 | 5 | # reference: https://github.com/feifeibear/LLMSpeculativeSampling 6 | def _debug_show_kvcache(past_key_values): 7 | if past_key_values is None: 8 | return 9 | for elem in past_key_values: 10 | k, v = elem 11 | print(f"kv cache: k shape {k.shape}, v shape {v.shape}") 12 | break 13 | 14 | class KVCacheModel(): 15 | def __init__(self, model : torch.nn.Module, temperature : float = 1, top_k : int = 0, top_p : float = 0) -> None: 16 | self._model = model 17 | self._past_key_values = None 18 | self._prob_history = None 19 | 20 | self._temperature = temperature 21 | self._top_k = top_k 22 | self._top_p = top_p 23 | 24 | def _forward_with_kvcache(self, input_ids : torch.Tensor, use_debug = True) -> torch.Tensor: 25 | if self._past_key_values is None: 26 | assert self._prob_history is None, f"{self._prob_history.shape}" 27 | # the first forward (prefill) returns the prompt's logits 28 | outputs = self._model(input_ids) 29 | self._prob_history = outputs.logits 30 | for i in range(self._prob_history.shape[-2]): 31 | self._prob_history[:, i, :] = norm_logits(self._prob_history[:, i, :], self._temperature, self._top_k, self._top_p) 32 | self._past_key_values = outputs.past_key_values 33 | last_q = self._prob_history[:, -1, :] 34 | else: 35 | # return the last token's logits 36 | cached_len = 0 37 | for kv in self._past_key_values: 38 | k, v = kv 39 | cached_len = k.shape[2] 40 | 41 | last_input_id = input_ids[:, cached_len:] 42 | if last_input_id.dim() == 1: 43 | last_input_id = torch.unsqueeze(last_input_id, 0) 44 | 45 | if use_debug: 46 | print(f"last_input_id shape {last_input_id.shape}") 47 | _debug_show_kvcache(self._past_key_values) 48 | 49 | outputs = self._model(last_input_id, past_key_values=self._past_key_values, use_cache=True) 50 | 51 | not_cached_q = outputs.logits 52 | if not_cached_q.dim() == 2: 53 | not_cached_q = torch.unsqueeze(not_cached_q, 0) 54 | 55 | for i in range(not_cached_q.shape[-2]): 56 | not_cached_q[:, i, :] = norm_logits(not_cached_q[:, i, :], self._temperature, self._top_k, self._top_p) 57 | 58 | self._prob_history = torch.cat([self._prob_history, not_cached_q], dim=1) 59 | 60 | last_q = not_cached_q[:, -1, :] 61 | self._past_key_values = outputs.past_key_values 62 | 63 | return last_q 64 | 65 | 66 | def _generate_with_kvcache(self, prefix : torch.Tensor, 67 | gamma : int, 68 | use_debug = False) -> torch.Tensor: 69 | """ forward the model gamma times 70 | 71 | Args: 72 | prefix (torch.Tensor): the prefix 73 | gamma (int): how many times approx guesses 74 | 75 | Returns: 76 | Torch.Tensor: prefix+generated tokens 77 | """ 78 | x = prefix 79 | 80 | for _ in range(gamma): 81 | q = self._forward_with_kvcache(x, use_debug) 82 | next_tok = sample(q) 83 | x = torch.cat((x, next_tok), dim=1) 84 | return x 85 | 86 | @torch.no_grad() 87 | def generate(self, input : torch.Tensor, gamma : int) -> torch.Tensor: 88 | output = self._generate_with_kvcache(input, gamma) 89 | return output 90 | 91 | @torch.no_grad() 92 | def rollback(self, end_pos : int): 93 | past_key_values_trimmed = [] 94 | assert self._past_key_values 95 | for kv in self._past_key_values: 96 | k, v = kv 97 | # NOTE() the indexing is specific for bloom. This won't work for other models 98 | # For example llama k, v should be (batch, num_head, seq_len, hidden_dim) 99 | 100 | # k, v (batch, head, seq, hidden_dim) 101 | k = k[:, :, :end_pos, :] 102 | v = v[:, :, :end_pos, :] 103 | kv_trimmed = (k, v) 104 | past_key_values_trimmed.append(kv_trimmed) 105 | 106 | self._past_key_values = past_key_values_trimmed 107 | self._prob_history = self._prob_history[:, :end_pos, :] 108 | 109 | -------------------------------------------------------------------------------- /src/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import transformers 4 | import accelerate 5 | 6 | from typing import Union 7 | from src.model.QuantLinear import QuantLinear 8 | 9 | def get_device(obj: Union[torch.Tensor, nn.Module]): 10 | if isinstance(obj, torch.Tensor): 11 | return obj.device 12 | return next(obj.parameters()).device 13 | 14 | def find_layers(module, layers=None, name=""): 15 | if not layers: 16 | layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear] 17 | for layer in layers: 18 | if isinstance(module, layer): 19 | return {name: module} 20 | res = {} 21 | for name1, child in module.named_children(): 22 | res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) 23 | return res 24 | 25 | def get_module_by_name_suffix(model, module_name: str): 26 | for name, module in model.named_modules(): 27 | if name.endswith(module_name): 28 | return module 29 | 30 | def simple_dispatch_model(model, device_map): 31 | from accelerate.hooks import AlignDevicesHook, add_hook_to_module 32 | 33 | if "" in device_map: 34 | d = device_map[""] 35 | model = model.to(torch.device(d)) 36 | model.hf_device_map = device_map 37 | return model 38 | 39 | tied_params = accelerate.utils.modeling.find_tied_parameters(model) 40 | if set(device_map.values()) == {"cpu"} or set(device_map.values()) == { 41 | "cpu", 42 | "disk", 43 | }: 44 | main_device = "cpu" 45 | else: 46 | main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] 47 | 48 | cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"] 49 | prev_hook = None 50 | for idx, (n, d) in enumerate(cpu_offload_group): 51 | m = get_module_by_name_suffix(model, n) 52 | _, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook) 53 | # set first cpu offload module's prev_module_hook to the last cpu offload module's hook 54 | if len(cpu_offload_group) > 1: 55 | get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook 56 | 57 | for n, d in device_map.items(): 58 | m = get_module_by_name_suffix(model, n) 59 | if d != "cpu": 60 | d = torch.device(d) 61 | hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) 62 | add_hook_to_module(m, hook) 63 | accelerate.utils.modeling.retie_parameters(model, tied_params) 64 | model.hf_device_map = device_map 65 | 66 | return model 67 | 68 | def make_quant( 69 | module, 70 | names, 71 | bits, 72 | group_size, 73 | name="", 74 | use_triton: bool = True, 75 | use_cuda_fp16: bool = True, 76 | desc_act: bool = False, 77 | ): 78 | for attr in dir(module): 79 | tmp = getattr(module, attr) 80 | name1 = name + "." + attr if name != "" else attr 81 | if name1 in names: 82 | ori_layer_device = get_device(getattr(module, attr)) 83 | delattr(module, attr) 84 | if isinstance(tmp, nn.Linear): 85 | in_features = tmp.in_features 86 | out_features = tmp.out_features 87 | elif isinstance(tmp, nn.Conv2d): 88 | in_features = tmp.in_channels 89 | out_features = tmp.out_channels 90 | elif isinstance(tmp, transformers.pytorch_utils.Conv1D): 91 | in_features = tmp.weight.shape[0] 92 | out_features = tmp.weight.shape[1] 93 | if (not (desc_act) or group_size == -1) and not use_triton: 94 | new_layer = QuantLinear( 95 | bits, 96 | group_size, 97 | in_features, 98 | out_features, 99 | True, 100 | use_cuda_fp16=use_cuda_fp16, 101 | weight_dtype=tmp.weight.dtype, 102 | ) 103 | else: 104 | new_layer = QuantLinear( 105 | bits, 106 | group_size, 107 | in_features, 108 | out_features, 109 | True, 110 | weight_dtype=tmp.weight.dtype, 111 | ) 112 | new_layer.device = ori_layer_device 113 | setattr(module, attr, new_layer.to(ori_layer_device)) 114 | for name1, child in module.named_children(): 115 | make_quant( 116 | child, 117 | names, 118 | bits, 119 | group_size, 120 | name + "." + name1 if name != "" else name1, 121 | use_triton=use_triton, 122 | use_cuda_fp16=use_cuda_fp16, 123 | desc_act=desc_act, 124 | ) 125 | -------------------------------------------------------------------------------- /src/model/QuantLinear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import transformers 6 | 7 | from logging import getLogger 8 | from torch.cuda.amp import custom_fwd 9 | from src.model.kernels import quant_matmul_248, transpose_quant_matmul_248, quant_matmul_inference_only_248 10 | 11 | logger = getLogger(__name__) 12 | 13 | class QuantLinearInferenceOnlyFunction(torch.autograd.Function): 14 | @staticmethod 15 | @custom_fwd(cast_inputs=torch.float16) 16 | def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): 17 | output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq) 18 | return output 19 | 20 | class TritonModuleMixin: 21 | @classmethod 22 | def warmup(cls, model, transpose=False, seqlen=2048): 23 | pass 24 | 25 | class QuantLinear(nn.Module, TritonModuleMixin): 26 | QUANT_TYPE = "triton" 27 | 28 | def __init__(self, bits, group_size, infeatures, outfeatures, bias, **kwargs): 29 | super().__init__() 30 | if bits not in [2, 4, 8]: 31 | raise NotImplementedError("Only 2,4,8 bits are supported.") 32 | if infeatures % 32 != 0 or outfeatures % 32 != 0: 33 | raise NotImplementedError("in_feature and out_feature must be divisible by 32.") 34 | self.infeatures = infeatures 35 | self.outfeatures = outfeatures 36 | self.bits = bits 37 | self.group_size = group_size if group_size != -1 else infeatures 38 | self.maxq = 2**self.bits - 1 39 | 40 | self.register_buffer( 41 | "qweight", 42 | torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), 43 | ) 44 | self.register_buffer( 45 | "qzeros", 46 | torch.zeros( 47 | ( 48 | math.ceil(infeatures / self.group_size), 49 | outfeatures // 32 * self.bits, 50 | ), 51 | dtype=torch.int32, 52 | ), 53 | ) 54 | self.register_buffer( 55 | "scales", 56 | torch.zeros( 57 | (math.ceil(infeatures / self.group_size), outfeatures), 58 | dtype=torch.float16, 59 | ), 60 | ) 61 | self.register_buffer( 62 | "g_idx", 63 | torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), 64 | ) 65 | if bias: 66 | self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) 67 | else: 68 | self.bias = None 69 | 70 | def post_init(self): 71 | pass 72 | 73 | def pack(self, linear, scales, zeros, g_idx=None): 74 | W = linear.weight.data.clone() 75 | if isinstance(linear, nn.Conv2d): 76 | W = W.flatten(1) 77 | if isinstance(linear, transformers.pytorch_utils.Conv1D): 78 | W = W.t() 79 | 80 | self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx 81 | 82 | scales = scales.t().contiguous() 83 | zeros = zeros.t().contiguous() 84 | scale_zeros = zeros * scales 85 | self.scales = scales.clone().half() 86 | if linear.bias is not None: 87 | self.bias = linear.bias.clone().half() 88 | 89 | intweight = [] 90 | for idx in range(self.infeatures): 91 | intweight.append( 92 | torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[ 93 | :, None 94 | ] 95 | ) 96 | intweight = torch.cat(intweight, dim=1) 97 | intweight = intweight.t().contiguous() 98 | intweight = intweight.numpy().astype(np.uint32) 99 | 100 | i = 0 101 | row = 0 102 | qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) 103 | while row < qweight.shape[0]: 104 | if self.bits in [2, 4, 8]: 105 | for j in range(i, i + (32 // self.bits)): 106 | qweight[row] |= intweight[j] << (self.bits * (j - i)) 107 | i += 32 // self.bits 108 | row += 1 109 | else: 110 | raise NotImplementedError("Only 2,4,8 bits are supported.") 111 | 112 | qweight = qweight.astype(np.int32) 113 | self.qweight = torch.from_numpy(qweight) 114 | 115 | zeros -= 1 116 | zeros = zeros.numpy().astype(np.uint32) 117 | qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) 118 | i = 0 119 | col = 0 120 | while col < qzeros.shape[1]: 121 | if self.bits in [2, 4, 8]: 122 | for j in range(i, i + (32 // self.bits)): 123 | qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) 124 | i += 32 // self.bits 125 | col += 1 126 | else: 127 | raise NotImplementedError("Only 2,4,8 bits are supported.") 128 | 129 | qzeros = qzeros.astype(np.int32) 130 | self.qzeros = torch.from_numpy(qzeros) 131 | 132 | def forward(self, x): 133 | out_shape = x.shape[:-1] + (self.outfeatures,) 134 | quant_linear_fn = QuantLinearInferenceOnlyFunction 135 | out = quant_linear_fn.apply( 136 | x.reshape(-1, x.shape[-1]), 137 | self.qweight, 138 | self.scales, 139 | self.qzeros, 140 | self.g_idx, 141 | self.bits, 142 | self.maxq, 143 | ) 144 | out = out.half().reshape(out_shape) 145 | out = out + self.bias if self.bias is not None else out 146 | return out 147 | 148 | @classmethod 149 | def warmup(cls, model, transpose=False, seqlen=2048): 150 | """ 151 | Pre-tunes the quantized kernel 152 | """ 153 | from tqdm import tqdm 154 | 155 | kn_values = {} 156 | 157 | for _, m in model.named_modules(): 158 | if not isinstance(m, cls): 159 | continue 160 | 161 | k = m.infeatures 162 | n = m.outfeatures 163 | 164 | if (k, n) not in kn_values: 165 | kn_values[(k, n)] = ( 166 | m.qweight, 167 | m.scales, 168 | m.qzeros, 169 | m.g_idx, 170 | m.bits, 171 | m.maxq, 172 | ) 173 | 174 | logger.info(f"Found {len(kn_values)} unique KN Linear values.") 175 | logger.info("Warming up autotune cache ...") 176 | with torch.no_grad(): 177 | for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): 178 | m = 2**m 179 | for (k, n), ( 180 | qweight, 181 | scales, 182 | qzeros, 183 | g_idx, 184 | bits, 185 | maxq, 186 | ) in kn_values.items(): 187 | if transpose: 188 | a = torch.randn(m, k, dtype=torch.float16, device=model.device) 189 | quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) 190 | a = torch.randn(m, n, dtype=torch.float16, device=model.device) 191 | transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) 192 | else: 193 | a = torch.randn(m, k, dtype=torch.float16, device=model.device) 194 | quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq) 195 | del kn_values 196 | 197 | -------------------------------------------------------------------------------- /src/model/custom_autotune.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import math 3 | import time 4 | from typing import Dict 5 | 6 | import triton 7 | 8 | 9 | # code based https://github.com/fpgaminer/GPTQ-triton 10 | """ 11 | Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. 12 | """ 13 | 14 | 15 | class CustomizedTritonAutoTuner(triton.KernelInterface): 16 | def __init__( 17 | self, 18 | fn, 19 | arg_names, 20 | configs, 21 | key, 22 | reset_to_zero, 23 | prune_configs_by: Dict = None, 24 | nearest_power_of_two: bool = False, 25 | ): 26 | if not configs: 27 | self.configs = [triton.Config({}, num_warps=4, num_stages=2)] 28 | else: 29 | self.configs = configs 30 | self.key_idx = [arg_names.index(k) for k in key] 31 | self.nearest_power_of_two = nearest_power_of_two 32 | self.cache = {} 33 | # hook to reset all required tensor to zeros before relaunching a kernel 34 | self.hook = lambda args: 0 35 | if reset_to_zero is not None: 36 | self.reset_idx = [arg_names.index(k) for k in reset_to_zero] 37 | 38 | def _hook(args): 39 | for i in self.reset_idx: 40 | args[i].zero_() 41 | 42 | self.hook = _hook 43 | self.arg_names = arg_names 44 | # prune configs 45 | if prune_configs_by: 46 | perf_model, top_k = ( 47 | prune_configs_by["perf_model"], 48 | prune_configs_by["top_k"], 49 | ) 50 | if "early_config_prune" in prune_configs_by: 51 | early_config_prune = prune_configs_by["early_config_prune"] 52 | else: 53 | perf_model, top_k, early_config_prune = None, None, None 54 | self.perf_model, self.configs_top_k = perf_model, top_k 55 | self.early_config_prune = early_config_prune 56 | self.fn = fn 57 | 58 | def _bench(self, *args, config, **meta): 59 | # check for conflicts, i.e. meta-parameters both provided 60 | # as kwargs and by the autotuner 61 | conflicts = meta.keys() & config.kwargs.keys() 62 | if conflicts: 63 | raise ValueError( 64 | f"Conflicting meta-parameters: {', '.join(conflicts)}." 65 | " Make sure that you don't re-define auto-tuned symbols." 66 | ) 67 | # augment meta-parameters with tunable ones 68 | current = dict(meta, **config.kwargs) 69 | 70 | def kernel_call(): 71 | if config.pre_hook: 72 | config.pre_hook(self.nargs) 73 | self.hook(args) 74 | self.fn.run( 75 | *args, 76 | num_warps=config.num_warps, 77 | num_stages=config.num_stages, 78 | **current, 79 | ) 80 | 81 | try: 82 | # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses 83 | # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default 84 | return triton.testing.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40) 85 | except triton.OutOfResources: 86 | return (float("inf"), float("inf"), float("inf")) 87 | 88 | def run(self, *args, **kwargs): 89 | self.nargs = dict(zip(self.arg_names, args)) 90 | if len(self.configs) > 1: 91 | key = tuple(args[i] for i in self.key_idx) 92 | 93 | # This reduces the amount of autotuning by rounding the keys to the nearest power of two 94 | # In my testing this gives decent results, and greatly reduces the amount of tuning required 95 | if self.nearest_power_of_two: 96 | key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) 97 | 98 | if key not in self.cache: 99 | # prune configs 100 | pruned_configs = self.prune_configs(kwargs) 101 | bench_start = time.time() 102 | timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} 103 | bench_end = time.time() 104 | self.bench_time = bench_end - bench_start 105 | self.cache[key] = builtins.min(timings, key=timings.get) 106 | self.hook(args) 107 | self.configs_timings = timings 108 | config = self.cache[key] 109 | else: 110 | config = self.configs[0] 111 | self.best_config = config 112 | if config.pre_hook is not None: 113 | config.pre_hook(self.nargs) 114 | return self.fn.run( 115 | *args, 116 | num_warps=config.num_warps, 117 | num_stages=config.num_stages, 118 | **kwargs, 119 | **config.kwargs, 120 | ) 121 | 122 | def prune_configs(self, kwargs): 123 | pruned_configs = self.configs 124 | if self.early_config_prune: 125 | pruned_configs = self.early_config_prune(self.configs, self.nargs) 126 | if self.perf_model: 127 | top_k = self.configs_top_k 128 | if isinstance(top_k, float) and top_k <= 1.0: 129 | top_k = int(len(self.configs) * top_k) 130 | if len(pruned_configs) > top_k: 131 | est_timing = { 132 | config: self.perf_model( 133 | **self.nargs, 134 | **kwargs, 135 | **config.kwargs, 136 | num_stages=config.num_stages, 137 | num_warps=config.num_warps, 138 | ) 139 | for config in pruned_configs 140 | } 141 | pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] 142 | return pruned_configs 143 | 144 | def warmup(self, *args, **kwargs): 145 | self.nargs = dict(zip(self.arg_names, args)) 146 | for config in self.prune_configs(kwargs): 147 | self.fn.warmup( 148 | *args, 149 | num_warps=config.num_warps, 150 | num_stages=config.num_stages, 151 | **kwargs, 152 | **config.kwargs, 153 | ) 154 | self.nargs = None 155 | 156 | 157 | def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): 158 | def decorator(fn): 159 | return CustomizedTritonAutoTuner( 160 | fn, 161 | fn.arg_names, 162 | configs, 163 | key, 164 | reset_to_zero, 165 | prune_configs_by, 166 | nearest_power_of_two, 167 | ) 168 | 169 | return decorator 170 | 171 | 172 | def matmul248_kernel_config_pruner(configs, nargs): 173 | """ 174 | The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. 175 | """ 176 | m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) 177 | n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) 178 | k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) 179 | 180 | used = set() 181 | for config in configs: 182 | block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) 183 | block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) 184 | block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) 185 | group_size_m = config.kwargs["GROUP_SIZE_M"] 186 | 187 | if ( 188 | block_size_m, 189 | block_size_n, 190 | block_size_k, 191 | group_size_m, 192 | config.num_stages, 193 | config.num_warps, 194 | ) in used: 195 | continue 196 | 197 | used.add( 198 | ( 199 | block_size_m, 200 | block_size_n, 201 | block_size_k, 202 | group_size_m, 203 | config.num_stages, 204 | config.num_warps, 205 | ) 206 | ) 207 | yield triton.Config( 208 | { 209 | "BLOCK_SIZE_M": block_size_m, 210 | "BLOCK_SIZE_N": block_size_n, 211 | "BLOCK_SIZE_K": block_size_k, 212 | "GROUP_SIZE_M": group_size_m, 213 | }, 214 | num_stages=config.num_stages, 215 | num_warps=config.num_warps, 216 | ) 217 | 218 | 219 | __all__ = ["autotune"] 220 | -------------------------------------------------------------------------------- /src/model/fused_llama_attn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from transformers.models.llama.modeling_llama import ( 7 | LlamaAttention, 8 | apply_rotary_pos_emb, 9 | ) 10 | 11 | from packaging.version import parse as parse_version 12 | from src.model.fused_base import FusedBaseAttentionModule 13 | from src.model.QuantLinear import QuantLinear 14 | 15 | def compare_pytorch_version(version: str = "v2.0.0", op: str = "eq"): 16 | assert op in ["eq", "lt", "le", "gt", "ge"] 17 | 18 | from torch import __version__ 19 | 20 | return getattr(parse_version(__version__), f"__{op}__")(parse_version(version)) 21 | 22 | class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): 23 | """Multi-headed attention from 'Attention Is All You Need' paper""" 24 | 25 | def __init__( 26 | self, 27 | hidden_size, 28 | num_heads, 29 | qkv_proj, 30 | o_proj, 31 | rotary_emb, 32 | layer_idx, 33 | ): 34 | super().__init__() 35 | self.hidden_size = hidden_size 36 | self.num_heads = num_heads 37 | self.head_dim = hidden_size // num_heads 38 | self.layer_idx = layer_idx 39 | 40 | if self.head_dim * num_heads != self.hidden_size: 41 | raise ValueError( 42 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 43 | f" and `num_heads`: {num_heads})." 44 | ) 45 | self.qkv_proj = qkv_proj 46 | self.o_proj = o_proj 47 | self.rotary_emb = rotary_emb 48 | 49 | def _shape(self, tensor, seq_len, bsz): 50 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 51 | 52 | def forward( 53 | self, 54 | hidden_states, 55 | past_key_value=None, 56 | attention_mask=None, 57 | position_ids=None, 58 | output_attentions=False, 59 | use_cache=False, 60 | **kwargs, 61 | ): 62 | """Input shape: Batch x Time x Channel""" 63 | 64 | bsz, q_len, _ = hidden_states.size() 65 | 66 | qkv_states = self.qkv_proj(hidden_states) 67 | query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) 68 | 69 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 70 | key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 71 | value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 72 | 73 | kv_seq_len = key_states.shape[-2] 74 | if past_key_value is not None: 75 | if self.layer_idx is None: 76 | raise ValueError( 77 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 78 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 79 | "with a layer index. Please open an issue in AutoGPTQ if you hit this." 80 | ) 81 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 82 | 83 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 84 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 85 | # [bsz, nh, t, hd] 86 | 87 | if past_key_value is not None: 88 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 89 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 90 | 91 | if use_cache: 92 | # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor 93 | # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. 94 | query_states = query_states.contiguous() 95 | key_states = key_states.contiguous() 96 | value_states = value_states.contiguous() 97 | 98 | if compare_pytorch_version("v2.0.0", op="ge"): 99 | attn_output = F.scaled_dot_product_attention( 100 | query_states, 101 | key_states, 102 | value_states, 103 | attn_mask=attention_mask, 104 | is_causal=attention_mask is None and q_len > 1, 105 | ) 106 | attn_weights = None 107 | else: 108 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 109 | 110 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 111 | raise ValueError( 112 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 113 | f" {attn_weights.size()}" 114 | ) 115 | 116 | if attention_mask is not None: 117 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 118 | raise ValueError( 119 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 120 | ) 121 | attn_weights = attn_weights + attention_mask 122 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) 123 | 124 | # upcast attention to fp32 125 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 126 | attn_output = torch.matmul(attn_weights, value_states) 127 | 128 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 129 | raise ValueError( 130 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 131 | f" {attn_output.size()}" 132 | ) 133 | 134 | attn_output = attn_output.transpose(1, 2) 135 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 136 | 137 | attn_output = self.o_proj(attn_output) 138 | 139 | if not output_attentions: 140 | attn_weights = None 141 | 142 | return attn_output, attn_weights, past_key_value 143 | 144 | @classmethod 145 | def inject_to_model( 146 | cls, 147 | model, 148 | use_triton=False, 149 | group_size=-1, 150 | use_cuda_fp16=True, 151 | desc_act=False, 152 | trainable=False, 153 | bits: int = 4, 154 | disable_exllama=True, 155 | disable_exllamav2=False, 156 | **kwargs, 157 | ): 158 | """ 159 | Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. 160 | """ 161 | for name, m in model.named_modules(): 162 | if not isinstance(m, LlamaAttention): 163 | continue 164 | 165 | q_proj = m.q_proj 166 | k_proj = m.k_proj 167 | v_proj = m.v_proj 168 | 169 | qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) 170 | qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) 171 | scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) 172 | 173 | if QuantLinear.QUANT_TYPE == "exllama": 174 | if desc_act: 175 | # TODO: support it. The issue lies maybe in the line: 176 | # int groups = qzeros.size(0); 177 | # in exllama_ext.cpp 178 | raise ValueError( 179 | "Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True." 180 | ) 181 | else: 182 | g_idx = None 183 | else: 184 | g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) 185 | 186 | bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None 187 | 188 | qlinear_args = ( 189 | q_proj.bits, 190 | q_proj.group_size, 191 | q_proj.infeatures, 192 | q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, 193 | True if q_proj.bias is not None else False, 194 | ) 195 | qlinear_kwargs = {"trainable": trainable} 196 | if (not desc_act or group_size == -1) and not use_triton: 197 | qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16 198 | qlinear_kwargs["weight_dtype"] = q_proj.scales.dtype 199 | 200 | qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs) 201 | qkv_layer.qweight = qweights 202 | qkv_layer.qzeros = qzeros 203 | qkv_layer.scales = scales 204 | qkv_layer.g_idx = g_idx 205 | qkv_layer.bias = bias 206 | 207 | # Introduced in Transformers 4.36 208 | layer_idx = None 209 | if hasattr(m, "layer_idx"): 210 | layer_idx = m.layer_idx 211 | attn = cls( 212 | m.hidden_size, 213 | m.num_heads, 214 | qkv_layer, 215 | m.o_proj, 216 | m.rotary_emb, 217 | layer_idx=layer_idx, 218 | ) 219 | 220 | if "." in name: 221 | parent_name = name.rsplit(".", 1)[0] 222 | child_name = name[len(parent_name) + 1 :] 223 | parent = model.get_submodule(parent_name) 224 | else: 225 | parent_name = "" 226 | parent = model 227 | child_name = name 228 | 229 | setattr(parent, child_name, attn) 230 | 231 | 232 | __all__ = ["FusedLlamaAttentionForQuantizedModel"] 233 | 234 | -------------------------------------------------------------------------------- /src/model/fused_llama_mlp.py: -------------------------------------------------------------------------------- 1 | import math 2 | from logging import getLogger 3 | from abc import abstractmethod 4 | 5 | import torch 6 | from transformers.models.llama.modeling_llama import LlamaMLP 7 | 8 | from src.model.fused_base import FusedBaseMLPModule, FusedBaseModule 9 | 10 | try: 11 | import triton 12 | TRITON_AVAILABLE = True 13 | except ImportError: 14 | TRITON_AVAILABLE = False 15 | 16 | 17 | logger = getLogger(__name__) 18 | 19 | if TRITON_AVAILABLE: 20 | import triton 21 | import triton.language as tl 22 | 23 | from . import custom_autotune 24 | from src.model.kernels import silu 25 | 26 | @custom_autotune.autotune( 27 | configs=[ 28 | triton.Config( 29 | { 30 | "BLOCK_SIZE_M": 256, 31 | "BLOCK_SIZE_N": 64, 32 | "BLOCK_SIZE_K": 32, 33 | "GROUP_SIZE_M": 8, 34 | }, 35 | num_stages=4, 36 | num_warps=4, 37 | ), 38 | triton.Config( 39 | { 40 | "BLOCK_SIZE_M": 64, 41 | "BLOCK_SIZE_N": 256, 42 | "BLOCK_SIZE_K": 32, 43 | "GROUP_SIZE_M": 8, 44 | }, 45 | num_stages=4, 46 | num_warps=4, 47 | ), 48 | triton.Config( 49 | { 50 | "BLOCK_SIZE_M": 128, 51 | "BLOCK_SIZE_N": 128, 52 | "BLOCK_SIZE_K": 32, 53 | "GROUP_SIZE_M": 8, 54 | }, 55 | num_stages=4, 56 | num_warps=4, 57 | ), 58 | triton.Config( 59 | { 60 | "BLOCK_SIZE_M": 128, 61 | "BLOCK_SIZE_N": 64, 62 | "BLOCK_SIZE_K": 32, 63 | "GROUP_SIZE_M": 8, 64 | }, 65 | num_stages=4, 66 | num_warps=4, 67 | ), 68 | triton.Config( 69 | { 70 | "BLOCK_SIZE_M": 64, 71 | "BLOCK_SIZE_N": 128, 72 | "BLOCK_SIZE_K": 32, 73 | "GROUP_SIZE_M": 8, 74 | }, 75 | num_stages=4, 76 | num_warps=4, 77 | ), 78 | triton.Config( 79 | { 80 | "BLOCK_SIZE_M": 128, 81 | "BLOCK_SIZE_N": 32, 82 | "BLOCK_SIZE_K": 32, 83 | "GROUP_SIZE_M": 8, 84 | }, 85 | num_stages=4, 86 | num_warps=4, 87 | ), # 3090 88 | triton.Config( 89 | { 90 | "BLOCK_SIZE_M": 128, 91 | "BLOCK_SIZE_N": 16, 92 | "BLOCK_SIZE_K": 32, 93 | "GROUP_SIZE_M": 8, 94 | }, 95 | num_stages=4, 96 | num_warps=4, 97 | ), # 3090 98 | triton.Config( 99 | { 100 | "BLOCK_SIZE_M": 32, 101 | "BLOCK_SIZE_N": 32, 102 | "BLOCK_SIZE_K": 128, 103 | "GROUP_SIZE_M": 8, 104 | }, 105 | num_stages=2, 106 | num_warps=4, 107 | ), # 3090 108 | triton.Config( 109 | { 110 | "BLOCK_SIZE_M": 64, 111 | "BLOCK_SIZE_N": 16, 112 | "BLOCK_SIZE_K": 64, 113 | "GROUP_SIZE_M": 8, 114 | }, 115 | num_stages=4, 116 | num_warps=4, 117 | ), # 3090 118 | triton.Config( 119 | { 120 | "BLOCK_SIZE_M": 64, 121 | "BLOCK_SIZE_N": 32, 122 | "BLOCK_SIZE_K": 64, 123 | "GROUP_SIZE_M": 8, 124 | }, 125 | num_stages=4, 126 | num_warps=4, 127 | ), # 3090 128 | ], 129 | key=["M", "N", "K"], 130 | nearest_power_of_two=True, 131 | prune_configs_by={ 132 | "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, 133 | "perf_model": None, 134 | "top_k": None, 135 | }, 136 | ) 137 | @triton.jit 138 | def quant_fused_matmul_248_kernel( 139 | a_ptr, 140 | c_ptr, 141 | b1_ptr, 142 | scales1_ptr, 143 | zeros1_ptr, 144 | g1_ptr, 145 | b2_ptr, 146 | scales2_ptr, 147 | zeros2_ptr, 148 | g2_ptr, 149 | M, 150 | N, 151 | K, 152 | bits, 153 | maxq, 154 | stride_am, 155 | stride_ak, 156 | stride_bk, 157 | stride_bn, 158 | stride_cm, 159 | stride_cn, 160 | stride_scales, 161 | stride_zeros, 162 | BLOCK_SIZE_M: tl.constexpr, 163 | BLOCK_SIZE_N: tl.constexpr, 164 | BLOCK_SIZE_K: tl.constexpr, 165 | GROUP_SIZE_M: tl.constexpr, 166 | ): 167 | """ 168 | Computes: C = silu(A * B1) * (A * B2) 169 | A is of shape (M, K) float16 170 | B is of shape (K//8, N) int32 171 | C is of shape (M, N) float16 172 | scales is of shape (1, N) float16 173 | zeros is of shape (1, N//8) int32 174 | """ 175 | infearure_per_bits = 32 // bits 176 | 177 | pid = tl.program_id(axis=0) 178 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 179 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 180 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 181 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 182 | group_id = pid // num_pid_in_group 183 | first_pid_m = group_id * GROUP_SIZE_M 184 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 185 | pid_m = first_pid_m + (pid % group_size_m) 186 | pid_n = (pid % num_pid_in_group) // group_size_m 187 | 188 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 189 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 190 | offs_k = tl.arange(0, BLOCK_SIZE_K) 191 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 192 | a_mask = offs_am[:, None] < M 193 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 194 | b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) 195 | b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) 196 | g1_ptrs = g1_ptr + offs_k 197 | g2_ptrs = g2_ptr + offs_k 198 | # shifter is used to extract the N bits of each element in the 32-bit word from B 199 | scales1_ptrs = scales1_ptr + offs_bn[None, :] 200 | scales2_ptrs = scales2_ptr + offs_bn[None, :] 201 | zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits) 202 | zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits) 203 | 204 | shifter = (offs_k % infearure_per_bits) * bits 205 | zeros_shifter = (offs_bn % infearure_per_bits) * bits 206 | accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 207 | accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 208 | for k in range(0, num_pid_k): 209 | g1_idx = tl.load(g1_ptrs) 210 | g2_idx = tl.load(g2_ptrs) 211 | 212 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 213 | scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 214 | scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales) 215 | 216 | zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 217 | zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq 218 | zeros1 = zeros1 + 1 219 | 220 | zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 221 | zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq 222 | zeros2 = zeros2 + 1 223 | 224 | a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 225 | b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 226 | b2 = tl.load(b2_ptrs) 227 | 228 | # Now we need to unpack b (which is N-bit values) into 32-bit values 229 | b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values 230 | b1 = (b1 - zeros1) * scales1 # Scale and shift 231 | accumulator1 += tl.dot(a, b1) 232 | 233 | b2 = (b2 >> shifter[:, None]) & maxq 234 | b2 = (b2 - zeros2) * scales2 235 | accumulator2 += tl.dot(a, b2) 236 | 237 | a_ptrs += BLOCK_SIZE_K 238 | b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk 239 | b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk 240 | g1_ptrs += BLOCK_SIZE_K 241 | g2_ptrs += BLOCK_SIZE_K 242 | 243 | accumulator1 = silu(accumulator1) 244 | c = accumulator1 * accumulator2 245 | c = c.to(tl.float16) 246 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] 247 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) 248 | tl.store(c_ptrs, c, mask=c_mask) 249 | 250 | else: 251 | quant_fused_matmul_248_kernel = None 252 | 253 | 254 | class FusedBaseMLPModule(FusedBaseModule): 255 | @classmethod 256 | @abstractmethod 257 | def inject_to_model(cls, model, use_triton=False, **kwargs): 258 | raise NotImplementedError() 259 | 260 | 261 | class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule): 262 | def __init__( 263 | self, 264 | gate_proj, 265 | down_proj, 266 | up_proj, 267 | ): 268 | super().__init__() 269 | 270 | self.infeatures = gate_proj.infeatures 271 | self.intermediate_size = gate_proj.outfeatures 272 | self.outfeatures = down_proj.outfeatures 273 | self.bits = gate_proj.bits 274 | self.maxq = gate_proj.maxq 275 | 276 | self.gate_proj = gate_proj 277 | self.up_proj = up_proj 278 | self.down_proj = down_proj 279 | 280 | def forward(self, x): 281 | return self.down_proj(self.triton_llama_mlp(x)) 282 | 283 | def triton_llama_mlp(self, x): 284 | with torch.cuda.device(x.device): 285 | out_shape = x.shape[:-1] + (self.intermediate_size,) 286 | x = x.reshape(-1, x.shape[-1]) 287 | M, K = x.shape 288 | N = self.intermediate_size 289 | c = torch.empty((M, N), device=x.device, dtype=torch.float16) 290 | grid = lambda META: ( # noqa: E731 291 | triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 292 | ) 293 | quant_fused_matmul_248_kernel[grid]( 294 | x, 295 | c, 296 | self.gate_proj.qweight, 297 | self.gate_proj.scales, 298 | self.gate_proj.qzeros, 299 | self.gate_proj.g_idx, 300 | self.up_proj.qweight, 301 | self.up_proj.scales, 302 | self.up_proj.qzeros, 303 | self.up_proj.g_idx, 304 | M, 305 | N, 306 | K, 307 | self.bits, 308 | self.maxq, 309 | x.stride(0), 310 | x.stride(1), 311 | self.gate_proj.qweight.stride(0), 312 | self.gate_proj.qweight.stride(1), 313 | c.stride(0), 314 | c.stride(1), 315 | self.gate_proj.scales.stride(0), 316 | self.gate_proj.qzeros.stride(0), 317 | ) 318 | c = c.reshape(out_shape) 319 | return c 320 | 321 | @classmethod 322 | def inject_to_model(cls, model, use_triton=False, **kwargs): 323 | if not use_triton: 324 | logger.warning( 325 | f"Skipping module injection for {cls.__name__} as currently not supported with use_triton=False." 326 | ) 327 | return 328 | elif not TRITON_AVAILABLE: 329 | logger.warning( 330 | f"Skipping module injection for {cls.__name__} as Triton is not available. Please check your installation." 331 | ) 332 | return 333 | 334 | for name, m in model.named_modules(): 335 | if not isinstance(m, LlamaMLP): 336 | continue 337 | 338 | # import pdb;pdb.set_trace() 339 | mlp = cls(m.gate_proj, m.down_proj, m.up_proj) 340 | 341 | if "." in name: 342 | parent_name = name.rsplit(".", 1)[0] 343 | child_name = name[len(parent_name) + 1 :] 344 | parent = model.get_submodule(parent_name) 345 | else: 346 | parent_name = "" 347 | parent = model 348 | child_name = name 349 | 350 | setattr(parent, child_name, mlp) 351 | 352 | @classmethod 353 | def warmup(cls, model, transpose=False, seqlen=2048): 354 | from tqdm import tqdm 355 | 356 | kn_values = {} 357 | 358 | for _, m in model.named_modules(): 359 | if not isinstance(m, cls): 360 | continue 361 | 362 | k = m.infeatures 363 | n = m.intermediate_size 364 | 365 | if (k, n) not in kn_values: 366 | kn_values[(k, n)] = m 367 | 368 | logger.info(f"Found {len(kn_values)} unique fused mlp KN values.") 369 | logger.info("Warming up autotune cache ...") 370 | with torch.no_grad(): 371 | for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): 372 | m = 2**m 373 | for (k, n), (modules) in kn_values.items(): 374 | a = torch.randn(m, k, dtype=torch.float16, device=model.device) 375 | modules.triton_llama_mlp(a) 376 | del kn_values 377 | 378 | 379 | __all__ = ["FusedLlamaMLPForQuantizedModel"] 380 | -------------------------------------------------------------------------------- /src/model/kernels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | from . import custom_autotune 6 | 7 | 8 | @custom_autotune.autotune( 9 | configs=[ 10 | triton.Config( 11 | { 12 | "BLOCK_SIZE_M": 64, 13 | "BLOCK_SIZE_N": 256, 14 | "BLOCK_SIZE_K": 32, 15 | "GROUP_SIZE_M": 8, 16 | }, 17 | num_stages=4, 18 | num_warps=4, 19 | ), 20 | triton.Config( 21 | { 22 | "BLOCK_SIZE_M": 128, 23 | "BLOCK_SIZE_N": 128, 24 | "BLOCK_SIZE_K": 32, 25 | "GROUP_SIZE_M": 8, 26 | }, 27 | num_stages=4, 28 | num_warps=4, 29 | ), 30 | triton.Config( 31 | { 32 | "BLOCK_SIZE_M": 64, 33 | "BLOCK_SIZE_N": 128, 34 | "BLOCK_SIZE_K": 32, 35 | "GROUP_SIZE_M": 8, 36 | }, 37 | num_stages=4, 38 | num_warps=4, 39 | ), 40 | triton.Config( 41 | { 42 | "BLOCK_SIZE_M": 128, 43 | "BLOCK_SIZE_N": 32, 44 | "BLOCK_SIZE_K": 32, 45 | "GROUP_SIZE_M": 8, 46 | }, 47 | num_stages=4, 48 | num_warps=4, 49 | ), 50 | triton.Config( 51 | { 52 | "BLOCK_SIZE_M": 64, 53 | "BLOCK_SIZE_N": 64, 54 | "BLOCK_SIZE_K": 32, 55 | "GROUP_SIZE_M": 8, 56 | }, 57 | num_stages=4, 58 | num_warps=4, 59 | ), 60 | triton.Config( 61 | { 62 | "BLOCK_SIZE_M": 64, 63 | "BLOCK_SIZE_N": 128, 64 | "BLOCK_SIZE_K": 32, 65 | "GROUP_SIZE_M": 8, 66 | }, 67 | num_stages=2, 68 | num_warps=8, 69 | ), 70 | ], 71 | key=["M", "N", "K"], 72 | nearest_power_of_two=True, 73 | prune_configs_by={ 74 | "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, 75 | "perf_model": None, 76 | "top_k": None, 77 | }, 78 | ) 79 | @triton.jit 80 | def quant_matmul_248_kernel( 81 | a_ptr, 82 | b_ptr, 83 | c_ptr, 84 | scales_ptr, 85 | zeros_ptr, 86 | g_ptr, 87 | M, 88 | N, 89 | K, 90 | bits, 91 | maxq, 92 | stride_am, 93 | stride_ak, 94 | stride_bk, 95 | stride_bn, 96 | stride_cm, 97 | stride_cn, 98 | stride_scales, 99 | stride_zeros, 100 | BLOCK_SIZE_M: tl.constexpr, 101 | BLOCK_SIZE_N: tl.constexpr, 102 | BLOCK_SIZE_K: tl.constexpr, 103 | GROUP_SIZE_M: tl.constexpr, 104 | ): 105 | """ 106 | Compute the matrix multiplication C = A x B. 107 | A is of shape (M, K) float16 108 | B is of shape (K//8, N) int32 109 | C is of shape (M, N) float16 110 | scales is of shape (G, N) float16 111 | zeros is of shape (G, N) float16 112 | g_ptr is of shape (K) int32 113 | """ 114 | infearure_per_bits = 32 // bits 115 | 116 | pid = tl.program_id(axis=0) 117 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 118 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 119 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 120 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 121 | group_id = pid // num_pid_in_group 122 | first_pid_m = group_id * GROUP_SIZE_M 123 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 124 | pid_m = first_pid_m + (pid % group_size_m) 125 | pid_n = (pid % num_pid_in_group) // group_size_m 126 | 127 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 128 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 129 | offs_k = tl.arange(0, BLOCK_SIZE_K) 130 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 131 | a_mask = offs_am[:, None] < M 132 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 133 | b_ptrs = b_ptr + ( 134 | (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn 135 | ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 136 | g_ptrs = g_ptr + offs_k 137 | # shifter is used to extract the N bits of each element in the 32-bit word from B 138 | scales_ptrs = scales_ptr + offs_bn[None, :] 139 | zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) 140 | 141 | shifter = (offs_k % infearure_per_bits) * bits 142 | zeros_shifter = (offs_bn % infearure_per_bits) * bits 143 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 144 | 145 | for k in range(0, num_pid_k): 146 | g_idx = tl.load(g_ptrs) 147 | 148 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 149 | scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 150 | zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 151 | 152 | zeros = (zeros >> zeros_shifter[None, :]) & maxq 153 | zeros = zeros + 1 154 | 155 | a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 156 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 157 | 158 | # Now we need to unpack b (which is N-bit values) into 32-bit values 159 | b = (b >> shifter[:, None]) & maxq # Extract the N-bit values 160 | b = (b - zeros) * scales # Scale and shift 161 | 162 | accumulator += tl.dot(a, b) 163 | a_ptrs += BLOCK_SIZE_K 164 | b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk 165 | g_ptrs += BLOCK_SIZE_K 166 | 167 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] 168 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) 169 | tl.store(c_ptrs, accumulator, mask=c_mask) 170 | 171 | 172 | @custom_autotune.autotune( 173 | configs=[ 174 | triton.Config( 175 | { 176 | "BLOCK_SIZE_M": 64, 177 | "BLOCK_SIZE_N": 32, 178 | "BLOCK_SIZE_K": 256, 179 | "GROUP_SIZE_M": 8, 180 | }, 181 | num_stages=4, 182 | num_warps=4, 183 | ), 184 | triton.Config( 185 | { 186 | "BLOCK_SIZE_M": 128, 187 | "BLOCK_SIZE_N": 32, 188 | "BLOCK_SIZE_K": 128, 189 | "GROUP_SIZE_M": 8, 190 | }, 191 | num_stages=4, 192 | num_warps=4, 193 | ), 194 | triton.Config( 195 | { 196 | "BLOCK_SIZE_M": 64, 197 | "BLOCK_SIZE_N": 32, 198 | "BLOCK_SIZE_K": 128, 199 | "GROUP_SIZE_M": 8, 200 | }, 201 | num_stages=4, 202 | num_warps=4, 203 | ), 204 | triton.Config( 205 | { 206 | "BLOCK_SIZE_M": 128, 207 | "BLOCK_SIZE_N": 32, 208 | "BLOCK_SIZE_K": 32, 209 | "GROUP_SIZE_M": 8, 210 | }, 211 | num_stages=4, 212 | num_warps=4, 213 | ), 214 | triton.Config( 215 | { 216 | "BLOCK_SIZE_M": 64, 217 | "BLOCK_SIZE_N": 32, 218 | "BLOCK_SIZE_K": 64, 219 | "GROUP_SIZE_M": 8, 220 | }, 221 | num_stages=4, 222 | num_warps=4, 223 | ), 224 | triton.Config( 225 | { 226 | "BLOCK_SIZE_M": 64, 227 | "BLOCK_SIZE_N": 32, 228 | "BLOCK_SIZE_K": 128, 229 | "GROUP_SIZE_M": 8, 230 | }, 231 | num_stages=2, 232 | num_warps=8, 233 | ), 234 | ], 235 | key=["M", "N", "K"], 236 | nearest_power_of_two=True, 237 | ) 238 | @triton.jit 239 | def transpose_quant_matmul_248_kernel( 240 | a_ptr, 241 | b_ptr, 242 | c_ptr, 243 | scales_ptr, 244 | zeros_ptr, 245 | g_ptr, 246 | M, 247 | N, 248 | K, 249 | bits, 250 | maxq, 251 | stride_am, 252 | stride_ak, 253 | stride_bk, 254 | stride_bn, 255 | stride_cm, 256 | stride_cn, 257 | stride_scales, 258 | stride_zeros, 259 | BLOCK_SIZE_M: tl.constexpr, 260 | BLOCK_SIZE_N: tl.constexpr, 261 | BLOCK_SIZE_K: tl.constexpr, 262 | GROUP_SIZE_M: tl.constexpr, 263 | ): 264 | """ 265 | Compute the matrix multiplication C = A x B. 266 | A is of shape (M, N) float16 267 | B is of shape (K//8, N) int32 268 | C is of shape (M, K) float16 269 | scales is of shape (G, N) float16 270 | zeros is of shape (G, N) float16 271 | g_ptr is of shape (K) int32 272 | """ 273 | infearure_per_bits = 32 // bits 274 | 275 | pid = tl.program_id(axis=0) 276 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 277 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 278 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 279 | num_pid_in_group = GROUP_SIZE_M * num_pid_k 280 | group_id = pid // num_pid_in_group 281 | first_pid_m = group_id * GROUP_SIZE_M 282 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 283 | pid_m = first_pid_m + (pid % group_size_m) 284 | pid_k = (pid % num_pid_in_group) // group_size_m 285 | 286 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 287 | offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 288 | offs_n = tl.arange(0, BLOCK_SIZE_N) 289 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) 290 | a_mask = offs_am[:, None] < M 291 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 292 | b_ptrs = b_ptr + ( 293 | (offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn 294 | ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 295 | g_ptrs = g_ptr + offs_bk 296 | g_idx = tl.load(g_ptrs) 297 | 298 | # shifter is used to extract the N bits of each element in the 32-bit word from B 299 | scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales 300 | zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros 301 | 302 | shifter = (offs_bk % infearure_per_bits) * bits 303 | zeros_shifter = (offs_n % infearure_per_bits) * bits 304 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) 305 | 306 | for k in range(0, num_pid_n): 307 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 308 | scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 309 | zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 310 | 311 | zeros = (zeros >> zeros_shifter[None, :]) & maxq 312 | zeros = zeros + 1 313 | 314 | a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N) 315 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 316 | 317 | # Now we need to unpack b (which is N-bit values) into 32-bit values 318 | b = (b >> shifter[:, None]) & maxq # Extract the N-bit values 319 | b = (b - zeros) * scales # Scale and shift 320 | b = tl.trans(b) 321 | 322 | accumulator += tl.dot(a, b) 323 | a_ptrs += BLOCK_SIZE_N 324 | b_ptrs += BLOCK_SIZE_N 325 | scales_ptrs += BLOCK_SIZE_N 326 | zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits 327 | 328 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] 329 | c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) 330 | tl.store(c_ptrs, accumulator, mask=c_mask) 331 | 332 | @triton.jit 333 | def silu(x): 334 | return x * tl.sigmoid(x) 335 | 336 | def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq): 337 | with torch.cuda.device(input.device): 338 | output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype) 339 | grid = lambda META: ( # noqa: E731 340 | triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), 341 | ) 342 | quant_matmul_248_kernel[grid]( 343 | input, 344 | qweight, 345 | output, 346 | scales.to(input.dtype), 347 | qzeros, 348 | g_idx, 349 | input.shape[0], 350 | qweight.shape[1], 351 | input.shape[1], 352 | bits, 353 | maxq, 354 | input.stride(0), 355 | input.stride(1), 356 | qweight.stride(0), 357 | qweight.stride(1), 358 | output.stride(0), 359 | output.stride(1), 360 | scales.stride(0), 361 | qzeros.stride(0), 362 | ) 363 | return output 364 | 365 | def transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq): 366 | with torch.cuda.device(input.device): 367 | output_dim = (qweight.shape[0] * 32) // bits 368 | output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype) 369 | grid = lambda META: ( # noqa: E731 370 | triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(output_dim, META["BLOCK_SIZE_K"]), 371 | ) 372 | transpose_quant_matmul_248_kernel[grid]( 373 | input, 374 | qweight, 375 | output, 376 | scales.to(input.dtype), 377 | qzeros, 378 | g_idx, 379 | input.shape[0], 380 | qweight.shape[1], 381 | output_dim, 382 | bits, 383 | maxq, 384 | input.stride(0), 385 | input.stride(1), 386 | qweight.stride(0), 387 | qweight.stride(1), 388 | output.stride(0), 389 | output.stride(1), 390 | scales.stride(0), 391 | qzeros.stride(0), 392 | ) 393 | return output 394 | 395 | def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq): 396 | with torch.cuda.device(input.device): 397 | output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) 398 | grid = lambda META: ( # noqa: E731 399 | triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), 400 | ) 401 | quant_matmul_248_kernel[grid]( 402 | input, 403 | qweight, 404 | output, 405 | scales, 406 | qzeros, 407 | g_idx, 408 | input.shape[0], 409 | qweight.shape[1], 410 | input.shape[1], 411 | bits, 412 | maxq, 413 | input.stride(0), 414 | input.stride(1), 415 | qweight.stride(0), 416 | qweight.stride(1), 417 | output.stride(0), 418 | output.stride(1), 419 | scales.stride(0), 420 | qzeros.stride(0), 421 | ) 422 | return output 423 | 424 | -------------------------------------------------------------------------------- /src/model/LlamaGPTQ.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from dataclasses import dataclass, field, fields 5 | from os.path import isfile, join 6 | from typing import Dict, Optional, Union 7 | 8 | import accelerate 9 | import torch 10 | import torch.nn as nn 11 | import transformers 12 | 13 | from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel 14 | from transformers.modeling_utils import no_init_weights 15 | from transformers.utils.generic import ContextManagers 16 | from transformers.utils.hub import ( 17 | PushToHubMixin, 18 | cached_file 19 | ) 20 | 21 | from src.model.utils import find_layers, simple_dispatch_model, make_quant 22 | from src.model.fused_llama_attn import FusedLlamaAttentionForQuantizedModel 23 | from src.model.fused_llama_mlp import FusedLlamaMLPForQuantizedModel 24 | 25 | logger = logging.getLogger(__name__) 26 | handler = logging.StreamHandler() 27 | formatter = logging.Formatter("%(levelname)s - %(message)s") 28 | handler.setFormatter(formatter) 29 | logger.addHandler(handler) 30 | logger.setLevel(logging.INFO) 31 | 32 | SYNONYMS = { 33 | "w_bit": "bits", 34 | "q_group_size": "group_size", 35 | } 36 | 37 | @dataclass 38 | class BaseQuantizeConfig(PushToHubMixin): 39 | bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]}) 40 | group_size: int = field(default=-1) 41 | damp_percent: float = field(default=0.01) 42 | desc_act: bool = field(default=True) 43 | static_groups: bool = field(default=False) 44 | sym: bool = field(default=True) 45 | true_sequential: bool = field(default=True) 46 | is_marlin_format: bool = field(default=False) 47 | model_name_or_path: Optional[str] = field(default=None) 48 | model_file_base_name: Optional[str] = field(default=None) 49 | awq_gemm_checkpoint: Optional[bool] = field(default=False) 50 | 51 | def __post_init__(self): 52 | fields_info = fields(self) 53 | 54 | if self.bits not in fields_info[0].metadata["choices"]: 55 | raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.") 56 | if self.group_size != -1 and self.group_size <= 0: 57 | raise ValueError("unless equal to -1, group_size must greater then 0.") 58 | if not (0 < self.damp_percent < 1): 59 | raise ValueError("damp_percent must between 0 and 1.") 60 | 61 | def save_pretrained(self, save_dir: str, **kwargs): 62 | with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f: 63 | json.dump(self.to_dict(), f, indent=2) 64 | 65 | @classmethod 66 | def from_pretrained(cls, save_dir: str, **kwargs): 67 | # Parameters related to loading from Hugging Face Hub 68 | cache_dir = kwargs.pop("cache_dir", None) 69 | force_download = kwargs.pop("force_download", False) 70 | resume_download = kwargs.pop("resume_download", False) 71 | proxies = kwargs.pop("proxies", None) 72 | local_files_only = kwargs.pop("local_files_only", False) 73 | use_auth_token = kwargs.pop("use_auth_token", None) 74 | revision = kwargs.pop("revision", None) 75 | subfolder = kwargs.pop("subfolder", None) 76 | commit_hash = kwargs.pop("_commit_hash", None) 77 | 78 | transformers_config = False 79 | for quantize_config_filename in [ 80 | "quantize_config.json", 81 | "quant_config.json", 82 | "config.json", 83 | ]: 84 | if os.path.isdir(save_dir): # Local 85 | resolved_config_file = join(save_dir, quantize_config_filename) 86 | else: # Remote 87 | resolved_config_file = cached_file( 88 | save_dir, 89 | quantize_config_filename, 90 | cache_dir=cache_dir, 91 | force_download=force_download, 92 | resume_download=resume_download, 93 | proxies=proxies, 94 | use_auth_token=use_auth_token, 95 | revision=revision, 96 | local_files_only=local_files_only, 97 | subfolder=subfolder, 98 | _raise_exceptions_for_missing_entries=False, 99 | _raise_exceptions_for_connection_errors=False, 100 | _commit_hash=commit_hash, 101 | ) 102 | if resolved_config_file is not None: 103 | if quantize_config_filename == "config.json": 104 | transformers_config = True 105 | break 106 | 107 | if resolved_config_file is None: 108 | raise ValueError( 109 | "No quantize_config.json, quant_config.json or config.json file was found in the model repository." 110 | ) 111 | 112 | field_names = [field.name for field in fields(cls)] 113 | with open(resolved_config_file, "r", encoding="utf-8") as f: 114 | args_from_json = json.load(f) 115 | 116 | if transformers_config: 117 | args_from_json = args_from_json["quantization_config"] 118 | 119 | filtered_args = {"awq_gemm_checkpoint": False} 120 | for key, val in args_from_json.items(): 121 | if key == "version" and val == "GEMM": 122 | filtered_args["awq_gemm_checkpoint"] = True 123 | elif key in field_names: 124 | filtered_args[key] = val 125 | elif key in SYNONYMS and SYNONYMS[key] in field_names: 126 | filtered_args[SYNONYMS[key]] = val 127 | else: 128 | logger.warning(f"ignoring unknown parameter in {quantize_config_filename}: {key}.") 129 | 130 | if filtered_args["awq_gemm_checkpoint"]: 131 | # AWQ does not reorder the rows. 132 | filtered_args["desc_act"] = False 133 | 134 | if "sym" not in args_from_json: 135 | logger.warning( 136 | f"The quantization configuration {quantize_config_filename} does not contain an entry `sym` (symetric quantization). This may result in silent errors." 137 | ) 138 | 139 | return cls(**filtered_args) 140 | 141 | def to_dict(self): 142 | return { 143 | "bits": self.bits, 144 | "group_size": self.group_size, 145 | "damp_percent": self.damp_percent, 146 | "desc_act": self.desc_act, 147 | "static_groups": self.static_groups, 148 | "sym": self.sym, 149 | "true_sequential": self.true_sequential, 150 | "model_name_or_path": self.model_name_or_path, 151 | "model_file_base_name": self.model_file_base_name, 152 | "is_marlin_format": self.is_marlin_format, 153 | "quant_method": "gptq", 154 | } 155 | 156 | class LlamaGPTQ(nn.Module, PushToHubMixin): 157 | layer_type = "LlamaDecoderLayer" 158 | layers_block_name = "model.layers" 159 | outside_layer_modules = ["model.embed_tokens", "model.norm"] 160 | inside_layer_modules = [ 161 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 162 | ["self_attn.o_proj"], 163 | ["mlp.up_proj", "mlp.gate_proj"], 164 | ["mlp.down_proj"], 165 | ] 166 | 167 | fused_attn_module_type = FusedLlamaAttentionForQuantizedModel 168 | fused_mlp_module_type = FusedLlamaMLPForQuantizedModel 169 | lm_head_name: str = "lm_head" 170 | 171 | def __init__( 172 | self, 173 | model: PreTrainedModel, 174 | quantized: bool, 175 | quantize_config: BaseQuantizeConfig, 176 | is_triton_backend: bool = False, 177 | injected_fused_attention: bool = False, 178 | injected_fused_mlp: bool = False, 179 | ): 180 | super().__init__() 181 | 182 | self.model = model 183 | self.model_type = self.model.config.model_type 184 | self._quantized = quantized 185 | self.quantize_config = quantize_config 186 | self.config = self.model.config 187 | 188 | self.is_triton_backend = is_triton_backend 189 | self.injected_fused_attention = injected_fused_attention 190 | self.injected_fused_mlp = injected_fused_mlp 191 | 192 | @property 193 | def quantized(self): 194 | return self._quantized 195 | 196 | @property 197 | def hf_device_map(self): 198 | return getattr(self.model, "hf_device_map", None) 199 | 200 | @property 201 | def device(self): 202 | if not self.hf_device_map: 203 | return self.model.device 204 | else: 205 | device = [d for d in self.hf_device_map.values() if d not in {"disk"}][0] 206 | return torch.device(device) 207 | 208 | def to(self, device: Union[str, torch.device]): 209 | self.model.to(device) 210 | return self 211 | 212 | def forward(self, *args, **kwargs): 213 | return self.model(*args, **kwargs) 214 | 215 | def generate(self, **kwargs): 216 | """shortcut for model.generate""" 217 | with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type): 218 | return self.model.generate(**kwargs) 219 | 220 | @classmethod 221 | def from_quantized( 222 | cls, 223 | model_name_or_path: Optional[str], 224 | device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, 225 | max_memory: Optional[dict] = None, 226 | device: Optional[Union[str, int]] = None, 227 | use_triton: bool = True, 228 | torch_dtype: Optional[torch.dtype] = None, 229 | inject_fused_attention: bool = True, 230 | inject_fused_mlp: bool = True, 231 | use_cuda_fp16: bool = True, 232 | quantize_config: Optional[BaseQuantizeConfig] = None, 233 | model_basename: Optional[str] = None, 234 | trust_remote_code: bool = False, 235 | warmup_triton: bool = False, 236 | **kwargs, 237 | ): 238 | # == step1: prepare configs and file names == # 239 | config = AutoConfig.from_pretrained( 240 | model_name_or_path, 241 | trust_remote_code=trust_remote_code 242 | ) 243 | 244 | if config.model_type not in ['llama']: 245 | raise TypeError(f"{config.model_type} isn't supported yet.") 246 | 247 | if quantize_config is None: 248 | quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **kwargs) 249 | 250 | if model_basename is None: 251 | if quantize_config.model_file_base_name: 252 | possible_model_basenames = [quantize_config.model_file_base_name] 253 | else: 254 | possible_model_basenames = [ 255 | f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g", 256 | "model", 257 | ] 258 | else: 259 | possible_model_basenames = [model_basename] 260 | 261 | quantize_config.model_name_or_path = model_name_or_path 262 | 263 | extensions = [".safetensors",] 264 | model_name_or_path = str(model_name_or_path) 265 | 266 | resolved_archive_file = None 267 | true_model_basename = None 268 | searched_files = [] 269 | 270 | for ext in extensions: 271 | for possible_model_basename in possible_model_basenames: 272 | model_save_name = join(model_name_or_path, possible_model_basename) 273 | searched_files.append(possible_model_basename + ext) 274 | if isfile(model_save_name + ext): 275 | resolved_archive_file = model_save_name + ext 276 | true_model_basename = possible_model_basename 277 | break 278 | 279 | quantize_config.model_file_base_name = true_model_basename 280 | if resolved_archive_file is None: 281 | raise FileNotFoundError( 282 | f"Could not find a model in {model_name_or_path} with a name in {', '.join(searched_files)}. Please specify the argument model_basename to use a custom file name." 283 | ) 284 | 285 | model_save_name = resolved_archive_file 286 | 287 | # == step2: convert model to gptq-model (replace Linear with QuantLinear) == # 288 | def skip(*args, **kwargs): 289 | pass 290 | 291 | torch_dtype = torch.float16 292 | 293 | torch.nn.init.kaiming_uniform_ = skip 294 | torch.nn.init.uniform_ = skip 295 | torch.nn.init.normal_ = skip 296 | 297 | transformers.modeling_utils._init_weights = False 298 | 299 | init_contexts = [no_init_weights()] 300 | with ContextManagers(init_contexts): 301 | model = AutoModelForCausalLM.from_config( 302 | config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype 303 | ) 304 | 305 | layers = find_layers(model) 306 | ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules 307 | for name in list(layers.keys()): 308 | if any(name.startswith(ignore_layer) for ignore_layer in ignore_layers) or all( 309 | not name.endswith(ignore_layer) 310 | for sublist in cls.inside_layer_modules 311 | for ignore_layer in sublist 312 | ): 313 | logger.info(f"The layer {name} is not quantized.") 314 | del layers[name] 315 | 316 | make_quant( 317 | model, 318 | layers, 319 | quantize_config.bits, 320 | quantize_config.group_size, 321 | use_triton=use_triton, 322 | use_cuda_fp16=use_cuda_fp16, 323 | desc_act=quantize_config.desc_act, 324 | ) 325 | model.tie_weights() 326 | 327 | # == step3: load checkpoint and dispatch == # 328 | if isinstance(device_map, str) and device_map not in [ 329 | "auto", 330 | "balanced", 331 | "balanced_low_0", 332 | "sequential", 333 | ]: 334 | raise ValueError( 335 | "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " 336 | "'sequential'." 337 | ) 338 | 339 | device = torch.device(device) 340 | if not max_memory and not device_map: 341 | device_map = {"": device.index if device.type == "cuda" else device.type} 342 | 343 | 344 | accelerate.utils.modeling.load_checkpoint_in_model( 345 | model, 346 | dtype=torch_dtype, 347 | checkpoint=model_save_name, 348 | device_map=device_map, 349 | offload_state_dict=True, 350 | offload_buffers=True, 351 | ) 352 | 353 | # TODO: Why are we using this custom function and not dispatch_model? 354 | model = simple_dispatch_model(model, device_map) 355 | 356 | # == step4: set seqlen == # 357 | model_config = model.config.to_dict() 358 | seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] 359 | if any(k in model_config for k in seq_len_keys): 360 | for key in seq_len_keys: 361 | if key in model_config: 362 | model.seqlen = model_config[key] 363 | break 364 | else: 365 | logger.warning("can't get model's sequence length from model config, will set to 4096.") 366 | model.seqlen = 4096 367 | 368 | # == step5: (optional) inject optimized module == # 369 | if inject_fused_attention: 370 | if cls.fused_attn_module_type is None: 371 | inject_fused_attention = False 372 | logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.") 373 | else: 374 | cls.fused_attn_module_type.inject_to_model( 375 | model, 376 | use_triton=use_triton, 377 | group_size=quantize_config.group_size, 378 | use_cuda_fp16=use_cuda_fp16, 379 | desc_act=quantize_config.desc_act, 380 | trainable=False, 381 | bits=quantize_config.bits, 382 | disable_exllama=True, 383 | disable_exllamav2=True, 384 | ) 385 | if inject_fused_mlp: 386 | if cls.fused_mlp_module_type is None: 387 | inject_fused_mlp = False 388 | logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.") 389 | else: 390 | cls.fused_mlp_module_type.inject_to_model(model, use_triton=use_triton) 391 | 392 | torch.cuda.empty_cache() 393 | model.eval() 394 | 395 | # == step6: (optional) warmup triton == # 396 | if use_triton and warmup_triton: 397 | from src.model.QuantLinear import QuantLinear 398 | QuantLinear.warmup(model, seqlen=model.seqlen) 399 | 400 | if inject_fused_mlp and cls.fused_mlp_module_type is not None: 401 | cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen) 402 | 403 | # import pdb; pdb.set_trace() 404 | return cls( 405 | model, 406 | True, 407 | quantize_config, 408 | is_triton_backend=use_triton, 409 | injected_fused_attention=inject_fused_attention, 410 | injected_fused_mlp=inject_fused_mlp and use_triton, 411 | ) 412 | 413 | # def warmup_triton(self, enabled: bool = True): 414 | # if not enabled: 415 | # return 416 | # if not TRITON_AVAILABLE: 417 | # logger.warning("triton is not available, skip warmup stage directly.") 418 | # return 419 | 420 | # from ..nn_modules.qlinear.qlinear_triton import QuantLinear 421 | 422 | # QuantLinear.warmup(self.model, seqlen=self.model.seqlen) 423 | 424 | # if self.fused_mlp_module_type is not None: 425 | # self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen) 426 | 427 | # def enable_trainable_mode(self, enabled: bool = True): 428 | # if not self.is_triton_backend and enabled: 429 | # raise NotImplementedError("For now, trainable mode only supports triton backend.") 430 | # for n, m in self.model.named_modules(): 431 | # if hasattr(m, "trainable"): 432 | # setattr(m, "trainable", enabled) 433 | 434 | # def disable_trainable_mode(self): 435 | # self.enable_trainable_mode(enabled=False) 436 | 437 | 438 | # def __getattr__(self, item): 439 | # try: 440 | # return super().__getattr__(item) 441 | # except Exception: 442 | # return getattr(self.model, item) 443 | 444 | 445 | __all__ = ["LlamaGPTQ",] 446 | --------------------------------------------------------------------------------