├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── benchmarks ├── hadamard_benchmark.py ├── qattention_benchmark.py └── qlinear_benchmark.py ├── e2e ├── __init__.py ├── benchmark.py ├── benchmark_layer.py ├── checkpoint_utils │ ├── __init__.py │ ├── data_utils.py │ ├── gptq_utils.py │ ├── quantize_llama_checkpoint.py │ └── rotation_utils.py └── quantized_llama │ ├── __init__.py │ └── modeling_llama.py ├── fake_quant ├── README.md ├── data_utils.py ├── eval_utils.py ├── gptq_utils.py ├── hadamard_utils.py ├── main.py ├── model_utils.py ├── monkeypatch.py ├── quant_utils.py ├── rotation_utils.py └── utils.py ├── img ├── carrot.png └── fig1.png ├── quarot ├── __init__.py ├── functional │ ├── __init__.py │ ├── hadamard.py │ └── quantization.py ├── kernels │ ├── bindings.cpp │ ├── flashinfer.cu │ ├── gemm.cu │ ├── include │ │ ├── common.h │ │ ├── flashinfer.h │ │ ├── flashinfer │ │ │ ├── cp_async.cuh │ │ │ ├── decode.cuh │ │ │ ├── layout.cuh │ │ │ ├── math.cuh │ │ │ ├── mma.cuh │ │ │ ├── page.cuh │ │ │ ├── permuted_smem.cuh │ │ │ ├── prefill.cuh │ │ │ ├── quantization.cuh │ │ │ ├── rope.cuh │ │ │ ├── state.cuh │ │ │ ├── utils.cuh │ │ │ └── vec_dtypes.cuh │ │ ├── gemm.h │ │ ├── int4.h │ │ ├── quant.h │ │ └── util.h │ └── quant.cu ├── nn │ ├── __init__.py │ ├── hadamard.py │ ├── linear.py │ ├── normalization.py │ └── quantization.py └── transformers │ ├── __init__.py │ └── kv_cache.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ninja_deps 2 | *.o 3 | *.json 4 | *.log 5 | *.pyc 6 | *.so 7 | build/* 8 | quarot.egg-info/* 9 | *.cmake 10 | *.in 11 | CMakeFiles/* 12 | CMakeCache.txt 13 | *.egg 14 | Makefile 15 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | 2 | [submodule "third-party/cutlass"] 3 | path = third-party/cutlass 4 | url = https://github.com/NVIDIA/cutlass.git 5 | [submodule "third-party/nvbench"] 6 | path = third-party/nvbench 7 | url = https://github.com/NVIDIA/nvbench 8 | [submodule "third-party/fast-hadamard-transform"] 9 | path = third-party/fast-hadamard-transform 10 | url = https://github.com/Dao-AILab/fast-hadamard-transform.git 11 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.11) 2 | project(quarot LANGUAGES CXX) 3 | 4 | find_package(Git REQUIRED) 5 | if(GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git") 6 | message(STATUS "Populating Git submodule.") 7 | execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive 8 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 9 | RESULT_VARIABLE GIT_SUBMOD_RESULT) 10 | if(NOT GIT_SUBMOD_RESULT EQUAL "0") 11 | message(FATAL_ERROR 12 | "git submodule updata --init --recursive failed with ${GIT_SUBMOD_RESULT}.") 13 | endif() 14 | endif() 15 | 16 | 17 | set(_saved_CMAKE_MESSAGE_LOG_LEVEL ${CMAKE_MESSAGE_LOG_LEVEL}) 18 | set(CMAKE_MESSAGE_LOG_LEVEL ERROR) 19 | add_subdirectory(third-party/cutlass) 20 | set(CMAKE_MESSAGE_LOG_LEVEL ${_saved_CMAKE_MESSAGE_LOG_LEVEL}) 21 | 22 | include_directories("${CMAKE_SOURCE_DIR}") 23 | include_directories(third-party/cutlass/tools/util/include) 24 | include_directories(third-party/cutlass/include) 25 | include_directories(quarot/kernels/include) 26 | 27 | get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) 28 | foreach(dir ${dirs}) 29 | message(STATUS "dir='${dir}'") 30 | endforeach() 31 | 32 | # add_subdirectory(quarot/kernels) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Your ImageQuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs 3 | This repository contains the code for [**QuaRot**: Outlier-Free 4-Bit Inference in Rotated LLMs](https://arxiv.org/abs/2404.00456). 4 | 5 | 6 | 7 | ## Abstract 8 | We introduce QuaRot, a new **Qua**ntization scheme based on **Rot**ations, which is able to quantize LLMs end-to-end, including all weights, activations, and KV cache in 4 bits. QuaRot rotates LLMs in a way that removes outliers from the hidden state without changing the output, making quantization easier. This *computational invariance* is applied to the hidden state (residual) of the LLM, as well as to the activations of the feed-forward components, aspects of the attention mechanism and to the KV cache. The result is a quantized model where all matrix multiplications are performed in 4-bits, without any channels identified for retention in higher precision. Our quantized **LLaMa2-70B** model has losses of at most **0.29 WikiText perplexity** and retains **99% of the zero-shot** performance. 9 | 10 | ![Your Image](img/fig1.png) 11 | 12 | ## Usage 13 | 14 | 15 | Compile the QuaRot kernels using the following commands: 16 | 17 | ```bash 18 | git clone https://github.com/spcl/QuaRot.git 19 | cd QuaRot 20 | pip install -e . # or pip install . 21 | ``` 22 | 23 | For simulation results, check [fake_quant](https://github.com/spcl/QuaRot/tree/main/fake_quant) directory. 24 | 25 | 26 | 27 | ### Star History 28 | 29 | [![Star History Chart](https://api.star-history.com/svg?repos=spcl/QuaRot&type=Date)](https://star-history.com/#spcl/QuaRot&Date) 30 | 31 | 32 | ## Citation 33 | 34 | The full citation is 35 | 36 | ``` 37 | @article{ashkboos2024quarot, 38 | title={QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs}, 39 | author={Ashkboos, Saleh and Mohtashami, Amirkeivan and Croci, Maximilian L and Li, Bo and Jaggi, Martin and Alistarh, Dan and Hoefler, Torsten and Hensman, James}, 40 | journal={arXiv preprint arXiv:2404.00456}, 41 | year={2024} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /benchmarks/hadamard_benchmark.py: -------------------------------------------------------------------------------- 1 | import fast_hadamard_transform 2 | import torch 3 | import time 4 | for i in [1024, 2048, 4096, 4096*2, 4096*3]: 5 | x = torch.rand(i, i).cuda().to(torch.float16) 6 | torch.cuda.synchronize() 7 | fp32_time = 0 8 | fp16_time = 0 9 | 10 | for j in range(10): 11 | timer = time.time() 12 | y_had_float = fast_hadamard_transform.hadamard_transform(x.float()).half() 13 | torch.cuda.synchronize() 14 | fp32_time += time.time() - timer 15 | torch.cuda.synchronize() 16 | print(fp32_time) 17 | 18 | for j in range(10): 19 | timer = time.time() 20 | y_had = fast_hadamard_transform.hadamard_transform(x) 21 | torch.cuda.synchronize() 22 | fp16_time += time.time() - timer 23 | torch.cuda.synchronize() 24 | print(fp16_time) 25 | print(torch.allclose(y_had, y_had_float, atol=1e-7)) -------------------------------------------------------------------------------- /benchmarks/qattention_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | import numpy as np 4 | import torch 5 | import time 6 | 7 | from quarot.transformers.kv_cache import MultiLayerPagedKVCache4Bit 8 | 9 | model_sizes = [ 10 | (32, 32, 128), #llama-7b 11 | (40, 40, 128), #llama-13b 12 | (80, 64, 128) #llama-70b 13 | ] 14 | 15 | benchmark_dtypes = ["int4", torch.float16] 16 | num_warmup_steps = 5 17 | num_bench_steps = 100 18 | 19 | def module_benchmark(module): 20 | # warmup 21 | for i in range(num_warmup_steps): 22 | out = module() 23 | torch.cuda.synchronize() 24 | 25 | torch.cuda.reset_max_memory_allocated() 26 | start_time = time.perf_counter() 27 | for i in range(num_bench_steps): 28 | out = module() 29 | torch.cuda.synchronize() 30 | memory_usage = torch.cuda.max_memory_allocated() 31 | 32 | end_time = time.perf_counter() 33 | 34 | 35 | return (end_time - start_time) * 1000 / num_bench_steps, memory_usage 36 | 37 | def quantized_kv_cache_decode( 38 | n_layers, num_heads, head_dim, 39 | batch_size, dtype, seq_len, 40 | hadamard_dtype=torch.float16): 41 | device = torch.device("cuda:0") 42 | cache = MultiLayerPagedKVCache4Bit( 43 | batch_size=batch_size, 44 | page_size=seq_len, 45 | max_seq_len=seq_len, 46 | device=device, 47 | n_layers=n_layers, # Ignornig n_layers as it does not affect speed 48 | num_heads=num_heads, 49 | head_dim=head_dim, 50 | disable_quant=dtype == torch.float16, 51 | hadamard_dtype=hadamard_dtype, 52 | ) 53 | query_states = torch.rand((batch_size, 1, num_heads, head_dim), device=device, dtype=torch.float16) 54 | key_states = torch.rand((batch_size, 1, num_heads, head_dim), device=device, dtype=torch.float16) 55 | value_states = torch.rand((batch_size, 1, num_heads, head_dim), device=device, dtype=torch.float16) 56 | def _fake_prefill_and_decode(): 57 | cache._needs_init = [False] * len(cache._needs_init) 58 | cache.length = seq_len - 1 59 | forward_func = cache.update(key_states, value_states, layer_idx=0, cache_kwargs={}) 60 | attn_out = forward_func(query_states) 61 | 62 | times = [] 63 | for i in range(10): 64 | times.append(module_benchmark(_fake_prefill_and_decode)) 65 | return zip(*times) 66 | 67 | 68 | def qattention_benchmark(args): 69 | 70 | for n_layers, num_heads, head_dim in model_sizes: 71 | time_fp16, memory_fp16 = quantized_kv_cache_decode( 72 | n_layers=n_layers, 73 | num_heads=num_heads, 74 | head_dim=head_dim, 75 | batch_size=args.batch_size, 76 | dtype=torch.float16, 77 | seq_len=args.seq_len, 78 | hadamard_dtype=None 79 | ) 80 | 81 | time_int4, memory_int4 = quantized_kv_cache_decode( 82 | n_layers=n_layers, 83 | num_heads=num_heads, 84 | head_dim=head_dim, 85 | batch_size=args.batch_size, 86 | dtype="int4", 87 | seq_len=args.seq_len, 88 | hadamard_dtype=None 89 | ) 90 | time_int4_hadfp16, _ = quantized_kv_cache_decode( 91 | n_layers=n_layers, 92 | num_heads=num_heads, 93 | head_dim=head_dim, 94 | batch_size=args.batch_size, 95 | dtype="int4", 96 | seq_len=args.seq_len, 97 | hadamard_dtype=torch.float16 98 | ) 99 | time_int4_hadfp32, _ = quantized_kv_cache_decode( 100 | n_layers=n_layers, 101 | num_heads=num_heads, 102 | head_dim=head_dim, 103 | batch_size=args.batch_size, 104 | dtype="int4", 105 | seq_len=args.seq_len, 106 | hadamard_dtype=torch.float32 107 | ) 108 | 109 | print(f"Int4 time: {np.mean(time_int4):.3f} +- {1.96 * np.std(time_int4):.3f}ms") 110 | 111 | print(f"Int4 (+FP16had) time: {np.mean(time_int4_hadfp16):.3f} +- {1.96 * np.std(time_int4_hadfp16):.3f}ms") 112 | 113 | print(f"Int4 (+FP32had) time: {np.mean(time_int4_hadfp32):.3f} +- {1.96 * np.std(time_int4_hadfp32):.3f}ms") 114 | 115 | print(f"FP16 time: {np.mean(time_fp16):.3f} +- {1.96 * np.std(time_fp16):.3f}ms") 116 | 117 | print(f"Speedup: {np.mean(time_fp16) / np.mean(time_int4_hadfp16):.3f}x") 118 | 119 | print(f"Int4 memory: {np.mean(memory_int4):.3f} +- {1.96 * np.std(memory_int4):.3f}ms") 120 | print(f"FP16 memory: {np.mean(memory_fp16):.3f} +- {1.96 * np.std(memory_fp16):.3f}ms") 121 | print(f"Memory Saving: {np.mean(memory_fp16) / np.mean(memory_int4):.3f}x") 122 | 123 | # table-style output 124 | print(f'{n_layers}x{num_heads}x{head_dim} & {args.batch_size} & {np.mean(time_fp16):.3f} & {np.mean(time_int4):.3f} & {np.mean(time_int4_hadfp32):.3f} & {np.mean(time_int4_hadfp16):.3f}\\\\') 125 | print('--------------') 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | 130 | parser.add_argument( 131 | '--batch_size', type=int, 132 | help='Batch size', 133 | default=1, 134 | ) 135 | parser.add_argument( 136 | '--seq_len', type=int, 137 | help='Size of the input sequence', 138 | default=2048, 139 | ) 140 | 141 | args = parser.parse_args() 142 | pprint.pprint(vars(args)) 143 | qattention_benchmark(args) 144 | -------------------------------------------------------------------------------- /benchmarks/qlinear_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from quarot.nn import Linear4bit, Quantizer, OnlineHadamard 3 | import time 4 | import argparse 5 | import numpy as np 6 | import pprint 7 | 8 | model_sizes = [ 9 | (4096, 4096), #llama-7b 10 | (5120, 5120), #llama-13b 11 | (8192, 8192) #llama-70b 12 | ] 13 | 14 | mlp_sizes = [ 15 | (4096, 11008), #llama-7b 16 | (5120, 13824), #llama-13b 17 | (8192, 28672) #llama-70b 18 | ] 19 | benchmark_dtypes = [torch.float16] 20 | num_warmup_steps = 5 21 | num_bench_steps = 100 22 | 23 | 24 | def module_benchmark(module, x): 25 | x = x.cuda() 26 | 27 | # warmup 28 | for i in range(num_warmup_steps): 29 | out = module(x) 30 | torch.cuda.synchronize() 31 | 32 | start_time = time.perf_counter() 33 | for i in range(num_bench_steps): 34 | out = module(x) 35 | torch.cuda.synchronize() 36 | 37 | end_time = time.perf_counter() 38 | 39 | 40 | return (end_time - start_time) * 1000 / num_bench_steps 41 | 42 | def linear4bit_benchmark(args): 43 | 44 | bsz = args.bsz 45 | seq_len = args.seq_len 46 | 47 | if args.layer_type == 'v_proj': 48 | layer_size = model_sizes 49 | else: 50 | layer_size = mlp_sizes 51 | 52 | 53 | for (feature_dim_in, feature_dim_out) in layer_size: 54 | for dtype in benchmark_dtypes: 55 | 56 | x = torch.rand((bsz, 57 | seq_len, 58 | feature_dim_in)).cuda().to(dtype) 59 | 60 | baseline_mod = torch.nn.Linear(feature_dim_in, 61 | feature_dim_out, 62 | bias=False).cuda().to(dtype) 63 | 64 | baseline_mod.weight.data = torch.randint_like(baseline_mod.weight.data, 65 | low=-8, high=7).to(dtype) 66 | 67 | s_w = torch.ones((feature_dim_out, 1), dtype=torch.float16, device='cuda') 68 | int4_mod = torch.nn.Sequential( 69 | Quantizer(input_clip_ratio=1.0), 70 | Linear4bit.from_float(baseline_mod, weight_scales=s_w) 71 | ).cuda() 72 | int4_mod_had = torch.nn.Sequential( 73 | OnlineHadamard(baseline_mod.in_features, force_fp32=True), 74 | Quantizer(input_clip_ratio=1.0), 75 | Linear4bit.from_float(baseline_mod, weight_scales=s_w), 76 | ).cuda() 77 | #int4_mod_had.online_full_had = True 78 | #int4_mod.fp32_had = True 79 | 80 | int4_mod_fp16had = torch.nn.Sequential( 81 | OnlineHadamard(baseline_mod.in_features, force_fp32=False), 82 | Quantizer(input_clip_ratio=1.0), 83 | Linear4bit.from_float(baseline_mod, weight_scales=s_w), 84 | ).cuda() 85 | 86 | 87 | 88 | print(f"{dtype}. Sizes: {baseline_mod.weight.shape}") 89 | times_4bit = [] 90 | for i in range(10): 91 | times_4bit.append(module_benchmark(int4_mod, x)) 92 | print(f"Int4 time: {np.mean(times_4bit):.3f} +- {1.96 * np.std(times_4bit):.3f}ms") 93 | 94 | times_4bit_had = [] 95 | for i in range(10): 96 | times_4bit_had.append(module_benchmark(int4_mod_had, x)) 97 | print(f"Int4 (+FP32had) time: {np.mean(times_4bit_had):.3f} +- {1.96 * np.std(times_4bit_had):.3f}ms") 98 | 99 | times_4bit_fp16had = [] 100 | for i in range(10): 101 | times_4bit_fp16had.append(module_benchmark(int4_mod_fp16had, x)) 102 | print(f"Int4 (+FP16had) time: {np.mean(times_4bit_fp16had):.3f} +- {1.96 * np.std(times_4bit_fp16had):.3f}ms") 103 | 104 | 105 | times_baseline = [] 106 | for i in range(10): 107 | times_baseline.append(module_benchmark(baseline_mod, x)) 108 | print(f"FP16 time: {np.mean(times_baseline):.3f} +- {1.96 * np.std(times_baseline):.3f}ms") 109 | 110 | print(f"Speedup: {np.mean(times_baseline) / np.mean(times_4bit):.3f}x") 111 | 112 | # table-style output 113 | print(f'{feature_dim_in}x{feature_dim_out} & {args.bsz} & {np.mean(times_baseline):.3f} & {np.mean(times_4bit):.3f} & {np.mean(times_4bit_had):.3f} & {np.mean(times_4bit_fp16had):.3f}\\\\') 114 | print('--------------') 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | 120 | parser.add_argument( 121 | '--bsz', type=int, 122 | help='Batch size', 123 | default=1, 124 | ) 125 | parser.add_argument( 126 | '--seq_len', type=int, 127 | help='Size of the input sequence', 128 | default=2048, 129 | ) 130 | parser.add_argument( 131 | '--layer_type', type=str, 132 | help='Type of the layer in the model (v_proj [default], down_proj)', 133 | default='v_proj', 134 | choices=['v_proj', 'down_proj'] 135 | ) 136 | 137 | args = parser.parse_args() 138 | pprint.pprint(vars(args)) 139 | linear4bit_benchmark(args) 140 | -------------------------------------------------------------------------------- /e2e/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantized_llama.modeling_llama import QuarotLlamaForCausalLM, QuarotLlamaConfig 2 | -------------------------------------------------------------------------------- /e2e/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import pprint 4 | import numpy as np 5 | import torch 6 | import time 7 | 8 | from e2e.quantized_llama import modeling_llama 9 | import torch 10 | import transformers 11 | 12 | model_configs = [ 13 | "meta-llama/Llama-2-7b-hf", 14 | # "meta-llama/Llama-2-13b-hf", 15 | # "meta-llama/Llama-2-70b-hf", 16 | ] 17 | 18 | benchmark_dtypes = ["int4", torch.float16] 19 | num_warmup_steps = 0 20 | num_bench_steps = 1 21 | 22 | def repeated_run(num_repeats=10): 23 | def func(module): 24 | def _f(*args, **kwargs): 25 | times = [] 26 | for i in range(num_repeats): 27 | times.append(module(*args, **kwargs)) 28 | return tuple(zip(*times)) 29 | return _f 30 | return func 31 | 32 | def _cleanup(): 33 | gc.collect() 34 | torch.cuda.empty_cache() 35 | 36 | @repeated_run() 37 | def module_benchmark(module): 38 | # warmup 39 | for i in range(num_warmup_steps): 40 | out = module() 41 | torch.cuda.synchronize() 42 | 43 | _cleanup() 44 | torch.cuda.reset_max_memory_allocated() 45 | start_time = time.perf_counter() 46 | 47 | 48 | for i in range(num_bench_steps): 49 | out = module() 50 | torch.cuda.synchronize() 51 | peak_memory = torch.cuda.max_memory_allocated() 52 | 53 | end_time = time.perf_counter() 54 | 55 | return (end_time - start_time) * 1000 / num_bench_steps, peak_memory 56 | 57 | 58 | def get_model_quantized(config_name): 59 | config = transformers.AutoConfig.from_pretrained( 60 | config_name, 61 | attn_implementation="flash_attention_2" 62 | ) 63 | dtype_old = torch.get_default_dtype() 64 | torch.set_default_dtype(torch.float16) 65 | with transformers.modeling_utils.no_init_weights(): 66 | model = modeling_llama.QuarotLlamaForCausalLM(config=config) 67 | torch.set_default_dtype(dtype_old) 68 | return model 69 | 70 | 71 | def get_model_hf(config_name): 72 | return transformers.LlamaForCausalLM.from_pretrained( 73 | config_name, 74 | torch_dtype=torch.float16, 75 | attn_implementation="flash_attention_2" 76 | ) 77 | 78 | def get_model_fp16(config_name): 79 | return modeling_llama.QuarotFP16LlamaForCausalLM.from_pretrained( 80 | config_name, 81 | torch_dtype=torch.float16, 82 | attn_implementation="flash_attention_2" 83 | ) 84 | 85 | 86 | def run_prefill(model, bsz, prefill_length): 87 | device = model.device 88 | test_input = torch.randint(100, 200, (bsz, prefill_length), dtype=torch.int32, device=device) 89 | return module_benchmark(lambda: model(test_input)) 90 | 91 | 92 | def run_decode(model, bsz, prefill_length, decode_steps): 93 | device = model.device 94 | test_input = torch.randint(100, 200, (bsz, prefill_length), dtype=torch.int32, device=device) 95 | model._expected_max_length = prefill_length + decode_steps 96 | out = model(test_input) 97 | past_key_values = out.past_key_values 98 | del out 99 | _cleanup() 100 | next_input = torch.tensor([[100] for _ in range (bsz)], dtype=torch.int32, device=device) 101 | def _decode_for_multiple_steps(): 102 | past_key_values.length = prefill_length 103 | for _ in range(decode_steps): 104 | model(next_input, past_key_values=past_key_values) 105 | return module_benchmark(_decode_for_multiple_steps) 106 | 107 | 108 | def run_e2e(model, bsz, prefill_length, decode_steps): 109 | device = model.device 110 | test_input = torch.randint(100, 200, (bsz, prefill_length), dtype=torch.int32, device=device) 111 | next_input = torch.tensor([[100] for _ in range (bsz)], dtype=torch.int32, device=device) 112 | def _prefill_and_decode_for_multiple_steps(): 113 | model._expected_max_length = prefill_length + decode_steps 114 | out = model(test_input) 115 | for _ in range(decode_steps): 116 | model(next_input, past_key_values=out.past_key_values) 117 | return module_benchmark(_prefill_and_decode_for_multiple_steps) 118 | 119 | 120 | def _wait_for_input(): 121 | print("Press enter") 122 | input() 123 | 124 | @torch.no_grad 125 | def run_all_for_model(model, bsz, prefill, decode): 126 | model.eval() 127 | model = model.cuda() 128 | time_prefill, _ = run_prefill(model, bsz, prefill) 129 | _cleanup() 130 | if decode is not None: 131 | time_decode, memory_decode = run_decode(model, bsz, prefill, decode) 132 | _cleanup() 133 | time_e2e, _ = run_e2e(model, bsz, prefill, decode) 134 | _cleanup() 135 | else: 136 | time_decode = time_e2e = None 137 | return time_prefill, time_decode, time_e2e, memory_decode 138 | 139 | def benchmark(args): 140 | 141 | for config_name in model_configs: 142 | model = get_model_quantized(config_name) 143 | time_prefill_i4, time_decode_i4, time_e2e_i4, mem_i4 = run_all_for_model( 144 | model, args.batch_size, args.prefill_seq_len, args.decode_steps) 145 | del model 146 | _cleanup() 147 | model = get_model_fp16(config_name) 148 | time_prefill_f16, time_decode_f16, time_e2e_f16, mem_f16 = run_all_for_model( 149 | model, args.batch_size, args.prefill_seq_len, args.decode_steps) 150 | del model 151 | _cleanup() 152 | 153 | print(f"Prefill Int4 time: {np.mean(time_prefill_i4):.3f} +- {1.96 * np.std(time_prefill_i4):.3f}ms") 154 | print(f"Prefill FP16 time: {np.mean(time_prefill_f16):.3f} +- {1.96 * np.std(time_prefill_f16):.3f}ms") 155 | print(f"Speedup: {np.mean(time_prefill_f16) / np.mean(time_prefill_i4):.3f}x") 156 | print(f'Prefill & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {np.mean(time_prefill_f16):.3f} & {np.mean(time_prefill_i4):.3f}\\\\') 157 | 158 | if args.decode_steps is not None: 159 | print(f"Decode Int4 time: {np.mean(time_decode_i4):.3f} +- {1.96 * np.std(time_decode_i4):.3f}ms") 160 | print(f"Decode FP16 time: {np.mean(time_decode_f16):.3f} +- {1.96 * np.std(time_decode_f16):.3f}ms") 161 | print(f"Speedup: {np.mean(time_decode_f16) / np.mean(time_decode_i4):.3f}x") 162 | print(f'Decode & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(time_decode_f16):.3f} & {np.mean(time_decode_i4):.3f}\\\\') 163 | 164 | print(f"E2E Int4 time: {np.mean(time_e2e_i4):.3f} +- {1.96 * np.std(time_e2e_i4):.3f}ms") 165 | print(f"E2E FP16 time: {np.mean(time_e2e_f16):.3f} +- {1.96 * np.std(time_e2e_f16):.3f}ms") 166 | print(f"Speedup: {np.mean(time_e2e_f16) / np.mean(time_e2e_i4):.3f}x") 167 | print(f'E2E & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(time_e2e_f16):.3f} & {np.mean(time_e2e_i4):.3f}\\\\') 168 | 169 | # table-style output 170 | 171 | print(f"Int4 memory: {np.mean(mem_i4) / (1024 * 1024 * 1024):.3f}GB +- {1.96 * np.std(mem_i4):.3f}") 172 | print(f"FP16 memory: {np.mean(mem_f16) / (1024 * 1024 * 1024):.3f}GB +- {1.96 * np.std(mem_f16):.3f}") 173 | print(f"Memory saving: {np.mean(mem_f16) / np.mean(mem_i4):.3f}x") 174 | print(f'Memory saving & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(mem_i4) / (1024 * 1024 * 1024):.3f}GB & {np.mean(mem_f16) / (1024 * 1024 * 1024):.3f}GB\\\\') 175 | 176 | print('--------------') 177 | 178 | if __name__ == '__main__': 179 | parser = argparse.ArgumentParser() 180 | 181 | parser.add_argument( 182 | '--batch_size', type=int, 183 | help='Batch size', 184 | default=1, 185 | ) 186 | parser.add_argument( 187 | '--prefill_seq_len', type=int, 188 | help='Size of the input sequence', 189 | default=2048, 190 | ) 191 | parser.add_argument( 192 | '--decode_steps', type=int, 193 | help='Decode steps', 194 | required=False, 195 | default=None, 196 | ) 197 | 198 | args = parser.parse_args() 199 | pprint.pprint(vars(args)) 200 | benchmark(args) 201 | -------------------------------------------------------------------------------- /e2e/benchmark_layer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import functools 4 | import pprint 5 | import numpy as np 6 | import torch 7 | import time 8 | 9 | import quarot 10 | from e2e.quantized_llama import modeling_llama 11 | import torch 12 | import transformers 13 | 14 | model_configs = [ 15 | "meta-llama/Llama-2-7b-hf", 16 | #"meta-llama/Llama-2-13b-hf", 17 | # "meta-llama/Llama-2-70b-hf", 18 | ] 19 | 20 | benchmark_dtypes = ["int4", torch.float16] 21 | num_warmup_steps = 3 22 | num_bench_steps = 10 23 | 24 | def repeated_run(num_repeats=10): 25 | def func(module): 26 | def _f(*args, **kwargs): 27 | times = [] 28 | for i in range(num_repeats): 29 | times.append(module(*args, **kwargs)) 30 | return tuple(zip(*times)) 31 | return _f 32 | return func 33 | 34 | def _cleanup(): 35 | gc.collect() 36 | torch.cuda.empty_cache() 37 | 38 | @repeated_run() 39 | def module_benchmark(module): 40 | # warmup 41 | for i in range(num_warmup_steps): 42 | out = module() 43 | torch.cuda.synchronize() 44 | 45 | start_time = time.perf_counter() 46 | torch.cuda.reset_max_memory_allocated() 47 | 48 | for i in range(num_bench_steps): 49 | out = module() 50 | torch.cuda.synchronize() 51 | peak_memory = torch.cuda.max_memory_allocated() 52 | 53 | end_time = time.perf_counter() 54 | 55 | return (end_time - start_time) * 1000 / num_bench_steps, peak_memory 56 | 57 | 58 | def _build_cache(batch_size, length, layer, disable_quant, num_key_value_heads, hidden_size, device): 59 | num_heads = num_key_value_heads 60 | model_dim = hidden_size 61 | head_dim = model_dim // num_heads 62 | return quarot.transformers.MultiLayerPagedKVCache4Bit( 63 | batch_size=batch_size, 64 | page_size=length, 65 | max_seq_len=length, 66 | device=device, 67 | n_layers=1, 68 | num_heads=num_heads, 69 | head_dim=head_dim, 70 | disable_quant=disable_quant, 71 | hadamard_dtype=None if disable_quant else torch.float16 72 | ) 73 | 74 | def get_model_quantized(config_name): 75 | config = transformers.AutoConfig.from_pretrained( 76 | config_name, 77 | attn_implementation="flash_attention_2" 78 | ) 79 | torch.set_default_dtype(torch.float16) 80 | with transformers.modeling_utils.no_init_weights(): 81 | model = modeling_llama.QuarotLlamaForCausalLM(config=config) 82 | 83 | return model, functools.partial( 84 | _build_cache, 85 | disable_quant=False, 86 | device=torch.device("cuda:0"), 87 | num_key_value_heads=model.config.num_key_value_heads, 88 | hidden_size=model.config.hidden_size,), model.config.hidden_size 89 | 90 | 91 | def get_model_hf(config_name): 92 | return transformers.LlamaForCausalLM.from_pretrained( 93 | config_name, 94 | torch_dtype=torch.float16, 95 | attn_implementation="flash_attention_2" 96 | ), None, model.config.hidden_size 97 | 98 | def get_model_fp16(config_name): 99 | model = modeling_llama.QuarotFP16LlamaForCausalLM.from_pretrained( 100 | config_name, 101 | torch_dtype=torch.float16, 102 | attn_implementation="flash_attention_2" 103 | ) 104 | return model, functools.partial( 105 | _build_cache, 106 | disable_quant=True, 107 | device=torch.device("cuda:0"), 108 | num_key_value_heads=model.config.num_key_value_heads, 109 | hidden_size=model.config.hidden_size, 110 | ), model.config.hidden_size 111 | 112 | 113 | def run_prefill(layer, cache_builder, bsz, prefill_length, hidden_size): 114 | device = layer.self_attn.v_proj.weight.device 115 | test_input = torch.rand((bsz, prefill_length, hidden_size), dtype=torch.float16, device=device) 116 | if cache_builder is None: 117 | def _prefill(): 118 | layer(test_input) 119 | else: 120 | past_key_values = cache_builder(bsz, prefill_length, layer) 121 | def _prefill(): 122 | past_key_values.length = 0 123 | past_key_values._needs_init[0] = True 124 | layer(test_input, past_key_value=past_key_values) 125 | return module_benchmark(_prefill) 126 | 127 | 128 | def run_decode(layer, cache_builder, bsz, prefill_length, decode_steps, hidden_size): 129 | device = layer.self_attn.v_proj.weight.device 130 | test_input = torch.rand((bsz, prefill_length, hidden_size), dtype=torch.float16, device=device) 131 | next_input = torch.rand((bsz, 1, hidden_size), dtype=torch.float16, device=device) 132 | assert cache_builder is not None 133 | past_key_values = cache_builder(bsz, prefill_length + decode_steps, layer) 134 | layer(test_input, past_key_value=past_key_values) 135 | def _decode_for_multiple_steps(): 136 | past_key_values.length = prefill_length 137 | for i in range(decode_steps): 138 | layer(next_input, past_key_value=past_key_values, 139 | position_ids=torch.tensor([[prefill_length + i]] * bsz, device=past_key_values.device, dtype=torch.int32)) 140 | return module_benchmark(_decode_for_multiple_steps) 141 | 142 | 143 | def run_e2e(layer, cache_builder, bsz, prefill_length, decode_steps, hidden_size): 144 | device = layer.self_attn.v_proj.weight.device 145 | test_input = torch.rand((bsz, prefill_length, hidden_size), dtype=torch.float16, device=device) 146 | next_input = torch.rand((bsz, 1, hidden_size), dtype=torch.float16, device=device) 147 | assert cache_builder is not None 148 | past_key_values = cache_builder(bsz, prefill_length + decode_steps, layer) 149 | def _prefill_and_decode_for_multiple_steps(): 150 | past_key_values.length = 0 151 | past_key_values._needs_init[0] = True 152 | layer(test_input, past_key_value=past_key_values) 153 | for i in range(decode_steps): 154 | layer(next_input, past_key_value=past_key_values, 155 | position_ids=torch.tensor([[prefill_length + i]] * bsz, device=device, dtype=torch.int32)) 156 | return module_benchmark(_prefill_and_decode_for_multiple_steps) 157 | 158 | 159 | def _wait_for_input(): 160 | print("Press enter") 161 | input() 162 | 163 | @torch.no_grad 164 | def run_all_for_model(layer, cache_builder, bsz, prefill, decode, hidden_size): 165 | layer = layer.cuda() 166 | layer.eval() 167 | time_prefill, _ = run_prefill(layer, cache_builder, bsz, prefill, hidden_size) 168 | 169 | _cleanup() 170 | if decode is not None: 171 | time_decode, memory_decode = run_decode(layer, cache_builder, bsz, prefill, decode, hidden_size) 172 | _cleanup() 173 | time_e2e, _ = run_e2e(layer, cache_builder, bsz, prefill, decode, hidden_size) 174 | _cleanup() 175 | else: 176 | time_decode = time_e2e = None 177 | memory_decode = None 178 | return time_prefill, time_decode, time_e2e, memory_decode 179 | 180 | def benchmark(args): 181 | 182 | for config_name in model_configs: 183 | model, cache_builder, hidden_size = get_model_quantized(config_name) 184 | layer = model.model.layers[0] 185 | del model 186 | _cleanup() 187 | time_prefill_i4, time_decode_i4, time_e2e_i4, mem_i4 = run_all_for_model( 188 | layer, cache_builder, args.batch_size, args.prefill_seq_len, args.decode_steps, hidden_size) 189 | del layer 190 | _cleanup() 191 | model, cache_builder, hidden_size = get_model_fp16(config_name) 192 | layer = model.model.layers[0] 193 | del model 194 | _cleanup() 195 | time_prefill_f16, time_decode_f16, time_e2e_f16, mem_f16 = run_all_for_model( 196 | layer, cache_builder, args.batch_size, args.prefill_seq_len, args.decode_steps, hidden_size) 197 | del layer 198 | _cleanup() 199 | 200 | print(f"Prefill Int4 time: {np.mean(time_prefill_i4):.3f} +- {1.96 * np.std(time_prefill_i4):.3f}ms") 201 | print(f"Prefill FP16 time: {np.mean(time_prefill_f16):.3f} +- {1.96 * np.std(time_prefill_f16):.3f}ms") 202 | print(f"Speedup: {np.mean(time_prefill_f16) / np.mean(time_prefill_i4):.3f}x") 203 | print(f'Prefill & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {np.mean(time_prefill_f16):.3f} & {np.mean(time_prefill_i4):.3f}\\\\') 204 | 205 | if args.decode_steps is not None: 206 | print(f"Decode Int4 time: {np.mean(time_decode_i4):.3f} +- {1.96 * np.std(time_decode_i4):.3f}ms") 207 | print(f"Decode FP16 time: {np.mean(time_decode_f16):.3f} +- {1.96 * np.std(time_decode_f16):.3f}ms") 208 | print(f"Speedup: {np.mean(time_decode_f16) / np.mean(time_decode_i4):.3f}x") 209 | print(f'Decode & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(time_decode_f16):.3f} & {np.mean(time_decode_i4):.3f}\\\\') 210 | 211 | print(f"E2E Int4 time: {np.mean(time_e2e_i4):.3f} +- {1.96 * np.std(time_e2e_i4):.3f}ms") 212 | print(f"E2E FP16 time: {np.mean(time_e2e_f16):.3f} +- {1.96 * np.std(time_e2e_f16):.3f}ms") 213 | print(f"Speedup: {np.mean(time_e2e_f16) / np.mean(time_e2e_i4):.3f}x") 214 | print(f'E2E & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(time_e2e_f16):.3f} & {np.mean(time_e2e_i4):.3f}\\\\') 215 | 216 | # table-style output 217 | 218 | print(f"Int4 memory: {np.mean(mem_i4) / (1024 * 1024 * 1024):.3f}GB +- {1.96 * np.std(mem_i4):.3f}") 219 | print(f"FP16 memory: {np.mean(mem_f16) / (1024 * 1024 * 1024):.3f}GB +- {1.96 * np.std(mem_f16):.3f}") 220 | print(f"Memory saving: {np.mean(mem_f16) / np.mean(mem_i4):.3f}x") 221 | print(f'Memory saving & {config_name} & {args.batch_size} & {args.prefill_seq_len} & {args.decode_steps} & {np.mean(mem_i4) / (1024 * 1024 * 1024):.3f}GB & {np.mean(mem_f16) / (1024 * 1024 * 1024):.3f}GB\\\\') 222 | 223 | print('--------------') 224 | 225 | if __name__ == '__main__': 226 | parser = argparse.ArgumentParser() 227 | 228 | parser.add_argument( 229 | '--batch_size', type=int, 230 | help='Batch size', 231 | default=1, 232 | ) 233 | parser.add_argument( 234 | '--prefill_seq_len', type=int, 235 | help='Size of the input sequence', 236 | default=2048, 237 | ) 238 | parser.add_argument( 239 | '--decode_steps', type=int, 240 | help='Decode steps', 241 | required=False, 242 | default=None, 243 | ) 244 | 245 | args = parser.parse_args() 246 | pprint.pprint(vars(args)) 247 | benchmark(args) 248 | -------------------------------------------------------------------------------- /e2e/checkpoint_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/QuaRot/5008669b08c1f11f9b64d52d16fddd47ca754c5a/e2e/checkpoint_utils/__init__.py -------------------------------------------------------------------------------- /e2e/checkpoint_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import random 3 | import transformers 4 | 5 | def get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode=False): 6 | 7 | if hf_token is None: 8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 9 | else: 10 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 11 | 12 | if eval_mode: 13 | testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 15 | return testenc 16 | else: 17 | traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 18 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 19 | random.seed(seed) 20 | trainloader = [] 21 | for _ in range(nsamples): 22 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 23 | j = i + seqlen 24 | inp = trainenc.input_ids[:, i:j] 25 | tar = inp.clone() 26 | tar[:, :-1] = -100 27 | trainloader.append((inp, tar)) 28 | return trainloader 29 | 30 | def get_c4_new(nsamples, seed, seqlen, model, hf_token=None, eval_mode=False): 31 | 32 | if hf_token is None: 33 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 34 | else: 35 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 36 | 37 | if eval_mode: 38 | valdata = datasets.load_dataset( 39 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 40 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 41 | valenc = valenc.input_ids[:, :(256 * seqlen)] 42 | class TokenizerWrapper: 43 | def __init__(self, input_ids): 44 | self.input_ids = input_ids 45 | valenc = TokenizerWrapper(valenc) 46 | return valenc 47 | else: 48 | traindata = datasets.load_dataset( 49 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 50 | 51 | random.seed(seed) 52 | trainloader = [] 53 | for _ in range(nsamples): 54 | while True: 55 | i = random.randint(0, len(traindata) - 1) 56 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 57 | if trainenc.input_ids.shape[1] >= seqlen: 58 | break 59 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 60 | j = i + seqlen 61 | inp = trainenc.input_ids[:, i:j] 62 | tar = inp.clone() 63 | tar[:, :-1] = -100 64 | trainloader.append((inp, tar)) 65 | return trainloader 66 | 67 | 68 | 69 | 70 | def get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode=False): 71 | 72 | 73 | if hf_token is None: 74 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 75 | else: 76 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 77 | 78 | if eval_mode: 79 | testdata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='test') 80 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 81 | return testenc 82 | else: 83 | traindata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='train') 84 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 85 | random.seed(seed) 86 | trainloader = [] 87 | for _ in range(nsamples): 88 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 89 | j = i + seqlen 90 | inp = trainenc.input_ids[:, i:j] 91 | tar = inp.clone() 92 | tar[:, :-1] = -100 93 | trainloader.append((inp, tar)) 94 | return trainloader 95 | 96 | 97 | def get_loaders( 98 | name, nsamples=128, seed=0, seqlen=2048, model='', hf_token=None, eval_mode=False 99 | ): 100 | if 'wikitext2' in name: 101 | return get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode) 102 | if 'ptb' in name: 103 | return get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode) 104 | if 'c4' in name: 105 | return get_c4_new(nsamples, seed, seqlen, model, hf_token, eval_mode) 106 | -------------------------------------------------------------------------------- /e2e/checkpoint_utils/quantize_llama_checkpoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import transformers 3 | import torch 4 | import shutil 5 | import json 6 | 7 | from e2e.quantized_llama import modeling_llama 8 | from e2e.checkpoint_utils import data_utils, gptq_utils, rotation_utils 9 | from quarot.functional import pack_i4 10 | 11 | def main(args): 12 | model = transformers.LlamaForCausalLM.from_pretrained(args.pretraiend_path_or_name) 13 | 14 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 15 | model.seqlen = 2048 16 | rotation_utils.fuse_layer_norms(model) 17 | rotation_utils.rotate_model(model) 18 | if not args.w_rtn: 19 | trainloader = data_utils.get_loaders( 20 | args.cal_dataset, nsamples=args.nsamples, 21 | seed=args.seed, model=args.pretraiend_path_or_name, 22 | seqlen=model.seqlen, eval_mode=False 23 | ) 24 | quantizers = gptq_utils.gptq_fwrd(model, trainloader, device, args) 25 | else: 26 | quantizers = gptq_utils.rtn_fwrd(model, device, args) 27 | 28 | 29 | old_dict = model.state_dict() 30 | key_maps = { 31 | "mlp.down_proj": "mlp.down_proj.2", 32 | "self_attn.o_proj": "self_attn.o_proj.1" 33 | } 34 | bad_key_names = { 35 | "post_attention_layernorm.weight", 36 | "input_layernorm.weight" 37 | } 38 | def _get_new_key(key): 39 | new_key = key 40 | for old_name, new_name in key_maps.items(): 41 | new_key = new_key.replace(old_name, new_name) 42 | return new_key 43 | 44 | def _keep_key(key): 45 | return all(bad_name not in key for bad_name in bad_key_names) 46 | 47 | new_dict = {_get_new_key(key): value for key, value in old_dict.items() if _keep_key(key)} 48 | for key, value in quantizers.items(): 49 | new_key = _get_new_key(key) 50 | weight_scales = value.scale 51 | new_dict[f"{new_key}.weight_scales"] = weight_scales 52 | weight_matrix = new_dict[f"{new_key}.weight"] 53 | int_rounded_weight = (weight_matrix/weight_scales).round() 54 | new_dict[f"{new_key}.weight"] = pack_i4(int_rounded_weight.to(torch.int8)) 55 | 56 | config = modeling_llama.QuarotLlamaConfig.from_pretrained( 57 | args.pretraiend_path_or_name, 58 | attn_implementation="flash_attention_2" 59 | ) 60 | torch.set_default_dtype(torch.float16) 61 | with transformers.modeling_utils.no_init_weights(): 62 | new_model = modeling_llama.QuarotLlamaForCausalLM(config=config) 63 | 64 | result = new_model.load_state_dict(new_dict, strict=False) 65 | assert all("had_rem_dim" in key for key in result.missing_keys), result 66 | assert len(result.unexpected_keys) == 0, result 67 | 68 | new_model = new_model.cpu() 69 | 70 | new_model.save_pretrained(args.save_path) 71 | with open(f"{args.save_path}/config.json") as f: 72 | config = json.load(f) 73 | config["auto_map"] = { 74 | "AutoConfig": "quarot.LlamaConfig", 75 | "AutoModelForCausalLM": "quarot.QuarotLlamaForCausalLM" 76 | } 77 | config["model_type"] = "llama_quarot" 78 | with open(f"{args.save_path}/config.json", "w") as f: 79 | json.dump(config, f) 80 | 81 | shutil.copy("e2e/quantized_llama/modeling_llama.py", f"{args.save_path}/quarot.py") 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | 87 | supported_models = [ 88 | 'meta-llama/Llama-2-7b-hf', 89 | 'meta-llama/Llama-2-13b-hf', 90 | 'meta-llama/Llama-2-70b-hf', 91 | ] 92 | 93 | supported_datasets = ['wikitext2', 'ptb', 'c4'] 94 | 95 | # General Arguments 96 | parser.add_argument('--pretraiend_path_or_name', type=str, default='meta-llama/Llama-2-7b-hf', 97 | help='Model to load;', choices=supported_models) 98 | parser.add_argument('--save_path', type=str, required=True) 99 | parser.add_argument('--seed', type=int, default=0, help='Random Seed for HuggingFace and PyTorch') 100 | parser.add_argument('--eval_dataset', type=str, default='wikitext2', 101 | help='Dataset for Evaluation (default: wikitext2)', choices=supported_datasets,) 102 | 103 | 104 | parser.add_argument('--w_groupsize', type=int, default=-1, 105 | help='Groupsize for weight quantization. Note that this should be the same as a_groupsize') 106 | parser.add_argument('--w_asym', action=argparse.BooleanOptionalAction, default=False, 107 | help='ASymmetric weight quantization (default: False)') 108 | parser.add_argument('--w_rtn', action=argparse.BooleanOptionalAction, default=False, 109 | help='Quantize the weights using RtN. If the w_bits < 16 and this flag is not set, we use GPTQ') 110 | parser.add_argument('--w_clip', action=argparse.BooleanOptionalAction, default=False, 111 | help='''Clipping the weight quantization! 112 | We do not support arguments for clipping and we find the best clip ratio during the weight quantization''') 113 | parser.add_argument('--nsamples', type=int, default=128, 114 | help='Number of calibration data samples for GPTQ.') 115 | parser.add_argument('--cal_dataset', type=str, default='wikitext2', 116 | help='calibration data samples for GPTQ.', choices=supported_datasets) 117 | parser.add_argument('--percdamp', type=float, default=.01, 118 | help='Percent of the average Hessian diagonal to use for dampening.') 119 | parser.add_argument('--act_order', action=argparse.BooleanOptionalAction, default=False, 120 | help='act-order in GPTQ') 121 | 122 | args = parser.parse_args() 123 | 124 | args.w_bits = 4 125 | main(args) 126 | -------------------------------------------------------------------------------- /e2e/checkpoint_utils/rotation_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import typing 4 | import transformers 5 | import tqdm, math 6 | from quarot.functional import random_hadamard_matrix, apply_exact_had_to_linear 7 | 8 | def fuse_ln_linear(layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear]) -> None: 9 | """ 10 | fuse the linear operations in Layernorm into the adjacent linear blocks. 11 | """ 12 | for linear in linear_layers: 13 | linear_dtype = linear.weight.dtype 14 | 15 | # Calculating new weight and bias 16 | W_ = linear.weight.data.double() 17 | linear.weight.data = (W_ * layernorm.weight.double()).to(linear_dtype) 18 | 19 | if hasattr(layernorm, 'bias'): 20 | if linear.bias is None: 21 | linear.bias = torch.nn.Parameter(torch.zeros(linear.out_features, dtype=torch.float64)) 22 | linear.bias.data = linear.bias.data.double() + torch.matmul(W_, layernorm.bias.double()) 23 | linear.bias.data = linear.bias.data.to(linear_dtype) 24 | 25 | def fuse_layer_norms(model): 26 | 27 | # Embedding fusion 28 | W = model.model.embed_tokens 29 | W_ = W.weight.data.double() 30 | W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) 31 | 32 | layers = model.model.layers 33 | 34 | # Fuse the linear operations in Layernorm into the adjacent linear blocks. 35 | for layer in layers: 36 | # fuse the input layernorms into the linear layers 37 | fuse_ln_linear(layer.post_attention_layernorm, [layer.mlp.up_proj, layer.mlp.gate_proj]) 38 | fuse_ln_linear(layer.input_layernorm, [layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj]) 39 | 40 | 41 | fuse_ln_linear(model.model.norm, [model.lm_head]) 42 | 43 | 44 | 45 | def rotate_embeddings(model, Q: torch.Tensor) -> None: 46 | # Rotate the embeddings. 47 | W = model.model.embed_tokens 48 | dtype = W.weight.data.dtype 49 | W_ = W.weight.data.to(dtype=torch.float64) 50 | W.weight.data = torch.matmul(W_, Q).to(device="cpu", dtype=dtype) 51 | 52 | 53 | def rotate_attention_inputs(layer, Q) -> None: 54 | # Rotate the WQ, WK and WV matrices of the self-attention layer. 55 | for W in [layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj]: 56 | dtype = W.weight.dtype 57 | W_ = W.weight.to(dtype=torch.float64) 58 | W.weight.data = torch.matmul(W_, Q).to(device="cpu", dtype=dtype) 59 | 60 | def rotate_attention_output(layer, Q) -> None: 61 | # Rotate output matrix of the self-attention layer. 62 | W = layer.self_attn.o_proj 63 | dtype = W.weight.data.dtype 64 | W_ = W.weight.data.to(dtype=torch.float64) 65 | W.weight.data = torch.matmul(Q.T, W_).to(device="cpu", dtype=dtype) 66 | if W.bias is not None: 67 | b = W.bias.data.to(dtype=torch.float64) 68 | W.bias.data = torch.matmul(Q.T, b).to(device="cpu", dtype=dtype) 69 | 70 | def rotate_mlp_input(layer, Q): 71 | # Rotate the MLP input weights. 72 | mlp_inputs = [layer.mlp.up_proj, layer.mlp.gate_proj] 73 | for W in mlp_inputs: 74 | dtype = W.weight.dtype 75 | W_ = W.weight.data.to(dtype=torch.float64) 76 | W.weight.data = torch.matmul(W_, Q).to(device="cpu", dtype=dtype) 77 | 78 | def rotate_mlp_output(layer, Q): 79 | # Rotate the MLP output weights and bias. 80 | W = layer.mlp.down_proj 81 | dtype = W.weight.data.dtype 82 | W_ = W.weight.data.to(dtype=torch.float64) 83 | W.weight.data = torch.matmul(Q.T, W_).to(device="cpu", dtype=dtype) 84 | apply_exact_had_to_linear(W, had_dim=-1, output=False) #apply exact (inverse) hadamard on the weights of mlp output 85 | if W.bias is not None: 86 | b = W.bias.data.to(dtype=torch.float64) 87 | W.bias.data = torch.matmul(Q.T, b).to(device="cpu", dtype=dtype) 88 | 89 | def rotate_head(model, Q: torch.Tensor) -> None: 90 | # Rotate the head. 91 | W = model.lm_head 92 | dtype = W.weight.data.dtype 93 | W_ = W.weight.data.to(dtype=torch.float64) 94 | W.weight.data = torch.matmul(W_, Q).to(device="cpu", dtype=dtype) 95 | 96 | def rotate_ov_proj(layer, head_num, head_dim): 97 | v_proj = layer.self_attn.v_proj 98 | o_proj = layer.self_attn.o_proj 99 | 100 | apply_exact_had_to_linear(v_proj, had_dim=head_dim, output=True) 101 | apply_exact_had_to_linear(o_proj, had_dim=-1, output=False) 102 | 103 | 104 | @torch.inference_mode() 105 | def rotate_model(model): 106 | Q = random_hadamard_matrix(model.config.hidden_size, model.device) 107 | config = model.config 108 | num_heads = config.num_attention_heads 109 | model_dim = config.hidden_size 110 | head_dim = model_dim // num_heads 111 | 112 | 113 | rotate_embeddings(model, Q) 114 | rotate_head(model, Q) 115 | gc.collect() 116 | torch.cuda.empty_cache() 117 | layers = model.model.layers 118 | for idx, layer in enumerate(tqdm.tqdm(layers, unit="layer", desc="Rotating")): 119 | rotate_attention_inputs(layers[idx], Q) 120 | rotate_attention_output(layers[idx], Q) 121 | rotate_mlp_input(layers[idx], Q) 122 | rotate_mlp_output(layers[idx], Q) 123 | rotate_ov_proj(layers[idx], num_heads, head_dim) 124 | -------------------------------------------------------------------------------- /e2e/quantized_llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/QuaRot/5008669b08c1f11f9b64d52d16fddd47ca754c5a/e2e/quantized_llama/__init__.py -------------------------------------------------------------------------------- /e2e/quantized_llama/modeling_llama.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import quarot 3 | import quarot.transformers 4 | import torch 5 | from transformers import LlamaConfig 6 | from transformers.models.llama.modeling_llama import LlamaAttention, \ 7 | LlamaFlashAttention2, LlamaForCausalLM, apply_rotary_pos_emb, LlamaMLP 8 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 9 | from typing import Optional, Tuple 10 | from transformers import Cache 11 | 12 | 13 | ALL_LAYERNORM_LAYERS.append(quarot.nn.RMSNorm) 14 | 15 | class QuarotLlamaConfig(LlamaConfig): 16 | model_type = "llama_quarot" 17 | 18 | class QuarotFP16LlamaAttention(LlamaFlashAttention2): 19 | 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.quantizer = torch.nn.Identity() 23 | self.o_proj_hadamard = torch.nn.Identity() 24 | 25 | def forward( 26 | self, 27 | hidden_states: torch.Tensor, 28 | attention_mask: Optional[torch.LongTensor] = None, 29 | position_ids: Optional[torch.LongTensor] = None, 30 | past_key_value: Optional[Cache] = None, 31 | output_attentions: bool = False, 32 | use_cache: bool = False, 33 | cache_position: Optional[torch.LongTensor] = None, 34 | **kwargs, 35 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 36 | output_attentions = False 37 | 38 | bsz, q_len, _ = hidden_states.size() 39 | 40 | hidden_states = self.quantizer(hidden_states) 41 | 42 | query_states = self.q_proj(hidden_states) 43 | key_states = self.k_proj(hidden_states) 44 | value_states = self.v_proj(hidden_states) 45 | 46 | # Flash attention requires the input to have the shape 47 | # batch_size x seq_length x head_dim x hidden_dim 48 | # therefore we just need to keep the original shape 49 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) 50 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) 51 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) 52 | 53 | kv_seq_len = key_states.shape[1] 54 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 55 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 56 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, unsqueeze_dim=2) 57 | 58 | past_key_value = getattr(self, "past_key_value", past_key_value) 59 | assert past_key_value is not None 60 | # sin and cos are specific to RoPE models; position_ids needed for the static cache 61 | 62 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "attention_mask": attention_mask} 63 | cache_out = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 64 | 65 | 66 | dropout_rate = self.attention_dropout if self.training else 0.0 67 | 68 | assert self.is_causal 69 | 70 | if isinstance(cache_out, tuple): 71 | key_states, value_states = cache_out 72 | attn_output = self._flash_attention_forward( 73 | query_states, 74 | key_states, 75 | value_states, 76 | query_length=q_len, 77 | attention_mask=attention_mask 78 | ) 79 | else: 80 | attn_output = cache_out(query_states) 81 | 82 | attn_output = self.o_proj_hadamard(attn_output.transpose(-1, -2)).transpose(-1, -2) 83 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 84 | attn_output = self.o_proj(attn_output) 85 | 86 | if not output_attentions: 87 | attn_weights = None 88 | 89 | return attn_output, attn_weights, past_key_value 90 | 91 | class QuarotLlamaAttention(QuarotFP16LlamaAttention): 92 | 93 | def __init__(self, *args, **kwargs): 94 | super().__init__(*args, **kwargs) 95 | self.quantizer = quarot.nn.Quantizer() 96 | self.q_proj = quarot.nn.Linear4bit.from_float(self.q_proj) 97 | self.k_proj = quarot.nn.Linear4bit.from_float(self.k_proj) 98 | self.v_proj = quarot.nn.Linear4bit.from_float(self.v_proj) 99 | self.o_proj_hadamard = quarot.nn.OnlineHadamard(self.num_heads) 100 | self.o_proj = torch.nn.Sequential( 101 | quarot.nn.Quantizer(), 102 | quarot.nn.Linear4bit.from_float(self.o_proj) 103 | ) 104 | 105 | class QuarotLlamaMLP(LlamaMLP): 106 | def __init__(self, *args, **kwargs): 107 | super().__init__(*args, **kwargs) 108 | self.quantizer = quarot.nn.Quantizer() 109 | self.up_proj = quarot.nn.Linear4bit.from_float(self.up_proj) 110 | self.gate_proj = quarot.nn.Linear4bit.from_float(self.gate_proj) 111 | self.down_proj = torch.nn.Sequential( 112 | quarot.nn.OnlineHadamard(self.intermediate_size), 113 | quarot.nn.Quantizer(), 114 | quarot.nn.Linear4bit.from_float(self.down_proj) 115 | ) 116 | 117 | def forward(self, x): 118 | x = self.quantizer(x) 119 | return super().forward(x) 120 | 121 | 122 | class QuarotFP16LlamaForCausalLM(LlamaForCausalLM): 123 | def __init__(self, config): 124 | super().__init__(config) 125 | assert config._attn_implementation == "flash_attention_2" 126 | for layer_idx, layer in enumerate(self.model.layers): 127 | layer.self_attn = QuarotFP16LlamaAttention(config=config, layer_idx=layer_idx) 128 | self.cache_dtype = "float16" 129 | self._expected_max_length = None 130 | 131 | 132 | def build_cache(self, batch_size, page_size, max_length): 133 | device = self.model.layers[0].self_attn.v_proj.weight.device 134 | dtype = self.cache_dtype or self.model.layers[0].self_attn.v_proj.weight.dtype 135 | 136 | num_heads = self.config.num_key_value_heads 137 | model_dim = self.config.hidden_size 138 | head_dim = model_dim // num_heads 139 | disable_quant = self.cache_dtype == "float16" 140 | return quarot.transformers.MultiLayerPagedKVCache4Bit( 141 | batch_size=batch_size, 142 | page_size=page_size, 143 | max_seq_len=max_length, 144 | device=device, 145 | n_layers=len(self.model.layers), 146 | num_heads=num_heads, 147 | head_dim=head_dim, 148 | disable_quant=disable_quant, 149 | hadamard_dtype=None if disable_quant else torch.float16 150 | ) 151 | 152 | def _get_logits_processor(self, generation_config, *args, **kwargs): 153 | # This is a hack to get the max length from generation_config. 154 | # Doing it here because max_length might not be set before this 155 | # method is called. 156 | self._expected_max_length = generation_config.max_length # This value will be reset at the next forward call 157 | return super()._get_logits_processor(generation_config, *args, **kwargs) 158 | 159 | 160 | def forward(self, input_ids, *args, past_key_values=None, **kwargs): 161 | if past_key_values is None: 162 | max_length = self._expected_max_length or input_ids.shape[1] 163 | self._expected_max_length = None # Reset this value. 164 | past_key_values = self.build_cache( 165 | input_ids.shape[0], 166 | page_size=max_length, # For now working with single page per batch. 167 | max_length=max_length) 168 | out = super().forward(input_ids, *args, past_key_values=past_key_values, **kwargs) 169 | return out 170 | 171 | 172 | 173 | class QuarotLlamaForCausalLM(QuarotFP16LlamaForCausalLM): 174 | def __init__(self, config): 175 | super().__init__(config) 176 | assert config._attn_implementation == "flash_attention_2" 177 | self.norm = quarot.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 178 | for layer_idx, layer in enumerate(self.model.layers): 179 | layer.self_attn = QuarotLlamaAttention(config=config, layer_idx=layer_idx) 180 | layer.input_layernorm = quarot.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 181 | layer.post_attention_layernorm = quarot.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 182 | layer.mlp = QuarotLlamaMLP(config=config) 183 | self.cache_dtype = "int4" 184 | -------------------------------------------------------------------------------- /fake_quant/README.md: -------------------------------------------------------------------------------- 1 | # Fake Quantization in QuaRot 2 | 3 | 4 | In this directory, we provide the torch scripts for the experiments in QuaRot. 5 | 6 | 7 | ## Language Generation and Zero-Shot Evaluations 8 | 9 | Currently, we only support **LLaMa-2** models. You can simply run the `main.py` to reproduce the results in the paper. The most important arguments are: 10 | 11 | - `--model`: the model name (or path to the weights) 12 | - `--bsz`: the batch size for PPL evaluation 13 | - `--rotate`: whether we want to rotate the model 14 | - `--lm_eval`: whether we want to run LM-Eval for Zero-Shot tasks 15 | - `--tasks`: the tasks for LM-Eval 16 | - `--cal_dataset`: the calibration dataset for GPTQ quantization 17 | - `--a_bits`: the number of bits for activation quantization 18 | - `--w_bits`: the number of bits for weight quantization 19 | - `--v_bits`: the number of bits for value quantization 20 | - `--k_bits`: the number of bits for key quantization 21 | - `--w_clip`: Whether we want to clip the weights 22 | - `--a_clip_ratio`: The ratio of clipping for activation 23 | - `--k_clip_ratio`: The ratio of clipping for key 24 | - `--v_clip_ratio`: The ratio of clipping for value 25 | - `--w_asym`: Whether we want to use asymmetric quantization for weights 26 | - `--a_asym`: Whether we want to use asymmetric quantization for activation 27 | - `--v_asym`: Whether we want to use asymmetric quantization for value 28 | - `--k_asym`: Whether we want to use asymmetric quantization for key 29 | - `--a_groupsize`: The group size for activation quantization 30 | - `--w_groupsize`: The group size for weight quantization 31 | - `--v_groupsize`: The group size for value quantization 32 | - `--k_groupsize`: The group size for key quantization 33 | 34 | For example, to run the perplexity of `LLaMA2-7B` model with quantizing all weights and activations, you can run the following command: 35 | 36 | ```bash 37 | /bin/python main.py --model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --v_bits 4 --k_bits 4 --w_bits 4 --w_clip 38 | ``` 39 | -------------------------------------------------------------------------------- /fake_quant/data_utils.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import random 3 | import transformers 4 | 5 | def get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode=False): 6 | 7 | if hf_token is None: 8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 9 | else: 10 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 11 | 12 | if eval_mode: 13 | testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 15 | return testenc 16 | else: 17 | traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 18 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 19 | random.seed(seed) 20 | trainloader = [] 21 | for _ in range(nsamples): 22 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 23 | j = i + seqlen 24 | inp = trainenc.input_ids[:, i:j] 25 | tar = inp.clone() 26 | tar[:, :-1] = -100 27 | trainloader.append((inp, tar)) 28 | return trainloader 29 | 30 | def get_c4_new(nsamples, seed, seqlen, model, hf_token=None, eval_mode=False): 31 | 32 | if hf_token is None: 33 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 34 | else: 35 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 36 | 37 | if eval_mode: 38 | valdata = datasets.load_dataset( 39 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 40 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 41 | valenc = valenc.input_ids[:, :(256 * seqlen)] 42 | class TokenizerWrapper: 43 | def __init__(self, input_ids): 44 | self.input_ids = input_ids 45 | valenc = TokenizerWrapper(valenc) 46 | return valenc 47 | else: 48 | traindata = datasets.load_dataset( 49 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 50 | 51 | random.seed(seed) 52 | trainloader = [] 53 | for _ in range(nsamples): 54 | while True: 55 | i = random.randint(0, len(traindata) - 1) 56 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 57 | if trainenc.input_ids.shape[1] >= seqlen: 58 | break 59 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 60 | j = i + seqlen 61 | inp = trainenc.input_ids[:, i:j] 62 | tar = inp.clone() 63 | tar[:, :-1] = -100 64 | trainloader.append((inp, tar)) 65 | return trainloader 66 | 67 | 68 | 69 | 70 | def get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode=False): 71 | 72 | 73 | if hf_token is None: 74 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) 75 | else: 76 | tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) 77 | 78 | if eval_mode: 79 | testdata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='test') 80 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 81 | return testenc 82 | else: 83 | traindata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='train') 84 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 85 | random.seed(seed) 86 | trainloader = [] 87 | for _ in range(nsamples): 88 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 89 | j = i + seqlen 90 | inp = trainenc.input_ids[:, i:j] 91 | tar = inp.clone() 92 | tar[:, :-1] = -100 93 | trainloader.append((inp, tar)) 94 | return trainloader 95 | 96 | 97 | def get_loaders( 98 | name, nsamples=128, seed=0, seqlen=2048, model='', hf_token=None, eval_mode=False 99 | ): 100 | if 'wikitext2' in name: 101 | return get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode) 102 | if 'ptb' in name: 103 | return get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode) 104 | if 'c4' in name: 105 | return get_c4_new(nsamples, seed, seqlen, model, hf_token, eval_mode) -------------------------------------------------------------------------------- /fake_quant/eval_utils.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import model_utils 3 | import quant_utils 4 | import torch 5 | import os 6 | import logging 7 | from tqdm import tqdm 8 | 9 | 10 | @torch.no_grad() 11 | def evaluator(model, testenc, dev, args): 12 | 13 | model.eval() 14 | 15 | if 'opt' in args.model: 16 | opt_type = True 17 | llama_type = False 18 | elif 'meta' in args.model: 19 | llama_type = True 20 | opt_type = False 21 | else: 22 | raise ValueError(f'Unknown model {args.model}') 23 | 24 | 25 | use_cache = model.config.use_cache 26 | model.config.use_cache = False 27 | 28 | if opt_type: 29 | layers = model.model.decoder.layers 30 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 31 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 32 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 33 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 34 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 35 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 36 | 37 | elif llama_type: 38 | layers = model.model.layers 39 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 40 | 41 | layers[0] = layers[0].to(dev) 42 | 43 | # Convert the whole text of evaluation dataset into batches of sequences. 44 | input_ids = testenc.input_ids # (1, text_len) 45 | nsamples = input_ids.numel() // model.seqlen # The tail is truncated. 46 | input_ids = input_ids[:, :nsamples * model.seqlen].view(nsamples, model.seqlen).to(dev) # (nsamples, seqlen) 47 | 48 | batch_size = args.bsz 49 | input_ids = [input_ids[i:i + batch_size] for i in range(0, nsamples, batch_size)] 50 | nbatches = len(input_ids) 51 | 52 | dtype = next(iter(model.parameters())).dtype 53 | # The input of the first decoder layer. 54 | inps = torch.zeros( 55 | (nbatches, batch_size, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 56 | ) 57 | inps = [0] * nbatches 58 | cache = {'i': 0, 'attention_mask': None} 59 | class Catcher(torch.nn.Module): 60 | def __init__(self, module): 61 | super().__init__() 62 | self.module = module 63 | def forward(self, inp, **kwargs): 64 | inps[cache['i']] = inp 65 | cache['i'] += 1 66 | cache['attention_mask'] = kwargs['attention_mask'] 67 | if llama_type: 68 | cache['position_ids'] = kwargs['position_ids'] 69 | raise ValueError 70 | layers[0] = Catcher(layers[0]) 71 | 72 | for i in range(nbatches): 73 | batch = input_ids[i] 74 | try: 75 | model(batch) 76 | except ValueError: 77 | pass 78 | layers[0] = layers[0].module 79 | layers[0] = layers[0].cpu() 80 | 81 | if opt_type: 82 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 83 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 84 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 85 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 86 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 87 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 88 | elif llama_type: 89 | model.model.embed_tokens = model.model.embed_tokens.cpu() 90 | position_ids = cache['position_ids'] 91 | 92 | torch.cuda.empty_cache() 93 | outs = [0] * nbatches 94 | attention_mask = cache['attention_mask'] 95 | 96 | for i in tqdm(range(len(layers)), desc="(Eval) Layers"): 97 | layer = layers[i].to(dev) 98 | 99 | # Dump the layer input and output 100 | if args.capture_layer_io and args.layer_idx == i: 101 | captured_io = model_utils.capture_layer_io(model_utils.get_model_type(model), layer, inps) 102 | save_path = model_utils.get_layer_io_save_path(args) 103 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 104 | torch.save(captured_io, save_path) 105 | logging.info(f'Dumped layer input and output to: {save_path}') 106 | 107 | for j in range(nbatches): 108 | if opt_type: 109 | outs[j] = layer(inps[j], attention_mask=attention_mask)[0] 110 | elif llama_type: 111 | outs[j] = layer(inps[j], attention_mask=attention_mask, position_ids=position_ids)[0] 112 | layers[i] = layer.cpu() 113 | del layer 114 | torch.cuda.empty_cache() 115 | inps, outs = outs, inps 116 | 117 | if opt_type: 118 | if model.model.decoder.final_layer_norm is not None: 119 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 120 | if model.model.decoder.project_out is not None: 121 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 122 | 123 | elif llama_type: 124 | if model.model.norm is not None: 125 | model.model.norm = model.model.norm.to(dev) 126 | 127 | model.lm_head = model.lm_head.to(dev) 128 | nlls = [] 129 | loss_fct = torch.nn.CrossEntropyLoss(reduction = "none") 130 | for i in range(nbatches): 131 | hidden_states = inps[i] 132 | if opt_type: 133 | if model.model.decoder.final_layer_norm is not None: 134 | hidden_states = model.model.decoder.final_layer_norm(hidden_states) 135 | if model.model.decoder.project_out is not None: 136 | hidden_states = model.model.decoder.project_out(hidden_states) 137 | elif llama_type: 138 | if model.model.norm is not None: 139 | hidden_states = model.model.norm(hidden_states) 140 | lm_logits = model.lm_head(hidden_states) 141 | shift_logits = lm_logits[:, :-1, :] 142 | shift_labels = input_ids[i][:, 1:] 143 | loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels) 144 | neg_log_likelihood = loss.float().mean(dim=1) 145 | nlls.append(neg_log_likelihood) 146 | nlls_tensor = torch.cat(nlls) 147 | ppl = torch.exp(nlls_tensor.mean()) 148 | model.config.use_cache = use_cache 149 | logging.info(f'\n{args.eval_dataset.upper()} PPL: {ppl.item():.3f}') 150 | return ppl.item() 151 | -------------------------------------------------------------------------------- /fake_quant/gptq_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import utils 7 | import quant_utils 8 | import logging 9 | 10 | torch.backends.cuda.matmul.allow_tf32 = False 11 | torch.backends.cudnn.allow_tf32 = False 12 | 13 | 14 | class GPTQ: 15 | 16 | def __init__(self, layer): 17 | self.layer = layer 18 | self.dev = self.layer.weight.device 19 | W = layer.weight.data.clone() 20 | self.rows = W.shape[0] 21 | self.columns = W.shape[1] 22 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 23 | self.nsamples = 0 24 | 25 | def add_batch(self, inp, out): 26 | 27 | if len(inp.shape) == 2: 28 | inp = inp.unsqueeze(0) 29 | tmp = inp.shape[0] 30 | if len(inp.shape) == 3: 31 | inp = inp.reshape((-1, inp.shape[-1])) 32 | inp = inp.t() 33 | self.H *= self.nsamples / (self.nsamples + tmp) 34 | self.nsamples += tmp 35 | # inp = inp.float() 36 | inp = math.sqrt(2 / self.nsamples) * inp.float() 37 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 38 | self.H += inp.matmul(inp.t()) 39 | 40 | def fasterquant( 41 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False 42 | ): 43 | W = self.layer.weight.data.clone() 44 | W = W.float() 45 | 46 | tick = time.time() 47 | 48 | if not self.quantizer.ready(): 49 | self.quantizer.find_params(W) 50 | 51 | H = self.H 52 | del self.H 53 | dead = torch.diag(H) == 0 54 | H[dead, dead] = 1 55 | W[:, dead] = 0 56 | 57 | if static_groups: 58 | import copy 59 | groups = [] 60 | for i in range(0, self.columns, groupsize): 61 | quantizer = copy.deepcopy(self.quantizer) 62 | quantizer.find_params(W[:, i:(i + groupsize)]) 63 | groups.append(quantizer) 64 | 65 | if actorder: 66 | perm = torch.argsort(torch.diag(H), descending=True) 67 | W = W[:, perm] 68 | H = H[perm][:, perm] 69 | invperm = torch.argsort(perm) 70 | 71 | Losses = torch.zeros_like(W) 72 | Q = torch.zeros_like(W) 73 | 74 | damp = percdamp * torch.mean(torch.diag(H)) 75 | diag = torch.arange(self.columns, device=self.dev) 76 | H[diag, diag] += damp 77 | H = torch.linalg.cholesky(H) 78 | H = torch.cholesky_inverse(H) 79 | H = torch.linalg.cholesky(H, upper=True) 80 | Hinv = H 81 | 82 | for i1 in range(0, self.columns, blocksize): 83 | i2 = min(i1 + blocksize, self.columns) 84 | count = i2 - i1 85 | 86 | W1 = W[:, i1:i2].clone() 87 | Q1 = torch.zeros_like(W1) 88 | Err1 = torch.zeros_like(W1) 89 | Losses1 = torch.zeros_like(W1) 90 | Hinv1 = Hinv[i1:i2, i1:i2] 91 | 92 | for i in range(count): 93 | w = W1[:, i] 94 | d = Hinv1[i, i] 95 | 96 | if groupsize != -1: 97 | if not static_groups: 98 | if (i1 + i) % groupsize == 0: 99 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)]) 100 | else: 101 | idx = i1 + i 102 | if actorder: 103 | idx = perm[idx] 104 | self.quantizer = groups[idx // groupsize] 105 | 106 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten() 107 | Q1[:, i] = q 108 | Losses1[:, i] = (w - q) ** 2 / d ** 2 109 | 110 | err1 = (w - q) / d 111 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 112 | Err1[:, i] = err1 113 | 114 | Q[:, i1:i2] = Q1 115 | Losses[:, i1:i2] = Losses1 / 2 116 | 117 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 118 | 119 | torch.cuda.synchronize() 120 | 121 | if actorder: 122 | Q = Q[:, invperm] 123 | 124 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 125 | if torch.any(torch.isnan(self.layer.weight.data)): 126 | logging.warning('NaN in weights') 127 | import pprint 128 | pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point) 129 | raise ValueError('NaN in weights') 130 | 131 | def free(self): 132 | self.H = None 133 | self.Losses = None 134 | self.Trace = None 135 | torch.cuda.empty_cache() 136 | utils.cleanup_memory(verbos=False) 137 | 138 | 139 | @torch.no_grad() 140 | def gptq_fwrd(model, dataloader, dev, args): 141 | ''' 142 | From GPTQ repo 143 | TODO: Make this function general to support both OPT and LLaMA models 144 | ''' 145 | logging.info('-----GPTQ Quantization-----') 146 | 147 | use_cache = model.config.use_cache 148 | model.config.use_cache = False 149 | layers = model.model.layers 150 | 151 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 152 | model.model.norm = model.model.norm.to(dev) 153 | layers[0] = layers[0].to(dev) 154 | 155 | dtype = next(iter(model.parameters())).dtype 156 | inps = torch.zeros( 157 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 158 | ) 159 | cache = {'i': 0, 'attention_mask': None} 160 | 161 | class Catcher(nn.Module): 162 | def __init__(self, module): 163 | super().__init__() 164 | self.module = module 165 | def forward(self, inp, **kwargs): 166 | inps[cache['i']] = inp 167 | cache['i'] += 1 168 | cache['attention_mask'] = kwargs['attention_mask'] 169 | cache['position_ids'] = kwargs['position_ids'] 170 | raise ValueError 171 | layers[0] = Catcher(layers[0]) 172 | for batch in dataloader: 173 | try: 174 | model(batch[0].to(dev)) 175 | except ValueError: 176 | pass 177 | layers[0] = layers[0].module 178 | 179 | layers[0] = layers[0].cpu() 180 | model.model.embed_tokens = model.model.embed_tokens.cpu() 181 | model.model.norm = model.model.norm.cpu() 182 | torch.cuda.empty_cache() 183 | 184 | outs = torch.zeros_like(inps) 185 | attention_mask = cache['attention_mask'] 186 | position_ids = cache['position_ids'] 187 | 188 | quantizers = {} 189 | sequential = [ 190 | ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'], 191 | ['self_attn.o_proj.module'], 192 | ['mlp.up_proj.module', 'mlp.gate_proj.module'], 193 | ['mlp.down_proj.module'] 194 | ] 195 | for i in range(len(layers)): 196 | print(f'\nLayer {i}:', flush=True, end=' ') 197 | layer = layers[i].to(dev) 198 | full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear]) 199 | for names in sequential: 200 | subset = {n: full[n] for n in names} 201 | 202 | gptq = {} 203 | for name in subset: 204 | print(f'{name}', end=' ', flush=True) 205 | layer_weight_bits = args.w_bits 206 | layer_weight_sym = not(args.w_asym) 207 | if 'lm_head' in name: 208 | layer_weight_bits = 16 209 | continue 210 | if args.int8_down_proj and 'down_proj' in name: 211 | layer_weight_bits = 8 212 | gptq[name] = GPTQ(subset[name]) 213 | gptq[name].quantizer = quant_utils.WeightQuantizer() 214 | gptq[name].quantizer.configure( 215 | layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip 216 | ) 217 | 218 | def add_batch(name): 219 | def tmp(_, inp, out): 220 | gptq[name].add_batch(inp[0].data, out.data) 221 | return tmp 222 | handles = [] 223 | for name in subset: 224 | handles.append(subset[name].register_forward_hook(add_batch(name))) 225 | for j in range(args.nsamples): 226 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 227 | for h in handles: 228 | h.remove() 229 | 230 | for name in subset: 231 | layer_w_groupsize = args.w_groupsize 232 | gptq[name].fasterquant( 233 | percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False 234 | ) 235 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer 236 | gptq[name].free() 237 | 238 | for j in range(args.nsamples): 239 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 240 | 241 | layers[i] = layer.cpu() 242 | del layer 243 | del gptq 244 | torch.cuda.empty_cache() 245 | 246 | inps, outs = outs, inps 247 | 248 | model.config.use_cache = use_cache 249 | utils.cleanup_memory(verbos=True) 250 | logging.info('-----GPTQ Quantization Done-----\n') 251 | return quantizers 252 | 253 | 254 | 255 | 256 | @torch.no_grad() 257 | def rtn_fwrd(model, dev, args): 258 | ''' 259 | From GPTQ repo 260 | TODO: Make this function general to support both OPT and LLaMA models 261 | ''' 262 | assert args.w_groupsize ==-1, "Groupsize not supported in RTN!" 263 | layers = model.model.layers 264 | torch.cuda.empty_cache() 265 | 266 | quantizers = {} 267 | 268 | for i in tqdm.tqdm(range(len(layers)), desc="(RtN Quant.) Layers"): 269 | layer = layers[i].to(dev) 270 | 271 | subset = quant_utils.find_qlayers(layer, 272 | layers=[torch.nn.Linear]) 273 | 274 | for name in subset: 275 | layer_weight_bits = args.w_bits 276 | if 'lm_head' in name: 277 | layer_weight_bits = 16 278 | continue 279 | if args.int8_down_proj and 'down_proj' in name: 280 | layer_weight_bits = 8 281 | 282 | quantizer = quant_utils.WeightQuantizer() 283 | quantizer.configure( 284 | layer_weight_bits, perchannel=True, sym=not(args.w_asym), mse=args.w_clip 285 | ) 286 | W = subset[name].weight.data 287 | quantizer.find_params(W) 288 | subset[name].weight.data = quantizer.quantize(W).to( 289 | next(iter(layer.parameters())).dtype) 290 | quantizers['model.layers.%d.%s' % (i, name)] = quantizer.cpu() 291 | layers[i] = layer.cpu() 292 | torch.cuda.empty_cache() 293 | del layer 294 | 295 | utils.cleanup_memory(verbos=True) 296 | return quantizers 297 | -------------------------------------------------------------------------------- /fake_quant/main.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import torch 3 | import model_utils 4 | import data_utils 5 | import transformers 6 | import quant_utils 7 | import rotation_utils 8 | import gptq_utils 9 | import eval_utils 10 | import hadamard_utils 11 | 12 | def main(): 13 | args = utils.parser_gen() 14 | if args.wandb: 15 | import wandb 16 | wandb.init(project=args.wandb_project, entity=args.wandb_id) 17 | wandb.config.update(args) 18 | 19 | transformers.set_seed(args.seed) 20 | model = model_utils.get_model(args.model, args.hf_token) 21 | model.eval() 22 | 23 | 24 | # Rotate the weights 25 | if args.rotate: 26 | rotation_utils.fuse_layer_norms(model) 27 | rotation_utils.rotate_model(model, args) 28 | utils.cleanup_memory(verbos=True) 29 | 30 | quant_utils.add_actquant(model) #Add Activation Wrapper to the model 31 | qlayers = quant_utils.find_qlayers(model) 32 | for name in qlayers: 33 | if 'down_proj' in name: 34 | had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size) 35 | qlayers[name].online_full_had = True 36 | qlayers[name].had_K = had_K 37 | qlayers[name].K = K 38 | qlayers[name].fp32_had = args.fp32_had 39 | if 'o_proj' in name: 40 | had_K, K = hadamard_utils.get_hadK(model.config.num_attention_heads) 41 | qlayers[name].online_partial_had = True 42 | qlayers[name].had_K = had_K 43 | qlayers[name].K = K 44 | qlayers[name].had_dim = model.config.hidden_size//model.config.num_attention_heads 45 | qlayers[name].fp32_had = args.fp32_had 46 | else: 47 | quant_utils.add_actquant(model) #Add Activation Wrapper to the model as the rest of the code assumes it is present 48 | 49 | 50 | if args.w_bits < 16: 51 | save_dict = {} 52 | if args.load_qmodel_path: # Load Quantized Rotated Model 53 | assert args.rotate, "Model should be rotated to load a quantized model!" 54 | assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!" 55 | print("Load quantized model from ", args.load_qmodel_path) 56 | save_dict = torch.load(args.load_qmodel_path) 57 | model.load_state_dict(save_dict["model"]) 58 | 59 | elif not args.w_rtn: # GPTQ Weight Quantization 60 | assert "llama" in args.model, "Only llama is supported for GPTQ!" 61 | 62 | trainloader = data_utils.get_loaders( 63 | args.cal_dataset, nsamples=args.nsamples, 64 | seed=args.seed, model=args.model, 65 | seqlen=model.seqlen, eval_mode=False 66 | ) 67 | quantizers = gptq_utils.gptq_fwrd(model, trainloader, utils.DEV, args) 68 | save_dict["w_quantizers"] = quantizers 69 | else: # RTN Weight Quantization 70 | quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args) 71 | save_dict["w_quantizers"] = quantizers 72 | 73 | if args.save_qmodel_path: 74 | save_dict["model"] = model.state_dict() 75 | torch.save(save_dict, args.save_qmodel_path) 76 | 77 | 78 | # Add Input Quantization 79 | if args.a_bits < 16 or args.v_bits < 16: 80 | qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper]) 81 | down_proj_groupsize = -1 82 | if args.a_groupsize > 0 and "llama" in args.model: 83 | down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize) 84 | 85 | for name in qlayers: 86 | layer_input_bits = args.a_bits 87 | layer_groupsize = args.a_groupsize 88 | layer_a_sym = not(args.a_asym) 89 | layer_a_clip = args.a_clip_ratio 90 | 91 | if 'v_proj' in name and args.v_bits < 16: #Set the v_proj precision 92 | qlayers[name].out_quantizer.configure(bits=args.v_bits, 93 | groupsize=args.v_groupsize, 94 | sym=not(args.v_asym), 95 | clip_ratio=args.v_clip_ratio) 96 | 97 | if 'lm_head' in name: #Skip lm_head quantization 98 | layer_input_bits = 16 99 | 100 | if 'down_proj' in name: #Set the down_proj precision 101 | if args.int8_down_proj: 102 | layer_input_bits = 8 103 | layer_groupsize = down_proj_groupsize 104 | 105 | 106 | qlayers[name].quantizer.configure(bits=layer_input_bits, 107 | groupsize=layer_groupsize, 108 | sym=layer_a_sym, 109 | clip_ratio=layer_a_clip) 110 | 111 | if args.k_bits < 16: 112 | if args.k_pre_rope: 113 | raise NotImplementedError("Pre-RoPE quantization is not supported yet!") 114 | else: 115 | rope_function_name = model_utils.get_rope_function_name(model) 116 | layers = model_utils.get_layers(model) 117 | k_quant_config = {'k_bits':args.k_bits, "k_groupsize": args.k_groupsize, 118 | "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio} 119 | for layer in layers: 120 | rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward( 121 | layer.self_attn, 122 | rope_function_name, 123 | config=model.config, 124 | **k_quant_config) 125 | 126 | # Evaluating on dataset 127 | testloader = data_utils.get_loaders( 128 | args.eval_dataset, 129 | seed=args.seed, 130 | model=args.model, 131 | seqlen=model.seqlen, 132 | hf_token=args.hf_token, 133 | eval_mode=True 134 | ) 135 | 136 | 137 | dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, args) 138 | if args.wandb: 139 | wandb.log({'ppl/{}'.format(args.eval_dataset.upper()): dataset_ppl}) 140 | 141 | if not args.lm_eval: 142 | return 143 | else: 144 | # Import lm_eval utils 145 | import lm_eval 146 | from lm_eval import utils as lm_eval_utils 147 | from lm_eval.api.registry import ALL_TASKS 148 | from lm_eval.models.huggingface import HFLM 149 | 150 | 151 | 152 | if args.distribute: 153 | utils.distribute_model(model) 154 | else: 155 | model.to(utils.DEV) 156 | 157 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=False, use_auth_token=args.hf_token) 158 | hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.lm_eval_batch_size) 159 | 160 | task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS) 161 | results = lm_eval.simple_evaluate(hflm, tasks=task_names, batch_size=args.lm_eval_batch_size)['results'] 162 | 163 | metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results.items()} 164 | metric_vals['acc_avg'] = round(sum(metric_vals.values()) / len(metric_vals.values()), 4) 165 | print(metric_vals) 166 | 167 | if args.wandb: 168 | wandb.log(metric_vals) 169 | 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /fake_quant/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import typing 3 | import transformers 4 | import utils 5 | import os 6 | import logging 7 | 8 | OPT_MODEL = transformers.models.opt.modeling_opt.OPTForCausalLM 9 | OPT_LAYER = transformers.models.opt.modeling_opt.OPTDecoderLayer 10 | LLAMA_MODEL = transformers.models.llama.modeling_llama.LlamaForCausalLM 11 | LLAMA_LAYER = transformers.models.llama.modeling_llama.LlamaDecoderLayer 12 | 13 | 14 | def model_type_extractor(model): 15 | if isinstance(model, LLAMA_MODEL): 16 | return LLAMA_MODEL 17 | elif isinstance(model, OPT_MODEL): 18 | return OPT_MODEL 19 | else: 20 | raise ValueError(f'Unknown model type {model}') 21 | 22 | def skip(*args, **kwargs): 23 | # This is a helper function to save time during the initialization! 24 | pass 25 | 26 | def get_rope_function_name(model): 27 | if isinstance(model, LLAMA_MODEL): 28 | return "apply_rotary_pos_emb" 29 | raise NotImplementedError 30 | 31 | 32 | def get_layers(model): 33 | if isinstance(model, OPT_MODEL): 34 | return model.model.decoder.layers 35 | if isinstance(model, LLAMA_MODEL): 36 | return model.model.layers 37 | raise NotImplementedError 38 | 39 | 40 | def get_llama(model_name, hf_token): 41 | torch.nn.init.kaiming_uniform_ = skip 42 | torch.nn.init.uniform_ = skip 43 | torch.nn.init.normal_ = skip 44 | model = transformers.LlamaForCausalLM.from_pretrained(model_name, torch_dtype='auto', 45 | use_auth_token=hf_token, 46 | low_cpu_mem_usage=True) 47 | model.seqlen = 2048 48 | logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen)) 49 | return model 50 | 51 | 52 | 53 | def get_opt(model_name): 54 | torch.nn.init.kaiming_uniform_ = skip 55 | torch.nn.init.uniform_ = skip 56 | torch.nn.init.normal_ = skip 57 | model = transformers.OPTForCausalLM.from_pretrained(model_name, torch_dtype='auto', 58 | low_cpu_mem_usage=True) 59 | model.seqlen = model.config.max_position_embeddings 60 | logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen)) 61 | return model 62 | 63 | 64 | def get_model( 65 | model_name, hf_token=None 66 | ): 67 | if 'llama' in model_name: 68 | return get_llama(model_name, hf_token) 69 | elif 'opt' in model_name: 70 | return get_opt(model_name) 71 | else: 72 | raise ValueError(f'Unknown model {model_name}') 73 | 74 | 75 | def get_model_type(model): 76 | if isinstance(model, OPT_MODEL): 77 | model_type = OPT_MODEL 78 | elif isinstance(model, LLAMA_MODEL): 79 | model_type = LLAMA_MODEL 80 | else: 81 | raise ValueError(f'Unknown model type {model}') 82 | return model_type 83 | 84 | def get_embeddings(model, model_type) -> list[torch.nn.Module]: 85 | if model_type == LLAMA_MODEL: 86 | return [model.model.embed_tokens] 87 | elif model_type == OPT_MODEL: 88 | return [model.model.decoder.embed_tokens, model.model.decoder.embed_positions] 89 | else: 90 | raise ValueError(f'Unknown model type {model_type}') 91 | 92 | 93 | def get_transformer_layers(model, model_type): 94 | if model_type == LLAMA_MODEL: 95 | return [layer for layer in model.model.layers] 96 | elif model_type == OPT_MODEL: 97 | return [layer for layer in model.model.decoder.layers] 98 | else: 99 | raise ValueError(f'Unknown model type {model_type}') 100 | 101 | 102 | def get_lm_head(model, model_type): 103 | if model_type == LLAMA_MODEL: 104 | return model.lm_head 105 | elif model_type == OPT_MODEL: 106 | return model.lm_head 107 | else: 108 | raise ValueError(f'Unknown model type {model_type}') 109 | 110 | def get_pre_head_layernorm(model, model_type): 111 | if model_type == LLAMA_MODEL: 112 | pre_head_layernorm = model.model.norm 113 | assert isinstance(pre_head_layernorm, 114 | transformers.models.llama.modeling_llama.LlamaRMSNorm) 115 | elif model_type == OPT_MODEL: 116 | pre_head_layernorm = model.model.decoder.final_layer_norm 117 | assert pre_head_layernorm is not None 118 | else: 119 | raise ValueError(f'Unknown model type {model_type}') 120 | return pre_head_layernorm 121 | 122 | def get_mlp_bottleneck_size(model): 123 | model_type = get_model_type(model) 124 | if model_type == LLAMA_MODEL: 125 | return model.config.intermediate_size 126 | elif model_type == OPT_MODEL: 127 | return model.config.ffn_dim 128 | else: 129 | raise ValueError(f'Unknown model type {model_type}') 130 | 131 | def replace_modules( 132 | root: torch.nn.Module, 133 | type_to_replace, 134 | new_module_factory, 135 | replace_layers: bool, 136 | ) -> None: 137 | """Replace modules of given type using the supplied module factory. 138 | 139 | Perform a depth-first search of a module hierarchy starting at root 140 | and replace all instances of type_to_replace with modules created by 141 | new_module_factory. Children of replaced modules are not processed. 142 | 143 | Args: 144 | root: the root of the module hierarchy where modules should be replaced 145 | type_to_replace: a type instances of which will be replaced 146 | new_module_factory: a function that given a module that should be replaced 147 | produces a module to replace it with. 148 | """ 149 | for name, module in root.named_children(): 150 | new_module = None 151 | if isinstance(module, type_to_replace): 152 | if replace_layers: # layernorm_fusion.replace_layers case where transformer layers are replaced 153 | new_module = new_module_factory(module, int(name)) 154 | else: # layernorm_fusion.fuse_modules case where layernorms are fused 155 | new_module = new_module_factory(module) 156 | elif len(list(module.children())) > 0: 157 | replace_modules(module, type_to_replace, new_module_factory, replace_layers) 158 | 159 | if new_module is not None: 160 | setattr(root, name, new_module) 161 | 162 | 163 | class RMSN(torch.nn.Module): 164 | """ 165 | This class implements the Root Mean Square Normalization (RMSN) layer. 166 | We use the implementation from LLAMARMSNorm here: 167 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L75 168 | """ 169 | 170 | def __init__(self, mean_dim: int, eps=1e-5): 171 | super().__init__() 172 | self.eps = eps 173 | self.mean_dim = mean_dim 174 | self.weight = torch.nn.Parameter(torch.zeros(1)) 175 | 176 | def forward(self, x: torch.Tensor) -> torch.Tensor: 177 | input_dtype = x.dtype 178 | if x.dtype == torch.float16: 179 | x = x.to(torch.float32) 180 | variance = x.pow(2).sum(-1, keepdim=True) / self.mean_dim 181 | x = x * torch.rsqrt(variance + self.eps) 182 | return x.to(input_dtype) 183 | 184 | 185 | def get_layer_io_save_path(args): 186 | return os.path.join(args.save_path, 'layer_io', f'{args.layer_idx:03d}.pt') 187 | 188 | def capture_layer_io(model_type, layer, layer_input): 189 | def hook_factory(module_name, captured_vals, is_input): 190 | def hook(module, input, output): 191 | if is_input: 192 | captured_vals[module_name].append(input[0].detach().cpu()) 193 | else: 194 | captured_vals[module_name].append(output.detach().cpu()) 195 | return hook 196 | 197 | handles = [] 198 | 199 | if model_type == LLAMA_MODEL: 200 | captured_inputs = { 201 | 'k_proj': [], # q_proj, v_proj has the same input as k_proj 202 | 'o_proj': [], 203 | 'gate_proj': [], # up_proj has the same input as gate_proj 204 | 'down_proj': [] 205 | } 206 | 207 | captured_outputs = { 208 | 'v_proj': [], 209 | } 210 | 211 | for name in captured_inputs.keys(): 212 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None) 213 | handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True))) 214 | 215 | for name in captured_outputs.keys(): 216 | module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None) 217 | handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False))) 218 | 219 | elif model_type == OPT_MODEL: 220 | captured_inputs = { 221 | 'k_proj': [], # q_proj, v_proj has the same input as k_proj 222 | 'out_proj': [], 223 | 'fc1': [], 224 | 'fc2': [] 225 | } 226 | captured_outputs = { 227 | 'v_proj': [], 228 | } 229 | for name in captured_inputs.keys(): 230 | # In OPT, fc1 and fc2 are directly contained in OPTDecoderLayer 231 | module = getattr(layer.self_attn, name, None) or getattr(layer, name, None) 232 | handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True))) 233 | 234 | for name in captured_outputs.keys(): 235 | # In OPT, fc1 and fc2 are directly contained in OPTDecoderLayer 236 | module = getattr(layer.self_attn, name, None) or getattr(layer, name, None) 237 | handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False))) 238 | else: 239 | raise ValueError(f'Unknown model type {model_type}') 240 | 241 | # Process each sequence in the batch one by one to avoid OOM. 242 | for seq_idx in range(layer_input.shape[0]): 243 | # Extract the current sequence across all dimensions. 244 | seq = layer_input[seq_idx:seq_idx + 1].to(utils.DEV) 245 | # Perform a forward pass for the current sequence. 246 | layer(seq) 247 | 248 | # After processing all sequences, concatenate the accumulated inputs for each sub-layer across the batch. 249 | for module_name in captured_inputs: 250 | captured_inputs[module_name] = torch.cat(captured_inputs[module_name], dim=0) 251 | for module_name in captured_outputs: 252 | captured_outputs[module_name] = torch.cat(captured_outputs[module_name], dim=0) 253 | 254 | # Cleanup. 255 | for h in handles: 256 | h.remove() 257 | 258 | return { 259 | 'input': captured_inputs, 260 | 'output': captured_outputs 261 | } 262 | -------------------------------------------------------------------------------- /fake_quant/monkeypatch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import types 4 | 5 | def copy_func_with_new_globals(f, globals=None): 6 | """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" 7 | if globals is None: 8 | globals = f.__globals__ 9 | g = types.FunctionType(f.__code__, globals, name=f.__name__, 10 | argdefs=f.__defaults__, closure=f.__closure__) 11 | g = functools.update_wrapper(g, f) 12 | g.__module__ = f.__module__ 13 | g.__kwdefaults__ = copy.copy(f.__kwdefaults__) 14 | return g 15 | 16 | def add_wrapper_after_function_call_in_method(module, method_name, function_name, wrapper_fn): 17 | ''' 18 | This function adds a wrapper after the output of a function call in the method named `method_name`. 19 | Only calls directly in the method are affected. Calls by other functions called in the method are not affected. 20 | ''' 21 | 22 | original_method = getattr(module, method_name).__func__ 23 | method_globals = dict(original_method.__globals__) 24 | wrapper = wrapper_fn(method_globals[function_name]) 25 | method_globals[function_name] = wrapper 26 | new_method = copy_func_with_new_globals(original_method, globals=method_globals) 27 | setattr(module, method_name, new_method.__get__(module)) 28 | return wrapper 29 | 30 | -------------------------------------------------------------------------------- /img/carrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/QuaRot/5008669b08c1f11f9b64d52d16fddd47ca754c5a/img/carrot.png -------------------------------------------------------------------------------- /img/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/QuaRot/5008669b08c1f11f9b64d52d16fddd47ca754c5a/img/fig1.png -------------------------------------------------------------------------------- /quarot/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import nn 3 | from . import functional 4 | 5 | 6 | import quarot._CUDA 7 | 8 | 9 | __all__ = [ 10 | "matmul", #int-4 matmul 11 | "sym_quant", "sym_dequant", "PackedQuantizedTensor", # Quantization 12 | ] 13 | 14 | class ShapeHandler: 15 | def __init__(self, x: torch.Tensor): 16 | self.size_excl_last = x.numel()//x.shape[-1] 17 | self.shape_excl_last = tuple(x.shape[:-1]) 18 | 19 | # Keep the last dim unchanged, flatten all previous dims 20 | def flatten(self, x: torch.Tensor): 21 | return x.view(self.size_excl_last, -1) 22 | 23 | # Recover back to the original shape. 24 | def unflatten(self, x: torch.Tensor): 25 | return x.view(self.shape_excl_last + (-1,)) 26 | 27 | def unflatten_scale(self, x: torch.Tensor): 28 | return x.view(self.shape_excl_last) 29 | 30 | 31 | def flatten_last_dim_and_return_shape(x: torch.Tensor): 32 | shape_excl_last = x.shape[:-1] 33 | x = x.view(-1, x.shape[-1]) 34 | return x, shape_excl_last 35 | 36 | 37 | def matmul(A, B): 38 | assert A.shape[-1] % 32 == 0, "A.shape[-1]: {} must be multiplication of 32".format(A.shape[-1]) 39 | A, A_shape_excl_last = flatten_last_dim_and_return_shape(A) 40 | B, B_shape_excl_last = flatten_last_dim_and_return_shape(B) 41 | return quarot._CUDA.matmul(A, B).view(*A_shape_excl_last, *B_shape_excl_last) 42 | 43 | def sym_quant(x, scale): 44 | assert x.dtype == scale.dtype == torch.float16 45 | x, x_shape_excl_last = flatten_last_dim_and_return_shape(x) 46 | return quarot._CUDA.sym_quant(x, scale.view(-1)).view(*x_shape_excl_last, -1) 47 | 48 | def sym_dequant(q, scale_row, scale_col, bits=32): 49 | assert q.dtype == torch.int32 50 | assert scale_row.dtype == scale_col.dtype == torch.float16 51 | q, q_shape_excl_last = flatten_last_dim_and_return_shape(q) 52 | return quarot._CUDA.sym_dequant(q, scale_row.view(-1), scale_col, bits).view(*q_shape_excl_last, -1) 53 | 54 | 55 | class PackedQuantizedTensor: 56 | def __init__(self, 57 | quantized_x: torch.Tensor, 58 | scales_x: torch.Tensor): 59 | self.quantized_x = quantized_x 60 | self.scales_x = scales_x 61 | 62 | def size(self): 63 | return self.quantized_x.size() 64 | 65 | @property 66 | def device(self): 67 | return self.quantized_x.device 68 | 69 | @property 70 | def dtype(self): 71 | return self.quantized_x.dtype 72 | -------------------------------------------------------------------------------- /quarot/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantization import pack_i4, unpack_i4, asym_quant_dequant, sym_quant_dequant 2 | from .hadamard import ( 3 | matmul_hadU_cuda, 4 | random_hadamard_matrix, 5 | apply_exact_had_to_linear) 6 | -------------------------------------------------------------------------------- /quarot/functional/quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def two_compl(x, bits: int): 5 | return torch.where(x < 0, 2 ** bits + x, x) 6 | 7 | def get_minq_maxq(bits: int, sym: bool): 8 | if sym: 9 | maxq = torch.tensor(2**(bits-1)-1) 10 | minq = torch.tensor(-maxq -1) 11 | else: 12 | maxq = torch.tensor(2**bits - 1) 13 | minq = torch.tensor(0) 14 | 15 | return minq, maxq 16 | 17 | def asym_quant(x, scale, zero, maxq): 18 | scale = scale.to(x.device) 19 | zero = zero.to(x.device) 20 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 21 | return q, scale, zero 22 | 23 | def asym_dequant(q, scale, zero): 24 | return scale * (q - zero) 25 | 26 | def asym_quant_dequant(x, scale, zero, maxq): 27 | return asym_dequant(*asym_quant(x, scale, zero, maxq)) 28 | 29 | def sym_quant(x, scale, maxq): 30 | scale = scale.to(x.device) 31 | q = torch.clamp(torch.round(x / scale), -(maxq+1), maxq) 32 | return q, scale 33 | def sym_dequant(q, scale): 34 | return scale * q 35 | 36 | def sym_quant_dequant(x, scale, maxq): 37 | return sym_dequant(*sym_quant(x, scale, maxq)) 38 | 39 | 40 | 41 | # Pack the int tensor. Each uint8 stores two int4 value. 42 | def pack_i4(q): 43 | assert torch.is_signed(q), 'The tensor to be packed should be signed int' 44 | minq, maxq = get_minq_maxq(4, True) 45 | assert torch.all(torch.logical_and(q >= minq, q <= maxq)) 46 | 47 | q_i8 = two_compl(q.to(dtype=torch.int8), 4).to(torch.uint8) 48 | q_i4 = q_i8[:, 0::2] | (q_i8[:, 1::2] << 4) 49 | return q_i4 50 | 51 | # Unpack the quantized int4 tensor (stored in uint8) into int32 tensor. 52 | def unpack_i4(x: torch.Tensor): 53 | assert x.dtype == torch.uint8, 'The tensor to be unpacked should be stored in uint8' 54 | 55 | out_shape = list(x.shape) 56 | out_shape[-1] *= 2 # Each uint8 packs two numbers 57 | 58 | # Low 4 bits 59 | x0 = (x & 0x0f).to(torch.int8) 60 | x0[x0>=8] -= 16 61 | x0 = x0.view(-1, x0.shape[-1]) 62 | 63 | # High 4 bits 64 | x1 = ((x & 0xf0) >> 4).to(torch.int8) 65 | x1[x1>=8] -= 16 66 | x1 = x1.view(-1, x1.shape[-1]) 67 | 68 | out = torch.empty(out_shape, device=x.device, dtype=torch.int32) 69 | out = out.view(-1, out.shape[-1]) 70 | # Interleaving 71 | out[:, 0::2] = x0 72 | out[:, 1::2] = x1 73 | 74 | return out.view(out_shape) 75 | -------------------------------------------------------------------------------- /quarot/kernels/flashinfer.cu: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | template 10 | void FlashInferBatchDecodeKernel_i4(nv_half* o, nv_half* q, void* kv_data, 11 | nv_half2* kv_param, int32_t* kv_indptr, 12 | int32_t* kv_indicies, 13 | int32_t* last_page_offset, int num_layers, 14 | int layer_idx, int num_heads, int page_size, 15 | int batch_size) { 16 | using DTypeIn = flashinfer::quant::__precision__s4; 17 | using DTypeInQ = nv_half; 18 | using DTypeOut = nv_half; 19 | 20 | flashinfer::paged_kv_t paged_kv( 21 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size, 22 | (DTypeIn*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset); 23 | 24 | const float rope_scale = 1.f; 25 | const float rope_theta = 1e4; 26 | const float sm_scale = 1.f / std::sqrt(float(head_dim)); 27 | const float rope_inv_scale = 1.f / rope_scale; 28 | const float rope_inv_theta = 1.f / rope_theta; 29 | 30 | constexpr bool norm_on_the_fly = false; 31 | constexpr auto rotary_mode = flashinfer::RotaryMode::kNone; 32 | constexpr size_t FoldFactor = 2; 33 | constexpr size_t vec_size = std::max( 34 | static_cast(16 / flashinfer::quant::size_of_type() / 35 | FoldFactor), 36 | static_cast(head_dim / 32)); 37 | constexpr size_t bdx = head_dim / vec_size; 38 | constexpr size_t bdy = 128 / bdx; 39 | dim3 nblks(paged_kv.batch_size, paged_kv.num_heads); 40 | dim3 nthrs(bdx, bdy); 41 | 42 | flashinfer::BatchDecodeWithPagedKVCacheKernel< 43 | rotary_mode, norm_on_the_fly, vec_size, bdx, bdy, FoldFactor, DTypeInQ, 44 | DTypeIn, DTypeOut, int32_t><<>>( 45 | q, paged_kv, o, sm_scale, rope_inv_scale, rope_inv_theta); 46 | } 47 | 48 | template 49 | void FlashInferInitKvKernel_i4(void* kv_data, nv_half2* kv_param, 50 | int32_t* kv_indptr, int32_t* kv_indicies, 51 | int32_t* last_page_offset, void* key, 52 | void* value, nv_half2* key_param, 53 | nv_half2* value_param, int32_t* seqlen_indptr, 54 | int num_layers, int layer_idx, int num_heads, 55 | int page_size, int batch_size) { 56 | using T = flashinfer::quant::__precision__s4; 57 | flashinfer::paged_kv_t paged_kv( 58 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size, 59 | (T*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset); 60 | 61 | constexpr size_t vec_size = 62 | std::max(static_cast(16 / flashinfer::quant::size_of_type()), 63 | static_cast(head_dim / 32)); 64 | constexpr size_t bdx = head_dim / vec_size; 65 | constexpr size_t bdy = 128 / bdx; 66 | dim3 nblks(paged_kv.batch_size * ((paged_kv.num_heads + bdy - 1) / bdy)); 67 | dim3 nthrs(bdx, bdy); 68 | flashinfer::AppendPagedKVCachePrefillKernel<<>>( 70 | paged_kv, (T*)key, (T*)value, key_param, value_param, seqlen_indptr); 71 | } 72 | 73 | template 74 | void FlashInferAppendKvKernel_i4(void* kv_data, nv_half2* kv_param, 75 | int32_t* kv_indptr, int32_t* kv_indicies, 76 | int32_t* last_page_offset, void* key, 77 | void* value, nv_half2* key_param, 78 | nv_half2* value_param, int num_layers, 79 | int layer_idx, int num_heads, int page_size, 80 | int batch_size) { 81 | using T = flashinfer::quant::__precision__s4; 82 | flashinfer::paged_kv_t paged_kv( 83 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size, 84 | (T*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset); 85 | 86 | constexpr size_t vec_size = 87 | std::max(static_cast(16 / flashinfer::quant::size_of_type()), 88 | static_cast(head_dim / 32)); 89 | constexpr size_t bdx = head_dim / vec_size; 90 | constexpr size_t bdy = 128 / bdx; 91 | dim3 nblks(paged_kv.batch_size * ((paged_kv.num_heads + bdy - 1) / bdy)); 92 | dim3 nthrs(bdx, bdy); 93 | flashinfer::AppendPagedKVCacheDecodeKernel 95 | <<>>(paged_kv, (T*)key, (T*)value, key_param, value_param); 96 | } 97 | 98 | 99 | template 100 | void FlashInferBatchDecodeKernel_f16(nv_half* o, nv_half* q, void* kv_data, 101 | nv_half2* kv_param, int32_t* kv_indptr, 102 | int32_t* kv_indicies, 103 | int32_t* last_page_offset, int num_layers, 104 | int layer_idx, int num_heads, int page_size, 105 | int batch_size) { 106 | using DTypeIn = nv_half; 107 | using DTypeInQ = nv_half; 108 | using DTypeOut = nv_half; 109 | 110 | flashinfer::paged_kv_t paged_kv( 111 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size, 112 | (DTypeIn*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset); 113 | 114 | const float rope_scale = 1.f; 115 | const float rope_theta = 1e4; 116 | const float sm_scale = 1.f / std::sqrt(float(head_dim)); 117 | const float rope_inv_scale = 1.f / rope_scale; 118 | const float rope_inv_theta = 1.f / rope_theta; 119 | 120 | constexpr bool norm_on_the_fly = false; 121 | constexpr auto rotary_mode = flashinfer::RotaryMode::kNone; 122 | constexpr size_t FoldFactor = 1; 123 | constexpr size_t vec_size = std::max( 124 | static_cast(16 / flashinfer::quant::size_of_type() / 125 | FoldFactor), 126 | static_cast(head_dim / 32)); 127 | constexpr size_t bdx = head_dim / vec_size; 128 | constexpr size_t bdy = 128 / bdx; 129 | dim3 nblks(paged_kv.batch_size, paged_kv.num_heads); 130 | dim3 nthrs(bdx, bdy); 131 | 132 | flashinfer::BatchDecodeWithPagedKVCacheKernel< 133 | rotary_mode, norm_on_the_fly, vec_size, bdx, bdy, FoldFactor, DTypeInQ, 134 | DTypeIn, DTypeOut, int32_t><<>>( 135 | q, paged_kv, o, sm_scale, rope_inv_scale, rope_inv_theta); 136 | } 137 | 138 | template 139 | void FlashInferInitKvKernel_f16(void* kv_data, nv_half2* kv_param, 140 | int32_t* kv_indptr, int32_t* kv_indicies, 141 | int32_t* last_page_offset, void* key, 142 | void* value, nv_half2* key_param, 143 | nv_half2* value_param, int32_t* seqlen_indptr, 144 | int num_layers, int layer_idx, int num_heads, 145 | int page_size, int batch_size) { 146 | using T = nv_half; 147 | flashinfer::paged_kv_t paged_kv( 148 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size, 149 | (T*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset); 150 | 151 | constexpr size_t vec_size = 152 | std::max(static_cast(16 / flashinfer::quant::size_of_type()), 153 | static_cast(head_dim / 32)); 154 | constexpr size_t bdx = head_dim / vec_size; 155 | constexpr size_t bdy = 128 / bdx; 156 | dim3 nblks(paged_kv.batch_size * ((paged_kv.num_heads + bdy - 1) / bdy)); 157 | dim3 nthrs(bdx, bdy); 158 | flashinfer::AppendPagedKVCachePrefillKernel<<>>( 160 | paged_kv, (T*)key, (T*)value, key_param, value_param, seqlen_indptr); 161 | } 162 | 163 | template 164 | void FlashInferAppendKvKernel_f16(void* kv_data, nv_half2* kv_param, 165 | int32_t* kv_indptr, int32_t* kv_indicies, 166 | int32_t* last_page_offset, void* key, 167 | void* value, nv_half2* key_param, 168 | nv_half2* value_param, int num_layers, 169 | int layer_idx, int num_heads, int page_size, 170 | int batch_size) { 171 | using T = nv_half; 172 | flashinfer::paged_kv_t paged_kv( 173 | num_layers, layer_idx, num_heads, page_size, head_dim, batch_size, 174 | (T*)kv_data, kv_param, kv_indptr, kv_indicies, last_page_offset); 175 | 176 | constexpr size_t vec_size = 177 | std::max(static_cast(16 / flashinfer::quant::size_of_type()), 178 | static_cast(head_dim / 32)); 179 | constexpr size_t bdx = head_dim / vec_size; 180 | constexpr size_t bdy = 128 / bdx; 181 | dim3 nblks(paged_kv.batch_size * ((paged_kv.num_heads + bdy - 1) / bdy)); 182 | dim3 nthrs(bdx, bdy); 183 | flashinfer::AppendPagedKVCacheDecodeKernel 185 | <<>>(paged_kv, (T*)key, (T*)value, key_param, value_param); 186 | } 187 | 188 | 189 | template void FlashInferBatchDecodeKernel_i4<128>( 190 | nv_half* o, nv_half* q, void* kv_data, nv_half2* kv_param, 191 | int32_t* kv_indptr, int32_t* kv_indicies, int32_t* last_page_offset, 192 | int num_layers, int layer_idx, int num_heads, int page_size, 193 | int batch_size); 194 | 195 | template void FlashInferInitKvKernel_i4<128>( 196 | void* kv_data, nv_half2* kv_param, int32_t* kv_indptr, int32_t* kv_indicies, 197 | int32_t* last_page_offset, void* key, void* value, nv_half2* key_param, 198 | nv_half2* value_param, int32_t* seqlen_indptr, int num_layers, 199 | int layer_idx, int num_heads, int page_size, int batch_size); 200 | 201 | template void FlashInferAppendKvKernel_i4<128>( 202 | void* kv_data, nv_half2* kv_param, int32_t* kv_indptr, int32_t* kv_indicies, 203 | int32_t* last_page_offset, void* key, void* value, nv_half2* key_param, 204 | nv_half2* value_param, int num_layers, int layer_idx, int num_heads, 205 | int page_size, int batch_size); 206 | 207 | 208 | template void FlashInferBatchDecodeKernel_f16<128>( 209 | nv_half* o, nv_half* q, void* kv_data, nv_half2* kv_param, 210 | int32_t* kv_indptr, int32_t* kv_indicies, int32_t* last_page_offset, 211 | int num_layers, int layer_idx, int num_heads, int page_size, 212 | int batch_size); 213 | 214 | template void FlashInferInitKvKernel_f16<128>( 215 | void* kv_data, nv_half2* kv_param, int32_t* kv_indptr, int32_t* kv_indicies, 216 | int32_t* last_page_offset, void* key, void* value, nv_half2* key_param, 217 | nv_half2* value_param, int32_t* seqlen_indptr, int num_layers, 218 | int layer_idx, int num_heads, int page_size, int batch_size); 219 | 220 | template void FlashInferAppendKvKernel_f16<128>( 221 | void* kv_data, nv_half2* kv_param, int32_t* kv_indptr, int32_t* kv_indicies, 222 | int32_t* last_page_offset, void* key, void* value, nv_half2* key_param, 223 | nv_half2* value_param, int num_layers, int layer_idx, int num_heads, 224 | int page_size, int batch_size); 225 | -------------------------------------------------------------------------------- /quarot/kernels/gemm.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | 5 | 6 | 7 | void matmul_host( 8 | const Int4Storage *A, 9 | const Int4Storage *B, 10 | uint32_t M, 11 | uint32_t N, 12 | uint32_t K, 13 | int32_t *C 14 | ) 15 | { 16 | using Gemm = cutlass::gemm::device::Gemm< 17 | cutlass::int4b_t, // ElementA 18 | cutlass::layout::RowMajor, // LayoutA 19 | cutlass::int4b_t, // ElementB 20 | cutlass::layout::ColumnMajor, // LayoutB 21 | int32_t, // ElementOutput 22 | cutlass::layout::RowMajor, // LayoutOutput 23 | int32_t, // ElementAccumulator 24 | cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores 25 | cutlass::arch::Sm80 // tag indicating target GPU compute architecture // TODO: This is just for compiling on my laptop temporarily. Should be higher when doing benchmarking. 26 | >; 27 | 28 | Gemm gemmOp; 29 | 30 | using GemmCoord = cutlass::gemm::GemmCoord; 31 | 32 | typename Gemm::Arguments arguments{ 33 | {static_cast(M), static_cast(N), static_cast(K)}, 34 | {(cutlass::int4b_t *) A, K}, 35 | {(cutlass::int4b_t *) B, K}, 36 | {C, N}, 37 | {C, N}, 38 | {1, 0} 39 | }; 40 | 41 | auto status = gemmOp(arguments); 42 | 43 | ensure(status == cutlass::Status::kSuccess, 44 | cutlassGetStatusString(status)); 45 | 46 | } 47 | -------------------------------------------------------------------------------- /quarot/kernels/include/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | //#include 10 | #include 11 | 12 | #define HOST_DEVICE __forceinline__ __host__ __device__ 13 | #define DEVICE __forceinline__ __device__ 14 | #define HOST __forceinline__ __host__ 15 | 16 | HOST void ensure(bool condition, const std::string& msg) { 17 | if (!condition) { 18 | std::cerr << "Assertion failed: " << msg << '\n'; 19 | // Choose the appropriate action: throw an exception, abort, etc. 20 | // For example, throwing an exception: 21 | throw std::runtime_error(msg); 22 | } 23 | } 24 | 25 | template 26 | HOST_DEVICE T mymax(T a, T b) 27 | { 28 | return a > b ? a : b; 29 | } 30 | 31 | template 32 | HOST_DEVICE T mymin(T a, T b) 33 | { 34 | return a < b ? a : b; 35 | } 36 | 37 | template 38 | HOST_DEVICE T cdiv(T a, T b) { return (a + b - 1) / b; } 39 | 40 | template 41 | HOST_DEVICE T clamp(T x, T a, T b) { return mymax(a, mymin(b, x)); } 42 | 43 | template 44 | HOST_DEVICE T myabs(T x) { return x < (T) 0 ? -x : x; } 45 | 46 | template 47 | DEVICE T sqr(T x) 48 | { 49 | return x * x; 50 | } 51 | 52 | constexpr int qmin = -8; 53 | constexpr int qmax = 7; 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | 6 | template 7 | void FlashInferBatchDecodeKernel_i4(nv_half* o, nv_half* q, void* kv_data, 8 | nv_half2* kv_param, int32_t* kv_indptr, 9 | int32_t* kv_indicies, 10 | int32_t* last_page_offset, int num_layers, 11 | int layer_idx, int num_heads, int page_size, 12 | int batch_size); 13 | 14 | template 15 | void FlashInferInitKvKernel_i4(void* kv_data, nv_half2* kv_param, 16 | int32_t* kv_indptr, int32_t* kv_indicies, 17 | int32_t* last_page_offset, void* key, 18 | void* value, nv_half2* key_param, 19 | nv_half2* value_param, int32_t* seqlen_indptr, 20 | int num_layers, int layer_idx, int num_heads, 21 | int page_size, int batch_size); 22 | 23 | template 24 | void FlashInferAppendKvKernel_i4(void* kv_data, nv_half2* kv_param, 25 | int32_t* kv_indptr, int32_t* kv_indicies, 26 | int32_t* last_page_offset, void* key, 27 | void* value, nv_half2* key_param, 28 | nv_half2* value_param, int num_layers, 29 | int layer_idx, int num_heads, int page_size, 30 | int batch_size); 31 | 32 | 33 | template 34 | void FlashInferBatchDecodeKernel_f16(nv_half* o, nv_half* q, void* kv_data, 35 | nv_half2* kv_param, int32_t* kv_indptr, 36 | int32_t* kv_indicies, 37 | int32_t* last_page_offset, int num_layers, 38 | int layer_idx, int num_heads, int page_size, 39 | int batch_size); 40 | 41 | template 42 | void FlashInferInitKvKernel_f16(void* kv_data, nv_half2* kv_param, 43 | int32_t* kv_indptr, int32_t* kv_indicies, 44 | int32_t* last_page_offset, void* key, 45 | void* value, nv_half2* key_param, 46 | nv_half2* value_param, int32_t* seqlen_indptr, 47 | int num_layers, int layer_idx, int num_heads, 48 | int page_size, int batch_size); 49 | 50 | template 51 | void FlashInferAppendKvKernel_f16(void* kv_data, nv_half2* kv_param, 52 | int32_t* kv_indptr, int32_t* kv_indicies, 53 | int32_t* last_page_offset, void* key, 54 | void* value, nv_half2* key_param, 55 | nv_half2* value_param, int num_layers, 56 | int layer_idx, int num_heads, int page_size, 57 | int batch_size); 58 | -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/cp_async.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_CP_ASYNC_CUH_ 2 | #define FLASHINFER_CP_ASYNC_CUH_ 3 | 4 | #include 5 | #include "quantization.cuh" 6 | 7 | namespace flashinfer { 8 | 9 | namespace cp_async { 10 | 11 | __device__ __forceinline__ void commit_group() { 12 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) 13 | asm volatile("cp.async.commit_group;\n" ::); 14 | #endif 15 | } 16 | 17 | template 18 | __device__ __forceinline__ void wait_group() { 19 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) 20 | asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); 21 | #endif 22 | } 23 | 24 | template 25 | __device__ __forceinline__ void load_128(T* smem_ptr, const T* gmem_ptr) { 26 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) 27 | 28 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 29 | if constexpr (prefetch) { 30 | asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), 31 | "l"(gmem_ptr), "n"(16), "r"(16)); 32 | } else { 33 | asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(smem_int_ptr), "l"(gmem_ptr), 34 | "n"(16)); 35 | } 36 | #else 37 | *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); 38 | #endif 39 | } 40 | 41 | template 42 | __device__ __forceinline__ void pred_load_128(T* smem_ptr, const T* gmem_ptr, bool predicate) { 43 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) 44 | 45 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 46 | if constexpr (prefetch) { 47 | asm volatile( 48 | "{\n" 49 | " .reg .pred p;\n" 50 | " setp.ne.b32 p, %0, 0;\n" 51 | " @p cp.async.cg.shared.global [%1], [%2], %3;\n" 52 | "}\n" ::"r"((int)predicate), 53 | "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); 54 | } else { 55 | asm volatile( 56 | "{\n" 57 | " .reg .pred p;\n" 58 | " setp.ne.b32 p, %0, 0;\n" 59 | " @p cp.async.cg.shared.global [%1], [%2], %3;\n" 60 | "}\n" ::"r"((int)predicate), 61 | "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); 62 | } 63 | #else 64 | if (predicate) { 65 | *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); 66 | } 67 | #endif 68 | } 69 | 70 | template 71 | __device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { 72 | static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256"); 73 | if constexpr (num_bits == 128) { 74 | load_128(smem_ptr, gmem_ptr); 75 | } else { 76 | load_128(smem_ptr, gmem_ptr); 77 | load_128(smem_ptr + static_cast(16 / quant::size_of_type()), gmem_ptr + static_cast(16 / quant::size_of_type())); 78 | } 79 | } 80 | 81 | template 82 | __device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) { 83 | static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256"); 84 | if constexpr (num_bits == 128) { 85 | pred_load_128(smem_ptr, gmem_ptr, predicate); 86 | } else { 87 | pred_load_128(smem_ptr, gmem_ptr, predicate); 88 | pred_load_128(smem_ptr + static_cast(16 / quant::size_of_type()), gmem_ptr + static_cast(16 / quant::size_of_type()), predicate); 89 | } 90 | } 91 | 92 | } // namespace cp_async 93 | 94 | } // namespace flashinfer 95 | 96 | #endif // FLASHINFER_CP_ASYNC_CUH_ 97 | -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/layout.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_LAYOUT_CUH_ 2 | #define FLASHINFER_LAYOUT_CUH_ 3 | 4 | #include 5 | 6 | namespace flashinfer { 7 | 8 | /*! 9 | * \brief The Layout of QKV matrices 10 | */ 11 | enum class QKVLayout { 12 | // [seq_len, num_heads, head_dim] 13 | kNHD = 0U, 14 | // [num_heads, head_dim, seq_len] 15 | kHND = 1U, 16 | }; 17 | 18 | template 19 | __host__ __device__ __forceinline__ size_t get_elem_offset_impl(size_t elem_idx, size_t head_idx, 20 | size_t feat_idx, size_t seq_len, 21 | size_t num_heads, size_t head_dim) { 22 | if constexpr (layout == QKVLayout::kHND) { 23 | return (head_idx * seq_len + elem_idx) * head_dim + feat_idx; 24 | } else { 25 | return (elem_idx * num_heads + head_idx) * head_dim + feat_idx; 26 | } 27 | } 28 | 29 | template 30 | struct tensor_info_t { 31 | size_t qo_len; 32 | size_t kv_len; 33 | size_t num_heads; 34 | size_t head_dim; 35 | __host__ __device__ __forceinline__ tensor_info_t(size_t qo_len, size_t kv_len, size_t num_heads, 36 | size_t head_dim) 37 | : qo_len(qo_len), kv_len(kv_len), num_heads(num_heads), head_dim(head_dim) {} 38 | 39 | __host__ __device__ __forceinline__ size_t get_qo_elem_offset(size_t query_idx, size_t head_idx, 40 | size_t feat_idx) const { 41 | return get_elem_offset_impl(query_idx, head_idx, feat_idx, qo_len, num_heads, head_dim); 42 | } 43 | 44 | __host__ __device__ __forceinline__ size_t get_kv_elem_offset(size_t kv_idx, size_t head_idx, 45 | size_t feat_idx) const { 46 | return get_elem_offset_impl(kv_idx, head_idx, feat_idx, kv_len, num_heads, head_dim); 47 | } 48 | }; 49 | 50 | /*! 51 | * \brief Convert QKVLayout to string 52 | * \param qkv_layout The QKVLayout to convert 53 | */ 54 | inline std::string QKVLayoutToString(const QKVLayout &qkv_layout) { 55 | switch (qkv_layout) { 56 | case QKVLayout::kNHD: 57 | return "NHD"; 58 | case QKVLayout::kHND: 59 | return "HND"; 60 | default: 61 | return "Unknown"; 62 | } 63 | } 64 | 65 | } // namespace flashinfer 66 | #endif // FLASHINFER_LAYOUT_CUH_ -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/math.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_MATH_CUH_ 2 | #define FLASHINFER_MATH_CUH_ 3 | 4 | #include 5 | #include 6 | 7 | namespace flashinfer { 8 | namespace math { 9 | 10 | constexpr float log2e = M_LOG2E; // 1.44269504088896340736f; 11 | 12 | __forceinline__ __device__ float ptx_exp2(float x) { 13 | float y = exp2f(x); 14 | // asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); 15 | return y; 16 | } 17 | 18 | __forceinline__ __device__ float shfl_xor_sync(float x, int delta) { 19 | float y; 20 | asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;" : "=f"(y) : "f"(x), "r"(delta)); 21 | return y; 22 | } 23 | 24 | } // namespace math 25 | } // namespace flashinfer 26 | #endif // FLASHINFER_MATH_CUH_ 27 | -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/mma.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_MMA_CUH_ 2 | #define FLASHINFER_MMA_CUH_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace flashinfer { 11 | 12 | namespace mma { 13 | 14 | constexpr size_t frag_size = 16; 15 | 16 | template 17 | __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t *R, T *smem_ptr) { 18 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 19 | asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" 20 | : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) 21 | : "r"(smem_int_ptr)); 22 | } 23 | 24 | template 25 | __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t *R, T *smem_ptr) { 26 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 27 | asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" 28 | : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) 29 | : "r"(smem_int_ptr)); 30 | } 31 | 32 | template 33 | __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t *R, T *smem_ptr) { 34 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 11) 35 | uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); 36 | asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" 37 | : "r"(smem_int_ptr), "r"(R[0]), "r"(R[1]), "r"(R[2]), "r"(R[3])); 38 | #else 39 | // NOTE(Zihao): Not implemented yet. 40 | #endif 41 | } 42 | 43 | template 44 | __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float *C, uint32_t *A, 45 | uint32_t *B) { 46 | if constexpr (std::is_same::value) { 47 | asm volatile( 48 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " 49 | "{%0, %1, %2, %3}," 50 | "{%4, %5, %6, %7}," 51 | "{%8, %9}," 52 | "{%10, %11, %12, %13};\n" 53 | : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) 54 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), 55 | "f"(C[2]), "f"(C[3])); 56 | asm volatile( 57 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " 58 | "{%0, %1, %2, %3}," 59 | "{%4, %5, %6, %7}," 60 | "{%8, %9}," 61 | "{%10, %11, %12, %13};\n" 62 | : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) 63 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), 64 | "f"(C[6]), "f"(C[7])); 65 | } else { 66 | asm volatile( 67 | "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " 68 | "{%0, %1, %2, %3}," 69 | "{%4, %5, %6, %7}," 70 | "{%8, %9}," 71 | "{%10, %11, %12, %13};\n" 72 | : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) 73 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), 74 | "f"(C[2]), "f"(C[3])); 75 | asm volatile( 76 | "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " 77 | "{%0, %1, %2, %3}," 78 | "{%4, %5, %6, %7}," 79 | "{%8, %9}," 80 | "{%10, %11, %12, %13};\n" 81 | : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) 82 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), 83 | "f"(C[6]), "f"(C[7])); 84 | } 85 | } 86 | 87 | } // namespace mma 88 | 89 | } // namespace flashinfer 90 | 91 | #endif // FLASHINFER_MMA_CUH_ 92 | -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/page.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_PAGE_CUH_ 2 | #define FLASHINFER_PAGE_CUH_ 3 | 4 | #include "layout.cuh" 5 | #include "utils.cuh" 6 | #include "vec_dtypes.cuh" 7 | #include "quantization.cuh" 8 | 9 | namespace flashinfer { 10 | 11 | /*! 12 | * \brief Paged key-value cache 13 | * \tparam DType The data type of the key-value cache 14 | * \tparam IdType The index data type of the kv-cache 15 | * \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, head_dim] 16 | * \note This layout is kind of HND, which is memory-friendly for the self-attn kernel's tile block. 17 | */ 18 | template 19 | struct paged_kv_t { 20 | size_t num_layers; 21 | size_t layer_idx; 22 | size_t num_heads; 23 | size_t page_size; 24 | size_t head_dim; 25 | size_t batch_size; 26 | // [max_num_pages * num_layers * 2 * num_heads * page_size * head_dim] 27 | // The flattened key-value cache 28 | DType* data; 29 | // [max_num_pages * num_layers * 2 * num_heads * page_size * 1] 30 | // The flattened key-value quantization parameter cache 31 | half2* param; 32 | // [batch_size + 1] The page indptr array, with the first element 0 33 | IdType* indptr; 34 | // [nnz_pages] The page indices array 35 | IdType* indices; 36 | // [batch_size] The offset of the last page for each request in the batch 37 | IdType* last_page_offset; 38 | /*! 39 | * \brief Construct a paged key-value cache 40 | * \param num_layers The number of layers 41 | * \param layer_idx The index of the layer 42 | * \param num_heads The number of heads 43 | * \param page_size The size of each page 44 | * \param head_dim The dimension of each head 45 | * \param batch_size The batch size 46 | * \param data The flattened key-value cache 47 | * \param indptr The page indptr array 48 | * \param indices The page indices array 49 | * \param last_page_offset The offset of the last page for each request in the batch 50 | */ 51 | __host__ __device__ __forceinline__ paged_kv_t( 52 | size_t num_layers, 53 | size_t layer_idx, 54 | size_t num_heads, 55 | size_t page_size, 56 | size_t head_dim, 57 | size_t batch_size, 58 | DType* data, 59 | half2* param, 60 | IdType* indptr, 61 | IdType* indices, 62 | IdType* last_page_offset 63 | ): num_layers(num_layers), 64 | layer_idx(layer_idx), 65 | num_heads(num_heads), 66 | page_size(page_size), 67 | head_dim(head_dim), 68 | batch_size(batch_size), 69 | data(data), 70 | param(param), 71 | indptr(indptr), 72 | indices(indices), 73 | last_page_offset(last_page_offset) {} 74 | 75 | // \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, head_dim] 76 | __host__ __device__ __forceinline__ size_t get_k_elem_offset(size_t page_idx, size_t head_idx, 77 | size_t entry_idx, size_t feat_idx) { 78 | return (((page_idx * num_layers + layer_idx) * 2 * num_heads + head_idx) * page_size + 79 | entry_idx) * 80 | head_dim + 81 | feat_idx; 82 | } 83 | 84 | // \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, head_dim] 85 | __host__ __device__ __forceinline__ size_t get_v_elem_offset(size_t page_idx, size_t head_idx, 86 | size_t entry_idx, size_t feat_idx) { 87 | return ((((page_idx * num_layers + layer_idx) * 2 + 1) * num_heads + head_idx) * page_size + 88 | entry_idx) * 89 | head_dim + 90 | feat_idx; 91 | } 92 | 93 | // \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, 1] 94 | __host__ __device__ __forceinline__ size_t get_param_k_elem_offset(size_t page_idx, size_t head_idx, 95 | size_t entry_idx) { 96 | return ((page_idx * num_layers + layer_idx) * 2 * num_heads + head_idx) * page_size + entry_idx; 97 | } 98 | 99 | // \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, 1] 100 | __host__ __device__ __forceinline__ size_t get_param_v_elem_offset(size_t page_idx, size_t head_idx, 101 | size_t entry_idx) { 102 | return (((page_idx * num_layers + layer_idx) * 2 + 1) * num_heads + head_idx) * page_size + entry_idx; 103 | } 104 | 105 | __host__ __device__ __forceinline__ size_t get_valid_page_size(size_t batch_idx, size_t page_iter) { 106 | if (page_iter == indptr[batch_idx + 1] - 1) { 107 | return last_page_offset[batch_idx]; 108 | } else { 109 | return page_size; 110 | } 111 | } 112 | }; 113 | 114 | /*! 115 | * \brief: Append single token to the exisiting kv cache. 116 | * \note: Layout of key: [batch_size, num_heads, head_dim] 117 | * \note: this layout is natural output of previous dense layer, which don't need transpose. 118 | */ 119 | template 120 | __global__ void AppendPagedKVCacheDecodeKernel( 121 | paged_kv_t paged_kv, 122 | DType* __restrict__ key, 123 | DType* __restrict__ value, 124 | half2* __restrict__ key_param, 125 | half2* __restrict__ value_param 126 | ) { 127 | size_t tx = threadIdx.x, ty = threadIdx.y; 128 | size_t num_heads = paged_kv.num_heads; 129 | size_t batch_idx = blockIdx.x / ((num_heads + bdy - 1) / bdy); 130 | size_t head_idx = (blockIdx.x % ((num_heads + bdy - 1) / bdy)) * bdy + ty; 131 | if (head_idx < num_heads) { 132 | // Pre-allocated enough space for the last page 133 | // seq_len included the added one 134 | size_t seq_len = 135 | (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size + 136 | paged_kv.last_page_offset[batch_idx]; 137 | 138 | size_t page_idx = 139 | paged_kv.indices[paged_kv.indptr[batch_idx] + (seq_len - 1) / paged_kv.page_size]; 140 | size_t entry_idx = (seq_len - 1) % paged_kv.page_size; 141 | 142 | vec_t::memcpy( 143 | quant::get_ptr(paged_kv.data, paged_kv.get_k_elem_offset(page_idx, head_idx, entry_idx, tx * vec_size)), 144 | quant::get_ptr(key, (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size)); 145 | 146 | vec_t::memcpy( 147 | quant::get_ptr(paged_kv.data, paged_kv.get_v_elem_offset(page_idx, head_idx, entry_idx, tx * vec_size)), 148 | quant::get_ptr(value, (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size)); 149 | 150 | // Copy the quantization parameters 151 | // One group only copies once 152 | if(tx == 0){ 153 | quant::get_ptr( 154 | paged_kv.param, 155 | paged_kv.get_param_k_elem_offset(page_idx, head_idx, entry_idx) 156 | )[0] = key_param[batch_idx * num_heads + head_idx]; 157 | 158 | quant::get_ptr( 159 | paged_kv.param, 160 | paged_kv.get_param_v_elem_offset(page_idx, head_idx, entry_idx) 161 | )[0] = value_param[batch_idx * num_heads + head_idx]; 162 | } 163 | } 164 | } 165 | 166 | template 167 | __global__ void AppendPagedKVCachePrefillKernel( 168 | paged_kv_t paged_kv, 169 | DType* __restrict__ key, 170 | DType* __restrict__ value, 171 | half2* __restrict__ key_param, 172 | half2* __restrict__ value_param, 173 | IdType* __restrict__ append_indptr 174 | ) { 175 | size_t tx = threadIdx.x, ty = threadIdx.y; 176 | size_t num_heads = paged_kv.num_heads; 177 | size_t batch_idx = blockIdx.x / ((num_heads + bdy - 1) / bdy); 178 | size_t head_idx = (blockIdx.x % ((num_heads + bdy - 1) / bdy)) * bdy + ty; 179 | if (head_idx < num_heads) { 180 | 181 | // Pre-filled seq_len 182 | size_t seq_len = 183 | (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size + 184 | paged_kv.last_page_offset[batch_idx]; 185 | // Calculated to-be-filled seq_len 186 | size_t append_seq_len = append_indptr[batch_idx + 1] - append_indptr[batch_idx]; 187 | size_t append_start = seq_len - append_seq_len; 188 | 189 | #pragma unroll 2 190 | for (size_t j = 0; j < append_seq_len; ++j) { 191 | size_t page_seq_idx = j + append_start; 192 | size_t page_idx = 193 | paged_kv.indices[paged_kv.indptr[batch_idx] + page_seq_idx / paged_kv.page_size]; 194 | size_t entry_idx = page_seq_idx % paged_kv.page_size; 195 | 196 | vec_t::memcpy( 197 | quant::get_ptr(paged_kv.data, paged_kv.get_k_elem_offset(page_idx, head_idx, entry_idx, tx * vec_size)), 198 | quant::get_ptr(key, ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size)); 199 | 200 | vec_t::memcpy( 201 | quant::get_ptr(paged_kv.data, paged_kv.get_v_elem_offset(page_idx, head_idx, entry_idx, tx * vec_size)), 202 | quant::get_ptr(value, ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size)); 203 | 204 | // Copy the quantization parameters 205 | // One group only copies once 206 | if(tx == 0){ 207 | quant::get_ptr( 208 | paged_kv.param, 209 | paged_kv.get_param_k_elem_offset(page_idx, head_idx, entry_idx) 210 | )[0] = key_param[(append_indptr[batch_idx] + j) * num_heads + head_idx]; 211 | 212 | quant::get_ptr( 213 | paged_kv.param, 214 | paged_kv.get_param_v_elem_offset(page_idx, head_idx, entry_idx) 215 | )[0] = value_param[(append_indptr[batch_idx] + j) * num_heads + head_idx]; 216 | } 217 | } 218 | } 219 | } 220 | 221 | template 222 | cudaError_t AppendPagedKVCacheDecode( 223 | paged_kv_t paged_kv, 224 | DType* key, 225 | DType* value, 226 | half2* key_param, 227 | half2* value_param, 228 | cudaStream_t stream = nullptr, 229 | size_t dev_id = 0 230 | ) { 231 | FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); 232 | size_t head_dim = paged_kv.head_dim; 233 | size_t batch_size = paged_kv.batch_size; 234 | size_t num_heads = paged_kv.num_heads; 235 | SWITCH_HEAD_DIM(head_dim, HEAD_DIM, { 236 | constexpr size_t vec_size = std::max(static_cast(16 / quant::size_of_type()), HEAD_DIM / 32); 237 | constexpr size_t bdx = HEAD_DIM / vec_size; 238 | constexpr size_t bdy = 128 / bdx; 239 | assert(num_heads % bdy == 0); 240 | dim3 nblks(batch_size * num_heads / bdy); 241 | dim3 nthrs(bdx, bdy); 242 | auto kernel = AppendPagedKVCacheDecodeKernel; 243 | void* args[] = { 244 | (void*)&paged_kv, 245 | (void*)&key, 246 | (void*)&value, 247 | (void*)&key_param, 248 | (void*)&value_param 249 | }; 250 | FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); 251 | }); 252 | return cudaSuccess; 253 | } 254 | 255 | template 256 | cudaError_t AppendPagedKVCachePrefill( 257 | paged_kv_t paged_kv, 258 | DType* key, 259 | DType* value, 260 | half2* key_param, 261 | half2* value_param, 262 | IdType* append_indptr, 263 | cudaStream_t stream = nullptr, 264 | size_t dev_id = 0 265 | ) { 266 | FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); 267 | size_t head_dim = paged_kv.head_dim; 268 | size_t batch_size = paged_kv.batch_size; 269 | size_t num_heads = paged_kv.num_heads; 270 | SWITCH_HEAD_DIM(head_dim, HEAD_DIM, { 271 | constexpr size_t vec_size = std::max(static_cast(16 / quant::size_of_type()), HEAD_DIM / 32); 272 | constexpr size_t bdx = HEAD_DIM / vec_size; 273 | constexpr size_t bdy = 128 / bdx; 274 | assert(num_heads % bdy == 0); 275 | dim3 nblks(batch_size * num_heads / bdy); 276 | dim3 nthrs(bdx, bdy); 277 | auto kernel = AppendPagedKVCachePrefillKernel; 278 | void* args[] = { 279 | (void*)&paged_kv, 280 | (void*)&key, 281 | (void*)&value, 282 | (void*)&key_param, 283 | (void*)&value_param, 284 | (void*)&append_indptr 285 | }; 286 | FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); 287 | }); 288 | return cudaSuccess; 289 | } 290 | 291 | } // namespace flashinfer 292 | 293 | #endif // FLAHSINFER_PAGE_CUH_ 294 | -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/permuted_smem.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_PERMUTED_SMEM_CUH_ 2 | #define FLASHINFER_PERMUTED_SMEM_CUH_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #include "cp_async.cuh" 11 | #include "mma.cuh" 12 | 13 | namespace flashinfer { 14 | 15 | // Each bank is 4 bytes. 16 | using bank_t = uint4; 17 | 18 | template 19 | constexpr __host__ __device__ __forceinline__ size_t bank_capacity() { 20 | return sizeof(bank_t) / sizeof(T); 21 | } 22 | 23 | namespace permuted_smem_impl { 24 | 25 | /*! 26 | * \brief Compute the address of the element at (i, j) in the permuted shared memory. 27 | * \tparam stride The number of banks per row in the permuted shared memory. 28 | * \param smem_base The base address of the permuted shared memory. 29 | * \param i The row index. 30 | * \param j The column (bank) index. 31 | * \note The permuted shared memory maps 8x4 block in logical space to 4x8 block in physical space. 32 | * \see GTC 2020: Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100. 33 | */ 34 | template 35 | __host__ __device__ __forceinline__ bank_t *get_smem_ptr(T *smem_base, size_t i, size_t j) { 36 | return ((bank_t *)smem_base) + (i / 2) * stride * 2 + (j / 4) * 8 + (i % 2) * 4 + 37 | ((j % 4) ^ ((i / 2) % 4)); 38 | } 39 | 40 | template 41 | __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t *R, T *smem_base, size_t i, size_t j) { 42 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j); 43 | mma::ldmatrix_m8n8x4(R, smem_ptr); 44 | } 45 | 46 | template 47 | __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t *R, T *smem_base, size_t i, 48 | size_t j) { 49 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j); 50 | mma::ldmatrix_m8n8x4_trans(R, smem_ptr); 51 | } 52 | 53 | template 54 | __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t *R, T *smem_base, size_t i, size_t j) { 55 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j); 56 | mma::stmatrix_m8n8x4(R, smem_ptr); 57 | } 58 | 59 | template 60 | __device__ __forceinline__ void load_bank(T *smem_base, size_t i, size_t j, const T *gptr) { 61 | *get_smem_ptr(smem_base, i, j) = *reinterpret_cast(gptr); 62 | } 63 | 64 | template 65 | __device__ __forceinline__ void store_bank(T *smem_base, size_t i, size_t j, T *gptr) { 66 | *reinterpret_cast(gptr) = *get_smem_ptr(smem_base, i, j); 67 | } 68 | 69 | template 70 | __device__ __forceinline__ void load_bank_async(T *smem_base, size_t i, size_t j, const T *gptr, 71 | bool predicate) { 72 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j); 73 | cp_async::pred_load_128(smem_ptr, reinterpret_cast(gptr), predicate); 74 | } 75 | 76 | template 77 | __device__ __forceinline__ void load_bank_async(T *smem_base, size_t i, size_t j, const T *gptr) { 78 | bank_t *smem_ptr = get_smem_ptr(smem_base, i, j); 79 | cp_async::load_128(smem_ptr, reinterpret_cast(gptr)); 80 | } 81 | 82 | } // namespace permuted_smem_impl 83 | 84 | template 85 | struct permuted_smem_t { 86 | T __align__(16) * base; 87 | __device__ __forceinline__ permuted_smem_t() : base(nullptr) {} 88 | __device__ __forceinline__ permuted_smem_t(T *base) : base(base) {} 89 | __device__ __forceinline__ bank_t *get_ptr(size_t i, size_t j) { 90 | return permuted_smem_impl::get_smem_ptr(base, i, j); 91 | } 92 | __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t *R, size_t i, size_t j) { 93 | permuted_smem_impl::ldmatrix_m8n8x4(R, base, i, j); 94 | } 95 | __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t *R, size_t i, size_t j) { 96 | permuted_smem_impl::ldmatrix_m8n8x4_trans(R, base, i, j); 97 | } 98 | __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t *R, size_t i, size_t j) { 99 | permuted_smem_impl::stmatrix_m8n8x4(R, base, i, j); 100 | } 101 | __device__ __forceinline__ void load_bank(size_t i, size_t j, const T *gptr) { 102 | permuted_smem_impl::load_bank(base, i, j, gptr); 103 | } 104 | __device__ __forceinline__ void load_bank_async(size_t i, size_t j, const T *gptr, 105 | bool predicate) { 106 | permuted_smem_impl::load_bank_async(base, i, j, gptr, predicate); 107 | } 108 | __device__ __forceinline__ void load_bank_async(size_t i, size_t j, const T *gptr) { 109 | permuted_smem_impl::load_bank_async(base, i, j, gptr); 110 | } 111 | __device__ __forceinline__ void store_bank(size_t i, size_t j, T *gptr) { 112 | permuted_smem_impl::store_bank(base, i, j, gptr); 113 | } 114 | }; 115 | 116 | } // namespace flashinfer 117 | 118 | #endif // FLASHINFER_PERMUTED_SMEM_CUH_ 119 | -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/quantization.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_QUANTIZATION_CUH_ 2 | #define FLASHINFER_QUANTIZATION_CUH_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "vec_dtypes.cuh" 9 | 10 | namespace flashinfer{ 11 | namespace quant{ 12 | 13 | /*! 14 | * \brief Identifier used for 4bit quantization, 15 | * for data size calculattion. 16 | */ 17 | struct __precision__s4{}; 18 | 19 | /*! 20 | * \brief Simliar to sizeof 21 | * \tparam T Data type to be sizeof 22 | */ 23 | template 24 | FLASHINFER_INLINE constexpr float size_of_type(){ 25 | if constexpr (std::is_same::value){ 26 | return 0.5f; 27 | }else{ 28 | return sizeof(T); 29 | } 30 | } 31 | 32 | /*! 33 | * \brief Used to get the pointer by offset. 34 | * \tparam T A template indicates the data type 35 | * \param ptr Pointer to the data 36 | * \param offset Offset to the pointer 37 | */ 38 | template 39 | FLASHINFER_INLINE T* get_ptr(T* ptr, const size_t offset){ 40 | if constexpr (std::is_same::value){ 41 | return reinterpret_cast(reinterpret_cast(ptr) + offset / 2); 42 | }else if constexpr (std::is_same::value){ 43 | // Patch for const qualifiers 44 | return reinterpret_cast(reinterpret_cast(ptr) + offset / 2); 45 | }else{ 46 | return ptr + offset; 47 | } 48 | } 49 | 50 | /*! 51 | * \brief Dequantize the input into vec_t 52 | * \tparam src_float_t A template indicates the quantization data type 53 | * \tparam vec_size A template integer indicates the vector size 54 | * \param src Const input data 55 | * \param tgt Output data 56 | * \param scale Quantization parameter 57 | * \param zero_point Quantization parameter 58 | */ 59 | template 60 | FLASHINFER_INLINE void dequantize_impl( 61 | const vec_t &src, 62 | vec_t &tgt, 63 | float scale, 64 | float zero_point 65 | ){ 66 | if constexpr (std::is_same::value){ 67 | // 4bit asymmetric quantization 68 | static_assert(vec_size % 8 == 0, "32bits pack 8 u4 elements."); 69 | // 8 x s4 in int32_t register 70 | constexpr size_t PACK_NUM = 8; 71 | #pragma unroll 72 | for(int i = 0;i < vec_size / PACK_NUM;++i){ 73 | uint32_t packedValue = src.at(i); 74 | #pragma unroll 75 | for(int j = 0; j < PACK_NUM;++j){ 76 | float unpackedValue = static_cast(packedValue & 0xf) * scale - zero_point; 77 | tgt[i * PACK_NUM + j] = unpackedValue; 78 | packedValue >>= 4; 79 | } 80 | } 81 | }else{ 82 | // Not implemented 83 | } 84 | } 85 | } // namespace quant 86 | } // namespace flashinfer 87 | 88 | #endif // FLASHINFER_QUANTIZATION_CUH_ -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/rope.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_ROPE_CUH_ 2 | #define FLASHINFER_ROPE_CUH_ 3 | 4 | #include 5 | 6 | namespace flashinfer { 7 | 8 | /*! 9 | * \brief An enumeration class that defines different modes for applying RoPE 10 | * (Rotary Positional Embeddings). 11 | */ 12 | enum class RotaryMode { 13 | // No rotary positional embeddings 14 | kNone = 0U, 15 | // Apply Llama-style rope. 16 | kLlama = 1U, 17 | }; 18 | 19 | /*! 20 | * \brief Convert RotaryMode to string 21 | * \param rotary_mode A RotaryMode value 22 | */ 23 | inline std::string RotaryModeToString(const RotaryMode &rotary_mode) { 24 | switch (rotary_mode) { 25 | case RotaryMode::kNone: 26 | return "None"; 27 | case RotaryMode::kLlama: 28 | return "Llama"; 29 | default: 30 | return "Unknown"; 31 | } 32 | } 33 | 34 | } // namespace flashinfer 35 | 36 | #endif // FLASHINFER_ROPE_CUH_ -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/state.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_STATE_CUH_ 2 | #define FLASHINFER_STATE_CUH_ 3 | 4 | #include "math.cuh" 5 | #include "vec_dtypes.cuh" 6 | 7 | namespace flashinfer { 8 | 9 | /*! 10 | * \brief The flashattention state. 11 | * \tparam vec_size The size of the vector used in o. 12 | * \tparam norm_on_the_fly Whether to normalize the state on the fly. If true, the state will be 13 | * normalized when merge() is called. If false, the state will be normalized when normalize() is 14 | * called. 15 | */ 16 | template 17 | struct state_t { 18 | vec_t o; /* the weighted sum of v: exp(pre-softmax logit - m) * v / d */ 19 | float m; /* maximum value of pre-softmax logits */ 20 | float d; /* sum of exp(pre-softmax logits - m) */ 21 | 22 | __device__ __forceinline__ void init() { 23 | o.fill(0.f); 24 | m = -5e4; 25 | d = 0.f; 26 | } 27 | 28 | __device__ __forceinline__ state_t() { init(); } 29 | 30 | /*! 31 | * \brief Merge the state with another state. 32 | * \param other_m The maximum value of pre-softmax logits of the other state. 33 | * \param other_d The sum of exp(pre-softmax logits - m) of the other state. 34 | * \param other_o The weighted sum of v of the other state. 35 | */ 36 | __device__ __forceinline__ void merge(const vec_t &other_o, float other_m, 37 | float other_d) { 38 | float m_prev = m, d_prev = d; 39 | m = max(m_prev, other_m); 40 | d = d_prev * math::ptx_exp2(m_prev - m) + other_d * math::ptx_exp2(other_m - m); 41 | if constexpr (norm_on_the_fly) { 42 | #pragma unroll 43 | for (size_t i = 0; i < vec_size; ++i) { 44 | o[i] = o[i] * math::ptx_exp2(m_prev - m) * (d_prev / d) + 45 | other_o[i] * math::ptx_exp2(other_m - m) * (other_d / d); 46 | } 47 | } else { 48 | #pragma unroll 49 | for (size_t i = 0; i < vec_size; ++i) { 50 | o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(other_m - m); 51 | } 52 | } 53 | } 54 | 55 | /*! 56 | * \brief Merge the state with another state. 57 | * \param other The other state. 58 | */ 59 | __device__ __forceinline__ void merge(const state_t &other) { 60 | merge(other.o, other.m, other.d); 61 | } 62 | 63 | /*! 64 | * \brief Merge the state with a single pre-softmax logit and value vector. 65 | * \param x The pre-softmax logit. 66 | * \param v The value vector. 67 | */ 68 | __device__ __forceinline__ void merge(const vec_t &other_o, float x) { 69 | float m_prev = m, d_prev = d; 70 | m = max(m_prev, x); 71 | d = d * math::ptx_exp2(m_prev - m) + math::ptx_exp2(x - m); 72 | if constexpr (norm_on_the_fly) { 73 | #pragma unroll 74 | for (size_t i = 0; i < vec_size; ++i) { 75 | o[i] = o[i] * (math::ptx_exp2(m_prev - m) * d_prev / d) + 76 | other_o[i] * (math::ptx_exp2(x - m) / d); 77 | } 78 | } else { 79 | #pragma unroll 80 | for (size_t i = 0; i < vec_size; ++i) { 81 | o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(x - m); 82 | } 83 | } 84 | } 85 | 86 | __device__ __forceinline__ void normalize() { 87 | if constexpr (!norm_on_the_fly) { 88 | // only normalize by d when not normalized on the fly 89 | #pragma unroll 90 | for (size_t i = 0; i < vec_size; ++i) { 91 | o[i] = __fdividef(o[i], d); 92 | } 93 | } 94 | } 95 | }; 96 | 97 | } // namespace flashinfer 98 | 99 | #endif // FLASHINFER_STATE_CUH_ -------------------------------------------------------------------------------- /quarot/kernels/include/flashinfer/utils.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLASHINFER_UTILS_CUH_ 2 | #define FLASHINFER_UTILS_CUH_ 3 | #include 4 | 5 | #include "layout.cuh" 6 | #include "rope.cuh" 7 | 8 | #define FLASHINFER_CUDA_CALL(func, ...) \ 9 | { \ 10 | cudaError_t e = (func); \ 11 | if (e != cudaSuccess) { \ 12 | return e; \ 13 | } \ 14 | } 15 | 16 | #define SWITCH_LAYOUT(layout, LAYOUT, ...) \ 17 | switch (layout) { \ 18 | case QKVLayout::kNHD: { \ 19 | constexpr QKVLayout LAYOUT = QKVLayout::kNHD; \ 20 | __VA_ARGS__ \ 21 | break; \ 22 | } \ 23 | case QKVLayout::kHND: { \ 24 | constexpr QKVLayout LAYOUT = QKVLayout::kHND; \ 25 | __VA_ARGS__ \ 26 | break; \ 27 | } \ 28 | default: { \ 29 | std::cerr << "Unsupported qkv_layout: " << int(layout) << std::endl; \ 30 | abort(); \ 31 | } \ 32 | } 33 | 34 | #define SWITCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ 35 | switch (head_dim) { \ 36 | case 64: { \ 37 | constexpr size_t HEAD_DIM = 64; \ 38 | __VA_ARGS__ \ 39 | break; \ 40 | } \ 41 | case 128: { \ 42 | constexpr size_t HEAD_DIM = 128; \ 43 | __VA_ARGS__ \ 44 | break; \ 45 | } \ 46 | case 256: { \ 47 | constexpr size_t HEAD_DIM = 256; \ 48 | __VA_ARGS__ \ 49 | break; \ 50 | } \ 51 | default: { \ 52 | std::cerr << "Unsupported head_dim: " << head_dim << std::endl; \ 53 | abort(); \ 54 | } \ 55 | } 56 | 57 | #define SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \ 58 | switch (rotary_mode) { \ 59 | case RotaryMode::kNone: { \ 60 | constexpr RotaryMode ROTARY_MODE = RotaryMode::kNone; \ 61 | __VA_ARGS__ \ 62 | break; \ 63 | } \ 64 | case RotaryMode::kLlama: { \ 65 | constexpr RotaryMode ROTARY_MODE = RotaryMode::kLlama; \ 66 | __VA_ARGS__ \ 67 | break; \ 68 | } \ 69 | default: { \ 70 | std::cerr << "Unsupported rotary_mode: " << int(rotary_mode) << std::endl; \ 71 | abort(); \ 72 | } \ 73 | } 74 | 75 | #endif // FLASHINFER_UTILS_CUH_ -------------------------------------------------------------------------------- /quarot/kernels/include/gemm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | 6 | void matmul_host( 7 | const Int4Storage *A, 8 | const Int4Storage *B, 9 | uint32_t M, 10 | uint32_t N, 11 | uint32_t K, 12 | int32_t *C 13 | ); -------------------------------------------------------------------------------- /quarot/kernels/include/int4.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace cutlass { 5 | 6 | template 10 | class MySubbyteReference { 11 | public: 12 | using Element = Element_; 13 | using Storage = Storage_; 14 | using StoragePointer = Storage *; 15 | 16 | static_assert(sizeof_bits::value <= sizeof_bits::value, 17 | "Size of Element must not be greater than Storage."); 18 | 19 | static_assert(!(sizeof_bits::value % sizeof_bits::value), 20 | "Storage must be divisible by Element"); 21 | 22 | constexpr static int const kElementsPerVector = 23 | sizeof_bits::value / sizeof_bits::value; 24 | 25 | private: 26 | ///! Number of elements per storage vector 27 | 28 | ///! Bit mask 29 | Storage const kMask = 30 | ((sizeof_bits::value < sizeof_bits::value) 31 | ? (Storage(1) << sizeof_bits::value) - Storage(1) 32 | : ~Storage(0)); 33 | 34 | private: 35 | /// Pointer to array containing element 36 | StoragePointer ptr_; 37 | 38 | /// Offset (in units of elements) from pointer. 39 | /// 40 | /// Invariant: must always be in range [0, kElementsPerVector) 41 | int offset_; 42 | 43 | public: 44 | CUTLASS_HOST_DEVICE 45 | MySubbyteReference() : ptr_(nullptr), offset_(0) {} 46 | 47 | /// Constructor 48 | CUTLASS_HOST_DEVICE 49 | MySubbyteReference(Element *ptr, /// pointer to memory 50 | int64_t offset /// logical offset in units of Element 51 | ) 52 | : ptr_(reinterpret_cast(ptr)), offset_(0) { 53 | int64_t offset_in_vectors = offset / kElementsPerVector; 54 | int64_t offset_in_elements = offset % kElementsPerVector; 55 | 56 | ptr_ += offset_in_vectors; 57 | offset_ = int(offset_in_elements); 58 | } 59 | 60 | /// Constructor 61 | CUTLASS_HOST_DEVICE 62 | MySubbyteReference(Element *ptr = nullptr) : MySubbyteReference(ptr, 0) {} 63 | 64 | /// Gets storage pointer 65 | CUTLASS_HOST_DEVICE 66 | StoragePointer storage_pointer() const { return ptr_; } 67 | 68 | /// Gets storage pointer 69 | CUTLASS_HOST_DEVICE 70 | Element *operator&() const { return reinterpret_cast(ptr_); } 71 | 72 | /// Gets element offset within storage vector 73 | CUTLASS_HOST_DEVICE 74 | int element_offset() const { return offset_; } 75 | 76 | /// Unpacks an element from memory 77 | CUTLASS_HOST_DEVICE 78 | Element get() const { 79 | Storage item = 80 | Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); 81 | return reinterpret_cast(item); 82 | } 83 | 84 | /// Stores an element to memory 85 | CUTLASS_HOST_DEVICE 86 | MySubbyteReference &set(Element const &x) { 87 | Storage item = (reinterpret_cast(x) & kMask); 88 | Storage kUpdateMask = 89 | Storage(~(kMask << (offset_ * cutlass::sizeof_bits::value))); 90 | Storage new_bits = 91 | Storage(item << (offset_ * cutlass::sizeof_bits::value)); 92 | 93 | Storage original = (*ptr_); 94 | Storage updated = Storage((original & kUpdateMask) | new_bits); 95 | *ptr_ = updated; 96 | 97 | return *this; 98 | } 99 | 100 | //// 101 | 102 | /// Unpacks an element from memory 103 | CUTLASS_HOST_DEVICE 104 | operator Element() const { return get(); } 105 | 106 | /// Stores an element to memory 107 | CUTLASS_HOST_DEVICE 108 | MySubbyteReference &operator=(Element const &x) { return set(x); } 109 | 110 | /// Stores an element to memory 111 | CUTLASS_HOST_DEVICE 112 | MySubbyteReference &operator=(MySubbyteReference const &x) { 113 | return set(x.get()); 114 | } 115 | 116 | /// Stores an element to memory 117 | CUTLASS_HOST_DEVICE 118 | MySubbyteReference &operator=( 119 | ConstSubbyteReference const &x) { 120 | return set(x.get()); 121 | } 122 | 123 | /// Adds an offset in units of elements to the reference 124 | CUTLASS_HOST_DEVICE 125 | MySubbyteReference &operator+=(int offset) { 126 | offset += offset_; 127 | 128 | int offset_in_vectors = offset / kElementsPerVector; 129 | int offset_in_elements = offset % kElementsPerVector; 130 | 131 | ptr_ += offset_in_vectors; 132 | offset_ = offset_in_elements; 133 | 134 | return *this; 135 | } 136 | 137 | /// Adds an offset in units of elements to the reference 138 | CUTLASS_HOST_DEVICE 139 | MySubbyteReference &operator+=(long long offset) { 140 | offset += offset_; 141 | 142 | long long offset_in_vectors = offset / kElementsPerVector; 143 | int offset_in_elements = int(offset % kElementsPerVector); 144 | 145 | ptr_ += offset_in_vectors; 146 | offset_ = offset_in_elements; 147 | 148 | return *this; 149 | } 150 | 151 | /// Adds an offset in units of elements to the reference 152 | CUTLASS_HOST_DEVICE 153 | MySubbyteReference &operator-=(int offset) { 154 | int offset_in_vectors = offset / kElementsPerVector; 155 | int offset_in_elements = offset % kElementsPerVector; 156 | 157 | ptr_ -= offset_in_vectors; 158 | offset_ -= offset_in_elements; 159 | 160 | if (offset_ < 0) { 161 | offset_ += kElementsPerVector; 162 | --ptr_; 163 | } 164 | 165 | return *this; 166 | } 167 | 168 | /// Adds an offset in units of elements to the reference 169 | CUTLASS_HOST_DEVICE 170 | MySubbyteReference &operator-=(long long offset) { 171 | long long offset_in_vectors = offset / kElementsPerVector; 172 | int offset_in_elements = int(offset % kElementsPerVector); 173 | 174 | ptr_ -= offset_in_vectors; 175 | offset_ -= offset_in_elements; 176 | 177 | if (offset_ < 0) { 178 | offset_ += kElementsPerVector; 179 | --ptr_; 180 | } 181 | 182 | return *this; 183 | } 184 | 185 | /// Returns a reference to an element with a given offset from the current 186 | /// reference 187 | CUTLASS_HOST_DEVICE 188 | MySubbyteReference operator+(int offset) const { 189 | MySubbyteReference ref(ptr_, offset_); 190 | ref += offset; 191 | 192 | return ref; 193 | } 194 | 195 | /// Returns a reference to an element with a given offset from the current 196 | /// reference 197 | CUTLASS_HOST_DEVICE 198 | MySubbyteReference operator+(long long offset) const { 199 | MySubbyteReference ref(ptr_, offset_); 200 | ref += offset; 201 | 202 | return ref; 203 | } 204 | 205 | /// Returns a reference to an element with a given offset from the current 206 | /// reference 207 | CUTLASS_HOST_DEVICE 208 | MySubbyteReference operator-(int offset) const { 209 | MySubbyteReference ref(ptr_, offset_); 210 | ref -= offset; 211 | 212 | return ref; 213 | } 214 | 215 | /// Returns a reference to an element with a given offset from the current 216 | /// reference 217 | CUTLASS_HOST_DEVICE 218 | MySubbyteReference operator-=(long long offset) const { 219 | MySubbyteReference ref(ptr_, offset_); 220 | ref -= offset; 221 | 222 | return ref; 223 | } 224 | 225 | /// Computes the difference in elements between references 226 | CUTLASS_HOST_DEVICE 227 | ptrdiff_t operator-(MySubbyteReference ref) const { 228 | return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); 229 | } 230 | 231 | /// Explicit cast to int 232 | CUTLASS_HOST_DEVICE 233 | explicit operator int() const { return int(get()); } 234 | 235 | /// Explicit cast to signed 64-bit integer 236 | CUTLASS_HOST_DEVICE 237 | explicit operator int64_t() const { return int64_t(get()); } 238 | 239 | /// Explicit cast to unsigned 64-bit integer 240 | CUTLASS_HOST_DEVICE 241 | explicit operator uint64_t() const { return uint64_t(get()); } 242 | 243 | /// Explicit cast to float 244 | CUTLASS_HOST_DEVICE 245 | explicit operator float() const { return float(get()); } 246 | 247 | /// Explicit cast to double 248 | CUTLASS_HOST_DEVICE 249 | explicit operator double() const { return double(get()); } 250 | }; 251 | 252 | } // namespace cutlass 253 | 254 | using Int4Subbyte = cutlass::MySubbyteReference; 255 | using Int4Storage = Int4Subbyte::Storage; 256 | constexpr const uint32_t kElementsPerVector = 257 | cutlass::sizeof_bits::value / 258 | cutlass::sizeof_bits::value; 259 | -------------------------------------------------------------------------------- /quarot/kernels/include/quant.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | 6 | void sym_quant_host( 7 | const half *x, 8 | const half *scale, 9 | uint32_t rows, 10 | uint32_t colsSrc, 11 | uint32_t colsDst, 12 | Int4Storage *q 13 | ); 14 | 15 | 16 | void sym_dequant_host( 17 | const int32_t *q, 18 | const half *scale_row, 19 | const half *scale_col, 20 | uint32_t rows, 21 | uint32_t cols, 22 | half *x 23 | ); -------------------------------------------------------------------------------- /quarot/kernels/include/util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | /* TODO: This file can be safely discarded. Kept in case needed in the future. 3 | 4 | // #include 5 | // #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | 12 | 13 | #define _BITS_STDINT_UINTN_H 1 14 | #define _BITS_STDINT_INTN_H 1 15 | #include 16 | typedef __int8_t int8_t; 17 | typedef __int16_t int16_t; 18 | typedef __int32_t int32_t; 19 | typedef __int64_t int64_t; 20 | typedef __uint8_t uint8_t; 21 | typedef __uint16_t uint16_t; 22 | typedef __uint32_t uint32_t; 23 | typedef __uint64_t uint64_t; 24 | 25 | 26 | 27 | template 28 | struct TorchDtypeDispatcher; 29 | 30 | template <> 31 | struct TorchDtypeDispatcher { 32 | constexpr static const auto value = torch::kUInt8; 33 | }; 34 | 35 | template <> 36 | struct TorchDtypeDispatcher { 37 | constexpr static const auto value = torch::kInt8; 38 | }; 39 | 40 | template <> 41 | struct TorchDtypeDispatcher { 42 | constexpr static const auto value = torch::kInt32; 43 | }; 44 | 45 | template <> 46 | struct TorchDtypeDispatcher { 47 | constexpr static const auto value = torch::kFloat16; 48 | }; 49 | 50 | template 51 | struct DtypeTorchDispatcher; 52 | 53 | template <> 54 | struct DtypeTorchDispatcher { 55 | using value = __half; 56 | }; 57 | 58 | template <> 59 | struct DtypeTorchDispatcher { 60 | using value = __nv_bfloat16; 61 | }; 62 | 63 | template 64 | __device__ inline int type2int_rn(T a) { 65 | return static_cast(a); 66 | } 67 | 68 | template <> 69 | __device__ inline int type2int_rn<__half>(__half input) { 70 | return __half2int_rn(input); 71 | } 72 | 73 | // template <> 74 | // __device__ inline int type2int_rn<__nv_bfloat16>(__nv_bfloat16 input) { 75 | // return __bfloat162int_rn(input); 76 | // } 77 | 78 | template 79 | __device__ inline float type2float(T a) { 80 | return static_cast(a); 81 | } 82 | 83 | template <> 84 | __device__ inline float type2float<__half>(__half input) { 85 | return __half2float(input); 86 | } 87 | 88 | template <> 89 | __device__ inline float type2float<__nv_bfloat16>(__nv_bfloat16 input) { 90 | return __bfloat162float(input); 91 | } 92 | 93 | template 94 | __device__ inline T float2type(float a) { 95 | return static_cast(a); 96 | } 97 | 98 | template <> 99 | __device__ inline __half float2type<__half>(float input) { 100 | return __float2half(input); 101 | } 102 | 103 | template <> 104 | __device__ inline __nv_bfloat16 float2type<__nv_bfloat16>(float input) { 105 | return __float2bfloat16_rn(input); 106 | } 107 | 108 | template 109 | struct DtypeDtype2Dispatcher; 110 | 111 | template <> 112 | struct DtypeDtype2Dispatcher<__half> { 113 | using value = __half2; 114 | }; 115 | 116 | template <> 117 | struct DtypeDtype2Dispatcher<__nv_bfloat16> { 118 | using value = __nv_bfloat162; 119 | }; 120 | 121 | __device__ inline __half2 type2type2(__half input, __half input2) { 122 | return __halves2half2(input, input2); 123 | } 124 | 125 | // __device__ inline __nv_bfloat162 type2type2(__nv_bfloat16 input, 126 | // __nv_bfloat16 input2) { 127 | // return __halves2bfloat162(input, input2); 128 | // } 129 | 130 | // template 131 | // T div(T a, T b) { 132 | // return a / b; 133 | // } 134 | // 135 | // template <> 136 | //__half div(__half a, __half b) { 137 | // return __hdiv(a, b); 138 | // } 139 | // 140 | // template <> 141 | //__nv_bfloat16 div(__nv_bfloat16 a, __nv_bfloat16 b) { 142 | // return __hdiv(a, b); 143 | // } 144 | 145 | */ -------------------------------------------------------------------------------- /quarot/kernels/quant.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | template 5 | __device__ __half int_to_half(T value) 6 | { 7 | return __int2half_rn(static_cast(value)); 8 | } 9 | 10 | 11 | __global__ 12 | void sym_quantize_f16_i4_kernel( 13 | const half *__restrict__ x, 14 | const half *__restrict__ scale, 15 | uint32_t rows, 16 | uint32_t colsSrc, 17 | uint32_t colsDst, 18 | Int4Storage *__restrict__ q 19 | ) 20 | { 21 | uint32_t row = threadIdx.y + blockIdx.y * blockDim.y; 22 | uint32_t colDst = threadIdx.x + blockIdx.x * blockDim.x; 23 | if (row >= rows || colDst * kElementsPerVector >= colsSrc) 24 | { 25 | return; 26 | } 27 | Int4Storage storage; 28 | memset(&storage, 0, sizeof(storage)); 29 | uint32_t id = colDst * kElementsPerVector + row * colsSrc; 30 | #pragma unroll 31 | for (int i = 0; i < kElementsPerVector; ++i) 32 | { 33 | bool safe = (colDst * kElementsPerVector + i) < colsSrc; 34 | if (safe) 35 | { 36 | half data = __hdiv(x[id + i], scale[row]); 37 | 38 | int qval = clamp(__half2int_rn(data), qmin, qmax); 39 | Int4Subbyte{reinterpret_cast(&storage), i}.set( 40 | qval); 41 | } 42 | } 43 | 44 | q[colDst + row * colsDst] = storage; 45 | } 46 | 47 | 48 | void sym_quant_host( 49 | const half *x, 50 | const half *scale, 51 | uint32_t rows, 52 | uint32_t colsSrc, 53 | uint32_t colsDst, 54 | Int4Storage *q 55 | ) 56 | { 57 | 58 | dim3 block{std::min(colsDst, 32), std::min(rows, 16)}; 59 | dim3 grid{cdiv(colsDst, block.x), cdiv(rows, block.y)}; 60 | sym_quantize_f16_i4_kernel<<>>(x, scale, rows, colsSrc, colsDst, q); 61 | } 62 | 63 | 64 | __global__ void sym_dequantize_i32_f16_kernel( 65 | const int32_t *__restrict__ q, 66 | const half *__restrict__ scale_row, 67 | const half *__restrict__ scale_col, 68 | uint32_t rows, uint32_t cols, 69 | half *__restrict__ x) 70 | { 71 | uint32_t row = threadIdx.y + blockIdx.y * blockDim.y; 72 | uint32_t col = threadIdx.x + blockIdx.x * blockDim.x; 73 | 74 | if (col >= cols || row >= rows) 75 | { 76 | return; 77 | } 78 | 79 | half xElement = int_to_half(q[col + row * cols]); 80 | x[col + row * cols] = scale_row[row] * scale_col[col] * xElement; 81 | } 82 | 83 | void sym_dequant_host(const int32_t *q, 84 | const half *scale_row, 85 | const half *scale_col, 86 | uint32_t rows, 87 | uint32_t cols, 88 | half *x 89 | ) 90 | { 91 | dim3 block{std::min(cols, 16), std::min(rows, 16)}; 92 | dim3 grid{cdiv(cols, block.x), cdiv(rows, block.y)}; 93 | sym_dequantize_i32_f16_kernel<<>>( 94 | q, 95 | scale_row, scale_col, 96 | rows, cols, x); 97 | } 98 | -------------------------------------------------------------------------------- /quarot/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import Linear4bit 2 | from .normalization import RMSNorm 3 | from .quantization import Quantizer 4 | from .hadamard import OnlineHadamard 5 | -------------------------------------------------------------------------------- /quarot/nn/hadamard.py: -------------------------------------------------------------------------------- 1 | import quarot 2 | import torch 3 | 4 | 5 | class OnlineHadamard(torch.nn.Module): 6 | def __init__(self, hadamard_dim, force_fp32=False): 7 | super().__init__() 8 | self.fp32_had = force_fp32 9 | had_rem_dim, self.rem_dim = quarot.functional.hadamard.get_hadK(hadamard_dim) 10 | if had_rem_dim is not None: 11 | self.register_buffer("had_rem_dim", had_rem_dim) 12 | if not self.fp32_had: 13 | self.had_rem_dim = self.had_rem_dim.to(torch.float16) 14 | else: 15 | self.had_rem_dim = None 16 | 17 | def forward(self, x): 18 | x_dtype = x.dtype 19 | if self.fp32_had: 20 | x = x.float() 21 | x = quarot.functional.matmul_hadU_cuda(x, self.had_rem_dim, self.rem_dim) 22 | x = x.to(x_dtype) 23 | return x 24 | -------------------------------------------------------------------------------- /quarot/nn/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import quarot 4 | import fast_hadamard_transform 5 | 6 | 7 | class ShapeHandler: 8 | def __init__(self, x: torch.Tensor): 9 | self.size_excl_last = x.numel()//x.shape[-1] 10 | self.shape_excl_last = tuple(x.shape[:-1]) 11 | 12 | # Keep the last dim unchanged, flatten all previous dims 13 | def flatten(self, x: torch.Tensor): 14 | return x.view(self.size_excl_last, -1) 15 | 16 | # Recover back to the original shape. 17 | def unflatten(self, x: torch.Tensor): 18 | return x.view(self.shape_excl_last + (-1,)) 19 | 20 | def unflatten_scale(self, x: torch.Tensor): 21 | return x.view(self.shape_excl_last) 22 | 23 | 24 | class Linear4bit(torch.nn.Module): 25 | def __init__(self, in_features, out_features, bias=False, dtype=torch.float16): 26 | ''' 27 | Symmetric 4-bit Linear Layer. 28 | ''' 29 | super().__init__() 30 | self.in_features = in_features 31 | self.out_features = out_features 32 | self.register_buffer('weight_scales', 33 | torch.zeros((self.out_features, 1), requires_grad=False)) 34 | self.register_buffer('weight', (torch.randint(1, 7, (self.out_features, self.in_features // 2), 35 | # SubByte weight 36 | dtype=torch.uint8, requires_grad=False))) 37 | if bias: 38 | self.register_buffer('bias', torch.zeros((self.out_features), dtype=dtype)) 39 | else: 40 | self.bias = None 41 | 42 | def forward(self, x): 43 | #if torch.cuda.current_device() != x.device: 44 | # torch.cuda.set_device(x.device) 45 | 46 | assert type(x) == quarot.PackedQuantizedTensor #Quantized input is given 47 | x, scales_x = x.quantized_x, x.scales_x 48 | #shape_handler = ShapeHandler(quantized_x) 49 | #quantized_x = shape_handler.flatten(quantized_x) 50 | x = quarot.matmul(x, self.weight) 51 | #out = shape_handler.unflatten( 52 | # quarot.sym_dequant(int_result, scales_x, self.weight_scales)) 53 | if self.bias is not None: 54 | return quarot.sym_dequant(x, scales_x, self.weight_scales) + self.bias 55 | else: 56 | return quarot.sym_dequant(x, scales_x, self.weight_scales) 57 | 58 | @staticmethod 59 | def from_float(module: torch.nn.Linear, weight_scales=None,): 60 | ''' 61 | Generate a new Linear4bit module from a FP16 Linear module. 62 | The weight matrix should have the same shape as the weight matrix of the FP16 Linear module and rounded using torch.round() 63 | routine. We will convert it to subByte representation and save it in the int_weight buffer. 64 | ''' 65 | weight_matrix = module.weight.data 66 | 67 | 68 | int_module = Linear4bit(module.in_features, module.out_features, bias=module.bias is not None, dtype=weight_matrix.dtype).to(weight_matrix.dtype) 69 | if weight_scales is not None: 70 | assert weight_scales.shape == (module.out_features, 1), 'weight_scales should have shape (out_features, 1)' 71 | weight_matrix = weight_matrix.cuda() 72 | int_module.weight_scales.copy_(weight_scales.to(weight_matrix.dtype)) 73 | int_rounded_weight = (weight_matrix/weight_scales.cuda()).round() 74 | int_module.weight.copy_(quarot.functional.pack_i4(int_rounded_weight.to(torch.int8)).cpu()) 75 | 76 | if module.bias is not None: 77 | int_module.bias.copy_(module.bias) 78 | 79 | return int_module 80 | -------------------------------------------------------------------------------- /quarot/nn/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RMSNorm(torch.nn.Module): 4 | """ 5 | This class implements the Root Mean Square Normalization (RMSN) layer. 6 | We use the implementation from LLAMARMSNorm here: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L75 8 | """ 9 | 10 | def __init__(self, mean_dim: int, eps=1e-5): 11 | super().__init__() 12 | self.eps = eps 13 | self.mean_dim = mean_dim 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | input_dtype = x.dtype 17 | if x.dtype == torch.float16: 18 | x = x.to(torch.float32) 19 | variance = x.pow(2).sum(-1, keepdim=True) / self.mean_dim 20 | x = x * torch.rsqrt(variance + self.eps) 21 | return x.to(input_dtype) 22 | -------------------------------------------------------------------------------- /quarot/nn/quantization.py: -------------------------------------------------------------------------------- 1 | import quarot 2 | import torch 3 | 4 | class Quantizer(torch.nn.Module): 5 | def __init__(self, input_clip_ratio=1.0): 6 | super().__init__() 7 | self.input_clip_ratio = input_clip_ratio 8 | 9 | def forward(self, x): 10 | scales_x = (torch.max(torch.abs(x), dim=-1)[0].unsqueeze(1)/7).to(torch.float16) * self.input_clip_ratio 11 | quantized_x = quarot.sym_quant(x, scales_x) 12 | packed_tensor = quarot.PackedQuantizedTensor(quantized_x, scales_x) 13 | return packed_tensor 14 | -------------------------------------------------------------------------------- /quarot/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from .kv_cache import MultiLayerPagedKVCache4Bit 2 | -------------------------------------------------------------------------------- /quarot/transformers/kv_cache.py: -------------------------------------------------------------------------------- 1 | from transformers.cache_utils import Cache 2 | from typing import Optional, Tuple, Dict, Any 3 | import math 4 | import torch 5 | from .. import _CUDA 6 | import functools 7 | from fast_hadamard_transform import hadamard_transform 8 | from quarot.functional.quantization import get_minq_maxq 9 | 10 | @torch.jit.script 11 | def asym_quantize_and_pack_i4(x: torch.Tensor): 12 | minq, maxq = get_minq_maxq(bits=4, sym=False) 13 | xmax = torch.amax(x, dim=-1, keepdim=True) 14 | xmin = torch.amin(x, dim=-1, keepdim=True) 15 | scale = ((xmax - xmin).clamp(min=1e-5) / maxq) 16 | zero = -xmin 17 | q = torch.clamp(torch.round((x + zero) / scale), 0, maxq) 18 | 19 | # pack int4 20 | q = q.to(dtype=torch.uint8) 21 | q = q[..., 0::2] | (q[..., 1::2] << 4) 22 | return q, scale, zero 23 | 24 | def unpack_i4_and_asym_dequantize(q, scale, zero): 25 | #unpack int4 26 | assert q.dtype == torch.uint8 27 | q = torch.stack((q & 0x0f, (q >> 4) & 0x0f), dim=-1).view(*q.shape[:-1], q.shape[-1] * 2) 28 | return q * scale - zero 29 | 30 | def matmul_had_cuda(X, dtype): 31 | n = X.shape[-1] 32 | input = hadamard_transform(X.to(dtype).contiguous(), scale=1/math.sqrt(n)) 33 | return input.to(X.dtype).view(X.shape) 34 | 35 | 36 | def init_kv_i4(kv_data, kv_param, 37 | kv_indptr, kv_indices, 38 | last_page_offset, k, 39 | v, k_param, v_param, 40 | seqlen_indptr, layer_idx): 41 | return _CUDA.init_kv_i4( 42 | kv_data, kv_param, 43 | kv_indptr, kv_indices, 44 | last_page_offset, k, 45 | v, k_param, v_param, 46 | seqlen_indptr, layer_idx) 47 | 48 | 49 | def append_kv_i4(kv_data, kv_param, 50 | kv_indptr, kv_indices, 51 | last_page_offset, k, 52 | v, k_param, v_param, 53 | layer_idx): 54 | return _CUDA.append_kv_i4( 55 | kv_data, kv_param, 56 | kv_indptr, kv_indices, 57 | last_page_offset, k, 58 | v, k_param, v_param, 59 | layer_idx) 60 | 61 | def batch_decode_i4(o, q, kv_data, kv_param, 62 | kv_indptr, kv_indices, 63 | last_page_offset, layer_idx): 64 | return _CUDA.batch_decode_i4( 65 | o, q, kv_data, kv_param, 66 | kv_indptr, kv_indices, 67 | last_page_offset, layer_idx) 68 | 69 | 70 | def init_kv_f16(kv_data, kv_param, 71 | kv_indptr, kv_indices, 72 | last_page_offset, k, 73 | v, k_param, v_param, 74 | seqlen_indptr, layer_idx): 75 | return _CUDA.init_kv_f16( 76 | kv_data, kv_param, 77 | kv_indptr, kv_indices, 78 | last_page_offset, k, 79 | v, k_param, v_param, 80 | seqlen_indptr, layer_idx) 81 | 82 | 83 | def append_kv_f16(kv_data, kv_param, 84 | kv_indptr, kv_indices, 85 | last_page_offset, k, 86 | v, k_param, v_param, 87 | layer_idx): 88 | return _CUDA.append_kv_f16( 89 | kv_data, kv_param, 90 | kv_indptr, kv_indices, 91 | last_page_offset, k, 92 | v, k_param, v_param, 93 | layer_idx) 94 | 95 | def batch_decode_f16(o, q, kv_data, kv_param, 96 | kv_indptr, kv_indices, 97 | last_page_offset, layer_idx): 98 | return _CUDA.batch_decode_f16( 99 | o, q, kv_data, kv_param, 100 | kv_indptr, kv_indices, 101 | last_page_offset, layer_idx) 102 | 103 | 104 | class _AttentionStub(object): 105 | def __init__(self, cache_page_size, device, n_layers, disable_quant, hadamard_dtype): 106 | self.cache_page_size = cache_page_size 107 | self.n_layers = n_layers 108 | self.disable_quant = disable_quant 109 | self.hadamard_dtype = hadamard_dtype 110 | 111 | def forward(self, q, num_kv_heads, attention_kwargs, layer_idx): 112 | batch_size, q_len, num_qo_heads, head_dim = q.shape 113 | assert q_len == 1 114 | q = q.view(batch_size, num_qo_heads, head_dim) 115 | if self.hadamard_dtype is not None: 116 | q = matmul_had_cuda(q, dtype=self.hadamard_dtype) 117 | attn_output = torch.empty_like(q) 118 | if self.disable_quant: 119 | batch_decode = batch_decode_f16 120 | else: 121 | batch_decode = batch_decode_i4 122 | batch_decode( 123 | attn_output, q, 124 | **attention_kwargs, layer_idx=layer_idx 125 | ) 126 | attn_output = attn_output.unsqueeze(1) 127 | return attn_output 128 | 129 | 130 | class MultiLayerPagedKVCache4Bit(Cache): 131 | def __init__( 132 | self, batch_size, page_size, max_seq_len, 133 | device, n_layers, num_heads, head_dim, 134 | disable_quant=False, hadamard_dtype=torch.float16 ): 135 | self.page_size = page_size 136 | self.batch_size = batch_size 137 | max_page_cnt = self.page_cnt_from_length(max_seq_len) 138 | self.disable_quant = disable_quant 139 | self.pages = torch.empty( 140 | ( 141 | max_page_cnt * batch_size, 142 | n_layers, 143 | 2, 144 | num_heads, 145 | page_size, 146 | head_dim if disable_quant else head_dim // 2 147 | ), 148 | dtype=torch.float16 if disable_quant else torch.uint8, device=device) 149 | 150 | self.scales = torch.empty((max_page_cnt * batch_size, n_layers, 2, num_heads, page_size, 2), dtype=torch.float16, device=device) 151 | self.page_size = page_size 152 | self.max_seq_len = max_seq_len 153 | self._needs_init = [True] * n_layers 154 | self.length = 0 155 | self.device = device 156 | self.hadamard_dtype = hadamard_dtype 157 | self._stub = _AttentionStub( 158 | self.page_size, device, n_layers, 159 | disable_quant=self.disable_quant, 160 | hadamard_dtype=self.hadamard_dtype) 161 | 162 | def page_cnt_from_length(self, length): 163 | return (length + self.page_size - 1) // self.page_size 164 | 165 | def _ensure_page_cnt_per_batch(self, expected_page_cnt_per_batch): 166 | expected_page_cnt = expected_page_cnt_per_batch * self.batch_size 167 | if expected_page_cnt <= self.pages.shape[0]: 168 | return 169 | raise NotImplementedError 170 | 171 | @property 172 | def seen_tokens(self): 173 | return self.length 174 | 175 | def update( 176 | self, 177 | key_states: torch.Tensor, 178 | value_states: torch.Tensor, 179 | layer_idx: int, 180 | cache_kwargs: Optional[Dict[str, Any]] = None, 181 | ): 182 | 183 | b_sz, added_length, num_heads, head_dim = key_states.shape 184 | 185 | orig_key_states = key_states 186 | orig_value_states = value_states 187 | 188 | if self.hadamard_dtype is not None: 189 | key_states = matmul_had_cuda(key_states, dtype=self.hadamard_dtype) 190 | 191 | if self.disable_quant: 192 | k_scale = key_states.new_ones((b_sz, added_length, num_heads, 1)) 193 | k_zero = key_states.new_zeros((b_sz, added_length, num_heads, 1)) 194 | v_scale = value_states.new_ones((b_sz, added_length, num_heads, 1)) 195 | v_zero = value_states.new_zeros((b_sz, added_length, num_heads, 1)) 196 | else: 197 | key_states, k_scale, k_zero = asym_quantize_and_pack_i4(key_states) 198 | value_states, v_scale, v_zero = asym_quantize_and_pack_i4(value_states) 199 | 200 | k_param = torch.cat([k_scale, k_zero], dim=-1).view(self.batch_size * added_length, num_heads, 2) 201 | v_param = torch.cat([v_scale, v_zero], dim=-1).view(self.batch_size * added_length, num_heads, 2) 202 | 203 | quantized_head_dim = self.pages.shape[-1] 204 | 205 | assert b_sz == self.batch_size 206 | if layer_idx == 0: 207 | current_length = self.length 208 | new_length = current_length + added_length 209 | self._ensure_page_cnt_per_batch(self.page_cnt_from_length(new_length)) 210 | self.length = new_length 211 | attention_mask = cache_kwargs.get("attention_mask") 212 | if self._needs_init[layer_idx]: 213 | self._needs_init[layer_idx] = False 214 | if attention_mask is not None: 215 | nonzero_indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten().view(-1, 1) 216 | key_states = key_states.view(self.batch_size * added_length, num_heads * quantized_head_dim) 217 | value_states = value_states.view(self.batch_size * added_length, num_heads * quantized_head_dim) 218 | key_states = torch.gather(key_states, 0, nonzero_indices.expand(-1, num_heads * quantized_head_dim)) 219 | value_states = torch.gather(value_states, 0, nonzero_indices.expand(-1, num_heads * quantized_head_dim)) 220 | 221 | k_param = k_param.view(self.batch_size * added_length, num_heads * 2) 222 | v_param = v_param.view(self.batch_size * added_length, num_heads * 2) 223 | k_param = torch.gather(k_param, 0, nonzero_indices.expand(-1, num_heads * 2)) 224 | v_param = torch.gather(v_param, 0, nonzero_indices.expand(-1, num_heads * 2)) 225 | 226 | seqlens_in_batch = torch.nn.functional.pad(torch.cumsum(attention_mask.sum(dim=-1, dtype=torch.int32), dim=0, dtype=torch.int32), (1, 0)) 227 | else: 228 | seqlens_in_batch = torch.arange(self.batch_size + 1, device=self.device, dtype=torch.int) * added_length 229 | 230 | init_kv = init_kv_f16 if self.disable_quant else init_kv_i4 231 | init_kv( 232 | **self.get_cache_specs_for_flash_infer(attention_mask), 233 | k=key_states.view(-1, num_heads, quantized_head_dim), 234 | v=value_states.view(-1, num_heads, quantized_head_dim), 235 | k_param=k_param.view(-1, num_heads, 2), 236 | v_param=v_param.view(-1, num_heads, 2), 237 | seqlen_indptr=seqlens_in_batch, 238 | layer_idx=layer_idx 239 | ) 240 | return orig_key_states, orig_value_states 241 | else: 242 | assert added_length == 1 243 | append_kv = append_kv_f16 if self.disable_quant else append_kv_i4 244 | append_kv( 245 | **self.get_cache_specs_for_flash_infer(attention_mask), 246 | k=key_states.view(self.batch_size, num_heads, quantized_head_dim), 247 | v=value_states.view(self.batch_size, num_heads, quantized_head_dim), 248 | k_param=k_param.view(-1, num_heads, 2), 249 | v_param=v_param.view(-1, num_heads, 2), 250 | layer_idx=layer_idx, 251 | ) 252 | return functools.partial( 253 | self._stub.forward, 254 | num_kv_heads=num_heads, 255 | attention_kwargs=self.get_cache_specs_for_flash_infer(attention_mask), 256 | layer_idx=layer_idx, 257 | ) 258 | 259 | def get_cache_specs_for_flash_infer(self, attention_mask): 260 | if attention_mask is not None: 261 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 262 | else: 263 | seqlens_in_batch = torch.tensor([self.length], dtype=torch.int32, device=self.device).expand(self.batch_size) 264 | page_cnt = self.page_cnt_from_length(seqlens_in_batch) 265 | if (page_cnt[0] != page_cnt).any(): 266 | raise NotImplementedError("Current implementation does not support the case where batches have different number of pages") 267 | page_cnt = page_cnt[0] 268 | page_ptr = seqlens_in_batch % self.page_size 269 | page_ptr = torch.where((seqlens_in_batch != 0) & (page_ptr == 0), self.page_size, page_ptr) 270 | return { 271 | f"kv_data": self.pages, 272 | f"kv_indptr": torch.arange(0, self.batch_size + 1, device=self.device, dtype=torch.int) * page_cnt, 273 | f"kv_indices": ( 274 | (torch.arange(page_cnt, device=self.device, dtype=torch.int) * self.batch_size).unsqueeze(0) + 275 | torch.arange(self.batch_size, device=self.device, dtype=torch.int).unsqueeze(1)).view(-1), 276 | f"last_page_offset": page_ptr, #torch.full((self.batch_size, ), page_ptr, device=self.device, dtype=torch.int), 277 | f"kv_param": self.scales, 278 | } 279 | 280 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 281 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 282 | return self.length 283 | 284 | def get_max_length(self) -> Optional[int]: 285 | """Returns the maximum sequence length of the cached states, if there is any.""" 286 | return None 287 | 288 | def to_legacy_cache(self): 289 | return self 290 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.38.0 2 | torch==2.2.1 3 | sentencepiece==0.2.0 4 | wandb==0.16.3 5 | huggingface-hub==0.20.3 6 | accelerate==0.27.2 7 | datasets==2.17.1 8 | lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@9b0b15b1ccace3534ffbd13298c569869ce8eaf3 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import torch.utils.cpp_extension as torch_cpp_ext 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | import os 5 | import pathlib 6 | setup_dir = os.path.dirname(os.path.realpath(__file__)) 7 | HERE = pathlib.Path(__file__).absolute().parent 8 | 9 | def remove_unwanted_pytorch_nvcc_flags(): 10 | REMOVE_NVCC_FLAGS = [ 11 | '-D__CUDA_NO_HALF_OPERATORS__', 12 | '-D__CUDA_NO_HALF_CONVERSIONS__', 13 | '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', 14 | '-D__CUDA_NO_HALF2_OPERATORS__', 15 | ] 16 | for flag in REMOVE_NVCC_FLAGS: 17 | try: 18 | torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) 19 | except ValueError: 20 | pass 21 | 22 | def get_cuda_arch_flags(): 23 | return [ 24 | '-gencode', 'arch=compute_75,code=sm_75', # Turing 25 | '-gencode', 'arch=compute_80,code=sm_80', # Ampere 26 | '-gencode', 'arch=compute_86,code=sm_86', # Ampere 27 | ] 28 | 29 | def third_party_cmake(): 30 | import subprocess, sys, shutil 31 | 32 | cmake = shutil.which('cmake') 33 | if cmake is None: 34 | raise RuntimeError('Cannot find CMake executable.') 35 | 36 | retcode = subprocess.call([cmake, HERE]) 37 | if retcode != 0: 38 | sys.stderr.write("Error: CMake configuration failed.\n") 39 | sys.exit(1) 40 | 41 | # install fast hadamard transform 42 | hadamard_dir = os.path.join(HERE, 'third-party/fast-hadamard-transform') 43 | pip = shutil.which('pip') 44 | retcode = subprocess.call([pip, 'install', '-e', hadamard_dir]) 45 | 46 | if __name__ == '__main__': 47 | third_party_cmake() 48 | remove_unwanted_pytorch_nvcc_flags() 49 | setup( 50 | name='quarot', 51 | ext_modules=[ 52 | CUDAExtension( 53 | name='quarot._CUDA', 54 | sources=[ 55 | 'quarot/kernels/bindings.cpp', 56 | 'quarot/kernels/gemm.cu', 57 | 'quarot/kernels/quant.cu', 58 | 'quarot/kernels/flashinfer.cu', 59 | ], 60 | include_dirs=[ 61 | os.path.join(setup_dir, 'quarot/kernels/include'), 62 | os.path.join(setup_dir, 'third-party/cutlass/include'), 63 | os.path.join(setup_dir, 'third-party/cutlass/tools/util/include') 64 | ], 65 | extra_compile_args={ 66 | 'cxx': [], 67 | 'nvcc': get_cuda_arch_flags(), 68 | } 69 | ) 70 | ], 71 | cmdclass={ 72 | 'build_ext': BuildExtension 73 | } 74 | ) 75 | --------------------------------------------------------------------------------