├── .gitignore
├── LICENSE
├── README.md
├── assets
├── marlin.png
├── models.png
├── peak.png
└── sustained.png
├── bench.py
├── gptq
├── datautils.py
├── eval.py
├── gptq.py
├── llama2.py
└── quant.py
├── marlin
├── __init__.py
├── marlin_cuda.cpp
└── marlin_cuda_kernel.cu
├── setup.py
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 | dist
3 | profile.ncu-rep
4 | *.egg-info*
5 | backup
6 | marlin/__pycache__
7 | _backup
8 | gptq/__pycache__
9 | gptq/*.marlin.g128
10 | __pycache__
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | # Marlin
6 |
7 | This is Marlin, a **M**ixed **A**uto-**R**egressive **Lin**ear kernel (and the name of one of the planet's fastest fish), an extremely optimized FP16xINT4 matmul kernel aimed at LLM inference that can deliver close to ideal (4x)
8 | speedups up to batchsizes of 16-32 tokens (in contrast to the 1-2 tokens of prior work with comparable speedup). This makes Marlin well suited for larger-scale
9 | serving, speculative decoding or advanced multi-inference schemes such as CoT-Majority.
10 |
11 | ## Techniques:
12 |
13 | Most modern GPUs feature FLOP to byte ratios of around 100-200.
14 | Hence, as long as we perform less than 25-50 (tensor core) multiply-accumulates per 4-bit quantized weight, it should (theoretically) be possible to maintain near ideal 4x speedup over FP16 weights.
15 | This means that the full performance benefits of weight-only quantization should, in principle, extend to batchsizes 4-8x larger than what is currently achieved by existing kernels.
16 | However, actually realizing this in practice is very challenging, since we essentially need to fully utilize all available GPU resources (global memory, L2 cache, shared memory, tensor cores, vector cores), *simultaneously*.
17 | Marlin accomplishes this through numerous techniques and optimizations, briefly sketched below:
18 |
19 | * We organize computation in such a way that all activations are essentially always fetched from L2 cache and are further reused several times within registers to make sure that repeated loading from shared memory does not become a bottleneck either.
20 | * We execute global weight loads asynchronously, to all compute operations but also activations loads, with a cache policy that allows immediate eviction in order to not unnecessary pollute the L2 cache with values that are never reused.
21 | * We perform shared memory loads, whose footprint is quite significant due to relatively large activations, via double buffering to overlap them with computation and global loads.
22 | * We carefully order dequantization and tensor core instructions to ensure that both GPU pipelines are well saturated and do not bottleneck each other.
23 | * In general, both quantized weights and group scales are reshuffled offline, into a layout that gives ideal access patterns during execution, allowing for instance directly dequantizing weights into tensor core organization.
24 | * We have multiple warps in a threadblock compute partial results of the same output tile, in order to achieve higher warp counts, maximizing compute and latency hiding, without increasing the output tile size, which would make good partioning on realistic matrices difficult.
25 | * All loads use maximum vector length for peak efficiency and we also perform several layout transformations to guarantee that all shared memory reads and writes are conflict-free, in particular for matrix loading instructions, and that global reduction happens at minimal memory overhead.
26 | * We set up and unroll loops such that the majority of memory offsets are static, minimizing runtime index calculations.
27 | * We implement a "striped" paritioning scheme where the segment of tiles processed by each SM may (partially) span over multiple column "slices". This leads to good SM utlization on most matrix shapes, while minimizing required global reduction steps.
28 | * Global reduction happens directly in the output buffer (temporarily downcasting FP32 accumulators to FP16) which is kept in L2 cache; reduction operations are generally optimized to avoid any unnecessary reads or writes as well.
29 | * Overall, the kernel's PTX assembly was extensively analyzed in NSight-Compute, and the CUDA code features several more redundant or slightly suboptimal constructions that however compile to faster PTX.
30 |
31 | ## Benchmarks:
32 |
33 | We first compare the performance of Marlin with other popular 4-bit inference kernels, on a large matrix that can be
34 | ideally partioned on an NVIDIA A10 GPU. This allows all kernels to reach pretty much their best possible performance.
35 | All kernels are executed at groupsize 128 (however, we note that scale formats are not 100% identical).
36 |
37 |
38 |

39 |
40 |
41 | While existing kernels achieve relatively close to the optimal 3.87x (note the 0.125 bits storage overhead of the
42 | group scales) speedup at batchsize 1, their performance degrades quickly as the number of inputs is increased. In
43 | contrast, Marlin delivers essentially ideal speedups at all batchsizes, enabling the maximum possible 3.87x speedup up
44 | to batchsizes around 16-32.
45 |
46 | Due to its striped partioning scheme, Marlin brings strong performance also on real (smaller) matrices and various GPUs.
47 | This is demonstrated by the below results, where we benchmark, at batchsize 16, the overall runtime across all linear
48 | layers in Transformer blocks of popular open-source models.
49 |
50 |
51 |

52 |
53 |
54 | Finally, we also study what performance can be sustained over longer periods of time, at locked base GPU clock.
55 | Interestingly, we find that reduced clock speeds significantly harm the relative speedups of prior kernels, but have no
56 | effect on Marlin's virtually optimal performance (relative to the lower clock setting).
57 |
58 |
59 |

60 |
61 |
62 | ## Requirements:
63 |
64 | * CUDA >= 11.8 (in particular also for the `nvcc` compiler, the version of which should match with torch)
65 | * NVIDIA GPU with compute capability >= 8.0 (Ampere or Ada, Marlin is not yet optimized for Hopper)
66 | * `torch>=2.0.0`
67 | * `numpy`
68 | For running quantization script one also needs:
69 | * `transformers`
70 | * `datasets`
71 | * `sentencepiece`
72 |
73 | ## Usage:
74 |
75 | If all requirements are met, it should be possible to install Marlin by calling
76 |
77 | ```
78 | pip install .
79 | ```
80 |
81 | in the root folder of this repository.
82 |
83 | Afterwards, the easiest way to use the Marlin kernel is via a `marlin.Layer`, a torch-module representing a Marlin
84 | quantized layer. It allows converting a "fake-quantized" (dequantized values stored in FP16) `torch.Linear` layer into
85 | the compressed Marlin format via `marlin.Layer.pack(linear, scales)`. Alternatively, the kernel can also be called
86 | directly through `marlin.mul(..)`, provided that weights and scales have already been appropriately preprocessed (see
87 | `marlin.Layer.pack(...)`). The kernel itself can be found in the self-contained `marlin/marlin_cuda_kernel.cu` file,
88 | which does not contain any dependencies beyond base-CUDA and should thus be easy to integrate into other lower-level
89 | frameworks.
90 |
91 | Correctness tests can be executed via `python test.py` and benchmarks via `python bench.py`. Please note that in order
92 | to reproduce our "sustainable performance" benchmarks, the GPU clocks need to be locked to their respective base values
93 | using:
94 |
95 | ```
96 | sudo nvidia-smi --lock-gpu-clocks=BASE_GPU_CLOCK --lock-memory-clocks=BASE_MEM_CLOCK
97 | ```
98 |
99 | Additionally, if ECC is enabled (e.g., on an A10), then the maximum achievable memory bandwidth will be 10-15% lower
100 | than in the official spec sheet as every memory requests will contain checksum overheads. This can be disabled via
101 |
102 | ```
103 | sudo nvidia-smi -e 0
104 | ```
105 |
106 | which we do in our A10 benchmarks.
107 |
108 | ## GPTQ Example:
109 |
110 | In the `gptq` subfolder, we also provide a slightly improved version of the [GPTQ](https://github.com/IST-DASLab/gptq) algorithm, with better group grid clipping and non-uniform calibration sample length, that can produce Marlin-compatible 4-bit versions of Llama2 models.
111 | Additionally, there is a script to evaluate such compressed models (using Marlin kernels) in the popular [LLM eval harness](https://github.com/EleutherAI/lm-evaluation-harness).
112 | The script below was tested with `lm-eval-harness==0.4.0` and may not work with newer or older versions.
113 | Here are corresponding sample commands (`marlin`, `transformers` and `datasets` packages must be installed):
114 |
115 | ```
116 | % Compress Llama2 model and export model in Marlin format.
117 | python llama2.py LLAMA2_CHECKPOINT --wbits 4 --save checkpoint.pt
118 | % Perform perplexity evaluation of uncompressed model.
119 | python llama2.py LLAMA2_CHECKPOINT
120 | % Evaluate compressed model (with Marlin kernels) in the eval harness.
121 | python eval.py --model hf --model_args pretrained=LLAMA2_CHECKPOINT --tasks mmlu \
122 | --marlin_checkpoint checkpoint.marlin.g128
123 | % Evaluate full precision baseline.
124 | python eval.py --model hf --model_args pretrained=LLAMA2_CHECKPOINT --tasks mmlu
125 | ```
126 |
127 | We measure the following WikiText and Red-Pajama perplexities, as well as MMLU zero-shot accuracy, for 4-bit (group=128) Marlin models:
128 |
129 | | Llama2 | Wiki2 (FP16) | Wiki2 (INT4) | RedPaj (FP16) | RedPaj (INT4) | MMLU (FP16) | MMLU (INT4) |
130 | |:---:|:----:|:----:|:----:|:----:|:-----:|:-----:|
131 | | 7B | 5.12 | 5.27 | 6.14 | 6.30 | 41.80 | 40.07 |
132 | | 13B | 4.57 | 4.67 | 5.67 | 5.79 | 52.10 | 51.13 |
133 | | 70B | 3.12 | 3.21 | 4.74 | 4.81 | 65.43 | 64.81 |
134 |
135 | We note that this GPTQ example is currently intended mostly as a demonstration of how to produce accurate Marlin models and as an end-to-end validation of kernel correctness (rather than to be a flexible compression tool).
136 |
137 | ## Cite:
138 |
139 | If you found this work useful, please consider citing:
140 |
141 | ```
142 | @article{frantar2024marlin,
143 | title={MARLIN: Mixed-Precision Auto-Regressive Parallel Inference on Large Language Models},
144 | author={Frantar, Elias and Castro, Roberto L and Chen, Jiale and Hoefler, Torsten and Alistarh, Dan},
145 | journal={arXiv preprint arXiv:2408.11743},
146 | year={2024}
147 | }
148 | ```
149 |
--------------------------------------------------------------------------------
/assets/marlin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IST-DASLab/marlin/1f25790bdd49fba53106164a24666dade68d7c90/assets/marlin.png
--------------------------------------------------------------------------------
/assets/models.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IST-DASLab/marlin/1f25790bdd49fba53106164a24666dade68d7c90/assets/models.png
--------------------------------------------------------------------------------
/assets/peak.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IST-DASLab/marlin/1f25790bdd49fba53106164a24666dade68d7c90/assets/peak.png
--------------------------------------------------------------------------------
/assets/sustained.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IST-DASLab/marlin/1f25790bdd49fba53106164a24666dade68d7c90/assets/sustained.png
--------------------------------------------------------------------------------
/bench.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import numpy as np
4 | import torch
5 | import marlin
6 |
7 | import time
8 |
9 | def benchmark(f, warmup=1, iter=10):
10 | for i in range(warmup + iter):
11 | f()
12 | # We do not synchronize here in order to hide the kernel launch overhead during benchmarkining as this will also
13 | # happen during realistic model inference as many launches are submitted to the kernel queue.
14 | if i == warmup - 1:
15 | torch.cuda.synchronize()
16 | tick = time.time()
17 | torch.cuda.synchronize()
18 | res = (time.time() - tick) / iter
19 | # Make sure there is enough to "cool down" the GPU in between benchmarks to avoid throttling for later runs when
20 | # we execute many benchmarks consecutively
21 | time.sleep(1.)
22 | return res
23 |
24 | def get_problem(m, n, k, groupsize=-1):
25 | if groupsize == -1:
26 | groupsize = k
27 | dev = torch.device('cuda:0')
28 | A = torch.randn((m, k), dtype=torch.half, device=dev)
29 | B = torch.randint(low=-2**31, high=2**31, size=(k * n // 8,), device=dev)
30 | B_ref = torch.randn((k, n), dtype=torch.half, device=dev)
31 | C = torch.zeros((m, n), dtype=torch.half, device=dev)
32 | s = torch.zeros((k // groupsize, n), dtype=torch.half, device=dev)
33 | torch.cuda.synchronize()
34 | return A, B, C, B_ref, s
35 |
36 | def benchmark_dense(A, B, C):
37 | res = benchmark(lambda: torch.matmul(A, B, out=C))
38 | return {
39 | 's': res,
40 | 'TFLOP/s': 2 * A.numel() * C.shape[1] / res / 10 ** 12,
41 | 'GB/s': (2 * A.numel() + 2 * B.numel() + 2 * C.numel()) / res / 10 ** 9
42 | }
43 |
44 | def benchmark_quant(A, B, C, s, thread_k, thread_n, sms):
45 | workspace = torch.zeros(C.shape[1] // 128 * 16, device=torch.device('cuda:0'))
46 | res = benchmark(lambda: marlin.mul(A, B, C, s, workspace, thread_k, thread_n, sms))
47 | return {
48 | 's': res,
49 | 'TFLOP/s': 2 * A.numel() * C.shape[1] / res / 10 ** 12,
50 | 'GB/s': (2 * A.numel() + 4 * B.numel() + 2 * C.numel() + 2 * s.numel()) / res / 10 ** 9
51 | }
52 |
53 | # Pass the SM count for known GPUs to avoid the kernel having to query this information (this is very minor)
54 | gpu = torch.cuda.get_device_name(0)
55 | if 'A100' in gpu:
56 | SMS = 108
57 | elif 'A10' in gpu:
58 | SMS = 72
59 | elif '3090' in gpu:
60 | SMS = 82
61 | elif 'A6000' in gpu:
62 | SMS = 84
63 | else:
64 | SMS = -1
65 |
66 | MODELS = {
67 | 'ideal': [
68 | (4 * 256 * SMS, 256 * SMS)
69 | ],
70 | 'Llama7B': [
71 | (4096, 3 * 4096),
72 | (4096, 4096),
73 | (4096, 2 * 10752),
74 | (10752, 4096)
75 | ],
76 | 'Llama13B': [
77 | (5120, 3 * 5120),
78 | (5120, 5120),
79 | (5120, 2 * 13568),
80 | (13568, 5120)
81 | ],
82 | 'Llama33B': [
83 | (6656, 3 * 6656),
84 | (6656, 6656),
85 | (6656, 2 * 17664),
86 | (17664, 6656)
87 | ],
88 | 'Llama65B': [
89 | (8192, 3 * 8192),
90 | (8192, 8192),
91 | (8192, 2 * 21760),
92 | (21760, 8192)
93 | ],
94 | 'Falcon180B': [
95 | # Note that parallel attention and FC allows layer fusions
96 | (14848, 14848 * 5 + 1024),
97 | (14848 * 5, 14848)
98 | ]
99 | }
100 |
101 | # Set to true in order to run a more complete benchmark sweep; the default is reproduce README experiments
102 | ALL = False
103 |
104 | for groupsize in [-1, 128] if ALL else [128]:
105 | print('groupsize=%d' % groupsize)
106 | print()
107 | for model, layers in MODELS.items():
108 | print(model)
109 | if ALL:
110 | batchsizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
111 | else:
112 | batchsizes = [1, 2, 4, 8, 16, 32, 64, 128]
113 | for batch in batchsizes:
114 | if not ALL and model != 'ideal' and batch != 16:
115 | continue
116 | tot_q = {'s': 0, 'TFLOP/s': 0, 'GB/s': 0, 'speedup': 0}
117 | for layer in layers:
118 | A, B, C, B_ref, s = get_problem(batch, layer[1], layer[0], groupsize)
119 | res_d = benchmark_dense(A, B_ref, C)
120 | if model == 'ideal' and batch == 16:
121 | # This is a special case constructed to be optimal for a thread-shape different than the default one
122 | res_q = benchmark_quant(A, B, C, s, 64, 256, SMS)
123 | else:
124 | res_q = benchmark_quant(A, B, C, s, -1, -1, SMS)
125 | res_q['speedup'] = res_d['s'] / res_q['s']
126 | tot_q['s'] += res_q['s']
127 | for k in tot_q:
128 | if k != 's':
129 | tot_q[k] += res_q[k] * res_q['s']
130 | for k in tot_q:
131 | if k != 's':
132 | tot_q[k] /= tot_q['s']
133 | print('batch=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f' % (
134 | batch,
135 | tot_q['s'],
136 | tot_q['TFLOP/s'],
137 | tot_q['GB/s'],
138 | tot_q['speedup']
139 | ))
140 | print()
141 |
--------------------------------------------------------------------------------
/gptq/datautils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def set_seed(seed):
6 | np.random.seed(seed)
7 | torch.random.manual_seed(seed)
8 |
9 |
10 | def get_wikitext2(nsamples, seed, seqlen, model):
11 | from datasets import load_dataset
12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
14 |
15 | from transformers import AutoTokenizer
16 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
17 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
19 |
20 | import random
21 | random.seed(seed)
22 | trainloader = []
23 | for _ in range(nsamples):
24 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
25 | j = i + seqlen
26 | inp = trainenc.input_ids[:, i:j]
27 | tar = inp.clone()
28 | tar[:, :-1] = -100
29 | trainloader.append((inp, tar))
30 | testloader = []
31 | for i in range(0, testenc.input_ids.shape[1] - seqlen, seqlen):
32 | testloader.append(testenc.input_ids[:, i:(i + seqlen)])
33 |
34 | return trainloader, testloader
35 |
36 | def get_red(nsamples, seed, seqlen, model):
37 | VALSAMPLES = 1024
38 |
39 | from datasets import load_dataset
40 | from transformers import AutoTokenizer
41 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
42 | traindata = load_dataset('togethercomputer/RedPajama-Data-1T-Sample', split='train')
43 |
44 | np.random.seed(0)
45 | perm = np.random.permutation(len(traindata))
46 |
47 | dataloader = []
48 | for i in perm:
49 | tokens = tokenizer(traindata[int(i)]['text'], return_tensors='pt').input_ids
50 | if not (1 < tokens.shape[1] <= seqlen):
51 | continue
52 | dataloader.append(tokens)
53 | if len(dataloader) == nsamples + VALSAMPLES:
54 | break
55 | trainloader = dataloader[VALSAMPLES:]
56 | testloader = dataloader[:VALSAMPLES]
57 | return trainloader, testloader
58 |
59 |
60 | def get_loaders(
61 | name, nsamples=256, seed=0, seqlen=2048, model=''
62 | ):
63 | if 'wikitext2' in name:
64 | return get_wikitext2(nsamples, seed, seqlen, model)
65 | return data, None
66 | if 'red' in name:
67 | return get_red(nsamples, seed, seqlen, model)
68 |
69 |
--------------------------------------------------------------------------------
/gptq/eval.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/__main__.py
2 | # with minor modifications for Marlin checkpoint loading as I didn't find an easy way to call `lm_eval.cli_evaluate(...)` directly
3 |
4 |
5 | import marlin
6 |
7 | # Save checkpoint name here since passing around extra args seems to confuse the eval harness
8 | MARLIN_CHECKPOINT = ''
9 |
10 | def get_llama_marlin(name, *args, **kwargs):
11 | import torch
12 | def skip(*args, **kwargs):
13 | pass
14 | torch.nn.init.kaiming_uniform_ = skip
15 | torch.nn.init.uniform_ = skip
16 | torch.nn.init.normal_ = skip
17 | from transformers import LlamaForCausalLM
18 | model = LlamaForCausalLM.from_pretrained(name, torch_dtype='auto')
19 | # Not really sure why this is sometimes > 1, but it messes up quantized inference ...
20 | # Fortunately, just setting it to 1 doesn't seem to affect standard inference
21 | model.config.pretraining_tp = 1
22 | def name_filter(n):
23 | if 'q_proj' in n or 'k_proj' in n or 'v_proj' in n or 'o_proj' in n:
24 | return True
25 | if 'mlp.gate_proj' in n or 'mlp.up_proj' in n or 'mlp.down_proj' in n:
26 | return True
27 | return False
28 | groupsize = -1 if MARLIN_CHECKPOINT.endswith('marlin') else 128
29 | marlin.replace_linear(model, name_filter, groupsize=groupsize)
30 | model.load_state_dict(torch.load(MARLIN_CHECKPOINT))
31 | return model
32 |
33 |
34 | import argparse
35 | import json
36 | import logging
37 | import os
38 | import re
39 | import sys
40 | from pathlib import Path
41 | from typing import Union
42 |
43 | import numpy as np
44 |
45 | from lm_eval import evaluator, utils
46 | from lm_eval.api.registry import ALL_TASKS
47 | from lm_eval.tasks import include_path, initialize_tasks
48 | from lm_eval.utils import make_table
49 |
50 |
51 | def _handle_non_serializable(o):
52 | if isinstance(o, np.int64) or isinstance(o, np.int32):
53 | return int(o)
54 | elif isinstance(o, set):
55 | return list(o)
56 | else:
57 | return str(o)
58 |
59 |
60 | def parse_eval_args() -> argparse.Namespace:
61 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
62 | parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
63 | parser.add_argument(
64 | "--tasks",
65 | "-t",
66 | default=None,
67 | metavar="task1,task2",
68 | help="To get full list of tasks, use the command lm-eval --tasks list",
69 | )
70 | parser.add_argument(
71 | "--model_args",
72 | "-a",
73 | default="",
74 | help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
75 | )
76 | parser.add_argument(
77 | "--num_fewshot",
78 | "-f",
79 | type=int,
80 | default=None,
81 | metavar="N",
82 | help="Number of examples in few-shot context",
83 | )
84 | parser.add_argument(
85 | "--batch_size",
86 | "-b",
87 | type=str,
88 | default=1,
89 | metavar="auto|auto:N|N",
90 | help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
91 | )
92 | parser.add_argument(
93 | "--max_batch_size",
94 | type=int,
95 | default=None,
96 | metavar="N",
97 | help="Maximal batch size to try with --batch_size auto.",
98 | )
99 | parser.add_argument(
100 | "--device",
101 | type=str,
102 | default=None,
103 | help="Device to use (e.g. cuda, cuda:0, cpu).",
104 | )
105 | parser.add_argument(
106 | "--output_path",
107 | "-o",
108 | default=None,
109 | type=str,
110 | metavar="DIR|DIR/file.json",
111 | help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
112 | )
113 | parser.add_argument(
114 | "--limit",
115 | "-L",
116 | type=float,
117 | default=None,
118 | metavar="N|0 None:
189 | if not args:
190 | # we allow for args to be passed externally, else we parse them ourselves
191 | args = parse_eval_args()
192 | if args.marlin_checkpoint:
193 | global MARLIN_CHECKPOINT
194 | MARLIN_CHECKPOINT = args.marlin_checkpoint
195 | del args.marlin_checkpoint
196 | # Overwrite model load with marlin load
197 | import transformers
198 | transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin)
199 |
200 | eval_logger = utils.eval_logger
201 | eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
202 | eval_logger.info(f"Verbosity set to {args.verbosity}")
203 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
204 |
205 | initialize_tasks(args.verbosity)
206 |
207 | if args.limit:
208 | eval_logger.warning(
209 | " --limit SHOULD ONLY BE USED FOR TESTING."
210 | "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
211 | )
212 | if args.include_path is not None:
213 | eval_logger.info(f"Including path: {args.include_path}")
214 | include_path(args.include_path)
215 |
216 | if args.tasks is None:
217 | task_names = ALL_TASKS
218 | elif args.tasks == "list":
219 | eval_logger.info(
220 | "Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS)))
221 | )
222 | sys.exit()
223 | else:
224 | if os.path.isdir(args.tasks):
225 | import glob
226 |
227 | task_names = []
228 | yaml_path = os.path.join(args.tasks, "*.yaml")
229 | for yaml_file in glob.glob(yaml_path):
230 | config = utils.load_yaml_config(yaml_file)
231 | task_names.append(config)
232 | else:
233 | tasks_list = args.tasks.split(",")
234 | task_names = utils.pattern_match(tasks_list, ALL_TASKS)
235 | for task in [task for task in tasks_list if task not in task_names]:
236 | if os.path.isfile(task):
237 | config = utils.load_yaml_config(task)
238 | task_names.append(config)
239 | task_missing = [
240 | task
241 | for task in tasks_list
242 | if task not in task_names and "*" not in task
243 | ] # we don't want errors if a wildcard ("*") task name was used
244 |
245 | if task_missing:
246 | missing = ", ".join(task_missing)
247 | eval_logger.error(
248 | f"Tasks were not found: {missing}\n"
249 | f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
250 | )
251 | raise ValueError(
252 | f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
253 | )
254 |
255 | if args.output_path:
256 | path = Path(args.output_path)
257 | # check if file or 'dir/results.json' exists
258 | if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
259 | eval_logger.warning(
260 | f"File already exists at {path}. Results will be overwritten."
261 | )
262 | output_path_file = path.joinpath("results.json")
263 | assert not path.is_file(), "File already exists"
264 | # if path json then get parent dir
265 | elif path.suffix in (".json", ".jsonl"):
266 | output_path_file = path
267 | path.parent.mkdir(parents=True, exist_ok=True)
268 | path = path.parent
269 | else:
270 | path.mkdir(parents=True, exist_ok=True)
271 | output_path_file = path.joinpath("results.json")
272 | elif args.log_samples and not args.output_path:
273 | assert args.output_path, "Specify --output_path"
274 |
275 | eval_logger.info(f"Selected Tasks: {task_names}")
276 |
277 | results = evaluator.simple_evaluate(
278 | model=args.model,
279 | model_args=args.model_args,
280 | tasks=task_names,
281 | num_fewshot=args.num_fewshot,
282 | batch_size=args.batch_size,
283 | max_batch_size=args.max_batch_size,
284 | device=args.device,
285 | use_cache=args.use_cache,
286 | limit=args.limit,
287 | decontamination_ngrams_path=args.decontamination_ngrams_path,
288 | check_integrity=args.check_integrity,
289 | write_out=args.write_out,
290 | log_samples=args.log_samples,
291 | gen_kwargs=args.gen_kwargs,
292 | )
293 |
294 | if results is not None:
295 | if args.log_samples:
296 | samples = results.pop("samples")
297 | dumped = json.dumps(
298 | results, indent=2, default=_handle_non_serializable, ensure_ascii=False
299 | )
300 | if args.show_config:
301 | print(dumped)
302 |
303 | batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
304 |
305 | if args.output_path:
306 | output_path_file.open("w").write(dumped)
307 |
308 | if args.log_samples:
309 | for task_name, config in results["configs"].items():
310 | output_name = "{}_{}".format(
311 | re.sub("/|=", "__", args.model_args), task_name
312 | )
313 | filename = path.joinpath(f"{output_name}.jsonl")
314 | samples_dumped = json.dumps(
315 | samples[task_name],
316 | indent=2,
317 | default=_handle_non_serializable,
318 | ensure_ascii=False,
319 | )
320 | filename.open("w").write(samples_dumped)
321 |
322 | print(
323 | f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
324 | f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
325 | )
326 | print(make_table(results))
327 | if "groups" in results:
328 | print(make_table(results, "groups"))
329 |
330 |
331 | cli_evaluate()
332 |
333 |
--------------------------------------------------------------------------------
/gptq/gptq.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import time
4 |
5 | import torch
6 | import torch.nn as nn
7 | import transformers
8 |
9 | from quant import *
10 |
11 |
12 | DEBUG = False
13 |
14 | torch.backends.cuda.matmul.allow_tf32 = False
15 | torch.backends.cudnn.allow_tf32 = False
16 |
17 |
18 | class GPTQ:
19 |
20 | def __init__(self, layer, stable=False):
21 | self.layer = layer
22 | self.dev = self.layer.weight.device
23 | W = layer.weight.data.clone()
24 | self.rows = W.shape[0]
25 | self.columns = W.shape[1]
26 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
27 | self.nsamples = 0
28 |
29 | self.stable = stable
30 | self.mean = torch.zeros((self.columns, 1), device=self.dev)
31 |
32 | def add_batch(self, inp, out):
33 | if DEBUG:
34 | self.inp1 = inp
35 | self.out1 = out
36 | if len(inp.shape) == 2:
37 | inp = inp.unsqueeze(0)
38 | tmp = inp.shape[0]
39 | if len(inp.shape) == 3:
40 | inp = inp.reshape((-1, inp.shape[-1]))
41 | inp = inp.t()
42 |
43 | if self.stable:
44 | inp = inp.float()
45 | delta = torch.mean(inp, 1, keepdims=True) - self.mean
46 | self.H += inp.matmul(inp.t()) + delta.matmul(delta.t()) * self.nsamples * tmp / (self.nsamples + tmp)
47 | self.nsamples += tmp
48 | self.mean += delta * tmp / self.nsamples
49 | else:
50 | self.H *= self.nsamples / (self.nsamples + tmp)
51 | self.nsamples += tmp
52 | inp = math.sqrt(2 / self.nsamples) * inp.float()
53 | self.H += inp.matmul(inp.t())
54 |
55 | def fasterquant(
56 | self, blocksize=128, percdamp=.1, groupsize=-1, clip=False, baseline=False
57 | ):
58 | W = self.layer.weight.data.clone()
59 | W = W.float()
60 |
61 | tick = time.time()
62 |
63 | if self.stable:
64 | self.H /= self.nsamples
65 | self.H += self.mean.matmul(self.mean.t())
66 | self.H *= 2
67 | H = self.H
68 | del self.H
69 |
70 | Losses = torch.zeros_like(W)
71 | Q = torch.zeros_like(W)
72 |
73 | if not baseline:
74 | try:
75 | damp = percdamp * torch.mean(torch.diag(H))
76 | diag = torch.arange(self.columns, device=self.dev)
77 | H[diag, diag] += damp
78 | H = torch.linalg.cholesky(H)
79 | H = torch.cholesky_inverse(H)
80 | H = torch.linalg.cholesky(H, upper=True)
81 | Hinv = H
82 | except:
83 | print('Singularity.')
84 | baseline = True
85 | if baseline:
86 | del H
87 | Hinv = torch.eye(self.columns, device=self.dev)
88 |
89 | if groupsize == -1:
90 | self.quantizer.find_params(W)
91 | groups = []
92 |
93 | for i1 in range(0, self.columns, blocksize):
94 | i2 = min(i1 + blocksize, self.columns)
95 | count = i2 - i1
96 |
97 | W1 = W[:, i1:i2].clone()
98 | Q1 = torch.zeros_like(W1)
99 | Err1 = torch.zeros_like(W1)
100 | Losses1 = torch.zeros_like(W1)
101 | Hinv1 = Hinv[i1:i2, i1:i2]
102 |
103 | if groupsize != -1:
104 | self.quantizer.find_params(W1, solve=Hinv1 if clip else None)
105 | groups.append(copy.deepcopy(self.quantizer))
106 |
107 | for i in range(count):
108 | w = W1[:, i]
109 | d = Hinv1[i, i]
110 |
111 | q = quantize(
112 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
113 | ).flatten()
114 |
115 | Q1[:, i] = q
116 | Losses1[:, i] = (w - q) ** 2 / d ** 2
117 |
118 | err1 = (w - q) / d
119 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
120 | Err1[:, i] = err1
121 |
122 | Q[:, i1:i2] = Q1
123 | Losses[:, i1:i2] = Losses1 / 2
124 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
125 |
126 | if DEBUG:
127 | self.layer.weight.data[:, :i2] = Q[:, :i2]
128 | self.layer.weight.data[:, i2:] = W[:, i2:]
129 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
130 | print(torch.sum(Losses))
131 |
132 | torch.cuda.synchronize()
133 | print('time %.2f' % (time.time() - tick))
134 | print('error', torch.sum(Losses).item())
135 |
136 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
137 | if DEBUG:
138 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
139 |
140 | if groups:
141 | scale = torch.cat([q.scale for q in groups], dim=1)
142 | zero = torch.cat([q.zero for q in groups], dim=1)
143 | return scale, zero
144 | return self.quantizer.scale, self.quantizer.zero
145 |
--------------------------------------------------------------------------------
/gptq/llama2.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from gptq import *
7 | from quant import *
8 | import marlin
9 |
10 |
11 | DEV = torch.device('cuda:0')
12 |
13 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
14 | if type(module) in layers:
15 | return {name: module}
16 | res = {}
17 | for name1, child in module.named_children():
18 | res.update(find_layers(
19 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
20 | ))
21 | return res
22 |
23 |
24 | def get_llama(name):
25 | import torch
26 | def skip(*args, **kwargs):
27 | pass
28 | torch.nn.init.kaiming_uniform_ = skip
29 | torch.nn.init.uniform_ = skip
30 | torch.nn.init.normal_ = skip
31 | from transformers import LlamaForCausalLM
32 | model = LlamaForCausalLM.from_pretrained(name, torch_dtype='auto')
33 | model.config.pretraining_tp = 1
34 | model.seqlen = 4096
35 | return model
36 |
37 | @torch.no_grad()
38 | def llama_sequential(model, dataloader, dev):
39 | print('Starting ...')
40 |
41 | use_cache = model.config.use_cache
42 | model.config.use_cache = False
43 | layers = model.model.layers
44 |
45 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
46 | model.model.norm = model.model.norm.to(dev)
47 | layers[0] = layers[0].to(dev)
48 |
49 | dtype = next(iter(model.parameters())).dtype
50 | inps = []
51 | attention_masks = []
52 | position_ids = []
53 |
54 | class Catcher(nn.Module):
55 | def __init__(self, module):
56 | super().__init__()
57 | self.module = module
58 | def forward(self, inp, **kwargs):
59 | inps.append(inp)
60 | attention_masks.append(kwargs['attention_mask'])
61 | position_ids.append(kwargs['position_ids'])
62 | raise ValueError
63 | layers[0] = Catcher(layers[0])
64 | for batch in dataloader:
65 | try:
66 | model(batch.to(dev))
67 | except ValueError:
68 | pass
69 | layers[0] = layers[0].module
70 |
71 | layers[0] = layers[0].cpu()
72 | model.model.embed_tokens = model.model.embed_tokens.cpu()
73 | model.model.norm = model.model.norm.cpu()
74 | torch.cuda.empty_cache()
75 |
76 | print('Ready.')
77 |
78 | quantizers = {}
79 | for i in range(len(layers)):
80 | layer = layers[i].to(dev)
81 | full = find_layers(layer)
82 |
83 | if args.true_sequential:
84 | sequential = [
85 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
86 | ['self_attn.o_proj'],
87 | ['mlp.up_proj', 'mlp.gate_proj'],
88 | ['mlp.down_proj']
89 | ]
90 | else:
91 | sequential = [list(full.keys())]
92 |
93 | for names in sequential:
94 | if model.config.num_attention_heads != model.config.num_key_value_heads and args.skip_gq:
95 | names.remove('self_attn.k_proj')
96 | names.remove('self_attn.v_proj')
97 |
98 | subset = {n: full[n] for n in names}
99 |
100 | gptq = {}
101 | for name in subset:
102 | gptq[name] = GPTQ(subset[name])
103 | gptq[name].quantizer = Quantizer()
104 | gptq[name].quantizer.configure(args.wbits)
105 |
106 | def add_batch(name):
107 | def tmp(_, inp, out):
108 | gptq[name].add_batch(inp[0].data, out.data)
109 | return tmp
110 | handles = []
111 | for name in subset:
112 | handles.append(subset[name].register_forward_hook(add_batch(name)))
113 | for j in range(args.nsamples):
114 | layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j])
115 | for h in handles:
116 | h.remove()
117 |
118 | for name in subset:
119 | print(i, name)
120 | print('Quantizing ...')
121 | res = gptq[name].fasterquant(
122 | percdamp=args.percdamp, groupsize=args.groupsize, clip=not args.no_clip, baseline=args.nearest
123 | )
124 | res = list(res)
125 | res[0] = res[0].cpu()
126 | res[1] = res[1].cpu()
127 | quantizers['model.layers.%d.%s' % (i, name)] = res
128 |
129 | for j in range(args.nsamples):
130 | inps[j] = layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j])[0]
131 |
132 | layers[i] = layer.cpu()
133 | del layer
134 | del gptq
135 | torch.cuda.empty_cache()
136 |
137 | model.config.use_cache = use_cache
138 | return quantizers
139 |
140 | @torch.no_grad()
141 | def llama_eval(model, dataloader, dev):
142 | print('Evaluating ...')
143 |
144 | nsamples = len(dataloader)
145 |
146 | use_cache = model.config.use_cache
147 | model.config.use_cache = False
148 | layers = model.model.layers
149 |
150 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
151 | layers[0] = layers[0].to(dev)
152 |
153 | dtype = next(iter(model.parameters())).dtype
154 | inps = []
155 | attention_masks = []
156 | position_ids = []
157 |
158 | class Catcher(nn.Module):
159 | def __init__(self, module):
160 | super().__init__()
161 | self.module = module
162 | def forward(self, inp, **kwargs):
163 | inps.append(inp)
164 | attention_masks.append(kwargs['attention_mask'])
165 | position_ids.append(kwargs['position_ids'])
166 | raise ValueError
167 | layers[0] = Catcher(layers[0])
168 | for batch in dataloader:
169 | try:
170 | model(batch.to(dev))
171 | except ValueError:
172 | pass
173 | layers[0] = layers[0].module
174 |
175 | layers[0] = layers[0].cpu()
176 | model.model.embed_tokens = model.model.embed_tokens.cpu()
177 | torch.cuda.empty_cache()
178 |
179 | for i in range(len(layers)):
180 | print(i)
181 | layer = layers[i].to(dev)
182 | for j in range(nsamples):
183 | inps[j] = layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j])[0]
184 | layers[i] = layer.cpu()
185 | del layer
186 | torch.cuda.empty_cache()
187 |
188 | if model.model.norm is not None:
189 | model.model.norm = model.model.norm.to(dev)
190 | model.lm_head = model.lm_head.to(dev)
191 |
192 | nlls = []
193 | for i in range(nsamples):
194 | hidden_states = inps[i]
195 | if model.model.norm is not None:
196 | hidden_states = model.model.norm(hidden_states)
197 | lm_logits = model.lm_head(hidden_states)
198 | shift_logits = lm_logits[:, :-1, :].contiguous()
199 | shift_labels = (dataloader[i].to(dev))[:, 1:]
200 | loss_fct = nn.CrossEntropyLoss()
201 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
202 | neg_log_likelihood = loss.float() * model.seqlen
203 | nlls.append(neg_log_likelihood)
204 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
205 | print(ppl.item())
206 |
207 | model.config.use_cache = use_cache
208 |
209 | def llama_pack(model, quantizers):
210 | layers = find_layers(model)
211 | layers = {n: layers[n] for n in quantizers}
212 | marlin.replace_linear(model, lambda n: n in quantizers, groupsize=args.groupsize)
213 | qlayers = find_layers(model, [marlin.Layer])
214 | print('Packing ...')
215 | for name in qlayers:
216 | print(name)
217 | qlayers[name].pack(layers[name].to(DEV), quantizers[name][0].to(DEV))
218 | qlayers[name].cpu()
219 | quantizers[name][0].cpu()
220 | layers[name].cpu()
221 | print('Done.')
222 | return model
223 |
224 |
225 | if __name__ == '__main__':
226 | import argparse
227 | from datautils import *
228 |
229 | parser = argparse.ArgumentParser()
230 |
231 | parser.add_argument(
232 | 'model', type=str,
233 | help='LlaMa model to load; pass location of hugginface converted checkpoint.'
234 | )
235 | parser.add_argument(
236 | '--dataset', type=str, default='red', choices=['red'],
237 | help='Where to extract calibration data from.'
238 | )
239 | parser.add_argument(
240 | '--seed',
241 | type=int, default=0, help='Seed for sampling the calibration data.'
242 | )
243 | parser.add_argument(
244 | '--nsamples', type=int, default=256,
245 | help='Number of calibration data samples.'
246 | )
247 | parser.add_argument(
248 | '--percdamp', type=float, default=.1,
249 | help='Percent of the average Hessian diagonal to use for dampening.'
250 | )
251 | parser.add_argument(
252 | '--nearest', action='store_true',
253 | help='Whether to run the RTN baseline.'
254 | )
255 | parser.add_argument(
256 | '--wbits', type=int, default=16, choices=[4, 16],
257 | help='#bits to use for quantization; use 16 for evaluating base model.'
258 | )
259 | parser.add_argument(
260 | '--groupsize', type=int, default=128, choices=[-1, 128],
261 | help='Groupsize to use for quantization; default is 128.'
262 | )
263 | parser.add_argument(
264 | '--true-sequential', action='store_true',
265 | help='Whether to run in true sequential model.'
266 | )
267 | parser.add_argument(
268 | '--no_clip', action='store_true',
269 | help='Whether to skip hessian based grid clipping when using groups.'
270 | )
271 | parser.add_argument(
272 | '--skip_gq', action='store_true',
273 | help='Whether to skip quantizing group keys and values for the 70B model with group-query attention.'
274 | )
275 | parser.add_argument(
276 | '--save', type=str, default='',
277 | help='Whether and where to save the quantized model.'
278 | )
279 |
280 | args = parser.parse_args()
281 |
282 | if args.nearest:
283 | args.nsamples = 0
284 |
285 | model = get_llama(args.model)
286 | model.eval()
287 |
288 | dataloader, testloader = get_loaders(
289 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
290 | )
291 |
292 | if args.wbits < 16:
293 | tick = time.time()
294 | quantizers = llama_sequential(model, dataloader, DEV)
295 | print(time.time() - tick)
296 |
297 | datasets = ['wikitext2', 'red']
298 | for dataset in datasets:
299 | dataloader, testloader = get_loaders(
300 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen
301 | )
302 | print(dataset)
303 | llama_eval(model, testloader, DEV)
304 |
305 | if args.save:
306 | args.save += '.marlin'
307 | if args.groupsize != -1:
308 | args.save += '.g%d' % args.groupsize
309 | llama_pack(model, quantizers)
310 | torch.save(model.state_dict(), args.save)
311 |
312 |
--------------------------------------------------------------------------------
/gptq/quant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def quantize(x, scale, zero, maxq):
7 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
8 | return scale * (q - zero)
9 |
10 | class Quantizer(nn.Module):
11 |
12 | def __init__(self, shape=1):
13 | super(Quantizer, self).__init__()
14 | self.register_buffer('maxq', torch.tensor(0))
15 | self.register_buffer('scale', torch.zeros(shape))
16 | self.register_buffer('zero', torch.zeros(shape))
17 |
18 | def configure(self, bits, sym=True, grid=100, maxshrink=.75):
19 | self.maxq = torch.tensor(2 ** bits - 1)
20 | self.sym = sym
21 | self.grid = grid
22 | self.maxshrink = maxshrink
23 |
24 | def find_params(self, x, solve=None, scales=None):
25 | dev = x.device
26 | self.maxq = self.maxq.to(dev)
27 |
28 | shape = x.shape
29 | x = x.flatten(1)
30 | if scales is not None:
31 | x *= scales
32 |
33 | tmp = torch.zeros(x.shape[0], device=dev)
34 | xmin = torch.minimum(x.min(1)[0], tmp)
35 | xmax = torch.maximum(x.max(1)[0], tmp)
36 |
37 | if self.sym:
38 | xmax = torch.maximum(torch.abs(xmin), xmax)
39 | tmp = xmin < 0
40 | if torch.any(tmp):
41 | xmin[tmp] = -xmax[tmp]
42 | tmp = (xmin == 0) & (xmax == 0)
43 | xmin[tmp] = -1
44 | xmax[tmp] = +1
45 |
46 | self.scale = (xmax - xmin) / self.maxq
47 | if self.sym:
48 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
49 | else:
50 | self.zero = torch.round(-xmin / self.scale)
51 |
52 | if solve is not None:
53 | best = torch.full([x.shape[0]], float('inf'), device=dev)
54 | for i in range(int(self.maxshrink * self.grid) + 1):
55 | p = 1 - i / self.grid
56 | clip = p * torch.max(xmax, torch.abs(xmin))
57 | xmax1 = torch.min(xmax, +clip)
58 | xmin1 = torch.max(xmin, -clip)
59 | scale1 = (xmax1 - xmin1) / self.maxq
60 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
61 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
62 | if scales is not None:
63 | q /= scales
64 | delta = q - x
65 | err = torch.sum(torch.linalg.solve_triangular(solve, delta, upper=True, left=False) ** 2, 1)
66 | tmp = err < best
67 | if torch.any(tmp):
68 | best[tmp] = err[tmp]
69 | self.scale[tmp] = scale1[tmp]
70 | self.zero[tmp] = zero1[tmp]
71 |
72 | shape = [-1] + [1] * (len(shape) - 1)
73 | self.scale = self.scale.reshape(shape)
74 | self.zero = self.zero.reshape(shape)
75 |
76 | def quantize(self, x):
77 | return quantize(x, self.scale, self.zero, self.maxq)
78 |
79 |
--------------------------------------------------------------------------------
/marlin/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import numpy as np
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | import marlin_cuda
22 |
23 | def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
24 | """Marlin FP16xINT4 multiply; can be used within `torch.compile`.
25 | @A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
26 | @B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
27 | @C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
28 | @s: `torch.half` scales of shape `(m / groupsize, n)`
29 | @workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero
30 | @thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1)
31 | @thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
32 | @sms: number of SMs to use for the kernel (can usually be left as auto -1)
33 | @max_par: maximum number of batch 64 problems to solve in parallel for large input sizes
34 | """
35 | marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par)
36 |
37 |
38 | # Precompute permutations for Marlin weight and scale shuffling
39 |
40 | def _get_perms():
41 | perm = []
42 | for i in range(32):
43 | perm1 = []
44 | col = i // 4
45 | for block in [0, 1]:
46 | for row in [
47 | 2 * (i % 4),
48 | 2 * (i % 4) + 1,
49 | 2 * (i % 4 + 4),
50 | 2 * (i % 4 + 4) + 1
51 | ]:
52 | perm1.append(16 * row + col + 8 * block)
53 | for j in range(4):
54 | perm.extend([p + 256 * j for p in perm1])
55 |
56 | perm = np.array(perm)
57 | interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
58 | perm = perm.reshape((-1, 8))[:, interleave].ravel()
59 | perm = torch.from_numpy(perm)
60 | scale_perm = []
61 | for i in range(8):
62 | scale_perm.extend([i + 8 * j for j in range(8)])
63 | scale_perm_single = []
64 | for i in range(4):
65 | scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
66 | return perm, scale_perm, scale_perm_single
67 |
68 | _perm, _scale_perm, _scale_perm_single = _get_perms()
69 |
70 |
71 | class Layer(nn.Module):
72 | """PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias."""
73 |
74 | def __init__(self, infeatures, outfeatures, groupsize=-1):
75 | """Create an empty Marlin layer.
76 | @infeatures: number of input features (must be divisible by 128)
77 | @outfeatures: number of output features (must be divisible by 256)
78 | @groupsize: quantization groupsize (must be -1 or 128)
79 | """
80 | super().__init__()
81 | if groupsize not in [-1, 128]:
82 | raise ValueError('Only groupsize -1 and 128 are supported.')
83 | if infeatures % 128 != 0 or outfeatures % 256 != 0:
84 | raise ValueError('`infeatures` must be divisible by 128 and `outfeatures` by 256.')
85 | if groupsize == -1:
86 | groupsize = infeatures
87 | if infeatures % groupsize != 0:
88 | raise ValueError('`infeatures` must be divisible by `groupsize`.')
89 | self.k = infeatures
90 | self.n = outfeatures
91 | self.groupsize = groupsize
92 | self.register_buffer('B', torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int))
93 | self.register_buffer('s', torch.empty((self.k // groupsize, self.n), dtype=torch.half))
94 | # 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par`
95 | self.register_buffer('workspace', torch.zeros(self.n // 128 * 16, dtype=torch.int), persistent=False)
96 |
97 | def forward(self, A):
98 | C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
99 | mul(A.view((-1, A.shape[-1])), self.B, C.view((-1, C.shape[-1])), self.s, self.workspace)
100 | return C
101 |
102 | def pack(self, linear, scales):
103 | """Pack a fake-quantized linear layer into this actual Marlin representation.
104 | @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`)
105 | @scales: corresponding quantization scales of shape `(infeatures, groups)`
106 | """
107 | if linear.weight.dtype != torch.half:
108 | raise ValueError('Only `torch.half` weights are supported.')
109 | tile = 16
110 | maxq = 2 ** 4 - 1
111 | s = scales.t()
112 | w = linear.weight.data.t()
113 | if self.groupsize != self.k:
114 | w = w.reshape((-1, self.groupsize, self.n))
115 | w = w.permute(1, 0, 2)
116 | w = w.reshape((self.groupsize, -1))
117 | s = s.reshape((1, -1))
118 | w = torch.round(w / s).int()
119 | w += (maxq + 1) // 2
120 | w = torch.clamp(w, 0, maxq)
121 | if self.groupsize != self.k:
122 | w = w.reshape((self.groupsize, -1, self.n))
123 | w = w.permute(1, 0, 2)
124 | w = w.reshape((self.k, self.n)).contiguous()
125 | s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
126 | else:
127 | s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
128 | s = s.reshape((-1, self.n)).contiguous()
129 | w = w.reshape((self.k // tile, tile, self.n // tile, tile))
130 | w = w.permute((0, 2, 1, 3))
131 | w = w.reshape((self.k // tile, self.n * tile))
132 | res = w
133 | res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
134 | q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
135 | res = res.cpu().numpy().astype(np.uint32)
136 | for i in range(8):
137 | q |= res[:, i::8] << 4 * i
138 | q = torch.from_numpy(q.astype(np.int32)).to(w.device)
139 | self.B[:, :] = q.to(self.B.device)
140 | self.s[:, :] = s.to(self.s.device)
141 |
142 |
143 | def replace_linear(module, name_filter=lambda n: True, groupsize=-1, name=''):
144 | """Recursively replace all `torch.nn.Linear` layers by empty Marlin layers.
145 | @module: top-level module in which to perform the replacement
146 | @name_filter: lambda indicating if a layer should be replaced
147 | @groupsize: marlin groupsize
148 | @name: root-level name
149 | """
150 | if isinstance(module, Layer):
151 | return
152 | for attr in dir(module):
153 | tmp = getattr(module, attr)
154 | name1 = name + '.' + attr if name != '' else attr
155 | if isinstance(tmp, nn.Linear) and name_filter(name1):
156 | setattr(
157 | module, attr, Layer(tmp.in_features, tmp.out_features, groupsize=groupsize)
158 | )
159 | for name1, child in module.named_children():
160 | replace_linear(child, name_filter, groupsize=groupsize, name=name + '.' + name1 if name != '' else name1)
161 |
--------------------------------------------------------------------------------
/marlin/marlin_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | int marlin_cuda(
24 | const void* A,
25 | const void* B,
26 | void* C,
27 | void* s,
28 | int prob_m,
29 | int prob_n,
30 | int prob_k,
31 | void* workspace,
32 | int groupsize = -1,
33 | int dev = 0,
34 | cudaStream_t stream = 0,
35 | int thread_k = -1,
36 | int thread_n = -1,
37 | int sms = -1,
38 | int max_par = 16
39 | );
40 |
41 | const int ERR_PROB_SHAPE = 1;
42 | const int ERR_KERN_SHAPE = 2;
43 |
44 | void mul(
45 | const torch::Tensor& A,
46 | const torch::Tensor& B,
47 | torch::Tensor& C,
48 | const torch::Tensor& s,
49 | torch::Tensor& workspace,
50 | int thread_k = -1,
51 | int thread_n = -1,
52 | int sms = -1,
53 | int max_par = 8
54 | ) {
55 | int prob_m = A.size(0);
56 | int prob_n = C.size(1);
57 | int prob_k = A.size(1);
58 | int groupsize = (s.size(0) == 1) ? -1 : prob_k / s.size(0);
59 | if (groupsize != -1 && groupsize * s.size(0) != prob_k)
60 | AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups.");
61 | if (workspace.numel() < prob_n / 128 * max_par)
62 | AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, ".");
63 | int dev = A.get_device();
64 | int err = marlin_cuda(
65 | A.data_ptr(),
66 | B.data_ptr(),
67 | C.data_ptr(),
68 | s.data_ptr(),
69 | prob_m, prob_n, prob_k,
70 | workspace.data_ptr(),
71 | groupsize,
72 | dev,
73 | at::cuda::getCurrentCUDAStream(dev),
74 | thread_k,
75 | thread_n,
76 | sms,
77 | max_par
78 | );
79 | if (err == ERR_PROB_SHAPE) {
80 | AT_ERROR(
81 | "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")",
82 | " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "."
83 | );
84 | } else if (err == ERR_KERN_SHAPE) {
85 | AT_ERROR(
86 | "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "."
87 | );
88 | }
89 | }
90 |
91 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
92 | m.def("mul", &mul, "Marlin FP16xINT4 matmul.");
93 | }
94 |
--------------------------------------------------------------------------------
/marlin/marlin_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 |
18 | #ifndef MARLIN_CUDA_KERNEL_CUH
19 | #define MARLIN_CUDA_KERNEL_CUH
20 |
21 |
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 |
28 | constexpr int ceildiv(int a, int b) {
29 | return (a + b - 1) / b;
30 | }
31 |
32 | // Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core
33 | // operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we
34 | // extensively use `#pragma unroll` throughout the kernel code to guarantee this.
35 | template
36 | struct Vec {
37 | T elems[n];
38 | __device__ T& operator[](int i) {
39 | return elems[i];
40 | }
41 | };
42 |
43 | using I4 = Vec;
44 |
45 | // Matrix fragments for tensor core instructions; their precise layout is documented here:
46 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
47 | using FragA = Vec;
48 | using FragB = Vec;
49 | using FragC = Vec;
50 | using FragS = Vec; // quantization scales
51 |
52 | // Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that
53 | // are not multiples of 16.
54 | __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
55 | const int BYTES = 16;
56 | uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
57 | asm volatile(
58 | "{\n"
59 | " .reg .pred p;\n"
60 | " setp.ne.b32 p, %0, 0;\n"
61 | " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
62 | "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)
63 | );
64 | }
65 |
66 | // Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for
67 | // quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
68 | // for inputs A and outputs C.
69 | __device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
70 | const int BYTES = 16;
71 | uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
72 | asm volatile(
73 | "{\n"
74 | " .reg .b64 p;\n"
75 | " createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
76 | " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
77 | "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
78 | );
79 | }
80 |
81 | // Async copy fence.
82 | __device__ inline void cp_async_fence() {
83 | asm volatile("cp.async.commit_group;\n" ::);
84 | }
85 |
86 | // Wait until at most `n` async copy stages are still pending.
87 | template
88 | __device__ inline void cp_async_wait() {
89 | asm volatile("cp.async.wait_group %0;\n" :: "n"(n));
90 | }
91 |
92 | // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
93 | __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {
94 | const uint32_t* a = reinterpret_cast(&a_frag);
95 | const uint32_t* b = reinterpret_cast(&frag_b);
96 | float* c = reinterpret_cast(&frag_c);
97 | asm volatile(
98 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
99 | "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
100 | : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
101 | : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
102 | "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])
103 | );
104 | }
105 |
106 | // Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
107 | __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
108 | uint32_t* a = reinterpret_cast(&frag_a);
109 | uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
110 | asm volatile(
111 | "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
112 | : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)
113 | );
114 | }
115 |
116 | // Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
117 | // automatically recognize it in all cases.
118 | template
119 | __device__ inline int lop3(int a, int b, int c) {
120 | int res;
121 | asm volatile(
122 | "lop3.b32 %0, %1, %2, %3, %4;\n"
123 | : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)
124 | );
125 | return res;
126 | }
127 |
128 | // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.
129 | // We mostly follow the strategy in the link below, with some small changes:
130 | // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
131 | __device__ inline FragB dequant(int q) {
132 | const int LO = 0x000f000f;
133 | const int HI = 0x00f000f0;
134 | const int EX = 0x64006400;
135 | // Guarantee that the `(a & b) | c` operations are LOP3s.
136 | int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
137 | int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
138 | // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`.
139 | const int SUB = 0x64086408;
140 | const int MUL = 0x2c002c00;
141 | const int ADD = 0xd480d480;
142 | FragB frag_b;
143 | frag_b[0] = __hsub2(
144 | *reinterpret_cast(&lo),
145 | *reinterpret_cast(&SUB)
146 | );
147 | frag_b[1] = __hfma2(
148 | *reinterpret_cast(&hi),
149 | *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)
150 | );
151 | return frag_b;
152 | }
153 |
154 | // Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.
155 | __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
156 | half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
157 | frag_b[0] = __hmul2(frag_b[0], s);
158 | frag_b[1] = __hmul2(frag_b[1], s);
159 | }
160 |
161 | // Wait until barrier reaches `count`, then lock for current threadblock.
162 | __device__ inline void barrier_acquire(int* lock, int count) {
163 | if (threadIdx.x == 0) {
164 | int state = -1;
165 | do
166 | // Guarantee that subsequent writes by this threadblock will be visible globally.
167 | asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
168 | while (state != count);
169 | }
170 | __syncthreads();
171 | }
172 |
173 | // Release barrier and increment visitation count.
174 | __device__ inline void barrier_release(int* lock, bool reset = false) {
175 | __syncthreads();
176 | if (threadIdx.x == 0) {
177 | if (reset) {
178 | lock[0] = 0;
179 | return;
180 | }
181 | int val = 1;
182 | // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier.
183 | asm volatile ("fence.acq_rel.gpu;\n");
184 | asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
185 | }
186 | }
187 |
188 |
189 | template <
190 | const int threads, // number of threads in a threadblock
191 | const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock
192 | const int thread_n_blocks, // same for n dimension (output)
193 | const int thread_k_blocks, // same for k dimension (reduction)
194 | const int stages, // number of stages for the async global->shared fetch pipeline
195 | const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale
196 | >
197 | __global__ void Marlin(
198 | const int4* __restrict__ A, // fp16 input matrix of shape mxk
199 | const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
200 | int4* __restrict__ C, // fp16 output buffer of shape mxn
201 | const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
202 | int prob_m, // batch dimension m
203 | int prob_n, // output dimension n
204 | int prob_k, // reduction dimension k
205 | int* locks // extra global storage for barrier synchronization
206 | ) {
207 | // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple
208 | // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:
209 | // 0 1 3
210 | // 0 2 3
211 | // 1 2 4
212 | // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs
213 | // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as
214 | // possible.
215 |
216 | // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions
217 | int parallel = 1;
218 | if (prob_m > 16 * thread_m_blocks) {
219 | parallel = prob_m / (16 * thread_m_blocks);
220 | prob_m = 16 * thread_m_blocks;
221 | }
222 |
223 | int k_tiles = prob_k / 16 / thread_k_blocks;
224 | int n_tiles = prob_n / 16 / thread_n_blocks;
225 | int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
226 | // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case
227 | // where a stripe starts in the middle of group.
228 | if (group_blocks != -1)
229 | iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks));
230 |
231 | int slice_row = (iters * blockIdx.x) % k_tiles;
232 | int slice_col_par = (iters * blockIdx.x) / k_tiles;
233 | int slice_col = slice_col_par;
234 | int slice_iters; // number of threadblock tiles in the current slice
235 | int slice_count = 0; // total number of active threadblocks in the current slice
236 | int slice_idx; // index of threadblock in current slice; numbered bottom to top
237 |
238 | // We can easily implement parallel problem execution by just remapping indices and advancing global pointers
239 | if (slice_col_par >= n_tiles) {
240 | A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
241 | C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
242 | locks += (slice_col_par / n_tiles) * n_tiles;
243 | slice_col = slice_col_par % n_tiles;
244 | }
245 |
246 | // Compute all information about the current slice which is required for synchronization.
247 | auto init_slice = [&] () {
248 | slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
249 | if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
250 | slice_iters = 0;
251 | if (slice_iters == 0)
252 | return;
253 | if (slice_row + slice_iters > k_tiles)
254 | slice_iters = k_tiles - slice_row;
255 | slice_count = 1;
256 | slice_idx = 0;
257 | int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
258 | if (col_first <= k_tiles * (slice_col_par + 1)) {
259 | int col_off = col_first - k_tiles * slice_col_par;
260 | slice_count = ceildiv(k_tiles - col_off, iters);
261 | if (col_off > 0)
262 | slice_count++;
263 | int delta_first = iters * blockIdx.x - col_first;
264 | if (delta_first < 0 || (col_off == 0 && delta_first == 0))
265 | slice_idx = slice_count - 1;
266 | else {
267 | slice_idx = slice_count - 1 - delta_first / iters;
268 | if (col_off > 0)
269 | slice_idx--;
270 | }
271 | }
272 | if (slice_col == n_tiles) {
273 | A += 16 * thread_m_blocks * prob_k / 8;
274 | C += 16 * thread_m_blocks * prob_n / 8;
275 | locks += n_tiles;
276 | slice_col = 0;
277 | }
278 | };
279 | init_slice();
280 |
281 | int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
282 | // We typically use `constexpr` to indicate that this value is a compile-time constant
283 | constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
284 | constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory
285 | int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
286 | constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
287 | constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads
288 | constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile
289 | constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
290 | constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile
291 |
292 | int b_gl_stride = 16 * prob_n / 32;
293 | constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
294 | int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
295 | int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
296 | constexpr int b_sh_wr_delta = threads;
297 | constexpr int b_sh_rd_delta = threads;
298 | constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
299 | constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
300 |
301 | int s_gl_stride = prob_n / 8;
302 | constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
303 | constexpr int s_sh_stage = s_sh_stride;
304 | int s_gl_rd_delta = s_gl_stride;
305 |
306 | // Global A read index of current thread.
307 | int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
308 | a_gl_rd += a_gl_rd_delta_o * slice_row;
309 | // Shared write index of current thread.
310 | int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
311 | // Shared read index.
312 | int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
313 | a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
314 |
315 | int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
316 | b_gl_rd += b_sh_stride * slice_col;
317 | b_gl_rd += b_gl_rd_delta_o * slice_row;
318 | int b_sh_wr = threadIdx.x;
319 | int b_sh_rd = threadIdx.x;
320 |
321 | int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
322 | int s_sh_wr = threadIdx.x;
323 | int s_sh_rd;
324 | // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major
325 | // layout in the former and in row-major in the latter case.
326 | if (group_blocks != -1)
327 | s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
328 | else
329 | s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
330 |
331 | // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than
332 | // required for a certain tilesize or when the batchsize is not a multiple of 16.
333 | bool a_sh_wr_pred[a_sh_wr_iters];
334 | #pragma unroll
335 | for (int i = 0; i < a_sh_wr_iters; i++)
336 | a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
337 | bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
338 |
339 | // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank
340 | // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of
341 | // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based
342 | // on NSight-Compute) that each warp must also write a consecutive memory segment?
343 | auto transform_a = [&] (int i) {
344 | int row = i / a_gl_rd_delta_o;
345 | return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
346 | };
347 | // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory
348 | // accesses are static, we simply precompute both transformed reads and writes.
349 | int a_sh_wr_trans[a_sh_wr_iters];
350 | #pragma unroll
351 | for (int i = 0; i < a_sh_wr_iters; i++)
352 | a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
353 | int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
354 | #pragma unroll
355 | for (int i = 0; i < b_sh_wr_iters; i++) {
356 | #pragma unroll
357 | for (int j = 0; j < thread_m_blocks; j++)
358 | a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
359 | }
360 |
361 | // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between
362 | // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization.
363 | const int4* B_ptr[b_sh_wr_iters];
364 | #pragma unroll
365 | for (int i = 0; i < b_sh_wr_iters; i++)
366 | B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
367 |
368 | extern __shared__ int4 sh[];
369 | // Shared memory storage for global fetch pipelines.
370 | int4* sh_a = sh;
371 | int4* sh_b = sh_a + (stages * a_sh_stage);
372 | int4* sh_s = sh_b + (stages * b_sh_stage);
373 | // Register storage for double buffer of shared memory reads.
374 | FragA frag_a[2][thread_m_blocks];
375 | I4 frag_b_quant[2];
376 | FragC frag_c[thread_m_blocks][4][2];
377 | FragS frag_s[2][4];
378 |
379 | // Zero accumulators.
380 | auto zero_accums = [&] () {
381 | #pragma unroll
382 | for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
383 | reinterpret_cast(frag_c)[i] = 0;
384 | };
385 |
386 | // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location.
387 | auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) {
388 | if (pred) {
389 | int4* sh_a_stage = sh_a + a_sh_stage * pipe;
390 | #pragma unroll
391 | for (int i = 0; i < a_sh_wr_iters; i++) {
392 | cp_async4_pred(
393 | &sh_a_stage[a_sh_wr_trans[i]],
394 | &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
395 | a_sh_wr_pred[i]
396 | );
397 | }
398 | int4* sh_b_stage = sh_b + b_sh_stage * pipe;
399 | #pragma unroll
400 | for (int i = 0; i < b_sh_wr_iters; i++) {
401 | cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
402 | B_ptr[i] += b_gl_rd_delta_o;
403 | }
404 | // Only fetch scales if this tile starts a new group
405 | if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
406 | int4* sh_s_stage = sh_s + s_sh_stage * pipe;
407 | if (s_sh_wr_pred)
408 | cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
409 | s_gl_rd += s_gl_rd_delta;
410 | }
411 | }
412 | // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point.
413 | cp_async_fence();
414 | };
415 |
416 | // Wait until the next thread tile has been loaded to shared memory.
417 | auto wait_for_stage = [&] () {
418 | // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when
419 | // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten).
420 | cp_async_wait();
421 | __syncthreads();
422 | };
423 |
424 | // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer.
425 | auto fetch_to_registers = [&] (int k, int pipe) {
426 | // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a
427 | // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the
428 | // compiler and correspondingly a noticable drop in performance.
429 | if (group_blocks != -1) {
430 | int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
431 | reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
432 | }
433 | int4* sh_a_stage = sh_a + a_sh_stage * pipe;
434 | #pragma unroll
435 | for (int i = 0; i < thread_m_blocks; i++)
436 | ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
437 | int4* sh_b_stage = sh_b + b_sh_stage * pipe;
438 | frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
439 | };
440 |
441 | // Execute the actual tensor core matmul of a sub-tile.
442 | auto matmul = [&] (int k) {
443 | // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations.
444 | #pragma unroll
445 | for (int j = 0; j < 4; j++) {
446 | int b_quant = frag_b_quant[k % 2][j];
447 | int b_quant_shift = b_quant >> 8;
448 | FragB frag_b0 = dequant(b_quant);
449 | // If there are no groups, we can just scale the final output once and can avoid doing so for each weight.
450 | if (group_blocks != -1)
451 | scale(frag_b0, frag_s[k % 2][j], 0);
452 | FragB frag_b1 = dequant(b_quant_shift);
453 | if (group_blocks != -1)
454 | scale(frag_b1, frag_s[k % 2][j], 1);
455 | #pragma unroll
456 | for (int i = 0; i < thread_m_blocks; i++) {
457 | mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
458 | mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
459 | }
460 | }
461 | };
462 |
463 | // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n
464 | // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output
465 | // location; which we have to reduce over in the end. We do in shared memory.
466 | auto thread_block_reduce = [&] () {
467 | constexpr int red_off = threads / b_sh_stride / 2;
468 | if (red_off >= 1) {
469 | int red_idx = threadIdx.x / b_sh_stride;
470 | constexpr int red_sh_stride = b_sh_stride * 4 * 2;
471 | constexpr int red_sh_delta = b_sh_stride;
472 | int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
473 |
474 | // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations,
475 | // e.g., for two warps we write only once by warp 1 and read only once by warp 0.
476 |
477 | #pragma unroll
478 | for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
479 | #pragma unroll
480 | for (int i = red_off; i > 0; i /= 2) {
481 | if (i <= red_idx && red_idx < 2 * i) {
482 | #pragma unroll
483 | for (int j = 0; j < 4 * 2; j++) {
484 | int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
485 | if (i < red_off) {
486 | float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]);
487 | float* c_wr = reinterpret_cast(&sh[red_sh_wr]);
488 | #pragma unroll
489 | for (int k = 0; k < 4; k++)
490 | reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
491 | }
492 | sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j];
493 | }
494 | }
495 | __syncthreads();
496 | }
497 | if (red_idx == 0) {
498 | #pragma unroll
499 | for (int i = 0; i < 4 * 2; i++) {
500 | float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]);
501 | #pragma unroll
502 | for (int j = 0; j < 4; j++)
503 | reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
504 | }
505 | }
506 | __syncthreads();
507 | }
508 | }
509 | };
510 |
511 | // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over
512 | // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather
513 | // small, we perform this reduction serially in L2 cache.
514 | auto global_reduce = [&] (bool first = false, bool last = false) {
515 | // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step.
516 | // To do this, we write out results in FP16 (but still reduce with FP32 compute).
517 | constexpr int active_threads = 32 * thread_n_blocks / 4;
518 | if (threadIdx.x < active_threads) {
519 | int c_gl_stride = prob_n / 8;
520 | int c_gl_wr_delta_o = 8 * c_gl_stride;
521 | int c_gl_wr_delta_i = 4 * (active_threads / 32);
522 | int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;
523 | c_gl_wr += (2 * thread_n_blocks) * slice_col;
524 | constexpr int c_sh_wr_delta = active_threads;
525 | int c_sh_wr = threadIdx.x;
526 |
527 | int row = (threadIdx.x % 32) / 4;
528 |
529 | if (!first) {
530 | // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,
531 | // hence we also use async-copies even though these fetches are not actually asynchronous.
532 | #pragma unroll
533 | for (int i = 0; i < thread_m_blocks * 4; i++) {
534 | cp_async4_pred(
535 | &sh[c_sh_wr + c_sh_wr_delta * i],
536 | &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
537 | i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
538 | );
539 | }
540 | cp_async_fence();
541 | cp_async_wait<0>();
542 | }
543 |
544 | #pragma unroll
545 | for (int i = 0; i < thread_m_blocks * 4; i++) {
546 | if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
547 | if (!first) {
548 | int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
549 | #pragma unroll
550 | for (int j = 0; j < 2 * 4; j++) {
551 | reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
552 | reinterpret_cast<__half*>(&c_red)[j]
553 | );
554 | }
555 | }
556 | if (!last) {
557 | int4 c;
558 | #pragma unroll
559 | for (int j = 0; j < 2 * 4; j++) {
560 | reinterpret_cast<__half*>(&c)[j] = __float2half(
561 | reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]
562 | );
563 | }
564 | C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c;
565 | }
566 | }
567 | }
568 | }
569 | };
570 |
571 | // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step,
572 | // the reduction above is performed in fragment layout.
573 | auto write_result = [&] () {
574 | int c_gl_stride = prob_n / 8;
575 | constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
576 | int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
577 | constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));
578 |
579 | int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
580 | c_gl_wr += (2 * thread_n_blocks) * slice_col;
581 | int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
582 | c_sh_wr += 32 * (threadIdx.x / 32);
583 | int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
584 |
585 | int c_gl_wr_end = c_gl_stride * prob_m;
586 |
587 | // We first reorder in shared memory to guarantee the most efficient final global write patterns
588 | auto write = [&] (int idx, float c0, float c1, FragS& s) {
589 | half2 res = __halves2half2(__float2half(c0), __float2half(c1));
590 | if (group_blocks == -1) // for per-column quantization we finally apply the scale here
591 | res = __hmul2(res, s[0]);
592 | ((half2*) sh)[idx] = res;
593 | };
594 | if (threadIdx.x / 32 < thread_n_blocks / 4) {
595 | #pragma unroll
596 | for (int i = 0; i < thread_m_blocks; i++) {
597 | #pragma unroll
598 | for (int j = 0; j < 4; j++) {
599 | int wr = c_sh_wr + 8 * j;
600 | write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
601 | write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
602 | write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
603 | write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
604 | }
605 | c_sh_wr += 16 * (4 * c_sh_stride);
606 | }
607 | }
608 | __syncthreads();
609 |
610 | #pragma unroll
611 | for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
612 | if (c_gl_wr < c_gl_wr_end) {
613 | C[c_gl_wr] = sh[c_sh_rd];
614 | c_gl_wr += c_gl_wr_delta;
615 | c_sh_rd += c_sh_rd_delta;
616 | }
617 | }
618 | };
619 |
620 | // Start global fetch and register load pipelines.
621 | auto start_pipes = [&] () {
622 | #pragma unroll
623 | for (int i = 0; i < stages - 1; i++)
624 | fetch_to_shared(i, i, i < slice_iters);
625 | zero_accums();
626 | wait_for_stage();
627 | fetch_to_registers(0, 0);
628 | a_gl_rd += a_gl_rd_delta_o * (stages - 1);
629 | };
630 | start_pipes();
631 |
632 | // Main loop.
633 | while (slice_iters) {
634 | // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are
635 | // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0.
636 | #pragma unroll
637 | for (int pipe = 0; pipe < stages;) {
638 | #pragma unroll
639 | for (int k = 0; k < b_sh_wr_iters; k++) {
640 | fetch_to_registers(k + 1, pipe % stages);
641 | if (k == b_sh_wr_iters - 2) {
642 | fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);
643 | pipe++;
644 | wait_for_stage();
645 | }
646 | matmul(k);
647 | }
648 | slice_iters--;
649 | if (slice_iters == 0)
650 | break;
651 | }
652 | a_gl_rd += a_gl_rd_delta_o * stages;
653 |
654 | // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most
655 | // readable, other ways of writing the loop seemed to noticeably worse performance after compliation.
656 | if (slice_iters == 0) {
657 | cp_async_wait<0>();
658 | bool last = slice_idx == slice_count - 1;
659 | // For per-column scales, we only fetch them here in the final step before write-out
660 | if (group_blocks == -1 && last) {
661 | if (s_sh_wr_pred)
662 | cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
663 | cp_async_fence();
664 | }
665 | thread_block_reduce();
666 | if (group_blocks == -1 && last) {
667 | cp_async_wait<0>();
668 | __syncthreads();
669 | if (threadIdx.x / 32 < thread_n_blocks / 4) {
670 | reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0];
671 | reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4];
672 | }
673 | }
674 | if (slice_count > 1) { // only globally reduce if there is more than one block in a slice
675 | barrier_acquire(&locks[slice_col], slice_idx);
676 | global_reduce(slice_idx == 0, last);
677 | barrier_release(&locks[slice_col], last);
678 | }
679 | if (last) // only the last block in a slice actually writes the result
680 | write_result();
681 | slice_row = 0;
682 | slice_col_par++;
683 | slice_col++;
684 | init_slice();
685 | if (slice_iters) {
686 | a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
687 | #pragma unroll
688 | for (int i = 0; i < b_sh_wr_iters; i++)
689 | B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
690 | if (slice_col == 0) {
691 | #pragma unroll
692 | for (int i = 0; i < b_sh_wr_iters; i++)
693 | B_ptr[i] -= b_gl_stride;
694 | }
695 | s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
696 | start_pipes();
697 | }
698 | }
699 | }
700 | }
701 |
702 |
703 | // 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more
704 | // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
705 | const int THREADS = 256;
706 | const int STAGES = 4; // 4 pipeline stages fit into shared memory
707 | const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
708 |
709 | #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
710 | else if ( \
711 | thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \
712 | group_blocks == GROUP_BLOCKS \
713 | ) { \
714 | cudaFuncSetAttribute( \
715 | Marlin, \
716 | cudaFuncAttributeMaxDynamicSharedMemorySize, \
717 | SHARED_MEM \
718 | ); \
719 | Marlin< \
720 | THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \
721 | ><<>>( \
722 | A_ptr, B_ptr, C_ptr, s_ptr, \
723 | prob_m, prob_n, prob_k, \
724 | locks \
725 | ); \
726 | }
727 |
728 | const int ERR_PROB_SHAPE = 1;
729 | const int ERR_KERN_SHAPE = 2;
730 |
731 | int marlin_cuda(
732 | const void* A,
733 | const void* B,
734 | void* C,
735 | void* s,
736 | int prob_m,
737 | int prob_n,
738 | int prob_k,
739 | void* workspace,
740 | int groupsize = -1,
741 | int dev = 0,
742 | cudaStream_t stream = 0,
743 | int thread_k = -1,
744 | int thread_n = -1,
745 | int sms = -1,
746 | int max_par = 16
747 | ) {
748 | int tot_m = prob_m;
749 | int tot_m_blocks = ceildiv(tot_m, 16);
750 | int pad = 16 * tot_m_blocks - tot_m;
751 |
752 | if (sms == -1)
753 | cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
754 | if (thread_k == -1 || thread_n == -1) {
755 | if (prob_m <= 16) {
756 | // For small batchizes, better partioning is slightly more important than better compute utilization
757 | thread_k = 128;
758 | thread_n = 128;
759 | } else {
760 | thread_k = 64;
761 | thread_n = 256;
762 | }
763 | }
764 |
765 | int thread_k_blocks = thread_k / 16;
766 | int thread_n_blocks = thread_n / 16;
767 | int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
768 | int blocks = sms;
769 |
770 | if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0))
771 | return ERR_PROB_SHAPE;
772 | if (prob_m == 0 || prob_n == 0 || prob_k == 0)
773 | return 0;
774 |
775 | const int4* A_ptr = (const int4*) A;
776 | const int4* B_ptr = (const int4*) B;
777 | int4* C_ptr = (int4*) C;
778 | const int4* s_ptr = (const int4*) s;
779 |
780 | int cols = prob_n / thread_n;
781 | int* locks = (int*) workspace;
782 |
783 | int ret = 0;
784 | for (int i = 0; i < tot_m_blocks; i += 4) {
785 | int thread_m_blocks = tot_m_blocks - i;
786 | prob_m = tot_m - 16 * i;
787 | int par = 1;
788 | if (thread_m_blocks > 4) {
789 | // Note that parallel > 1 currently only works for inputs without any padding
790 | par = (16 * thread_m_blocks - pad) / 64;
791 | if (par > max_par)
792 | par = max_par;
793 | prob_m = 64 * par;
794 | i += 4 * (par - 1);
795 | thread_m_blocks = 4;
796 | }
797 |
798 | // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)
799 | // in our testing, however many more are, in principle, possible.
800 | if (false) {}
801 | CALL_IF(1, 8, 8, -1)
802 | CALL_IF(1, 8, 8, 8)
803 | CALL_IF(1, 16, 4, -1)
804 | CALL_IF(1, 16, 4, 8)
805 | CALL_IF(2, 16, 4, -1)
806 | CALL_IF(2, 16, 4, 8)
807 | CALL_IF(3, 16, 4, -1)
808 | CALL_IF(3, 16, 4, 8)
809 | CALL_IF(4, 16, 4, -1)
810 | CALL_IF(4, 16, 4, 8)
811 | else
812 | ret = ERR_KERN_SHAPE;
813 |
814 | A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
815 | C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
816 | }
817 |
818 | return ret;
819 | }
820 |
821 |
822 | #endif
823 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils import cpp_extension
3 |
4 | setup(
5 | name='marlin',
6 | version='0.1.1',
7 | author='Elias Frantar',
8 | author_email='elias.frantar@ist.ac.at',
9 | description='Highly optimized FP16xINT4 CUDA matmul kernel.',
10 | install_requires=['numpy', 'torch'],
11 | packages=['marlin'],
12 | ext_modules=[cpp_extension.CUDAExtension(
13 | 'marlin_cuda', ['marlin/marlin_cuda.cpp', 'marlin/marlin_cuda_kernel.cu']
14 | )],
15 | cmdclass={'build_ext': cpp_extension.BuildExtension},
16 | )
17 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | import marlin
8 |
9 |
10 | seed = 0
11 | np.random.seed(seed)
12 | torch.random.manual_seed(seed)
13 |
14 | DEV = torch.device('cuda:0')
15 |
16 |
17 | def gen_quant4(m, n, groupsize=-1):
18 | tile = 16
19 | maxq = 2 ** 4 - 1
20 | w = torch.randn((m, n), dtype=torch.half, device=DEV)
21 | if groupsize != -1:
22 | w = w.reshape((-1, groupsize, n))
23 | w = w.permute(1, 0, 2)
24 | w = w.reshape((groupsize, -1))
25 | s = torch.max(torch.abs(w), 0, keepdim=True)[0]
26 | s *= 2 / maxq
27 | w = torch.round(w / s).int()
28 | w += (maxq + 1) // 2
29 | w = torch.clamp(w, 0, maxq)
30 | ref = (w - (maxq + 1) // 2).half() * s
31 | if groupsize != -1:
32 | def reshape(w):
33 | w = w.reshape((groupsize, -1, n))
34 | w = w.permute(1, 0, 2)
35 | w = w.reshape((m, n)).contiguous()
36 | return w
37 | ref = reshape(ref)
38 | w = reshape(w)
39 | s = s.reshape((-1, n)).contiguous()
40 | linear = nn.Linear(m, n)
41 | linear.weight.data = ref.t()
42 | # Workaround to test some special cases that are forbidden by the API
43 | layer = marlin.Layer(256, 256, groupsize=groupsize)
44 | if groupsize == -1:
45 | groupsize = m
46 | layer.k = m
47 | layer.n = n
48 | layer.groupsize = groupsize
49 | layer.B = torch.empty((m // 16, n * 16 // 8), dtype=torch.int, device=DEV)
50 | layer.s = torch.empty((m // groupsize, n), dtype=torch.half, device=DEV)
51 | layer.pack(linear, s.t())
52 | q = layer.B
53 | s = layer.s
54 | return ref, q, s
55 |
56 | class Test(unittest.TestCase):
57 |
58 | def run_problem(self, m, n, k, thread_k, thread_n, groupsize=-1):
59 | print('% 5d % 6d % 6d % 4d % 4d % 4d' % (m, n, k, thread_k, thread_n, groupsize))
60 | A = torch.randn((m, k), dtype=torch.half, device=DEV)
61 | B_ref, B, s = gen_quant4(k, n, groupsize=groupsize)
62 | C = torch.zeros((m, n), dtype=torch.half, device=DEV)
63 | C_ref = torch.matmul(A, B_ref)
64 | workspace = torch.zeros(n // 128 * 16, device=DEV)
65 | marlin.mul(A, B, C, s, workspace, thread_k, thread_n, -1)
66 | torch.cuda.synchronize()
67 | self.assertLess(torch.mean(torch.abs(C - C_ref)) / torch.mean(torch.abs(C_ref)), 0.001)
68 |
69 | def test_tiles(self):
70 | print()
71 | for m in [1, 2, 3, 4, 8, 12, 16, 24, 32, 48, 64, 118, 128, 152, 768, 1024]:
72 | for thread_k, thread_n in [(64, 256), (128, 128)]:
73 | if m > 16 and thread_k == 128:
74 | continue
75 | self.run_problem(m, 2 * 256, 1024, thread_k, thread_n)
76 |
77 | def test_k_stages_divisibility(self):
78 | print()
79 | for k in [3 * 64 + 64 * 4 * 2 + 64 * i for i in range(1, 4)]:
80 | self.run_problem(16, 2 * 256, k, 64, 256)
81 |
82 | def test_very_few_stages(self):
83 | print()
84 | for k in [64, 128, 192]:
85 | self.run_problem(16, 2 * 256, k, 64, 256)
86 |
87 | def test_llama_shapes(self):
88 | print()
89 | return
90 | MODELS = {
91 | ' 7B': [
92 | (4096, 3 * 4096),
93 | (4096, 4096),
94 | (4096, 2 * 10752),
95 | (10752, 4096)
96 | ],
97 | '13B': [
98 | (5120, 3 * 5120),
99 | (5120, 5120),
100 | (5120, 2 * 13568),
101 | (13568, 5120)
102 | ],
103 | '33B': [
104 | (6656, 3 * 6656),
105 | (6656, 6656),
106 | (6656, 2 * 17664),
107 | (17664, 6656)
108 | ],
109 | '70B': [
110 | (8192, 3 * 8192),
111 | (8192, 8192),
112 | (8192, 2 * 21760),
113 | (21760, 8192)
114 | ]
115 | }
116 | for _, layers in MODELS.items():
117 | for layer in layers:
118 | for thread_k, thread_n in [(128, 128)]:
119 | for batch in [1, 16]:
120 | self.run_problem(batch, layer[1], layer[0], thread_k, thread_n)
121 |
122 | def test_errors(self):
123 | print()
124 | m, n, k = 16, 256, 64
125 | A = torch.randn((m, k), dtype=torch.half, device=DEV)
126 | B_ref, B, s = gen_quant4(k, n)
127 | C = torch.zeros((m, n), dtype=torch.half, device=DEV)
128 | workspace = torch.zeros(n // 128, device=DEV)
129 | err = False
130 | try:
131 | marlin.mul(A, B, C, s, workspace, 128, 128, -1)
132 | except:
133 | err = True
134 | self.assertTrue(err)
135 | err = False
136 | try:
137 | marlin.mul(A, B, C, s, workspace, 256, 256, -1)
138 | except:
139 | err = True
140 | self.assertTrue(err)
141 | s = torch.zeros((2, n), dtype=torch.half, device=DEV)
142 | err = False
143 | try:
144 | marlin.mul(A, B, C, s, workspace, 256, 256, -1)
145 | except:
146 | err = True
147 | self.assertTrue(err)
148 |
149 | def test_groups(self):
150 | print()
151 | for m in [16]:
152 | for groupsize in [128]:
153 | for n, k in [(256, 512), (256, 1024), (256 * 128, 1024)]:
154 | for thread_shape in [(128, 128), (64, 256)]:
155 | self.run_problem(m, n, k, *thread_shape, groupsize)
156 |
157 |
158 | if __name__ == '__main__':
159 | unittest.main()
160 |
--------------------------------------------------------------------------------