├── .gitignore ├── LICENSE ├── README.md ├── assets ├── marlin.png ├── models.png ├── peak.png └── sustained.png ├── bench.py ├── gptq ├── datautils.py ├── eval.py ├── gptq.py ├── llama2.py └── quant.py ├── marlin ├── __init__.py ├── marlin_cuda.cpp └── marlin_cuda_kernel.cu ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | profile.ncu-rep 4 | *.egg-info* 5 | backup 6 | marlin/__pycache__ 7 | _backup 8 | gptq/__pycache__ 9 | gptq/*.marlin.g128 10 | __pycache__ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # Marlin 6 | 7 | This is Marlin, a **M**ixed **A**uto-**R**egressive **Lin**ear kernel (and the name of one of the planet's fastest fish), an extremely optimized FP16xINT4 matmul kernel aimed at LLM inference that can deliver close to ideal (4x) 8 | speedups up to batchsizes of 16-32 tokens (in contrast to the 1-2 tokens of prior work with comparable speedup). This makes Marlin well suited for larger-scale 9 | serving, speculative decoding or advanced multi-inference schemes such as CoT-Majority. 10 | 11 | ## Techniques: 12 | 13 | Most modern GPUs feature FLOP to byte ratios of around 100-200. 14 | Hence, as long as we perform less than 25-50 (tensor core) multiply-accumulates per 4-bit quantized weight, it should (theoretically) be possible to maintain near ideal 4x speedup over FP16 weights. 15 | This means that the full performance benefits of weight-only quantization should, in principle, extend to batchsizes 4-8x larger than what is currently achieved by existing kernels. 16 | However, actually realizing this in practice is very challenging, since we essentially need to fully utilize all available GPU resources (global memory, L2 cache, shared memory, tensor cores, vector cores), *simultaneously*. 17 | Marlin accomplishes this through numerous techniques and optimizations, briefly sketched below: 18 | 19 | * We organize computation in such a way that all activations are essentially always fetched from L2 cache and are further reused several times within registers to make sure that repeated loading from shared memory does not become a bottleneck either. 20 | * We execute global weight loads asynchronously, to all compute operations but also activations loads, with a cache policy that allows immediate eviction in order to not unnecessary pollute the L2 cache with values that are never reused. 21 | * We perform shared memory loads, whose footprint is quite significant due to relatively large activations, via double buffering to overlap them with computation and global loads. 22 | * We carefully order dequantization and tensor core instructions to ensure that both GPU pipelines are well saturated and do not bottleneck each other. 23 | * In general, both quantized weights and group scales are reshuffled offline, into a layout that gives ideal access patterns during execution, allowing for instance directly dequantizing weights into tensor core organization. 24 | * We have multiple warps in a threadblock compute partial results of the same output tile, in order to achieve higher warp counts, maximizing compute and latency hiding, without increasing the output tile size, which would make good partioning on realistic matrices difficult. 25 | * All loads use maximum vector length for peak efficiency and we also perform several layout transformations to guarantee that all shared memory reads and writes are conflict-free, in particular for matrix loading instructions, and that global reduction happens at minimal memory overhead. 26 | * We set up and unroll loops such that the majority of memory offsets are static, minimizing runtime index calculations. 27 | * We implement a "striped" paritioning scheme where the segment of tiles processed by each SM may (partially) span over multiple column "slices". This leads to good SM utlization on most matrix shapes, while minimizing required global reduction steps. 28 | * Global reduction happens directly in the output buffer (temporarily downcasting FP32 accumulators to FP16) which is kept in L2 cache; reduction operations are generally optimized to avoid any unnecessary reads or writes as well. 29 | * Overall, the kernel's PTX assembly was extensively analyzed in NSight-Compute, and the CUDA code features several more redundant or slightly suboptimal constructions that however compile to faster PTX. 30 | 31 | ## Benchmarks: 32 | 33 | We first compare the performance of Marlin with other popular 4-bit inference kernels, on a large matrix that can be 34 | ideally partioned on an NVIDIA A10 GPU. This allows all kernels to reach pretty much their best possible performance. 35 | All kernels are executed at groupsize 128 (however, we note that scale formats are not 100% identical). 36 | 37 |
38 | 39 |
40 | 41 | While existing kernels achieve relatively close to the optimal 3.87x (note the 0.125 bits storage overhead of the 42 | group scales) speedup at batchsize 1, their performance degrades quickly as the number of inputs is increased. In 43 | contrast, Marlin delivers essentially ideal speedups at all batchsizes, enabling the maximum possible 3.87x speedup up 44 | to batchsizes around 16-32. 45 | 46 | Due to its striped partioning scheme, Marlin brings strong performance also on real (smaller) matrices and various GPUs. 47 | This is demonstrated by the below results, where we benchmark, at batchsize 16, the overall runtime across all linear 48 | layers in Transformer blocks of popular open-source models. 49 | 50 |
51 | 52 |
53 | 54 | Finally, we also study what performance can be sustained over longer periods of time, at locked base GPU clock. 55 | Interestingly, we find that reduced clock speeds significantly harm the relative speedups of prior kernels, but have no 56 | effect on Marlin's virtually optimal performance (relative to the lower clock setting). 57 | 58 |
59 | 60 |
61 | 62 | ## Requirements: 63 | 64 | * CUDA >= 11.8 (in particular also for the `nvcc` compiler, the version of which should match with torch) 65 | * NVIDIA GPU with compute capability >= 8.0 (Ampere or Ada, Marlin is not yet optimized for Hopper) 66 | * `torch>=2.0.0` 67 | * `numpy` 68 | For running quantization script one also needs: 69 | * `transformers` 70 | * `datasets` 71 | * `sentencepiece` 72 | 73 | ## Usage: 74 | 75 | If all requirements are met, it should be possible to install Marlin by calling 76 | 77 | ``` 78 | pip install . 79 | ``` 80 | 81 | in the root folder of this repository. 82 | 83 | Afterwards, the easiest way to use the Marlin kernel is via a `marlin.Layer`, a torch-module representing a Marlin 84 | quantized layer. It allows converting a "fake-quantized" (dequantized values stored in FP16) `torch.Linear` layer into 85 | the compressed Marlin format via `marlin.Layer.pack(linear, scales)`. Alternatively, the kernel can also be called 86 | directly through `marlin.mul(..)`, provided that weights and scales have already been appropriately preprocessed (see 87 | `marlin.Layer.pack(...)`). The kernel itself can be found in the self-contained `marlin/marlin_cuda_kernel.cu` file, 88 | which does not contain any dependencies beyond base-CUDA and should thus be easy to integrate into other lower-level 89 | frameworks. 90 | 91 | Correctness tests can be executed via `python test.py` and benchmarks via `python bench.py`. Please note that in order 92 | to reproduce our "sustainable performance" benchmarks, the GPU clocks need to be locked to their respective base values 93 | using: 94 | 95 | ``` 96 | sudo nvidia-smi --lock-gpu-clocks=BASE_GPU_CLOCK --lock-memory-clocks=BASE_MEM_CLOCK 97 | ``` 98 | 99 | Additionally, if ECC is enabled (e.g., on an A10), then the maximum achievable memory bandwidth will be 10-15% lower 100 | than in the official spec sheet as every memory requests will contain checksum overheads. This can be disabled via 101 | 102 | ``` 103 | sudo nvidia-smi -e 0 104 | ``` 105 | 106 | which we do in our A10 benchmarks. 107 | 108 | ## GPTQ Example: 109 | 110 | In the `gptq` subfolder, we also provide a slightly improved version of the [GPTQ](https://github.com/IST-DASLab/gptq) algorithm, with better group grid clipping and non-uniform calibration sample length, that can produce Marlin-compatible 4-bit versions of Llama2 models. 111 | Additionally, there is a script to evaluate such compressed models (using Marlin kernels) in the popular [LLM eval harness](https://github.com/EleutherAI/lm-evaluation-harness). 112 | The script below was tested with `lm-eval-harness==0.4.0` and may not work with newer or older versions. 113 | Here are corresponding sample commands (`marlin`, `transformers` and `datasets` packages must be installed): 114 | 115 | ``` 116 | % Compress Llama2 model and export model in Marlin format. 117 | python llama2.py LLAMA2_CHECKPOINT --wbits 4 --save checkpoint.pt 118 | % Perform perplexity evaluation of uncompressed model. 119 | python llama2.py LLAMA2_CHECKPOINT 120 | % Evaluate compressed model (with Marlin kernels) in the eval harness. 121 | python eval.py --model hf --model_args pretrained=LLAMA2_CHECKPOINT --tasks mmlu \ 122 | --marlin_checkpoint checkpoint.marlin.g128 123 | % Evaluate full precision baseline. 124 | python eval.py --model hf --model_args pretrained=LLAMA2_CHECKPOINT --tasks mmlu 125 | ``` 126 | 127 | We measure the following WikiText and Red-Pajama perplexities, as well as MMLU zero-shot accuracy, for 4-bit (group=128) Marlin models: 128 | 129 | | Llama2 | Wiki2 (FP16) | Wiki2 (INT4) | RedPaj (FP16) | RedPaj (INT4) | MMLU (FP16) | MMLU (INT4) | 130 | |:---:|:----:|:----:|:----:|:----:|:-----:|:-----:| 131 | | 7B | 5.12 | 5.27 | 6.14 | 6.30 | 41.80 | 40.07 | 132 | | 13B | 4.57 | 4.67 | 5.67 | 5.79 | 52.10 | 51.13 | 133 | | 70B | 3.12 | 3.21 | 4.74 | 4.81 | 65.43 | 64.81 | 134 | 135 | We note that this GPTQ example is currently intended mostly as a demonstration of how to produce accurate Marlin models and as an end-to-end validation of kernel correctness (rather than to be a flexible compression tool). 136 | 137 | ## Cite: 138 | 139 | If you found this work useful, please consider citing: 140 | 141 | ``` 142 | @article{frantar2024marlin, 143 | title={MARLIN: Mixed-Precision Auto-Regressive Parallel Inference on Large Language Models}, 144 | author={Frantar, Elias and Castro, Roberto L and Chen, Jiale and Hoefler, Torsten and Alistarh, Dan}, 145 | journal={arXiv preprint arXiv:2408.11743}, 146 | year={2024} 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /assets/marlin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IST-DASLab/marlin/1f25790bdd49fba53106164a24666dade68d7c90/assets/marlin.png -------------------------------------------------------------------------------- /assets/models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IST-DASLab/marlin/1f25790bdd49fba53106164a24666dade68d7c90/assets/models.png -------------------------------------------------------------------------------- /assets/peak.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IST-DASLab/marlin/1f25790bdd49fba53106164a24666dade68d7c90/assets/peak.png -------------------------------------------------------------------------------- /assets/sustained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IST-DASLab/marlin/1f25790bdd49fba53106164a24666dade68d7c90/assets/sustained.png -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import torch 5 | import marlin 6 | 7 | import time 8 | 9 | def benchmark(f, warmup=1, iter=10): 10 | for i in range(warmup + iter): 11 | f() 12 | # We do not synchronize here in order to hide the kernel launch overhead during benchmarkining as this will also 13 | # happen during realistic model inference as many launches are submitted to the kernel queue. 14 | if i == warmup - 1: 15 | torch.cuda.synchronize() 16 | tick = time.time() 17 | torch.cuda.synchronize() 18 | res = (time.time() - tick) / iter 19 | # Make sure there is enough to "cool down" the GPU in between benchmarks to avoid throttling for later runs when 20 | # we execute many benchmarks consecutively 21 | time.sleep(1.) 22 | return res 23 | 24 | def get_problem(m, n, k, groupsize=-1): 25 | if groupsize == -1: 26 | groupsize = k 27 | dev = torch.device('cuda:0') 28 | A = torch.randn((m, k), dtype=torch.half, device=dev) 29 | B = torch.randint(low=-2**31, high=2**31, size=(k * n // 8,), device=dev) 30 | B_ref = torch.randn((k, n), dtype=torch.half, device=dev) 31 | C = torch.zeros((m, n), dtype=torch.half, device=dev) 32 | s = torch.zeros((k // groupsize, n), dtype=torch.half, device=dev) 33 | torch.cuda.synchronize() 34 | return A, B, C, B_ref, s 35 | 36 | def benchmark_dense(A, B, C): 37 | res = benchmark(lambda: torch.matmul(A, B, out=C)) 38 | return { 39 | 's': res, 40 | 'TFLOP/s': 2 * A.numel() * C.shape[1] / res / 10 ** 12, 41 | 'GB/s': (2 * A.numel() + 2 * B.numel() + 2 * C.numel()) / res / 10 ** 9 42 | } 43 | 44 | def benchmark_quant(A, B, C, s, thread_k, thread_n, sms): 45 | workspace = torch.zeros(C.shape[1] // 128 * 16, device=torch.device('cuda:0')) 46 | res = benchmark(lambda: marlin.mul(A, B, C, s, workspace, thread_k, thread_n, sms)) 47 | return { 48 | 's': res, 49 | 'TFLOP/s': 2 * A.numel() * C.shape[1] / res / 10 ** 12, 50 | 'GB/s': (2 * A.numel() + 4 * B.numel() + 2 * C.numel() + 2 * s.numel()) / res / 10 ** 9 51 | } 52 | 53 | # Pass the SM count for known GPUs to avoid the kernel having to query this information (this is very minor) 54 | gpu = torch.cuda.get_device_name(0) 55 | if 'A100' in gpu: 56 | SMS = 108 57 | elif 'A10' in gpu: 58 | SMS = 72 59 | elif '3090' in gpu: 60 | SMS = 82 61 | elif 'A6000' in gpu: 62 | SMS = 84 63 | else: 64 | SMS = -1 65 | 66 | MODELS = { 67 | 'ideal': [ 68 | (4 * 256 * SMS, 256 * SMS) 69 | ], 70 | 'Llama7B': [ 71 | (4096, 3 * 4096), 72 | (4096, 4096), 73 | (4096, 2 * 10752), 74 | (10752, 4096) 75 | ], 76 | 'Llama13B': [ 77 | (5120, 3 * 5120), 78 | (5120, 5120), 79 | (5120, 2 * 13568), 80 | (13568, 5120) 81 | ], 82 | 'Llama33B': [ 83 | (6656, 3 * 6656), 84 | (6656, 6656), 85 | (6656, 2 * 17664), 86 | (17664, 6656) 87 | ], 88 | 'Llama65B': [ 89 | (8192, 3 * 8192), 90 | (8192, 8192), 91 | (8192, 2 * 21760), 92 | (21760, 8192) 93 | ], 94 | 'Falcon180B': [ 95 | # Note that parallel attention and FC allows layer fusions 96 | (14848, 14848 * 5 + 1024), 97 | (14848 * 5, 14848) 98 | ] 99 | } 100 | 101 | # Set to true in order to run a more complete benchmark sweep; the default is reproduce README experiments 102 | ALL = False 103 | 104 | for groupsize in [-1, 128] if ALL else [128]: 105 | print('groupsize=%d' % groupsize) 106 | print() 107 | for model, layers in MODELS.items(): 108 | print(model) 109 | if ALL: 110 | batchsizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] 111 | else: 112 | batchsizes = [1, 2, 4, 8, 16, 32, 64, 128] 113 | for batch in batchsizes: 114 | if not ALL and model != 'ideal' and batch != 16: 115 | continue 116 | tot_q = {'s': 0, 'TFLOP/s': 0, 'GB/s': 0, 'speedup': 0} 117 | for layer in layers: 118 | A, B, C, B_ref, s = get_problem(batch, layer[1], layer[0], groupsize) 119 | res_d = benchmark_dense(A, B_ref, C) 120 | if model == 'ideal' and batch == 16: 121 | # This is a special case constructed to be optimal for a thread-shape different than the default one 122 | res_q = benchmark_quant(A, B, C, s, 64, 256, SMS) 123 | else: 124 | res_q = benchmark_quant(A, B, C, s, -1, -1, SMS) 125 | res_q['speedup'] = res_d['s'] / res_q['s'] 126 | tot_q['s'] += res_q['s'] 127 | for k in tot_q: 128 | if k != 's': 129 | tot_q[k] += res_q[k] * res_q['s'] 130 | for k in tot_q: 131 | if k != 's': 132 | tot_q[k] /= tot_q['s'] 133 | print('batch=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f' % ( 134 | batch, 135 | tot_q['s'], 136 | tot_q['TFLOP/s'], 137 | tot_q['GB/s'], 138 | tot_q['speedup'] 139 | )) 140 | print() 141 | -------------------------------------------------------------------------------- /gptq/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_seed(seed): 6 | np.random.seed(seed) 7 | torch.random.manual_seed(seed) 8 | 9 | 10 | def get_wikitext2(nsamples, seed, seqlen, model): 11 | from datasets import load_dataset 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | 15 | from transformers import AutoTokenizer 16 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 17 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 19 | 20 | import random 21 | random.seed(seed) 22 | trainloader = [] 23 | for _ in range(nsamples): 24 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 25 | j = i + seqlen 26 | inp = trainenc.input_ids[:, i:j] 27 | tar = inp.clone() 28 | tar[:, :-1] = -100 29 | trainloader.append((inp, tar)) 30 | testloader = [] 31 | for i in range(0, testenc.input_ids.shape[1] - seqlen, seqlen): 32 | testloader.append(testenc.input_ids[:, i:(i + seqlen)]) 33 | 34 | return trainloader, testloader 35 | 36 | def get_red(nsamples, seed, seqlen, model): 37 | VALSAMPLES = 1024 38 | 39 | from datasets import load_dataset 40 | from transformers import AutoTokenizer 41 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 42 | traindata = load_dataset('togethercomputer/RedPajama-Data-1T-Sample', split='train') 43 | 44 | np.random.seed(0) 45 | perm = np.random.permutation(len(traindata)) 46 | 47 | dataloader = [] 48 | for i in perm: 49 | tokens = tokenizer(traindata[int(i)]['text'], return_tensors='pt').input_ids 50 | if not (1 < tokens.shape[1] <= seqlen): 51 | continue 52 | dataloader.append(tokens) 53 | if len(dataloader) == nsamples + VALSAMPLES: 54 | break 55 | trainloader = dataloader[VALSAMPLES:] 56 | testloader = dataloader[:VALSAMPLES] 57 | return trainloader, testloader 58 | 59 | 60 | def get_loaders( 61 | name, nsamples=256, seed=0, seqlen=2048, model='' 62 | ): 63 | if 'wikitext2' in name: 64 | return get_wikitext2(nsamples, seed, seqlen, model) 65 | return data, None 66 | if 'red' in name: 67 | return get_red(nsamples, seed, seqlen, model) 68 | 69 | -------------------------------------------------------------------------------- /gptq/eval.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/__main__.py 2 | # with minor modifications for Marlin checkpoint loading as I didn't find an easy way to call `lm_eval.cli_evaluate(...)` directly 3 | 4 | 5 | import marlin 6 | 7 | # Save checkpoint name here since passing around extra args seems to confuse the eval harness 8 | MARLIN_CHECKPOINT = '' 9 | 10 | def get_llama_marlin(name, *args, **kwargs): 11 | import torch 12 | def skip(*args, **kwargs): 13 | pass 14 | torch.nn.init.kaiming_uniform_ = skip 15 | torch.nn.init.uniform_ = skip 16 | torch.nn.init.normal_ = skip 17 | from transformers import LlamaForCausalLM 18 | model = LlamaForCausalLM.from_pretrained(name, torch_dtype='auto') 19 | # Not really sure why this is sometimes > 1, but it messes up quantized inference ... 20 | # Fortunately, just setting it to 1 doesn't seem to affect standard inference 21 | model.config.pretraining_tp = 1 22 | def name_filter(n): 23 | if 'q_proj' in n or 'k_proj' in n or 'v_proj' in n or 'o_proj' in n: 24 | return True 25 | if 'mlp.gate_proj' in n or 'mlp.up_proj' in n or 'mlp.down_proj' in n: 26 | return True 27 | return False 28 | groupsize = -1 if MARLIN_CHECKPOINT.endswith('marlin') else 128 29 | marlin.replace_linear(model, name_filter, groupsize=groupsize) 30 | model.load_state_dict(torch.load(MARLIN_CHECKPOINT)) 31 | return model 32 | 33 | 34 | import argparse 35 | import json 36 | import logging 37 | import os 38 | import re 39 | import sys 40 | from pathlib import Path 41 | from typing import Union 42 | 43 | import numpy as np 44 | 45 | from lm_eval import evaluator, utils 46 | from lm_eval.api.registry import ALL_TASKS 47 | from lm_eval.tasks import include_path, initialize_tasks 48 | from lm_eval.utils import make_table 49 | 50 | 51 | def _handle_non_serializable(o): 52 | if isinstance(o, np.int64) or isinstance(o, np.int32): 53 | return int(o) 54 | elif isinstance(o, set): 55 | return list(o) 56 | else: 57 | return str(o) 58 | 59 | 60 | def parse_eval_args() -> argparse.Namespace: 61 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 62 | parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`") 63 | parser.add_argument( 64 | "--tasks", 65 | "-t", 66 | default=None, 67 | metavar="task1,task2", 68 | help="To get full list of tasks, use the command lm-eval --tasks list", 69 | ) 70 | parser.add_argument( 71 | "--model_args", 72 | "-a", 73 | default="", 74 | help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`", 75 | ) 76 | parser.add_argument( 77 | "--num_fewshot", 78 | "-f", 79 | type=int, 80 | default=None, 81 | metavar="N", 82 | help="Number of examples in few-shot context", 83 | ) 84 | parser.add_argument( 85 | "--batch_size", 86 | "-b", 87 | type=str, 88 | default=1, 89 | metavar="auto|auto:N|N", 90 | help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", 91 | ) 92 | parser.add_argument( 93 | "--max_batch_size", 94 | type=int, 95 | default=None, 96 | metavar="N", 97 | help="Maximal batch size to try with --batch_size auto.", 98 | ) 99 | parser.add_argument( 100 | "--device", 101 | type=str, 102 | default=None, 103 | help="Device to use (e.g. cuda, cuda:0, cpu).", 104 | ) 105 | parser.add_argument( 106 | "--output_path", 107 | "-o", 108 | default=None, 109 | type=str, 110 | metavar="DIR|DIR/file.json", 111 | help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.", 112 | ) 113 | parser.add_argument( 114 | "--limit", 115 | "-L", 116 | type=float, 117 | default=None, 118 | metavar="N|0 None: 189 | if not args: 190 | # we allow for args to be passed externally, else we parse them ourselves 191 | args = parse_eval_args() 192 | if args.marlin_checkpoint: 193 | global MARLIN_CHECKPOINT 194 | MARLIN_CHECKPOINT = args.marlin_checkpoint 195 | del args.marlin_checkpoint 196 | # Overwrite model load with marlin load 197 | import transformers 198 | transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin) 199 | 200 | eval_logger = utils.eval_logger 201 | eval_logger.setLevel(getattr(logging, f"{args.verbosity}")) 202 | eval_logger.info(f"Verbosity set to {args.verbosity}") 203 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 204 | 205 | initialize_tasks(args.verbosity) 206 | 207 | if args.limit: 208 | eval_logger.warning( 209 | " --limit SHOULD ONLY BE USED FOR TESTING." 210 | "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." 211 | ) 212 | if args.include_path is not None: 213 | eval_logger.info(f"Including path: {args.include_path}") 214 | include_path(args.include_path) 215 | 216 | if args.tasks is None: 217 | task_names = ALL_TASKS 218 | elif args.tasks == "list": 219 | eval_logger.info( 220 | "Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS))) 221 | ) 222 | sys.exit() 223 | else: 224 | if os.path.isdir(args.tasks): 225 | import glob 226 | 227 | task_names = [] 228 | yaml_path = os.path.join(args.tasks, "*.yaml") 229 | for yaml_file in glob.glob(yaml_path): 230 | config = utils.load_yaml_config(yaml_file) 231 | task_names.append(config) 232 | else: 233 | tasks_list = args.tasks.split(",") 234 | task_names = utils.pattern_match(tasks_list, ALL_TASKS) 235 | for task in [task for task in tasks_list if task not in task_names]: 236 | if os.path.isfile(task): 237 | config = utils.load_yaml_config(task) 238 | task_names.append(config) 239 | task_missing = [ 240 | task 241 | for task in tasks_list 242 | if task not in task_names and "*" not in task 243 | ] # we don't want errors if a wildcard ("*") task name was used 244 | 245 | if task_missing: 246 | missing = ", ".join(task_missing) 247 | eval_logger.error( 248 | f"Tasks were not found: {missing}\n" 249 | f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", 250 | ) 251 | raise ValueError( 252 | f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues." 253 | ) 254 | 255 | if args.output_path: 256 | path = Path(args.output_path) 257 | # check if file or 'dir/results.json' exists 258 | if path.is_file() or Path(args.output_path).joinpath("results.json").is_file(): 259 | eval_logger.warning( 260 | f"File already exists at {path}. Results will be overwritten." 261 | ) 262 | output_path_file = path.joinpath("results.json") 263 | assert not path.is_file(), "File already exists" 264 | # if path json then get parent dir 265 | elif path.suffix in (".json", ".jsonl"): 266 | output_path_file = path 267 | path.parent.mkdir(parents=True, exist_ok=True) 268 | path = path.parent 269 | else: 270 | path.mkdir(parents=True, exist_ok=True) 271 | output_path_file = path.joinpath("results.json") 272 | elif args.log_samples and not args.output_path: 273 | assert args.output_path, "Specify --output_path" 274 | 275 | eval_logger.info(f"Selected Tasks: {task_names}") 276 | 277 | results = evaluator.simple_evaluate( 278 | model=args.model, 279 | model_args=args.model_args, 280 | tasks=task_names, 281 | num_fewshot=args.num_fewshot, 282 | batch_size=args.batch_size, 283 | max_batch_size=args.max_batch_size, 284 | device=args.device, 285 | use_cache=args.use_cache, 286 | limit=args.limit, 287 | decontamination_ngrams_path=args.decontamination_ngrams_path, 288 | check_integrity=args.check_integrity, 289 | write_out=args.write_out, 290 | log_samples=args.log_samples, 291 | gen_kwargs=args.gen_kwargs, 292 | ) 293 | 294 | if results is not None: 295 | if args.log_samples: 296 | samples = results.pop("samples") 297 | dumped = json.dumps( 298 | results, indent=2, default=_handle_non_serializable, ensure_ascii=False 299 | ) 300 | if args.show_config: 301 | print(dumped) 302 | 303 | batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) 304 | 305 | if args.output_path: 306 | output_path_file.open("w").write(dumped) 307 | 308 | if args.log_samples: 309 | for task_name, config in results["configs"].items(): 310 | output_name = "{}_{}".format( 311 | re.sub("/|=", "__", args.model_args), task_name 312 | ) 313 | filename = path.joinpath(f"{output_name}.jsonl") 314 | samples_dumped = json.dumps( 315 | samples[task_name], 316 | indent=2, 317 | default=_handle_non_serializable, 318 | ensure_ascii=False, 319 | ) 320 | filename.open("w").write(samples_dumped) 321 | 322 | print( 323 | f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " 324 | f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" 325 | ) 326 | print(make_table(results)) 327 | if "groups" in results: 328 | print(make_table(results, "groups")) 329 | 330 | 331 | cli_evaluate() 332 | 333 | -------------------------------------------------------------------------------- /gptq/gptq.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import transformers 8 | 9 | from quant import * 10 | 11 | 12 | DEBUG = False 13 | 14 | torch.backends.cuda.matmul.allow_tf32 = False 15 | torch.backends.cudnn.allow_tf32 = False 16 | 17 | 18 | class GPTQ: 19 | 20 | def __init__(self, layer, stable=False): 21 | self.layer = layer 22 | self.dev = self.layer.weight.device 23 | W = layer.weight.data.clone() 24 | self.rows = W.shape[0] 25 | self.columns = W.shape[1] 26 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 27 | self.nsamples = 0 28 | 29 | self.stable = stable 30 | self.mean = torch.zeros((self.columns, 1), device=self.dev) 31 | 32 | def add_batch(self, inp, out): 33 | if DEBUG: 34 | self.inp1 = inp 35 | self.out1 = out 36 | if len(inp.shape) == 2: 37 | inp = inp.unsqueeze(0) 38 | tmp = inp.shape[0] 39 | if len(inp.shape) == 3: 40 | inp = inp.reshape((-1, inp.shape[-1])) 41 | inp = inp.t() 42 | 43 | if self.stable: 44 | inp = inp.float() 45 | delta = torch.mean(inp, 1, keepdims=True) - self.mean 46 | self.H += inp.matmul(inp.t()) + delta.matmul(delta.t()) * self.nsamples * tmp / (self.nsamples + tmp) 47 | self.nsamples += tmp 48 | self.mean += delta * tmp / self.nsamples 49 | else: 50 | self.H *= self.nsamples / (self.nsamples + tmp) 51 | self.nsamples += tmp 52 | inp = math.sqrt(2 / self.nsamples) * inp.float() 53 | self.H += inp.matmul(inp.t()) 54 | 55 | def fasterquant( 56 | self, blocksize=128, percdamp=.1, groupsize=-1, clip=False, baseline=False 57 | ): 58 | W = self.layer.weight.data.clone() 59 | W = W.float() 60 | 61 | tick = time.time() 62 | 63 | if self.stable: 64 | self.H /= self.nsamples 65 | self.H += self.mean.matmul(self.mean.t()) 66 | self.H *= 2 67 | H = self.H 68 | del self.H 69 | 70 | Losses = torch.zeros_like(W) 71 | Q = torch.zeros_like(W) 72 | 73 | if not baseline: 74 | try: 75 | damp = percdamp * torch.mean(torch.diag(H)) 76 | diag = torch.arange(self.columns, device=self.dev) 77 | H[diag, diag] += damp 78 | H = torch.linalg.cholesky(H) 79 | H = torch.cholesky_inverse(H) 80 | H = torch.linalg.cholesky(H, upper=True) 81 | Hinv = H 82 | except: 83 | print('Singularity.') 84 | baseline = True 85 | if baseline: 86 | del H 87 | Hinv = torch.eye(self.columns, device=self.dev) 88 | 89 | if groupsize == -1: 90 | self.quantizer.find_params(W) 91 | groups = [] 92 | 93 | for i1 in range(0, self.columns, blocksize): 94 | i2 = min(i1 + blocksize, self.columns) 95 | count = i2 - i1 96 | 97 | W1 = W[:, i1:i2].clone() 98 | Q1 = torch.zeros_like(W1) 99 | Err1 = torch.zeros_like(W1) 100 | Losses1 = torch.zeros_like(W1) 101 | Hinv1 = Hinv[i1:i2, i1:i2] 102 | 103 | if groupsize != -1: 104 | self.quantizer.find_params(W1, solve=Hinv1 if clip else None) 105 | groups.append(copy.deepcopy(self.quantizer)) 106 | 107 | for i in range(count): 108 | w = W1[:, i] 109 | d = Hinv1[i, i] 110 | 111 | q = quantize( 112 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 113 | ).flatten() 114 | 115 | Q1[:, i] = q 116 | Losses1[:, i] = (w - q) ** 2 / d ** 2 117 | 118 | err1 = (w - q) / d 119 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 120 | Err1[:, i] = err1 121 | 122 | Q[:, i1:i2] = Q1 123 | Losses[:, i1:i2] = Losses1 / 2 124 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 125 | 126 | if DEBUG: 127 | self.layer.weight.data[:, :i2] = Q[:, :i2] 128 | self.layer.weight.data[:, i2:] = W[:, i2:] 129 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 130 | print(torch.sum(Losses)) 131 | 132 | torch.cuda.synchronize() 133 | print('time %.2f' % (time.time() - tick)) 134 | print('error', torch.sum(Losses).item()) 135 | 136 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 137 | if DEBUG: 138 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 139 | 140 | if groups: 141 | scale = torch.cat([q.scale for q in groups], dim=1) 142 | zero = torch.cat([q.zero for q in groups], dim=1) 143 | return scale, zero 144 | return self.quantizer.scale, self.quantizer.zero 145 | -------------------------------------------------------------------------------- /gptq/llama2.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from gptq import * 7 | from quant import * 8 | import marlin 9 | 10 | 11 | DEV = torch.device('cuda:0') 12 | 13 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 14 | if type(module) in layers: 15 | return {name: module} 16 | res = {} 17 | for name1, child in module.named_children(): 18 | res.update(find_layers( 19 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 20 | )) 21 | return res 22 | 23 | 24 | def get_llama(name): 25 | import torch 26 | def skip(*args, **kwargs): 27 | pass 28 | torch.nn.init.kaiming_uniform_ = skip 29 | torch.nn.init.uniform_ = skip 30 | torch.nn.init.normal_ = skip 31 | from transformers import LlamaForCausalLM 32 | model = LlamaForCausalLM.from_pretrained(name, torch_dtype='auto') 33 | model.config.pretraining_tp = 1 34 | model.seqlen = 4096 35 | return model 36 | 37 | @torch.no_grad() 38 | def llama_sequential(model, dataloader, dev): 39 | print('Starting ...') 40 | 41 | use_cache = model.config.use_cache 42 | model.config.use_cache = False 43 | layers = model.model.layers 44 | 45 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 46 | model.model.norm = model.model.norm.to(dev) 47 | layers[0] = layers[0].to(dev) 48 | 49 | dtype = next(iter(model.parameters())).dtype 50 | inps = [] 51 | attention_masks = [] 52 | position_ids = [] 53 | 54 | class Catcher(nn.Module): 55 | def __init__(self, module): 56 | super().__init__() 57 | self.module = module 58 | def forward(self, inp, **kwargs): 59 | inps.append(inp) 60 | attention_masks.append(kwargs['attention_mask']) 61 | position_ids.append(kwargs['position_ids']) 62 | raise ValueError 63 | layers[0] = Catcher(layers[0]) 64 | for batch in dataloader: 65 | try: 66 | model(batch.to(dev)) 67 | except ValueError: 68 | pass 69 | layers[0] = layers[0].module 70 | 71 | layers[0] = layers[0].cpu() 72 | model.model.embed_tokens = model.model.embed_tokens.cpu() 73 | model.model.norm = model.model.norm.cpu() 74 | torch.cuda.empty_cache() 75 | 76 | print('Ready.') 77 | 78 | quantizers = {} 79 | for i in range(len(layers)): 80 | layer = layers[i].to(dev) 81 | full = find_layers(layer) 82 | 83 | if args.true_sequential: 84 | sequential = [ 85 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], 86 | ['self_attn.o_proj'], 87 | ['mlp.up_proj', 'mlp.gate_proj'], 88 | ['mlp.down_proj'] 89 | ] 90 | else: 91 | sequential = [list(full.keys())] 92 | 93 | for names in sequential: 94 | if model.config.num_attention_heads != model.config.num_key_value_heads and args.skip_gq: 95 | names.remove('self_attn.k_proj') 96 | names.remove('self_attn.v_proj') 97 | 98 | subset = {n: full[n] for n in names} 99 | 100 | gptq = {} 101 | for name in subset: 102 | gptq[name] = GPTQ(subset[name]) 103 | gptq[name].quantizer = Quantizer() 104 | gptq[name].quantizer.configure(args.wbits) 105 | 106 | def add_batch(name): 107 | def tmp(_, inp, out): 108 | gptq[name].add_batch(inp[0].data, out.data) 109 | return tmp 110 | handles = [] 111 | for name in subset: 112 | handles.append(subset[name].register_forward_hook(add_batch(name))) 113 | for j in range(args.nsamples): 114 | layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j]) 115 | for h in handles: 116 | h.remove() 117 | 118 | for name in subset: 119 | print(i, name) 120 | print('Quantizing ...') 121 | res = gptq[name].fasterquant( 122 | percdamp=args.percdamp, groupsize=args.groupsize, clip=not args.no_clip, baseline=args.nearest 123 | ) 124 | res = list(res) 125 | res[0] = res[0].cpu() 126 | res[1] = res[1].cpu() 127 | quantizers['model.layers.%d.%s' % (i, name)] = res 128 | 129 | for j in range(args.nsamples): 130 | inps[j] = layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j])[0] 131 | 132 | layers[i] = layer.cpu() 133 | del layer 134 | del gptq 135 | torch.cuda.empty_cache() 136 | 137 | model.config.use_cache = use_cache 138 | return quantizers 139 | 140 | @torch.no_grad() 141 | def llama_eval(model, dataloader, dev): 142 | print('Evaluating ...') 143 | 144 | nsamples = len(dataloader) 145 | 146 | use_cache = model.config.use_cache 147 | model.config.use_cache = False 148 | layers = model.model.layers 149 | 150 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 151 | layers[0] = layers[0].to(dev) 152 | 153 | dtype = next(iter(model.parameters())).dtype 154 | inps = [] 155 | attention_masks = [] 156 | position_ids = [] 157 | 158 | class Catcher(nn.Module): 159 | def __init__(self, module): 160 | super().__init__() 161 | self.module = module 162 | def forward(self, inp, **kwargs): 163 | inps.append(inp) 164 | attention_masks.append(kwargs['attention_mask']) 165 | position_ids.append(kwargs['position_ids']) 166 | raise ValueError 167 | layers[0] = Catcher(layers[0]) 168 | for batch in dataloader: 169 | try: 170 | model(batch.to(dev)) 171 | except ValueError: 172 | pass 173 | layers[0] = layers[0].module 174 | 175 | layers[0] = layers[0].cpu() 176 | model.model.embed_tokens = model.model.embed_tokens.cpu() 177 | torch.cuda.empty_cache() 178 | 179 | for i in range(len(layers)): 180 | print(i) 181 | layer = layers[i].to(dev) 182 | for j in range(nsamples): 183 | inps[j] = layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j])[0] 184 | layers[i] = layer.cpu() 185 | del layer 186 | torch.cuda.empty_cache() 187 | 188 | if model.model.norm is not None: 189 | model.model.norm = model.model.norm.to(dev) 190 | model.lm_head = model.lm_head.to(dev) 191 | 192 | nlls = [] 193 | for i in range(nsamples): 194 | hidden_states = inps[i] 195 | if model.model.norm is not None: 196 | hidden_states = model.model.norm(hidden_states) 197 | lm_logits = model.lm_head(hidden_states) 198 | shift_logits = lm_logits[:, :-1, :].contiguous() 199 | shift_labels = (dataloader[i].to(dev))[:, 1:] 200 | loss_fct = nn.CrossEntropyLoss() 201 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 202 | neg_log_likelihood = loss.float() * model.seqlen 203 | nlls.append(neg_log_likelihood) 204 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 205 | print(ppl.item()) 206 | 207 | model.config.use_cache = use_cache 208 | 209 | def llama_pack(model, quantizers): 210 | layers = find_layers(model) 211 | layers = {n: layers[n] for n in quantizers} 212 | marlin.replace_linear(model, lambda n: n in quantizers, groupsize=args.groupsize) 213 | qlayers = find_layers(model, [marlin.Layer]) 214 | print('Packing ...') 215 | for name in qlayers: 216 | print(name) 217 | qlayers[name].pack(layers[name].to(DEV), quantizers[name][0].to(DEV)) 218 | qlayers[name].cpu() 219 | quantizers[name][0].cpu() 220 | layers[name].cpu() 221 | print('Done.') 222 | return model 223 | 224 | 225 | if __name__ == '__main__': 226 | import argparse 227 | from datautils import * 228 | 229 | parser = argparse.ArgumentParser() 230 | 231 | parser.add_argument( 232 | 'model', type=str, 233 | help='LlaMa model to load; pass location of hugginface converted checkpoint.' 234 | ) 235 | parser.add_argument( 236 | '--dataset', type=str, default='red', choices=['red'], 237 | help='Where to extract calibration data from.' 238 | ) 239 | parser.add_argument( 240 | '--seed', 241 | type=int, default=0, help='Seed for sampling the calibration data.' 242 | ) 243 | parser.add_argument( 244 | '--nsamples', type=int, default=256, 245 | help='Number of calibration data samples.' 246 | ) 247 | parser.add_argument( 248 | '--percdamp', type=float, default=.1, 249 | help='Percent of the average Hessian diagonal to use for dampening.' 250 | ) 251 | parser.add_argument( 252 | '--nearest', action='store_true', 253 | help='Whether to run the RTN baseline.' 254 | ) 255 | parser.add_argument( 256 | '--wbits', type=int, default=16, choices=[4, 16], 257 | help='#bits to use for quantization; use 16 for evaluating base model.' 258 | ) 259 | parser.add_argument( 260 | '--groupsize', type=int, default=128, choices=[-1, 128], 261 | help='Groupsize to use for quantization; default is 128.' 262 | ) 263 | parser.add_argument( 264 | '--true-sequential', action='store_true', 265 | help='Whether to run in true sequential model.' 266 | ) 267 | parser.add_argument( 268 | '--no_clip', action='store_true', 269 | help='Whether to skip hessian based grid clipping when using groups.' 270 | ) 271 | parser.add_argument( 272 | '--skip_gq', action='store_true', 273 | help='Whether to skip quantizing group keys and values for the 70B model with group-query attention.' 274 | ) 275 | parser.add_argument( 276 | '--save', type=str, default='', 277 | help='Whether and where to save the quantized model.' 278 | ) 279 | 280 | args = parser.parse_args() 281 | 282 | if args.nearest: 283 | args.nsamples = 0 284 | 285 | model = get_llama(args.model) 286 | model.eval() 287 | 288 | dataloader, testloader = get_loaders( 289 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 290 | ) 291 | 292 | if args.wbits < 16: 293 | tick = time.time() 294 | quantizers = llama_sequential(model, dataloader, DEV) 295 | print(time.time() - tick) 296 | 297 | datasets = ['wikitext2', 'red'] 298 | for dataset in datasets: 299 | dataloader, testloader = get_loaders( 300 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 301 | ) 302 | print(dataset) 303 | llama_eval(model, testloader, DEV) 304 | 305 | if args.save: 306 | args.save += '.marlin' 307 | if args.groupsize != -1: 308 | args.save += '.g%d' % args.groupsize 309 | llama_pack(model, quantizers) 310 | torch.save(model.state_dict(), args.save) 311 | 312 | -------------------------------------------------------------------------------- /gptq/quant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def quantize(x, scale, zero, maxq): 7 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 8 | return scale * (q - zero) 9 | 10 | class Quantizer(nn.Module): 11 | 12 | def __init__(self, shape=1): 13 | super(Quantizer, self).__init__() 14 | self.register_buffer('maxq', torch.tensor(0)) 15 | self.register_buffer('scale', torch.zeros(shape)) 16 | self.register_buffer('zero', torch.zeros(shape)) 17 | 18 | def configure(self, bits, sym=True, grid=100, maxshrink=.75): 19 | self.maxq = torch.tensor(2 ** bits - 1) 20 | self.sym = sym 21 | self.grid = grid 22 | self.maxshrink = maxshrink 23 | 24 | def find_params(self, x, solve=None, scales=None): 25 | dev = x.device 26 | self.maxq = self.maxq.to(dev) 27 | 28 | shape = x.shape 29 | x = x.flatten(1) 30 | if scales is not None: 31 | x *= scales 32 | 33 | tmp = torch.zeros(x.shape[0], device=dev) 34 | xmin = torch.minimum(x.min(1)[0], tmp) 35 | xmax = torch.maximum(x.max(1)[0], tmp) 36 | 37 | if self.sym: 38 | xmax = torch.maximum(torch.abs(xmin), xmax) 39 | tmp = xmin < 0 40 | if torch.any(tmp): 41 | xmin[tmp] = -xmax[tmp] 42 | tmp = (xmin == 0) & (xmax == 0) 43 | xmin[tmp] = -1 44 | xmax[tmp] = +1 45 | 46 | self.scale = (xmax - xmin) / self.maxq 47 | if self.sym: 48 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 49 | else: 50 | self.zero = torch.round(-xmin / self.scale) 51 | 52 | if solve is not None: 53 | best = torch.full([x.shape[0]], float('inf'), device=dev) 54 | for i in range(int(self.maxshrink * self.grid) + 1): 55 | p = 1 - i / self.grid 56 | clip = p * torch.max(xmax, torch.abs(xmin)) 57 | xmax1 = torch.min(xmax, +clip) 58 | xmin1 = torch.max(xmin, -clip) 59 | scale1 = (xmax1 - xmin1) / self.maxq 60 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 61 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 62 | if scales is not None: 63 | q /= scales 64 | delta = q - x 65 | err = torch.sum(torch.linalg.solve_triangular(solve, delta, upper=True, left=False) ** 2, 1) 66 | tmp = err < best 67 | if torch.any(tmp): 68 | best[tmp] = err[tmp] 69 | self.scale[tmp] = scale1[tmp] 70 | self.zero[tmp] = zero1[tmp] 71 | 72 | shape = [-1] + [1] * (len(shape) - 1) 73 | self.scale = self.scale.reshape(shape) 74 | self.zero = self.zero.reshape(shape) 75 | 76 | def quantize(self, x): 77 | return quantize(x, self.scale, self.zero, self.maxq) 78 | 79 | -------------------------------------------------------------------------------- /marlin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | import marlin_cuda 22 | 23 | def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16): 24 | """Marlin FP16xINT4 multiply; can be used within `torch.compile`. 25 | @A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout 26 | @B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()` 27 | @C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout 28 | @s: `torch.half` scales of shape `(m / groupsize, n)` 29 | @workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero 30 | @thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1) 31 | @thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1) 32 | @sms: number of SMs to use for the kernel (can usually be left as auto -1) 33 | @max_par: maximum number of batch 64 problems to solve in parallel for large input sizes 34 | """ 35 | marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par) 36 | 37 | 38 | # Precompute permutations for Marlin weight and scale shuffling 39 | 40 | def _get_perms(): 41 | perm = [] 42 | for i in range(32): 43 | perm1 = [] 44 | col = i // 4 45 | for block in [0, 1]: 46 | for row in [ 47 | 2 * (i % 4), 48 | 2 * (i % 4) + 1, 49 | 2 * (i % 4 + 4), 50 | 2 * (i % 4 + 4) + 1 51 | ]: 52 | perm1.append(16 * row + col + 8 * block) 53 | for j in range(4): 54 | perm.extend([p + 256 * j for p in perm1]) 55 | 56 | perm = np.array(perm) 57 | interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) 58 | perm = perm.reshape((-1, 8))[:, interleave].ravel() 59 | perm = torch.from_numpy(perm) 60 | scale_perm = [] 61 | for i in range(8): 62 | scale_perm.extend([i + 8 * j for j in range(8)]) 63 | scale_perm_single = [] 64 | for i in range(4): 65 | scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) 66 | return perm, scale_perm, scale_perm_single 67 | 68 | _perm, _scale_perm, _scale_perm_single = _get_perms() 69 | 70 | 71 | class Layer(nn.Module): 72 | """PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias.""" 73 | 74 | def __init__(self, infeatures, outfeatures, groupsize=-1): 75 | """Create an empty Marlin layer. 76 | @infeatures: number of input features (must be divisible by 128) 77 | @outfeatures: number of output features (must be divisible by 256) 78 | @groupsize: quantization groupsize (must be -1 or 128) 79 | """ 80 | super().__init__() 81 | if groupsize not in [-1, 128]: 82 | raise ValueError('Only groupsize -1 and 128 are supported.') 83 | if infeatures % 128 != 0 or outfeatures % 256 != 0: 84 | raise ValueError('`infeatures` must be divisible by 128 and `outfeatures` by 256.') 85 | if groupsize == -1: 86 | groupsize = infeatures 87 | if infeatures % groupsize != 0: 88 | raise ValueError('`infeatures` must be divisible by `groupsize`.') 89 | self.k = infeatures 90 | self.n = outfeatures 91 | self.groupsize = groupsize 92 | self.register_buffer('B', torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int)) 93 | self.register_buffer('s', torch.empty((self.k // groupsize, self.n), dtype=torch.half)) 94 | # 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par` 95 | self.register_buffer('workspace', torch.zeros(self.n // 128 * 16, dtype=torch.int), persistent=False) 96 | 97 | def forward(self, A): 98 | C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device) 99 | mul(A.view((-1, A.shape[-1])), self.B, C.view((-1, C.shape[-1])), self.s, self.workspace) 100 | return C 101 | 102 | def pack(self, linear, scales): 103 | """Pack a fake-quantized linear layer into this actual Marlin representation. 104 | @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`) 105 | @scales: corresponding quantization scales of shape `(infeatures, groups)` 106 | """ 107 | if linear.weight.dtype != torch.half: 108 | raise ValueError('Only `torch.half` weights are supported.') 109 | tile = 16 110 | maxq = 2 ** 4 - 1 111 | s = scales.t() 112 | w = linear.weight.data.t() 113 | if self.groupsize != self.k: 114 | w = w.reshape((-1, self.groupsize, self.n)) 115 | w = w.permute(1, 0, 2) 116 | w = w.reshape((self.groupsize, -1)) 117 | s = s.reshape((1, -1)) 118 | w = torch.round(w / s).int() 119 | w += (maxq + 1) // 2 120 | w = torch.clamp(w, 0, maxq) 121 | if self.groupsize != self.k: 122 | w = w.reshape((self.groupsize, -1, self.n)) 123 | w = w.permute(1, 0, 2) 124 | w = w.reshape((self.k, self.n)).contiguous() 125 | s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] 126 | else: 127 | s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] 128 | s = s.reshape((-1, self.n)).contiguous() 129 | w = w.reshape((self.k // tile, tile, self.n // tile, tile)) 130 | w = w.permute((0, 2, 1, 3)) 131 | w = w.reshape((self.k // tile, self.n * tile)) 132 | res = w 133 | res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape) 134 | q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) 135 | res = res.cpu().numpy().astype(np.uint32) 136 | for i in range(8): 137 | q |= res[:, i::8] << 4 * i 138 | q = torch.from_numpy(q.astype(np.int32)).to(w.device) 139 | self.B[:, :] = q.to(self.B.device) 140 | self.s[:, :] = s.to(self.s.device) 141 | 142 | 143 | def replace_linear(module, name_filter=lambda n: True, groupsize=-1, name=''): 144 | """Recursively replace all `torch.nn.Linear` layers by empty Marlin layers. 145 | @module: top-level module in which to perform the replacement 146 | @name_filter: lambda indicating if a layer should be replaced 147 | @groupsize: marlin groupsize 148 | @name: root-level name 149 | """ 150 | if isinstance(module, Layer): 151 | return 152 | for attr in dir(module): 153 | tmp = getattr(module, attr) 154 | name1 = name + '.' + attr if name != '' else attr 155 | if isinstance(tmp, nn.Linear) and name_filter(name1): 156 | setattr( 157 | module, attr, Layer(tmp.in_features, tmp.out_features, groupsize=groupsize) 158 | ) 159 | for name1, child in module.named_children(): 160 | replace_linear(child, name_filter, groupsize=groupsize, name=name + '.' + name1 if name != '' else name1) 161 | -------------------------------------------------------------------------------- /marlin/marlin_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | int marlin_cuda( 24 | const void* A, 25 | const void* B, 26 | void* C, 27 | void* s, 28 | int prob_m, 29 | int prob_n, 30 | int prob_k, 31 | void* workspace, 32 | int groupsize = -1, 33 | int dev = 0, 34 | cudaStream_t stream = 0, 35 | int thread_k = -1, 36 | int thread_n = -1, 37 | int sms = -1, 38 | int max_par = 16 39 | ); 40 | 41 | const int ERR_PROB_SHAPE = 1; 42 | const int ERR_KERN_SHAPE = 2; 43 | 44 | void mul( 45 | const torch::Tensor& A, 46 | const torch::Tensor& B, 47 | torch::Tensor& C, 48 | const torch::Tensor& s, 49 | torch::Tensor& workspace, 50 | int thread_k = -1, 51 | int thread_n = -1, 52 | int sms = -1, 53 | int max_par = 8 54 | ) { 55 | int prob_m = A.size(0); 56 | int prob_n = C.size(1); 57 | int prob_k = A.size(1); 58 | int groupsize = (s.size(0) == 1) ? -1 : prob_k / s.size(0); 59 | if (groupsize != -1 && groupsize * s.size(0) != prob_k) 60 | AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); 61 | if (workspace.numel() < prob_n / 128 * max_par) 62 | AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, "."); 63 | int dev = A.get_device(); 64 | int err = marlin_cuda( 65 | A.data_ptr(), 66 | B.data_ptr(), 67 | C.data_ptr(), 68 | s.data_ptr(), 69 | prob_m, prob_n, prob_k, 70 | workspace.data_ptr(), 71 | groupsize, 72 | dev, 73 | at::cuda::getCurrentCUDAStream(dev), 74 | thread_k, 75 | thread_n, 76 | sms, 77 | max_par 78 | ); 79 | if (err == ERR_PROB_SHAPE) { 80 | AT_ERROR( 81 | "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", 82 | " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "." 83 | ); 84 | } else if (err == ERR_KERN_SHAPE) { 85 | AT_ERROR( 86 | "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "." 87 | ); 88 | } 89 | } 90 | 91 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 92 | m.def("mul", &mul, "Marlin FP16xINT4 matmul."); 93 | } 94 | -------------------------------------------------------------------------------- /marlin/marlin_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | 18 | #ifndef MARLIN_CUDA_KERNEL_CUH 19 | #define MARLIN_CUDA_KERNEL_CUH 20 | 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | 28 | constexpr int ceildiv(int a, int b) { 29 | return (a + b - 1) / b; 30 | } 31 | 32 | // Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core 33 | // operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we 34 | // extensively use `#pragma unroll` throughout the kernel code to guarantee this. 35 | template 36 | struct Vec { 37 | T elems[n]; 38 | __device__ T& operator[](int i) { 39 | return elems[i]; 40 | } 41 | }; 42 | 43 | using I4 = Vec; 44 | 45 | // Matrix fragments for tensor core instructions; their precise layout is documented here: 46 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type 47 | using FragA = Vec; 48 | using FragB = Vec; 49 | using FragC = Vec; 50 | using FragS = Vec; // quantization scales 51 | 52 | // Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that 53 | // are not multiples of 16. 54 | __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { 55 | const int BYTES = 16; 56 | uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); 57 | asm volatile( 58 | "{\n" 59 | " .reg .pred p;\n" 60 | " setp.ne.b32 p, %0, 0;\n" 61 | " @p cp.async.cg.shared.global [%1], [%2], %3;\n" 62 | "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) 63 | ); 64 | } 65 | 66 | // Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for 67 | // quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need 68 | // for inputs A and outputs C. 69 | __device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { 70 | const int BYTES = 16; 71 | uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); 72 | asm volatile( 73 | "{\n" 74 | " .reg .b64 p;\n" 75 | " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" 76 | " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" 77 | "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) 78 | ); 79 | } 80 | 81 | // Async copy fence. 82 | __device__ inline void cp_async_fence() { 83 | asm volatile("cp.async.commit_group;\n" ::); 84 | } 85 | 86 | // Wait until at most `n` async copy stages are still pending. 87 | template 88 | __device__ inline void cp_async_wait() { 89 | asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); 90 | } 91 | 92 | // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. 93 | __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { 94 | const uint32_t* a = reinterpret_cast(&a_frag); 95 | const uint32_t* b = reinterpret_cast(&frag_b); 96 | float* c = reinterpret_cast(&frag_c); 97 | asm volatile( 98 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " 99 | "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" 100 | : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) 101 | : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), 102 | "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) 103 | ); 104 | } 105 | 106 | // Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. 107 | __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { 108 | uint32_t* a = reinterpret_cast(&frag_a); 109 | uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); 110 | asm volatile( 111 | "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" 112 | : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) 113 | ); 114 | } 115 | 116 | // Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to 117 | // automatically recognize it in all cases. 118 | template 119 | __device__ inline int lop3(int a, int b, int c) { 120 | int res; 121 | asm volatile( 122 | "lop3.b32 %0, %1, %2, %3, %4;\n" 123 | : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) 124 | ); 125 | return res; 126 | } 127 | 128 | // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. 129 | // We mostly follow the strategy in the link below, with some small changes: 130 | // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h 131 | __device__ inline FragB dequant(int q) { 132 | const int LO = 0x000f000f; 133 | const int HI = 0x00f000f0; 134 | const int EX = 0x64006400; 135 | // Guarantee that the `(a & b) | c` operations are LOP3s. 136 | int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); 137 | int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); 138 | // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. 139 | const int SUB = 0x64086408; 140 | const int MUL = 0x2c002c00; 141 | const int ADD = 0xd480d480; 142 | FragB frag_b; 143 | frag_b[0] = __hsub2( 144 | *reinterpret_cast(&lo), 145 | *reinterpret_cast(&SUB) 146 | ); 147 | frag_b[1] = __hfma2( 148 | *reinterpret_cast(&hi), 149 | *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) 150 | ); 151 | return frag_b; 152 | } 153 | 154 | // Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. 155 | __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { 156 | half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); 157 | frag_b[0] = __hmul2(frag_b[0], s); 158 | frag_b[1] = __hmul2(frag_b[1], s); 159 | } 160 | 161 | // Wait until barrier reaches `count`, then lock for current threadblock. 162 | __device__ inline void barrier_acquire(int* lock, int count) { 163 | if (threadIdx.x == 0) { 164 | int state = -1; 165 | do 166 | // Guarantee that subsequent writes by this threadblock will be visible globally. 167 | asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); 168 | while (state != count); 169 | } 170 | __syncthreads(); 171 | } 172 | 173 | // Release barrier and increment visitation count. 174 | __device__ inline void barrier_release(int* lock, bool reset = false) { 175 | __syncthreads(); 176 | if (threadIdx.x == 0) { 177 | if (reset) { 178 | lock[0] = 0; 179 | return; 180 | } 181 | int val = 1; 182 | // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. 183 | asm volatile ("fence.acq_rel.gpu;\n"); 184 | asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); 185 | } 186 | } 187 | 188 | 189 | template < 190 | const int threads, // number of threads in a threadblock 191 | const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock 192 | const int thread_n_blocks, // same for n dimension (output) 193 | const int thread_k_blocks, // same for k dimension (reduction) 194 | const int stages, // number of stages for the async global->shared fetch pipeline 195 | const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale 196 | > 197 | __global__ void Marlin( 198 | const int4* __restrict__ A, // fp16 input matrix of shape mxk 199 | const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn 200 | int4* __restrict__ C, // fp16 output buffer of shape mxn 201 | const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn 202 | int prob_m, // batch dimension m 203 | int prob_n, // output dimension n 204 | int prob_k, // reduction dimension k 205 | int* locks // extra global storage for barrier synchronization 206 | ) { 207 | // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple 208 | // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: 209 | // 0 1 3 210 | // 0 2 3 211 | // 1 2 4 212 | // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs 213 | // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as 214 | // possible. 215 | 216 | // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions 217 | int parallel = 1; 218 | if (prob_m > 16 * thread_m_blocks) { 219 | parallel = prob_m / (16 * thread_m_blocks); 220 | prob_m = 16 * thread_m_blocks; 221 | } 222 | 223 | int k_tiles = prob_k / 16 / thread_k_blocks; 224 | int n_tiles = prob_n / 16 / thread_n_blocks; 225 | int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); 226 | // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case 227 | // where a stripe starts in the middle of group. 228 | if (group_blocks != -1) 229 | iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); 230 | 231 | int slice_row = (iters * blockIdx.x) % k_tiles; 232 | int slice_col_par = (iters * blockIdx.x) / k_tiles; 233 | int slice_col = slice_col_par; 234 | int slice_iters; // number of threadblock tiles in the current slice 235 | int slice_count = 0; // total number of active threadblocks in the current slice 236 | int slice_idx; // index of threadblock in current slice; numbered bottom to top 237 | 238 | // We can easily implement parallel problem execution by just remapping indices and advancing global pointers 239 | if (slice_col_par >= n_tiles) { 240 | A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; 241 | C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; 242 | locks += (slice_col_par / n_tiles) * n_tiles; 243 | slice_col = slice_col_par % n_tiles; 244 | } 245 | 246 | // Compute all information about the current slice which is required for synchronization. 247 | auto init_slice = [&] () { 248 | slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); 249 | if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) 250 | slice_iters = 0; 251 | if (slice_iters == 0) 252 | return; 253 | if (slice_row + slice_iters > k_tiles) 254 | slice_iters = k_tiles - slice_row; 255 | slice_count = 1; 256 | slice_idx = 0; 257 | int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); 258 | if (col_first <= k_tiles * (slice_col_par + 1)) { 259 | int col_off = col_first - k_tiles * slice_col_par; 260 | slice_count = ceildiv(k_tiles - col_off, iters); 261 | if (col_off > 0) 262 | slice_count++; 263 | int delta_first = iters * blockIdx.x - col_first; 264 | if (delta_first < 0 || (col_off == 0 && delta_first == 0)) 265 | slice_idx = slice_count - 1; 266 | else { 267 | slice_idx = slice_count - 1 - delta_first / iters; 268 | if (col_off > 0) 269 | slice_idx--; 270 | } 271 | } 272 | if (slice_col == n_tiles) { 273 | A += 16 * thread_m_blocks * prob_k / 8; 274 | C += 16 * thread_m_blocks * prob_n / 8; 275 | locks += n_tiles; 276 | slice_col = 0; 277 | } 278 | }; 279 | init_slice(); 280 | 281 | int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory 282 | // We typically use `constexpr` to indicate that this value is a compile-time constant 283 | constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory 284 | constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory 285 | int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile 286 | constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes 287 | constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads 288 | constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile 289 | constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile 290 | constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile 291 | 292 | int b_gl_stride = 16 * prob_n / 32; 293 | constexpr int b_sh_stride = 32 * thread_n_blocks / 4; 294 | int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; 295 | int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); 296 | constexpr int b_sh_wr_delta = threads; 297 | constexpr int b_sh_rd_delta = threads; 298 | constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; 299 | constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; 300 | 301 | int s_gl_stride = prob_n / 8; 302 | constexpr int s_sh_stride = 16 * thread_n_blocks / 8; 303 | constexpr int s_sh_stage = s_sh_stride; 304 | int s_gl_rd_delta = s_gl_stride; 305 | 306 | // Global A read index of current thread. 307 | int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); 308 | a_gl_rd += a_gl_rd_delta_o * slice_row; 309 | // Shared write index of current thread. 310 | int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); 311 | // Shared read index. 312 | int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; 313 | a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); 314 | 315 | int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); 316 | b_gl_rd += b_sh_stride * slice_col; 317 | b_gl_rd += b_gl_rd_delta_o * slice_row; 318 | int b_sh_wr = threadIdx.x; 319 | int b_sh_rd = threadIdx.x; 320 | 321 | int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; 322 | int s_sh_wr = threadIdx.x; 323 | int s_sh_rd; 324 | // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major 325 | // layout in the former and in row-major in the latter case. 326 | if (group_blocks != -1) 327 | s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; 328 | else 329 | s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; 330 | 331 | // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than 332 | // required for a certain tilesize or when the batchsize is not a multiple of 16. 333 | bool a_sh_wr_pred[a_sh_wr_iters]; 334 | #pragma unroll 335 | for (int i = 0; i < a_sh_wr_iters; i++) 336 | a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; 337 | bool s_sh_wr_pred = threadIdx.x < s_sh_stride; 338 | 339 | // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank 340 | // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of 341 | // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based 342 | // on NSight-Compute) that each warp must also write a consecutive memory segment? 343 | auto transform_a = [&] (int i) { 344 | int row = i / a_gl_rd_delta_o; 345 | return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; 346 | }; 347 | // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory 348 | // accesses are static, we simply precompute both transformed reads and writes. 349 | int a_sh_wr_trans[a_sh_wr_iters]; 350 | #pragma unroll 351 | for (int i = 0; i < a_sh_wr_iters; i++) 352 | a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); 353 | int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; 354 | #pragma unroll 355 | for (int i = 0; i < b_sh_wr_iters; i++) { 356 | #pragma unroll 357 | for (int j = 0; j < thread_m_blocks; j++) 358 | a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); 359 | } 360 | 361 | // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between 362 | // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. 363 | const int4* B_ptr[b_sh_wr_iters]; 364 | #pragma unroll 365 | for (int i = 0; i < b_sh_wr_iters; i++) 366 | B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; 367 | 368 | extern __shared__ int4 sh[]; 369 | // Shared memory storage for global fetch pipelines. 370 | int4* sh_a = sh; 371 | int4* sh_b = sh_a + (stages * a_sh_stage); 372 | int4* sh_s = sh_b + (stages * b_sh_stage); 373 | // Register storage for double buffer of shared memory reads. 374 | FragA frag_a[2][thread_m_blocks]; 375 | I4 frag_b_quant[2]; 376 | FragC frag_c[thread_m_blocks][4][2]; 377 | FragS frag_s[2][4]; 378 | 379 | // Zero accumulators. 380 | auto zero_accums = [&] () { 381 | #pragma unroll 382 | for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) 383 | reinterpret_cast(frag_c)[i] = 0; 384 | }; 385 | 386 | // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. 387 | auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { 388 | if (pred) { 389 | int4* sh_a_stage = sh_a + a_sh_stage * pipe; 390 | #pragma unroll 391 | for (int i = 0; i < a_sh_wr_iters; i++) { 392 | cp_async4_pred( 393 | &sh_a_stage[a_sh_wr_trans[i]], 394 | &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], 395 | a_sh_wr_pred[i] 396 | ); 397 | } 398 | int4* sh_b_stage = sh_b + b_sh_stage * pipe; 399 | #pragma unroll 400 | for (int i = 0; i < b_sh_wr_iters; i++) { 401 | cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); 402 | B_ptr[i] += b_gl_rd_delta_o; 403 | } 404 | // Only fetch scales if this tile starts a new group 405 | if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { 406 | int4* sh_s_stage = sh_s + s_sh_stage * pipe; 407 | if (s_sh_wr_pred) 408 | cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); 409 | s_gl_rd += s_gl_rd_delta; 410 | } 411 | } 412 | // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. 413 | cp_async_fence(); 414 | }; 415 | 416 | // Wait until the next thread tile has been loaded to shared memory. 417 | auto wait_for_stage = [&] () { 418 | // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when 419 | // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). 420 | cp_async_wait(); 421 | __syncthreads(); 422 | }; 423 | 424 | // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. 425 | auto fetch_to_registers = [&] (int k, int pipe) { 426 | // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a 427 | // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the 428 | // compiler and correspondingly a noticable drop in performance. 429 | if (group_blocks != -1) { 430 | int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); 431 | reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; 432 | } 433 | int4* sh_a_stage = sh_a + a_sh_stage * pipe; 434 | #pragma unroll 435 | for (int i = 0; i < thread_m_blocks; i++) 436 | ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); 437 | int4* sh_b_stage = sh_b + b_sh_stage * pipe; 438 | frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); 439 | }; 440 | 441 | // Execute the actual tensor core matmul of a sub-tile. 442 | auto matmul = [&] (int k) { 443 | // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. 444 | #pragma unroll 445 | for (int j = 0; j < 4; j++) { 446 | int b_quant = frag_b_quant[k % 2][j]; 447 | int b_quant_shift = b_quant >> 8; 448 | FragB frag_b0 = dequant(b_quant); 449 | // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. 450 | if (group_blocks != -1) 451 | scale(frag_b0, frag_s[k % 2][j], 0); 452 | FragB frag_b1 = dequant(b_quant_shift); 453 | if (group_blocks != -1) 454 | scale(frag_b1, frag_s[k % 2][j], 1); 455 | #pragma unroll 456 | for (int i = 0; i < thread_m_blocks; i++) { 457 | mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); 458 | mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); 459 | } 460 | } 461 | }; 462 | 463 | // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n 464 | // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output 465 | // location; which we have to reduce over in the end. We do in shared memory. 466 | auto thread_block_reduce = [&] () { 467 | constexpr int red_off = threads / b_sh_stride / 2; 468 | if (red_off >= 1) { 469 | int red_idx = threadIdx.x / b_sh_stride; 470 | constexpr int red_sh_stride = b_sh_stride * 4 * 2; 471 | constexpr int red_sh_delta = b_sh_stride; 472 | int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); 473 | 474 | // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, 475 | // e.g., for two warps we write only once by warp 1 and read only once by warp 0. 476 | 477 | #pragma unroll 478 | for (int m_block = 0; m_block < thread_m_blocks; m_block++) { 479 | #pragma unroll 480 | for (int i = red_off; i > 0; i /= 2) { 481 | if (i <= red_idx && red_idx < 2 * i) { 482 | #pragma unroll 483 | for (int j = 0; j < 4 * 2; j++) { 484 | int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); 485 | if (i < red_off) { 486 | float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); 487 | float* c_wr = reinterpret_cast(&sh[red_sh_wr]); 488 | #pragma unroll 489 | for (int k = 0; k < 4; k++) 490 | reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; 491 | } 492 | sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; 493 | } 494 | } 495 | __syncthreads(); 496 | } 497 | if (red_idx == 0) { 498 | #pragma unroll 499 | for (int i = 0; i < 4 * 2; i++) { 500 | float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); 501 | #pragma unroll 502 | for (int j = 0; j < 4; j++) 503 | reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; 504 | } 505 | } 506 | __syncthreads(); 507 | } 508 | } 509 | }; 510 | 511 | // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over 512 | // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather 513 | // small, we perform this reduction serially in L2 cache. 514 | auto global_reduce = [&] (bool first = false, bool last = false) { 515 | // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. 516 | // To do this, we write out results in FP16 (but still reduce with FP32 compute). 517 | constexpr int active_threads = 32 * thread_n_blocks / 4; 518 | if (threadIdx.x < active_threads) { 519 | int c_gl_stride = prob_n / 8; 520 | int c_gl_wr_delta_o = 8 * c_gl_stride; 521 | int c_gl_wr_delta_i = 4 * (active_threads / 32); 522 | int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; 523 | c_gl_wr += (2 * thread_n_blocks) * slice_col; 524 | constexpr int c_sh_wr_delta = active_threads; 525 | int c_sh_wr = threadIdx.x; 526 | 527 | int row = (threadIdx.x % 32) / 4; 528 | 529 | if (!first) { 530 | // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, 531 | // hence we also use async-copies even though these fetches are not actually asynchronous. 532 | #pragma unroll 533 | for (int i = 0; i < thread_m_blocks * 4; i++) { 534 | cp_async4_pred( 535 | &sh[c_sh_wr + c_sh_wr_delta * i], 536 | &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], 537 | i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m 538 | ); 539 | } 540 | cp_async_fence(); 541 | cp_async_wait<0>(); 542 | } 543 | 544 | #pragma unroll 545 | for (int i = 0; i < thread_m_blocks * 4; i++) { 546 | if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { 547 | if (!first) { 548 | int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; 549 | #pragma unroll 550 | for (int j = 0; j < 2 * 4; j++) { 551 | reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( 552 | reinterpret_cast<__half*>(&c_red)[j] 553 | ); 554 | } 555 | } 556 | if (!last) { 557 | int4 c; 558 | #pragma unroll 559 | for (int j = 0; j < 2 * 4; j++) { 560 | reinterpret_cast<__half*>(&c)[j] = __float2half( 561 | reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] 562 | ); 563 | } 564 | C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; 565 | } 566 | } 567 | } 568 | } 569 | }; 570 | 571 | // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, 572 | // the reduction above is performed in fragment layout. 573 | auto write_result = [&] () { 574 | int c_gl_stride = prob_n / 8; 575 | constexpr int c_sh_stride = 2 * thread_n_blocks + 1; 576 | int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); 577 | constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); 578 | 579 | int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); 580 | c_gl_wr += (2 * thread_n_blocks) * slice_col; 581 | int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; 582 | c_sh_wr += 32 * (threadIdx.x / 32); 583 | int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); 584 | 585 | int c_gl_wr_end = c_gl_stride * prob_m; 586 | 587 | // We first reorder in shared memory to guarantee the most efficient final global write patterns 588 | auto write = [&] (int idx, float c0, float c1, FragS& s) { 589 | half2 res = __halves2half2(__float2half(c0), __float2half(c1)); 590 | if (group_blocks == -1) // for per-column quantization we finally apply the scale here 591 | res = __hmul2(res, s[0]); 592 | ((half2*) sh)[idx] = res; 593 | }; 594 | if (threadIdx.x / 32 < thread_n_blocks / 4) { 595 | #pragma unroll 596 | for (int i = 0; i < thread_m_blocks; i++) { 597 | #pragma unroll 598 | for (int j = 0; j < 4; j++) { 599 | int wr = c_sh_wr + 8 * j; 600 | write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); 601 | write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); 602 | write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); 603 | write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); 604 | } 605 | c_sh_wr += 16 * (4 * c_sh_stride); 606 | } 607 | } 608 | __syncthreads(); 609 | 610 | #pragma unroll 611 | for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { 612 | if (c_gl_wr < c_gl_wr_end) { 613 | C[c_gl_wr] = sh[c_sh_rd]; 614 | c_gl_wr += c_gl_wr_delta; 615 | c_sh_rd += c_sh_rd_delta; 616 | } 617 | } 618 | }; 619 | 620 | // Start global fetch and register load pipelines. 621 | auto start_pipes = [&] () { 622 | #pragma unroll 623 | for (int i = 0; i < stages - 1; i++) 624 | fetch_to_shared(i, i, i < slice_iters); 625 | zero_accums(); 626 | wait_for_stage(); 627 | fetch_to_registers(0, 0); 628 | a_gl_rd += a_gl_rd_delta_o * (stages - 1); 629 | }; 630 | start_pipes(); 631 | 632 | // Main loop. 633 | while (slice_iters) { 634 | // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are 635 | // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. 636 | #pragma unroll 637 | for (int pipe = 0; pipe < stages;) { 638 | #pragma unroll 639 | for (int k = 0; k < b_sh_wr_iters; k++) { 640 | fetch_to_registers(k + 1, pipe % stages); 641 | if (k == b_sh_wr_iters - 2) { 642 | fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); 643 | pipe++; 644 | wait_for_stage(); 645 | } 646 | matmul(k); 647 | } 648 | slice_iters--; 649 | if (slice_iters == 0) 650 | break; 651 | } 652 | a_gl_rd += a_gl_rd_delta_o * stages; 653 | 654 | // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most 655 | // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. 656 | if (slice_iters == 0) { 657 | cp_async_wait<0>(); 658 | bool last = slice_idx == slice_count - 1; 659 | // For per-column scales, we only fetch them here in the final step before write-out 660 | if (group_blocks == -1 && last) { 661 | if (s_sh_wr_pred) 662 | cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); 663 | cp_async_fence(); 664 | } 665 | thread_block_reduce(); 666 | if (group_blocks == -1 && last) { 667 | cp_async_wait<0>(); 668 | __syncthreads(); 669 | if (threadIdx.x / 32 < thread_n_blocks / 4) { 670 | reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; 671 | reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; 672 | } 673 | } 674 | if (slice_count > 1) { // only globally reduce if there is more than one block in a slice 675 | barrier_acquire(&locks[slice_col], slice_idx); 676 | global_reduce(slice_idx == 0, last); 677 | barrier_release(&locks[slice_col], last); 678 | } 679 | if (last) // only the last block in a slice actually writes the result 680 | write_result(); 681 | slice_row = 0; 682 | slice_col_par++; 683 | slice_col++; 684 | init_slice(); 685 | if (slice_iters) { 686 | a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); 687 | #pragma unroll 688 | for (int i = 0; i < b_sh_wr_iters; i++) 689 | B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; 690 | if (slice_col == 0) { 691 | #pragma unroll 692 | for (int i = 0; i < b_sh_wr_iters; i++) 693 | B_ptr[i] -= b_gl_stride; 694 | } 695 | s_gl_rd = s_sh_stride * slice_col + threadIdx.x; 696 | start_pipes(); 697 | } 698 | } 699 | } 700 | } 701 | 702 | 703 | // 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more 704 | // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. 705 | const int THREADS = 256; 706 | const int STAGES = 4; // 4 pipeline stages fit into shared memory 707 | const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) 708 | 709 | #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ 710 | else if ( \ 711 | thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ 712 | group_blocks == GROUP_BLOCKS \ 713 | ) { \ 714 | cudaFuncSetAttribute( \ 715 | Marlin, \ 716 | cudaFuncAttributeMaxDynamicSharedMemorySize, \ 717 | SHARED_MEM \ 718 | ); \ 719 | Marlin< \ 720 | THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ 721 | ><<>>( \ 722 | A_ptr, B_ptr, C_ptr, s_ptr, \ 723 | prob_m, prob_n, prob_k, \ 724 | locks \ 725 | ); \ 726 | } 727 | 728 | const int ERR_PROB_SHAPE = 1; 729 | const int ERR_KERN_SHAPE = 2; 730 | 731 | int marlin_cuda( 732 | const void* A, 733 | const void* B, 734 | void* C, 735 | void* s, 736 | int prob_m, 737 | int prob_n, 738 | int prob_k, 739 | void* workspace, 740 | int groupsize = -1, 741 | int dev = 0, 742 | cudaStream_t stream = 0, 743 | int thread_k = -1, 744 | int thread_n = -1, 745 | int sms = -1, 746 | int max_par = 16 747 | ) { 748 | int tot_m = prob_m; 749 | int tot_m_blocks = ceildiv(tot_m, 16); 750 | int pad = 16 * tot_m_blocks - tot_m; 751 | 752 | if (sms == -1) 753 | cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); 754 | if (thread_k == -1 || thread_n == -1) { 755 | if (prob_m <= 16) { 756 | // For small batchizes, better partioning is slightly more important than better compute utilization 757 | thread_k = 128; 758 | thread_n = 128; 759 | } else { 760 | thread_k = 64; 761 | thread_n = 256; 762 | } 763 | } 764 | 765 | int thread_k_blocks = thread_k / 16; 766 | int thread_n_blocks = thread_n / 16; 767 | int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; 768 | int blocks = sms; 769 | 770 | if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) 771 | return ERR_PROB_SHAPE; 772 | if (prob_m == 0 || prob_n == 0 || prob_k == 0) 773 | return 0; 774 | 775 | const int4* A_ptr = (const int4*) A; 776 | const int4* B_ptr = (const int4*) B; 777 | int4* C_ptr = (int4*) C; 778 | const int4* s_ptr = (const int4*) s; 779 | 780 | int cols = prob_n / thread_n; 781 | int* locks = (int*) workspace; 782 | 783 | int ret = 0; 784 | for (int i = 0; i < tot_m_blocks; i += 4) { 785 | int thread_m_blocks = tot_m_blocks - i; 786 | prob_m = tot_m - 16 * i; 787 | int par = 1; 788 | if (thread_m_blocks > 4) { 789 | // Note that parallel > 1 currently only works for inputs without any padding 790 | par = (16 * thread_m_blocks - pad) / 64; 791 | if (par > max_par) 792 | par = max_par; 793 | prob_m = 64 * par; 794 | i += 4 * (par - 1); 795 | thread_m_blocks = 4; 796 | } 797 | 798 | // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) 799 | // in our testing, however many more are, in principle, possible. 800 | if (false) {} 801 | CALL_IF(1, 8, 8, -1) 802 | CALL_IF(1, 8, 8, 8) 803 | CALL_IF(1, 16, 4, -1) 804 | CALL_IF(1, 16, 4, 8) 805 | CALL_IF(2, 16, 4, -1) 806 | CALL_IF(2, 16, 4, 8) 807 | CALL_IF(3, 16, 4, -1) 808 | CALL_IF(3, 16, 4, 8) 809 | CALL_IF(4, 16, 4, -1) 810 | CALL_IF(4, 16, 4, 8) 811 | else 812 | ret = ERR_KERN_SHAPE; 813 | 814 | A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; 815 | C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; 816 | } 817 | 818 | return ret; 819 | } 820 | 821 | 822 | #endif 823 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils import cpp_extension 3 | 4 | setup( 5 | name='marlin', 6 | version='0.1.1', 7 | author='Elias Frantar', 8 | author_email='elias.frantar@ist.ac.at', 9 | description='Highly optimized FP16xINT4 CUDA matmul kernel.', 10 | install_requires=['numpy', 'torch'], 11 | packages=['marlin'], 12 | ext_modules=[cpp_extension.CUDAExtension( 13 | 'marlin_cuda', ['marlin/marlin_cuda.cpp', 'marlin/marlin_cuda_kernel.cu'] 14 | )], 15 | cmdclass={'build_ext': cpp_extension.BuildExtension}, 16 | ) 17 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | import marlin 8 | 9 | 10 | seed = 0 11 | np.random.seed(seed) 12 | torch.random.manual_seed(seed) 13 | 14 | DEV = torch.device('cuda:0') 15 | 16 | 17 | def gen_quant4(m, n, groupsize=-1): 18 | tile = 16 19 | maxq = 2 ** 4 - 1 20 | w = torch.randn((m, n), dtype=torch.half, device=DEV) 21 | if groupsize != -1: 22 | w = w.reshape((-1, groupsize, n)) 23 | w = w.permute(1, 0, 2) 24 | w = w.reshape((groupsize, -1)) 25 | s = torch.max(torch.abs(w), 0, keepdim=True)[0] 26 | s *= 2 / maxq 27 | w = torch.round(w / s).int() 28 | w += (maxq + 1) // 2 29 | w = torch.clamp(w, 0, maxq) 30 | ref = (w - (maxq + 1) // 2).half() * s 31 | if groupsize != -1: 32 | def reshape(w): 33 | w = w.reshape((groupsize, -1, n)) 34 | w = w.permute(1, 0, 2) 35 | w = w.reshape((m, n)).contiguous() 36 | return w 37 | ref = reshape(ref) 38 | w = reshape(w) 39 | s = s.reshape((-1, n)).contiguous() 40 | linear = nn.Linear(m, n) 41 | linear.weight.data = ref.t() 42 | # Workaround to test some special cases that are forbidden by the API 43 | layer = marlin.Layer(256, 256, groupsize=groupsize) 44 | if groupsize == -1: 45 | groupsize = m 46 | layer.k = m 47 | layer.n = n 48 | layer.groupsize = groupsize 49 | layer.B = torch.empty((m // 16, n * 16 // 8), dtype=torch.int, device=DEV) 50 | layer.s = torch.empty((m // groupsize, n), dtype=torch.half, device=DEV) 51 | layer.pack(linear, s.t()) 52 | q = layer.B 53 | s = layer.s 54 | return ref, q, s 55 | 56 | class Test(unittest.TestCase): 57 | 58 | def run_problem(self, m, n, k, thread_k, thread_n, groupsize=-1): 59 | print('% 5d % 6d % 6d % 4d % 4d % 4d' % (m, n, k, thread_k, thread_n, groupsize)) 60 | A = torch.randn((m, k), dtype=torch.half, device=DEV) 61 | B_ref, B, s = gen_quant4(k, n, groupsize=groupsize) 62 | C = torch.zeros((m, n), dtype=torch.half, device=DEV) 63 | C_ref = torch.matmul(A, B_ref) 64 | workspace = torch.zeros(n // 128 * 16, device=DEV) 65 | marlin.mul(A, B, C, s, workspace, thread_k, thread_n, -1) 66 | torch.cuda.synchronize() 67 | self.assertLess(torch.mean(torch.abs(C - C_ref)) / torch.mean(torch.abs(C_ref)), 0.001) 68 | 69 | def test_tiles(self): 70 | print() 71 | for m in [1, 2, 3, 4, 8, 12, 16, 24, 32, 48, 64, 118, 128, 152, 768, 1024]: 72 | for thread_k, thread_n in [(64, 256), (128, 128)]: 73 | if m > 16 and thread_k == 128: 74 | continue 75 | self.run_problem(m, 2 * 256, 1024, thread_k, thread_n) 76 | 77 | def test_k_stages_divisibility(self): 78 | print() 79 | for k in [3 * 64 + 64 * 4 * 2 + 64 * i for i in range(1, 4)]: 80 | self.run_problem(16, 2 * 256, k, 64, 256) 81 | 82 | def test_very_few_stages(self): 83 | print() 84 | for k in [64, 128, 192]: 85 | self.run_problem(16, 2 * 256, k, 64, 256) 86 | 87 | def test_llama_shapes(self): 88 | print() 89 | return 90 | MODELS = { 91 | ' 7B': [ 92 | (4096, 3 * 4096), 93 | (4096, 4096), 94 | (4096, 2 * 10752), 95 | (10752, 4096) 96 | ], 97 | '13B': [ 98 | (5120, 3 * 5120), 99 | (5120, 5120), 100 | (5120, 2 * 13568), 101 | (13568, 5120) 102 | ], 103 | '33B': [ 104 | (6656, 3 * 6656), 105 | (6656, 6656), 106 | (6656, 2 * 17664), 107 | (17664, 6656) 108 | ], 109 | '70B': [ 110 | (8192, 3 * 8192), 111 | (8192, 8192), 112 | (8192, 2 * 21760), 113 | (21760, 8192) 114 | ] 115 | } 116 | for _, layers in MODELS.items(): 117 | for layer in layers: 118 | for thread_k, thread_n in [(128, 128)]: 119 | for batch in [1, 16]: 120 | self.run_problem(batch, layer[1], layer[0], thread_k, thread_n) 121 | 122 | def test_errors(self): 123 | print() 124 | m, n, k = 16, 256, 64 125 | A = torch.randn((m, k), dtype=torch.half, device=DEV) 126 | B_ref, B, s = gen_quant4(k, n) 127 | C = torch.zeros((m, n), dtype=torch.half, device=DEV) 128 | workspace = torch.zeros(n // 128, device=DEV) 129 | err = False 130 | try: 131 | marlin.mul(A, B, C, s, workspace, 128, 128, -1) 132 | except: 133 | err = True 134 | self.assertTrue(err) 135 | err = False 136 | try: 137 | marlin.mul(A, B, C, s, workspace, 256, 256, -1) 138 | except: 139 | err = True 140 | self.assertTrue(err) 141 | s = torch.zeros((2, n), dtype=torch.half, device=DEV) 142 | err = False 143 | try: 144 | marlin.mul(A, B, C, s, workspace, 256, 256, -1) 145 | except: 146 | err = True 147 | self.assertTrue(err) 148 | 149 | def test_groups(self): 150 | print() 151 | for m in [16]: 152 | for groupsize in [128]: 153 | for n, k in [(256, 512), (256, 1024), (256 * 128, 1024)]: 154 | for thread_shape in [(128, 128), (64, 256)]: 155 | self.run_problem(m, n, k, *thread_shape, groupsize) 156 | 157 | 158 | if __name__ == '__main__': 159 | unittest.main() 160 | --------------------------------------------------------------------------------