├── .gitignore
├── .style.yapf
├── LICENSE.txt
├── README.md
├── convert_llama_weights_to_hf.py
├── gptq.py
├── llama.py
├── llama_inference.py
├── llama_inference_offload.py
├── neox.py
├── opt.py
├── quant
├── __init__.py
├── custom_autotune.py
├── fused_attn.py
├── fused_mlp.py
├── quant_linear.py
├── quantizer.py
└── triton_norm.py
├── requirements.txt
└── utils
├── __init__.py
├── datautils.py
├── export.py
└── modelutils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 |
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | based_on_style = pep8
3 | column_limit = 200
4 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GPTQ-for-LLaMA
2 |
3 | **I am currently focusing on [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) and recommend using [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) instead of GPTQ for Llama.**
4 |
5 |
6 |
7 | 4 bits quantization of [LLaMA](https://arxiv.org/abs/2302.13971) using [GPTQ](https://arxiv.org/abs/2210.17323)
8 |
9 | GPTQ is SOTA one-shot weight quantization method
10 |
11 | **It can be used universally, but it is not the [fastest](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/old-cuda) and only supports linux.**
12 |
13 | **Triton only supports Linux, so if you are a Windows user, please use [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install).**
14 |
15 | ## News or Update
16 | **AutoGPTQ-triton, a packaged version of GPTQ with triton, has been integrated into [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ).**
17 | ## Result
18 |
19 | LLaMA-7B(click me)
20 |
21 | | [LLaMA-7B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
22 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
23 | | FP16 | 16 | - | 13940 | 5.68 | 12.5 |
24 | | RTN | 4 | - | - | 6.29 | - |
25 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 4740 | 6.09 | 3.5 |
26 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 4891 | 5.85 | 3.6 |
27 | | RTN | 3 | - | - | 25.54 | - |
28 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 3852 | 8.07 | 2.7 |
29 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 4116 | 6.61 | 3.0 |
30 |
31 |
32 |
33 |
34 | LLaMA-13B
35 |
36 | | [LLaMA-13B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
37 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
38 | | FP16 | 16 | - | OOM | 5.09 | 24.2 |
39 | | RTN | 4 | - | - | 5.53 | - |
40 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 8410 | 5.36 | 6.5 |
41 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 8747 | 5.20 | 6.7 |
42 | | RTN | 3 | - | - | 11.40 | - |
43 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 6870 | 6.63 | 5.1 |
44 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 7277 | 5.62 | 5.4 |
45 |
46 |
47 |
48 |
49 | LLaMA-33B
50 |
51 | | [LLaMA-33B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
52 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
53 | | FP16 | 16 | - | OOM | 4.10 | 60.5 |
54 | | RTN | 4 | - | - | 4.54 | - |
55 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 19493 | 4.45 | 15.7 |
56 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 20570 | 4.23 | 16.3 |
57 | | RTN | 3 | - | - | 14.89 | - |
58 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 15493 | 5.69 | 12.0 |
59 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 16566 | 4.80 | 13.0 |
60 |
61 |
62 |
63 |
64 | LLaMA-65B
65 |
66 | | [LLaMA-65B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
67 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
68 | | FP16 | 16 | - | OOM | 3.53 | 121.0 |
69 | | RTN | 4 | - | - | 3.92 | - |
70 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | OOM | 3.84 | 31.1 |
71 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | OOM | 3.65 | 32.3 |
72 | | RTN | 3 | - | - | 10.59 | - |
73 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | OOM | 5.04 | 23.6 |
74 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | OOM | 4.17 | 25.6 |
75 |
76 |
77 | Quantization requires a large amount of CPU memory. However, the memory required can be reduced by using swap memory.
78 |
79 | Depending on the GPUs/drivers, there may be a difference in performance, which decreases as the model size increases.(https://github.com/IST-DASLab/gptq/issues/1)
80 |
81 | According to [GPTQ paper](https://arxiv.org/abs/2210.17323), As the size of the model increases, the difference in performance between FP16 and GPTQ decreases.
82 |
83 | ## GPTQ vs [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
84 |
85 |
86 | LLaMA-7B(click me)
87 |
88 | | [LLaMA-7B(seqlen=2048)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) |
89 | | --------------------------------------------------------------- | ------------------- | ----------- | --------- |
90 | | FP16 | 16 | 13948 | 5.22 |
91 | | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 4781 | 5.30 |
92 | | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 4804 | 5.30 |
93 | | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 5102 | 5.30 |
94 | | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 5102 | 5.33 |
95 |
96 |
97 |
98 |
99 | LLaMA-13B
100 |
101 | | [LLaMA-13B(seqlen=2048)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) |
102 | | ---------------------------------------------------------------- | ------------------- | ----------- | --------- |
103 | | FP16 | 16 | OOM | - |
104 | | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 8589 | 5.02 |
105 | | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 8581 | 5.04 |
106 | | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 9170 | 5.04 |
107 | | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 9170 | 5.11 |
108 |
109 |
110 |
111 | LLaMA-33B
112 |
113 | | [LLaMA-33B(seqlen=1024)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) |
114 | | ---------------------------------------------------------------- | ------------------- | ----------- | --------- |
115 | | FP16 | 16 | OOM | - |
116 | | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 18441 | 3.71 |
117 | | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 18313 | 3.76 |
118 | | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 19729 | 3.75 |
119 | | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 19729 | 3.75 |
120 |
121 |
122 |
123 | ## Installation
124 | If you don't have [conda](https://docs.conda.io/en/latest/miniconda.html), install it first.
125 | ```
126 | conda create --name gptq python=3.9 -y
127 | conda activate gptq
128 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
129 | # Or, if you're having trouble with conda, use pip with python3.9:
130 | # pip3 install torch torchvision torchaudio
131 |
132 | git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa
133 | cd GPTQ-for-LLaMa
134 | pip install -r requirements.txt
135 | ```
136 | ## Dependencies
137 |
138 | * `torch`: tested on v2.0.0+cu117
139 | * `transformers`: tested on v4.28.0.dev0
140 | * `datasets`: tested on v2.10.1
141 | * `safetensors`: tested on v0.3.0
142 |
143 | All experiments were run on a single NVIDIA RTX3090.
144 |
145 | # Language Generation
146 | ## LLaMA
147 |
148 | ```
149 | #convert LLaMA to hf
150 | python convert_llama_weights_to_hf.py --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir ./llama-hf
151 |
152 | # Benchmark language generation with 4-bit LLaMA-7B:
153 |
154 | # Save compressed model
155 | CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save llama7b-4bit-128g.pt
156 |
157 | # Or save compressed `.safetensors` model
158 | CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save_safetensors llama7b-4bit-128g.safetensors
159 |
160 | # Benchmark generating a 2048 token sequence with the saved model
161 | CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --benchmark 2048 --check
162 |
163 | # Benchmark FP16 baseline, note that the model will be split across all listed GPUs
164 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python llama.py ${MODEL_DIR} c4 --benchmark 2048 --check
165 |
166 | # model inference with the saved model
167 | CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama"
168 |
169 | # model inference with the saved model using safetensors loaded direct to gpu
170 | CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.safetensors --text "this is llama" --device=0
171 |
172 | # model inference with the saved model with offload(This is very slow).
173 | CUDA_VISIBLE_DEVICES=0 python llama_inference_offload.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama" --pre_layer 16
174 | It takes about 180 seconds to generate 45 tokens(5->50 tokens) on single RTX3090 based on LLaMa-65B. pre_layer is set to 50.
175 | ```
176 | Basically, 4-bit quantization and 128 groupsize are recommended.
177 |
178 | You can also export quantization parameters with toml+numpy format.
179 | ```
180 | CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --quant-directory ${TOML_DIR}
181 | ```
182 |
183 | # Acknowledgements
184 | This code is based on [GPTQ](https://github.com/IST-DASLab/gptq)
185 |
186 | Thanks to Meta AI for releasing [LLaMA](https://arxiv.org/abs/2302.13971), a powerful LLM.
187 |
188 | Triton GPTQ kernel code is based on [GPTQ-triton](https://github.com/fpgaminer/GPTQ-triton)
189 |
--------------------------------------------------------------------------------
/convert_llama_weights_to_hf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from transformers.models.llama.convert_llama_weights_to_hf import write_model, write_tokenizer
4 |
5 |
6 | def main():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument(
9 | "--input_dir",
10 | help="Location of LLaMA weights, which contains tokenizer.model and model folders",
11 | )
12 | parser.add_argument(
13 | "--model_size",
14 | choices=["7B", "13B", "30B", "65B", "tokenizer_only"],
15 | )
16 | parser.add_argument(
17 | "--output_dir",
18 | help="Location to write HF model and tokenizer",
19 | )
20 | args = parser.parse_args()
21 | if args.model_size != "tokenizer_only":
22 | write_model(
23 | model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
24 | input_base_path=os.path.join(args.input_dir, args.model_size),
25 | model_size=args.model_size,
26 | )
27 | write_tokenizer(
28 | tokenizer_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
29 | input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
30 | )
31 |
32 |
33 | if __name__ == "__main__":
34 | main()
35 |
--------------------------------------------------------------------------------
/gptq.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 | import quant
8 | from texttable import Texttable
9 | from utils import torch_snr_error
10 |
11 | torch.backends.cuda.matmul.allow_tf32 = False
12 | torch.backends.cudnn.allow_tf32 = False
13 |
14 |
15 | class Observer:
16 |
17 | def __init__(self, topk=32):
18 | self.loss_list = []
19 | self.topk = topk
20 |
21 | def submit(self, name: str, layerid: int, gptq, error: float):
22 |
23 | item = (name, layerid, {'gptq': gptq, 'error': error})
24 |
25 | if len(self.loss_list) < self.topk:
26 | self.loss_list.append(item)
27 | return
28 |
29 | min_error = error
30 | min_idx = -1
31 | for idx, data in enumerate(self.loss_list):
32 | if min_error > data[2]['error']:
33 | min_idx = idx
34 | min_error = data[2]['error']
35 |
36 | if min_idx >= 0:
37 | self.loss_list[min_idx] = item
38 |
39 | def print(self):
40 | self.loss_list = sorted(self.loss_list, key=lambda s: s[2]['error'], reverse=True)
41 |
42 | table = Texttable()
43 |
44 | table.header(['name', 'error'])
45 | table.set_cols_dtype(['t', 'f'])
46 |
47 | for item in self.loss_list:
48 | table.add_row([f"{item[0]}.{item[1]}", item[2]['error']])
49 | print(table.draw())
50 | print('\n')
51 |
52 | def items(self):
53 | return self.loss_list
54 |
55 |
56 | class GPTQ:
57 |
58 | def __init__(self, layer, observe=False):
59 | self.layer = layer
60 | self.dev = self.layer.weight.device
61 | W = layer.weight.data.clone()
62 | if isinstance(self.layer, nn.Conv2d):
63 | W = W.flatten(1)
64 | if isinstance(self.layer, transformers.Conv1D):
65 | W = W.t()
66 | self.rows = W.shape[0]
67 | self.columns = W.shape[1]
68 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
69 | self.nsamples = 0
70 | self.quantizer = quant.Quantizer()
71 | self.observe = observe
72 |
73 | def add_batch(self, inp, out):
74 | # Hessian H = 2 X XT + λ I
75 | if self.observe:
76 | self.inp1 = inp
77 | self.out1 = out
78 | else:
79 | self.inp1 = None
80 | self.out1 = None
81 |
82 | if len(inp.shape) == 2:
83 | inp = inp.unsqueeze(0)
84 | tmp = inp.shape[0]
85 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
86 | if len(inp.shape) == 3:
87 | inp = inp.reshape((-1, inp.shape[-1]))
88 | inp = inp.t()
89 | if isinstance(self.layer, nn.Conv2d):
90 | unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride)
91 | inp = unfold(inp)
92 | inp = inp.permute([1, 0, 2])
93 | inp = inp.flatten(1)
94 | self.H *= self.nsamples / (self.nsamples + tmp)
95 | self.nsamples += tmp
96 | # inp = inp.float()
97 | inp = math.sqrt(2 / self.nsamples) * inp.float()
98 | # self.H += 2 / self.nsamples * inp.matmul(inp.t())
99 | self.H += inp.matmul(inp.t())
100 |
101 | def print_loss(self, name, q_weight, weight_error, timecost):
102 | table = Texttable()
103 | name += ' ' * (16 - len(name))
104 |
105 | table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time'])
106 |
107 | # assign weight
108 | self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
109 |
110 | if self.inp1 is not None:
111 | # quantize input to int8
112 | quantizer = quant.Quantizer()
113 | quantizer.configure(8, perchannel=False, sym=True, mse=False)
114 | quantizer.find_params(self.inp1)
115 | q_in = quantizer.quantize(self.inp1).type(torch.float16)
116 | q_out = self.layer(q_in)
117 |
118 | # get kinds of SNR
119 | q_SNR = torch_snr_error(q_out, self.out1).item()
120 | fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
121 | else:
122 | q_SNR = '-'
123 | fp_SNR = '-'
124 |
125 | table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
126 | print(table.draw().split('\n')[-2])
127 |
128 | def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''):
129 | self.layer.to(self.dev)
130 |
131 | W = self.layer.weight.data.clone()
132 | if isinstance(self.layer, nn.Conv2d):
133 | W = W.flatten(1)
134 | if isinstance(self.layer, transformers.Conv1D):
135 | W = W.t()
136 | W = W.float()
137 |
138 | tick = time.time()
139 |
140 | if not self.quantizer.ready():
141 | self.quantizer.find_params(W, weight=True)
142 |
143 | H = self.H
144 | if not self.observe:
145 | del self.H
146 | dead = torch.diag(H) == 0
147 | H[dead, dead] = 1
148 | W[:, dead] = 0
149 |
150 | if actorder:
151 | perm = torch.argsort(torch.diag(H), descending=True)
152 | W = W[:, perm]
153 | H = H[perm][:, perm]
154 |
155 | Losses = torch.zeros_like(W)
156 | Q = torch.zeros_like(W)
157 |
158 | damp = percdamp * torch.mean(torch.diag(H))
159 | diag = torch.arange(self.columns, device=self.dev)
160 | H[diag, diag] += damp
161 | H = torch.linalg.cholesky(H)
162 | H = torch.cholesky_inverse(H)
163 | H = torch.linalg.cholesky(H, upper=True)
164 | Hinv = H
165 |
166 | g_idx = []
167 | scale = []
168 | zero = []
169 | now_idx = 1
170 |
171 | for i1 in range(0, self.columns, blocksize):
172 | i2 = min(i1 + blocksize, self.columns)
173 | count = i2 - i1
174 |
175 | W1 = W[:, i1:i2].clone()
176 | Q1 = torch.zeros_like(W1)
177 | Err1 = torch.zeros_like(W1)
178 | Losses1 = torch.zeros_like(W1)
179 | Hinv1 = Hinv[i1:i2, i1:i2]
180 |
181 | for i in range(count):
182 | w = W1[:, i]
183 | d = Hinv1[i, i]
184 |
185 | if groupsize != -1:
186 | if (i1 + i) % groupsize == 0:
187 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
188 |
189 | if ((i1 + i) // groupsize) - now_idx == -1:
190 | scale.append(self.quantizer.scale)
191 | zero.append(self.quantizer.zero)
192 | now_idx += 1
193 |
194 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
195 | Q1[:, i] = q
196 | Losses1[:, i] = (w - q)**2 / d**2
197 |
198 | err1 = (w - q) / d
199 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
200 | Err1[:, i] = err1
201 |
202 | Q[:, i1:i2] = Q1
203 | Losses[:, i1:i2] = Losses1 / 2
204 |
205 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
206 |
207 | torch.cuda.synchronize()
208 | error = torch.sum(Losses).item()
209 |
210 | groupsize = groupsize if groupsize != -1 else self.columns
211 | g_idx = [i // groupsize for i in range(self.columns)]
212 | g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
213 | if actorder:
214 | invperm = torch.argsort(perm)
215 | Q = Q[:, invperm]
216 | g_idx = g_idx[invperm]
217 |
218 | if isinstance(self.layer, transformers.Conv1D):
219 | Q = Q.t()
220 |
221 | self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick))
222 |
223 | if scale == []:
224 | scale.append(self.quantizer.scale)
225 | zero.append(self.quantizer.zero)
226 | scale = torch.cat(scale, dim=1)
227 | zero = torch.cat(zero, dim=1)
228 | return scale, zero, g_idx, error
229 |
230 | def free(self):
231 | self.inp1 = None
232 | self.out1 = None
233 | self.H = None
234 | self.Losses = None
235 | self.Trace = None
236 | torch.cuda.empty_cache()
237 |
--------------------------------------------------------------------------------
/llama.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import quant
7 |
8 | from gptq import GPTQ, Observer
9 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions
10 | from texttable import Texttable
11 |
12 |
13 | def get_llama(model):
14 |
15 | def skip(*args, **kwargs):
16 | pass
17 |
18 | torch.nn.init.kaiming_uniform_ = skip
19 | torch.nn.init.uniform_ = skip
20 | torch.nn.init.normal_ = skip
21 | from transformers import LlamaForCausalLM
22 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
23 | model.seqlen = 2048
24 | return model
25 |
26 |
27 | @torch.no_grad()
28 | def llama_sequential(model, dataloader, dev):
29 | print('Starting ...')
30 |
31 | use_cache = model.config.use_cache
32 | model.config.use_cache = False
33 | layers = model.model.layers
34 |
35 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
36 | model.model.norm = model.model.norm.to(dev)
37 | layers[0] = layers[0].to(dev)
38 |
39 | dtype = next(iter(model.parameters())).dtype
40 | inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
41 | cache = {'i': 0, 'attention_mask': None}
42 |
43 | class Catcher(nn.Module):
44 |
45 | def __init__(self, module):
46 | super().__init__()
47 | self.module = module
48 |
49 | def forward(self, inp, **kwargs):
50 | inps[cache['i']] = inp
51 | cache['i'] += 1
52 | cache['attention_mask'] = kwargs['attention_mask']
53 | cache['position_ids'] = kwargs['position_ids']
54 | raise ValueError
55 |
56 | layers[0] = Catcher(layers[0])
57 | for batch in dataloader:
58 | try:
59 | model(batch[0].to(dev))
60 | except ValueError:
61 | pass
62 | layers[0] = layers[0].module
63 |
64 | layers[0] = layers[0].cpu()
65 | model.model.embed_tokens = model.model.embed_tokens.cpu()
66 | model.model.norm = model.model.norm.cpu()
67 | torch.cuda.empty_cache()
68 |
69 | outs = torch.zeros_like(inps)
70 | attention_mask = cache['attention_mask']
71 | position_ids = cache['position_ids']
72 |
73 | print('Ready.')
74 |
75 | quantizers = {}
76 | observer = Observer()
77 | for i in range(len(layers)):
78 |
79 | print(f'Quantizing layer {i+1}/{len(layers)}..')
80 | print('+------------------+--------------+------------+-----------+-------+')
81 | print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
82 | print('+==================+==============+============+===========+=======+')
83 |
84 | layer = layers[i].to(dev)
85 | full = find_layers(layer)
86 | if args.true_sequential:
87 | sequential = [['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj']]
88 | else:
89 | sequential = [list(full.keys())]
90 |
91 | for names in sequential:
92 | subset = {n: full[n] for n in names}
93 | gptq = {}
94 | for name in subset:
95 | gptq[name] = GPTQ(subset[name], observe=args.observe)
96 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
97 |
98 | def add_batch(name):
99 |
100 | def tmp(_, inp, out):
101 | gptq[name].add_batch(inp[0].data, out.data)
102 |
103 | return tmp
104 |
105 | handles = []
106 | for name in subset:
107 | handles.append(subset[name].register_forward_hook(add_batch(name)))
108 | for j in range(args.nsamples):
109 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
110 | for h in handles:
111 | h.remove()
112 |
113 | for name in subset:
114 | scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name)
115 | quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize)
116 |
117 | if args.observe:
118 | observer.submit(name=name, layerid=i, gptq=gptq[name], error=error)
119 | else:
120 | gptq[name].free()
121 |
122 | for j in range(args.nsamples):
123 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
124 |
125 | layers[i] = layer.cpu()
126 | del layer
127 | del gptq
128 | torch.cuda.empty_cache()
129 |
130 | inps, outs = outs, inps
131 | print('+------------------+--------------+------------+-----------+-------+')
132 | print('\n')
133 |
134 | if args.observe:
135 | observer.print()
136 | conditions = gen_conditions(args.wbits, args.groupsize)
137 | for item in observer.items():
138 | name = item[0]
139 | layerid = item[1]
140 | gptq = item[2]['gptq']
141 | error = item[2]['error']
142 | target = error / 2
143 |
144 | table = Texttable()
145 | table.header(['wbits', 'groupsize', 'error'])
146 | table.set_cols_dtype(['i', 'i', 'f'])
147 | table.add_row([args.wbits, args.groupsize, error])
148 |
149 | print('Optimizing {} {} ..'.format(name, layerid))
150 | for wbits, groupsize in conditions:
151 |
152 | if error < target:
153 | # if error dropped 50%, skip
154 | break
155 |
156 | gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False)
157 |
158 | scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name)
159 |
160 | table.add_row([wbits, groupsize, error])
161 | quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize)
162 |
163 | print(table.draw())
164 | print('\n')
165 | gptq.layer.to('cpu')
166 | gptq.free()
167 |
168 | model.config.use_cache = use_cache
169 |
170 | return quantizers
171 |
172 |
173 | @torch.no_grad()
174 | def llama_eval(model, testenc, dev):
175 | print('Evaluating ...')
176 |
177 | testenc = testenc.input_ids
178 | nsamples = testenc.numel() // model.seqlen
179 |
180 | use_cache = model.config.use_cache
181 | model.config.use_cache = False
182 | layers = model.model.layers
183 |
184 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
185 | layers[0] = layers[0].to(dev)
186 |
187 | dtype = next(iter(model.parameters())).dtype
188 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
189 | cache = {'i': 0, 'attention_mask': None}
190 |
191 | class Catcher(nn.Module):
192 |
193 | def __init__(self, module):
194 | super().__init__()
195 | self.module = module
196 |
197 | def forward(self, inp, **kwargs):
198 | inps[cache['i']] = inp
199 | cache['i'] += 1
200 | cache['attention_mask'] = kwargs['attention_mask']
201 | cache['position_ids'] = kwargs['position_ids']
202 | raise ValueError
203 |
204 | layers[0] = Catcher(layers[0])
205 | for i in range(nsamples):
206 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
207 | try:
208 | model(batch)
209 | except ValueError:
210 | pass
211 | layers[0] = layers[0].module
212 |
213 | layers[0] = layers[0].cpu()
214 | model.model.embed_tokens = model.model.embed_tokens.cpu()
215 | torch.cuda.empty_cache()
216 |
217 | outs = torch.zeros_like(inps)
218 | attention_mask = cache['attention_mask']
219 | position_ids = cache['position_ids']
220 |
221 | for i in range(len(layers)):
222 | print(i)
223 | layer = layers[i].to(dev)
224 |
225 | if args.nearest:
226 | subset = find_layers(layer)
227 | for name in subset:
228 | quantizer = quant.Quantizer()
229 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
230 | W = subset[name].weight.data
231 | quantizer.find_params(W, weight=True)
232 | subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
233 |
234 | for j in range(nsamples):
235 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
236 | layers[i] = layer.cpu()
237 | del layer
238 | torch.cuda.empty_cache()
239 | inps, outs = outs, inps
240 |
241 | if model.model.norm is not None:
242 | model.model.norm = model.model.norm.to(dev)
243 | model.lm_head = model.lm_head.to(dev)
244 |
245 | testenc = testenc.to(dev)
246 | nlls = []
247 | for i in range(nsamples):
248 | hidden_states = inps[i].unsqueeze(0)
249 | if model.model.norm is not None:
250 | hidden_states = model.model.norm(hidden_states)
251 | lm_logits = model.lm_head(hidden_states)
252 | shift_logits = lm_logits[:, :-1, :].contiguous()
253 | shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
254 | loss_fct = nn.CrossEntropyLoss()
255 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
256 | neg_log_likelihood = loss.float() * model.seqlen
257 | nlls.append(neg_log_likelihood)
258 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
259 | print(ppl.item())
260 |
261 | model.config.use_cache = use_cache
262 |
263 |
264 | # TODO: perform packing on GPU
265 | def llama_pack(model, quantizers, wbits, groupsize):
266 | layers = find_layers(model)
267 | layers = {n: layers[n] for n in quantizers}
268 | quant.make_quant_linear(model, quantizers, wbits, groupsize)
269 | qlayers = find_layers(model, [quant.QuantLinear])
270 | print('Packing ...')
271 | for name in qlayers:
272 | print(name)
273 | quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
274 | qlayers[name].pack(layers[name], scale, zero, g_idx)
275 | print('Done.')
276 | return model
277 |
278 |
279 | def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
280 | from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
281 | config = LlamaConfig.from_pretrained(model)
282 |
283 | def noop(*args, **kwargs):
284 | pass
285 |
286 | torch.nn.init.kaiming_uniform_ = noop
287 | torch.nn.init.uniform_ = noop
288 | torch.nn.init.normal_ = noop
289 |
290 | torch.set_default_dtype(torch.half)
291 | modeling_utils._init_weights = False
292 | torch.set_default_dtype(torch.half)
293 | model = LlamaForCausalLM(config)
294 | torch.set_default_dtype(torch.float)
295 | if eval:
296 | model = model.eval()
297 | layers = find_layers(model)
298 | for name in ['lm_head']:
299 | if name in layers:
300 | del layers[name]
301 | quant.make_quant_linear(model, layers, wbits, groupsize)
302 |
303 | del layers
304 |
305 | print('Loading model ...')
306 | if checkpoint.endswith('.safetensors'):
307 | from safetensors.torch import load_file as safe_load
308 | model.load_state_dict(safe_load(checkpoint))
309 | else:
310 | model.load_state_dict(torch.load(checkpoint))
311 |
312 | if eval:
313 | quant.make_quant_attn(model)
314 | quant.make_quant_norm(model)
315 | if fused_mlp:
316 | quant.make_fused_mlp(model)
317 |
318 | if warmup_autotune:
319 | quant.autotune_warmup_linear(model, transpose=not (eval))
320 | if eval and fused_mlp:
321 | quant.autotune_warmup_fused(model)
322 | model.seqlen = 2048
323 | print('Done.')
324 |
325 | return model
326 |
327 |
328 | def llama_multigpu(model, gpus, gpu_dist):
329 | model.model.embed_tokens = model.model.embed_tokens.to(gpus[0])
330 | if hasattr(model.model, 'norm') and model.model.norm:
331 | model.model.norm = model.model.norm.to(gpus[0])
332 | import copy
333 | model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0])
334 |
335 | cache = {'mask': None, 'position_ids': None}
336 |
337 | class MoveModule(nn.Module):
338 |
339 | def __init__(self, module, invalidate_cache):
340 | super().__init__()
341 | self.module = module
342 | self.dev = next(iter(self.module.parameters())).device
343 | self.invalidate_cache=invalidate_cache
344 |
345 | def forward(self, *inp, **kwargs):
346 | inp = list(inp)
347 | if inp[0].device != self.dev:
348 | inp[0] = inp[0].to(self.dev)
349 |
350 | if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache:
351 | cache['mask'] = kwargs['attention_mask'].to(self.dev)
352 | kwargs['attention_mask'] = cache['mask']
353 |
354 | if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache:
355 | cache['position_ids'] = kwargs['position_ids'].to(self.dev)
356 | kwargs['position_ids'] = cache['position_ids']
357 |
358 | tmp = self.module(*inp, **kwargs)
359 | return tmp
360 |
361 | layers = model.model.layers
362 | from math import ceil
363 | if not gpu_dist:
364 | pergpu = ceil(len(layers) / len(gpus))
365 | for i in range(len(layers)):
366 | layers[i] = MoveModule(layers[i].to(0 if i == 0 or i == len(layers) -1 else gpus[(i-1) // pergpu]), i==0)
367 | else:
368 | assert gpu_dist[0] >= 2, "At least two layers must be on GPU 0."
369 | assigned_gpus = [0] * (gpu_dist[0]-1)
370 | for i in range(1, len(gpu_dist)):
371 | assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
372 |
373 | remaining_assignments = len(layers)-len(assigned_gpus) - 1
374 | if remaining_assignments > 0:
375 | assigned_gpus = assigned_gpus + [-1] * remaining_assignments
376 |
377 | assigned_gpus = assigned_gpus + [0]
378 |
379 | for i in range(len(layers)):
380 | layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0)
381 |
382 | model.gpus = gpus
383 |
384 |
385 | def benchmark(model, input_ids, check=False):
386 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
387 | torch.cuda.synchronize()
388 |
389 | cache = {'past': None}
390 |
391 | def clear_past(i):
392 |
393 | def tmp(layer, inp, out):
394 | if cache['past']:
395 | cache['past'][i] = None
396 |
397 | return tmp
398 |
399 | for i, layer in enumerate(model.model.layers):
400 | layer.register_forward_hook(clear_past(i))
401 |
402 | print('Benchmarking ...')
403 |
404 | if check:
405 | loss = nn.CrossEntropyLoss()
406 | tot = 0.
407 |
408 | def sync():
409 | if hasattr(model, 'gpus'):
410 | for gpu in model.gpus:
411 | torch.cuda.synchronize(gpu)
412 | else:
413 | torch.cuda.synchronize()
414 |
415 | max_memory = 0
416 | with torch.no_grad():
417 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
418 | times = []
419 | for i in range(input_ids.numel()):
420 | tick = time.time()
421 | out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
422 | sync()
423 | times.append(time.time() - tick)
424 | print(i, times[-1])
425 | if hasattr(model, 'gpus'):
426 | mem_allocated = sum(torch.cuda.memory_allocated(gpu) for gpu in model.gpus) / 1024 / 1024
427 | else:
428 | mem_allocated = torch.cuda.memory_allocated() / 1024 / 1024
429 | max_memory = max(max_memory, mem_allocated)
430 | if check and i != input_ids.numel() - 1:
431 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
432 | cache['past'] = list(out.past_key_values)
433 | del out
434 | sync()
435 | print('Median:', np.median(times))
436 | if check:
437 | print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
438 | print('max memory(MiB):', max_memory)
439 |
440 |
441 | if __name__ == '__main__':
442 |
443 | parser = argparse.ArgumentParser()
444 |
445 | parser.add_argument('model', type=str, help='llama model to load')
446 | parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
447 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
448 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
449 | parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
450 | parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
451 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
452 | parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
453 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
454 | parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
455 | parser.add_argument('--test-generation', action='store_true', help='test generation.')
456 | parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
457 | parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
458 | parser.add_argument('--load', type=str, default='', help='Load quantized model.')
459 | parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
460 | parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
461 | parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
462 | parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
463 | parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.')
464 | parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
465 | parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.')
466 | parser.add_argument('--observe',
467 | action='store_true',
468 | help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \
469 | When this feature enabled, `--save` or `--save_safetensors` would be disable.')
470 | parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.')
471 |
472 | args = parser.parse_args()
473 |
474 | if args.layers_dist:
475 | gpu_dist = [int(x) for x in args.layers_dist.split(':')]
476 | else:
477 | gpu_dist = []
478 |
479 | if type(args.load) is not str:
480 | args.load = args.load.as_posix()
481 |
482 | if args.load:
483 | model = load_quant(args.model, args.load, args.wbits, args.groupsize)
484 | else:
485 | model = get_llama(args.model)
486 | model.eval()
487 |
488 | dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
489 |
490 | if not args.load and args.wbits < 16 and not args.nearest:
491 | tick = time.time()
492 | quantizers = llama_sequential(model, dataloader, DEV)
493 | print(time.time() - tick)
494 |
495 | if args.benchmark:
496 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
497 | if len(gpus) > 1:
498 | llama_multigpu(model, gpus, gpu_dist)
499 | else:
500 | model = model.to(DEV)
501 | if args.benchmark:
502 | input_ids = next(iter(dataloader))[0][:, :args.benchmark]
503 | benchmark(model, input_ids, check=args.check)
504 |
505 | if args.eval:
506 | datasets = ['wikitext2', 'ptb', 'c4']
507 | if args.new_eval:
508 | datasets = ['wikitext2', 'ptb-new', 'c4-new']
509 | for dataset in datasets:
510 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
511 | print(dataset)
512 | llama_eval(model, testloader, DEV)
513 |
514 | if args.test_generation:
515 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
516 | if len(gpus) > 1:
517 | llama_multigpu(model, gpus, gpu_dist)
518 | else:
519 | model = model.to(DEV)
520 |
521 | from transformers import LlamaTokenizer, TextStreamer
522 | tokenizer = LlamaTokenizer.from_pretrained(args.model, use_fast=False)
523 | input_ids = tokenizer(["The capital of New Mexico is"], return_tensors="pt").input_ids.to(gpus[0])
524 | streamer = TextStreamer(tokenizer)
525 | with torch.no_grad():
526 | generated_ids = model.generate(input_ids, streamer=streamer)
527 |
528 |
529 |
530 | if args.quant_directory is not None:
531 | export_quant_table(quantizers, args.quant_directory)
532 |
533 | if not args.observe and args.save:
534 | llama_pack(model, quantizers, args.wbits, args.groupsize)
535 | torch.save(model.state_dict(), args.save)
536 |
537 | if not args.observe and args.save_safetensors:
538 | llama_pack(model, quantizers, args.wbits, args.groupsize)
539 | from safetensors.torch import save_file as safe_save
540 | state_dict = model.state_dict()
541 | state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
542 | safe_save(state_dict, args.save_safetensors)
543 |
--------------------------------------------------------------------------------
/llama_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.nn as nn
5 | import quant
6 |
7 | from gptq import GPTQ
8 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
9 | import transformers
10 | from transformers import AutoTokenizer
11 |
12 |
13 | def get_llama(model):
14 |
15 | def skip(*args, **kwargs):
16 | pass
17 |
18 | torch.nn.init.kaiming_uniform_ = skip
19 | torch.nn.init.uniform_ = skip
20 | torch.nn.init.normal_ = skip
21 | from transformers import LlamaForCausalLM
22 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto')
23 | model.seqlen = 2048
24 | return model
25 |
26 |
27 | def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
28 | from transformers import LlamaConfig, LlamaForCausalLM
29 | config = LlamaConfig.from_pretrained(model)
30 |
31 | def noop(*args, **kwargs):
32 | pass
33 |
34 | torch.nn.init.kaiming_uniform_ = noop
35 | torch.nn.init.uniform_ = noop
36 | torch.nn.init.normal_ = noop
37 |
38 | torch.set_default_dtype(torch.half)
39 | transformers.modeling_utils._init_weights = False
40 | torch.set_default_dtype(torch.half)
41 | model = LlamaForCausalLM(config)
42 | torch.set_default_dtype(torch.float)
43 | if eval:
44 | model = model.eval()
45 | layers = find_layers(model)
46 | for name in ['lm_head']:
47 | if name in layers:
48 | del layers[name]
49 | quant.make_quant_linear(model, layers, wbits, groupsize)
50 |
51 | del layers
52 |
53 | print('Loading model ...')
54 | if checkpoint.endswith('.safetensors'):
55 | from safetensors.torch import load_file as safe_load
56 | model.load_state_dict(safe_load(checkpoint), strict=False)
57 | else:
58 | model.load_state_dict(torch.load(checkpoint), strict=False)
59 |
60 | if eval:
61 | quant.make_quant_attn(model)
62 | quant.make_quant_norm(model)
63 | if fused_mlp:
64 | quant.make_fused_mlp(model)
65 | if warmup_autotune:
66 | quant.autotune_warmup_linear(model, transpose=not (eval))
67 | if eval and fused_mlp:
68 | quant.autotune_warmup_fused(model)
69 | model.seqlen = 2048
70 | print('Done.')
71 |
72 | return model
73 |
74 |
75 | if __name__ == '__main__':
76 |
77 | parser = argparse.ArgumentParser()
78 |
79 | parser.add_argument('model', type=str, help='llama model to load')
80 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
81 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
82 | parser.add_argument('--load', type=str, default='', help='Load quantized model.')
83 |
84 | parser.add_argument('--text', type=str, help='input text')
85 |
86 | parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.')
87 |
88 | parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.')
89 |
90 | parser.add_argument('--top_p',
91 | type=float,
92 | default=0.95,
93 | help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.')
94 |
95 | parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.')
96 |
97 | parser.add_argument('--device', type=int, default=-1, help='The device used to load the model when using safetensors. Default device is "cpu" or specify, 0,1,2,3,... for GPU device.')
98 |
99 | # fused mlp is sometimes not working with safetensors, no_fused_mlp is used to set fused_mlp to False, default is true
100 | parser.add_argument('--fused_mlp', action='store_true')
101 | parser.add_argument('--no_fused_mlp', dest='fused_mlp', action='store_false')
102 | parser.set_defaults(fused_mlp=True)
103 |
104 | args = parser.parse_args()
105 |
106 | if type(args.load) is not str:
107 | args.load = args.load.as_posix()
108 |
109 | if args.load:
110 | model = load_quant(args.model, args.load, args.wbits, args.groupsize, fused_mlp=args.fused_mlp)
111 | else:
112 | model = get_llama(args.model)
113 | model.eval()
114 |
115 | model.to(DEV)
116 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
117 | input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV)
118 |
119 | with torch.no_grad():
120 | generated_ids = model.generate(
121 | input_ids,
122 | do_sample=True,
123 | min_length=args.min_length,
124 | max_length=args.max_length,
125 | top_p=args.top_p,
126 | temperature=args.temperature,
127 | )
128 | print(tokenizer.decode([el.item() for el in generated_ids[0]]))
129 |
--------------------------------------------------------------------------------
/llama_inference_offload.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from gptq import GPTQ
5 | import argparse
6 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
7 | import quant
8 |
9 | import transformers
10 | from transformers import AutoTokenizer
11 | from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig
12 | from transformers.modeling_outputs import BaseModelOutputWithPast
13 | from typing import List, Optional, Tuple, Union
14 | from accelerate import cpu_offload_with_hook, load_checkpoint_in_model
15 |
16 |
17 | class Offload_LlamaModel(LlamaModel):
18 |
19 | def __init__(self, config: LlamaConfig):
20 | super().__init__(config)
21 |
22 | def cpu_offload(self, preload):
23 | hook = None
24 | for cpu_offloaded_model in self.layers[preload:]:
25 | _, hook = cpu_offload_with_hook(cpu_offloaded_model, DEV, prev_module_hook=hook)
26 |
27 | def forward(
28 | self,
29 | input_ids: torch.LongTensor = None,
30 | attention_mask: Optional[torch.Tensor] = None,
31 | position_ids: Optional[torch.LongTensor] = None,
32 | past_key_values: Optional[List[torch.FloatTensor]] = None,
33 | inputs_embeds: Optional[torch.FloatTensor] = None,
34 | use_cache: Optional[bool] = None,
35 | output_attentions: Optional[bool] = None,
36 | output_hidden_states: Optional[bool] = None,
37 | return_dict: Optional[bool] = None,
38 | ) -> Union[Tuple, BaseModelOutputWithPast]:
39 | r"""
40 | Args:
41 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
42 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
43 | provide it.
44 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
45 | [`PreTrainedTokenizer.__call__`] for details.
46 | [What are input IDs?](../glossary#input-ids)
47 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
48 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
49 | - 1 for tokens that are **not masked**,
50 | - 0 for tokens that are **masked**.
51 | [What are attention masks?](../glossary#attention-mask)
52 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
53 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
54 | `[0, config.n_positions - 1]`.
55 | [What are position IDs?](../glossary#position-ids)
56 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
57 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
58 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
59 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
60 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
61 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
62 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
63 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
64 | use_cache (`bool`, *optional*):
65 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
66 | (see `past_key_values`).
67 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
68 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
69 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors
70 | than the model's internal embedding lookup matrix.
71 | output_attentions (`bool`, *optional*):
72 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
73 | returned tensors for more detail.
74 | output_hidden_states (`bool`, *optional*):
75 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
76 | for more detail.
77 | return_dict (`bool`, *optional*):
78 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
79 | """
80 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
81 | output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
82 | use_cache = use_cache if use_cache is not None else self.config.use_cache
83 |
84 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
85 |
86 | # retrieve input_ids and inputs_embeds
87 | if input_ids is not None and inputs_embeds is not None:
88 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
89 | elif input_ids is not None:
90 | batch_size, seq_length = input_ids.shape
91 | elif inputs_embeds is not None:
92 | batch_size, seq_length, _ = inputs_embeds.shape
93 | else:
94 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
95 | seq_length_with_past = seq_length
96 | past_key_values_length = 0
97 | if past_key_values is not None:
98 | past_key_values_length = past_key_values[0][0].shape[2]
99 | seq_length_with_past = seq_length_with_past + past_key_values_length
100 |
101 | if position_ids is None:
102 | device = input_ids.device if input_ids is not None else inputs_embeds.device
103 | position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
104 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
105 | else:
106 | position_ids = position_ids.view(-1, seq_length).long()
107 |
108 | if inputs_embeds is None:
109 | inputs_embeds = self.embed_tokens(input_ids)
110 |
111 | # embed positions
112 | if attention_mask is None:
113 | attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
114 | attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length)
115 |
116 | hidden_states = inputs_embeds
117 |
118 | if self.gradient_checkpointing and self.training:
119 | if use_cache:
120 | logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
121 | use_cache = False
122 |
123 | # decoder layers
124 | all_hidden_states = () if output_hidden_states else None
125 | all_self_attns = () if output_attentions else None
126 | next_decoder_cache = () if use_cache else None
127 |
128 | for idx in range(len(self.layers)):
129 | decoder_layer = self.layers[idx]
130 |
131 | if output_hidden_states:
132 | all_hidden_states += (hidden_states, )
133 |
134 | past_key_value = past_key_values[idx] if past_key_values is not None else None
135 |
136 | if self.gradient_checkpointing and self.training:
137 |
138 | def create_custom_forward(module):
139 |
140 | def custom_forward(*inputs):
141 | # None for past_key_value
142 | return module(*inputs, output_attentions, None)
143 |
144 | return custom_forward
145 |
146 | layer_outputs = torch.utils.checkpoint.checkpoint(
147 | create_custom_forward(decoder_layer),
148 | hidden_states,
149 | attention_mask,
150 | position_ids,
151 | None,
152 | )
153 | else:
154 | layer_outputs = decoder_layer(
155 | hidden_states,
156 | attention_mask=attention_mask,
157 | position_ids=position_ids,
158 | past_key_value=past_key_value,
159 | output_attentions=output_attentions,
160 | use_cache=use_cache,
161 | )
162 |
163 | hidden_states = layer_outputs[0]
164 |
165 | if use_cache:
166 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1], )
167 |
168 | if output_attentions:
169 | all_self_attns += (layer_outputs[1], )
170 |
171 | hidden_states = self.norm(hidden_states)
172 |
173 | # add hidden states from the last decoder layer
174 | if output_hidden_states:
175 | all_hidden_states += (hidden_states, )
176 |
177 | next_cache = next_decoder_cache if use_cache else None
178 | if not return_dict:
179 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
180 | return BaseModelOutputWithPast(
181 | last_hidden_state=hidden_states,
182 | past_key_values=next_cache,
183 | hidden_states=all_hidden_states,
184 | attentions=all_self_attns,
185 | )
186 |
187 |
188 | def load_quant(model, checkpoint, wbits, groupsize, pre_layer, fused_mlp=True, warmup_autotune=True):
189 | transformers.models.llama.modeling_llama.LlamaModel = Offload_LlamaModel
190 | from transformers import LlamaConfig, LlamaForCausalLM
191 | config = LlamaConfig.from_pretrained(model)
192 |
193 | def noop(*args, **kwargs):
194 | pass
195 |
196 | torch.nn.init.kaiming_uniform_ = noop
197 | torch.nn.init.uniform_ = noop
198 | torch.nn.init.normal_ = noop
199 |
200 | torch.set_default_dtype(torch.half)
201 | transformers.modeling_utils._init_weights = False
202 | torch.set_default_dtype(torch.half)
203 | model = LlamaForCausalLM(config)
204 | torch.set_default_dtype(torch.float)
205 | model = model.eval()
206 | layers = find_layers(model)
207 | for name in ['lm_head']:
208 | if name in layers:
209 | del layers[name]
210 | quant.make_quant_linear(model, layers, wbits, groupsize)
211 |
212 | print('Loading model ...')
213 | load_checkpoint_in_model(model, checkpoint, dtype='float16')
214 | model.seqlen = 2048
215 |
216 | if eval:
217 | quant.make_quant_attn(model)
218 | quant.make_quant_norm(model)
219 | if fused_mlp:
220 | quant.make_fused_mlp(model)
221 |
222 |
223 | if warmup_autotune:
224 | quant.autotune_warmup_linear(model)
225 | if fused_mlp:
226 | quant.autotune_warmup_fused(model)
227 |
228 | for i in range(pre_layer):
229 | model.model.layers[i].to(DEV)
230 | model.model.embed_tokens.to(DEV)
231 | model.model.norm.to(DEV)
232 | model.lm_head.to(DEV)
233 | model.model.cpu_offload(pre_layer)
234 | print('Done.')
235 | return model
236 |
237 |
238 | if __name__ == '__main__':
239 | parser = argparse.ArgumentParser()
240 |
241 | parser.add_argument('model', type=str, help='llama model to load')
242 | parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8], help='#bits to use for quantization')
243 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
244 | parser.add_argument('--load', type=str, default='', help='Load quantized model.')
245 | parser.add_argument('--text', type=str, help='input text')
246 |
247 | parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.')
248 |
249 | parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.')
250 |
251 | parser.add_argument('--top_p',
252 | type=float,
253 | default=0.95,
254 | help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.')
255 |
256 | parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.')
257 |
258 | parser.add_argument('--pre_layer', type=int, default=50, help='The number of layers to preload')
259 |
260 | args = parser.parse_args()
261 |
262 | if type(args.load) is not str:
263 | args.load = args.load.as_posix()
264 |
265 | model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.pre_layer)
266 |
267 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
268 | input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV)
269 |
270 | with torch.no_grad():
271 | generated_ids = model.generate(
272 | input_ids,
273 | do_sample=True,
274 | min_length=args.min_length,
275 | max_length=args.max_length,
276 | top_p=args.top_p,
277 | temperature=args.temperature,
278 | )
279 | print(tokenizer.decode([el.item() for el in generated_ids[0]]))
280 |
--------------------------------------------------------------------------------
/neox.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import quant
7 |
8 | from gptq import GPTQ, Observer
9 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions
10 | from texttable import Texttable
11 |
12 |
13 | def get_neox(model, seqlen=-1):
14 |
15 | def skip(*args, **kwargs):
16 | pass
17 |
18 | torch.nn.init.kaiming_uniform_ = skip
19 | torch.nn.init.uniform_ = skip
20 | torch.nn.init.normal_ = skip
21 | from transformers import GPTNeoXForCausalLM
22 | model = GPTNeoXForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
23 | model.seqlen = seqlen if seqlen != -1 else model.config.max_position_embeddings
24 | return model
25 |
26 |
27 | @torch.no_grad()
28 | def neox_sequential(model, dataloader, dev):
29 | print('Starting ...')
30 |
31 | use_cache = model.config.use_cache
32 | model.config.use_cache = False
33 | layers = model.gpt_neox.layers
34 |
35 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev)
36 | layers[0] = layers[0].to(dev)
37 |
38 | dtype = next(iter(model.parameters())).dtype
39 | inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
40 | cache = {'i': 0, 'attention_mask': None}
41 |
42 | class Catcher(nn.Module):
43 |
44 | def __init__(self, module):
45 | super().__init__()
46 | self.module = module
47 |
48 | def forward(self, inp, **kwargs):
49 | inps[cache['i']] = inp
50 | cache['i'] += 1
51 | cache['attention_mask'] = kwargs['attention_mask']
52 | cache['position_ids'] = kwargs['position_ids']
53 | raise ValueError
54 |
55 | layers[0] = Catcher(layers[0])
56 | for batch in dataloader:
57 | try:
58 | model(batch[0].to(dev))
59 | except ValueError:
60 | pass
61 | layers[0] = layers[0].module
62 |
63 | layers[0] = layers[0].cpu()
64 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu()
65 | torch.cuda.empty_cache()
66 |
67 | outs = torch.zeros_like(inps)
68 | attention_mask = cache['attention_mask']
69 | position_ids = cache['position_ids']
70 |
71 | print('Ready.')
72 |
73 | quantizers = {}
74 | observer = Observer()
75 | for i in range(len(layers)):
76 |
77 | print(f'Quantizing layer {i+1}/{len(layers)}..')
78 | print('+------------------+--------------+------------+-----------+-------+')
79 | print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |')
80 | print('+==================+==============+============+===========+=======+')
81 |
82 | layer = layers[i].to(dev)
83 | full = find_layers(layer)
84 | sequential = [list(full.keys())]
85 |
86 | for names in sequential:
87 | subset = {n: full[n] for n in names}
88 | gptq = {}
89 | for name in subset:
90 | gptq[name] = GPTQ(subset[name], observe=False)
91 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
92 |
93 | def add_batch(name):
94 |
95 | def tmp(_, inp, out):
96 | gptq[name].add_batch(inp[0].data, out.data)
97 |
98 | return tmp
99 |
100 | handles = []
101 | for name in subset:
102 | handles.append(subset[name].register_forward_hook(add_batch(name)))
103 | for j in range(args.nsamples):
104 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
105 | for h in handles:
106 | h.remove()
107 |
108 | for name in subset:
109 | scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name)
110 | quantizers['gpt_neox.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize)
111 | gptq[name].free()
112 |
113 | for j in range(args.nsamples):
114 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
115 |
116 | layers[i] = layer.cpu()
117 | del layer
118 | del gptq
119 | torch.cuda.empty_cache()
120 |
121 | inps, outs = outs, inps
122 | print('+------------------+--------------+------------+-----------+-------+')
123 | print('\n')
124 |
125 | model.config.use_cache = use_cache
126 |
127 | return quantizers
128 |
129 |
130 | @torch.no_grad()
131 | def neox_eval(model, testenc, dev):
132 | print('Evaluating ...')
133 |
134 | testenc = testenc.input_ids
135 | nsamples = testenc.numel() // model.seqlen
136 |
137 | use_cache = model.config.use_cache
138 | model.config.use_cache = False
139 | layers = model.gpt_neox.layers
140 |
141 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev)
142 | layers[0] = layers[0].to(dev)
143 |
144 | dtype = next(iter(model.parameters())).dtype
145 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
146 | cache = {'i': 0, 'attention_mask': None}
147 |
148 | class Catcher(nn.Module):
149 |
150 | def __init__(self, module):
151 | super().__init__()
152 | self.module = module
153 |
154 | def forward(self, inp, **kwargs):
155 | inps[cache['i']] = inp
156 | cache['i'] += 1
157 | cache['attention_mask'] = kwargs['attention_mask']
158 | cache['position_ids'] = kwargs['position_ids']
159 | raise ValueError
160 |
161 | layers[0] = Catcher(layers[0])
162 | for i in range(nsamples):
163 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
164 | try:
165 | model(batch)
166 | except ValueError:
167 | pass
168 | layers[0] = layers[0].module
169 |
170 | layers[0] = layers[0].cpu()
171 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu()
172 | torch.cuda.empty_cache()
173 |
174 | outs = torch.zeros_like(inps)
175 | attention_mask = cache['attention_mask']
176 | position_ids = cache['position_ids']
177 |
178 | for i in range(len(layers)):
179 | print(i)
180 | layer = layers[i].to(dev)
181 |
182 | if args.nearest:
183 | subset = find_layers(layer)
184 | for name in subset:
185 | quantizer = quant.Quantizer()
186 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
187 | W = subset[name].weight.data
188 | quantizer.find_params(W, weight=True)
189 | subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
190 |
191 | for j in range(nsamples):
192 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
193 | layers[i] = layer.cpu()
194 | del layer
195 | torch.cuda.empty_cache()
196 | inps, outs = outs, inps
197 |
198 | model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(dev)
199 | model.embed_out = model.embed_out.to(dev)
200 |
201 | testenc = testenc.to(dev)
202 | nlls = []
203 | for i in range(nsamples):
204 | hidden_states = inps[i].unsqueeze(0)
205 | hidden_states = model.gpt_neox.final_layer_norm(hidden_states)
206 | lm_logits = model.embed_out(hidden_states)
207 | shift_logits = lm_logits[:, :-1, :].contiguous()
208 | shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
209 | loss_fct = nn.CrossEntropyLoss()
210 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
211 | neg_log_likelihood = loss.float() * model.seqlen
212 | nlls.append(neg_log_likelihood)
213 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
214 | print(ppl.item())
215 |
216 | model.config.use_cache = use_cache
217 |
218 |
219 | # TODO: perform packing on GPU
220 | def neox_pack(model, quantizers, wbits, groupsize):
221 | layers = find_layers(model)
222 | layers = {n: layers[n] for n in quantizers}
223 | quant.make_quant_linear(model, quantizers, wbits, groupsize)
224 | qlayers = find_layers(model, [quant.QuantLinear])
225 | print('Packing ...')
226 | for name in qlayers:
227 | print(name)
228 | quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
229 | qlayers[name].pack(layers[name], scale, zero, g_idx)
230 | print('Done.')
231 | return model
232 |
233 |
234 | def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True):
235 | from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, modeling_utils
236 | config = GPTNeoXConfig.from_pretrained(model)
237 |
238 | def noop(*args, **kwargs):
239 | pass
240 |
241 | torch.nn.init.kaiming_uniform_ = noop
242 | torch.nn.init.uniform_ = noop
243 | torch.nn.init.normal_ = noop
244 |
245 | torch.set_default_dtype(torch.half)
246 | modeling_utils._init_weights = False
247 | torch.set_default_dtype(torch.half)
248 | model = GPTNeoXForCausalLM(config)
249 | torch.set_default_dtype(torch.float)
250 | if eval:
251 | model = model.eval()
252 | layers = find_layers(model)
253 | for name in ['embed_in','embed_out']:
254 | if name in layers:
255 | del layers[name]
256 | quant.make_quant_linear(model, layers, wbits, groupsize)
257 |
258 | del layers
259 |
260 | print('Loading model ...')
261 | if checkpoint.endswith('.safetensors'):
262 | from safetensors.torch import load_file as safe_load
263 | model.load_state_dict(safe_load(checkpoint))
264 | else:
265 | model.load_state_dict(torch.load(checkpoint))
266 |
267 | if warmup_autotune:
268 | quant.autotune_warmup_linear(model, transpose=not (eval))
269 |
270 | model.seqlen = model.config.max_position_embeddings
271 | print('Done.')
272 |
273 | return model
274 |
275 |
276 | def neox_multigpu(model, gpus):
277 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(gpus[0])
278 | model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(gpus[-1])
279 | import copy
280 | model.embed_out = copy.deepcopy(model.embed_out).to(gpus[-1])
281 |
282 | cache = {'mask': None}
283 |
284 | class MoveModule(nn.Module):
285 |
286 | def __init__(self, module):
287 | super().__init__()
288 | self.module = module
289 | self.dev = next(iter(self.module.parameters())).device
290 |
291 | def forward(self, *inp, **kwargs):
292 | inp = list(inp)
293 | if inp[0].device != self.dev:
294 | inp[0] = inp[0].to(self.dev)
295 | if cache['mask'] is None or cache['mask'].device != self.dev:
296 | cache['mask'] = kwargs['attention_mask'].to(self.dev)
297 | kwargs['attention_mask'] = cache['mask']
298 | tmp = self.module(*inp, **kwargs)
299 | return tmp
300 |
301 | layers = model.gpt_neox.layers
302 | pergpu = math.ceil(len(layers) / len(gpus))
303 | for i in range(len(layers)):
304 | layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
305 |
306 | model.gpus = gpus
307 |
308 |
309 | def benchmark(model, input_ids, check=False):
310 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
311 | torch.cuda.synchronize()
312 |
313 | cache = {'past': None}
314 |
315 | def clear_past(i):
316 |
317 | def tmp(layer, inp, out):
318 | if cache['past']:
319 | cache['past'][i] = None
320 |
321 | return tmp
322 |
323 | for i, layer in enumerate(model.gpt_neox.layers):
324 | layer.register_forward_hook(clear_past(i))
325 |
326 | print('Benchmarking ...')
327 |
328 | if check:
329 | loss = nn.CrossEntropyLoss()
330 | tot = 0.
331 |
332 | def sync():
333 | if hasattr(model, 'gpus'):
334 | for gpu in model.gpus:
335 | torch.cuda.synchronize(gpu)
336 | else:
337 | torch.cuda.synchronize()
338 |
339 | max_memory = 0
340 | with torch.no_grad():
341 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
342 | times = []
343 | for i in range(input_ids.numel()):
344 | tick = time.time()
345 | out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
346 | sync()
347 | times.append(time.time() - tick)
348 | print(i, times[-1])
349 | max_memory = max(max_memory, torch.cuda.memory_allocated() / 1024 / 1024)
350 | if check and i != input_ids.numel() - 1:
351 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
352 | cache['past'] = list(out.past_key_values)
353 | del out
354 | sync()
355 | print('Median:', np.median(times))
356 | if check:
357 | print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
358 | print('max memory(MiB):', max_memory)
359 |
360 |
361 | if __name__ == '__main__':
362 |
363 | parser = argparse.ArgumentParser()
364 |
365 | parser.add_argument('model', type=str, help='llama model to load')
366 | parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
367 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
368 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
369 | parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
370 | parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
371 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='bits to use for quantization; use 16 for evaluating base model.')
372 | parser.add_argument('--seqlen', type=int, default=-1, help='seqlen to use for quantization; default uses full seqlen')
373 | parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
374 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
375 | parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
376 | parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
377 | parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
378 | parser.add_argument('--load', type=str, default='', help='Load quantized model.')
379 | parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
380 | parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
381 | parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
382 | parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
383 | parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
384 | args = parser.parse_args()
385 |
386 | if type(args.load) is not str:
387 | args.load = args.load.as_posix()
388 |
389 | if args.load:
390 | model = load_quant(args.model, args.load, args.wbits, args.groupsize)
391 | else:
392 | model = get_neox(args.model)
393 | model.eval()
394 |
395 | dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
396 |
397 | if not args.load and args.wbits < 16 and not args.nearest:
398 | tick = time.time()
399 | quantizers = neox_sequential(model, dataloader, DEV)
400 | print(time.time() - tick)
401 |
402 | if args.benchmark:
403 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
404 | if len(gpus) > 1:
405 | neox_multigpu(model, gpus)
406 | else:
407 | model = model.to(DEV)
408 | if args.benchmark:
409 | input_ids = next(iter(dataloader))[0][:, :args.benchmark]
410 | benchmark(model, input_ids, check=args.check)
411 |
412 | if args.eval:
413 | datasets = ['wikitext2', 'ptb', 'c4']
414 | if args.new_eval:
415 | datasets = ['wikitext2', 'ptb-new', 'c4-new']
416 | for dataset in datasets:
417 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
418 | print(dataset)
419 | neox_eval(model, testloader, DEV)
420 |
421 | if args.save:
422 | neox_pack(model, quantizers, args.wbits, args.groupsize)
423 | torch.save(model.state_dict(), args.save)
424 |
425 | if args.save_safetensors:
426 | neox_pack(model, quantizers, args.wbits, args.groupsize)
427 | from safetensors.torch import save_file as safe_save
428 | state_dict = model.state_dict()
429 | state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
430 | safe_save(state_dict, args.save_safetensors)
431 |
--------------------------------------------------------------------------------
/opt.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | import torch.nn as nn
5 | import argparse
6 |
7 | import transformers
8 | from gptq import GPTQ
9 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
10 | import quant
11 |
12 |
13 | def get_opt(model):
14 | import torch
15 |
16 | def skip(*args, **kwargs):
17 | pass
18 |
19 | torch.nn.init.kaiming_uniform_ = skip
20 | torch.nn.init.uniform_ = skip
21 | torch.nn.init.normal_ = skip
22 | from transformers import OPTForCausalLM
23 | model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
24 | model.seqlen = model.config.max_position_embeddings
25 | return model
26 |
27 |
28 | @torch.no_grad()
29 | def opt_sequential(model, dataloader, dev):
30 | print('Starting ...')
31 |
32 | use_cache = model.config.use_cache
33 | model.config.use_cache = False
34 | layers = model.model.decoder.layers
35 |
36 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
37 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
38 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
39 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
40 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
41 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
42 | layers[0] = layers[0].to(dev)
43 |
44 | dtype = next(iter(model.parameters())).dtype
45 | inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
46 | cache = {'i': 0, 'attention_mask': None}
47 |
48 | class Catcher(nn.Module):
49 |
50 | def __init__(self, module):
51 | super().__init__()
52 | self.module = module
53 |
54 | def forward(self, inp, **kwargs):
55 | inps[cache['i']] = inp
56 | cache['i'] += 1
57 | cache['attention_mask'] = kwargs['attention_mask']
58 | raise ValueError
59 |
60 | layers[0] = Catcher(layers[0])
61 | for batch in dataloader:
62 | try:
63 | model(batch[0].to(dev))
64 | except ValueError:
65 | pass
66 | layers[0] = layers[0].module
67 |
68 | layers[0] = layers[0].cpu()
69 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
70 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
71 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
72 | model.model.decoder.project_out = model.model.decoder.project_out.cpu()
73 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
74 | model.model.decoder.project_in = model.model.decoder.project_in.cpu()
75 | torch.cuda.empty_cache()
76 |
77 | outs = torch.zeros_like(inps)
78 | attention_mask = cache['attention_mask']
79 |
80 | print('Ready.')
81 |
82 | quantizers = {}
83 | for i in range(len(layers)):
84 | layer = layers[i].to(dev)
85 |
86 | subset = find_layers(layer)
87 | gptq = {}
88 | for name in subset:
89 | gptq[name] = GPTQ(subset[name])
90 | gptq[name].quantizer = quant.Quantizer()
91 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits)
92 |
93 | def add_batch(name):
94 |
95 | def tmp(_, inp, out):
96 | gptq[name].add_batch(inp[0].data, out.data)
97 |
98 | return tmp
99 |
100 | handles = []
101 | for name in subset:
102 | handles.append(subset[name].register_forward_hook(add_batch(name)))
103 |
104 | for j in range(args.nsamples):
105 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
106 |
107 | for h in handles:
108 | h.remove()
109 |
110 | for name in subset:
111 | print(f'Quantizing {name} in layer {i+1}/{len(layers)}...')
112 | scale, zero, g_idx, _ = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
113 | quantizers['model.decoder.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
114 | gptq[name].free()
115 |
116 | for j in range(args.nsamples):
117 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
118 |
119 | layers[i] = layer.cpu()
120 | del layer
121 | del gptq
122 | torch.cuda.empty_cache()
123 |
124 | inps, outs = outs, inps
125 |
126 | model.config.use_cache = use_cache
127 |
128 | return quantizers
129 |
130 |
131 | @torch.no_grad()
132 | def opt_eval(model, testenc, dev):
133 | print('Evaluating ...')
134 |
135 | testenc = testenc.input_ids
136 | nsamples = testenc.numel() // model.seqlen
137 |
138 | use_cache = model.config.use_cache
139 | model.config.use_cache = False
140 | layers = model.model.decoder.layers
141 |
142 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
143 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
144 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
145 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
146 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
147 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
148 | layers[0] = layers[0].to(dev)
149 |
150 | dtype = next(iter(model.parameters())).dtype
151 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
152 | cache = {'i': 0, 'attention_mask': None}
153 |
154 | class Catcher(nn.Module):
155 |
156 | def __init__(self, module):
157 | super().__init__()
158 | self.module = module
159 |
160 | def forward(self, inp, **kwargs):
161 | inps[cache['i']] = inp
162 | cache['i'] += 1
163 | cache['attention_mask'] = kwargs['attention_mask']
164 | raise ValueError
165 |
166 | layers[0] = Catcher(layers[0])
167 | for i in range(nsamples):
168 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
169 | try:
170 | model(batch)
171 | except ValueError:
172 | pass
173 | layers[0] = layers[0].module
174 |
175 | layers[0] = layers[0].cpu()
176 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
177 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
178 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
179 | model.model.decoder.project_out = model.model.decoder.project_out.cpu()
180 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
181 | model.model.decoder.project_in = model.model.decoder.project_in.cpu()
182 | torch.cuda.empty_cache()
183 |
184 | outs = torch.zeros_like(inps)
185 | attention_mask = cache['attention_mask']
186 |
187 | for i in range(len(layers)):
188 | print(i)
189 | layer = layers[i].to(dev)
190 |
191 | if args.nearest:
192 | subset = find_layers(layer)
193 | for name in subset:
194 | quantizer = quant.Quantizer()
195 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
196 | W = subset[name].weight.data
197 | quantizer.find_params(W, weight=True)
198 | subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
199 |
200 | for j in range(nsamples):
201 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
202 | layers[i] = layer.cpu()
203 | del layer
204 | torch.cuda.empty_cache()
205 | inps, outs = outs, inps
206 |
207 | if model.model.decoder.final_layer_norm is not None:
208 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
209 | if model.model.decoder.project_out is not None:
210 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
211 | model.lm_head = model.lm_head.to(dev)
212 |
213 | testenc = testenc.to(dev)
214 | nlls = []
215 | for i in range(nsamples):
216 | hidden_states = inps[i].unsqueeze(0)
217 | if model.model.decoder.final_layer_norm is not None:
218 | hidden_states = model.model.decoder.final_layer_norm(hidden_states)
219 | if model.model.decoder.project_out is not None:
220 | hidden_states = model.model.decoder.project_out(hidden_states)
221 | lm_logits = model.lm_head(hidden_states)
222 | shift_logits = lm_logits[:, :-1, :].contiguous()
223 | shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
224 | loss_fct = nn.CrossEntropyLoss()
225 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
226 | neg_log_likelihood = loss.float() * model.seqlen
227 | nlls.append(neg_log_likelihood)
228 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
229 | print(ppl.item())
230 |
231 | model.config.use_cache = use_cache
232 |
233 |
234 | # TODO: perform packing on GPU
235 | def opt_pack(model, quantizers, wbits, groupsize):
236 | layers = find_layers(model)
237 | layers = {n: layers[n] for n in quantizers}
238 | quant.make_quant_linear(model, quantizers, wbits, groupsize)
239 | qlayers = find_layers(model, [quant.QuantLinear])
240 | print('Packing ...')
241 | for name in qlayers:
242 | print(name)
243 | quantizers[name], scale, zero, g_idx = quantizers[name]
244 | qlayers[name].pack(layers[name], scale, zero, g_idx)
245 | print('Done.')
246 | return model
247 |
248 |
249 | def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True):
250 | from transformers import OPTConfig, OPTForCausalLM
251 | config = OPTConfig.from_pretrained(model)
252 |
253 | def noop(*args, **kwargs):
254 | pass
255 |
256 | torch.nn.init.kaiming_uniform_ = noop
257 | torch.nn.init.uniform_ = noop
258 | torch.nn.init.normal_ = noop
259 |
260 | torch.set_default_dtype(torch.half)
261 | transformers.modeling_utils._init_weights = False
262 | torch.set_default_dtype(torch.half)
263 | model = OPTForCausalLM(config)
264 | torch.set_default_dtype(torch.float)
265 | model = model.eval()
266 | layers = find_layers(model)
267 | for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']:
268 | if name in layers:
269 | del layers[name]
270 | quant.make_quant_linear(model, layers, wbits, groupsize)
271 |
272 | del layers
273 |
274 | print('Loading model ...')
275 | if checkpoint.endswith('.safetensors'):
276 | from safetensors.torch import load_file as safe_load
277 | model.load_state_dict(safe_load(checkpoint))
278 | else:
279 | model.load_state_dict(torch.load(checkpoint))
280 |
281 | if warmup_autotune:
282 | quant.autotune_warmup_linear(model, transpose=not (eval))
283 | model.seqlen = model.config.max_position_embeddings
284 | print('Done.')
285 | return model
286 |
287 |
288 | def opt_multigpu(model, gpus):
289 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0])
290 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0])
291 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
292 | model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0])
293 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
294 | model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1])
295 | if hasattr(model.model.decoder, 'final_layer_norm') and model.model.decoder.final_layer_norm:
296 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1])
297 | import copy
298 | model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1])
299 |
300 | cache = {'mask': None}
301 |
302 | class MoveModule(nn.Module):
303 |
304 | def __init__(self, module):
305 | super().__init__()
306 | self.module = module
307 | self.dev = next(iter(self.module.parameters())).device
308 |
309 | def forward(self, *inp, **kwargs):
310 | inp = list(inp)
311 | if inp[0].device != self.dev:
312 | inp[0] = inp[0].to(self.dev)
313 | if cache['mask'] is None or cache['mask'].device != self.dev:
314 | cache['mask'] = kwargs['attention_mask'].to(self.dev)
315 | kwargs['attention_mask'] = cache['mask']
316 | tmp = self.module(*inp, **kwargs)
317 | return tmp
318 |
319 | layers = model.model.decoder.layers
320 | pergpu = math.ceil(len(layers) / len(gpus))
321 | for i in range(len(layers)):
322 | layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
323 |
324 | model.gpus = gpus
325 |
326 |
327 | def benchmark(model, input_ids, check=False):
328 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
329 | torch.cuda.synchronize()
330 |
331 | cache = {'past': None}
332 |
333 | def clear_past(i):
334 |
335 | def tmp(layer, inp, out):
336 | if cache['past']:
337 | cache['past'][i] = None
338 |
339 | return tmp
340 |
341 | for i, layer in enumerate(model.model.decoder.layers):
342 | layer.register_forward_hook(clear_past(i))
343 |
344 | print('Benchmarking ...')
345 |
346 | if check:
347 | loss = nn.CrossEntropyLoss()
348 | tot = 0.
349 |
350 | def sync():
351 | if hasattr(model, 'gpus'):
352 | for gpu in model.gpus:
353 | torch.cuda.synchronize(gpu)
354 | else:
355 | torch.cuda.synchronize()
356 |
357 | with torch.no_grad():
358 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
359 | times = []
360 | for i in range(input_ids.numel()):
361 | tick = time.time()
362 | out = model(input_ids[:, i].reshape(-1), past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
363 | sync()
364 | times.append(time.time() - tick)
365 | print(i, times[-1])
366 | if check and i != input_ids.numel() - 1:
367 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
368 | cache['past'] = list(out.past_key_values)
369 | del out
370 | sync()
371 | import numpy as np
372 | print('Median:', np.median(times))
373 | if check:
374 | print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
375 |
376 |
377 | if __name__ == '__main__':
378 |
379 | parser = argparse.ArgumentParser()
380 |
381 | parser.add_argument('model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
382 | parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
383 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
384 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
385 | parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
386 | parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
387 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
388 | parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
389 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
390 | parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
391 | parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
392 | parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
393 | parser.add_argument('--load', type=str, default='', help='Load quantized model.')
394 | parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
395 | parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
396 | parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
397 | parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
398 | parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
399 |
400 | args = parser.parse_args()
401 |
402 | if type(args.load) is not str:
403 | args.load = args.load.as_posix()
404 |
405 | if args.load:
406 | model = load_quant(args.model, args.load, args.wbits, args.groupsize)
407 | else:
408 | model = get_opt(args.model)
409 | model.eval()
410 |
411 | dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
412 |
413 | if not args.load and args.wbits < 16 and not args.nearest:
414 | tick = time.time()
415 | quantizers = opt_sequential(model, dataloader, DEV)
416 | print(time.time() - tick)
417 |
418 | if args.benchmark:
419 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
420 | if len(gpus) > 1:
421 | opt_multigpu(model, gpus)
422 | else:
423 | model = model.to(DEV)
424 | if args.benchmark:
425 | input_ids = next(iter(dataloader))[0][:, :args.benchmark]
426 | benchmark(model, input_ids, check=args.check)
427 |
428 | if args.eval:
429 | datasets = ['wikitext2', 'ptb', 'c4']
430 | if args.new_eval:
431 | datasets = ['wikitext2', 'ptb-new', 'c4-new']
432 | for dataset in datasets:
433 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
434 | print(dataset)
435 | opt_eval(model, testloader, DEV)
436 |
437 | if args.save:
438 | opt_pack(model, quantizers, args.wbits, args.groupsize)
439 | torch.save(model.state_dict(), args.save)
440 |
441 | if args.save_safetensors:
442 | opt_pack(model, quantizers, args.wbits, args.groupsize)
443 | from safetensors.torch import save_file as safe_save
444 | state_dict = model.state_dict()
445 | state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
446 | safe_save(state_dict, args.save_safetensors)
447 |
--------------------------------------------------------------------------------
/quant/__init__.py:
--------------------------------------------------------------------------------
1 | from .quantizer import Quantizer
2 | from .fused_attn import QuantLlamaAttention, make_quant_attn
3 | from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused
4 | from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear
5 | from .triton_norm import TritonLlamaRMSNorm, make_quant_norm
6 |
--------------------------------------------------------------------------------
/quant/custom_autotune.py:
--------------------------------------------------------------------------------
1 | #https://github.com/fpgaminer/GPTQ-triton
2 | """
3 | Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
4 | """
5 |
6 | import builtins
7 | import math
8 | import time
9 | from typing import Dict
10 |
11 | import triton
12 |
13 |
14 | class Autotuner(triton.KernelInterface):
15 |
16 | def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
17 | '''
18 | :param prune_configs_by: a dict of functions that are used to prune configs, fields:
19 | 'perf_model': performance model used to predicate running time with different configs, returns running time
20 | 'top_k': number of configs to bench
21 | 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
22 | 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
23 | '''
24 | if not configs:
25 | self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
26 | else:
27 | self.configs = configs
28 | self.key_idx = [arg_names.index(k) for k in key]
29 | self.nearest_power_of_two = nearest_power_of_two
30 | self.cache = {}
31 | # hook to reset all required tensor to zeros before relaunching a kernel
32 | self.hook = lambda args: 0
33 | if reset_to_zero is not None:
34 | self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
35 |
36 | def _hook(args):
37 | for i in self.reset_idx:
38 | args[i].zero_()
39 |
40 | self.hook = _hook
41 | self.arg_names = arg_names
42 | # prune configs
43 | if prune_configs_by:
44 | perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
45 | if 'early_config_prune' in prune_configs_by:
46 | early_config_prune = prune_configs_by['early_config_prune']
47 | else:
48 | perf_model, top_k, early_config_prune = None, None, None
49 | self.perf_model, self.configs_top_k = perf_model, top_k
50 | self.early_config_prune = early_config_prune
51 | self.fn = fn
52 |
53 | def _bench(self, *args, config, **meta):
54 | # check for conflicts, i.e. meta-parameters both provided
55 | # as kwargs and by the autotuner
56 | conflicts = meta.keys() & config.kwargs.keys()
57 | if conflicts:
58 | raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
59 | " Make sure that you don't re-define auto-tuned symbols.")
60 | # augment meta-parameters with tunable ones
61 | current = dict(meta, **config.kwargs)
62 |
63 | def kernel_call():
64 | if config.pre_hook:
65 | config.pre_hook(self.nargs)
66 | self.hook(args)
67 | self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
68 |
69 | try:
70 | # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
71 | # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
72 | return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
73 | except triton.compiler.OutOfResources:
74 | return (float('inf'), float('inf'), float('inf'))
75 |
76 | def run(self, *args, **kwargs):
77 | self.nargs = dict(zip(self.arg_names, args))
78 | if len(self.configs) > 1:
79 | key = tuple(args[i] for i in self.key_idx)
80 |
81 | # This reduces the amount of autotuning by rounding the keys to the nearest power of two
82 | # In my testing this gives decent results, and greatly reduces the amount of tuning required
83 | if self.nearest_power_of_two:
84 | key = tuple([2**int(math.log2(x) + 0.5) for x in key])
85 |
86 | if key not in self.cache:
87 | # prune configs
88 | pruned_configs = self.prune_configs(kwargs)
89 | bench_start = time.time()
90 | timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
91 | bench_end = time.time()
92 | self.bench_time = bench_end - bench_start
93 | self.cache[key] = builtins.min(timings, key=timings.get)
94 | self.hook(args)
95 | self.configs_timings = timings
96 | config = self.cache[key]
97 | else:
98 | config = self.configs[0]
99 | self.best_config = config
100 | if config.pre_hook is not None:
101 | config.pre_hook(self.nargs)
102 | return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
103 |
104 | def prune_configs(self, kwargs):
105 | pruned_configs = self.configs
106 | if self.early_config_prune:
107 | pruned_configs = self.early_config_prune(self.configs, self.nargs)
108 | if self.perf_model:
109 | top_k = self.configs_top_k
110 | if isinstance(top_k, float) and top_k <= 1.0:
111 | top_k = int(len(self.configs) * top_k)
112 | if len(pruned_configs) > top_k:
113 | est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
114 | pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
115 | return pruned_configs
116 |
117 | def warmup(self, *args, **kwargs):
118 | self.nargs = dict(zip(self.arg_names, args))
119 | for config in self.prune_configs(kwargs):
120 | self.fn.warmup(
121 | *args,
122 | num_warps=config.num_warps,
123 | num_stages=config.num_stages,
124 | **kwargs,
125 | **config.kwargs,
126 | )
127 | self.nargs = None
128 |
129 |
130 | def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
131 | """
132 | Decorator for auto-tuning a :code:`triton.jit`'d function.
133 | .. highlight:: python
134 | .. code-block:: python
135 | @triton.autotune(configs=[
136 | triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
137 | triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
138 | ],
139 | key=['x_size'] # the two above configs will be evaluated anytime
140 | # the value of x_size changes
141 | )
142 | @triton.jit
143 | def kernel(x_ptr, x_size, **META):
144 | BLOCK_SIZE = META['BLOCK_SIZE']
145 | :note: When all the configurations are evaluated, the kernel will run multiple time.
146 | This means that whatever value the kernel updates will be updated multiple times.
147 | To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
148 | reset the value of the provided tensor to `zero` before running any configuration.
149 | :param configs: a list of :code:`triton.Config` objects
150 | :type configs: list[triton.Config]
151 | :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
152 | :type key: list[str]
153 | :param prune_configs_by: a dict of functions that are used to prune configs, fields:
154 | 'perf_model': performance model used to predicate running time with different configs, returns running time
155 | 'top_k': number of configs to bench
156 | 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
157 | :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
158 | :type reset_to_zero: list[str]
159 | """
160 |
161 | def decorator(fn):
162 | return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
163 |
164 | return decorator
165 |
166 |
167 | def matmul248_kernel_config_pruner(configs, nargs):
168 | """
169 | The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
170 | """
171 | m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
172 | n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
173 | k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
174 |
175 | used = set()
176 | for config in configs:
177 | block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
178 | block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
179 | block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
180 | group_size_m = config.kwargs['GROUP_SIZE_M']
181 |
182 | if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
183 | continue
184 |
185 | used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
186 | yield triton.Config({
187 | 'BLOCK_SIZE_M': block_size_m,
188 | 'BLOCK_SIZE_N': block_size_n,
189 | 'BLOCK_SIZE_K': block_size_k,
190 | 'GROUP_SIZE_M': group_size_m
191 | },
192 | num_stages=config.num_stages,
193 | num_warps=config.num_warps)
194 |
--------------------------------------------------------------------------------
/quant/fused_attn.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from transformers.models.llama.modeling_llama import LlamaAttention
3 | from .quant_linear import *
4 | import triton
5 | import triton.language as tl
6 |
7 |
8 | @triton.jit
9 | def rotate_half_kernel(
10 | qk_seq_ptr,
11 | position_ids_ptr,
12 | qk_seq_stride,
13 | position_ids_batch_stride,
14 | seq_len,
15 | HEAD_DIM: tl.constexpr,
16 | BLOCK_HEIGHT: tl.constexpr,
17 | BLOCK_WIDTH: tl.constexpr,
18 | INV_BASE: tl.constexpr
19 | ):
20 | # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.
21 | # position ids: (bsz, seq_len) -- must be contiguous in the last dimension.
22 |
23 | HALF_HEAD: tl.constexpr = HEAD_DIM // 2
24 | STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH
25 |
26 | batch_seq = tl.program_id(axis=0)
27 | row_blk_x_col_blk = tl.program_id(axis=1)
28 |
29 | row_blk = row_blk_x_col_blk // STEPS_PER_ROW
30 | row = row_blk * BLOCK_HEIGHT
31 | if BLOCK_WIDTH < HALF_HEAD:
32 | col_blk = row_blk_x_col_blk % STEPS_PER_ROW
33 | col = col_blk * BLOCK_WIDTH
34 | else:
35 | col: tl.constexpr = 0
36 |
37 | # A block will never cross a sequence boundary, which simplifies things a lot.
38 | batch = batch_seq // seq_len
39 | seq = batch_seq % seq_len
40 | position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)
41 | # As sometimes happens, just calculating this on the fly is faster than loading it from memory.
42 | # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.
43 | freq = tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE) * position_id
44 | cos = tl.cos(freq).to(tl.float32)
45 | sin = tl.sin(freq).to(tl.float32)
46 |
47 | col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)
48 | embed_offsets = (row * HEAD_DIM + col) + col_offsets
49 | x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets
50 |
51 | for k in range(0, BLOCK_HEIGHT):
52 | x = tl.load(x_ptrs).to(tl.float32)
53 | y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)
54 | out_x = x * cos - y * sin
55 | tl.store(x_ptrs, out_x)
56 | out_y = x * sin + y * cos
57 | tl.store(x_ptrs + HALF_HEAD, out_y)
58 | x_ptrs += HEAD_DIM
59 |
60 |
61 | def triton_rotate_half_(qk, position_ids, config=None):
62 | with torch.cuda.device(qk.device):
63 | batch_size, seq_len, qandk, num_heads, head_dim = qk.shape
64 |
65 | # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective.
66 | config = config or {'BLOCK_HEIGHT': 1, 'BLOCK_WIDTH': min(128, head_dim // 2), 'num_warps': 1}
67 | config['BLOCK_HEIGHT'] = min(config['BLOCK_HEIGHT'], 2 * num_heads)
68 |
69 | assert qk.stride(3) == head_dim
70 | assert qk.stride(4) == 1
71 | assert position_ids.shape == (batch_size, seq_len)
72 | assert position_ids.stride(1) == 1, 'position_ids must be contiguous in the last dimension'
73 | assert (2 * num_heads) % config['BLOCK_HEIGHT'] == 0, f'number of rows not evenly divisible by {config["BLOCK_HEIGHT"]}'
74 | assert (head_dim // 2) % config['BLOCK_WIDTH'] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config["BLOCK_WIDTH"]}'
75 |
76 | qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)
77 | grid = (qk_by_seq.shape[0], (2 * num_heads // config['BLOCK_HEIGHT']) * (head_dim // 2 // config['BLOCK_WIDTH']))
78 |
79 | # Must be the same as the theta of the frequencies used to train the model.
80 | BASE = 10000.0
81 |
82 | rotate_half_kernel[grid](
83 | qk_by_seq,
84 | position_ids,
85 | qk_by_seq.stride(0),
86 | position_ids.stride(0),
87 | seq_len,
88 | HEAD_DIM=head_dim,
89 | BLOCK_HEIGHT=config['BLOCK_HEIGHT'],
90 | BLOCK_WIDTH=config['BLOCK_WIDTH'],
91 | INV_BASE=-2.0 * math.log(BASE) / head_dim,
92 | num_warps=config['num_warps']
93 | )
94 |
95 |
96 | class QuantLlamaAttention(nn.Module):
97 | """Multi-headed attention from 'Attention Is All You Need' paper"""
98 |
99 | def __init__(
100 | self,
101 | hidden_size,
102 | num_heads,
103 | qkv_proj,
104 | o_proj
105 | ):
106 | super().__init__()
107 | self.hidden_size = hidden_size
108 | self.num_heads = num_heads
109 | self.head_dim = hidden_size // num_heads
110 |
111 | if (self.head_dim * num_heads) != self.hidden_size:
112 | raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
113 | f" and `num_heads`: {num_heads}).")
114 | self.qkv_proj = qkv_proj
115 | self.o_proj = o_proj
116 |
117 | def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
118 | """Input shape: Batch x Time x Channel"""
119 |
120 | bsz, q_len, _ = hidden_states.size()
121 |
122 | qkv_states = self.qkv_proj(hidden_states)
123 | qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
124 |
125 | # This updates the query and key states in-place, saving VRAM.
126 | triton_rotate_half_(qkv_states[:, :, :2], position_ids)
127 |
128 | query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
129 | del qkv_states
130 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
131 | key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
132 | value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
133 |
134 | is_causal = past_key_value is None
135 |
136 | kv_seq_len = q_len
137 | if past_key_value is not None:
138 | kv_seq_len += past_key_value[0].shape[-2]
139 |
140 | if past_key_value is not None:
141 | # reuse k, v, self_attention
142 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
143 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
144 |
145 | if use_cache:
146 | # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
147 | # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
148 | key_states = key_states.contiguous()
149 | value_states = value_states.contiguous()
150 | query_states = query_states.contiguous()
151 |
152 | past_key_value = (key_states, value_states) if use_cache else None
153 |
154 | with torch.backends.cuda.sdp_kernel(enable_math=False):
155 | attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
156 | del query_states, key_states, value_states
157 |
158 | attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
159 | attn_output = self.o_proj(attn_output)
160 |
161 | return attn_output, None, past_key_value
162 |
163 |
164 | def make_quant_attn(model):
165 | """
166 | Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
167 | """
168 |
169 | for name, m in model.named_modules():
170 | if not isinstance(m, LlamaAttention):
171 | continue
172 |
173 | q_proj = m.q_proj
174 | k_proj = m.k_proj
175 | v_proj = m.v_proj
176 |
177 | qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
178 | qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
179 | scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
180 | g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
181 | bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
182 |
183 | qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False)
184 | qkv_layer.qweight = qweights
185 | qkv_layer.qzeros = qzeros
186 | qkv_layer.scales = scales
187 | qkv_layer.g_idx = g_idx
188 | qkv_layer.bias = bias
189 | # We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch.
190 |
191 | attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj)
192 |
193 | if '.' in name:
194 | parent_name = name.rsplit('.', 1)[0]
195 | child_name = name[len(parent_name) + 1:]
196 | parent = model.get_submodule(parent_name)
197 | else:
198 | parent_name = ''
199 | parent = model
200 | child_name = name
201 |
202 | #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
203 |
204 | setattr(parent, child_name, attn)
205 |
--------------------------------------------------------------------------------
/quant/fused_mlp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.cuda.amp import custom_bwd, custom_fwd
5 | from transformers.models.llama.modeling_llama import LlamaMLP
6 |
7 | try:
8 | import triton
9 | import triton.language as tl
10 | from . import custom_autotune
11 |
12 | # code based https://github.com/fpgaminer/GPTQ-triton
13 | @custom_autotune.autotune(
14 | configs=[
15 | triton.Config({
16 | 'BLOCK_SIZE_M': 256,
17 | 'BLOCK_SIZE_N': 64,
18 | 'BLOCK_SIZE_K': 32,
19 | 'GROUP_SIZE_M': 8
20 | }, num_stages=4, num_warps=4),
21 | triton.Config({
22 | 'BLOCK_SIZE_M': 64,
23 | 'BLOCK_SIZE_N': 256,
24 | 'BLOCK_SIZE_K': 32,
25 | 'GROUP_SIZE_M': 8
26 | }, num_stages=4, num_warps=4),
27 | triton.Config({
28 | 'BLOCK_SIZE_M': 128,
29 | 'BLOCK_SIZE_N': 128,
30 | 'BLOCK_SIZE_K': 32,
31 | 'GROUP_SIZE_M': 8
32 | }, num_stages=4, num_warps=4),
33 | triton.Config({
34 | 'BLOCK_SIZE_M': 128,
35 | 'BLOCK_SIZE_N': 64,
36 | 'BLOCK_SIZE_K': 32,
37 | 'GROUP_SIZE_M': 8
38 | }, num_stages=4, num_warps=4),
39 | triton.Config({
40 | 'BLOCK_SIZE_M': 64,
41 | 'BLOCK_SIZE_N': 128,
42 | 'BLOCK_SIZE_K': 32,
43 | 'GROUP_SIZE_M': 8
44 | }, num_stages=4, num_warps=4),
45 | triton.Config({
46 | 'BLOCK_SIZE_M': 128,
47 | 'BLOCK_SIZE_N': 32,
48 | 'BLOCK_SIZE_K': 32,
49 | 'GROUP_SIZE_M': 8
50 | }, num_stages=4, num_warps=4), # 3090
51 | triton.Config({
52 | 'BLOCK_SIZE_M': 128,
53 | 'BLOCK_SIZE_N': 16,
54 | 'BLOCK_SIZE_K': 32,
55 | 'GROUP_SIZE_M': 8
56 | }, num_stages=4, num_warps=4), # 3090
57 | triton.Config({
58 | 'BLOCK_SIZE_M': 32,
59 | 'BLOCK_SIZE_N': 32,
60 | 'BLOCK_SIZE_K': 128,
61 | 'GROUP_SIZE_M': 8
62 | }, num_stages=2, num_warps=4), # 3090
63 | triton.Config({
64 | 'BLOCK_SIZE_M': 64,
65 | 'BLOCK_SIZE_N': 16,
66 | 'BLOCK_SIZE_K': 64,
67 | 'GROUP_SIZE_M': 8
68 | }, num_stages=4, num_warps=4), # 3090
69 | triton.Config({
70 | 'BLOCK_SIZE_M': 64,
71 | 'BLOCK_SIZE_N': 32,
72 | 'BLOCK_SIZE_K': 64,
73 | 'GROUP_SIZE_M': 8
74 | }, num_stages=4, num_warps=4), # 3090
75 | ],
76 | key=['M', 'N', 'K'],
77 | nearest_power_of_two=True,
78 | prune_configs_by={
79 | 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
80 | 'perf_model': None,
81 | 'top_k': None,
82 | },
83 | )
84 | @triton.jit
85 | def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn,
86 | stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
87 | """
88 | Computes: C = silu(A * B1) * (A * B2)
89 | A is of shape (M, K) float16
90 | B is of shape (K//8, N) int32
91 | C is of shape (M, N) float16
92 | scales is of shape (1, N) float16
93 | zeros is of shape (1, N//8) int32
94 | """
95 | infearure_per_bits = 32 // bits
96 |
97 | pid = tl.program_id(axis=0)
98 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
99 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
100 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
101 | num_pid_in_group = GROUP_SIZE_M * num_pid_n
102 | group_id = pid // num_pid_in_group
103 | first_pid_m = group_id * GROUP_SIZE_M
104 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
105 | pid_m = first_pid_m + (pid % group_size_m)
106 | pid_n = (pid % num_pid_in_group) // group_size_m
107 |
108 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
110 | offs_k = tl.arange(0, BLOCK_SIZE_K)
111 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
112 | a_mask = (offs_am[:, None] < M)
113 | # b_ptrs is set up such that it repeats elements along the K axis 8 times
114 | b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
115 | b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
116 | g1_ptrs = g1_ptr + offs_k
117 | g2_ptrs = g2_ptr + offs_k
118 | # shifter is used to extract the N bits of each element in the 32-bit word from B
119 | scales1_ptrs = scales1_ptr + offs_bn[None, :]
120 | scales2_ptrs = scales2_ptr + offs_bn[None, :]
121 | zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
122 | zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
123 |
124 | shifter = (offs_k % infearure_per_bits) * bits
125 | zeros_shifter = (offs_bn % infearure_per_bits) * bits
126 | accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
127 | accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
128 | for k in range(0, num_pid_k):
129 | g1_idx = tl.load(g1_ptrs)
130 | g2_idx = tl.load(g2_ptrs)
131 |
132 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
133 | scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
134 | scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
135 |
136 | zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
137 | zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
138 | zeros1 = (zeros1 + 1)
139 |
140 | zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
141 | zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
142 | zeros2 = (zeros2 + 1)
143 |
144 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
145 | b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
146 | b2 = tl.load(b2_ptrs)
147 |
148 | # Now we need to unpack b (which is N-bit values) into 32-bit values
149 | b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
150 | b1 = (b1 - zeros1) * scales1 # Scale and shift
151 | accumulator1 += tl.dot(a, b1)
152 |
153 | b2 = (b2 >> shifter[:, None]) & maxq
154 | b2 = (b2 - zeros2) * scales2
155 | accumulator2 += tl.dot(a, b2)
156 |
157 | a_ptrs += BLOCK_SIZE_K
158 | b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
159 | b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
160 | g1_ptrs += BLOCK_SIZE_K
161 | g2_ptrs += BLOCK_SIZE_K
162 |
163 | accumulator1 = silu(accumulator1)
164 | c = accumulator1 * accumulator2
165 | c = c.to(tl.float16)
166 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
167 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
168 | tl.store(c_ptrs, c, mask=c_mask)
169 |
170 | @triton.jit
171 | def silu(x):
172 | return x * tl.sigmoid(x)
173 | except:
174 | print('triton not installed.')
175 |
176 |
177 | class QuantLlamaMLP(nn.Module):
178 |
179 | def __init__(
180 | self,
181 | gate_proj,
182 | down_proj,
183 | up_proj,
184 | ):
185 | super().__init__()
186 | self.register_buffer('gate_proj_qweight', gate_proj.qweight)
187 | self.register_buffer('gate_proj_scales', gate_proj.scales)
188 | self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
189 | self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)
190 | self.register_buffer('up_proj_qweight', up_proj.qweight)
191 | self.register_buffer('up_proj_scales', up_proj.scales)
192 | self.register_buffer('up_proj_qzeros', up_proj.qzeros)
193 | self.register_buffer('up_proj_g_idx', up_proj.g_idx)
194 |
195 | self.infeatures = gate_proj.infeatures
196 | self.intermediate_size = gate_proj.outfeatures
197 | self.outfeatures = down_proj.outfeatures
198 | self.bits = gate_proj.bits
199 | self.maxq = gate_proj.maxq
200 |
201 | self.down_proj = down_proj
202 |
203 | def forward(self, x):
204 | return self.down_proj(self.triton_llama_mlp(x))
205 |
206 | def triton_llama_mlp(self, x):
207 | with torch.cuda.device(x.device):
208 | out_shape = x.shape[:-1] + (self.intermediate_size, )
209 | x = x.reshape(-1, x.shape[-1])
210 | M, K = x.shape
211 | N = self.intermediate_size
212 | c = torch.empty((M, N), device=x.device, dtype=torch.float16)
213 | grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
214 | fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales,
215 | self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0),
216 | self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0))
217 | c = c.reshape(out_shape)
218 | return c
219 |
220 | def fused2cuda(self):
221 | self.gate_proj_qweight = self.gate_proj_qweight.cuda()
222 | self.gate_proj_scales = self.gate_proj_scales.cuda()
223 | self.gate_proj_qzeros = self.gate_proj_qzeros.cuda()
224 | self.gate_proj_g_idx = self.gate_proj_g_idx.cuda()
225 | self.up_proj_qweight = self.up_proj_qweight.cuda()
226 | self.up_proj_scales = self.up_proj_scales.cuda()
227 | self.up_proj_qzeros = self.up_proj_qzeros.cuda()
228 | self.up_proj_g_idx = self.up_proj_g_idx.cuda()
229 |
230 | def fused2cpu(self):
231 | self.gate_proj_qweight = self.gate_proj_qweight.cpu()
232 | self.gate_proj_scales = self.gate_proj_scales.cpu()
233 | self.gate_proj_qzeros = self.gate_proj_qzeros.cpu()
234 | self.gate_proj_g_idx = self.gate_proj_g_idx.cpu()
235 | self.up_proj_qweight = self.up_proj_qweight.cpu()
236 | self.up_proj_scales = self.up_proj_scales.cpu()
237 | self.up_proj_qzeros = self.up_proj_qzeros.cpu()
238 | self.up_proj_g_idx = self.up_proj_g_idx.cpu()
239 |
240 |
241 | def make_fused_mlp(m, parent_name=''):
242 | """
243 | Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
244 | """
245 | if isinstance(m, LlamaMLP):
246 | return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
247 |
248 | for name, child in m.named_children():
249 | child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
250 |
251 | if isinstance(child, QuantLlamaMLP):
252 | setattr(m, name, child)
253 | return m
254 |
255 |
256 | def autotune_warmup_fused(model):
257 | """
258 | Pre-tunes the quantized kernel
259 | """
260 | from tqdm import tqdm
261 |
262 | kn_values = {}
263 |
264 | for _, m in model.named_modules():
265 | if not isinstance(m, QuantLlamaMLP):
266 | continue
267 |
268 | k = m.infeatures
269 | n = m.intermediate_size
270 |
271 | m.fused2cuda()
272 | if (k, n) not in kn_values:
273 | kn_values[(k, n)] = m
274 |
275 | print(f'Found {len(kn_values)} unique fused mlp KN values.')
276 |
277 | print('Warming up autotune cache ...')
278 | with torch.no_grad():
279 | for m in tqdm(range(0, 12)):
280 | m = 2**m # [1, 2048]
281 | for (k, n), (modules) in kn_values.items():
282 | a = torch.randn(m, k, dtype=torch.float16, device='cuda')
283 | modules.triton_llama_mlp(a)
284 |
285 | for (k, n), (modules) in kn_values.items():
286 | a = torch.randn(m, k, dtype=torch.float16, device='cuda')
287 | modules.fused2cpu()
288 | del kn_values
289 |
--------------------------------------------------------------------------------
/quant/quant_linear.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.cuda.amp import custom_bwd, custom_fwd
6 |
7 | try:
8 | import triton
9 | import triton.language as tl
10 | from . import custom_autotune
11 |
12 | # code based https://github.com/fpgaminer/GPTQ-triton
13 | @custom_autotune.autotune(
14 | configs=[
15 | triton.Config({
16 | 'BLOCK_SIZE_M': 64,
17 | 'BLOCK_SIZE_N': 256,
18 | 'BLOCK_SIZE_K': 32,
19 | 'GROUP_SIZE_M': 8
20 | }, num_stages=4, num_warps=4),
21 | triton.Config({
22 | 'BLOCK_SIZE_M': 128,
23 | 'BLOCK_SIZE_N': 128,
24 | 'BLOCK_SIZE_K': 32,
25 | 'GROUP_SIZE_M': 8
26 | }, num_stages=4, num_warps=4),
27 | triton.Config({
28 | 'BLOCK_SIZE_M': 64,
29 | 'BLOCK_SIZE_N': 128,
30 | 'BLOCK_SIZE_K': 32,
31 | 'GROUP_SIZE_M': 8
32 | }, num_stages=4, num_warps=4),
33 | triton.Config({
34 | 'BLOCK_SIZE_M': 128,
35 | 'BLOCK_SIZE_N': 32,
36 | 'BLOCK_SIZE_K': 32,
37 | 'GROUP_SIZE_M': 8
38 | }, num_stages=4, num_warps=4),
39 | triton.Config({
40 | 'BLOCK_SIZE_M': 64,
41 | 'BLOCK_SIZE_N': 64,
42 | 'BLOCK_SIZE_K': 32,
43 | 'GROUP_SIZE_M': 8
44 | }, num_stages=4, num_warps=4),
45 | triton.Config({
46 | 'BLOCK_SIZE_M': 64,
47 | 'BLOCK_SIZE_N': 128,
48 | 'BLOCK_SIZE_K': 32,
49 | 'GROUP_SIZE_M': 8
50 | }, num_stages=2, num_warps=8),
51 | triton.Config({
52 | 'BLOCK_SIZE_M': 64,
53 | 'BLOCK_SIZE_N': 64,
54 | 'BLOCK_SIZE_K': 64,
55 | 'GROUP_SIZE_M': 8
56 | }, num_stages=3, num_warps=8),
57 | triton.Config({
58 | 'BLOCK_SIZE_M': 32,
59 | 'BLOCK_SIZE_N': 32,
60 | 'BLOCK_SIZE_K': 128,
61 | 'GROUP_SIZE_M': 8
62 | }, num_stages=2, num_warps=4),
63 | ],
64 | key=['M', 'N', 'K'],
65 | nearest_power_of_two=True,
66 | prune_configs_by={
67 | 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
68 | 'perf_model': None,
69 | 'top_k': None,
70 | },
71 | )
72 | @triton.jit
73 | def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
74 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
75 | """
76 | Compute the matrix multiplication C = A x B.
77 | A is of shape (M, K) float16
78 | B is of shape (K//8, N) int32
79 | C is of shape (M, N) float16
80 | scales is of shape (G, N) float16
81 | zeros is of shape (G, N) float16
82 | g_ptr is of shape (K) int32
83 | """
84 | infearure_per_bits = 32 // bits
85 |
86 | pid = tl.program_id(axis=0)
87 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
88 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
89 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
90 | num_pid_in_group = GROUP_SIZE_M * num_pid_n
91 | group_id = pid // num_pid_in_group
92 | first_pid_m = group_id * GROUP_SIZE_M
93 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
94 | pid_m = first_pid_m + (pid % group_size_m)
95 | pid_n = (pid % num_pid_in_group) // group_size_m
96 |
97 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
98 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
99 | offs_k = tl.arange(0, BLOCK_SIZE_K)
100 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
101 | a_mask = (offs_am[:, None] < M)
102 | # b_ptrs is set up such that it repeats elements along the K axis 8 times
103 | b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
104 | g_ptrs = g_ptr + offs_k
105 | # shifter is used to extract the N bits of each element in the 32-bit word from B
106 | scales_ptrs = scales_ptr + offs_bn[None, :]
107 | zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
108 |
109 | shifter = (offs_k % infearure_per_bits) * bits
110 | zeros_shifter = (offs_bn % infearure_per_bits) * bits
111 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
112 |
113 | for k in range(0, num_pid_k):
114 | g_idx = tl.load(g_ptrs)
115 |
116 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
117 | scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
118 | zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
119 |
120 | zeros = (zeros >> zeros_shifter[None, :]) & maxq
121 | zeros = (zeros + 1)
122 |
123 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
124 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
125 |
126 | # Now we need to unpack b (which is N-bit values) into 32-bit values
127 | b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
128 | b = (b - zeros) * scales # Scale and shift
129 |
130 | accumulator += tl.dot(a, b)
131 | a_ptrs += BLOCK_SIZE_K
132 | b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
133 | g_ptrs += BLOCK_SIZE_K
134 |
135 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
136 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
137 | tl.store(c_ptrs, accumulator, mask=c_mask)
138 |
139 | @custom_autotune.autotune(configs=[
140 | triton.Config({
141 | 'BLOCK_SIZE_M': 64,
142 | 'BLOCK_SIZE_N': 32,
143 | 'BLOCK_SIZE_K': 256,
144 | 'GROUP_SIZE_M': 8
145 | }, num_stages=4, num_warps=4),
146 | triton.Config({
147 | 'BLOCK_SIZE_M': 128,
148 | 'BLOCK_SIZE_N': 32,
149 | 'BLOCK_SIZE_K': 128,
150 | 'GROUP_SIZE_M': 8
151 | }, num_stages=4, num_warps=4),
152 | triton.Config({
153 | 'BLOCK_SIZE_M': 64,
154 | 'BLOCK_SIZE_N': 32,
155 | 'BLOCK_SIZE_K': 128,
156 | 'GROUP_SIZE_M': 8
157 | }, num_stages=4, num_warps=4),
158 | triton.Config({
159 | 'BLOCK_SIZE_M': 128,
160 | 'BLOCK_SIZE_N': 32,
161 | 'BLOCK_SIZE_K': 32,
162 | 'GROUP_SIZE_M': 8
163 | }, num_stages=4, num_warps=4),
164 | triton.Config({
165 | 'BLOCK_SIZE_M': 64,
166 | 'BLOCK_SIZE_N': 32,
167 | 'BLOCK_SIZE_K': 64,
168 | 'GROUP_SIZE_M': 8
169 | }, num_stages=4, num_warps=4),
170 | triton.Config({
171 | 'BLOCK_SIZE_M': 64,
172 | 'BLOCK_SIZE_N': 32,
173 | 'BLOCK_SIZE_K': 128,
174 | 'GROUP_SIZE_M': 8
175 | }, num_stages=2, num_warps=8),
176 | triton.Config({
177 | 'BLOCK_SIZE_M': 64,
178 | 'BLOCK_SIZE_N': 64,
179 | 'BLOCK_SIZE_K': 64,
180 | 'GROUP_SIZE_M': 8
181 | }, num_stages=3, num_warps=8),
182 | triton.Config({
183 | 'BLOCK_SIZE_M': 32,
184 | 'BLOCK_SIZE_N': 128,
185 | 'BLOCK_SIZE_K': 32,
186 | 'GROUP_SIZE_M': 8
187 | }, num_stages=2, num_warps=4),
188 | ],
189 | key=['M', 'N', 'K'],
190 | nearest_power_of_two=True)
191 | @triton.jit
192 | def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
193 | stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
194 | """
195 | Compute the matrix multiplication C = A x B.
196 | A is of shape (M, N) float16
197 | B is of shape (K//8, N) int32
198 | C is of shape (M, K) float16
199 | scales is of shape (G, N) float16
200 | zeros is of shape (G, N) float16
201 | g_ptr is of shape (K) int32
202 | """
203 | infearure_per_bits = 32 // bits
204 |
205 | pid = tl.program_id(axis=0)
206 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
207 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
208 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
209 | num_pid_in_group = GROUP_SIZE_M * num_pid_k
210 | group_id = pid // num_pid_in_group
211 | first_pid_m = group_id * GROUP_SIZE_M
212 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
213 | pid_m = first_pid_m + (pid % group_size_m)
214 | pid_k = (pid % num_pid_in_group) // group_size_m
215 |
216 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
217 | offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
218 | offs_n = tl.arange(0, BLOCK_SIZE_N)
219 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
220 | a_mask = (offs_am[:, None] < M)
221 | # b_ptrs is set up such that it repeats elements along the K axis 8 times
222 | b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
223 | g_ptrs = g_ptr + offs_bk
224 | g_idx = tl.load(g_ptrs)
225 |
226 | # shifter is used to extract the N bits of each element in the 32-bit word from B
227 | scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
228 | zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
229 |
230 | shifter = (offs_bk % infearure_per_bits) * bits
231 | zeros_shifter = (offs_n % infearure_per_bits) * bits
232 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
233 |
234 | for n in range(0, num_pid_n):
235 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
236 | scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
237 | zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
238 |
239 | zeros = (zeros >> zeros_shifter[None, :]) & maxq
240 | zeros = (zeros + 1)
241 |
242 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
243 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
244 |
245 | # Now we need to unpack b (which is N-bit values) into 32-bit values
246 | b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
247 | b = (b - zeros) * scales # Scale and shift
248 | b = tl.trans(b)
249 |
250 | accumulator += tl.dot(a, b)
251 | a_ptrs += BLOCK_SIZE_N
252 | b_ptrs += BLOCK_SIZE_N
253 | scales_ptrs += BLOCK_SIZE_N
254 | zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
255 |
256 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
257 | c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
258 | tl.store(c_ptrs, accumulator, mask=c_mask)
259 | except:
260 | print('triton not installed.')
261 |
262 |
263 | def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
264 | with torch.cuda.device(input.device):
265 | output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
266 | grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
267 | matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
268 | qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
269 | return output
270 |
271 |
272 | def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
273 | with torch.cuda.device(input.device):
274 | output_dim = (qweight.shape[0] * 32) // bits
275 | output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
276 | grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
277 | transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
278 | qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
279 | return output
280 |
281 |
282 | class QuantLinearFunction(torch.autograd.Function):
283 |
284 | @staticmethod
285 | @custom_fwd(cast_inputs=torch.float16)
286 | def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
287 | output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
288 | ctx.save_for_backward(qweight, scales, qzeros, g_idx)
289 | ctx.bits, ctx.maxq = bits, maxq
290 | return output
291 |
292 | @staticmethod
293 | @custom_bwd
294 | def backward(ctx, grad_output):
295 | qweight, scales, qzeros, g_idx = ctx.saved_tensors
296 | bits, maxq = ctx.bits, ctx.maxq
297 | grad_input = None
298 |
299 | if ctx.needs_input_grad[0]:
300 | grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
301 | return grad_input, None, None, None, None, None, None
302 |
303 |
304 | class QuantLinear(nn.Module):
305 |
306 | def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
307 | super().__init__()
308 | if bits not in [2, 4, 8]:
309 | raise NotImplementedError("Only 2,4,8 bits are supported.")
310 | self.infeatures = infeatures
311 | self.outfeatures = outfeatures
312 | self.bits = bits
313 | self.maxq = 2**self.bits - 1
314 | self.groupsize = groupsize if groupsize != -1 else infeatures
315 |
316 | self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
317 | self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
318 | self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
319 | self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
320 | if bias:
321 | self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
322 | else:
323 | self.bias = None
324 |
325 | def pack(self, linear, scales, zeros, g_idx=None):
326 | self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
327 |
328 | scales = scales.t().contiguous()
329 | zeros = zeros.t().contiguous()
330 | scale_zeros = zeros * scales
331 | self.scales = scales.clone().half()
332 | if linear.bias is not None:
333 | self.bias = linear.bias.clone().half()
334 |
335 | intweight = []
336 | for idx in range(self.infeatures):
337 | intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
338 | intweight = torch.cat(intweight, dim=1)
339 | intweight = intweight.t().contiguous()
340 | intweight = intweight.numpy().astype(np.uint32)
341 | qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
342 | i = 0
343 | row = 0
344 | while row < qweight.shape[0]:
345 | if self.bits in [2, 4, 8]:
346 | for j in range(i, i + (32 // self.bits)):
347 | qweight[row] |= intweight[j] << (self.bits * (j - i))
348 | i += 32 // self.bits
349 | row += 1
350 | else:
351 | raise NotImplementedError("Only 2,4,8 bits are supported.")
352 |
353 | qweight = qweight.astype(np.int32)
354 | self.qweight = torch.from_numpy(qweight)
355 |
356 | zeros -= 1
357 | zeros = zeros.numpy().astype(np.uint32)
358 | qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
359 | i = 0
360 | col = 0
361 | while col < qzeros.shape[1]:
362 | if self.bits in [2, 4, 8]:
363 | for j in range(i, i + (32 // self.bits)):
364 | qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
365 | i += 32 // self.bits
366 | col += 1
367 | else:
368 | raise NotImplementedError("Only 2,4,8 bits are supported.")
369 |
370 | qzeros = qzeros.astype(np.int32)
371 | self.qzeros = torch.from_numpy(qzeros)
372 |
373 | def forward(self, x):
374 | out_shape = x.shape[:-1] + (self.outfeatures, )
375 | out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
376 | out = out + self.bias if self.bias is not None else out
377 | return out.reshape(out_shape)
378 |
379 |
380 | def make_quant_linear(module, names, bits, groupsize, name=''):
381 | if isinstance(module, QuantLinear):
382 | return
383 | for attr in dir(module):
384 | tmp = getattr(module, attr)
385 | name1 = name + '.' + attr if name != '' else attr
386 | if name1 in names:
387 | delattr(module, attr)
388 | setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
389 | for name1, child in module.named_children():
390 | make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
391 |
392 |
393 | def autotune_warmup_linear(model, transpose=False):
394 | """
395 | Pre-tunes the quantized kernel
396 | """
397 | from tqdm import tqdm
398 |
399 | kn_values = {}
400 |
401 | for _, m in model.named_modules():
402 | if not isinstance(m, QuantLinear):
403 | continue
404 |
405 | k = m.infeatures
406 | n = m.outfeatures
407 |
408 | if (k, n) not in kn_values:
409 | kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
410 |
411 | print(f'Found {len(kn_values)} unique KN Linear values.')
412 |
413 | print('Warming up autotune cache ...')
414 | with torch.no_grad():
415 | for m in tqdm(range(0, 12)):
416 | m = 2**m # [1, 2048]
417 | for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
418 | a = torch.randn(m, k, dtype=torch.float16, device='cuda')
419 | matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
420 | if transpose:
421 | a = torch.randn(m, n, dtype=torch.float16, device='cuda')
422 | transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
423 | del kn_values
424 |
--------------------------------------------------------------------------------
/quant/quantizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | class Quantizer(nn.Module):
8 |
9 | def __init__(self, shape=1):
10 | super(Quantizer, self).__init__()
11 | self.register_buffer('maxq', torch.tensor(0))
12 | self.register_buffer('scale', torch.zeros(shape))
13 | self.register_buffer('zero', torch.zeros(shape))
14 |
15 | def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
16 |
17 | self.maxq = torch.tensor(2**bits - 1)
18 | self.perchannel = perchannel
19 | self.sym = sym
20 | self.mse = mse
21 | self.norm = norm
22 | self.grid = grid
23 | self.maxshrink = maxshrink
24 | if trits:
25 | self.maxq = torch.tensor(-1)
26 | self.scale = torch.zeros_like(self.scale)
27 |
28 | def _quantize(self, x, scale, zero, maxq):
29 | if maxq < 0:
30 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
31 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
32 | return scale * (q - zero)
33 |
34 | def find_params(self, x, weight=False):
35 | dev = x.device
36 | self.maxq = self.maxq.to(dev)
37 |
38 | shape = x.shape
39 | if self.perchannel:
40 | if weight:
41 | x = x.flatten(1)
42 | else:
43 | if len(shape) == 4:
44 | x = x.permute([1, 0, 2, 3])
45 | x = x.flatten(1)
46 | if len(shape) == 3:
47 | x = x.reshape((-1, shape[-1])).t()
48 | if len(shape) == 2:
49 | x = x.t()
50 | else:
51 | x = x.flatten().unsqueeze(0)
52 |
53 | tmp = torch.zeros(x.shape[0], device=dev)
54 | xmin = torch.minimum(x.min(1)[0], tmp)
55 | xmax = torch.maximum(x.max(1)[0], tmp)
56 |
57 | if self.sym:
58 | xmax = torch.maximum(torch.abs(xmin), xmax)
59 | tmp = xmin < 0
60 | if torch.any(tmp):
61 | xmin[tmp] = -xmax[tmp]
62 | tmp = (xmin == 0) & (xmax == 0)
63 | xmin[tmp] = -1
64 | xmax[tmp] = +1
65 |
66 | if self.maxq < 0:
67 | self.scale = xmax
68 | self.zero = xmin
69 | else:
70 | self.scale = (xmax - xmin) / self.maxq
71 | if self.sym:
72 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
73 | else:
74 | self.zero = torch.round(-xmin / self.scale)
75 |
76 | if self.mse:
77 | best = torch.full([x.shape[0]], float('inf'), device=dev)
78 | for i in range(int(self.maxshrink * self.grid)):
79 | p = 1 - i / self.grid
80 | xmin1 = p * xmin
81 | xmax1 = p * xmax
82 | scale1 = (xmax1 - xmin1) / self.maxq
83 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
84 | q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
85 | q -= x
86 | q.abs_()
87 | q.pow_(self.norm)
88 | err = torch.sum(q, 1)
89 | tmp = err < best
90 | if torch.any(tmp):
91 | best[tmp] = err[tmp]
92 | self.scale[tmp] = scale1[tmp]
93 | self.zero[tmp] = zero1[tmp]
94 | if not self.perchannel:
95 | if weight:
96 | tmp = shape[0]
97 | else:
98 | tmp = shape[1] if len(shape) != 3 else shape[2]
99 | self.scale = self.scale.repeat(tmp)
100 | self.zero = self.zero.repeat(tmp)
101 |
102 | if weight:
103 | shape = [-1] + [1] * (len(shape) - 1)
104 | self.scale = self.scale.reshape(shape)
105 | self.zero = self.zero.reshape(shape)
106 | return
107 | if len(shape) == 4:
108 | self.scale = self.scale.reshape((1, -1, 1, 1))
109 | self.zero = self.zero.reshape((1, -1, 1, 1))
110 | if len(shape) == 3:
111 | self.scale = self.scale.reshape((1, 1, -1))
112 | self.zero = self.zero.reshape((1, 1, -1))
113 | if len(shape) == 2:
114 | self.scale = self.scale.unsqueeze(0)
115 | self.zero = self.zero.unsqueeze(0)
116 |
117 | def quantize(self, x):
118 | if self.ready():
119 | return self._quantize(x, self.scale, self.zero, self.maxq)
120 |
121 | return x
122 |
123 | def enabled(self):
124 | return self.maxq > 0
125 |
126 | def ready(self):
127 | return torch.all(self.scale != 0)
128 |
--------------------------------------------------------------------------------
/quant/triton_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import triton
4 | import triton.language as tl
5 | from transformers.models.llama.modeling_llama import LlamaRMSNorm
6 |
7 | @triton.jit
8 | def rms_norm_fwd_fused(
9 | X, # pointer to the input
10 | Y, # pointer to the output
11 | W, # pointer to the weights
12 | stride, # how much to increase the pointer when moving by 1 row
13 | N, # number of columns in X
14 | eps, # epsilon to avoid division by zero
15 | BLOCK_SIZE: tl.constexpr,
16 | ):
17 | # Map the program id to the row of X and Y it should compute.
18 | row = tl.program_id(0)
19 | Y += row * stride
20 | X += row * stride
21 | # Compute variance
22 | _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
23 | for off in range(0, N, BLOCK_SIZE):
24 | cols = off + tl.arange(0, BLOCK_SIZE)
25 | x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
26 | x = tl.where(cols < N, x, 0.)
27 | _var += x * x
28 | var = tl.sum(_var, axis=0) / N
29 | rstd = 1 / tl.sqrt(var + eps)
30 | # Normalize and apply linear transformation
31 | for off in range(0, N, BLOCK_SIZE):
32 | cols = off + tl.arange(0, BLOCK_SIZE)
33 | mask = cols < N
34 | w = tl.load(W + cols, mask=mask)
35 | x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
36 | x_hat = x * rstd
37 | y = x_hat * w
38 | # Write output
39 | tl.store(Y + cols, y, mask=mask)
40 |
41 | class TritonLlamaRMSNorm(nn.Module):
42 | def __init__(self, weight, eps=1e-6):
43 | """
44 | LlamaRMSNorm is equivalent to T5LayerNorm
45 | """
46 | super().__init__()
47 | self.weight = weight
48 | self.variance_epsilon = eps
49 |
50 | def forward(self, x):
51 | with torch.cuda.device(x.device):
52 | y = torch.empty_like(x)
53 | # reshape input data into 2D tensor
54 | x_arg = x.reshape(-1, x.shape[-1])
55 | M, N = x_arg.shape
56 | # Less than 64KB per feature: enqueue fused kernel
57 | MAX_FUSED_SIZE = 65536 // x.element_size()
58 | BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
59 | if N > BLOCK_SIZE:
60 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
61 | # heuristics for number of warps
62 | num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
63 | # enqueue kernel
64 | rms_norm_fwd_fused[(M,)](x_arg, y, self.weight,
65 | x_arg.stride(0), N, self.variance_epsilon,
66 | BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
67 | return y
68 |
69 |
70 | def make_quant_norm(model):
71 | """
72 | Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules
73 | """
74 |
75 | for name, m in model.named_modules():
76 | if not isinstance(m, LlamaRMSNorm):
77 | continue
78 |
79 | norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon)
80 |
81 | if '.' in name:
82 | parent_name = name.rsplit('.', 1)[0]
83 | child_name = name[len(parent_name) + 1:]
84 | parent = model.get_submodule(parent_name)
85 | else:
86 | parent_name = ''
87 | parent = model
88 | child_name = name
89 |
90 | #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
91 |
92 | setattr(parent, child_name, norm)
93 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | safetensors==0.3.1
2 | datasets==2.10.1
3 | sentencepiece
4 | git+https://github.com/huggingface/transformers
5 | accelerate==0.20.3
6 | triton==2.0.0
7 | texttable
8 | toml
9 | numpy
10 | protobuf==3.20.2
11 |
12 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .modelutils import DEV, find_layers, gen_conditions, torch_snr_error
2 | from .datautils import set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
3 | from .export import export_quant_table
4 |
--------------------------------------------------------------------------------
/utils/datautils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def set_seed(seed):
6 | np.random.seed(seed)
7 | torch.random.manual_seed(seed)
8 |
9 |
10 | def get_wikitext2(nsamples, seed, seqlen, model):
11 | from datasets import load_dataset
12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
14 |
15 | from transformers import AutoTokenizer
16 | try:
17 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
18 | except:
19 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
20 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
21 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
22 |
23 | import random
24 | random.seed(seed)
25 | trainloader = []
26 | for _ in range(nsamples):
27 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
28 | j = i + seqlen
29 | inp = trainenc.input_ids[:, i:j]
30 | tar = inp.clone()
31 | tar[:, :-1] = -100
32 | trainloader.append((inp, tar))
33 | return trainloader, testenc
34 |
35 |
36 | def get_ptb(nsamples, seed, seqlen, model):
37 | from datasets import load_dataset
38 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
39 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
40 |
41 | from transformers import AutoTokenizer
42 | try:
43 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
44 | except:
45 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
46 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt')
47 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt')
48 |
49 | import random
50 | random.seed(seed)
51 | trainloader = []
52 | for _ in range(nsamples):
53 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
54 | j = i + seqlen
55 | inp = trainenc.input_ids[:, i:j]
56 | tar = inp.clone()
57 | tar[:, :-1] = -100
58 | trainloader.append((inp, tar))
59 | return trainloader, testenc
60 |
61 |
62 | def get_c4(nsamples, seed, seqlen, model):
63 | from datasets import load_dataset
64 | traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False)
65 | valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False)
66 |
67 | from transformers import AutoTokenizer
68 | try:
69 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
70 | except:
71 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
72 |
73 | import random
74 | random.seed(seed)
75 | trainloader = []
76 | for _ in range(nsamples):
77 | while True:
78 | i = random.randint(0, len(traindata) - 1)
79 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
80 | if trainenc.input_ids.shape[1] >= seqlen:
81 | break
82 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
83 | j = i + seqlen
84 | inp = trainenc.input_ids[:, i:j]
85 | tar = inp.clone()
86 | tar[:, :-1] = -100
87 | trainloader.append((inp, tar))
88 |
89 | import random
90 | random.seed(0)
91 | valenc = []
92 | for _ in range(256):
93 | while True:
94 | i = random.randint(0, len(valdata) - 1)
95 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
96 | if tmp.input_ids.shape[1] >= seqlen:
97 | break
98 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
99 | j = i + seqlen
100 | valenc.append(tmp.input_ids[:, i:j])
101 | valenc = torch.hstack(valenc)
102 |
103 | class TokenizerWrapper:
104 |
105 | def __init__(self, input_ids):
106 | self.input_ids = input_ids
107 |
108 | valenc = TokenizerWrapper(valenc)
109 |
110 | return trainloader, valenc
111 |
112 |
113 | def get_ptb_new(nsamples, seed, seqlen, model):
114 | from datasets import load_dataset
115 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
116 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
117 |
118 | from transformers import AutoTokenizer
119 | try:
120 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
121 | except:
122 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
123 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
124 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
125 |
126 | import random
127 | random.seed(seed)
128 | trainloader = []
129 | for _ in range(nsamples):
130 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
131 | j = i + seqlen
132 | inp = trainenc.input_ids[:, i:j]
133 | tar = inp.clone()
134 | tar[:, :-1] = -100
135 | trainloader.append((inp, tar))
136 | return trainloader, testenc
137 |
138 |
139 | def get_c4_new(nsamples, seed, seqlen, model):
140 | from datasets import load_dataset
141 | traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
142 | valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
143 |
144 | from transformers import AutoTokenizer
145 | try:
146 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
147 | except:
148 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
149 |
150 | import random
151 | random.seed(seed)
152 | trainloader = []
153 | for _ in range(nsamples):
154 | while True:
155 | i = random.randint(0, len(traindata) - 1)
156 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
157 | if trainenc.input_ids.shape[1] >= seqlen:
158 | break
159 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
160 | j = i + seqlen
161 | inp = trainenc.input_ids[:, i:j]
162 | tar = inp.clone()
163 | tar[:, :-1] = -100
164 | trainloader.append((inp, tar))
165 |
166 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
167 | valenc = valenc.input_ids[:, :(256 * seqlen)]
168 |
169 | class TokenizerWrapper:
170 |
171 | def __init__(self, input_ids):
172 | self.input_ids = input_ids
173 |
174 | valenc = TokenizerWrapper(valenc)
175 |
176 | return trainloader, valenc
177 |
178 |
179 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
180 | if 'wikitext2' in name:
181 | return get_wikitext2(nsamples, seed, seqlen, model)
182 | if 'ptb' in name:
183 | if 'new' in name:
184 | return get_ptb_new(nsamples, seed, seqlen, model)
185 | return get_ptb(nsamples, seed, seqlen, model)
186 | if 'c4' in name:
187 | if 'new' in name:
188 | return get_c4_new(nsamples, seed, seqlen, model)
189 | return get_c4(nsamples, seed, seqlen, model)
190 |
--------------------------------------------------------------------------------
/utils/export.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import toml
3 | import os
4 |
5 |
6 | def export_quant_table(quantizers: dict, quant_dir: str, format: str = 'toml'):
7 |
8 | table = {}
9 |
10 | def save_tensor(name: str, tensor):
11 | np.save(os.path.join(quant_dir, name), tensor.numpy())
12 | return '{}.npy'.format(name)
13 |
14 | for key, value in quantizers.items():
15 | quantizer = value[0]
16 |
17 | dump = dict()
18 |
19 | sym = quantizer.sym
20 | if not sym:
21 | dump['zero'] = save_tensor(name=key + '.zero', tensor=value[2])
22 | dump['scale'] = save_tensor(name=key + '.scale', tensor=value[1])
23 | dump['wbits'] = value[4]
24 | dump['groupsize'] = value[5]
25 | if value[5] > 0:
26 | dump['group_ids'] = save_tensor(name=key + '.group_ids', tensor=value[3])
27 |
28 | dump['sym'] = sym
29 | dump['perchannel'] = quantizer.perchannel
30 |
31 | table[key] = dump
32 |
33 | if not os.path.exists(quant_dir):
34 | os.mkdir(quant_dir)
35 |
36 | with open(os.path.join(quant_dir, 'quant.toml'), 'w') as f:
37 | toml.dump(table, f)
38 |
--------------------------------------------------------------------------------
/utils/modelutils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | DEV = torch.device('cuda:0')
5 |
6 |
7 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
8 | if type(module) in layers:
9 | return {name: module}
10 | res = {}
11 | for name1, child in module.named_children():
12 | res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
13 | return res
14 |
15 |
16 | def gen_conditions(_wbits, _groupsize):
17 | wbits = _wbits
18 | groupsize = _groupsize
19 | conditions = []
20 | while True:
21 | if wbits >= 8:
22 | if groupsize == -1 or groupsize == 32:
23 | break
24 |
25 | if groupsize > 32:
26 | groupsize /= 2
27 | else:
28 | wbits *= 2
29 | groupsize = _groupsize
30 |
31 | conditions.append((int(wbits), int(groupsize)))
32 | return conditions
33 |
34 |
35 | # copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
36 | def torch_snr_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
37 | """
38 | Compute SNR between y_pred(tensor) and y_real(tensor)
39 |
40 | SNR can be calcualted as following equation:
41 |
42 | SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
43 |
44 | if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
45 |
46 | SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
47 | Args:
48 | y_pred (torch.Tensor): _description_
49 | y_real (torch.Tensor): _description_
50 | reduction (str, optional): _description_. Defaults to 'mean'.
51 | Raises:
52 | ValueError: _description_
53 | ValueError: _description_
54 | Returns:
55 | torch.Tensor: _description_
56 | """
57 | y_pred = y_pred.type(torch.float32)
58 | y_real = y_real.type(torch.float32)
59 |
60 | if y_pred.shape != y_real.shape:
61 | raise ValueError(f'Can not compute snr loss for tensors with different shape. '
62 | f'({y_pred.shape} and {y_real.shape})')
63 | reduction = str(reduction).lower()
64 |
65 | if y_pred.ndim == 1:
66 | y_pred = y_pred.unsqueeze(0)
67 | y_real = y_real.unsqueeze(0)
68 |
69 | y_pred = y_pred.flatten(start_dim=1)
70 | y_real = y_real.flatten(start_dim=1)
71 |
72 | noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
73 | signal_power = torch.pow(y_real, 2).sum(dim=-1)
74 | snr = (noise_power) / (signal_power + 1e-7)
75 |
76 | if reduction == 'mean':
77 | return torch.mean(snr)
78 | elif reduction == 'sum':
79 | return torch.sum(snr)
80 | elif reduction == 'none':
81 | return snr
82 | else:
83 | raise ValueError(f'Unsupported reduction method.')
84 |
--------------------------------------------------------------------------------