├── .gitignore ├── .style.yapf ├── LICENSE.txt ├── README.md ├── convert_llama_weights_to_hf.py ├── gptq.py ├── llama.py ├── llama_inference.py ├── llama_inference_offload.py ├── neox.py ├── opt.py ├── quant ├── __init__.py ├── custom_autotune.py ├── fused_attn.py ├── fused_mlp.py ├── quant_linear.py ├── quantizer.py └── triton_norm.py ├── requirements.txt └── utils ├── __init__.py ├── datautils.py ├── export.py └── modelutils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | column_limit = 200 4 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPTQ-for-LLaMA 2 | 3 | **I am currently focusing on [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) and recommend using [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) instead of GPTQ for Llama.** 4 | 5 | 6 | 7 | 4 bits quantization of [LLaMA](https://arxiv.org/abs/2302.13971) using [GPTQ](https://arxiv.org/abs/2210.17323) 8 | 9 | GPTQ is SOTA one-shot weight quantization method 10 | 11 | **It can be used universally, but it is not the [fastest](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/old-cuda) and only supports linux.** 12 | 13 | **Triton only supports Linux, so if you are a Windows user, please use [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install).** 14 | 15 | ## News or Update 16 | **AutoGPTQ-triton, a packaged version of GPTQ with triton, has been integrated into [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ).** 17 | ## Result 18 |
19 | LLaMA-7B(click me) 20 | 21 | | [LLaMA-7B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) | 22 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- | 23 | | FP16 | 16 | - | 13940 | 5.68 | 12.5 | 24 | | RTN | 4 | - | - | 6.29 | - | 25 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 4740 | 6.09 | 3.5 | 26 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 4891 | 5.85 | 3.6 | 27 | | RTN | 3 | - | - | 25.54 | - | 28 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 3852 | 8.07 | 2.7 | 29 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 4116 | 6.61 | 3.0 | 30 | 31 |
32 | 33 |
34 | LLaMA-13B 35 | 36 | | [LLaMA-13B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) | 37 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- | 38 | | FP16 | 16 | - | OOM | 5.09 | 24.2 | 39 | | RTN | 4 | - | - | 5.53 | - | 40 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 8410 | 5.36 | 6.5 | 41 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 8747 | 5.20 | 6.7 | 42 | | RTN | 3 | - | - | 11.40 | - | 43 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 6870 | 6.63 | 5.1 | 44 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 7277 | 5.62 | 5.4 | 45 | 46 |
47 | 48 |
49 | LLaMA-33B 50 | 51 | | [LLaMA-33B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) | 52 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- | 53 | | FP16 | 16 | - | OOM | 4.10 | 60.5 | 54 | | RTN | 4 | - | - | 4.54 | - | 55 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 19493 | 4.45 | 15.7 | 56 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 20570 | 4.23 | 16.3 | 57 | | RTN | 3 | - | - | 14.89 | - | 58 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 15493 | 5.69 | 12.0 | 59 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 16566 | 4.80 | 13.0 | 60 | 61 |
62 | 63 |
64 | LLaMA-65B 65 | 66 | | [LLaMA-65B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) | 67 | | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- | 68 | | FP16 | 16 | - | OOM | 3.53 | 121.0 | 69 | | RTN | 4 | - | - | 3.92 | - | 70 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | OOM | 3.84 | 31.1 | 71 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | OOM | 3.65 | 32.3 | 72 | | RTN | 3 | - | - | 10.59 | - | 73 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | OOM | 5.04 | 23.6 | 74 | | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | OOM | 4.17 | 25.6 | 75 |
76 | 77 | Quantization requires a large amount of CPU memory. However, the memory required can be reduced by using swap memory. 78 | 79 | Depending on the GPUs/drivers, there may be a difference in performance, which decreases as the model size increases.(https://github.com/IST-DASLab/gptq/issues/1) 80 | 81 | According to [GPTQ paper](https://arxiv.org/abs/2210.17323), As the size of the model increases, the difference in performance between FP16 and GPTQ decreases. 82 | 83 | ## GPTQ vs [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 84 | 85 |
86 | LLaMA-7B(click me) 87 | 88 | | [LLaMA-7B(seqlen=2048)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) | 89 | | --------------------------------------------------------------- | ------------------- | ----------- | --------- | 90 | | FP16 | 16 | 13948 | 5.22 | 91 | | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 4781 | 5.30 | 92 | | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 4804 | 5.30 | 93 | | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 5102 | 5.30 | 94 | | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 5102 | 5.33 | 95 | 96 |
97 | 98 |
99 | LLaMA-13B 100 | 101 | | [LLaMA-13B(seqlen=2048)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) | 102 | | ---------------------------------------------------------------- | ------------------- | ----------- | --------- | 103 | | FP16 | 16 | OOM | - | 104 | | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 8589 | 5.02 | 105 | | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 8581 | 5.04 | 106 | | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 9170 | 5.04 | 107 | | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 9170 | 5.11 | 108 |
109 | 110 |
111 | LLaMA-33B 112 | 113 | | [LLaMA-33B(seqlen=1024)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) | 114 | | ---------------------------------------------------------------- | ------------------- | ----------- | --------- | 115 | | FP16 | 16 | OOM | - | 116 | | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 18441 | 3.71 | 117 | | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 18313 | 3.76 | 118 | | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 19729 | 3.75 | 119 | | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 19729 | 3.75 | 120 | 121 |
122 | 123 | ## Installation 124 | If you don't have [conda](https://docs.conda.io/en/latest/miniconda.html), install it first. 125 | ``` 126 | conda create --name gptq python=3.9 -y 127 | conda activate gptq 128 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia 129 | # Or, if you're having trouble with conda, use pip with python3.9: 130 | # pip3 install torch torchvision torchaudio 131 | 132 | git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa 133 | cd GPTQ-for-LLaMa 134 | pip install -r requirements.txt 135 | ``` 136 | ## Dependencies 137 | 138 | * `torch`: tested on v2.0.0+cu117 139 | * `transformers`: tested on v4.28.0.dev0 140 | * `datasets`: tested on v2.10.1 141 | * `safetensors`: tested on v0.3.0 142 | 143 | All experiments were run on a single NVIDIA RTX3090. 144 | 145 | # Language Generation 146 | ## LLaMA 147 | 148 | ``` 149 | #convert LLaMA to hf 150 | python convert_llama_weights_to_hf.py --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir ./llama-hf 151 | 152 | # Benchmark language generation with 4-bit LLaMA-7B: 153 | 154 | # Save compressed model 155 | CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save llama7b-4bit-128g.pt 156 | 157 | # Or save compressed `.safetensors` model 158 | CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save_safetensors llama7b-4bit-128g.safetensors 159 | 160 | # Benchmark generating a 2048 token sequence with the saved model 161 | CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --benchmark 2048 --check 162 | 163 | # Benchmark FP16 baseline, note that the model will be split across all listed GPUs 164 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python llama.py ${MODEL_DIR} c4 --benchmark 2048 --check 165 | 166 | # model inference with the saved model 167 | CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama" 168 | 169 | # model inference with the saved model using safetensors loaded direct to gpu 170 | CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.safetensors --text "this is llama" --device=0 171 | 172 | # model inference with the saved model with offload(This is very slow). 173 | CUDA_VISIBLE_DEVICES=0 python llama_inference_offload.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama" --pre_layer 16 174 | It takes about 180 seconds to generate 45 tokens(5->50 tokens) on single RTX3090 based on LLaMa-65B. pre_layer is set to 50. 175 | ``` 176 | Basically, 4-bit quantization and 128 groupsize are recommended. 177 | 178 | You can also export quantization parameters with toml+numpy format. 179 | ``` 180 | CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --quant-directory ${TOML_DIR} 181 | ``` 182 | 183 | # Acknowledgements 184 | This code is based on [GPTQ](https://github.com/IST-DASLab/gptq) 185 | 186 | Thanks to Meta AI for releasing [LLaMA](https://arxiv.org/abs/2302.13971), a powerful LLM. 187 | 188 | Triton GPTQ kernel code is based on [GPTQ-triton](https://github.com/fpgaminer/GPTQ-triton) 189 | -------------------------------------------------------------------------------- /convert_llama_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from transformers.models.llama.convert_llama_weights_to_hf import write_model, write_tokenizer 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--input_dir", 10 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 11 | ) 12 | parser.add_argument( 13 | "--model_size", 14 | choices=["7B", "13B", "30B", "65B", "tokenizer_only"], 15 | ) 16 | parser.add_argument( 17 | "--output_dir", 18 | help="Location to write HF model and tokenizer", 19 | ) 20 | args = parser.parse_args() 21 | if args.model_size != "tokenizer_only": 22 | write_model( 23 | model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()), 24 | input_base_path=os.path.join(args.input_dir, args.model_size), 25 | model_size=args.model_size, 26 | ) 27 | write_tokenizer( 28 | tokenizer_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()), 29 | input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"), 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /gptq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | import quant 8 | from texttable import Texttable 9 | from utils import torch_snr_error 10 | 11 | torch.backends.cuda.matmul.allow_tf32 = False 12 | torch.backends.cudnn.allow_tf32 = False 13 | 14 | 15 | class Observer: 16 | 17 | def __init__(self, topk=32): 18 | self.loss_list = [] 19 | self.topk = topk 20 | 21 | def submit(self, name: str, layerid: int, gptq, error: float): 22 | 23 | item = (name, layerid, {'gptq': gptq, 'error': error}) 24 | 25 | if len(self.loss_list) < self.topk: 26 | self.loss_list.append(item) 27 | return 28 | 29 | min_error = error 30 | min_idx = -1 31 | for idx, data in enumerate(self.loss_list): 32 | if min_error > data[2]['error']: 33 | min_idx = idx 34 | min_error = data[2]['error'] 35 | 36 | if min_idx >= 0: 37 | self.loss_list[min_idx] = item 38 | 39 | def print(self): 40 | self.loss_list = sorted(self.loss_list, key=lambda s: s[2]['error'], reverse=True) 41 | 42 | table = Texttable() 43 | 44 | table.header(['name', 'error']) 45 | table.set_cols_dtype(['t', 'f']) 46 | 47 | for item in self.loss_list: 48 | table.add_row([f"{item[0]}.{item[1]}", item[2]['error']]) 49 | print(table.draw()) 50 | print('\n') 51 | 52 | def items(self): 53 | return self.loss_list 54 | 55 | 56 | class GPTQ: 57 | 58 | def __init__(self, layer, observe=False): 59 | self.layer = layer 60 | self.dev = self.layer.weight.device 61 | W = layer.weight.data.clone() 62 | if isinstance(self.layer, nn.Conv2d): 63 | W = W.flatten(1) 64 | if isinstance(self.layer, transformers.Conv1D): 65 | W = W.t() 66 | self.rows = W.shape[0] 67 | self.columns = W.shape[1] 68 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 69 | self.nsamples = 0 70 | self.quantizer = quant.Quantizer() 71 | self.observe = observe 72 | 73 | def add_batch(self, inp, out): 74 | # Hessian H = 2 X XT + λ I 75 | if self.observe: 76 | self.inp1 = inp 77 | self.out1 = out 78 | else: 79 | self.inp1 = None 80 | self.out1 = None 81 | 82 | if len(inp.shape) == 2: 83 | inp = inp.unsqueeze(0) 84 | tmp = inp.shape[0] 85 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 86 | if len(inp.shape) == 3: 87 | inp = inp.reshape((-1, inp.shape[-1])) 88 | inp = inp.t() 89 | if isinstance(self.layer, nn.Conv2d): 90 | unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride) 91 | inp = unfold(inp) 92 | inp = inp.permute([1, 0, 2]) 93 | inp = inp.flatten(1) 94 | self.H *= self.nsamples / (self.nsamples + tmp) 95 | self.nsamples += tmp 96 | # inp = inp.float() 97 | inp = math.sqrt(2 / self.nsamples) * inp.float() 98 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 99 | self.H += inp.matmul(inp.t()) 100 | 101 | def print_loss(self, name, q_weight, weight_error, timecost): 102 | table = Texttable() 103 | name += ' ' * (16 - len(name)) 104 | 105 | table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time']) 106 | 107 | # assign weight 108 | self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 109 | 110 | if self.inp1 is not None: 111 | # quantize input to int8 112 | quantizer = quant.Quantizer() 113 | quantizer.configure(8, perchannel=False, sym=True, mse=False) 114 | quantizer.find_params(self.inp1) 115 | q_in = quantizer.quantize(self.inp1).type(torch.float16) 116 | q_out = self.layer(q_in) 117 | 118 | # get kinds of SNR 119 | q_SNR = torch_snr_error(q_out, self.out1).item() 120 | fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() 121 | else: 122 | q_SNR = '-' 123 | fp_SNR = '-' 124 | 125 | table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) 126 | print(table.draw().split('\n')[-2]) 127 | 128 | def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''): 129 | self.layer.to(self.dev) 130 | 131 | W = self.layer.weight.data.clone() 132 | if isinstance(self.layer, nn.Conv2d): 133 | W = W.flatten(1) 134 | if isinstance(self.layer, transformers.Conv1D): 135 | W = W.t() 136 | W = W.float() 137 | 138 | tick = time.time() 139 | 140 | if not self.quantizer.ready(): 141 | self.quantizer.find_params(W, weight=True) 142 | 143 | H = self.H 144 | if not self.observe: 145 | del self.H 146 | dead = torch.diag(H) == 0 147 | H[dead, dead] = 1 148 | W[:, dead] = 0 149 | 150 | if actorder: 151 | perm = torch.argsort(torch.diag(H), descending=True) 152 | W = W[:, perm] 153 | H = H[perm][:, perm] 154 | 155 | Losses = torch.zeros_like(W) 156 | Q = torch.zeros_like(W) 157 | 158 | damp = percdamp * torch.mean(torch.diag(H)) 159 | diag = torch.arange(self.columns, device=self.dev) 160 | H[diag, diag] += damp 161 | H = torch.linalg.cholesky(H) 162 | H = torch.cholesky_inverse(H) 163 | H = torch.linalg.cholesky(H, upper=True) 164 | Hinv = H 165 | 166 | g_idx = [] 167 | scale = [] 168 | zero = [] 169 | now_idx = 1 170 | 171 | for i1 in range(0, self.columns, blocksize): 172 | i2 = min(i1 + blocksize, self.columns) 173 | count = i2 - i1 174 | 175 | W1 = W[:, i1:i2].clone() 176 | Q1 = torch.zeros_like(W1) 177 | Err1 = torch.zeros_like(W1) 178 | Losses1 = torch.zeros_like(W1) 179 | Hinv1 = Hinv[i1:i2, i1:i2] 180 | 181 | for i in range(count): 182 | w = W1[:, i] 183 | d = Hinv1[i, i] 184 | 185 | if groupsize != -1: 186 | if (i1 + i) % groupsize == 0: 187 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) 188 | 189 | if ((i1 + i) // groupsize) - now_idx == -1: 190 | scale.append(self.quantizer.scale) 191 | zero.append(self.quantizer.zero) 192 | now_idx += 1 193 | 194 | q = self.quantizer.quantize(w.unsqueeze(1)).flatten() 195 | Q1[:, i] = q 196 | Losses1[:, i] = (w - q)**2 / d**2 197 | 198 | err1 = (w - q) / d 199 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 200 | Err1[:, i] = err1 201 | 202 | Q[:, i1:i2] = Q1 203 | Losses[:, i1:i2] = Losses1 / 2 204 | 205 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 206 | 207 | torch.cuda.synchronize() 208 | error = torch.sum(Losses).item() 209 | 210 | groupsize = groupsize if groupsize != -1 else self.columns 211 | g_idx = [i // groupsize for i in range(self.columns)] 212 | g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) 213 | if actorder: 214 | invperm = torch.argsort(perm) 215 | Q = Q[:, invperm] 216 | g_idx = g_idx[invperm] 217 | 218 | if isinstance(self.layer, transformers.Conv1D): 219 | Q = Q.t() 220 | 221 | self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)) 222 | 223 | if scale == []: 224 | scale.append(self.quantizer.scale) 225 | zero.append(self.quantizer.zero) 226 | scale = torch.cat(scale, dim=1) 227 | zero = torch.cat(zero, dim=1) 228 | return scale, zero, g_idx, error 229 | 230 | def free(self): 231 | self.inp1 = None 232 | self.out1 = None 233 | self.H = None 234 | self.Losses = None 235 | self.Trace = None 236 | torch.cuda.empty_cache() 237 | -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import quant 7 | 8 | from gptq import GPTQ, Observer 9 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions 10 | from texttable import Texttable 11 | 12 | 13 | def get_llama(model): 14 | 15 | def skip(*args, **kwargs): 16 | pass 17 | 18 | torch.nn.init.kaiming_uniform_ = skip 19 | torch.nn.init.uniform_ = skip 20 | torch.nn.init.normal_ = skip 21 | from transformers import LlamaForCausalLM 22 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16) 23 | model.seqlen = 2048 24 | return model 25 | 26 | 27 | @torch.no_grad() 28 | def llama_sequential(model, dataloader, dev): 29 | print('Starting ...') 30 | 31 | use_cache = model.config.use_cache 32 | model.config.use_cache = False 33 | layers = model.model.layers 34 | 35 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 36 | model.model.norm = model.model.norm.to(dev) 37 | layers[0] = layers[0].to(dev) 38 | 39 | dtype = next(iter(model.parameters())).dtype 40 | inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 41 | cache = {'i': 0, 'attention_mask': None} 42 | 43 | class Catcher(nn.Module): 44 | 45 | def __init__(self, module): 46 | super().__init__() 47 | self.module = module 48 | 49 | def forward(self, inp, **kwargs): 50 | inps[cache['i']] = inp 51 | cache['i'] += 1 52 | cache['attention_mask'] = kwargs['attention_mask'] 53 | cache['position_ids'] = kwargs['position_ids'] 54 | raise ValueError 55 | 56 | layers[0] = Catcher(layers[0]) 57 | for batch in dataloader: 58 | try: 59 | model(batch[0].to(dev)) 60 | except ValueError: 61 | pass 62 | layers[0] = layers[0].module 63 | 64 | layers[0] = layers[0].cpu() 65 | model.model.embed_tokens = model.model.embed_tokens.cpu() 66 | model.model.norm = model.model.norm.cpu() 67 | torch.cuda.empty_cache() 68 | 69 | outs = torch.zeros_like(inps) 70 | attention_mask = cache['attention_mask'] 71 | position_ids = cache['position_ids'] 72 | 73 | print('Ready.') 74 | 75 | quantizers = {} 76 | observer = Observer() 77 | for i in range(len(layers)): 78 | 79 | print(f'Quantizing layer {i+1}/{len(layers)}..') 80 | print('+------------------+--------------+------------+-----------+-------+') 81 | print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') 82 | print('+==================+==============+============+===========+=======+') 83 | 84 | layer = layers[i].to(dev) 85 | full = find_layers(layer) 86 | if args.true_sequential: 87 | sequential = [['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj']] 88 | else: 89 | sequential = [list(full.keys())] 90 | 91 | for names in sequential: 92 | subset = {n: full[n] for n in names} 93 | gptq = {} 94 | for name in subset: 95 | gptq[name] = GPTQ(subset[name], observe=args.observe) 96 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 97 | 98 | def add_batch(name): 99 | 100 | def tmp(_, inp, out): 101 | gptq[name].add_batch(inp[0].data, out.data) 102 | 103 | return tmp 104 | 105 | handles = [] 106 | for name in subset: 107 | handles.append(subset[name].register_forward_hook(add_batch(name))) 108 | for j in range(args.nsamples): 109 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 110 | for h in handles: 111 | h.remove() 112 | 113 | for name in subset: 114 | scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name) 115 | quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize) 116 | 117 | if args.observe: 118 | observer.submit(name=name, layerid=i, gptq=gptq[name], error=error) 119 | else: 120 | gptq[name].free() 121 | 122 | for j in range(args.nsamples): 123 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 124 | 125 | layers[i] = layer.cpu() 126 | del layer 127 | del gptq 128 | torch.cuda.empty_cache() 129 | 130 | inps, outs = outs, inps 131 | print('+------------------+--------------+------------+-----------+-------+') 132 | print('\n') 133 | 134 | if args.observe: 135 | observer.print() 136 | conditions = gen_conditions(args.wbits, args.groupsize) 137 | for item in observer.items(): 138 | name = item[0] 139 | layerid = item[1] 140 | gptq = item[2]['gptq'] 141 | error = item[2]['error'] 142 | target = error / 2 143 | 144 | table = Texttable() 145 | table.header(['wbits', 'groupsize', 'error']) 146 | table.set_cols_dtype(['i', 'i', 'f']) 147 | table.add_row([args.wbits, args.groupsize, error]) 148 | 149 | print('Optimizing {} {} ..'.format(name, layerid)) 150 | for wbits, groupsize in conditions: 151 | 152 | if error < target: 153 | # if error dropped 50%, skip 154 | break 155 | 156 | gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False) 157 | 158 | scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name) 159 | 160 | table.add_row([wbits, groupsize, error]) 161 | quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) 162 | 163 | print(table.draw()) 164 | print('\n') 165 | gptq.layer.to('cpu') 166 | gptq.free() 167 | 168 | model.config.use_cache = use_cache 169 | 170 | return quantizers 171 | 172 | 173 | @torch.no_grad() 174 | def llama_eval(model, testenc, dev): 175 | print('Evaluating ...') 176 | 177 | testenc = testenc.input_ids 178 | nsamples = testenc.numel() // model.seqlen 179 | 180 | use_cache = model.config.use_cache 181 | model.config.use_cache = False 182 | layers = model.model.layers 183 | 184 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 185 | layers[0] = layers[0].to(dev) 186 | 187 | dtype = next(iter(model.parameters())).dtype 188 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 189 | cache = {'i': 0, 'attention_mask': None} 190 | 191 | class Catcher(nn.Module): 192 | 193 | def __init__(self, module): 194 | super().__init__() 195 | self.module = module 196 | 197 | def forward(self, inp, **kwargs): 198 | inps[cache['i']] = inp 199 | cache['i'] += 1 200 | cache['attention_mask'] = kwargs['attention_mask'] 201 | cache['position_ids'] = kwargs['position_ids'] 202 | raise ValueError 203 | 204 | layers[0] = Catcher(layers[0]) 205 | for i in range(nsamples): 206 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 207 | try: 208 | model(batch) 209 | except ValueError: 210 | pass 211 | layers[0] = layers[0].module 212 | 213 | layers[0] = layers[0].cpu() 214 | model.model.embed_tokens = model.model.embed_tokens.cpu() 215 | torch.cuda.empty_cache() 216 | 217 | outs = torch.zeros_like(inps) 218 | attention_mask = cache['attention_mask'] 219 | position_ids = cache['position_ids'] 220 | 221 | for i in range(len(layers)): 222 | print(i) 223 | layer = layers[i].to(dev) 224 | 225 | if args.nearest: 226 | subset = find_layers(layer) 227 | for name in subset: 228 | quantizer = quant.Quantizer() 229 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 230 | W = subset[name].weight.data 231 | quantizer.find_params(W, weight=True) 232 | subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype) 233 | 234 | for j in range(nsamples): 235 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 236 | layers[i] = layer.cpu() 237 | del layer 238 | torch.cuda.empty_cache() 239 | inps, outs = outs, inps 240 | 241 | if model.model.norm is not None: 242 | model.model.norm = model.model.norm.to(dev) 243 | model.lm_head = model.lm_head.to(dev) 244 | 245 | testenc = testenc.to(dev) 246 | nlls = [] 247 | for i in range(nsamples): 248 | hidden_states = inps[i].unsqueeze(0) 249 | if model.model.norm is not None: 250 | hidden_states = model.model.norm(hidden_states) 251 | lm_logits = model.lm_head(hidden_states) 252 | shift_logits = lm_logits[:, :-1, :].contiguous() 253 | shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:] 254 | loss_fct = nn.CrossEntropyLoss() 255 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 256 | neg_log_likelihood = loss.float() * model.seqlen 257 | nlls.append(neg_log_likelihood) 258 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 259 | print(ppl.item()) 260 | 261 | model.config.use_cache = use_cache 262 | 263 | 264 | # TODO: perform packing on GPU 265 | def llama_pack(model, quantizers, wbits, groupsize): 266 | layers = find_layers(model) 267 | layers = {n: layers[n] for n in quantizers} 268 | quant.make_quant_linear(model, quantizers, wbits, groupsize) 269 | qlayers = find_layers(model, [quant.QuantLinear]) 270 | print('Packing ...') 271 | for name in qlayers: 272 | print(name) 273 | quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] 274 | qlayers[name].pack(layers[name], scale, zero, g_idx) 275 | print('Done.') 276 | return model 277 | 278 | 279 | def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): 280 | from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils 281 | config = LlamaConfig.from_pretrained(model) 282 | 283 | def noop(*args, **kwargs): 284 | pass 285 | 286 | torch.nn.init.kaiming_uniform_ = noop 287 | torch.nn.init.uniform_ = noop 288 | torch.nn.init.normal_ = noop 289 | 290 | torch.set_default_dtype(torch.half) 291 | modeling_utils._init_weights = False 292 | torch.set_default_dtype(torch.half) 293 | model = LlamaForCausalLM(config) 294 | torch.set_default_dtype(torch.float) 295 | if eval: 296 | model = model.eval() 297 | layers = find_layers(model) 298 | for name in ['lm_head']: 299 | if name in layers: 300 | del layers[name] 301 | quant.make_quant_linear(model, layers, wbits, groupsize) 302 | 303 | del layers 304 | 305 | print('Loading model ...') 306 | if checkpoint.endswith('.safetensors'): 307 | from safetensors.torch import load_file as safe_load 308 | model.load_state_dict(safe_load(checkpoint)) 309 | else: 310 | model.load_state_dict(torch.load(checkpoint)) 311 | 312 | if eval: 313 | quant.make_quant_attn(model) 314 | quant.make_quant_norm(model) 315 | if fused_mlp: 316 | quant.make_fused_mlp(model) 317 | 318 | if warmup_autotune: 319 | quant.autotune_warmup_linear(model, transpose=not (eval)) 320 | if eval and fused_mlp: 321 | quant.autotune_warmup_fused(model) 322 | model.seqlen = 2048 323 | print('Done.') 324 | 325 | return model 326 | 327 | 328 | def llama_multigpu(model, gpus, gpu_dist): 329 | model.model.embed_tokens = model.model.embed_tokens.to(gpus[0]) 330 | if hasattr(model.model, 'norm') and model.model.norm: 331 | model.model.norm = model.model.norm.to(gpus[0]) 332 | import copy 333 | model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0]) 334 | 335 | cache = {'mask': None, 'position_ids': None} 336 | 337 | class MoveModule(nn.Module): 338 | 339 | def __init__(self, module, invalidate_cache): 340 | super().__init__() 341 | self.module = module 342 | self.dev = next(iter(self.module.parameters())).device 343 | self.invalidate_cache=invalidate_cache 344 | 345 | def forward(self, *inp, **kwargs): 346 | inp = list(inp) 347 | if inp[0].device != self.dev: 348 | inp[0] = inp[0].to(self.dev) 349 | 350 | if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache: 351 | cache['mask'] = kwargs['attention_mask'].to(self.dev) 352 | kwargs['attention_mask'] = cache['mask'] 353 | 354 | if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache: 355 | cache['position_ids'] = kwargs['position_ids'].to(self.dev) 356 | kwargs['position_ids'] = cache['position_ids'] 357 | 358 | tmp = self.module(*inp, **kwargs) 359 | return tmp 360 | 361 | layers = model.model.layers 362 | from math import ceil 363 | if not gpu_dist: 364 | pergpu = ceil(len(layers) / len(gpus)) 365 | for i in range(len(layers)): 366 | layers[i] = MoveModule(layers[i].to(0 if i == 0 or i == len(layers) -1 else gpus[(i-1) // pergpu]), i==0) 367 | else: 368 | assert gpu_dist[0] >= 2, "At least two layers must be on GPU 0." 369 | assigned_gpus = [0] * (gpu_dist[0]-1) 370 | for i in range(1, len(gpu_dist)): 371 | assigned_gpus = assigned_gpus + [i] * gpu_dist[i] 372 | 373 | remaining_assignments = len(layers)-len(assigned_gpus) - 1 374 | if remaining_assignments > 0: 375 | assigned_gpus = assigned_gpus + [-1] * remaining_assignments 376 | 377 | assigned_gpus = assigned_gpus + [0] 378 | 379 | for i in range(len(layers)): 380 | layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0) 381 | 382 | model.gpus = gpus 383 | 384 | 385 | def benchmark(model, input_ids, check=False): 386 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) 387 | torch.cuda.synchronize() 388 | 389 | cache = {'past': None} 390 | 391 | def clear_past(i): 392 | 393 | def tmp(layer, inp, out): 394 | if cache['past']: 395 | cache['past'][i] = None 396 | 397 | return tmp 398 | 399 | for i, layer in enumerate(model.model.layers): 400 | layer.register_forward_hook(clear_past(i)) 401 | 402 | print('Benchmarking ...') 403 | 404 | if check: 405 | loss = nn.CrossEntropyLoss() 406 | tot = 0. 407 | 408 | def sync(): 409 | if hasattr(model, 'gpus'): 410 | for gpu in model.gpus: 411 | torch.cuda.synchronize(gpu) 412 | else: 413 | torch.cuda.synchronize() 414 | 415 | max_memory = 0 416 | with torch.no_grad(): 417 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV) 418 | times = [] 419 | for i in range(input_ids.numel()): 420 | tick = time.time() 421 | out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1))) 422 | sync() 423 | times.append(time.time() - tick) 424 | print(i, times[-1]) 425 | if hasattr(model, 'gpus'): 426 | mem_allocated = sum(torch.cuda.memory_allocated(gpu) for gpu in model.gpus) / 1024 / 1024 427 | else: 428 | mem_allocated = torch.cuda.memory_allocated() / 1024 / 1024 429 | max_memory = max(max_memory, mem_allocated) 430 | if check and i != input_ids.numel() - 1: 431 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() 432 | cache['past'] = list(out.past_key_values) 433 | del out 434 | sync() 435 | print('Median:', np.median(times)) 436 | if check: 437 | print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) 438 | print('max memory(MiB):', max_memory) 439 | 440 | 441 | if __name__ == '__main__': 442 | 443 | parser = argparse.ArgumentParser() 444 | 445 | parser.add_argument('model', type=str, help='llama model to load') 446 | parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') 447 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') 448 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.') 449 | parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') 450 | parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.') 451 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') 452 | parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') 453 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') 454 | parser.add_argument('--eval', action='store_true', help='evaluate quantized model.') 455 | parser.add_argument('--test-generation', action='store_true', help='test generation.') 456 | parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.') 457 | parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.') 458 | parser.add_argument('--load', type=str, default='', help='Load quantized model.') 459 | parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.') 460 | parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') 461 | parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') 462 | parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') 463 | parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.') 464 | parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval') 465 | parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.') 466 | parser.add_argument('--observe', 467 | action='store_true', 468 | help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \ 469 | When this feature enabled, `--save` or `--save_safetensors` would be disable.') 470 | parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.') 471 | 472 | args = parser.parse_args() 473 | 474 | if args.layers_dist: 475 | gpu_dist = [int(x) for x in args.layers_dist.split(':')] 476 | else: 477 | gpu_dist = [] 478 | 479 | if type(args.load) is not str: 480 | args.load = args.load.as_posix() 481 | 482 | if args.load: 483 | model = load_quant(args.model, args.load, args.wbits, args.groupsize) 484 | else: 485 | model = get_llama(args.model) 486 | model.eval() 487 | 488 | dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen) 489 | 490 | if not args.load and args.wbits < 16 and not args.nearest: 491 | tick = time.time() 492 | quantizers = llama_sequential(model, dataloader, DEV) 493 | print(time.time() - tick) 494 | 495 | if args.benchmark: 496 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] 497 | if len(gpus) > 1: 498 | llama_multigpu(model, gpus, gpu_dist) 499 | else: 500 | model = model.to(DEV) 501 | if args.benchmark: 502 | input_ids = next(iter(dataloader))[0][:, :args.benchmark] 503 | benchmark(model, input_ids, check=args.check) 504 | 505 | if args.eval: 506 | datasets = ['wikitext2', 'ptb', 'c4'] 507 | if args.new_eval: 508 | datasets = ['wikitext2', 'ptb-new', 'c4-new'] 509 | for dataset in datasets: 510 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) 511 | print(dataset) 512 | llama_eval(model, testloader, DEV) 513 | 514 | if args.test_generation: 515 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] 516 | if len(gpus) > 1: 517 | llama_multigpu(model, gpus, gpu_dist) 518 | else: 519 | model = model.to(DEV) 520 | 521 | from transformers import LlamaTokenizer, TextStreamer 522 | tokenizer = LlamaTokenizer.from_pretrained(args.model, use_fast=False) 523 | input_ids = tokenizer(["The capital of New Mexico is"], return_tensors="pt").input_ids.to(gpus[0]) 524 | streamer = TextStreamer(tokenizer) 525 | with torch.no_grad(): 526 | generated_ids = model.generate(input_ids, streamer=streamer) 527 | 528 | 529 | 530 | if args.quant_directory is not None: 531 | export_quant_table(quantizers, args.quant_directory) 532 | 533 | if not args.observe and args.save: 534 | llama_pack(model, quantizers, args.wbits, args.groupsize) 535 | torch.save(model.state_dict(), args.save) 536 | 537 | if not args.observe and args.save_safetensors: 538 | llama_pack(model, quantizers, args.wbits, args.groupsize) 539 | from safetensors.torch import save_file as safe_save 540 | state_dict = model.state_dict() 541 | state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} 542 | safe_save(state_dict, args.save_safetensors) 543 | -------------------------------------------------------------------------------- /llama_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | import quant 6 | 7 | from gptq import GPTQ 8 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders 9 | import transformers 10 | from transformers import AutoTokenizer 11 | 12 | 13 | def get_llama(model): 14 | 15 | def skip(*args, **kwargs): 16 | pass 17 | 18 | torch.nn.init.kaiming_uniform_ = skip 19 | torch.nn.init.uniform_ = skip 20 | torch.nn.init.normal_ = skip 21 | from transformers import LlamaForCausalLM 22 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') 23 | model.seqlen = 2048 24 | return model 25 | 26 | 27 | def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): 28 | from transformers import LlamaConfig, LlamaForCausalLM 29 | config = LlamaConfig.from_pretrained(model) 30 | 31 | def noop(*args, **kwargs): 32 | pass 33 | 34 | torch.nn.init.kaiming_uniform_ = noop 35 | torch.nn.init.uniform_ = noop 36 | torch.nn.init.normal_ = noop 37 | 38 | torch.set_default_dtype(torch.half) 39 | transformers.modeling_utils._init_weights = False 40 | torch.set_default_dtype(torch.half) 41 | model = LlamaForCausalLM(config) 42 | torch.set_default_dtype(torch.float) 43 | if eval: 44 | model = model.eval() 45 | layers = find_layers(model) 46 | for name in ['lm_head']: 47 | if name in layers: 48 | del layers[name] 49 | quant.make_quant_linear(model, layers, wbits, groupsize) 50 | 51 | del layers 52 | 53 | print('Loading model ...') 54 | if checkpoint.endswith('.safetensors'): 55 | from safetensors.torch import load_file as safe_load 56 | model.load_state_dict(safe_load(checkpoint), strict=False) 57 | else: 58 | model.load_state_dict(torch.load(checkpoint), strict=False) 59 | 60 | if eval: 61 | quant.make_quant_attn(model) 62 | quant.make_quant_norm(model) 63 | if fused_mlp: 64 | quant.make_fused_mlp(model) 65 | if warmup_autotune: 66 | quant.autotune_warmup_linear(model, transpose=not (eval)) 67 | if eval and fused_mlp: 68 | quant.autotune_warmup_fused(model) 69 | model.seqlen = 2048 70 | print('Done.') 71 | 72 | return model 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | parser = argparse.ArgumentParser() 78 | 79 | parser.add_argument('model', type=str, help='llama model to load') 80 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') 81 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') 82 | parser.add_argument('--load', type=str, default='', help='Load quantized model.') 83 | 84 | parser.add_argument('--text', type=str, help='input text') 85 | 86 | parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.') 87 | 88 | parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.') 89 | 90 | parser.add_argument('--top_p', 91 | type=float, 92 | default=0.95, 93 | help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.') 94 | 95 | parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.') 96 | 97 | parser.add_argument('--device', type=int, default=-1, help='The device used to load the model when using safetensors. Default device is "cpu" or specify, 0,1,2,3,... for GPU device.') 98 | 99 | # fused mlp is sometimes not working with safetensors, no_fused_mlp is used to set fused_mlp to False, default is true 100 | parser.add_argument('--fused_mlp', action='store_true') 101 | parser.add_argument('--no_fused_mlp', dest='fused_mlp', action='store_false') 102 | parser.set_defaults(fused_mlp=True) 103 | 104 | args = parser.parse_args() 105 | 106 | if type(args.load) is not str: 107 | args.load = args.load.as_posix() 108 | 109 | if args.load: 110 | model = load_quant(args.model, args.load, args.wbits, args.groupsize, fused_mlp=args.fused_mlp) 111 | else: 112 | model = get_llama(args.model) 113 | model.eval() 114 | 115 | model.to(DEV) 116 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 117 | input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV) 118 | 119 | with torch.no_grad(): 120 | generated_ids = model.generate( 121 | input_ids, 122 | do_sample=True, 123 | min_length=args.min_length, 124 | max_length=args.max_length, 125 | top_p=args.top_p, 126 | temperature=args.temperature, 127 | ) 128 | print(tokenizer.decode([el.item() for el in generated_ids[0]])) 129 | -------------------------------------------------------------------------------- /llama_inference_offload.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from gptq import GPTQ 5 | import argparse 6 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders 7 | import quant 8 | 9 | import transformers 10 | from transformers import AutoTokenizer 11 | from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig 12 | from transformers.modeling_outputs import BaseModelOutputWithPast 13 | from typing import List, Optional, Tuple, Union 14 | from accelerate import cpu_offload_with_hook, load_checkpoint_in_model 15 | 16 | 17 | class Offload_LlamaModel(LlamaModel): 18 | 19 | def __init__(self, config: LlamaConfig): 20 | super().__init__(config) 21 | 22 | def cpu_offload(self, preload): 23 | hook = None 24 | for cpu_offloaded_model in self.layers[preload:]: 25 | _, hook = cpu_offload_with_hook(cpu_offloaded_model, DEV, prev_module_hook=hook) 26 | 27 | def forward( 28 | self, 29 | input_ids: torch.LongTensor = None, 30 | attention_mask: Optional[torch.Tensor] = None, 31 | position_ids: Optional[torch.LongTensor] = None, 32 | past_key_values: Optional[List[torch.FloatTensor]] = None, 33 | inputs_embeds: Optional[torch.FloatTensor] = None, 34 | use_cache: Optional[bool] = None, 35 | output_attentions: Optional[bool] = None, 36 | output_hidden_states: Optional[bool] = None, 37 | return_dict: Optional[bool] = None, 38 | ) -> Union[Tuple, BaseModelOutputWithPast]: 39 | r""" 40 | Args: 41 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 42 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 43 | provide it. 44 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 45 | [`PreTrainedTokenizer.__call__`] for details. 46 | [What are input IDs?](../glossary#input-ids) 47 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 48 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 49 | - 1 for tokens that are **not masked**, 50 | - 0 for tokens that are **masked**. 51 | [What are attention masks?](../glossary#attention-mask) 52 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 53 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range 54 | `[0, config.n_positions - 1]`. 55 | [What are position IDs?](../glossary#position-ids) 56 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 57 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 58 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 59 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 60 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 61 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 62 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 63 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 64 | use_cache (`bool`, *optional*): 65 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 66 | (see `past_key_values`). 67 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 68 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 69 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 70 | than the model's internal embedding lookup matrix. 71 | output_attentions (`bool`, *optional*): 72 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 73 | returned tensors for more detail. 74 | output_hidden_states (`bool`, *optional*): 75 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 76 | for more detail. 77 | return_dict (`bool`, *optional*): 78 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 79 | """ 80 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 81 | output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) 82 | use_cache = use_cache if use_cache is not None else self.config.use_cache 83 | 84 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 85 | 86 | # retrieve input_ids and inputs_embeds 87 | if input_ids is not None and inputs_embeds is not None: 88 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 89 | elif input_ids is not None: 90 | batch_size, seq_length = input_ids.shape 91 | elif inputs_embeds is not None: 92 | batch_size, seq_length, _ = inputs_embeds.shape 93 | else: 94 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 95 | seq_length_with_past = seq_length 96 | past_key_values_length = 0 97 | if past_key_values is not None: 98 | past_key_values_length = past_key_values[0][0].shape[2] 99 | seq_length_with_past = seq_length_with_past + past_key_values_length 100 | 101 | if position_ids is None: 102 | device = input_ids.device if input_ids is not None else inputs_embeds.device 103 | position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) 104 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 105 | else: 106 | position_ids = position_ids.view(-1, seq_length).long() 107 | 108 | if inputs_embeds is None: 109 | inputs_embeds = self.embed_tokens(input_ids) 110 | 111 | # embed positions 112 | if attention_mask is None: 113 | attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) 114 | attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length) 115 | 116 | hidden_states = inputs_embeds 117 | 118 | if self.gradient_checkpointing and self.training: 119 | if use_cache: 120 | logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") 121 | use_cache = False 122 | 123 | # decoder layers 124 | all_hidden_states = () if output_hidden_states else None 125 | all_self_attns = () if output_attentions else None 126 | next_decoder_cache = () if use_cache else None 127 | 128 | for idx in range(len(self.layers)): 129 | decoder_layer = self.layers[idx] 130 | 131 | if output_hidden_states: 132 | all_hidden_states += (hidden_states, ) 133 | 134 | past_key_value = past_key_values[idx] if past_key_values is not None else None 135 | 136 | if self.gradient_checkpointing and self.training: 137 | 138 | def create_custom_forward(module): 139 | 140 | def custom_forward(*inputs): 141 | # None for past_key_value 142 | return module(*inputs, output_attentions, None) 143 | 144 | return custom_forward 145 | 146 | layer_outputs = torch.utils.checkpoint.checkpoint( 147 | create_custom_forward(decoder_layer), 148 | hidden_states, 149 | attention_mask, 150 | position_ids, 151 | None, 152 | ) 153 | else: 154 | layer_outputs = decoder_layer( 155 | hidden_states, 156 | attention_mask=attention_mask, 157 | position_ids=position_ids, 158 | past_key_value=past_key_value, 159 | output_attentions=output_attentions, 160 | use_cache=use_cache, 161 | ) 162 | 163 | hidden_states = layer_outputs[0] 164 | 165 | if use_cache: 166 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1], ) 167 | 168 | if output_attentions: 169 | all_self_attns += (layer_outputs[1], ) 170 | 171 | hidden_states = self.norm(hidden_states) 172 | 173 | # add hidden states from the last decoder layer 174 | if output_hidden_states: 175 | all_hidden_states += (hidden_states, ) 176 | 177 | next_cache = next_decoder_cache if use_cache else None 178 | if not return_dict: 179 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 180 | return BaseModelOutputWithPast( 181 | last_hidden_state=hidden_states, 182 | past_key_values=next_cache, 183 | hidden_states=all_hidden_states, 184 | attentions=all_self_attns, 185 | ) 186 | 187 | 188 | def load_quant(model, checkpoint, wbits, groupsize, pre_layer, fused_mlp=True, warmup_autotune=True): 189 | transformers.models.llama.modeling_llama.LlamaModel = Offload_LlamaModel 190 | from transformers import LlamaConfig, LlamaForCausalLM 191 | config = LlamaConfig.from_pretrained(model) 192 | 193 | def noop(*args, **kwargs): 194 | pass 195 | 196 | torch.nn.init.kaiming_uniform_ = noop 197 | torch.nn.init.uniform_ = noop 198 | torch.nn.init.normal_ = noop 199 | 200 | torch.set_default_dtype(torch.half) 201 | transformers.modeling_utils._init_weights = False 202 | torch.set_default_dtype(torch.half) 203 | model = LlamaForCausalLM(config) 204 | torch.set_default_dtype(torch.float) 205 | model = model.eval() 206 | layers = find_layers(model) 207 | for name in ['lm_head']: 208 | if name in layers: 209 | del layers[name] 210 | quant.make_quant_linear(model, layers, wbits, groupsize) 211 | 212 | print('Loading model ...') 213 | load_checkpoint_in_model(model, checkpoint, dtype='float16') 214 | model.seqlen = 2048 215 | 216 | if eval: 217 | quant.make_quant_attn(model) 218 | quant.make_quant_norm(model) 219 | if fused_mlp: 220 | quant.make_fused_mlp(model) 221 | 222 | 223 | if warmup_autotune: 224 | quant.autotune_warmup_linear(model) 225 | if fused_mlp: 226 | quant.autotune_warmup_fused(model) 227 | 228 | for i in range(pre_layer): 229 | model.model.layers[i].to(DEV) 230 | model.model.embed_tokens.to(DEV) 231 | model.model.norm.to(DEV) 232 | model.lm_head.to(DEV) 233 | model.model.cpu_offload(pre_layer) 234 | print('Done.') 235 | return model 236 | 237 | 238 | if __name__ == '__main__': 239 | parser = argparse.ArgumentParser() 240 | 241 | parser.add_argument('model', type=str, help='llama model to load') 242 | parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8], help='#bits to use for quantization') 243 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') 244 | parser.add_argument('--load', type=str, default='', help='Load quantized model.') 245 | parser.add_argument('--text', type=str, help='input text') 246 | 247 | parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.') 248 | 249 | parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.') 250 | 251 | parser.add_argument('--top_p', 252 | type=float, 253 | default=0.95, 254 | help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.') 255 | 256 | parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.') 257 | 258 | parser.add_argument('--pre_layer', type=int, default=50, help='The number of layers to preload') 259 | 260 | args = parser.parse_args() 261 | 262 | if type(args.load) is not str: 263 | args.load = args.load.as_posix() 264 | 265 | model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.pre_layer) 266 | 267 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 268 | input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV) 269 | 270 | with torch.no_grad(): 271 | generated_ids = model.generate( 272 | input_ids, 273 | do_sample=True, 274 | min_length=args.min_length, 275 | max_length=args.max_length, 276 | top_p=args.top_p, 277 | temperature=args.temperature, 278 | ) 279 | print(tokenizer.decode([el.item() for el in generated_ids[0]])) 280 | -------------------------------------------------------------------------------- /neox.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import quant 7 | 8 | from gptq import GPTQ, Observer 9 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions 10 | from texttable import Texttable 11 | 12 | 13 | def get_neox(model, seqlen=-1): 14 | 15 | def skip(*args, **kwargs): 16 | pass 17 | 18 | torch.nn.init.kaiming_uniform_ = skip 19 | torch.nn.init.uniform_ = skip 20 | torch.nn.init.normal_ = skip 21 | from transformers import GPTNeoXForCausalLM 22 | model = GPTNeoXForCausalLM.from_pretrained(model, torch_dtype=torch.float16) 23 | model.seqlen = seqlen if seqlen != -1 else model.config.max_position_embeddings 24 | return model 25 | 26 | 27 | @torch.no_grad() 28 | def neox_sequential(model, dataloader, dev): 29 | print('Starting ...') 30 | 31 | use_cache = model.config.use_cache 32 | model.config.use_cache = False 33 | layers = model.gpt_neox.layers 34 | 35 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev) 36 | layers[0] = layers[0].to(dev) 37 | 38 | dtype = next(iter(model.parameters())).dtype 39 | inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 40 | cache = {'i': 0, 'attention_mask': None} 41 | 42 | class Catcher(nn.Module): 43 | 44 | def __init__(self, module): 45 | super().__init__() 46 | self.module = module 47 | 48 | def forward(self, inp, **kwargs): 49 | inps[cache['i']] = inp 50 | cache['i'] += 1 51 | cache['attention_mask'] = kwargs['attention_mask'] 52 | cache['position_ids'] = kwargs['position_ids'] 53 | raise ValueError 54 | 55 | layers[0] = Catcher(layers[0]) 56 | for batch in dataloader: 57 | try: 58 | model(batch[0].to(dev)) 59 | except ValueError: 60 | pass 61 | layers[0] = layers[0].module 62 | 63 | layers[0] = layers[0].cpu() 64 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu() 65 | torch.cuda.empty_cache() 66 | 67 | outs = torch.zeros_like(inps) 68 | attention_mask = cache['attention_mask'] 69 | position_ids = cache['position_ids'] 70 | 71 | print('Ready.') 72 | 73 | quantizers = {} 74 | observer = Observer() 75 | for i in range(len(layers)): 76 | 77 | print(f'Quantizing layer {i+1}/{len(layers)}..') 78 | print('+------------------+--------------+------------+-----------+-------+') 79 | print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') 80 | print('+==================+==============+============+===========+=======+') 81 | 82 | layer = layers[i].to(dev) 83 | full = find_layers(layer) 84 | sequential = [list(full.keys())] 85 | 86 | for names in sequential: 87 | subset = {n: full[n] for n in names} 88 | gptq = {} 89 | for name in subset: 90 | gptq[name] = GPTQ(subset[name], observe=False) 91 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 92 | 93 | def add_batch(name): 94 | 95 | def tmp(_, inp, out): 96 | gptq[name].add_batch(inp[0].data, out.data) 97 | 98 | return tmp 99 | 100 | handles = [] 101 | for name in subset: 102 | handles.append(subset[name].register_forward_hook(add_batch(name))) 103 | for j in range(args.nsamples): 104 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 105 | for h in handles: 106 | h.remove() 107 | 108 | for name in subset: 109 | scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name) 110 | quantizers['gpt_neox.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize) 111 | gptq[name].free() 112 | 113 | for j in range(args.nsamples): 114 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 115 | 116 | layers[i] = layer.cpu() 117 | del layer 118 | del gptq 119 | torch.cuda.empty_cache() 120 | 121 | inps, outs = outs, inps 122 | print('+------------------+--------------+------------+-----------+-------+') 123 | print('\n') 124 | 125 | model.config.use_cache = use_cache 126 | 127 | return quantizers 128 | 129 | 130 | @torch.no_grad() 131 | def neox_eval(model, testenc, dev): 132 | print('Evaluating ...') 133 | 134 | testenc = testenc.input_ids 135 | nsamples = testenc.numel() // model.seqlen 136 | 137 | use_cache = model.config.use_cache 138 | model.config.use_cache = False 139 | layers = model.gpt_neox.layers 140 | 141 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev) 142 | layers[0] = layers[0].to(dev) 143 | 144 | dtype = next(iter(model.parameters())).dtype 145 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 146 | cache = {'i': 0, 'attention_mask': None} 147 | 148 | class Catcher(nn.Module): 149 | 150 | def __init__(self, module): 151 | super().__init__() 152 | self.module = module 153 | 154 | def forward(self, inp, **kwargs): 155 | inps[cache['i']] = inp 156 | cache['i'] += 1 157 | cache['attention_mask'] = kwargs['attention_mask'] 158 | cache['position_ids'] = kwargs['position_ids'] 159 | raise ValueError 160 | 161 | layers[0] = Catcher(layers[0]) 162 | for i in range(nsamples): 163 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 164 | try: 165 | model(batch) 166 | except ValueError: 167 | pass 168 | layers[0] = layers[0].module 169 | 170 | layers[0] = layers[0].cpu() 171 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu() 172 | torch.cuda.empty_cache() 173 | 174 | outs = torch.zeros_like(inps) 175 | attention_mask = cache['attention_mask'] 176 | position_ids = cache['position_ids'] 177 | 178 | for i in range(len(layers)): 179 | print(i) 180 | layer = layers[i].to(dev) 181 | 182 | if args.nearest: 183 | subset = find_layers(layer) 184 | for name in subset: 185 | quantizer = quant.Quantizer() 186 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 187 | W = subset[name].weight.data 188 | quantizer.find_params(W, weight=True) 189 | subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype) 190 | 191 | for j in range(nsamples): 192 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 193 | layers[i] = layer.cpu() 194 | del layer 195 | torch.cuda.empty_cache() 196 | inps, outs = outs, inps 197 | 198 | model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(dev) 199 | model.embed_out = model.embed_out.to(dev) 200 | 201 | testenc = testenc.to(dev) 202 | nlls = [] 203 | for i in range(nsamples): 204 | hidden_states = inps[i].unsqueeze(0) 205 | hidden_states = model.gpt_neox.final_layer_norm(hidden_states) 206 | lm_logits = model.embed_out(hidden_states) 207 | shift_logits = lm_logits[:, :-1, :].contiguous() 208 | shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:] 209 | loss_fct = nn.CrossEntropyLoss() 210 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 211 | neg_log_likelihood = loss.float() * model.seqlen 212 | nlls.append(neg_log_likelihood) 213 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 214 | print(ppl.item()) 215 | 216 | model.config.use_cache = use_cache 217 | 218 | 219 | # TODO: perform packing on GPU 220 | def neox_pack(model, quantizers, wbits, groupsize): 221 | layers = find_layers(model) 222 | layers = {n: layers[n] for n in quantizers} 223 | quant.make_quant_linear(model, quantizers, wbits, groupsize) 224 | qlayers = find_layers(model, [quant.QuantLinear]) 225 | print('Packing ...') 226 | for name in qlayers: 227 | print(name) 228 | quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] 229 | qlayers[name].pack(layers[name], scale, zero, g_idx) 230 | print('Done.') 231 | return model 232 | 233 | 234 | def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True): 235 | from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, modeling_utils 236 | config = GPTNeoXConfig.from_pretrained(model) 237 | 238 | def noop(*args, **kwargs): 239 | pass 240 | 241 | torch.nn.init.kaiming_uniform_ = noop 242 | torch.nn.init.uniform_ = noop 243 | torch.nn.init.normal_ = noop 244 | 245 | torch.set_default_dtype(torch.half) 246 | modeling_utils._init_weights = False 247 | torch.set_default_dtype(torch.half) 248 | model = GPTNeoXForCausalLM(config) 249 | torch.set_default_dtype(torch.float) 250 | if eval: 251 | model = model.eval() 252 | layers = find_layers(model) 253 | for name in ['embed_in','embed_out']: 254 | if name in layers: 255 | del layers[name] 256 | quant.make_quant_linear(model, layers, wbits, groupsize) 257 | 258 | del layers 259 | 260 | print('Loading model ...') 261 | if checkpoint.endswith('.safetensors'): 262 | from safetensors.torch import load_file as safe_load 263 | model.load_state_dict(safe_load(checkpoint)) 264 | else: 265 | model.load_state_dict(torch.load(checkpoint)) 266 | 267 | if warmup_autotune: 268 | quant.autotune_warmup_linear(model, transpose=not (eval)) 269 | 270 | model.seqlen = model.config.max_position_embeddings 271 | print('Done.') 272 | 273 | return model 274 | 275 | 276 | def neox_multigpu(model, gpus): 277 | model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(gpus[0]) 278 | model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(gpus[-1]) 279 | import copy 280 | model.embed_out = copy.deepcopy(model.embed_out).to(gpus[-1]) 281 | 282 | cache = {'mask': None} 283 | 284 | class MoveModule(nn.Module): 285 | 286 | def __init__(self, module): 287 | super().__init__() 288 | self.module = module 289 | self.dev = next(iter(self.module.parameters())).device 290 | 291 | def forward(self, *inp, **kwargs): 292 | inp = list(inp) 293 | if inp[0].device != self.dev: 294 | inp[0] = inp[0].to(self.dev) 295 | if cache['mask'] is None or cache['mask'].device != self.dev: 296 | cache['mask'] = kwargs['attention_mask'].to(self.dev) 297 | kwargs['attention_mask'] = cache['mask'] 298 | tmp = self.module(*inp, **kwargs) 299 | return tmp 300 | 301 | layers = model.gpt_neox.layers 302 | pergpu = math.ceil(len(layers) / len(gpus)) 303 | for i in range(len(layers)): 304 | layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) 305 | 306 | model.gpus = gpus 307 | 308 | 309 | def benchmark(model, input_ids, check=False): 310 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) 311 | torch.cuda.synchronize() 312 | 313 | cache = {'past': None} 314 | 315 | def clear_past(i): 316 | 317 | def tmp(layer, inp, out): 318 | if cache['past']: 319 | cache['past'][i] = None 320 | 321 | return tmp 322 | 323 | for i, layer in enumerate(model.gpt_neox.layers): 324 | layer.register_forward_hook(clear_past(i)) 325 | 326 | print('Benchmarking ...') 327 | 328 | if check: 329 | loss = nn.CrossEntropyLoss() 330 | tot = 0. 331 | 332 | def sync(): 333 | if hasattr(model, 'gpus'): 334 | for gpu in model.gpus: 335 | torch.cuda.synchronize(gpu) 336 | else: 337 | torch.cuda.synchronize() 338 | 339 | max_memory = 0 340 | with torch.no_grad(): 341 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV) 342 | times = [] 343 | for i in range(input_ids.numel()): 344 | tick = time.time() 345 | out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1))) 346 | sync() 347 | times.append(time.time() - tick) 348 | print(i, times[-1]) 349 | max_memory = max(max_memory, torch.cuda.memory_allocated() / 1024 / 1024) 350 | if check and i != input_ids.numel() - 1: 351 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() 352 | cache['past'] = list(out.past_key_values) 353 | del out 354 | sync() 355 | print('Median:', np.median(times)) 356 | if check: 357 | print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) 358 | print('max memory(MiB):', max_memory) 359 | 360 | 361 | if __name__ == '__main__': 362 | 363 | parser = argparse.ArgumentParser() 364 | 365 | parser.add_argument('model', type=str, help='llama model to load') 366 | parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') 367 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') 368 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.') 369 | parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') 370 | parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.') 371 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='bits to use for quantization; use 16 for evaluating base model.') 372 | parser.add_argument('--seqlen', type=int, default=-1, help='seqlen to use for quantization; default uses full seqlen') 373 | parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') 374 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') 375 | parser.add_argument('--eval', action='store_true', help='evaluate quantized model.') 376 | parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.') 377 | parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.') 378 | parser.add_argument('--load', type=str, default='', help='Load quantized model.') 379 | parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.') 380 | parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') 381 | parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') 382 | parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') 383 | parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval') 384 | args = parser.parse_args() 385 | 386 | if type(args.load) is not str: 387 | args.load = args.load.as_posix() 388 | 389 | if args.load: 390 | model = load_quant(args.model, args.load, args.wbits, args.groupsize) 391 | else: 392 | model = get_neox(args.model) 393 | model.eval() 394 | 395 | dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen) 396 | 397 | if not args.load and args.wbits < 16 and not args.nearest: 398 | tick = time.time() 399 | quantizers = neox_sequential(model, dataloader, DEV) 400 | print(time.time() - tick) 401 | 402 | if args.benchmark: 403 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] 404 | if len(gpus) > 1: 405 | neox_multigpu(model, gpus) 406 | else: 407 | model = model.to(DEV) 408 | if args.benchmark: 409 | input_ids = next(iter(dataloader))[0][:, :args.benchmark] 410 | benchmark(model, input_ids, check=args.check) 411 | 412 | if args.eval: 413 | datasets = ['wikitext2', 'ptb', 'c4'] 414 | if args.new_eval: 415 | datasets = ['wikitext2', 'ptb-new', 'c4-new'] 416 | for dataset in datasets: 417 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) 418 | print(dataset) 419 | neox_eval(model, testloader, DEV) 420 | 421 | if args.save: 422 | neox_pack(model, quantizers, args.wbits, args.groupsize) 423 | torch.save(model.state_dict(), args.save) 424 | 425 | if args.save_safetensors: 426 | neox_pack(model, quantizers, args.wbits, args.groupsize) 427 | from safetensors.torch import save_file as safe_save 428 | state_dict = model.state_dict() 429 | state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} 430 | safe_save(state_dict, args.save_safetensors) 431 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import argparse 6 | 7 | import transformers 8 | from gptq import GPTQ 9 | from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders 10 | import quant 11 | 12 | 13 | def get_opt(model): 14 | import torch 15 | 16 | def skip(*args, **kwargs): 17 | pass 18 | 19 | torch.nn.init.kaiming_uniform_ = skip 20 | torch.nn.init.uniform_ = skip 21 | torch.nn.init.normal_ = skip 22 | from transformers import OPTForCausalLM 23 | model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') 24 | model.seqlen = model.config.max_position_embeddings 25 | return model 26 | 27 | 28 | @torch.no_grad() 29 | def opt_sequential(model, dataloader, dev): 30 | print('Starting ...') 31 | 32 | use_cache = model.config.use_cache 33 | model.config.use_cache = False 34 | layers = model.model.decoder.layers 35 | 36 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 37 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 38 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 39 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 40 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 41 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 42 | layers[0] = layers[0].to(dev) 43 | 44 | dtype = next(iter(model.parameters())).dtype 45 | inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 46 | cache = {'i': 0, 'attention_mask': None} 47 | 48 | class Catcher(nn.Module): 49 | 50 | def __init__(self, module): 51 | super().__init__() 52 | self.module = module 53 | 54 | def forward(self, inp, **kwargs): 55 | inps[cache['i']] = inp 56 | cache['i'] += 1 57 | cache['attention_mask'] = kwargs['attention_mask'] 58 | raise ValueError 59 | 60 | layers[0] = Catcher(layers[0]) 61 | for batch in dataloader: 62 | try: 63 | model(batch[0].to(dev)) 64 | except ValueError: 65 | pass 66 | layers[0] = layers[0].module 67 | 68 | layers[0] = layers[0].cpu() 69 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 70 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 71 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 72 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 73 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 74 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 75 | torch.cuda.empty_cache() 76 | 77 | outs = torch.zeros_like(inps) 78 | attention_mask = cache['attention_mask'] 79 | 80 | print('Ready.') 81 | 82 | quantizers = {} 83 | for i in range(len(layers)): 84 | layer = layers[i].to(dev) 85 | 86 | subset = find_layers(layer) 87 | gptq = {} 88 | for name in subset: 89 | gptq[name] = GPTQ(subset[name]) 90 | gptq[name].quantizer = quant.Quantizer() 91 | gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits) 92 | 93 | def add_batch(name): 94 | 95 | def tmp(_, inp, out): 96 | gptq[name].add_batch(inp[0].data, out.data) 97 | 98 | return tmp 99 | 100 | handles = [] 101 | for name in subset: 102 | handles.append(subset[name].register_forward_hook(add_batch(name))) 103 | 104 | for j in range(args.nsamples): 105 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 106 | 107 | for h in handles: 108 | h.remove() 109 | 110 | for name in subset: 111 | print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') 112 | scale, zero, g_idx, _ = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) 113 | quantizers['model.decoder.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) 114 | gptq[name].free() 115 | 116 | for j in range(args.nsamples): 117 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 118 | 119 | layers[i] = layer.cpu() 120 | del layer 121 | del gptq 122 | torch.cuda.empty_cache() 123 | 124 | inps, outs = outs, inps 125 | 126 | model.config.use_cache = use_cache 127 | 128 | return quantizers 129 | 130 | 131 | @torch.no_grad() 132 | def opt_eval(model, testenc, dev): 133 | print('Evaluating ...') 134 | 135 | testenc = testenc.input_ids 136 | nsamples = testenc.numel() // model.seqlen 137 | 138 | use_cache = model.config.use_cache 139 | model.config.use_cache = False 140 | layers = model.model.decoder.layers 141 | 142 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 143 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 144 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 145 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 146 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 147 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 148 | layers[0] = layers[0].to(dev) 149 | 150 | dtype = next(iter(model.parameters())).dtype 151 | inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) 152 | cache = {'i': 0, 'attention_mask': None} 153 | 154 | class Catcher(nn.Module): 155 | 156 | def __init__(self, module): 157 | super().__init__() 158 | self.module = module 159 | 160 | def forward(self, inp, **kwargs): 161 | inps[cache['i']] = inp 162 | cache['i'] += 1 163 | cache['attention_mask'] = kwargs['attention_mask'] 164 | raise ValueError 165 | 166 | layers[0] = Catcher(layers[0]) 167 | for i in range(nsamples): 168 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 169 | try: 170 | model(batch) 171 | except ValueError: 172 | pass 173 | layers[0] = layers[0].module 174 | 175 | layers[0] = layers[0].cpu() 176 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 177 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 178 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 179 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 180 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 181 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 182 | torch.cuda.empty_cache() 183 | 184 | outs = torch.zeros_like(inps) 185 | attention_mask = cache['attention_mask'] 186 | 187 | for i in range(len(layers)): 188 | print(i) 189 | layer = layers[i].to(dev) 190 | 191 | if args.nearest: 192 | subset = find_layers(layer) 193 | for name in subset: 194 | quantizer = quant.Quantizer() 195 | quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) 196 | W = subset[name].weight.data 197 | quantizer.find_params(W, weight=True) 198 | subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype) 199 | 200 | for j in range(nsamples): 201 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 202 | layers[i] = layer.cpu() 203 | del layer 204 | torch.cuda.empty_cache() 205 | inps, outs = outs, inps 206 | 207 | if model.model.decoder.final_layer_norm is not None: 208 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 209 | if model.model.decoder.project_out is not None: 210 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 211 | model.lm_head = model.lm_head.to(dev) 212 | 213 | testenc = testenc.to(dev) 214 | nlls = [] 215 | for i in range(nsamples): 216 | hidden_states = inps[i].unsqueeze(0) 217 | if model.model.decoder.final_layer_norm is not None: 218 | hidden_states = model.model.decoder.final_layer_norm(hidden_states) 219 | if model.model.decoder.project_out is not None: 220 | hidden_states = model.model.decoder.project_out(hidden_states) 221 | lm_logits = model.lm_head(hidden_states) 222 | shift_logits = lm_logits[:, :-1, :].contiguous() 223 | shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:] 224 | loss_fct = nn.CrossEntropyLoss() 225 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 226 | neg_log_likelihood = loss.float() * model.seqlen 227 | nlls.append(neg_log_likelihood) 228 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 229 | print(ppl.item()) 230 | 231 | model.config.use_cache = use_cache 232 | 233 | 234 | # TODO: perform packing on GPU 235 | def opt_pack(model, quantizers, wbits, groupsize): 236 | layers = find_layers(model) 237 | layers = {n: layers[n] for n in quantizers} 238 | quant.make_quant_linear(model, quantizers, wbits, groupsize) 239 | qlayers = find_layers(model, [quant.QuantLinear]) 240 | print('Packing ...') 241 | for name in qlayers: 242 | print(name) 243 | quantizers[name], scale, zero, g_idx = quantizers[name] 244 | qlayers[name].pack(layers[name], scale, zero, g_idx) 245 | print('Done.') 246 | return model 247 | 248 | 249 | def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True): 250 | from transformers import OPTConfig, OPTForCausalLM 251 | config = OPTConfig.from_pretrained(model) 252 | 253 | def noop(*args, **kwargs): 254 | pass 255 | 256 | torch.nn.init.kaiming_uniform_ = noop 257 | torch.nn.init.uniform_ = noop 258 | torch.nn.init.normal_ = noop 259 | 260 | torch.set_default_dtype(torch.half) 261 | transformers.modeling_utils._init_weights = False 262 | torch.set_default_dtype(torch.half) 263 | model = OPTForCausalLM(config) 264 | torch.set_default_dtype(torch.float) 265 | model = model.eval() 266 | layers = find_layers(model) 267 | for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']: 268 | if name in layers: 269 | del layers[name] 270 | quant.make_quant_linear(model, layers, wbits, groupsize) 271 | 272 | del layers 273 | 274 | print('Loading model ...') 275 | if checkpoint.endswith('.safetensors'): 276 | from safetensors.torch import load_file as safe_load 277 | model.load_state_dict(safe_load(checkpoint)) 278 | else: 279 | model.load_state_dict(torch.load(checkpoint)) 280 | 281 | if warmup_autotune: 282 | quant.autotune_warmup_linear(model, transpose=not (eval)) 283 | model.seqlen = model.config.max_position_embeddings 284 | print('Done.') 285 | return model 286 | 287 | 288 | def opt_multigpu(model, gpus): 289 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0]) 290 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0]) 291 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 292 | model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0]) 293 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 294 | model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1]) 295 | if hasattr(model.model.decoder, 'final_layer_norm') and model.model.decoder.final_layer_norm: 296 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1]) 297 | import copy 298 | model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1]) 299 | 300 | cache = {'mask': None} 301 | 302 | class MoveModule(nn.Module): 303 | 304 | def __init__(self, module): 305 | super().__init__() 306 | self.module = module 307 | self.dev = next(iter(self.module.parameters())).device 308 | 309 | def forward(self, *inp, **kwargs): 310 | inp = list(inp) 311 | if inp[0].device != self.dev: 312 | inp[0] = inp[0].to(self.dev) 313 | if cache['mask'] is None or cache['mask'].device != self.dev: 314 | cache['mask'] = kwargs['attention_mask'].to(self.dev) 315 | kwargs['attention_mask'] = cache['mask'] 316 | tmp = self.module(*inp, **kwargs) 317 | return tmp 318 | 319 | layers = model.model.decoder.layers 320 | pergpu = math.ceil(len(layers) / len(gpus)) 321 | for i in range(len(layers)): 322 | layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) 323 | 324 | model.gpus = gpus 325 | 326 | 327 | def benchmark(model, input_ids, check=False): 328 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) 329 | torch.cuda.synchronize() 330 | 331 | cache = {'past': None} 332 | 333 | def clear_past(i): 334 | 335 | def tmp(layer, inp, out): 336 | if cache['past']: 337 | cache['past'][i] = None 338 | 339 | return tmp 340 | 341 | for i, layer in enumerate(model.model.decoder.layers): 342 | layer.register_forward_hook(clear_past(i)) 343 | 344 | print('Benchmarking ...') 345 | 346 | if check: 347 | loss = nn.CrossEntropyLoss() 348 | tot = 0. 349 | 350 | def sync(): 351 | if hasattr(model, 'gpus'): 352 | for gpu in model.gpus: 353 | torch.cuda.synchronize(gpu) 354 | else: 355 | torch.cuda.synchronize() 356 | 357 | with torch.no_grad(): 358 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV) 359 | times = [] 360 | for i in range(input_ids.numel()): 361 | tick = time.time() 362 | out = model(input_ids[:, i].reshape(-1), past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1))) 363 | sync() 364 | times.append(time.time() - tick) 365 | print(i, times[-1]) 366 | if check and i != input_ids.numel() - 1: 367 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() 368 | cache['past'] = list(out.past_key_values) 369 | del out 370 | sync() 371 | import numpy as np 372 | print('Median:', np.median(times)) 373 | if check: 374 | print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) 375 | 376 | 377 | if __name__ == '__main__': 378 | 379 | parser = argparse.ArgumentParser() 380 | 381 | parser.add_argument('model', type=str, help='OPT model to load; pass `facebook/opt-X`.') 382 | parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') 383 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') 384 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.') 385 | parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') 386 | parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.') 387 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') 388 | parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') 389 | parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') 390 | parser.add_argument('--eval', action='store_true', help='evaluate quantized model.') 391 | parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.') 392 | parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.') 393 | parser.add_argument('--load', type=str, default='', help='Load quantized model.') 394 | parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.') 395 | parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') 396 | parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') 397 | parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') 398 | parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval') 399 | 400 | args = parser.parse_args() 401 | 402 | if type(args.load) is not str: 403 | args.load = args.load.as_posix() 404 | 405 | if args.load: 406 | model = load_quant(args.model, args.load, args.wbits, args.groupsize) 407 | else: 408 | model = get_opt(args.model) 409 | model.eval() 410 | 411 | dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen) 412 | 413 | if not args.load and args.wbits < 16 and not args.nearest: 414 | tick = time.time() 415 | quantizers = opt_sequential(model, dataloader, DEV) 416 | print(time.time() - tick) 417 | 418 | if args.benchmark: 419 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] 420 | if len(gpus) > 1: 421 | opt_multigpu(model, gpus) 422 | else: 423 | model = model.to(DEV) 424 | if args.benchmark: 425 | input_ids = next(iter(dataloader))[0][:, :args.benchmark] 426 | benchmark(model, input_ids, check=args.check) 427 | 428 | if args.eval: 429 | datasets = ['wikitext2', 'ptb', 'c4'] 430 | if args.new_eval: 431 | datasets = ['wikitext2', 'ptb-new', 'c4-new'] 432 | for dataset in datasets: 433 | dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) 434 | print(dataset) 435 | opt_eval(model, testloader, DEV) 436 | 437 | if args.save: 438 | opt_pack(model, quantizers, args.wbits, args.groupsize) 439 | torch.save(model.state_dict(), args.save) 440 | 441 | if args.save_safetensors: 442 | opt_pack(model, quantizers, args.wbits, args.groupsize) 443 | from safetensors.torch import save_file as safe_save 444 | state_dict = model.state_dict() 445 | state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} 446 | safe_save(state_dict, args.save_safetensors) 447 | -------------------------------------------------------------------------------- /quant/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantizer import Quantizer 2 | from .fused_attn import QuantLlamaAttention, make_quant_attn 3 | from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused 4 | from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear 5 | from .triton_norm import TritonLlamaRMSNorm, make_quant_norm 6 | -------------------------------------------------------------------------------- /quant/custom_autotune.py: -------------------------------------------------------------------------------- 1 | #https://github.com/fpgaminer/GPTQ-triton 2 | """ 3 | Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. 4 | """ 5 | 6 | import builtins 7 | import math 8 | import time 9 | from typing import Dict 10 | 11 | import triton 12 | 13 | 14 | class Autotuner(triton.KernelInterface): 15 | 16 | def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): 17 | ''' 18 | :param prune_configs_by: a dict of functions that are used to prune configs, fields: 19 | 'perf_model': performance model used to predicate running time with different configs, returns running time 20 | 'top_k': number of configs to bench 21 | 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. 22 | 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results 23 | ''' 24 | if not configs: 25 | self.configs = [triton.Config({}, num_warps=4, num_stages=2)] 26 | else: 27 | self.configs = configs 28 | self.key_idx = [arg_names.index(k) for k in key] 29 | self.nearest_power_of_two = nearest_power_of_two 30 | self.cache = {} 31 | # hook to reset all required tensor to zeros before relaunching a kernel 32 | self.hook = lambda args: 0 33 | if reset_to_zero is not None: 34 | self.reset_idx = [arg_names.index(k) for k in reset_to_zero] 35 | 36 | def _hook(args): 37 | for i in self.reset_idx: 38 | args[i].zero_() 39 | 40 | self.hook = _hook 41 | self.arg_names = arg_names 42 | # prune configs 43 | if prune_configs_by: 44 | perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] 45 | if 'early_config_prune' in prune_configs_by: 46 | early_config_prune = prune_configs_by['early_config_prune'] 47 | else: 48 | perf_model, top_k, early_config_prune = None, None, None 49 | self.perf_model, self.configs_top_k = perf_model, top_k 50 | self.early_config_prune = early_config_prune 51 | self.fn = fn 52 | 53 | def _bench(self, *args, config, **meta): 54 | # check for conflicts, i.e. meta-parameters both provided 55 | # as kwargs and by the autotuner 56 | conflicts = meta.keys() & config.kwargs.keys() 57 | if conflicts: 58 | raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." 59 | " Make sure that you don't re-define auto-tuned symbols.") 60 | # augment meta-parameters with tunable ones 61 | current = dict(meta, **config.kwargs) 62 | 63 | def kernel_call(): 64 | if config.pre_hook: 65 | config.pre_hook(self.nargs) 66 | self.hook(args) 67 | self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) 68 | 69 | try: 70 | # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses 71 | # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default 72 | return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) 73 | except triton.compiler.OutOfResources: 74 | return (float('inf'), float('inf'), float('inf')) 75 | 76 | def run(self, *args, **kwargs): 77 | self.nargs = dict(zip(self.arg_names, args)) 78 | if len(self.configs) > 1: 79 | key = tuple(args[i] for i in self.key_idx) 80 | 81 | # This reduces the amount of autotuning by rounding the keys to the nearest power of two 82 | # In my testing this gives decent results, and greatly reduces the amount of tuning required 83 | if self.nearest_power_of_two: 84 | key = tuple([2**int(math.log2(x) + 0.5) for x in key]) 85 | 86 | if key not in self.cache: 87 | # prune configs 88 | pruned_configs = self.prune_configs(kwargs) 89 | bench_start = time.time() 90 | timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} 91 | bench_end = time.time() 92 | self.bench_time = bench_end - bench_start 93 | self.cache[key] = builtins.min(timings, key=timings.get) 94 | self.hook(args) 95 | self.configs_timings = timings 96 | config = self.cache[key] 97 | else: 98 | config = self.configs[0] 99 | self.best_config = config 100 | if config.pre_hook is not None: 101 | config.pre_hook(self.nargs) 102 | return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) 103 | 104 | def prune_configs(self, kwargs): 105 | pruned_configs = self.configs 106 | if self.early_config_prune: 107 | pruned_configs = self.early_config_prune(self.configs, self.nargs) 108 | if self.perf_model: 109 | top_k = self.configs_top_k 110 | if isinstance(top_k, float) and top_k <= 1.0: 111 | top_k = int(len(self.configs) * top_k) 112 | if len(pruned_configs) > top_k: 113 | est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} 114 | pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] 115 | return pruned_configs 116 | 117 | def warmup(self, *args, **kwargs): 118 | self.nargs = dict(zip(self.arg_names, args)) 119 | for config in self.prune_configs(kwargs): 120 | self.fn.warmup( 121 | *args, 122 | num_warps=config.num_warps, 123 | num_stages=config.num_stages, 124 | **kwargs, 125 | **config.kwargs, 126 | ) 127 | self.nargs = None 128 | 129 | 130 | def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): 131 | """ 132 | Decorator for auto-tuning a :code:`triton.jit`'d function. 133 | .. highlight:: python 134 | .. code-block:: python 135 | @triton.autotune(configs=[ 136 | triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), 137 | triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), 138 | ], 139 | key=['x_size'] # the two above configs will be evaluated anytime 140 | # the value of x_size changes 141 | ) 142 | @triton.jit 143 | def kernel(x_ptr, x_size, **META): 144 | BLOCK_SIZE = META['BLOCK_SIZE'] 145 | :note: When all the configurations are evaluated, the kernel will run multiple time. 146 | This means that whatever value the kernel updates will be updated multiple times. 147 | To avoid this undesired behavior, you can use the `reset_to_zero` argument, which 148 | reset the value of the provided tensor to `zero` before running any configuration. 149 | :param configs: a list of :code:`triton.Config` objects 150 | :type configs: list[triton.Config] 151 | :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. 152 | :type key: list[str] 153 | :param prune_configs_by: a dict of functions that are used to prune configs, fields: 154 | 'perf_model': performance model used to predicate running time with different configs, returns running time 155 | 'top_k': number of configs to bench 156 | 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. 157 | :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. 158 | :type reset_to_zero: list[str] 159 | """ 160 | 161 | def decorator(fn): 162 | return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) 163 | 164 | return decorator 165 | 166 | 167 | def matmul248_kernel_config_pruner(configs, nargs): 168 | """ 169 | The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. 170 | """ 171 | m = max(2**int(math.ceil(math.log2(nargs['M']))), 16) 172 | n = max(2**int(math.ceil(math.log2(nargs['N']))), 16) 173 | k = max(2**int(math.ceil(math.log2(nargs['K']))), 16) 174 | 175 | used = set() 176 | for config in configs: 177 | block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) 178 | block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) 179 | block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) 180 | group_size_m = config.kwargs['GROUP_SIZE_M'] 181 | 182 | if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: 183 | continue 184 | 185 | used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) 186 | yield triton.Config({ 187 | 'BLOCK_SIZE_M': block_size_m, 188 | 'BLOCK_SIZE_N': block_size_n, 189 | 'BLOCK_SIZE_K': block_size_k, 190 | 'GROUP_SIZE_M': group_size_m 191 | }, 192 | num_stages=config.num_stages, 193 | num_warps=config.num_warps) 194 | -------------------------------------------------------------------------------- /quant/fused_attn.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from transformers.models.llama.modeling_llama import LlamaAttention 3 | from .quant_linear import * 4 | import triton 5 | import triton.language as tl 6 | 7 | 8 | @triton.jit 9 | def rotate_half_kernel( 10 | qk_seq_ptr, 11 | position_ids_ptr, 12 | qk_seq_stride, 13 | position_ids_batch_stride, 14 | seq_len, 15 | HEAD_DIM: tl.constexpr, 16 | BLOCK_HEIGHT: tl.constexpr, 17 | BLOCK_WIDTH: tl.constexpr, 18 | INV_BASE: tl.constexpr 19 | ): 20 | # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension. 21 | # position ids: (bsz, seq_len) -- must be contiguous in the last dimension. 22 | 23 | HALF_HEAD: tl.constexpr = HEAD_DIM // 2 24 | STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH 25 | 26 | batch_seq = tl.program_id(axis=0) 27 | row_blk_x_col_blk = tl.program_id(axis=1) 28 | 29 | row_blk = row_blk_x_col_blk // STEPS_PER_ROW 30 | row = row_blk * BLOCK_HEIGHT 31 | if BLOCK_WIDTH < HALF_HEAD: 32 | col_blk = row_blk_x_col_blk % STEPS_PER_ROW 33 | col = col_blk * BLOCK_WIDTH 34 | else: 35 | col: tl.constexpr = 0 36 | 37 | # A block will never cross a sequence boundary, which simplifies things a lot. 38 | batch = batch_seq // seq_len 39 | seq = batch_seq % seq_len 40 | position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq) 41 | # As sometimes happens, just calculating this on the fly is faster than loading it from memory. 42 | # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate. 43 | freq = tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE) * position_id 44 | cos = tl.cos(freq).to(tl.float32) 45 | sin = tl.sin(freq).to(tl.float32) 46 | 47 | col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH) 48 | embed_offsets = (row * HEAD_DIM + col) + col_offsets 49 | x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets 50 | 51 | for k in range(0, BLOCK_HEIGHT): 52 | x = tl.load(x_ptrs).to(tl.float32) 53 | y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32) 54 | out_x = x * cos - y * sin 55 | tl.store(x_ptrs, out_x) 56 | out_y = x * sin + y * cos 57 | tl.store(x_ptrs + HALF_HEAD, out_y) 58 | x_ptrs += HEAD_DIM 59 | 60 | 61 | def triton_rotate_half_(qk, position_ids, config=None): 62 | with torch.cuda.device(qk.device): 63 | batch_size, seq_len, qandk, num_heads, head_dim = qk.shape 64 | 65 | # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective. 66 | config = config or {'BLOCK_HEIGHT': 1, 'BLOCK_WIDTH': min(128, head_dim // 2), 'num_warps': 1} 67 | config['BLOCK_HEIGHT'] = min(config['BLOCK_HEIGHT'], 2 * num_heads) 68 | 69 | assert qk.stride(3) == head_dim 70 | assert qk.stride(4) == 1 71 | assert position_ids.shape == (batch_size, seq_len) 72 | assert position_ids.stride(1) == 1, 'position_ids must be contiguous in the last dimension' 73 | assert (2 * num_heads) % config['BLOCK_HEIGHT'] == 0, f'number of rows not evenly divisible by {config["BLOCK_HEIGHT"]}' 74 | assert (head_dim // 2) % config['BLOCK_WIDTH'] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config["BLOCK_WIDTH"]}' 75 | 76 | qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim) 77 | grid = (qk_by_seq.shape[0], (2 * num_heads // config['BLOCK_HEIGHT']) * (head_dim // 2 // config['BLOCK_WIDTH'])) 78 | 79 | # Must be the same as the theta of the frequencies used to train the model. 80 | BASE = 10000.0 81 | 82 | rotate_half_kernel[grid]( 83 | qk_by_seq, 84 | position_ids, 85 | qk_by_seq.stride(0), 86 | position_ids.stride(0), 87 | seq_len, 88 | HEAD_DIM=head_dim, 89 | BLOCK_HEIGHT=config['BLOCK_HEIGHT'], 90 | BLOCK_WIDTH=config['BLOCK_WIDTH'], 91 | INV_BASE=-2.0 * math.log(BASE) / head_dim, 92 | num_warps=config['num_warps'] 93 | ) 94 | 95 | 96 | class QuantLlamaAttention(nn.Module): 97 | """Multi-headed attention from 'Attention Is All You Need' paper""" 98 | 99 | def __init__( 100 | self, 101 | hidden_size, 102 | num_heads, 103 | qkv_proj, 104 | o_proj 105 | ): 106 | super().__init__() 107 | self.hidden_size = hidden_size 108 | self.num_heads = num_heads 109 | self.head_dim = hidden_size // num_heads 110 | 111 | if (self.head_dim * num_heads) != self.hidden_size: 112 | raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 113 | f" and `num_heads`: {num_heads}).") 114 | self.qkv_proj = qkv_proj 115 | self.o_proj = o_proj 116 | 117 | def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): 118 | """Input shape: Batch x Time x Channel""" 119 | 120 | bsz, q_len, _ = hidden_states.size() 121 | 122 | qkv_states = self.qkv_proj(hidden_states) 123 | qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim) 124 | 125 | # This updates the query and key states in-place, saving VRAM. 126 | triton_rotate_half_(qkv_states[:, :, :2], position_ids) 127 | 128 | query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2) 129 | del qkv_states 130 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 131 | key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 132 | value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 133 | 134 | is_causal = past_key_value is None 135 | 136 | kv_seq_len = q_len 137 | if past_key_value is not None: 138 | kv_seq_len += past_key_value[0].shape[-2] 139 | 140 | if past_key_value is not None: 141 | # reuse k, v, self_attention 142 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 143 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 144 | 145 | if use_cache: 146 | # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor 147 | # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. 148 | key_states = key_states.contiguous() 149 | value_states = value_states.contiguous() 150 | query_states = query_states.contiguous() 151 | 152 | past_key_value = (key_states, value_states) if use_cache else None 153 | 154 | with torch.backends.cuda.sdp_kernel(enable_math=False): 155 | attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) 156 | del query_states, key_states, value_states 157 | 158 | attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) 159 | attn_output = self.o_proj(attn_output) 160 | 161 | return attn_output, None, past_key_value 162 | 163 | 164 | def make_quant_attn(model): 165 | """ 166 | Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. 167 | """ 168 | 169 | for name, m in model.named_modules(): 170 | if not isinstance(m, LlamaAttention): 171 | continue 172 | 173 | q_proj = m.q_proj 174 | k_proj = m.k_proj 175 | v_proj = m.v_proj 176 | 177 | qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) 178 | qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) 179 | scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) 180 | g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) 181 | bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None 182 | 183 | qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False) 184 | qkv_layer.qweight = qweights 185 | qkv_layer.qzeros = qzeros 186 | qkv_layer.scales = scales 187 | qkv_layer.g_idx = g_idx 188 | qkv_layer.bias = bias 189 | # We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch. 190 | 191 | attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj) 192 | 193 | if '.' in name: 194 | parent_name = name.rsplit('.', 1)[0] 195 | child_name = name[len(parent_name) + 1:] 196 | parent = model.get_submodule(parent_name) 197 | else: 198 | parent_name = '' 199 | parent = model 200 | child_name = name 201 | 202 | #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") 203 | 204 | setattr(parent, child_name, attn) 205 | -------------------------------------------------------------------------------- /quant/fused_mlp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.cuda.amp import custom_bwd, custom_fwd 5 | from transformers.models.llama.modeling_llama import LlamaMLP 6 | 7 | try: 8 | import triton 9 | import triton.language as tl 10 | from . import custom_autotune 11 | 12 | # code based https://github.com/fpgaminer/GPTQ-triton 13 | @custom_autotune.autotune( 14 | configs=[ 15 | triton.Config({ 16 | 'BLOCK_SIZE_M': 256, 17 | 'BLOCK_SIZE_N': 64, 18 | 'BLOCK_SIZE_K': 32, 19 | 'GROUP_SIZE_M': 8 20 | }, num_stages=4, num_warps=4), 21 | triton.Config({ 22 | 'BLOCK_SIZE_M': 64, 23 | 'BLOCK_SIZE_N': 256, 24 | 'BLOCK_SIZE_K': 32, 25 | 'GROUP_SIZE_M': 8 26 | }, num_stages=4, num_warps=4), 27 | triton.Config({ 28 | 'BLOCK_SIZE_M': 128, 29 | 'BLOCK_SIZE_N': 128, 30 | 'BLOCK_SIZE_K': 32, 31 | 'GROUP_SIZE_M': 8 32 | }, num_stages=4, num_warps=4), 33 | triton.Config({ 34 | 'BLOCK_SIZE_M': 128, 35 | 'BLOCK_SIZE_N': 64, 36 | 'BLOCK_SIZE_K': 32, 37 | 'GROUP_SIZE_M': 8 38 | }, num_stages=4, num_warps=4), 39 | triton.Config({ 40 | 'BLOCK_SIZE_M': 64, 41 | 'BLOCK_SIZE_N': 128, 42 | 'BLOCK_SIZE_K': 32, 43 | 'GROUP_SIZE_M': 8 44 | }, num_stages=4, num_warps=4), 45 | triton.Config({ 46 | 'BLOCK_SIZE_M': 128, 47 | 'BLOCK_SIZE_N': 32, 48 | 'BLOCK_SIZE_K': 32, 49 | 'GROUP_SIZE_M': 8 50 | }, num_stages=4, num_warps=4), # 3090 51 | triton.Config({ 52 | 'BLOCK_SIZE_M': 128, 53 | 'BLOCK_SIZE_N': 16, 54 | 'BLOCK_SIZE_K': 32, 55 | 'GROUP_SIZE_M': 8 56 | }, num_stages=4, num_warps=4), # 3090 57 | triton.Config({ 58 | 'BLOCK_SIZE_M': 32, 59 | 'BLOCK_SIZE_N': 32, 60 | 'BLOCK_SIZE_K': 128, 61 | 'GROUP_SIZE_M': 8 62 | }, num_stages=2, num_warps=4), # 3090 63 | triton.Config({ 64 | 'BLOCK_SIZE_M': 64, 65 | 'BLOCK_SIZE_N': 16, 66 | 'BLOCK_SIZE_K': 64, 67 | 'GROUP_SIZE_M': 8 68 | }, num_stages=4, num_warps=4), # 3090 69 | triton.Config({ 70 | 'BLOCK_SIZE_M': 64, 71 | 'BLOCK_SIZE_N': 32, 72 | 'BLOCK_SIZE_K': 64, 73 | 'GROUP_SIZE_M': 8 74 | }, num_stages=4, num_warps=4), # 3090 75 | ], 76 | key=['M', 'N', 'K'], 77 | nearest_power_of_two=True, 78 | prune_configs_by={ 79 | 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, 80 | 'perf_model': None, 81 | 'top_k': None, 82 | }, 83 | ) 84 | @triton.jit 85 | def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, 86 | stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): 87 | """ 88 | Computes: C = silu(A * B1) * (A * B2) 89 | A is of shape (M, K) float16 90 | B is of shape (K//8, N) int32 91 | C is of shape (M, N) float16 92 | scales is of shape (1, N) float16 93 | zeros is of shape (1, N//8) int32 94 | """ 95 | infearure_per_bits = 32 // bits 96 | 97 | pid = tl.program_id(axis=0) 98 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 99 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 100 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 101 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 102 | group_id = pid // num_pid_in_group 103 | first_pid_m = group_id * GROUP_SIZE_M 104 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 105 | pid_m = first_pid_m + (pid % group_size_m) 106 | pid_n = (pid % num_pid_in_group) // group_size_m 107 | 108 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 109 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 110 | offs_k = tl.arange(0, BLOCK_SIZE_K) 111 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 112 | a_mask = (offs_am[:, None] < M) 113 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 114 | b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) 115 | b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) 116 | g1_ptrs = g1_ptr + offs_k 117 | g2_ptrs = g2_ptr + offs_k 118 | # shifter is used to extract the N bits of each element in the 32-bit word from B 119 | scales1_ptrs = scales1_ptr + offs_bn[None, :] 120 | scales2_ptrs = scales2_ptr + offs_bn[None, :] 121 | zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits) 122 | zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits) 123 | 124 | shifter = (offs_k % infearure_per_bits) * bits 125 | zeros_shifter = (offs_bn % infearure_per_bits) * bits 126 | accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 127 | accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 128 | for k in range(0, num_pid_k): 129 | g1_idx = tl.load(g1_ptrs) 130 | g2_idx = tl.load(g2_ptrs) 131 | 132 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 133 | scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 134 | scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales) 135 | 136 | zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 137 | zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq 138 | zeros1 = (zeros1 + 1) 139 | 140 | zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 141 | zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq 142 | zeros2 = (zeros2 + 1) 143 | 144 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 145 | b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 146 | b2 = tl.load(b2_ptrs) 147 | 148 | # Now we need to unpack b (which is N-bit values) into 32-bit values 149 | b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values 150 | b1 = (b1 - zeros1) * scales1 # Scale and shift 151 | accumulator1 += tl.dot(a, b1) 152 | 153 | b2 = (b2 >> shifter[:, None]) & maxq 154 | b2 = (b2 - zeros2) * scales2 155 | accumulator2 += tl.dot(a, b2) 156 | 157 | a_ptrs += BLOCK_SIZE_K 158 | b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk 159 | b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk 160 | g1_ptrs += BLOCK_SIZE_K 161 | g2_ptrs += BLOCK_SIZE_K 162 | 163 | accumulator1 = silu(accumulator1) 164 | c = accumulator1 * accumulator2 165 | c = c.to(tl.float16) 166 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] 167 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) 168 | tl.store(c_ptrs, c, mask=c_mask) 169 | 170 | @triton.jit 171 | def silu(x): 172 | return x * tl.sigmoid(x) 173 | except: 174 | print('triton not installed.') 175 | 176 | 177 | class QuantLlamaMLP(nn.Module): 178 | 179 | def __init__( 180 | self, 181 | gate_proj, 182 | down_proj, 183 | up_proj, 184 | ): 185 | super().__init__() 186 | self.register_buffer('gate_proj_qweight', gate_proj.qweight) 187 | self.register_buffer('gate_proj_scales', gate_proj.scales) 188 | self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) 189 | self.register_buffer('gate_proj_g_idx', gate_proj.g_idx) 190 | self.register_buffer('up_proj_qweight', up_proj.qweight) 191 | self.register_buffer('up_proj_scales', up_proj.scales) 192 | self.register_buffer('up_proj_qzeros', up_proj.qzeros) 193 | self.register_buffer('up_proj_g_idx', up_proj.g_idx) 194 | 195 | self.infeatures = gate_proj.infeatures 196 | self.intermediate_size = gate_proj.outfeatures 197 | self.outfeatures = down_proj.outfeatures 198 | self.bits = gate_proj.bits 199 | self.maxq = gate_proj.maxq 200 | 201 | self.down_proj = down_proj 202 | 203 | def forward(self, x): 204 | return self.down_proj(self.triton_llama_mlp(x)) 205 | 206 | def triton_llama_mlp(self, x): 207 | with torch.cuda.device(x.device): 208 | out_shape = x.shape[:-1] + (self.intermediate_size, ) 209 | x = x.reshape(-1, x.shape[-1]) 210 | M, K = x.shape 211 | N = self.intermediate_size 212 | c = torch.empty((M, N), device=x.device, dtype=torch.float16) 213 | grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) 214 | fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales, 215 | self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0), 216 | self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0)) 217 | c = c.reshape(out_shape) 218 | return c 219 | 220 | def fused2cuda(self): 221 | self.gate_proj_qweight = self.gate_proj_qweight.cuda() 222 | self.gate_proj_scales = self.gate_proj_scales.cuda() 223 | self.gate_proj_qzeros = self.gate_proj_qzeros.cuda() 224 | self.gate_proj_g_idx = self.gate_proj_g_idx.cuda() 225 | self.up_proj_qweight = self.up_proj_qweight.cuda() 226 | self.up_proj_scales = self.up_proj_scales.cuda() 227 | self.up_proj_qzeros = self.up_proj_qzeros.cuda() 228 | self.up_proj_g_idx = self.up_proj_g_idx.cuda() 229 | 230 | def fused2cpu(self): 231 | self.gate_proj_qweight = self.gate_proj_qweight.cpu() 232 | self.gate_proj_scales = self.gate_proj_scales.cpu() 233 | self.gate_proj_qzeros = self.gate_proj_qzeros.cpu() 234 | self.gate_proj_g_idx = self.gate_proj_g_idx.cpu() 235 | self.up_proj_qweight = self.up_proj_qweight.cpu() 236 | self.up_proj_scales = self.up_proj_scales.cpu() 237 | self.up_proj_qzeros = self.up_proj_qzeros.cpu() 238 | self.up_proj_g_idx = self.up_proj_g_idx.cpu() 239 | 240 | 241 | def make_fused_mlp(m, parent_name=''): 242 | """ 243 | Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. 244 | """ 245 | if isinstance(m, LlamaMLP): 246 | return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj) 247 | 248 | for name, child in m.named_children(): 249 | child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") 250 | 251 | if isinstance(child, QuantLlamaMLP): 252 | setattr(m, name, child) 253 | return m 254 | 255 | 256 | def autotune_warmup_fused(model): 257 | """ 258 | Pre-tunes the quantized kernel 259 | """ 260 | from tqdm import tqdm 261 | 262 | kn_values = {} 263 | 264 | for _, m in model.named_modules(): 265 | if not isinstance(m, QuantLlamaMLP): 266 | continue 267 | 268 | k = m.infeatures 269 | n = m.intermediate_size 270 | 271 | m.fused2cuda() 272 | if (k, n) not in kn_values: 273 | kn_values[(k, n)] = m 274 | 275 | print(f'Found {len(kn_values)} unique fused mlp KN values.') 276 | 277 | print('Warming up autotune cache ...') 278 | with torch.no_grad(): 279 | for m in tqdm(range(0, 12)): 280 | m = 2**m # [1, 2048] 281 | for (k, n), (modules) in kn_values.items(): 282 | a = torch.randn(m, k, dtype=torch.float16, device='cuda') 283 | modules.triton_llama_mlp(a) 284 | 285 | for (k, n), (modules) in kn_values.items(): 286 | a = torch.randn(m, k, dtype=torch.float16, device='cuda') 287 | modules.fused2cpu() 288 | del kn_values 289 | -------------------------------------------------------------------------------- /quant/quant_linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.cuda.amp import custom_bwd, custom_fwd 6 | 7 | try: 8 | import triton 9 | import triton.language as tl 10 | from . import custom_autotune 11 | 12 | # code based https://github.com/fpgaminer/GPTQ-triton 13 | @custom_autotune.autotune( 14 | configs=[ 15 | triton.Config({ 16 | 'BLOCK_SIZE_M': 64, 17 | 'BLOCK_SIZE_N': 256, 18 | 'BLOCK_SIZE_K': 32, 19 | 'GROUP_SIZE_M': 8 20 | }, num_stages=4, num_warps=4), 21 | triton.Config({ 22 | 'BLOCK_SIZE_M': 128, 23 | 'BLOCK_SIZE_N': 128, 24 | 'BLOCK_SIZE_K': 32, 25 | 'GROUP_SIZE_M': 8 26 | }, num_stages=4, num_warps=4), 27 | triton.Config({ 28 | 'BLOCK_SIZE_M': 64, 29 | 'BLOCK_SIZE_N': 128, 30 | 'BLOCK_SIZE_K': 32, 31 | 'GROUP_SIZE_M': 8 32 | }, num_stages=4, num_warps=4), 33 | triton.Config({ 34 | 'BLOCK_SIZE_M': 128, 35 | 'BLOCK_SIZE_N': 32, 36 | 'BLOCK_SIZE_K': 32, 37 | 'GROUP_SIZE_M': 8 38 | }, num_stages=4, num_warps=4), 39 | triton.Config({ 40 | 'BLOCK_SIZE_M': 64, 41 | 'BLOCK_SIZE_N': 64, 42 | 'BLOCK_SIZE_K': 32, 43 | 'GROUP_SIZE_M': 8 44 | }, num_stages=4, num_warps=4), 45 | triton.Config({ 46 | 'BLOCK_SIZE_M': 64, 47 | 'BLOCK_SIZE_N': 128, 48 | 'BLOCK_SIZE_K': 32, 49 | 'GROUP_SIZE_M': 8 50 | }, num_stages=2, num_warps=8), 51 | triton.Config({ 52 | 'BLOCK_SIZE_M': 64, 53 | 'BLOCK_SIZE_N': 64, 54 | 'BLOCK_SIZE_K': 64, 55 | 'GROUP_SIZE_M': 8 56 | }, num_stages=3, num_warps=8), 57 | triton.Config({ 58 | 'BLOCK_SIZE_M': 32, 59 | 'BLOCK_SIZE_N': 32, 60 | 'BLOCK_SIZE_K': 128, 61 | 'GROUP_SIZE_M': 8 62 | }, num_stages=2, num_warps=4), 63 | ], 64 | key=['M', 'N', 'K'], 65 | nearest_power_of_two=True, 66 | prune_configs_by={ 67 | 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, 68 | 'perf_model': None, 69 | 'top_k': None, 70 | }, 71 | ) 72 | @triton.jit 73 | def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, 74 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): 75 | """ 76 | Compute the matrix multiplication C = A x B. 77 | A is of shape (M, K) float16 78 | B is of shape (K//8, N) int32 79 | C is of shape (M, N) float16 80 | scales is of shape (G, N) float16 81 | zeros is of shape (G, N) float16 82 | g_ptr is of shape (K) int32 83 | """ 84 | infearure_per_bits = 32 // bits 85 | 86 | pid = tl.program_id(axis=0) 87 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 88 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 89 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 90 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 91 | group_id = pid // num_pid_in_group 92 | first_pid_m = group_id * GROUP_SIZE_M 93 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 94 | pid_m = first_pid_m + (pid % group_size_m) 95 | pid_n = (pid % num_pid_in_group) // group_size_m 96 | 97 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 98 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 99 | offs_k = tl.arange(0, BLOCK_SIZE_K) 100 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 101 | a_mask = (offs_am[:, None] < M) 102 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 103 | b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 104 | g_ptrs = g_ptr + offs_k 105 | # shifter is used to extract the N bits of each element in the 32-bit word from B 106 | scales_ptrs = scales_ptr + offs_bn[None, :] 107 | zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) 108 | 109 | shifter = (offs_k % infearure_per_bits) * bits 110 | zeros_shifter = (offs_bn % infearure_per_bits) * bits 111 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 112 | 113 | for k in range(0, num_pid_k): 114 | g_idx = tl.load(g_ptrs) 115 | 116 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 117 | scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 118 | zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 119 | 120 | zeros = (zeros >> zeros_shifter[None, :]) & maxq 121 | zeros = (zeros + 1) 122 | 123 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) 124 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 125 | 126 | # Now we need to unpack b (which is N-bit values) into 32-bit values 127 | b = (b >> shifter[:, None]) & maxq # Extract the N-bit values 128 | b = (b - zeros) * scales # Scale and shift 129 | 130 | accumulator += tl.dot(a, b) 131 | a_ptrs += BLOCK_SIZE_K 132 | b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk 133 | g_ptrs += BLOCK_SIZE_K 134 | 135 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] 136 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) 137 | tl.store(c_ptrs, accumulator, mask=c_mask) 138 | 139 | @custom_autotune.autotune(configs=[ 140 | triton.Config({ 141 | 'BLOCK_SIZE_M': 64, 142 | 'BLOCK_SIZE_N': 32, 143 | 'BLOCK_SIZE_K': 256, 144 | 'GROUP_SIZE_M': 8 145 | }, num_stages=4, num_warps=4), 146 | triton.Config({ 147 | 'BLOCK_SIZE_M': 128, 148 | 'BLOCK_SIZE_N': 32, 149 | 'BLOCK_SIZE_K': 128, 150 | 'GROUP_SIZE_M': 8 151 | }, num_stages=4, num_warps=4), 152 | triton.Config({ 153 | 'BLOCK_SIZE_M': 64, 154 | 'BLOCK_SIZE_N': 32, 155 | 'BLOCK_SIZE_K': 128, 156 | 'GROUP_SIZE_M': 8 157 | }, num_stages=4, num_warps=4), 158 | triton.Config({ 159 | 'BLOCK_SIZE_M': 128, 160 | 'BLOCK_SIZE_N': 32, 161 | 'BLOCK_SIZE_K': 32, 162 | 'GROUP_SIZE_M': 8 163 | }, num_stages=4, num_warps=4), 164 | triton.Config({ 165 | 'BLOCK_SIZE_M': 64, 166 | 'BLOCK_SIZE_N': 32, 167 | 'BLOCK_SIZE_K': 64, 168 | 'GROUP_SIZE_M': 8 169 | }, num_stages=4, num_warps=4), 170 | triton.Config({ 171 | 'BLOCK_SIZE_M': 64, 172 | 'BLOCK_SIZE_N': 32, 173 | 'BLOCK_SIZE_K': 128, 174 | 'GROUP_SIZE_M': 8 175 | }, num_stages=2, num_warps=8), 176 | triton.Config({ 177 | 'BLOCK_SIZE_M': 64, 178 | 'BLOCK_SIZE_N': 64, 179 | 'BLOCK_SIZE_K': 64, 180 | 'GROUP_SIZE_M': 8 181 | }, num_stages=3, num_warps=8), 182 | triton.Config({ 183 | 'BLOCK_SIZE_M': 32, 184 | 'BLOCK_SIZE_N': 128, 185 | 'BLOCK_SIZE_K': 32, 186 | 'GROUP_SIZE_M': 8 187 | }, num_stages=2, num_warps=4), 188 | ], 189 | key=['M', 'N', 'K'], 190 | nearest_power_of_two=True) 191 | @triton.jit 192 | def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, 193 | stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): 194 | """ 195 | Compute the matrix multiplication C = A x B. 196 | A is of shape (M, N) float16 197 | B is of shape (K//8, N) int32 198 | C is of shape (M, K) float16 199 | scales is of shape (G, N) float16 200 | zeros is of shape (G, N) float16 201 | g_ptr is of shape (K) int32 202 | """ 203 | infearure_per_bits = 32 // bits 204 | 205 | pid = tl.program_id(axis=0) 206 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 207 | num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) 208 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 209 | num_pid_in_group = GROUP_SIZE_M * num_pid_k 210 | group_id = pid // num_pid_in_group 211 | first_pid_m = group_id * GROUP_SIZE_M 212 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 213 | pid_m = first_pid_m + (pid % group_size_m) 214 | pid_k = (pid % num_pid_in_group) // group_size_m 215 | 216 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 217 | offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 218 | offs_n = tl.arange(0, BLOCK_SIZE_N) 219 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) 220 | a_mask = (offs_am[:, None] < M) 221 | # b_ptrs is set up such that it repeats elements along the K axis 8 times 222 | b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) 223 | g_ptrs = g_ptr + offs_bk 224 | g_idx = tl.load(g_ptrs) 225 | 226 | # shifter is used to extract the N bits of each element in the 32-bit word from B 227 | scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales 228 | zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros 229 | 230 | shifter = (offs_bk % infearure_per_bits) * bits 231 | zeros_shifter = (offs_n % infearure_per_bits) * bits 232 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) 233 | 234 | for n in range(0, num_pid_n): 235 | # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop 236 | scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 237 | zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) 238 | 239 | zeros = (zeros >> zeros_shifter[None, :]) & maxq 240 | zeros = (zeros + 1) 241 | 242 | a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) 243 | b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated 244 | 245 | # Now we need to unpack b (which is N-bit values) into 32-bit values 246 | b = (b >> shifter[:, None]) & maxq # Extract the N-bit values 247 | b = (b - zeros) * scales # Scale and shift 248 | b = tl.trans(b) 249 | 250 | accumulator += tl.dot(a, b) 251 | a_ptrs += BLOCK_SIZE_N 252 | b_ptrs += BLOCK_SIZE_N 253 | scales_ptrs += BLOCK_SIZE_N 254 | zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) 255 | 256 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] 257 | c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) 258 | tl.store(c_ptrs, accumulator, mask=c_mask) 259 | except: 260 | print('triton not installed.') 261 | 262 | 263 | def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): 264 | with torch.cuda.device(input.device): 265 | output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) 266 | grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) 267 | matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), 268 | qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) 269 | return output 270 | 271 | 272 | def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): 273 | with torch.cuda.device(input.device): 274 | output_dim = (qweight.shape[0] * 32) // bits 275 | output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16) 276 | grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) 277 | transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), 278 | qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) 279 | return output 280 | 281 | 282 | class QuantLinearFunction(torch.autograd.Function): 283 | 284 | @staticmethod 285 | @custom_fwd(cast_inputs=torch.float16) 286 | def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): 287 | output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) 288 | ctx.save_for_backward(qweight, scales, qzeros, g_idx) 289 | ctx.bits, ctx.maxq = bits, maxq 290 | return output 291 | 292 | @staticmethod 293 | @custom_bwd 294 | def backward(ctx, grad_output): 295 | qweight, scales, qzeros, g_idx = ctx.saved_tensors 296 | bits, maxq = ctx.bits, ctx.maxq 297 | grad_input = None 298 | 299 | if ctx.needs_input_grad[0]: 300 | grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) 301 | return grad_input, None, None, None, None, None, None 302 | 303 | 304 | class QuantLinear(nn.Module): 305 | 306 | def __init__(self, bits, groupsize, infeatures, outfeatures, bias): 307 | super().__init__() 308 | if bits not in [2, 4, 8]: 309 | raise NotImplementedError("Only 2,4,8 bits are supported.") 310 | self.infeatures = infeatures 311 | self.outfeatures = outfeatures 312 | self.bits = bits 313 | self.maxq = 2**self.bits - 1 314 | self.groupsize = groupsize if groupsize != -1 else infeatures 315 | 316 | self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) 317 | self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) 318 | self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) 319 | self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) 320 | if bias: 321 | self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) 322 | else: 323 | self.bias = None 324 | 325 | def pack(self, linear, scales, zeros, g_idx=None): 326 | self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx 327 | 328 | scales = scales.t().contiguous() 329 | zeros = zeros.t().contiguous() 330 | scale_zeros = zeros * scales 331 | self.scales = scales.clone().half() 332 | if linear.bias is not None: 333 | self.bias = linear.bias.clone().half() 334 | 335 | intweight = [] 336 | for idx in range(self.infeatures): 337 | intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None]) 338 | intweight = torch.cat(intweight, dim=1) 339 | intweight = intweight.t().contiguous() 340 | intweight = intweight.numpy().astype(np.uint32) 341 | qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) 342 | i = 0 343 | row = 0 344 | while row < qweight.shape[0]: 345 | if self.bits in [2, 4, 8]: 346 | for j in range(i, i + (32 // self.bits)): 347 | qweight[row] |= intweight[j] << (self.bits * (j - i)) 348 | i += 32 // self.bits 349 | row += 1 350 | else: 351 | raise NotImplementedError("Only 2,4,8 bits are supported.") 352 | 353 | qweight = qweight.astype(np.int32) 354 | self.qweight = torch.from_numpy(qweight) 355 | 356 | zeros -= 1 357 | zeros = zeros.numpy().astype(np.uint32) 358 | qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) 359 | i = 0 360 | col = 0 361 | while col < qzeros.shape[1]: 362 | if self.bits in [2, 4, 8]: 363 | for j in range(i, i + (32 // self.bits)): 364 | qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) 365 | i += 32 // self.bits 366 | col += 1 367 | else: 368 | raise NotImplementedError("Only 2,4,8 bits are supported.") 369 | 370 | qzeros = qzeros.astype(np.int32) 371 | self.qzeros = torch.from_numpy(qzeros) 372 | 373 | def forward(self, x): 374 | out_shape = x.shape[:-1] + (self.outfeatures, ) 375 | out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) 376 | out = out + self.bias if self.bias is not None else out 377 | return out.reshape(out_shape) 378 | 379 | 380 | def make_quant_linear(module, names, bits, groupsize, name=''): 381 | if isinstance(module, QuantLinear): 382 | return 383 | for attr in dir(module): 384 | tmp = getattr(module, attr) 385 | name1 = name + '.' + attr if name != '' else attr 386 | if name1 in names: 387 | delattr(module, attr) 388 | setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) 389 | for name1, child in module.named_children(): 390 | make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) 391 | 392 | 393 | def autotune_warmup_linear(model, transpose=False): 394 | """ 395 | Pre-tunes the quantized kernel 396 | """ 397 | from tqdm import tqdm 398 | 399 | kn_values = {} 400 | 401 | for _, m in model.named_modules(): 402 | if not isinstance(m, QuantLinear): 403 | continue 404 | 405 | k = m.infeatures 406 | n = m.outfeatures 407 | 408 | if (k, n) not in kn_values: 409 | kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq) 410 | 411 | print(f'Found {len(kn_values)} unique KN Linear values.') 412 | 413 | print('Warming up autotune cache ...') 414 | with torch.no_grad(): 415 | for m in tqdm(range(0, 12)): 416 | m = 2**m # [1, 2048] 417 | for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items(): 418 | a = torch.randn(m, k, dtype=torch.float16, device='cuda') 419 | matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) 420 | if transpose: 421 | a = torch.randn(m, n, dtype=torch.float16, device='cuda') 422 | transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) 423 | del kn_values 424 | -------------------------------------------------------------------------------- /quant/quantizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class Quantizer(nn.Module): 8 | 9 | def __init__(self, shape=1): 10 | super(Quantizer, self).__init__() 11 | self.register_buffer('maxq', torch.tensor(0)) 12 | self.register_buffer('scale', torch.zeros(shape)) 13 | self.register_buffer('zero', torch.zeros(shape)) 14 | 15 | def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): 16 | 17 | self.maxq = torch.tensor(2**bits - 1) 18 | self.perchannel = perchannel 19 | self.sym = sym 20 | self.mse = mse 21 | self.norm = norm 22 | self.grid = grid 23 | self.maxshrink = maxshrink 24 | if trits: 25 | self.maxq = torch.tensor(-1) 26 | self.scale = torch.zeros_like(self.scale) 27 | 28 | def _quantize(self, x, scale, zero, maxq): 29 | if maxq < 0: 30 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 31 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 32 | return scale * (q - zero) 33 | 34 | def find_params(self, x, weight=False): 35 | dev = x.device 36 | self.maxq = self.maxq.to(dev) 37 | 38 | shape = x.shape 39 | if self.perchannel: 40 | if weight: 41 | x = x.flatten(1) 42 | else: 43 | if len(shape) == 4: 44 | x = x.permute([1, 0, 2, 3]) 45 | x = x.flatten(1) 46 | if len(shape) == 3: 47 | x = x.reshape((-1, shape[-1])).t() 48 | if len(shape) == 2: 49 | x = x.t() 50 | else: 51 | x = x.flatten().unsqueeze(0) 52 | 53 | tmp = torch.zeros(x.shape[0], device=dev) 54 | xmin = torch.minimum(x.min(1)[0], tmp) 55 | xmax = torch.maximum(x.max(1)[0], tmp) 56 | 57 | if self.sym: 58 | xmax = torch.maximum(torch.abs(xmin), xmax) 59 | tmp = xmin < 0 60 | if torch.any(tmp): 61 | xmin[tmp] = -xmax[tmp] 62 | tmp = (xmin == 0) & (xmax == 0) 63 | xmin[tmp] = -1 64 | xmax[tmp] = +1 65 | 66 | if self.maxq < 0: 67 | self.scale = xmax 68 | self.zero = xmin 69 | else: 70 | self.scale = (xmax - xmin) / self.maxq 71 | if self.sym: 72 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 73 | else: 74 | self.zero = torch.round(-xmin / self.scale) 75 | 76 | if self.mse: 77 | best = torch.full([x.shape[0]], float('inf'), device=dev) 78 | for i in range(int(self.maxshrink * self.grid)): 79 | p = 1 - i / self.grid 80 | xmin1 = p * xmin 81 | xmax1 = p * xmax 82 | scale1 = (xmax1 - xmin1) / self.maxq 83 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 84 | q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 85 | q -= x 86 | q.abs_() 87 | q.pow_(self.norm) 88 | err = torch.sum(q, 1) 89 | tmp = err < best 90 | if torch.any(tmp): 91 | best[tmp] = err[tmp] 92 | self.scale[tmp] = scale1[tmp] 93 | self.zero[tmp] = zero1[tmp] 94 | if not self.perchannel: 95 | if weight: 96 | tmp = shape[0] 97 | else: 98 | tmp = shape[1] if len(shape) != 3 else shape[2] 99 | self.scale = self.scale.repeat(tmp) 100 | self.zero = self.zero.repeat(tmp) 101 | 102 | if weight: 103 | shape = [-1] + [1] * (len(shape) - 1) 104 | self.scale = self.scale.reshape(shape) 105 | self.zero = self.zero.reshape(shape) 106 | return 107 | if len(shape) == 4: 108 | self.scale = self.scale.reshape((1, -1, 1, 1)) 109 | self.zero = self.zero.reshape((1, -1, 1, 1)) 110 | if len(shape) == 3: 111 | self.scale = self.scale.reshape((1, 1, -1)) 112 | self.zero = self.zero.reshape((1, 1, -1)) 113 | if len(shape) == 2: 114 | self.scale = self.scale.unsqueeze(0) 115 | self.zero = self.zero.unsqueeze(0) 116 | 117 | def quantize(self, x): 118 | if self.ready(): 119 | return self._quantize(x, self.scale, self.zero, self.maxq) 120 | 121 | return x 122 | 123 | def enabled(self): 124 | return self.maxq > 0 125 | 126 | def ready(self): 127 | return torch.all(self.scale != 0) 128 | -------------------------------------------------------------------------------- /quant/triton_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import triton 4 | import triton.language as tl 5 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 6 | 7 | @triton.jit 8 | def rms_norm_fwd_fused( 9 | X, # pointer to the input 10 | Y, # pointer to the output 11 | W, # pointer to the weights 12 | stride, # how much to increase the pointer when moving by 1 row 13 | N, # number of columns in X 14 | eps, # epsilon to avoid division by zero 15 | BLOCK_SIZE: tl.constexpr, 16 | ): 17 | # Map the program id to the row of X and Y it should compute. 18 | row = tl.program_id(0) 19 | Y += row * stride 20 | X += row * stride 21 | # Compute variance 22 | _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 23 | for off in range(0, N, BLOCK_SIZE): 24 | cols = off + tl.arange(0, BLOCK_SIZE) 25 | x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) 26 | x = tl.where(cols < N, x, 0.) 27 | _var += x * x 28 | var = tl.sum(_var, axis=0) / N 29 | rstd = 1 / tl.sqrt(var + eps) 30 | # Normalize and apply linear transformation 31 | for off in range(0, N, BLOCK_SIZE): 32 | cols = off + tl.arange(0, BLOCK_SIZE) 33 | mask = cols < N 34 | w = tl.load(W + cols, mask=mask) 35 | x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) 36 | x_hat = x * rstd 37 | y = x_hat * w 38 | # Write output 39 | tl.store(Y + cols, y, mask=mask) 40 | 41 | class TritonLlamaRMSNorm(nn.Module): 42 | def __init__(self, weight, eps=1e-6): 43 | """ 44 | LlamaRMSNorm is equivalent to T5LayerNorm 45 | """ 46 | super().__init__() 47 | self.weight = weight 48 | self.variance_epsilon = eps 49 | 50 | def forward(self, x): 51 | with torch.cuda.device(x.device): 52 | y = torch.empty_like(x) 53 | # reshape input data into 2D tensor 54 | x_arg = x.reshape(-1, x.shape[-1]) 55 | M, N = x_arg.shape 56 | # Less than 64KB per feature: enqueue fused kernel 57 | MAX_FUSED_SIZE = 65536 // x.element_size() 58 | BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 59 | if N > BLOCK_SIZE: 60 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 61 | # heuristics for number of warps 62 | num_warps = min(max(BLOCK_SIZE // 256, 1), 8) 63 | # enqueue kernel 64 | rms_norm_fwd_fused[(M,)](x_arg, y, self.weight, 65 | x_arg.stride(0), N, self.variance_epsilon, 66 | BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) 67 | return y 68 | 69 | 70 | def make_quant_norm(model): 71 | """ 72 | Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules 73 | """ 74 | 75 | for name, m in model.named_modules(): 76 | if not isinstance(m, LlamaRMSNorm): 77 | continue 78 | 79 | norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon) 80 | 81 | if '.' in name: 82 | parent_name = name.rsplit('.', 1)[0] 83 | child_name = name[len(parent_name) + 1:] 84 | parent = model.get_submodule(parent_name) 85 | else: 86 | parent_name = '' 87 | parent = model 88 | child_name = name 89 | 90 | #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") 91 | 92 | setattr(parent, child_name, norm) 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | safetensors==0.3.1 2 | datasets==2.10.1 3 | sentencepiece 4 | git+https://github.com/huggingface/transformers 5 | accelerate==0.20.3 6 | triton==2.0.0 7 | texttable 8 | toml 9 | numpy 10 | protobuf==3.20.2 11 | 12 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .modelutils import DEV, find_layers, gen_conditions, torch_snr_error 2 | from .datautils import set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders 3 | from .export import export_quant_table 4 | -------------------------------------------------------------------------------- /utils/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_seed(seed): 6 | np.random.seed(seed) 7 | torch.random.manual_seed(seed) 8 | 9 | 10 | def get_wikitext2(nsamples, seed, seqlen, model): 11 | from datasets import load_dataset 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | 15 | from transformers import AutoTokenizer 16 | try: 17 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 18 | except: 19 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) 20 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 21 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 22 | 23 | import random 24 | random.seed(seed) 25 | trainloader = [] 26 | for _ in range(nsamples): 27 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 28 | j = i + seqlen 29 | inp = trainenc.input_ids[:, i:j] 30 | tar = inp.clone() 31 | tar[:, :-1] = -100 32 | trainloader.append((inp, tar)) 33 | return trainloader, testenc 34 | 35 | 36 | def get_ptb(nsamples, seed, seqlen, model): 37 | from datasets import load_dataset 38 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 39 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') 40 | 41 | from transformers import AutoTokenizer 42 | try: 43 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 44 | except: 45 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) 46 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') 47 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 48 | 49 | import random 50 | random.seed(seed) 51 | trainloader = [] 52 | for _ in range(nsamples): 53 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 54 | j = i + seqlen 55 | inp = trainenc.input_ids[:, i:j] 56 | tar = inp.clone() 57 | tar[:, :-1] = -100 58 | trainloader.append((inp, tar)) 59 | return trainloader, testenc 60 | 61 | 62 | def get_c4(nsamples, seed, seqlen, model): 63 | from datasets import load_dataset 64 | traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False) 65 | valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False) 66 | 67 | from transformers import AutoTokenizer 68 | try: 69 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 70 | except: 71 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) 72 | 73 | import random 74 | random.seed(seed) 75 | trainloader = [] 76 | for _ in range(nsamples): 77 | while True: 78 | i = random.randint(0, len(traindata) - 1) 79 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 80 | if trainenc.input_ids.shape[1] >= seqlen: 81 | break 82 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 83 | j = i + seqlen 84 | inp = trainenc.input_ids[:, i:j] 85 | tar = inp.clone() 86 | tar[:, :-1] = -100 87 | trainloader.append((inp, tar)) 88 | 89 | import random 90 | random.seed(0) 91 | valenc = [] 92 | for _ in range(256): 93 | while True: 94 | i = random.randint(0, len(valdata) - 1) 95 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 96 | if tmp.input_ids.shape[1] >= seqlen: 97 | break 98 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 99 | j = i + seqlen 100 | valenc.append(tmp.input_ids[:, i:j]) 101 | valenc = torch.hstack(valenc) 102 | 103 | class TokenizerWrapper: 104 | 105 | def __init__(self, input_ids): 106 | self.input_ids = input_ids 107 | 108 | valenc = TokenizerWrapper(valenc) 109 | 110 | return trainloader, valenc 111 | 112 | 113 | def get_ptb_new(nsamples, seed, seqlen, model): 114 | from datasets import load_dataset 115 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 116 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 117 | 118 | from transformers import AutoTokenizer 119 | try: 120 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 121 | except: 122 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) 123 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 124 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 125 | 126 | import random 127 | random.seed(seed) 128 | trainloader = [] 129 | for _ in range(nsamples): 130 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 131 | j = i + seqlen 132 | inp = trainenc.input_ids[:, i:j] 133 | tar = inp.clone() 134 | tar[:, :-1] = -100 135 | trainloader.append((inp, tar)) 136 | return trainloader, testenc 137 | 138 | 139 | def get_c4_new(nsamples, seed, seqlen, model): 140 | from datasets import load_dataset 141 | traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 142 | valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 143 | 144 | from transformers import AutoTokenizer 145 | try: 146 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 147 | except: 148 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) 149 | 150 | import random 151 | random.seed(seed) 152 | trainloader = [] 153 | for _ in range(nsamples): 154 | while True: 155 | i = random.randint(0, len(traindata) - 1) 156 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 157 | if trainenc.input_ids.shape[1] >= seqlen: 158 | break 159 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 160 | j = i + seqlen 161 | inp = trainenc.input_ids[:, i:j] 162 | tar = inp.clone() 163 | tar[:, :-1] = -100 164 | trainloader.append((inp, tar)) 165 | 166 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 167 | valenc = valenc.input_ids[:, :(256 * seqlen)] 168 | 169 | class TokenizerWrapper: 170 | 171 | def __init__(self, input_ids): 172 | self.input_ids = input_ids 173 | 174 | valenc = TokenizerWrapper(valenc) 175 | 176 | return trainloader, valenc 177 | 178 | 179 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''): 180 | if 'wikitext2' in name: 181 | return get_wikitext2(nsamples, seed, seqlen, model) 182 | if 'ptb' in name: 183 | if 'new' in name: 184 | return get_ptb_new(nsamples, seed, seqlen, model) 185 | return get_ptb(nsamples, seed, seqlen, model) 186 | if 'c4' in name: 187 | if 'new' in name: 188 | return get_c4_new(nsamples, seed, seqlen, model) 189 | return get_c4(nsamples, seed, seqlen, model) 190 | -------------------------------------------------------------------------------- /utils/export.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import toml 3 | import os 4 | 5 | 6 | def export_quant_table(quantizers: dict, quant_dir: str, format: str = 'toml'): 7 | 8 | table = {} 9 | 10 | def save_tensor(name: str, tensor): 11 | np.save(os.path.join(quant_dir, name), tensor.numpy()) 12 | return '{}.npy'.format(name) 13 | 14 | for key, value in quantizers.items(): 15 | quantizer = value[0] 16 | 17 | dump = dict() 18 | 19 | sym = quantizer.sym 20 | if not sym: 21 | dump['zero'] = save_tensor(name=key + '.zero', tensor=value[2]) 22 | dump['scale'] = save_tensor(name=key + '.scale', tensor=value[1]) 23 | dump['wbits'] = value[4] 24 | dump['groupsize'] = value[5] 25 | if value[5] > 0: 26 | dump['group_ids'] = save_tensor(name=key + '.group_ids', tensor=value[3]) 27 | 28 | dump['sym'] = sym 29 | dump['perchannel'] = quantizer.perchannel 30 | 31 | table[key] = dump 32 | 33 | if not os.path.exists(quant_dir): 34 | os.mkdir(quant_dir) 35 | 36 | with open(os.path.join(quant_dir, 'quant.toml'), 'w') as f: 37 | toml.dump(table, f) 38 | -------------------------------------------------------------------------------- /utils/modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | DEV = torch.device('cuda:0') 5 | 6 | 7 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 8 | if type(module) in layers: 9 | return {name: module} 10 | res = {} 11 | for name1, child in module.named_children(): 12 | res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) 13 | return res 14 | 15 | 16 | def gen_conditions(_wbits, _groupsize): 17 | wbits = _wbits 18 | groupsize = _groupsize 19 | conditions = [] 20 | while True: 21 | if wbits >= 8: 22 | if groupsize == -1 or groupsize == 32: 23 | break 24 | 25 | if groupsize > 32: 26 | groupsize /= 2 27 | else: 28 | wbits *= 2 29 | groupsize = _groupsize 30 | 31 | conditions.append((int(wbits), int(groupsize))) 32 | return conditions 33 | 34 | 35 | # copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py 36 | def torch_snr_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: 37 | """ 38 | Compute SNR between y_pred(tensor) and y_real(tensor) 39 | 40 | SNR can be calcualted as following equation: 41 | 42 | SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 43 | 44 | if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. 45 | 46 | SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) 47 | Args: 48 | y_pred (torch.Tensor): _description_ 49 | y_real (torch.Tensor): _description_ 50 | reduction (str, optional): _description_. Defaults to 'mean'. 51 | Raises: 52 | ValueError: _description_ 53 | ValueError: _description_ 54 | Returns: 55 | torch.Tensor: _description_ 56 | """ 57 | y_pred = y_pred.type(torch.float32) 58 | y_real = y_real.type(torch.float32) 59 | 60 | if y_pred.shape != y_real.shape: 61 | raise ValueError(f'Can not compute snr loss for tensors with different shape. ' 62 | f'({y_pred.shape} and {y_real.shape})') 63 | reduction = str(reduction).lower() 64 | 65 | if y_pred.ndim == 1: 66 | y_pred = y_pred.unsqueeze(0) 67 | y_real = y_real.unsqueeze(0) 68 | 69 | y_pred = y_pred.flatten(start_dim=1) 70 | y_real = y_real.flatten(start_dim=1) 71 | 72 | noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) 73 | signal_power = torch.pow(y_real, 2).sum(dim=-1) 74 | snr = (noise_power) / (signal_power + 1e-7) 75 | 76 | if reduction == 'mean': 77 | return torch.mean(snr) 78 | elif reduction == 'sum': 79 | return torch.sum(snr) 80 | elif reduction == 'none': 81 | return snr 82 | else: 83 | raise ValueError(f'Unsupported reduction method.') 84 | --------------------------------------------------------------------------------