├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── LICENSE
├── README.md
├── benchmarks
├── hadamard_benchmark.py
├── qattention_benchmark.py
└── qlinear_benchmark.py
├── e2e
├── __init__.py
├── benchmark.py
├── benchmark_layer.py
├── checkpoint_utils
│ ├── __init__.py
│ ├── data_utils.py
│ ├── gptq_utils.py
│ ├── quantize_llama_checkpoint.py
│ └── rotation_utils.py
└── quantized_llama
│ ├── __init__.py
│ └── modeling_llama.py
├── fake_quant
├── README.md
├── data_utils.py
├── eval_utils.py
├── gptq_utils.py
├── hadamard_utils.py
├── main.py
├── model_utils.py
├── monkeypatch.py
├── quant_utils.py
├── rotation_utils.py
└── utils.py
├── img
├── carrot.png
└── fig1.png
├── quarot
├── __init__.py
├── functional
│ ├── __init__.py
│ ├── hadamard.py
│ └── quantization.py
├── kernels
│ ├── bindings.cpp
│ ├── flashinfer.cu
│ ├── gemm.cu
│ ├── include
│ │ ├── common.h
│ │ ├── flashinfer.h
│ │ ├── flashinfer
│ │ │ ├── cp_async.cuh
│ │ │ ├── decode.cuh
│ │ │ ├── layout.cuh
│ │ │ ├── math.cuh
│ │ │ ├── mma.cuh
│ │ │ ├── page.cuh
│ │ │ ├── permuted_smem.cuh
│ │ │ ├── prefill.cuh
│ │ │ ├── quantization.cuh
│ │ │ ├── rope.cuh
│ │ │ ├── state.cuh
│ │ │ ├── utils.cuh
│ │ │ └── vec_dtypes.cuh
│ │ ├── gemm.h
│ │ ├── int4.h
│ │ ├── quant.h
│ │ └── util.h
│ └── quant.cu
├── nn
│ ├── __init__.py
│ ├── hadamard.py
│ ├── linear.py
│ ├── normalization.py
│ └── quantization.py
└── transformers
│ ├── __init__.py
│ └── kv_cache.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ninja_deps
2 | *.o
3 | *.json
4 | *.log
5 | *.pyc
6 | *.so
7 | build/*
8 | quarot.egg-info/*
9 | *.cmake
10 | *.in
11 | CMakeFiles/*
12 | CMakeCache.txt
13 | *.egg
14 | Makefile
15 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 |
2 | [submodule "third-party/cutlass"]
3 | path = third-party/cutlass
4 | url = https://github.com/NVIDIA/cutlass.git
5 | [submodule "third-party/nvbench"]
6 | path = third-party/nvbench
7 | url = https://github.com/NVIDIA/nvbench
8 | [submodule "third-party/fast-hadamard-transform"]
9 | path = third-party/fast-hadamard-transform
10 | url = https://github.com/Dao-AILab/fast-hadamard-transform.git
11 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.11)
2 | project(quarot LANGUAGES CXX)
3 |
4 | find_package(Git REQUIRED)
5 | if(GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git")
6 | message(STATUS "Populating Git submodule.")
7 | execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive
8 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
9 | RESULT_VARIABLE GIT_SUBMOD_RESULT)
10 | if(NOT GIT_SUBMOD_RESULT EQUAL "0")
11 | message(FATAL_ERROR
12 | "git submodule updata --init --recursive failed with ${GIT_SUBMOD_RESULT}.")
13 | endif()
14 | endif()
15 |
16 |
17 | set(_saved_CMAKE_MESSAGE_LOG_LEVEL ${CMAKE_MESSAGE_LOG_LEVEL})
18 | set(CMAKE_MESSAGE_LOG_LEVEL ERROR)
19 | add_subdirectory(third-party/cutlass)
20 | set(CMAKE_MESSAGE_LOG_LEVEL ${_saved_CMAKE_MESSAGE_LOG_LEVEL})
21 |
22 | include_directories("${CMAKE_SOURCE_DIR}")
23 | include_directories(third-party/cutlass/tools/util/include)
24 | include_directories(third-party/cutlass/include)
25 | include_directories(quarot/kernels/include)
26 |
27 | get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
28 | foreach(dir ${dirs})
29 | message(STATUS "dir='${dir}'")
30 | endforeach()
31 |
32 | # add_subdirectory(quarot/kernels)
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | #
QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs
3 | This repository contains the code for [**QuaRot**: Outlier-Free 4-Bit Inference in Rotated LLMs](https://arxiv.org/abs/2404.00456).
4 |
5 |
6 |
7 | ## Abstract
8 | We introduce QuaRot, a new **Qua**ntization scheme based on **Rot**ations, which is able to quantize LLMs end-to-end, including all weights, activations, and KV cache in 4 bits. QuaRot rotates LLMs in a way that removes outliers from the hidden state without changing the output, making quantization easier. This *computational invariance* is applied to the hidden state (residual) of the LLM, as well as to the activations of the feed-forward components, aspects of the attention mechanism and to the KV cache. The result is a quantized model where all matrix multiplications are performed in 4-bits, without any channels identified for retention in higher precision. Our quantized **LLaMa2-70B** model has losses of at most **0.29 WikiText perplexity** and retains **99% of the zero-shot** performance.
9 |
10 | 
11 |
12 | ## Usage
13 |
14 |
15 | Compile the QuaRot kernels using the following commands:
16 |
17 | ```bash
18 | git clone https://github.com/spcl/QuaRot.git
19 | cd QuaRot
20 | pip install -e . # or pip install .
21 | ```
22 |
23 | For simulation results, check [fake_quant](https://github.com/spcl/QuaRot/tree/main/fake_quant) directory.
24 |
25 |
26 |
27 | ### Star History
28 |
29 | [](https://star-history.com/#spcl/QuaRot&Date)
30 |
31 |
32 | ## Citation
33 |
34 | The full citation is
35 |
36 | ```
37 | @article{ashkboos2024quarot,
38 | title={QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs},
39 | author={Ashkboos, Saleh and Mohtashami, Amirkeivan and Croci, Maximilian L and Li, Bo and Jaggi, Martin and Alistarh, Dan and Hoefler, Torsten and Hensman, James},
40 | journal={arXiv preprint arXiv:2404.00456},
41 | year={2024}
42 | }
43 | ```
44 |
--------------------------------------------------------------------------------
/benchmarks/hadamard_benchmark.py:
--------------------------------------------------------------------------------
1 | import fast_hadamard_transform
2 | import torch
3 | import time
4 | for i in [1024, 2048, 4096, 4096*2, 4096*3]:
5 | x = torch.rand(i, i).cuda().to(torch.float16)
6 | torch.cuda.synchronize()
7 | fp32_time = 0
8 | fp16_time = 0
9 |
10 | for j in range(10):
11 | timer = time.time()
12 | y_had_float = fast_hadamard_transform.hadamard_transform(x.float()).half()
13 | torch.cuda.synchronize()
14 | fp32_time += time.time() - timer
15 | torch.cuda.synchronize()
16 | print(fp32_time)
17 |
18 | for j in range(10):
19 | timer = time.time()
20 | y_had = fast_hadamard_transform.hadamard_transform(x)
21 | torch.cuda.synchronize()
22 | fp16_time += time.time() - timer
23 | torch.cuda.synchronize()
24 | print(fp16_time)
25 | print(torch.allclose(y_had, y_had_float, atol=1e-7))
--------------------------------------------------------------------------------
/benchmarks/qattention_benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pprint
3 | import numpy as np
4 | import torch
5 | import time
6 |
7 | from quarot.transformers.kv_cache import MultiLayerPagedKVCache4Bit
8 |
9 | model_sizes = [
10 | (32, 32, 128), #llama-7b
11 | (40, 40, 128), #llama-13b
12 | (80, 64, 128) #llama-70b
13 | ]
14 |
15 | benchmark_dtypes = ["int4", torch.float16]
16 | num_warmup_steps = 5
17 | num_bench_steps = 100
18 |
19 | def module_benchmark(module):
20 | # warmup
21 | for i in range(num_warmup_steps):
22 | out = module()
23 | torch.cuda.synchronize()
24 |
25 | torch.cuda.reset_max_memory_allocated()
26 | start_time = time.perf_counter()
27 | for i in range(num_bench_steps):
28 | out = module()
29 | torch.cuda.synchronize()
30 | memory_usage = torch.cuda.max_memory_allocated()
31 |
32 | end_time = time.perf_counter()
33 |
34 |
35 | return (end_time - start_time) * 1000 / num_bench_steps, memory_usage
36 |
37 | def quantized_kv_cache_decode(
38 | n_layers, num_heads, head_dim,
39 | batch_size, dtype, seq_len,
40 | hadamard_dtype=torch.float16):
41 | device = torch.device("cuda:0")
42 | cache = MultiLayerPagedKVCache4Bit(
43 | batch_size=batch_size,
44 | page_size=seq_len,
45 | max_seq_len=seq_len,
46 | device=device,
47 | n_layers=n_layers, # Ignornig n_layers as it does not affect speed
48 | num_heads=num_heads,
49 | head_dim=head_dim,
50 | disable_quant=dtype == torch.float16,
51 | hadamard_dtype=hadamard_dtype,
52 | )
53 | query_states = torch.rand((batch_size, 1, num_heads, head_dim), device=device, dtype=torch.float16)
54 | key_states = torch.rand((batch_size, 1, num_heads, head_dim), device=device, dtype=torch.float16)
55 | value_states = torch.rand((batch_size, 1, num_heads, head_dim), device=device, dtype=torch.float16)
56 | def _fake_prefill_and_decode():
57 | cache._needs_init = [False] * len(cache._needs_init)
58 | cache.length = seq_len - 1
59 | forward_func = cache.update(key_states, value_states, layer_idx=0, cache_kwargs={})
60 | attn_out = forward_func(query_states)
61 |
62 | times = []
63 | for i in range(10):
64 | times.append(module_benchmark(_fake_prefill_and_decode))
65 | return zip(*times)
66 |
67 |
68 | def qattention_benchmark(args):
69 |
70 | for n_layers, num_heads, head_dim in model_sizes:
71 | time_fp16, memory_fp16 = quantized_kv_cache_decode(
72 | n_layers=n_layers,
73 | num_heads=num_heads,
74 | head_dim=head_dim,
75 | batch_size=args.batch_size,
76 | dtype=torch.float16,
77 | seq_len=args.seq_len,
78 | hadamard_dtype=None
79 | )
80 |
81 | time_int4, memory_int4 = quantized_kv_cache_decode(
82 | n_layers=n_layers,
83 | num_heads=num_heads,
84 | head_dim=head_dim,
85 | batch_size=args.batch_size,
86 | dtype="int4",
87 | seq_len=args.seq_len,
88 | hadamard_dtype=None
89 | )
90 | time_int4_hadfp16, _ = quantized_kv_cache_decode(
91 | n_layers=n_layers,
92 | num_heads=num_heads,
93 | head_dim=head_dim,
94 | batch_size=args.batch_size,
95 | dtype="int4",
96 | seq_len=args.seq_len,
97 | hadamard_dtype=torch.float16
98 | )
99 | time_int4_hadfp32, _ = quantized_kv_cache_decode(
100 | n_layers=n_layers,
101 | num_heads=num_heads,
102 | head_dim=head_dim,
103 | batch_size=args.batch_size,
104 | dtype="int4",
105 | seq_len=args.seq_len,
106 | hadamard_dtype=torch.float32
107 | )
108 |
109 | print(f"Int4 time: {np.mean(time_int4):.3f} +- {1.96 * np.std(time_int4):.3f}ms")
110 |
111 | print(f"Int4 (+FP16had) time: {np.mean(time_int4_hadfp16):.3f} +- {1.96 * np.std(time_int4_hadfp16):.3f}ms")
112 |
113 | print(f"Int4 (+FP32had) time: {np.mean(time_int4_hadfp32):.3f} +- {1.96 * np.std(time_int4_hadfp32):.3f}ms")
114 |
115 | print(f"FP16 time: {np.mean(time_fp16):.3f} +- {1.96 * np.std(time_fp16):.3f}ms")
116 |
117 | print(f"Speedup: {np.mean(time_fp16) / np.mean(time_int4_hadfp16):.3f}x")
118 |
119 | print(f"Int4 memory: {np.mean(memory_int4):.3f} +- {1.96 * np.std(memory_int4):.3f}ms")
120 | print(f"FP16 memory: {np.mean(memory_fp16):.3f} +- {1.96 * np.std(memory_fp16):.3f}ms")
121 | print(f"Memory Saving: {np.mean(memory_fp16) / np.mean(memory_int4):.3f}x")
122 |
123 | # table-style output
124 | print(f'{n_layers}x{num_heads}x{head_dim} & {args.batch_size} & {np.mean(time_fp16):.3f} & {np.mean(time_int4):.3f} & {np.mean(time_int4_hadfp32):.3f} & {np.mean(time_int4_hadfp16):.3f}\\\\')
125 | print('--------------')
126 |
127 | if __name__ == '__main__':
128 | parser = argparse.ArgumentParser()
129 |
130 | parser.add_argument(
131 | '--batch_size', type=int,
132 | help='Batch size',
133 | default=1,
134 | )
135 | parser.add_argument(
136 | '--seq_len', type=int,
137 | help='Size of the input sequence',
138 | default=2048,
139 | )
140 |
141 | args = parser.parse_args()
142 | pprint.pprint(vars(args))
143 | qattention_benchmark(args)
144 |
--------------------------------------------------------------------------------
/benchmarks/qlinear_benchmark.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from quarot.nn import Linear4bit, Quantizer, OnlineHadamard
3 | import time
4 | import argparse
5 | import numpy as np
6 | import pprint
7 |
8 | model_sizes = [
9 | (4096, 4096), #llama-7b
10 | (5120, 5120), #llama-13b
11 | (8192, 8192) #llama-70b
12 | ]
13 |
14 | mlp_sizes = [
15 | (4096, 11008), #llama-7b
16 | (5120, 13824), #llama-13b
17 | (8192, 28672) #llama-70b
18 | ]
19 | benchmark_dtypes = [torch.float16]
20 | num_warmup_steps = 5
21 | num_bench_steps = 100
22 |
23 |
24 | def module_benchmark(module, x):
25 | x = x.cuda()
26 |
27 | # warmup
28 | for i in range(num_warmup_steps):
29 | out = module(x)
30 | torch.cuda.synchronize()
31 |
32 | start_time = time.perf_counter()
33 | for i in range(num_bench_steps):
34 | out = module(x)
35 | torch.cuda.synchronize()
36 |
37 | end_time = time.perf_counter()
38 |
39 |
40 | return (end_time - start_time) * 1000 / num_bench_steps
41 |
42 | def linear4bit_benchmark(args):
43 |
44 | bsz = args.bsz
45 | seq_len = args.seq_len
46 |
47 | if args.layer_type == 'v_proj':
48 | layer_size = model_sizes
49 | else:
50 | layer_size = mlp_sizes
51 |
52 |
53 | for (feature_dim_in, feature_dim_out) in layer_size:
54 | for dtype in benchmark_dtypes:
55 |
56 | x = torch.rand((bsz,
57 | seq_len,
58 | feature_dim_in)).cuda().to(dtype)
59 |
60 | baseline_mod = torch.nn.Linear(feature_dim_in,
61 | feature_dim_out,
62 | bias=False).cuda().to(dtype)
63 |
64 | baseline_mod.weight.data = torch.randint_like(baseline_mod.weight.data,
65 | low=-8, high=7).to(dtype)
66 |
67 | s_w = torch.ones((feature_dim_out, 1), dtype=torch.float16, device='cuda')
68 | int4_mod = torch.nn.Sequential(
69 | Quantizer(input_clip_ratio=1.0),
70 | Linear4bit.from_float(baseline_mod, weight_scales=s_w)
71 | ).cuda()
72 | int4_mod_had = torch.nn.Sequential(
73 | OnlineHadamard(baseline_mod.in_features, force_fp32=True),
74 | Quantizer(input_clip_ratio=1.0),
75 | Linear4bit.from_float(baseline_mod, weight_scales=s_w),
76 | ).cuda()
77 | #int4_mod_had.online_full_had = True
78 | #int4_mod.fp32_had = True
79 |
80 | int4_mod_fp16had = torch.nn.Sequential(
81 | OnlineHadamard(baseline_mod.in_features, force_fp32=False),
82 | Quantizer(input_clip_ratio=1.0),
83 | Linear4bit.from_float(baseline_mod, weight_scales=s_w),
84 | ).cuda()
85 |
86 |
87 |
88 | print(f"{dtype}. Sizes: {baseline_mod.weight.shape}")
89 | times_4bit = []
90 | for i in range(10):
91 | times_4bit.append(module_benchmark(int4_mod, x))
92 | print(f"Int4 time: {np.mean(times_4bit):.3f} +- {1.96 * np.std(times_4bit):.3f}ms")
93 |
94 | times_4bit_had = []
95 | for i in range(10):
96 | times_4bit_had.append(module_benchmark(int4_mod_had, x))
97 | print(f"Int4 (+FP32had) time: {np.mean(times_4bit_had):.3f} +- {1.96 * np.std(times_4bit_had):.3f}ms")
98 |
99 | times_4bit_fp16had = []
100 | for i in range(10):
101 | times_4bit_fp16had.append(module_benchmark(int4_mod_fp16had, x))
102 | print(f"Int4 (+FP16had) time: {np.mean(times_4bit_fp16had):.3f} +- {1.96 * np.std(times_4bit_fp16had):.3f}ms")
103 |
104 |
105 | times_baseline = []
106 | for i in range(10):
107 | times_baseline.append(module_benchmark(baseline_mod, x))
108 | print(f"FP16 time: {np.mean(times_baseline):.3f} +- {1.96 * np.std(times_baseline):.3f}ms")
109 |
110 | print(f"Speedup: {np.mean(times_baseline) / np.mean(times_4bit):.3f}x")
111 |
112 | # table-style output
113 | print(f'{feature_dim_in}x{feature_dim_out} & {args.bsz} & {np.mean(times_baseline):.3f} & {np.mean(times_4bit):.3f} & {np.mean(times_4bit_had):.3f} & {np.mean(times_4bit_fp16had):.3f}\\\\')
114 | print('--------------')
115 |
116 |
117 | if __name__ == '__main__':
118 | parser = argparse.ArgumentParser()
119 |
120 | parser.add_argument(
121 | '--bsz', type=int,
122 | help='Batch size',
123 | default=1,
124 | )
125 | parser.add_argument(
126 | '--seq_len', type=int,
127 | help='Size of the input sequence',
128 | default=2048,
129 | )
130 | parser.add_argument(
131 | '--layer_type', type=str,
132 | help='Type of the layer in the model (v_proj [default], down_proj)',
133 | default='v_proj',
134 | choices=['v_proj', 'down_proj']
135 | )
136 |
137 | args = parser.parse_args()
138 | pprint.pprint(vars(args))
139 | linear4bit_benchmark(args)
140 |
--------------------------------------------------------------------------------
/e2e/__init__.py:
--------------------------------------------------------------------------------
1 | from .quantized_llama.modeling_llama import QuarotLlamaForCausalLM, QuarotLlamaConfig
2 |
--------------------------------------------------------------------------------
/e2e/benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import pprint
4 | import numpy as np
5 | import torch
6 | import time
7 |
8 | from e2e.quantized_llama import modeling_llama
9 | import torch
10 | import transformers
11 |
12 | model_configs = [
13 | "meta-llama/Llama-2-7b-hf",
14 | # "meta-llama/Llama-2-13b-hf",
15 | # "meta-llama/Llama-2-70b-hf",
16 | ]
17 |
18 | benchmark_dtypes = ["int4", torch.float16]
19 | num_warmup_steps = 0
20 | num_bench_steps = 1
21 |
22 | def repeated_run(num_repeats=10):
23 | def func(module):
24 | def _f(*args, **kwargs):
25 | times = []
26 | for i in range(num_repeats):
27 | times.append(module(*args, **kwargs))
28 | return tuple(zip(*times))
29 | return _f
30 | return func
31 |
32 | def _cleanup():
33 | gc.collect()
34 | torch.cuda.empty_cache()
35 |
36 | @repeated_run()
37 | def module_benchmark(module):
38 | # warmup
39 | for i in range(num_warmup_steps):
40 | out = module()
41 | torch.cuda.synchronize()
42 |
43 | _cleanup()
44 | torch.cuda.reset_max_memory_allocated()
45 | start_time = time.perf_counter()
46 |
47 |
48 | for i in range(num_bench_steps):
49 | out = module()
50 | torch.cuda.synchronize()
51 | peak_memory = torch.cuda.max_memory_allocated()
52 |
53 | end_time = time.perf_counter()
54 |
55 | return (end_time - start_time) * 1000 / num_bench_steps, peak_memory
56 |
57 |
58 | def get_model_quantized(config_name):
59 | config = transformers.AutoConfig.from_pretrained(
60 | config_name,
61 | attn_implementation="flash_attention_2"
62 | )
63 | dtype_old = torch.get_default_dtype()
64 | torch.set_default_dtype(torch.float16)
65 | with transformers.modeling_utils.no_init_weights():
66 | model = modeling_llama.QuarotLlamaForCausalLM(config=config)
67 | torch.set_default_dtype(dtype_old)
68 | return model
69 |
70 |
71 | def get_model_hf(config_name):
72 | return transformers.LlamaForCausalLM.from_pretrained(
73 | config_name,
74 | torch_dtype=torch.float16,
75 | attn_implementation="flash_attention_2"
76 | )
77 |
78 | def get_model_fp16(config_name):
79 | return modeling_llama.QuarotFP16LlamaForCausalLM.from_pretrained(
80 | config_name,
81 | torch_dtype=torch.float16,
82 | attn_implementation="flash_attention_2"
83 | )
84 |
85 |
86 | def run_prefill(model, bsz, prefill_length):
87 | device = model.device
88 | test_input = torch.randint(100, 200, (bsz, prefill_length), dtype=torch.int32, device=device)
89 | return module_benchmark(lambda: model(test_input))
90 |
91 |
92 | def run_decode(model, bsz, prefill_length, decode_steps):
93 | device = model.device
94 | test_input = torch.randint(100, 200, (bsz, prefill_length), dtype=torch.int32, device=device)
95 | model._expected_max_length = prefill_length + decode_steps
96 | out = model(test_input)
97 | past_key_values = out.past_key_values
98 | del out
99 | _cleanup()
100 | next_input = torch.tensor([[100] for _ in range (bsz)], dtype=torch.int32, device=device)
101 | def _decode_for_multiple_steps():
102 | past_key_values.length = prefill_length
103 | for _ in range(decode_steps):
104 | model(next_input, past_key_values=past_key_values)
105 | return module_benchmark(_decode_for_multiple_steps)
106 |
107 |
108 | def run_e2e(model, bsz, prefill_length, decode_steps):
109 | device = model.device
110 | test_input = torch.randint(100, 200, (bsz, prefill_length), dtype=torch.int32, device=device)
111 | next_input = torch.tensor([[100] for _ in range (bsz)], dtype=torch.int32, device=device)
112 | def _prefill_and_decode_for_multiple_steps():
113 | model._expected_max_length = prefill_length + decode_steps
114 | out = model(test_input)
115 | for _ in range(decode_steps):
116 | model(next_input, past_key_values=out.past_key_values)
117 | return module_benchmark(_prefill_and_decode_for_multiple_steps)
118 |
119 |
120 | def _wait_for_input():
121 | print("Press enter")
122 | input()
123 |
124 | @torch.no_grad
125 | def run_all_for_model(model, bsz, prefill, decode):
126 | model.eval()
127 | model = model.cuda()
128 | time_prefill, _ = run_prefill(model, bsz, prefill)
129 | _cleanup()
130 | if decode is not None:
131 | time_decode, memory_decode = run_decode(model, bsz, prefill, decode)
132 | _cleanup()
133 | time_e2e, _ = run_e2e(model, bsz, prefill, decode)
134 | _cleanup()
135 | else:
136 | time_decode = time_e2e = None
137 | return time_prefill, time_decode, time_e2e, memory_decode
138 |
139 | def benchmark(args):
140 |
141 | for config_name in model_configs:
142 | model = get_model_quantized(config_name)
143 | time_prefill_i4, time_decode_i4, time_e2e_i4, mem_i4 = run_all_for_model(
144 | model, args.batch_size, args.prefill_seq_len, args.decode_steps)
145 | del model
146 | _cleanup()
147 | model = get_model_fp16(config_name)
148 | time_prefill_f16, time_decode_f16, time_e2e_f16, mem_f16 = run_all_for_model(
149 | model, args.batch_size, args.prefill_seq_len, args.decode_steps)
150 | del model
151 | _cleanup()
152 |
153 | print(f"Prefill Int4 time: {np.mean(time_prefill_i4):.3f} +- {1.96 * np.std(time_prefill_i4):.3f}ms")
154 | print(f"Prefill FP16 time: {np.mean(time_prefill_f16):.3f} +- {1.96 * np.std(time_prefill_f16):.3f}ms")
155 | print(f"Speedup: {np.mean(time_prefill_f16) / np.mean(time_prefill_i4):.3f}x")
156 | print(f'Prefill & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {np.mean(time_prefill_f16):.3f} & {np.mean(time_prefill_i4):.3f}\\\\')
157 |
158 | if args.decode_steps is not None:
159 | print(f"Decode Int4 time: {np.mean(time_decode_i4):.3f} +- {1.96 * np.std(time_decode_i4):.3f}ms")
160 | print(f"Decode FP16 time: {np.mean(time_decode_f16):.3f} +- {1.96 * np.std(time_decode_f16):.3f}ms")
161 | print(f"Speedup: {np.mean(time_decode_f16) / np.mean(time_decode_i4):.3f}x")
162 | print(f'Decode & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(time_decode_f16):.3f} & {np.mean(time_decode_i4):.3f}\\\\')
163 |
164 | print(f"E2E Int4 time: {np.mean(time_e2e_i4):.3f} +- {1.96 * np.std(time_e2e_i4):.3f}ms")
165 | print(f"E2E FP16 time: {np.mean(time_e2e_f16):.3f} +- {1.96 * np.std(time_e2e_f16):.3f}ms")
166 | print(f"Speedup: {np.mean(time_e2e_f16) / np.mean(time_e2e_i4):.3f}x")
167 | print(f'E2E & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(time_e2e_f16):.3f} & {np.mean(time_e2e_i4):.3f}\\\\')
168 |
169 | # table-style output
170 |
171 | print(f"Int4 memory: {np.mean(mem_i4) / (1024 * 1024 * 1024):.3f}GB +- {1.96 * np.std(mem_i4):.3f}")
172 | print(f"FP16 memory: {np.mean(mem_f16) / (1024 * 1024 * 1024):.3f}GB +- {1.96 * np.std(mem_f16):.3f}")
173 | print(f"Memory saving: {np.mean(mem_f16) / np.mean(mem_i4):.3f}x")
174 | print(f'Memory saving & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(mem_i4) / (1024 * 1024 * 1024):.3f}GB & {np.mean(mem_f16) / (1024 * 1024 * 1024):.3f}GB\\\\')
175 |
176 | print('--------------')
177 |
178 | if __name__ == '__main__':
179 | parser = argparse.ArgumentParser()
180 |
181 | parser.add_argument(
182 | '--batch_size', type=int,
183 | help='Batch size',
184 | default=1,
185 | )
186 | parser.add_argument(
187 | '--prefill_seq_len', type=int,
188 | help='Size of the input sequence',
189 | default=2048,
190 | )
191 | parser.add_argument(
192 | '--decode_steps', type=int,
193 | help='Decode steps',
194 | required=False,
195 | default=None,
196 | )
197 |
198 | args = parser.parse_args()
199 | pprint.pprint(vars(args))
200 | benchmark(args)
201 |
--------------------------------------------------------------------------------
/e2e/benchmark_layer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import functools
4 | import pprint
5 | import numpy as np
6 | import torch
7 | import time
8 |
9 | import quarot
10 | from e2e.quantized_llama import modeling_llama
11 | import torch
12 | import transformers
13 |
14 | model_configs = [
15 | "meta-llama/Llama-2-7b-hf",
16 | #"meta-llama/Llama-2-13b-hf",
17 | # "meta-llama/Llama-2-70b-hf",
18 | ]
19 |
20 | benchmark_dtypes = ["int4", torch.float16]
21 | num_warmup_steps = 3
22 | num_bench_steps = 10
23 |
24 | def repeated_run(num_repeats=10):
25 | def func(module):
26 | def _f(*args, **kwargs):
27 | times = []
28 | for i in range(num_repeats):
29 | times.append(module(*args, **kwargs))
30 | return tuple(zip(*times))
31 | return _f
32 | return func
33 |
34 | def _cleanup():
35 | gc.collect()
36 | torch.cuda.empty_cache()
37 |
38 | @repeated_run()
39 | def module_benchmark(module):
40 | # warmup
41 | for i in range(num_warmup_steps):
42 | out = module()
43 | torch.cuda.synchronize()
44 |
45 | start_time = time.perf_counter()
46 | torch.cuda.reset_max_memory_allocated()
47 |
48 | for i in range(num_bench_steps):
49 | out = module()
50 | torch.cuda.synchronize()
51 | peak_memory = torch.cuda.max_memory_allocated()
52 |
53 | end_time = time.perf_counter()
54 |
55 | return (end_time - start_time) * 1000 / num_bench_steps, peak_memory
56 |
57 |
58 | def _build_cache(batch_size, length, layer, disable_quant, num_key_value_heads, hidden_size, device):
59 | num_heads = num_key_value_heads
60 | model_dim = hidden_size
61 | head_dim = model_dim // num_heads
62 | return quarot.transformers.MultiLayerPagedKVCache4Bit(
63 | batch_size=batch_size,
64 | page_size=length,
65 | max_seq_len=length,
66 | device=device,
67 | n_layers=1,
68 | num_heads=num_heads,
69 | head_dim=head_dim,
70 | disable_quant=disable_quant,
71 | hadamard_dtype=None if disable_quant else torch.float16
72 | )
73 |
74 | def get_model_quantized(config_name):
75 | config = transformers.AutoConfig.from_pretrained(
76 | config_name,
77 | attn_implementation="flash_attention_2"
78 | )
79 | torch.set_default_dtype(torch.float16)
80 | with transformers.modeling_utils.no_init_weights():
81 | model = modeling_llama.QuarotLlamaForCausalLM(config=config)
82 |
83 | return model, functools.partial(
84 | _build_cache,
85 | disable_quant=False,
86 | device=torch.device("cuda:0"),
87 | num_key_value_heads=model.config.num_key_value_heads,
88 | hidden_size=model.config.hidden_size,), model.config.hidden_size
89 |
90 |
91 | def get_model_hf(config_name):
92 | return transformers.LlamaForCausalLM.from_pretrained(
93 | config_name,
94 | torch_dtype=torch.float16,
95 | attn_implementation="flash_attention_2"
96 | ), None, model.config.hidden_size
97 |
98 | def get_model_fp16(config_name):
99 | model = modeling_llama.QuarotFP16LlamaForCausalLM.from_pretrained(
100 | config_name,
101 | torch_dtype=torch.float16,
102 | attn_implementation="flash_attention_2"
103 | )
104 | return model, functools.partial(
105 | _build_cache,
106 | disable_quant=True,
107 | device=torch.device("cuda:0"),
108 | num_key_value_heads=model.config.num_key_value_heads,
109 | hidden_size=model.config.hidden_size,
110 | ), model.config.hidden_size
111 |
112 |
113 | def run_prefill(layer, cache_builder, bsz, prefill_length, hidden_size):
114 | device = layer.self_attn.v_proj.weight.device
115 | test_input = torch.rand((bsz, prefill_length, hidden_size), dtype=torch.float16, device=device)
116 | if cache_builder is None:
117 | def _prefill():
118 | layer(test_input)
119 | else:
120 | past_key_values = cache_builder(bsz, prefill_length, layer)
121 | def _prefill():
122 | past_key_values.length = 0
123 | past_key_values._needs_init[0] = True
124 | layer(test_input, past_key_value=past_key_values)
125 | return module_benchmark(_prefill)
126 |
127 |
128 | def run_decode(layer, cache_builder, bsz, prefill_length, decode_steps, hidden_size):
129 | device = layer.self_attn.v_proj.weight.device
130 | test_input = torch.rand((bsz, prefill_length, hidden_size), dtype=torch.float16, device=device)
131 | next_input = torch.rand((bsz, 1, hidden_size), dtype=torch.float16, device=device)
132 | assert cache_builder is not None
133 | past_key_values = cache_builder(bsz, prefill_length + decode_steps, layer)
134 | layer(test_input, past_key_value=past_key_values)
135 | def _decode_for_multiple_steps():
136 | past_key_values.length = prefill_length
137 | for i in range(decode_steps):
138 | layer(next_input, past_key_value=past_key_values,
139 | position_ids=torch.tensor([[prefill_length + i]] * bsz, device=past_key_values.device, dtype=torch.int32))
140 | return module_benchmark(_decode_for_multiple_steps)
141 |
142 |
143 | def run_e2e(layer, cache_builder, bsz, prefill_length, decode_steps, hidden_size):
144 | device = layer.self_attn.v_proj.weight.device
145 | test_input = torch.rand((bsz, prefill_length, hidden_size), dtype=torch.float16, device=device)
146 | next_input = torch.rand((bsz, 1, hidden_size), dtype=torch.float16, device=device)
147 | assert cache_builder is not None
148 | past_key_values = cache_builder(bsz, prefill_length + decode_steps, layer)
149 | def _prefill_and_decode_for_multiple_steps():
150 | past_key_values.length = 0
151 | past_key_values._needs_init[0] = True
152 | layer(test_input, past_key_value=past_key_values)
153 | for i in range(decode_steps):
154 | layer(next_input, past_key_value=past_key_values,
155 | position_ids=torch.tensor([[prefill_length + i]] * bsz, device=device, dtype=torch.int32))
156 | return module_benchmark(_prefill_and_decode_for_multiple_steps)
157 |
158 |
159 | def _wait_for_input():
160 | print("Press enter")
161 | input()
162 |
163 | @torch.no_grad
164 | def run_all_for_model(layer, cache_builder, bsz, prefill, decode, hidden_size):
165 | layer = layer.cuda()
166 | layer.eval()
167 | time_prefill, _ = run_prefill(layer, cache_builder, bsz, prefill, hidden_size)
168 |
169 | _cleanup()
170 | if decode is not None:
171 | time_decode, memory_decode = run_decode(layer, cache_builder, bsz, prefill, decode, hidden_size)
172 | _cleanup()
173 | time_e2e, _ = run_e2e(layer, cache_builder, bsz, prefill, decode, hidden_size)
174 | _cleanup()
175 | else:
176 | time_decode = time_e2e = None
177 | memory_decode = None
178 | return time_prefill, time_decode, time_e2e, memory_decode
179 |
180 | def benchmark(args):
181 |
182 | for config_name in model_configs:
183 | model, cache_builder, hidden_size = get_model_quantized(config_name)
184 | layer = model.model.layers[0]
185 | del model
186 | _cleanup()
187 | time_prefill_i4, time_decode_i4, time_e2e_i4, mem_i4 = run_all_for_model(
188 | layer, cache_builder, args.batch_size, args.prefill_seq_len, args.decode_steps, hidden_size)
189 | del layer
190 | _cleanup()
191 | model, cache_builder, hidden_size = get_model_fp16(config_name)
192 | layer = model.model.layers[0]
193 | del model
194 | _cleanup()
195 | time_prefill_f16, time_decode_f16, time_e2e_f16, mem_f16 = run_all_for_model(
196 | layer, cache_builder, args.batch_size, args.prefill_seq_len, args.decode_steps, hidden_size)
197 | del layer
198 | _cleanup()
199 |
200 | print(f"Prefill Int4 time: {np.mean(time_prefill_i4):.3f} +- {1.96 * np.std(time_prefill_i4):.3f}ms")
201 | print(f"Prefill FP16 time: {np.mean(time_prefill_f16):.3f} +- {1.96 * np.std(time_prefill_f16):.3f}ms")
202 | print(f"Speedup: {np.mean(time_prefill_f16) / np.mean(time_prefill_i4):.3f}x")
203 | print(f'Prefill & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {np.mean(time_prefill_f16):.3f} & {np.mean(time_prefill_i4):.3f}\\\\')
204 |
205 | if args.decode_steps is not None:
206 | print(f"Decode Int4 time: {np.mean(time_decode_i4):.3f} +- {1.96 * np.std(time_decode_i4):.3f}ms")
207 | print(f"Decode FP16 time: {np.mean(time_decode_f16):.3f} +- {1.96 * np.std(time_decode_f16):.3f}ms")
208 | print(f"Speedup: {np.mean(time_decode_f16) / np.mean(time_decode_i4):.3f}x")
209 | print(f'Decode & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(time_decode_f16):.3f} & {np.mean(time_decode_i4):.3f}\\\\')
210 |
211 | print(f"E2E Int4 time: {np.mean(time_e2e_i4):.3f} +- {1.96 * np.std(time_e2e_i4):.3f}ms")
212 | print(f"E2E FP16 time: {np.mean(time_e2e_f16):.3f} +- {1.96 * np.std(time_e2e_f16):.3f}ms")
213 | print(f"Speedup: {np.mean(time_e2e_f16) / np.mean(time_e2e_i4):.3f}x")
214 | print(f'E2E & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(time_e2e_f16):.3f} & {np.mean(time_e2e_i4):.3f}\\\\')
215 |
216 | # table-style output
217 |
218 | print(f"Int4 memory: {np.mean(mem_i4) / (1024 * 1024 * 1024):.3f}GB +- {1.96 * np.std(mem_i4):.3f}")
219 | print(f"FP16 memory: {np.mean(mem_f16) / (1024 * 1024 * 1024):.3f}GB +- {1.96 * np.std(mem_f16):.3f}")
220 | print(f"Memory saving: {np.mean(mem_f16) / np.mean(mem_i4):.3f}x")
221 | print(f'Memory saving & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(mem_i4) / (1024 * 1024 * 1024):.3f}GB & {np.mean(mem_f16) / (1024 * 1024 * 1024):.3f}GB\\\\')
222 |
223 | print('--------------')
224 |
225 | if __name__ == '__main__':
226 | parser = argparse.ArgumentParser()
227 |
228 | parser.add_argument(
229 | '--batch_size', type=int,
230 | help='Batch size',
231 | default=1,
232 | )
233 | parser.add_argument(
234 | '--prefill_seq_len', type=int,
235 | help='Size of the input sequence',
236 | default=2048,
237 | )
238 | parser.add_argument(
239 | '--decode_steps', type=int,
240 | help='Decode steps',
241 | required=False,
242 | default=None,
243 | )
244 |
245 | args = parser.parse_args()
246 | pprint.pprint(vars(args))
247 | benchmark(args)
248 |
--------------------------------------------------------------------------------
/e2e/checkpoint_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/QuaRot/5008669b08c1f11f9b64d52d16fddd47ca754c5a/e2e/checkpoint_utils/__init__.py
--------------------------------------------------------------------------------
/e2e/checkpoint_utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import random
3 | import transformers
4 |
5 | def get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode=False):
6 |
7 | if hf_token is None:
8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
9 | else:
10 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
11 |
12 | if eval_mode:
13 | testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
14 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
15 | return testenc
16 | else:
17 | traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
18 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
19 | random.seed(seed)
20 | trainloader = []
21 | for _ in range(nsamples):
22 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
23 | j = i + seqlen
24 | inp = trainenc.input_ids[:, i:j]
25 | tar = inp.clone()
26 | tar[:, :-1] = -100
27 | trainloader.append((inp, tar))
28 | return trainloader
29 |
30 | def get_c4_new(nsamples, seed, seqlen, model, hf_token=None, eval_mode=False):
31 |
32 | if hf_token is None:
33 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
34 | else:
35 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
36 |
37 | if eval_mode:
38 | valdata = datasets.load_dataset(
39 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
40 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
41 | valenc = valenc.input_ids[:, :(256 * seqlen)]
42 | class TokenizerWrapper:
43 | def __init__(self, input_ids):
44 | self.input_ids = input_ids
45 | valenc = TokenizerWrapper(valenc)
46 | return valenc
47 | else:
48 | traindata = datasets.load_dataset(
49 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
50 |
51 | random.seed(seed)
52 | trainloader = []
53 | for _ in range(nsamples):
54 | while True:
55 | i = random.randint(0, len(traindata) - 1)
56 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
57 | if trainenc.input_ids.shape[1] >= seqlen:
58 | break
59 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
60 | j = i + seqlen
61 | inp = trainenc.input_ids[:, i:j]
62 | tar = inp.clone()
63 | tar[:, :-1] = -100
64 | trainloader.append((inp, tar))
65 | return trainloader
66 |
67 |
68 |
69 |
70 | def get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode=False):
71 |
72 |
73 | if hf_token is None:
74 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
75 | else:
76 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
77 |
78 | if eval_mode:
79 | testdata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='test')
80 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
81 | return testenc
82 | else:
83 | traindata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='train')
84 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
85 | random.seed(seed)
86 | trainloader = []
87 | for _ in range(nsamples):
88 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
89 | j = i + seqlen
90 | inp = trainenc.input_ids[:, i:j]
91 | tar = inp.clone()
92 | tar[:, :-1] = -100
93 | trainloader.append((inp, tar))
94 | return trainloader
95 |
96 |
97 | def get_loaders(
98 | name, nsamples=128, seed=0, seqlen=2048, model='', hf_token=None, eval_mode=False
99 | ):
100 | if 'wikitext2' in name:
101 | return get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode)
102 | if 'ptb' in name:
103 | return get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode)
104 | if 'c4' in name:
105 | return get_c4_new(nsamples, seed, seqlen, model, hf_token, eval_mode)
106 |
--------------------------------------------------------------------------------
/e2e/checkpoint_utils/quantize_llama_checkpoint.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import transformers
3 | import torch
4 | import shutil
5 | import json
6 |
7 | from e2e.quantized_llama import modeling_llama
8 | from e2e.checkpoint_utils import data_utils, gptq_utils, rotation_utils
9 | from quarot.functional import pack_i4
10 |
11 | def main(args):
12 | model = transformers.LlamaForCausalLM.from_pretrained(args.pretraiend_path_or_name)
13 |
14 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
15 | model.seqlen = 2048
16 | rotation_utils.fuse_layer_norms(model)
17 | rotation_utils.rotate_model(model)
18 | if not args.w_rtn:
19 | trainloader = data_utils.get_loaders(
20 | args.cal_dataset, nsamples=args.nsamples,
21 | seed=args.seed, model=args.pretraiend_path_or_name,
22 | seqlen=model.seqlen, eval_mode=False
23 | )
24 | quantizers = gptq_utils.gptq_fwrd(model, trainloader, device, args)
25 | else:
26 | quantizers = gptq_utils.rtn_fwrd(model, device, args)
27 |
28 |
29 | old_dict = model.state_dict()
30 | key_maps = {
31 | "mlp.down_proj": "mlp.down_proj.2",
32 | "self_attn.o_proj": "self_attn.o_proj.1"
33 | }
34 | bad_key_names = {
35 | "post_attention_layernorm.weight",
36 | "input_layernorm.weight"
37 | }
38 | def _get_new_key(key):
39 | new_key = key
40 | for old_name, new_name in key_maps.items():
41 | new_key = new_key.replace(old_name, new_name)
42 | return new_key
43 |
44 | def _keep_key(key):
45 | return all(bad_name not in key for bad_name in bad_key_names)
46 |
47 | new_dict = {_get_new_key(key): value for key, value in old_dict.items() if _keep_key(key)}
48 | for key, value in quantizers.items():
49 | new_key = _get_new_key(key)
50 | weight_scales = value.scale
51 | new_dict[f"{new_key}.weight_scales"] = weight_scales
52 | weight_matrix = new_dict[f"{new_key}.weight"]
53 | int_rounded_weight = (weight_matrix/weight_scales).round()
54 | new_dict[f"{new_key}.weight"] = pack_i4(int_rounded_weight.to(torch.int8))
55 |
56 | config = modeling_llama.QuarotLlamaConfig.from_pretrained(
57 | args.pretraiend_path_or_name,
58 | attn_implementation="flash_attention_2"
59 | )
60 | torch.set_default_dtype(torch.float16)
61 | with transformers.modeling_utils.no_init_weights():
62 | new_model = modeling_llama.QuarotLlamaForCausalLM(config=config)
63 |
64 | result = new_model.load_state_dict(new_dict, strict=False)
65 | assert all("had_rem_dim" in key for key in result.missing_keys), result
66 | assert len(result.unexpected_keys) == 0, result
67 |
68 | new_model = new_model.cpu()
69 |
70 | new_model.save_pretrained(args.save_path)
71 | with open(f"{args.save_path}/config.json") as f:
72 | config = json.load(f)
73 | config["auto_map"] = {
74 | "AutoConfig": "quarot.LlamaConfig",
75 | "AutoModelForCausalLM": "quarot.QuarotLlamaForCausalLM"
76 | }
77 | config["model_type"] = "llama_quarot"
78 | with open(f"{args.save_path}/config.json", "w") as f:
79 | json.dump(config, f)
80 |
81 | shutil.copy("e2e/quantized_llama/modeling_llama.py", f"{args.save_path}/quarot.py")
82 |
83 |
84 | if __name__ == "__main__":
85 | parser = argparse.ArgumentParser()
86 |
87 | supported_models = [
88 | 'meta-llama/Llama-2-7b-hf',
89 | 'meta-llama/Llama-2-13b-hf',
90 | 'meta-llama/Llama-2-70b-hf',
91 | ]
92 |
93 | supported_datasets = ['wikitext2', 'ptb', 'c4']
94 |
95 | # General Arguments
96 | parser.add_argument('--pretraiend_path_or_name', type=str, default='meta-llama/Llama-2-7b-hf',
97 | help='Model to load;', choices=supported_models)
98 | parser.add_argument('--save_path', type=str, required=True)
99 | parser.add_argument('--seed', type=int, default=0, help='Random Seed for HuggingFace and PyTorch')
100 | parser.add_argument('--eval_dataset', type=str, default='wikitext2',
101 | help='Dataset for Evaluation (default: wikitext2)', choices=supported_datasets,)
102 |
103 |
104 | parser.add_argument('--w_groupsize', type=int, default=-1,
105 | help='Groupsize for weight quantization. Note that this should be the same as a_groupsize')
106 | parser.add_argument('--w_asym', action=argparse.BooleanOptionalAction, default=False,
107 | help='ASymmetric weight quantization (default: False)')
108 | parser.add_argument('--w_rtn', action=argparse.BooleanOptionalAction, default=False,
109 | help='Quantize the weights using RtN. If the w_bits < 16 and this flag is not set, we use GPTQ')
110 | parser.add_argument('--w_clip', action=argparse.BooleanOptionalAction, default=False,
111 | help='''Clipping the weight quantization!
112 | We do not support arguments for clipping and we find the best clip ratio during the weight quantization''')
113 | parser.add_argument('--nsamples', type=int, default=128,
114 | help='Number of calibration data samples for GPTQ.')
115 | parser.add_argument('--cal_dataset', type=str, default='wikitext2',
116 | help='calibration data samples for GPTQ.', choices=supported_datasets)
117 | parser.add_argument('--percdamp', type=float, default=.01,
118 | help='Percent of the average Hessian diagonal to use for dampening.')
119 | parser.add_argument('--act_order', action=argparse.BooleanOptionalAction, default=False,
120 | help='act-order in GPTQ')
121 |
122 | args = parser.parse_args()
123 |
124 | args.w_bits = 4
125 | main(args)
126 |
--------------------------------------------------------------------------------
/e2e/checkpoint_utils/rotation_utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import torch
3 | import typing
4 | import transformers
5 | import tqdm, math
6 | from quarot.functional import random_hadamard_matrix, apply_exact_had_to_linear
7 |
8 | def fuse_ln_linear(layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear]) -> None:
9 | """
10 | fuse the linear operations in Layernorm into the adjacent linear blocks.
11 | """
12 | for linear in linear_layers:
13 | linear_dtype = linear.weight.dtype
14 |
15 | # Calculating new weight and bias
16 | W_ = linear.weight.data.double()
17 | linear.weight.data = (W_ * layernorm.weight.double()).to(linear_dtype)
18 |
19 | if hasattr(layernorm, 'bias'):
20 | if linear.bias is None:
21 | linear.bias = torch.nn.Parameter(torch.zeros(linear.out_features, dtype=torch.float64))
22 | linear.bias.data = linear.bias.data.double() + torch.matmul(W_, layernorm.bias.double())
23 | linear.bias.data = linear.bias.data.to(linear_dtype)
24 |
25 | def fuse_layer_norms(model):
26 |
27 | # Embedding fusion
28 | W = model.model.embed_tokens
29 | W_ = W.weight.data.double()
30 | W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
31 |
32 | layers = model.model.layers
33 |
34 | # Fuse the linear operations in Layernorm into the adjacent linear blocks.
35 | for layer in layers:
36 | # fuse the input layernorms into the linear layers
37 | fuse_ln_linear(layer.post_attention_layernorm, [layer.mlp.up_proj, layer.mlp.gate_proj])
38 | fuse_ln_linear(layer.input_layernorm, [layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj])
39 |
40 |
41 | fuse_ln_linear(model.model.norm, [model.lm_head])
42 |
43 |
44 |
45 | def rotate_embeddings(model, Q: torch.Tensor) -> None:
46 | # Rotate the embeddings.
47 | W = model.model.embed_tokens
48 | dtype = W.weight.data.dtype
49 | W_ = W.weight.data.to(dtype=torch.float64)
50 | W.weight.data = torch.matmul(W_, Q).to(device="cpu", dtype=dtype)
51 |
52 |
53 | def rotate_attention_inputs(layer, Q) -> None:
54 | # Rotate the WQ, WK and WV matrices of the self-attention layer.
55 | for W in [layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj]:
56 | dtype = W.weight.dtype
57 | W_ = W.weight.to(dtype=torch.float64)
58 | W.weight.data = torch.matmul(W_, Q).to(device="cpu", dtype=dtype)
59 |
60 | def rotate_attention_output(layer, Q) -> None:
61 | # Rotate output matrix of the self-attention layer.
62 | W = layer.self_attn.o_proj
63 | dtype = W.weight.data.dtype
64 | W_ = W.weight.data.to(dtype=torch.float64)
65 | W.weight.data = torch.matmul(Q.T, W_).to(device="cpu", dtype=dtype)
66 | if W.bias is not None:
67 | b = W.bias.data.to(dtype=torch.float64)
68 | W.bias.data = torch.matmul(Q.T, b).to(device="cpu", dtype=dtype)
69 |
70 | def rotate_mlp_input(layer, Q):
71 | # Rotate the MLP input weights.
72 | mlp_inputs = [layer.mlp.up_proj, layer.mlp.gate_proj]
73 | for W in mlp_inputs:
74 | dtype = W.weight.dtype
75 | W_ = W.weight.data.to(dtype=torch.float64)
76 | W.weight.data = torch.matmul(W_, Q).to(device="cpu", dtype=dtype)
77 |
78 | def rotate_mlp_output(layer, Q):
79 | # Rotate the MLP output weights and bias.
80 | W = layer.mlp.down_proj
81 | dtype = W.weight.data.dtype
82 | W_ = W.weight.data.to(dtype=torch.float64)
83 | W.weight.data = torch.matmul(Q.T, W_).to(device="cpu", dtype=dtype)
84 | apply_exact_had_to_linear(W, had_dim=-1, output=False) #apply exact (inverse) hadamard on the weights of mlp output
85 | if W.bias is not None:
86 | b = W.bias.data.to(dtype=torch.float64)
87 | W.bias.data = torch.matmul(Q.T, b).to(device="cpu", dtype=dtype)
88 |
89 | def rotate_head(model, Q: torch.Tensor) -> None:
90 | # Rotate the head.
91 | W = model.lm_head
92 | dtype = W.weight.data.dtype
93 | W_ = W.weight.data.to(dtype=torch.float64)
94 | W.weight.data = torch.matmul(W_, Q).to(device="cpu", dtype=dtype)
95 |
96 | def rotate_ov_proj(layer, head_num, head_dim):
97 | v_proj = layer.self_attn.v_proj
98 | o_proj = layer.self_attn.o_proj
99 |
100 | apply_exact_had_to_linear(v_proj, had_dim=head_dim, output=True)
101 | apply_exact_had_to_linear(o_proj, had_dim=-1, output=False)
102 |
103 |
104 | @torch.inference_mode()
105 | def rotate_model(model):
106 | Q = random_hadamard_matrix(model.config.hidden_size, model.device)
107 | config = model.config
108 | num_heads = config.num_attention_heads
109 | model_dim = config.hidden_size
110 | head_dim = model_dim // num_heads
111 |
112 |
113 | rotate_embeddings(model, Q)
114 | rotate_head(model, Q)
115 | gc.collect()
116 | torch.cuda.empty_cache()
117 | layers = model.model.layers
118 | for idx, layer in enumerate(tqdm.tqdm(layers, unit="layer", desc="Rotating")):
119 | rotate_attention_inputs(layers[idx], Q)
120 | rotate_attention_output(layers[idx], Q)
121 | rotate_mlp_input(layers[idx], Q)
122 | rotate_mlp_output(layers[idx], Q)
123 | rotate_ov_proj(layers[idx], num_heads, head_dim)
124 |
--------------------------------------------------------------------------------
/e2e/quantized_llama/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/QuaRot/5008669b08c1f11f9b64d52d16fddd47ca754c5a/e2e/quantized_llama/__init__.py
--------------------------------------------------------------------------------
/e2e/quantized_llama/modeling_llama.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import quarot
3 | import quarot.transformers
4 | import torch
5 | from transformers import LlamaConfig
6 | from transformers.models.llama.modeling_llama import LlamaAttention, \
7 | LlamaFlashAttention2, LlamaForCausalLM, apply_rotary_pos_emb, LlamaMLP
8 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
9 | from typing import Optional, Tuple
10 | from transformers import Cache
11 |
12 |
13 | ALL_LAYERNORM_LAYERS.append(quarot.nn.RMSNorm)
14 |
15 | class QuarotLlamaConfig(LlamaConfig):
16 | model_type = "llama_quarot"
17 |
18 | class QuarotFP16LlamaAttention(LlamaFlashAttention2):
19 |
20 | def __init__(self, *args, **kwargs):
21 | super().__init__(*args, **kwargs)
22 | self.quantizer = torch.nn.Identity()
23 | self.o_proj_hadamard = torch.nn.Identity()
24 |
25 | def forward(
26 | self,
27 | hidden_states: torch.Tensor,
28 | attention_mask: Optional[torch.LongTensor] = None,
29 | position_ids: Optional[torch.LongTensor] = None,
30 | past_key_value: Optional[Cache] = None,
31 | output_attentions: bool = False,
32 | use_cache: bool = False,
33 | cache_position: Optional[torch.LongTensor] = None,
34 | **kwargs,
35 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
36 | output_attentions = False
37 |
38 | bsz, q_len, _ = hidden_states.size()
39 |
40 | hidden_states = self.quantizer(hidden_states)
41 |
42 | query_states = self.q_proj(hidden_states)
43 | key_states = self.k_proj(hidden_states)
44 | value_states = self.v_proj(hidden_states)
45 |
46 | # Flash attention requires the input to have the shape
47 | # batch_size x seq_length x head_dim x hidden_dim
48 | # therefore we just need to keep the original shape
49 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
50 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
51 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
52 |
53 | kv_seq_len = key_states.shape[1]
54 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
55 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
56 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, unsqueeze_dim=2)
57 |
58 | past_key_value = getattr(self, "past_key_value", past_key_value)
59 | assert past_key_value is not None
60 | # sin and cos are specific to RoPE models; position_ids needed for the static cache
61 |
62 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "attention_mask": attention_mask}
63 | cache_out = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
64 |
65 |
66 | dropout_rate = self.attention_dropout if self.training else 0.0
67 |
68 | assert self.is_causal
69 |
70 | if isinstance(cache_out, tuple):
71 | key_states, value_states = cache_out
72 | attn_output = self._flash_attention_forward(
73 | query_states,
74 | key_states,
75 | value_states,
76 | query_length=q_len,
77 | attention_mask=attention_mask
78 | )
79 | else:
80 | attn_output = cache_out(query_states)
81 |
82 | attn_output = self.o_proj_hadamard(attn_output.transpose(-1, -2)).transpose(-1, -2)
83 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
84 | attn_output = self.o_proj(attn_output)
85 |
86 | if not output_attentions:
87 | attn_weights = None
88 |
89 | return attn_output, attn_weights, past_key_value
90 |
91 | class QuarotLlamaAttention(QuarotFP16LlamaAttention):
92 |
93 | def __init__(self, *args, **kwargs):
94 | super().__init__(*args, **kwargs)
95 | self.quantizer = quarot.nn.Quantizer()
96 | self.q_proj = quarot.nn.Linear4bit.from_float(self.q_proj)
97 | self.k_proj = quarot.nn.Linear4bit.from_float(self.k_proj)
98 | self.v_proj = quarot.nn.Linear4bit.from_float(self.v_proj)
99 | self.o_proj_hadamard = quarot.nn.OnlineHadamard(self.num_heads)
100 | self.o_proj = torch.nn.Sequential(
101 | quarot.nn.Quantizer(),
102 | quarot.nn.Linear4bit.from_float(self.o_proj)
103 | )
104 |
105 | class QuarotLlamaMLP(LlamaMLP):
106 | def __init__(self, *args, **kwargs):
107 | super().__init__(*args, **kwargs)
108 | self.quantizer = quarot.nn.Quantizer()
109 | self.up_proj = quarot.nn.Linear4bit.from_float(self.up_proj)
110 | self.gate_proj = quarot.nn.Linear4bit.from_float(self.gate_proj)
111 | self.down_proj = torch.nn.Sequential(
112 | quarot.nn.OnlineHadamard(self.intermediate_size),
113 | quarot.nn.Quantizer(),
114 | quarot.nn.Linear4bit.from_float(self.down_proj)
115 | )
116 |
117 | def forward(self, x):
118 | x = self.quantizer(x)
119 | return super().forward(x)
120 |
121 |
122 | class QuarotFP16LlamaForCausalLM(LlamaForCausalLM):
123 | def __init__(self, config):
124 | super().__init__(config)
125 | assert config._attn_implementation == "flash_attention_2"
126 | for layer_idx, layer in enumerate(self.model.layers):
127 | layer.self_attn = QuarotFP16LlamaAttention(config=config, layer_idx=layer_idx)
128 | self.cache_dtype = "float16"
129 | self._expected_max_length = None
130 |
131 |
132 | def build_cache(self, batch_size, page_size, max_length):
133 | device = self.model.layers[0].self_attn.v_proj.weight.device
134 | dtype = self.cache_dtype or self.model.layers[0].self_attn.v_proj.weight.dtype
135 |
136 | num_heads = self.config.num_key_value_heads
137 | model_dim = self.config.hidden_size
138 | head_dim = model_dim // num_heads
139 | disable_quant = self.cache_dtype == "float16"
140 | return quarot.transformers.MultiLayerPagedKVCache4Bit(
141 | batch_size=batch_size,
142 | page_size=page_size,
143 | max_seq_len=max_length,
144 | device=device,
145 | n_layers=len(self.model.layers),
146 | num_heads=num_heads,
147 | head_dim=head_dim,
148 | disable_quant=disable_quant,
149 | hadamard_dtype=None if disable_quant else torch.float16
150 | )
151 |
152 | def _get_logits_processor(self, generation_config, *args, **kwargs):
153 | # This is a hack to get the max length from generation_config.
154 | # Doing it here because max_length might not be set before this
155 | # method is called.
156 | self._expected_max_length = generation_config.max_length # This value will be reset at the next forward call
157 | return super()._get_logits_processor(generation_config, *args, **kwargs)
158 |
159 |
160 | def forward(self, input_ids, *args, past_key_values=None, **kwargs):
161 | if past_key_values is None:
162 | max_length = self._expected_max_length or input_ids.shape[1]
163 | self._expected_max_length = None # Reset this value.
164 | past_key_values = self.build_cache(
165 | input_ids.shape[0],
166 | page_size=max_length, # For now working with single page per batch.
167 | max_length=max_length)
168 | out = super().forward(input_ids, *args, past_key_values=past_key_values, **kwargs)
169 | return out
170 |
171 |
172 |
173 | class QuarotLlamaForCausalLM(QuarotFP16LlamaForCausalLM):
174 | def __init__(self, config):
175 | super().__init__(config)
176 | assert config._attn_implementation == "flash_attention_2"
177 | self.norm = quarot.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
178 | for layer_idx, layer in enumerate(self.model.layers):
179 | layer.self_attn = QuarotLlamaAttention(config=config, layer_idx=layer_idx)
180 | layer.input_layernorm = quarot.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
181 | layer.post_attention_layernorm = quarot.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
182 | layer.mlp = QuarotLlamaMLP(config=config)
183 | self.cache_dtype = "int4"
184 |
--------------------------------------------------------------------------------
/fake_quant/README.md:
--------------------------------------------------------------------------------
1 | # Fake Quantization in QuaRot
2 |
3 |
4 | In this directory, we provide the torch scripts for the experiments in QuaRot.
5 |
6 |
7 | ## Language Generation and Zero-Shot Evaluations
8 |
9 | Currently, we only support **LLaMa-2** models. You can simply run the `main.py` to reproduce the results in the paper. The most important arguments are:
10 |
11 | - `--model`: the model name (or path to the weights)
12 | - `--bsz`: the batch size for PPL evaluation
13 | - `--rotate`: whether we want to rotate the model
14 | - `--lm_eval`: whether we want to run LM-Eval for Zero-Shot tasks
15 | - `--tasks`: the tasks for LM-Eval
16 | - `--cal_dataset`: the calibration dataset for GPTQ quantization
17 | - `--a_bits`: the number of bits for activation quantization
18 | - `--w_bits`: the number of bits for weight quantization
19 | - `--v_bits`: the number of bits for value quantization
20 | - `--k_bits`: the number of bits for key quantization
21 | - `--w_clip`: Whether we want to clip the weights
22 | - `--a_clip_ratio`: The ratio of clipping for activation
23 | - `--k_clip_ratio`: The ratio of clipping for key
24 | - `--v_clip_ratio`: The ratio of clipping for value
25 | - `--w_asym`: Whether we want to use asymmetric quantization for weights
26 | - `--a_asym`: Whether we want to use asymmetric quantization for activation
27 | - `--v_asym`: Whether we want to use asymmetric quantization for value
28 | - `--k_asym`: Whether we want to use asymmetric quantization for key
29 | - `--a_groupsize`: The group size for activation quantization
30 | - `--w_groupsize`: The group size for weight quantization
31 | - `--v_groupsize`: The group size for value quantization
32 | - `--k_groupsize`: The group size for key quantization
33 |
34 | For example, to run the perplexity of `LLaMA2-7B` model with quantizing all weights and activations, you can run the following command:
35 |
36 | ```bash
37 | /bin/python main.py --model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --v_bits 4 --k_bits 4 --w_bits 4 --w_clip
38 | ```
39 |
--------------------------------------------------------------------------------
/fake_quant/data_utils.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import random
3 | import transformers
4 |
5 | def get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode=False):
6 |
7 | if hf_token is None:
8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
9 | else:
10 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
11 |
12 | if eval_mode:
13 | testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
14 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
15 | return testenc
16 | else:
17 | traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
18 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
19 | random.seed(seed)
20 | trainloader = []
21 | for _ in range(nsamples):
22 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
23 | j = i + seqlen
24 | inp = trainenc.input_ids[:, i:j]
25 | tar = inp.clone()
26 | tar[:, :-1] = -100
27 | trainloader.append((inp, tar))
28 | return trainloader
29 |
30 | def get_c4_new(nsamples, seed, seqlen, model, hf_token=None, eval_mode=False):
31 |
32 | if hf_token is None:
33 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
34 | else:
35 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
36 |
37 | if eval_mode:
38 | valdata = datasets.load_dataset(
39 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
40 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
41 | valenc = valenc.input_ids[:, :(256 * seqlen)]
42 | class TokenizerWrapper:
43 | def __init__(self, input_ids):
44 | self.input_ids = input_ids
45 | valenc = TokenizerWrapper(valenc)
46 | return valenc
47 | else:
48 | traindata = datasets.load_dataset(
49 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
50 |
51 | random.seed(seed)
52 | trainloader = []
53 | for _ in range(nsamples):
54 | while True:
55 | i = random.randint(0, len(traindata) - 1)
56 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
57 | if trainenc.input_ids.shape[1] >= seqlen:
58 | break
59 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
60 | j = i + seqlen
61 | inp = trainenc.input_ids[:, i:j]
62 | tar = inp.clone()
63 | tar[:, :-1] = -100
64 | trainloader.append((inp, tar))
65 | return trainloader
66 |
67 |
68 |
69 |
70 | def get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode=False):
71 |
72 |
73 | if hf_token is None:
74 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
75 | else:
76 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
77 |
78 | if eval_mode:
79 | testdata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='test')
80 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
81 | return testenc
82 | else:
83 | traindata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='train')
84 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
85 | random.seed(seed)
86 | trainloader = []
87 | for _ in range(nsamples):
88 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
89 | j = i + seqlen
90 | inp = trainenc.input_ids[:, i:j]
91 | tar = inp.clone()
92 | tar[:, :-1] = -100
93 | trainloader.append((inp, tar))
94 | return trainloader
95 |
96 |
97 | def get_loaders(
98 | name, nsamples=128, seed=0, seqlen=2048, model='', hf_token=None, eval_mode=False
99 | ):
100 | if 'wikitext2' in name:
101 | return get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode)
102 | if 'ptb' in name:
103 | return get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode)
104 | if 'c4' in name:
105 | return get_c4_new(nsamples, seed, seqlen, model, hf_token, eval_mode)
--------------------------------------------------------------------------------
/fake_quant/eval_utils.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import model_utils
3 | import quant_utils
4 | import torch
5 | import os
6 | import logging
7 | from tqdm import tqdm
8 |
9 |
10 | @torch.no_grad()
11 | def evaluator(model, testenc, dev, args):
12 |
13 | model.eval()
14 |
15 | if 'opt' in args.model:
16 | opt_type = True
17 | llama_type = False
18 | elif 'meta' in args.model:
19 | llama_type = True
20 | opt_type = False
21 | else:
22 | raise ValueError(f'Unknown model {args.model}')
23 |
24 |
25 | use_cache = model.config.use_cache
26 | model.config.use_cache = False
27 |
28 | if opt_type:
29 | layers = model.model.decoder.layers
30 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
31 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
32 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
33 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
34 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
35 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
36 |
37 | elif llama_type:
38 | layers = model.model.layers
39 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
40 |
41 | layers[0] = layers[0].to(dev)
42 |
43 | # Convert the whole text of evaluation dataset into batches of sequences.
44 | input_ids = testenc.input_ids # (1, text_len)
45 | nsamples = input_ids.numel() // model.seqlen # The tail is truncated.
46 | input_ids = input_ids[:, :nsamples * model.seqlen].view(nsamples, model.seqlen).to(dev) # (nsamples, seqlen)
47 |
48 | batch_size = args.bsz
49 | input_ids = [input_ids[i:i + batch_size] for i in range(0, nsamples, batch_size)]
50 | nbatches = len(input_ids)
51 |
52 | dtype = next(iter(model.parameters())).dtype
53 | # The input of the first decoder layer.
54 | inps = torch.zeros(
55 | (nbatches, batch_size, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
56 | )
57 | inps = [0] * nbatches
58 | cache = {'i': 0, 'attention_mask': None}
59 | class Catcher(torch.nn.Module):
60 | def __init__(self, module):
61 | super().__init__()
62 | self.module = module
63 | def forward(self, inp, **kwargs):
64 | inps[cache['i']] = inp
65 | cache['i'] += 1
66 | cache['attention_mask'] = kwargs['attention_mask']
67 | if llama_type:
68 | cache['position_ids'] = kwargs['position_ids']
69 | raise ValueError
70 | layers[0] = Catcher(layers[0])
71 |
72 | for i in range(nbatches):
73 | batch = input_ids[i]
74 | try:
75 | model(batch)
76 | except ValueError:
77 | pass
78 | layers[0] = layers[0].module
79 | layers[0] = layers[0].cpu()
80 |
81 | if opt_type:
82 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
83 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
84 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
85 | model.model.decoder.project_out = model.model.decoder.project_out.cpu()
86 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
87 | model.model.decoder.project_in = model.model.decoder.project_in.cpu()
88 | elif llama_type:
89 | model.model.embed_tokens = model.model.embed_tokens.cpu()
90 | position_ids = cache['position_ids']
91 |
92 | torch.cuda.empty_cache()
93 | outs = [0] * nbatches
94 | attention_mask = cache['attention_mask']
95 |
96 | for i in tqdm(range(len(layers)), desc="(Eval) Layers"):
97 | layer = layers[i].to(dev)
98 |
99 | # Dump the layer input and output
100 | if args.capture_layer_io and args.layer_idx == i:
101 | captured_io = model_utils.capture_layer_io(model_utils.get_model_type(model), layer, inps)
102 | save_path = model_utils.get_layer_io_save_path(args)
103 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
104 | torch.save(captured_io, save_path)
105 | logging.info(f'Dumped layer input and output to: {save_path}')
106 |
107 | for j in range(nbatches):
108 | if opt_type:
109 | outs[j] = layer(inps[j], attention_mask=attention_mask)[0]
110 | elif llama_type:
111 | outs[j] = layer(inps[j], attention_mask=attention_mask, position_ids=position_ids)[0]
112 | layers[i] = layer.cpu()
113 | del layer
114 | torch.cuda.empty_cache()
115 | inps, outs = outs, inps
116 |
117 | if opt_type:
118 | if model.model.decoder.final_layer_norm is not None:
119 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
120 | if model.model.decoder.project_out is not None:
121 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
122 |
123 | elif llama_type:
124 | if model.model.norm is not None:
125 | model.model.norm = model.model.norm.to(dev)
126 |
127 | model.lm_head = model.lm_head.to(dev)
128 | nlls = []
129 | loss_fct = torch.nn.CrossEntropyLoss(reduction = "none")
130 | for i in range(nbatches):
131 | hidden_states = inps[i]
132 | if opt_type:
133 | if model.model.decoder.final_layer_norm is not None:
134 | hidden_states = model.model.decoder.final_layer_norm(hidden_states)
135 | if model.model.decoder.project_out is not None:
136 | hidden_states = model.model.decoder.project_out(hidden_states)
137 | elif llama_type:
138 | if model.model.norm is not None:
139 | hidden_states = model.model.norm(hidden_states)
140 | lm_logits = model.lm_head(hidden_states)
141 | shift_logits = lm_logits[:, :-1, :]
142 | shift_labels = input_ids[i][:, 1:]
143 | loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels)
144 | neg_log_likelihood = loss.float().mean(dim=1)
145 | nlls.append(neg_log_likelihood)
146 | nlls_tensor = torch.cat(nlls)
147 | ppl = torch.exp(nlls_tensor.mean())
148 | model.config.use_cache = use_cache
149 | logging.info(f'\n{args.eval_dataset.upper()} PPL: {ppl.item():.3f}')
150 | return ppl.item()
151 |
--------------------------------------------------------------------------------
/fake_quant/gptq_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import tqdm
4 | import torch
5 | import torch.nn as nn
6 | import utils
7 | import quant_utils
8 | import logging
9 |
10 | torch.backends.cuda.matmul.allow_tf32 = False
11 | torch.backends.cudnn.allow_tf32 = False
12 |
13 |
14 | class GPTQ:
15 |
16 | def __init__(self, layer):
17 | self.layer = layer
18 | self.dev = self.layer.weight.device
19 | W = layer.weight.data.clone()
20 | self.rows = W.shape[0]
21 | self.columns = W.shape[1]
22 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
23 | self.nsamples = 0
24 |
25 | def add_batch(self, inp, out):
26 |
27 | if len(inp.shape) == 2:
28 | inp = inp.unsqueeze(0)
29 | tmp = inp.shape[0]
30 | if len(inp.shape) == 3:
31 | inp = inp.reshape((-1, inp.shape[-1]))
32 | inp = inp.t()
33 | self.H *= self.nsamples / (self.nsamples + tmp)
34 | self.nsamples += tmp
35 | # inp = inp.float()
36 | inp = math.sqrt(2 / self.nsamples) * inp.float()
37 | # self.H += 2 / self.nsamples * inp.matmul(inp.t())
38 | self.H += inp.matmul(inp.t())
39 |
40 | def fasterquant(
41 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False
42 | ):
43 | W = self.layer.weight.data.clone()
44 | W = W.float()
45 |
46 | tick = time.time()
47 |
48 | if not self.quantizer.ready():
49 | self.quantizer.find_params(W)
50 |
51 | H = self.H
52 | del self.H
53 | dead = torch.diag(H) == 0
54 | H[dead, dead] = 1
55 | W[:, dead] = 0
56 |
57 | if static_groups:
58 | import copy
59 | groups = []
60 | for i in range(0, self.columns, groupsize):
61 | quantizer = copy.deepcopy(self.quantizer)
62 | quantizer.find_params(W[:, i:(i + groupsize)])
63 | groups.append(quantizer)
64 |
65 | if actorder:
66 | perm = torch.argsort(torch.diag(H), descending=True)
67 | W = W[:, perm]
68 | H = H[perm][:, perm]
69 | invperm = torch.argsort(perm)
70 |
71 | Losses = torch.zeros_like(W)
72 | Q = torch.zeros_like(W)
73 |
74 | damp = percdamp * torch.mean(torch.diag(H))
75 | diag = torch.arange(self.columns, device=self.dev)
76 | H[diag, diag] += damp
77 | H = torch.linalg.cholesky(H)
78 | H = torch.cholesky_inverse(H)
79 | H = torch.linalg.cholesky(H, upper=True)
80 | Hinv = H
81 |
82 | for i1 in range(0, self.columns, blocksize):
83 | i2 = min(i1 + blocksize, self.columns)
84 | count = i2 - i1
85 |
86 | W1 = W[:, i1:i2].clone()
87 | Q1 = torch.zeros_like(W1)
88 | Err1 = torch.zeros_like(W1)
89 | Losses1 = torch.zeros_like(W1)
90 | Hinv1 = Hinv[i1:i2, i1:i2]
91 |
92 | for i in range(count):
93 | w = W1[:, i]
94 | d = Hinv1[i, i]
95 |
96 | if groupsize != -1:
97 | if not static_groups:
98 | if (i1 + i) % groupsize == 0:
99 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
100 | else:
101 | idx = i1 + i
102 | if actorder:
103 | idx = perm[idx]
104 | self.quantizer = groups[idx // groupsize]
105 |
106 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
107 | Q1[:, i] = q
108 | Losses1[:, i] = (w - q) ** 2 / d ** 2
109 |
110 | err1 = (w - q) / d
111 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
112 | Err1[:, i] = err1
113 |
114 | Q[:, i1:i2] = Q1
115 | Losses[:, i1:i2] = Losses1 / 2
116 |
117 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
118 |
119 | torch.cuda.synchronize()
120 |
121 | if actorder:
122 | Q = Q[:, invperm]
123 |
124 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
125 | if torch.any(torch.isnan(self.layer.weight.data)):
126 | logging.warning('NaN in weights')
127 | import pprint
128 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point)
129 | raise ValueError('NaN in weights')
130 |
131 | def free(self):
132 | self.H = None
133 | self.Losses = None
134 | self.Trace = None
135 | torch.cuda.empty_cache()
136 | utils.cleanup_memory(verbos=False)
137 |
138 |
139 | @torch.no_grad()
140 | def gptq_fwrd(model, dataloader, dev, args):
141 | '''
142 | From GPTQ repo
143 | TODO: Make this function general to support both OPT and LLaMA models
144 | '''
145 | logging.info('-----GPTQ Quantization-----')
146 |
147 | use_cache = model.config.use_cache
148 | model.config.use_cache = False
149 | layers = model.model.layers
150 |
151 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
152 | model.model.norm = model.model.norm.to(dev)
153 | layers[0] = layers[0].to(dev)
154 |
155 | dtype = next(iter(model.parameters())).dtype
156 | inps = torch.zeros(
157 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
158 | )
159 | cache = {'i': 0, 'attention_mask': None}
160 |
161 | class Catcher(nn.Module):
162 | def __init__(self, module):
163 | super().__init__()
164 | self.module = module
165 | def forward(self, inp, **kwargs):
166 | inps[cache['i']] = inp
167 | cache['i'] += 1
168 | cache['attention_mask'] = kwargs['attention_mask']
169 | cache['position_ids'] = kwargs['position_ids']
170 | raise ValueError
171 | layers[0] = Catcher(layers[0])
172 | for batch in dataloader:
173 | try:
174 | model(batch[0].to(dev))
175 | except ValueError:
176 | pass
177 | layers[0] = layers[0].module
178 |
179 | layers[0] = layers[0].cpu()
180 | model.model.embed_tokens = model.model.embed_tokens.cpu()
181 | model.model.norm = model.model.norm.cpu()
182 | torch.cuda.empty_cache()
183 |
184 | outs = torch.zeros_like(inps)
185 | attention_mask = cache['attention_mask']
186 | position_ids = cache['position_ids']
187 |
188 | quantizers = {}
189 | sequential = [
190 | ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
191 | ['self_attn.o_proj.module'],
192 | ['mlp.up_proj.module', 'mlp.gate_proj.module'],
193 | ['mlp.down_proj.module']
194 | ]
195 | for i in range(len(layers)):
196 | print(f'\nLayer {i}:', flush=True, end=' ')
197 | layer = layers[i].to(dev)
198 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear])
199 | for names in sequential:
200 | subset = {n: full[n] for n in names}
201 |
202 | gptq = {}
203 | for name in subset:
204 | print(f'{name}', end=' ', flush=True)
205 | layer_weight_bits = args.w_bits
206 | layer_weight_sym = not(args.w_asym)
207 | if 'lm_head' in name:
208 | layer_weight_bits = 16
209 | continue
210 | if args.int8_down_proj and 'down_proj' in name:
211 | layer_weight_bits = 8
212 | gptq[name] = GPTQ(subset[name])
213 | gptq[name].quantizer = quant_utils.WeightQuantizer()
214 | gptq[name].quantizer.configure(
215 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip
216 | )
217 |
218 | def add_batch(name):
219 | def tmp(_, inp, out):
220 | gptq[name].add_batch(inp[0].data, out.data)
221 | return tmp
222 | handles = []
223 | for name in subset:
224 | handles.append(subset[name].register_forward_hook(add_batch(name)))
225 | for j in range(args.nsamples):
226 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
227 | for h in handles:
228 | h.remove()
229 |
230 | for name in subset:
231 | layer_w_groupsize = args.w_groupsize
232 | gptq[name].fasterquant(
233 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False
234 | )
235 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
236 | gptq[name].free()
237 |
238 | for j in range(args.nsamples):
239 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
240 |
241 | layers[i] = layer.cpu()
242 | del layer
243 | del gptq
244 | torch.cuda.empty_cache()
245 |
246 | inps, outs = outs, inps
247 |
248 | model.config.use_cache = use_cache
249 | utils.cleanup_memory(verbos=True)
250 | logging.info('-----GPTQ Quantization Done-----\n')
251 | return quantizers
252 |
253 |
254 |
255 |
256 | @torch.no_grad()
257 | def rtn_fwrd(model, dev, args):
258 | '''
259 | From GPTQ repo
260 | TODO: Make this function general to support both OPT and LLaMA models
261 | '''
262 | assert args.w_groupsize ==-1, "Groupsize not supported in RTN!"
263 | layers = model.model.layers
264 | torch.cuda.empty_cache()
265 |
266 | quantizers = {}
267 |
268 | for i in tqdm.tqdm(range(len(layers)), desc="(RtN Quant.) Layers"):
269 | layer = layers[i].to(dev)
270 |
271 | subset = quant_utils.find_qlayers(layer,
272 | layers=[torch.nn.Linear])
273 |
274 | for name in subset:
275 | layer_weight_bits = args.w_bits
276 | if 'lm_head' in name:
277 | layer_weight_bits = 16
278 | continue
279 | if args.int8_down_proj and 'down_proj' in name:
280 | layer_weight_bits = 8
281 |
282 | quantizer = quant_utils.WeightQuantizer()
283 | quantizer.configure(
284 | layer_weight_bits, perchannel=True, sym=not(args.w_asym), mse=args.w_clip
285 | )
286 | W = subset[name].weight.data
287 | quantizer.find_params(W)
288 | subset[name].weight.data = quantizer.quantize(W).to(
289 | next(iter(layer.parameters())).dtype)
290 | quantizers['model.layers.%d.%s' % (i, name)] = quantizer.cpu()
291 | layers[i] = layer.cpu()
292 | torch.cuda.empty_cache()
293 | del layer
294 |
295 | utils.cleanup_memory(verbos=True)
296 | return quantizers
297 |
--------------------------------------------------------------------------------
/fake_quant/main.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import torch
3 | import model_utils
4 | import data_utils
5 | import transformers
6 | import quant_utils
7 | import rotation_utils
8 | import gptq_utils
9 | import eval_utils
10 | import hadamard_utils
11 |
12 | def main():
13 | args = utils.parser_gen()
14 | if args.wandb:
15 | import wandb
16 | wandb.init(project=args.wandb_project, entity=args.wandb_id)
17 | wandb.config.update(args)
18 |
19 | transformers.set_seed(args.seed)
20 | model = model_utils.get_model(args.model, args.hf_token)
21 | model.eval()
22 |
23 |
24 | # Rotate the weights
25 | if args.rotate:
26 | rotation_utils.fuse_layer_norms(model)
27 | rotation_utils.rotate_model(model, args)
28 | utils.cleanup_memory(verbos=True)
29 |
30 | quant_utils.add_actquant(model) #Add Activation Wrapper to the model
31 | qlayers = quant_utils.find_qlayers(model)
32 | for name in qlayers:
33 | if 'down_proj' in name:
34 | had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
35 | qlayers[name].online_full_had = True
36 | qlayers[name].had_K = had_K
37 | qlayers[name].K = K
38 | qlayers[name].fp32_had = args.fp32_had
39 | if 'o_proj' in name:
40 | had_K, K = hadamard_utils.get_hadK(model.config.num_attention_heads)
41 | qlayers[name].online_partial_had = True
42 | qlayers[name].had_K = had_K
43 | qlayers[name].K = K
44 | qlayers[name].had_dim = model.config.hidden_size//model.config.num_attention_heads
45 | qlayers[name].fp32_had = args.fp32_had
46 | else:
47 | quant_utils.add_actquant(model) #Add Activation Wrapper to the model as the rest of the code assumes it is present
48 |
49 |
50 | if args.w_bits < 16:
51 | save_dict = {}
52 | if args.load_qmodel_path: # Load Quantized Rotated Model
53 | assert args.rotate, "Model should be rotated to load a quantized model!"
54 | assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
55 | print("Load quantized model from ", args.load_qmodel_path)
56 | save_dict = torch.load(args.load_qmodel_path)
57 | model.load_state_dict(save_dict["model"])
58 |
59 | elif not args.w_rtn: # GPTQ Weight Quantization
60 | assert "llama" in args.model, "Only llama is supported for GPTQ!"
61 |
62 | trainloader = data_utils.get_loaders(
63 | args.cal_dataset, nsamples=args.nsamples,
64 | seed=args.seed, model=args.model,
65 | seqlen=model.seqlen, eval_mode=False
66 | )
67 | quantizers = gptq_utils.gptq_fwrd(model, trainloader, utils.DEV, args)
68 | save_dict["w_quantizers"] = quantizers
69 | else: # RTN Weight Quantization
70 | quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args)
71 | save_dict["w_quantizers"] = quantizers
72 |
73 | if args.save_qmodel_path:
74 | save_dict["model"] = model.state_dict()
75 | torch.save(save_dict, args.save_qmodel_path)
76 |
77 |
78 | # Add Input Quantization
79 | if args.a_bits < 16 or args.v_bits < 16:
80 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
81 | down_proj_groupsize = -1
82 | if args.a_groupsize > 0 and "llama" in args.model:
83 | down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize)
84 |
85 | for name in qlayers:
86 | layer_input_bits = args.a_bits
87 | layer_groupsize = args.a_groupsize
88 | layer_a_sym = not(args.a_asym)
89 | layer_a_clip = args.a_clip_ratio
90 |
91 | if 'v_proj' in name and args.v_bits < 16: #Set the v_proj precision
92 | qlayers[name].out_quantizer.configure(bits=args.v_bits,
93 | groupsize=args.v_groupsize,
94 | sym=not(args.v_asym),
95 | clip_ratio=args.v_clip_ratio)
96 |
97 | if 'lm_head' in name: #Skip lm_head quantization
98 | layer_input_bits = 16
99 |
100 | if 'down_proj' in name: #Set the down_proj precision
101 | if args.int8_down_proj:
102 | layer_input_bits = 8
103 | layer_groupsize = down_proj_groupsize
104 |
105 |
106 | qlayers[name].quantizer.configure(bits=layer_input_bits,
107 | groupsize=layer_groupsize,
108 | sym=layer_a_sym,
109 | clip_ratio=layer_a_clip)
110 |
111 | if args.k_bits < 16:
112 | if args.k_pre_rope:
113 | raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
114 | else:
115 | rope_function_name = model_utils.get_rope_function_name(model)
116 | layers = model_utils.get_layers(model)
117 | k_quant_config = {'k_bits':args.k_bits, "k_groupsize": args.k_groupsize,
118 | "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio}
119 | for layer in layers:
120 | rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
121 | layer.self_attn,
122 | rope_function_name,
123 | config=model.config,
124 | **k_quant_config)
125 |
126 | # Evaluating on dataset
127 | testloader = data_utils.get_loaders(
128 | args.eval_dataset,
129 | seed=args.seed,
130 | model=args.model,
131 | seqlen=model.seqlen,
132 | hf_token=args.hf_token,
133 | eval_mode=True
134 | )
135 |
136 |
137 | dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, args)
138 | if args.wandb:
139 | wandb.log({'ppl/{}'.format(args.eval_dataset.upper()): dataset_ppl})
140 |
141 | if not args.lm_eval:
142 | return
143 | else:
144 | # Import lm_eval utils
145 | import lm_eval
146 | from lm_eval import utils as lm_eval_utils
147 | from lm_eval.api.registry import ALL_TASKS
148 | from lm_eval.models.huggingface import HFLM
149 |
150 |
151 |
152 | if args.distribute:
153 | utils.distribute_model(model)
154 | else:
155 | model.to(utils.DEV)
156 |
157 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=False, use_auth_token=args.hf_token)
158 | hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.lm_eval_batch_size)
159 |
160 | task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS)
161 | results = lm_eval.simple_evaluate(hflm, tasks=task_names, batch_size=args.lm_eval_batch_size)['results']
162 |
163 | metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results.items()}
164 | metric_vals['acc_avg'] = round(sum(metric_vals.values()) / len(metric_vals.values()), 4)
165 | print(metric_vals)
166 |
167 | if args.wandb:
168 | wandb.log(metric_vals)
169 |
170 |
171 | if __name__ == '__main__':
172 | main()
173 |
--------------------------------------------------------------------------------
/fake_quant/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import typing
3 | import transformers
4 | import utils
5 | import os
6 | import logging
7 |
8 | OPT_MODEL = transformers.models.opt.modeling_opt.OPTForCausalLM
9 | OPT_LAYER = transformers.models.opt.modeling_opt.OPTDecoderLayer
10 | LLAMA_MODEL = transformers.models.llama.modeling_llama.LlamaForCausalLM
11 | LLAMA_LAYER = transformers.models.llama.modeling_llama.LlamaDecoderLayer
12 |
13 |
14 | def model_type_extractor(model):
15 | if isinstance(model, LLAMA_MODEL):
16 | return LLAMA_MODEL
17 | elif isinstance(model, OPT_MODEL):
18 | return OPT_MODEL
19 | else:
20 | raise ValueError(f'Unknown model type {model}')
21 |
22 | def skip(*args, **kwargs):
23 | # This is a helper function to save time during the initialization!
24 | pass
25 |
26 | def get_rope_function_name(model):
27 | if isinstance(model, LLAMA_MODEL):
28 | return "apply_rotary_pos_emb"
29 | raise NotImplementedError
30 |
31 |
32 | def get_layers(model):
33 | if isinstance(model, OPT_MODEL):
34 | return model.model.decoder.layers
35 | if isinstance(model, LLAMA_MODEL):
36 | return model.model.layers
37 | raise NotImplementedError
38 |
39 |
40 | def get_llama(model_name, hf_token):
41 | torch.nn.init.kaiming_uniform_ = skip
42 | torch.nn.init.uniform_ = skip
43 | torch.nn.init.normal_ = skip
44 | model = transformers.LlamaForCausalLM.from_pretrained(model_name, torch_dtype='auto',
45 | use_auth_token=hf_token,
46 | low_cpu_mem_usage=True)
47 | model.seqlen = 2048
48 | logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen))
49 | return model
50 |
51 |
52 |
53 | def get_opt(model_name):
54 | torch.nn.init.kaiming_uniform_ = skip
55 | torch.nn.init.uniform_ = skip
56 | torch.nn.init.normal_ = skip
57 | model = transformers.OPTForCausalLM.from_pretrained(model_name, torch_dtype='auto',
58 | low_cpu_mem_usage=True)
59 | model.seqlen = model.config.max_position_embeddings
60 | logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen))
61 | return model
62 |
63 |
64 | def get_model(
65 | model_name, hf_token=None
66 | ):
67 | if 'llama' in model_name:
68 | return get_llama(model_name, hf_token)
69 | elif 'opt' in model_name:
70 | return get_opt(model_name)
71 | else:
72 | raise ValueError(f'Unknown model {model_name}')
73 |
74 |
75 | def get_model_type(model):
76 | if isinstance(model, OPT_MODEL):
77 | model_type = OPT_MODEL
78 | elif isinstance(model, LLAMA_MODEL):
79 | model_type = LLAMA_MODEL
80 | else:
81 | raise ValueError(f'Unknown model type {model}')
82 | return model_type
83 |
84 | def get_embeddings(model, model_type) -> list[torch.nn.Module]:
85 | if model_type == LLAMA_MODEL:
86 | return [model.model.embed_tokens]
87 | elif model_type == OPT_MODEL:
88 | return [model.model.decoder.embed_tokens, model.model.decoder.embed_positions]
89 | else:
90 | raise ValueError(f'Unknown model type {model_type}')
91 |
92 |
93 | def get_transformer_layers(model, model_type):
94 | if model_type == LLAMA_MODEL:
95 | return [layer for layer in model.model.layers]
96 | elif model_type == OPT_MODEL:
97 | return [layer for layer in model.model.decoder.layers]
98 | else:
99 | raise ValueError(f'Unknown model type {model_type}')
100 |
101 |
102 | def get_lm_head(model, model_type):
103 | if model_type == LLAMA_MODEL:
104 | return model.lm_head
105 | elif model_type == OPT_MODEL:
106 | return model.lm_head
107 | else:
108 | raise ValueError(f'Unknown model type {model_type}')
109 |
110 | def get_pre_head_layernorm(model, model_type):
111 | if model_type == LLAMA_MODEL:
112 | pre_head_layernorm = model.model.norm
113 | assert isinstance(pre_head_layernorm,
114 | transformers.models.llama.modeling_llama.LlamaRMSNorm)
115 | elif model_type == OPT_MODEL:
116 | pre_head_layernorm = model.model.decoder.final_layer_norm
117 | assert pre_head_layernorm is not None
118 | else:
119 | raise ValueError(f'Unknown model type {model_type}')
120 | return pre_head_layernorm
121 |
122 | def get_mlp_bottleneck_size(model):
123 | model_type = get_model_type(model)
124 | if model_type == LLAMA_MODEL:
125 | return model.config.intermediate_size
126 | elif model_type == OPT_MODEL:
127 | return model.config.ffn_dim
128 | else:
129 | raise ValueError(f'Unknown model type {model_type}')
130 |
131 | def replace_modules(
132 | root: torch.nn.Module,
133 | type_to_replace,
134 | new_module_factory,
135 | replace_layers: bool,
136 | ) -> None:
137 | """Replace modules of given type using the supplied module factory.
138 |
139 | Perform a depth-first search of a module hierarchy starting at root
140 | and replace all instances of type_to_replace with modules created by
141 | new_module_factory. Children of replaced modules are not processed.
142 |
143 | Args:
144 | root: the root of the module hierarchy where modules should be replaced
145 | type_to_replace: a type instances of which will be replaced
146 | new_module_factory: a function that given a module that should be replaced
147 | produces a module to replace it with.
148 | """
149 | for name, module in root.named_children():
150 | new_module = None
151 | if isinstance(module, type_to_replace):
152 | if replace_layers: # layernorm_fusion.replace_layers case where transformer layers are replaced
153 | new_module = new_module_factory(module, int(name))
154 | else: # layernorm_fusion.fuse_modules case where layernorms are fused
155 | new_module = new_module_factory(module)
156 | elif len(list(module.children())) > 0:
157 | replace_modules(module, type_to_replace, new_module_factory, replace_layers)
158 |
159 | if new_module is not None:
160 | setattr(root, name, new_module)
161 |
162 |
163 | class RMSN(torch.nn.Module):
164 | """
165 | This class implements the Root Mean Square Normalization (RMSN) layer.
166 | We use the implementation from LLAMARMSNorm here:
167 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L75
168 | """
169 |
170 | def __init__(self, mean_dim: int, eps=1e-5):
171 | super().__init__()
172 | self.eps = eps
173 | self.mean_dim = mean_dim
174 | self.weight = torch.nn.Parameter(torch.zeros(1))
175 |
176 | def forward(self, x: torch.Tensor) -> torch.Tensor:
177 | input_dtype = x.dtype
178 | if x.dtype == torch.float16:
179 | x = x.to(torch.float32)
180 | variance = x.pow(2).sum(-1, keepdim=True) / self.mean_dim
181 | x = x * torch.rsqrt(variance + self.eps)
182 | return x.to(input_dtype)
183 |
184 |
185 | def get_layer_io_save_path(args):
186 | return os.path.join(args.save_path, 'layer_io', f'{args.layer_idx:03d}.pt')
187 |
188 | def capture_layer_io(model_type, layer, layer_input):
189 | def hook_factory(module_name, captured_vals, is_input):
190 | def hook(module, input, output):
191 | if is_input:
192 | captured_vals[module_name].append(input[0].detach().cpu())
193 | else:
194 | captured_vals[module_name].append(output.detach().cpu())
195 | return hook
196 |
197 | handles = []
198 |
199 | if model_type == LLAMA_MODEL:
200 | captured_inputs = {
201 | 'k_proj': [], # q_proj, v_proj has the same input as k_proj
202 | 'o_proj': [],
203 | 'gate_proj': [], # up_proj has the same input as gate_proj
204 | 'down_proj': []
205 | }
206 |
207 | captured_outputs = {
208 | 'v_proj': [],
209 | }
210 |
211 | for name in captured_inputs.keys():
212 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None)
213 | handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True)))
214 |
215 | for name in captured_outputs.keys():
216 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None)
217 | handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False)))
218 |
219 | elif model_type == OPT_MODEL:
220 | captured_inputs = {
221 | 'k_proj': [], # q_proj, v_proj has the same input as k_proj
222 | 'out_proj': [],
223 | 'fc1': [],
224 | 'fc2': []
225 | }
226 | captured_outputs = {
227 | 'v_proj': [],
228 | }
229 | for name in captured_inputs.keys():
230 | # In OPT, fc1 and fc2 are directly contained in OPTDecoderLayer
231 | module = getattr(layer.self_attn, name, None) or getattr(layer, name, None)
232 | handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True)))
233 |
234 | for name in captured_outputs.keys():
235 | # In OPT, fc1 and fc2 are directly contained in OPTDecoderLayer
236 | module = getattr(layer.self_attn, name, None) or getattr(layer, name, None)
237 | handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False)))
238 | else:
239 | raise ValueError(f'Unknown model type {model_type}')
240 |
241 | # Process each sequence in the batch one by one to avoid OOM.
242 | for seq_idx in range(layer_input.shape[0]):
243 | # Extract the current sequence across all dimensions.
244 | seq = layer_input[seq_idx:seq_idx + 1].to(utils.DEV)
245 | # Perform a forward pass for the current sequence.
246 | layer(seq)
247 |
248 | # After processing all sequences, concatenate the accumulated inputs for each sub-layer across the batch.
249 | for module_name in captured_inputs:
250 | captured_inputs[module_name] = torch.cat(captured_inputs[module_name], dim=0)
251 | for module_name in captured_outputs:
252 | captured_outputs[module_name] = torch.cat(captured_outputs[module_name], dim=0)
253 |
254 | # Cleanup.
255 | for h in handles:
256 | h.remove()
257 |
258 | return {
259 | 'input': captured_inputs,
260 | 'output': captured_outputs
261 | }
262 |
--------------------------------------------------------------------------------
/fake_quant/monkeypatch.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import types
4 |
5 | def copy_func_with_new_globals(f, globals=None):
6 | """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
7 | if globals is None:
8 | globals = f.__globals__
9 | g = types.FunctionType(f.__code__, globals, name=f.__name__,
10 | argdefs=f.__defaults__, closure=f.__closure__)
11 | g = functools.update_wrapper(g, f)
12 | g.__module__ = f.__module__
13 | g.__kwdefaults__ = copy.copy(f.__kwdefaults__)
14 | return g
15 |
16 | def add_wrapper_after_function_call_in_method(module, method_name, function_name, wrapper_fn):
17 | '''
18 | This function adds a wrapper after the output of a function call in the method named `method_name`.
19 | Only calls directly in the method are affected. Calls by other functions called in the method are not affected.
20 | '''
21 |
22 | original_method = getattr(module, method_name).__func__
23 | method_globals = dict(original_method.__globals__)
24 | wrapper = wrapper_fn(method_globals[function_name])
25 | method_globals[function_name] = wrapper
26 | new_method = copy_func_with_new_globals(original_method, globals=method_globals)
27 | setattr(module, method_name, new_method.__get__(module))
28 | return wrapper
29 |
30 |
--------------------------------------------------------------------------------
/img/carrot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/QuaRot/5008669b08c1f11f9b64d52d16fddd47ca754c5a/img/carrot.png
--------------------------------------------------------------------------------
/img/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/QuaRot/5008669b08c1f11f9b64d52d16fddd47ca754c5a/img/fig1.png
--------------------------------------------------------------------------------
/quarot/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from . import nn
3 | from . import functional
4 |
5 |
6 | import quarot._CUDA
7 |
8 |
9 | __all__ = [
10 | "matmul", #int-4 matmul
11 | "sym_quant", "sym_dequant", "PackedQuantizedTensor", # Quantization
12 | ]
13 |
14 | class ShapeHandler:
15 | def __init__(self, x: torch.Tensor):
16 | self.size_excl_last = x.numel()//x.shape[-1]
17 | self.shape_excl_last = tuple(x.shape[:-1])
18 |
19 | # Keep the last dim unchanged, flatten all previous dims
20 | def flatten(self, x: torch.Tensor):
21 | return x.view(self.size_excl_last, -1)
22 |
23 | # Recover back to the original shape.
24 | def unflatten(self, x: torch.Tensor):
25 | return x.view(self.shape_excl_last + (-1,))
26 |
27 | def unflatten_scale(self, x: torch.Tensor):
28 | return x.view(self.shape_excl_last)
29 |
30 |
31 | def flatten_last_dim_and_return_shape(x: torch.Tensor):
32 | shape_excl_last = x.shape[:-1]
33 | x = x.view(-1, x.shape[-1])
34 | return x, shape_excl_last
35 |
36 |
37 | def matmul(A, B):
38 | assert A.shape[-1] % 32 == 0, "A.shape[-1]: {} must be multiplication of 32".format(A.shape[-1])
39 | A, A_shape_excl_last = flatten_last_dim_and_return_shape(A)
40 | B, B_shape_excl_last = flatten_last_dim_and_return_shape(B)
41 | return quarot._CUDA.matmul(A, B).view(*A_shape_excl_last, *B_shape_excl_last)
42 |
43 | def sym_quant(x, scale):
44 | assert x.dtype == scale.dtype == torch.float16
45 | x, x_shape_excl_last = flatten_last_dim_and_return_shape(x)
46 | return quarot._CUDA.sym_quant(x, scale.view(-1)).view(*x_shape_excl_last, -1)
47 |
48 | def sym_dequant(q, scale_row, scale_col, bits=32):
49 | assert q.dtype == torch.int32
50 | assert scale_row.dtype == scale_col.dtype == torch.float16
51 | q, q_shape_excl_last = flatten_last_dim_and_return_shape(q)
52 | return quarot._CUDA.sym_dequant(q, scale_row.view(-1), scale_col, bits).view(*q_shape_excl_last, -1)
53 |
54 |
55 | class PackedQuantizedTensor:
56 | def __init__(self,
57 | quantized_x: torch.Tensor,
58 | scales_x: torch.Tensor):
59 | self.quantized_x = quantized_x
60 | self.scales_x = scales_x
61 |
62 | def size(self):
63 | return self.quantized_x.size()
64 |
65 | @property
66 | def device(self):
67 | return self.quantized_x.device
68 |
69 | @property
70 | def dtype(self):
71 | return self.quantized_x.dtype
72 |
--------------------------------------------------------------------------------
/quarot/functional/__init__.py:
--------------------------------------------------------------------------------
1 | from .quantization import pack_i4, unpack_i4, asym_quant_dequant, sym_quant_dequant
2 | from .hadamard import (
3 | matmul_hadU_cuda,
4 | random_hadamard_matrix,
5 | apply_exact_had_to_linear)
6 |
--------------------------------------------------------------------------------
/quarot/functional/quantization.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def two_compl(x, bits: int):
5 | return torch.where(x < 0, 2 ** bits + x, x)
6 |
7 | def get_minq_maxq(bits: int, sym: bool):
8 | if sym:
9 | maxq = torch.tensor(2**(bits-1)-1)
10 | minq = torch.tensor(-maxq -1)
11 | else:
12 | maxq = torch.tensor(2**bits - 1)
13 | minq = torch.tensor(0)
14 |
15 | return minq, maxq
16 |
17 | def asym_quant(x, scale, zero, maxq):
18 | scale = scale.to(x.device)
19 | zero = zero.to(x.device)
20 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
21 | return q, scale, zero
22 |
23 | def asym_dequant(q, scale, zero):
24 | return scale * (q - zero)
25 |
26 | def asym_quant_dequant(x, scale, zero, maxq):
27 | return asym_dequant(*asym_quant(x, scale, zero, maxq))
28 |
29 | def sym_quant(x, scale, maxq):
30 | scale = scale.to(x.device)
31 | q = torch.clamp(torch.round(x / scale), -(maxq+1), maxq)
32 | return q, scale
33 | def sym_dequant(q, scale):
34 | return scale * q
35 |
36 | def sym_quant_dequant(x, scale, maxq):
37 | return sym_dequant(*sym_quant(x, scale, maxq))
38 |
39 |
40 |
41 | # Pack the int tensor. Each uint8 stores two int4 value.
42 | def pack_i4(q):
43 | assert torch.is_signed(q), 'The tensor to be packed should be signed int'
44 | minq, maxq = get_minq_maxq(4, True)
45 | assert torch.all(torch.logical_and(q >= minq, q <= maxq))
46 |
47 | q_i8 = two_compl(q.to(dtype=torch.int8), 4).to(torch.uint8)
48 | q_i4 = q_i8[:, 0::2] | (q_i8[:, 1::2] << 4)
49 | return q_i4
50 |
51 | # Unpack the quantized int4 tensor (stored in uint8) into int32 tensor.
52 | def unpack_i4(x: torch.Tensor):
53 | assert x.dtype == torch.uint8, 'The tensor to be unpacked should be stored in uint8'
54 |
55 | out_shape = list(x.shape)
56 | out_shape[-1] *= 2 # Each uint8 packs two numbers
57 |
58 | # Low 4 bits
59 | x0 = (x & 0x0f).to(torch.int8)
60 | x0[x0>=8] -= 16
61 | x0 = x0.view(-1, x0.shape[-1])
62 |
63 | # High 4 bits
64 | x1 = ((x & 0xf0) >> 4).to(torch.int8)
65 | x1[x1>=8] -= 16
66 | x1 = x1.view(-1, x1.shape[-1])
67 |
68 | out = torch.empty(out_shape, device=x.device, dtype=torch.int32)
69 | out = out.view(-1, out.shape[-1])
70 | # Interleaving
71 | out[:, 0::2] = x0
72 | out[:, 1::2] = x1
73 |
74 | return out.view(out_shape)
75 |
--------------------------------------------------------------------------------
/quarot/kernels/flashinfer.cu:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | template
10 | void FlashInferBatchDecodeKernel_i4(nv_half* o, nv_half* q, void* kv_data,
11 | nv_half2* kv_param, int32_t* kv_indptr,
12 | int32_t* kv_indicies,
13 | int32_t* last_page_offset, int num_layers,
14 | int layer_idx, int num_heads, int page_size,
15 | int batch_size) {
16 | using DTypeIn = flashinfer::quant::__precision__s4;
17 | using DTypeInQ = nv_half;
18 | using DTypeOut = nv_half;
19 |
20 | flashinfer::paged_kv_t paged_kv(
21 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size,
22 | (DTypeIn*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset);
23 |
24 | const float rope_scale = 1.f;
25 | const float rope_theta = 1e4;
26 | const float sm_scale = 1.f / std::sqrt(float(head_dim));
27 | const float rope_inv_scale = 1.f / rope_scale;
28 | const float rope_inv_theta = 1.f / rope_theta;
29 |
30 | constexpr bool norm_on_the_fly = false;
31 | constexpr auto rotary_mode = flashinfer::RotaryMode::kNone;
32 | constexpr size_t FoldFactor = 2;
33 | constexpr size_t vec_size = std::max(
34 | static_cast(16 / flashinfer::quant::size_of_type() /
35 | FoldFactor),
36 | static_cast(head_dim / 32));
37 | constexpr size_t bdx = head_dim / vec_size;
38 | constexpr size_t bdy = 128 / bdx;
39 | dim3 nblks(paged_kv.batch_size, paged_kv.num_heads);
40 | dim3 nthrs(bdx, bdy);
41 |
42 | flashinfer::BatchDecodeWithPagedKVCacheKernel<
43 | rotary_mode, norm_on_the_fly, vec_size, bdx, bdy, FoldFactor, DTypeInQ,
44 | DTypeIn, DTypeOut, int32_t><<>>(
45 | q, paged_kv, o, sm_scale, rope_inv_scale, rope_inv_theta);
46 | }
47 |
48 | template
49 | void FlashInferInitKvKernel_i4(void* kv_data, nv_half2* kv_param,
50 | int32_t* kv_indptr, int32_t* kv_indicies,
51 | int32_t* last_page_offset, void* key,
52 | void* value, nv_half2* key_param,
53 | nv_half2* value_param, int32_t* seqlen_indptr,
54 | int num_layers, int layer_idx, int num_heads,
55 | int page_size, int batch_size) {
56 | using T = flashinfer::quant::__precision__s4;
57 | flashinfer::paged_kv_t paged_kv(
58 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size,
59 | (T*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset);
60 |
61 | constexpr size_t vec_size =
62 | std::max(static_cast(16 / flashinfer::quant::size_of_type()),
63 | static_cast(head_dim / 32));
64 | constexpr size_t bdx = head_dim / vec_size;
65 | constexpr size_t bdy = 128 / bdx;
66 | dim3 nblks(paged_kv.batch_size * ((paged_kv.num_heads + bdy - 1) / bdy));
67 | dim3 nthrs(bdx, bdy);
68 | flashinfer::AppendPagedKVCachePrefillKernel<<>>(
70 | paged_kv, (T*)key, (T*)value, key_param, value_param, seqlen_indptr);
71 | }
72 |
73 | template
74 | void FlashInferAppendKvKernel_i4(void* kv_data, nv_half2* kv_param,
75 | int32_t* kv_indptr, int32_t* kv_indicies,
76 | int32_t* last_page_offset, void* key,
77 | void* value, nv_half2* key_param,
78 | nv_half2* value_param, int num_layers,
79 | int layer_idx, int num_heads, int page_size,
80 | int batch_size) {
81 | using T = flashinfer::quant::__precision__s4;
82 | flashinfer::paged_kv_t paged_kv(
83 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size,
84 | (T*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset);
85 |
86 | constexpr size_t vec_size =
87 | std::max(static_cast(16 / flashinfer::quant::size_of_type()),
88 | static_cast(head_dim / 32));
89 | constexpr size_t bdx = head_dim / vec_size;
90 | constexpr size_t bdy = 128 / bdx;
91 | dim3 nblks(paged_kv.batch_size * ((paged_kv.num_heads + bdy - 1) / bdy));
92 | dim3 nthrs(bdx, bdy);
93 | flashinfer::AppendPagedKVCacheDecodeKernel
95 | <<>>(paged_kv, (T*)key, (T*)value, key_param, value_param);
96 | }
97 |
98 |
99 | template
100 | void FlashInferBatchDecodeKernel_f16(nv_half* o, nv_half* q, void* kv_data,
101 | nv_half2* kv_param, int32_t* kv_indptr,
102 | int32_t* kv_indicies,
103 | int32_t* last_page_offset, int num_layers,
104 | int layer_idx, int num_heads, int page_size,
105 | int batch_size) {
106 | using DTypeIn = nv_half;
107 | using DTypeInQ = nv_half;
108 | using DTypeOut = nv_half;
109 |
110 | flashinfer::paged_kv_t paged_kv(
111 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size,
112 | (DTypeIn*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset);
113 |
114 | const float rope_scale = 1.f;
115 | const float rope_theta = 1e4;
116 | const float sm_scale = 1.f / std::sqrt(float(head_dim));
117 | const float rope_inv_scale = 1.f / rope_scale;
118 | const float rope_inv_theta = 1.f / rope_theta;
119 |
120 | constexpr bool norm_on_the_fly = false;
121 | constexpr auto rotary_mode = flashinfer::RotaryMode::kNone;
122 | constexpr size_t FoldFactor = 1;
123 | constexpr size_t vec_size = std::max(
124 | static_cast(16 / flashinfer::quant::size_of_type() /
125 | FoldFactor),
126 | static_cast(head_dim / 32));
127 | constexpr size_t bdx = head_dim / vec_size;
128 | constexpr size_t bdy = 128 / bdx;
129 | dim3 nblks(paged_kv.batch_size, paged_kv.num_heads);
130 | dim3 nthrs(bdx, bdy);
131 |
132 | flashinfer::BatchDecodeWithPagedKVCacheKernel<
133 | rotary_mode, norm_on_the_fly, vec_size, bdx, bdy, FoldFactor, DTypeInQ,
134 | DTypeIn, DTypeOut, int32_t><<>>(
135 | q, paged_kv, o, sm_scale, rope_inv_scale, rope_inv_theta);
136 | }
137 |
138 | template
139 | void FlashInferInitKvKernel_f16(void* kv_data, nv_half2* kv_param,
140 | int32_t* kv_indptr, int32_t* kv_indicies,
141 | int32_t* last_page_offset, void* key,
142 | void* value, nv_half2* key_param,
143 | nv_half2* value_param, int32_t* seqlen_indptr,
144 | int num_layers, int layer_idx, int num_heads,
145 | int page_size, int batch_size) {
146 | using T = nv_half;
147 | flashinfer::paged_kv_t paged_kv(
148 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size,
149 | (T*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset);
150 |
151 | constexpr size_t vec_size =
152 | std::max(static_cast(16 / flashinfer::quant::size_of_type()),
153 | static_cast(head_dim / 32));
154 | constexpr size_t bdx = head_dim / vec_size;
155 | constexpr size_t bdy = 128 / bdx;
156 | dim3 nblks(paged_kv.batch_size * ((paged_kv.num_heads + bdy - 1) / bdy));
157 | dim3 nthrs(bdx, bdy);
158 | flashinfer::AppendPagedKVCachePrefillKernel<<>>(
160 | paged_kv, (T*)key, (T*)value, key_param, value_param, seqlen_indptr);
161 | }
162 |
163 | template
164 | void FlashInferAppendKvKernel_f16(void* kv_data, nv_half2* kv_param,
165 | int32_t* kv_indptr, int32_t* kv_indicies,
166 | int32_t* last_page_offset, void* key,
167 | void* value, nv_half2* key_param,
168 | nv_half2* value_param, int num_layers,
169 | int layer_idx, int num_heads, int page_size,
170 | int batch_size) {
171 | using T = nv_half;
172 | flashinfer::paged_kv_t paged_kv(
173 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size,
174 | (T*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset);
175 |
176 | constexpr size_t vec_size =
177 | std::max(static_cast(16 / flashinfer::quant::size_of_type()),
178 | static_cast(head_dim / 32));
179 | constexpr size_t bdx = head_dim / vec_size;
180 | constexpr size_t bdy = 128 / bdx;
181 | dim3 nblks(paged_kv.batch_size * ((paged_kv.num_heads + bdy - 1) / bdy));
182 | dim3 nthrs(bdx, bdy);
183 | flashinfer::AppendPagedKVCacheDecodeKernel
185 | <<>>(paged_kv, (T*)key, (T*)value, key_param, value_param);
186 | }
187 |
188 |
189 | template void FlashInferBatchDecodeKernel_i4<128>(
190 | nv_half* o, nv_half* q, void* kv_data, nv_half2* kv_param,
191 | int32_t* kv_indptr, int32_t* kv_indicies, int32_t* last_page_offset,
192 | int num_layers, int layer_idx, int num_heads, int page_size,
193 | int batch_size);
194 |
195 | template void FlashInferInitKvKernel_i4<128>(
196 | void* kv_data, nv_half2* kv_param, int32_t* kv_indptr, int32_t* kv_indicies,
197 | int32_t* last_page_offset, void* key, void* value, nv_half2* key_param,
198 | nv_half2* value_param, int32_t* seqlen_indptr, int num_layers,
199 | int layer_idx, int num_heads, int page_size, int batch_size);
200 |
201 | template void FlashInferAppendKvKernel_i4<128>(
202 | void* kv_data, nv_half2* kv_param, int32_t* kv_indptr, int32_t* kv_indicies,
203 | int32_t* last_page_offset, void* key, void* value, nv_half2* key_param,
204 | nv_half2* value_param, int num_layers, int layer_idx, int num_heads,
205 | int page_size, int batch_size);
206 |
207 |
208 | template void FlashInferBatchDecodeKernel_f16<128>(
209 | nv_half* o, nv_half* q, void* kv_data, nv_half2* kv_param,
210 | int32_t* kv_indptr, int32_t* kv_indicies, int32_t* last_page_offset,
211 | int num_layers, int layer_idx, int num_heads, int page_size,
212 | int batch_size);
213 |
214 | template void FlashInferInitKvKernel_f16<128>(
215 | void* kv_data, nv_half2* kv_param, int32_t* kv_indptr, int32_t* kv_indicies,
216 | int32_t* last_page_offset, void* key, void* value, nv_half2* key_param,
217 | nv_half2* value_param, int32_t* seqlen_indptr, int num_layers,
218 | int layer_idx, int num_heads, int page_size, int batch_size);
219 |
220 | template void FlashInferAppendKvKernel_f16<128>(
221 | void* kv_data, nv_half2* kv_param, int32_t* kv_indptr, int32_t* kv_indicies,
222 | int32_t* last_page_offset, void* key, void* value, nv_half2* key_param,
223 | nv_half2* value_param, int num_layers, int layer_idx, int num_heads,
224 | int page_size, int batch_size);
225 |
--------------------------------------------------------------------------------
/quarot/kernels/gemm.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 |
5 |
6 |
7 | void matmul_host(
8 | const Int4Storage *A,
9 | const Int4Storage *B,
10 | uint32_t M,
11 | uint32_t N,
12 | uint32_t K,
13 | int32_t *C
14 | )
15 | {
16 | using Gemm = cutlass::gemm::device::Gemm<
17 | cutlass::int4b_t, // ElementA
18 | cutlass::layout::RowMajor, // LayoutA
19 | cutlass::int4b_t, // ElementB
20 | cutlass::layout::ColumnMajor, // LayoutB
21 | int32_t, // ElementOutput
22 | cutlass::layout::RowMajor, // LayoutOutput
23 | int32_t, // ElementAccumulator
24 | cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores
25 | cutlass::arch::Sm80 // tag indicating target GPU compute architecture // TODO: This is just for compiling on my laptop temporarily. Should be higher when doing benchmarking.
26 | >;
27 |
28 | Gemm gemmOp;
29 |
30 | using GemmCoord = cutlass::gemm::GemmCoord;
31 |
32 | typename Gemm::Arguments arguments{
33 | {static_cast(M), static_cast(N), static_cast(K)},
34 | {(cutlass::int4b_t *) A, K},
35 | {(cutlass::int4b_t *) B, K},
36 | {C, N},
37 | {C, N},
38 | {1, 0}
39 | };
40 |
41 | auto status = gemmOp(arguments);
42 |
43 | ensure(status == cutlass::Status::kSuccess,
44 | cutlassGetStatusString(status));
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/quarot/kernels/include/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 |
6 | #include
7 | #include
8 |
9 | //#include
10 | #include
11 |
12 | #define HOST_DEVICE __forceinline__ __host__ __device__
13 | #define DEVICE __forceinline__ __device__
14 | #define HOST __forceinline__ __host__
15 |
16 | HOST void ensure(bool condition, const std::string& msg) {
17 | if (!condition) {
18 | std::cerr << "Assertion failed: " << msg << '\n';
19 | // Choose the appropriate action: throw an exception, abort, etc.
20 | // For example, throwing an exception:
21 | throw std::runtime_error(msg);
22 | }
23 | }
24 |
25 | template
26 | HOST_DEVICE T mymax(T a, T b)
27 | {
28 | return a > b ? a : b;
29 | }
30 |
31 | template
32 | HOST_DEVICE T mymin(T a, T b)
33 | {
34 | return a < b ? a : b;
35 | }
36 |
37 | template
38 | HOST_DEVICE T cdiv(T a, T b) { return (a + b - 1) / b; }
39 |
40 | template
41 | HOST_DEVICE T clamp(T x, T a, T b) { return mymax(a, mymin(b, x)); }
42 |
43 | template
44 | HOST_DEVICE T myabs(T x) { return x < (T) 0 ? -x : x; }
45 |
46 | template
47 | DEVICE T sqr(T x)
48 | {
49 | return x * x;
50 | }
51 |
52 | constexpr int qmin = -8;
53 | constexpr int qmax = 7;
54 |
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | #include
5 |
6 | template
7 | void FlashInferBatchDecodeKernel_i4(nv_half* o, nv_half* q, void* kv_data,
8 | nv_half2* kv_param, int32_t* kv_indptr,
9 | int32_t* kv_indicies,
10 | int32_t* last_page_offset, int num_layers,
11 | int layer_idx, int num_heads, int page_size,
12 | int batch_size);
13 |
14 | template
15 | void FlashInferInitKvKernel_i4(void* kv_data, nv_half2* kv_param,
16 | int32_t* kv_indptr, int32_t* kv_indicies,
17 | int32_t* last_page_offset, void* key,
18 | void* value, nv_half2* key_param,
19 | nv_half2* value_param, int32_t* seqlen_indptr,
20 | int num_layers, int layer_idx, int num_heads,
21 | int page_size, int batch_size);
22 |
23 | template
24 | void FlashInferAppendKvKernel_i4(void* kv_data, nv_half2* kv_param,
25 | int32_t* kv_indptr, int32_t* kv_indicies,
26 | int32_t* last_page_offset, void* key,
27 | void* value, nv_half2* key_param,
28 | nv_half2* value_param, int num_layers,
29 | int layer_idx, int num_heads, int page_size,
30 | int batch_size);
31 |
32 |
33 | template
34 | void FlashInferBatchDecodeKernel_f16(nv_half* o, nv_half* q, void* kv_data,
35 | nv_half2* kv_param, int32_t* kv_indptr,
36 | int32_t* kv_indicies,
37 | int32_t* last_page_offset, int num_layers,
38 | int layer_idx, int num_heads, int page_size,
39 | int batch_size);
40 |
41 | template
42 | void FlashInferInitKvKernel_f16(void* kv_data, nv_half2* kv_param,
43 | int32_t* kv_indptr, int32_t* kv_indicies,
44 | int32_t* last_page_offset, void* key,
45 | void* value, nv_half2* key_param,
46 | nv_half2* value_param, int32_t* seqlen_indptr,
47 | int num_layers, int layer_idx, int num_heads,
48 | int page_size, int batch_size);
49 |
50 | template
51 | void FlashInferAppendKvKernel_f16(void* kv_data, nv_half2* kv_param,
52 | int32_t* kv_indptr, int32_t* kv_indicies,
53 | int32_t* last_page_offset, void* key,
54 | void* value, nv_half2* key_param,
55 | nv_half2* value_param, int num_layers,
56 | int layer_idx, int num_heads, int page_size,
57 | int batch_size);
58 |
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/cp_async.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_CP_ASYNC_CUH_
2 | #define FLASHINFER_CP_ASYNC_CUH_
3 |
4 | #include
5 | #include "quantization.cuh"
6 |
7 | namespace flashinfer {
8 |
9 | namespace cp_async {
10 |
11 | __device__ __forceinline__ void commit_group() {
12 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
13 | asm volatile("cp.async.commit_group;\n" ::);
14 | #endif
15 | }
16 |
17 | template
18 | __device__ __forceinline__ void wait_group() {
19 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
20 | asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
21 | #endif
22 | }
23 |
24 | template
25 | __device__ __forceinline__ void load_128(T* smem_ptr, const T* gmem_ptr) {
26 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
27 |
28 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr));
29 | if constexpr (prefetch) {
30 | asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
31 | "l"(gmem_ptr), "n"(16), "r"(16));
32 | } else {
33 | asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(smem_int_ptr), "l"(gmem_ptr),
34 | "n"(16));
35 | }
36 | #else
37 | *((uint4*)smem_ptr) = *((uint4*)gmem_ptr);
38 | #endif
39 | }
40 |
41 | template
42 | __device__ __forceinline__ void pred_load_128(T* smem_ptr, const T* gmem_ptr, bool predicate) {
43 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
44 |
45 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr));
46 | if constexpr (prefetch) {
47 | asm volatile(
48 | "{\n"
49 | " .reg .pred p;\n"
50 | " setp.ne.b32 p, %0, 0;\n"
51 | " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
52 | "}\n" ::"r"((int)predicate),
53 | "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16));
54 | } else {
55 | asm volatile(
56 | "{\n"
57 | " .reg .pred p;\n"
58 | " setp.ne.b32 p, %0, 0;\n"
59 | " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
60 | "}\n" ::"r"((int)predicate),
61 | "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16));
62 | }
63 | #else
64 | if (predicate) {
65 | *((uint4*)smem_ptr) = *((uint4*)gmem_ptr);
66 | }
67 | #endif
68 | }
69 |
70 | template
71 | __device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) {
72 | static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256");
73 | if constexpr (num_bits == 128) {
74 | load_128(smem_ptr, gmem_ptr);
75 | } else {
76 | load_128(smem_ptr, gmem_ptr);
77 | load_128(smem_ptr + static_cast(16 / quant::size_of_type()), gmem_ptr + static_cast(16 / quant::size_of_type()));
78 | }
79 | }
80 |
81 | template
82 | __device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) {
83 | static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256");
84 | if constexpr (num_bits == 128) {
85 | pred_load_128(smem_ptr, gmem_ptr, predicate);
86 | } else {
87 | pred_load_128(smem_ptr, gmem_ptr, predicate);
88 | pred_load_128(smem_ptr + static_cast(16 / quant::size_of_type()), gmem_ptr + static_cast(16 / quant::size_of_type()), predicate);
89 | }
90 | }
91 |
92 | } // namespace cp_async
93 |
94 | } // namespace flashinfer
95 |
96 | #endif // FLASHINFER_CP_ASYNC_CUH_
97 |
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/layout.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_LAYOUT_CUH_
2 | #define FLASHINFER_LAYOUT_CUH_
3 |
4 | #include
5 |
6 | namespace flashinfer {
7 |
8 | /*!
9 | * \brief The Layout of QKV matrices
10 | */
11 | enum class QKVLayout {
12 | // [seq_len, num_heads, head_dim]
13 | kNHD = 0U,
14 | // [num_heads, head_dim, seq_len]
15 | kHND = 1U,
16 | };
17 |
18 | template
19 | __host__ __device__ __forceinline__ size_t get_elem_offset_impl(size_t elem_idx, size_t head_idx,
20 | size_t feat_idx, size_t seq_len,
21 | size_t num_heads, size_t head_dim) {
22 | if constexpr (layout == QKVLayout::kHND) {
23 | return (head_idx * seq_len + elem_idx) * head_dim + feat_idx;
24 | } else {
25 | return (elem_idx * num_heads + head_idx) * head_dim + feat_idx;
26 | }
27 | }
28 |
29 | template
30 | struct tensor_info_t {
31 | size_t qo_len;
32 | size_t kv_len;
33 | size_t num_heads;
34 | size_t head_dim;
35 | __host__ __device__ __forceinline__ tensor_info_t(size_t qo_len, size_t kv_len, size_t num_heads,
36 | size_t head_dim)
37 | : qo_len(qo_len), kv_len(kv_len), num_heads(num_heads), head_dim(head_dim) {}
38 |
39 | __host__ __device__ __forceinline__ size_t get_qo_elem_offset(size_t query_idx, size_t head_idx,
40 | size_t feat_idx) const {
41 | return get_elem_offset_impl(query_idx, head_idx, feat_idx, qo_len, num_heads, head_dim);
42 | }
43 |
44 | __host__ __device__ __forceinline__ size_t get_kv_elem_offset(size_t kv_idx, size_t head_idx,
45 | size_t feat_idx) const {
46 | return get_elem_offset_impl(kv_idx, head_idx, feat_idx, kv_len, num_heads, head_dim);
47 | }
48 | };
49 |
50 | /*!
51 | * \brief Convert QKVLayout to string
52 | * \param qkv_layout The QKVLayout to convert
53 | */
54 | inline std::string QKVLayoutToString(const QKVLayout &qkv_layout) {
55 | switch (qkv_layout) {
56 | case QKVLayout::kNHD:
57 | return "NHD";
58 | case QKVLayout::kHND:
59 | return "HND";
60 | default:
61 | return "Unknown";
62 | }
63 | }
64 |
65 | } // namespace flashinfer
66 | #endif // FLASHINFER_LAYOUT_CUH_
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/math.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_MATH_CUH_
2 | #define FLASHINFER_MATH_CUH_
3 |
4 | #include
5 | #include
6 |
7 | namespace flashinfer {
8 | namespace math {
9 |
10 | constexpr float log2e = M_LOG2E; // 1.44269504088896340736f;
11 |
12 | __forceinline__ __device__ float ptx_exp2(float x) {
13 | float y = exp2f(x);
14 | // asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
15 | return y;
16 | }
17 |
18 | __forceinline__ __device__ float shfl_xor_sync(float x, int delta) {
19 | float y;
20 | asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;" : "=f"(y) : "f"(x), "r"(delta));
21 | return y;
22 | }
23 |
24 | } // namespace math
25 | } // namespace flashinfer
26 | #endif // FLASHINFER_MATH_CUH_
27 |
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/mma.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_MMA_CUH_
2 | #define FLASHINFER_MMA_CUH_
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 |
10 | namespace flashinfer {
11 |
12 | namespace mma {
13 |
14 | constexpr size_t frag_size = 16;
15 |
16 | template
17 | __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t *R, T *smem_ptr) {
18 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr));
19 | asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
20 | : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3])
21 | : "r"(smem_int_ptr));
22 | }
23 |
24 | template
25 | __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t *R, T *smem_ptr) {
26 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr));
27 | asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
28 | : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3])
29 | : "r"(smem_int_ptr));
30 | }
31 |
32 | template
33 | __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t *R, T *smem_ptr) {
34 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 11)
35 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr));
36 | asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n"
37 | : "r"(smem_int_ptr), "r"(R[0]), "r"(R[1]), "r"(R[2]), "r"(R[3]));
38 | #else
39 | // NOTE(Zihao): Not implemented yet.
40 | #endif
41 | }
42 |
43 | template
44 | __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float *C, uint32_t *A,
45 | uint32_t *B) {
46 | if constexpr (std::is_same::value) {
47 | asm volatile(
48 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
49 | "{%0, %1, %2, %3},"
50 | "{%4, %5, %6, %7},"
51 | "{%8, %9},"
52 | "{%10, %11, %12, %13};\n"
53 | : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
54 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]),
55 | "f"(C[2]), "f"(C[3]));
56 | asm volatile(
57 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
58 | "{%0, %1, %2, %3},"
59 | "{%4, %5, %6, %7},"
60 | "{%8, %9},"
61 | "{%10, %11, %12, %13};\n"
62 | : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
63 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]),
64 | "f"(C[6]), "f"(C[7]));
65 | } else {
66 | asm volatile(
67 | "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
68 | "{%0, %1, %2, %3},"
69 | "{%4, %5, %6, %7},"
70 | "{%8, %9},"
71 | "{%10, %11, %12, %13};\n"
72 | : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
73 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]),
74 | "f"(C[2]), "f"(C[3]));
75 | asm volatile(
76 | "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
77 | "{%0, %1, %2, %3},"
78 | "{%4, %5, %6, %7},"
79 | "{%8, %9},"
80 | "{%10, %11, %12, %13};\n"
81 | : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
82 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]),
83 | "f"(C[6]), "f"(C[7]));
84 | }
85 | }
86 |
87 | } // namespace mma
88 |
89 | } // namespace flashinfer
90 |
91 | #endif // FLASHINFER_MMA_CUH_
92 |
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/page.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_PAGE_CUH_
2 | #define FLASHINFER_PAGE_CUH_
3 |
4 | #include "layout.cuh"
5 | #include "utils.cuh"
6 | #include "vec_dtypes.cuh"
7 | #include "quantization.cuh"
8 |
9 | namespace flashinfer {
10 |
11 | /*!
12 | * \brief Paged key-value cache
13 | * \tparam DType The data type of the key-value cache
14 | * \tparam IdType The index data type of the kv-cache
15 | * \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, head_dim]
16 | * \note This layout is kind of HND, which is memory-friendly for the self-attn kernel's tile block.
17 | */
18 | template
19 | struct paged_kv_t {
20 | size_t num_layers;
21 | size_t layer_idx;
22 | size_t num_heads;
23 | size_t page_size;
24 | size_t head_dim;
25 | size_t batch_size;
26 | // [max_num_pages * num_layers * 2 * num_heads * page_size * head_dim]
27 | // The flattened key-value cache
28 | DType* data;
29 | // [max_num_pages * num_layers * 2 * num_heads * page_size * 1]
30 | // The flattened key-value quantization parameter cache
31 | half2* param;
32 | // [batch_size + 1] The page indptr array, with the first element 0
33 | IdType* indptr;
34 | // [nnz_pages] The page indices array
35 | IdType* indices;
36 | // [batch_size] The offset of the last page for each request in the batch
37 | IdType* last_page_offset;
38 | /*!
39 | * \brief Construct a paged key-value cache
40 | * \param num_layers The number of layers
41 | * \param layer_idx The index of the layer
42 | * \param num_heads The number of heads
43 | * \param page_size The size of each page
44 | * \param head_dim The dimension of each head
45 | * \param batch_size The batch size
46 | * \param data The flattened key-value cache
47 | * \param indptr The page indptr array
48 | * \param indices The page indices array
49 | * \param last_page_offset The offset of the last page for each request in the batch
50 | */
51 | __host__ __device__ __forceinline__ paged_kv_t(
52 | size_t num_layers,
53 | size_t layer_idx,
54 | size_t num_heads,
55 | size_t page_size,
56 | size_t head_dim,
57 | size_t batch_size,
58 | DType* data,
59 | half2* param,
60 | IdType* indptr,
61 | IdType* indices,
62 | IdType* last_page_offset
63 | ): num_layers(num_layers),
64 | layer_idx(layer_idx),
65 | num_heads(num_heads),
66 | page_size(page_size),
67 | head_dim(head_dim),
68 | batch_size(batch_size),
69 | data(data),
70 | param(param),
71 | indptr(indptr),
72 | indices(indices),
73 | last_page_offset(last_page_offset) {}
74 |
75 | // \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, head_dim]
76 | __host__ __device__ __forceinline__ size_t get_k_elem_offset(size_t page_idx, size_t head_idx,
77 | size_t entry_idx, size_t feat_idx) {
78 | return (((page_idx * num_layers + layer_idx) * 2 * num_heads + head_idx) * page_size +
79 | entry_idx) *
80 | head_dim +
81 | feat_idx;
82 | }
83 |
84 | // \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, head_dim]
85 | __host__ __device__ __forceinline__ size_t get_v_elem_offset(size_t page_idx, size_t head_idx,
86 | size_t entry_idx, size_t feat_idx) {
87 | return ((((page_idx * num_layers + layer_idx) * 2 + 1) * num_heads + head_idx) * page_size +
88 | entry_idx) *
89 | head_dim +
90 | feat_idx;
91 | }
92 |
93 | // \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, 1]
94 | __host__ __device__ __forceinline__ size_t get_param_k_elem_offset(size_t page_idx, size_t head_idx,
95 | size_t entry_idx) {
96 | return ((page_idx * num_layers + layer_idx) * 2 * num_heads + head_idx) * page_size + entry_idx;
97 | }
98 |
99 | // \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, 1]
100 | __host__ __device__ __forceinline__ size_t get_param_v_elem_offset(size_t page_idx, size_t head_idx,
101 | size_t entry_idx) {
102 | return (((page_idx * num_layers + layer_idx) * 2 + 1) * num_heads + head_idx) * page_size + entry_idx;
103 | }
104 |
105 | __host__ __device__ __forceinline__ size_t get_valid_page_size(size_t batch_idx, size_t page_iter) {
106 | if (page_iter == indptr[batch_idx + 1] - 1) {
107 | return last_page_offset[batch_idx];
108 | } else {
109 | return page_size;
110 | }
111 | }
112 | };
113 |
114 | /*!
115 | * \brief: Append single token to the exisiting kv cache.
116 | * \note: Layout of key: [batch_size, num_heads, head_dim]
117 | * \note: this layout is natural output of previous dense layer, which don't need transpose.
118 | */
119 | template
120 | __global__ void AppendPagedKVCacheDecodeKernel(
121 | paged_kv_t paged_kv,
122 | DType* __restrict__ key,
123 | DType* __restrict__ value,
124 | half2* __restrict__ key_param,
125 | half2* __restrict__ value_param
126 | ) {
127 | size_t tx = threadIdx.x, ty = threadIdx.y;
128 | size_t num_heads = paged_kv.num_heads;
129 | size_t batch_idx = blockIdx.x / ((num_heads + bdy - 1) / bdy);
130 | size_t head_idx = (blockIdx.x % ((num_heads + bdy - 1) / bdy)) * bdy + ty;
131 | if (head_idx < num_heads) {
132 | // Pre-allocated enough space for the last page
133 | // seq_len included the added one
134 | size_t seq_len =
135 | (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size +
136 | paged_kv.last_page_offset[batch_idx];
137 |
138 | size_t page_idx =
139 | paged_kv.indices[paged_kv.indptr[batch_idx] + (seq_len - 1) / paged_kv.page_size];
140 | size_t entry_idx = (seq_len - 1) % paged_kv.page_size;
141 |
142 | vec_t::memcpy(
143 | quant::get_ptr(paged_kv.data, paged_kv.get_k_elem_offset(page_idx, head_idx, entry_idx, tx * vec_size)),
144 | quant::get_ptr(key, (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size));
145 |
146 | vec_t::memcpy(
147 | quant::get_ptr(paged_kv.data, paged_kv.get_v_elem_offset(page_idx, head_idx, entry_idx, tx * vec_size)),
148 | quant::get_ptr(value, (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size));
149 |
150 | // Copy the quantization parameters
151 | // One group only copies once
152 | if(tx == 0){
153 | quant::get_ptr(
154 | paged_kv.param,
155 | paged_kv.get_param_k_elem_offset(page_idx, head_idx, entry_idx)
156 | )[0] = key_param[batch_idx * num_heads + head_idx];
157 |
158 | quant::get_ptr(
159 | paged_kv.param,
160 | paged_kv.get_param_v_elem_offset(page_idx, head_idx, entry_idx)
161 | )[0] = value_param[batch_idx * num_heads + head_idx];
162 | }
163 | }
164 | }
165 |
166 | template
167 | __global__ void AppendPagedKVCachePrefillKernel(
168 | paged_kv_t paged_kv,
169 | DType* __restrict__ key,
170 | DType* __restrict__ value,
171 | half2* __restrict__ key_param,
172 | half2* __restrict__ value_param,
173 | IdType* __restrict__ append_indptr
174 | ) {
175 | size_t tx = threadIdx.x, ty = threadIdx.y;
176 | size_t num_heads = paged_kv.num_heads;
177 | size_t batch_idx = blockIdx.x / ((num_heads + bdy - 1) / bdy);
178 | size_t head_idx = (blockIdx.x % ((num_heads + bdy - 1) / bdy)) * bdy + ty;
179 | if (head_idx < num_heads) {
180 |
181 | // Pre-filled seq_len
182 | size_t seq_len =
183 | (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size +
184 | paged_kv.last_page_offset[batch_idx];
185 | // Calculated to-be-filled seq_len
186 | size_t append_seq_len = append_indptr[batch_idx + 1] - append_indptr[batch_idx];
187 | size_t append_start = seq_len - append_seq_len;
188 |
189 | #pragma unroll 2
190 | for (size_t j = 0; j < append_seq_len; ++j) {
191 | size_t page_seq_idx = j + append_start;
192 | size_t page_idx =
193 | paged_kv.indices[paged_kv.indptr[batch_idx] + page_seq_idx / paged_kv.page_size];
194 | size_t entry_idx = page_seq_idx % paged_kv.page_size;
195 |
196 | vec_t::memcpy(
197 | quant::get_ptr(paged_kv.data, paged_kv.get_k_elem_offset(page_idx, head_idx, entry_idx, tx * vec_size)),
198 | quant::get_ptr(key, ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size));
199 |
200 | vec_t::memcpy(
201 | quant::get_ptr(paged_kv.data, paged_kv.get_v_elem_offset(page_idx, head_idx, entry_idx, tx * vec_size)),
202 | quant::get_ptr(value, ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size));
203 |
204 | // Copy the quantization parameters
205 | // One group only copies once
206 | if(tx == 0){
207 | quant::get_ptr(
208 | paged_kv.param,
209 | paged_kv.get_param_k_elem_offset(page_idx, head_idx, entry_idx)
210 | )[0] = key_param[(append_indptr[batch_idx] + j) * num_heads + head_idx];
211 |
212 | quant::get_ptr(
213 | paged_kv.param,
214 | paged_kv.get_param_v_elem_offset(page_idx, head_idx, entry_idx)
215 | )[0] = value_param[(append_indptr[batch_idx] + j) * num_heads + head_idx];
216 | }
217 | }
218 | }
219 | }
220 |
221 | template
222 | cudaError_t AppendPagedKVCacheDecode(
223 | paged_kv_t paged_kv,
224 | DType* key,
225 | DType* value,
226 | half2* key_param,
227 | half2* value_param,
228 | cudaStream_t stream = nullptr,
229 | size_t dev_id = 0
230 | ) {
231 | FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id));
232 | size_t head_dim = paged_kv.head_dim;
233 | size_t batch_size = paged_kv.batch_size;
234 | size_t num_heads = paged_kv.num_heads;
235 | SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
236 | constexpr size_t vec_size = std::max(static_cast(16 / quant::size_of_type()), HEAD_DIM / 32);
237 | constexpr size_t bdx = HEAD_DIM / vec_size;
238 | constexpr size_t bdy = 128 / bdx;
239 | assert(num_heads % bdy == 0);
240 | dim3 nblks(batch_size * num_heads / bdy);
241 | dim3 nthrs(bdx, bdy);
242 | auto kernel = AppendPagedKVCacheDecodeKernel;
243 | void* args[] = {
244 | (void*)&paged_kv,
245 | (void*)&key,
246 | (void*)&value,
247 | (void*)&key_param,
248 | (void*)&value_param
249 | };
250 | FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
251 | });
252 | return cudaSuccess;
253 | }
254 |
255 | template
256 | cudaError_t AppendPagedKVCachePrefill(
257 | paged_kv_t paged_kv,
258 | DType* key,
259 | DType* value,
260 | half2* key_param,
261 | half2* value_param,
262 | IdType* append_indptr,
263 | cudaStream_t stream = nullptr,
264 | size_t dev_id = 0
265 | ) {
266 | FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id));
267 | size_t head_dim = paged_kv.head_dim;
268 | size_t batch_size = paged_kv.batch_size;
269 | size_t num_heads = paged_kv.num_heads;
270 | SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
271 | constexpr size_t vec_size = std::max(static_cast(16 / quant::size_of_type()), HEAD_DIM / 32);
272 | constexpr size_t bdx = HEAD_DIM / vec_size;
273 | constexpr size_t bdy = 128 / bdx;
274 | assert(num_heads % bdy == 0);
275 | dim3 nblks(batch_size * num_heads / bdy);
276 | dim3 nthrs(bdx, bdy);
277 | auto kernel = AppendPagedKVCachePrefillKernel;
278 | void* args[] = {
279 | (void*)&paged_kv,
280 | (void*)&key,
281 | (void*)&value,
282 | (void*)&key_param,
283 | (void*)&value_param,
284 | (void*)&append_indptr
285 | };
286 | FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
287 | });
288 | return cudaSuccess;
289 | }
290 |
291 | } // namespace flashinfer
292 |
293 | #endif // FLAHSINFER_PAGE_CUH_
294 |
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/permuted_smem.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_PERMUTED_SMEM_CUH_
2 | #define FLASHINFER_PERMUTED_SMEM_CUH_
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 |
10 | #include "cp_async.cuh"
11 | #include "mma.cuh"
12 |
13 | namespace flashinfer {
14 |
15 | // Each bank is 4 bytes.
16 | using bank_t = uint4;
17 |
18 | template
19 | constexpr __host__ __device__ __forceinline__ size_t bank_capacity() {
20 | return sizeof(bank_t) / sizeof(T);
21 | }
22 |
23 | namespace permuted_smem_impl {
24 |
25 | /*!
26 | * \brief Compute the address of the element at (i, j) in the permuted shared memory.
27 | * \tparam stride The number of banks per row in the permuted shared memory.
28 | * \param smem_base The base address of the permuted shared memory.
29 | * \param i The row index.
30 | * \param j The column (bank) index.
31 | * \note The permuted shared memory maps 8x4 block in logical space to 4x8 block in physical space.
32 | * \see GTC 2020: Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100.
33 | */
34 | template
35 | __host__ __device__ __forceinline__ bank_t *get_smem_ptr(T *smem_base, size_t i, size_t j) {
36 | return ((bank_t *)smem_base) + (i / 2) * stride * 2 + (j / 4) * 8 + (i % 2) * 4 +
37 | ((j % 4) ^ ((i / 2) % 4));
38 | }
39 |
40 | template
41 | __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t *R, T *smem_base, size_t i, size_t j) {
42 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j);
43 | mma::ldmatrix_m8n8x4(R, smem_ptr);
44 | }
45 |
46 | template
47 | __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t *R, T *smem_base, size_t i,
48 | size_t j) {
49 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j);
50 | mma::ldmatrix_m8n8x4_trans(R, smem_ptr);
51 | }
52 |
53 | template
54 | __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t *R, T *smem_base, size_t i, size_t j) {
55 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j);
56 | mma::stmatrix_m8n8x4(R, smem_ptr);
57 | }
58 |
59 | template
60 | __device__ __forceinline__ void load_bank(T *smem_base, size_t i, size_t j, const T *gptr) {
61 | *get_smem_ptr(smem_base, i, j) = *reinterpret_cast(gptr);
62 | }
63 |
64 | template
65 | __device__ __forceinline__ void store_bank(T *smem_base, size_t i, size_t j, T *gptr) {
66 | *reinterpret_cast(gptr) = *get_smem_ptr(smem_base, i, j);
67 | }
68 |
69 | template
70 | __device__ __forceinline__ void load_bank_async(T *smem_base, size_t i, size_t j, const T *gptr,
71 | bool predicate) {
72 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j);
73 | cp_async::pred_load_128(smem_ptr, reinterpret_cast(gptr), predicate);
74 | }
75 |
76 | template
77 | __device__ __forceinline__ void load_bank_async(T *smem_base, size_t i, size_t j, const T *gptr) {
78 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j);
79 | cp_async::load_128(smem_ptr, reinterpret_cast(gptr));
80 | }
81 |
82 | } // namespace permuted_smem_impl
83 |
84 | template
85 | struct permuted_smem_t {
86 | T __align__(16) * base;
87 | __device__ __forceinline__ permuted_smem_t() : base(nullptr) {}
88 | __device__ __forceinline__ permuted_smem_t(T *base) : base(base) {}
89 | __device__ __forceinline__ bank_t *get_ptr(size_t i, size_t j) {
90 | return permuted_smem_impl::get_smem_ptr(base, i, j);
91 | }
92 | __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t *R, size_t i, size_t j) {
93 | permuted_smem_impl::ldmatrix_m8n8x4(R, base, i, j);
94 | }
95 | __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t *R, size_t i, size_t j) {
96 | permuted_smem_impl::ldmatrix_m8n8x4_trans(R, base, i, j);
97 | }
98 | __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t *R, size_t i, size_t j) {
99 | permuted_smem_impl::stmatrix_m8n8x4(R, base, i, j);
100 | }
101 | __device__ __forceinline__ void load_bank(size_t i, size_t j, const T *gptr) {
102 | permuted_smem_impl::load_bank(base, i, j, gptr);
103 | }
104 | __device__ __forceinline__ void load_bank_async(size_t i, size_t j, const T *gptr,
105 | bool predicate) {
106 | permuted_smem_impl::load_bank_async(base, i, j, gptr, predicate);
107 | }
108 | __device__ __forceinline__ void load_bank_async(size_t i, size_t j, const T *gptr) {
109 | permuted_smem_impl::load_bank_async(base, i, j, gptr);
110 | }
111 | __device__ __forceinline__ void store_bank(size_t i, size_t j, T *gptr) {
112 | permuted_smem_impl::store_bank(base, i, j, gptr);
113 | }
114 | };
115 |
116 | } // namespace flashinfer
117 |
118 | #endif // FLASHINFER_PERMUTED_SMEM_CUH_
119 |
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/quantization.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_QUANTIZATION_CUH_
2 | #define FLASHINFER_QUANTIZATION_CUH_
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include "vec_dtypes.cuh"
9 |
10 | namespace flashinfer{
11 | namespace quant{
12 |
13 | /*!
14 | * \brief Identifier used for 4bit quantization,
15 | * for data size calculattion.
16 | */
17 | struct __precision__s4{};
18 |
19 | /*!
20 | * \brief Simliar to sizeof
21 | * \tparam T Data type to be sizeof
22 | */
23 | template
24 | FLASHINFER_INLINE constexpr float size_of_type(){
25 | if constexpr (std::is_same::value){
26 | return 0.5f;
27 | }else{
28 | return sizeof(T);
29 | }
30 | }
31 |
32 | /*!
33 | * \brief Used to get the pointer by offset.
34 | * \tparam T A template indicates the data type
35 | * \param ptr Pointer to the data
36 | * \param offset Offset to the pointer
37 | */
38 | template
39 | FLASHINFER_INLINE T* get_ptr(T* ptr, const size_t offset){
40 | if constexpr (std::is_same::value){
41 | return reinterpret_cast(reinterpret_cast(ptr) + offset / 2);
42 | }else if constexpr (std::is_same::value){
43 | // Patch for const qualifiers
44 | return reinterpret_cast(reinterpret_cast(ptr) + offset / 2);
45 | }else{
46 | return ptr + offset;
47 | }
48 | }
49 |
50 | /*!
51 | * \brief Dequantize the input into vec_t
52 | * \tparam src_float_t A template indicates the quantization data type
53 | * \tparam vec_size A template integer indicates the vector size
54 | * \param src Const input data
55 | * \param tgt Output data
56 | * \param scale Quantization parameter
57 | * \param zero_point Quantization parameter
58 | */
59 | template
60 | FLASHINFER_INLINE void dequantize_impl(
61 | const vec_t &src,
62 | vec_t &tgt,
63 | float scale,
64 | float zero_point
65 | ){
66 | if constexpr (std::is_same::value){
67 | // 4bit asymmetric quantization
68 | static_assert(vec_size % 8 == 0, "32bits pack 8 u4 elements.");
69 | // 8 x s4 in int32_t register
70 | constexpr size_t PACK_NUM = 8;
71 | #pragma unroll
72 | for(int i = 0;i < vec_size / PACK_NUM;++i){
73 | uint32_t packedValue = src.at(i);
74 | #pragma unroll
75 | for(int j = 0; j < PACK_NUM;++j){
76 | float unpackedValue = static_cast(packedValue & 0xf) * scale - zero_point;
77 | tgt[i * PACK_NUM + j] = unpackedValue;
78 | packedValue >>= 4;
79 | }
80 | }
81 | }else{
82 | // Not implemented
83 | }
84 | }
85 | } // namespace quant
86 | } // namespace flashinfer
87 |
88 | #endif // FLASHINFER_QUANTIZATION_CUH_
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/rope.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_ROPE_CUH_
2 | #define FLASHINFER_ROPE_CUH_
3 |
4 | #include
5 |
6 | namespace flashinfer {
7 |
8 | /*!
9 | * \brief An enumeration class that defines different modes for applying RoPE
10 | * (Rotary Positional Embeddings).
11 | */
12 | enum class RotaryMode {
13 | // No rotary positional embeddings
14 | kNone = 0U,
15 | // Apply Llama-style rope.
16 | kLlama = 1U,
17 | };
18 |
19 | /*!
20 | * \brief Convert RotaryMode to string
21 | * \param rotary_mode A RotaryMode value
22 | */
23 | inline std::string RotaryModeToString(const RotaryMode &rotary_mode) {
24 | switch (rotary_mode) {
25 | case RotaryMode::kNone:
26 | return "None";
27 | case RotaryMode::kLlama:
28 | return "Llama";
29 | default:
30 | return "Unknown";
31 | }
32 | }
33 |
34 | } // namespace flashinfer
35 |
36 | #endif // FLASHINFER_ROPE_CUH_
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/state.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_STATE_CUH_
2 | #define FLASHINFER_STATE_CUH_
3 |
4 | #include "math.cuh"
5 | #include "vec_dtypes.cuh"
6 |
7 | namespace flashinfer {
8 |
9 | /*!
10 | * \brief The flashattention state.
11 | * \tparam vec_size The size of the vector used in o.
12 | * \tparam norm_on_the_fly Whether to normalize the state on the fly. If true, the state will be
13 | * normalized when merge() is called. If false, the state will be normalized when normalize() is
14 | * called.
15 | */
16 | template
17 | struct state_t {
18 | vec_t o; /* the weighted sum of v: exp(pre-softmax logit - m) * v / d */
19 | float m; /* maximum value of pre-softmax logits */
20 | float d; /* sum of exp(pre-softmax logits - m) */
21 |
22 | __device__ __forceinline__ void init() {
23 | o.fill(0.f);
24 | m = -5e4;
25 | d = 0.f;
26 | }
27 |
28 | __device__ __forceinline__ state_t() { init(); }
29 |
30 | /*!
31 | * \brief Merge the state with another state.
32 | * \param other_m The maximum value of pre-softmax logits of the other state.
33 | * \param other_d The sum of exp(pre-softmax logits - m) of the other state.
34 | * \param other_o The weighted sum of v of the other state.
35 | */
36 | __device__ __forceinline__ void merge(const vec_t &other_o, float other_m,
37 | float other_d) {
38 | float m_prev = m, d_prev = d;
39 | m = max(m_prev, other_m);
40 | d = d_prev * math::ptx_exp2(m_prev - m) + other_d * math::ptx_exp2(other_m - m);
41 | if constexpr (norm_on_the_fly) {
42 | #pragma unroll
43 | for (size_t i = 0; i < vec_size; ++i) {
44 | o[i] = o[i] * math::ptx_exp2(m_prev - m) * (d_prev / d) +
45 | other_o[i] * math::ptx_exp2(other_m - m) * (other_d / d);
46 | }
47 | } else {
48 | #pragma unroll
49 | for (size_t i = 0; i < vec_size; ++i) {
50 | o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(other_m - m);
51 | }
52 | }
53 | }
54 |
55 | /*!
56 | * \brief Merge the state with another state.
57 | * \param other The other state.
58 | */
59 | __device__ __forceinline__ void merge(const state_t &other) {
60 | merge(other.o, other.m, other.d);
61 | }
62 |
63 | /*!
64 | * \brief Merge the state with a single pre-softmax logit and value vector.
65 | * \param x The pre-softmax logit.
66 | * \param v The value vector.
67 | */
68 | __device__ __forceinline__ void merge(const vec_t &other_o, float x) {
69 | float m_prev = m, d_prev = d;
70 | m = max(m_prev, x);
71 | d = d * math::ptx_exp2(m_prev - m) + math::ptx_exp2(x - m);
72 | if constexpr (norm_on_the_fly) {
73 | #pragma unroll
74 | for (size_t i = 0; i < vec_size; ++i) {
75 | o[i] = o[i] * (math::ptx_exp2(m_prev - m) * d_prev / d) +
76 | other_o[i] * (math::ptx_exp2(x - m) / d);
77 | }
78 | } else {
79 | #pragma unroll
80 | for (size_t i = 0; i < vec_size; ++i) {
81 | o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(x - m);
82 | }
83 | }
84 | }
85 |
86 | __device__ __forceinline__ void normalize() {
87 | if constexpr (!norm_on_the_fly) {
88 | // only normalize by d when not normalized on the fly
89 | #pragma unroll
90 | for (size_t i = 0; i < vec_size; ++i) {
91 | o[i] = __fdividef(o[i], d);
92 | }
93 | }
94 | }
95 | };
96 |
97 | } // namespace flashinfer
98 |
99 | #endif // FLASHINFER_STATE_CUH_
--------------------------------------------------------------------------------
/quarot/kernels/include/flashinfer/utils.cuh:
--------------------------------------------------------------------------------
1 | #ifndef FLASHINFER_UTILS_CUH_
2 | #define FLASHINFER_UTILS_CUH_
3 | #include
4 |
5 | #include "layout.cuh"
6 | #include "rope.cuh"
7 |
8 | #define FLASHINFER_CUDA_CALL(func, ...) \
9 | { \
10 | cudaError_t e = (func); \
11 | if (e != cudaSuccess) { \
12 | return e; \
13 | } \
14 | }
15 |
16 | #define SWITCH_LAYOUT(layout, LAYOUT, ...) \
17 | switch (layout) { \
18 | case QKVLayout::kNHD: { \
19 | constexpr QKVLayout LAYOUT = QKVLayout::kNHD; \
20 | __VA_ARGS__ \
21 | break; \
22 | } \
23 | case QKVLayout::kHND: { \
24 | constexpr QKVLayout LAYOUT = QKVLayout::kHND; \
25 | __VA_ARGS__ \
26 | break; \
27 | } \
28 | default: { \
29 | std::cerr << "Unsupported qkv_layout: " << int(layout) << std::endl; \
30 | abort(); \
31 | } \
32 | }
33 |
34 | #define SWITCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
35 | switch (head_dim) { \
36 | case 64: { \
37 | constexpr size_t HEAD_DIM = 64; \
38 | __VA_ARGS__ \
39 | break; \
40 | } \
41 | case 128: { \
42 | constexpr size_t HEAD_DIM = 128; \
43 | __VA_ARGS__ \
44 | break; \
45 | } \
46 | case 256: { \
47 | constexpr size_t HEAD_DIM = 256; \
48 | __VA_ARGS__ \
49 | break; \
50 | } \
51 | default: { \
52 | std::cerr << "Unsupported head_dim: " << head_dim << std::endl; \
53 | abort(); \
54 | } \
55 | }
56 |
57 | #define SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \
58 | switch (rotary_mode) { \
59 | case RotaryMode::kNone: { \
60 | constexpr RotaryMode ROTARY_MODE = RotaryMode::kNone; \
61 | __VA_ARGS__ \
62 | break; \
63 | } \
64 | case RotaryMode::kLlama: { \
65 | constexpr RotaryMode ROTARY_MODE = RotaryMode::kLlama; \
66 | __VA_ARGS__ \
67 | break; \
68 | } \
69 | default: { \
70 | std::cerr << "Unsupported rotary_mode: " << int(rotary_mode) << std::endl; \
71 | abort(); \
72 | } \
73 | }
74 |
75 | #endif // FLASHINFER_UTILS_CUH_
--------------------------------------------------------------------------------
/quarot/kernels/include/gemm.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 |
6 | void matmul_host(
7 | const Int4Storage *A,
8 | const Int4Storage *B,
9 | uint32_t M,
10 | uint32_t N,
11 | uint32_t K,
12 | int32_t *C
13 | );
--------------------------------------------------------------------------------
/quarot/kernels/include/int4.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | namespace cutlass {
5 |
6 | template
10 | class MySubbyteReference {
11 | public:
12 | using Element = Element_;
13 | using Storage = Storage_;
14 | using StoragePointer = Storage *;
15 |
16 | static_assert(sizeof_bits::value <= sizeof_bits::value,
17 | "Size of Element must not be greater than Storage.");
18 |
19 | static_assert(!(sizeof_bits::value % sizeof_bits::value),
20 | "Storage must be divisible by Element");
21 |
22 | constexpr static int const kElementsPerVector =
23 | sizeof_bits::value / sizeof_bits::value;
24 |
25 | private:
26 | ///! Number of elements per storage vector
27 |
28 | ///! Bit mask
29 | Storage const kMask =
30 | ((sizeof_bits::value < sizeof_bits::value)
31 | ? (Storage(1) << sizeof_bits::value) - Storage(1)
32 | : ~Storage(0));
33 |
34 | private:
35 | /// Pointer to array containing element
36 | StoragePointer ptr_;
37 |
38 | /// Offset (in units of elements) from pointer.
39 | ///
40 | /// Invariant: must always be in range [0, kElementsPerVector)
41 | int offset_;
42 |
43 | public:
44 | CUTLASS_HOST_DEVICE
45 | MySubbyteReference() : ptr_(nullptr), offset_(0) {}
46 |
47 | /// Constructor
48 | CUTLASS_HOST_DEVICE
49 | MySubbyteReference(Element *ptr, /// pointer to memory
50 | int64_t offset /// logical offset in units of Element
51 | )
52 | : ptr_(reinterpret_cast(ptr)), offset_(0) {
53 | int64_t offset_in_vectors = offset / kElementsPerVector;
54 | int64_t offset_in_elements = offset % kElementsPerVector;
55 |
56 | ptr_ += offset_in_vectors;
57 | offset_ = int(offset_in_elements);
58 | }
59 |
60 | /// Constructor
61 | CUTLASS_HOST_DEVICE
62 | MySubbyteReference(Element *ptr = nullptr) : MySubbyteReference(ptr, 0) {}
63 |
64 | /// Gets storage pointer
65 | CUTLASS_HOST_DEVICE
66 | StoragePointer storage_pointer() const { return ptr_; }
67 |
68 | /// Gets storage pointer
69 | CUTLASS_HOST_DEVICE
70 | Element *operator&() const { return reinterpret_cast(ptr_); }
71 |
72 | /// Gets element offset within storage vector
73 | CUTLASS_HOST_DEVICE
74 | int element_offset() const { return offset_; }
75 |
76 | /// Unpacks an element from memory
77 | CUTLASS_HOST_DEVICE
78 | Element get() const {
79 | Storage item =
80 | Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask);
81 | return reinterpret_cast(item);
82 | }
83 |
84 | /// Stores an element to memory
85 | CUTLASS_HOST_DEVICE
86 | MySubbyteReference &set(Element const &x) {
87 | Storage item = (reinterpret_cast(x) & kMask);
88 | Storage kUpdateMask =
89 | Storage(~(kMask << (offset_ * cutlass::sizeof_bits::value)));
90 | Storage new_bits =
91 | Storage(item << (offset_ * cutlass::sizeof_bits::value));
92 |
93 | Storage original = (*ptr_);
94 | Storage updated = Storage((original & kUpdateMask) | new_bits);
95 | *ptr_ = updated;
96 |
97 | return *this;
98 | }
99 |
100 | ////
101 |
102 | /// Unpacks an element from memory
103 | CUTLASS_HOST_DEVICE
104 | operator Element() const { return get(); }
105 |
106 | /// Stores an element to memory
107 | CUTLASS_HOST_DEVICE
108 | MySubbyteReference &operator=(Element const &x) { return set(x); }
109 |
110 | /// Stores an element to memory
111 | CUTLASS_HOST_DEVICE
112 | MySubbyteReference &operator=(MySubbyteReference const &x) {
113 | return set(x.get());
114 | }
115 |
116 | /// Stores an element to memory
117 | CUTLASS_HOST_DEVICE
118 | MySubbyteReference &operator=(
119 | ConstSubbyteReference const &x) {
120 | return set(x.get());
121 | }
122 |
123 | /// Adds an offset in units of elements to the reference
124 | CUTLASS_HOST_DEVICE
125 | MySubbyteReference &operator+=(int offset) {
126 | offset += offset_;
127 |
128 | int offset_in_vectors = offset / kElementsPerVector;
129 | int offset_in_elements = offset % kElementsPerVector;
130 |
131 | ptr_ += offset_in_vectors;
132 | offset_ = offset_in_elements;
133 |
134 | return *this;
135 | }
136 |
137 | /// Adds an offset in units of elements to the reference
138 | CUTLASS_HOST_DEVICE
139 | MySubbyteReference &operator+=(long long offset) {
140 | offset += offset_;
141 |
142 | long long offset_in_vectors = offset / kElementsPerVector;
143 | int offset_in_elements = int(offset % kElementsPerVector);
144 |
145 | ptr_ += offset_in_vectors;
146 | offset_ = offset_in_elements;
147 |
148 | return *this;
149 | }
150 |
151 | /// Adds an offset in units of elements to the reference
152 | CUTLASS_HOST_DEVICE
153 | MySubbyteReference &operator-=(int offset) {
154 | int offset_in_vectors = offset / kElementsPerVector;
155 | int offset_in_elements = offset % kElementsPerVector;
156 |
157 | ptr_ -= offset_in_vectors;
158 | offset_ -= offset_in_elements;
159 |
160 | if (offset_ < 0) {
161 | offset_ += kElementsPerVector;
162 | --ptr_;
163 | }
164 |
165 | return *this;
166 | }
167 |
168 | /// Adds an offset in units of elements to the reference
169 | CUTLASS_HOST_DEVICE
170 | MySubbyteReference &operator-=(long long offset) {
171 | long long offset_in_vectors = offset / kElementsPerVector;
172 | int offset_in_elements = int(offset % kElementsPerVector);
173 |
174 | ptr_ -= offset_in_vectors;
175 | offset_ -= offset_in_elements;
176 |
177 | if (offset_ < 0) {
178 | offset_ += kElementsPerVector;
179 | --ptr_;
180 | }
181 |
182 | return *this;
183 | }
184 |
185 | /// Returns a reference to an element with a given offset from the current
186 | /// reference
187 | CUTLASS_HOST_DEVICE
188 | MySubbyteReference operator+(int offset) const {
189 | MySubbyteReference ref(ptr_, offset_);
190 | ref += offset;
191 |
192 | return ref;
193 | }
194 |
195 | /// Returns a reference to an element with a given offset from the current
196 | /// reference
197 | CUTLASS_HOST_DEVICE
198 | MySubbyteReference operator+(long long offset) const {
199 | MySubbyteReference ref(ptr_, offset_);
200 | ref += offset;
201 |
202 | return ref;
203 | }
204 |
205 | /// Returns a reference to an element with a given offset from the current
206 | /// reference
207 | CUTLASS_HOST_DEVICE
208 | MySubbyteReference operator-(int offset) const {
209 | MySubbyteReference ref(ptr_, offset_);
210 | ref -= offset;
211 |
212 | return ref;
213 | }
214 |
215 | /// Returns a reference to an element with a given offset from the current
216 | /// reference
217 | CUTLASS_HOST_DEVICE
218 | MySubbyteReference operator-=(long long offset) const {
219 | MySubbyteReference ref(ptr_, offset_);
220 | ref -= offset;
221 |
222 | return ref;
223 | }
224 |
225 | /// Computes the difference in elements between references
226 | CUTLASS_HOST_DEVICE
227 | ptrdiff_t operator-(MySubbyteReference ref) const {
228 | return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
229 | }
230 |
231 | /// Explicit cast to int
232 | CUTLASS_HOST_DEVICE
233 | explicit operator int() const { return int(get()); }
234 |
235 | /// Explicit cast to signed 64-bit integer
236 | CUTLASS_HOST_DEVICE
237 | explicit operator int64_t() const { return int64_t(get()); }
238 |
239 | /// Explicit cast to unsigned 64-bit integer
240 | CUTLASS_HOST_DEVICE
241 | explicit operator uint64_t() const { return uint64_t(get()); }
242 |
243 | /// Explicit cast to float
244 | CUTLASS_HOST_DEVICE
245 | explicit operator float() const { return float(get()); }
246 |
247 | /// Explicit cast to double
248 | CUTLASS_HOST_DEVICE
249 | explicit operator double() const { return double(get()); }
250 | };
251 |
252 | } // namespace cutlass
253 |
254 | using Int4Subbyte = cutlass::MySubbyteReference;
255 | using Int4Storage = Int4Subbyte::Storage;
256 | constexpr const uint32_t kElementsPerVector =
257 | cutlass::sizeof_bits::value /
258 | cutlass::sizeof_bits::value;
259 |
--------------------------------------------------------------------------------
/quarot/kernels/include/quant.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 |
6 | void sym_quant_host(
7 | const half *x,
8 | const half *scale,
9 | uint32_t rows,
10 | uint32_t colsSrc,
11 | uint32_t colsDst,
12 | Int4Storage *q
13 | );
14 |
15 |
16 | void sym_dequant_host(
17 | const int32_t *q,
18 | const half *scale_row,
19 | const half *scale_col,
20 | uint32_t rows,
21 | uint32_t cols,
22 | half *x
23 | );
--------------------------------------------------------------------------------
/quarot/kernels/include/util.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | /* TODO: This file can be safely discarded. Kept in case needed in the future.
3 |
4 | // #include
5 | // #include
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 |
12 |
13 | #define _BITS_STDINT_UINTN_H 1
14 | #define _BITS_STDINT_INTN_H 1
15 | #include
16 | typedef __int8_t int8_t;
17 | typedef __int16_t int16_t;
18 | typedef __int32_t int32_t;
19 | typedef __int64_t int64_t;
20 | typedef __uint8_t uint8_t;
21 | typedef __uint16_t uint16_t;
22 | typedef __uint32_t uint32_t;
23 | typedef __uint64_t uint64_t;
24 |
25 |
26 |
27 | template
28 | struct TorchDtypeDispatcher;
29 |
30 | template <>
31 | struct TorchDtypeDispatcher {
32 | constexpr static const auto value = torch::kUInt8;
33 | };
34 |
35 | template <>
36 | struct TorchDtypeDispatcher {
37 | constexpr static const auto value = torch::kInt8;
38 | };
39 |
40 | template <>
41 | struct TorchDtypeDispatcher {
42 | constexpr static const auto value = torch::kInt32;
43 | };
44 |
45 | template <>
46 | struct TorchDtypeDispatcher {
47 | constexpr static const auto value = torch::kFloat16;
48 | };
49 |
50 | template
51 | struct DtypeTorchDispatcher;
52 |
53 | template <>
54 | struct DtypeTorchDispatcher {
55 | using value = __half;
56 | };
57 |
58 | template <>
59 | struct DtypeTorchDispatcher {
60 | using value = __nv_bfloat16;
61 | };
62 |
63 | template
64 | __device__ inline int type2int_rn(T a) {
65 | return static_cast(a);
66 | }
67 |
68 | template <>
69 | __device__ inline int type2int_rn<__half>(__half input) {
70 | return __half2int_rn(input);
71 | }
72 |
73 | // template <>
74 | // __device__ inline int type2int_rn<__nv_bfloat16>(__nv_bfloat16 input) {
75 | // return __bfloat162int_rn(input);
76 | // }
77 |
78 | template
79 | __device__ inline float type2float(T a) {
80 | return static_cast(a);
81 | }
82 |
83 | template <>
84 | __device__ inline float type2float<__half>(__half input) {
85 | return __half2float(input);
86 | }
87 |
88 | template <>
89 | __device__ inline float type2float<__nv_bfloat16>(__nv_bfloat16 input) {
90 | return __bfloat162float(input);
91 | }
92 |
93 | template
94 | __device__ inline T float2type(float a) {
95 | return static_cast(a);
96 | }
97 |
98 | template <>
99 | __device__ inline __half float2type<__half>(float input) {
100 | return __float2half(input);
101 | }
102 |
103 | template <>
104 | __device__ inline __nv_bfloat16 float2type<__nv_bfloat16>(float input) {
105 | return __float2bfloat16_rn(input);
106 | }
107 |
108 | template
109 | struct DtypeDtype2Dispatcher;
110 |
111 | template <>
112 | struct DtypeDtype2Dispatcher<__half> {
113 | using value = __half2;
114 | };
115 |
116 | template <>
117 | struct DtypeDtype2Dispatcher<__nv_bfloat16> {
118 | using value = __nv_bfloat162;
119 | };
120 |
121 | __device__ inline __half2 type2type2(__half input, __half input2) {
122 | return __halves2half2(input, input2);
123 | }
124 |
125 | // __device__ inline __nv_bfloat162 type2type2(__nv_bfloat16 input,
126 | // __nv_bfloat16 input2) {
127 | // return __halves2bfloat162(input, input2);
128 | // }
129 |
130 | // template
131 | // T div(T a, T b) {
132 | // return a / b;
133 | // }
134 | //
135 | // template <>
136 | //__half div(__half a, __half b) {
137 | // return __hdiv(a, b);
138 | // }
139 | //
140 | // template <>
141 | //__nv_bfloat16 div(__nv_bfloat16 a, __nv_bfloat16 b) {
142 | // return __hdiv(a, b);
143 | // }
144 |
145 | */
--------------------------------------------------------------------------------
/quarot/kernels/quant.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | template
5 | __device__ __half int_to_half(T value)
6 | {
7 | return __int2half_rn(static_cast(value));
8 | }
9 |
10 |
11 | __global__
12 | void sym_quantize_f16_i4_kernel(
13 | const half *__restrict__ x,
14 | const half *__restrict__ scale,
15 | uint32_t rows,
16 | uint32_t colsSrc,
17 | uint32_t colsDst,
18 | Int4Storage *__restrict__ q
19 | )
20 | {
21 | uint32_t row = threadIdx.y + blockIdx.y * blockDim.y;
22 | uint32_t colDst = threadIdx.x + blockIdx.x * blockDim.x;
23 | if (row >= rows || colDst * kElementsPerVector >= colsSrc)
24 | {
25 | return;
26 | }
27 | Int4Storage storage;
28 | memset(&storage, 0, sizeof(storage));
29 | uint32_t id = colDst * kElementsPerVector + row * colsSrc;
30 | #pragma unroll
31 | for (int i = 0; i < kElementsPerVector; ++i)
32 | {
33 | bool safe = (colDst * kElementsPerVector + i) < colsSrc;
34 | if (safe)
35 | {
36 | half data = __hdiv(x[id + i], scale[row]);
37 |
38 | int qval = clamp(__half2int_rn(data), qmin, qmax);
39 | Int4Subbyte{reinterpret_cast(&storage), i}.set(
40 | qval);
41 | }
42 | }
43 |
44 | q[colDst + row * colsDst] = storage;
45 | }
46 |
47 |
48 | void sym_quant_host(
49 | const half *x,
50 | const half *scale,
51 | uint32_t rows,
52 | uint32_t colsSrc,
53 | uint32_t colsDst,
54 | Int4Storage *q
55 | )
56 | {
57 |
58 | dim3 block{std::min(colsDst, 32), std::min(rows, 16)};
59 | dim3 grid{cdiv(colsDst, block.x), cdiv(rows, block.y)};
60 | sym_quantize_f16_i4_kernel<<>>(x, scale, rows, colsSrc, colsDst, q);
61 | }
62 |
63 |
64 | __global__ void sym_dequantize_i32_f16_kernel(
65 | const int32_t *__restrict__ q,
66 | const half *__restrict__ scale_row,
67 | const half *__restrict__ scale_col,
68 | uint32_t rows, uint32_t cols,
69 | half *__restrict__ x)
70 | {
71 | uint32_t row = threadIdx.y + blockIdx.y * blockDim.y;
72 | uint32_t col = threadIdx.x + blockIdx.x * blockDim.x;
73 |
74 | if (col >= cols || row >= rows)
75 | {
76 | return;
77 | }
78 |
79 | half xElement = int_to_half(q[col + row * cols]);
80 | x[col + row * cols] = scale_row[row] * scale_col[col] * xElement;
81 | }
82 |
83 | void sym_dequant_host(const int32_t *q,
84 | const half *scale_row,
85 | const half *scale_col,
86 | uint32_t rows,
87 | uint32_t cols,
88 | half *x
89 | )
90 | {
91 | dim3 block{std::min(cols, 16), std::min(rows, 16)};
92 | dim3 grid{cdiv(cols, block.x), cdiv(rows, block.y)};
93 | sym_dequantize_i32_f16_kernel<<>>(
94 | q,
95 | scale_row, scale_col,
96 | rows, cols, x);
97 | }
98 |
--------------------------------------------------------------------------------
/quarot/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .linear import Linear4bit
2 | from .normalization import RMSNorm
3 | from .quantization import Quantizer
4 | from .hadamard import OnlineHadamard
5 |
--------------------------------------------------------------------------------
/quarot/nn/hadamard.py:
--------------------------------------------------------------------------------
1 | import quarot
2 | import torch
3 |
4 |
5 | class OnlineHadamard(torch.nn.Module):
6 | def __init__(self, hadamard_dim, force_fp32=False):
7 | super().__init__()
8 | self.fp32_had = force_fp32
9 | had_rem_dim, self.rem_dim = quarot.functional.hadamard.get_hadK(hadamard_dim)
10 | if had_rem_dim is not None:
11 | self.register_buffer("had_rem_dim", had_rem_dim)
12 | if not self.fp32_had:
13 | self.had_rem_dim = self.had_rem_dim.to(torch.float16)
14 | else:
15 | self.had_rem_dim = None
16 |
17 | def forward(self, x):
18 | x_dtype = x.dtype
19 | if self.fp32_had:
20 | x = x.float()
21 | x = quarot.functional.matmul_hadU_cuda(x, self.had_rem_dim, self.rem_dim)
22 | x = x.to(x_dtype)
23 | return x
24 |
--------------------------------------------------------------------------------
/quarot/nn/linear.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import quarot
4 | import fast_hadamard_transform
5 |
6 |
7 | class ShapeHandler:
8 | def __init__(self, x: torch.Tensor):
9 | self.size_excl_last = x.numel()//x.shape[-1]
10 | self.shape_excl_last = tuple(x.shape[:-1])
11 |
12 | # Keep the last dim unchanged, flatten all previous dims
13 | def flatten(self, x: torch.Tensor):
14 | return x.view(self.size_excl_last, -1)
15 |
16 | # Recover back to the original shape.
17 | def unflatten(self, x: torch.Tensor):
18 | return x.view(self.shape_excl_last + (-1,))
19 |
20 | def unflatten_scale(self, x: torch.Tensor):
21 | return x.view(self.shape_excl_last)
22 |
23 |
24 | class Linear4bit(torch.nn.Module):
25 | def __init__(self, in_features, out_features, bias=False, dtype=torch.float16):
26 | '''
27 | Symmetric 4-bit Linear Layer.
28 | '''
29 | super().__init__()
30 | self.in_features = in_features
31 | self.out_features = out_features
32 | self.register_buffer('weight_scales',
33 | torch.zeros((self.out_features, 1), requires_grad=False))
34 | self.register_buffer('weight', (torch.randint(1, 7, (self.out_features, self.in_features // 2),
35 | # SubByte weight
36 | dtype=torch.uint8, requires_grad=False)))
37 | if bias:
38 | self.register_buffer('bias', torch.zeros((self.out_features), dtype=dtype))
39 | else:
40 | self.bias = None
41 |
42 | def forward(self, x):
43 | #if torch.cuda.current_device() != x.device:
44 | # torch.cuda.set_device(x.device)
45 |
46 | assert type(x) == quarot.PackedQuantizedTensor #Quantized input is given
47 | x, scales_x = x.quantized_x, x.scales_x
48 | #shape_handler = ShapeHandler(quantized_x)
49 | #quantized_x = shape_handler.flatten(quantized_x)
50 | x = quarot.matmul(x, self.weight)
51 | #out = shape_handler.unflatten(
52 | # quarot.sym_dequant(int_result, scales_x, self.weight_scales))
53 | if self.bias is not None:
54 | return quarot.sym_dequant(x, scales_x, self.weight_scales) + self.bias
55 | else:
56 | return quarot.sym_dequant(x, scales_x, self.weight_scales)
57 |
58 | @staticmethod
59 | def from_float(module: torch.nn.Linear, weight_scales=None,):
60 | '''
61 | Generate a new Linear4bit module from a FP16 Linear module.
62 | The weight matrix should have the same shape as the weight matrix of the FP16 Linear module and rounded using torch.round()
63 | routine. We will convert it to subByte representation and save it in the int_weight buffer.
64 | '''
65 | weight_matrix = module.weight.data
66 |
67 |
68 | int_module = Linear4bit(module.in_features, module.out_features, bias=module.bias is not None, dtype=weight_matrix.dtype).to(weight_matrix.dtype)
69 | if weight_scales is not None:
70 | assert weight_scales.shape == (module.out_features, 1), 'weight_scales should have shape (out_features, 1)'
71 | weight_matrix = weight_matrix.cuda()
72 | int_module.weight_scales.copy_(weight_scales.to(weight_matrix.dtype))
73 | int_rounded_weight = (weight_matrix/weight_scales.cuda()).round()
74 | int_module.weight.copy_(quarot.functional.pack_i4(int_rounded_weight.to(torch.int8)).cpu())
75 |
76 | if module.bias is not None:
77 | int_module.bias.copy_(module.bias)
78 |
79 | return int_module
80 |
--------------------------------------------------------------------------------
/quarot/nn/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class RMSNorm(torch.nn.Module):
4 | """
5 | This class implements the Root Mean Square Normalization (RMSN) layer.
6 | We use the implementation from LLAMARMSNorm here:
7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L75
8 | """
9 |
10 | def __init__(self, mean_dim: int, eps=1e-5):
11 | super().__init__()
12 | self.eps = eps
13 | self.mean_dim = mean_dim
14 |
15 | def forward(self, x: torch.Tensor) -> torch.Tensor:
16 | input_dtype = x.dtype
17 | if x.dtype == torch.float16:
18 | x = x.to(torch.float32)
19 | variance = x.pow(2).sum(-1, keepdim=True) / self.mean_dim
20 | x = x * torch.rsqrt(variance + self.eps)
21 | return x.to(input_dtype)
22 |
--------------------------------------------------------------------------------
/quarot/nn/quantization.py:
--------------------------------------------------------------------------------
1 | import quarot
2 | import torch
3 |
4 | class Quantizer(torch.nn.Module):
5 | def __init__(self, input_clip_ratio=1.0):
6 | super().__init__()
7 | self.input_clip_ratio = input_clip_ratio
8 |
9 | def forward(self, x):
10 | scales_x = (torch.max(torch.abs(x), dim=-1)[0].unsqueeze(1)/7).to(torch.float16) * self.input_clip_ratio
11 | quantized_x = quarot.sym_quant(x, scales_x)
12 | packed_tensor = quarot.PackedQuantizedTensor(quantized_x, scales_x)
13 | return packed_tensor
14 |
--------------------------------------------------------------------------------
/quarot/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | from .kv_cache import MultiLayerPagedKVCache4Bit
2 |
--------------------------------------------------------------------------------
/quarot/transformers/kv_cache.py:
--------------------------------------------------------------------------------
1 | from transformers.cache_utils import Cache
2 | from typing import Optional, Tuple, Dict, Any
3 | import math
4 | import torch
5 | from .. import _CUDA
6 | import functools
7 | from fast_hadamard_transform import hadamard_transform
8 | from quarot.functional.quantization import get_minq_maxq
9 |
10 | @torch.jit.script
11 | def asym_quantize_and_pack_i4(x: torch.Tensor):
12 | minq, maxq = get_minq_maxq(bits=4, sym=False)
13 | xmax = torch.amax(x, dim=-1, keepdim=True)
14 | xmin = torch.amin(x, dim=-1, keepdim=True)
15 | scale = ((xmax - xmin).clamp(min=1e-5) / maxq)
16 | zero = -xmin
17 | q = torch.clamp(torch.round((x + zero) / scale), 0, maxq)
18 |
19 | # pack int4
20 | q = q.to(dtype=torch.uint8)
21 | q = q[..., 0::2] | (q[..., 1::2] << 4)
22 | return q, scale, zero
23 |
24 | def unpack_i4_and_asym_dequantize(q, scale, zero):
25 | #unpack int4
26 | assert q.dtype == torch.uint8
27 | q = torch.stack((q & 0x0f, (q >> 4) & 0x0f), dim=-1).view(*q.shape[:-1], q.shape[-1] * 2)
28 | return q * scale - zero
29 |
30 | def matmul_had_cuda(X, dtype):
31 | n = X.shape[-1]
32 | input = hadamard_transform(X.to(dtype).contiguous(), scale=1/math.sqrt(n))
33 | return input.to(X.dtype).view(X.shape)
34 |
35 |
36 | def init_kv_i4(kv_data, kv_param,
37 | kv_indptr, kv_indices,
38 | last_page_offset, k,
39 | v, k_param, v_param,
40 | seqlen_indptr, layer_idx):
41 | return _CUDA.init_kv_i4(
42 | kv_data, kv_param,
43 | kv_indptr, kv_indices,
44 | last_page_offset, k,
45 | v, k_param, v_param,
46 | seqlen_indptr, layer_idx)
47 |
48 |
49 | def append_kv_i4(kv_data, kv_param,
50 | kv_indptr, kv_indices,
51 | last_page_offset, k,
52 | v, k_param, v_param,
53 | layer_idx):
54 | return _CUDA.append_kv_i4(
55 | kv_data, kv_param,
56 | kv_indptr, kv_indices,
57 | last_page_offset, k,
58 | v, k_param, v_param,
59 | layer_idx)
60 |
61 | def batch_decode_i4(o, q, kv_data, kv_param,
62 | kv_indptr, kv_indices,
63 | last_page_offset, layer_idx):
64 | return _CUDA.batch_decode_i4(
65 | o, q, kv_data, kv_param,
66 | kv_indptr, kv_indices,
67 | last_page_offset, layer_idx)
68 |
69 |
70 | def init_kv_f16(kv_data, kv_param,
71 | kv_indptr, kv_indices,
72 | last_page_offset, k,
73 | v, k_param, v_param,
74 | seqlen_indptr, layer_idx):
75 | return _CUDA.init_kv_f16(
76 | kv_data, kv_param,
77 | kv_indptr, kv_indices,
78 | last_page_offset, k,
79 | v, k_param, v_param,
80 | seqlen_indptr, layer_idx)
81 |
82 |
83 | def append_kv_f16(kv_data, kv_param,
84 | kv_indptr, kv_indices,
85 | last_page_offset, k,
86 | v, k_param, v_param,
87 | layer_idx):
88 | return _CUDA.append_kv_f16(
89 | kv_data, kv_param,
90 | kv_indptr, kv_indices,
91 | last_page_offset, k,
92 | v, k_param, v_param,
93 | layer_idx)
94 |
95 | def batch_decode_f16(o, q, kv_data, kv_param,
96 | kv_indptr, kv_indices,
97 | last_page_offset, layer_idx):
98 | return _CUDA.batch_decode_f16(
99 | o, q, kv_data, kv_param,
100 | kv_indptr, kv_indices,
101 | last_page_offset, layer_idx)
102 |
103 |
104 | class _AttentionStub(object):
105 | def __init__(self, cache_page_size, device, n_layers, disable_quant, hadamard_dtype):
106 | self.cache_page_size = cache_page_size
107 | self.n_layers = n_layers
108 | self.disable_quant = disable_quant
109 | self.hadamard_dtype = hadamard_dtype
110 |
111 | def forward(self, q, num_kv_heads, attention_kwargs, layer_idx):
112 | batch_size, q_len, num_qo_heads, head_dim = q.shape
113 | assert q_len == 1
114 | q = q.view(batch_size, num_qo_heads, head_dim)
115 | if self.hadamard_dtype is not None:
116 | q = matmul_had_cuda(q, dtype=self.hadamard_dtype)
117 | attn_output = torch.empty_like(q)
118 | if self.disable_quant:
119 | batch_decode = batch_decode_f16
120 | else:
121 | batch_decode = batch_decode_i4
122 | batch_decode(
123 | attn_output, q,
124 | **attention_kwargs, layer_idx=layer_idx
125 | )
126 | attn_output = attn_output.unsqueeze(1)
127 | return attn_output
128 |
129 |
130 | class MultiLayerPagedKVCache4Bit(Cache):
131 | def __init__(
132 | self, batch_size, page_size, max_seq_len,
133 | device, n_layers, num_heads, head_dim,
134 | disable_quant=False, hadamard_dtype=torch.float16 ):
135 | self.page_size = page_size
136 | self.batch_size = batch_size
137 | max_page_cnt = self.page_cnt_from_length(max_seq_len)
138 | self.disable_quant = disable_quant
139 | self.pages = torch.empty(
140 | (
141 | max_page_cnt * batch_size,
142 | n_layers,
143 | 2,
144 | num_heads,
145 | page_size,
146 | head_dim if disable_quant else head_dim // 2
147 | ),
148 | dtype=torch.float16 if disable_quant else torch.uint8, device=device)
149 |
150 | self.scales = torch.empty((max_page_cnt * batch_size, n_layers, 2, num_heads, page_size, 2), dtype=torch.float16, device=device)
151 | self.page_size = page_size
152 | self.max_seq_len = max_seq_len
153 | self._needs_init = [True] * n_layers
154 | self.length = 0
155 | self.device = device
156 | self.hadamard_dtype = hadamard_dtype
157 | self._stub = _AttentionStub(
158 | self.page_size, device, n_layers,
159 | disable_quant=self.disable_quant,
160 | hadamard_dtype=self.hadamard_dtype)
161 |
162 | def page_cnt_from_length(self, length):
163 | return (length + self.page_size - 1) // self.page_size
164 |
165 | def _ensure_page_cnt_per_batch(self, expected_page_cnt_per_batch):
166 | expected_page_cnt = expected_page_cnt_per_batch * self.batch_size
167 | if expected_page_cnt <= self.pages.shape[0]:
168 | return
169 | raise NotImplementedError
170 |
171 | @property
172 | def seen_tokens(self):
173 | return self.length
174 |
175 | def update(
176 | self,
177 | key_states: torch.Tensor,
178 | value_states: torch.Tensor,
179 | layer_idx: int,
180 | cache_kwargs: Optional[Dict[str, Any]] = None,
181 | ):
182 |
183 | b_sz, added_length, num_heads, head_dim = key_states.shape
184 |
185 | orig_key_states = key_states
186 | orig_value_states = value_states
187 |
188 | if self.hadamard_dtype is not None:
189 | key_states = matmul_had_cuda(key_states, dtype=self.hadamard_dtype)
190 |
191 | if self.disable_quant:
192 | k_scale = key_states.new_ones((b_sz, added_length, num_heads, 1))
193 | k_zero = key_states.new_zeros((b_sz, added_length, num_heads, 1))
194 | v_scale = value_states.new_ones((b_sz, added_length, num_heads, 1))
195 | v_zero = value_states.new_zeros((b_sz, added_length, num_heads, 1))
196 | else:
197 | key_states, k_scale, k_zero = asym_quantize_and_pack_i4(key_states)
198 | value_states, v_scale, v_zero = asym_quantize_and_pack_i4(value_states)
199 |
200 | k_param = torch.cat([k_scale, k_zero], dim=-1).view(self.batch_size * added_length, num_heads, 2)
201 | v_param = torch.cat([v_scale, v_zero], dim=-1).view(self.batch_size * added_length, num_heads, 2)
202 |
203 | quantized_head_dim = self.pages.shape[-1]
204 |
205 | assert b_sz == self.batch_size
206 | if layer_idx == 0:
207 | current_length = self.length
208 | new_length = current_length + added_length
209 | self._ensure_page_cnt_per_batch(self.page_cnt_from_length(new_length))
210 | self.length = new_length
211 | attention_mask = cache_kwargs.get("attention_mask")
212 | if self._needs_init[layer_idx]:
213 | self._needs_init[layer_idx] = False
214 | if attention_mask is not None:
215 | nonzero_indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten().view(-1, 1)
216 | key_states = key_states.view(self.batch_size * added_length, num_heads * quantized_head_dim)
217 | value_states = value_states.view(self.batch_size * added_length, num_heads * quantized_head_dim)
218 | key_states = torch.gather(key_states, 0, nonzero_indices.expand(-1, num_heads * quantized_head_dim))
219 | value_states = torch.gather(value_states, 0, nonzero_indices.expand(-1, num_heads * quantized_head_dim))
220 |
221 | k_param = k_param.view(self.batch_size * added_length, num_heads * 2)
222 | v_param = v_param.view(self.batch_size * added_length, num_heads * 2)
223 | k_param = torch.gather(k_param, 0, nonzero_indices.expand(-1, num_heads * 2))
224 | v_param = torch.gather(v_param, 0, nonzero_indices.expand(-1, num_heads * 2))
225 |
226 | seqlens_in_batch = torch.nn.functional.pad(torch.cumsum(attention_mask.sum(dim=-1, dtype=torch.int32), dim=0, dtype=torch.int32), (1, 0))
227 | else:
228 | seqlens_in_batch = torch.arange(self.batch_size + 1, device=self.device, dtype=torch.int) * added_length
229 |
230 | init_kv = init_kv_f16 if self.disable_quant else init_kv_i4
231 | init_kv(
232 | **self.get_cache_specs_for_flash_infer(attention_mask),
233 | k=key_states.view(-1, num_heads, quantized_head_dim),
234 | v=value_states.view(-1, num_heads, quantized_head_dim),
235 | k_param=k_param.view(-1, num_heads, 2),
236 | v_param=v_param.view(-1, num_heads, 2),
237 | seqlen_indptr=seqlens_in_batch,
238 | layer_idx=layer_idx
239 | )
240 | return orig_key_states, orig_value_states
241 | else:
242 | assert added_length == 1
243 | append_kv = append_kv_f16 if self.disable_quant else append_kv_i4
244 | append_kv(
245 | **self.get_cache_specs_for_flash_infer(attention_mask),
246 | k=key_states.view(self.batch_size, num_heads, quantized_head_dim),
247 | v=value_states.view(self.batch_size, num_heads, quantized_head_dim),
248 | k_param=k_param.view(-1, num_heads, 2),
249 | v_param=v_param.view(-1, num_heads, 2),
250 | layer_idx=layer_idx,
251 | )
252 | return functools.partial(
253 | self._stub.forward,
254 | num_kv_heads=num_heads,
255 | attention_kwargs=self.get_cache_specs_for_flash_infer(attention_mask),
256 | layer_idx=layer_idx,
257 | )
258 |
259 | def get_cache_specs_for_flash_infer(self, attention_mask):
260 | if attention_mask is not None:
261 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
262 | else:
263 | seqlens_in_batch = torch.tensor([self.length], dtype=torch.int32, device=self.device).expand(self.batch_size)
264 | page_cnt = self.page_cnt_from_length(seqlens_in_batch)
265 | if (page_cnt[0] != page_cnt).any():
266 | raise NotImplementedError("Current implementation does not support the case where batches have different number of pages")
267 | page_cnt = page_cnt[0]
268 | page_ptr = seqlens_in_batch % self.page_size
269 | page_ptr = torch.where((seqlens_in_batch != 0) & (page_ptr == 0), self.page_size, page_ptr)
270 | return {
271 | f"kv_data": self.pages,
272 | f"kv_indptr": torch.arange(0, self.batch_size + 1, device=self.device, dtype=torch.int) * page_cnt,
273 | f"kv_indices": (
274 | (torch.arange(page_cnt, device=self.device, dtype=torch.int) * self.batch_size).unsqueeze(0) +
275 | torch.arange(self.batch_size, device=self.device, dtype=torch.int).unsqueeze(1)).view(-1),
276 | f"last_page_offset": page_ptr, #torch.full((self.batch_size, ), page_ptr, device=self.device, dtype=torch.int),
277 | f"kv_param": self.scales,
278 | }
279 |
280 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
281 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
282 | return self.length
283 |
284 | def get_max_length(self) -> Optional[int]:
285 | """Returns the maximum sequence length of the cached states, if there is any."""
286 | return None
287 |
288 | def to_legacy_cache(self):
289 | return self
290 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.38.0
2 | torch==2.2.1
3 | sentencepiece==0.2.0
4 | wandb==0.16.3
5 | huggingface-hub==0.20.3
6 | accelerate==0.27.2
7 | datasets==2.17.1
8 | lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@9b0b15b1ccace3534ffbd13298c569869ce8eaf3
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | import torch.utils.cpp_extension as torch_cpp_ext
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 | import os
5 | import pathlib
6 | setup_dir = os.path.dirname(os.path.realpath(__file__))
7 | HERE = pathlib.Path(__file__).absolute().parent
8 |
9 | def remove_unwanted_pytorch_nvcc_flags():
10 | REMOVE_NVCC_FLAGS = [
11 | '-D__CUDA_NO_HALF_OPERATORS__',
12 | '-D__CUDA_NO_HALF_CONVERSIONS__',
13 | '-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
14 | '-D__CUDA_NO_HALF2_OPERATORS__',
15 | ]
16 | for flag in REMOVE_NVCC_FLAGS:
17 | try:
18 | torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
19 | except ValueError:
20 | pass
21 |
22 | def get_cuda_arch_flags():
23 | return [
24 | '-gencode', 'arch=compute_75,code=sm_75', # Turing
25 | '-gencode', 'arch=compute_80,code=sm_80', # Ampere
26 | '-gencode', 'arch=compute_86,code=sm_86', # Ampere
27 | ]
28 |
29 | def third_party_cmake():
30 | import subprocess, sys, shutil
31 |
32 | cmake = shutil.which('cmake')
33 | if cmake is None:
34 | raise RuntimeError('Cannot find CMake executable.')
35 |
36 | retcode = subprocess.call([cmake, HERE])
37 | if retcode != 0:
38 | sys.stderr.write("Error: CMake configuration failed.\n")
39 | sys.exit(1)
40 |
41 | # install fast hadamard transform
42 | hadamard_dir = os.path.join(HERE, 'third-party/fast-hadamard-transform')
43 | pip = shutil.which('pip')
44 | retcode = subprocess.call([pip, 'install', '-e', hadamard_dir])
45 |
46 | if __name__ == '__main__':
47 | third_party_cmake()
48 | remove_unwanted_pytorch_nvcc_flags()
49 | setup(
50 | name='quarot',
51 | ext_modules=[
52 | CUDAExtension(
53 | name='quarot._CUDA',
54 | sources=[
55 | 'quarot/kernels/bindings.cpp',
56 | 'quarot/kernels/gemm.cu',
57 | 'quarot/kernels/quant.cu',
58 | 'quarot/kernels/flashinfer.cu',
59 | ],
60 | include_dirs=[
61 | os.path.join(setup_dir, 'quarot/kernels/include'),
62 | os.path.join(setup_dir, 'third-party/cutlass/include'),
63 | os.path.join(setup_dir, 'third-party/cutlass/tools/util/include')
64 | ],
65 | extra_compile_args={
66 | 'cxx': [],
67 | 'nvcc': get_cuda_arch_flags(),
68 | }
69 | )
70 | ],
71 | cmdclass={
72 | 'build_ext': BuildExtension
73 | }
74 | )
75 |
--------------------------------------------------------------------------------