├── .gitignore
├── LICENSE
├── README.md
├── config
├── dataset2maxlen.json
├── dataset2prompt.json
├── model2maxlen.json
└── model2path.json
├── docs
└── long_bench.md
├── eval_long_bench.py
├── example.py
├── img
├── algo.png
└── quant_scheme.png
├── long_context_example.py
├── mem_spd_test.py
├── metrics.py
├── models
├── __init__.py
├── llama_kivi.py
├── mistral_kivi.py
└── utils_quant.py
├── passkey_examples.jsonl
├── pred_long_bench.py
├── pyproject.toml
├── quant
├── __init__.py
├── csrc
│ ├── gemv_cuda.cu
│ ├── gemv_cuda.h
│ ├── gemv_cuda_backup.cu
│ └── pybind.cpp
├── gemv.py
├── matmul.py
├── new_pack.py
├── qmodule.py
├── setup.py
├── test.py
└── timeit_v2.py
├── requirements.txt
├── scripts
└── long_test.sh
├── utils
├── data.py
├── metrics.py
└── process_args.py
└── vis
└── vis.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | outputs
3 | logs
4 | cached_models
5 | *.egg-info
6 | build
7 | *.so
8 | out_*
9 | *.sh
10 | *.jsonl
11 | *.pt
12 | notebook
13 | dataset
14 | output_samples
15 | *.pdf
16 | *.json
17 | *.jsonl
18 | third_party
19 | speed_logs
20 | pred
21 | pred_e
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 jiayi yuan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache
2 |
3 | Implementation of [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache](https://arxiv.org/abs/2402.02750)
4 |
5 | ## Updates
6 | - [2025.01.18]:We add KIVI implementation with GQA and compatiable with transformers 4.43. Now it supports LLama3 family. Please reinstall KIVI.
7 | - [2024.06.07]:🎉 KIVI largely inspires the [HuggingFace Transformers KV Cache quantization](https://huggingface.co/docs/transformers/main/en/kv_cache)
8 | - [2024.06.06]:(Beta) We extensively optimize the codebase in [branch develop](https://github.com/jy-yuan/KIVI/tree/develop) to reduce the latency of KIVI. Note that **you need to reinstall our CUDA implementation** under the ```quant``` folder. We will release a blog soon about the detailed optimization.
9 | - [2024.05.01]:🎉 KIVI has been accepted by ICML 2024! See you in Vienna!
10 | - [2024.04.12]: We add the support for Mistral model family. The performance of LongChat-7b-v1.5-32K and Mistral-7B-Instruct-v0.2 on 15 tasks from LongBench can be found in [long_bench.md](./docs/long_bench.md).
11 |
12 | - [2024.04.05]: We release the code for reproducing our CoQA/TruthfulQA/GSM8K results using LM-Eval. Please check the [README of branch lmeval](https://github.com/jy-yuan/KIVI/tree/lmeval).
13 |
14 | - [2024.04.04]: 🔥🔥We add a new 5-digit [passkey example](./long_context_example.py) with 12k context length to show the performance of 2bit KIVI under the long context senario.
15 |
16 | - [2024.04.04]: (Beta) We add the flash-attention support for KIVI during the prefill phase.
17 |
18 | - [2024.04.03]: We add a new [5-shot GSM8K example.py](./example.py) to show the performance of 2/4 bit KIVI with 32 full precision tokens.
19 |
20 | - [2024.02.05]: KIVI ver. 2 is released on [arXiv](https://arxiv.org/abs/2402.02750).
21 |
22 | - [2024.02.03]: KIVI code is released.
23 |
24 | - [2023.12.29]: KIVI ver. 1 is released on [researchgate](https://www.researchgate.net/publication/376831635_KIVI_Plug-and-play_2bit_KV_Cache_Quantization_with_Streaming_Asymmetric_Quantization).
25 |
26 | ## Overview
27 |
28 | KIVI is a new plug-and-play 2bit KV cache quantization algorithm without any fine-tuning. This algorithm optimizes memory usage by quantizing the key cache per-channel and the value cache per-token to 2bit. KIVI's hardware-friendly design allows LLMs like Llama-2, Falcon, and Mistral to maintain comparable quality levels while reducing peak memory usage by 2.6 times. This enables up to 4 times larger batch sizes and significantly increases throughput by 2.35 to 3.47 times in real LLM inference workloads, effectively addressing the bottleneck issues in speed and memory usage.
29 |
30 | Illustration of KIVI quantization scheme: key cache per-channel and value cache per-token.
31 |
32 |
33 |
34 |
35 | Illustration of KIVI algorithm during inference prefill and decoding phase:
36 |
37 |
38 |
39 |
40 | ## How to use KIVI
41 |
42 | ### Setup
43 |
44 | To install the required packages:
45 |
46 | ```bash
47 | conda create -n kivi python=3.10
48 | conda activate kivi
49 | pip install --upgrade pip # enable PEP 660 support
50 | pip install -e .
51 | ```
52 |
53 | Then install our CUDA implementation:
54 |
55 | ```bash
56 | cd quant && pip install -e .
57 | ```
58 |
59 | ### Example
60 |
61 | Load model with KIVI: (e.g., Llama-2-7b)
62 |
63 | ```python
64 | # LLaMA model with KIVI
65 | import torch
66 | import os
67 | from models.llama_kivi import LlamaForCausalLM_KIVI
68 | from transformers import LlamaConfig, AutoTokenizer
69 | config = LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
70 |
71 | config.k_bits = K_BITS # current support 2/4 bit for KV Cache
72 | config.v_bits = V_BITS # current support 2/4 bit for KV Cache
73 | config.group_size = GROUP_SIZE
74 | config.residual_length = RESIDUAL_LENGTH # the number of recent fp16 tokens
75 | CACHE_DIR = PATH_TO_YOUR_SAVE_DIR
76 |
77 | model = LlamaForCausalLM_KIVI.from_pretrained(
78 | pretrained_model_name_or_path='meta-llama/Llama-2-7b-hf',
79 | config=config,
80 | cache_dir=CACHE_DIR,
81 | torch_dtype=torch.float16,
82 | low_cpu_mem_usage=True,
83 | device_map="auto",
84 | )
85 |
86 | tokenizer = AutoTokenizer.from_pretrained(
87 | 'meta-llama/Llama-2-7b-hf',
88 | use_fast=False,
89 | trust_remote_code=True,
90 | tokenizer_type='llama')
91 |
92 | # Inference
93 | # e.g., model.generate(...)
94 | ```
95 |
96 | #### GSM8K example
97 | We use GSM8K as an example to show how to use KIVI. You can check [example.py](./example.py):
98 |
99 | ```bash
100 | python example.py
101 | ```
102 |
103 | #### Passkey retrieval example
104 |
105 | Passkey retrieval with KIVI. You can check [long_context_example.py](./long_context_example.py):
106 |
107 | ```bash
108 | python long_context_example.py
109 | ```
110 |
111 | #### Evaluate KIVI on LongBench
112 |
113 | We currently support Llama and Mistral family of models. We recently test KIVI on Mistral-7B-Instruct-v0.2 and Longchat-7b-v1.5-32k. Please check [long_bench.md](./docs/long_bench.md) for more details.
114 | ```bash
115 | bash scripts/long_test.sh {GPU_ID} {K_BITS} {V_BITS} {GROUP_LENGTH} {RESIDUAL_LENGTH} {MODEL_NAME}
116 | python eval_long_bench.py --model {MODEL} # MODEL is the dir name under pred/ Currently it support Llama family model and Mistral model.
117 | ```
118 |
119 | ## Citation
120 |
121 | If you find our method useful, please kindly cite our paper.
122 |
123 | ```bibtex
124 | @article{liu2024kivi,
125 | title={KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache},
126 | author={Liu, Zirui and Yuan, Jiayi and Jin, Hongye and Zhong, Shaochen and Xu, Zhaozhuo and Braverman, Vladimir and Chen, Beidi and Hu, Xia},
127 | journal={arXiv preprint arXiv:2402.02750},
128 | year={2024}
129 | }
130 | ```
131 |
132 | ## Contributing
133 | We welcome contributions from the research community to improve KIVI. If you have any idea or would like to report a bug, please open an issue or submit a pull request.
134 |
135 | ## License
136 | The code is released under the MIT License.
137 |
--------------------------------------------------------------------------------
/config/dataset2maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "narrativeqa": 128,
3 | "qasper": 128,
4 | "multifieldqa_en": 64,
5 | "multifieldqa_zh": 64,
6 | "hotpotqa": 32,
7 | "2wikimqa": 32,
8 | "musique": 32,
9 | "dureader": 128,
10 | "gov_report": 512,
11 | "qmsum": 512,
12 | "multi_news": 512,
13 | "vcsum": 512,
14 | "trec": 64,
15 | "triviaqa": 32,
16 | "samsum": 128,
17 | "lsht": 64,
18 | "passage_count": 32,
19 | "passage_retrieval_en": 32,
20 | "passage_retrieval_zh": 32,
21 | "lcc": 64,
22 | "repobench-p": 64
23 | }
--------------------------------------------------------------------------------
/config/dataset2prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
23 | }
--------------------------------------------------------------------------------
/config/model2maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "llama2-7b-chat-4k": 3500,
3 | "longchat-v1.5-7b-32k": 31500,
4 | "xgen-7b-8k": 7500,
5 | "internlm-7b-8k": 7500,
6 | "chatglm2-6b": 31500,
7 | "chatglm2-6b-32k": 31500,
8 | "vicuna-v1.5-7b-16k": 15500,
9 | "LLaMA-2-7B-32K": 7500,
10 | "llama-7b": 4096,
11 | "Llama-2-7b-chat-hf": 4096,
12 | "Llama-2-7b-hf": 4096,
13 | "llama-13b": 4096,
14 | "Llama-2-13b-chat-hf": 4096,
15 | "Llama-2-13b-hf": 4096,
16 | "falcon-7b": 4096,
17 | "Mistral-7B-v0.1": 8192,
18 | "longchat-7b-v1.5-32k": 31500,
19 | "Mistral-7B-Instruct-v0.2": 31500
20 | }
--------------------------------------------------------------------------------
/config/model2path.json:
--------------------------------------------------------------------------------
1 | {
2 | "llama2-7b-chat-4k": "meta-llama/Llama-2-7b-chat-hf",
3 | "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k",
4 | "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst",
5 | "internlm-7b-8k": "internlm/internlm-chat-7b-8k",
6 | "chatglm2-6b": "THUDM/chatglm2-6b",
7 | "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
8 | "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k"
9 | }
--------------------------------------------------------------------------------
/docs/long_bench.md:
--------------------------------------------------------------------------------
1 | # More results on LongBench
2 |
3 | From the results, for vanilla multihead attention models, we recommend using KIVI-2, which can maintain the performance of the full-precision model while offering the best efficiency. For multiquery attention or group query attention, since the keys and values are already compressed, we recommend using KIVI-4, which can maintain the performance of the full-precision model in these cases.
4 |
5 | ### Table 1: Performance of LongChat-7b-v1.5-32K
6 |
7 | The results of LongChat-7b-v1.5-32K on 15 tasks from LongBench. The model has 32K context length. We use a 32 group size and 128 residual length for both KIVI-2 and KIVI-4. The baseline is of full precision.
8 |
9 | | Task | LongChat-7b-v1.5-32K | w./ KIVI-2 | w./ KIVI-4 |
10 | |------------------|:--------------------:|:------------:|:-------------:|
11 | | NarrativeQA | 20.65 | 20.79 | 20.49 |
12 | | Qasper | 29.42 | 28.69 | 28.90 |
13 | | MultiFieldQA | 43.15 | 41.02 | 43.24 |
14 | | HotpotQA | 33.05 | 32.91 | 33.07 |
15 | | MuSiQue | 14.66 | 13.82 | 14.66 |
16 | | 2WikiMultihopQA | 24.14 | 23.00 | 24.86 |
17 | | GovReport | 30.85 | 30.47 | 31.40 |
18 | | QMSum | 22.84 | 22.59 | 22.84 |
19 | | MultiNews | 26.55 | 26.28 | 26.52 |
20 | | LCC | 54.83 | 54.11 | 54.06 |
21 | | RepoBench-P | 58.94 | 57.62 | 58.77 |
22 | | TriviaQA | 83.99 | 83.19 | 83.88 |
23 | | SAMSum | 40.75 | 41.28 | 40.62 |
24 | | TRec | 66.50 | 66.50 | 67.00 |
25 | | PassageRetrieval | 30.50 | 32.25 | 31.50 |
26 | | **Average** | **38.72** | **38.30** | **38.79** |
27 |
28 | ### Table 2: Performance of Mistral-7B-Instruct-v0.2
29 |
30 | The results of Mistral-7B-Instruct-v0.2 on 15 tasks from LongBench. The model has 32K context length and applies group query attention, which uses 8 heads for KV Cache, as opposed to the full 32 heads. We use a 32 group size and 128 residual length for both KIVI-2 and KIVI-4. The baseline is of full precision.
31 |
32 | | Task | Mistral-7B-Instruct-v0.2 | w./ KIVI-2 | w./ KIVI-4 |
33 | |------------------|:------------------------:|:------------:|:-------------:|
34 | | NarrativeQA | 21.02 | 20.61 | 20.97 |
35 | | Qasper | 29.41 | 28.73 | 29.41 |
36 | | MultiFieldQA | 47.13 | 44.88 | 46.52 |
37 | | HotpotQA | 36.53 | 35.47 | 36.25 |
38 | | MuSiQue | 19.13 | 17.95 | 19.53 |
39 | | 2WikiMultihopQA | 21.76 | 20.68 | 21.66 |
40 | | GovReport | 32.59 | 32.55 | 32.97 |
41 | | QMSum | 23.99 | 23.65 | 24.06 |
42 | | MultiNews | 27.09 | 26.54 | 26.89 |
43 | | LCC | 53.49 | 53.03 | 53.33 |
44 | | RepoBench-P | 51.40 | 51.16 | 51.41 |
45 | | TriviaQA | 86.23 | 86.00 | 86.23 |
46 | | SAMSum | 43.04 | 43.34 | 43.34 |
47 | | TRec | 71.00 | 71.00 | 71.00 |
48 | | PassageRetrieval | 89.33 | 80.83 | 89.42 |
49 | | **Average** | **43.54** | **42.43** | **43.53** |
50 |
--------------------------------------------------------------------------------
/eval_long_bench.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 |
6 | from metrics import (
7 | qa_f1_score,
8 | rouge_zh_score,
9 | qa_f1_zh_score,
10 | rouge_score,
11 | classification_score,
12 | retrieval_score,
13 | retrieval_zh_score,
14 | count_score,
15 | code_sim_score,
16 | )
17 |
18 | dataset2metric = {
19 | "narrativeqa": qa_f1_score,
20 | "qasper": qa_f1_score,
21 | "multifieldqa_en": qa_f1_score,
22 | "multifieldqa_zh": qa_f1_zh_score,
23 | "hotpotqa": qa_f1_score,
24 | "2wikimqa": qa_f1_score,
25 | "musique": qa_f1_score,
26 | "dureader": rouge_zh_score,
27 | "gov_report": rouge_score,
28 | "qmsum": rouge_score,
29 | "multi_news": rouge_score,
30 | "vcsum": rouge_zh_score,
31 | "trec": classification_score,
32 | "triviaqa": qa_f1_score,
33 | "samsum": rouge_score,
34 | "lsht": classification_score,
35 | "passage_retrieval_en": retrieval_score,
36 | "passage_count": count_score,
37 | "passage_retrieval_zh": retrieval_zh_score,
38 | "lcc": code_sim_score,
39 | "repobench-p": code_sim_score,
40 | }
41 |
42 | def parse_args(args=None):
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('--model', type=str, default=None)
45 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
46 | return parser.parse_args(args)
47 |
48 | def scorer_e(dataset, predictions, answers, lengths, all_classes):
49 | scores = {"0-4k": [], "4-8k": [], "8k+": []}
50 | for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
51 | score = 0.
52 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
53 | prediction = prediction.lstrip('\n').split('\n')[0]
54 | for ground_truth in ground_truths:
55 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
56 | if length < 4000:
57 | scores["0-4k"].append(score)
58 | elif length < 8000:
59 | scores["4-8k"].append(score)
60 | else:
61 | scores["8k+"].append(score)
62 | for key in scores.keys():
63 | scores[key] = round(100 * np.mean(scores[key]), 2)
64 | return scores
65 |
66 | def scorer(dataset, predictions, answers, all_classes):
67 | total_score = 0.
68 | for (prediction, ground_truths) in zip(predictions, answers):
69 | score = 0.
70 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
71 | prediction = prediction.lstrip('\n').split('\n')[0]
72 | for ground_truth in ground_truths:
73 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
74 | total_score += score
75 | return round(100 * total_score / len(predictions), 2)
76 |
77 | if __name__ == '__main__':
78 | args = parse_args()
79 | scores = dict()
80 | if args.e:
81 | path = f"pred_e/{args.model}/"
82 | else:
83 | path = f"pred/{args.model}/"
84 | all_files = os.listdir(path)
85 | print("Evaluating on:", all_files)
86 | for filename in all_files:
87 | if not filename.endswith("jsonl"):
88 | continue
89 | predictions, answers, lengths = [], [], []
90 | dataset = filename.split('.')[0]
91 | with open(f"{path}{filename}", "r", encoding="utf-8") as f:
92 | for line in f:
93 | data = json.loads(line)
94 | predictions.append(data["pred"])
95 | answers.append(data["answers"])
96 | all_classes = data["all_classes"]
97 | if "length" in data:
98 | lengths.append(data["length"])
99 | if args.e:
100 | score = scorer_e(dataset, predictions, answers, lengths, all_classes)
101 | else:
102 | score = scorer(dataset, predictions, answers, all_classes)
103 | scores[dataset] = score
104 | if args.e:
105 | out_path = f"pred_e/{args.model}/result.json"
106 | else:
107 | out_path = f"pred/{args.model}/result.json"
108 | with open(out_path, "w") as f:
109 | json.dump(scores, f, ensure_ascii=False, indent=4)
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | # LLaMA model with KIVI
2 | import warnings
3 | warnings.filterwarnings("ignore")
4 | import torch
5 | import random
6 | from models.llama_kivi import LlamaForCausalLM_KIVI
7 | from transformers import LlamaConfig, AutoTokenizer
8 | from datasets import load_dataset
9 |
10 | # For reproducibility
11 | random.seed(0)
12 | torch.manual_seed(0)
13 |
14 | config = LlamaConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
15 |
16 | config.k_bits = 2 # KiVi currently support 2/4 K/V bits
17 | config.v_bits = 2
18 | config.group_size = 32
19 | config.residual_length = 32 # corresponding to the number of recent fp16 tokens
20 | config.use_flash = True
21 |
22 | model = LlamaForCausalLM_KIVI.from_pretrained(
23 | # pretrained_model_name_or_path='meta-llama/Llama-2-7b-hf',
24 | pretrained_model_name_or_path='meta-llama/Llama-3.1-8B-Instruct',
25 | config=config,
26 | low_cpu_mem_usage=True,
27 | torch_dtype=torch.float16,
28 | ).cuda()
29 |
30 | enc = AutoTokenizer.from_pretrained(
31 | 'meta-llama/Llama-3.1-8B-Instruct',
32 | use_fast=False,
33 | trust_remote_code=True)
34 |
35 | dataset = load_dataset('gsm8k', 'main')
36 |
37 | prompt = ''
38 | for i in range(5):
39 | prompt += 'Question: ' + dataset['train'][i]['question'] + '\nAnswer: ' + dataset['train'][i]['answer'] + '\n'
40 | prompt += "Question: John takes care of 10 dogs. Each dog takes .5 hours a day to walk and take care of their business. How many hours a week does he spend taking care of dogs?"
41 | inputs = enc(prompt, return_tensors="pt").input_ids.cuda()
42 |
43 | output = model.generate(inputs, max_new_tokens=96)
44 | config_str = f"# prompt tokens: {inputs.shape[1]}, K bit: {config.k_bits}, v_bits: {config.v_bits}, group_size: {config.group_size}, residual_length: {config.residual_length}"
45 |
46 | print(prompt + "\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nKiVi Output:")
47 | print(enc.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True))
--------------------------------------------------------------------------------
/img/algo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jy-yuan/KIVI/6de8e6e50547f9e240b28b17296affe8a2f034b7/img/algo.png
--------------------------------------------------------------------------------
/img/quant_scheme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jy-yuan/KIVI/6de8e6e50547f9e240b28b17296affe8a2f034b7/img/quant_scheme.png
--------------------------------------------------------------------------------
/long_context_example.py:
--------------------------------------------------------------------------------
1 | # LLaMA model with KIVI
2 | import warnings
3 | warnings.filterwarnings("ignore")
4 | import torch
5 | import json
6 | from models.llama_kivi import LlamaForCausalLM_KIVI
7 | from transformers import LlamaConfig, AutoTokenizer
8 | from datasets import load_dataset
9 |
10 | config = LlamaConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
11 | config.k_bits = 2 # KiVi currently support 2/4 K/V bits
12 | config.v_bits = 2
13 | config.group_size = 32
14 | config.residual_length = 32 # corresponding to the number of recent fp16 tokens
15 | config.use_flash = True # use flash-attention with KiVi for long context inference
16 | CACHE_DIR = "/scratch/cached_model"
17 |
18 | model = LlamaForCausalLM_KIVI.from_pretrained(
19 | pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
20 | config=config,
21 | # cache_dir=CACHE_DIR,
22 | low_cpu_mem_usage=True,
23 | torch_dtype=torch.float16,
24 | ).cuda()
25 |
26 | enc = AutoTokenizer.from_pretrained(
27 | "meta-llama/Llama-3.1-8B-Instruct",
28 | use_fast=False,
29 | trust_remote_code=True,)
30 |
31 | model.eval()
32 | file_name = "passkey_examples.jsonl"
33 | method_name = f"K{config.k_bits}V{config.v_bits} KiVi"
34 | print("=========="*2 + f"**{method_name}**" + "=========="*2)
35 | for line in open(file_name, "r"):
36 | example = json.loads(line)
37 | prompt_postfix = "What is the pass key? The pass key is "
38 | prompt = example["input"] + prompt_postfix
39 | input_ids = enc(prompt, return_tensors="pt").input_ids.cuda()
40 | print( "-----------------------------------" )
41 | print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " )
42 | print( "Passkey target:", example["target"] )
43 |
44 | tokens = model.generate(input_ids, max_new_tokens=len(example["target"]))
45 | answer = prompt_postfix + enc.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
46 | answer = answer.replace("\n", "\\n")
47 | answer= f"{method_name}:\n [ {answer} ]"
48 | print( answer )
49 | print( "-----------------------------------\n" )
50 |
--------------------------------------------------------------------------------
/mem_spd_test.py:
--------------------------------------------------------------------------------
1 | # LLaMA model with KIVI
2 | import torch
3 | import os
4 | from models.llama_kivi import LlamaForCausalLM_KIVI
5 | from transformers import LlamaConfig, AutoTokenizer
6 | import time
7 |
8 | K_BITS = 2
9 | V_BITS = 2
10 | GROUP_SIZE = 32
11 | RESIDUAL_LENGTH = 128
12 | BATCH_SIZE = 96
13 | PATH_TO_YOUR_SAVE_DIR = './cached_models'
14 |
15 | model_name_or_path = 'meta-llama/Llama-2-7b-hf'
16 | config = LlamaConfig.from_pretrained(model_name_or_path)
17 | config.k_bits = K_BITS # current support 2/4 bit for KV Cache
18 | config.v_bits = V_BITS # current support 2/4 bit for KV Cache
19 | config.group_size = GROUP_SIZE
20 | config.residual_length = RESIDUAL_LENGTH # the number of recent fp16 tokens
21 | CACHE_DIR = PATH_TO_YOUR_SAVE_DIR
22 |
23 | if K_BITS < 16 and V_BITS < 16:
24 | model = LlamaForCausalLM_KIVI.from_pretrained(
25 | pretrained_model_name_or_path=model_name_or_path,
26 | config=config,
27 | cache_dir=CACHE_DIR,
28 | torch_dtype=torch.float16,
29 | low_cpu_mem_usage=True,
30 | device_map="auto",
31 | )
32 | else:
33 | from transformers import LlamaForCausalLM
34 | model = LlamaForCausalLM.from_pretrained(
35 | pretrained_model_name_or_path=model_name_or_path,
36 | config=config,
37 | cache_dir=CACHE_DIR,
38 | torch_dtype=torch.float16,
39 | low_cpu_mem_usage=True,
40 | device_map="auto",
41 | )
42 |
43 | tokenizer = AutoTokenizer.from_pretrained(
44 | model_name_or_path,
45 | use_fast=False,
46 | trust_remote_code=True,
47 | tokenizer_type='llama')
48 |
49 | model.cuda().eval()
50 |
51 | context = []
52 | batch_size = BATCH_SIZE
53 | prompt_lenth = 160
54 | output_length = 338
55 | num_repeats = 3
56 | for _ in range(batch_size):
57 | string = 't,' * (prompt_lenth // 2)
58 | context.append(string[:-1])
59 | inputs = tokenizer(context, return_tensors="pt").to('cuda')
60 | input_ids = inputs['input_ids']
61 | print(f"bs: {batch_size}, seqlen: {input_ids.shape[1]}+{output_length}\nmodel:{model_name_or_path}")
62 | torch.cuda.reset_peak_memory_stats()
63 | with torch.no_grad():
64 | torch.cuda.synchronize()
65 | st = time.time()
66 | for i in range(num_repeats):
67 | outputs = model.generate(**inputs, max_new_tokens=output_length)
68 | torch.cuda.synchronize()
69 | print(f'used time: {(time.time() - st) / num_repeats * 1000} ms')
70 | used_mem = torch.cuda.max_memory_allocated()
71 | print(f'peak mem: {used_mem / 1024 ** 3} GB')
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | import jieba
5 | from fuzzywuzzy import fuzz
6 | import difflib
7 |
8 | from typing import List
9 | from collections import Counter
10 | from rouge import Rouge
11 |
12 | def normalize_answer(s):
13 | """Lower text and remove punctuation, articles and extra whitespace."""
14 |
15 | def remove_articles(text):
16 | return re.sub(r"\b(a|an|the)\b", " ", text)
17 |
18 | def white_space_fix(text):
19 | return " ".join(text.split())
20 |
21 | def remove_punc(text):
22 | exclude = set(string.punctuation)
23 | return "".join(ch for ch in text if ch not in exclude)
24 |
25 | def lower(text):
26 | return text.lower()
27 |
28 | return white_space_fix(remove_articles(remove_punc(lower(s))))
29 |
30 |
31 | def normalize_zh_answer(s):
32 | """Lower text and remove punctuation, extra whitespace."""
33 |
34 | def white_space_fix(text):
35 | return "".join(text.split())
36 |
37 | def remove_punc(text):
38 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
39 | all_punctuation = set(string.punctuation + cn_punctuation)
40 | return "".join(ch for ch in text if ch not in all_punctuation)
41 |
42 | def lower(text):
43 | return text.lower()
44 |
45 | return white_space_fix(remove_punc(lower(s)))
46 |
47 | def count_score(prediction, ground_truth, **kwargs):
48 | numbers = re.findall(r"\d+", prediction)
49 | right_num = 0
50 | for number in numbers:
51 | if str(number) == str(ground_truth):
52 | right_num += 1
53 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
54 | return float(final_score)
55 |
56 | def retrieval_score(prediction, ground_truth, **kwargs):
57 | pattern = r'Paragraph (\d+)'
58 | matches = re.findall(pattern, ground_truth)
59 | ground_truth_id = matches[0]
60 | numbers = re.findall(r"\d+", prediction)
61 | right_num = 0
62 | for number in numbers:
63 | if str(number) == str(ground_truth_id):
64 | right_num += 1
65 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
66 | return float(final_score)
67 |
68 | def retrieval_zh_score(prediction, ground_truth, **kwargs):
69 | pattern = r'段落(\d+)'
70 | matches = re.findall(pattern, ground_truth)
71 | ground_truth_id = matches[0]
72 | numbers = re.findall(r"\d+", prediction)
73 | right_num = 0
74 | for number in numbers:
75 | if str(number) == str(ground_truth_id):
76 | right_num += 1
77 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
78 | return float(final_score)
79 |
80 | def code_sim_score(prediction, ground_truth, **kwargs):
81 | all_lines = prediction.lstrip('\n').split('\n')
82 | prediction = ""
83 | for line in all_lines:
84 | if ('`' not in line) and ('#' not in line) and ('//' not in line):
85 | prediction = line
86 | break
87 | return (fuzz.ratio(prediction, ground_truth) / 100)
88 |
89 | def classification_score(prediction, ground_truth, **kwargs):
90 | em_match_list = []
91 | all_classes = kwargs["all_classes"]
92 | for class_name in all_classes:
93 | if class_name in prediction:
94 | em_match_list.append(class_name)
95 | for match_term in em_match_list:
96 | if match_term in ground_truth and match_term != ground_truth:
97 | em_match_list.remove(match_term)
98 | if em_match_list != 0:
99 | if ground_truth in em_match_list:
100 | score = (1.0 / len(em_match_list))
101 | else:
102 | score = 0.0
103 | else:
104 | best_match = None
105 | highest_similarity = 0
106 | for string in all_classes:
107 | similarity = difflib.SequenceMatcher(None, string, prediction).ratio()
108 | if similarity > highest_similarity:
109 | highest_similarity = similarity
110 | best_match = string
111 | score = float(best_match == ground_truth)
112 | return score
113 |
114 | def rouge_score(prediction, ground_truth, **kwargs):
115 | rouge = Rouge()
116 | try:
117 | scores = rouge.get_scores([prediction], [ground_truth], avg=True)
118 | except:
119 | return 0.0
120 | return scores["rouge-l"]["f"]
121 |
122 | def rouge_zh_score(prediction, ground_truth, **kwargs):
123 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
124 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
125 | score = rouge_score(prediction, ground_truth)
126 | return score
127 |
128 | def f1_score(prediction, ground_truth, **kwargs):
129 | common = Counter(prediction) & Counter(ground_truth)
130 | num_same = sum(common.values())
131 | if num_same == 0:
132 | return 0
133 | precision = 1.0 * num_same / len(prediction)
134 | recall = 1.0 * num_same / len(ground_truth)
135 | f1 = (2 * precision * recall) / (precision + recall)
136 | return f1
137 |
138 | def qa_f1_score(prediction, ground_truth, **kwargs):
139 | normalized_prediction = normalize_answer(prediction)
140 | normalized_ground_truth = normalize_answer(ground_truth)
141 |
142 | prediction_tokens = normalized_prediction.split()
143 | ground_truth_tokens = normalized_ground_truth.split()
144 | return f1_score(prediction_tokens, ground_truth_tokens)
145 |
146 |
147 | def qa_f1_zh_score(prediction, ground_truth, **kwargs):
148 | prediction_tokens = list(jieba.cut(prediction, cut_all=False))
149 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
150 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
151 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
152 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
153 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
154 | return f1_score(prediction_tokens, ground_truth_tokens)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jy-yuan/KIVI/6de8e6e50547f9e240b28b17296affe8a2f034b7/models/__init__.py
--------------------------------------------------------------------------------
/models/utils_quant.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 | # 2023.07.05 - Modified weight quantization
9 | # Meta Platforms, Inc.
10 | #
11 | # Copyright 2021 Huawei Technologies Co., Ltd.
12 | #
13 | # Licensed under the Apache License, Version 2.0 (the "License");
14 | # you may not use this file except in compliance with the License.
15 | # You may obtain a copy of the License at
16 | #
17 | # http://www.apache.org/licenses/LICENSE-2.0
18 | #
19 | # Unless required by applicable law or agreed to in writing, software
20 | # distributed under the License is distributed on an "AS IS" BASIS,
21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22 | # See the License for the specific language governing permissions and
23 | # limitations under the License.
24 |
25 | import math
26 |
27 | import torch
28 | import torch.nn.functional as F
29 | import torch.nn as nn
30 | import numpy as np
31 |
32 |
33 | class SymQuantizer(torch.autograd.Function):
34 | """
35 | uniform quantization
36 | """
37 |
38 | @staticmethod
39 | def forward(ctx, input, clip_val, num_bits, layerwise):
40 | """
41 | :param ctx:
42 | :param input: tensor to be quantized
43 | :param clip_val: clip the tensor before quantization
44 | :param quant_bits: number of bits
45 | :return: quantized tensor
46 | """
47 | ctx.save_for_backward(input, clip_val)
48 | # input = torch.clamp(input, clip_val[0], clip_val[1])
49 | # input = torch.where(input < clip_val[1], input, clip_val[1])
50 | # input = torch.where(input > clip_val[0], input, clip_val[0])
51 | # NOTE: dynamic scaling (max_input).
52 | if layerwise:
53 | max_input = torch.max(torch.abs(input)).expand_as(input)
54 | else:
55 | if input.ndimension() <= 3:
56 | # weight & hidden layer
57 | max_input = (
58 | torch.max(torch.abs(input), dim=-1, keepdim=True)[0]
59 | .expand_as(input)
60 | .detach()
61 | )
62 | elif input.ndimension() == 4:
63 | # TODO: attention score matrix, calculate alpha / beta per head
64 | tmp = input.view(input.shape[0], input.shape[1], -1)
65 | max_input = (
66 | torch.max(torch.abs(tmp), dim=-1, keepdim=True)[0]
67 | .unsqueeze(-1)
68 | .expand_as(input)
69 | .detach()
70 | )
71 | else:
72 | raise ValueError
73 | s = (2 ** (num_bits - 1) - 1) / (max_input + 1e-6)
74 | output = torch.round(input * s).div(s + 1e-6)
75 |
76 | return output
77 |
78 | @staticmethod
79 | def backward(ctx, grad_output):
80 | """
81 | :param ctx: saved non-clipped full-precision tensor and clip_val
82 | :param grad_output: gradient ert the quantized tensor
83 | :return: estimated gradient wrt the full-precision tensor
84 | """
85 | input, clip_val = ctx.saved_tensors # unclipped input
86 | grad_input = grad_output.clone()
87 | grad_input[input.ge(clip_val[1])] = 0
88 | grad_input[input.le(clip_val[0])] = 0
89 | return grad_input, None, None, None
90 |
91 |
92 | class AsymQuantizer(torch.autograd.Function):
93 | """
94 | min-max quantization
95 | """
96 |
97 | @staticmethod
98 | def forward(ctx, input, clip_val, num_bits, layerwise):
99 | """
100 | :param ctx:
101 | :param input: tensor to be quantized
102 | :param clip_val: clip the tensor before quantization
103 | :param quant_bits: number of bits
104 | :return: quantized tensor
105 | """
106 | ctx.save_for_backward(input, clip_val)
107 |
108 | # input = torch.where(input < clip_val[1], input, clip_val[1])
109 | # input = torch.where(input > clip_val[0], input, clip_val[0])
110 | # input = torch.clamp(input, clip_val[0], clip_val[1])
111 | # NOTE: dynamic scaling gives better performance than static
112 | if layerwise:
113 | alpha = (input.max() - input.min()).detach()
114 | beta = input.min().detach()
115 | else:
116 | if input.ndimension() <= 3:
117 | # weight & hidden layer
118 | alpha = (
119 | (
120 | input.max(dim=-1, keepdim=True)[0]
121 | - input.min(dim=-1, keepdim=True)[0]
122 | )
123 | .expand_as(input)
124 | .detach()
125 | )
126 | beta = input.min(dim=-1, keepdim=True)[0].expand_as(input).detach()
127 | elif input.ndimension() == 4:
128 | # TODO: attention score matrix, calculate alpha / beta per head
129 | tmp = input.view(input.shape[0], input.shape[1], -1)
130 | alpha = (
131 | (
132 | tmp.max(dim=-1, keepdim=True)[0].unsqueeze(-1)
133 | - tmp.min(dim=-1, keepdim=True)[0].unsqueeze(-1)
134 | )
135 | .expand_as(input)
136 | .detach()
137 | )
138 | beta = (
139 | tmp.min(dim=-1, keepdim=True)[0]
140 | .unsqueeze(-1)
141 | .expand_as(input)
142 | .detach()
143 | )
144 | else:
145 | raise ValueError
146 | input_normalized = (input - beta) / (alpha + 1e-8)
147 | s = 2**num_bits - 1
148 | quant_input = torch.round(input_normalized * s).div(s)
149 | output = quant_input * (alpha + 1e-8) + beta
150 |
151 | return output
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | """
156 | :param ctx: saved non-clipped full-precision tensor and clip_val
157 | :param grad_output: gradient ert the quantized tensor
158 | :return: estimated gradient wrt the full-precision tensor
159 | """
160 | input, clip_val = ctx.saved_tensors # unclipped input
161 | grad_input = grad_output.clone()
162 | grad_input[input.ge(clip_val[1])] = 0
163 | grad_input[input.le(clip_val[0])] = 0
164 | return grad_input, None, None, None
165 |
166 |
167 | class AsymGroupedQuantizer(torch.autograd.Function):
168 | @staticmethod
169 | def forward(ctx, input, clip_val, num_bits, group_size, prec_map_indices=None):
170 | ctx.save_for_backward(input, clip_val)
171 | # input = torch.clamp(input, clip_val[0], clip_val[1])
172 | # input = torch.where(input < clip_val[1], input, clip_val[1])
173 | # input = torch.where(input > clip_val[0], input, clip_val[0])
174 | # NOTE: dynamic scaling (max_input).
175 |
176 | bs, seqlen, d = input.shape
177 | num_groups = d // group_size
178 | if num_groups * group_size != input.shape[-1]:
179 | raise ValueError("group_size should be a factor of the last dimension size")
180 |
181 |
182 | input_in_groups = input.view(bs, seqlen, num_groups, group_size)
183 |
184 | #####
185 | # input_in_groups_cpy = input_in_groups.clone().detach()
186 | #####
187 |
188 | mx, mn = input_in_groups.max(dim=-1)[0], input_in_groups.min(dim=-1)[0]
189 | mx, mn = mx.unsqueeze(-1), mn.unsqueeze(-1)
190 |
191 | scale = (mx - mn) / (2 ** num_bits - 1)
192 | input_in_groups = (input_in_groups - mn) / scale
193 | input_in_groups = F.relu(input_in_groups)
194 | rounded_input_in_groups = input_in_groups.round_()
195 | dequantized_input_in_groups = rounded_input_in_groups * scale + mn
196 |
197 | #####
198 | # if prec_map_indices is not None:
199 | # _, num_heads, _ = prec_map_indices.shape
200 | # for i in range(bs):
201 | # for j in range(num_heads):
202 | # for k in prec_map_indices[i, j]:
203 | # dequantized_input_in_groups[i, k, j, :] = input_in_groups_cpy[i, k, j, :]
204 | #####
205 |
206 | dequantized_input = dequantized_input_in_groups.view(bs, seqlen, -1)
207 | return dequantized_input
208 |
209 | @staticmethod
210 | def backward(ctx, grad_output):
211 | input, clip_val = ctx.saved_tensors
212 | grad_input = grad_output
213 |
214 | # clip version
215 | # grad_input[input.ge(clip_val[1])] = 0
216 | # grad_input[input.le(clip_val[0])] = 0
217 | return grad_input, None, None, None, None
218 |
219 | ### group by channel
220 | class AsymGroupedQuantizerByChannel(torch.autograd.Function):
221 | @staticmethod
222 | def forward(ctx, input, clip_val, num_bits, group_size, prec_map_indices=None):
223 | ctx.save_for_backward(input, clip_val)
224 | # input = torch.clamp(input, clip_val[0], clip_val[1])
225 | # input = torch.where(input < clip_val[1], input, clip_val[1])
226 | # input = torch.where(input > clip_val[0], input, clip_val[0])
227 | bs, seqlen, d = input.shape
228 | mx, mn = input.max(dim=-2)[0], input.min(dim=-2)[0]
229 | mx, mn = mx.unsqueeze(-2), mn.unsqueeze(-2)
230 | scale = (mx - mn) / (2 ** num_bits - 1)
231 | input = (input - mn) / scale
232 | input = F.relu(input)
233 | rounded_input = input.round_()
234 | dequantized_input = rounded_input * scale + mn
235 |
236 | assert dequantized_input.shape == input.shape
237 |
238 | return dequantized_input
239 |
240 | @staticmethod
241 | def backward(ctx, grad_output):
242 | input, clip_val = ctx.saved_tensors
243 | grad_input = grad_output
244 |
245 | # clip version
246 | # grad_input[input.ge(clip_val[1])] = 0
247 | # grad_input[input.le(clip_val[0])] = 0
248 | return grad_input, None, None, None, None
249 |
250 | class QuantizeLinear(nn.Linear):
251 | def __init__(
252 | self,
253 | *kargs,
254 | symmetric=True,
255 | bias=False,
256 | w_bits=32,
257 | a_bits=32,
258 | act_layerwise=False,
259 | weight_layerwise=False,
260 | ):
261 | super(QuantizeLinear, self).__init__(*kargs, bias=False)
262 | self.w_bits = w_bits
263 | self.a_bits = a_bits
264 | self.act_layerwise = act_layerwise
265 | self.weight_layerwise = weight_layerwise
266 | # params for weight quant
267 | # if self.w_bits < 32:
268 | # self.weight_clip_val = Parameter(torch.tensor([-2.0, 2.0]), requires_grad=False)
269 | if self.a_bits < 32 and self.a_bits > 2:
270 | if symmetric:
271 | self.act_quantizer = SymQuantizer
272 | else:
273 | self.act_quantizer = AsymQuantizer
274 |
275 | def forward(self, input_):
276 | # quantize weight
277 | assert len(self.weight.size()) == 2
278 | real_weights = self.weight
279 |
280 | if self.w_bits >= 32:
281 | weight = self.weight
282 | elif self.w_bits >= 3:
283 | weight_clip_val = torch.tensor([-2.0, 2.0])
284 | weight = SymQuantizer.apply(
285 | real_weights, weight_clip_val, self.w_bits, self.weight_layerwise
286 | )
287 | else:
288 | if self.w_bits == 1:
289 | if self.weight_layerwise:
290 | scaling_factor = torch.mean(abs(real_weights)).detach()
291 | else:
292 | scaling_factor = torch.mean(
293 | abs(real_weights), dim=1, keepdim=True
294 | ).detach()
295 | quan_weights_no_grad = scaling_factor * (
296 | torch.sign(real_weights / scaling_factor)
297 | )
298 | # elif self.w_bits == 2:
299 | # scaling_factor = 4/3 * torch.mean(abs(real_weights), dim=1, keepdim=True).detach()
300 | # quan_weights_no_grad = scaling_factor * (torch.round(torch.clamp(real_weights/scaling_factor, -1, 1)))
301 | else:
302 | num_bits = 2 ** (self.w_bits - 1)
303 | clip_val = 1 - 1e-2
304 | if self.weight_layerwise:
305 | scaling_factor = 2 * torch.mean(abs(real_weights)).detach()
306 | else:
307 | scaling_factor = (
308 | 2 * torch.mean(abs(real_weights), dim=1, keepdim=True).detach()
309 | )
310 | quan_weights_no_grad = (
311 | scaling_factor
312 | * (
313 | torch.round(
314 | torch.clamp(
315 | real_weights / scaling_factor, -clip_val, clip_val
316 | )
317 | * num_bits
318 | - 0.5
319 | )
320 | + 0.5
321 | )
322 | / num_bits
323 | )
324 |
325 | weight = (
326 | quan_weights_no_grad.detach() - real_weights.detach() + real_weights
327 | )
328 | # Quantize inputs
329 | if self.a_bits < 32 and self.a_bits > 2:
330 | act_clip_val = torch.tensor([-2.0, 2.0])
331 | input_ = self.act_quantizer.apply(
332 | input_, act_clip_val, self.a_bits, self.act_layerwise
333 | )
334 |
335 | out = nn.functional.linear(input_, weight)
336 | if self.bias is not None:
337 | out += self.bias.view(1, -1).expand_as(out)
338 |
339 | return out
340 |
341 |
342 | def test_group_quantize():
343 | input = torch.randn((4, 16, 1024), dtype=torch.float16, device='cuda')
344 | clip_val = torch.tensor([-2.0, 2.0])
345 | for num_bits, group_size in [ (2, 64), (4, 64), (8, 64), \
346 | (2, 128), (4, 128), (8, 128), \
347 | (2, 256), (4, 256), (8, 256)]:
348 | output = AsymGroupedQuantizer.apply(input, clip_val, num_bits, group_size)
349 | err = torch.mean(torch.abs(input - output)).item()
350 | print(num_bits, group_size, err)
351 | # print(input[0,0,100:150])
352 | # print(output[0,0,100:150])
353 |
354 |
355 | def process_input(input, group_size):
356 | N = input.shape[0]
357 | input_flatten = input.reshape(N, -1)
358 | num_features = input_flatten.shape[1]
359 |
360 | # Compute min, max by groups
361 | if num_features % group_size != 0:
362 | # Padding
363 | new_num_features = (num_features // group_size + 1) * group_size
364 | delta = new_num_features - num_features
365 | input_flatten = torch.cat([input_flatten,
366 | torch.zeros([N, delta], dtype=input.dtype, device=input.device)], 1)
367 |
368 | input_groups = input_flatten.reshape(-1, group_size)
369 | mn, mx = torch.min(input_groups, 1)[0], torch.max(input_groups, 1)[0]
370 | return input_groups.view(N, -1, group_size), mn.view(N, -1), mx.view(N, -1)
371 |
372 |
373 | def quantize_and_pack(data, group_size, bits, simulate=False):
374 | data, mn, mx = process_input(data, group_size)
375 | data = data.transpose(0, 1)
376 | mn = mn.t()
377 | mx = mx.t()
378 | if simulate:
379 | mn, mx = mn.unsqueeze(-1), mx.unsqueeze(-1)
380 | N = data.shape[0]
381 | output = data # N, groups, group_dim
382 | if isinstance(bits, int):
383 | bits = torch.ones(N, dtype=torch.int32, device='cuda') * bits
384 |
385 | B = (2 ** bits - 1).view(N, 1, 1)
386 | mn = mn - 1e-6
387 | mx = mx + 1e-6
388 | scale = B / (mx - mn) # N, groups, 1
389 | output = (output - mn) * scale
390 | output = F.relu(output)
391 | output = torch.min(output, B.float()).round_().int()
392 | else:
393 | # data.shape == B, ng, gz
394 | # mn.shape == B, ng
395 | # mx.shape == B, ng
396 | # import ipdb; ipdb.set_trace()
397 | output, scale = dequant_cuda.pack_single_precision(data, mn, mx, bits, False)
398 | scale = scale.squeeze(-1)
399 | return output, scale, mn
400 |
401 |
402 | def dequantize_and_unpack(data, group_size, shape, bits, scale, mn, simulate=False):
403 | if simulate:
404 | scale, mn = scale.unsqueeze(-1), mn.unsqueeze(-1)
405 | data = data / scale + mn
406 | else:
407 | # Pad to group_size
408 | N = shape[0]
409 | num_features = int(np.prod(shape[1:]))
410 | num_features = (num_features + (group_size - num_features % group_size) % group_size)
411 |
412 | # Unpack bitstream
413 | data = dequant_cuda.unpack_single_precision(data, bits, scale, mn, N, num_features // group_size, group_size)
414 | data = data.view(shape)
415 | return data
416 |
417 |
418 | def process_input_by_channel(input, group_size):
419 | num_features = input.shape[-1]
420 | # input_flatten: [num_feats, bs * seqlen]
421 | input_flatten = input.view(-1, num_features).transpose(0, 1)
422 | num_instances = input_flatten.shape[-1]
423 | # Compute min, max by groups
424 | if num_instances % group_size != 0:
425 | # Padding
426 | new_num_instances = (num_instances // group_size + 1) * group_size
427 | delta = new_num_instances - num_instances
428 | input_flatten = torch.cat([input_flatten,
429 | torch.zeros([num_features, delta], dtype=input.dtype, device=input.device)], 1)
430 | input_groups = input_flatten.reshape(-1, group_size)
431 | mn, mx = torch.min(input_groups, 1)[0], torch.max(input_groups, 1)[0]
432 | return input_groups.view(num_features, -1, group_size), mn.view(num_features, -1), mx.view(num_features, -1)
433 |
434 |
435 | def quantize_by_channel_and_pack(input, group_size, num_bits, simulate=False):
436 | assert len(input.shape) == 3
437 | shape = input.shape
438 | ori_num_instances = shape[0] * shape[1]
439 | input_groups, mn, mx = process_input_by_channel(input, group_size)
440 | if simulate:
441 | mn, mx = mn.unsqueeze(-1), mx.unsqueeze(-1)
442 | scale = (mx - mn) / (2 ** num_bits - 1)
443 | input_groups = (input_groups - mn) / scale
444 | input_groups = F.relu(input_groups)
445 | rounded_input = input_groups.round_()
446 | return rounded_input, scale, mn
447 | # dequantized_input = rounded_input * scale + mn
448 | # dequantized_input = dequantized_input.view(input.shape[-1], -1)
449 | # if ori_num_instances != dequantized_input.shape[1]:
450 | # dequantized_input = dequantized_input[:, 0:ori_num_instances]
451 | # dequantized_input = dequantized_input.transpose(0, 1).view(shape)
452 | # assert dequantized_input.shape == shape
453 | # return dequantized_input, scale, mn
454 | else:
455 | output, scale = dequant_cuda.pack_single_precision(input_groups, mn, mx, num_bits, False)
456 | assert len(scale.shape) >= 2 and len(mn.shape) >= 2
457 | if len(scale.shape) == 3:
458 | scale = scale.squeeze(-1)
459 | if len(mn.shape) == 3:
460 | mn = mn.squeeze(-1)
461 | return output, scale, mn
462 |
463 |
464 | def dequantize_by_channel_and_unpack(data, group_size, shape, bits, scale, mn, simulate=False):
465 | num_feats = shape[-1]
466 | ori_num_instances = shape[0] * shape[1]
467 | if simulate:
468 | # import ipdb; ipdb.set_trace()
469 | data = data * scale + mn
470 | else:
471 | # Pad to group_size
472 | tot_num_instances = (ori_num_instances + (group_size - ori_num_instances % group_size) % group_size)
473 |
474 | # Unpack bitstream
475 | data = dequant_cuda.unpack_single_precision(data, bits, scale, mn, num_feats, tot_num_instances // group_size, group_size)
476 | dequantized_input = data.view(shape[-1], -1)
477 | if ori_num_instances != dequantized_input.shape[1]:
478 | dequantized_input = dequantized_input[:, 0:ori_num_instances]
479 | data = dequantized_input.transpose(0, 1).view(shape)
480 | return data
481 |
482 |
483 | def cal_tensor_size(x):
484 | if isinstance(x, list):
485 | return np.sum([cal_tensor_size(x_) for x_ in x])
486 | elif isinstance(x, torch.Tensor):
487 | num_params = np.prod(x.shape)
488 | if x.dtype == torch.int32:
489 | return num_params * 4
490 | elif x.dtype in [torch.bfloat16, torch.float16]:
491 | return num_params * 2
492 | else:
493 | raise NotImplementedError
494 | else:
495 | raise NotImplementedError
496 |
497 |
498 | def quantize_by_channel_and_pack_cache(input, group_size, num_bits, simulate=False):
499 | ## convert kv_cache shape (bsz, head, seq_len, dim_head) to zirui shape (bsz, seq_len, dim)
500 | assert len(input.shape) == 4
501 | bsz, _, seq_len, _ = input.shape
502 | input = input.transpose(1, 2).reshape(bsz, seq_len, -1)
503 | ##
504 |
505 | shape = input.shape
506 | ori_num_instances = shape[0] * shape[1]
507 | input_groups, mn, mx = process_input_by_channel(input, group_size)
508 | if simulate:
509 | mn, mx = mn.unsqueeze(-1), mx.unsqueeze(-1)
510 | scale = (mx - mn) / (2 ** num_bits - 1)
511 | input_groups = (input_groups - mn) / scale
512 | input_groups = F.relu(input_groups)
513 | rounded_input = input_groups.round_()
514 | return rounded_input, scale, mn
515 | # dequantized_input = rounded_input * scale + mn
516 | # dequantized_input = dequantized_input.view(input.shape[-1], -1)
517 | # if ori_num_instances != dequantized_input.shape[1]:
518 | # dequantized_input = dequantized_input[:, 0:ori_num_instances]
519 | # dequantized_input = dequantized_input.transpose(0, 1).view(shape)
520 | # assert dequantized_input.shape == shape
521 | # return dequantized_input, scale, mn
522 | else:
523 | # import ipdb; ipdb.set_trace()
524 | output, scale = dequant_cuda.pack_single_precision(input_groups, mn, mx, num_bits, False)
525 | assert len(scale.shape) >= 2 and len(mn.shape) >= 2
526 | if len(scale.shape) == 3:
527 | scale = scale.squeeze(-1)
528 | if len(mn.shape) == 3:
529 | mn = mn.squeeze(-1)
530 | return output, scale, mn
531 |
532 |
533 | def dequantize_by_channel_and_unpack_cache(data, group_size, shape, bits, scale, mn, simulate=False):
534 | ## the input shape is not zirui shape (bsz, seq_len, dim), but kv_cache shape (bsz, head, seq_len, dim_head)
535 | # original variables
536 | # num_feats = shape[-1]
537 | # ori_num_instances = shape[0] * shape[1]
538 | assert len(shape) == 4
539 | num_feats = shape[1] * shape[3]
540 | ori_num_instances = shape[0] * shape[2]
541 | ##
542 |
543 | if simulate:
544 | # import ipdb; ipdb.set_trace()
545 | data = data * scale + mn
546 | else:
547 | # Pad to group_size
548 | tot_num_instances = (ori_num_instances + (group_size - ori_num_instances % group_size) % group_size)
549 |
550 | # Unpack bitstream
551 | data = dequant_cuda.unpack_single_precision(data, bits, scale, mn, num_feats, tot_num_instances // group_size, group_size)
552 | dequantized_input = data.view(num_feats, -1)
553 | if ori_num_instances != dequantized_input.shape[1]:
554 | dequantized_input = dequantized_input[:, 0:ori_num_instances]
555 |
556 | ## convert zirui shape (bsz, seq_len, dim) to kv_cache shape (bsz, head, seq_len, dim_head)
557 | # this is last step (this shape is zirui shape) data = dequantized_input.transpose(0, 1).view(shape)
558 | data = dequantized_input.transpose(0, 1).view(shape[0], -1, num_feats)
559 | data = data.view(shape[0], shape[2], shape[1], -1).transpose(1, 2)
560 | ##
561 | assert data.shape == shape
562 |
563 | return data
564 |
565 | def test_channel_quantize():
566 | input = torch.randn((112, 334, 4096), dtype=torch.float16, device='cuda')
567 | shape = input.shape
568 | # for num_bits, group_size in [ (2, 64), (4, 64), (8, 64), \
569 | # (2, 128), (4, 128), (8, 128), \
570 | # (2, 256), (4, 256), (8, 256)]:
571 | for num_bits, group_size in [(2, 128), (4, 128)]:
572 | # fake_code, scale, mn = quantize_by_channel_and_pack(input, group_size, num_bits, True)
573 | # output_fake = dequantize_by_channel_and_unpack(fake_code, group_size, shape, num_bits, scale, mn, True)
574 | # err = torch.mean(torch.abs(input - output_fake)).item()
575 | # print(num_bits, group_size, err)
576 | real_code, scale, mn = quantize_by_channel_and_pack(input, group_size, num_bits, False)
577 | output_real = dequantize_by_channel_and_unpack(real_code, group_size, shape, num_bits, scale, mn, False)
578 | err = torch.mean(torch.abs(input - output_real)).item()
579 | print(num_bits, group_size, err)
580 |
581 | def test_quantize():
582 | input = torch.randn((1, 32, 340, 128), dtype=torch.float16, device='cuda')
583 | shape = input.shape
584 | quantized_v, scale, mn = quantize_and_pack(input, 128, 4, False)
585 | dequantized_v = dequantize_and_unpack(quantized_v, 128, shape, 4, scale, mn, False)
586 |
587 | quantized_v, scale, mn = quantize_by_channel_and_pack_cache(input, 128, 4, False)
588 | dequantized_v = dequantize_by_channel_and_unpack_cache(quantized_v, 128, shape, 4, scale, mn, False)
589 |
590 |
591 | if __name__ == '__main__':
592 | # test_group_quantize()
593 | # test_channel_quantize()
594 | test_quantize()
--------------------------------------------------------------------------------
/pred_long_bench.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datasets import load_dataset
3 | import torch
4 | import json
5 | from tqdm import tqdm
6 | import numpy as np
7 | import random
8 | import argparse
9 | os.environ["WANDB_DISABLED"] = "true"
10 |
11 | from utils.process_args import process_args
12 | from transformers import LlamaConfig, MistralConfig, AutoTokenizer
13 |
14 |
15 | # This is the customized building prompt for chat models
16 | def build_chat(tokenizer, prompt, model_name):
17 | # For results in KIVI paper (Llama, Llama-Chat, Mistral-7B-v0.1), we do not apply any special treatment to the prompt.
18 | # For lmsys/longchat-7b-v1.5-32k and mistralai/Mistral-7B-Instruct-v0.2, we need to rewrite the prompt a little bit.
19 | # Update: we add the template for the new llama-3-instruct model
20 | if "llama-3" in model_name.lower() and "instruct" in model_name.lower():
21 | messages = [
22 | {"role": "user", "content": prompt},
23 | ]
24 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
25 | elif "longchat" in model_name.lower():
26 | from fastchat.model import get_conversation_template
27 | conv = get_conversation_template("vicuna")
28 | conv.append_message(conv.roles[0], prompt)
29 | conv.append_message(conv.roles[1], None)
30 | prompt = conv.get_prompt()
31 | elif "mistral-v0.2-instruct" in model_name.lower():
32 | messages = [
33 | {
34 | "role": "user",
35 | "content": prompt
36 | }
37 | ]
38 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
39 | return prompt
40 |
41 | def post_process(response, model_name):
42 | if "xgen" in model_name:
43 | response = response.strip().replace("Assistant:", "")
44 | elif "internlm" in model_name:
45 | response = response.split("")[0]
46 | return response
47 |
48 | def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name):
49 | preds = []
50 | for json_obj in tqdm(data):
51 | prompt = prompt_format.format(**json_obj)
52 | # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
53 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
54 | # if "chatglm3" in model:
55 | # tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0]
56 | if len(tokenized_prompt) > max_length:
57 | half = int(max_length/2)
58 | prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
59 | if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
60 | prompt = build_chat(tokenizer, prompt, model_name)
61 | input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
62 | context_length = input.input_ids.shape[-1]
63 | if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
64 | output = model.generate(
65 | **input,
66 | max_new_tokens=max_gen,
67 | num_beams=1,
68 | do_sample=False,
69 | temperature=1.0,
70 | min_length=context_length+1,
71 | eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
72 | )[0]
73 | else:
74 | output = model.generate(
75 | **input,
76 | max_new_tokens=max_gen,
77 | num_beams=1,
78 | do_sample=False,
79 | temperature=1.0,
80 | )[0]
81 | pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
82 | pred = post_process(pred, model_name)
83 | preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]})
84 | return preds
85 |
86 | def seed_everything(seed):
87 | torch.manual_seed(seed)
88 | torch.cuda.manual_seed(seed)
89 | np.random.seed(seed)
90 | random.seed(seed)
91 | torch.backends.cudnn.benchmark = False
92 | torch.backends.cudnn.deterministic = True
93 | torch.cuda.manual_seed_all(seed)
94 |
95 | if __name__ == '__main__':
96 | seed_everything(42)
97 | # args = parse_args()
98 | model2path = json.load(open("config/model2path.json", "r"))
99 | model2maxlen = json.load(open("config/model2maxlen.json", "r"))
100 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
101 | # model_name = args.model
102 |
103 | # define your model
104 | model_args, data_args, training_args = process_args()
105 | # print(model_args, data_args, training_args)
106 | model_name = model_args.model_name_or_path.split("/")[-1]
107 | # dtype = torch.bfloat16 if training_args.bf16 else torch.float
108 | dtype = torch.float16
109 |
110 | if 'llama' in model_args.model_name_or_path.lower() or 'longchat' in model_args.model_name_or_path.lower():
111 | config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
112 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path,
113 | use_fast=False,
114 | trust_remote_code=True,
115 | tokenizer_type='llama')
116 | # model_max_length=training_args.model_max_length)
117 | elif 'mistral' in model_args.model_name_or_path.lower():
118 | config = MistralConfig.from_pretrained(model_args.model_name_or_path)
119 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path,
120 | use_fast=False,
121 | trust_remote_code=True)
122 | else:
123 | raise NotImplementedError
124 |
125 | if 'llama' in model_args.model_name_or_path.lower() or 'longchat' in model_args.model_name_or_path.lower():
126 | if model_args.k_bits < 16 and model_args.v_bits < 16:
127 | from models.llama_kivi import LlamaForCausalLM_KIVI
128 | config.k_bits = model_args.k_bits
129 | config.v_bits = model_args.v_bits
130 | config.group_size = model_args.group_size
131 | config.residual_length = model_args.residual_length
132 | config.use_flash = True # Note: We activate the flashattention to speed up the inference
133 | model = LlamaForCausalLM_KIVI.from_pretrained(
134 | pretrained_model_name_or_path=model_args.model_name_or_path,
135 | config=config,
136 | cache_dir=training_args.cache_dir,
137 | torch_dtype=dtype,
138 | low_cpu_mem_usage=True,
139 | device_map="auto",
140 | )
141 | else:
142 | from transformers import LlamaForCausalLM
143 | model = LlamaForCausalLM.from_pretrained(
144 | pretrained_model_name_or_path=model_args.model_name_or_path,
145 | config=config,
146 | cache_dir=training_args.cache_dir,
147 | torch_dtype=dtype,
148 | low_cpu_mem_usage=True,
149 | use_flash_attention_2=True,
150 | device_map="auto",
151 | )
152 |
153 | elif 'mistral' in model_args.model_name_or_path.lower():
154 | if model_args.k_bits < 16 and model_args.v_bits < 16:
155 | from models.mistral_kivi import MistralForCausalLM_KIVI
156 | config.k_bits = model_args.k_bits
157 | config.v_bits = model_args.v_bits
158 | config.group_size = model_args.group_size
159 | config.residual_length = model_args.residual_length
160 | config.use_flash = True
161 | model = MistralForCausalLM_KIVI.from_pretrained(
162 | pretrained_model_name_or_path=model_args.model_name_or_path,
163 | config=config,
164 | cache_dir=training_args.cache_dir,
165 | torch_dtype=dtype,
166 | low_cpu_mem_usage=True,
167 | device_map="auto",
168 | )
169 | else:
170 | from transformers import MistralForCausalLM
171 | model = MistralForCausalLM.from_pretrained(
172 | pretrained_model_name_or_path=model_args.model_name_or_path,
173 | config=config,
174 | cache_dir=training_args.cache_dir,
175 | torch_dtype=dtype,
176 | low_cpu_mem_usage=True,
177 | use_flash_attention_2=True,
178 | device_map="auto",
179 | )
180 |
181 | else:
182 | raise NotImplementedError
183 |
184 | #
185 | # Load model directly
186 | # tokenizer = AutoTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K")
187 | # model = AutoModelForCausalLM.from_pretrained("togethercomputer/LLaMA-2-7B-32K")
188 |
189 | model.eval()
190 | max_length = model2maxlen[model_name]
191 | if data_args.e:
192 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news",
193 | "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
194 | else:
195 | datasets = ["triviaqa", "qasper", "trec", "samsum", "lcc", "repobench-p", "qmsum", "multi_news"]
196 | # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
197 | dataset2prompt = json.load(open("config/dataset2prompt.json", "r"))
198 | dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r"))
199 | # predict on each dataset
200 | if not os.path.exists("pred"):
201 | os.makedirs("pred")
202 | if not os.path.exists("pred_e"):
203 | os.makedirs("pred_e")
204 | for dataset in datasets:
205 | if data_args.e:
206 | data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
207 | if not os.path.exists(f"pred_e/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}"):
208 | os.makedirs(f"pred_e/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}")
209 | out_path = f"pred_e/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}/{dataset}.jsonl"
210 | else:
211 | data = load_dataset('THUDM/LongBench', dataset, split='test')
212 | if not os.path.exists(f"pred/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}"):
213 | os.makedirs(f"pred/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}")
214 | out_path = f"pred/{model_name}_{max_length}_{model_args.k_bits}bits_group{model_args.group_size}_residual{model_args.residual_length}/{dataset}.jsonl"
215 | prompt_format = dataset2prompt[dataset]
216 | max_gen = dataset2maxlen[dataset]
217 | preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name)
218 | with open(out_path, "w", encoding="utf-8") as f:
219 | for pred in preds:
220 | json.dump(pred, f, ensure_ascii=False)
221 | f.write('\n')
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "kivi"
7 | version = "0.1.0"
8 | description = "An tuning-free 2/4/8 bit KV Cache quantization method for LLMs."
9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "packaging==24.0", "sentencepiece", "tokenizers>=0.15",
17 | "torch==2.4.1", "ipdb",
18 | "transformers==4.43.1",
19 | "toml", "attributedict",
20 | "accelerate",
21 | "fastchat",
22 | "protobuf",
23 | "flash-attn",
24 | "datasets"
25 | ]
26 |
27 | [tool.setuptools.packages.find]
28 | exclude = ["results*", "scripts*", "examples*"]
29 |
30 | [tool.wheel]
31 | exclude = ["results*", "scripts*", "examples*"]
32 |
--------------------------------------------------------------------------------
/quant/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jy-yuan/KIVI/6de8e6e50547f9e240b28b17296affe8a2f034b7/quant/__init__.py
--------------------------------------------------------------------------------
/quant/csrc/gemv_cuda.cu:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/ankan-ban/llama_cu_awq
2 | // and the official implementation of AWQ
3 | /*
4 |
5 | @article{lin2023awq,
6 | title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
7 | author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
8 | journal={arXiv},
9 | year={2023}
10 | }
11 |
12 | */
13 |
14 | #include
15 | #include
16 | #include
17 | #include "gemv_cuda.h"
18 | #define VECTORIZE_FACTOR 8
19 | #define Q_VECTORIZE_FACTOR 8
20 | #define PACK_FACTOR 8
21 | #define WARP_SIZE 32
22 |
23 |
24 | // Reduce sum within the warp using the tree reduction algorithm.
25 | __device__ __forceinline__ float warp_reduce_sum(float sum) {
26 | #pragma unroll
27 | for(int i = 4; i >= 0; i--){
28 | sum += __shfl_down_sync(0xffffffff, sum, 1< 64 numbers -> 1 group; 1 warp = 16 groups.
85 | float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]);
86 | float current_zeros = __half2float(zeros[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]);
87 | int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
88 | const float4* inputs_ptr = inputs + inputs_ptr_delta;
89 | // multiply 32 weights with 32 inputs
90 | #pragma unroll
91 | for (int ic_0 = 0; ic_0 < 4; ic_0++){
92 | // iterate over different uint32_t packed_weights in this loop
93 | uint32_t current_packed_weight = packed_weights[ic_0];
94 | half packed_inputs[PACK_FACTOR];
95 | // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
96 | if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
97 | *((float4*)packed_inputs) = *(inputs_ptr + ic_0);
98 | #pragma unroll
99 | for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
100 | // iterate over 8 numbers packed within each uint32_t number
101 | float current_single_weight_fp = (float)(current_packed_weight & 0xF);
102 | float dequantized_weight = scaling_factor * current_single_weight_fp + current_zeros;
103 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
104 | psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
105 | current_packed_weight = current_packed_weight >> 4;
106 | }
107 | }
108 | }
109 | }
110 | psum = warp_reduce_sum(psum);
111 | if (threadIdx.x == 0) {
112 | outputs[oc_idx] = __float2half(psum);
113 | }
114 | }
115 |
116 |
117 | /*
118 | Computes GEMV (group_size = 128).
119 |
120 | Args:
121 | inputs: vector of shape [batch_size, IC];
122 | weight: matrix of shape [OC, IC / 8];
123 | output: vector of shape [OC];
124 | zeros: matrix of shape [OC, IC / group_size / 8];
125 | scaling_factors: matrix of shape [OC, IC / group_size];
126 |
127 | Notes:
128 | One cannot infer group_size from the shape of scaling factors.
129 | the second dimension is rounded up to a multiple of PACK_FACTOR.
130 | */
131 | __global__ void gemv_kernel_g128(
132 | const float4* _inputs, const uint32_t* weight, const half* zeros, const half* scaling_factors, half* _outputs,
133 | const int IC, const int OC){
134 | const int group_size = 128;
135 | float psum = 0;
136 | const int batch_idx = blockIdx.z;
137 | const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
138 | const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR;
139 | half* outputs = _outputs + batch_idx * OC;
140 | const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR);
141 | const int weight_w = IC / PACK_FACTOR;
142 | // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
143 | const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR);
144 | // consistent with input shape
145 | const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR;
146 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
147 | // tile size: 4 OC x 1024 IC per iter
148 | for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){
149 | // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
150 | uint32_t packed_weights[4];
151 | // use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
152 | *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
153 | // load scaling factors
154 | // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
155 | float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]);
156 | float current_zeros = __half2float(zeros[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]);
157 | int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
158 | const float4* inputs_ptr = inputs + inputs_ptr_delta;
159 | // multiply 32 weights with 32 inputs
160 | #pragma unroll
161 | for (int ic_0 = 0; ic_0 < 4; ic_0++){
162 | // iterate over different uint32_t packed_weights in this loop
163 | uint32_t current_packed_weight = packed_weights[ic_0];
164 | half packed_inputs[PACK_FACTOR];
165 | // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
166 | if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
167 | *((float4*)packed_inputs) = *(inputs_ptr + ic_0);
168 | #pragma unroll
169 | for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
170 | // iterate over 8 numbers packed within each uint32_t number
171 | float current_single_weight_fp = (float)(current_packed_weight & 0xF);
172 | float dequantized_weight = scaling_factor * current_single_weight_fp + current_zeros;
173 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
174 | psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
175 | current_packed_weight = current_packed_weight >> 4;
176 | }
177 | }
178 | }
179 | }
180 | psum = warp_reduce_sum(psum);
181 | if (threadIdx.x == 0) {
182 | outputs[oc_idx] = __float2half(psum);
183 | }
184 | }
185 |
186 |
187 | /*
188 | Computes GEMV (PyTorch interface).
189 |
190 | Args:
191 | _in_feats: tensor of shape [B, IC];
192 | _kernel: int tensor of shape [OC, IC // 8];
193 | _zeros: int tensor of shape [OC, IC // G // 8];
194 | _scaling_factors: tensor of shape [OC, IC // G];
195 | blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
196 | blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
197 |
198 | Returns:
199 | out_feats: tensor of shape [B, OC];
200 | */
201 | torch::Tensor gemv_forward_cuda(
202 | torch::Tensor _in_feats,
203 | torch::Tensor _kernel,
204 | torch::Tensor _scaling_factors,
205 | torch::Tensor _zeros,
206 | const int bit,
207 | const int group_size)
208 | {
209 | int num_in_feats = _in_feats.size(0);
210 | int num_in_channels = _in_feats.size(1);
211 | // int kernel_volume = _out_in_map.size(1);
212 | auto in_feats = reinterpret_cast(_in_feats.data_ptr());
213 | auto kernel = reinterpret_cast(_kernel.data_ptr());
214 | auto zeros = reinterpret_cast(_zeros.data_ptr());
215 | auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr());
216 | // auto out_in_map = _out_in_map.data_ptr();
217 | auto options =
218 | torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
219 | // kernel is [OC, IC]
220 | at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options);
221 | int num_out_feats = _out_feats.size(-2);
222 | int num_out_channels = _out_feats.size(-1);
223 | auto out_feats = reinterpret_cast(_out_feats.data_ptr());
224 | int blockDim_z = num_out_feats;
225 | dim3 num_blocks(1, num_out_channels / 4, num_out_feats);
226 | dim3 num_threads(32, 4);
227 | if (group_size == 64)
228 | {
229 | gemv_kernel_g64<<>>(
230 | // pointers
231 | in_feats, kernel, zeros, scaling_factors, out_feats,
232 | // constants
233 | num_in_channels, num_out_channels
234 | );
235 | }
236 | else if (group_size == 128)
237 | {
238 | gemv_kernel_g128<<>>(
239 | // pointers
240 | in_feats, kernel, zeros, scaling_factors, out_feats,
241 | // constants
242 | num_in_channels, num_out_channels
243 | );
244 | }
245 | return _out_feats;
246 | ;}
247 |
248 |
249 |
250 |
251 | /*
252 | Computes Batched 4-bit GEMV (group_size = 64).
253 |
254 | Args:
255 | inputs: vector of shape [BS, 1, IC];
256 | weight: matrix of shape [BS, OC // PACK_FACTOR, IC];
257 | output: vector of shape [BS, 1, OC];
258 | zeros: matrix of shape [BS, OC // group_size, IC];
259 | scaling_factors: matrix of shape [BS, OC // group_size, IC];
260 |
261 | Notes:
262 | One cannot infer group_size from the shape of scaling factors.
263 | the second dimension is rounded up to a multiple of PACK_FACTOR.
264 | */
265 | __global__ void bgemv4_kernel_outer_dim(
266 | const half* _inputs, const uint32_t* _weight, const half* _zeros, const half* _scale, half* _outputs,
267 | const int IC, const int OC, const int group_size, const int nh, const int nh_kv){
268 | const int bit = 4;
269 | const int pack_factor = 8;
270 | const int batch_idx = blockIdx.x;
271 | const int packed_oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
272 | const int oc_start_idx = packed_oc_idx * pack_factor;
273 | const int group_idx = oc_start_idx / group_size;
274 | const half* inputs = _inputs + batch_idx * IC;
275 | half* outputs = _outputs + batch_idx * OC;
276 | const int ratio = nh / nh_kv;
277 | int _batch_idx = batch_idx / ratio;
278 | const uint32_t* weight = _weight + _batch_idx * OC * IC / pack_factor;
279 | const half* scaling_factors = _scale + _batch_idx * OC * IC / group_size;
280 | const half* zeros = _zeros + _batch_idx * OC * IC / group_size;
281 | const int TILE_DIM = 128;
282 | const int num = 0xFF >> (8-bit);
283 | const int ICR = IC;
284 | // 1float4 == 8 half number
285 | float psum[pack_factor]{};
286 | for (int k=0; k < (IC + TILE_DIM - 1) / TILE_DIM; k++){
287 | uint32_t qw[4]{};
288 | half cscale[4]{};
289 | half czero[4]{};
290 | half inp[4]{};
291 | // each thread load 32 int4 number
292 | int weight_offset = packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4;
293 | int scale_mn_offset = group_idx * ICR + k * TILE_DIM + threadIdx.x*4;
294 | int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4;
295 | for (int i=0; i<4; i++){
296 | if (weight_offset + i < OC * ICR / pack_factor)
297 | qw[i] = *(weight + weight_offset + i);
298 | if (scale_mn_offset + i < OC * ICR / group_size){
299 | cscale[i] = *(scaling_factors + scale_mn_offset + i);
300 | czero[i] = *(zeros + scale_mn_offset + i);}
301 | if (inputs_ptr_delta + i < ICR)
302 | inp[i] = *(inputs + inputs_ptr_delta + i);
303 | }
304 | // each thread load 32 int4 number
305 | // int weight_offset = packed_oc_idx * IC + k * TILE_DIM + threadIdx.x*4;
306 | // if (weight_offset < OC * IC / pack_factor)
307 | // *((float4*)(qw)) = *((float4*)(weight + packed_oc_idx * IC + k * TILE_DIM + threadIdx.x*4));
308 | // int scale_mn_offset = group_idx * IC + k * TILE_DIM + threadIdx.x*4;
309 | // if (scale_mn_offset < OC * IC / group_size){
310 | // *((float2*)(cscale)) = *((float2*)(scaling_factors + scale_mn_offset));
311 | // *((float2*)(czero)) = *((float2*)(zeros + scale_mn_offset));
312 | // }
313 | // int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4;
314 | // if (inputs_ptr_delta < IC){
315 | // const half* inputs_ptr = inputs + inputs_ptr_delta;
316 | // *((float2*)(inp)) = *((float2*)(inputs_ptr));
317 | // }
318 | // multiply 32 weights with 32 inputs
319 | #pragma unroll
320 | for (int ic_0 = 0; ic_0 < 4; ic_0++){
321 | uint32_t cur_packed_weight = qw[ic_0];
322 | float cur_inp = __half2float(inp[ic_0]);
323 | float cur_scale = __half2float(cscale[ic_0]);
324 | float cur_zero = __half2float(czero[ic_0]);
325 | for (int ic_1 = 0; ic_1 < pack_factor; ic_1++){
326 | int oc_idx = oc_start_idx + ic_1;
327 | if (oc_idx < OC){
328 | float cur_single_weight_fp = (float)(cur_packed_weight & num);
329 | float dequantized_weight = cur_scale * cur_single_weight_fp + cur_zero;
330 | // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && k == 1) printf("%d %d %d %f %f %f %f %f\n", k, ic_0, ic_1, dequantized_weight, cur_single_weight_fp, cur_scale, cur_zero, cur_inp);
331 | cur_packed_weight = cur_packed_weight >> bit;
332 | psum[ic_1] += dequantized_weight * cur_inp;
333 | }
334 | }
335 | }
336 | }
337 | for (int i=0; i < pack_factor; i++){
338 | int oc_idx = oc_start_idx + i;
339 | if (oc_idx < OC){
340 | psum[i] = warp_reduce_sum(psum[i]);
341 | if (threadIdx.x == 0)
342 | outputs[oc_idx] = __float2half(psum[i]);
343 | }
344 | }
345 | }
346 |
347 |
348 | __global__ void bgemv2_kernel_outer_dim(
349 | const half* _inputs, const uint32_t* _weight, const half* _zeros, const half* _scale, half* _outputs,
350 | const int IC, const int OC, const int group_size, const int nh, const int nh_kv){
351 | // const int group_size = 64;
352 | const int bit = 2;
353 | const int pack_factor = 16;
354 | const int batch_idx = blockIdx.x;
355 | const int packed_oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
356 | const int oc_start_idx = packed_oc_idx * pack_factor;
357 | const int group_idx = oc_start_idx / group_size;
358 | const int ICR = IC;
359 | const half* inputs = _inputs + batch_idx * ICR;
360 | half* outputs = _outputs + batch_idx * OC;
361 | const int ratio = nh / nh_kv;
362 | int _batch_idx = batch_idx / ratio;
363 | const uint32_t* weight = _weight + _batch_idx * OC * IC / pack_factor;
364 | const half* scaling_factors = _scale + _batch_idx * OC * IC / group_size;
365 | const half* zeros = _zeros + _batch_idx * OC * IC / group_size;
366 | const int TILE_DIM = 128;
367 | const int num = 0xFF >> (8-bit);
368 | // 1float4 == 8 half number
369 | float psum[pack_factor]{};
370 | for (int k=0; k < (ICR + TILE_DIM - 1) / TILE_DIM; k++){
371 | uint32_t qw[4]{};
372 | half cscale[4]{};
373 | half czero[4]{};
374 | half inp[4]{};
375 | // each thread load 32 int4 number
376 | int weight_offset = packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4;
377 | int scale_mn_offset = group_idx * ICR + k * TILE_DIM + threadIdx.x*4;
378 | int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4;
379 | for (int i=0; i<4; i++){
380 | if (weight_offset + i < OC * ICR / pack_factor)
381 | qw[i] = *(weight + weight_offset + i);
382 | if (scale_mn_offset + i < OC * ICR / group_size){
383 | cscale[i] = *(scaling_factors + scale_mn_offset + i);
384 | czero[i] = *(zeros + scale_mn_offset + i);}
385 | if (inputs_ptr_delta + i < ICR)
386 | inp[i] = *(inputs + inputs_ptr_delta + i);
387 | }
388 | // if (weight_offset < OC * ICR / pack_factor)
389 | // *((float4*)(qw)) = *((float4*)(weight + packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4));
390 | // int scale_mn_offset = group_idx * ICR + k * TILE_DIM + threadIdx.x*4;
391 | // if (scale_mn_offset < OC * ICR / group_size){
392 | // *((float2*)(cscale)) = *((float2*)(scaling_factors + scale_mn_offset));
393 | // *((float2*)(czero)) = *((float2*)(zeros + scale_mn_offset));
394 | // }
395 | // int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4;
396 | // if (inputs_ptr_delta < ICR){
397 | // const half* inputs_ptr = inputs + inputs_ptr_delta;
398 | // *((float2*)(inp)) = *((float2*)(inputs_ptr));
399 | // }
400 | // multiply 32 weights with 32 inputs
401 | #pragma unroll
402 | for (int ic_0 = 0; ic_0 < 4; ic_0++){
403 | uint32_t cur_packed_weight = qw[ic_0];
404 | float cur_inp = __half2float(inp[ic_0]);
405 | float cur_scale = __half2float(cscale[ic_0]);
406 | float cur_zero = __half2float(czero[ic_0]);
407 | for (int ic_1 = 0; ic_1 < pack_factor; ic_1++){
408 | int oc_idx = oc_start_idx + ic_1;
409 | if (oc_idx < OC){
410 | float cur_single_weight_fp = (float)(cur_packed_weight & num);
411 | float dequantized_weight = cur_scale * cur_single_weight_fp + cur_zero;
412 | // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && k == 1) printf("%d %d %d %f %f %f %f %f\n", k, ic_0, ic_1, dequantized_weight, cur_single_weight_fp, cur_scale, cur_zero, cur_inp);
413 | cur_packed_weight = cur_packed_weight >> bit;
414 | psum[ic_1] += dequantized_weight * cur_inp;
415 | }
416 | }
417 | }
418 | }
419 | for (int i=0; i < pack_factor; i++){
420 | int oc_idx = oc_start_idx + i;
421 | if (oc_idx < OC){
422 | psum[i] = warp_reduce_sum(psum[i]);
423 | if (threadIdx.x == 0)
424 | outputs[oc_idx] = __float2half(psum[i]);
425 | }
426 | }
427 | }
428 |
429 | // __global__ void bgemv2_kernel_g64_outer_dim(
430 | // const half* _inputs, const uint32_t* _weight, const half* _zeros, const half* _scale, half* _outputs,
431 | // const int IC, const int OC){
432 | // const int group_size = 64;
433 | // const int bit = 2;
434 | // const int pack_factor = 16;
435 | // const int batch_idx = blockIdx.x;
436 | // const int packed_oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
437 | // const int oc_start_idx = packed_oc_idx * pack_factor;
438 | // const int group_idx = oc_start_idx / group_size;
439 | // const int ICR = IC;
440 | // const half* inputs = _inputs + batch_idx * ICR;
441 | // half* outputs = _outputs + batch_idx * OC;
442 | // const uint32_t* weight = _weight + batch_idx * OC * IC / pack_factor;
443 | // const half* scaling_factors = _scale + batch_idx * OC * IC / group_size;
444 | // const half* zeros = _zeros + batch_idx * OC * IC / group_size;
445 | // const int TILE_DIM = 128;
446 | // const int num = 0xFF >> (8-bit);
447 | // // 1float4 == 8 half number
448 | // float psum[pack_factor]{};
449 | // for (int k=0; k < (ICR + TILE_DIM - 1) / TILE_DIM; k++){
450 | // uint32_t qw[4]{};
451 | // half cscale[4]{};
452 | // half czero[4]{};
453 | // half inp[4]{};
454 | // // each thread load 32 int4 number
455 | // int weight_offset = packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4;
456 | // if (weight_offset < OC * ICR / pack_factor)
457 | // *((float4*)(qw)) = *((float4*)(weight + packed_oc_idx * ICR + k * TILE_DIM + threadIdx.x*4));
458 | // int scale_mn_offset = group_idx * ICR + k * TILE_DIM + threadIdx.x*4;
459 | // if (scale_mn_offset < OC * ICR / group_size){
460 | // *((float2*)(cscale)) = *((float2*)(scaling_factors + scale_mn_offset));
461 | // *((float2*)(czero)) = *((float2*)(zeros + scale_mn_offset));
462 | // }
463 | // int inputs_ptr_delta = k * TILE_DIM + threadIdx.x * 4;
464 | // if (inputs_ptr_delta < ICR){
465 | // const half* inputs_ptr = inputs + inputs_ptr_delta;
466 | // *((float2*)(inp)) = *((float2*)(inputs_ptr));
467 | // }
468 | // // multiply 32 weights with 32 inputs
469 | // #pragma unroll
470 | // for (int ic_0 = 0; ic_0 < 4; ic_0++){
471 | // uint32_t cur_packed_weight = qw[ic_0];
472 | // float cur_inp = __half2float(inp[ic_0]);
473 | // float cur_scale = __half2float(cscale[ic_0]);
474 | // float cur_zero = __half2float(czero[ic_0]);
475 | // for (int ic_1 = 0; ic_1 < pack_factor; ic_1++){
476 | // int oc_idx = oc_start_idx + ic_1;
477 | // if (oc_idx < OC){
478 | // float cur_single_weight_fp = (float)(cur_packed_weight & num);
479 | // float dequantized_weight = cur_scale * cur_single_weight_fp + cur_zero;
480 | // // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && k == 1) printf("%d %d %d %f %f %f %f %f\n", k, ic_0, ic_1, dequantized_weight, cur_single_weight_fp, cur_scale, cur_zero, cur_inp);
481 | // cur_packed_weight = cur_packed_weight >> bit;
482 | // psum[ic_1] += dequantized_weight * cur_inp;
483 | // }
484 | // }
485 | // }
486 | // }
487 | // for (int i=0; i < pack_factor; i++){
488 | // int oc_idx = oc_start_idx + i;
489 | // if (oc_idx < OC){
490 | // psum[i] = warp_reduce_sum(psum[i]);
491 | // if (threadIdx.x == 0)
492 | // outputs[oc_idx] = __float2half(psum[i]);
493 | // }
494 | // }
495 | // }
496 |
497 |
498 | /*
499 | Computes GEMV (PyTorch interface).
500 |
501 | Args:
502 | _in_feats: tensor of shape [B, IC];
503 | _kernel: int tensor of shape [OC // PACK_Factor, IC];
504 | _zeros: int tensor of shape [OC // G, IC];
505 | _scaling_factors: tensor of shape [OC // G, IC];
506 | blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
507 | blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
508 | Returns:
509 | out_feats: tensor of shape [B, OC];
510 | */
511 | torch::Tensor gemv_forward_cuda_outer_dim(
512 | torch::Tensor _in_feats,
513 | torch::Tensor _kernel,
514 | torch::Tensor _scaling_factors,
515 | torch::Tensor _zeros,
516 | const int bit,
517 | const int group_size,
518 | const int nh,
519 | const int nh_kv)
520 | {
521 | int BS = _in_feats.size(0);
522 | int num_in_feats = _in_feats.size(1);
523 | int num_in_channels = _in_feats.size(2);
524 | int num_out_channels = _zeros.size(1) * group_size;
525 | // int kernel_volume = _out_in_map.size(1);
526 | auto in_feats = reinterpret_cast(_in_feats.data_ptr());
527 | auto kernel = reinterpret_cast(_kernel.data_ptr());
528 | auto zeros = reinterpret_cast(_zeros.data_ptr());
529 | auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr());
530 | // auto out_in_map = _out_in_map.data_ptr();
531 | auto options =
532 | torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
533 | // kernel is [OC, IC]
534 | at::Tensor _out_feats = torch::empty({BS, num_in_feats, num_out_channels}, options);
535 | int num_out_feats = _out_feats.size(-2);
536 | auto out_feats = reinterpret_cast(_out_feats.data_ptr());
537 | int pack_factor = 32 / bit;
538 | dim3 num_blocks(BS, (num_out_channels / pack_factor + 3) / 4, num_out_feats);
539 | dim3 num_threads(32, 4);
540 | if (bit == 4){
541 | bgemv4_kernel_outer_dim<<>>(
542 | // pointers
543 | in_feats, kernel, zeros, scaling_factors, out_feats,
544 | // constants
545 | num_in_channels, num_out_channels, group_size, nh, nh_kv
546 | );}
547 | else{
548 | // note: in this case, pack factor == 16
549 | bgemv2_kernel_outer_dim<<>>(
550 | // pointers
551 | in_feats, kernel, zeros, scaling_factors, out_feats,
552 | // constants
553 | num_in_channels, num_out_channels, group_size, nh, nh_kv
554 | );
555 | }
556 | return _out_feats;
557 | ;}
558 |
--------------------------------------------------------------------------------
/quant/csrc/gemv_cuda.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | torch::Tensor gemv_forward_cuda(
5 | torch::Tensor _in_feats,
6 | torch::Tensor _kernel,
7 | torch::Tensor _scaling_factors,
8 | torch::Tensor _zeros,
9 | const int bit,
10 | const int group_size);
11 |
12 |
13 | torch::Tensor gemv_forward_cuda_outer_dim(
14 | torch::Tensor _in_feats,
15 | torch::Tensor _kernel,
16 | torch::Tensor _scaling_factors,
17 | torch::Tensor _zeros,
18 | const int bit,
19 | const int group_size,
20 | const int nh,
21 | const int nh_kv);
--------------------------------------------------------------------------------
/quant/csrc/gemv_cuda_backup.cu:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/ankan-ban/llama_cu_awq
2 | /*
3 |
4 | @article{lin2023awq,
5 | title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
6 | author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
7 | journal={arXiv},
8 | year={2023}
9 | }
10 |
11 | */
12 |
13 | #include
14 | #include
15 | #include
16 | #include "gemv_cuda.h"
17 | #define VECTORIZE_FACTOR 8
18 | #define Q_VECTORIZE_FACTOR 8
19 | #define WARP_SIZE 32
20 |
21 |
22 | // Reduce sum within the warp using the tree reduction algorithm.
23 | __device__ __forceinline__ float warp_reduce_sum(float sum) {
24 | #pragma unroll
25 | for(int i = 4; i >= 0; i--){
26 | sum += __shfl_down_sync(0xffffffff, sum, 1<
59 | __global__ void gemv_kernel_g64(
60 | const float4* _inputs, const uint32_t* weight, const half* zeros, const half* scaling_factors, half* _outputs,
61 | const int IC, const int OC){
62 | const int group_size = 64;
63 | float psum = 0;
64 | const int batch_idx = blockIdx.z;
65 | const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
66 | const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR;
67 | half* outputs = _outputs + batch_idx * OC;
68 | const int num_groups_packed = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2;
69 | const int weight_w = IC / PACK_FACTOR;
70 | // consistent with input shape
71 | const int sf_w = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2 * PACK_FACTOR;
72 | // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w, sf_w);
73 | int elem_per_th = 128 / bit;
74 | int ng_per_warp = 32 * elem_per_th / 64;
75 | // tile size: 4 OC x (128 * PACK_FACTOR) IC per iter
76 | for(int packed_group_idx = 0; packed_group_idx < num_groups_packed / 2; packed_group_idx++){
77 | uint32_t packed_weights[4];
78 | // use float4 to load weights, each thread load (64,32,16) int-(2,4,8) numbers (1 x float4)
79 | *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
80 | // load scaling factors
81 | // 1 warp == 32 threads
82 | // g64: 1 threads -> 64,32,16 numbers -> 1,.5,0.25 group; 1 warp = 32,16,8 groups.
83 | // TODO: from here
84 | // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && packed_group_idx == 0) printf("%d %d\n", elem_per_th, ng_per_warp);
85 | float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * ng_per_warp + (threadIdx.x*ng_per_warp/32)]);
86 | float current_zeros = __half2float(zeros[oc_idx * sf_w + packed_group_idx * ng_per_warp + (threadIdx.x*ng_per_warp/32)]);
87 | int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
88 | const float4* inputs_ptr = inputs + inputs_ptr_delta;
89 | const int num = 0xFF >> (8-bit);
90 | // multiply (64,32,16) weights with (64,32,16) inputs
91 | #pragma unroll
92 | for (int ic_0 = 0; ic_0 < 4; ic_0++){
93 | // iterate over different uint32_t packed_weights in this loop
94 | uint32_t current_packed_weight = packed_weights[ic_0];
95 | half packed_inputs[PACK_FACTOR];
96 | // each thread load (16,8,4) inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
97 | if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
98 | // TODO: bug is here!! for 4-bit, one float4 == 8 half number == packed_inputs[8]
99 | *((float4*)packed_inputs) = *(inputs_ptr + ic_0);
100 | #pragma unroll
101 | for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
102 | // iterate over (16,8,4) numbers packed within each uint32_t number
103 | float current_single_weight_fp = (float)(current_packed_weight & num);
104 | float dequantized_weight = scaling_factor * current_single_weight_fp + current_zeros;
105 | if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 1 && packed_group_idx == 0) printf("%f %f %f %f %f\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, __half2float(packed_inputs[ic_1]));
106 | psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
107 | current_packed_weight = current_packed_weight >> bit;
108 | }
109 | }
110 | }
111 | }
112 | psum = warp_reduce_sum(psum);
113 | if (threadIdx.x == 0) {
114 | outputs[oc_idx] = __float2half(psum);
115 | }
116 | }
117 |
118 |
119 | /*
120 | Computes GEMV (group_size = 128).
121 |
122 | Args:
123 | inputs: vector of shape [batch_size, IC];
124 | weight: matrix of shape [OC, IC / 8];
125 | output: vector of shape [OC];
126 | zeros: matrix of shape [OC, IC / group_size / 8];
127 | scaling_factors: matrix of shape [OC, IC / group_size];
128 |
129 | Notes:
130 | One cannot infer group_size from the shape of scaling factors.
131 | the second dimension is rounded up to a multiple of PACK_FACTOR.
132 | */
133 | template
134 | __global__ void gemv_kernel_g128(
135 | const float4* _inputs, const uint32_t* weight, const half* zeros, const half* scaling_factors, half* _outputs,
136 | const int IC, const int OC){
137 | const int group_size = 128;
138 | float psum = 0;
139 | const int batch_idx = blockIdx.z;
140 | const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
141 | const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR;
142 | half* outputs = _outputs + batch_idx * OC;
143 | const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR);
144 | const int weight_w = IC / PACK_FACTOR;
145 | // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
146 | const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR);
147 | // consistent with input shape
148 | const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR;
149 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
150 | // tile size: 4 OC x 1024 IC per iter
151 | for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){
152 | // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
153 | uint32_t packed_weights[4];
154 | // use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
155 | *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
156 | // load scaling factors
157 | // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
158 | float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]);
159 | float current_zeros = __half2float(zeros[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]);
160 | int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
161 | const float4* inputs_ptr = inputs + inputs_ptr_delta;
162 | const int num = 0xFF >> (8-bit);
163 | // multiply 32 weights with 32 inputs
164 | #pragma unroll
165 | for (int ic_0 = 0; ic_0 < 4; ic_0++){
166 | // iterate over different uint32_t packed_weights in this loop
167 | uint32_t current_packed_weight = packed_weights[ic_0];
168 | half packed_inputs[PACK_FACTOR];
169 | // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
170 | if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
171 | *((float4*)packed_inputs) = *(inputs_ptr + ic_0);
172 | #pragma unroll
173 | for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
174 | // iterate over 8 numbers packed within each uint32_t number
175 | float current_single_weight_fp = (float)(current_packed_weight & num);
176 | float dequantized_weight = scaling_factor * current_single_weight_fp + current_zeros;
177 | //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
178 | psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
179 | current_packed_weight = current_packed_weight >> bit;
180 | }
181 | }
182 | }
183 | }
184 | psum = warp_reduce_sum(psum);
185 | if (threadIdx.x == 0) {
186 | outputs[oc_idx] = __float2half(psum);
187 | }
188 | }
189 |
190 |
191 | /*
192 | Computes GEMV (PyTorch interface).
193 |
194 | Args:
195 | _in_feats: tensor of shape [B, IC];
196 | _kernel: int tensor of shape [OC, IC // 8];
197 | _zeros: int tensor of shape [OC, IC // G // 8];
198 | _scaling_factors: tensor of shape [OC, IC // G];
199 | blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
200 | blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
201 |
202 | Returns:
203 | out_feats: tensor of shape [B, OC];
204 | */
205 | torch::Tensor gemv_forward_cuda(
206 | torch::Tensor _in_feats,
207 | torch::Tensor _kernel,
208 | torch::Tensor _scaling_factors,
209 | torch::Tensor _zeros,
210 | const int bit,
211 | const int group_size)
212 | {
213 | int num_in_feats = _in_feats.size(0);
214 | int num_in_channels = _in_feats.size(1);
215 | // int kernel_volume = _out_in_map.size(1);
216 | auto in_feats = reinterpret_cast(_in_feats.data_ptr());
217 | auto kernel = reinterpret_cast(_kernel.data_ptr());
218 | auto zeros = reinterpret_cast(_zeros.data_ptr());
219 | auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr());
220 | // auto out_in_map = _out_in_map.data_ptr();
221 | auto options =
222 | torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
223 | // kernel is [OC, IC]
224 | at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options);
225 | int num_out_feats = _out_feats.size(-2);
226 | int num_out_channels = _out_feats.size(-1);
227 | auto out_feats = reinterpret_cast(_out_feats.data_ptr());
228 | int blockDim_z = num_out_feats;
229 | dim3 num_blocks(1, num_out_channels / 4, num_out_feats);
230 | dim3 num_threads(32, 4);
231 | if (bit == 2){
232 | if (group_size == 64)
233 | {
234 | gemv_kernel_g64<2, 16><<>>(
235 | // pointers
236 | in_feats, kernel, zeros, scaling_factors, out_feats,
237 | // constants
238 | num_in_channels, num_out_channels
239 | );
240 | }
241 | else if (group_size == 128)
242 | {
243 | gemv_kernel_g128<2, 16><<>>(
244 | // pointers
245 | in_feats, kernel, zeros, scaling_factors, out_feats,
246 | // constants
247 | num_in_channels, num_out_channels
248 | );}
249 | }else if (bit == 4){
250 | if (group_size == 64)
251 | {
252 | gemv_kernel_g64<4, 8><<>>(
253 | // pointers
254 | in_feats, kernel, zeros, scaling_factors, out_feats,
255 | // constants
256 | num_in_channels, num_out_channels
257 | );
258 | }
259 | else if (group_size == 128)
260 | {
261 | gemv_kernel_g128<4, 8><<>>(
262 | // pointers
263 | in_feats, kernel, zeros, scaling_factors, out_feats,
264 | // constants
265 | num_in_channels, num_out_channels
266 | );
267 | };}
268 | else{
269 | if (group_size == 64)
270 | {
271 | gemv_kernel_g64<8, 4><<>>(
272 | // pointers
273 | in_feats, kernel, zeros, scaling_factors, out_feats,
274 | // constants
275 | num_in_channels, num_out_channels
276 | );
277 | }
278 | else if (group_size == 128)
279 | {
280 | gemv_kernel_g128<8, 4><<>>(
281 | // pointers
282 | in_feats, kernel, zeros, scaling_factors, out_feats,
283 | // constants
284 | num_in_channels, num_out_channels
285 | );
286 | }}
287 | return _out_feats;
288 | }
289 |
290 | /*
291 | Computes GEMV (group_size = 64).
292 |
293 | Args:
294 | inputs: vector of shape [batch_size, IC];
295 | weight: matrix of shape [OC // PACK_FACTOR, IC;
296 | output: vector of shape [OC];
297 | zeros: matrix of shape [OC // group_size, IC];
298 | scaling_factors: matrix of shape [OC // group_size, IC];
299 |
300 | Notes:
301 | One cannot infer group_size from the shape of scaling factors.
302 | the second dimension is rounded up to a multiple of PACK_FACTOR.
303 | */
304 | __global__ void gemv_kernel_g64_outer_dim(
305 | const float4* _inputs, const uint32_t* weight, const half* zeros, const half* scaling_factors, half* _outputs,
306 | const int IC, const int OC){
307 | const int group_size = 64;
308 | float psum = 0;
309 | const int pack_factor = 8;
310 | const int batch_idx = blockIdx.z;
311 | const int packed_oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
312 | const int group_idx = packed_oc_idx * pack_factor / group_size;
313 | const float4* inputs = _inputs + batch_idx * IC;
314 | half* outputs = _outputs + batch_idx * OC;
315 | const int TILE_DIM = 32;
316 | extern __shared__ uint32_t packed_weight_shared[TILE_DIM];
317 | extern __shared__ float scale_shared[TILE_DIM];
318 | extern __shared__ float mn_shared[TILE_DIM];
319 | for (int k=0; k < (IC + TILE_DIM - 1) / TILE_DIM; k++){
320 | if (packed_oc_idx * pack_factor < OC && k*TILE_DIM+threadIdx.x < IC)
321 | packed_weight_shared[threadIdx.x] = weight[packed_oc_idx * IC + k * TILE_DIM + threadIdx.x];
322 | else
323 | packed_weight_shared[threadIdx.x] = 0;
324 | if (group_idx * group_size < OC && k*TILE_DIM+threadIdx.x < IC){
325 | scale_shared[threadIdx.x] = __half2float(scaling_factors[oc_idx / group_size * IC + k * TILE_DIM + threadIdx.x]);
326 | mn_shared[threadIdx.x] = __half2float(zeros[oc_idx / group_size * IC + k * TILE_DIM + threadIdx.x]);
327 | }
328 | else{
329 | scale_shared[threadIdx.x] = 0.0;
330 | mn_shared[threadIdx.x] = 0.0;
331 | }
332 | __syncthreads();
333 | }
334 | }
335 |
--------------------------------------------------------------------------------
/quant/csrc/pybind.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "gemv_cuda.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
6 | {
7 | m.def("gemv_forward_cuda", &gemv_forward_cuda);
8 | m.def("gemv_forward_cuda_outer_dim", &gemv_forward_cuda_outer_dim);
9 | }
--------------------------------------------------------------------------------
/quant/gemv.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"]="2"
3 | os.environ["CUDA_LAUNCH_BLOCKING"]="1"
4 | import numpy as np
5 | import torch
6 | # import ipdb
7 | import random
8 | import triton
9 | import triton.language as tl
10 | from new_pack import pack_tensor
11 | from timeit_v2 import py_benchmark
12 | import kivi_gemv
13 |
14 | B, nh, IC, OC = 8, 32, 739, 128
15 |
16 | @triton.jit
17 | def gemv_kernel_g64(inputs_ptr, qw_ptr, mn_ptr,
18 | scale_ptr, output_ptr,
19 | IC: tl.constexpr, OC: tl.constexpr, bit: tl.constexpr,
20 | OC_PER_PH: tl.constexpr, PACK_FACTOR: tl.constexpr, BLOCK_SIZE):
21 | """
22 | Computes GEMV (group_size = 64).
23 |
24 | Args:
25 | inputs: vector of shape [batch_size, IC];
26 | qw: matrix of shape [OC, IC / 8];
27 | output: vector of shape [OC];
28 | mn: matrix of shape [OC, NG];
29 | scale: matrix of shape [OC, NG];
30 |
31 | Notes:
32 | One cannot infer group_size from the shape of scaling factors.
33 | the second dimension is rounded up to a multiple of PACK_FACTOR.
34 | """
35 | group_size = 64
36 | oc_idx = tl.program_id(axis=0) * OC_PER_PH + tl.arange(0, OC_PER_PH)
37 | batch_idx = tl.program_id(axis=1)
38 | num_groups = IC // group_size
39 | num_groups_packed = tl.cdiv(num_groups, PACK_FACTOR)
40 | # tl.store(output_ptr, num_groups_packed)
41 | weight_w = IC // PACK_FACTOR
42 | num = 0xFF >> (8-bit)
43 | accumulator = tl.zeros((OC_PER_PH,), dtype=tl.float32)
44 | for group_idx in range(0, num_groups):
45 | # load scaling factors
46 | # each time we load 4 OC x 1 G
47 | scale = tl.load(scale_ptr + oc_idx[:, None] * num_groups + group_idx)
48 | mn = tl.load(mn_ptr + oc_idx[:, None] * num_groups + group_idx)
49 | # 1 G -> 64 numbers -> 64 // PACK_FACTOR packed numbers
50 | cur_qw_ptr = qw_ptr + oc_idx[:, None] * weight_w + group_idx * (64 // PACK_FACTOR) + tl.arange(0, 64 // PACK_FACTOR)[None, :]
51 | qw = tl.load(cur_qw_ptr)
52 | for i in range(PACK_FACTOR):
53 | w_fp = qw & num
54 | # load 4 OC x
55 | w_fp = w_fp * scale + mn
56 | qw = qw >> bit
57 | cur_inp_ptr = inputs_ptr + batch_idx * IC + group_idx * 64 + i + tl.arange(0, 64 // PACK_FACTOR)[None, :] * PACK_FACTOR
58 | cur_input = tl.load(cur_inp_ptr)
59 | accumulator += tl.sum(cur_input * w_fp, 1)
60 | ptr = output_ptr + oc_idx + batch_idx * OC
61 | tl.store(ptr, accumulator)
62 |
63 |
64 | def dequant_weight(w, scale, mn, gs):
65 | w_fp = w.half().view(w.shape[0], w.shape[1]//gs, gs)
66 | w_fp = w_fp * scale.unsqueeze(-1) + mn.unsqueeze(-1)
67 | return w_fp.view(w.shape)
68 |
69 |
70 | def dequant_weight_outer(w, scale, mn, gs):
71 | # ipdb.set_trace()
72 | w_fp = w.half().view(w.shape[0], w.shape[1], w.shape[2]//gs, gs)
73 | w_fp = w_fp * scale.unsqueeze(-1) + mn.unsqueeze(-1)
74 | return w_fp.view(w.shape)
75 |
76 |
77 | def gemv_fwd(bit, group_size, inp, qweight, mn, scale):
78 | B, IC = inp.shape
79 | OC = qweight.shape[0]
80 | BLOCK_SIZE = 32
81 | OC_PER_PH = 32
82 | PACK_FACTOR = 32 // bit
83 | assert group_size == 64
84 | output = torch.empty((B, OC), device=inp.device, dtype=torch.float16)
85 | grid = lambda META: (
86 | triton.cdiv(OC, META['OC_PER_PH']), B
87 | )
88 | gemv_kernel_g64[grid](inp, qweight, mn, scale, output,
89 | IC, OC, bit, OC_PER_PH, PACK_FACTOR, BLOCK_SIZE)
90 | return output
91 |
92 |
93 | def test_bgemv_outer_correct_mha():
94 | flatten_B = B * nh
95 | inp = torch.randn((flatten_B, 1, IC), device='cuda', dtype=torch.float16)
96 | ori_weight = torch.randn((flatten_B, IC, OC), device='cuda', dtype=torch.float16)
97 | GS = 32
98 | for BIT in [2, 4]:
99 | weight = ori_weight
100 | PACK_FACTOR = 32 // BIT
101 | assert OC % GS == 0 and OC % PACK_FACTOR == 0
102 | NG = OC // GS
103 | weight = weight.view(flatten_B, IC, NG, GS)
104 | mx = torch.max(weight, dim=-1, keepdim=False)[0]
105 | mn = torch.min(weight, dim=-1, keepdim=False)[0]
106 | maxq = 2 ** BIT - 1
107 | scale = (mx - mn) / maxq
108 | weight = weight - mn.unsqueeze(-1)
109 | weight.div_(scale.unsqueeze(-1))
110 | weight = weight.clamp_(0, maxq).round_().to(torch.int32)
111 | weight = weight.view(flatten_B, IC, OC)
112 | qweight = pack_tensor(weight, BIT, 2)
113 | weight = weight.transpose(1, 2).contiguous()
114 | qweight = qweight.transpose(1, 2).contiguous()
115 | scale = scale.transpose(1, 2).contiguous()
116 | mn = mn.transpose(1, 2).contiguous()
117 | output = kivi_gemv.gemv_forward_cuda_outer_dim(inp, qweight, scale, mn, BIT, GS, nh, False)
118 | deq_w = dequant_weight_outer(weight.transpose(1, 2),
119 | scale.transpose(1, 2),
120 | mn.transpose(1, 2), GS)
121 | # rel_error = torch.abs((deq_w - ori_weight).float() / (ori_weight + 1e-5).float()).mean()
122 | # print(f'bit {BIT} avg rel weight quant error: {rel_error}')
123 | output_ref = inp @ deq_w
124 | error = output_ref - output
125 | rel_out_error = torch.abs(error.float() / (torch.abs(output_ref).float()+1e-5)).mean()
126 | print(f'mha bit {BIT} avg rel out quant error: {rel_out_error}')
127 |
128 |
129 | def test_bgemv_outer_correct_mqa():
130 | flatten_B = B * nh
131 | inp = torch.randn((flatten_B, 1, IC), device='cuda', dtype=torch.float16)
132 | ori_weight = torch.randn((B, IC, OC), device='cuda', dtype=torch.float16)
133 | GS = 32
134 | for BIT in [2, 4]:
135 | weight = ori_weight
136 | PACK_FACTOR = 32 // BIT
137 | assert OC % GS == 0 and OC % PACK_FACTOR == 0
138 | NG = OC // GS
139 | weight = weight.view(B, IC, NG, GS)
140 | mx = torch.max(weight, dim=-1, keepdim=False)[0]
141 | mn = torch.min(weight, dim=-1, keepdim=False)[0]
142 | maxq = 2 ** BIT - 1
143 | scale = (mx - mn) / maxq
144 | weight = weight - mn.unsqueeze(-1)
145 | weight.div_(scale.unsqueeze(-1))
146 | weight = weight.clamp_(0, maxq).round_().to(torch.int32)
147 | weight = weight.view(B, IC, OC)
148 | qweight = pack_tensor(weight, BIT, 2)
149 | inp = inp.contiguous()
150 | weight = weight.transpose(1, 2).contiguous()
151 | qweight = qweight.transpose(1, 2).contiguous()
152 | scale = scale.transpose(1, 2).contiguous()
153 | mn = mn.transpose(1, 2).contiguous()
154 | output = kivi_gemv.gemv_forward_cuda_outer_dim(inp, qweight, scale, mn, BIT, GS, nh, True)
155 | deq_w = dequant_weight_outer(weight.transpose(1, 2),
156 | scale.transpose(1, 2),
157 | mn.transpose(1, 2), GS)
158 | # rel_error = torch.abs((deq_w - ori_weight).float() / (ori_weight + 1e-5).float()).mean()
159 | # print(f'bit {BIT} avg rel weight quant error: {rel_error}')
160 | output_ref = inp.view(B, nh, 1, IC) @ deq_w.view(B, 1, IC, OC)
161 | output_ref = output_ref.view(flatten_B, 1, OC)
162 | error = output_ref - output
163 | # ipdb.set_trace()
164 | rel_out_error = torch.abs(error.float() / (torch.abs(output_ref).float()+1e-5)).mean()
165 | print(f'mqa bit {BIT} avg rel out quant error: {rel_out_error}')
166 |
167 |
168 | def test_gemv_correct():
169 | inp = torch.randn((B, IC), device='cuda', dtype=torch.float16)
170 | ori_weight = torch.randn((OC, IC), device='cuda', dtype=torch.float16)
171 | GS = 64
172 | for BIT in [4]:
173 | weight = ori_weight
174 | PACK_FACTOR = 32 // BIT
175 | assert IC % GS == 0 and IC % PACK_FACTOR == 0
176 | NG = IC // GS
177 | weight = weight.view(OC, NG, GS)
178 | mx = torch.max(weight, dim=2, keepdim=False)[0]
179 | mn = torch.min(weight, dim=2, keepdim=False)[0]
180 | maxq = 2 ** BIT - 1
181 | scale = (mx - mn) / maxq
182 | weight = weight - mn.unsqueeze(-1)
183 | weight.div_(scale.unsqueeze(-1))
184 | weight = weight.clamp_(0, maxq).round_().to(torch.int32)
185 | weight = weight.view(OC, IC)
186 | qweight = pack_tensor(weight, BIT, 1)
187 | # output = gemv_fwd(BIT, GS, inp, qweight, mn, scale)
188 | output = kivi_gemv.gemv_forward_cuda(inp, qweight, scale, mn, BIT, GS)
189 | deq_w = dequant_weight(weight, scale, mn, GS)
190 | rel_error = torch.abs((deq_w - ori_weight).float() / (ori_weight + 1e-5).float()).mean()
191 | # print(f'bit {BIT} avg rel weight quant error: {rel_error}')
192 | output_ref = inp @ deq_w.T
193 | error = output_ref - output
194 | rel_out_error = torch.abs(error.float() / (output_ref + 1e-5).float()).mean()
195 | print(f'bit {BIT} avg rel out quant error: {rel_out_error}')
196 |
197 |
198 | def test_gemv_speed():
199 | inp = torch.randn((B, IC), device='cuda', dtype=torch.float16)
200 | ori_weight = torch.randn((OC, IC), device='cuda', dtype=torch.float16)
201 | weight = ori_weight
202 | BIT = 4
203 | GS = 64
204 | PACK_FACTOR = 32 // BIT
205 | assert IC % GS == 0 and IC % PACK_FACTOR == 0
206 | NG = IC // GS
207 | weight = weight.view(OC, NG, GS)
208 | mx = torch.max(weight, dim=2, keepdim=False)[0]
209 | mn = torch.min(weight, dim=2, keepdim=False)[0]
210 | maxq = 2 ** BIT - 1
211 | scale = (mx - mn) / maxq
212 | weight = weight - mn.unsqueeze(-1)
213 | weight.div_(scale.unsqueeze(-1))
214 | weight = weight.clamp_(0, maxq).round_().to(torch.int32)
215 | weight = weight.view(OC, IC)
216 | qweight = pack_tensor(weight, BIT, 1)
217 | output = gemv_fwd(BIT, GS, inp, qweight, mn, scale)
218 | deq_w = dequant_weight(weight, scale, mn, GS)
219 | stmt = "inp @ deq_w.T"
220 | t_ref = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1,
221 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
222 | # stmt = "gemv_fwd(BIT, GS, inp, qweight, mn, scale)"
223 | # t_our = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1,
224 | # setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
225 | stmt = "kivi_gemv.gemv_forward_cuda(inp, qweight, scale, mn, BIT, GS)"
226 | t_our = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1,
227 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
228 | print(f'vanilla pytorch gemv: {t_ref * 1000} ms')
229 | print(f'awq fused IC {IC} OC {OC} {BIT}-bit gemv: {t_our * 1000} ms')
230 |
231 |
232 | def test_bgemv_outer_speed():
233 | inp = torch.randn((B, 1, IC), device='cuda', dtype=torch.float16)
234 | ori_weight = torch.randn((B, IC, OC), device='cuda', dtype=torch.float16)
235 | GS = 64
236 | for BIT in [2]:
237 | weight = ori_weight
238 | PACK_FACTOR = 32 // BIT
239 | assert OC % GS == 0 and OC % PACK_FACTOR == 0
240 | NG = OC // GS
241 | weight = weight.view(B, IC, NG, GS)
242 | mx = torch.max(weight, dim=-1, keepdim=False)[0]
243 | mn = torch.min(weight, dim=-1, keepdim=False)[0]
244 | maxq = 2 ** BIT - 1
245 | scale = (mx - mn) / maxq
246 | weight = weight - mn.unsqueeze(-1)
247 | weight.div_(scale.unsqueeze(-1))
248 | weight = weight.clamp_(0, maxq).round_().to(torch.int32)
249 | weight = weight.view(B, IC, OC)
250 | qweight = pack_tensor(weight, BIT, 2)
251 | weight = weight.transpose(1, 2).contiguous()
252 | qweight = qweight.transpose(1, 2).contiguous()
253 | scale = scale.transpose(1, 2).contiguous()
254 | mn = mn.transpose(1, 2).contiguous()
255 | deq_w = dequant_weight_outer(weight.transpose(1, 2),
256 | scale.transpose(1, 2),
257 | mn.transpose(1, 2), GS)
258 | stmt = "inp @ deq_w"
259 | t_ref = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1,
260 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
261 | # stmt = "gemv_fwd(BIT, GS, inp, qweight, mn, scale)"
262 | # t_our = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1,
263 | # setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
264 | stmt = "kivi_gemv.gemv_forward_cuda_outer_dim(inp, qweight, scale, mn, BIT, GS)"
265 | t_our = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=1,
266 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
267 | print(f'BS {B} IC {IC} OC {OC} pytorch batched gemv: {t_ref * 1000} ms')
268 | print(f'our fused BS {B} IC {IC} OC {OC} {BIT}-bit outer-dim batched gemv: {t_our * 1000} ms')
269 |
270 | if __name__ == "__main__":
271 | torch.manual_seed(0)
272 | np.random.seed(0)
273 | random.seed(0)
274 | # test_gemv_correct()
275 | test_bgemv_outer_correct_mha()
276 | test_bgemv_outer_correct_mqa()
277 | # test_gemv_speed()
278 | # test_bgemv_outer_speed()
279 |
--------------------------------------------------------------------------------
/quant/matmul.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # import ipdb
3 | import random
4 | import triton
5 | import triton.language as tl
6 | import kivi_gemv
7 |
8 |
9 | @triton.jit
10 | def qbvm_kernel(
11 | bits,
12 | a_ptr, b_ptr, c_ptr,
13 | scales_ptr, zeros_ptr,
14 | M, N, K,
15 | stride_abatch, stride_am, stride_ak,
16 | stride_bbatch, stride_bk, stride_bn,
17 | stride_cbatch, stride_cm, stride_cn,
18 | stride_scales_b, stride_scales_k, stride_scales_g,
19 | stride_zeros_b, stride_zeros_k, stride_zeros_g,
20 | groupsize,
21 | BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
22 | ):
23 | """
24 | Compute the batch matrix multiplication C = A x B.
25 | A is of shape (B, 1, K) float16
26 | B is of shape (B, K, N//feat_per_int) int32
27 | C is of shape (B, 1, N) float16
28 | scales is of shape (B, K, G) float16
29 | zeros is of shape (B, K, G) float16
30 | groupsize is an int specifying the size of groups for scales and zeros.
31 | G is N // groupsize.
32 | Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient.
33 |
34 | WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K.
35 | WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N.
36 | WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K.
37 | """
38 | pid_batch = tl.program_id(axis=0)
39 | pid = tl.program_id(axis=1)
40 | feat_per_int = 32 // bits
41 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
42 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
43 | pid_n = pid % num_pid_n
44 | offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
45 | offs_k = tl.arange(0, BLOCK_SIZE_K)
46 | a_batch_offset = (pid_batch * stride_abatch)
47 | b_batch_offset = (pid_batch * stride_bbatch)
48 | c_batch_offset = (pid_batch * stride_cbatch)
49 | a_ptr = a_ptr + a_batch_offset
50 | b_ptr = b_ptr + b_batch_offset
51 | c_ptr = c_ptr + c_batch_offset
52 | a_ptrs = a_ptr + (offs_k[:, None] * stride_ak) # (BLOCK_SIZE_K, 1)
53 | # a_mask = (offs_am[:, None] < M)
54 | # b_ptrs is set up such that it repeats elements along the N axis feat_per_int times
55 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + (offs_bn[None, :]//feat_per_int) * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
56 | # shifter is used to extract the # bits bits of each element in the 32-bit word from B
57 | shifter = (offs_bn % feat_per_int) * bits
58 | scales_ptr = scales_ptr + pid_batch*stride_scales_b + ((offs_bn[None, :] // groupsize)) * stride_scales_g # (BLOCK_SIZE_N,)
59 | zeros_ptr = zeros_ptr + pid_batch*stride_zeros_b + ((offs_bn[None, :] // groupsize)) * stride_zeros_g # (BLOCK_SIZE_N,)
60 |
61 | # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N)
62 | # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension
63 | # So this loop is along the infeatures dimension (K)
64 | # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel
65 | # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
66 | accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
67 | num = 0xFF >> (8-bits)
68 | for pid_k in range(0, num_pid_k):
69 | offs_bk = (offs_k[:, None] + pid_k * BLOCK_SIZE_K)
70 | # offs_k[None, :] < K - pid_k * BLOCK_SIZE_K
71 | a = tl.load(a_ptrs, mask=offs_bk < K, other=0.) # (1, BLOCK_SIZE_K)
72 | b = tl.load(b_ptrs, mask=offs_bk < K, other=0.) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
73 | ptr = scales_ptr + offs_bk * stride_scales_k
74 | scales = tl.load(ptr, mask=offs_bk < K, other=0.) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
75 | ptr = zeros_ptr + offs_bk * stride_zeros_k
76 | zeros = tl.load(ptr, mask=offs_bk < K, other=0.) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
77 | # Now we need to unpack b into 32-bit values
78 | # tl.device_print("scale ",scales.dtype)
79 | # tl.device_print("zeros ",zeros.dtype)
80 | b = (b >> shifter[None, :]) & num # For 4-bit values, bit_op_num is 0xF
81 | b = b * scales + zeros # Scale and shift
82 | accumulator += tl.sum(a * b, 0) # tl.dot(a, b)
83 | # if pid_m == 0 and pid_n == 0:
84 | # tl.device_print("hello ", tl.dot(a, b).shape)
85 | a_ptrs += BLOCK_SIZE_K * stride_ak
86 | b_ptrs += BLOCK_SIZE_K * stride_bk
87 | c = accumulator # .to(tl.float16)
88 | # c = accumulator
89 | # Store the result
90 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
91 | c_ptrs = c_ptr + stride_cn * offs_cn
92 | c_mask = (offs_cn < N)
93 | tl.store(c_ptrs, c, mask=c_mask)
94 |
95 |
96 | def understand_code():
97 | M, N, K = 512, 256, 256
98 | BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M = 64, 64, 4
99 | total_program_id = triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)
100 | for pid in range(0, total_program_id):
101 | num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
102 | num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
103 | num_pid_in_group = GROUP_SIZE_M * num_pid_n
104 | group_id = pid // num_pid_in_group
105 | first_pid_m = group_id * GROUP_SIZE_M
106 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
107 | pid_m = first_pid_m + (pid % group_size_m)
108 | pid_n = (pid % num_pid_in_group) // group_size_m
109 | print(f"pid={pid}, pid_m={pid_m}, pid_n={pid_n}")
110 |
111 |
112 | def triton_bmm_fA_qB_outer(group_size: int,
113 | fA: torch.FloatTensor,
114 | qB: torch.IntTensor,
115 | scales: torch.FloatTensor,
116 | zeros: torch.FloatTensor,
117 | bits: int) -> torch.FloatTensor:
118 | """
119 | Compute the matrix multiplication C = query x key.
120 | Where key is quantized into 2-bit values.
121 |
122 | fA is of shape (B, nh, M, K) float16
123 | qB is of shape (B, nh, K, N // feat_per_int) int32
124 | scales is of shape (B, nh, K, G) float16
125 | zeros is of shape (B, nh, K, G) float16
126 |
127 | groupsize is the number of outer dimensions in each group.
128 | G = N // groupsize
129 |
130 | Returns C of shape (B, nh, M, N) float16
131 | """
132 | assert len(fA.shape) == 4 and len(qB.shape) == 4
133 | B, nh, M, K = fA.shape
134 | feat_per_int = 32 // bits
135 | # flatten to a 3D tensor
136 | fA = fA.view(-1, M, K)
137 | N = qB.shape[-1] * feat_per_int
138 | qB = qB.reshape(-1, K, qB.shape[-1])
139 | # This is based on the possible BLOCK_SIZE_Ks
140 | # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128"
141 | # This is based on the possible BLOCK_SIZE_Ns
142 | assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0, "N must be a multiple of 16, 32, 64, 128, and 256"
143 | # This is based on the possible BLOCK_SIZE_Ks
144 | assert group_size % 64 == 0, "groupsize must be a multiple of 64, and 128"
145 | flatten_B = B * nh
146 | c = torch.empty((flatten_B, M, N), device='cuda', dtype=torch.float16)
147 | # print(f'M {M} N {N} K {K}')
148 | grid = lambda META: (
149 | flatten_B, triton.cdiv(N, META['BLOCK_SIZE_N']),
150 | )
151 | scales = scales.view(flatten_B, scales.shape[-2], scales.shape[-1])
152 | zeros = zeros.view(flatten_B, zeros.shape[-2], zeros.shape[-1])
153 | if N > K:
154 | BLOCK_SIZE_N = 128
155 | BLOCK_SIZE_K = 32
156 | num_warps=4 #
157 | else:
158 | BLOCK_SIZE_N = 32
159 | BLOCK_SIZE_K = 128
160 | num_warps = 2
161 | num_stages= 7 if K > 64 else 3 #
162 | qbvm_kernel[grid](
163 | bits,
164 | fA, qB, c,
165 | scales, zeros,
166 | M, N, K,
167 | fA.stride(0), fA.stride(1), fA.stride(2),
168 | qB.stride(0), qB.stride(1), qB.stride(2),
169 | c.stride(0), c.stride(1), c.stride(2),
170 | scales.stride(0), scales.stride(1), scales.stride(2),
171 | zeros.stride(0), zeros.stride(1), scales.stride(2),
172 | group_size, BLOCK_SIZE_N, BLOCK_SIZE_K,
173 | num_warps=num_warps, num_stages=num_stages
174 | )
175 | return c.view(B, nh, c.shape[-2], c.shape[-1])
176 |
177 |
178 | def cuda_bmm_fA_qB_outer(group_size: int,
179 | fA: torch.FloatTensor,
180 | qB: torch.IntTensor,
181 | scales: torch.FloatTensor,
182 | zeros: torch.FloatTensor,
183 | bits: int) -> torch.FloatTensor:
184 | """
185 | Compute the matrix multiplication C = query x key.
186 | Where key is quantized into 2-bit values.
187 |
188 | fA is of shape (B, nh, M, K) float16
189 | qB is of shape (B, nh, K, N // feat_per_int) int32
190 | scales is of shape (B, nh, K, G) float16
191 | zeros is of shape (B, nh, K, G) float16
192 |
193 | groupsize is the number of outer dimensions in each group.
194 | G = N // groupsize
195 |
196 | Returns C of shape (B, nh, M, N) float16
197 | """
198 | assert len(fA.shape) == 4 and len(qB.shape) == 4
199 | B, nh, M, K = fA.shape
200 | nh_kv = qB.shape[1]
201 | feat_per_int = 32 // bits
202 | # flatten to a 3D tensor
203 | fA = fA.view(-1, M, K).contiguous()
204 | N = qB.shape[-1] * feat_per_int
205 | qB = qB.reshape(-1, K, qB.shape[-1]).transpose(1, 2).contiguous()
206 | # This is based on the possible BLOCK_SIZE_Ks
207 | # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128"
208 | # This is based on the possible BLOCK_SIZE_Ns
209 | # assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0, "N must be a multiple of 16, 32, 64, 128, and 256"
210 | # This is based on the possible BLOCK_SIZE_Ks
211 | # assert group_size % 64 == 0, "groupsize must be a multiple of 64, and 128"
212 | flatten_B = B * nh_kv
213 | scales = scales.view(flatten_B, scales.shape[-2], scales.shape[-1]).transpose(1, 2).contiguous()
214 | zeros = zeros.view(flatten_B, zeros.shape[-2], zeros.shape[-1]).transpose(1, 2).contiguous()
215 | assert bits in [2, 4]
216 | assert nh % nh_kv == 0
217 | c = kivi_gemv.gemv_forward_cuda_outer_dim(fA, qB, scales, zeros, bits, group_size, nh, nh_kv)
218 | c = c.view(B, nh, c.shape[-2], c.shape[-1])
219 | return c
220 |
--------------------------------------------------------------------------------
/quant/new_pack.py:
--------------------------------------------------------------------------------
1 | import triton
2 | import triton.language as tl
3 | import random
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def quant_and_pack_kcache(k: torch.FloatTensor, group_size: int, bits: int):
9 | assert len(k.shape) == 4
10 | shape = k.shape
11 | B, nh, T, D = shape
12 | # ================== Get Scale & Zeros ===============
13 | assert T % group_size == 0
14 | num_groups = T // group_size
15 | new_shape = (B, nh, num_groups, group_size, D)
16 | # Quantize
17 | max_int = 2 ** bits - 1
18 | data = k.view(new_shape)
19 | mn = torch.min(data, dim=-2, keepdim=True)[0]
20 | mx = torch.max(data, dim=-2, keepdim=True)[0]
21 | scale = (mx - mn) / max_int
22 | data = data - mn
23 | data.div_(scale)
24 | data = data.clamp_(0, max_int).round_().to(torch.int32)
25 | data = data.view(shape)
26 | code = pack_tensor(data, bits, pack_dim=2)
27 | return code, scale, mn
28 |
29 |
30 | def quant_and_pack_vcache(v: torch.FloatTensor, group_size: int, bits: int):
31 | shape = v.shape
32 | assert len(shape) == 4
33 | assert v.shape[-1] % group_size == 0
34 | num_groups = shape[-1] // group_size
35 | new_shape = (shape[:-1] + (num_groups, group_size))
36 | # Quantize
37 | max_int = 2 ** bits - 1
38 | data = v.view(new_shape)
39 | mn = torch.min(data, dim=-1, keepdim=True)[0]
40 | mx = torch.max(data, dim=-1, keepdim=True)[0]
41 | scale = (mx - mn) / max_int
42 | data = data - mn
43 | data.div_(scale)
44 | data = data.clamp_(0, max_int).round_().to(torch.int32)
45 | data = data.view(shape)
46 | # Pack
47 | code = pack_tensor(data, bits, pack_dim=3)
48 | return code, scale, mn
49 |
50 |
51 | def unpack_and_dequant_kcache(k_code: torch.FloatTensor,
52 | scale: torch.FloatTensor,
53 | mn: torch.FloatTensor,
54 | group_size: int,
55 | bits: int,
56 | ):
57 | pack_dim = 2
58 | assert bits in [2, 4, 8]
59 | assert len(k_code.shape) == 4
60 | data = unpack_tensor(k_code, bits, pack_dim=pack_dim)
61 | shape = data.shape
62 | num_groups = shape[pack_dim] // group_size
63 | data = data.view(shape[:pack_dim] + (num_groups, group_size,) + shape[pack_dim+1:])
64 | data = data.to(torch.float16)
65 | data = data * scale + mn
66 | return data.view(shape)
67 |
68 |
69 | def unpack_and_dequant_vcache(v_code: torch.FloatTensor,
70 | scale: torch.FloatTensor,
71 | mn: torch.FloatTensor,
72 | group_size: int,
73 | bits: int,
74 | ):
75 | assert bits in [2, 4, 8]
76 | assert len(v_code.shape) == 4
77 | data = unpack_tensor(v_code, bits, pack_dim=3)
78 | shape = data.shape
79 | num_groups = shape[-1] // group_size
80 | data = data.view(shape[:-1] + (num_groups, group_size,))
81 | data = data.to(torch.float16)
82 | data = data * scale + mn
83 | return data.view(shape)
84 |
85 |
86 | def pack_tensor(data, bits, pack_dim):
87 | # Pack
88 | shape = data.shape
89 | feat_per_int = 32 // bits
90 | assert bits in [2,4,8], "Only 2, 4, 8 bits are supported"
91 | assert shape[pack_dim] % feat_per_int == 0, "Dimension length must be divisible by number of features per int"
92 | # BS, nh, T, nd // 16 # 16 is for 2bit
93 | code = torch.zeros(shape[:pack_dim] + (shape[pack_dim] // feat_per_int,)+shape[pack_dim+1:],
94 | dtype=torch.int32,
95 | device=data.device)
96 | i = 0
97 | row = 0
98 | unpacked_indices = [slice(None)] * len(data.shape)
99 | packed_indices = [slice(None)] * len(data.shape)
100 | while row < code.shape[pack_dim]:
101 | packed_indices[pack_dim] = row
102 | for j in range(i, i + (32 // bits)):
103 | unpacked_indices[pack_dim] = j
104 | code[packed_indices] |= data[unpacked_indices] << (bits * (j - i))
105 | i += 32 // bits
106 | row += 1
107 | return code
108 |
109 |
110 | def unpack_tensor(v_code: torch.FloatTensor,
111 | bits: int,
112 | pack_dim: int):
113 | assert bits in [2,4,8]
114 | shape = v_code.shape
115 | feat_per_int = 32 // bits
116 | new_shape = shape[:pack_dim] + (shape[pack_dim] * feat_per_int,) + shape[pack_dim+1:]
117 | unpacked_v_code = torch.zeros(new_shape, dtype=torch.int8, device=v_code.device)
118 | i = torch.arange(new_shape[pack_dim], device=v_code.device) // feat_per_int
119 | j = torch.arange(new_shape[pack_dim], device=v_code.device) % feat_per_int
120 | num = 0xFF >> (8 - bits)
121 | packed_indices = [slice(None)] * len(new_shape)
122 | packed_indices[pack_dim] = i
123 | if pack_dim == 2:
124 | unpacked_v_code = ((v_code[packed_indices] >> (j * bits)[None, None, :, None]).to(torch.int16)) & num
125 | elif pack_dim == 3:
126 | unpacked_v_code = ((v_code[packed_indices] >> (j * bits)).to(torch.int16)) & num
127 | else:
128 | raise NotImplementedError
129 | return unpacked_v_code
130 |
131 |
132 | @triton.jit
133 | def _pack_along_last_dim(
134 | bits: tl.constexpr,
135 | intensor_ptr,
136 | code_ptr,
137 | N,
138 | num_feats: tl.constexpr,
139 | feat_per_int: tl.constexpr,
140 | BLOCK_SIZE_N: tl.constexpr
141 | ):
142 | num_int_per_y_dim = num_feats // feat_per_int
143 | bid = tl.program_id(axis=0)
144 | yid = tl.program_id(axis=1)
145 | offs_N = bid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
146 | block_start = intensor_ptr + offs_N * num_feats + yid * feat_per_int # offset of the first element at current tile
147 | packed = tl.zeros((BLOCK_SIZE_N,), dtype=tl.int32)
148 | for i in range(feat_per_int):
149 | ptr = block_start + i
150 | element = tl.load(ptr, mask=offs_N= 128:
12 | size_multiplier = 1
13 | elif group_size == 64:
14 | size_multiplier = 2
15 | elif group_size == 32:
16 | size_multiplier = 4
17 | else:
18 | raise NotImplementedError
19 |
20 | base_width = make_divisible(in_features // group_size, pack_num)
21 | base_width = make_divisible(base_width, size_multiplier) * size_multiplier
22 | return base_width
23 |
24 |
25 | def dequantize_weight(qweight, d_out, d_in, w_bit, scales, zeros, group_size):
26 | data = qweight.reshape(-1)
27 | N, num_features = d_out, d_in
28 | weight_fp = dequant_cuda.unpack_single_precision(data, w_bit, scales, zeros, N,
29 | num_features // group_size, group_size)
30 | return weight_fp.view(d_out, d_in)
31 |
32 |
33 | class MatMul4Bit(torch.autograd.Function):
34 | # forward is the same, but we added the fallback for pre-turing GPUs
35 | # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
36 |
37 | @staticmethod
38 | def forward(ctx, A, qweight, bias, d_out, d_in, w_bit, scales, zeros, group_size):
39 | # default of pytorch behavior if inputs are empty
40 | # 1. Dequantize
41 | # 2. MatmulnN
42 | weight_fp = dequantize_weight(qweight, d_out, d_in, w_bit, scales, zeros, group_size)
43 | output = torch.nn.functional.linear(A, weight_fp.to(A.dtype), bias)
44 | # 3. Save state
45 | ctx.state = (d_out, d_in, w_bit, scales, zeros, group_size)
46 | ctx.tensors = qweight
47 | return output
48 |
49 |
50 | @staticmethod
51 | def backward(ctx, grad_output):
52 | req_gradA, _, req_gradBias = ctx.needs_input_grad[:3]
53 | qweight = ctx.tensors
54 | d_out, d_in, w_bit, scales, zeros, group_size = ctx.state
55 |
56 | grad_A, grad_bias = None, None
57 |
58 | if req_gradBias:
59 | # compute grad_bias first before changing grad_output dtype
60 | grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
61 |
62 | # not supported by PyTorch. TODO: create work-around
63 | #if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
64 | if req_gradA:
65 | weight_fp = dequantize_weight(qweight, d_out, d_in, w_bit, scales, zeros, group_size)
66 | grad_A = torch.matmul(grad_output, weight_fp.to(grad_output.dtype))
67 | if grad_A.isnan().any():
68 | import ipdb; ipdb.set_trace()
69 | # print(grad_A.norm())
70 | return grad_A, None, grad_bias, None, None, None, None, None, None
71 |
72 |
73 | class WQLinearForTrain(torch.nn.Module):
74 | def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
75 | super().__init__()
76 |
77 | if w_bit not in [4]:
78 | raise NotImplementedError("Only 4-bit are supported for now.")
79 |
80 | self.in_features = in_features
81 | self.out_features = out_features
82 | self.w_bit = w_bit
83 | self.group_size = group_size if group_size != -1 else in_features
84 | # quick sanity check (make sure aligment)
85 | assert self.in_features % self.group_size == 0
86 | assert out_features % (32 // self.w_bit) == 0
87 | pack_num = (32 // self.w_bit)
88 | self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev))
89 | self.register_buffer('zeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev))
90 | self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev))
91 | if bias:
92 | self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
93 | else:
94 | self.bias = None
95 |
96 |
97 | def forward(self, x):
98 | # weight_fp = self.dequantize_weight().half()
99 | # out = torch.matmul(x, weight_fp.T)
100 | # out = out + self.bias if self.bias is not None else out
101 |
102 | out = MatMul4Bit.apply(x, self.qweight, self.bias,
103 | self.out_features, self.in_features,
104 | self.w_bit, self.scales,
105 | self.zeros, self.group_size)
106 | return out
107 |
108 | def dequantize_weight(self):
109 | data = self.qweight.reshape(-1)
110 | N, num_features = self.out_features, self.in_features
111 | weight_fp = dequant_cuda.unpack_single_precision(data, self.w_bit, self.scales, self.zeros, N,
112 | num_features // self.group_size, self.group_size)
113 | return weight_fp.view(self.out_features, self.in_features)
114 |
115 |
116 | @classmethod
117 | def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
118 | q_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
119 | if init_only: # just prepare for loading sd
120 | return q_linear
121 | quantized, scales, mn = quantize_and_pack(linear.weight, group_size, w_bit, simulate=False)
122 | q_linear.qweight = quantized
123 | q_linear.scales = scales
124 | q_linear.zeros = mn
125 | return q_linear
--------------------------------------------------------------------------------
/quant/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
3 |
4 |
5 | extra_compile_args = {
6 | "cxx": [
7 | "-g",
8 | "-O3",
9 | "-fopenmp",
10 | "-lgomp",
11 | "-std=c++17",
12 | "-DENABLE_BF16"
13 | ],
14 | "nvcc": [
15 | "-O3",
16 | "-std=c++17",
17 | "-DENABLE_BF16", # TODO
18 | "-U__CUDA_NO_HALF_OPERATORS__",
19 | "-U__CUDA_NO_HALF_CONVERSIONS__",
20 | "-U__CUDA_NO_BFLOAT16_OPERATORS__",
21 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
22 | "-U__CUDA_NO_BFLOAT162_OPERATORS__",
23 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
24 | "--expt-relaxed-constexpr",
25 | "--expt-extended-lambda",
26 | "--use_fast_math",
27 | "--threads=8"
28 | ],
29 | }
30 |
31 | setup(
32 | name="kivi_gemv",
33 | packages=find_packages(),
34 | ext_modules=[
35 | CUDAExtension(
36 | name="kivi_gemv",
37 | sources=[
38 | "csrc/pybind.cpp",
39 | "csrc/gemv_cuda.cu"
40 | ],
41 | extra_compile_args=extra_compile_args,
42 | ),
43 | ],
44 | cmdclass={"build_ext": BuildExtension},
45 | install_requires=["torch"],
46 | )
--------------------------------------------------------------------------------
/quant/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | os.environ["CUDA_VISIBLE_DEVICES"]="2"
4 | import numpy as np
5 | import random
6 | # import ipdb
7 | import math
8 | import os
9 | import triton
10 | from new_pack import quant_and_pack_vcache, unpack_and_dequant_kcache, triton_quantize_and_pack_along_last_dim, unpack_and_dequant_vcache, quant_and_pack_kcache
11 | from matmul import triton_bmm_fA_qB_outer
12 | from timeit_v2 import py_benchmark
13 |
14 |
15 | def set_seed(seed):
16 | np.random.seed(seed)
17 | torch.random.manual_seed(seed)
18 | random.seed(seed)
19 |
20 |
21 | def test_vcache():
22 | torch.manual_seed(0)
23 | np.random.seed(0)
24 | random.seed(0)
25 | B, nh, T, hd = 555, 32, 433, 128
26 | v = torch.randn((B, nh, T, hd), device='cuda', dtype=torch.float16)
27 | group_size = 64
28 | for bits in [2, 4, 8]:
29 | code, scale, mn = triton_quantize_and_pack_along_last_dim(v, group_size, bits)
30 | # print(f'bit {bits}, scale.shape: {scale.shape}')
31 | # print(f'bit {bits}, code.shape: {code.shape}')
32 | dequant_v = unpack_and_dequant_vcache(code, scale.unsqueeze(-1), mn.unsqueeze(-1), group_size, bits)
33 | assert not dequant_v.isnan().any()
34 | gap = (dequant_v - v) / v
35 | gap = torch.nan_to_num(gap)
36 | print(f'bit {bits}, mean v rel arr: {torch.mean(torch.abs(gap))}')
37 |
38 |
39 | def test_kcache():
40 | torch.manual_seed(0)
41 | np.random.seed(0)
42 | random.seed(0)
43 | BS, nh, T, D = 11, 32, 4096, 128
44 | k = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16)
45 | group_size = 64
46 | for bits in [2, 4, 8]:
47 | code, scale, mn = triton_quantize_and_pack_along_last_dim(k.transpose(2, 3).contiguous(),
48 | group_size,
49 | bits)
50 | dequant_k = unpack_and_dequant_vcache(code, scale.unsqueeze(-1), mn.unsqueeze(-1), group_size, bits)
51 | assert not dequant_k.isnan().any()
52 | gap = (dequant_k.transpose(2, 3) - k) / k
53 | gap = torch.nan_to_num(gap)
54 | print(f'bit {bits}, k mean rel arr: {torch.mean(torch.abs(gap))}')
55 |
56 |
57 | def test_bmm_speed():
58 | BS, nh, T, D = 64, 32, 512, 128
59 | bits = 2
60 | key_state = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16)
61 | val_state = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16)
62 | group_size = 64
63 | query_len = 1
64 | query_state = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16)
65 |
66 | # quantiles = [0.5, 0.2, 0.8]
67 | # ms, min_ms, max_ms = triton.testing.do_bench(
68 | # lambda: triton_quantize_and_pack_along_last_dim(key_state.transpose(2,3).contiguous(),
69 | # group_size, bits), quantiles=quantiles)
70 | # print(f'batch size {BS} nh {nh} seqlen {T} quant and pack pytorch impl: {ms * 1000: .2f} ms')
71 | code, scale, mn = triton_quantize_and_pack_along_last_dim(
72 | key_state.transpose(2,3).contiguous(), group_size, bits)
73 | code = code.contiguous()
74 | scale = scale.contiguous()
75 | mn = mn.contiguous()
76 |
77 | stmt = "triton_quantize_and_pack_along_last_dim(key_state.transpose(2,3).contiguous(), group_size, bits)"
78 | t_triton_quant = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3,
79 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
80 | print(f'our triton quant & pack impl: {t_triton_quant * 1000} ms')
81 | stmt = "quant_and_pack_kcache(key_state, group_size, bits)"
82 | t_quant = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3,
83 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
84 | print(f'vanilla pytorch quant & pack impl: {t_quant * 1000} ms')
85 | stmt = 'triton_bmm_fA_qB_outer(group_size, query_state, code, scale, mn, bits)'
86 | t_qk = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3,
87 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
88 | print(f'batch size {BS} seqlen {T} our fused batch qk impl: {t_qk * 1000: .2f} ms')
89 | stmt = 'torch.matmul(query_state, key_state.transpose(2, 3))'
90 | t_qk_ref = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3,
91 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
92 | print(f'batch size {BS} seqlen {T} pytorch batch qk impl: {t_qk_ref * 1000: .2f} ms')
93 | attn_weight = torch.randn((BS, nh, query_len, T), device='cuda', dtype=torch.float16)
94 | code, scale, mn = triton_quantize_and_pack_along_last_dim(
95 | val_state, group_size, bits)
96 | stmt = 'triton_bmm_fA_qB_outer(group_size, attn_weight, code, scale, mn, bits)'
97 | t_av = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3,
98 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
99 | print(f'batch size {BS} seqlen {T} our fused batch av impl: {t_av * 1000: .2f} ms')
100 | stmt = 'torch.matmul(attn_weight, val_state)'
101 | t_av_ref = py_benchmark(stmt, {**globals(), **locals()}, min_repeat_second=3,
102 | setup="torch.cuda.synchronize()", finish="torch.cuda.synchronize()")
103 | print(f'batch size {BS} seqlen {T} pytorch batch av impl: {t_av_ref * 1000: .2f} ms')
104 |
105 | # _code, _scale, _mn = quant_and_pack_kcache(
106 | # key_state, group_size, bits)
107 | # _code = _code.transpose(2, 3)
108 | # _scale = _scale.squeeze(-2).transpose(2,3)
109 | # _mn = _mn.squeeze(-2).transpose(2,3)
110 | # print(_code.shape, code.shape, _code.dtype, code.dtype)
111 | # print(_scale.shape, scale.shape, _scale.dtype, scale.dtype)
112 |
113 | # our_out = triton_bmm_fA_qB_outer(group_size, query_state, code, scale, mn, bits)
114 | # ref_out = torch.matmul(query_state, key_state.transpose(2, 3))
115 | # gap = (our_out - ref_out) / ref_out
116 | # gap = torch.nan_to_num(gap)
117 | # err = torch.mean(torch.abs(gap)).item()
118 | # print(f'bits {bits}, err: {err}')
119 | # ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_bmm_fA_qB_outer(group_size, query_state, code, scale, mn, bits), quantiles=quantiles)
120 | # print(f'batch size {BS} seqlen {T} our fused batch matmul impl: {ms * 1000: .2f} ms')
121 | # ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(query_state, key_state.transpose(2, 3)), quantiles=quantiles)
122 | # print(f'batch size {BS} seqlen {T} pytorch batch matmul impl: {ms * 1000: .2f} ms')
123 |
124 |
125 | def test_streaming_kvcache():
126 | BS, nh, T, D = 1, 32, 340, 128
127 | our_attn_output = None
128 | group_size = 64
129 | query_len = 1
130 | bits = 2
131 | key_states = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16)
132 | value_states = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16)
133 | key_states_quant = key_states[:, :, :-(key_states.shape[-2] % group_size), :].contiguous()
134 | key_states_full = key_states[:, :, -(key_states.shape[-2] % group_size):, :].contiguous()
135 | value_states_quant, value_scale, value_mn = triton_quantize_and_pack_along_last_dim(value_states,
136 | group_size,
137 | bits)
138 | key_states_quant_trans, key_scale_trans, key_mn_trans = triton_quantize_and_pack_along_last_dim(key_states_quant.transpose(2, 3).contiguous(),
139 | group_size, bits)
140 | for i in range(16):
141 | if our_attn_output is None:
142 | query_states = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16)
143 | else:
144 | query_states = our_attn_output
145 | key_states_new = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16)
146 | value_states_new = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16)
147 | att_qkquant = triton_bmm_fA_qB_outer(group_size, query_states, key_states_quant_trans,
148 | key_scale_trans, key_mn_trans, bits)
149 | key_states_full = torch.cat([key_states_full, key_states_new], dim=2)
150 | att_qkfull = torch.matmul(query_states, key_states_full.transpose(2, 3))
151 | our_att_weights = torch.cat([att_qkquant, att_qkfull], dim=-1) / math.sqrt(D)
152 | our_att_weights = torch.softmax(our_att_weights, dim=-1)
153 | value_states_quant_new, scale, mn = triton_quantize_and_pack_along_last_dim(value_states_new,
154 | group_size,
155 | bits)
156 | value_states_quant = torch.cat([value_states_quant, value_states_quant_new], dim=2)
157 | value_scale = torch.cat([value_scale, scale], dim=2)
158 | value_mn = torch.cat([value_mn, mn], dim=2)
159 | our_attn_output = triton_bmm_fA_qB_outer(group_size, our_att_weights, value_states_quant,
160 | value_scale, value_mn, bits)
161 | # ===
162 | key_states = torch.cat([key_states, key_states_new], dim=2)
163 | value_states = torch.cat([value_states, value_states_new], dim=2)
164 | ref_att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(D)
165 | ref_att_weights = torch.softmax(ref_att_weights, dim=-1)
166 | ref_att_out = torch.matmul(ref_att_weights, value_states)
167 | att_weight_gap = (ref_att_weights - our_att_weights) / ref_att_weights
168 | print(f'i {i} bit {bits}, mean att weight rel arr: {torch.mean(torch.abs(att_weight_gap))}')
169 | att_out_gap = (ref_att_out - our_attn_output) / ref_att_out
170 | print(f'i {i} bit {bits}, mean att out rel arr: {torch.mean(torch.abs(att_out_gap))}')
171 |
172 |
173 | def test_4d_qmatmul():
174 | torch.manual_seed(0)
175 | np.random.seed(0)
176 | random.seed(0)
177 | query_len = 1
178 | BS, nh, T, D = 16, 32, 1024, 128
179 | group_size = 64
180 | # k = torch.randn((BS, nh, T, D), device='cuda', dtype=torch.float16)
181 | # query_state = torch.randn((BS, nh, query_len, D), device='cuda', dtype=torch.float16)
182 | k = torch.randint(10, (BS, nh, T, D), device='cuda').to(torch.float16)
183 | query_state = torch.randint(5, (BS, nh, query_len, D), device='cuda').to(torch.float16)
184 | for bits in [8, 4, 2]:
185 | # code.shape == BS, nh, T // feat_per_int, D
186 | # scale, mn.shape == BS, nh, ng, 1, D
187 | code, scale, mn = quant_and_pack_kcache(k, group_size, bits)
188 | dequant_k = unpack_and_dequant_kcache(code, scale, mn, group_size, bits)
189 | # BS, nh, D, T // feat_per_int
190 | code = code.transpose(2, 3)
191 | # BS, nh, D, T // group_size
192 | scale = scale.view(BS, nh, -1, D).transpose(2, 3)
193 | mn = mn.view(BS, nh, -1, D).transpose(2, 3)
194 | our_out = triton_bmm_fA_qB_outer(group_size, query_state, code, scale, mn, bits)
195 | ref_out = torch.matmul(query_state, k.transpose(2, 3))
196 | # ref_out = torch.matmul(query_state, k.transpose(2, 3))
197 | assert not our_out.isnan().any()
198 | assert not ref_out.isnan().any()
199 | gap = (our_out - ref_out) / ref_out
200 | gap = torch.nan_to_num(gap)
201 | err = torch.mean(torch.abs(gap)).item()
202 | print(f'bits {bits}, err: {err}')
203 |
204 |
205 | if __name__ == '__main__':
206 | set_seed(114514)
207 | # test_kcache()
208 | # test_vcache()
209 | # test_4d_qmatmul()
210 | # test_streaming_kvcache()
211 | test_bmm_speed()
--------------------------------------------------------------------------------
/quant/timeit_v2.py:
--------------------------------------------------------------------------------
1 | # timeit_v2.py: Copied from the default library with the following two modifiations
2 | # 1. Add 'finish' argument to timeit for calling cuda synchronization.
3 | # 2. Add accurate measurment utility function py_benchmark
4 |
5 | """Tool for measuring execution time of small code snippets.
6 |
7 | This module avoids a number of common traps for measuring execution
8 | times. See also Tim Peters' introduction to the Algorithms chapter in
9 | the Python Cookbook, published by O'Reilly.
10 |
11 | Library usage: see the Timer class.
12 |
13 | Command line usage:
14 | python timeit.py [-n N] [-r N] [-s S] [-p] [-h] [--] [statement]
15 |
16 | Options:
17 | -n/--number N: how many times to execute 'statement' (default: see below)
18 | -r/--repeat N: how many times to repeat the timer (default 5)
19 | -s/--setup S: statement to be executed once initially (default 'pass').
20 | Execution time of this setup statement is NOT timed.
21 | -p/--process: use time.process_time() (default is time.perf_counter())
22 | -v/--verbose: print raw timing results; repeat for more digits precision
23 | -u/--unit: set the output time unit (nsec, usec, msec, or sec)
24 | -h/--help: print this usage message and exit
25 | --: separate options from statement, use when statement starts with -
26 | statement: statement to be timed (default 'pass')
27 |
28 | A multi-line statement may be given by specifying each line as a
29 | separate argument; indented lines are possible by enclosing an
30 | argument in quotes and using leading spaces. Multiple -s options are
31 | treated similarly.
32 |
33 | If -n is not given, a suitable number of loops is calculated by trying
34 | successive powers of 10 until the total time is at least 0.2 seconds.
35 |
36 | Note: there is a certain baseline overhead associated with executing a
37 | pass statement. It differs between versions. The code here doesn't try
38 | to hide it, but you should be aware of it. The baseline overhead can be
39 | measured by invoking the program without arguments.
40 |
41 | Classes:
42 |
43 | Timer
44 |
45 | Functions:
46 |
47 | timeit(string, string) -> float
48 | repeat(string, string) -> list
49 | default_timer() -> float
50 | """
51 |
52 | import gc
53 | import sys
54 | import time
55 | import itertools
56 |
57 | __all__ = ["Timer", "timeit", "repeat", "default_timer"]
58 |
59 | dummy_src_name = ""
60 | default_number = 1000000
61 | default_repeat = 5
62 | default_timer = time.perf_counter
63 |
64 | _globals = globals
65 |
66 | # Don't change the indentation of the template; the reindent() calls
67 | # in Timer.__init__() depend on setup being indented 4 spaces and stmt
68 | # being indented 8 spaces.
69 | template = """
70 | def inner(_it, _timer{init}):
71 | {setup}
72 | _t0 = _timer()
73 | for _i in _it:
74 | {stmt}
75 | {finish}
76 | _t1 = _timer()
77 | return _t1 - _t0
78 | """
79 |
80 | def reindent(src, indent):
81 | """Helper to reindent a multi-line statement."""
82 | return src.replace("\n", "\n" + " "*indent)
83 |
84 | class Timer:
85 | """Class for timing execution speed of small code snippets.
86 |
87 | The constructor takes a statement to be timed, an additional
88 | statement used for setup, and a timer function. Both statements
89 | default to 'pass'; the timer function is platform-dependent (see
90 | module doc string). If 'globals' is specified, the code will be
91 | executed within that namespace (as opposed to inside timeit's
92 | namespace).
93 |
94 | To measure the execution time of the first statement, use the
95 | timeit() method. The repeat() method is a convenience to call
96 | timeit() multiple times and return a list of results.
97 |
98 | The statements may contain newlines, as long as they don't contain
99 | multi-line string literals.
100 | """
101 |
102 | def __init__(self, stmt="pass", setup="pass", finish='pass', timer=default_timer,
103 | globals=None):
104 | """Constructor. See class doc string."""
105 | self.timer = timer
106 | local_ns = {}
107 | global_ns = _globals() if globals is None else globals
108 | init = ''
109 | if isinstance(setup, str):
110 | # Check that the code can be compiled outside a function
111 | compile(setup, dummy_src_name, "exec")
112 | stmtprefix = setup + '\n'
113 | setup = reindent(setup, 4)
114 | elif callable(setup):
115 | local_ns['_setup'] = setup
116 | init += ', _setup=_setup'
117 | stmtprefix = ''
118 | setup = '_setup()'
119 | else:
120 | raise ValueError("setup is neither a string nor callable")
121 | if isinstance(stmt, str):
122 | # Check that the code can be compiled outside a function
123 | compile(stmtprefix + stmt, dummy_src_name, "exec")
124 | stmt = reindent(stmt, 8)
125 | elif callable(stmt):
126 | local_ns['_stmt'] = stmt
127 | init += ', _stmt=_stmt'
128 | stmt = '_stmt()'
129 | else:
130 | raise ValueError("stmt is neither a string nor callable")
131 |
132 | assert isinstance(finish, str)
133 | compile(setup + '\n' + stmt + '\n' + finish, dummy_src_name, 'exec')
134 | finish = reindent(finish, 4)
135 |
136 | src = template.format(stmt=stmt, setup=setup, init=init, finish=finish)
137 | self.src = src # Save for traceback display
138 | code = compile(src, dummy_src_name, "exec")
139 | exec(code, global_ns, local_ns)
140 | self.inner = local_ns["inner"]
141 |
142 | def print_exc(self, file=None):
143 | """Helper to print a traceback from the timed code.
144 |
145 | Typical use:
146 |
147 | t = Timer(...) # outside the try/except
148 | try:
149 | t.timeit(...) # or t.repeat(...)
150 | except:
151 | t.print_exc()
152 |
153 | The advantage over the standard traceback is that source lines
154 | in the compiled template will be displayed.
155 |
156 | The optional file argument directs where the traceback is
157 | sent; it defaults to sys.stderr.
158 | """
159 | import linecache, traceback
160 | if self.src is not None:
161 | linecache.cache[dummy_src_name] = (len(self.src),
162 | None,
163 | self.src.split("\n"),
164 | dummy_src_name)
165 | # else the source is already stored somewhere else
166 |
167 | traceback.print_exc(file=file)
168 |
169 | def timeit(self, number=default_number):
170 | """Time 'number' executions of the main statement.
171 |
172 | To be precise, this executes the setup statement once, and
173 | then returns the time it takes to execute the main statement
174 | a number of times, as a float measured in seconds. The
175 | argument is the number of times through the loop, defaulting
176 | to one million. The main statement, the setup statement and
177 | the timer function to be used are passed to the constructor.
178 | """
179 | it = itertools.repeat(None, number)
180 | gcold = gc.isenabled()
181 | gc.disable()
182 | try:
183 | timing = self.inner(it, self.timer)
184 | finally:
185 | if gcold:
186 | gc.enable()
187 | return timing
188 |
189 | def repeat(self, repeat=default_repeat, number=default_number):
190 | """Call timeit() a few times.
191 |
192 | This is a convenience function that calls the timeit()
193 | repeatedly, returning a list of results. The first argument
194 | specifies how many times to call timeit(), defaulting to 5;
195 | the second argument specifies the timer argument, defaulting
196 | to one million.
197 |
198 | Note: it's tempting to calculate mean and standard deviation
199 | from the result vector and report these. However, this is not
200 | very useful. In a typical case, the lowest value gives a
201 | lower bound for how fast your machine can run the given code
202 | snippet; higher values in the result vector are typically not
203 | caused by variability in Python's speed, but by other
204 | processes interfering with your timing accuracy. So the min()
205 | of the result is probably the only number you should be
206 | interested in. After that, you should look at the entire
207 | vector and apply common sense rather than statistics.
208 | """
209 | r = []
210 | for i in range(repeat):
211 | t = self.timeit(number)
212 | r.append(t)
213 | return r
214 |
215 | def autorange(self, callback=None):
216 | """Return the number of loops and time taken so that total time >= 0.2.
217 |
218 | Calls the timeit method with increasing numbers from the sequence
219 | 1, 2, 5, 10, 20, 50, ... until the time taken is at least 0.2
220 | second. Returns (number, time_taken).
221 |
222 | If *callback* is given and is not None, it will be called after
223 | each trial with two arguments: ``callback(number, time_taken)``.
224 | """
225 | i = 1
226 | while True:
227 | for j in 1, 2, 5:
228 | number = i * j
229 | time_taken = self.timeit(number)
230 | if callback:
231 | callback(number, time_taken)
232 | if time_taken >= 0.2:
233 | return (number, time_taken)
234 | i *= 10
235 |
236 | def timeit(stmt="pass", setup="pass", finish='pass', timer=default_timer,
237 | number=default_number, globals=None):
238 | """Convenience function to create Timer object and call timeit method."""
239 | return Timer(stmt, setup, finish, timer, globals).timeit(number)
240 |
241 | def repeat(stmt="pass", setup="pass", finish='pass', timer=default_timer,
242 | repeat=default_repeat, number=default_number, globals=None):
243 | """Convenience function to create Timer object and call repeat method."""
244 | return Timer(stmt, setup, finish, timer, globals).repeat(repeat, number)
245 |
246 | def py_benchmark(stmt, context, min_repeat_second=1, setup='pass', finish='pass'):
247 | total_time = 0
248 | number = 10
249 |
250 | eval(stmt, context) # warmup
251 | total_time = timeit(stmt=stmt, setup=setup, finish=finish, number=number, globals=context)
252 | while total_time < min_repeat_second:
253 | number = int(number * (min_repeat_second / total_time)) + 1
254 | total_time = timeit(stmt=stmt, setup=setup, finish=finish, number=number, globals=context)
255 |
256 | return total_time / number
257 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.1.2
2 | packaging==24.0
3 | absl-py==2.0.0
4 | accelerate==0.25.0
5 | aiofiles==23.2.1
6 | aiohttp==3.9.1
7 | aiosignal==1.3.1
8 | altair==5.2.0
9 | annotated-types==0.6.0
10 | anyio==4.3.0
11 | asttokens==2.4.1
12 | async-timeout==4.0.3
13 | attributedict==0.3.0
14 | attrs==23.2.0
15 | bitsandbytes==0.43.0
16 | blessings==1.7
17 | cachetools==5.3.2
18 | certifi==2023.11.17
19 | chardet==5.2.0
20 | charset-normalizer==3.3.2
21 | click==8.1.7
22 | codecov==2.1.13
23 | colorama==0.4.6
24 | coloredlogs==15.0.1
25 | colour-runner==0.1.1
26 | contourpy==1.2.0
27 | coverage==7.4.0
28 | cycler==0.12.1
29 | DataProperty==1.0.1
30 | datasets==2.16.1
31 | decorator==5.1.1
32 | deepdiff==6.7.1
33 | deepspeed==0.12.6
34 | dill==0.3.7
35 | distlib==0.3.8
36 | distro==1.9.0
37 | einops==0.7.0
38 | evaluate==0.4.1
39 | exceptiongroup==1.2.0
40 | executing==2.0.1
41 | fastapi==0.110.0
42 | ffmpy==0.3.2
43 | filelock==3.13.1
44 | flash-attn==2.5.6
45 | fonttools==4.50.0
46 | frozenlist==1.4.1
47 | fsspec==2023.10.0
48 | fuzzywuzzy==0.18.0
49 | gradio==5.0.0
50 | gradio_client==0.2.9
51 | h11==0.14.0
52 | hjson==3.1.0
53 | httpcore==1.0.4
54 | httpx==0.27.0
55 | huggingface-hub==0.20.2
56 | humanfriendly==10.0
57 | idna==3.6
58 | inspecta==0.1.3
59 | ipdb==0.13.13
60 | ipython==8.19.0
61 | jedi==0.19.1
62 | jieba==0.42.1
63 | Jinja2==3.1.2
64 | joblib==1.3.2
65 | jsonlines==4.0.0
66 | jsonschema==4.21.1
67 | jsonschema-specifications==2023.12.1
68 | kiwisolver==1.4.5
69 | linkify-it-py==2.0.3
70 | -e git+https://github.com/EleutherAI/lm-evaluation-harness.git@c9bbec6e7de418b9082379da82797522eb173054#egg=lm_eval
71 | lxml==5.0.1
72 | markdown-it-py==2.2.0
73 | MarkupSafe==2.1.3
74 | matplotlib==3.8.3
75 | matplotlib-inline==0.1.6
76 | mbstrdecoder==1.1.3
77 | mdit-py-plugins==0.3.3
78 | mdurl==0.1.2
79 | mpmath==1.3.0
80 | multidict==6.0.4
81 | multiprocess==0.70.15
82 | networkx==3.2.1
83 | ninja==1.11.1.1
84 | nltk==3.8.1
85 | numexpr==2.8.8
86 | numpy==1.26.3
87 | nvidia-cublas-cu12==12.1.3.1
88 | nvidia-cuda-cupti-cu12==12.1.105
89 | nvidia-cuda-nvrtc-cu12==12.1.105
90 | nvidia-cuda-runtime-cu12==12.1.105
91 | nvidia-cudnn-cu12==8.9.2.26
92 | nvidia-cufft-cu12==11.0.2.54
93 | nvidia-curand-cu12==10.3.2.106
94 | nvidia-cusolver-cu12==11.4.5.107
95 | nvidia-cusparse-cu12==12.1.0.106
96 | nvidia-nccl-cu12==2.18.1
97 | nvidia-nvjitlink-cu12==12.3.101
98 | nvidia-nvtx-cu12==12.1.105
99 | openai==1.14.2
100 | ordered-set==4.1.0
101 | orjson==3.9.15
102 | packaging==23.2
103 | pandas==2.1.4
104 | parso==0.8.3
105 | pathvalidate==3.2.0
106 | peft==0.7.1
107 | pexpect==4.9.0
108 | pillow==10.2.0
109 | platformdirs==4.1.0
110 | pluggy==1.3.0
111 | portalocker==2.8.2
112 | prompt-toolkit==3.0.43
113 | protobuf==4.25.1
114 | psutil==5.9.7
115 | ptyprocess==0.7.0
116 | pure-eval==0.2.2
117 | py-cpuinfo==9.0.0
118 | pyarrow==14.0.2
119 | pyarrow-hotfix==0.6
120 | pybind11==2.11.1
121 | pycountry==23.12.11
122 | pydantic==1.10.14
123 | pydantic_core==2.14.6
124 | pydub==0.25.1
125 | Pygments==2.17.2
126 | pynvml==11.5.0
127 | pyparsing==3.1.2
128 | pyproject-api==1.6.1
129 | pytablewriter==1.2.0
130 | python-dateutil==2.8.2
131 | python-multipart==0.0.9
132 | pytz==2023.3.post1
133 | PyYAML==6.0.1
134 | referencing==0.34.0
135 | regex==2023.12.25
136 | requests==2.31.0
137 | responses==0.18.0
138 | rootpath==0.1.1
139 | rouge==1.0.1
140 | rouge-score==0.1.2
141 | rpds-py==0.18.0
142 | sacrebleu==1.5.0
143 | safetensors==0.4.1
144 | scikit-learn==1.3.2
145 | scipy==1.11.4
146 | semantic-version==2.10.0
147 | sentencepiece==0.1.99
148 | six==1.16.0
149 | sniffio==1.3.1
150 | sqlitedict==2.1.0
151 | stack-data==0.6.3
152 | starlette==0.36.3
153 | sympy==1.12
154 | tabledata==1.3.3
155 | tabulate==0.9.0
156 | tcolorpy==0.1.4
157 | termcolor==2.4.0
158 | texttable==1.7.0
159 | threadpoolctl==3.2.0
160 | tokenizers==0.15.0
161 | toml==0.10.2
162 | tomli==2.0.1
163 | toolz==0.12.1
164 | torchvision==0.16.2
165 | tox==4.11.4
166 | tqdm==4.66.1
167 | tqdm-multiprocess==0.0.11
168 | traitlets==5.14.1
169 | transformers==4.36.2
170 | triton==2.1.0
171 | typepy==1.3.2
172 | typing_extensions==4.9.0
173 | tzdata==2023.4
174 | uc-micro-py==1.0.3
175 | urllib3==2.1.0
176 | uvicorn==0.29.0
177 | virtualenv==20.25.0
178 | wcwidth==0.2.13
179 | websockets==12.0
180 | xxhash==3.4.1
181 | yarl==1.9.4
182 | zstandard==0.22.0
183 |
--------------------------------------------------------------------------------
/scripts/long_test.sh:
--------------------------------------------------------------------------------
1 | # model e.g.: meta-llama/Llama-2-7b-hf
2 |
3 | gpuid=$1
4 | k_bits=$2
5 | v_bits=$3
6 | group_size=$4
7 | residual_length=$5
8 | model=$6
9 | e=0
10 |
11 | CUDA_VISIBLE_DEVICES=$gpuid python pred_long_bench.py --model_name_or_path $model \
12 | --cache_dir ./cached_models \
13 | --k_bits $k_bits \
14 | --v_bits $v_bits \
15 | --group_size $group_size \
16 | --residual_length $residual_length \
17 | --e ${e}
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | from torch.utils.data import Dataset
5 | from datasets import load_dataset
6 | from torch.utils.data import DataLoader
7 |
8 | class TextDataset(torch.utils.data.IterableDataset):
9 | def __init__(self, data, tokenizer, seqlen, col_key, cutoff=1000):
10 | self.tokenizer = tokenizer
11 | self.col_key = col_key
12 | self.cutoff = cutoff
13 | self.block_size = seqlen
14 | if cutoff is None:
15 | cutoff = len(data)
16 | tokenized_datasets = [self.tokenizer(data[i][col_key]) for i in range(cutoff)]
17 | grouped_dataset = self.group_texts(tokenized_datasets)
18 | self.input_ids = grouped_dataset["input_ids"]
19 | self.labels = grouped_dataset["labels"]
20 | self.data = [
21 | dict(input_ids=self.input_ids[i], labels=self.labels[i])
22 | for i in range(len(self.input_ids))
23 | ]
24 |
25 | def __len__(self):
26 | return len(self.input_ids)
27 |
28 | def __getitem__(self, i):
29 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
30 |
31 | def __iter__(self):
32 | return iter(self.data)
33 |
34 | def group_texts(self, examples):
35 | # Concatenate all texts.
36 | # Initialize an empty dictionary
37 | concatenated_examples = {}
38 |
39 | # Loop through the list of dictionaries
40 | for d in examples:
41 | # Loop through the keys in each dictionary
42 | for key in d.keys():
43 | # If the key is not already a key in the dict_of_lists, create a new list
44 | if key not in concatenated_examples:
45 | concatenated_examples[key] = []
46 | # Append the value to the list associated with the key in dict_of_lists
47 | concatenated_examples[key].extend(d[key])
48 | total_length = len(concatenated_examples["input_ids"])
49 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
50 | # customize this part to your needs.
51 | if total_length >= self.block_size:
52 | total_length = (total_length // self.block_size) * self.block_size
53 | # Split by chunks of max_len.
54 | result = {
55 | k: [
56 | t[i : i + self.block_size]
57 | for i in range(0, total_length, self.block_size)
58 | ]
59 | for k, t in concatenated_examples.items()
60 | }
61 | result["labels"] = result["input_ids"].copy()
62 | return result
63 |
64 |
65 | # class TextDataset(Dataset):
66 | # def __init__(self, data, tokenizer, seqlen, col_key, per_device_train_batch_size, mode="train", cutoff=None):
67 | # self.tokenizer = tokenizer
68 | # self.mode = mode
69 | # self.col_key = col_key
70 | # self.cutoff = cutoff
71 | # self.seqlen = seqlen
72 | # self.per_device_train_batch_size = per_device_train_batch_size
73 |
74 | # if self.mode == "train":
75 | # self.encodings = self.process_data(data)
76 | # else:
77 | # self.encodings = self.process_data(data, is_val=True)
78 |
79 | # def process_data(self, data, is_val=False):
80 | # seqlen = self.seqlen
81 | # if is_val:
82 | # if self.cutoff is None:
83 | # enc = self.tokenizer(" ".join(data[self.col_key]), return_tensors='pt')
84 | # else:
85 | # enc = self.tokenizer(" ".join(data[:self.cutoff][self.col_key]), return_tensors='pt')
86 | # tot_num_seq = enc['input_ids'].size(1) // seqlen
87 | # enc['input_ids'] = enc['input_ids'][..., :tot_num_seq*seqlen]
88 | # else:
89 | # if self.cutoff is None:
90 | # enc = self.tokenizer(" ".join(data[self.col_key]), return_tensors='pt')
91 | # else:
92 | # enc = self.tokenizer(" ".join(data[:self.cutoff][self.col_key]), return_tensors='pt')
93 | # tot_num_seq = enc['input_ids'].size(1) // (seqlen*self.per_device_train_batch_size)
94 | # enc['input_ids'] = enc['input_ids'][..., :tot_num_seq*seqlen*self.per_device_train_batch_size]
95 |
96 | # return enc
97 |
98 | # def __getitem__(self, idx):
99 | # input_ids = self.encodings['input_ids'][0, idx*self.seqlen:(idx+1)*self.seqlen]
100 | # return input_ids
101 |
102 | # def __len__(self):
103 | # return self.encodings['input_ids'].size(1) // self.seqlen
104 |
105 |
106 | def set_seed(seed):
107 | np.random.seed(seed)
108 | torch.random.manual_seed(seed)
109 | random.seed(seed)
110 |
111 |
112 | def get_c4(n_train_samples, n_eval_samples, seqlen, tokenizer):
113 | # raw_tra_data = load_dataset("c4", split="train")
114 | raw_tra_data = load_dataset('allenai/c4', 'allenai--c4',
115 | data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
116 | split='train')
117 | # raw_val_data = load_dataset("c4", split="validation")
118 | raw_val_data = load_dataset('allenai/c4', 'allenai--c4',
119 | data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
120 | split='validation')
121 | train_dataset = TextDataset(raw_tra_data, tokenizer,
122 | col_key='text',
123 | cutoff=n_train_samples,
124 | seqlen=seqlen)
125 | val_dataset = TextDataset(raw_val_data, tokenizer,
126 | col_key='text',
127 | cutoff=n_eval_samples, # todo: change to 1100
128 | seqlen=seqlen)
129 | return train_dataset, val_dataset
130 |
131 |
132 | def get_loaders(
133 | name, enc, n_train_samples=128, n_eval_samples=1024, seqlen=2048):
134 | if 'c4' in name:
135 | return get_c4(n_train_samples, n_eval_samples, seqlen, enc)
136 | else:
137 | raise NotImplementedError
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | import jieba
5 | from fuzzywuzzy import fuzz
6 | import difflib
7 |
8 | from typing import List
9 | from collections import Counter
10 | from rouge import Rouge
11 |
12 | def normalize_answer(s):
13 | """Lower text and remove punctuation, articles and extra whitespace."""
14 |
15 | def remove_articles(text):
16 | return re.sub(r"\b(a|an|the)\b", " ", text)
17 |
18 | def white_space_fix(text):
19 | return " ".join(text.split())
20 |
21 | def remove_punc(text):
22 | exclude = set(string.punctuation)
23 | return "".join(ch for ch in text if ch not in exclude)
24 |
25 | def lower(text):
26 | return text.lower()
27 |
28 | return white_space_fix(remove_articles(remove_punc(lower(s))))
29 |
30 |
31 | def normalize_zh_answer(s):
32 | """Lower text and remove punctuation, extra whitespace."""
33 |
34 | def white_space_fix(text):
35 | return "".join(text.split())
36 |
37 | def remove_punc(text):
38 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
39 | all_punctuation = set(string.punctuation + cn_punctuation)
40 | return "".join(ch for ch in text if ch not in all_punctuation)
41 |
42 | def lower(text):
43 | return text.lower()
44 |
45 | return white_space_fix(remove_punc(lower(s)))
46 |
47 | def count_score(prediction, ground_truth, **kwargs):
48 | numbers = re.findall(r"\d+", prediction)
49 | right_num = 0
50 | for number in numbers:
51 | if str(number) == str(ground_truth):
52 | right_num += 1
53 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
54 | return float(final_score)
55 |
56 | def retrieval_score(prediction, ground_truth, **kwargs):
57 | pattern = r'Paragraph (\d+)'
58 | matches = re.findall(pattern, ground_truth)
59 | ground_truth_id = matches[0]
60 | numbers = re.findall(r"\d+", prediction)
61 | right_num = 0
62 | for number in numbers:
63 | if str(number) == str(ground_truth_id):
64 | right_num += 1
65 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
66 | return float(final_score)
67 |
68 | def retrieval_zh_score(prediction, ground_truth, **kwargs):
69 | pattern = r'段落(\d+)'
70 | matches = re.findall(pattern, ground_truth)
71 | ground_truth_id = matches[0]
72 | numbers = re.findall(r"\d+", prediction)
73 | right_num = 0
74 | for number in numbers:
75 | if str(number) == str(ground_truth_id):
76 | right_num += 1
77 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
78 | return float(final_score)
79 |
80 | def code_sim_score(prediction, ground_truth, **kwargs):
81 | all_lines = prediction.lstrip('\n').split('\n')
82 | prediction = ""
83 | for line in all_lines:
84 | if ('`' not in line) and ('#' not in line) and ('//' not in line):
85 | prediction = line
86 | break
87 | return (fuzz.ratio(prediction, ground_truth) / 100)
88 |
89 | def classification_score(prediction, ground_truth, **kwargs):
90 | em_match_list = []
91 | all_classes = kwargs["all_classes"]
92 | for class_name in all_classes:
93 | if class_name in prediction:
94 | em_match_list.append(class_name)
95 | for match_term in em_match_list:
96 | if match_term in ground_truth and match_term != ground_truth:
97 | em_match_list.remove(match_term)
98 | if em_match_list != 0:
99 | if ground_truth in em_match_list:
100 | score = (1.0 / len(em_match_list))
101 | else:
102 | score = 0.0
103 | else:
104 | best_match = None
105 | highest_similarity = 0
106 | for string in all_classes:
107 | similarity = difflib.SequenceMatcher(None, string, prediction).ratio()
108 | if similarity > highest_similarity:
109 | highest_similarity = similarity
110 | best_match = string
111 | score = float(best_match == ground_truth)
112 | return score
113 |
114 | def rouge_score(prediction, ground_truth, **kwargs):
115 | rouge = Rouge()
116 | try:
117 | scores = rouge.get_scores([prediction], [ground_truth], avg=True)
118 | except:
119 | return 0.0
120 | return scores["rouge-l"]["f"]
121 |
122 | def rouge_zh_score(prediction, ground_truth, **kwargs):
123 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
124 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
125 | score = rouge_score(prediction, ground_truth)
126 | return score
127 |
128 | def f1_score(prediction, ground_truth, **kwargs):
129 | common = Counter(prediction) & Counter(ground_truth)
130 | num_same = sum(common.values())
131 | if num_same == 0:
132 | return 0
133 | precision = 1.0 * num_same / len(prediction)
134 | recall = 1.0 * num_same / len(ground_truth)
135 | f1 = (2 * precision * recall) / (precision + recall)
136 | return f1
137 |
138 | def qa_f1_score(prediction, ground_truth, **kwargs):
139 | normalized_prediction = normalize_answer(prediction)
140 | normalized_ground_truth = normalize_answer(ground_truth)
141 |
142 | prediction_tokens = normalized_prediction.split()
143 | ground_truth_tokens = normalized_ground_truth.split()
144 | return f1_score(prediction_tokens, ground_truth_tokens)
145 |
146 |
147 | def qa_f1_zh_score(prediction, ground_truth, **kwargs):
148 | prediction_tokens = list(jieba.cut(prediction, cut_all=False))
149 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
150 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
151 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
152 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
153 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
154 | return f1_score(prediction_tokens, ground_truth_tokens)
--------------------------------------------------------------------------------
/utils/process_args.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import os
9 | from dataclasses import dataclass, field
10 | from typing import Optional
11 |
12 | import transformers
13 |
14 |
15 | @dataclass
16 | class ModelArguments:
17 | model_name_or_path: str = field(
18 | default=None, metadata={"help": "Output model local path, do not set manually"}
19 | )
20 | k_bits: Optional[int] = field(
21 | default=2,
22 | metadata={"help": "KV_cache quantization bits."},
23 | )
24 | v_bits: Optional[int] = field(
25 | default=2,
26 | metadata={"help": "KV_cache quantization bits."},
27 | )
28 | k_quant_dim: Optional[str] = field(
29 | default='token',
30 | metadata={"help": "KV_cache quantization bits."},
31 | )
32 | v_quant_dim: Optional[str] = field(
33 | default='token',
34 | metadata={"help": "KV_cache quantization bits."},
35 | )
36 | group_size: Optional[int] = field(
37 | default=128,
38 | metadata={"help": "KV_cache quantization group size."},
39 | )
40 | residual_length: Optional[int] = field(
41 | default=128,
42 | metadata={"help": "KV_cache residual length."},
43 | )
44 | output_model_filename: Optional[str] = field(
45 | default="test-output", metadata={"help": "Output model relative manifold path"}
46 | )
47 | load_quant: Optional[str] = field(
48 | default=None,
49 | metadata={"help": "The path to a quantized model"},
50 | )
51 | w_bit: Optional[int] = field(
52 | default=4,
53 | metadata={"help": "The model weight bit width."},
54 | )
55 | lora: Optional[bool] = field(
56 | default=False,
57 | metadata={"help": "Whether to use LoRA"},
58 | )
59 | lora_mode: Optional[str] = field(
60 | default="q",
61 | metadata={"help": "LoRA mode"},
62 | )
63 | lora_r: Optional[int] = field(
64 | default=1,
65 | metadata={"help": "LoRA r"},
66 | )
67 | lora_alpha: Optional[float] = field(
68 | default=1.,
69 | metadata={"help": "LoRA alpha"},
70 | )
71 | lora_dropout: Optional[float] = field(
72 | default=0.,
73 | metadata={"help": "LoRA dropout"},
74 | )
75 |
76 |
77 |
78 | @dataclass
79 | class DataArguments:
80 | dataset: Optional[str] = field(
81 | default='c4',
82 | metadata={"help": "The dataset used for fine-tuning the model."},
83 | )
84 | eval_tasks: Optional[str] = field(
85 | default='wikitext',
86 | metadata={"help": "The dataset used for evaluation."},
87 | )
88 | tasks: Optional[str] = field(
89 | default='wikitext',
90 | metadata={"help": "The dataset used for evaluation."},
91 | )
92 | batch_size: Optional[int] = field(
93 | default=1,
94 | metadata={"help": "The batch size."},
95 | )
96 | num_fewshot: Optional[int] = field(
97 | default=0,
98 | metadata={"help": "The number of fewshot examples."},
99 | )
100 | output_path: Optional[str] = field(
101 | default='./outputs',
102 | metadata={"help": "The output path."},
103 | )
104 | e: Optional[bool] = field(
105 | default=False,
106 | metadata={"help": "Evaluate on LongBench-E."},
107 | )
108 | use_our_imp: Optional[bool] = field(
109 | default=True,
110 | metadata={"help": "Whether to use our KV cache quantization implementation."},
111 | )
112 |
113 |
114 |
115 | @dataclass
116 | class TrainingArguments(transformers.TrainingArguments):
117 | cache_dir: Optional[str] = field(default=None)
118 | optim: Optional[str] = field(default="adamw_torch")
119 | output_dir: Optional[str] = field(default="./outputs")
120 | model_max_length: Optional[int] = field(
121 | default=512,
122 | metadata={
123 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated). 512 or 1024"
124 | },
125 | )
126 | num_train_epochs: Optional[int] = field(default=1)
127 | n_train_samples: Optional[int] = field(default=None)
128 | n_eval_samples: Optional[int] = field(default=None)
129 | qat: Optional[bool] = field(default=False)
130 | exp_name: Optional[str] = field(default="test")
131 |
132 |
133 | def process_args():
134 | parser = transformers.HfArgumentParser(
135 | (ModelArguments, DataArguments, TrainingArguments)
136 | )
137 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
138 | os.makedirs(training_args.output_dir, exist_ok=True)
139 |
140 | model_args.output_model_local_path = os.path.join(
141 | training_args.output_dir, "models", str(model_args.output_model_filename)
142 | )
143 |
144 | return model_args, data_args, training_args
145 |
--------------------------------------------------------------------------------
/vis/vis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import os\n",
11 | "import matplotlib\n",
12 | "import numpy as np\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "from mpl_toolkits.mplot3d import Axes3D\n",
15 | "from matplotlib.colors import Normalize\n",
16 | "\n",
17 | "FONTSIZE = 16\n",
18 | "\n",
19 | "font_config = {'font.size': FONTSIZE, 'font.family': 'DejaVu Math TeX Gyre'}\n",
20 | "plt.rcParams.update(font_config)\n",
21 | "plt.rcParams[\"figure.figsize\"] = (4, 4.5)\n",
22 | "\n",
23 | "# generate kv cache and attention\n",
24 | "# inputs = enc(sample, return_tensors='pt').to('cuda')\n",
25 | "# outputs = model(inputs['input_ids'], use_cache=True, output_attentions=True)\n",
26 | "# past_key_values = outputs.past_key_values\n",
27 | "# attentions = outputs.attentions\n",
28 | "# torch.save(past_key_values, f'./{model}_kvcache.pt')\n",
29 | "# torch.save(attentions, f'./{model}_attention.pt')\n",
30 | "\n",
31 | "model = 'Llama-2-7b-hf' # replace with your model name\n",
32 | "kv_filename = f'./{model}_kvcache.pt'\n",
33 | "attn_filename = f'./{model}_attention.pt'\n",
34 | "kvcache = torch.load(kv_filename, map_location='cpu')\n",
35 | "attentions = torch.load(attn_filename, map_location='cpu')"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "for layer_id in [3, 8, 14, 16, 18, 20, 31]: # replace with your layer ids\n",
45 | " head_id = 0\n",
46 | " k, v = kvcache[layer_id][0].squeeze(0), kvcache[layer_id][1].squeeze(0)\n",
47 | "\n",
48 | " k = k.transpose(0, 1).abs().detach().numpy()\n",
49 | " v = v.transpose(0, 1).abs().detach().numpy()\n",
50 | " k, v = k[:, head_id, :], v[:, head_id, :]\n",
51 | "\n",
52 | " # Sample 2D tensor (replace this with your actual tensor)\n",
53 | " for idx, tensor in enumerate([k, v]):\n",
54 | " # Creating a meshgrid\n",
55 | " tokens, channels = tensor.shape\n",
56 | " x = np.arange(tokens)\n",
57 | " y = np.arange(channels)\n",
58 | " X, Y = np.meshgrid(x, y)\n",
59 | " # Creating a figure and a 3D subplot\n",
60 | " fig = plt.figure()\n",
61 | " ax = fig.add_subplot(111, projection='3d')\n",
62 | " # Plotting the surface\n",
63 | " surf = ax.plot_surface(X, Y, tensor.T, cmap='coolwarm')\n",
64 | "\n",
65 | " ax.xaxis.set_tick_params(pad=-5)\n",
66 | " ax.yaxis.set_tick_params(pad=-3)\n",
67 | " ax.zaxis.set_tick_params(pad=-130)\n",
68 | "\n",
69 | " # Adding labels\n",
70 | " ax.set_xlabel('Token', labelpad=-5)\n",
71 | " ax.set_ylabel('Column', labelpad=-1)\n",
72 | " if layer_id in [3, 16]:\n",
73 | " ax.zaxis.set_rotate_label(False) \n",
74 | " if idx == 0:\n",
75 | " save_filename = f'./saved_figs/{model}_layer{layer_id}_head{head_id}_k.pdf'\n",
76 | " else:\n",
77 | " save_filename = f'./saved_figs/{model}_layer{layer_id}_head{head_id}_v.pdf'\n",
78 | " plt.savefig(save_filename, bbox_inches='tight')\n",
79 | " plt.clf()"
80 | ]
81 | }
82 | ],
83 | "metadata": {
84 | "language_info": {
85 | "name": "python"
86 | }
87 | },
88 | "nbformat": 4,
89 | "nbformat_minor": 2
90 | }
91 |
--------------------------------------------------------------------------------