├── .gitignore ├── LICENSE ├── README.md ├── config ├── dataset2maxlen.json ├── dataset2prompt.json ├── model2maxlen.json └── model2path.json ├── docs └── long_bench.md ├── eval_long_bench.py ├── example.py ├── img ├── algo.png └── quant_scheme.png ├── long_context_example.py ├── mem_spd_test.py ├── metrics.py ├── models ├── __init__.py ├── llama_kivi.py ├── mistral_kivi.py └── utils_quant.py ├── passkey_examples.jsonl ├── pred_long_bench.py ├── pyproject.toml ├── quant ├── __init__.py ├── csrc │ ├── gemv_cuda.cu │ ├── gemv_cuda.h │ ├── gemv_cuda_backup.cu │ └── pybind.cpp ├── gemv.py ├── matmul.py ├── new_pack.py ├── qmodule.py ├── setup.py ├── test.py └── timeit_v2.py ├── requirements.txt ├── scripts └── long_test.sh ├── utils ├── data.py ├── metrics.py └── process_args.py └── vis └── vis.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | outputs 3 | logs 4 | cached_models 5 | *.egg-info 6 | build 7 | *.so 8 | out_* 9 | *.sh 10 | *.jsonl 11 | *.pt 12 | notebook 13 | dataset 14 | output_samples 15 | *.pdf 16 | *.json 17 | *.jsonl 18 | third_party 19 | speed_logs 20 | pred 21 | pred_e -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 jiayi yuan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache 2 | 3 | Implementation of [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache](https://arxiv.org/abs/2402.02750) 4 | 5 | ## Updates 6 | - [2025.01.18]:We add KIVI implementation with GQA and compatiable with transformers 4.43. Now it supports LLama3 family. Please reinstall KIVI. 7 | - [2024.06.07]:🎉 KIVI largely inspires the [HuggingFace Transformers KV Cache quantization](https://huggingface.co/docs/transformers/main/en/kv_cache) 8 | - [2024.06.06]:(Beta) We extensively optimize the codebase in [branch develop](https://github.com/jy-yuan/KIVI/tree/develop) to reduce the latency of KIVI. Note that **you need to reinstall our CUDA implementation** under the ```quant``` folder. We will release a blog soon about the detailed optimization. 9 | - [2024.05.01]:🎉 KIVI has been accepted by ICML 2024! See you in Vienna! 10 | - [2024.04.12]: We add the support for Mistral model family. The performance of LongChat-7b-v1.5-32K and Mistral-7B-Instruct-v0.2 on 15 tasks from LongBench can be found in [long_bench.md](./docs/long_bench.md). 11 | 12 | - [2024.04.05]: We release the code for reproducing our CoQA/TruthfulQA/GSM8K results using LM-Eval. Please check the [README of branch lmeval](https://github.com/jy-yuan/KIVI/tree/lmeval). 13 | 14 | - [2024.04.04]: 🔥🔥We add a new 5-digit [passkey example](./long_context_example.py) with 12k context length to show the performance of 2bit KIVI under the long context senario. 15 | 16 | - [2024.04.04]: (Beta) We add the flash-attention support for KIVI during the prefill phase. 17 | 18 | - [2024.04.03]: We add a new [5-shot GSM8K example.py](./example.py) to show the performance of 2/4 bit KIVI with 32 full precision tokens. 19 | 20 | - [2024.02.05]: KIVI ver. 2 is released on [arXiv](https://arxiv.org/abs/2402.02750). 21 | 22 | - [2024.02.03]: KIVI code is released. 23 | 24 | - [2023.12.29]: KIVI ver. 1 is released on [researchgate](https://www.researchgate.net/publication/376831635_KIVI_Plug-and-play_2bit_KV_Cache_Quantization_with_Streaming_Asymmetric_Quantization). 25 | 26 | ## Overview 27 | 28 | KIVI is a new plug-and-play 2bit KV cache quantization algorithm without any fine-tuning. This algorithm optimizes memory usage by quantizing the key cache per-channel and the value cache per-token to 2bit. KIVI's hardware-friendly design allows LLMs like Llama-2, Falcon, and Mistral to maintain comparable quality levels while reducing peak memory usage by 2.6 times. This enables up to 4 times larger batch sizes and significantly increases throughput by 2.35 to 3.47 times in real LLM inference workloads, effectively addressing the bottleneck issues in speed and memory usage. 29 | 30 | Illustration of KIVI quantization scheme: key cache per-channel and value cache per-token. 31 |

32 | 33 |

34 | 35 | Illustration of KIVI algorithm during inference prefill and decoding phase: 36 |

37 | 38 |

39 | 40 | ## How to use KIVI 41 | 42 | ### Setup 43 | 44 | To install the required packages: 45 | 46 | ```bash 47 | conda create -n kivi python=3.10 48 | conda activate kivi 49 | pip install --upgrade pip # enable PEP 660 support 50 | pip install -e . 51 | ``` 52 | 53 | Then install our CUDA implementation: 54 | 55 | ```bash 56 | cd quant && pip install -e . 57 | ``` 58 | 59 | ### Example 60 | 61 | Load model with KIVI: (e.g., Llama-2-7b) 62 | 63 | ```python 64 | # LLaMA model with KIVI 65 | import torch 66 | import os 67 | from models.llama_kivi import LlamaForCausalLM_KIVI 68 | from transformers import LlamaConfig, AutoTokenizer 69 | config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf") 70 | 71 | config.k_bits = K_BITS # current support 2/4 bit for KV Cache 72 | config.v_bits = V_BITS # current support 2/4 bit for KV Cache 73 | config.group_size = GROUP_SIZE 74 | config.residual_length = RESIDUAL_LENGTH # the number of recent fp16 tokens 75 | CACHE_DIR = PATH_TO_YOUR_SAVE_DIR 76 | 77 | model = LlamaForCausalLM_KIVI.from_pretrained( 78 | pretrained_model_name_or_path='meta-llama/Llama-2-7b-hf', 79 | config=config, 80 | cache_dir=CACHE_DIR, 81 | torch_dtype=torch.float16, 82 | low_cpu_mem_usage=True, 83 | device_map="auto", 84 | ) 85 | 86 | tokenizer = AutoTokenizer.from_pretrained( 87 | 'meta-llama/Llama-2-7b-hf', 88 | use_fast=False, 89 | trust_remote_code=True, 90 | tokenizer_type='llama') 91 | 92 | # Inference 93 | # e.g., model.generate(...) 94 | ``` 95 | 96 | #### GSM8K example 97 | We use GSM8K as an example to show how to use KIVI. You can check [example.py](./example.py): 98 | 99 | ```bash 100 | python example.py 101 | ``` 102 | 103 | #### Passkey retrieval example 104 | 105 | Passkey retrieval with KIVI. You can check [long_context_example.py](./long_context_example.py): 106 | 107 | ```bash 108 | python long_context_example.py 109 | ``` 110 | 111 | #### Evaluate KIVI on LongBench 112 | 113 | We currently support Llama and Mistral family of models. We recently test KIVI on Mistral-7B-Instruct-v0.2 and Longchat-7b-v1.5-32k. Please check [long_bench.md](./docs/long_bench.md) for more details. 114 | ```bash 115 | bash scripts/long_test.sh {GPU_ID} {K_BITS} {V_BITS} {GROUP_LENGTH} {RESIDUAL_LENGTH} {MODEL_NAME} 116 | python eval_long_bench.py --model {MODEL} # MODEL is the dir name under pred/ Currently it support Llama family model and Mistral model. 117 | ``` 118 | 119 | ## Citation 120 | 121 | If you find our method useful, please kindly cite our paper. 122 | 123 | ```bibtex 124 | @article{liu2024kivi, 125 | title={KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache}, 126 | author={Liu, Zirui and Yuan, Jiayi and Jin, Hongye and Zhong, Shaochen and Xu, Zhaozhuo and Braverman, Vladimir and Chen, Beidi and Hu, Xia}, 127 | journal={arXiv preprint arXiv:2402.02750}, 128 | year={2024} 129 | } 130 | ``` 131 | 132 | ## Contributing 133 | We welcome contributions from the research community to improve KIVI. If you have any idea or would like to report a bug, please open an issue or submit a pull request. 134 | 135 | ## License 136 | The code is released under the MIT License. 137 | -------------------------------------------------------------------------------- /config/dataset2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": 128, 3 | "qasper": 128, 4 | "multifieldqa_en": 64, 5 | "multifieldqa_zh": 64, 6 | "hotpotqa": 32, 7 | "2wikimqa": 32, 8 | "musique": 32, 9 | "dureader": 128, 10 | "gov_report": 512, 11 | "qmsum": 512, 12 | "multi_news": 512, 13 | "vcsum": 512, 14 | "trec": 64, 15 | "triviaqa": 32, 16 | "samsum": 128, 17 | "lsht": 64, 18 | "passage_count": 32, 19 | "passage_retrieval_en": 32, 20 | "passage_retrieval_zh": 32, 21 | "lcc": 64, 22 | "repobench-p": 64 23 | } -------------------------------------------------------------------------------- /config/dataset2prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", 6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", 10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", 11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", 12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", 13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", 14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", 15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", 16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", 17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", 18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", 19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", 20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", 21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n", 22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" 23 | } -------------------------------------------------------------------------------- /config/model2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama2-7b-chat-4k": 3500, 3 | "longchat-v1.5-7b-32k": 31500, 4 | "xgen-7b-8k": 7500, 5 | "internlm-7b-8k": 7500, 6 | "chatglm2-6b": 31500, 7 | "chatglm2-6b-32k": 31500, 8 | "vicuna-v1.5-7b-16k": 15500, 9 | "LLaMA-2-7B-32K": 7500, 10 | "llama-7b": 4096, 11 | "Llama-2-7b-chat-hf": 4096, 12 | "Llama-2-7b-hf": 4096, 13 | "llama-13b": 4096, 14 | "Llama-2-13b-chat-hf": 4096, 15 | "Llama-2-13b-hf": 4096, 16 | "falcon-7b": 4096, 17 | "Mistral-7B-v0.1": 8192, 18 | "longchat-7b-v1.5-32k": 31500, 19 | "Mistral-7B-Instruct-v0.2": 31500 20 | } -------------------------------------------------------------------------------- /config/model2path.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama2-7b-chat-4k": "meta-llama/Llama-2-7b-chat-hf", 3 | "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k", 4 | "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst", 5 | "internlm-7b-8k": "internlm/internlm-chat-7b-8k", 6 | "chatglm2-6b": "THUDM/chatglm2-6b", 7 | "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", 8 | "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k" 9 | } -------------------------------------------------------------------------------- /docs/long_bench.md: -------------------------------------------------------------------------------- 1 | # More results on LongBench 2 | 3 | From the results, for vanilla multihead attention models, we recommend using KIVI-2, which can maintain the performance of the full-precision model while offering the best efficiency. For multiquery attention or group query attention, since the keys and values are already compressed, we recommend using KIVI-4, which can maintain the performance of the full-precision model in these cases. 4 | 5 | ### Table 1: Performance of LongChat-7b-v1.5-32K 6 | 7 | The results of LongChat-7b-v1.5-32K on 15 tasks from LongBench. The model has 32K context length. We use a 32 group size and 128 residual length for both KIVI-2 and KIVI-4. The baseline is of full precision. 8 | 9 | | Task | LongChat-7b-v1.5-32K | w./ KIVI-2 | w./ KIVI-4 | 10 | |------------------|:--------------------:|:------------:|:-------------:| 11 | | NarrativeQA | 20.65 | 20.79 | 20.49 | 12 | | Qasper | 29.42 | 28.69 | 28.90 | 13 | | MultiFieldQA | 43.15 | 41.02 | 43.24 | 14 | | HotpotQA | 33.05 | 32.91 | 33.07 | 15 | | MuSiQue | 14.66 | 13.82 | 14.66 | 16 | | 2WikiMultihopQA | 24.14 | 23.00 | 24.86 | 17 | | GovReport | 30.85 | 30.47 | 31.40 | 18 | | QMSum | 22.84 | 22.59 | 22.84 | 19 | | MultiNews | 26.55 | 26.28 | 26.52 | 20 | | LCC | 54.83 | 54.11 | 54.06 | 21 | | RepoBench-P | 58.94 | 57.62 | 58.77 | 22 | | TriviaQA | 83.99 | 83.19 | 83.88 | 23 | | SAMSum | 40.75 | 41.28 | 40.62 | 24 | | TRec | 66.50 | 66.50 | 67.00 | 25 | | PassageRetrieval | 30.50 | 32.25 | 31.50 | 26 | | **Average** | **38.72** | **38.30** | **38.79** | 27 | 28 | ### Table 2: Performance of Mistral-7B-Instruct-v0.2 29 | 30 | The results of Mistral-7B-Instruct-v0.2 on 15 tasks from LongBench. The model has 32K context length and applies group query attention, which uses 8 heads for KV Cache, as opposed to the full 32 heads. We use a 32 group size and 128 residual length for both KIVI-2 and KIVI-4. The baseline is of full precision. 31 | 32 | | Task | Mistral-7B-Instruct-v0.2 | w./ KIVI-2 | w./ KIVI-4 | 33 | |------------------|:------------------------:|:------------:|:-------------:| 34 | | NarrativeQA | 21.02 | 20.61 | 20.97 | 35 | | Qasper | 29.41 | 28.73 | 29.41 | 36 | | MultiFieldQA | 47.13 | 44.88 | 46.52 | 37 | | HotpotQA | 36.53 | 35.47 | 36.25 | 38 | | MuSiQue | 19.13 | 17.95 | 19.53 | 39 | | 2WikiMultihopQA | 21.76 | 20.68 | 21.66 | 40 | | GovReport | 32.59 | 32.55 | 32.97 | 41 | | QMSum | 23.99 | 23.65 | 24.06 | 42 | | MultiNews | 27.09 | 26.54 | 26.89 | 43 | | LCC | 53.49 | 53.03 | 53.33 | 44 | | RepoBench-P | 51.40 | 51.16 | 51.41 | 45 | | TriviaQA | 86.23 | 86.00 | 86.23 | 46 | | SAMSum | 43.04 | 43.34 | 43.34 | 47 | | TRec | 71.00 | 71.00 | 71.00 | 48 | | PassageRetrieval | 89.33 | 80.83 | 89.42 | 49 | | **Average** | **43.54** | **42.43** | **43.53** | 50 | -------------------------------------------------------------------------------- /eval_long_bench.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | from metrics import ( 7 | qa_f1_score, 8 | rouge_zh_score, 9 | qa_f1_zh_score, 10 | rouge_score, 11 | classification_score, 12 | retrieval_score, 13 | retrieval_zh_score, 14 | count_score, 15 | code_sim_score, 16 | ) 17 | 18 | dataset2metric = { 19 | "narrativeqa": qa_f1_score, 20 | "qasper": qa_f1_score, 21 | "multifieldqa_en": qa_f1_score, 22 | "multifieldqa_zh": qa_f1_zh_score, 23 | "hotpotqa": qa_f1_score, 24 | "2wikimqa": qa_f1_score, 25 | "musique": qa_f1_score, 26 | "dureader": rouge_zh_score, 27 | "gov_report": rouge_score, 28 | "qmsum": rouge_score, 29 | "multi_news": rouge_score, 30 | "vcsum": rouge_zh_score, 31 | "trec": classification_score, 32 | "triviaqa": qa_f1_score, 33 | "samsum": rouge_score, 34 | "lsht": classification_score, 35 | "passage_retrieval_en": retrieval_score, 36 | "passage_count": count_score, 37 | "passage_retrieval_zh": retrieval_zh_score, 38 | "lcc": code_sim_score, 39 | "repobench-p": code_sim_score, 40 | } 41 | 42 | def parse_args(args=None): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--model', type=str, default=None) 45 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 46 | return parser.parse_args(args) 47 | 48 | def scorer_e(dataset, predictions, answers, lengths, all_classes): 49 | scores = {"0-4k": [], "4-8k": [], "8k+": []} 50 | for (prediction, ground_truths, length) in zip(predictions, answers, lengths): 51 | score = 0. 52 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 53 | prediction = prediction.lstrip('\n').split('\n')[0] 54 | for ground_truth in ground_truths: 55 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 56 | if length < 4000: 57 | scores["0-4k"].append(score) 58 | elif length < 8000: 59 | scores["4-8k"].append(score) 60 | else: 61 | scores["8k+"].append(score) 62 | for key in scores.keys(): 63 | scores[key] = round(100 * np.mean(scores[key]), 2) 64 | return scores 65 | 66 | def scorer(dataset, predictions, answers, all_classes): 67 | total_score = 0. 68 | for (prediction, ground_truths) in zip(predictions, answers): 69 | score = 0. 70 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 71 | prediction = prediction.lstrip('\n').split('\n')[0] 72 | for ground_truth in ground_truths: 73 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 74 | total_score += score 75 | return round(100 * total_score / len(predictions), 2) 76 | 77 | if __name__ == '__main__': 78 | args = parse_args() 79 | scores = dict() 80 | if args.e: 81 | path = f"pred_e/{args.model}/" 82 | else: 83 | path = f"pred/{args.model}/" 84 | all_files = os.listdir(path) 85 | print("Evaluating on:", all_files) 86 | for filename in all_files: 87 | if not filename.endswith("jsonl"): 88 | continue 89 | predictions, answers, lengths = [], [], [] 90 | dataset = filename.split('.')[0] 91 | with open(f"{path}{filename}", "r", encoding="utf-8") as f: 92 | for line in f: 93 | data = json.loads(line) 94 | predictions.append(data["pred"]) 95 | answers.append(data["answers"]) 96 | all_classes = data["all_classes"] 97 | if "length" in data: 98 | lengths.append(data["length"]) 99 | if args.e: 100 | score = scorer_e(dataset, predictions, answers, lengths, all_classes) 101 | else: 102 | score = scorer(dataset, predictions, answers, all_classes) 103 | scores[dataset] = score 104 | if args.e: 105 | out_path = f"pred_e/{args.model}/result.json" 106 | else: 107 | out_path = f"pred/{args.model}/result.json" 108 | with open(out_path, "w") as f: 109 | json.dump(scores, f, ensure_ascii=False, indent=4) -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # LLaMA model with KIVI 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import torch 5 | import random 6 | from models.llama_kivi import LlamaForCausalLM_KIVI 7 | from transformers import LlamaConfig, AutoTokenizer 8 | from datasets import load_dataset 9 | 10 | # For reproducibility 11 | random.seed(0) 12 | torch.manual_seed(0) 13 | 14 | config = LlamaConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") 15 | 16 | config.k_bits = 2 # KiVi currently support 2/4 K/V bits 17 | config.v_bits = 2 18 | config.group_size = 32 19 | config.residual_length = 32 # corresponding to the number of recent fp16 tokens 20 | config.use_flash = True 21 | 22 | model = LlamaForCausalLM_KIVI.from_pretrained( 23 | # pretrained_model_name_or_path='meta-llama/Llama-2-7b-hf', 24 | pretrained_model_name_or_path='meta-llama/Llama-3.1-8B-Instruct', 25 | config=config, 26 | low_cpu_mem_usage=True, 27 | torch_dtype=torch.float16, 28 | ).cuda() 29 | 30 | enc = AutoTokenizer.from_pretrained( 31 | 'meta-llama/Llama-3.1-8B-Instruct', 32 | use_fast=False, 33 | trust_remote_code=True) 34 | 35 | dataset = load_dataset('gsm8k', 'main') 36 | 37 | prompt = '' 38 | for i in range(5): 39 | prompt += 'Question: ' + dataset['train'][i]['question'] + '\nAnswer: ' + dataset['train'][i]['answer'] + '\n' 40 | prompt += "Question: John takes care of 10 dogs. Each dog takes .5 hours a day to walk and take care of their business. How many hours a week does he spend taking care of dogs?" 41 | inputs = enc(prompt, return_tensors="pt").input_ids.cuda() 42 | 43 | output = model.generate(inputs, max_new_tokens=96) 44 | config_str = f"# prompt tokens: {inputs.shape[1]}, K bit: {config.k_bits}, v_bits: {config.v_bits}, group_size: {config.group_size}, residual_length: {config.residual_length}" 45 | 46 | print(prompt + "\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nKiVi Output:") 47 | print(enc.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True)) -------------------------------------------------------------------------------- /img/algo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jy-yuan/KIVI/6de8e6e50547f9e240b28b17296affe8a2f034b7/img/algo.png -------------------------------------------------------------------------------- /img/quant_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jy-yuan/KIVI/6de8e6e50547f9e240b28b17296affe8a2f034b7/img/quant_scheme.png -------------------------------------------------------------------------------- /long_context_example.py: -------------------------------------------------------------------------------- 1 | # LLaMA model with KIVI 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import torch 5 | import json 6 | from models.llama_kivi import LlamaForCausalLM_KIVI 7 | from transformers import LlamaConfig, AutoTokenizer 8 | from datasets import load_dataset 9 | 10 | config = LlamaConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") 11 | config.k_bits = 2 # KiVi currently support 2/4 K/V bits 12 | config.v_bits = 2 13 | config.group_size = 32 14 | config.residual_length = 32 # corresponding to the number of recent fp16 tokens 15 | config.use_flash = True # use flash-attention with KiVi for long context inference 16 | CACHE_DIR = "/scratch/cached_model" 17 | 18 | model = LlamaForCausalLM_KIVI.from_pretrained( 19 | pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct", 20 | config=config, 21 | # cache_dir=CACHE_DIR, 22 | low_cpu_mem_usage=True, 23 | torch_dtype=torch.float16, 24 | ).cuda() 25 | 26 | enc = AutoTokenizer.from_pretrained( 27 | "meta-llama/Llama-3.1-8B-Instruct", 28 | use_fast=False, 29 | trust_remote_code=True,) 30 | 31 | model.eval() 32 | file_name = "passkey_examples.jsonl" 33 | method_name = f"K{config.k_bits}V{config.v_bits} KiVi" 34 | print("=========="*2 + f"**{method_name}**" + "=========="*2) 35 | for line in open(file_name, "r"): 36 | example = json.loads(line) 37 | prompt_postfix = "What is the pass key? The pass key is " 38 | prompt = example["input"] + prompt_postfix 39 | input_ids = enc(prompt, return_tensors="pt").input_ids.cuda() 40 | print( "-----------------------------------" ) 41 | print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " ) 42 | print( "Passkey target:", example["target"] ) 43 | 44 | tokens = model.generate(input_ids, max_new_tokens=len(example["target"])) 45 | answer = prompt_postfix + enc.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True) 46 | answer = answer.replace("\n", "\\n") 47 | answer= f"{method_name}:\n [ {answer} ]" 48 | print( answer ) 49 | print( "-----------------------------------\n" ) 50 | -------------------------------------------------------------------------------- /mem_spd_test.py: -------------------------------------------------------------------------------- 1 | # LLaMA model with KIVI 2 | import torch 3 | import os 4 | from models.llama_kivi import LlamaForCausalLM_KIVI 5 | from transformers import LlamaConfig, AutoTokenizer 6 | import time 7 | 8 | K_BITS = 2 9 | V_BITS = 2 10 | GROUP_SIZE = 32 11 | RESIDUAL_LENGTH = 128 12 | BATCH_SIZE = 96 13 | PATH_TO_YOUR_SAVE_DIR = './cached_models' 14 | 15 | model_name_or_path = 'meta-llama/Llama-2-7b-hf' 16 | config = LlamaConfig.from_pretrained(model_name_or_path) 17 | config.k_bits = K_BITS # current support 2/4 bit for KV Cache 18 | config.v_bits = V_BITS # current support 2/4 bit for KV Cache 19 | config.group_size = GROUP_SIZE 20 | config.residual_length = RESIDUAL_LENGTH # the number of recent fp16 tokens 21 | CACHE_DIR = PATH_TO_YOUR_SAVE_DIR 22 | 23 | if K_BITS < 16 and V_BITS < 16: 24 | model = LlamaForCausalLM_KIVI.from_pretrained( 25 | pretrained_model_name_or_path=model_name_or_path, 26 | config=config, 27 | cache_dir=CACHE_DIR, 28 | torch_dtype=torch.float16, 29 | low_cpu_mem_usage=True, 30 | device_map="auto", 31 | ) 32 | else: 33 | from transformers import LlamaForCausalLM 34 | model = LlamaForCausalLM.from_pretrained( 35 | pretrained_model_name_or_path=model_name_or_path, 36 | config=config, 37 | cache_dir=CACHE_DIR, 38 | torch_dtype=torch.float16, 39 | low_cpu_mem_usage=True, 40 | device_map="auto", 41 | ) 42 | 43 | tokenizer = AutoTokenizer.from_pretrained( 44 | model_name_or_path, 45 | use_fast=False, 46 | trust_remote_code=True, 47 | tokenizer_type='llama') 48 | 49 | model.cuda().eval() 50 | 51 | context = [] 52 | batch_size = BATCH_SIZE 53 | prompt_lenth = 160 54 | output_length = 338 55 | num_repeats = 3 56 | for _ in range(batch_size): 57 | string = 't,' * (prompt_lenth // 2) 58 | context.append(string[:-1]) 59 | inputs = tokenizer(context, return_tensors="pt").to('cuda') 60 | input_ids = inputs['input_ids'] 61 | print(f"bs: {batch_size}, seqlen: {input_ids.shape[1]}+{output_length}\nmodel:{model_name_or_path}") 62 | torch.cuda.reset_peak_memory_stats() 63 | with torch.no_grad(): 64 | torch.cuda.synchronize() 65 | st = time.time() 66 | for i in range(num_repeats): 67 | outputs = model.generate(**inputs, max_new_tokens=output_length) 68 | torch.cuda.synchronize() 69 | print(f'used time: {(time.time() - st) / num_repeats * 1000} ms') 70 | used_mem = torch.cuda.max_memory_allocated() 71 | print(f'peak mem: {used_mem / 1024 ** 3} GB') -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | import jieba 5 | from fuzzywuzzy import fuzz 6 | import difflib 7 | 8 | from typing import List 9 | from collections import Counter 10 | from rouge import Rouge 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | 15 | def remove_articles(text): 16 | return re.sub(r"\b(a|an|the)\b", " ", text) 17 | 18 | def white_space_fix(text): 19 | return " ".join(text.split()) 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return "".join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 29 | 30 | 31 | def normalize_zh_answer(s): 32 | """Lower text and remove punctuation, extra whitespace.""" 33 | 34 | def white_space_fix(text): 35 | return "".join(text.split()) 36 | 37 | def remove_punc(text): 38 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." 39 | all_punctuation = set(string.punctuation + cn_punctuation) 40 | return "".join(ch for ch in text if ch not in all_punctuation) 41 | 42 | def lower(text): 43 | return text.lower() 44 | 45 | return white_space_fix(remove_punc(lower(s))) 46 | 47 | def count_score(prediction, ground_truth, **kwargs): 48 | numbers = re.findall(r"\d+", prediction) 49 | right_num = 0 50 | for number in numbers: 51 | if str(number) == str(ground_truth): 52 | right_num += 1 53 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 54 | return float(final_score) 55 | 56 | def retrieval_score(prediction, ground_truth, **kwargs): 57 | pattern = r'Paragraph (\d+)' 58 | matches = re.findall(pattern, ground_truth) 59 | ground_truth_id = matches[0] 60 | numbers = re.findall(r"\d+", prediction) 61 | right_num = 0 62 | for number in numbers: 63 | if str(number) == str(ground_truth_id): 64 | right_num += 1 65 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 66 | return float(final_score) 67 | 68 | def retrieval_zh_score(prediction, ground_truth, **kwargs): 69 | pattern = r'段落(\d+)' 70 | matches = re.findall(pattern, ground_truth) 71 | ground_truth_id = matches[0] 72 | numbers = re.findall(r"\d+", prediction) 73 | right_num = 0 74 | for number in numbers: 75 | if str(number) == str(ground_truth_id): 76 | right_num += 1 77 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 78 | return float(final_score) 79 | 80 | def code_sim_score(prediction, ground_truth, **kwargs): 81 | all_lines = prediction.lstrip('\n').split('\n') 82 | prediction = "" 83 | for line in all_lines: 84 | if ('`' not in line) and ('#' not in line) and ('//' not in line): 85 | prediction = line 86 | break 87 | return (fuzz.ratio(prediction, ground_truth) / 100) 88 | 89 | def classification_score(prediction, ground_truth, **kwargs): 90 | em_match_list = [] 91 | all_classes = kwargs["all_classes"] 92 | for class_name in all_classes: 93 | if class_name in prediction: 94 | em_match_list.append(class_name) 95 | for match_term in em_match_list: 96 | if match_term in ground_truth and match_term != ground_truth: 97 | em_match_list.remove(match_term) 98 | if em_match_list != 0: 99 | if ground_truth in em_match_list: 100 | score = (1.0 / len(em_match_list)) 101 | else: 102 | score = 0.0 103 | else: 104 | best_match = None 105 | highest_similarity = 0 106 | for string in all_classes: 107 | similarity = difflib.SequenceMatcher(None, string, prediction).ratio() 108 | if similarity > highest_similarity: 109 | highest_similarity = similarity 110 | best_match = string 111 | score = float(best_match == ground_truth) 112 | return score 113 | 114 | def rouge_score(prediction, ground_truth, **kwargs): 115 | rouge = Rouge() 116 | try: 117 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 118 | except: 119 | return 0.0 120 | return scores["rouge-l"]["f"] 121 | 122 | def rouge_zh_score(prediction, ground_truth, **kwargs): 123 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) 124 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 125 | score = rouge_score(prediction, ground_truth) 126 | return score 127 | 128 | def f1_score(prediction, ground_truth, **kwargs): 129 | common = Counter(prediction) & Counter(ground_truth) 130 | num_same = sum(common.values()) 131 | if num_same == 0: 132 | return 0 133 | precision = 1.0 * num_same / len(prediction) 134 | recall = 1.0 * num_same / len(ground_truth) 135 | f1 = (2 * precision * recall) / (precision + recall) 136 | return f1 137 | 138 | def qa_f1_score(prediction, ground_truth, **kwargs): 139 | normalized_prediction = normalize_answer(prediction) 140 | normalized_ground_truth = normalize_answer(ground_truth) 141 | 142 | prediction_tokens = normalized_prediction.split() 143 | ground_truth_tokens = normalized_ground_truth.split() 144 | return f1_score(prediction_tokens, ground_truth_tokens) 145 | 146 | 147 | def qa_f1_zh_score(prediction, ground_truth, **kwargs): 148 | prediction_tokens = list(jieba.cut(prediction, cut_all=False)) 149 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) 150 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] 151 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] 152 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0] 153 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] 154 | return f1_score(prediction_tokens, ground_truth_tokens) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jy-yuan/KIVI/6de8e6e50547f9e240b28b17296affe8a2f034b7/models/__init__.py -------------------------------------------------------------------------------- /models/utils_quant.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 2023.07.05 - Modified weight quantization 9 | # Meta Platforms, Inc. 10 | # 11 | # Copyright 2021 Huawei Technologies Co., Ltd. 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | 25 | import math 26 | 27 | import torch 28 | import torch.nn.functional as F 29 | import torch.nn as nn 30 | import numpy as np 31 | 32 | 33 | class SymQuantizer(torch.autograd.Function): 34 | """ 35 | uniform quantization 36 | """ 37 | 38 | @staticmethod 39 | def forward(ctx, input, clip_val, num_bits, layerwise): 40 | """ 41 | :param ctx: 42 | :param input: tensor to be quantized 43 | :param clip_val: clip the tensor before quantization 44 | :param quant_bits: number of bits 45 | :return: quantized tensor 46 | """ 47 | ctx.save_for_backward(input, clip_val) 48 | # input = torch.clamp(input, clip_val[0], clip_val[1]) 49 | # input = torch.where(input < clip_val[1], input, clip_val[1]) 50 | # input = torch.where(input > clip_val[0], input, clip_val[0]) 51 | # NOTE: dynamic scaling (max_input). 52 | if layerwise: 53 | max_input = torch.max(torch.abs(input)).expand_as(input) 54 | else: 55 | if input.ndimension() <= 3: 56 | # weight & hidden layer 57 | max_input = ( 58 | torch.max(torch.abs(input), dim=-1, keepdim=True)[0] 59 | .expand_as(input) 60 | .detach() 61 | ) 62 | elif input.ndimension() == 4: 63 | # TODO: attention score matrix, calculate alpha / beta per head 64 | tmp = input.view(input.shape[0], input.shape[1], -1) 65 | max_input = ( 66 | torch.max(torch.abs(tmp), dim=-1, keepdim=True)[0] 67 | .unsqueeze(-1) 68 | .expand_as(input) 69 | .detach() 70 | ) 71 | else: 72 | raise ValueError 73 | s = (2 ** (num_bits - 1) - 1) / (max_input + 1e-6) 74 | output = torch.round(input * s).div(s + 1e-6) 75 | 76 | return output 77 | 78 | @staticmethod 79 | def backward(ctx, grad_output): 80 | """ 81 | :param ctx: saved non-clipped full-precision tensor and clip_val 82 | :param grad_output: gradient ert the quantized tensor 83 | :return: estimated gradient wrt the full-precision tensor 84 | """ 85 | input, clip_val = ctx.saved_tensors # unclipped input 86 | grad_input = grad_output.clone() 87 | grad_input[input.ge(clip_val[1])] = 0 88 | grad_input[input.le(clip_val[0])] = 0 89 | return grad_input, None, None, None 90 | 91 | 92 | class AsymQuantizer(torch.autograd.Function): 93 | """ 94 | min-max quantization 95 | """ 96 | 97 | @staticmethod 98 | def forward(ctx, input, clip_val, num_bits, layerwise): 99 | """ 100 | :param ctx: 101 | :param input: tensor to be quantized 102 | :param clip_val: clip the tensor before quantization 103 | :param quant_bits: number of bits 104 | :return: quantized tensor 105 | """ 106 | ctx.save_for_backward(input, clip_val) 107 | 108 | # input = torch.where(input < clip_val[1], input, clip_val[1]) 109 | # input = torch.where(input > clip_val[0], input, clip_val[0]) 110 | # input = torch.clamp(input, clip_val[0], clip_val[1]) 111 | # NOTE: dynamic scaling gives better performance than static 112 | if layerwise: 113 | alpha = (input.max() - input.min()).detach() 114 | beta = input.min().detach() 115 | else: 116 | if input.ndimension() <= 3: 117 | # weight & hidden layer 118 | alpha = ( 119 | ( 120 | input.max(dim=-1, keepdim=True)[0] 121 | - input.min(dim=-1, keepdim=True)[0] 122 | ) 123 | .expand_as(input) 124 | .detach() 125 | ) 126 | beta = input.min(dim=-1, keepdim=True)[0].expand_as(input).detach() 127 | elif input.ndimension() == 4: 128 | # TODO: attention score matrix, calculate alpha / beta per head 129 | tmp = input.view(input.shape[0], input.shape[1], -1) 130 | alpha = ( 131 | ( 132 | tmp.max(dim=-1, keepdim=True)[0].unsqueeze(-1) 133 | - tmp.min(dim=-1, keepdim=True)[0].unsqueeze(-1) 134 | ) 135 | .expand_as(input) 136 | .detach() 137 | ) 138 | beta = ( 139 | tmp.min(dim=-1, keepdim=True)[0] 140 | .unsqueeze(-1) 141 | .expand_as(input) 142 | .detach() 143 | ) 144 | else: 145 | raise ValueError 146 | input_normalized = (input - beta) / (alpha + 1e-8) 147 | s = 2**num_bits - 1 148 | quant_input = torch.round(input_normalized * s).div(s) 149 | output = quant_input * (alpha + 1e-8) + beta 150 | 151 | return output 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | """ 156 | :param ctx: saved non-clipped full-precision tensor and clip_val 157 | :param grad_output: gradient ert the quantized tensor 158 | :return: estimated gradient wrt the full-precision tensor 159 | """ 160 | input, clip_val = ctx.saved_tensors # unclipped input 161 | grad_input = grad_output.clone() 162 | grad_input[input.ge(clip_val[1])] = 0 163 | grad_input[input.le(clip_val[0])] = 0 164 | return grad_input, None, None, None 165 | 166 | 167 | class AsymGroupedQuantizer(torch.autograd.Function): 168 | @staticmethod 169 | def forward(ctx, input, clip_val, num_bits, group_size, prec_map_indices=None): 170 | ctx.save_for_backward(input, clip_val) 171 | # input = torch.clamp(input, clip_val[0], clip_val[1]) 172 | # input = torch.where(input < clip_val[1], input, clip_val[1]) 173 | # input = torch.where(input > clip_val[0], input, clip_val[0]) 174 | # NOTE: dynamic scaling (max_input). 175 | 176 | bs, seqlen, d = input.shape 177 | num_groups = d // group_size 178 | if num_groups * group_size != input.shape[-1]: 179 | raise ValueError("group_size should be a factor of the last dimension size") 180 | 181 | 182 | input_in_groups = input.view(bs, seqlen, num_groups, group_size) 183 | 184 | ##### 185 | # input_in_groups_cpy = input_in_groups.clone().detach() 186 | ##### 187 | 188 | mx, mn = input_in_groups.max(dim=-1)[0], input_in_groups.min(dim=-1)[0] 189 | mx, mn = mx.unsqueeze(-1), mn.unsqueeze(-1) 190 | 191 | scale = (mx - mn) / (2 ** num_bits - 1) 192 | input_in_groups = (input_in_groups - mn) / scale 193 | input_in_groups = F.relu(input_in_groups) 194 | rounded_input_in_groups = input_in_groups.round_() 195 | dequantized_input_in_groups = rounded_input_in_groups * scale + mn 196 | 197 | ##### 198 | # if prec_map_indices is not None: 199 | # _, num_heads, _ = prec_map_indices.shape 200 | # for i in range(bs): 201 | # for j in range(num_heads): 202 | # for k in prec_map_indices[i, j]: 203 | # dequantized_input_in_groups[i, k, j, :] = input_in_groups_cpy[i, k, j, :] 204 | ##### 205 | 206 | dequantized_input = dequantized_input_in_groups.view(bs, seqlen, -1) 207 | return dequantized_input 208 | 209 | @staticmethod 210 | def backward(ctx, grad_output): 211 | input, clip_val = ctx.saved_tensors 212 | grad_input = grad_output 213 | 214 | # clip version 215 | # grad_input[input.ge(clip_val[1])] = 0 216 | # grad_input[input.le(clip_val[0])] = 0 217 | return grad_input, None, None, None, None 218 | 219 | ### group by channel 220 | class AsymGroupedQuantizerByChannel(torch.autograd.Function): 221 | @staticmethod 222 | def forward(ctx, input, clip_val, num_bits, group_size, prec_map_indices=None): 223 | ctx.save_for_backward(input, clip_val) 224 | # input = torch.clamp(input, clip_val[0], clip_val[1]) 225 | # input = torch.where(input < clip_val[1], input, clip_val[1]) 226 | # input = torch.where(input > clip_val[0], input, clip_val[0]) 227 | bs, seqlen, d = input.shape 228 | mx, mn = input.max(dim=-2)[0], input.min(dim=-2)[0] 229 | mx, mn = mx.unsqueeze(-2), mn.unsqueeze(-2) 230 | scale = (mx - mn) / (2 ** num_bits - 1) 231 | input = (input - mn) / scale 232 | input = F.relu(input) 233 | rounded_input = input.round_() 234 | dequantized_input = rounded_input * scale + mn 235 | 236 | assert dequantized_input.shape == input.shape 237 | 238 | return dequantized_input 239 | 240 | @staticmethod 241 | def backward(ctx, grad_output): 242 | input, clip_val = ctx.saved_tensors 243 | grad_input = grad_output 244 | 245 | # clip version 246 | # grad_input[input.ge(clip_val[1])] = 0 247 | # grad_input[input.le(clip_val[0])] = 0 248 | return grad_input, None, None, None, None 249 | 250 | class QuantizeLinear(nn.Linear): 251 | def __init__( 252 | self, 253 | *kargs, 254 | symmetric=True, 255 | bias=False, 256 | w_bits=32, 257 | a_bits=32, 258 | act_layerwise=False, 259 | weight_layerwise=False, 260 | ): 261 | super(QuantizeLinear, self).__init__(*kargs, bias=False) 262 | self.w_bits = w_bits 263 | self.a_bits = a_bits 264 | self.act_layerwise = act_layerwise 265 | self.weight_layerwise = weight_layerwise 266 | # params for weight quant 267 | # if self.w_bits < 32: 268 | # self.weight_clip_val = Parameter(torch.tensor([-2.0, 2.0]), requires_grad=False) 269 | if self.a_bits < 32 and self.a_bits > 2: 270 | if symmetric: 271 | self.act_quantizer = SymQuantizer 272 | else: 273 | self.act_quantizer = AsymQuantizer 274 | 275 | def forward(self, input_): 276 | # quantize weight 277 | assert len(self.weight.size()) == 2 278 | real_weights = self.weight 279 | 280 | if self.w_bits >= 32: 281 | weight = self.weight 282 | elif self.w_bits >= 3: 283 | weight_clip_val = torch.tensor([-2.0, 2.0]) 284 | weight = SymQuantizer.apply( 285 | real_weights, weight_clip_val, self.w_bits, self.weight_layerwise 286 | ) 287 | else: 288 | if self.w_bits == 1: 289 | if self.weight_layerwise: 290 | scaling_factor = torch.mean(abs(real_weights)).detach() 291 | else: 292 | scaling_factor = torch.mean( 293 | abs(real_weights), dim=1, keepdim=True 294 | ).detach() 295 | quan_weights_no_grad = scaling_factor * ( 296 | torch.sign(real_weights / scaling_factor) 297 | ) 298 | # elif self.w_bits == 2: 299 | # scaling_factor = 4/3 * torch.mean(abs(real_weights), dim=1, keepdim=True).detach() 300 | # quan_weights_no_grad = scaling_factor * (torch.round(torch.clamp(real_weights/scaling_factor, -1, 1))) 301 | else: 302 | num_bits = 2 ** (self.w_bits - 1) 303 | clip_val = 1 - 1e-2 304 | if self.weight_layerwise: 305 | scaling_factor = 2 * torch.mean(abs(real_weights)).detach() 306 | else: 307 | scaling_factor = ( 308 | 2 * torch.mean(abs(real_weights), dim=1, keepdim=True).detach() 309 | ) 310 | quan_weights_no_grad = ( 311 | scaling_factor 312 | * ( 313 | torch.round( 314 | torch.clamp( 315 | real_weights / scaling_factor, -clip_val, clip_val 316 | ) 317 | * num_bits 318 | - 0.5 319 | ) 320 | + 0.5 321 | ) 322 | / num_bits 323 | ) 324 | 325 | weight = ( 326 | quan_weights_no_grad.detach() - real_weights.detach() + real_weights 327 | ) 328 | # Quantize inputs 329 | if self.a_bits < 32 and self.a_bits > 2: 330 | act_clip_val = torch.tensor([-2.0, 2.0]) 331 | input_ = self.act_quantizer.apply( 332 | input_, act_clip_val, self.a_bits, self.act_layerwise 333 | ) 334 | 335 | out = nn.functional.linear(input_, weight) 336 | if self.bias is not None: 337 | out += self.bias.view(1, -1).expand_as(out) 338 | 339 | return out 340 | 341 | 342 | def test_group_quantize(): 343 | input = torch.randn((4, 16, 1024), dtype=torch.float16, device='cuda') 344 | clip_val = torch.tensor([-2.0, 2.0]) 345 | for num_bits, group_size in [ (2, 64), (4, 64), (8, 64), \ 346 | (2, 128), (4, 128), (8, 128), \ 347 | (2, 256), (4, 256), (8, 256)]: 348 | output = AsymGroupedQuantizer.apply(input, clip_val, num_bits, group_size) 349 | err = torch.mean(torch.abs(input - output)).item() 350 | print(num_bits, group_size, err) 351 | # print(input[0,0,100:150]) 352 | # print(output[0,0,100:150]) 353 | 354 | 355 | def process_input(input, group_size): 356 | N = input.shape[0] 357 | input_flatten = input.reshape(N, -1) 358 | num_features = input_flatten.shape[1] 359 | 360 | # Compute min, max by groups 361 | if num_features % group_size != 0: 362 | # Padding 363 | new_num_features = (num_features // group_size + 1) * group_size 364 | delta = new_num_features - num_features 365 | input_flatten = torch.cat([input_flatten, 366 | torch.zeros([N, delta], dtype=input.dtype, device=input.device)], 1) 367 | 368 | input_groups = input_flatten.reshape(-1, group_size) 369 | mn, mx = torch.min(input_groups, 1)[0], torch.max(input_groups, 1)[0] 370 | return input_groups.view(N, -1, group_size), mn.view(N, -1), mx.view(N, -1) 371 | 372 | 373 | def quantize_and_pack(data, group_size, bits, simulate=False): 374 | data, mn, mx = process_input(data, group_size) 375 | data = data.transpose(0, 1) 376 | mn = mn.t() 377 | mx = mx.t() 378 | if simulate: 379 | mn, mx = mn.unsqueeze(-1), mx.unsqueeze(-1) 380 | N = data.shape[0] 381 | output = data # N, groups, group_dim 382 | if isinstance(bits, int): 383 | bits = torch.ones(N, dtype=torch.int32, device='cuda') * bits 384 | 385 | B = (2 ** bits - 1).view(N, 1, 1) 386 | mn = mn - 1e-6 387 | mx = mx + 1e-6 388 | scale = B / (mx - mn) # N, groups, 1 389 | output = (output - mn) * scale 390 | output = F.relu(output) 391 | output = torch.min(output, B.float()).round_().int() 392 | else: 393 | # data.shape == B, ng, gz 394 | # mn.shape == B, ng 395 | # mx.shape == B, ng 396 | # import ipdb; ipdb.set_trace() 397 | output, scale = dequant_cuda.pack_single_precision(data, mn, mx, bits, False) 398 | scale = scale.squeeze(-1) 399 | return output, scale, mn 400 | 401 | 402 | def dequantize_and_unpack(data, group_size, shape, bits, scale, mn, simulate=False): 403 | if simulate: 404 | scale, mn = scale.unsqueeze(-1), mn.unsqueeze(-1) 405 | data = data / scale + mn 406 | else: 407 | # Pad to group_size 408 | N = shape[0] 409 | num_features = int(np.prod(shape[1:])) 410 | num_features = (num_features + (group_size - num_features % group_size) % group_size) 411 | 412 | # Unpack bitstream 413 | data = dequant_cuda.unpack_single_precision(data, bits, scale, mn, N, num_features // group_size, group_size) 414 | data = data.view(shape) 415 | return data 416 | 417 | 418 | def process_input_by_channel(input, group_size): 419 | num_features = input.shape[-1] 420 | # input_flatten: [num_feats, bs * seqlen] 421 | input_flatten = input.view(-1, num_features).transpose(0, 1) 422 | num_instances = input_flatten.shape[-1] 423 | # Compute min, max by groups 424 | if num_instances % group_size != 0: 425 | # Padding 426 | new_num_instances = (num_instances // group_size + 1) * group_size 427 | delta = new_num_instances - num_instances 428 | input_flatten = torch.cat([input_flatten, 429 | torch.zeros([num_features, delta], dtype=input.dtype, device=input.device)], 1) 430 | input_groups = input_flatten.reshape(-1, group_size) 431 | mn, mx = torch.min(input_groups, 1)[0], torch.max(input_groups, 1)[0] 432 | return input_groups.view(num_features, -1, group_size), mn.view(num_features, -1), mx.view(num_features, -1) 433 | 434 | 435 | def quantize_by_channel_and_pack(input, group_size, num_bits, simulate=False): 436 | assert len(input.shape) == 3 437 | shape = input.shape 438 | ori_num_instances = shape[0] * shape[1] 439 | input_groups, mn, mx = process_input_by_channel(input, group_size) 440 | if simulate: 441 | mn, mx = mn.unsqueeze(-1), mx.unsqueeze(-1) 442 | scale = (mx - mn) / (2 ** num_bits - 1) 443 | input_groups = (input_groups - mn) / scale 444 | input_groups = F.relu(input_groups) 445 | rounded_input = input_groups.round_() 446 | return rounded_input, scale, mn 447 | # dequantized_input = rounded_input * scale + mn 448 | # dequantized_input = dequantized_input.view(input.shape[-1], -1) 449 | # if ori_num_instances != dequantized_input.shape[1]: 450 | # dequantized_input = dequantized_input[:, 0:ori_num_instances] 451 | # dequantized_input = dequantized_input.transpose(0, 1).view(shape) 452 | # assert dequantized_input.shape == shape 453 | # return dequantized_input, scale, mn 454 | else: 455 | output, scale = dequant_cuda.pack_single_precision(input_groups, mn, mx, num_bits, False) 456 | assert len(scale.shape) >= 2 and len(mn.shape) >= 2 457 | if len(scale.shape) == 3: 458 | scale = scale.squeeze(-1) 459 | if len(mn.shape) == 3: 460 | mn = mn.squeeze(-1) 461 | return output, scale, mn 462 | 463 | 464 | def dequantize_by_channel_and_unpack(data, group_size, shape, bits, scale, mn, simulate=False): 465 | num_feats = shape[-1] 466 | ori_num_instances = shape[0] * shape[1] 467 | if simulate: 468 | # import ipdb; ipdb.set_trace() 469 | data = data * scale + mn 470 | else: 471 | # Pad to group_size 472 | tot_num_instances = (ori_num_instances + (group_size - ori_num_instances % group_size) % group_size) 473 | 474 | # Unpack bitstream 475 | data = dequant_cuda.unpack_single_precision(data, bits, scale, mn, num_feats, tot_num_instances // group_size, group_size) 476 | dequantized_input = data.view(shape[-1], -1) 477 | if ori_num_instances != dequantized_input.shape[1]: 478 | dequantized_input = dequantized_input[:, 0:ori_num_instances] 479 | data = dequantized_input.transpose(0, 1).view(shape) 480 | return data 481 | 482 | 483 | def cal_tensor_size(x): 484 | if isinstance(x, list): 485 | return np.sum([cal_tensor_size(x_) for x_ in x]) 486 | elif isinstance(x, torch.Tensor): 487 | num_params = np.prod(x.shape) 488 | if x.dtype == torch.int32: 489 | return num_params * 4 490 | elif x.dtype in [torch.bfloat16, torch.float16]: 491 | return num_params * 2 492 | else: 493 | raise NotImplementedError 494 | else: 495 | raise NotImplementedError 496 | 497 | 498 | def quantize_by_channel_and_pack_cache(input, group_size, num_bits, simulate=False): 499 | ## convert kv_cache shape (bsz, head, seq_len, dim_head) to zirui shape (bsz, seq_len, dim) 500 | assert len(input.shape) == 4 501 | bsz, _, seq_len, _ = input.shape 502 | input = input.transpose(1, 2).reshape(bsz, seq_len, -1) 503 | ## 504 | 505 | shape = input.shape 506 | ori_num_instances = shape[0] * shape[1] 507 | input_groups, mn, mx = process_input_by_channel(input, group_size) 508 | if simulate: 509 | mn, mx = mn.unsqueeze(-1), mx.unsqueeze(-1) 510 | scale = (mx - mn) / (2 ** num_bits - 1) 511 | input_groups = (input_groups - mn) / scale 512 | input_groups = F.relu(input_groups) 513 | rounded_input = input_groups.round_() 514 | return rounded_input, scale, mn 515 | # dequantized_input = rounded_input * scale + mn 516 | # dequantized_input = dequantized_input.view(input.shape[-1], -1) 517 | # if ori_num_instances != dequantized_input.shape[1]: 518 | # dequantized_input = dequantized_input[:, 0:ori_num_instances] 519 | # dequantized_input = dequantized_input.transpose(0, 1).view(shape) 520 | # assert dequantized_input.shape == shape 521 | # return dequantized_input, scale, mn 522 | else: 523 | # import ipdb; ipdb.set_trace() 524 | output, scale = dequant_cuda.pack_single_precision(input_groups, mn, mx, num_bits, False) 525 | assert len(scale.shape) >= 2 and len(mn.shape) >= 2 526 | if len(scale.shape) == 3: 527 | scale = scale.squeeze(-1) 528 | if len(mn.shape) == 3: 529 | mn = mn.squeeze(-1) 530 | return output, scale, mn 531 | 532 | 533 | def dequantize_by_channel_and_unpack_cache(data, group_size, shape, bits, scale, mn, simulate=False): 534 | ## the input shape is not zirui shape (bsz, seq_len, dim), but kv_cache shape (bsz, head, seq_len, dim_head) 535 | # original variables 536 | # num_feats = shape[-1] 537 | # ori_num_instances = shape[0] * shape[1] 538 | assert len(shape) == 4 539 | num_feats = shape[1] * shape[3] 540 | ori_num_instances = shape[0] * shape[2] 541 | ## 542 | 543 | if simulate: 544 | # import ipdb; ipdb.set_trace() 545 | data = data * scale + mn 546 | else: 547 | # Pad to group_size 548 | tot_num_instances = (ori_num_instances + (group_size - ori_num_instances % group_size) % group_size) 549 | 550 | # Unpack bitstream 551 | data = dequant_cuda.unpack_single_precision(data, bits, scale, mn, num_feats, tot_num_instances // group_size, group_size) 552 | dequantized_input = data.view(num_feats, -1) 553 | if ori_num_instances != dequantized_input.shape[1]: 554 | dequantized_input = dequantized_input[:, 0:ori_num_instances] 555 | 556 | ## convert zirui shape (bsz, seq_len, dim) to kv_cache shape (bsz, head, seq_len, dim_head) 557 | # this is last step (this shape is zirui shape) data = dequantized_input.transpose(0, 1).view(shape) 558 | data = dequantized_input.transpose(0, 1).view(shape[0], -1, num_feats) 559 | data = data.view(shape[0], shape[2], shape[1], -1).transpose(1, 2) 560 | ## 561 | assert data.shape == shape 562 | 563 | return data 564 | 565 | def test_channel_quantize(): 566 | input = torch.randn((112, 334, 4096), dtype=torch.float16, device='cuda') 567 | shape = input.shape 568 | # for num_bits, group_size in [ (2, 64), (4, 64), (8, 64), \ 569 | # (2, 128), (4, 128), (8, 128), \ 570 | # (2, 256), (4, 256), (8, 256)]: 571 | for num_bits, group_size in [(2, 128), (4, 128)]: 572 | # fake_code, scale, mn = quantize_by_channel_and_pack(input, group_size, num_bits, True) 573 | # output_fake = dequantize_by_channel_and_unpack(fake_code, group_size, shape, num_bits, scale, mn, True) 574 | # err = torch.mean(torch.abs(input - output_fake)).item() 575 | # print(num_bits, group_size, err) 576 | real_code, scale, mn = quantize_by_channel_and_pack(input, group_size, num_bits, False) 577 | output_real = dequantize_by_channel_and_unpack(real_code, group_size, shape, num_bits, scale, mn, False) 578 | err = torch.mean(torch.abs(input - output_real)).item() 579 | print(num_bits, group_size, err) 580 | 581 | def test_quantize(): 582 | input = torch.randn((1, 32, 340, 128), dtype=torch.float16, device='cuda') 583 | shape = input.shape 584 | quantized_v, scale, mn = quantize_and_pack(input, 128, 4, False) 585 | dequantized_v = dequantize_and_unpack(quantized_v, 128, shape, 4, scale, mn, False) 586 | 587 | quantized_v, scale, mn = quantize_by_channel_and_pack_cache(input, 128, 4, False) 588 | dequantized_v = dequantize_by_channel_and_unpack_cache(quantized_v, 128, shape, 4, scale, mn, False) 589 | 590 | 591 | if __name__ == '__main__': 592 | # test_group_quantize() 593 | # test_channel_quantize() 594 | test_quantize() -------------------------------------------------------------------------------- /pred_long_bench.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import load_dataset 3 | import torch 4 | import json 5 | from tqdm import tqdm 6 | import numpy as np 7 | import random 8 | import argparse 9 | os.environ["WANDB_DISABLED"] = "true" 10 | 11 | from utils.process_args import process_args 12 | from transformers import LlamaConfig, MistralConfig, AutoTokenizer 13 | 14 | 15 | # This is the customized building prompt for chat models 16 | def build_chat(tokenizer, prompt, model_name): 17 | # For results in KIVI paper (Llama, Llama-Chat, Mistral-7B-v0.1), we do not apply any special treatment to the prompt. 18 | # For lmsys/longchat-7b-v1.5-32k and mistralai/Mistral-7B-Instruct-v0.2, we need to rewrite the prompt a little bit. 19 | # Update: we add the template for the new llama-3-instruct model 20 | if "llama-3" in model_name.lower() and "instruct" in model_name.lower(): 21 | messages = [ 22 | {"role": "user", "content": prompt}, 23 | ] 24 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 25 | elif "longchat" in model_name.lower(): 26 | from fastchat.model import get_conversation_template 27 | conv = get_conversation_template("vicuna") 28 | conv.append_message(conv.roles[0], prompt) 29 | conv.append_message(conv.roles[1], None) 30 | prompt = conv.get_prompt() 31 | elif "mistral-v0.2-instruct" in model_name.lower(): 32 | messages = [ 33 | { 34 | "role": "user", 35 | "content": prompt 36 | } 37 | ] 38 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 39 | return prompt 40 | 41 | def post_process(response, model_name): 42 | if "xgen" in model_name: 43 | response = response.strip().replace("Assistant:", "") 44 | elif "internlm" in model_name: 45 | response = response.split("")[0] 46 | return response 47 | 48 | def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name): 49 | preds = [] 50 | for json_obj in tqdm(data): 51 | prompt = prompt_format.format(**json_obj) 52 | # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) 53 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] 54 | # if "chatglm3" in model: 55 | # tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0] 56 | if len(tokenized_prompt) > max_length: 57 | half = int(max_length/2) 58 | prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) 59 | if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks 60 | prompt = build_chat(tokenizer, prompt, model_name) 61 | input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) 62 | context_length = input.input_ids.shape[-1] 63 | if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue 64 | output = model.generate( 65 | **input, 66 | max_new_tokens=max_gen, 67 | num_beams=1, 68 | do_sample=False, 69 | temperature=1.0, 70 | min_length=context_length+1, 71 | eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], 72 | )[0] 73 | else: 74 | output = model.generate( 75 | **input, 76 | max_new_tokens=max_gen, 77 | num_beams=1, 78 | do_sample=False, 79 | temperature=1.0, 80 | )[0] 81 | pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) 82 | pred = post_process(pred, model_name) 83 | preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}) 84 | return preds 85 | 86 | def seed_everything(seed): 87 | torch.manual_seed(seed) 88 | torch.cuda.manual_seed(seed) 89 | np.random.seed(seed) 90 | random.seed(seed) 91 | torch.backends.cudnn.benchmark = False 92 | torch.backends.cudnn.deterministic = True 93 | torch.cuda.manual_seed_all(seed) 94 | 95 | if __name__ == '__main__': 96 | seed_everything(42) 97 | # args = parse_args() 98 | model2path = json.load(open("config/model2path.json", "r")) 99 | model2maxlen = json.load(open("config/model2maxlen.json", "r")) 100 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 101 | # model_name = args.model 102 | 103 | # define your model 104 | model_args, data_args, training_args = process_args() 105 | # print(model_args, data_args, training_args) 106 | model_name = model_args.model_name_or_path.split("/")[-1] 107 | # dtype = torch.bfloat16 if training_args.bf16 else torch.float 108 | dtype = torch.float16 109 | 110 | if 'llama' in model_args.model_name_or_path.lower() or 'longchat' in model_args.model_name_or_path.lower(): 111 | config = LlamaConfig.from_pretrained(model_args.model_name_or_path) 112 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, 113 | use_fast=False, 114 | trust_remote_code=True, 115 | tokenizer_type='llama') 116 | # model_max_length=training_args.model_max_length) 117 | elif 'mistral' in model_args.model_name_or_path.lower(): 118 | config = MistralConfig.from_pretrained(model_args.model_name_or_path) 119 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, 120 | use_fast=False, 121 | trust_remote_code=True) 122 | else: 123 | raise NotImplementedError 124 | 125 | if 'llama' in model_args.model_name_or_path.lower() or 'longchat' in model_args.model_name_or_path.lower(): 126 | if model_args.k_bits < 16 and model_args.v_bits < 16: 127 | from models.llama_kivi import LlamaForCausalLM_KIVI 128 | config.k_bits = model_args.k_bits 129 | config.v_bits = model_args.v_bits 130 | config.group_size = model_args.group_size 131 | config.residual_length = model_args.residual_length 132 | config.use_flash = True # Note: We activate the flashattention to speed up the inference 133 | model = LlamaForCausalLM_KIVI.from_pretrained( 134 | pretrained_model_name_or_path=model_args.model_name_or_path, 135 | config=config, 136 | cache_dir=training_args.cache_dir, 137 | torch_dtype=dtype, 138 | low_cpu_mem_usage=True, 139 | device_map="auto", 140 | ) 141 | else: 142 | from transformers import LlamaForCausalLM 143 | model = LlamaForCausalLM.from_pretrained( 144 | pretrained_model_name_or_path=model_args.model_name_or_path, 145 | config=config, 146 | cache_dir=training_args.cache_dir, 147 | torch_dtype=dtype, 148 | low_cpu_mem_usage=True, 149 | use_flash_attention_2=True, 150 | device_map="auto", 151 | ) 152 | 153 | elif 'mistral' in model_args.model_name_or_path.lower(): 154 | if model_args.k_bits < 16 and model_args.v_bits < 16: 155 | from models.mistral_kivi import MistralForCausalLM_KIVI 156 | config.k_bits = model_args.k_bits 157 | config.v_bits = model_args.v_bits 158 | config.group_size = model_args.group_size 159 | config.residual_length = model_args.residual_length 160 | config.use_flash = True 161 | model = MistralForCausalLM_KIVI.from_pretrained( 162 | pretrained_model_name_or_path=model_args.model_name_or_path, 163 | config=config, 164 | cache_dir=training_args.cache_dir, 165 | torch_dtype=dtype, 166 | low_cpu_mem_usage=True, 167 | device_map="auto", 168 | ) 169 | else: 170 | from transformers import MistralForCausalLM 171 | model = MistralForCausalLM.from_pretrained( 172 | pretrained_model_name_or_path=model_args.model_name_or_path, 173 | config=config, 174 | cache_dir=training_args.cache_dir, 175 | torch_dtype=dtype, 176 | low_cpu_mem_usage=True, 177 | use_flash_attention_2=True, 178 | device_map="auto", 179 | ) 180 | 181 | else: 182 | raise NotImplementedError 183 | 184 | # 185 | # Load model directly 186 | # tokenizer = AutoTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K") 187 | # model = AutoModelForCausalLM.from_pretrained("togethercomputer/LLaMA-2-7B-32K") 188 | 189 | model.eval() 190 | max_length = model2maxlen[model_name] 191 | if data_args.e: 192 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", 193 | "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] 194 | else: 195 | datasets = ["triviaqa", "qasper", "trec", "samsum", "lcc", "repobench-p", "qmsum", "multi_news"] 196 | # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output 197 | dataset2prompt = json.load(open("config/dataset2prompt.json", "r")) 198 | dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r")) 199 | # predict on each dataset 200 | if not os.path.exists("pred"): 201 | os.makedirs("pred") 202 | if not os.path.exists("pred_e"): 203 | os.makedirs("pred_e") 204 | for dataset in datasets: 205 | if data_args.e: 206 | data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test') 207 | if not os.path.exists(f"pred_e/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}"): 208 | os.makedirs(f"pred_e/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}") 209 | out_path = f"pred_e/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}/{dataset}.jsonl" 210 | else: 211 | data = load_dataset('THUDM/LongBench', dataset, split='test') 212 | if not os.path.exists(f"pred/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}"): 213 | os.makedirs(f"pred/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}") 214 | out_path = f"pred/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}/{dataset}.jsonl" 215 | prompt_format = dataset2prompt[dataset] 216 | max_gen = dataset2maxlen[dataset] 217 | preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name) 218 | with open(out_path, "w", encoding="utf-8") as f: 219 | for pred in preds: 220 | json.dump(pred, f, ensure_ascii=False) 221 | f.write('\n') -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "kivi" 7 | version = "0.1.0" 8 | description = "An tuning-free 2/4/8 bit KV Cache quantization method for LLMs." 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "packaging==24.0", "sentencepiece", "tokenizers>=0.15", 17 | "torch==2.4.1", "ipdb", 18 | "transformers==4.43.1", 19 | "toml", "attributedict", 20 | "accelerate", 21 | "fastchat", 22 | "protobuf", 23 | "flash-attn", 24 | "datasets" 25 | ] 26 | 27 | [tool.setuptools.packages.find] 28 | exclude = ["results*", "scripts*", "examples*"] 29 | 30 | [tool.wheel] 31 | exclude = ["results*", "scripts*", "examples*"] 32 | -------------------------------------------------------------------------------- /quant/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jy-yuan/KIVI/6de8e6e50547f9e240b28b17296affe8a2f034b7/quant/__init__.py -------------------------------------------------------------------------------- /quant/csrc/gemv_cuda.cu: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/ankan-ban/llama_cu_awq 2 | // and the official implementation of AWQ 3 | /* 4 | 5 | @article{lin2023awq, 6 | title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, 7 | author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, 8 | journal={arXiv}, 9 | year={2023} 10 | } 11 | 12 | */ 13 | 14 | #include 15 | #include 16 | #include 17 | #include "gemv_cuda.h" 18 | #define VECTORIZE_FACTOR 8 19 | #define Q_VECTORIZE_FACTOR 8 20 | #define PACK_FACTOR 8 21 | #define WARP_SIZE 32 22 | 23 | 24 | // Reduce sum within the warp using the tree reduction algorithm. 25 | __device__ __forceinline__ float warp_reduce_sum(float sum) { 26 | #pragma unroll 27 | for(int i = 4; i >= 0; i--){ 28 | sum += __shfl_down_sync(0xffffffff, sum, 1< 64 numbers -> 1 group; 1 warp = 16 groups. 85 | float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]); 86 | float current_zeros = __half2float(zeros[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]); 87 | int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; 88 | const float4* inputs_ptr = inputs + inputs_ptr_delta; 89 | // multiply 32 weights with 32 inputs 90 | #pragma unroll 91 | for (int ic_0 = 0; ic_0 < 4; ic_0++){ 92 | // iterate over different uint32_t packed_weights in this loop 93 | uint32_t current_packed_weight = packed_weights[ic_0]; 94 | half packed_inputs[PACK_FACTOR]; 95 | // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) 96 | if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { 97 | *((float4*)packed_inputs) = *(inputs_ptr + ic_0); 98 | #pragma unroll 99 | for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ 100 | // iterate over 8 numbers packed within each uint32_t number 101 | float current_single_weight_fp = (float)(current_packed_weight & 0xF); 102 | float dequantized_weight = scaling_factor * current_single_weight_fp + current_zeros; 103 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); 104 | psum += dequantized_weight * __half2float(packed_inputs[ic_1]); 105 | current_packed_weight = current_packed_weight >> 4; 106 | } 107 | } 108 | } 109 | } 110 | psum = warp_reduce_sum(psum); 111 | if (threadIdx.x == 0) { 112 | outputs[oc_idx] = __float2half(psum); 113 | } 114 | } 115 | 116 | 117 | /* 118 | Computes GEMV (group_size = 128). 119 | 120 | Args: 121 | inputs: vector of shape [batch_size, IC]; 122 | weight: matrix of shape [OC, IC / 8]; 123 | output: vector of shape [OC]; 124 | zeros: matrix of shape [OC, IC / group_size / 8]; 125 | scaling_factors: matrix of shape [OC, IC / group_size]; 126 | 127 | Notes: 128 | One cannot infer group_size from the shape of scaling factors. 129 | the second dimension is rounded up to a multiple of PACK_FACTOR. 130 | */ 131 | __global__ void gemv_kernel_g128( 132 | const float4* _inputs, const uint32_t* weight, const half* zeros, const half* scaling_factors, half* _outputs, 133 | const int IC, const int OC){ 134 | const int group_size = 128; 135 | float psum = 0; 136 | const int batch_idx = blockIdx.z; 137 | const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; 138 | const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; 139 | half* outputs = _outputs + batch_idx * OC; 140 | const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR); 141 | const int weight_w = IC / PACK_FACTOR; 142 | // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address 143 | const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR); 144 | // consistent with input shape 145 | const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR; 146 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w); 147 | // tile size: 4 OC x 1024 IC per iter 148 | for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){ 149 | // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros. 150 | uint32_t packed_weights[4]; 151 | // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) 152 | *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); 153 | // load scaling factors 154 | // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups. 155 | float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); 156 | float current_zeros = __half2float(zeros[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); 157 | int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; 158 | const float4* inputs_ptr = inputs + inputs_ptr_delta; 159 | // multiply 32 weights with 32 inputs 160 | #pragma unroll 161 | for (int ic_0 = 0; ic_0 < 4; ic_0++){ 162 | // iterate over different uint32_t packed_weights in this loop 163 | uint32_t current_packed_weight = packed_weights[ic_0]; 164 | half packed_inputs[PACK_FACTOR]; 165 | // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) 166 | if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { 167 | *((float4*)packed_inputs) = *(inputs_ptr + ic_0); 168 | #pragma unroll 169 | for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ 170 | // iterate over 8 numbers packed within each uint32_t number 171 | float current_single_weight_fp = (float)(current_packed_weight & 0xF); 172 | float dequantized_weight = scaling_factor * current_single_weight_fp + current_zeros; 173 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); 174 | psum += dequantized_weight * __half2float(packed_inputs[ic_1]); 175 | current_packed_weight = current_packed_weight >> 4; 176 | } 177 | } 178 | } 179 | } 180 | psum = warp_reduce_sum(psum); 181 | if (threadIdx.x == 0) { 182 | outputs[oc_idx] = __float2half(psum); 183 | } 184 | } 185 | 186 | 187 | /* 188 | Computes GEMV (PyTorch interface). 189 | 190 | Args: 191 | _in_feats: tensor of shape [B, IC]; 192 | _kernel: int tensor of shape [OC, IC // 8]; 193 | _zeros: int tensor of shape [OC, IC // G // 8]; 194 | _scaling_factors: tensor of shape [OC, IC // G]; 195 | blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; 196 | blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; 197 | 198 | Returns: 199 | out_feats: tensor of shape [B, OC]; 200 | */ 201 | torch::Tensor gemv_forward_cuda( 202 | torch::Tensor _in_feats, 203 | torch::Tensor _kernel, 204 | torch::Tensor _scaling_factors, 205 | torch::Tensor _zeros, 206 | const int bit, 207 | const int group_size) 208 | { 209 | int num_in_feats = _in_feats.size(0); 210 | int num_in_channels = _in_feats.size(1); 211 | // int kernel_volume = _out_in_map.size(1); 212 | auto in_feats = reinterpret_cast(_in_feats.data_ptr()); 213 | auto kernel = reinterpret_cast(_kernel.data_ptr()); 214 | auto zeros = reinterpret_cast(_zeros.data_ptr()); 215 | auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); 216 | // auto out_in_map = _out_in_map.data_ptr(); 217 | auto options = 218 | torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); 219 | // kernel is [OC, IC] 220 | at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options); 221 | int num_out_feats = _out_feats.size(-2); 222 | int num_out_channels = _out_feats.size(-1); 223 | auto out_feats = reinterpret_cast(_out_feats.data_ptr()); 224 | int blockDim_z = num_out_feats; 225 | dim3 num_blocks(1, num_out_channels / 4, num_out_feats); 226 | dim3 num_threads(32, 4); 227 | if (group_size == 64) 228 | { 229 | gemv_kernel_g64<<>>( 230 | // pointers 231 | in_feats, kernel, zeros, scaling_factors, out_feats, 232 | // constants 233 | num_in_channels, num_out_channels 234 | ); 235 | } 236 | else if (group_size == 128) 237 | { 238 | gemv_kernel_g128<<>>( 239 | // pointers 240 | in_feats, kernel, zeros, scaling_factors, out_feats, 241 | // constants 242 | num_in_channels, num_out_channels 243 | ); 244 | } 245 | return _out_feats; 246 | ;} 247 | 248 | 249 | 250 | 251 | /* 252 | Computes Batched 4-bit GEMV (group_size = 64). 253 | 254 | Args: 255 | inputs: vector of shape [BS, 1, IC]; 256 | weight: matrix of shape [BS, OC // PACK_FACTOR, IC]; 257 | output: vector of shape [BS, 1, OC]; 258 | zeros: matrix of shape [BS, OC // group_size, IC]; 259 | scaling_factors: matrix of shape [BS, OC // group_size, IC]; 260 | 261 | Notes: 262 | One cannot infer group_size from the shape of scaling factors. 263 | the second dimension is rounded up to a multiple of PACK_FACTOR. 264 | */ 265 | __global__ void bgemv4_kernel_outer_dim( 266 | const half* _inputs, const uint32_t* _weight, const half* _zeros, const half* _scale, half* _outputs, 267 | const int IC, const int OC, const int group_size, const int nh, const int nh_kv){ 268 | const int bit = 4; 269 | const int pack_factor = 8; 270 | const int batch_idx = blockIdx.x; 271 | const int packed_oc_idx = blockIdx.y * blockDim.y + threadIdx.y; 272 | const int oc_start_idx = packed_oc_idx * pack_factor; 273 | const int group_idx = oc_start_idx / group_size; 274 | const half* inputs = _inputs + batch_idx * IC; 275 | half* outputs = _outputs + batch_idx * OC; 276 | const int ratio = nh / nh_kv; 277 | int _batch_idx = batch_idx / ratio; 278 | const uint32_t* weight = _weight + _batch_idx * OC * IC / pack_factor; 279 | const half* scaling_factors = _scale + _batch_idx * OC * IC / group_size; 280 | const half* zeros = _zeros + _batch_idx * OC * IC / group_size; 281 | const int TILE_DIM = 128; 282 | const int num = 0xFF >> (8-bit); 283 | const int ICR = IC; 284 | // 1float4 == 8 half number 285 | float psum[pack_factor]{}; 286 | for (int k=0; k < (IC + TILE_DIM - 1) / TILE_DIM; k++){ 287 | uint32_t qw[4]{}; 288 | half cscale[4]{}; 289 | half czero[4]{}; 290 | half inp[4]{}; 291 | // each thread load 32 int4 number 292 | int weight_offset = packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4; 293 | int scale_mn_offset = group_idx * ICR + k * TILE_DIM + threadIdx.x*4; 294 | int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4; 295 | for (int i=0; i<4; i++){ 296 | if (weight_offset + i < OC * ICR / pack_factor) 297 | qw[i] = *(weight + weight_offset + i); 298 | if (scale_mn_offset + i < OC * ICR / group_size){ 299 | cscale[i] = *(scaling_factors + scale_mn_offset + i); 300 | czero[i] = *(zeros + scale_mn_offset + i);} 301 | if (inputs_ptr_delta + i < ICR) 302 | inp[i] = *(inputs + inputs_ptr_delta + i); 303 | } 304 | // each thread load 32 int4 number 305 | // int weight_offset = packed_oc_idx * IC + k * TILE_DIM + threadIdx.x*4; 306 | // if (weight_offset < OC * IC / pack_factor) 307 | // *((float4*)(qw)) = *((float4*)(weight + packed_oc_idx * IC + k * TILE_DIM + threadIdx.x*4)); 308 | // int scale_mn_offset = group_idx * IC + k * TILE_DIM + threadIdx.x*4; 309 | // if (scale_mn_offset < OC * IC / group_size){ 310 | // *((float2*)(cscale)) = *((float2*)(scaling_factors + scale_mn_offset)); 311 | // *((float2*)(czero)) = *((float2*)(zeros + scale_mn_offset)); 312 | // } 313 | // int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4; 314 | // if (inputs_ptr_delta < IC){ 315 | // const half* inputs_ptr = inputs + inputs_ptr_delta; 316 | // *((float2*)(inp)) = *((float2*)(inputs_ptr)); 317 | // } 318 | // multiply 32 weights with 32 inputs 319 | #pragma unroll 320 | for (int ic_0 = 0; ic_0 < 4; ic_0++){ 321 | uint32_t cur_packed_weight = qw[ic_0]; 322 | float cur_inp = __half2float(inp[ic_0]); 323 | float cur_scale = __half2float(cscale[ic_0]); 324 | float cur_zero = __half2float(czero[ic_0]); 325 | for (int ic_1 = 0; ic_1 < pack_factor; ic_1++){ 326 | int oc_idx = oc_start_idx + ic_1; 327 | if (oc_idx < OC){ 328 | float cur_single_weight_fp = (float)(cur_packed_weight & num); 329 | float dequantized_weight = cur_scale * cur_single_weight_fp + cur_zero; 330 | // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && k == 1) printf("%d %d %d %f %f %f %f %f\n", k, ic_0, ic_1, dequantized_weight, cur_single_weight_fp, cur_scale, cur_zero, cur_inp); 331 | cur_packed_weight = cur_packed_weight >> bit; 332 | psum[ic_1] += dequantized_weight * cur_inp; 333 | } 334 | } 335 | } 336 | } 337 | for (int i=0; i < pack_factor; i++){ 338 | int oc_idx = oc_start_idx + i; 339 | if (oc_idx < OC){ 340 | psum[i] = warp_reduce_sum(psum[i]); 341 | if (threadIdx.x == 0) 342 | outputs[oc_idx] = __float2half(psum[i]); 343 | } 344 | } 345 | } 346 | 347 | 348 | __global__ void bgemv2_kernel_outer_dim( 349 | const half* _inputs, const uint32_t* _weight, const half* _zeros, const half* _scale, half* _outputs, 350 | const int IC, const int OC, const int group_size, const int nh, const int nh_kv){ 351 | // const int group_size = 64; 352 | const int bit = 2; 353 | const int pack_factor = 16; 354 | const int batch_idx = blockIdx.x; 355 | const int packed_oc_idx = blockIdx.y * blockDim.y + threadIdx.y; 356 | const int oc_start_idx = packed_oc_idx * pack_factor; 357 | const int group_idx = oc_start_idx / group_size; 358 | const int ICR = IC; 359 | const half* inputs = _inputs + batch_idx * ICR; 360 | half* outputs = _outputs + batch_idx * OC; 361 | const int ratio = nh / nh_kv; 362 | int _batch_idx = batch_idx / ratio; 363 | const uint32_t* weight = _weight + _batch_idx * OC * IC / pack_factor; 364 | const half* scaling_factors = _scale + _batch_idx * OC * IC / group_size; 365 | const half* zeros = _zeros + _batch_idx * OC * IC / group_size; 366 | const int TILE_DIM = 128; 367 | const int num = 0xFF >> (8-bit); 368 | // 1float4 == 8 half number 369 | float psum[pack_factor]{}; 370 | for (int k=0; k < (ICR + TILE_DIM - 1) / TILE_DIM; k++){ 371 | uint32_t qw[4]{}; 372 | half cscale[4]{}; 373 | half czero[4]{}; 374 | half inp[4]{}; 375 | // each thread load 32 int4 number 376 | int weight_offset = packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4; 377 | int scale_mn_offset = group_idx * ICR + k * TILE_DIM + threadIdx.x*4; 378 | int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4; 379 | for (int i=0; i<4; i++){ 380 | if (weight_offset + i < OC * ICR / pack_factor) 381 | qw[i] = *(weight + weight_offset + i); 382 | if (scale_mn_offset + i < OC * ICR / group_size){ 383 | cscale[i] = *(scaling_factors + scale_mn_offset + i); 384 | czero[i] = *(zeros + scale_mn_offset + i);} 385 | if (inputs_ptr_delta + i < ICR) 386 | inp[i] = *(inputs + inputs_ptr_delta + i); 387 | } 388 | // if (weight_offset < OC * ICR / pack_factor) 389 | // *((float4*)(qw)) = *((float4*)(weight + packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4)); 390 | // int scale_mn_offset = group_idx * ICR + k * TILE_DIM + threadIdx.x*4; 391 | // if (scale_mn_offset < OC * ICR / group_size){ 392 | // *((float2*)(cscale)) = *((float2*)(scaling_factors + scale_mn_offset)); 393 | // *((float2*)(czero)) = *((float2*)(zeros + scale_mn_offset)); 394 | // } 395 | // int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4; 396 | // if (inputs_ptr_delta < ICR){ 397 | // const half* inputs_ptr = inputs + inputs_ptr_delta; 398 | // *((float2*)(inp)) = *((float2*)(inputs_ptr)); 399 | // } 400 | // multiply 32 weights with 32 inputs 401 | #pragma unroll 402 | for (int ic_0 = 0; ic_0 < 4; ic_0++){ 403 | uint32_t cur_packed_weight = qw[ic_0]; 404 | float cur_inp = __half2float(inp[ic_0]); 405 | float cur_scale = __half2float(cscale[ic_0]); 406 | float cur_zero = __half2float(czero[ic_0]); 407 | for (int ic_1 = 0; ic_1 < pack_factor; ic_1++){ 408 | int oc_idx = oc_start_idx + ic_1; 409 | if (oc_idx < OC){ 410 | float cur_single_weight_fp = (float)(cur_packed_weight & num); 411 | float dequantized_weight = cur_scale * cur_single_weight_fp + cur_zero; 412 | // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && k == 1) printf("%d %d %d %f %f %f %f %f\n", k, ic_0, ic_1, dequantized_weight, cur_single_weight_fp, cur_scale, cur_zero, cur_inp); 413 | cur_packed_weight = cur_packed_weight >> bit; 414 | psum[ic_1] += dequantized_weight * cur_inp; 415 | } 416 | } 417 | } 418 | } 419 | for (int i=0; i < pack_factor; i++){ 420 | int oc_idx = oc_start_idx + i; 421 | if (oc_idx < OC){ 422 | psum[i] = warp_reduce_sum(psum[i]); 423 | if (threadIdx.x == 0) 424 | outputs[oc_idx] = __float2half(psum[i]); 425 | } 426 | } 427 | } 428 | 429 | // __global__ void bgemv2_kernel_g64_outer_dim( 430 | // const half* _inputs, const uint32_t* _weight, const half* _zeros, const half* _scale, half* _outputs, 431 | // const int IC, const int OC){ 432 | // const int group_size = 64; 433 | // const int bit = 2; 434 | // const int pack_factor = 16; 435 | // const int batch_idx = blockIdx.x; 436 | // const int packed_oc_idx = blockIdx.y * blockDim.y + threadIdx.y; 437 | // const int oc_start_idx = packed_oc_idx * pack_factor; 438 | // const int group_idx = oc_start_idx / group_size; 439 | // const int ICR = IC; 440 | // const half* inputs = _inputs + batch_idx * ICR; 441 | // half* outputs = _outputs + batch_idx * OC; 442 | // const uint32_t* weight = _weight + batch_idx * OC * IC / pack_factor; 443 | // const half* scaling_factors = _scale + batch_idx * OC * IC / group_size; 444 | // const half* zeros = _zeros + batch_idx * OC * IC / group_size; 445 | // const int TILE_DIM = 128; 446 | // const int num = 0xFF >> (8-bit); 447 | // // 1float4 == 8 half number 448 | // float psum[pack_factor]{}; 449 | // for (int k=0; k < (ICR + TILE_DIM - 1) / TILE_DIM; k++){ 450 | // uint32_t qw[4]{}; 451 | // half cscale[4]{}; 452 | // half czero[4]{}; 453 | // half inp[4]{}; 454 | // // each thread load 32 int4 number 455 | // int weight_offset = packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4; 456 | // if (weight_offset < OC * ICR / pack_factor) 457 | // *((float4*)(qw)) = *((float4*)(weight + packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4)); 458 | // int scale_mn_offset = group_idx * ICR + k * TILE_DIM + threadIdx.x*4; 459 | // if (scale_mn_offset < OC * ICR / group_size){ 460 | // *((float2*)(cscale)) = *((float2*)(scaling_factors + scale_mn_offset)); 461 | // *((float2*)(czero)) = *((float2*)(zeros + scale_mn_offset)); 462 | // } 463 | // int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4; 464 | // if (inputs_ptr_delta < ICR){ 465 | // const half* inputs_ptr = inputs + inputs_ptr_delta; 466 | // *((float2*)(inp)) = *((float2*)(inputs_ptr)); 467 | // } 468 | // // multiply 32 weights with 32 inputs 469 | // #pragma unroll 470 | // for (int ic_0 = 0; ic_0 < 4; ic_0++){ 471 | // uint32_t cur_packed_weight = qw[ic_0]; 472 | // float cur_inp = __half2float(inp[ic_0]); 473 | // float cur_scale = __half2float(cscale[ic_0]); 474 | // float cur_zero = __half2float(czero[ic_0]); 475 | // for (int ic_1 = 0; ic_1 < pack_factor; ic_1++){ 476 | // int oc_idx = oc_start_idx + ic_1; 477 | // if (oc_idx < OC){ 478 | // float cur_single_weight_fp = (float)(cur_packed_weight & num); 479 | // float dequantized_weight = cur_scale * cur_single_weight_fp + cur_zero; 480 | // // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && k == 1) printf("%d %d %d %f %f %f %f %f\n", k, ic_0, ic_1, dequantized_weight, cur_single_weight_fp, cur_scale, cur_zero, cur_inp); 481 | // cur_packed_weight = cur_packed_weight >> bit; 482 | // psum[ic_1] += dequantized_weight * cur_inp; 483 | // } 484 | // } 485 | // } 486 | // } 487 | // for (int i=0; i < pack_factor; i++){ 488 | // int oc_idx = oc_start_idx + i; 489 | // if (oc_idx < OC){ 490 | // psum[i] = warp_reduce_sum(psum[i]); 491 | // if (threadIdx.x == 0) 492 | // outputs[oc_idx] = __float2half(psum[i]); 493 | // } 494 | // } 495 | // } 496 | 497 | 498 | /* 499 | Computes GEMV (PyTorch interface). 500 | 501 | Args: 502 | _in_feats: tensor of shape [B, IC]; 503 | _kernel: int tensor of shape [OC // PACK_Factor, IC]; 504 | _zeros: int tensor of shape [OC // G, IC]; 505 | _scaling_factors: tensor of shape [OC // G, IC]; 506 | blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; 507 | blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; 508 | Returns: 509 | out_feats: tensor of shape [B, OC]; 510 | */ 511 | torch::Tensor gemv_forward_cuda_outer_dim( 512 | torch::Tensor _in_feats, 513 | torch::Tensor _kernel, 514 | torch::Tensor _scaling_factors, 515 | torch::Tensor _zeros, 516 | const int bit, 517 | const int group_size, 518 | const int nh, 519 | const int nh_kv) 520 | { 521 | int BS = _in_feats.size(0); 522 | int num_in_feats = _in_feats.size(1); 523 | int num_in_channels = _in_feats.size(2); 524 | int num_out_channels = _zeros.size(1) * group_size; 525 | // int kernel_volume = _out_in_map.size(1); 526 | auto in_feats = reinterpret_cast(_in_feats.data_ptr()); 527 | auto kernel = reinterpret_cast(_kernel.data_ptr()); 528 | auto zeros = reinterpret_cast(_zeros.data_ptr()); 529 | auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); 530 | // auto out_in_map = _out_in_map.data_ptr(); 531 | auto options = 532 | torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); 533 | // kernel is [OC, IC] 534 | at::Tensor _out_feats = torch::empty({BS, num_in_feats, num_out_channels}, options); 535 | int num_out_feats = _out_feats.size(-2); 536 | auto out_feats = reinterpret_cast(_out_feats.data_ptr()); 537 | int pack_factor = 32 / bit; 538 | dim3 num_blocks(BS, (num_out_channels / pack_factor + 3) / 4, num_out_feats); 539 | dim3 num_threads(32, 4); 540 | if (bit == 4){ 541 | bgemv4_kernel_outer_dim<<>>( 542 | // pointers 543 | in_feats, kernel, zeros, scaling_factors, out_feats, 544 | // constants 545 | num_in_channels, num_out_channels, group_size, nh, nh_kv 546 | );} 547 | else{ 548 | // note: in this case, pack factor == 16 549 | bgemv2_kernel_outer_dim<<>>( 550 | // pointers 551 | in_feats, kernel, zeros, scaling_factors, out_feats, 552 | // constants 553 | num_in_channels, num_out_channels, group_size, nh, nh_kv 554 | ); 555 | } 556 | return _out_feats; 557 | ;} 558 | -------------------------------------------------------------------------------- /quant/csrc/gemv_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | torch::Tensor gemv_forward_cuda( 5 | torch::Tensor _in_feats, 6 | torch::Tensor _kernel, 7 | torch::Tensor _scaling_factors, 8 | torch::Tensor _zeros, 9 | const int bit, 10 | const int group_size); 11 | 12 | 13 | torch::Tensor gemv_forward_cuda_outer_dim( 14 | torch::Tensor _in_feats, 15 | torch::Tensor _kernel, 16 | torch::Tensor _scaling_factors, 17 | torch::Tensor _zeros, 18 | const int bit, 19 | const int group_size, 20 | const int nh, 21 | const int nh_kv); -------------------------------------------------------------------------------- /quant/csrc/gemv_cuda_backup.cu: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/ankan-ban/llama_cu_awq 2 | /* 3 | 4 | @article{lin2023awq, 5 | title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, 6 | author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, 7 | journal={arXiv}, 8 | year={2023} 9 | } 10 | 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include "gemv_cuda.h" 17 | #define VECTORIZE_FACTOR 8 18 | #define Q_VECTORIZE_FACTOR 8 19 | #define WARP_SIZE 32 20 | 21 | 22 | // Reduce sum within the warp using the tree reduction algorithm. 23 | __device__ __forceinline__ float warp_reduce_sum(float sum) { 24 | #pragma unroll 25 | for(int i = 4; i >= 0; i--){ 26 | sum += __shfl_down_sync(0xffffffff, sum, 1< 59 | __global__ void gemv_kernel_g64( 60 | const float4* _inputs, const uint32_t* weight, const half* zeros, const half* scaling_factors, half* _outputs, 61 | const int IC, const int OC){ 62 | const int group_size = 64; 63 | float psum = 0; 64 | const int batch_idx = blockIdx.z; 65 | const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; 66 | const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; 67 | half* outputs = _outputs + batch_idx * OC; 68 | const int num_groups_packed = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2; 69 | const int weight_w = IC / PACK_FACTOR; 70 | // consistent with input shape 71 | const int sf_w = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2 * PACK_FACTOR; 72 | // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w, sf_w); 73 | int elem_per_th = 128 / bit; 74 | int ng_per_warp = 32 * elem_per_th / 64; 75 | // tile size: 4 OC x (128 * PACK_FACTOR) IC per iter 76 | for(int packed_group_idx = 0; packed_group_idx < num_groups_packed / 2; packed_group_idx++){ 77 | uint32_t packed_weights[4]; 78 | // use float4 to load weights, each thread load (64,32,16) int-(2,4,8) numbers (1 x float4) 79 | *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); 80 | // load scaling factors 81 | // 1 warp == 32 threads 82 | // g64: 1 threads -> 64,32,16 numbers -> 1,.5,0.25 group; 1 warp = 32,16,8 groups. 83 | // TODO: from here 84 | // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && packed_group_idx == 0) printf("%d %d\n", elem_per_th, ng_per_warp); 85 | float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * ng_per_warp + (threadIdx.x*ng_per_warp/32)]); 86 | float current_zeros = __half2float(zeros[oc_idx * sf_w + packed_group_idx * ng_per_warp + (threadIdx.x*ng_per_warp/32)]); 87 | int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; 88 | const float4* inputs_ptr = inputs + inputs_ptr_delta; 89 | const int num = 0xFF >> (8-bit); 90 | // multiply (64,32,16) weights with (64,32,16) inputs 91 | #pragma unroll 92 | for (int ic_0 = 0; ic_0 < 4; ic_0++){ 93 | // iterate over different uint32_t packed_weights in this loop 94 | uint32_t current_packed_weight = packed_weights[ic_0]; 95 | half packed_inputs[PACK_FACTOR]; 96 | // each thread load (16,8,4) inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) 97 | if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { 98 | // TODO: bug is here!! for 4-bit, one float4 == 8 half number == packed_inputs[8] 99 | *((float4*)packed_inputs) = *(inputs_ptr + ic_0); 100 | #pragma unroll 101 | for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ 102 | // iterate over (16,8,4) numbers packed within each uint32_t number 103 | float current_single_weight_fp = (float)(current_packed_weight & num); 104 | float dequantized_weight = scaling_factor * current_single_weight_fp + current_zeros; 105 | if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 1 && packed_group_idx == 0) printf("%f %f %f %f %f\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, __half2float(packed_inputs[ic_1])); 106 | psum += dequantized_weight * __half2float(packed_inputs[ic_1]); 107 | current_packed_weight = current_packed_weight >> bit; 108 | } 109 | } 110 | } 111 | } 112 | psum = warp_reduce_sum(psum); 113 | if (threadIdx.x == 0) { 114 | outputs[oc_idx] = __float2half(psum); 115 | } 116 | } 117 | 118 | 119 | /* 120 | Computes GEMV (group_size = 128). 121 | 122 | Args: 123 | inputs: vector of shape [batch_size, IC]; 124 | weight: matrix of shape [OC, IC / 8]; 125 | output: vector of shape [OC]; 126 | zeros: matrix of shape [OC, IC / group_size / 8]; 127 | scaling_factors: matrix of shape [OC, IC / group_size]; 128 | 129 | Notes: 130 | One cannot infer group_size from the shape of scaling factors. 131 | the second dimension is rounded up to a multiple of PACK_FACTOR. 132 | */ 133 | template 134 | __global__ void gemv_kernel_g128( 135 | const float4* _inputs, const uint32_t* weight, const half* zeros, const half* scaling_factors, half* _outputs, 136 | const int IC, const int OC){ 137 | const int group_size = 128; 138 | float psum = 0; 139 | const int batch_idx = blockIdx.z; 140 | const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; 141 | const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; 142 | half* outputs = _outputs + batch_idx * OC; 143 | const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR); 144 | const int weight_w = IC / PACK_FACTOR; 145 | // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address 146 | const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR); 147 | // consistent with input shape 148 | const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR; 149 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w); 150 | // tile size: 4 OC x 1024 IC per iter 151 | for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){ 152 | // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros. 153 | uint32_t packed_weights[4]; 154 | // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) 155 | *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); 156 | // load scaling factors 157 | // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups. 158 | float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); 159 | float current_zeros = __half2float(zeros[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); 160 | int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; 161 | const float4* inputs_ptr = inputs + inputs_ptr_delta; 162 | const int num = 0xFF >> (8-bit); 163 | // multiply 32 weights with 32 inputs 164 | #pragma unroll 165 | for (int ic_0 = 0; ic_0 < 4; ic_0++){ 166 | // iterate over different uint32_t packed_weights in this loop 167 | uint32_t current_packed_weight = packed_weights[ic_0]; 168 | half packed_inputs[PACK_FACTOR]; 169 | // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) 170 | if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { 171 | *((float4*)packed_inputs) = *(inputs_ptr + ic_0); 172 | #pragma unroll 173 | for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ 174 | // iterate over 8 numbers packed within each uint32_t number 175 | float current_single_weight_fp = (float)(current_packed_weight & num); 176 | float dequantized_weight = scaling_factor * current_single_weight_fp + current_zeros; 177 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); 178 | psum += dequantized_weight * __half2float(packed_inputs[ic_1]); 179 | current_packed_weight = current_packed_weight >> bit; 180 | } 181 | } 182 | } 183 | } 184 | psum = warp_reduce_sum(psum); 185 | if (threadIdx.x == 0) { 186 | outputs[oc_idx] = __float2half(psum); 187 | } 188 | } 189 | 190 | 191 | /* 192 | Computes GEMV (PyTorch interface). 193 | 194 | Args: 195 | _in_feats: tensor of shape [B, IC]; 196 | _kernel: int tensor of shape [OC, IC // 8]; 197 | _zeros: int tensor of shape [OC, IC // G // 8]; 198 | _scaling_factors: tensor of shape [OC, IC // G]; 199 | blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; 200 | blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; 201 | 202 | Returns: 203 | out_feats: tensor of shape [B, OC]; 204 | */ 205 | torch::Tensor gemv_forward_cuda( 206 | torch::Tensor _in_feats, 207 | torch::Tensor _kernel, 208 | torch::Tensor _scaling_factors, 209 | torch::Tensor _zeros, 210 | const int bit, 211 | const int group_size) 212 | { 213 | int num_in_feats = _in_feats.size(0); 214 | int num_in_channels = _in_feats.size(1); 215 | // int kernel_volume = _out_in_map.size(1); 216 | auto in_feats = reinterpret_cast(_in_feats.data_ptr()); 217 | auto kernel = reinterpret_cast(_kernel.data_ptr()); 218 | auto zeros = reinterpret_cast(_zeros.data_ptr()); 219 | auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); 220 | // auto out_in_map = _out_in_map.data_ptr(); 221 | auto options = 222 | torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); 223 | // kernel is [OC, IC] 224 | at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options); 225 | int num_out_feats = _out_feats.size(-2); 226 | int num_out_channels = _out_feats.size(-1); 227 | auto out_feats = reinterpret_cast(_out_feats.data_ptr()); 228 | int blockDim_z = num_out_feats; 229 | dim3 num_blocks(1, num_out_channels / 4, num_out_feats); 230 | dim3 num_threads(32, 4); 231 | if (bit == 2){ 232 | if (group_size == 64) 233 | { 234 | gemv_kernel_g64<2, 16><<>>( 235 | // pointers 236 | in_feats, kernel, zeros, scaling_factors, out_feats, 237 | // constants 238 | num_in_channels, num_out_channels 239 | ); 240 | } 241 | else if (group_size == 128) 242 | { 243 | gemv_kernel_g128<2, 16><<>>( 244 | // pointers 245 | in_feats, kernel, zeros, scaling_factors, out_feats, 246 | // constants 247 | num_in_channels, num_out_channels 248 | );} 249 | }else if (bit == 4){ 250 | if (group_size == 64) 251 | { 252 | gemv_kernel_g64<4, 8><<>>( 253 | // pointers 254 | in_feats, kernel, zeros, scaling_factors, out_feats, 255 | // constants 256 | num_in_channels, num_out_channels 257 | ); 258 | } 259 | else if (group_size == 128) 260 | { 261 | gemv_kernel_g128<4, 8><<>>( 262 | // pointers 263 | in_feats, kernel, zeros, scaling_factors, out_feats, 264 | // constants 265 | num_in_channels, num_out_channels 266 | ); 267 | };} 268 | else{ 269 | if (group_size == 64) 270 | { 271 | gemv_kernel_g64<8, 4><<>>( 272 | // pointers 273 | in_feats, kernel, zeros, scaling_factors, out_feats, 274 | // constants 275 | num_in_channels, num_out_channels 276 | ); 277 | } 278 | else if (group_size == 128) 279 | { 280 | gemv_kernel_g128<8, 4><<>>( 281 | // pointers 282 | in_feats, kernel, zeros, scaling_factors, out_feats, 283 | // constants 284 | num_in_channels, num_out_channels 285 | ); 286 | }} 287 | return _out_feats; 288 | } 289 | 290 | /* 291 | Computes GEMV (group_size = 64). 292 | 293 | Args: 294 | inputs: vector of shape [batch_size, IC]; 295 | weight: matrix of shape [OC // PACK_FACTOR, IC; 296 | output: vector of shape [OC]; 297 | zeros: matrix of shape [OC // group_size, IC]; 298 | scaling_factors: matrix of shape [OC // group_size, IC]; 299 | 300 | Notes: 301 | One cannot infer group_size from the shape of scaling factors. 302 | the second dimension is rounded up to a multiple of PACK_FACTOR. 303 | */ 304 | __global__ void gemv_kernel_g64_outer_dim( 305 | const float4* _inputs, const uint32_t* weight, const half* zeros, const half* scaling_factors, half* _outputs, 306 | const int IC, const int OC){ 307 | const int group_size = 64; 308 | float psum = 0; 309 | const int pack_factor = 8; 310 | const int batch_idx = blockIdx.z; 311 | const int packed_oc_idx = blockIdx.y * blockDim.y + threadIdx.y; 312 | const int group_idx = packed_oc_idx * pack_factor / group_size; 313 | const float4* inputs = _inputs + batch_idx * IC; 314 | half* outputs = _outputs + batch_idx * OC; 315 | const int TILE_DIM = 32; 316 | extern __shared__ uint32_t packed_weight_shared[TILE_DIM]; 317 | extern __shared__ float scale_shared[TILE_DIM]; 318 | extern __shared__ float mn_shared[TILE_DIM]; 319 | for (int k=0; k < (IC + TILE_DIM - 1) / TILE_DIM; k++){ 320 | if (packed_oc_idx * pack_factor < OC && k*TILE_DIM+threadIdx.x < IC) 321 | packed_weight_shared[threadIdx.x] = weight[packed_oc_idx * IC + k * TILE_DIM + threadIdx.x]; 322 | else 323 | packed_weight_shared[threadIdx.x] = 0; 324 | if (group_idx * group_size < OC && k*TILE_DIM+threadIdx.x < IC){ 325 | scale_shared[threadIdx.x] = __half2float(scaling_factors[oc_idx / group_size * IC + k * TILE_DIM + threadIdx.x]); 326 | mn_shared[threadIdx.x] = __half2float(zeros[oc_idx / group_size * IC + k * TILE_DIM + threadIdx.x]); 327 | } 328 | else{ 329 | scale_shared[threadIdx.x] = 0.0; 330 | mn_shared[threadIdx.x] = 0.0; 331 | } 332 | __syncthreads(); 333 | } 334 | } 335 | -------------------------------------------------------------------------------- /quant/csrc/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "gemv_cuda.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 6 | { 7 | m.def("gemv_forward_cuda", &gemv_forward_cuda); 8 | m.def("gemv_forward_cuda_outer_dim", &gemv_forward_cuda_outer_dim); 9 | } -------------------------------------------------------------------------------- /quant/gemv.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="2" 3 | os.environ["CUDA_LAUNCH_BLOCKING"]="1" 4 | import numpy as np 5 | import torch 6 | # import ipdb 7 | import random 8 | import triton 9 | import triton.language as tl 10 | from new_pack import pack_tensor 11 | from timeit_v2 import py_benchmark 12 | import kivi_gemv 13 | 14 | B, nh, IC, OC = 8, 32, 739, 128 15 | 16 | @triton.jit 17 | def gemv_kernel_g64(inputs_ptr, qw_ptr, mn_ptr, 18 | scale_ptr, output_ptr, 19 | IC: tl.constexpr, OC: tl.constexpr, bit: tl.constexpr, 20 | OC_PER_PH: tl.constexpr, PACK_FACTOR: tl.constexpr, BLOCK_SIZE): 21 | """ 22 | Computes GEMV (group_size = 64). 23 | 24 | Args: 25 | inputs: vector of shape [batch_size, IC]; 26 | qw: matrix of shape [OC, IC / 8]; 27 | output: vector of shape [OC]; 28 | mn: matrix of shape [OC, NG]; 29 | scale: matrix of shape [OC, NG]; 30 | 31 | Notes: 32 | One cannot infer group_size from the shape of scaling factors. 33 | the second dimension is rounded up to a multiple of PACK_FACTOR. 34 | """ 35 | group_size = 64 36 | oc_idx = tl.program_id(axis=0) * OC_PER_PH + tl.arange(0, OC_PER_PH) 37 | batch_idx = tl.program_id(axis=1) 38 | num_groups = IC // group_size 39 | num_groups_packed = tl.cdiv(num_groups, PACK_FACTOR) 40 | # tl.store(output_ptr, num_groups_packed) 41 | weight_w = IC // PACK_FACTOR 42 | num = 0xFF >> (8-bit) 43 | accumulator = tl.zeros((OC_PER_PH,), dtype=tl.float32) 44 | for group_idx in range(0, num_groups): 45 | # load scaling factors 46 | # each time we load 4 OC x 1 G 47 | scale = tl.load(scale_ptr + oc_idx[:, None] * num_groups + group_idx) 48 | mn = tl.load(mn_ptr + oc_idx[:, None] * num_groups + group_idx) 49 | # 1 G -> 64 numbers -> 64 // PACK_FACTOR packed numbers 50 | cur_qw_ptr = qw_ptr + oc_idx[:, None] * weight_w + group_idx * (64 // PACK_FACTOR) + tl.arange(0, 64 // PACK_FACTOR)[None, :] 51 | qw = tl.load(cur_qw_ptr) 52 | for i in range(PACK_FACTOR): 53 | w_fp = qw & num 54 | # load 4 OC x 55 | w_fp = w_fp * scale + mn 56 | qw = qw >> bit 57 | cur_inp_ptr = inputs_ptr + batch_idx * IC + group_idx * 64 + i + tl.arange(0, 64 // PACK_FACTOR)[None, :] * PACK_FACTOR 58 | cur_input = tl.load(cur_inp_ptr) 59 | accumulator += tl.sum(cur_input * w_fp, 1) 60 | ptr = output_ptr + oc_idx + batch_idx * OC 61 | tl.store(ptr, accumulator) 62 | 63 | 64 | def dequant_weight(w, scale, mn, gs): 65 | w_fp = w.half().view(w.shape[0], w.shape[1]//gs, gs) 66 | w_fp = w_fp * scale.unsqueeze(-1) + mn.unsqueeze(-1) 67 | return w_fp.view(w.shape) 68 | 69 | 70 | def dequant_weight_outer(w, scale, mn, gs): 71 | # ipdb.set_trace() 72 | w_fp = w.half().view(w.shape[0], w.shape[1], w.shape[2]//gs, gs) 73 | w_fp = w_fp * scale.unsqueeze(-1) + mn.unsqueeze(-1) 74 | return w_fp.view(w.shape) 75 | 76 | 77 | def gemv_fwd(bit, group_size, inp, qweight, mn, scale): 78 | B, IC = inp.shape 79 | OC = qweight.shape[0] 80 | BLOCK_SIZE = 32 81 | OC_PER_PH = 32 82 | PACK_FACTOR = 32 // bit 83 | assert group_size == 64 84 | output = torch.empty((B, OC), device=inp.device, dtype=torch.float16) 85 | grid = lambda META: ( 86 | triton.cdiv(OC, META['OC_PER_PH']), B 87 | ) 88 | gemv_kernel_g64[grid](inp, qweight, mn, scale, output, 89 | IC, OC, bit, OC_PER_PH, PACK_FACTOR, BLOCK_SIZE) 90 | return output 91 | 92 | 93 | def test_bgemv_outer_correct_mha(): 94 | flatten_B = B * nh 95 | inp = torch.randn((flatten_B, 1, IC), device='cuda', dtype=torch.float16) 96 | ori_weight = torch.randn((flatten_B, IC, OC), device='cuda', dtype=torch.float16) 97 | GS = 32 98 | for BIT in [2, 4]: 99 | weight = ori_weight 100 | PACK_FACTOR = 32 // BIT 101 | assert OC % GS == 0 and OC % PACK_FACTOR == 0 102 | NG = OC // GS 103 | weight = weight.view(flatten_B, IC, NG, GS) 104 | mx = torch.max(weight, dim=-1, keepdim=False)[0] 105 | mn = torch.min(weight, dim=-1, keepdim=False)[0] 106 | maxq = 2 ** BIT - 1 107 | scale = (mx - mn) / maxq 108 | weight = weight - mn.unsqueeze(-1) 109 | weight.div_(scale.unsqueeze(-1)) 110 | weight = weight.clamp_(0, maxq).round_().to(torch.int32) 111 | weight = weight.view(flatten_B, IC, OC) 112 | qweight = pack_tensor(weight, BIT, 2) 113 | weight = weight.transpose(1, 2).contiguous() 114 | qweight = qweight.transpose(1, 2).contiguous() 115 | scale = scale.transpose(1, 2).contiguous() 116 | mn = mn.transpose(1, 2).contiguous() 117 | output = kivi_gemv.gemv_forward_cuda_outer_dim(inp, qweight, scale, mn, BIT, GS, nh, False) 118 | deq_w = dequant_weight_outer(weight.transpose(1, 2), 119 | scale.transpose(1, 2), 120 | mn.transpose(1, 2), GS) 121 | # rel_error = torch.abs((deq_w - ori_weight).float() / (ori_weight + 1e-5).float()).mean() 122 | # print(f'bit {BIT} avg rel weight quant error: {rel_error}') 123 | output_ref = inp @ deq_w 124 | error = output_ref - output 125 | rel_out_error = torch.abs(error.float() / (torch.abs(output_ref).float()+1e-5)).mean() 126 | print(f'mha bit {BIT} avg rel out quant error: {rel_out_error}') 127 | 128 | 129 | def test_bgemv_outer_correct_mqa(): 130 | flatten_B = B * nh 131 | inp = torch.randn((flatten_B, 1, IC), device='cuda', dtype=torch.float16) 132 | ori_weight = torch.randn((B, IC, OC), device='cuda', dtype=torch.float16) 133 | GS = 32 134 | for BIT in [2, 4]: 135 | weight = ori_weight 136 | PACK_FACTOR = 32 // BIT 137 | assert OC % GS == 0 and OC % PACK_FACTOR == 0 138 | NG = OC // GS 139 | weight = weight.view(B, IC, NG, GS) 140 | mx = torch.max(weight, dim=-1, keepdim=False)[0] 141 | mn = torch.min(weight, dim=-1, keepdim=False)[0] 142 | maxq = 2 ** BIT - 1 143 | scale = (mx - mn) / maxq 144 | weight = weight - mn.unsqueeze(-1) 145 | weight.div_(scale.unsqueeze(-1)) 146 | weight = weight.clamp_(0, maxq).round_().to(torch.int32) 147 | weight = weight.view(B, IC, OC) 148 | qweight = pack_tensor(weight, BIT, 2) 149 | inp = inp.contiguous() 150 | weight = weight.transpose(1, 2).contiguous() 151 | qweight = qweight.transpose(1, 2).contiguous() 152 | scale = scale.transpose(1, 2).contiguous() 153 | mn = mn.transpose(1, 2).contiguous() 154 | output = kivi_gemv.gemv_forward_cuda_outer_dim(inp, qweight, scale, mn, BIT, GS, nh, True) 155 | deq_w = dequant_weight_outer(weight.transpose(1, 2), 156 | scale.transpose(1, 2), 157 | mn.transpose(1, 2), GS) 158 | # rel_error = torch.abs((deq_w - ori_weight).float() / (ori_weight + 1e-5).float()).mean() 159 | # print(f'bit {BIT} avg rel weight quant error: {rel_error}') 160 | output_ref = inp.view(B, nh, 1, IC) @ deq_w.view(B, 1, IC, OC) 161 | output_ref = output_ref.view(flatten_B, 1, OC) 162 | error = output_ref - output 163 | # ipdb.set_trace() 164 | rel_out_error = torch.abs(error.float() / (torch.abs(output_ref).float()+1e-5)).mean() 165 | print(f'mqa bit {BIT} avg rel out quant error: {rel_out_error}') 166 | 167 | 168 | def test_gemv_correct(): 169 | inp = torch.randn((B, IC), device='cuda', dtype=torch.float16) 170 | ori_weight = torch.randn((OC, IC), device='cuda', dtype=torch.float16) 171 | GS = 64 172 | for BIT in [4]: 173 | weight = ori_weight 174 | PACK_FACTOR = 32 // BIT 175 | assert IC % GS == 0 and IC % PACK_FACTOR == 0 176 | NG = IC // GS 177 | weight = weight.view(OC, NG, GS) 178 | mx = torch.max(weight, dim=2, keepdim=False)[0] 179 | mn = torch.min(weight, dim=2, keepdim=False)[0] 180 | maxq = 2 ** BIT - 1 181 | scale = (mx - mn) / maxq 182 | weight = weight - mn.unsqueeze(-1) 183 | weight.div_(scale.unsqueeze(-1)) 184 | weight = weight.clamp_(0, maxq).round_().to(torch.int32) 185 | weight = weight.view(OC, IC) 186 | qweight = pack_tensor(weight, BIT, 1) 187 | # output = gemv_fwd(BIT, GS, inp, qweight, mn, scale) 188 | output = kivi_gemv.gemv_forward_cuda(inp, qweight, scale, mn, BIT, GS) 189 | deq_w = dequant_weight(weight, scale, mn, GS) 190 | rel_error = torch.abs((deq_w - ori_weight).float() / (ori_weight + 1e-5).float()).mean() 191 | # print(f'bit {BIT} avg rel weight quant error: {rel_error}') 192 | output_ref = inp @ deq_w.T 193 | error = output_ref - output 194 | rel_out_error = torch.abs(error.float() / (output_ref + 1e-5).float()).mean() 195 | print(f'bit {BIT} avg rel out quant error: {rel_out_error}') 196 | 197 | 198 | def test_gemv_speed(): 199 | inp = torch.randn((B, IC), device='cuda', dtype=torch.float16) 200 | ori_weight = torch.randn((OC, IC), device='cuda', dtype=torch.float16) 201 | weight = ori_weight 202 | BIT = 4 203 | GS = 64 204 | PACK_FACTOR = 32 // BIT 205 | assert IC % GS == 0 and IC % PACK_FACTOR == 0 206 | NG = IC // GS 207 | weight = weight.view(OC, NG, GS) 208 | mx = torch.max(weight, dim=2, keepdim=False)[0] 209 | mn = torch.min(weight, dim=2, keepdim=False)[0] 210 | maxq = 2 ** BIT - 1 211 | scale = (mx - mn) / maxq 212 | weight = weight - mn.unsqueeze(-1) 213 | weight.div_(scale.unsqueeze(-1)) 214 | weight = weight.clamp_(0, maxq).round_().to(torch.int32) 215 | weight = weight.view(OC, IC) 216 | qweight = pack_tensor(weight, BIT, 1) 217 | output = gemv_fwd(BIT, GS, inp, qweight, mn, scale) 218 | deq_w = dequant_weight(weight, scale, mn, GS) 219 | stmt = "inp @ deq_w.T" 220 | t_ref = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1, 221 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 222 | # stmt = "gemv_fwd(BIT, GS, inp, qweight, mn, scale)" 223 | # t_our = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1, 224 | # setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 225 | stmt = "kivi_gemv.gemv_forward_cuda(inp, qweight, scale, mn, BIT, GS)" 226 | t_our = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1, 227 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 228 | print(f'vanilla pytorch gemv: {t_ref * 1000} ms') 229 | print(f'awq fused IC {IC} OC {OC} {BIT}-bit gemv: {t_our * 1000} ms') 230 | 231 | 232 | def test_bgemv_outer_speed(): 233 | inp = torch.randn((B, 1, IC), device='cuda', dtype=torch.float16) 234 | ori_weight = torch.randn((B, IC, OC), device='cuda', dtype=torch.float16) 235 | GS = 64 236 | for BIT in [2]: 237 | weight = ori_weight 238 | PACK_FACTOR = 32 // BIT 239 | assert OC % GS == 0 and OC % PACK_FACTOR == 0 240 | NG = OC // GS 241 | weight = weight.view(B, IC, NG, GS) 242 | mx = torch.max(weight, dim=-1, keepdim=False)[0] 243 | mn = torch.min(weight, dim=-1, keepdim=False)[0] 244 | maxq = 2 ** BIT - 1 245 | scale = (mx - mn) / maxq 246 | weight = weight - mn.unsqueeze(-1) 247 | weight.div_(scale.unsqueeze(-1)) 248 | weight = weight.clamp_(0, maxq).round_().to(torch.int32) 249 | weight = weight.view(B, IC, OC) 250 | qweight = pack_tensor(weight, BIT, 2) 251 | weight = weight.transpose(1, 2).contiguous() 252 | qweight = qweight.transpose(1, 2).contiguous() 253 | scale = scale.transpose(1, 2).contiguous() 254 | mn = mn.transpose(1, 2).contiguous() 255 | deq_w = dequant_weight_outer(weight.transpose(1, 2), 256 | scale.transpose(1, 2), 257 | mn.transpose(1, 2), GS) 258 | stmt = "inp @ deq_w" 259 | t_ref = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1, 260 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 261 | # stmt = "gemv_fwd(BIT, GS, inp, qweight, mn, scale)" 262 | # t_our = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1, 263 | # setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 264 | stmt = "kivi_gemv.gemv_forward_cuda_outer_dim(inp, qweight, scale, mn, BIT, GS)" 265 | t_our = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1, 266 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 267 | print(f'BS {B} IC {IC} OC {OC} pytorch batched gemv: {t_ref * 1000} ms') 268 | print(f'our fused BS {B} IC {IC} OC {OC} {BIT}-bit outer-dim batched gemv: {t_our * 1000} ms') 269 | 270 | if __name__ == "__main__": 271 | torch.manual_seed(0) 272 | np.random.seed(0) 273 | random.seed(0) 274 | # test_gemv_correct() 275 | test_bgemv_outer_correct_mha() 276 | test_bgemv_outer_correct_mqa() 277 | # test_gemv_speed() 278 | # test_bgemv_outer_speed() 279 | -------------------------------------------------------------------------------- /quant/matmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import ipdb 3 | import random 4 | import triton 5 | import triton.language as tl 6 | import kivi_gemv 7 | 8 | 9 | @triton.jit 10 | def qbvm_kernel( 11 | bits, 12 | a_ptr, b_ptr, c_ptr, 13 | scales_ptr, zeros_ptr, 14 | M, N, K, 15 | stride_abatch, stride_am, stride_ak, 16 | stride_bbatch, stride_bk, stride_bn, 17 | stride_cbatch, stride_cm, stride_cn, 18 | stride_scales_b, stride_scales_k, stride_scales_g, 19 | stride_zeros_b, stride_zeros_k, stride_zeros_g, 20 | groupsize, 21 | BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 22 | ): 23 | """ 24 | Compute the batch matrix multiplication C = A x B. 25 | A is of shape (B, 1, K) float16 26 | B is of shape (B, K, N//feat_per_int) int32 27 | C is of shape (B, 1, N) float16 28 | scales is of shape (B, K, G) float16 29 | zeros is of shape (B, K, G) float16 30 | groupsize is an int specifying the size of groups for scales and zeros. 31 | G is N // groupsize. 32 | Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient. 33 | 34 | WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. 35 | WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. 36 | WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. 37 | """ 38 | pid_batch = tl.program_id(axis=0) 39 | pid = tl.program_id(axis=1) 40 | feat_per_int = 32 // bits 41 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 42 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 43 | pid_n = pid % num_pid_n 44 | offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) 45 | offs_k = tl.arange(0, BLOCK_SIZE_K) 46 | a_batch_offset = (pid_batch * stride_abatch) 47 | b_batch_offset = (pid_batch * stride_bbatch) 48 | c_batch_offset = (pid_batch * stride_cbatch) 49 | a_ptr = a_ptr + a_batch_offset 50 | b_ptr = b_ptr + b_batch_offset 51 | c_ptr = c_ptr + c_batch_offset 52 | a_ptrs = a_ptr + (offs_k[:, None] * stride_ak) # (BLOCK_SIZE_K, 1) 53 | # a_mask = (offs_am[:, None] < M) 54 | # b_ptrs is set up such that it repeats elements along the N axis feat_per_int times 55 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + (offs_bn[None, :]//feat_per_int) * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 56 | # shifter is used to extract the # bits bits of each element in the 32-bit word from B 57 | shifter = (offs_bn % feat_per_int) * bits 58 | scales_ptr = scales_ptr + pid_batch*stride_scales_b + ((offs_bn[None, :] // groupsize)) * stride_scales_g # (BLOCK_SIZE_N,) 59 | zeros_ptr = zeros_ptr + pid_batch*stride_zeros_b + ((offs_bn[None, :] // groupsize)) * stride_zeros_g # (BLOCK_SIZE_N,) 60 | 61 | # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) 62 | # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension 63 | # So this loop is along the infeatures dimension (K) 64 | # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel 65 | # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 66 | accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) 67 | num = 0xFF >> (8-bits) 68 | for pid_k in range(0, num_pid_k): 69 | offs_bk = (offs_k[:, None] + pid_k * BLOCK_SIZE_K) 70 | # offs_k[None, :] < K - pid_k * BLOCK_SIZE_K 71 | a = tl.load(a_ptrs, mask=offs_bk < K, other=0.) # (1, BLOCK_SIZE_K) 72 | b = tl.load(b_ptrs, mask=offs_bk < K, other=0.) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 73 | ptr = scales_ptr + offs_bk * stride_scales_k 74 | scales = tl.load(ptr, mask=offs_bk < K, other=0.) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 75 | ptr = zeros_ptr + offs_bk * stride_zeros_k 76 | zeros = tl.load(ptr, mask=offs_bk < K, other=0.) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 77 | # Now we need to unpack b into 32-bit values 78 | # tl.device_print("scale ",scales.dtype) 79 | # tl.device_print("zeros ",zeros.dtype) 80 | b = (b >> shifter[None, :]) & num # For 4-bit values, bit_op_num is 0xF 81 | b = b * scales + zeros # Scale and shift 82 | accumulator += tl.sum(a * b, 0) # tl.dot(a, b) 83 | # if pid_m == 0 and pid_n == 0: 84 | # tl.device_print("hello ", tl.dot(a, b).shape) 85 | a_ptrs += BLOCK_SIZE_K * stride_ak 86 | b_ptrs += BLOCK_SIZE_K * stride_bk 87 | c = accumulator # .to(tl.float16) 88 | # c = accumulator 89 | # Store the result 90 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 91 | c_ptrs = c_ptr + stride_cn * offs_cn 92 | c_mask = (offs_cn < N) 93 | tl.store(c_ptrs, c, mask=c_mask) 94 | 95 | 96 | def understand_code(): 97 | M, N, K = 512, 256, 256 98 | BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M = 64, 64, 4 99 | total_program_id = triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N) 100 | for pid in range(0, total_program_id): 101 | num_pid_m = triton.cdiv(M, BLOCK_SIZE_M) 102 | num_pid_n = triton.cdiv(N, BLOCK_SIZE_N) 103 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 104 | group_id = pid // num_pid_in_group 105 | first_pid_m = group_id * GROUP_SIZE_M 106 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 107 | pid_m = first_pid_m + (pid % group_size_m) 108 | pid_n = (pid % num_pid_in_group) // group_size_m 109 | print(f"pid={pid}, pid_m={pid_m}, pid_n={pid_n}") 110 | 111 | 112 | def triton_bmm_fA_qB_outer(group_size: int, 113 | fA: torch.FloatTensor, 114 | qB: torch.IntTensor, 115 | scales: torch.FloatTensor, 116 | zeros: torch.FloatTensor, 117 | bits: int) -> torch.FloatTensor: 118 | """ 119 | Compute the matrix multiplication C = query x key. 120 | Where key is quantized into 2-bit values. 121 | 122 | fA is of shape (B, nh, M, K) float16 123 | qB is of shape (B, nh, K, N // feat_per_int) int32 124 | scales is of shape (B, nh, K, G) float16 125 | zeros is of shape (B, nh, K, G) float16 126 | 127 | groupsize is the number of outer dimensions in each group. 128 | G = N // groupsize 129 | 130 | Returns C of shape (B, nh, M, N) float16 131 | """ 132 | assert len(fA.shape) == 4 and len(qB.shape) == 4 133 | B, nh, M, K = fA.shape 134 | feat_per_int = 32 // bits 135 | # flatten to a 3D tensor 136 | fA = fA.view(-1, M, K) 137 | N = qB.shape[-1] * feat_per_int 138 | qB = qB.reshape(-1, K, qB.shape[-1]) 139 | # This is based on the possible BLOCK_SIZE_Ks 140 | # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" 141 | # This is based on the possible BLOCK_SIZE_Ns 142 | assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" 143 | # This is based on the possible BLOCK_SIZE_Ks 144 | assert group_size % 64 == 0, "groupsize must be a multiple of 64, and 128" 145 | flatten_B = B * nh 146 | c = torch.empty((flatten_B, M, N), device='cuda', dtype=torch.float16) 147 | # print(f'M {M} N {N} K {K}') 148 | grid = lambda META: ( 149 | flatten_B, triton.cdiv(N, META['BLOCK_SIZE_N']), 150 | ) 151 | scales = scales.view(flatten_B, scales.shape[-2], scales.shape[-1]) 152 | zeros = zeros.view(flatten_B, zeros.shape[-2], zeros.shape[-1]) 153 | if N > K: 154 | BLOCK_SIZE_N = 128 155 | BLOCK_SIZE_K = 32 156 | num_warps=4 # 157 | else: 158 | BLOCK_SIZE_N = 32 159 | BLOCK_SIZE_K = 128 160 | num_warps = 2 161 | num_stages= 7 if K > 64 else 3 # 162 | qbvm_kernel[grid]( 163 | bits, 164 | fA, qB, c, 165 | scales, zeros, 166 | M, N, K, 167 | fA.stride(0), fA.stride(1), fA.stride(2), 168 | qB.stride(0), qB.stride(1), qB.stride(2), 169 | c.stride(0), c.stride(1), c.stride(2), 170 | scales.stride(0), scales.stride(1), scales.stride(2), 171 | zeros.stride(0), zeros.stride(1), scales.stride(2), 172 | group_size, BLOCK_SIZE_N, BLOCK_SIZE_K, 173 | num_warps=num_warps, num_stages=num_stages 174 | ) 175 | return c.view(B, nh, c.shape[-2], c.shape[-1]) 176 | 177 | 178 | def cuda_bmm_fA_qB_outer(group_size: int, 179 | fA: torch.FloatTensor, 180 | qB: torch.IntTensor, 181 | scales: torch.FloatTensor, 182 | zeros: torch.FloatTensor, 183 | bits: int) -> torch.FloatTensor: 184 | """ 185 | Compute the matrix multiplication C = query x key. 186 | Where key is quantized into 2-bit values. 187 | 188 | fA is of shape (B, nh, M, K) float16 189 | qB is of shape (B, nh, K, N // feat_per_int) int32 190 | scales is of shape (B, nh, K, G) float16 191 | zeros is of shape (B, nh, K, G) float16 192 | 193 | groupsize is the number of outer dimensions in each group. 194 | G = N // groupsize 195 | 196 | Returns C of shape (B, nh, M, N) float16 197 | """ 198 | assert len(fA.shape) == 4 and len(qB.shape) == 4 199 | B, nh, M, K = fA.shape 200 | nh_kv = qB.shape[1] 201 | feat_per_int = 32 // bits 202 | # flatten to a 3D tensor 203 | fA = fA.view(-1, M, K).contiguous() 204 | N = qB.shape[-1] * feat_per_int 205 | qB = qB.reshape(-1, K, qB.shape[-1]).transpose(1, 2).contiguous() 206 | # This is based on the possible BLOCK_SIZE_Ks 207 | # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" 208 | # This is based on the possible BLOCK_SIZE_Ns 209 | # assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" 210 | # This is based on the possible BLOCK_SIZE_Ks 211 | # assert group_size % 64 == 0, "groupsize must be a multiple of 64, and 128" 212 | flatten_B = B * nh_kv 213 | scales = scales.view(flatten_B, scales.shape[-2], scales.shape[-1]).transpose(1, 2).contiguous() 214 | zeros = zeros.view(flatten_B, zeros.shape[-2], zeros.shape[-1]).transpose(1, 2).contiguous() 215 | assert bits in [2, 4] 216 | assert nh % nh_kv == 0 217 | c = kivi_gemv.gemv_forward_cuda_outer_dim(fA, qB, scales, zeros, bits, group_size, nh, nh_kv) 218 | c = c.view(B, nh, c.shape[-2], c.shape[-1]) 219 | return c 220 | -------------------------------------------------------------------------------- /quant/new_pack.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import random 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def quant_and_pack_kcache(k: torch.FloatTensor, group_size: int, bits: int): 9 | assert len(k.shape) == 4 10 | shape = k.shape 11 | B, nh, T, D = shape 12 | # ================== Get Scale & Zeros =============== 13 | assert T % group_size == 0 14 | num_groups = T // group_size 15 | new_shape = (B, nh, num_groups, group_size, D) 16 | # Quantize 17 | max_int = 2 ** bits - 1 18 | data = k.view(new_shape) 19 | mn = torch.min(data, dim=-2, keepdim=True)[0] 20 | mx = torch.max(data, dim=-2, keepdim=True)[0] 21 | scale = (mx - mn) / max_int 22 | data = data - mn 23 | data.div_(scale) 24 | data = data.clamp_(0, max_int).round_().to(torch.int32) 25 | data = data.view(shape) 26 | code = pack_tensor(data, bits, pack_dim=2) 27 | return code, scale, mn 28 | 29 | 30 | def quant_and_pack_vcache(v: torch.FloatTensor, group_size: int, bits: int): 31 | shape = v.shape 32 | assert len(shape) == 4 33 | assert v.shape[-1] % group_size == 0 34 | num_groups = shape[-1] // group_size 35 | new_shape = (shape[:-1] + (num_groups, group_size)) 36 | # Quantize 37 | max_int = 2 ** bits - 1 38 | data = v.view(new_shape) 39 | mn = torch.min(data, dim=-1, keepdim=True)[0] 40 | mx = torch.max(data, dim=-1, keepdim=True)[0] 41 | scale = (mx - mn) / max_int 42 | data = data - mn 43 | data.div_(scale) 44 | data = data.clamp_(0, max_int).round_().to(torch.int32) 45 | data = data.view(shape) 46 | # Pack 47 | code = pack_tensor(data, bits, pack_dim=3) 48 | return code, scale, mn 49 | 50 | 51 | def unpack_and_dequant_kcache(k_code: torch.FloatTensor, 52 | scale: torch.FloatTensor, 53 | mn: torch.FloatTensor, 54 | group_size: int, 55 | bits: int, 56 | ): 57 | pack_dim = 2 58 | assert bits in [2, 4, 8] 59 | assert len(k_code.shape) == 4 60 | data = unpack_tensor(k_code, bits, pack_dim=pack_dim) 61 | shape = data.shape 62 | num_groups = shape[pack_dim] // group_size 63 | data = data.view(shape[:pack_dim] + (num_groups, group_size,) + shape[pack_dim+1:]) 64 | data = data.to(torch.float16) 65 | data = data * scale + mn 66 | return data.view(shape) 67 | 68 | 69 | def unpack_and_dequant_vcache(v_code: torch.FloatTensor, 70 | scale: torch.FloatTensor, 71 | mn: torch.FloatTensor, 72 | group_size: int, 73 | bits: int, 74 | ): 75 | assert bits in [2, 4, 8] 76 | assert len(v_code.shape) == 4 77 | data = unpack_tensor(v_code, bits, pack_dim=3) 78 | shape = data.shape 79 | num_groups = shape[-1] // group_size 80 | data = data.view(shape[:-1] + (num_groups, group_size,)) 81 | data = data.to(torch.float16) 82 | data = data * scale + mn 83 | return data.view(shape) 84 | 85 | 86 | def pack_tensor(data, bits, pack_dim): 87 | # Pack 88 | shape = data.shape 89 | feat_per_int = 32 // bits 90 | assert bits in [2,4,8], "Only 2, 4, 8 bits are supported" 91 | assert shape[pack_dim] % feat_per_int == 0, "Dimension length must be divisible by number of features per int" 92 | # BS, nh, T, nd // 16 # 16 is for 2bit 93 | code = torch.zeros(shape[:pack_dim] + (shape[pack_dim] // feat_per_int,)+shape[pack_dim+1:], 94 | dtype=torch.int32, 95 | device=data.device) 96 | i = 0 97 | row = 0 98 | unpacked_indices = [slice(None)] * len(data.shape) 99 | packed_indices = [slice(None)] * len(data.shape) 100 | while row < code.shape[pack_dim]: 101 | packed_indices[pack_dim] = row 102 | for j in range(i, i + (32 // bits)): 103 | unpacked_indices[pack_dim] = j 104 | code[packed_indices] |= data[unpacked_indices] << (bits * (j - i)) 105 | i += 32 // bits 106 | row += 1 107 | return code 108 | 109 | 110 | def unpack_tensor(v_code: torch.FloatTensor, 111 | bits: int, 112 | pack_dim: int): 113 | assert bits in [2,4,8] 114 | shape = v_code.shape 115 | feat_per_int = 32 // bits 116 | new_shape = shape[:pack_dim] + (shape[pack_dim] * feat_per_int,) + shape[pack_dim+1:] 117 | unpacked_v_code = torch.zeros(new_shape, dtype=torch.int8, device=v_code.device) 118 | i = torch.arange(new_shape[pack_dim], device=v_code.device) // feat_per_int 119 | j = torch.arange(new_shape[pack_dim], device=v_code.device) % feat_per_int 120 | num = 0xFF >> (8 - bits) 121 | packed_indices = [slice(None)] * len(new_shape) 122 | packed_indices[pack_dim] = i 123 | if pack_dim == 2: 124 | unpacked_v_code = ((v_code[packed_indices] >> (j * bits)[None, None, :, None]).to(torch.int16)) & num 125 | elif pack_dim == 3: 126 | unpacked_v_code = ((v_code[packed_indices] >> (j * bits)).to(torch.int16)) & num 127 | else: 128 | raise NotImplementedError 129 | return unpacked_v_code 130 | 131 | 132 | @triton.jit 133 | def _pack_along_last_dim( 134 | bits: tl.constexpr, 135 | intensor_ptr, 136 | code_ptr, 137 | N, 138 | num_feats: tl.constexpr, 139 | feat_per_int: tl.constexpr, 140 | BLOCK_SIZE_N: tl.constexpr 141 | ): 142 | num_int_per_y_dim = num_feats // feat_per_int 143 | bid = tl.program_id(axis=0) 144 | yid = tl.program_id(axis=1) 145 | offs_N = bid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 146 | block_start = intensor_ptr + offs_N * num_feats + yid * feat_per_int # offset of the first element at current tile 147 | packed = tl.zeros((BLOCK_SIZE_N,), dtype=tl.int32) 148 | for i in range(feat_per_int): 149 | ptr = block_start + i 150 | element = tl.load(ptr, mask=offs_N= 128: 12 | size_multiplier = 1 13 | elif group_size == 64: 14 | size_multiplier = 2 15 | elif group_size == 32: 16 | size_multiplier = 4 17 | else: 18 | raise NotImplementedError 19 | 20 | base_width = make_divisible(in_features // group_size, pack_num) 21 | base_width = make_divisible(base_width, size_multiplier) * size_multiplier 22 | return base_width 23 | 24 | 25 | def dequantize_weight(qweight, d_out, d_in, w_bit, scales, zeros, group_size): 26 | data = qweight.reshape(-1) 27 | N, num_features = d_out, d_in 28 | weight_fp = dequant_cuda.unpack_single_precision(data, w_bit, scales, zeros, N, 29 | num_features // group_size, group_size) 30 | return weight_fp.view(d_out, d_in) 31 | 32 | 33 | class MatMul4Bit(torch.autograd.Function): 34 | # forward is the same, but we added the fallback for pre-turing GPUs 35 | # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") 36 | 37 | @staticmethod 38 | def forward(ctx, A, qweight, bias, d_out, d_in, w_bit, scales, zeros, group_size): 39 | # default of pytorch behavior if inputs are empty 40 | # 1. Dequantize 41 | # 2. MatmulnN 42 | weight_fp = dequantize_weight(qweight, d_out, d_in, w_bit, scales, zeros, group_size) 43 | output = torch.nn.functional.linear(A, weight_fp.to(A.dtype), bias) 44 | # 3. Save state 45 | ctx.state = (d_out, d_in, w_bit, scales, zeros, group_size) 46 | ctx.tensors = qweight 47 | return output 48 | 49 | 50 | @staticmethod 51 | def backward(ctx, grad_output): 52 | req_gradA, _, req_gradBias = ctx.needs_input_grad[:3] 53 | qweight = ctx.tensors 54 | d_out, d_in, w_bit, scales, zeros, group_size = ctx.state 55 | 56 | grad_A, grad_bias = None, None 57 | 58 | if req_gradBias: 59 | # compute grad_bias first before changing grad_output dtype 60 | grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) 61 | 62 | # not supported by PyTorch. TODO: create work-around 63 | #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) 64 | if req_gradA: 65 | weight_fp = dequantize_weight(qweight, d_out, d_in, w_bit, scales, zeros, group_size) 66 | grad_A = torch.matmul(grad_output, weight_fp.to(grad_output.dtype)) 67 | if grad_A.isnan().any(): 68 | import ipdb; ipdb.set_trace() 69 | # print(grad_A.norm()) 70 | return grad_A, None, grad_bias, None, None, None, None, None, None 71 | 72 | 73 | class WQLinearForTrain(torch.nn.Module): 74 | def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): 75 | super().__init__() 76 | 77 | if w_bit not in [4]: 78 | raise NotImplementedError("Only 4-bit are supported for now.") 79 | 80 | self.in_features = in_features 81 | self.out_features = out_features 82 | self.w_bit = w_bit 83 | self.group_size = group_size if group_size != -1 else in_features 84 | # quick sanity check (make sure aligment) 85 | assert self.in_features % self.group_size == 0 86 | assert out_features % (32 // self.w_bit) == 0 87 | pack_num = (32 // self.w_bit) 88 | self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev)) 89 | self.register_buffer('zeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev)) 90 | self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev)) 91 | if bias: 92 | self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev)) 93 | else: 94 | self.bias = None 95 | 96 | 97 | def forward(self, x): 98 | # weight_fp = self.dequantize_weight().half() 99 | # out = torch.matmul(x, weight_fp.T) 100 | # out = out + self.bias if self.bias is not None else out 101 | 102 | out = MatMul4Bit.apply(x, self.qweight, self.bias, 103 | self.out_features, self.in_features, 104 | self.w_bit, self.scales, 105 | self.zeros, self.group_size) 106 | return out 107 | 108 | def dequantize_weight(self): 109 | data = self.qweight.reshape(-1) 110 | N, num_features = self.out_features, self.in_features 111 | weight_fp = dequant_cuda.unpack_single_precision(data, self.w_bit, self.scales, self.zeros, N, 112 | num_features // self.group_size, self.group_size) 113 | return weight_fp.view(self.out_features, self.in_features) 114 | 115 | 116 | @classmethod 117 | def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None): 118 | q_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device) 119 | if init_only: # just prepare for loading sd 120 | return q_linear 121 | quantized, scales, mn = quantize_and_pack(linear.weight, group_size, w_bit, simulate=False) 122 | q_linear.qweight = quantized 123 | q_linear.scales = scales 124 | q_linear.zeros = mn 125 | return q_linear -------------------------------------------------------------------------------- /quant/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 3 | 4 | 5 | extra_compile_args = { 6 | "cxx": [ 7 | "-g", 8 | "-O3", 9 | "-fopenmp", 10 | "-lgomp", 11 | "-std=c++17", 12 | "-DENABLE_BF16" 13 | ], 14 | "nvcc": [ 15 | "-O3", 16 | "-std=c++17", 17 | "-DENABLE_BF16", # TODO 18 | "-U__CUDA_NO_HALF_OPERATORS__", 19 | "-U__CUDA_NO_HALF_CONVERSIONS__", 20 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 21 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 22 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 23 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 24 | "--expt-relaxed-constexpr", 25 | "--expt-extended-lambda", 26 | "--use_fast_math", 27 | "--threads=8" 28 | ], 29 | } 30 | 31 | setup( 32 | name="kivi_gemv", 33 | packages=find_packages(), 34 | ext_modules=[ 35 | CUDAExtension( 36 | name="kivi_gemv", 37 | sources=[ 38 | "csrc/pybind.cpp", 39 | "csrc/gemv_cuda.cu" 40 | ], 41 | extra_compile_args=extra_compile_args, 42 | ), 43 | ], 44 | cmdclass={"build_ext": BuildExtension}, 45 | install_requires=["torch"], 46 | ) -------------------------------------------------------------------------------- /quant/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"]="2" 4 | import numpy as np 5 | import random 6 | # import ipdb 7 | import math 8 | import os 9 | import triton 10 | from new_pack import quant_and_pack_vcache, unpack_and_dequant_kcache, triton_quantize_and_pack_along_last_dim, unpack_and_dequant_vcache, quant_and_pack_kcache 11 | from matmul import triton_bmm_fA_qB_outer 12 | from timeit_v2 import py_benchmark 13 | 14 | 15 | def set_seed(seed): 16 | np.random.seed(seed) 17 | torch.random.manual_seed(seed) 18 | random.seed(seed) 19 | 20 | 21 | def test_vcache(): 22 | torch.manual_seed(0) 23 | np.random.seed(0) 24 | random.seed(0) 25 | B, nh, T, hd = 555, 32, 433, 128 26 | v = torch.randn((B, nh, T, hd), device='cuda', dtype=torch.float16) 27 | group_size = 64 28 | for bits in [2, 4, 8]: 29 | code, scale, mn = triton_quantize_and_pack_along_last_dim(v, group_size, bits) 30 | # print(f'bit {bits}, scale.shape: {scale.shape}') 31 | # print(f'bit {bits}, code.shape: {code.shape}') 32 | dequant_v = unpack_and_dequant_vcache(code, scale.unsqueeze(-1), mn.unsqueeze(-1), group_size, bits) 33 | assert not dequant_v.isnan().any() 34 | gap = (dequant_v - v) / v 35 | gap = torch.nan_to_num(gap) 36 | print(f'bit {bits}, mean v rel arr: {torch.mean(torch.abs(gap))}') 37 | 38 | 39 | def test_kcache(): 40 | torch.manual_seed(0) 41 | np.random.seed(0) 42 | random.seed(0) 43 | BS, nh, T, D = 11, 32, 4096, 128 44 | k = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16) 45 | group_size = 64 46 | for bits in [2, 4, 8]: 47 | code, scale, mn = triton_quantize_and_pack_along_last_dim(k.transpose(2, 3).contiguous(), 48 | group_size, 49 | bits) 50 | dequant_k = unpack_and_dequant_vcache(code, scale.unsqueeze(-1), mn.unsqueeze(-1), group_size, bits) 51 | assert not dequant_k.isnan().any() 52 | gap = (dequant_k.transpose(2, 3) - k) / k 53 | gap = torch.nan_to_num(gap) 54 | print(f'bit {bits}, k mean rel arr: {torch.mean(torch.abs(gap))}') 55 | 56 | 57 | def test_bmm_speed(): 58 | BS, nh, T, D = 64, 32, 512, 128 59 | bits = 2 60 | key_state = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16) 61 | val_state = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16) 62 | group_size = 64 63 | query_len = 1 64 | query_state = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16) 65 | 66 | # quantiles = [0.5, 0.2, 0.8] 67 | # ms, min_ms, max_ms = triton.testing.do_bench( 68 | # lambda: triton_quantize_and_pack_along_last_dim(key_state.transpose(2,3).contiguous(), 69 | # group_size, bits), quantiles=quantiles) 70 | # print(f'batch size {BS} nh {nh} seqlen {T} quant and pack pytorch impl: {ms * 1000: .2f} ms') 71 | code, scale, mn = triton_quantize_and_pack_along_last_dim( 72 | key_state.transpose(2,3).contiguous(), group_size, bits) 73 | code = code.contiguous() 74 | scale = scale.contiguous() 75 | mn = mn.contiguous() 76 | 77 | stmt = "triton_quantize_and_pack_along_last_dim(key_state.transpose(2,3).contiguous(), group_size, bits)" 78 | t_triton_quant = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3, 79 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 80 | print(f'our triton quant & pack impl: {t_triton_quant * 1000} ms') 81 | stmt = "quant_and_pack_kcache(key_state, group_size, bits)" 82 | t_quant = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3, 83 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 84 | print(f'vanilla pytorch quant & pack impl: {t_quant * 1000} ms') 85 | stmt = 'triton_bmm_fA_qB_outer(group_size, query_state, code, scale, mn, bits)' 86 | t_qk = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3, 87 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 88 | print(f'batch size {BS} seqlen {T} our fused batch qk impl: {t_qk * 1000: .2f} ms') 89 | stmt = 'torch.matmul(query_state, key_state.transpose(2, 3))' 90 | t_qk_ref = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3, 91 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 92 | print(f'batch size {BS} seqlen {T} pytorch batch qk impl: {t_qk_ref * 1000: .2f} ms') 93 | attn_weight = torch.randn((BS, nh, query_len, T), device='cuda', dtype=torch.float16) 94 | code, scale, mn = triton_quantize_and_pack_along_last_dim( 95 | val_state, group_size, bits) 96 | stmt = 'triton_bmm_fA_qB_outer(group_size, attn_weight, code, scale, mn, bits)' 97 | t_av = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3, 98 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 99 | print(f'batch size {BS} seqlen {T} our fused batch av impl: {t_av * 1000: .2f} ms') 100 | stmt = 'torch.matmul(attn_weight, val_state)' 101 | t_av_ref = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3, 102 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()") 103 | print(f'batch size {BS} seqlen {T} pytorch batch av impl: {t_av_ref * 1000: .2f} ms') 104 | 105 | # _code, _scale, _mn = quant_and_pack_kcache( 106 | # key_state, group_size, bits) 107 | # _code = _code.transpose(2, 3) 108 | # _scale = _scale.squeeze(-2).transpose(2,3) 109 | # _mn = _mn.squeeze(-2).transpose(2,3) 110 | # print(_code.shape, code.shape, _code.dtype, code.dtype) 111 | # print(_scale.shape, scale.shape, _scale.dtype, scale.dtype) 112 | 113 | # our_out = triton_bmm_fA_qB_outer(group_size, query_state, code, scale, mn, bits) 114 | # ref_out = torch.matmul(query_state, key_state.transpose(2, 3)) 115 | # gap = (our_out - ref_out) / ref_out 116 | # gap = torch.nan_to_num(gap) 117 | # err = torch.mean(torch.abs(gap)).item() 118 | # print(f'bits {bits}, err: {err}') 119 | # ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_bmm_fA_qB_outer(group_size, query_state, code, scale, mn, bits), quantiles=quantiles) 120 | # print(f'batch size {BS} seqlen {T} our fused batch matmul impl: {ms * 1000: .2f} ms') 121 | # ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(query_state, key_state.transpose(2, 3)), quantiles=quantiles) 122 | # print(f'batch size {BS} seqlen {T} pytorch batch matmul impl: {ms * 1000: .2f} ms') 123 | 124 | 125 | def test_streaming_kvcache(): 126 | BS, nh, T, D = 1, 32, 340, 128 127 | our_attn_output = None 128 | group_size = 64 129 | query_len = 1 130 | bits = 2 131 | key_states = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16) 132 | value_states = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16) 133 | key_states_quant = key_states[:, :, :-(key_states.shape[-2] % group_size), :].contiguous() 134 | key_states_full = key_states[:, :, -(key_states.shape[-2] % group_size):, :].contiguous() 135 | value_states_quant, value_scale, value_mn = triton_quantize_and_pack_along_last_dim(value_states, 136 | group_size, 137 | bits) 138 | key_states_quant_trans, key_scale_trans, key_mn_trans = triton_quantize_and_pack_along_last_dim(key_states_quant.transpose(2, 3).contiguous(), 139 | group_size, bits) 140 | for i in range(16): 141 | if our_attn_output is None: 142 | query_states = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16) 143 | else: 144 | query_states = our_attn_output 145 | key_states_new = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16) 146 | value_states_new = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16) 147 | att_qkquant = triton_bmm_fA_qB_outer(group_size, query_states, key_states_quant_trans, 148 | key_scale_trans, key_mn_trans, bits) 149 | key_states_full = torch.cat([key_states_full, key_states_new], dim=2) 150 | att_qkfull = torch.matmul(query_states, key_states_full.transpose(2, 3)) 151 | our_att_weights = torch.cat([att_qkquant, att_qkfull], dim=-1) / math.sqrt(D) 152 | our_att_weights = torch.softmax(our_att_weights, dim=-1) 153 | value_states_quant_new, scale, mn = triton_quantize_and_pack_along_last_dim(value_states_new, 154 | group_size, 155 | bits) 156 | value_states_quant = torch.cat([value_states_quant, value_states_quant_new], dim=2) 157 | value_scale = torch.cat([value_scale, scale], dim=2) 158 | value_mn = torch.cat([value_mn, mn], dim=2) 159 | our_attn_output = triton_bmm_fA_qB_outer(group_size, our_att_weights, value_states_quant, 160 | value_scale, value_mn, bits) 161 | # === 162 | key_states = torch.cat([key_states, key_states_new], dim=2) 163 | value_states = torch.cat([value_states, value_states_new], dim=2) 164 | ref_att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(D) 165 | ref_att_weights = torch.softmax(ref_att_weights, dim=-1) 166 | ref_att_out = torch.matmul(ref_att_weights, value_states) 167 | att_weight_gap = (ref_att_weights - our_att_weights) / ref_att_weights 168 | print(f'i {i} bit {bits}, mean att weight rel arr: {torch.mean(torch.abs(att_weight_gap))}') 169 | att_out_gap = (ref_att_out - our_attn_output) / ref_att_out 170 | print(f'i {i} bit {bits}, mean att out rel arr: {torch.mean(torch.abs(att_out_gap))}') 171 | 172 | 173 | def test_4d_qmatmul(): 174 | torch.manual_seed(0) 175 | np.random.seed(0) 176 | random.seed(0) 177 | query_len = 1 178 | BS, nh, T, D = 16, 32, 1024, 128 179 | group_size = 64 180 | # k = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16) 181 | # query_state = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16) 182 | k = torch.randint(10, (BS, nh, T, D), device='cuda').to(torch.float16) 183 | query_state = torch.randint(5, (BS, nh, query_len, D), device='cuda').to(torch.float16) 184 | for bits in [8, 4, 2]: 185 | # code.shape == BS, nh, T // feat_per_int, D 186 | # scale, mn.shape == BS, nh, ng, 1, D 187 | code, scale, mn = quant_and_pack_kcache(k, group_size, bits) 188 | dequant_k = unpack_and_dequant_kcache(code, scale, mn, group_size, bits) 189 | # BS, nh, D, T // feat_per_int 190 | code = code.transpose(2, 3) 191 | # BS, nh, D, T // group_size 192 | scale = scale.view(BS, nh, -1, D).transpose(2, 3) 193 | mn = mn.view(BS, nh, -1, D).transpose(2, 3) 194 | our_out = triton_bmm_fA_qB_outer(group_size, query_state, code, scale, mn, bits) 195 | ref_out = torch.matmul(query_state, k.transpose(2, 3)) 196 | # ref_out = torch.matmul(query_state, k.transpose(2, 3)) 197 | assert not our_out.isnan().any() 198 | assert not ref_out.isnan().any() 199 | gap = (our_out - ref_out) / ref_out 200 | gap = torch.nan_to_num(gap) 201 | err = torch.mean(torch.abs(gap)).item() 202 | print(f'bits {bits}, err: {err}') 203 | 204 | 205 | if __name__ == '__main__': 206 | set_seed(114514) 207 | # test_kcache() 208 | # test_vcache() 209 | # test_4d_qmatmul() 210 | # test_streaming_kvcache() 211 | test_bmm_speed() -------------------------------------------------------------------------------- /quant/timeit_v2.py: -------------------------------------------------------------------------------- 1 | # timeit_v2.py: Copied from the default library with the following two modifiations 2 | # 1. Add 'finish' argument to timeit for calling cuda synchronization. 3 | # 2. Add accurate measurment utility function py_benchmark 4 | 5 | """Tool for measuring execution time of small code snippets. 6 | 7 | This module avoids a number of common traps for measuring execution 8 | times. See also Tim Peters' introduction to the Algorithms chapter in 9 | the Python Cookbook, published by O'Reilly. 10 | 11 | Library usage: see the Timer class. 12 | 13 | Command line usage: 14 | python timeit.py [-n N] [-r N] [-s S] [-p] [-h] [--] [statement] 15 | 16 | Options: 17 | -n/--number N: how many times to execute 'statement' (default: see below) 18 | -r/--repeat N: how many times to repeat the timer (default 5) 19 | -s/--setup S: statement to be executed once initially (default 'pass'). 20 | Execution time of this setup statement is NOT timed. 21 | -p/--process: use time.process_time() (default is time.perf_counter()) 22 | -v/--verbose: print raw timing results; repeat for more digits precision 23 | -u/--unit: set the output time unit (nsec, usec, msec, or sec) 24 | -h/--help: print this usage message and exit 25 | --: separate options from statement, use when statement starts with - 26 | statement: statement to be timed (default 'pass') 27 | 28 | A multi-line statement may be given by specifying each line as a 29 | separate argument; indented lines are possible by enclosing an 30 | argument in quotes and using leading spaces. Multiple -s options are 31 | treated similarly. 32 | 33 | If -n is not given, a suitable number of loops is calculated by trying 34 | successive powers of 10 until the total time is at least 0.2 seconds. 35 | 36 | Note: there is a certain baseline overhead associated with executing a 37 | pass statement. It differs between versions. The code here doesn't try 38 | to hide it, but you should be aware of it. The baseline overhead can be 39 | measured by invoking the program without arguments. 40 | 41 | Classes: 42 | 43 | Timer 44 | 45 | Functions: 46 | 47 | timeit(string, string) -> float 48 | repeat(string, string) -> list 49 | default_timer() -> float 50 | """ 51 | 52 | import gc 53 | import sys 54 | import time 55 | import itertools 56 | 57 | __all__ = ["Timer", "timeit", "repeat", "default_timer"] 58 | 59 | dummy_src_name = "" 60 | default_number = 1000000 61 | default_repeat = 5 62 | default_timer = time.perf_counter 63 | 64 | _globals = globals 65 | 66 | # Don't change the indentation of the template; the reindent() calls 67 | # in Timer.__init__() depend on setup being indented 4 spaces and stmt 68 | # being indented 8 spaces. 69 | template = """ 70 | def inner(_it, _timer{init}): 71 | {setup} 72 | _t0 = _timer() 73 | for _i in _it: 74 | {stmt} 75 | {finish} 76 | _t1 = _timer() 77 | return _t1 - _t0 78 | """ 79 | 80 | def reindent(src, indent): 81 | """Helper to reindent a multi-line statement.""" 82 | return src.replace("\n", "\n" + " "*indent) 83 | 84 | class Timer: 85 | """Class for timing execution speed of small code snippets. 86 | 87 | The constructor takes a statement to be timed, an additional 88 | statement used for setup, and a timer function. Both statements 89 | default to 'pass'; the timer function is platform-dependent (see 90 | module doc string). If 'globals' is specified, the code will be 91 | executed within that namespace (as opposed to inside timeit's 92 | namespace). 93 | 94 | To measure the execution time of the first statement, use the 95 | timeit() method. The repeat() method is a convenience to call 96 | timeit() multiple times and return a list of results. 97 | 98 | The statements may contain newlines, as long as they don't contain 99 | multi-line string literals. 100 | """ 101 | 102 | def __init__(self, stmt="pass", setup="pass", finish='pass', timer=default_timer, 103 | globals=None): 104 | """Constructor. See class doc string.""" 105 | self.timer = timer 106 | local_ns = {} 107 | global_ns = _globals() if globals is None else globals 108 | init = '' 109 | if isinstance(setup, str): 110 | # Check that the code can be compiled outside a function 111 | compile(setup, dummy_src_name, "exec") 112 | stmtprefix = setup + '\n' 113 | setup = reindent(setup, 4) 114 | elif callable(setup): 115 | local_ns['_setup'] = setup 116 | init += ', _setup=_setup' 117 | stmtprefix = '' 118 | setup = '_setup()' 119 | else: 120 | raise ValueError("setup is neither a string nor callable") 121 | if isinstance(stmt, str): 122 | # Check that the code can be compiled outside a function 123 | compile(stmtprefix + stmt, dummy_src_name, "exec") 124 | stmt = reindent(stmt, 8) 125 | elif callable(stmt): 126 | local_ns['_stmt'] = stmt 127 | init += ', _stmt=_stmt' 128 | stmt = '_stmt()' 129 | else: 130 | raise ValueError("stmt is neither a string nor callable") 131 | 132 | assert isinstance(finish, str) 133 | compile(setup + '\n' + stmt + '\n' + finish, dummy_src_name, 'exec') 134 | finish = reindent(finish, 4) 135 | 136 | src = template.format(stmt=stmt, setup=setup, init=init, finish=finish) 137 | self.src = src # Save for traceback display 138 | code = compile(src, dummy_src_name, "exec") 139 | exec(code, global_ns, local_ns) 140 | self.inner = local_ns["inner"] 141 | 142 | def print_exc(self, file=None): 143 | """Helper to print a traceback from the timed code. 144 | 145 | Typical use: 146 | 147 | t = Timer(...) # outside the try/except 148 | try: 149 | t.timeit(...) # or t.repeat(...) 150 | except: 151 | t.print_exc() 152 | 153 | The advantage over the standard traceback is that source lines 154 | in the compiled template will be displayed. 155 | 156 | The optional file argument directs where the traceback is 157 | sent; it defaults to sys.stderr. 158 | """ 159 | import linecache, traceback 160 | if self.src is not None: 161 | linecache.cache[dummy_src_name] = (len(self.src), 162 | None, 163 | self.src.split("\n"), 164 | dummy_src_name) 165 | # else the source is already stored somewhere else 166 | 167 | traceback.print_exc(file=file) 168 | 169 | def timeit(self, number=default_number): 170 | """Time 'number' executions of the main statement. 171 | 172 | To be precise, this executes the setup statement once, and 173 | then returns the time it takes to execute the main statement 174 | a number of times, as a float measured in seconds. The 175 | argument is the number of times through the loop, defaulting 176 | to one million. The main statement, the setup statement and 177 | the timer function to be used are passed to the constructor. 178 | """ 179 | it = itertools.repeat(None, number) 180 | gcold = gc.isenabled() 181 | gc.disable() 182 | try: 183 | timing = self.inner(it, self.timer) 184 | finally: 185 | if gcold: 186 | gc.enable() 187 | return timing 188 | 189 | def repeat(self, repeat=default_repeat, number=default_number): 190 | """Call timeit() a few times. 191 | 192 | This is a convenience function that calls the timeit() 193 | repeatedly, returning a list of results. The first argument 194 | specifies how many times to call timeit(), defaulting to 5; 195 | the second argument specifies the timer argument, defaulting 196 | to one million. 197 | 198 | Note: it's tempting to calculate mean and standard deviation 199 | from the result vector and report these. However, this is not 200 | very useful. In a typical case, the lowest value gives a 201 | lower bound for how fast your machine can run the given code 202 | snippet; higher values in the result vector are typically not 203 | caused by variability in Python's speed, but by other 204 | processes interfering with your timing accuracy. So the min() 205 | of the result is probably the only number you should be 206 | interested in. After that, you should look at the entire 207 | vector and apply common sense rather than statistics. 208 | """ 209 | r = [] 210 | for i in range(repeat): 211 | t = self.timeit(number) 212 | r.append(t) 213 | return r 214 | 215 | def autorange(self, callback=None): 216 | """Return the number of loops and time taken so that total time >= 0.2. 217 | 218 | Calls the timeit method with increasing numbers from the sequence 219 | 1, 2, 5, 10, 20, 50, ... until the time taken is at least 0.2 220 | second. Returns (number, time_taken). 221 | 222 | If *callback* is given and is not None, it will be called after 223 | each trial with two arguments: ``callback(number, time_taken)``. 224 | """ 225 | i = 1 226 | while True: 227 | for j in 1, 2, 5: 228 | number = i * j 229 | time_taken = self.timeit(number) 230 | if callback: 231 | callback(number, time_taken) 232 | if time_taken >= 0.2: 233 | return (number, time_taken) 234 | i *= 10 235 | 236 | def timeit(stmt="pass", setup="pass", finish='pass', timer=default_timer, 237 | number=default_number, globals=None): 238 | """Convenience function to create Timer object and call timeit method.""" 239 | return Timer(stmt, setup, finish, timer, globals).timeit(number) 240 | 241 | def repeat(stmt="pass", setup="pass", finish='pass', timer=default_timer, 242 | repeat=default_repeat, number=default_number, globals=None): 243 | """Convenience function to create Timer object and call repeat method.""" 244 | return Timer(stmt, setup, finish, timer, globals).repeat(repeat, number) 245 | 246 | def py_benchmark(stmt, context, min_repeat_second=1, setup='pass', finish='pass'): 247 | total_time = 0 248 | number = 10 249 | 250 | eval(stmt, context) # warmup 251 | total_time = timeit(stmt=stmt, setup=setup, finish=finish, number=number, globals=context) 252 | while total_time < min_repeat_second: 253 | number = int(number * (min_repeat_second / total_time)) + 1 254 | total_time = timeit(stmt=stmt, setup=setup, finish=finish, number=number, globals=context) 255 | 256 | return total_time / number 257 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.2 2 | packaging==24.0 3 | absl-py==2.0.0 4 | accelerate==0.25.0 5 | aiofiles==23.2.1 6 | aiohttp==3.9.1 7 | aiosignal==1.3.1 8 | altair==5.2.0 9 | annotated-types==0.6.0 10 | anyio==4.3.0 11 | asttokens==2.4.1 12 | async-timeout==4.0.3 13 | attributedict==0.3.0 14 | attrs==23.2.0 15 | bitsandbytes==0.43.0 16 | blessings==1.7 17 | cachetools==5.3.2 18 | certifi==2023.11.17 19 | chardet==5.2.0 20 | charset-normalizer==3.3.2 21 | click==8.1.7 22 | codecov==2.1.13 23 | colorama==0.4.6 24 | coloredlogs==15.0.1 25 | colour-runner==0.1.1 26 | contourpy==1.2.0 27 | coverage==7.4.0 28 | cycler==0.12.1 29 | DataProperty==1.0.1 30 | datasets==2.16.1 31 | decorator==5.1.1 32 | deepdiff==6.7.1 33 | deepspeed==0.12.6 34 | dill==0.3.7 35 | distlib==0.3.8 36 | distro==1.9.0 37 | einops==0.7.0 38 | evaluate==0.4.1 39 | exceptiongroup==1.2.0 40 | executing==2.0.1 41 | fastapi==0.110.0 42 | ffmpy==0.3.2 43 | filelock==3.13.1 44 | flash-attn==2.5.6 45 | fonttools==4.50.0 46 | frozenlist==1.4.1 47 | fsspec==2023.10.0 48 | fuzzywuzzy==0.18.0 49 | gradio==5.0.0 50 | gradio_client==0.2.9 51 | h11==0.14.0 52 | hjson==3.1.0 53 | httpcore==1.0.4 54 | httpx==0.27.0 55 | huggingface-hub==0.20.2 56 | humanfriendly==10.0 57 | idna==3.6 58 | inspecta==0.1.3 59 | ipdb==0.13.13 60 | ipython==8.19.0 61 | jedi==0.19.1 62 | jieba==0.42.1 63 | Jinja2==3.1.2 64 | joblib==1.3.2 65 | jsonlines==4.0.0 66 | jsonschema==4.21.1 67 | jsonschema-specifications==2023.12.1 68 | kiwisolver==1.4.5 69 | linkify-it-py==2.0.3 70 | -e git+https://github.com/EleutherAI/lm-evaluation-harness.git@c9bbec6e7de418b9082379da82797522eb173054#egg=lm_eval 71 | lxml==5.0.1 72 | markdown-it-py==2.2.0 73 | MarkupSafe==2.1.3 74 | matplotlib==3.8.3 75 | matplotlib-inline==0.1.6 76 | mbstrdecoder==1.1.3 77 | mdit-py-plugins==0.3.3 78 | mdurl==0.1.2 79 | mpmath==1.3.0 80 | multidict==6.0.4 81 | multiprocess==0.70.15 82 | networkx==3.2.1 83 | ninja==1.11.1.1 84 | nltk==3.8.1 85 | numexpr==2.8.8 86 | numpy==1.26.3 87 | nvidia-cublas-cu12==12.1.3.1 88 | nvidia-cuda-cupti-cu12==12.1.105 89 | nvidia-cuda-nvrtc-cu12==12.1.105 90 | nvidia-cuda-runtime-cu12==12.1.105 91 | nvidia-cudnn-cu12==8.9.2.26 92 | nvidia-cufft-cu12==11.0.2.54 93 | nvidia-curand-cu12==10.3.2.106 94 | nvidia-cusolver-cu12==11.4.5.107 95 | nvidia-cusparse-cu12==12.1.0.106 96 | nvidia-nccl-cu12==2.18.1 97 | nvidia-nvjitlink-cu12==12.3.101 98 | nvidia-nvtx-cu12==12.1.105 99 | openai==1.14.2 100 | ordered-set==4.1.0 101 | orjson==3.9.15 102 | packaging==23.2 103 | pandas==2.1.4 104 | parso==0.8.3 105 | pathvalidate==3.2.0 106 | peft==0.7.1 107 | pexpect==4.9.0 108 | pillow==10.2.0 109 | platformdirs==4.1.0 110 | pluggy==1.3.0 111 | portalocker==2.8.2 112 | prompt-toolkit==3.0.43 113 | protobuf==4.25.1 114 | psutil==5.9.7 115 | ptyprocess==0.7.0 116 | pure-eval==0.2.2 117 | py-cpuinfo==9.0.0 118 | pyarrow==14.0.2 119 | pyarrow-hotfix==0.6 120 | pybind11==2.11.1 121 | pycountry==23.12.11 122 | pydantic==1.10.14 123 | pydantic_core==2.14.6 124 | pydub==0.25.1 125 | Pygments==2.17.2 126 | pynvml==11.5.0 127 | pyparsing==3.1.2 128 | pyproject-api==1.6.1 129 | pytablewriter==1.2.0 130 | python-dateutil==2.8.2 131 | python-multipart==0.0.9 132 | pytz==2023.3.post1 133 | PyYAML==6.0.1 134 | referencing==0.34.0 135 | regex==2023.12.25 136 | requests==2.31.0 137 | responses==0.18.0 138 | rootpath==0.1.1 139 | rouge==1.0.1 140 | rouge-score==0.1.2 141 | rpds-py==0.18.0 142 | sacrebleu==1.5.0 143 | safetensors==0.4.1 144 | scikit-learn==1.3.2 145 | scipy==1.11.4 146 | semantic-version==2.10.0 147 | sentencepiece==0.1.99 148 | six==1.16.0 149 | sniffio==1.3.1 150 | sqlitedict==2.1.0 151 | stack-data==0.6.3 152 | starlette==0.36.3 153 | sympy==1.12 154 | tabledata==1.3.3 155 | tabulate==0.9.0 156 | tcolorpy==0.1.4 157 | termcolor==2.4.0 158 | texttable==1.7.0 159 | threadpoolctl==3.2.0 160 | tokenizers==0.15.0 161 | toml==0.10.2 162 | tomli==2.0.1 163 | toolz==0.12.1 164 | torchvision==0.16.2 165 | tox==4.11.4 166 | tqdm==4.66.1 167 | tqdm-multiprocess==0.0.11 168 | traitlets==5.14.1 169 | transformers==4.36.2 170 | triton==2.1.0 171 | typepy==1.3.2 172 | typing_extensions==4.9.0 173 | tzdata==2023.4 174 | uc-micro-py==1.0.3 175 | urllib3==2.1.0 176 | uvicorn==0.29.0 177 | virtualenv==20.25.0 178 | wcwidth==0.2.13 179 | websockets==12.0 180 | xxhash==3.4.1 181 | yarl==1.9.4 182 | zstandard==0.22.0 183 | -------------------------------------------------------------------------------- /scripts/long_test.sh: -------------------------------------------------------------------------------- 1 | # model e.g.: meta-llama/Llama-2-7b-hf 2 | 3 | gpuid=$1 4 | k_bits=$2 5 | v_bits=$3 6 | group_size=$4 7 | residual_length=$5 8 | model=$6 9 | e=0 10 | 11 | CUDA_VISIBLE_DEVICES=$gpuid python pred_long_bench.py --model_name_or_path $model \ 12 | --cache_dir ./cached_models \ 13 | --k_bits $k_bits \ 14 | --v_bits $v_bits \ 15 | --group_size $group_size \ 16 | --residual_length $residual_length \ 17 | --e ${e} -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from torch.utils.data import Dataset 5 | from datasets import load_dataset 6 | from torch.utils.data import DataLoader 7 | 8 | class TextDataset(torch.utils.data.IterableDataset): 9 | def __init__(self, data, tokenizer, seqlen, col_key, cutoff=1000): 10 | self.tokenizer = tokenizer 11 | self.col_key = col_key 12 | self.cutoff = cutoff 13 | self.block_size = seqlen 14 | if cutoff is None: 15 | cutoff = len(data) 16 | tokenized_datasets = [self.tokenizer(data[i][col_key]) for i in range(cutoff)] 17 | grouped_dataset = self.group_texts(tokenized_datasets) 18 | self.input_ids = grouped_dataset["input_ids"] 19 | self.labels = grouped_dataset["labels"] 20 | self.data = [ 21 | dict(input_ids=self.input_ids[i], labels=self.labels[i]) 22 | for i in range(len(self.input_ids)) 23 | ] 24 | 25 | def __len__(self): 26 | return len(self.input_ids) 27 | 28 | def __getitem__(self, i): 29 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 30 | 31 | def __iter__(self): 32 | return iter(self.data) 33 | 34 | def group_texts(self, examples): 35 | # Concatenate all texts. 36 | # Initialize an empty dictionary 37 | concatenated_examples = {} 38 | 39 | # Loop through the list of dictionaries 40 | for d in examples: 41 | # Loop through the keys in each dictionary 42 | for key in d.keys(): 43 | # If the key is not already a key in the dict_of_lists, create a new list 44 | if key not in concatenated_examples: 45 | concatenated_examples[key] = [] 46 | # Append the value to the list associated with the key in dict_of_lists 47 | concatenated_examples[key].extend(d[key]) 48 | total_length = len(concatenated_examples["input_ids"]) 49 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 50 | # customize this part to your needs. 51 | if total_length >= self.block_size: 52 | total_length = (total_length // self.block_size) * self.block_size 53 | # Split by chunks of max_len. 54 | result = { 55 | k: [ 56 | t[i : i + self.block_size] 57 | for i in range(0, total_length, self.block_size) 58 | ] 59 | for k, t in concatenated_examples.items() 60 | } 61 | result["labels"] = result["input_ids"].copy() 62 | return result 63 | 64 | 65 | # class TextDataset(Dataset): 66 | # def __init__(self, data, tokenizer, seqlen, col_key, per_device_train_batch_size, mode="train", cutoff=None): 67 | # self.tokenizer = tokenizer 68 | # self.mode = mode 69 | # self.col_key = col_key 70 | # self.cutoff = cutoff 71 | # self.seqlen = seqlen 72 | # self.per_device_train_batch_size = per_device_train_batch_size 73 | 74 | # if self.mode == "train": 75 | # self.encodings = self.process_data(data) 76 | # else: 77 | # self.encodings = self.process_data(data, is_val=True) 78 | 79 | # def process_data(self, data, is_val=False): 80 | # seqlen = self.seqlen 81 | # if is_val: 82 | # if self.cutoff is None: 83 | # enc = self.tokenizer(" ".join(data[self.col_key]), return_tensors='pt') 84 | # else: 85 | # enc = self.tokenizer(" ".join(data[:self.cutoff][self.col_key]), return_tensors='pt') 86 | # tot_num_seq = enc['input_ids'].size(1) // seqlen 87 | # enc['input_ids'] = enc['input_ids'][..., :tot_num_seq*seqlen] 88 | # else: 89 | # if self.cutoff is None: 90 | # enc = self.tokenizer(" ".join(data[self.col_key]), return_tensors='pt') 91 | # else: 92 | # enc = self.tokenizer(" ".join(data[:self.cutoff][self.col_key]), return_tensors='pt') 93 | # tot_num_seq = enc['input_ids'].size(1) // (seqlen*self.per_device_train_batch_size) 94 | # enc['input_ids'] = enc['input_ids'][..., :tot_num_seq*seqlen*self.per_device_train_batch_size] 95 | 96 | # return enc 97 | 98 | # def __getitem__(self, idx): 99 | # input_ids = self.encodings['input_ids'][0, idx*self.seqlen:(idx+1)*self.seqlen] 100 | # return input_ids 101 | 102 | # def __len__(self): 103 | # return self.encodings['input_ids'].size(1) // self.seqlen 104 | 105 | 106 | def set_seed(seed): 107 | np.random.seed(seed) 108 | torch.random.manual_seed(seed) 109 | random.seed(seed) 110 | 111 | 112 | def get_c4(n_train_samples, n_eval_samples, seqlen, tokenizer): 113 | # raw_tra_data = load_dataset("c4", split="train") 114 | raw_tra_data = load_dataset('allenai/c4', 'allenai--c4', 115 | data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, 116 | split='train') 117 | # raw_val_data = load_dataset("c4", split="validation") 118 | raw_val_data = load_dataset('allenai/c4', 'allenai--c4', 119 | data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, 120 | split='validation') 121 | train_dataset = TextDataset(raw_tra_data, tokenizer, 122 | col_key='text', 123 | cutoff=n_train_samples, 124 | seqlen=seqlen) 125 | val_dataset = TextDataset(raw_val_data, tokenizer, 126 | col_key='text', 127 | cutoff=n_eval_samples, # todo: change to 1100 128 | seqlen=seqlen) 129 | return train_dataset, val_dataset 130 | 131 | 132 | def get_loaders( 133 | name, enc, n_train_samples=128, n_eval_samples=1024, seqlen=2048): 134 | if 'c4' in name: 135 | return get_c4(n_train_samples, n_eval_samples, seqlen, enc) 136 | else: 137 | raise NotImplementedError -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | import jieba 5 | from fuzzywuzzy import fuzz 6 | import difflib 7 | 8 | from typing import List 9 | from collections import Counter 10 | from rouge import Rouge 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | 15 | def remove_articles(text): 16 | return re.sub(r"\b(a|an|the)\b", " ", text) 17 | 18 | def white_space_fix(text): 19 | return " ".join(text.split()) 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return "".join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 29 | 30 | 31 | def normalize_zh_answer(s): 32 | """Lower text and remove punctuation, extra whitespace.""" 33 | 34 | def white_space_fix(text): 35 | return "".join(text.split()) 36 | 37 | def remove_punc(text): 38 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." 39 | all_punctuation = set(string.punctuation + cn_punctuation) 40 | return "".join(ch for ch in text if ch not in all_punctuation) 41 | 42 | def lower(text): 43 | return text.lower() 44 | 45 | return white_space_fix(remove_punc(lower(s))) 46 | 47 | def count_score(prediction, ground_truth, **kwargs): 48 | numbers = re.findall(r"\d+", prediction) 49 | right_num = 0 50 | for number in numbers: 51 | if str(number) == str(ground_truth): 52 | right_num += 1 53 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 54 | return float(final_score) 55 | 56 | def retrieval_score(prediction, ground_truth, **kwargs): 57 | pattern = r'Paragraph (\d+)' 58 | matches = re.findall(pattern, ground_truth) 59 | ground_truth_id = matches[0] 60 | numbers = re.findall(r"\d+", prediction) 61 | right_num = 0 62 | for number in numbers: 63 | if str(number) == str(ground_truth_id): 64 | right_num += 1 65 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 66 | return float(final_score) 67 | 68 | def retrieval_zh_score(prediction, ground_truth, **kwargs): 69 | pattern = r'段落(\d+)' 70 | matches = re.findall(pattern, ground_truth) 71 | ground_truth_id = matches[0] 72 | numbers = re.findall(r"\d+", prediction) 73 | right_num = 0 74 | for number in numbers: 75 | if str(number) == str(ground_truth_id): 76 | right_num += 1 77 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 78 | return float(final_score) 79 | 80 | def code_sim_score(prediction, ground_truth, **kwargs): 81 | all_lines = prediction.lstrip('\n').split('\n') 82 | prediction = "" 83 | for line in all_lines: 84 | if ('`' not in line) and ('#' not in line) and ('//' not in line): 85 | prediction = line 86 | break 87 | return (fuzz.ratio(prediction, ground_truth) / 100) 88 | 89 | def classification_score(prediction, ground_truth, **kwargs): 90 | em_match_list = [] 91 | all_classes = kwargs["all_classes"] 92 | for class_name in all_classes: 93 | if class_name in prediction: 94 | em_match_list.append(class_name) 95 | for match_term in em_match_list: 96 | if match_term in ground_truth and match_term != ground_truth: 97 | em_match_list.remove(match_term) 98 | if em_match_list != 0: 99 | if ground_truth in em_match_list: 100 | score = (1.0 / len(em_match_list)) 101 | else: 102 | score = 0.0 103 | else: 104 | best_match = None 105 | highest_similarity = 0 106 | for string in all_classes: 107 | similarity = difflib.SequenceMatcher(None, string, prediction).ratio() 108 | if similarity > highest_similarity: 109 | highest_similarity = similarity 110 | best_match = string 111 | score = float(best_match == ground_truth) 112 | return score 113 | 114 | def rouge_score(prediction, ground_truth, **kwargs): 115 | rouge = Rouge() 116 | try: 117 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 118 | except: 119 | return 0.0 120 | return scores["rouge-l"]["f"] 121 | 122 | def rouge_zh_score(prediction, ground_truth, **kwargs): 123 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) 124 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 125 | score = rouge_score(prediction, ground_truth) 126 | return score 127 | 128 | def f1_score(prediction, ground_truth, **kwargs): 129 | common = Counter(prediction) & Counter(ground_truth) 130 | num_same = sum(common.values()) 131 | if num_same == 0: 132 | return 0 133 | precision = 1.0 * num_same / len(prediction) 134 | recall = 1.0 * num_same / len(ground_truth) 135 | f1 = (2 * precision * recall) / (precision + recall) 136 | return f1 137 | 138 | def qa_f1_score(prediction, ground_truth, **kwargs): 139 | normalized_prediction = normalize_answer(prediction) 140 | normalized_ground_truth = normalize_answer(ground_truth) 141 | 142 | prediction_tokens = normalized_prediction.split() 143 | ground_truth_tokens = normalized_ground_truth.split() 144 | return f1_score(prediction_tokens, ground_truth_tokens) 145 | 146 | 147 | def qa_f1_zh_score(prediction, ground_truth, **kwargs): 148 | prediction_tokens = list(jieba.cut(prediction, cut_all=False)) 149 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) 150 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] 151 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] 152 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0] 153 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] 154 | return f1_score(prediction_tokens, ground_truth_tokens) -------------------------------------------------------------------------------- /utils/process_args.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from dataclasses import dataclass, field 10 | from typing import Optional 11 | 12 | import transformers 13 | 14 | 15 | @dataclass 16 | class ModelArguments: 17 | model_name_or_path: str = field( 18 | default=None, metadata={"help": "Output model local path, do not set manually"} 19 | ) 20 | k_bits: Optional[int] = field( 21 | default=2, 22 | metadata={"help": "KV_cache quantization bits."}, 23 | ) 24 | v_bits: Optional[int] = field( 25 | default=2, 26 | metadata={"help": "KV_cache quantization bits."}, 27 | ) 28 | k_quant_dim: Optional[str] = field( 29 | default='token', 30 | metadata={"help": "KV_cache quantization bits."}, 31 | ) 32 | v_quant_dim: Optional[str] = field( 33 | default='token', 34 | metadata={"help": "KV_cache quantization bits."}, 35 | ) 36 | group_size: Optional[int] = field( 37 | default=128, 38 | metadata={"help": "KV_cache quantization group size."}, 39 | ) 40 | residual_length: Optional[int] = field( 41 | default=128, 42 | metadata={"help": "KV_cache residual length."}, 43 | ) 44 | output_model_filename: Optional[str] = field( 45 | default="test-output", metadata={"help": "Output model relative manifold path"} 46 | ) 47 | load_quant: Optional[str] = field( 48 | default=None, 49 | metadata={"help": "The path to a quantized model"}, 50 | ) 51 | w_bit: Optional[int] = field( 52 | default=4, 53 | metadata={"help": "The model weight bit width."}, 54 | ) 55 | lora: Optional[bool] = field( 56 | default=False, 57 | metadata={"help": "Whether to use LoRA"}, 58 | ) 59 | lora_mode: Optional[str] = field( 60 | default="q", 61 | metadata={"help": "LoRA mode"}, 62 | ) 63 | lora_r: Optional[int] = field( 64 | default=1, 65 | metadata={"help": "LoRA r"}, 66 | ) 67 | lora_alpha: Optional[float] = field( 68 | default=1., 69 | metadata={"help": "LoRA alpha"}, 70 | ) 71 | lora_dropout: Optional[float] = field( 72 | default=0., 73 | metadata={"help": "LoRA dropout"}, 74 | ) 75 | 76 | 77 | 78 | @dataclass 79 | class DataArguments: 80 | dataset: Optional[str] = field( 81 | default='c4', 82 | metadata={"help": "The dataset used for fine-tuning the model."}, 83 | ) 84 | eval_tasks: Optional[str] = field( 85 | default='wikitext', 86 | metadata={"help": "The dataset used for evaluation."}, 87 | ) 88 | tasks: Optional[str] = field( 89 | default='wikitext', 90 | metadata={"help": "The dataset used for evaluation."}, 91 | ) 92 | batch_size: Optional[int] = field( 93 | default=1, 94 | metadata={"help": "The batch size."}, 95 | ) 96 | num_fewshot: Optional[int] = field( 97 | default=0, 98 | metadata={"help": "The number of fewshot examples."}, 99 | ) 100 | output_path: Optional[str] = field( 101 | default='./outputs', 102 | metadata={"help": "The output path."}, 103 | ) 104 | e: Optional[bool] = field( 105 | default=False, 106 | metadata={"help": "Evaluate on LongBench-E."}, 107 | ) 108 | use_our_imp: Optional[bool] = field( 109 | default=True, 110 | metadata={"help": "Whether to use our KV cache quantization implementation."}, 111 | ) 112 | 113 | 114 | 115 | @dataclass 116 | class TrainingArguments(transformers.TrainingArguments): 117 | cache_dir: Optional[str] = field(default=None) 118 | optim: Optional[str] = field(default="adamw_torch") 119 | output_dir: Optional[str] = field(default="./outputs") 120 | model_max_length: Optional[int] = field( 121 | default=512, 122 | metadata={ 123 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated). 512 or 1024" 124 | }, 125 | ) 126 | num_train_epochs: Optional[int] = field(default=1) 127 | n_train_samples: Optional[int] = field(default=None) 128 | n_eval_samples: Optional[int] = field(default=None) 129 | qat: Optional[bool] = field(default=False) 130 | exp_name: Optional[str] = field(default="test") 131 | 132 | 133 | def process_args(): 134 | parser = transformers.HfArgumentParser( 135 | (ModelArguments, DataArguments, TrainingArguments) 136 | ) 137 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 138 | os.makedirs(training_args.output_dir, exist_ok=True) 139 | 140 | model_args.output_model_local_path = os.path.join( 141 | training_args.output_dir, "models", str(model_args.output_model_filename) 142 | ) 143 | 144 | return model_args, data_args, training_args 145 | -------------------------------------------------------------------------------- /vis/vis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import os\n", 11 | "import matplotlib\n", 12 | "import numpy as np\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "from mpl_toolkits.mplot3d import Axes3D\n", 15 | "from matplotlib.colors import Normalize\n", 16 | "\n", 17 | "FONTSIZE = 16\n", 18 | "\n", 19 | "font_config = {'font.size': FONTSIZE, 'font.family': 'DejaVu Math TeX Gyre'}\n", 20 | "plt.rcParams.update(font_config)\n", 21 | "plt.rcParams[\"figure.figsize\"] = (4, 4.5)\n", 22 | "\n", 23 | "# generate kv cache and attention\n", 24 | "# inputs = enc(sample, return_tensors='pt').to('cuda')\n", 25 | "# outputs = model(inputs['input_ids'], use_cache=True, output_attentions=True)\n", 26 | "# past_key_values = outputs.past_key_values\n", 27 | "# attentions = outputs.attentions\n", 28 | "# torch.save(past_key_values, f'./{model}_kvcache.pt')\n", 29 | "# torch.save(attentions, f'./{model}_attention.pt')\n", 30 | "\n", 31 | "model = 'Llama-2-7b-hf' # replace with your model name\n", 32 | "kv_filename = f'./{model}_kvcache.pt'\n", 33 | "attn_filename = f'./{model}_attention.pt'\n", 34 | "kvcache = torch.load(kv_filename, map_location='cpu')\n", 35 | "attentions = torch.load(attn_filename, map_location='cpu')" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "for layer_id in [3, 8, 14, 16, 18, 20, 31]: # replace with your layer ids\n", 45 | " head_id = 0\n", 46 | " k, v = kvcache[layer_id][0].squeeze(0), kvcache[layer_id][1].squeeze(0)\n", 47 | "\n", 48 | " k = k.transpose(0, 1).abs().detach().numpy()\n", 49 | " v = v.transpose(0, 1).abs().detach().numpy()\n", 50 | " k, v = k[:, head_id, :], v[:, head_id, :]\n", 51 | "\n", 52 | " # Sample 2D tensor (replace this with your actual tensor)\n", 53 | " for idx, tensor in enumerate([k, v]):\n", 54 | " # Creating a meshgrid\n", 55 | " tokens, channels = tensor.shape\n", 56 | " x = np.arange(tokens)\n", 57 | " y = np.arange(channels)\n", 58 | " X, Y = np.meshgrid(x, y)\n", 59 | " # Creating a figure and a 3D subplot\n", 60 | " fig = plt.figure()\n", 61 | " ax = fig.add_subplot(111, projection='3d')\n", 62 | " # Plotting the surface\n", 63 | " surf = ax.plot_surface(X, Y, tensor.T, cmap='coolwarm')\n", 64 | "\n", 65 | " ax.xaxis.set_tick_params(pad=-5)\n", 66 | " ax.yaxis.set_tick_params(pad=-3)\n", 67 | " ax.zaxis.set_tick_params(pad=-130)\n", 68 | "\n", 69 | " # Adding labels\n", 70 | " ax.set_xlabel('Token', labelpad=-5)\n", 71 | " ax.set_ylabel('Column', labelpad=-1)\n", 72 | " if layer_id in [3, 16]:\n", 73 | " ax.zaxis.set_rotate_label(False) \n", 74 | " if idx == 0:\n", 75 | " save_filename = f'./saved_figs/{model}_layer{layer_id}_head{head_id}_k.pdf'\n", 76 | " else:\n", 77 | " save_filename = f'./saved_figs/{model}_layer{layer_id}_head{head_id}_v.pdf'\n", 78 | " plt.savefig(save_filename, bbox_inches='tight')\n", 79 | " plt.clf()" 80 | ] 81 | } 82 | ], 83 | "metadata": { 84 | "language_info": { 85 | "name": "python" 86 | } 87 | }, 88 | "nbformat": 4, 89 | "nbformat_minor": 2 90 | } 91 | --------------------------------------------------------------------------------