├── .gitignore ├── LICENSE ├── README.md ├── bloom.py ├── datautils.py ├── gptq.py ├── llama.py ├── modelutils.py ├── opt.py ├── quant.py ├── quant_cuda.cpp ├── quant_cuda_kernel.cu ├── setup_cuda.py ├── test_kernel.py └── zeroShot ├── LICENSE.md ├── README.md ├── datautils.py ├── evaluator.py ├── main.py ├── metrics.py ├── models ├── __init__.py ├── bloom.py ├── fast_trueobs.py ├── gptq.py ├── models_utils.py ├── opt.py └── quant.py ├── tasks ├── __init__.py ├── arc.py ├── glue.py ├── lambada.py ├── local_datasets │ ├── __init__.py │ ├── lambada │ │ ├── __init__.py │ │ ├── dataset_infos.json │ │ └── lambada.py │ └── storyCloze2018 │ │ └── cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv ├── piqa.py ├── storycloze.py ├── superglue.py └── tasks_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | build 3 | dist 4 | opt175b 5 | *.txt 6 | *.pt 7 | *egg-info* 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPTQ 2 | 3 | This repository contains the code for the ICLR 2023 paper [GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers](https://arxiv.org/abs/2210.17323). 4 | The current release includes the following features: 5 | 6 | * An efficient implementation of the GPTQ algorithm: `gptq.py` 7 | * Compressing all models from the OPT and BLOOM families to 2/3/4 bits, including weight grouping: `opt.py`, `bloom.py`, `zeroShot/` 8 | * Evaluating the perplexity of quantized models on several language generation tasks: `opt.py`, `bloom.py` 9 | * Evaluating the performance of quantized models on several ZeroShot tasks: `zeroShot/` 10 | * A 3-bit quantized matrix full-precision vector product CUDA kernel: `quant_cuda_kernel.cu`, `quant_cuda.cpp`, `setup_cuda.py` 11 | * Benchmarking code for individual matrix-vector products and for language generation with quantized models: `test_kernel.py`, `opt.py` 12 | 13 | ## New Features 14 | 15 | Update July 2023: 16 | 17 | * Added `--static-groups` options which determines all group-grids in advance rather than dynamically during quantization, which has the effect that `--act-order` does not require any inference changes (that may cause slowdown) when used together with this option. 18 | 19 | Together with the camera ready version of the paper we have added several updates to this repository: 20 | 21 | * Slightly adjusted preprocessing of C4 and PTB for more realistic evaluations (used in our updated results); can be activated via the flag `--new-eval`. 22 | * Optimized 3bit kernels, which are considerably faster especially on the A100, e.g. 1.9x -> 3.25x generation speedup for OPT-175B; can be activated via `--faster-kernel`. 23 | * A minimal LlaMa integration (for more complete features see the [GPTQ-for-LLaMA](https://github.com/qwopqwop200/GPTQ-for-LLaMa) repository), which demonstrates two new tricks:`--act-order` (quantizing columns in order of decreasing activation size) and `--true-sequential` (performing sequential quantization even within a single Transformer block). Those fix GPTQ's strangely bad performance on the 7B model (from 7.15 to 6.09 Wiki2 PPL) and lead to slight improvements on most models/settings in general. 24 | 25 | Here is a summary of LLaMa results: 26 | 27 | | Wiki2 PPL | FP16 | 4bit-RTN | 4bit-GPTQ | 3bit-RTN | 3bit-GPTQ | 3g128-GPTQ | 28 | |:---------:|:----:|:--------:|:---------:|:--------:|:---------:|:----------:| 29 | | LLaMa-7B | 5.68 | 6.29 | **6.09** | 25.54 | **8.07** | 6.61 | 30 | | LLaMa-13B | 5.09 | 5.53 | **5.36** | 11.40 | **6.63** | 5.62 | 31 | | LLaMa-30B | 4.10 | 4.54 | **4.45** | 14.89 | **5.69** | 4.80 | 32 | | LLaMa-65B | 3.53 | 3.92 | **3.84** | 10.59 | **5.04** | 4.17 | 33 | 34 | Here is a sample command: 35 | 36 | ``` 37 | python llama.py LLAMA_HF_FOLDER c4 --wbits 4 --true-sequential --act-order --new-eval 38 | ``` 39 | 40 | The `--act-order` heuristic also dramatically improves accuracy on the OPT-66B outlier model: 9.55 to 9.34 and 14.16 to 9.95 PPL on Wiki2 for 4bit and 3bit, respectively. 41 | 42 | ## Dependencies 43 | 44 | * `torch`: tested on v1.10.1+cu111 45 | * `transformers`: tested on v4.21.2 (the LLaMa integration currently requires a main install from source and `sentencepiece`) 46 | * `datasets`: tested on v1.17.0 47 | * (to run 3-bit kernels: setup for compiling PyTorch CUDA extensions, see also https://pytorch.org/tutorials/advanced/cpp_extension.html, tested on CUDA 11.4) 48 | 49 | All experiments were run on a single 80GB NVIDIA A100. However, most experiments will work on a GPU with a lot less memory as well. 50 | 51 | ## Language Generation 52 | 53 | ### OPT 54 | 55 | ``` 56 | # Compute full precision (FP16) results 57 | CUDA_VISIBLE_DEVICES=0 python opt.py facebook/opt-125m c4 58 | # Run RTN baseline and compute results 59 | CUDA_VISIBLE_DEVICES=0 python opt.py facebook/opt-125m c4 --wbits 4 --nearest 60 | # Run GPTQ and compute results 61 | CUDA_VISIBLE_DEVICES=0 python opt.py facebook/opt-125m c4 --wbits 4 [--groupsize 1024] 62 | ```` 63 | 64 | To run other OPT models replace `opt-125m` with one of: `opt-350m`, `opt-1.3b`, `opt-2.7b`, `opt-6.7b`, `opt-13b`, `opt-66b`. 65 | For the 175B-parameter mode, you have to request access from Meta and then convert it to a local HuggingFace checkpoint using their scripts in `metaseq`. 66 | Once you have such a checkpoint, simply pass its path instead of `facebook/opt-125m`. 67 | 68 | ### BLOOM 69 | 70 | ``` 71 | # Compute full precision (FP16) results 72 | CUDA_VISIBLE_DEVICES=0 python bloom.py bigscience/bloom-560m c4 73 | # Run RTN baseline and compute results 74 | CUDA_VISIBLE_DEVICES=0 python bloom.py bigscience/bloom-560m c4 --wbits 4 --nearest 75 | # Run GPTQ and compute results 76 | CUDA_VISIBLE_DEVICES=0 python bloom.py bigscience/bloom-560m c4 --wbits 4 [--groupsize 1024] 77 | ```` 78 | 79 | To run other BLOOM models replace `bloom-560m` with one of: `bloom-1b1`, `bloom-1b7`, `bloom-3b`, `bloom-7b1`, `bloom`. 80 | 81 | ## ZeroShot 82 | 83 | See `zeroShot/` folder. 84 | 85 | ## 3-bit CUDA Kernels 86 | 87 | ``` 88 | # Install kernels 89 | python setup_cuda.py install 90 | 91 | # Benchmark performance for FC2 layer of OPT-175B 92 | CUDA_VISIBLE_DEVICES=0 python test_kernel.py 93 | 94 | # Benchmark language generation with 3-bit OPT-175B: 95 | # OPT175B denotes the name of the folder with the HuggingFace OPT-175b checkpoint (see above) 96 | 97 | # Save compressed model 98 | CUDA_VISIBLE_DEVICES=0 python opt.py OPT175B c4 --wbits 3 --save opt175-3bit.pt 99 | # Benchmark generating a 128 token sequence with the saved model 100 | CUDA_VISIBLE_DEVICES=0 python opt.py OPT175B c4 --load opt175b-3bit.pt --benchmark 128 101 | # Benchmark FP16 baseline, note that the model will be split across all listed GPUs 102 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python opt.py OPT175B c4 --benchmark 128 103 | ``` 104 | 105 | Please note that our 3-bit kernels are currently only optimized for OPT-175B running on 1xA100 or 2xA6000 and may thus yield suboptimal performance on smaller models or on other GPUs. 106 | 107 | ## Cite 108 | 109 | If you found this work useful, please consider citing: 110 | 111 | ``` 112 | @article{frantar-gptq, 113 | title={{GPTQ}: Accurate Post-training Compression for Generative Pretrained Transformers}, 114 | author={Elias Frantar and Saleh Ashkboos and Torsten Hoefler and Dan Alistarh}, 115 | year={2022}, 116 | journal={arXiv preprint arXiv:2210.17323} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /bloom.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from gptq import * 9 | from modelutils import * 10 | from quant import * 11 | 12 | 13 | def get_bloom(model): 14 | import torch 15 | def skip(*args, **kwargs): 16 | pass 17 | torch.nn.init.kaiming_uniform_ = skip 18 | torch.nn.init.uniform_ = skip 19 | torch.nn.init.normal_ = skip 20 | from transformers import BloomForCausalLM 21 | model = BloomForCausalLM.from_pretrained(model, torch_dtype='auto') 22 | model.seqlen = 2048 23 | return model 24 | 25 | @torch.no_grad() 26 | def bloom_sequential(model, dataloader, dev, means=None, stds=None): 27 | print('Starting ...') 28 | 29 | use_cache = model.config.use_cache 30 | model.config.use_cache = False 31 | layers = model.transformer.h 32 | 33 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(dev) 34 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(dev) 35 | layers[0] = layers[0].to(dev) 36 | 37 | dtype = next(iter(model.parameters())).dtype 38 | inps = torch.zeros( 39 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 40 | ) 41 | cache = {'i': 0, 'attention_mask': None, 'alibi': None} 42 | 43 | class Catcher(nn.Module): 44 | def __init__(self, module): 45 | super().__init__() 46 | self.module = module 47 | def forward(self, inp, **kwargs): 48 | inps[cache['i']] = inp 49 | cache['i'] += 1 50 | cache['attention_mask'] = kwargs['attention_mask'] 51 | cache['alibi'] = kwargs['alibi'] 52 | raise ValueError 53 | layers[0] = Catcher(layers[0]) 54 | for batch in dataloader: 55 | try: 56 | model(batch[0].to(dev)) 57 | except ValueError: 58 | pass 59 | layers[0] = layers[0].module 60 | 61 | layers[0] = layers[0].cpu() 62 | model.transformer.word_embeddings = model.transformer.word_embeddings.cpu() 63 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.cpu() 64 | torch.cuda.empty_cache() 65 | 66 | outs = torch.zeros_like(inps) 67 | attention_mask = cache['attention_mask'] 68 | alibi = cache['alibi'] 69 | 70 | print('Ready.') 71 | 72 | quantizers = {} 73 | for i in range(len(layers)): 74 | layer = layers[i].to(dev) 75 | 76 | subset = find_layers(layer) 77 | gptq = {} 78 | for name in subset: 79 | gptq[name] = GPTQ(subset[name]) 80 | gptq[name].quantizer = Quantizer() 81 | gptq[name].quantizer.configure( 82 | args.wbits, perchannel=True, sym=args.sym, mse=False 83 | ) 84 | 85 | def add_batch(name): 86 | def tmp(_, inp, out): 87 | gptq[name].add_batch(inp[0].data, out.data) 88 | return tmp 89 | handles = [] 90 | for name in subset: 91 | handles.append(subset[name].register_forward_hook(add_batch(name))) 92 | for j in range(args.nsamples): 93 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, alibi=alibi)[0] 94 | for h in handles: 95 | h.remove() 96 | 97 | for name in subset: 98 | print(i, name) 99 | print('Quantizing ...') 100 | gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize) 101 | quantizers['transformer.h.%d.%s' % (i, name)] = gptq[name].quantizer 102 | for j in range(args.nsamples): 103 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, alibi=alibi)[0] 104 | 105 | layers[i] = layer.cpu() 106 | del gptq 107 | torch.cuda.empty_cache() 108 | 109 | inps, outs = outs, inps 110 | 111 | model.config.use_cache = use_cache 112 | 113 | return quantizers 114 | 115 | @torch.no_grad() 116 | def bloom_eval(model, testenc, dev): 117 | print('Evaluation...') 118 | 119 | testenc = testenc.input_ids 120 | nsamples = testenc.numel() // model.seqlen 121 | 122 | use_cache = model.config.use_cache 123 | model.config.use_cache = False 124 | layers = model.transformer.h 125 | 126 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(dev) 127 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(dev) 128 | layers[0] = layers[0].to(dev) 129 | 130 | dtype = next(iter(model.parameters())).dtype 131 | inps = torch.zeros( 132 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 133 | ) 134 | cache = {'i': 0, 'attention_mask': None, 'alibi': None} 135 | 136 | class Catcher(nn.Module): 137 | def __init__(self, module): 138 | super().__init__() 139 | self.module = module 140 | def forward(self, inp, **kwargs): 141 | inps[cache['i']] = inp 142 | cache['i'] += 1 143 | cache['attention_mask'] = kwargs['attention_mask'] 144 | cache['alibi'] = kwargs['alibi'] 145 | raise ValueError 146 | layers[0] = Catcher(layers[0]) 147 | for i in range(nsamples): 148 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 149 | try: 150 | model(batch) 151 | except ValueError: 152 | pass 153 | layers[0] = layers[0].module 154 | 155 | layers[0] = layers[0].cpu() 156 | model.transformer.word_embeddings = model.transformer.word_embeddings.cpu() 157 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.cpu() 158 | torch.cuda.empty_cache() 159 | 160 | outs = torch.zeros_like(inps) 161 | attention_mask = cache['attention_mask'] 162 | alibi = cache['alibi'] 163 | 164 | for i in range(len(layers)): 165 | print(i) 166 | layer = layers[i].to(dev) 167 | 168 | if args.nearest: 169 | subset = find_layers(layer) 170 | for name in subset: 171 | quantizer = Quantizer() 172 | quantizer.configure( 173 | args.wbits, perchannel=True, sym=args.sym, mse=False 174 | ) 175 | W = subset[name].weight.data 176 | quantizer.find_params(W, weight=True) 177 | subset[name].weight.data = quantize( 178 | W, quantizer.scale, quantizer.zero, quantizer.maxq 179 | ).to(next(iter(layer.parameters())).dtype) 180 | 181 | for j in range(nsamples): 182 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, alibi=alibi)[0] 183 | layers[i] = layer.cpu() 184 | del layer 185 | torch.cuda.empty_cache() 186 | inps, outs = outs, inps 187 | 188 | model.transformer.ln_f = model.transformer.ln_f.to(dev) 189 | model.lm_head = model.lm_head.to(dev) 190 | 191 | testenc = testenc.to(dev) 192 | nlls = [] 193 | for i in range(nsamples): 194 | hidden_states = inps[i].unsqueeze(0) 195 | hidden_states = model.transformer.ln_f(hidden_states) 196 | lm_logits = model.lm_head(hidden_states) 197 | shift_logits = lm_logits[:, :-1, :].contiguous() 198 | shift_labels = testenc[ 199 | :, (i * model.seqlen):((i + 1) * model.seqlen) 200 | ][:, 1:] 201 | loss_fct = nn.CrossEntropyLoss() 202 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 203 | neg_log_likelihood = loss.float() * model.seqlen 204 | nlls.append(neg_log_likelihood) 205 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 206 | print(ppl.item()) 207 | 208 | model.config.use_cache = use_cache 209 | 210 | 211 | def bloom_pack3(model, quantizers): 212 | layers = find_layers(model) 213 | layers = {n: layers[n] for n in quantizers} 214 | make_quant3(model, quantizers) 215 | qlayers = find_layers(model, [Quant3Linear]) 216 | print('Packing ...') 217 | for name in qlayers: 218 | print(name) 219 | quantizers[name] = quantizers[name].cpu() 220 | qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero) 221 | print('Done.') 222 | return model 223 | 224 | 225 | if __name__ == '__main__': 226 | import argparse 227 | from datautils import * 228 | 229 | parser = argparse.ArgumentParser() 230 | 231 | parser.add_argument( 232 | 'model', type=str, 233 | help='BLOOM model to load; pass `bigscience/bloom-X`.' 234 | ) 235 | parser.add_argument( 236 | 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], 237 | help='Where to extract calibration data from.' 238 | ) 239 | parser.add_argument( 240 | '--seed', 241 | type=int, default=0, help='Seed for sampling the calibration data.' 242 | ) 243 | parser.add_argument( 244 | '--nsamples', type=int, default=128, 245 | help='Number of calibration data samples.' 246 | ) 247 | parser.add_argument( 248 | '--percdamp', type=float, default=.01, 249 | help='Percent of the average Hessian diagonal to use for dampening.' 250 | ) 251 | parser.add_argument( 252 | '--nearest', action='store_true', 253 | help='Whether to run the RTN baseline.' 254 | ) 255 | parser.add_argument( 256 | '--wbits', type=int, default=16, choices=[2, 3, 4, 16], 257 | help='#bits to use for quantization; use 16 for evaluating base model.' 258 | ) 259 | parser.add_argument( 260 | '--groupsize', type=int, default=-1, 261 | help='Groupsize to use for quantization; default uses full row.' 262 | ) 263 | parser.add_argument( 264 | '--sym', action='store_true', 265 | help='Whether to perform symmetric quantization.' 266 | ) 267 | parser.add_argument( 268 | '--save', type=str, default='', 269 | help='Save quantized checkpoint under this name.' 270 | ) 271 | parser.add_argument( 272 | '--new-eval', action='store_true', 273 | help='Whether to use the new PTB and C4 eval' 274 | ) 275 | 276 | 277 | args = parser.parse_args() 278 | 279 | model = get_bloom(args.model) 280 | model.eval() 281 | 282 | dataloader, testloader = get_loaders( 283 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 284 | ) 285 | 286 | if args.wbits < 16 and not args.nearest: 287 | tick = time.time() 288 | quantizers = bloom_sequential(model, dataloader, DEV) 289 | print(time.time() - tick) 290 | 291 | datasets = ['wikitext2', 'ptb', 'c4'] 292 | if args.new_eval: 293 | datasets = ['wikitext2', 'ptb-new', 'c4-new'] 294 | for dataset in datasets: 295 | dataloader, testloader = get_loaders( 296 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 297 | ) 298 | print(dataset) 299 | bloom_eval(model, testloader, DEV) 300 | 301 | if args.save: 302 | bloom_pack3(model, quantizers) 303 | torch.save(model.state_dict(), args.save) 304 | -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_seed(seed): 6 | np.random.seed(seed) 7 | torch.random.manual_seed(seed) 8 | 9 | 10 | def get_wikitext2(nsamples, seed, seqlen, model): 11 | from datasets import load_dataset 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | 15 | from transformers import AutoTokenizer 16 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 17 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 19 | 20 | import random 21 | random.seed(seed) 22 | trainloader = [] 23 | for _ in range(nsamples): 24 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 25 | j = i + seqlen 26 | inp = trainenc.input_ids[:, i:j] 27 | tar = inp.clone() 28 | tar[:, :-1] = -100 29 | trainloader.append((inp, tar)) 30 | return trainloader, testenc 31 | 32 | def get_ptb(nsamples, seed, seqlen, model): 33 | from datasets import load_dataset 34 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 35 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') 36 | 37 | from transformers import AutoTokenizer 38 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 39 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') 40 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 41 | 42 | import random 43 | random.seed(seed) 44 | trainloader = [] 45 | for _ in range(nsamples): 46 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 47 | j = i + seqlen 48 | inp = trainenc.input_ids[:, i:j] 49 | tar = inp.clone() 50 | tar[:, :-1] = -100 51 | trainloader.append((inp, tar)) 52 | return trainloader, testenc 53 | 54 | def get_c4(nsamples, seed, seqlen, model): 55 | from datasets import load_dataset 56 | traindata = load_dataset( 57 | 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' 58 | ) 59 | valdata = load_dataset( 60 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' 61 | ) 62 | 63 | from transformers import AutoTokenizer 64 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 65 | 66 | import random 67 | random.seed(seed) 68 | trainloader = [] 69 | for _ in range(nsamples): 70 | while True: 71 | i = random.randint(0, len(traindata) - 1) 72 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 73 | if trainenc.input_ids.shape[1] >= seqlen: 74 | break 75 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 76 | j = i + seqlen 77 | inp = trainenc.input_ids[:, i:j] 78 | tar = inp.clone() 79 | tar[:, :-1] = -100 80 | trainloader.append((inp, tar)) 81 | 82 | import random 83 | random.seed(0) 84 | valenc = [] 85 | for _ in range(256): 86 | while True: 87 | i = random.randint(0, len(valdata) - 1) 88 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 89 | if tmp.input_ids.shape[1] >= seqlen: 90 | break 91 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 92 | j = i + seqlen 93 | valenc.append(tmp.input_ids[:, i:j]) 94 | valenc = torch.hstack(valenc) 95 | class TokenizerWrapper: 96 | def __init__(self, input_ids): 97 | self.input_ids = input_ids 98 | valenc = TokenizerWrapper(valenc) 99 | 100 | return trainloader, valenc 101 | 102 | def get_ptb_new(nsamples, seed, seqlen, model): 103 | from datasets import load_dataset 104 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 105 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 106 | 107 | from transformers import AutoTokenizer 108 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 109 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 110 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 111 | 112 | import random 113 | random.seed(seed) 114 | trainloader = [] 115 | for _ in range(nsamples): 116 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 117 | j = i + seqlen 118 | inp = trainenc.input_ids[:, i:j] 119 | tar = inp.clone() 120 | tar[:, :-1] = -100 121 | trainloader.append((inp, tar)) 122 | return trainloader, testenc 123 | 124 | def get_c4_new(nsamples, seed, seqlen, model): 125 | from datasets import load_dataset 126 | traindata = load_dataset( 127 | 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' 128 | ) 129 | valdata = load_dataset( 130 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' 131 | ) 132 | 133 | from transformers import AutoTokenizer 134 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 135 | 136 | import random 137 | random.seed(seed) 138 | trainloader = [] 139 | for _ in range(nsamples): 140 | while True: 141 | i = random.randint(0, len(traindata) - 1) 142 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 143 | if trainenc.input_ids.shape[1] >= seqlen: 144 | break 145 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 146 | j = i + seqlen 147 | inp = trainenc.input_ids[:, i:j] 148 | tar = inp.clone() 149 | tar[:, :-1] = -100 150 | trainloader.append((inp, tar)) 151 | 152 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 153 | valenc = valenc.input_ids[:, :(256 * seqlen)] 154 | 155 | class TokenizerWrapper: 156 | def __init__(self, input_ids): 157 | self.input_ids = input_ids 158 | valenc = TokenizerWrapper(valenc) 159 | 160 | return trainloader, valenc 161 | 162 | 163 | def get_loaders( 164 | name, nsamples=128, seed=0, seqlen=2048, model='' 165 | ): 166 | if 'wikitext2' in name: 167 | return get_wikitext2(nsamples, seed, seqlen, model) 168 | if 'ptb' in name: 169 | if 'new' in name: 170 | return get_ptb_new(nsamples, seed, seqlen, model) 171 | return get_ptb(nsamples, seed, seqlen, model) 172 | if 'c4' in name: 173 | if 'new' in name: 174 | return get_c4_new(nsamples, seed, seqlen, model) 175 | return get_c4(nsamples, seed, seqlen, model) 176 | -------------------------------------------------------------------------------- /gptq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from quant import * 9 | 10 | 11 | DEBUG = False 12 | 13 | torch.backends.cuda.matmul.allow_tf32 = False 14 | torch.backends.cudnn.allow_tf32 = False 15 | 16 | 17 | class GPTQ: 18 | 19 | def __init__(self, layer): 20 | self.layer = layer 21 | self.dev = self.layer.weight.device 22 | W = layer.weight.data.clone() 23 | if isinstance(self.layer, nn.Conv2d): 24 | W = W.flatten(1) 25 | if isinstance(self.layer, transformers.Conv1D): 26 | W = W.t() 27 | self.rows = W.shape[0] 28 | self.columns = W.shape[1] 29 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 30 | self.nsamples = 0 31 | 32 | def add_batch(self, inp, out): 33 | if DEBUG: 34 | self.inp1 = inp 35 | self.out1 = out 36 | if len(inp.shape) == 2: 37 | inp = inp.unsqueeze(0) 38 | tmp = inp.shape[0] 39 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 40 | if len(inp.shape) == 3: 41 | inp = inp.reshape((-1, inp.shape[-1])) 42 | inp = inp.t() 43 | if isinstance(self.layer, nn.Conv2d): 44 | unfold = nn.Unfold( 45 | self.layer.kernel_size, 46 | dilation=self.layer.dilation, 47 | padding=self.layer.padding, 48 | stride=self.layer.stride 49 | ) 50 | inp = unfold(inp) 51 | inp = inp.permute([1, 0, 2]) 52 | inp = inp.flatten(1) 53 | self.H *= self.nsamples / (self.nsamples + tmp) 54 | self.nsamples += tmp 55 | # inp = inp.float() 56 | inp = math.sqrt(2 / self.nsamples) * inp.float() 57 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 58 | self.H += inp.matmul(inp.t()) 59 | 60 | def fasterquant( 61 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False 62 | ): 63 | W = self.layer.weight.data.clone() 64 | if isinstance(self.layer, nn.Conv2d): 65 | W = W.flatten(1) 66 | if isinstance(self.layer, transformers.Conv1D): 67 | W = W.t() 68 | W = W.float() 69 | 70 | tick = time.time() 71 | 72 | if not self.quantizer.ready(): 73 | self.quantizer.find_params(W, weight=True) 74 | 75 | H = self.H 76 | del self.H 77 | dead = torch.diag(H) == 0 78 | H[dead, dead] = 1 79 | W[:, dead] = 0 80 | 81 | if static_groups: 82 | import copy 83 | groups = [] 84 | for i in range(0, self.columns, groupsize): 85 | quantizer = copy.deepcopy(self.quantizer) 86 | quantizer.find_params(W[:, i:(i + groupsize)], weight=True) 87 | groups.append(quantizer) 88 | 89 | if actorder: 90 | perm = torch.argsort(torch.diag(H), descending=True) 91 | W = W[:, perm] 92 | H = H[perm][:, perm] 93 | invperm = torch.argsort(perm) 94 | 95 | Losses = torch.zeros_like(W) 96 | Q = torch.zeros_like(W) 97 | 98 | damp = percdamp * torch.mean(torch.diag(H)) 99 | diag = torch.arange(self.columns, device=self.dev) 100 | H[diag, diag] += damp 101 | H = torch.linalg.cholesky(H) 102 | H = torch.cholesky_inverse(H) 103 | H = torch.linalg.cholesky(H, upper=True) 104 | Hinv = H 105 | 106 | for i1 in range(0, self.columns, blocksize): 107 | i2 = min(i1 + blocksize, self.columns) 108 | count = i2 - i1 109 | 110 | W1 = W[:, i1:i2].clone() 111 | Q1 = torch.zeros_like(W1) 112 | Err1 = torch.zeros_like(W1) 113 | Losses1 = torch.zeros_like(W1) 114 | Hinv1 = Hinv[i1:i2, i1:i2] 115 | 116 | for i in range(count): 117 | w = W1[:, i] 118 | d = Hinv1[i, i] 119 | 120 | if groupsize != -1: 121 | if not static_groups: 122 | if (i1 + i) % groupsize == 0: 123 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) 124 | else: 125 | idx = i1 + i 126 | if actorder: 127 | idx = perm[idx] 128 | self.quantizer = groups[idx // groupsize] 129 | 130 | q = quantize( 131 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 132 | ).flatten() 133 | Q1[:, i] = q 134 | Losses1[:, i] = (w - q) ** 2 / d ** 2 135 | 136 | err1 = (w - q) / d 137 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 138 | Err1[:, i] = err1 139 | 140 | Q[:, i1:i2] = Q1 141 | Losses[:, i1:i2] = Losses1 / 2 142 | 143 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 144 | 145 | if DEBUG: 146 | self.layer.weight.data[:, :i2] = Q[:, :i2] 147 | self.layer.weight.data[:, i2:] = W[:, i2:] 148 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 149 | print(torch.sum(Losses)) 150 | 151 | torch.cuda.synchronize() 152 | print('time %.2f' % (time.time() - tick)) 153 | print('error', torch.sum(Losses).item()) 154 | 155 | if actorder: 156 | Q = Q[:, invperm] 157 | 158 | if isinstance(self.layer, transformers.Conv1D): 159 | Q = Q.t() 160 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 161 | if DEBUG: 162 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 163 | 164 | def free(self): 165 | if DEBUG: 166 | self.inp1 = None 167 | self.out1 = None 168 | self.H = None 169 | self.Losses = None 170 | self.Trace = None 171 | torch.cuda.empty_cache() 172 | -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from gptq import * 7 | from modelutils import * 8 | from quant import * 9 | 10 | 11 | def get_llama(model): 12 | import torch 13 | def skip(*args, **kwargs): 14 | pass 15 | torch.nn.init.kaiming_uniform_ = skip 16 | torch.nn.init.uniform_ = skip 17 | torch.nn.init.normal_ = skip 18 | from transformers import LlamaForCausalLM 19 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') 20 | model.seqlen = 2048 21 | return model 22 | 23 | @torch.no_grad() 24 | def llama_sequential(model, dataloader, dev): 25 | print('Starting ...') 26 | 27 | use_cache = model.config.use_cache 28 | model.config.use_cache = False 29 | layers = model.model.layers 30 | 31 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 32 | model.model.norm = model.model.norm.to(dev) 33 | layers[0] = layers[0].to(dev) 34 | 35 | dtype = next(iter(model.parameters())).dtype 36 | inps = torch.zeros( 37 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 38 | ) 39 | cache = {'i': 0, 'attention_mask': None} 40 | 41 | class Catcher(nn.Module): 42 | def __init__(self, module): 43 | super().__init__() 44 | self.module = module 45 | def forward(self, inp, **kwargs): 46 | inps[cache['i']] = inp 47 | cache['i'] += 1 48 | cache['attention_mask'] = kwargs['attention_mask'] 49 | cache['position_ids'] = kwargs['position_ids'] 50 | raise ValueError 51 | layers[0] = Catcher(layers[0]) 52 | for batch in dataloader: 53 | try: 54 | model(batch[0].to(dev)) 55 | except ValueError: 56 | pass 57 | layers[0] = layers[0].module 58 | 59 | layers[0] = layers[0].cpu() 60 | model.model.embed_tokens = model.model.embed_tokens.cpu() 61 | model.model.norm = model.model.norm.cpu() 62 | torch.cuda.empty_cache() 63 | 64 | outs = torch.zeros_like(inps) 65 | attention_mask = cache['attention_mask'] 66 | position_ids = cache['position_ids'] 67 | 68 | print('Ready.') 69 | 70 | quantizers = {} 71 | for i in range(len(layers)): 72 | layer = layers[i].to(dev) 73 | full = find_layers(layer) 74 | 75 | if args.true_sequential: 76 | sequential = [ 77 | ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], 78 | ['self_attn.o_proj'], 79 | ['mlp.up_proj', 'mlp.gate_proj'], 80 | ['mlp.down_proj'] 81 | ] 82 | else: 83 | sequential = [list(full.keys())] 84 | 85 | for names in sequential: 86 | subset = {n: full[n] for n in names} 87 | 88 | gptq = {} 89 | for name in subset: 90 | gptq[name] = GPTQ(subset[name]) 91 | gptq[name].quantizer = Quantizer() 92 | gptq[name].quantizer.configure( 93 | args.wbits, perchannel=True, sym=args.sym, mse=False 94 | ) 95 | 96 | def add_batch(name): 97 | def tmp(_, inp, out): 98 | gptq[name].add_batch(inp[0].data, out.data) 99 | return tmp 100 | handles = [] 101 | for name in subset: 102 | handles.append(subset[name].register_forward_hook(add_batch(name))) 103 | for j in range(args.nsamples): 104 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 105 | for h in handles: 106 | h.remove() 107 | 108 | for name in subset: 109 | print(i, name) 110 | print('Quantizing ...') 111 | gptq[name].fasterquant( 112 | percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, static_groups=args.static_groups 113 | ) 114 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer 115 | gptq[name].free() 116 | 117 | for j in range(args.nsamples): 118 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 119 | 120 | layers[i] = layer.cpu() 121 | del layer 122 | del gptq 123 | torch.cuda.empty_cache() 124 | 125 | inps, outs = outs, inps 126 | 127 | model.config.use_cache = use_cache 128 | 129 | return quantizers 130 | 131 | @torch.no_grad() 132 | def llama_eval(model, testenc, dev): 133 | print('Evaluating ...') 134 | 135 | testenc = testenc.input_ids 136 | nsamples = testenc.numel() // model.seqlen 137 | 138 | use_cache = model.config.use_cache 139 | model.config.use_cache = False 140 | layers = model.model.layers 141 | 142 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 143 | layers[0] = layers[0].to(dev) 144 | 145 | dtype = next(iter(model.parameters())).dtype 146 | inps = torch.zeros( 147 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 148 | ) 149 | cache = {'i': 0, 'attention_mask': None} 150 | 151 | class Catcher(nn.Module): 152 | def __init__(self, module): 153 | super().__init__() 154 | self.module = module 155 | def forward(self, inp, **kwargs): 156 | inps[cache['i']] = inp 157 | cache['i'] += 1 158 | cache['attention_mask'] = kwargs['attention_mask'] 159 | cache['position_ids'] = kwargs['position_ids'] 160 | raise ValueError 161 | layers[0] = Catcher(layers[0]) 162 | for i in range(nsamples): 163 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 164 | try: 165 | model(batch) 166 | except ValueError: 167 | pass 168 | layers[0] = layers[0].module 169 | 170 | layers[0] = layers[0].cpu() 171 | model.model.embed_tokens = model.model.embed_tokens.cpu() 172 | torch.cuda.empty_cache() 173 | 174 | outs = torch.zeros_like(inps) 175 | attention_mask = cache['attention_mask'] 176 | position_ids = cache['position_ids'] 177 | 178 | for i in range(len(layers)): 179 | print(i) 180 | layer = layers[i].to(dev) 181 | 182 | if args.nearest: 183 | subset = find_layers(layer) 184 | for name in subset: 185 | quantizer = Quantizer() 186 | quantizer.configure( 187 | args.wbits, perchannel=True, sym=False, mse=False 188 | ) 189 | W = subset[name].weight.data 190 | quantizer.find_params(W, weight=True) 191 | subset[name].weight.data = quantize( 192 | W, quantizer.scale, quantizer.zero, quantizer.maxq 193 | ).to(next(iter(layer.parameters())).dtype) 194 | 195 | for j in range(nsamples): 196 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 197 | layers[i] = layer.cpu() 198 | del layer 199 | torch.cuda.empty_cache() 200 | inps, outs = outs, inps 201 | 202 | if model.model.norm is not None: 203 | model.model.norm = model.model.norm.to(dev) 204 | model.lm_head = model.lm_head.to(dev) 205 | 206 | testenc = testenc.to(dev) 207 | nlls = [] 208 | for i in range(nsamples): 209 | hidden_states = inps[i].unsqueeze(0) 210 | if model.model.norm is not None: 211 | hidden_states = model.model.norm(hidden_states) 212 | lm_logits = model.lm_head(hidden_states) 213 | shift_logits = lm_logits[:, :-1, :].contiguous() 214 | shift_labels = testenc[ 215 | :, (i * model.seqlen):((i + 1) * model.seqlen) 216 | ][:, 1:] 217 | loss_fct = nn.CrossEntropyLoss() 218 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 219 | neg_log_likelihood = loss.float() * model.seqlen 220 | nlls.append(neg_log_likelihood) 221 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 222 | print(ppl.item()) 223 | 224 | model.config.use_cache = use_cache 225 | 226 | def llama_pack3(model, quantizers): 227 | layers = find_layers(model) 228 | layers = {n: layers[n] for n in quantizers} 229 | make_quant3(model, quantizers) 230 | qlayers = find_layers(model, [Quant3Linear]) 231 | print('Packing ...') 232 | for name in qlayers: 233 | print(name) 234 | quantizers[name] = quantizers[name].cpu() 235 | qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero) 236 | print('Done.') 237 | return model 238 | 239 | 240 | if __name__ == '__main__': 241 | import argparse 242 | from datautils import * 243 | 244 | parser = argparse.ArgumentParser() 245 | 246 | parser.add_argument( 247 | 'model', type=str, 248 | help='LlaMa model to load; pass location of hugginface converted checkpoint.' 249 | ) 250 | parser.add_argument( 251 | 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], 252 | help='Where to extract calibration data from.' 253 | ) 254 | parser.add_argument( 255 | '--seed', 256 | type=int, default=0, help='Seed for sampling the calibration data.' 257 | ) 258 | parser.add_argument( 259 | '--nsamples', type=int, default=128, 260 | help='Number of calibration data samples.' 261 | ) 262 | parser.add_argument( 263 | '--percdamp', type=float, default=.01, 264 | help='Percent of the average Hessian diagonal to use for dampening.' 265 | ) 266 | parser.add_argument( 267 | '--nearest', action='store_true', 268 | help='Whether to run the RTN baseline.' 269 | ) 270 | parser.add_argument( 271 | '--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], 272 | help='#bits to use for quantization; use 16 for evaluating base model.' 273 | ) 274 | parser.add_argument( 275 | '--groupsize', type=int, default=-1, 276 | help='Groupsize to use for quantization; default uses full row.' 277 | ) 278 | parser.add_argument( 279 | '--sym', action='store_true', 280 | help='Whether to perform symmetric quantization.' 281 | ) 282 | parser.add_argument( 283 | '--save', type=str, default='', 284 | help='Save quantized checkpoint under this name.' 285 | ) 286 | parser.add_argument( 287 | '--new-eval', action='store_true', 288 | help='Whether to use the new PTB and C4 eval.' 289 | ) 290 | parser.add_argument( 291 | '--act-order', action='store_true', 292 | help='Whether to apply the activation order GPTQ heuristic' 293 | ) 294 | parser.add_argument( 295 | '--true-sequential', action='store_true', 296 | help='Whether to run in true sequential model.' 297 | ) 298 | parser.add_argument( 299 | '--static-groups', action='store_true', 300 | help='Whether to use static groups; recommended when using `--actorder` for more efficient inference.' 301 | ) 302 | 303 | args = parser.parse_args() 304 | 305 | model = get_llama(args.model) 306 | model.eval() 307 | 308 | dataloader, testloader = get_loaders( 309 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 310 | ) 311 | 312 | if args.wbits < 16 and not args.nearest: 313 | tick = time.time() 314 | quantizers = llama_sequential(model, dataloader, DEV) 315 | print(time.time() - tick) 316 | 317 | datasets = ['wikitext2', 'ptb', 'c4'] 318 | if args.new_eval: 319 | datasets = ['wikitext2', 'ptb-new', 'c4-new'] 320 | for dataset in datasets: 321 | dataloader, testloader = get_loaders( 322 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 323 | ) 324 | print(dataset) 325 | llama_eval(model, testloader, DEV) 326 | 327 | if args.save: 328 | llama_pack3(model, quantizers) 329 | torch.save(model.state_dict(), args.save) 330 | 331 | -------------------------------------------------------------------------------- /modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | DEV = torch.device('cuda:0') 6 | 7 | 8 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 9 | if type(module) in layers: 10 | return {name: module} 11 | res = {} 12 | for name1, child in module.named_children(): 13 | res.update(find_layers( 14 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 15 | )) 16 | return res 17 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from gptq import * 7 | from modelutils import * 8 | from quant import * 9 | 10 | 11 | def get_opt(model): 12 | import torch 13 | def skip(*args, **kwargs): 14 | pass 15 | torch.nn.init.kaiming_uniform_ = skip 16 | torch.nn.init.uniform_ = skip 17 | torch.nn.init.normal_ = skip 18 | from transformers import OPTForCausalLM 19 | model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') 20 | model.seqlen = model.config.max_position_embeddings 21 | return model 22 | 23 | @torch.no_grad() 24 | def opt_sequential(model, dataloader, dev): 25 | print('Starting ...') 26 | 27 | use_cache = model.config.use_cache 28 | model.config.use_cache = False 29 | layers = model.model.decoder.layers 30 | 31 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 32 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 33 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 34 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 35 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 36 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 37 | layers[0] = layers[0].to(dev) 38 | 39 | dtype = next(iter(model.parameters())).dtype 40 | inps = torch.zeros( 41 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 42 | ) 43 | cache = {'i': 0, 'attention_mask': None} 44 | 45 | class Catcher(nn.Module): 46 | def __init__(self, module): 47 | super().__init__() 48 | self.module = module 49 | def forward(self, inp, **kwargs): 50 | inps[cache['i']] = inp 51 | cache['i'] += 1 52 | cache['attention_mask'] = kwargs['attention_mask'] 53 | raise ValueError 54 | layers[0] = Catcher(layers[0]) 55 | for batch in dataloader: 56 | try: 57 | model(batch[0].to(dev)) 58 | except ValueError: 59 | pass 60 | layers[0] = layers[0].module 61 | 62 | layers[0] = layers[0].cpu() 63 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 64 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 65 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 66 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 67 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 68 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 69 | torch.cuda.empty_cache() 70 | 71 | outs = torch.zeros_like(inps) 72 | attention_mask = cache['attention_mask'] 73 | 74 | print('Ready.') 75 | 76 | quantizers = {} 77 | for i in range(len(layers)): 78 | layer = layers[i].to(dev) 79 | 80 | subset = find_layers(layer) 81 | gptq = {} 82 | for name in subset: 83 | gptq[name] = GPTQ(subset[name]) 84 | gptq[name].quantizer = Quantizer() 85 | gptq[name].quantizer.configure( 86 | args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits 87 | ) 88 | 89 | def add_batch(name): 90 | def tmp(_, inp, out): 91 | gptq[name].add_batch(inp[0].data, out.data) 92 | return tmp 93 | handles = [] 94 | for name in subset: 95 | handles.append(subset[name].register_forward_hook(add_batch(name))) 96 | for j in range(args.nsamples): 97 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 98 | for h in handles: 99 | h.remove() 100 | 101 | for name in subset: 102 | print(i, name) 103 | print('Quantizing ...') 104 | gptq[name].fasterquant( 105 | percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, static_groups=args.static_groups 106 | ) 107 | quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer 108 | gptq[name].free() 109 | for j in range(args.nsamples): 110 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 111 | 112 | layers[i] = layer.cpu() 113 | del layer 114 | del gptq 115 | torch.cuda.empty_cache() 116 | 117 | inps, outs = outs, inps 118 | 119 | model.config.use_cache = use_cache 120 | 121 | return quantizers 122 | 123 | @torch.no_grad() 124 | def opt_eval(model, testenc, dev): 125 | print('Evaluating ...') 126 | 127 | testenc = testenc.input_ids 128 | nsamples = testenc.numel() // model.seqlen 129 | 130 | use_cache = model.config.use_cache 131 | model.config.use_cache = False 132 | layers = model.model.decoder.layers 133 | 134 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 135 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 136 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 137 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 138 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 139 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 140 | layers[0] = layers[0].to(dev) 141 | 142 | dtype = next(iter(model.parameters())).dtype 143 | inps = torch.zeros( 144 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 145 | ) 146 | cache = {'i': 0, 'attention_mask': None} 147 | 148 | class Catcher(nn.Module): 149 | def __init__(self, module): 150 | super().__init__() 151 | self.module = module 152 | def forward(self, inp, **kwargs): 153 | inps[cache['i']] = inp 154 | cache['i'] += 1 155 | cache['attention_mask'] = kwargs['attention_mask'] 156 | raise ValueError 157 | layers[0] = Catcher(layers[0]) 158 | for i in range(nsamples): 159 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 160 | try: 161 | model(batch) 162 | except ValueError: 163 | pass 164 | layers[0] = layers[0].module 165 | 166 | layers[0] = layers[0].cpu() 167 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 168 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 169 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 170 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 171 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 172 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 173 | torch.cuda.empty_cache() 174 | 175 | outs = torch.zeros_like(inps) 176 | attention_mask = cache['attention_mask'] 177 | 178 | for i in range(len(layers)): 179 | print(i) 180 | layer = layers[i].to(dev) 181 | 182 | if args.nearest: 183 | subset = find_layers(layer) 184 | for name in subset: 185 | quantizer = Quantizer() 186 | quantizer.configure( 187 | args.wbits, perchannel=True, sym=args.sym, mse=False 188 | ) 189 | W = subset[name].weight.data 190 | quantizer.find_params(W, weight=True) 191 | subset[name].weight.data = quantize( 192 | W, quantizer.scale, quantizer.zero, quantizer.maxq 193 | ).to(next(iter(layer.parameters())).dtype) 194 | 195 | for j in range(nsamples): 196 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 197 | layers[i] = layer.cpu() 198 | del layer 199 | torch.cuda.empty_cache() 200 | inps, outs = outs, inps 201 | 202 | if model.model.decoder.final_layer_norm is not None: 203 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 204 | if model.model.decoder.project_out is not None: 205 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 206 | model.lm_head = model.lm_head.to(dev) 207 | 208 | testenc = testenc.to(dev) 209 | nlls = [] 210 | for i in range(nsamples): 211 | hidden_states = inps[i].unsqueeze(0) 212 | if model.model.decoder.final_layer_norm is not None: 213 | hidden_states = model.model.decoder.final_layer_norm(hidden_states) 214 | if model.model.decoder.project_out is not None: 215 | hidden_states = model.model.decoder.project_out(hidden_states) 216 | lm_logits = model.lm_head(hidden_states) 217 | shift_logits = lm_logits[:, :-1, :].contiguous() 218 | shift_labels = testenc[ 219 | :, (i * model.seqlen):((i + 1) * model.seqlen) 220 | ][:, 1:] 221 | loss_fct = nn.CrossEntropyLoss() 222 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 223 | neg_log_likelihood = loss.float() * model.seqlen 224 | nlls.append(neg_log_likelihood) 225 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 226 | print(ppl.item()) 227 | 228 | model.config.use_cache = use_cache 229 | 230 | # TODO: perform packing on GPU 231 | def opt_pack3(model, quantizers): 232 | layers = find_layers(model) 233 | layers = {n: layers[n] for n in quantizers} 234 | make_quant3(model, quantizers, faster=args.faster_kernel) 235 | qlayers = find_layers(model, [Quant3Linear]) 236 | print('Packing ...') 237 | for name in qlayers: 238 | print(name) 239 | quantizers[name] = quantizers[name].cpu() 240 | qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero) 241 | print('Done.') 242 | return model 243 | 244 | def load_quant3(model, checkpoint): 245 | from transformers import OPTConfig, OPTForCausalLM 246 | config = OPTConfig.from_pretrained(model) 247 | def noop(*args, **kwargs): 248 | pass 249 | torch.nn.init.kaiming_uniform_ = noop 250 | torch.nn.init.uniform_ = noop 251 | torch.nn.init.normal_ = noop 252 | 253 | torch.set_default_dtype(torch.half) 254 | transformers.modeling_utils._init_weights = False 255 | torch.set_default_dtype(torch.half) 256 | model = OPTForCausalLM(config) 257 | torch.set_default_dtype(torch.float) 258 | model = model.eval() 259 | layers = find_layers(model) 260 | for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']: 261 | if name in layers: 262 | del layers[name] 263 | make_quant3(model, layers, faster=args.faster_kernel) 264 | 265 | print('Loading model ...') 266 | model.load_state_dict(torch.load(checkpoint)) 267 | model.seqlen = model.config.max_position_embeddings 268 | print('Done.') 269 | 270 | return model 271 | 272 | def opt_multigpu(model, gpus): 273 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0]) 274 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0]) 275 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 276 | model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0]) 277 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 278 | model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1]) 279 | if hasattr(model.model.decoder, 'final_layer_norm') and model.model.decoder.final_layer_norm: 280 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1]) 281 | import copy 282 | model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1]) 283 | 284 | cache = {'mask': None} 285 | 286 | class MoveModule(nn.Module): 287 | def __init__(self, module): 288 | super().__init__() 289 | self.module = module 290 | self.dev = next(iter(self.module.parameters())).device 291 | def forward(self, *inp, **kwargs): 292 | inp = list(inp) 293 | if inp[0].device != self.dev: 294 | inp[0] = inp[0].to(self.dev) 295 | if cache['mask'] is None or cache['mask'].device != self.dev: 296 | cache['mask'] = kwargs['attention_mask'].to(self.dev) 297 | kwargs['attention_mask'] = cache['mask'] 298 | tmp = self.module(*inp, **kwargs) 299 | return tmp 300 | 301 | layers = model.model.decoder.layers 302 | pergpu = math.ceil(len(layers) / len(gpus)) 303 | for i in range(len(layers)): 304 | layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) 305 | 306 | model.gpus = gpus 307 | 308 | def benchmark(model, input_ids, check=False): 309 | input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) 310 | torch.cuda.synchronize() 311 | 312 | cache = {'past': None} 313 | def clear_past(i): 314 | def tmp(layer, inp, out): 315 | if cache['past']: 316 | cache['past'][i] = None 317 | return tmp 318 | for i, layer in enumerate(model.model.decoder.layers): 319 | layer.register_forward_hook(clear_past(i)) 320 | 321 | print('Benchmarking ...') 322 | 323 | if check: 324 | loss = nn.CrossEntropyLoss() 325 | tot = 0. 326 | 327 | def sync(): 328 | if hasattr(model, 'gpus'): 329 | for gpu in model.gpus: 330 | torch.cuda.synchronize(gpu) 331 | else: 332 | torch.cuda.synchronize() 333 | with torch.no_grad(): 334 | attention_mask = torch.ones((1, input_ids.numel()), device=DEV) 335 | times = [] 336 | for i in range(input_ids.numel()): 337 | tick = time.time() 338 | out = model( 339 | input_ids[:, i].reshape((1,-1)), 340 | past_key_values=cache['past'], 341 | attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)) 342 | ) 343 | sync() 344 | times.append(time.time() - tick) 345 | print(i, times[-1]) 346 | if check and i != input_ids.numel() - 1: 347 | tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() 348 | cache['past'] = list(out.past_key_values) 349 | del out 350 | sync() 351 | import numpy as np 352 | print('Median:', np.median(times)) 353 | if check: 354 | print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) 355 | 356 | 357 | if __name__ == '__main__': 358 | import argparse 359 | from datautils import * 360 | 361 | parser = argparse.ArgumentParser() 362 | 363 | parser.add_argument( 364 | 'model', type=str, 365 | help='OPT model to load; pass `facebook/opt-X`.' 366 | ) 367 | parser.add_argument( 368 | 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], 369 | help='Where to extract calibration data from.' 370 | ) 371 | parser.add_argument( 372 | '--seed', 373 | type=int, default=0, help='Seed for sampling the calibration data.' 374 | ) 375 | parser.add_argument( 376 | '--nsamples', type=int, default=128, 377 | help='Number of calibration data samples.' 378 | ) 379 | parser.add_argument( 380 | '--percdamp', type=float, default=.01, 381 | help='Percent of the average Hessian diagonal to use for dampening.' 382 | ) 383 | parser.add_argument( 384 | '--nearest', action='store_true', 385 | help='Whether to run the RTN baseline.' 386 | ) 387 | parser.add_argument( 388 | '--wbits', type=int, default=16, choices=[2, 3, 4, 16], 389 | help='#bits to use for quantization; use 16 for evaluating base model.' 390 | ) 391 | parser.add_argument( 392 | '--trits', action='store_true', 393 | help='Whether to use trits for quantization.' 394 | ) 395 | parser.add_argument( 396 | '--groupsize', type=int, default=-1, 397 | help='Groupsize to use for quantization; default uses full row.' 398 | ) 399 | parser.add_argument( 400 | '--sym', action='store_true', 401 | help='Whether to perform symmetric quantization.' 402 | ) 403 | parser.add_argument( 404 | '--save', type=str, default='', 405 | help='Save quantized checkpoint under this name.' 406 | ) 407 | parser.add_argument( 408 | '--load', type=str, default='', 409 | help='Load quantized model.' 410 | ) 411 | parser.add_argument( 412 | '--benchmark', type=int, default=0, 413 | help='Number of tokens to use for benchmarking.' 414 | ) 415 | parser.add_argument( 416 | '--check', action='store_true', 417 | help='Whether to compute perplexity during benchmarking for verification.' 418 | ) 419 | parser.add_argument( 420 | '--new-eval', action='store_true', 421 | help='Whether to use the new PTB and C4 eval.' 422 | ) 423 | parser.add_argument( 424 | '--faster-kernel', action='store_true', 425 | help='Whether to use the new faster kernel for benchmarking.' 426 | ) 427 | parser.add_argument( 428 | '--act-order', action='store_true', 429 | help='Whether to apply the activation order GPTQ heuristic' 430 | ) 431 | parser.add_argument( 432 | '--static-groups', action='store_true', 433 | help='Whether to use static groups; recommended when using `--actorder` for more efficient inference.' 434 | ) 435 | 436 | args = parser.parse_args() 437 | 438 | if args.load: 439 | model = load_quant3(args.model, args.load) 440 | else: 441 | model = get_opt(args.model) 442 | model.eval() 443 | 444 | dataloader, testloader = get_loaders( 445 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 446 | ) 447 | 448 | if args.wbits < 16 and not args.nearest: 449 | tick = time.time() 450 | quantizers = opt_sequential(model, dataloader, DEV) 451 | print(time.time() - tick) 452 | 453 | if args.benchmark: 454 | gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] 455 | if len(gpus) > 1: 456 | opt_multigpu(model, gpus) 457 | else: 458 | model = model.to(DEV) 459 | if args.benchmark: 460 | input_ids = next(iter(dataloader))[0][:, :args.benchmark] 461 | benchmark(model, input_ids, check=args.check) 462 | if args.load: 463 | exit() 464 | 465 | datasets = ['wikitext2', 'ptb', 'c4'] 466 | if args.new_eval: 467 | datasets = ['wikitext2', 'ptb-new', 'c4-new'] 468 | for dataset in datasets: 469 | dataloader, testloader = get_loaders( 470 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 471 | ) 472 | print(dataset) 473 | opt_eval(model, testloader, DEV) 474 | 475 | if args.save: 476 | opt_pack3(model, quantizers) 477 | torch.save(model.state_dict(), args.save) 478 | -------------------------------------------------------------------------------- /quant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def quantize(x, scale, zero, maxq): 7 | if maxq < 0: 8 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 9 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 10 | return scale * (q - zero) 11 | 12 | class Quantizer(nn.Module): 13 | 14 | def __init__(self, shape=1): 15 | super(Quantizer, self).__init__() 16 | self.register_buffer('maxq', torch.tensor(0)) 17 | self.register_buffer('scale', torch.zeros(shape)) 18 | self.register_buffer('zero', torch.zeros(shape)) 19 | 20 | def configure( 21 | self, 22 | bits, perchannel=False, sym=True, 23 | mse=False, norm=2.4, grid=100, maxshrink=.8, 24 | trits=False 25 | ): 26 | self.maxq = torch.tensor(2 ** bits - 1) 27 | self.perchannel = perchannel 28 | self.sym = sym 29 | self.mse = mse 30 | self.norm = norm 31 | self.grid = grid 32 | self.maxshrink = maxshrink 33 | if trits: 34 | self.maxq = torch.tensor(-1) 35 | 36 | def find_params(self, x, weight=False): 37 | dev = x.device 38 | self.maxq = self.maxq.to(dev) 39 | 40 | shape = x.shape 41 | if self.perchannel: 42 | if weight: 43 | x = x.flatten(1) 44 | else: 45 | if len(shape) == 4: 46 | x = x.permute([1, 0, 2, 3]) 47 | x = x.flatten(1) 48 | if len(shape) == 3: 49 | x = x.reshape((-1, shape[-1])).t() 50 | if len(shape) == 2: 51 | x = x.t() 52 | else: 53 | x = x.flatten().unsqueeze(0) 54 | 55 | tmp = torch.zeros(x.shape[0], device=dev) 56 | xmin = torch.minimum(x.min(1)[0], tmp) 57 | xmax = torch.maximum(x.max(1)[0], tmp) 58 | 59 | if self.sym: 60 | xmax = torch.maximum(torch.abs(xmin), xmax) 61 | tmp = xmin < 0 62 | if torch.any(tmp): 63 | xmin[tmp] = -xmax[tmp] 64 | tmp = (xmin == 0) & (xmax == 0) 65 | xmin[tmp] = -1 66 | xmax[tmp] = +1 67 | 68 | if self.maxq < 0: 69 | self.scale = xmax 70 | self.zero = xmin 71 | else: 72 | self.scale = (xmax - xmin) / self.maxq 73 | if self.sym: 74 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 75 | else: 76 | self.zero = torch.round(-xmin / self.scale) 77 | 78 | if self.mse: 79 | best = torch.full([x.shape[0]], float('inf'), device=dev) 80 | for i in range(int(self.maxshrink * self.grid)): 81 | p = 1 - i / self.grid 82 | xmin1 = p * xmin 83 | xmax1 = p * xmax 84 | scale1 = (xmax1 - xmin1) / self.maxq 85 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 86 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 87 | q -= x 88 | q.abs_() 89 | q.pow_(self.norm) 90 | err = torch.sum(q, 1) 91 | tmp = err < best 92 | if torch.any(tmp): 93 | best[tmp] = err[tmp] 94 | self.scale[tmp] = scale1[tmp] 95 | self.zero[tmp] = zero1[tmp] 96 | if not self.perchannel: 97 | if weight: 98 | tmp = shape[0] 99 | else: 100 | tmp = shape[1] if len(shape) != 3 else shape[2] 101 | self.scale = self.scale.repeat(tmp) 102 | self.zero = self.zero.repeat(tmp) 103 | 104 | if weight: 105 | shape = [-1] + [1] * (len(shape) - 1) 106 | self.scale = self.scale.reshape(shape) 107 | self.zero = self.zero.reshape(shape) 108 | return 109 | if len(shape) == 4: 110 | self.scale = self.scale.reshape((1, -1, 1, 1)) 111 | self.zero = self.zero.reshape((1, -1, 1, 1)) 112 | if len(shape) == 3: 113 | self.scale = self.scale.reshape((1, 1, -1)) 114 | self.zero = self.zero.reshape((1, 1, -1)) 115 | if len(shape) == 2: 116 | self.scale = self.scale.unsqueeze(0) 117 | self.zero = self.zero.unsqueeze(0) 118 | 119 | def quantize(self, x): 120 | if self.ready(): 121 | return quantize(x, self.scale, self.zero, self.maxq) 122 | return x 123 | 124 | def enabled(self): 125 | return self.maxq > 0 126 | 127 | def ready(self): 128 | return torch.all(self.scale != 0) 129 | 130 | 131 | try: 132 | import quant_cuda 133 | except: 134 | print('CUDA extension not installed.') 135 | 136 | # Assumes layer is perfectly divisible into 1024 * 1024 blocks 137 | class Quant3Linear(nn.Module): 138 | 139 | def __init__(self, infeatures, outfeatures, faster=False): 140 | super().__init__() 141 | self.register_buffer('zeros', torch.zeros((outfeatures, 1))) 142 | self.register_buffer('scales', torch.zeros((outfeatures, 1))) 143 | self.register_buffer('bias', torch.zeros(outfeatures)) 144 | self.register_buffer( 145 | 'qweight', torch.zeros((infeatures // 32 * 3, outfeatures), dtype=torch.int) 146 | ) 147 | self.faster = faster 148 | 149 | def pack(self, linear, scales, zeros): 150 | self.zeros = zeros * scales 151 | self.scales = scales.clone() 152 | if linear.bias is not None: 153 | self.bias = linear.bias.clone() 154 | 155 | intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int) 156 | intweight = intweight.t().contiguous() 157 | intweight = intweight.numpy().astype(np.uint32) 158 | qweight = np.zeros( 159 | (intweight.shape[0] // 32 * 3, intweight.shape[1]), dtype=np.uint32 160 | ) 161 | i = 0 162 | row = 0 163 | while row < qweight.shape[0]: 164 | for j in range(i, i + 10): 165 | qweight[row] |= intweight[j] << (3 * (j - i)) 166 | i += 10 167 | qweight[row] |= intweight[i] << 30 168 | row += 1 169 | qweight[row] |= (intweight[i] >> 2) & 1 170 | i += 1 171 | for j in range(i, i + 10): 172 | qweight[row] |= intweight[j] << (3 * (j - i) + 1) 173 | i += 10 174 | qweight[row] |= intweight[i] << 31 175 | row += 1 176 | qweight[row] |= (intweight[i] >> 1) & 0x3 177 | i += 1 178 | for j in range(i, i + 10): 179 | qweight[row] |= intweight[j] << (3 * (j - i) + 2) 180 | i += 10 181 | row += 1 182 | 183 | qweight = qweight.astype(np.int32) 184 | self.qweight = torch.from_numpy(qweight) 185 | 186 | def forward(self, x): 187 | if x.shape[-1] == x.numel(): 188 | outshape = list(x.shape) 189 | y = self.bias.clone() 190 | outshape[-1] = self.bias.numel() 191 | dtype = x.dtype 192 | if self.faster: 193 | x = x.half() 194 | quant_cuda.vecquant3matmul_faster(x, self.qweight, y, self.scales, self.zeros) 195 | else: 196 | x = x.float() 197 | quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros) 198 | y = y.to(dtype) 199 | return y.reshape(outshape) 200 | raise ValueError('Only supports a single token currently.') 201 | 202 | def make_quant3(module, names, name='', faster=False): 203 | if isinstance(module, Quant3Linear): 204 | return 205 | for attr in dir(module): 206 | tmp = getattr(module, attr) 207 | name1 = name + '.' + attr if name != '' else attr 208 | if name1 in names: 209 | setattr( 210 | module, attr, Quant3Linear(tmp.in_features, tmp.out_features, faster=faster) 211 | ) 212 | for name1, child in module.named_children(): 213 | make_quant3(child, names, name + '.' + name1 if name != '' else name1, faster=faster) 214 | -------------------------------------------------------------------------------- /quant_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void vecquant3matmul_cuda( 6 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 7 | torch::Tensor scales, torch::Tensor zeros 8 | ); 9 | 10 | void vecquant3matmul_faster_cuda( 11 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 12 | torch::Tensor scales, torch::Tensor zeros 13 | ); 14 | 15 | void vecquant3matmul( 16 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 17 | torch::Tensor scales, torch::Tensor zeros 18 | ) { 19 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 20 | vecquant3matmul_cuda(vec, mat, mul, scales, zeros); 21 | } 22 | 23 | void vecquant3matmul_faster( 24 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 25 | torch::Tensor scales, torch::Tensor zeros 26 | ) { 27 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 28 | vecquant3matmul_faster_cuda(vec, mat, mul, scales, zeros); 29 | } 30 | 31 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 32 | m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); 33 | m.def("vecquant3matmul_faster", &vecquant3matmul_faster, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version"); 34 | } 35 | -------------------------------------------------------------------------------- /quant_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | template 8 | __global__ void VecQuant3MatMulKernel( 9 | const scalar_t* __restrict__ vec, 10 | const int* __restrict__ mat, 11 | scalar_t* __restrict__ mul, 12 | const scalar_t* __restrict__ scales, 13 | const scalar_t* __restrict__ zeros, 14 | int height, 15 | int width 16 | ); 17 | 18 | __global__ void VecQuant3MatMulKernelFaster( 19 | const half2* __restrict__ vec, 20 | const int* __restrict__ mat, 21 | float* __restrict__ mul, 22 | const float* __restrict__ scales, 23 | const float* __restrict__ zeros, 24 | int height, 25 | int width 26 | ); 27 | 28 | const int BLOCKWIDTH = 256; 29 | const int BLOCKHEIGHT = 24; 30 | 31 | void vecquant3matmul_cuda( 32 | torch::Tensor vec, 33 | torch::Tensor mat, 34 | torch::Tensor mul, 35 | torch::Tensor scales, 36 | torch::Tensor zeros 37 | ) { 38 | int height = mat.size(0); 39 | int width = mat.size(1); 40 | 41 | dim3 blocks( 42 | (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT, 43 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 44 | ); 45 | dim3 threads(BLOCKWIDTH); 46 | 47 | AT_DISPATCH_FLOATING_TYPES( 48 | vec.type(), "vecquant3matmul_cuda", ([&] { 49 | VecQuant3MatMulKernel<<>>( 50 | vec.data(), mat.data(), mul.data(), 51 | scales.data(), zeros.data(), 52 | height, width 53 | ); 54 | }) 55 | ); 56 | } 57 | 58 | void vecquant3matmul_faster_cuda( 59 | torch::Tensor vec, 60 | torch::Tensor mat, 61 | torch::Tensor mul, 62 | torch::Tensor scales, 63 | torch::Tensor zeros 64 | ) { 65 | int height = mat.size(0); 66 | int width = mat.size(1); 67 | 68 | dim3 blocks( 69 | (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT, 70 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 71 | ); 72 | dim3 threads(BLOCKWIDTH); 73 | 74 | VecQuant3MatMulKernelFaster<<>>( 75 | (half2*) vec.data_ptr(), 76 | mat.data_ptr(), 77 | mul.data_ptr(), 78 | scales.data_ptr(), 79 | zeros.data_ptr(), 80 | height, width 81 | ); 82 | } 83 | 84 | __device__ inline unsigned int as_unsigned(int i) { 85 | return *reinterpret_cast(&i); 86 | } 87 | 88 | template 89 | __global__ void VecQuant3MatMulKernel( 90 | const scalar_t* __restrict__ vec, 91 | const int* __restrict__ mat, 92 | scalar_t* __restrict__ mul, 93 | const scalar_t* __restrict__ scales, 94 | const scalar_t* __restrict__ zeros, 95 | int height, 96 | int width 97 | ) { 98 | int row = BLOCKHEIGHT * blockIdx.x; 99 | int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; 100 | 101 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 102 | blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT) * BLOCKWIDTH + threadIdx.x]; 103 | __syncthreads(); 104 | 105 | scalar_t scale = scales[col]; 106 | scalar_t zero = zeros[col]; 107 | 108 | scalar_t res = 0; 109 | int i = width * row + col; 110 | int k = 0; 111 | 112 | unsigned int tmp1; 113 | unsigned int tmp2; 114 | unsigned int tmp; 115 | 116 | while (k < BLOCKWIDTH) { 117 | tmp1 = as_unsigned(mat[i]); 118 | res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; 119 | res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; 120 | res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; 121 | res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; 122 | res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; 123 | res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; 124 | res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; 125 | res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; 126 | res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; 127 | res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; 128 | i += width; 129 | tmp2 = as_unsigned(mat[i]); 130 | tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); 131 | tmp2 >>= 1; 132 | res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; 133 | k += 11; 134 | res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; 135 | res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; 136 | res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; 137 | res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; 138 | res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; 139 | res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; 140 | res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; 141 | res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; 142 | res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; 143 | res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; 144 | i += width; 145 | tmp1 = as_unsigned(mat[i]); 146 | tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); 147 | tmp1 >>= 2; 148 | res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; 149 | k += 11; 150 | res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; 151 | res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; 152 | res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; 153 | res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; 154 | res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; 155 | res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; 156 | res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; 157 | res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; 158 | res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; 159 | res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; 160 | i += width; 161 | k += 10; 162 | } 163 | 164 | atomicAdd(&mul[col], res); 165 | } 166 | 167 | __global__ void VecQuant3MatMulKernelFaster( 168 | const half2* __restrict__ vec, 169 | const int* __restrict__ mat, 170 | float* __restrict__ mul, 171 | const float* __restrict__ scales, 172 | const float* __restrict__ zeros, 173 | int height, 174 | int width 175 | ) { 176 | const int blockwidth2 = BLOCKWIDTH / 2; 177 | 178 | int row = BLOCKHEIGHT * blockIdx.x; 179 | int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; 180 | 181 | __shared__ half2 blockvec[blockwidth2]; 182 | if (threadIdx.x < blockwidth2) 183 | blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT) * blockwidth2 + threadIdx.x]; 184 | 185 | __shared__ half2 deq2[64][32]; 186 | int val = threadIdx.x / 32; 187 | int off = threadIdx.x % 32; 188 | for (; val < 64; val += BLOCKWIDTH / 32) { 189 | deq2[val][off] = __halves2half2( 190 | __int2half_rn(val & 0x7), __int2half_rn(val >> 3) 191 | ); 192 | } 193 | 194 | half2 scale = __float2half2_rn(scales[col]); 195 | half2 zero = __float2half2_rn(-zeros[col]); 196 | 197 | int i = width * row + col; 198 | int k = 0; 199 | 200 | float res = 0; 201 | half2 res2; 202 | 203 | unsigned int tmp1; 204 | unsigned int tmp2; 205 | unsigned int tmp; 206 | 207 | __syncthreads(); 208 | 209 | while (k < blockwidth2) { 210 | res2 = {}; 211 | tmp1 = as_unsigned(mat[i]); 212 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); 213 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); 214 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); 215 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); 216 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); 217 | i += width; 218 | tmp2 = as_unsigned(mat[i]); 219 | tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x3c); 220 | res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 5], res2); 221 | tmp2 >>= 4; 222 | k += 6; 223 | res2 = __hfma2(__hfma2(deq2[(tmp2 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); 224 | res2 = __hfma2(__hfma2(deq2[(tmp2 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); 225 | res2 = __hfma2(__hfma2(deq2[(tmp2 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); 226 | res2 = __hfma2(__hfma2(deq2[(tmp2 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); 227 | i += width; 228 | tmp1 = as_unsigned(mat[i]); 229 | tmp = (tmp2 >> 24) | ((tmp1 << 4) & 0x30); 230 | res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 4], res2); 231 | tmp1 >>= 2; 232 | k += 5; 233 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); 234 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); 235 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); 236 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); 237 | res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); 238 | i += width; 239 | k += 5; 240 | res += __half2float(res2.x) + __half2float(res2.y); 241 | } 242 | 243 | atomicAdd(&mul[col], res); 244 | } 245 | -------------------------------------------------------------------------------- /setup_cuda.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from torch.utils import cpp_extension 3 | 4 | setup( 5 | name='quant_cuda', 6 | ext_modules=[cpp_extension.CUDAExtension( 7 | 'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu'] 8 | )], 9 | cmdclass={'build_ext': cpp_extension.BuildExtension} 10 | ) 11 | -------------------------------------------------------------------------------- /test_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import quant_cuda 5 | 6 | torch.backends.cuda.matmul.allow_tf32 = False 7 | torch.backends.cudnn.allow_tf32 = False 8 | 9 | print('Benchmarking OPT-175B FC2 matvec ...') 10 | 11 | DEV = torch.device('cuda:0') 12 | 13 | M = 12288 * 4 14 | N = 12288 15 | 16 | DTYPE = torch.half 17 | mat = torch.randn((M, N), device=DEV, dtype=DTYPE) 18 | vec = torch.randn((1, M), device=DEV, dtype=DTYPE) 19 | mul = torch.zeros((1, N), device=DEV, dtype=DTYPE) 20 | 21 | COUNT = 1000 22 | import time 23 | tick = time.time() 24 | for _ in range(COUNT): 25 | torch.matmul(vec, mat, out=mul) 26 | torch.cuda.synchronize() 27 | print('FP16:', (time.time() - tick) / COUNT) 28 | 29 | DTYPE = torch.float 30 | mat = mat.to(DTYPE) 31 | vec = vec.to(DTYPE) 32 | mul = mul.to(DTYPE) 33 | 34 | mat = torch.randint(-1000000000, 1000000000, (M // 1024 * 96, N), device=DEV, dtype=torch.int) 35 | scales = torch.randn(N, device=DEV, dtype=DTYPE) 36 | zeros = torch.randn(N, device=DEV, dtype=DTYPE) 37 | 38 | COUNT = 1000 39 | import time 40 | tick = time.time() 41 | for _ in range(COUNT): 42 | quant_cuda.vecquant3matmul(vec, mat, mul, scales, zeros) 43 | torch.cuda.synchronize() 44 | print('3bit:', (time.time() - tick) / COUNT) 45 | 46 | COUNT = 1000 47 | import time 48 | tick = time.time() 49 | for _ in range(COUNT): 50 | quant_cuda.vecquant3matmul_faster(vec, mat, mul, scales, zeros) 51 | torch.cuda.synchronize() 52 | print('3bit:', (time.time() - tick) / COUNT, '(faster)') 53 | 54 | print('Verifiying kernel correctness ...') 55 | 56 | M = 4 * 4096 57 | N = 4096 58 | 59 | layer = nn.Linear(M, N) 60 | vec = torch.randn(M).to(DEV) 61 | 62 | from quant import * 63 | quantizer = Quantizer() 64 | quantizer.configure(3, perchannel=True, sym=False, mse=False) 65 | quantizer.find_params(layer.weight.data, weight=True) 66 | layer.weight.data = quantize( 67 | layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq 68 | ) 69 | 70 | qlayer = Quant3Linear(layer.in_features, layer.out_features) 71 | qlayer.pack(layer, quantizer.scale, quantizer.zero) 72 | 73 | qlayer = qlayer.to(DEV) 74 | layer = layer.to(DEV) 75 | 76 | with torch.no_grad(): 77 | print('Simu:', layer.to(DEV)(vec)) 78 | print('Kern:', qlayer(vec)) 79 | qlayer.faster = True 80 | print('Kern:', qlayer(vec.half()), '(faster)') 81 | -------------------------------------------------------------------------------- /zeroShot/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /zeroShot/README.md: -------------------------------------------------------------------------------- 1 | This folder contains code to reproduce the FewShot tasks. We follow the structure of 2 | [this](https://github.com/EleutherAI/lm-evaluation-harness) repository for implementing 3 | our tasks and the evaluation framework. 4 | 5 | We implement the following tasks: 6 | - [x] LAMBADA 7 | - [x] PIQA 8 | - [x] ARC-easy 9 | - [x] ARC-challenge 10 | - [x] COPA 11 | - [x] WSC 12 | - [x] RTE 13 | - [x] CB 14 | - [x] StoryCloze-2018 15 | 16 | 17 | To add new tasks, please follow [this](https://github.com/EleutherAI/lm-evaluation-harness#code-structure) 18 | instruction. 19 | 20 | ## Dependencies 21 | 22 | * `torch`: tested on v1.10.1+cu111 23 | * `transformers`: tested on v4.21.2 24 | * `datasets`: tested on v1.17.0 25 | * `sacrebleu`: tested on v2.3.1 26 | * `scikit-learn`: tested on v1.0.2 27 | 28 | All experiments were run on a single 80GB NVIDIA A100. However, most experiments will work on a GPU with a lot less memory as well. 29 | 30 | # Usage 31 | 32 | To use the code, you need to simply run the following command: 33 | 34 | ```bash 35 | python3 main.py --task --num_fewshot 36 | ``` 37 | 38 | ### Example: PIQA 39 | 40 | To run `OPT` on the PIQA task, you need to run the following command: 41 | ``` 42 | # Compute full precision (FP16) results 43 | CUDA_VISIBLE_DEVICES=0 python main.py facebook/opt-125m c4 --task piqa 44 | # Run RTN baseline and compute results 45 | CUDA_VISIBLE_DEVICES=0 python main.py facebook/opt-125m c4 --wbits 4 --nearest --task piqa 46 | # Run GPTQ and compute results 47 | CUDA_VISIBLE_DEVICES=0 python main.py facebook/opt-125m c4 --wbits 4 --task piqa 48 | ```` 49 | 50 | To run other OPT models replace `opt-125m` with one of: `opt-350m`, `opt-1.3b`, `opt-2.7b`, `opt-6.7b`, `opt-13b`, `opt-66b`. 51 | For 175B you must request access from Meta and then convert it to a local HuggingFace checkpoint using their scripts in `metaseq`. 52 | Once you have such a checkpoint, simply pass its path instead of `facebook/opt-125m`. 53 | 54 | 55 | To run `BLOOM` models, you need to run the following command: 56 | 57 | ``` 58 | # Compute full precision (FP16) results 59 | CUDA_VISIBLE_DEVICES=0 python main.py bigscience/bloom-560m c4 --task piqa 60 | # Run RTN baseline and compute results 61 | CUDA_VISIBLE_DEVICES=0 python main.py bigscience/bloom-560m c4 --wbits 4 --nearest --task piqa 62 | # Run GPTQ and compute results 63 | CUDA_VISIBLE_DEVICES=0 python main.py bigscience/bloom-560m c4 --wbits 4 --task piqa 64 | ```` 65 | 66 | To run other BLOOM models replace `bloom-560m` with one of: `bloom-1b1`, `bloom-1b7`, `bloom-3b`, `bloom-7b1`, `bloom`. 67 | 68 | -------------------------------------------------------------------------------- /zeroShot/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_seed(seed): 6 | np.random.seed(seed) 7 | torch.random.manual_seed(seed) 8 | 9 | 10 | def get_wikitext2(nsamples, seed, seqlen, model): 11 | from datasets import load_dataset 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | 15 | from transformers import AutoTokenizer 16 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 17 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 19 | 20 | import random 21 | random.seed(seed) 22 | trainloader = [] 23 | for _ in range(nsamples): 24 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 25 | j = i + seqlen 26 | inp = trainenc.input_ids[:, i:j] 27 | tar = inp.clone() 28 | tar[:, :-1] = -100 29 | trainloader.append((inp, tar)) 30 | return trainloader, testenc 31 | 32 | def get_ptb(nsamples, seed, seqlen, model): 33 | from datasets import load_dataset 34 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 35 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') 36 | 37 | from transformers import AutoTokenizer 38 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 39 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') 40 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 41 | 42 | import random 43 | random.seed(seed) 44 | trainloader = [] 45 | for _ in range(nsamples): 46 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 47 | j = i + seqlen 48 | inp = trainenc.input_ids[:, i:j] 49 | tar = inp.clone() 50 | tar[:, :-1] = -100 51 | trainloader.append((inp, tar)) 52 | return trainloader, testenc 53 | 54 | def get_c4(nsamples, seed, seqlen, model): 55 | from datasets import load_dataset 56 | traindata = load_dataset( 57 | 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' 58 | ) 59 | valdata = load_dataset( 60 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' 61 | ) 62 | 63 | from transformers import AutoTokenizer 64 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 65 | 66 | import random 67 | random.seed(seed) 68 | trainloader = [] 69 | for _ in range(nsamples): 70 | while True: 71 | i = random.randint(0, len(traindata) - 1) 72 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 73 | if trainenc.input_ids.shape[1] >= seqlen: 74 | break 75 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 76 | j = i + seqlen 77 | inp = trainenc.input_ids[:, i:j] 78 | tar = inp.clone() 79 | tar[:, :-1] = -100 80 | trainloader.append((inp, tar)) 81 | 82 | import random 83 | random.seed(0) 84 | valenc = [] 85 | for _ in range(256): 86 | while True: 87 | i = random.randint(0, len(valdata) - 1) 88 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 89 | if tmp.input_ids.shape[1] >= seqlen: 90 | break 91 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 92 | j = i + seqlen 93 | valenc.append(tmp.input_ids[:, i:j]) 94 | valenc = torch.hstack(valenc) 95 | class TokenizerWrapper: 96 | def __init__(self, input_ids): 97 | self.input_ids = input_ids 98 | valenc = TokenizerWrapper(valenc) 99 | 100 | return trainloader, valenc 101 | 102 | 103 | def get_loaders( 104 | name, nsamples=128, seed=0, seqlen=2048, model='' 105 | ): 106 | if 'wikitext2' in name: 107 | return get_wikitext2(nsamples, seed, seqlen, model) 108 | if 'ptb' in name: 109 | return get_ptb(nsamples, seed, seqlen, model) 110 | if 'c4' in name: 111 | return get_c4(nsamples, seed, seqlen, model) 112 | -------------------------------------------------------------------------------- /zeroShot/evaluator.py: -------------------------------------------------------------------------------- 1 | 2 | from utils import positional_deprecated 3 | import random 4 | import numpy as np 5 | import models 6 | import models.models_utils 7 | import tasks 8 | import collections 9 | import itertools 10 | import metrics 11 | import torch 12 | import time 13 | 14 | from datautils import get_loaders 15 | 16 | @positional_deprecated 17 | def simple_evaluate( 18 | # model, 19 | args, 20 | tasks_list=[] 21 | ): 22 | 23 | """Instantiate and evaluate a model on a list of tasks. 24 | :param args: Optional[str] 25 | args for the zeroShot tasks 26 | :param tasks: list[Union[str, Task]] 27 | List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. 28 | :return 29 | Dictionary of results 30 | """ 31 | random.seed(args.seed) 32 | np.random.seed(args.seed) 33 | torch.manual_seed(args.seed) 34 | torch.cuda.manual_seed(args.seed) 35 | 36 | assert tasks_list != [], "No tasks specified" 37 | 38 | lm = models.get_model(args.model).create_from_arg_string({"args": args}) 39 | 40 | if args.load: 41 | print('Loading checkpoint from {}...'.format(args.load)) 42 | lm.model.load_state_dict(torch.load(args.load)) 43 | 44 | if args.wbits < 16 and not args.nearest: 45 | 46 | tick = time.time() 47 | dataloader, testloader = get_loaders( 48 | args.dataset, seed=args.seed, model=args.model, seqlen=lm.seqlen 49 | ) 50 | if 'opt' in args.model: 51 | quantizers = lm.opt_sequential(dataloader) 52 | else: 53 | quantizers = lm.bloom_sequential(dataloader) 54 | print(time.time() - tick) 55 | 56 | task_dict = tasks.get_task_dict(tasks_list) 57 | 58 | results = evaluate( 59 | lm=lm, 60 | task_dict=task_dict, 61 | seed=args.seed, 62 | num_fewshot=args.num_fewshot, 63 | ) 64 | 65 | # add info about the model and few shot config 66 | results["config"] = { 67 | "model": args.model, 68 | "num_fewshot": args.num_fewshot, 69 | "batch_size": args.batch_size, 70 | "bootstrap_iters": 1000, 71 | } 72 | 73 | return results 74 | 75 | @positional_deprecated 76 | def evaluate( 77 | lm, 78 | task_dict, 79 | seed=0, 80 | num_fewshot=0, 81 | ): 82 | """Instantiate and evaluate a model on a list of tasks. 83 | 84 | :param lm: obj 85 | Language Model 86 | :param task_dict: dict[str, Task] 87 | Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. 88 | :param provide_description: bool 89 | Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method 90 | :param num_fewshot: int 91 | Number of examples in few-shot context 92 | :return 93 | Dictionary of results 94 | """ 95 | 96 | task_dict_items = [ 97 | (name, task) 98 | for name, task in task_dict.items() 99 | if (task.has_validation_docs() or task.has_test_docs()) 100 | ] 101 | 102 | results = collections.defaultdict(dict) 103 | versions = collections.defaultdict(dict) 104 | 105 | requests = collections.defaultdict(list) 106 | requests_origin = collections.defaultdict(list) 107 | 108 | overlaps = collections.defaultdict(list) # {task_name: contaminated_docs} 109 | 110 | # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger 111 | # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because 112 | # over-engineering is bad (or we could make it write the requests to disk and then read them back out again 113 | # - probably using an sqlite db because of all the moving parts we have 114 | 115 | # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable 116 | docs = {} 117 | 118 | docs_for_decontamination = collections.defaultdict(list) 119 | 120 | # get lists of each type of request 121 | for task_name, task in task_dict_items: 122 | versions[task_name] = task.VERSION 123 | # default to test doc, fall back to val doc if validation unavailable 124 | # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point 125 | if task.has_test_docs(): 126 | task_doc_func = task.test_docs 127 | task_set = "test" # Required for caching in the decontamination 128 | elif task.has_validation_docs(): 129 | task_set = "val" # Required for caching in the decontamination 130 | task_doc_func = task.validation_docs 131 | else: 132 | raise RuntimeError("Task has neither test_docs nor validation_docs") 133 | 134 | # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order 135 | task_docs = list(task_doc_func()) 136 | rnd = random.Random() 137 | rnd.seed(seed) 138 | # rnd.shuffle(task_docs) 139 | 140 | description = "" 141 | 142 | for doc_id, doc in enumerate(itertools.islice(task_docs, 0, None)): 143 | 144 | docs[(task_name, doc_id)] = doc 145 | ctx = task.fewshot_context( 146 | doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description 147 | ) 148 | reqs = task.construct_requests(doc, ctx) 149 | if not isinstance(reqs, (list, tuple)): 150 | reqs = [reqs] 151 | for i, req in enumerate(reqs): 152 | requests[req.request_type].append(req) 153 | # i: index in requests for a single task instance 154 | # doc_id: unique id that we can get back to a doc using `docs` 155 | requests_origin[req.request_type].append((i, task_name, doc, doc_id)) 156 | 157 | # all responses for each (task, doc) 158 | process_res_queue = collections.defaultdict(list) 159 | 160 | 161 | # execute each type of request 162 | for reqtype, reqs in requests.items(): 163 | 164 | # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing 165 | # only in index. We could implement some kind of caching, but that would be more of a band-aid 166 | # solution. we could also implement some kind of auto-grouping here; 167 | # they should end up next to each other. 168 | 169 | print("Running", reqtype, "requests") 170 | resps = getattr(lm, reqtype)([req.args for req in reqs]) 171 | resps = [ 172 | x if req.index is None else x[req.index] for x, req in zip(resps, reqs) 173 | ] 174 | for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): 175 | process_res_queue[(task_name, doc_id)].append((i, resp)) 176 | 177 | vals = collections.defaultdict(list) 178 | 179 | # unpack results and sort back in order and return control to Task 180 | for (task_name, doc_id), requests in process_res_queue.items(): 181 | requests.sort(key=lambda x: x[0]) 182 | requests = [x[1] for x in requests] 183 | 184 | task = task_dict[task_name] 185 | doc = docs[(task_name, doc_id)] 186 | 187 | metrics_dict = task.process_results(doc, requests) 188 | for metric, value in metrics_dict.items(): 189 | vals[(task_name, metric)].append(value) 190 | 191 | # aggregate results 192 | for (task_name, metric), items in vals.items(): 193 | task = task_dict[task_name] 194 | real_metric = metric # key when looking up the metric with task.aggregation 195 | if metric.endswith(decontaminate_suffix): 196 | real_metric = metric.replace( 197 | decontaminate_suffix, "" 198 | ) # decontaminated still uses the same metric 199 | results[task_name][metric] = task.aggregation()[real_metric](items) 200 | 201 | # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap 202 | # so we run them less iterations. still looking for a cleaner way to do this 203 | 204 | stderr = metrics.stderr_for_metric( 205 | metric=task.aggregation()[real_metric], 206 | bootstrap_iters=1000 207 | ) 208 | 209 | if stderr is not None: 210 | results[task_name][metric + "_stderr"] = stderr(items) 211 | 212 | return {"results": dict(results), "versions": dict(versions)} 213 | 214 | 215 | def make_table(result_dict): 216 | """Generate table of results.""" 217 | from pytablewriter import MarkdownTableWriter, LatexTableWriter 218 | 219 | md_writer = MarkdownTableWriter() 220 | latex_writer = LatexTableWriter() 221 | md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] 222 | latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] 223 | 224 | values = [] 225 | 226 | for k, dic in result_dict["results"].items(): 227 | version = result_dict["versions"][k] 228 | for m, v in dic.items(): 229 | if m.endswith("_stderr"): 230 | continue 231 | 232 | if m + "_stderr" in dic: 233 | se = dic[m + "_stderr"] 234 | values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se]) 235 | else: 236 | values.append([k, version, m, "%.4f" % v, "", ""]) 237 | k = "" 238 | version = "" 239 | md_writer.value_matrix = values 240 | latex_writer.value_matrix = values 241 | return md_writer.dumps() 242 | 243 | decontaminate_suffix = "_decontaminate" 244 | -------------------------------------------------------------------------------- /zeroShot/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | import evaluator 5 | import tasks 6 | from utils import parse_args, pattern_match 7 | 8 | 9 | def main(): 10 | args = parse_args() 11 | 12 | if args.tasks is None: 13 | raise ValueError("Please specify a task to run") 14 | else: 15 | task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS) 16 | 17 | print(f"Selected Tasks: {task_names}") 18 | 19 | results = evaluator.simple_evaluate( 20 | args=args, 21 | tasks_list=task_names, 22 | ) 23 | 24 | dumped = json.dumps(results, indent=2) 25 | print(dumped) 26 | 27 | if args.output_path: 28 | with open(args.output_path, "w") as f: 29 | f.write(dumped) 30 | 31 | print( 32 | f"{args.model}" 33 | f"num_fewshot: {args.num_fewshot}," 34 | f" batch_size: {args.batch_size}" 35 | ) 36 | if args.table_results: 37 | print(evaluator.make_table(results)) 38 | else: 39 | from pprint import pprint 40 | pprint(results) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /zeroShot/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Iterable 3 | 4 | import numpy as np 5 | import sacrebleu 6 | import sklearn.metrics 7 | import random 8 | 9 | 10 | def mean(arr): 11 | return sum(arr) / len(arr) 12 | 13 | 14 | def pop_stddev(arr): 15 | mu = mean(arr) 16 | return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) 17 | 18 | 19 | def sample_stddev(arr): 20 | mu = mean(arr) 21 | return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) 22 | 23 | 24 | def mean_stderr(arr): 25 | return sample_stddev(arr) / math.sqrt(len(arr)) 26 | 27 | 28 | def median(arr): 29 | return arr[len(arr) // 2] 30 | 31 | 32 | def matthews_corrcoef(items): 33 | unzipped_list = list(zip(*items)) 34 | golds = unzipped_list[0] 35 | preds = unzipped_list[1] 36 | return sklearn.metrics.matthews_corrcoef(golds, preds) 37 | 38 | 39 | def f1_score(items): 40 | unzipped_list = list(zip(*items)) 41 | golds = unzipped_list[0] 42 | preds = unzipped_list[1] 43 | fscore = sklearn.metrics.f1_score(golds, preds) 44 | 45 | return np.max(fscore) 46 | 47 | 48 | def acc_all(items): 49 | # Only count as correct if all answers are labeled correctly for each question 50 | question_scoring_dict = {} 51 | preds = list(zip(*items))[0] 52 | docs = list(zip(*items))[1] 53 | 54 | for doc, pred in zip(docs, preds): 55 | paragraph_id = doc["idx"]["paragraph"] 56 | question_id = doc["idx"]["question"] 57 | if (paragraph_id, question_id) not in question_scoring_dict: 58 | question_scoring_dict[(paragraph_id, question_id)] = [] 59 | 60 | gold_label = doc["label"] == 1 61 | 62 | question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) 63 | acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) 64 | return acc 65 | 66 | 67 | def acc_all_stderr(items): 68 | # Only count as correct if all answers are labeled correctly for each question 69 | question_scoring_dict = {} 70 | preds = list(zip(*items))[0] 71 | docs = list(zip(*items))[1] 72 | 73 | for doc, pred in zip(docs, preds): 74 | question_id = doc["idx"]["question"] 75 | if question_id not in question_scoring_dict: 76 | question_scoring_dict[question_id] = [] 77 | 78 | gold_label = doc["label"] == 1 79 | question_scoring_dict[question_id].append(gold_label == pred) 80 | 81 | acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()]) 82 | return acc 83 | 84 | 85 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 86 | """Compute max metric between prediction and each ground truth.""" 87 | scores_for_ground_truths = [] 88 | for ground_truth in ground_truths: 89 | score = metric_fn(prediction, ground_truth) 90 | scores_for_ground_truths.append(score) 91 | return max(scores_for_ground_truths) 92 | 93 | 94 | def perplexity(items): 95 | return math.exp(-mean(items)) 96 | 97 | 98 | def weighted_mean(items): 99 | a, b = zip(*items) 100 | return sum(a) / sum(b) 101 | 102 | 103 | def weighted_perplexity(items): 104 | return math.exp(-weighted_mean(items)) 105 | 106 | 107 | def bits_per_byte(items): 108 | return -weighted_mean(items) / math.log(2) 109 | 110 | 111 | def bleu(items): 112 | """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric 113 | for evaluating a generated sentence to a reference sentence. It counts matching 114 | n-grams in the candidate translation to n-grams in the reference text, where 115 | 1-gram or unigram would be each token and a bigram comparison would be each 116 | word pair. The comparison is made regardless of word order 117 | Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ 118 | Paper: https://www.aclweb.org/anthology/P02-1040/ 119 | 120 | Higher is better 121 | """ 122 | refs = list(zip(*items))[0] 123 | preds = list(zip(*items))[1] 124 | refs, preds = _sacreformat(refs, preds) 125 | return sacrebleu.corpus_bleu(preds, refs).score 126 | 127 | 128 | def chrf(items): 129 | """chrF++ is a tool for automatic evaluation of machine translation output 130 | based on character n-gram precision and recall enhanced with word n-grams. 131 | Source: https://github.com/m-popovic/chrF 132 | Paper: https://www.aclweb.org/anthology/W15-3049.pdf 133 | 134 | Higher is better # TODO I think 135 | """ 136 | refs = list(zip(*items))[0] 137 | preds = list(zip(*items))[1] 138 | refs, preds = _sacreformat(refs, preds) 139 | return sacrebleu.corpus_chrf(preds, refs).score 140 | 141 | 142 | def ter(items): 143 | """Translation Error Rate is an error metric for machine translation that 144 | measures the number of edits required to change a system output into one 145 | of the references 146 | Source: http://www.cs.umd.edu/~snover/tercom/ 147 | Paper: http://mt-archive.info/AMTA-2006-Snover.pdf 148 | 149 | Lower is better 150 | """ 151 | refs = list(zip(*items))[0] 152 | preds = list(zip(*items))[1] 153 | refs, preds = _sacreformat(refs, preds) 154 | return sacrebleu.corpus_ter(preds, refs).score 155 | 156 | 157 | def is_non_str_iterable(obj): 158 | return isinstance(obj, Iterable) and not isinstance(obj, str) 159 | 160 | 161 | def _sacreformat(refs, preds): 162 | """Format refs and preds for sacrebleu corpus calculation. It is very particular""" 163 | # Sacrebleu expects (List[str], List[List[str]) 164 | # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) 165 | 166 | # Note [ref1_stream] is the first reference for each pred. 167 | # So lists are size N and (M, N) for N preds and M possible refs for each pred 168 | # This is a different order of dimensions that I would expect 169 | 170 | # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds 171 | # Must become List[List[str]] with the inner list corresponding to preds 172 | if not is_non_str_iterable(refs): 173 | refs = list(refs) 174 | if not is_non_str_iterable(refs[0]): 175 | refs = [[ref] for ref in refs] 176 | refs = list(zip(*refs)) 177 | # Note the number of refs in each ref list much match the number of preds 178 | 179 | # We expect preds to be List[str] or List[List[str]]. Must become List[str] 180 | if not is_non_str_iterable(preds): 181 | preds = list(preds) 182 | if is_non_str_iterable(preds[0]): 183 | assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" 184 | preds = [pred[0] for pred in preds] 185 | 186 | return refs, preds 187 | 188 | 189 | # stderr stuff 190 | 191 | 192 | class _bootstrap_internal: 193 | def __init__(self, f, n): 194 | self.f = f 195 | self.n = n 196 | 197 | def __call__(self, v): 198 | i, xs = v 199 | rnd = random.Random() 200 | rnd.seed(i) 201 | res = [] 202 | for _ in range(self.n): 203 | res.append(self.f(rnd.choices(xs, k=len(xs)))) 204 | return res 205 | 206 | 207 | def bootstrap_stderr(f, xs, iters): 208 | import multiprocessing as mp 209 | 210 | pool = mp.Pool(mp.cpu_count()) 211 | # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something 212 | # equivalent to stderr calculated without Bessel's correction in the stddev. 213 | # Unfortunately, I haven't been able to figure out what the right correction is 214 | # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but 215 | # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) 216 | # Thankfully, shouldn't matter because our samples are pretty big usually anyways 217 | res = [] 218 | chunk_size = min(1000, iters) 219 | from tqdm import tqdm 220 | 221 | print("bootstrapping for stddev:", f.__name__) 222 | for bootstrap in tqdm( 223 | pool.imap( 224 | _bootstrap_internal(f, chunk_size), 225 | [(i, xs) for i in range(iters // chunk_size)], 226 | ), 227 | total=iters // chunk_size, 228 | ): 229 | # sample w replacement 230 | res.extend(bootstrap) 231 | 232 | pool.close() 233 | return sample_stddev(res) 234 | 235 | 236 | def stderr_for_metric(metric, bootstrap_iters): 237 | bootstrappable = [ 238 | median, 239 | matthews_corrcoef, 240 | f1_score, 241 | perplexity, 242 | bleu, 243 | chrf, 244 | ter, 245 | ] 246 | 247 | if metric in bootstrappable: 248 | return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) 249 | 250 | stderr = {mean: mean_stderr, acc_all: acc_all_stderr} 251 | 252 | return stderr.get(metric, None) 253 | 254 | 255 | def yesno(x): 256 | if x: 257 | return "yes" 258 | else: 259 | return "no" 260 | -------------------------------------------------------------------------------- /zeroShot/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import opt 2 | from . import bloom 3 | 4 | MODEL_REGISTRY = { 5 | 'opt': opt.OPT, 6 | 'bloom': bloom.BLOOM 7 | } 8 | 9 | 10 | def get_model(model_name): 11 | if 'opt' in model_name: 12 | return MODEL_REGISTRY['opt'] 13 | elif 'bloom' in model_name: 14 | return MODEL_REGISTRY['bloom'] 15 | return MODEL_REGISTRY[model_name] 16 | -------------------------------------------------------------------------------- /zeroShot/models/bloom.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | from .models_utils import BaseLM, find_layers 4 | from transformers import BloomForCausalLM, AutoTokenizer 5 | import torch.nn.functional as F 6 | from torch import nn 7 | import torch 8 | from tqdm import tqdm 9 | from .quant import * 10 | from .gptq import GPTQ 11 | 12 | 13 | class BLOOMClass(BaseLM): 14 | def __init__(self, args): 15 | 16 | super().__init__() 17 | 18 | self.args = args 19 | self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | self.model_name = args.model 21 | self.batch_size_per_gpu = args.batch_size 22 | 23 | self.model = BloomForCausalLM.from_pretrained(self.model_name, torch_dtype='auto') 24 | self.model.eval() 25 | self.seqlen = 2048 26 | 27 | # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 28 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) 29 | 30 | self.vocab_size = self.tokenizer.vocab_size 31 | print('BLOOM vocab size: ', self.vocab_size) 32 | 33 | @property 34 | def eot_token_id(self): 35 | # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* 36 | return self.tokenizer.eos_token_id 37 | 38 | @property 39 | def max_length(self): 40 | return 2048 41 | @property 42 | def max_gen_toks(self): 43 | print('max_gen_toks fn') 44 | return 256 45 | 46 | @property 47 | def batch_size(self): 48 | # TODO: fix multi-gpu 49 | return self.batch_size_per_gpu # * gpus 50 | 51 | @property 52 | def device(self): 53 | # TODO: fix multi-gpu 54 | return self._device 55 | 56 | def tok_encode(self, string: str): 57 | return self.tokenizer.encode(string, add_special_tokens=False) 58 | 59 | def tok_decode(self, tokens): 60 | return self.tokenizer.decode(tokens) 61 | 62 | def _model_call(self, inps): 63 | """ 64 | inps: a torch tensor of shape [batch, sequence] 65 | the size of sequence may vary from call to call 66 | returns: a torch tensor of shape [batch, sequence, vocab] with the 67 | logits returned from the model 68 | """ 69 | with torch.no_grad(): 70 | return self.model(inps)[0][:, :, :250680] 71 | 72 | @torch.no_grad() 73 | def _model_logits_on_dataset(self, dataset_inps): 74 | dataset_logits = [] 75 | nsamples = len(dataset_inps) 76 | 77 | dev = self.device 78 | 79 | model = self.model 80 | 81 | print('Evaluation...') 82 | 83 | use_cache = model.config.use_cache 84 | model.config.use_cache = False 85 | layers = model.transformer.h 86 | 87 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(dev) 88 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(dev) 89 | layers[0] = layers[0].to(dev) 90 | 91 | dtype = next(iter(model.parameters())).dtype 92 | inps = [] 93 | outs = [] 94 | 95 | for batch_idx, batch in enumerate(dataset_inps): 96 | inps.append(torch.zeros( 97 | (batch.shape[1], self.model.config.hidden_size), dtype=dtype, 98 | )) 99 | outs.append(torch.zeros( 100 | (batch.shape[1], self.model.config.hidden_size), dtype=dtype, 101 | )) 102 | 103 | cache = {'i': 0, 'attention_masks': [], 'alibis': []} 104 | 105 | class Catcher(nn.Module): 106 | def __init__(self, module): 107 | super().__init__() 108 | self.module = module 109 | 110 | def forward(self, inp, **kwargs): 111 | inps[cache['i']] = inp 112 | cache['i'] += 1 113 | cache['attention_masks'].append(kwargs['attention_mask'].detach().cpu()) 114 | cache['alibis'].append(kwargs['alibi'].detach().cpu()) 115 | raise ValueError 116 | 117 | layers[0] = Catcher(layers[0]) 118 | for i in range(nsamples): 119 | batch = dataset_inps[i].to(dev) 120 | try: 121 | model(batch) 122 | except ValueError: 123 | pass 124 | layers[0] = layers[0].module 125 | 126 | layers[0] = layers[0].cpu() 127 | model.transformer.word_embeddings = model.transformer.word_embeddings.cpu() 128 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.cpu() 129 | torch.cuda.empty_cache() 130 | 131 | attention_masks = cache['attention_masks'] 132 | alibis = cache['alibis'] 133 | 134 | for i in range(len(layers)): 135 | print(i) 136 | layer = layers[i].to(dev) 137 | 138 | if self.args.nearest: 139 | subset = find_layers(layer) 140 | for name in subset: 141 | quantizer = Quantizer() 142 | quantizer.configure( 143 | self.args.wbits, perchannel=True, sym=False, mse=False 144 | ) 145 | W = subset[name].weight.data 146 | quantizer.find_params(W, weight=True) 147 | subset[name].weight.data = quantize( 148 | W, quantizer.scale, quantizer.zero, quantizer.maxq 149 | ).to(next(iter(layer.parameters())).dtype) 150 | 151 | for j in range(nsamples): 152 | outs[j] = layer(inps[j].to(self.device), 153 | attention_mask=attention_masks[j].to(self.device), 154 | alibi=alibis[j].to(self.device))[0].detach().cpu() 155 | 156 | layers[i] = layer.cpu() 157 | del layer 158 | torch.cuda.empty_cache() 159 | inps, outs = outs, inps 160 | 161 | model.transformer.ln_f = model.transformer.ln_f.to(dev) 162 | model.lm_head = model.lm_head.to(dev) 163 | 164 | for i in tqdm(range(nsamples), desc='Last Layer'): 165 | hidden_states = inps[i].unsqueeze(0).to(self.device) 166 | hidden_states = self.model.transformer.ln_f(hidden_states) 167 | batch_logits = F.log_softmax(self.model.lm_head(hidden_states)[0][:, :, :250680], dim=-1).cpu() 168 | dataset_logits.append(batch_logits) 169 | 170 | model.config.use_cache = use_cache 171 | return dataset_logits 172 | 173 | @torch.no_grad() 174 | def _model_logits_on_dataset2(self, dataset_inps): 175 | dataset_logits = [] 176 | nbatches = len(dataset_inps) 177 | 178 | use_cache = self.model.config.use_cache 179 | self.model.config.use_cache = False 180 | layers = self.model.transformer.h 181 | 182 | self.model.transformer.word_embeddings = self.model.transformer.word_embeddings.to(self.device) 183 | self.model.transformer.word_embeddings_layernorm = self.model.transformer.word_embeddings_layernorm.to( 184 | self.device) 185 | layers[0] = layers[0].to(self.device) 186 | 187 | dtype = next(iter(self.model.parameters())).dtype 188 | 189 | 190 | inps = [] 191 | outs = [] 192 | for batch_idx, batch in enumerate(dataset_inps): 193 | inps.append(torch.zeros( 194 | (batch.shape[1], self.model.config.hidden_size), dtype=dtype, 195 | )) 196 | outs.append(torch.zeros( 197 | (batch.shape[1], self.model.config.hidden_size), dtype=dtype, 198 | )) 199 | 200 | cache = {'i': 0, 'attention_masks': [], 'alibi': []} 201 | 202 | class Catcher(nn.Module): 203 | def __init__(self, module): 204 | super().__init__() 205 | self.module = module 206 | 207 | def forward(self, inp, **kwargs): 208 | inps[cache['i']] = inp.cpu() 209 | cache['i'] += 1 210 | cache['attention_masks'].append(kwargs['attention_mask'].detach().cpu()) 211 | cache['alibi'].append(kwargs['alibi'].detach().cpu()) 212 | raise ValueError 213 | 214 | layers[0] = Catcher(layers[0]) 215 | for i in range(nbatches): 216 | batch = dataset_inps[i].to(self.device) 217 | try: 218 | self.model(batch) 219 | except ValueError: 220 | pass 221 | layers[0] = layers[0].module 222 | 223 | layers[0] = layers[0].cpu() 224 | self.model.transformer.word_embeddings = self.model.transformer.word_embeddings.cpu() 225 | self.model.transformer.word_embeddings_layernorm = self.model.transformer.word_embeddings_layernorm.cpu() 226 | torch.cuda.empty_cache() # TODO: maybe we don't need this? 227 | 228 | attention_masks = cache['attention_masks'] 229 | alibis = cache['alibi'] 230 | 231 | for i in range(len(layers)): 232 | print('layer: ', i) 233 | layer = layers[i].to(self.device) 234 | 235 | if self.args.wbits < 32 and self.args.nearest: 236 | subset = find_layers(layer) 237 | for name in subset: 238 | if 'lm_head' in name: 239 | continue 240 | quantizer = Quantizer() 241 | quantizer.configure( 242 | self.args.wbits, 243 | perchannel=True, sym=False, mse=False, norm=2.4 244 | ) 245 | W = subset[name].weight.data 246 | quantizer.find_params(W, weight=True) 247 | subset[name].weight.data = quantize( 248 | W, quantizer.scale, quantizer.zero, quantizer.maxq 249 | ).to(next(iter(layer.parameters())).dtype) 250 | 251 | 252 | for j in range(nbatches): 253 | outs[j] = layer(inps[j].to(self.device), 254 | attention_mask=attention_masks[j].to(self.device), 255 | alibi=alibis[j].to(self.device))[0].detach().cpu() 256 | layers[i] = layer.cpu() 257 | del layer 258 | torch.cuda.empty_cache() 259 | inps, outs = outs, inps 260 | 261 | self.model.transformer.ln_f = self.model.transformer.ln_f.to(self.device) 262 | self.model.lm_head = self.model.lm_head.to(self.device) 263 | 264 | for i in tqdm(range(nbatches), desc='Last Layer'): 265 | hidden_states = inps[i].unsqueeze(0).to(self.device) 266 | hidden_states = self.model.transformer.ln_f(hidden_states) 267 | batch_logits = F.log_softmax(self.model.lm_head(hidden_states)[0][:, :, :250680], dim=-1).cpu() 268 | dataset_logits.append(batch_logits) 269 | 270 | return dataset_logits 271 | 272 | def _model_logits_on_dataset_2(self, inps): 273 | # import pdb;pdb.set_trace() 274 | self.model = self.model.to(self.device) 275 | dataset_logits = [] 276 | for batch in inps: 277 | multi_logits = F.log_softmax( 278 | self._model_call(batch), dim=-1 279 | ).cpu() # [batch, padding_length, vocab] 280 | dataset_logits.append(multi_logits) 281 | return dataset_logits 282 | 283 | 284 | def _model_generate(self, context, max_length, eos_token_id): 285 | return self.model.generate( 286 | context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False 287 | ) 288 | 289 | @torch.no_grad() 290 | def bloom_sequential(self, dataloader): 291 | print('Starting ...') 292 | 293 | model = self.model 294 | dev = self.device 295 | 296 | use_cache = model.config.use_cache 297 | model.config.use_cache = False 298 | layers = model.transformer.h 299 | 300 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(dev) 301 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(dev) 302 | layers[0] = layers[0].to(dev) 303 | 304 | dtype = next(iter(model.parameters())).dtype 305 | inps = torch.zeros( 306 | (self.args.nsamples, self.seqlen, model.config.hidden_size), dtype=dtype, device=dev 307 | ) 308 | cache = {'i': 0, 'attention_mask': None, 'alibi': None} 309 | 310 | class Catcher(nn.Module): 311 | def __init__(self, module): 312 | super().__init__() 313 | self.module = module 314 | 315 | def forward(self, inp, **kwargs): 316 | inps[cache['i']] = inp 317 | cache['i'] += 1 318 | cache['attention_mask'] = kwargs['attention_mask'] 319 | cache['alibi'] = kwargs['alibi'] 320 | raise ValueError 321 | 322 | layers[0] = Catcher(layers[0]) 323 | for batch in dataloader: 324 | try: 325 | model(batch[0].to(dev)) 326 | except ValueError: 327 | pass 328 | layers[0] = layers[0].module 329 | 330 | layers[0] = layers[0].cpu() 331 | model.transformer.word_embeddings = model.transformer.word_embeddings.cpu() 332 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.cpu() 333 | torch.cuda.empty_cache() 334 | 335 | outs = torch.zeros_like(inps) 336 | attention_mask = cache['attention_mask'] 337 | alibi = cache['alibi'] 338 | 339 | print('Ready.') 340 | 341 | for i in range(len(layers)): 342 | layer = layers[i].to(dev) 343 | 344 | subset = find_layers(layer) 345 | gptq = {} 346 | for name in subset: 347 | gptq[name] = GPTQ(subset[name]) 348 | gptq[name].quantizer = Quantizer() 349 | gptq[name].quantizer.configure( 350 | self.args.wbits, perchannel=True, sym=False, mse=False 351 | ) 352 | 353 | def add_batch(name): 354 | def tmp(_, inp, out): 355 | gptq[name].add_batch(inp[0].data, out.data) 356 | 357 | return tmp 358 | 359 | handles = [] 360 | for name in subset: 361 | handles.append(subset[name].register_forward_hook(add_batch(name))) 362 | for j in range(self.args.nsamples): 363 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, alibi=alibi)[0] 364 | for h in handles: 365 | h.remove() 366 | 367 | for name in subset: 368 | print(i, name) 369 | print('Quantizing ...') 370 | gptq[name].fasterquant(percdamp=self.args.percdamp, groupsize=self.args.groupsize) 371 | for j in range(self.args.nsamples): 372 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, alibi=alibi)[0] 373 | 374 | layers[i] = layer.cpu() 375 | del gptq 376 | torch.cuda.empty_cache() 377 | 378 | inps, outs = outs, inps 379 | 380 | model.config.use_cache = use_cache 381 | 382 | 383 | # for backwards compatibility 384 | BLOOM = BLOOMClass -------------------------------------------------------------------------------- /zeroShot/models/fast_trueobs.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from quant import * 9 | 10 | 11 | DEBUG = False 12 | 13 | torch.backends.cuda.matmul.allow_tf32 = False 14 | torch.backends.cudnn.allow_tf32 = False 15 | 16 | 17 | class TrueOBS: 18 | 19 | def __init__(self, layer): 20 | self.layer = layer 21 | self.dev = self.layer.weight.device 22 | W = layer.weight.data.clone() 23 | if isinstance(self.layer, nn.Conv2d): 24 | W = W.flatten(1) 25 | if isinstance(self.layer, transformers.Conv1D): 26 | W = W.t() 27 | self.rows = W.shape[0] 28 | self.columns = W.shape[1] 29 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 30 | self.nsamples = 0 31 | 32 | def add_batch(self, inp, out): 33 | if DEBUG: 34 | self.inp1 = inp 35 | self.out1 = out 36 | if len(inp.shape) == 2: # TODO: may not work for convnets 37 | inp = inp.unsqueeze(0) 38 | tmp = inp.shape[0] 39 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 40 | if len(inp.shape) == 3: 41 | inp = inp.reshape((-1, inp.shape[-1])) 42 | inp = inp.t() 43 | if isinstance(self.layer, nn.Conv2d): 44 | unfold = nn.Unfold( 45 | self.layer.kernel_size, 46 | dilation=self.layer.dilation, 47 | padding=self.layer.padding, 48 | stride=self.layer.stride 49 | ) 50 | inp = unfold(inp) 51 | inp = inp.permute([1, 0, 2]) 52 | inp = inp.flatten(1) 53 | self.H *= self.nsamples / (self.nsamples + tmp) 54 | self.nsamples += tmp 55 | # inp = inp.float() 56 | inp = math.sqrt(2 / self.nsamples) * inp.float() 57 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 58 | self.H += inp.matmul(inp.t()) 59 | 60 | def fasterquant( 61 | self, blocksize=128, percdamp=.01, sparseout=False, nearest=False 62 | ): 63 | W = self.layer.weight.data.clone() 64 | if isinstance(self.layer, nn.Conv2d): 65 | W = W.flatten(1) 66 | if isinstance(self.layer, transformers.Conv1D): 67 | W = W.t() 68 | W = W.float() 69 | 70 | tick = time.time() 71 | 72 | if not self.quantizer.ready(): 73 | self.quantizer.find_params(W, weight=True) 74 | 75 | if False: 76 | H = self.H 77 | dead = torch.diag(H) == 0 78 | H[dead, dead] = 1 79 | W[:, dead] = 0 80 | 81 | Losses = torch.zeros_like(W) 82 | Q = torch.zeros_like(W) 83 | 84 | damp = percdamp * torch.mean(torch.diag(H)) 85 | # diag = torch.arange(self.columns, device=self.dev) 86 | # H[diag, diag] += damp 87 | H += damp * torch.eye(self.columns, device=self.dev) 88 | Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) 89 | Hinv = torch.linalg.cholesky(Hinv, upper=True) 90 | else: 91 | H = self.H 92 | del self.H 93 | dead = torch.diag(H) == 0 94 | H[dead, dead] = 1 95 | W[:, dead] = 0 96 | 97 | Losses = torch.zeros_like(W) 98 | Q = torch.zeros_like(W) 99 | 100 | damp = percdamp * torch.mean(torch.diag(H)) 101 | diag = torch.arange(self.columns, device=self.dev) 102 | H[diag, diag] += damp 103 | H = torch.linalg.cholesky(H) 104 | H = torch.cholesky_inverse(H) 105 | H = torch.linalg.cholesky(H, upper=True) 106 | Hinv = H 107 | 108 | outlier = .25 * (self.quantizer.scale ** 2).flatten() 109 | tot = 0 110 | 111 | for i1 in range(0, self.columns, blocksize): 112 | i2 = min(i1 + blocksize, self.columns) 113 | count = i2 - i1 114 | 115 | W1 = W[:, i1:i2].clone() 116 | Q1 = torch.zeros_like(W1) 117 | Err1 = torch.zeros_like(W1) 118 | Losses1 = torch.zeros_like(W1) 119 | Hinv1 = Hinv[i1:i2, i1:i2] 120 | 121 | for i in range(count): 122 | w = W1[:, i] 123 | d = Hinv1[i, i] 124 | 125 | # if (i1 + i) % 512 == 0: 126 | # self.quantizer.find_params(W[:, (i1 + i):(i1 + i + 512)], weight=True) 127 | 128 | q = quantize( 129 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 130 | ).flatten() 131 | Q1[:, i] = q 132 | Losses1[:, i] = (w - q) ** 2 / d ** 2 133 | 134 | if sparseout: 135 | sel = (w - q) ** 2 > outlier 136 | Losses1[sel, i] = 0 137 | q[sel] = w[sel] 138 | Q1[sel, i] = q[sel] 139 | tot += torch.sum(sel.int()).item() 140 | 141 | err1 = (w - q) / d 142 | if not nearest: 143 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 144 | Err1[:, i] = err1 145 | 146 | Q[:, i1:i2] = Q1 147 | Losses[:, i1:i2] = Losses1 / 2 148 | 149 | if not nearest: 150 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 151 | 152 | if DEBUG: 153 | self.layer.weight.data[:, :i2] = Q[:, :i2] 154 | self.layer.weight.data[:, i2:] = W[:, i2:] 155 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 156 | print(torch.sum(Losses)) 157 | 158 | torch.cuda.synchronize() 159 | print(tot / W.numel()) 160 | print('time %.2f' % (time.time() - tick)) 161 | print('error', torch.sum(Losses).item()) 162 | 163 | if isinstance(self.layer, transformers.Conv1D): 164 | Q = Q.t() 165 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 166 | if DEBUG: 167 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 168 | 169 | def free(self): 170 | if DEBUG: 171 | self.inp1 = None 172 | self.out1 = None 173 | self.H = None 174 | self.Losses = None 175 | self.Trace = None 176 | torch.cuda.empty_cache() 177 | 178 | 179 | def print_mem(): 180 | t = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 181 | r = torch.cuda.memory_reserved(0) / 1024 ** 3 182 | a = torch.cuda.memory_allocated(0) / 1024 ** 3 183 | print(t, r, a) 184 | -------------------------------------------------------------------------------- /zeroShot/models/gptq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from .quant import * 9 | 10 | 11 | DEBUG = False 12 | 13 | torch.backends.cuda.matmul.allow_tf32 = False 14 | torch.backends.cudnn.allow_tf32 = False 15 | 16 | 17 | class GPTQ: 18 | 19 | def __init__(self, layer): 20 | self.layer = layer 21 | self.dev = self.layer.weight.device 22 | W = layer.weight.data.clone() 23 | if isinstance(self.layer, nn.Conv2d): 24 | W = W.flatten(1) 25 | if isinstance(self.layer, transformers.Conv1D): 26 | W = W.t() 27 | self.rows = W.shape[0] 28 | self.columns = W.shape[1] 29 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 30 | self.nsamples = 0 31 | 32 | def add_batch(self, inp, out): 33 | if DEBUG: 34 | self.inp1 = inp 35 | self.out1 = out 36 | if len(inp.shape) == 2: 37 | inp = inp.unsqueeze(0) 38 | tmp = inp.shape[0] 39 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 40 | if len(inp.shape) == 3: 41 | inp = inp.reshape((-1, inp.shape[-1])) 42 | inp = inp.t() 43 | if isinstance(self.layer, nn.Conv2d): 44 | unfold = nn.Unfold( 45 | self.layer.kernel_size, 46 | dilation=self.layer.dilation, 47 | padding=self.layer.padding, 48 | stride=self.layer.stride 49 | ) 50 | inp = unfold(inp) 51 | inp = inp.permute([1, 0, 2]) 52 | inp = inp.flatten(1) 53 | self.H *= self.nsamples / (self.nsamples + tmp) 54 | self.nsamples += tmp 55 | # inp = inp.float() 56 | inp = math.sqrt(2 / self.nsamples) * inp.float() 57 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 58 | self.H += inp.matmul(inp.t()) 59 | 60 | def fasterquant( 61 | self, blocksize=128, percdamp=.01, groupsize=-1 62 | ): 63 | W = self.layer.weight.data.clone() 64 | if isinstance(self.layer, nn.Conv2d): 65 | W = W.flatten(1) 66 | if isinstance(self.layer, transformers.Conv1D): 67 | W = W.t() 68 | W = W.float() 69 | 70 | tick = time.time() 71 | 72 | if not self.quantizer.ready(): 73 | self.quantizer.find_params(W, weight=True) 74 | 75 | H = self.H 76 | del self.H 77 | dead = torch.diag(H) == 0 78 | H[dead, dead] = 1 79 | W[:, dead] = 0 80 | 81 | Losses = torch.zeros_like(W) 82 | Q = torch.zeros_like(W) 83 | 84 | damp = percdamp * torch.mean(torch.diag(H)) 85 | diag = torch.arange(self.columns, device=self.dev) 86 | H[diag, diag] += damp 87 | H = torch.linalg.cholesky(H) 88 | H = torch.cholesky_inverse(H) 89 | H = torch.linalg.cholesky(H, upper=True) 90 | Hinv = H 91 | 92 | for i1 in range(0, self.columns, blocksize): 93 | i2 = min(i1 + blocksize, self.columns) 94 | count = i2 - i1 95 | 96 | W1 = W[:, i1:i2].clone() 97 | Q1 = torch.zeros_like(W1) 98 | Err1 = torch.zeros_like(W1) 99 | Losses1 = torch.zeros_like(W1) 100 | Hinv1 = Hinv[i1:i2, i1:i2] 101 | 102 | for i in range(count): 103 | w = W1[:, i] 104 | d = Hinv1[i, i] 105 | 106 | if groupsize != -1: 107 | if (i1 + i) % groupsize == 0: 108 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) 109 | 110 | q = quantize( 111 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 112 | ).flatten() 113 | Q1[:, i] = q 114 | Losses1[:, i] = (w - q) ** 2 / d ** 2 115 | 116 | err1 = (w - q) / d 117 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 118 | Err1[:, i] = err1 119 | 120 | Q[:, i1:i2] = Q1 121 | Losses[:, i1:i2] = Losses1 / 2 122 | 123 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 124 | 125 | if DEBUG: 126 | self.layer.weight.data[:, :i2] = Q[:, :i2] 127 | self.layer.weight.data[:, i2:] = W[:, i2:] 128 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 129 | print(torch.sum(Losses)) 130 | 131 | torch.cuda.synchronize() 132 | print('time %.2f' % (time.time() - tick)) 133 | print('error', torch.sum(Losses).item()) 134 | 135 | if isinstance(self.layer, transformers.Conv1D): 136 | Q = Q.t() 137 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 138 | if DEBUG: 139 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 140 | 141 | def free(self): 142 | if DEBUG: 143 | self.inp1 = None 144 | self.out1 = None 145 | self.H = None 146 | self.Losses = None 147 | self.Trace = None 148 | torch.cuda.empty_cache() 149 | -------------------------------------------------------------------------------- /zeroShot/models/opt.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | from .models_utils import BaseLM, find_layers 4 | from transformers import OPTForCausalLM, AutoTokenizer 5 | import torch.nn.functional as F 6 | from torch import nn 7 | import torch 8 | from tqdm import tqdm 9 | from .quant import * 10 | from .gptq import GPTQ 11 | 12 | 13 | class OPTClass(BaseLM): 14 | def __init__(self, args): 15 | 16 | super().__init__() 17 | 18 | self.args = args 19 | self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | self.model_name = args.model 21 | self.batch_size_per_gpu = args.batch_size 22 | 23 | self.model = OPTForCausalLM.from_pretrained(self.model_name, torch_dtype='auto') 24 | self.seqlen = self.model.config.max_position_embeddings 25 | self.model.eval() 26 | 27 | # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 28 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) 29 | self.vocab_size = self.tokenizer.vocab_size 30 | print('OPT vocab size: ', self.vocab_size) 31 | 32 | @property 33 | def eot_token_id(self): 34 | # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* 35 | return self.tokenizer.eos_token_id 36 | 37 | @property 38 | def max_length(self): 39 | try: 40 | return self.gpt2.config.n_ctx 41 | except AttributeError: 42 | # gptneoconfig doesn't have n_ctx apparently 43 | return self.model.config.max_position_embeddings 44 | @property 45 | def max_gen_toks(self): 46 | print('max_gen_toks fn') 47 | return 256 48 | 49 | @property 50 | def batch_size(self): 51 | # TODO: fix multi-gpu 52 | return self.batch_size_per_gpu # * gpus 53 | 54 | @property 55 | def device(self): 56 | # TODO: fix multi-gpu 57 | return self._device 58 | 59 | def tok_encode(self, string: str): 60 | return self.tokenizer.encode(string, add_special_tokens=False) 61 | 62 | def tok_decode(self, tokens): 63 | return self.tokenizer.decode(tokens) 64 | 65 | def _model_call(self, inps): 66 | """ 67 | inps: a torch tensor of shape [batch, sequence] 68 | the size of sequence may vary from call to call 69 | returns: a torch tensor of shape [batch, sequence, vocab] with the 70 | logits returned from the model 71 | """ 72 | with torch.no_grad(): 73 | return self.model(inps)[0][:, :, :50272] 74 | 75 | @torch.no_grad() 76 | def _model_logits_on_dataset(self, dataset_inps): 77 | print('Evaluating ...') 78 | 79 | nsamples = len(dataset_inps) 80 | 81 | model = self.model 82 | dev = self.device 83 | 84 | use_cache = model.config.use_cache 85 | model.config.use_cache = False 86 | layers = model.model.decoder.layers 87 | 88 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 89 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 90 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 91 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 92 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 93 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 94 | layers[0] = layers[0].to(dev) 95 | 96 | dtype = next(iter(model.parameters())).dtype 97 | inps = [] 98 | outs = [] 99 | for batch_idx, batch in enumerate(dataset_inps): 100 | inps.append(torch.zeros( 101 | (batch.shape[1], self.model.config.hidden_size), dtype=dtype, 102 | )) 103 | outs.append(torch.zeros( 104 | (batch.shape[1], self.model.config.hidden_size), dtype=dtype, 105 | )) 106 | 107 | cache = {'i': 0, 'attention_masks': []} 108 | 109 | class Catcher(nn.Module): 110 | def __init__(self, module): 111 | super().__init__() 112 | self.module = module 113 | 114 | def forward(self, inp, **kwargs): 115 | inps[cache['i']] = inp 116 | cache['i'] += 1 117 | cache['attention_masks'].append(kwargs['attention_mask'].detach().cpu()) 118 | raise ValueError 119 | 120 | layers[0] = Catcher(layers[0]) 121 | for i in range(nsamples): 122 | batch = dataset_inps[i].to(dev) 123 | try: 124 | model(batch) 125 | except ValueError: 126 | pass 127 | layers[0] = layers[0].module 128 | 129 | layers[0] = layers[0].cpu() 130 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 131 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 132 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 133 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 134 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 135 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 136 | torch.cuda.empty_cache() 137 | 138 | attention_masks = cache['attention_masks'] 139 | 140 | for i in range(len(layers)): 141 | print(i) 142 | layer = layers[i].to(dev) 143 | 144 | if self.args.nearest: 145 | subset = find_layers(layer) 146 | for name in subset: 147 | quantizer = Quantizer() 148 | quantizer.configure( 149 | self.args.wbits, perchannel=True, sym=False, mse=False 150 | ) 151 | W = subset[name].weight.data 152 | quantizer.find_params(W, weight=True) 153 | subset[name].weight.data = quantize( 154 | W, quantizer.scale, quantizer.zero, quantizer.maxq 155 | ).to(next(iter(layer.parameters())).dtype) 156 | 157 | for j in range(nsamples): 158 | outs[j] = layer(inps[j].to(self.device), attention_mask=attention_masks[j].to(self.device))[0].detach().cpu() 159 | 160 | layers[i] = layer.cpu() 161 | del layer 162 | torch.cuda.empty_cache() 163 | inps, outs = outs, inps 164 | 165 | if model.model.decoder.final_layer_norm is not None: 166 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 167 | if model.model.decoder.project_out is not None: 168 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 169 | model.lm_head = model.lm_head.to(dev) 170 | 171 | 172 | if self.model.model.decoder.final_layer_norm is not None: 173 | self.model.model.decoder.final_layer_norm = self.model.model.decoder.final_layer_norm.to(self.device) 174 | if self.model.model.decoder.project_out is not None: 175 | self.model.model.decoder.project_out = self.model.model.decoder.project_out.to(self.device) 176 | self.model.lm_head = self.model.lm_head.to(self.device) 177 | 178 | dataset_logits = [] 179 | 180 | for i in tqdm(range(nsamples), desc='Last Layer'): 181 | hidden_states = inps[i].unsqueeze(0).to(self.device) 182 | if self.model.model.decoder.final_layer_norm is not None: 183 | hidden_states = self.model.model.decoder.final_layer_norm(hidden_states) 184 | if self.model.model.decoder.project_out is not None: 185 | hidden_states = self.model.model.decoder.project_out(hidden_states) 186 | batch_logits = F.log_softmax(self.model.lm_head(hidden_states)[0][:, :, :50272], dim=-1).cpu() 187 | dataset_logits.append(batch_logits) 188 | model.config.use_cache = use_cache 189 | return dataset_logits 190 | 191 | 192 | def model_batched_set(self, inps): 193 | import pdb;pdb.set_trace() 194 | dataset_logits = [] 195 | for batch in inps: 196 | multi_logits = F.log_softmax( 197 | self._model_call(batch), dim=-1 198 | ).cpu() # [batch, padding_length, vocab] 199 | dataset_logits.append(multi_logits) 200 | return dataset_logits 201 | 202 | 203 | def _model_generate(self, context, max_length, eos_token_id): 204 | return self.model.generate( 205 | context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False 206 | ) 207 | 208 | @torch.no_grad() 209 | def opt_sequential(self, dataloader): 210 | print('Starting ...') 211 | 212 | model = self.model 213 | dev = self.device 214 | 215 | use_cache = model.config.use_cache 216 | model.config.use_cache = False 217 | layers = model.model.decoder.layers 218 | 219 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 220 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 221 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 222 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 223 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 224 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 225 | layers[0] = layers[0].to(dev) 226 | 227 | dtype = next(iter(model.parameters())).dtype 228 | inps = torch.zeros( 229 | (self.args.nsamples, self.seqlen, model.config.hidden_size), dtype=dtype, device=dev 230 | ) 231 | cache = {'i': 0, 'attention_mask': None} 232 | 233 | class Catcher(nn.Module): 234 | def __init__(self, module): 235 | super().__init__() 236 | self.module = module 237 | 238 | def forward(self, inp, **kwargs): 239 | inps[cache['i']] = inp 240 | cache['i'] += 1 241 | cache['attention_mask'] = kwargs['attention_mask'] 242 | raise ValueError 243 | 244 | layers[0] = Catcher(layers[0]) 245 | for batch in dataloader: 246 | try: 247 | model(batch[0].to(dev)) 248 | except ValueError: 249 | pass 250 | layers[0] = layers[0].module 251 | 252 | layers[0] = layers[0].cpu() 253 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 254 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 255 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 256 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 257 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 258 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 259 | torch.cuda.empty_cache() 260 | 261 | outs = torch.zeros_like(inps) 262 | attention_mask = cache['attention_mask'] 263 | 264 | print('Ready.') 265 | 266 | quantizers = {} 267 | for i in range(len(layers)): 268 | layer = layers[i].to(dev) 269 | 270 | subset = find_layers(layer) 271 | gptq = {} 272 | for name in subset: 273 | gptq[name] = GPTQ(subset[name]) 274 | gptq[name].quantizer = Quantizer() 275 | gptq[name].quantizer.configure( 276 | self.args.wbits, perchannel=True, sym=False, mse=False 277 | ) 278 | 279 | def add_batch(name): 280 | def tmp(_, inp, out): 281 | gptq[name].add_batch(inp[0].data, out.data) 282 | 283 | return tmp 284 | 285 | handles = [] 286 | for name in subset: 287 | handles.append(subset[name].register_forward_hook(add_batch(name))) 288 | for j in range(self.args.nsamples): 289 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 290 | for h in handles: 291 | h.remove() 292 | 293 | for name in subset: 294 | print(i, name) 295 | print('Quantizing ...') 296 | gptq[name].fasterquant(percdamp=self.args.percdamp, groupsize=self.args.groupsize) 297 | quantizers['model.decoder.layers.%d.%s' % (i, name)] = gptq[name].quantizer 298 | gptq[name].free() 299 | for j in range(self.args.nsamples): 300 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 301 | 302 | layers[i] = layer.cpu() 303 | del layer 304 | del gptq 305 | torch.cuda.empty_cache() 306 | 307 | inps, outs = outs, inps 308 | 309 | model.config.use_cache = use_cache 310 | 311 | return quantizers 312 | 313 | 314 | # for backwards compatibility 315 | OPT = OPTClass -------------------------------------------------------------------------------- /zeroShot/models/quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | import quant_cuda 6 | except: 7 | print('CUDA extension not installed.') 8 | 9 | 10 | def quantize(x, scale, zero, maxq): 11 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 12 | return scale * (q - zero) 13 | 14 | class Quantizer(nn.Module): 15 | 16 | def __init__(self, shape=1): 17 | super(Quantizer, self).__init__() 18 | self.register_buffer('maxq', torch.tensor(0)) 19 | self.register_buffer('scale', torch.zeros(shape)) 20 | self.register_buffer('zero', torch.zeros(shape)) 21 | 22 | def configure( 23 | self, 24 | bits, perchannel=False, sym=True, 25 | mse=False, norm=2.4, grid=100, maxshrink=.8 26 | ): 27 | self.maxq = torch.tensor(2 ** bits - 1) 28 | self.perchannel = perchannel 29 | self.sym = sym 30 | self.mse = mse 31 | self.norm = norm 32 | self.grid = grid 33 | self.maxshrink = maxshrink 34 | 35 | def find_params(self, x, weight=False): 36 | dev = x.device 37 | self.maxq = self.maxq.to(dev) 38 | 39 | shape = x.shape 40 | if self.perchannel: 41 | if weight: 42 | x = x.flatten(1) 43 | else: 44 | if len(shape) == 4: 45 | x = x.permute([1, 0, 2, 3]) 46 | x = x.flatten(1) 47 | if len(shape) == 3: 48 | x = x.reshape((-1, shape[-1])).t() 49 | if len(shape) == 2: 50 | x = x.t() 51 | else: 52 | x = x.flatten().unsqueeze(0) 53 | 54 | tmp = torch.zeros(x.shape[0], device=dev) 55 | xmin = torch.minimum(x.min(1)[0], tmp) 56 | xmax = torch.maximum(x.max(1)[0], tmp) 57 | 58 | if self.sym: 59 | xmax = torch.maximum(torch.abs(xmin), xmax) 60 | tmp = xmin < 0 61 | if torch.any(tmp): 62 | xmin[tmp] = -xmax[tmp] 63 | tmp = (xmin == 0) & (xmax == 0) 64 | xmin[tmp] = -1 65 | xmax[tmp] = +1 66 | 67 | self.scale = (xmax - xmin) / self.maxq 68 | if self.sym: 69 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 70 | else: 71 | self.zero = torch.round(-xmin / self.scale) 72 | 73 | if self.mse: 74 | best = torch.full([x.shape[0]], float('inf'), device=dev) 75 | for i in range(int(self.maxshrink * self.grid)): 76 | p = 1 - i / self.grid 77 | xmin1 = p * xmin 78 | xmax1 = p * xmax 79 | scale1 = (xmax1 - xmin1) / self.maxq 80 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 81 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 82 | q -= x 83 | q.abs_() 84 | q.pow_(self.norm) 85 | err = torch.sum(q, 1) 86 | tmp = err < best 87 | if torch.any(tmp): 88 | best[tmp] = err[tmp] 89 | self.scale[tmp] = scale1[tmp] 90 | self.zero[tmp] = zero1[tmp] 91 | if not self.perchannel: 92 | if weight: 93 | tmp = shape[0] 94 | else: 95 | tmp = shape[1] if len(shape) != 3 else shape[2] 96 | self.scale = self.scale.repeat(tmp) 97 | self.zero = self.zero.repeat(tmp) 98 | 99 | if weight: 100 | shape = [-1] + [1] * (len(shape) - 1) 101 | # self.scale = self.scale.unsqueeze(1) 102 | # self.zero = self.zero.unsqueeze(1) 103 | self.scale = self.scale.reshape(shape) 104 | self.zero = self.zero.reshape(shape) 105 | return 106 | if len(shape) == 4: 107 | self.scale = self.scale.reshape((1, -1, 1, 1)) 108 | self.zero = self.zero.reshape((1, -1, 1, 1)) 109 | if len(shape) == 3: 110 | self.scale = self.scale.reshape((1, 1, -1)) 111 | self.zero = self.zero.reshape((1, 1, -1)) 112 | if len(shape) == 2: 113 | self.scale = self.scale.unsqueeze(0) 114 | self.zero = self.zero.unsqueeze(0) 115 | 116 | def quantize(self, x): 117 | if self.ready(): 118 | return quantize(x, self.scale, self.zero, self.maxq) 119 | return x 120 | 121 | def enabled(self): 122 | return self.maxq > 0 123 | 124 | def ready(self): 125 | return torch.all(self.scale != 0) 126 | 127 | class ActQuantWrapper(nn.Module): 128 | 129 | def __init__(self, module): 130 | super(ActQuantWrapper, self).__init__() 131 | self.module = module 132 | shape = [1] * len(self.module.weight.shape) 133 | if len(shape) == 4: 134 | shape[1] = self.module.weight.shape[1] 135 | if len(shape) == 3: 136 | shape[2] = self.module.weight.shape[2] 137 | if len(shape) == 2: 138 | shape[1] = self.module.weight.shape[1] 139 | self.quantizer = Quantizer(shape=shape) 140 | 141 | def forward(self, x): 142 | return self.module(self.quantizer.quantize(x)) 143 | 144 | def add_actquant(module, name='', layers=[nn.Conv2d, nn.Linear]): 145 | if isinstance(module, ActQuantWrapper): 146 | return 147 | for attr in dir(module): 148 | tmp = getattr(module, attr) 149 | if type(tmp) in layers: 150 | setattr(module, attr, ActQuantWrapper(tmp)) 151 | if type(tmp) == nn.Sequential: 152 | replaced = [] 153 | for i, child in enumerate(tmp.children()): 154 | if type(child) in layers: 155 | replaced.append(ActQuantWrapper(child)) 156 | else: 157 | replaced.append(child) 158 | setattr(module, attr, nn.Sequential(*replaced)) 159 | if type(tmp) == torch.nn.ModuleList: 160 | replaced = [] 161 | for i, child in enumerate(tmp.children()): 162 | if type(child) in layers: 163 | replaced.append(ActQuantWrapper(child)) 164 | else: 165 | replaced.append(child) 166 | setattr(module, attr, nn.ModuleList(replaced)) 167 | for name1, child in module.named_children(): 168 | add_actquant(child, name + '.' + name1 if name != '' else name1, layers) 169 | 170 | import time 171 | 172 | class Quant4Linear(nn.Module): 173 | 174 | def __init__(self, linear, scales, zeros): 175 | super().__init__() 176 | self.register_buffer('zeros', zeros.clone() * scales) 177 | self.register_buffer('scales', scales) 178 | self.register_buffer('bias', linear.bias.data) 179 | intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int) 180 | intweight = intweight.t().contiguous() 181 | self.register_buffer('qweight', torch.zeros( 182 | (intweight.shape[0] // 8, intweight.shape[1]), dtype=torch.int, device=self.bias.device 183 | )) 184 | for i in range(intweight.shape[0]): 185 | self.qweight[i // 8] |= intweight[i] << (4 * (i % 8)) 186 | # self.linear = linear.to(torch.device('cuda:0')) 187 | 188 | def forward(self, x): 189 | if x.shape[-1] == x.numel(): 190 | outshape = list(x.shape) 191 | y = self.bias.clone() 192 | outshape[-1] = self.bias.numel() 193 | quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.zeros) 194 | # y = self.linear(x) 195 | return y.reshape(outshape) 196 | print(x.shape) 197 | raise ValueError('Only supports a single token currently.') 198 | 199 | def make_quant4(module, quantizers, name=''): 200 | if isinstance(module, Quant4Linear): 201 | return 202 | for attr in dir(module): 203 | tmp = getattr(module, attr) 204 | name1 = name + '.' + attr if name != '' else attr 205 | if name1 in quantizers: 206 | setattr( 207 | module, attr, 208 | Quant4Linear(tmp, quantizers[name1].scale, quantizers[name1].zero) 209 | ) 210 | for name1, child in module.named_children(): 211 | make_quant4(child, quantizers, name + '.' + name1 if name != '' else name1) 212 | -------------------------------------------------------------------------------- /zeroShot/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | from typing import List, Union 3 | from .tasks_utils import Task 4 | from . import piqa 5 | from . import arc 6 | from . import superglue 7 | from .local_datasets import lambada as lambada_dataset 8 | from .lambada import LAMBADA 9 | from . import glue 10 | from . import storycloze 11 | 12 | # TODO: Add the rest of the results! 13 | ######################################## 14 | # All tasks 15 | ######################################## 16 | 17 | 18 | TASK_REGISTRY = { 19 | "lambada": LAMBADA, 20 | "piqa": piqa.PiQA, 21 | "arc_easy": arc.ARCEasy, 22 | "arc_challenge": arc.ARCChallenge, 23 | "boolq": superglue.BoolQ, 24 | "cb": superglue.CommitmentBank, 25 | "copa": superglue.Copa, 26 | "wic": superglue.WordsInContext, 27 | "multirc": superglue.MultiRC, 28 | "rte": glue.RTE, 29 | "record": superglue.ReCoRD, 30 | "wsc": superglue.SGWinogradSchemaChallenge, 31 | "storycloze": storycloze.StoryCloze2018 32 | } 33 | 34 | ALL_TASKS = sorted(list(TASK_REGISTRY)) 35 | 36 | 37 | def get_task(task_name): 38 | try: 39 | return TASK_REGISTRY[task_name] 40 | except KeyError: 41 | print("Available tasks:") 42 | pprint(TASK_REGISTRY) 43 | raise KeyError(f"Missing task {task_name}") 44 | 45 | 46 | def get_task_name_from_object(task_object): 47 | for name, class_ in TASK_REGISTRY.items(): 48 | if class_ is task_object: 49 | return name 50 | 51 | # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting 52 | return ( 53 | task_object.EVAL_HARNESS_NAME 54 | if hasattr(task_object, "EVAL_HARNESS_NAME") 55 | else type(task_object).__name__ 56 | ) 57 | 58 | 59 | def get_task_dict(task_name_list: List[Union[str, Task]]): 60 | task_name_dict = { 61 | task_name: get_task(task_name)() 62 | for task_name in task_name_list 63 | if isinstance(task_name, str) 64 | } 65 | task_name_from_object_dict = { 66 | get_task_name_from_object(task_object): task_object 67 | for task_object in task_name_list 68 | if not isinstance(task_object, str) 69 | } 70 | assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) 71 | return {**task_name_dict, **task_name_from_object_dict} 72 | -------------------------------------------------------------------------------- /zeroShot/tasks/arc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge 3 | https://arxiv.org/pdf/1803.05457.pdf 4 | 5 | The ARC dataset consists of 7,787 science exam questions drawn from a variety 6 | of sources, including science questions provided under license by a research 7 | partner affiliated with AI2. These are text-only, English language exam questions 8 | that span several grade levels as indicated in the files. Each question has a 9 | multiple choice structure (typically 4 answer options). The questions are sorted 10 | into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and 11 | a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions. 12 | 13 | Homepage: https://allenai.org/data/arc 14 | """ 15 | from .tasks_utils import MultipleChoiceTask 16 | 17 | 18 | _CITATION = """ 19 | @article{Clark2018ThinkYH, 20 | title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge}, 21 | author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord}, 22 | journal={ArXiv}, 23 | year={2018}, 24 | volume={abs/1803.05457} 25 | } 26 | """ 27 | 28 | 29 | class ARCEasy(MultipleChoiceTask): 30 | VERSION = 0 31 | DATASET_PATH = "ai2_arc" 32 | DATASET_NAME = "ARC-Easy" 33 | 34 | def has_training_docs(self): 35 | return True 36 | 37 | def has_validation_docs(self): 38 | return True 39 | 40 | def has_test_docs(self): 41 | return True 42 | 43 | def training_docs(self): 44 | if self._training_docs is None: 45 | self._training_docs = list(map(self._process_doc, self.dataset["train"])) 46 | return self._training_docs 47 | 48 | def validation_docs(self): 49 | return map(self._process_doc, self.dataset["validation"]) 50 | 51 | def test_docs(self): 52 | return map(self._process_doc, self.dataset["test"]) 53 | 54 | def _process_doc(self, doc): 55 | # NOTE: Some `doc["answerKey"]`s are in numeric string format being one 56 | # of {'1', '2', '3', '4', '5'}. We map them back to letters. 57 | num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"} 58 | doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"]) 59 | out_doc = { 60 | "id": doc["id"], 61 | "query": "Question: " + doc["question"] + "\nAnswer:", 62 | "choices": doc["choices"]["text"], 63 | "gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]), 64 | } 65 | return out_doc 66 | 67 | def doc_to_text(self, doc): 68 | return doc["query"] 69 | 70 | def should_decontaminate(self): 71 | return True 72 | 73 | def doc_to_decontamination_query(self, doc): 74 | return doc["query"] 75 | 76 | 77 | class ARCChallenge(ARCEasy): 78 | DATASET_PATH = "ai2_arc" 79 | DATASET_NAME = "ARC-Challenge" 80 | -------------------------------------------------------------------------------- /zeroShot/tasks/lambada.py: -------------------------------------------------------------------------------- 1 | """ 2 | The LAMBADA dataset: Word prediction requiring a broad discourse context∗ 3 | https://arxiv.org/pdf/1606.06031.pdf 4 | 5 | LAMBADA is a dataset to evaluate the capabilities of computational models for text 6 | understanding by means of a word prediction task. LAMBADA is a collection of narrative 7 | passages sharing the characteristic that human subjects are able to guess their last 8 | word if they are exposed to the whole passage, but not if they only see the last 9 | sentence preceding the target word. To succeed on LAMBADA, computational models 10 | cannot simply rely on local context, but must be able to keep track of information 11 | in the broader discourse. 12 | 13 | Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI 14 | """ 15 | import inspect 16 | from .tasks_utils import Task, rf 17 | from .tasks_utils import mean, perplexity 18 | from .local_datasets import lambada 19 | 20 | _CITATION = """ 21 | @misc{ 22 | author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel}, 23 | title={The LAMBADA dataset}, 24 | DOI={10.5281/zenodo.2630551}, 25 | publisher={Zenodo}, 26 | year={2016}, 27 | month={Aug} 28 | } 29 | """ 30 | 31 | 32 | 33 | def preprocess(text): 34 | text = text.replace("“", '"') 35 | text = text.replace("”", '"') 36 | text = text.replace("''", '"') 37 | text = text.replace("``", '"') 38 | return '\n' + text.strip() 39 | 40 | 41 | class LAMBADA(Task): 42 | VERSION = 0 43 | DATASET_PATH = inspect.getfile(lambada.Lambada) 44 | 45 | 46 | def has_training_docs(self): 47 | return False 48 | 49 | def has_validation_docs(self): 50 | return True 51 | 52 | def has_test_docs(self): 53 | return False 54 | 55 | def training_docs(self): 56 | pass 57 | 58 | def validation_docs(self): 59 | return self.dataset["validation"] 60 | 61 | def test_docs(self): 62 | pass 63 | 64 | def doc_to_text(self, doc): 65 | return preprocess(doc["text"].strip()).rsplit(" ", 1)[0] 66 | 67 | def should_decontaminate(self): 68 | return True 69 | 70 | def doc_to_decontamination_query(self, doc): 71 | return doc["text"] 72 | 73 | def doc_to_target(self, doc): 74 | return " " + doc["text"].rsplit(" ", 1)[1] 75 | 76 | def construct_requests(self, doc, ctx): 77 | ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc)) 78 | 79 | return ll, is_greedy 80 | 81 | def process_results(self, doc, results): 82 | ll, is_greedy = results 83 | 84 | return {"ppl": ll, "acc": int(is_greedy)} 85 | 86 | def aggregation(self): 87 | return {"ppl": perplexity, "acc": mean} 88 | 89 | def higher_is_better(self): 90 | return {"ppl": False, "acc": True} 91 | -------------------------------------------------------------------------------- /zeroShot/tasks/local_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .lambada import lambada -------------------------------------------------------------------------------- /zeroShot/tasks/local_datasets/lambada/__init__.py: -------------------------------------------------------------------------------- 1 | from .lambada import Lambada -------------------------------------------------------------------------------- /zeroShot/tasks/local_datasets/lambada/dataset_infos.json: -------------------------------------------------------------------------------- 1 | {"original": {"description": "LAMBADA is a dataset to evaluate the capabilities of computational models for text\nunderstanding by means of a word prediction task. LAMBADA is a collection of narrative\ntexts sharing the characteristic that human subjects are able to guess their last\nword if they are exposed to the whole text, but not if they only see the last\nsentence preceding the target word. To succeed on LAMBADA, computational models\ncannot simply rely on local context, but must be able to keep track of information\nin the broader discourse.\n\nThe LAMBADA dataset", "citation": "@misc{\n author={Paperno, Denis and Kruszewski, Germ\u00e1n and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fern\u00e1ndez, Raquel}, \n title={The LAMBADA dataset},\n DOI={10.5281/zenodo.2630551},\n publisher={Zenodo},\n year={2016},\n month={Aug}\n}\n", "homepage": "https://zenodo.org/record/2630551#.X4Xzn5NKjUI", "license": "", "features": {"text": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "lambada", "config_name": "original", "version": {"version_str": "0.0.1", "description": null, "major": 0, "minor": 0, "patch": 1}, "splits": {"validation": {"name": "validation", "num_bytes": 1709449, "num_examples": 5153, "dataset_name": "lambada"}}, "download_checksums": {"http://eaidata.bmk.sh/data/lambada_test.jsonl": {"num_bytes": 1819752, "checksum": "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"}}, "download_size": 1819752, "post_processing_size": null, "dataset_size": 1709449, "size_in_bytes": 3529201}, "en": {"description": "LAMBADA is a dataset to evaluate the capabilities of computational models for text\nunderstanding by means of a word prediction task. LAMBADA is a collection of narrative\ntexts sharing the characteristic that human subjects are able to guess their last\nword if they are exposed to the whole text, but not if they only see the last\nsentence preceding the target word. To succeed on LAMBADA, computational models\ncannot simply rely on local context, but must be able to keep track of information\nin the broader discourse.\n\nThe English translated LAMBADA dataset", "citation": "@misc{\n author={Paperno, Denis and Kruszewski, Germ\u00e1n and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fern\u00e1ndez, Raquel}, \n title={The LAMBADA dataset},\n DOI={10.5281/zenodo.2630551},\n publisher={Zenodo},\n year={2016},\n month={Aug}\n}\n", "homepage": "https://zenodo.org/record/2630551#.X4Xzn5NKjUI", "license": "", "features": {"text": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "lambada", "config_name": "en", "version": {"version_str": "0.0.1", "description": null, "major": 0, "minor": 0, "patch": 1}, "splits": {"validation": {"name": "validation", "num_bytes": 1709449, "num_examples": 5153, "dataset_name": "lambada"}}, "download_checksums": {"http://eaidata.bmk.sh/data/lambada_test_en.jsonl": {"num_bytes": 1819752, "checksum": "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"}}, "download_size": 1819752, "post_processing_size": null, "dataset_size": 1709449, "size_in_bytes": 3529201}, "fr": {"description": "LAMBADA is a dataset to evaluate the capabilities of computational models for text\nunderstanding by means of a word prediction task. LAMBADA is a collection of narrative\ntexts sharing the characteristic that human subjects are able to guess their last\nword if they are exposed to the whole text, but not if they only see the last\nsentence preceding the target word. To succeed on LAMBADA, computational models\ncannot simply rely on local context, but must be able to keep track of information\nin the broader discourse.\n\nThe French translated LAMBADA dataset", "citation": "@misc{\n author={Paperno, Denis and Kruszewski, Germ\u00e1n and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fern\u00e1ndez, Raquel}, \n title={The LAMBADA dataset},\n DOI={10.5281/zenodo.2630551},\n publisher={Zenodo},\n year={2016},\n month={Aug}\n}\n", "homepage": "https://zenodo.org/record/2630551#.X4Xzn5NKjUI", "license": "", "features": {"text": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "lambada", "config_name": "fr", "version": {"version_str": "0.0.1", "description": null, "major": 0, "minor": 0, "patch": 1}, "splits": {"validation": {"name": "validation", "num_bytes": 1948795, "num_examples": 5153, "dataset_name": "lambada"}}, "download_checksums": {"http://eaidata.bmk.sh/data/lambada_test_fr.jsonl": {"num_bytes": 2028703, "checksum": "941ec6a73dba7dc91c860bf493eb66a527cd430148827a4753a4535a046bf362"}}, "download_size": 2028703, "post_processing_size": null, "dataset_size": 1948795, "size_in_bytes": 3977498}, "de": {"description": "LAMBADA is a dataset to evaluate the capabilities of computational models for text\nunderstanding by means of a word prediction task. LAMBADA is a collection of narrative\ntexts sharing the characteristic that human subjects are able to guess their last\nword if they are exposed to the whole text, but not if they only see the last\nsentence preceding the target word. To succeed on LAMBADA, computational models\ncannot simply rely on local context, but must be able to keep track of information\nin the broader discourse.\n\nThe German translated LAMBADA dataset", "citation": "@misc{\n author={Paperno, Denis and Kruszewski, Germ\u00e1n and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fern\u00e1ndez, Raquel}, \n title={The LAMBADA dataset},\n DOI={10.5281/zenodo.2630551},\n publisher={Zenodo},\n year={2016},\n month={Aug}\n}\n", "homepage": "https://zenodo.org/record/2630551#.X4Xzn5NKjUI", "license": "", "features": {"text": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "lambada", "config_name": "de", "version": {"version_str": "0.0.1", "description": null, "major": 0, "minor": 0, "patch": 1}, "splits": {"validation": {"name": "validation", "num_bytes": 1904576, "num_examples": 5153, "dataset_name": "lambada"}}, "download_checksums": {"http://eaidata.bmk.sh/data/lambada_test_de.jsonl": {"num_bytes": 1985231, "checksum": "51c6c1795894c46e88e4c104b5667f488efe79081fb34d746b82b8caa663865e"}}, "download_size": 1985231, "post_processing_size": null, "dataset_size": 1904576, "size_in_bytes": 3889807}, "it": {"description": "LAMBADA is a dataset to evaluate the capabilities of computational models for text\nunderstanding by means of a word prediction task. LAMBADA is a collection of narrative\ntexts sharing the characteristic that human subjects are able to guess their last\nword if they are exposed to the whole text, but not if they only see the last\nsentence preceding the target word. To succeed on LAMBADA, computational models\ncannot simply rely on local context, but must be able to keep track of information\nin the broader discourse.\n\nThe Italian translated LAMBADA dataset", "citation": "@misc{\n author={Paperno, Denis and Kruszewski, Germ\u00e1n and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fern\u00e1ndez, Raquel}, \n title={The LAMBADA dataset},\n DOI={10.5281/zenodo.2630551},\n publisher={Zenodo},\n year={2016},\n month={Aug}\n}\n", "homepage": "https://zenodo.org/record/2630551#.X4Xzn5NKjUI", "license": "", "features": {"text": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "lambada", "config_name": "it", "version": {"version_str": "0.0.1", "description": null, "major": 0, "minor": 0, "patch": 1}, "splits": {"validation": {"name": "validation", "num_bytes": 1813420, "num_examples": 5153, "dataset_name": "lambada"}}, "download_checksums": {"http://eaidata.bmk.sh/data/lambada_test_it.jsonl": {"num_bytes": 1894613, "checksum": "86654237716702ab74f42855ae5a78455c1b0e50054a4593fb9c6fcf7fad0850"}}, "download_size": 1894613, "post_processing_size": null, "dataset_size": 1813420, "size_in_bytes": 3708033}, "es": {"description": "LAMBADA is a dataset to evaluate the capabilities of computational models for text\nunderstanding by means of a word prediction task. LAMBADA is a collection of narrative\ntexts sharing the characteristic that human subjects are able to guess their last\nword if they are exposed to the whole text, but not if they only see the last\nsentence preceding the target word. To succeed on LAMBADA, computational models\ncannot simply rely on local context, but must be able to keep track of information\nin the broader discourse.\n\nThe Spanish translated LAMBADA dataset", "citation": "@misc{\n author={Paperno, Denis and Kruszewski, Germ\u00e1n and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fern\u00e1ndez, Raquel}, \n title={The LAMBADA dataset},\n DOI={10.5281/zenodo.2630551},\n publisher={Zenodo},\n year={2016},\n month={Aug}\n}\n", "homepage": "https://zenodo.org/record/2630551#.X4Xzn5NKjUI", "license": "", "features": {"text": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "lambada", "config_name": "es", "version": {"version_str": "0.0.1", "description": null, "major": 0, "minor": 0, "patch": 1}, "splits": {"validation": {"name": "validation", "num_bytes": 1821735, "num_examples": 5153, "dataset_name": "lambada"}}, "download_checksums": {"http://eaidata.bmk.sh/data/lambada_test_es.jsonl": {"num_bytes": 1902349, "checksum": "ffd760026c647fb43c67ce1bc56fd527937304b348712dce33190ea6caba6f9c"}}, "download_size": 1902349, "post_processing_size": null, "dataset_size": 1821735, "size_in_bytes": 3724084}} 2 | -------------------------------------------------------------------------------- /zeroShot/tasks/local_datasets/lambada/lambada.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # TODO: Address all TODOs and remove all explanatory comments 15 | """LAMBADA dataset.""" 16 | 17 | 18 | import json 19 | 20 | import datasets 21 | 22 | 23 | _CITATION = """\ 24 | @misc{ 25 | author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel}, 26 | title={The LAMBADA dataset}, 27 | DOI={10.5281/zenodo.2630551}, 28 | publisher={Zenodo}, 29 | year={2016}, 30 | month={Aug} 31 | } 32 | """ 33 | 34 | _DESCRIPTION = """\ 35 | LAMBADA is a dataset to evaluate the capabilities of computational models for text 36 | understanding by means of a word prediction task. LAMBADA is a collection of narrative 37 | texts sharing the characteristic that human subjects are able to guess their last 38 | word if they are exposed to the whole text, but not if they only see the last 39 | sentence preceding the target word. To succeed on LAMBADA, computational models 40 | cannot simply rely on local context, but must be able to keep track of information 41 | in the broader discourse. 42 | """ 43 | 44 | _HOMEPAGE = "https://zenodo.org/record/2630551#.X4Xzn5NKjUI" 45 | 46 | # TODO: Add the licence for the dataset here if you can find it 47 | _LICENSE = "" 48 | 49 | _URLS = { 50 | "original": "http://eaidata.bmk.sh/data/lambada_test.jsonl", 51 | "en": "http://eaidata.bmk.sh/data/lambada_test_en.jsonl", 52 | "fr": "http://eaidata.bmk.sh/data/lambada_test_fr.jsonl", 53 | "de": "http://eaidata.bmk.sh/data/lambada_test_de.jsonl", 54 | "it": "http://eaidata.bmk.sh/data/lambada_test_it.jsonl", 55 | "es": "http://eaidata.bmk.sh/data/lambada_test_es.jsonl", 56 | } 57 | 58 | 59 | class Lambada(datasets.GeneratorBasedBuilder): 60 | """LAMBADA is a dataset to evaluate the capabilities of computational models for text understanding by means of a word prediction task.""" 61 | 62 | VERSION = datasets.Version("0.0.1") 63 | 64 | BUILDER_CONFIGS = [ 65 | datasets.BuilderConfig( 66 | name="original", version=VERSION, description="The LAMBADA dataset" 67 | ), 68 | datasets.BuilderConfig( 69 | name="en", 70 | version=VERSION, 71 | description="The English translated LAMBADA dataset", 72 | ), 73 | datasets.BuilderConfig( 74 | name="fr", 75 | version=VERSION, 76 | description="The French translated LAMBADA dataset", 77 | ), 78 | datasets.BuilderConfig( 79 | name="de", 80 | version=VERSION, 81 | description="The German translated LAMBADA dataset", 82 | ), 83 | datasets.BuilderConfig( 84 | name="it", 85 | version=VERSION, 86 | description="The Italian translated LAMBADA dataset", 87 | ), 88 | datasets.BuilderConfig( 89 | name="es", 90 | version=VERSION, 91 | description="The Spanish translated LAMBADA dataset", 92 | ), 93 | ] 94 | 95 | DEFAULT_CONFIG_NAME = "original" 96 | 97 | def _info(self): 98 | features = datasets.Features( 99 | { 100 | "text": datasets.Value("string"), 101 | } 102 | ) 103 | return datasets.DatasetInfo( 104 | description=f"{_DESCRIPTION}\n{self.config.description}", 105 | features=features, 106 | homepage=_HOMEPAGE, 107 | license=_LICENSE, 108 | citation=_CITATION, 109 | ) 110 | 111 | def _split_generators(self, dl_manager): 112 | urls = _URLS[self.config.name] 113 | data_dir = dl_manager.download_and_extract(urls) 114 | return [ 115 | datasets.SplitGenerator( 116 | name=datasets.Split.VALIDATION, 117 | # These kwargs will be passed to _generate_examples 118 | gen_kwargs={ 119 | "filepath": data_dir, 120 | "split": "validation", 121 | }, 122 | ), 123 | ] 124 | 125 | # method parameters are unpacked from `gen_kwargs` as given in `_split_generators` 126 | def _generate_examples(self, filepath, split): 127 | with open(filepath, encoding="utf-8") as f: 128 | for key, row in enumerate(f): 129 | data = json.loads(row) 130 | yield key, {"text": data["text"]} 131 | -------------------------------------------------------------------------------- /zeroShot/tasks/piqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | PIQA: Reasoning about Physical Commonsense in Natural Language 3 | https://arxiv.org/pdf/1911.11641.pdf 4 | 5 | Physical Interaction: Question Answering (PIQA) is a physical commonsense 6 | reasoning and a corresponding benchmark dataset. PIQA was designed to investigate 7 | the physical knowledge of existing models. To what extent are current approaches 8 | actually learning about the world? 9 | 10 | Homepage: https://yonatanbisk.com/piqa/ 11 | """ 12 | from .tasks_utils import MultipleChoiceTask 13 | 14 | 15 | _CITATION = """ 16 | @inproceedings{Bisk2020, 17 | author = {Yonatan Bisk and Rowan Zellers and 18 | Ronan Le Bras and Jianfeng Gao 19 | and Yejin Choi}, 20 | title = {PIQA: Reasoning about Physical Commonsense in 21 | Natural Language}, 22 | booktitle = {Thirty-Fourth AAAI Conference on 23 | Artificial Intelligence}, 24 | year = {2020}, 25 | } 26 | """ 27 | 28 | 29 | class PiQA(MultipleChoiceTask): 30 | VERSION = 0 31 | DATASET_PATH = "piqa" 32 | DATASET_NAME = None 33 | 34 | def has_training_docs(self): 35 | return True 36 | 37 | def has_validation_docs(self): 38 | return True 39 | 40 | def has_test_docs(self): 41 | return False 42 | 43 | def training_docs(self): 44 | if self._training_docs is None: 45 | self._training_docs = list(map(self._process_doc, self.dataset["train"])) 46 | return self._training_docs 47 | 48 | def validation_docs(self): 49 | return map(self._process_doc, self.dataset["validation"]) 50 | 51 | def _process_doc(self, doc): 52 | out_doc = { 53 | "goal": doc["goal"], 54 | "choices": [doc["sol1"], doc["sol2"]], 55 | "gold": doc["label"], 56 | } 57 | return out_doc 58 | 59 | def doc_to_text(self, doc): 60 | return "Question: " + doc["goal"] + "\nAnswer:" 61 | 62 | def should_decontaminate(self): 63 | return True 64 | 65 | def doc_to_decontamination_query(self, doc): 66 | return doc["goal"] 67 | -------------------------------------------------------------------------------- /zeroShot/tasks/storycloze.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Corpus and Cloze Evaluation for Deeper Understanding of Commonsense Stories 3 | https://arxiv.org/pdf/1604.01696.pdf 4 | 'Story Cloze Test' (2018) is a commonsense reasoning framework for evaluating story 5 | understanding, story generation, and script learning. This test requires a system 6 | to choose the correct ending to a four-sentence story. 7 | Homepage: https://cs.rochester.edu/nlp/rocstories/ 8 | """ 9 | import numpy as np 10 | from .tasks_utils import Task, rf 11 | from .tasks_utils import mean 12 | 13 | _CITATION = """ 14 | @inproceedings{sharma-etal-2018-tackling, 15 | title = "Tackling the Story Ending Biases in The Story Cloze Test", 16 | author = "Sharma, Rishi and 17 | Allen, James and 18 | Bakhshandeh, Omid and 19 | Mostafazadeh, Nasrin", 20 | booktitle = "Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)", 21 | month = jul, 22 | year = "2018", 23 | address = "Melbourne, Australia", 24 | publisher = "Association for Computational Linguistics", 25 | url = "https://aclanthology.org/P18-2119", 26 | doi = "10.18653/v1/P18-2119", 27 | pages = "752--757", 28 | abstract = "The Story Cloze Test (SCT) is a recent framework for evaluating story comprehension and script learning. There have been a variety of models tackling the SCT so far. Although the original goal behind the SCT was to require systems to perform deep language understanding and commonsense reasoning for successful narrative understanding, some recent models could perform significantly better than the initial baselines by leveraging human-authorship biases discovered in the SCT dataset. In order to shed some light on this issue, we have performed various data analysis and analyzed a variety of top performing models presented for this task. Given the statistics we have aggregated, we have designed a new crowdsourcing scheme that creates a new SCT dataset, which overcomes some of the biases. We benchmark a few models on the new dataset and show that the top-performing model on the original SCT dataset fails to keep up its performance. Our findings further signify the importance of benchmarking NLP systems on various evolving test sets.", 29 | } 30 | """ 31 | 32 | 33 | class StoryCloze(Task): 34 | VERSION = 0 35 | DATASET_PATH = "story_cloze" 36 | DATASET_NAME = None 37 | 38 | def __init__(self, data_dir: str='tasks/local_datasets/storyCloze2018'): 39 | 40 | """ 41 | StoryCloze is not publicly available. You must download the data by 42 | following https://cs.rochester.edu/nlp/rocstories/ and pass the folder 43 | path into the `data_dir` arg. 44 | """ 45 | print("PLEASE MAKE SURE TO FILL THIS FORM BEFORE USING THE DATASET: https://cs.rochester.edu/nlp/rocstories/") 46 | super().__init__(data_dir=data_dir) 47 | 48 | def has_training_docs(self): 49 | return False 50 | 51 | def has_validation_docs(self): 52 | return True 53 | 54 | def has_test_docs(self): 55 | return False 56 | 57 | def training_docs(self): 58 | pass 59 | 60 | def validation_docs(self): 61 | return self.dataset["validation"] 62 | 63 | def test_docs(self): 64 | return self.dataset["test"] 65 | 66 | def doc_to_text(self, doc): 67 | return " ".join( 68 | [ 69 | doc["input_sentence_1"], 70 | doc["input_sentence_2"], 71 | doc["input_sentence_3"], 72 | doc["input_sentence_4"], 73 | ] 74 | ) 75 | 76 | def should_decontaminate(self): 77 | return True 78 | 79 | def doc_to_decontamination_query(self, doc): 80 | return " ".join( 81 | [ 82 | doc["input_sentence_1"], 83 | doc["input_sentence_2"], 84 | doc["input_sentence_3"], 85 | doc["input_sentence_4"], 86 | ] 87 | ) 88 | 89 | def doc_to_target(self, doc): 90 | clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]] 91 | # `- 1` because the `answer_right_ending` index is 1-based. 92 | return " " + clozes[doc["answer_right_ending"] - 1] 93 | 94 | def construct_requests(self, doc, ctx): 95 | """Uses RequestFactory to construct Requests and returns an iterable of 96 | Requests which will be sent to the LM. 97 | :param doc: 98 | The document as returned from training_docs, validation_docs, or test_docs. 99 | :param ctx: str 100 | The context string, generated by fewshot_context. This includes the natural 101 | language description, as well as the few shot examples, and the question 102 | part of the document for `doc`. 103 | """ 104 | clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]] 105 | lls = [rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in clozes] 106 | return lls 107 | 108 | def process_results(self, doc, results): 109 | """Take a single document and the LM results and evaluates, returning a 110 | dict where keys are the names of submetrics and values are the values of 111 | the metric for that one document 112 | :param doc: 113 | The document as returned from training_docs, validation_docs, or test_docs. 114 | :param results: 115 | The results of the requests created in construct_requests. 116 | """ 117 | gold = doc["answer_right_ending"] - 1 118 | acc = 1.0 if np.argmax(results) == gold else 0.0 119 | return {"acc": acc} 120 | 121 | def aggregation(self): 122 | """ 123 | :returns: {str: [float] -> float} 124 | A dictionary where keys are the names of submetrics and values are 125 | functions that aggregate a list of metrics 126 | """ 127 | return {"acc": mean} 128 | 129 | def higher_is_better(self): 130 | """ 131 | :returns: {str: bool} 132 | A dictionary where keys are the names of submetrics and values are 133 | whether a higher value of the submetric is better 134 | """ 135 | return {"acc": True} 136 | 137 | 138 | class StoryCloze2016(StoryCloze): 139 | DATASET_NAME = "2016" 140 | 141 | 142 | class StoryCloze2018(StoryCloze): 143 | DATASET_NAME = "2018" -------------------------------------------------------------------------------- /zeroShot/tasks/superglue.py: -------------------------------------------------------------------------------- 1 | """ 2 | SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems 3 | https://w4ngatang.github.io/static/papers/superglue.pdf 4 | 5 | SuperGLUE is a benchmark styled after GLUE with a new set of more difficult language 6 | understanding tasks. 7 | 8 | Homepage: https://super.gluebenchmark.com/ 9 | 10 | TODO: WSC requires free-form generation. 11 | """ 12 | import numpy as np 13 | import sklearn 14 | import re 15 | import transformers.data.metrics.squad_metrics as squad_metrics 16 | from .tasks_utils import Task, rf 17 | from .tasks_utils import mean, acc_all, metric_max_over_ground_truths, yesno 18 | 19 | 20 | def general_detokenize(string): 21 | string = string.replace(" n't", "n't") 22 | string = string.replace(" )", ")") 23 | string = string.replace("( ", "(") 24 | string = string.replace('" ', '"') 25 | string = string.replace(' "', '"') 26 | string = re.sub(r" (['.,])", r"\1", string) 27 | return string 28 | 29 | _CITATION = """ 30 | @inproceedings{NEURIPS2019_4496bf24, 31 | author = {Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel}, 32 | booktitle = {Advances in Neural Information Processing Systems}, 33 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 34 | pages = {}, 35 | publisher = {Curran Associates, Inc.}, 36 | title = {SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems}, 37 | url = {https://proceedings.neurips.cc/paper/2019/file/4496bf24afe7fab6f046bf4923da8de6-Paper.pdf}, 38 | volume = {32}, 39 | year = {2019} 40 | } 41 | """ 42 | 43 | 44 | class BoolQ(Task): 45 | VERSION = 1 46 | DATASET_PATH = "super_glue" 47 | DATASET_NAME = "boolq" 48 | 49 | def has_training_docs(self): 50 | return True 51 | 52 | def has_validation_docs(self): 53 | return True 54 | 55 | def has_test_docs(self): 56 | return False 57 | 58 | def training_docs(self): 59 | if self._training_docs is None: 60 | self._training_docs = list(self.dataset["train"]) 61 | return self._training_docs 62 | 63 | def validation_docs(self): 64 | return self.dataset["validation"] 65 | 66 | def doc_to_text(self, doc): 67 | return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:" 68 | 69 | def should_decontaminate(self): 70 | return True 71 | 72 | def doc_to_decontamination_query(self, doc): 73 | return doc["passage"] 74 | 75 | def doc_to_target(self, doc): 76 | return " " + yesno(doc["label"]) 77 | 78 | def construct_requests(self, doc, ctx): 79 | 80 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 81 | ll_no, _ = rf.loglikelihood(ctx, " no") 82 | 83 | return ll_yes, ll_no 84 | 85 | def process_results(self, doc, results): 86 | ll_yes, ll_no = results 87 | gold = doc["label"] 88 | 89 | acc = 1.0 if (ll_yes > ll_no) == gold else 0.0 90 | 91 | return {"acc": acc} 92 | 93 | def higher_is_better(self): 94 | return {"acc": True} 95 | 96 | def aggregation(self): 97 | return {"acc": mean} 98 | 99 | 100 | class CommitmentBank(Task): 101 | VERSION = 1 102 | DATASET_PATH = "super_glue" 103 | DATASET_NAME = "cb" 104 | 105 | def has_training_docs(self): 106 | return True 107 | 108 | def has_validation_docs(self): 109 | return True 110 | 111 | def has_test_docs(self): 112 | return False 113 | 114 | def training_docs(self): 115 | if self._training_docs is None: 116 | self._training_docs = list(self.dataset["train"]) 117 | return self._training_docs 118 | 119 | def validation_docs(self): 120 | return self.dataset["validation"] 121 | 122 | def doc_to_text(self, doc): 123 | return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format( 124 | doc["premise"], 125 | doc["hypothesis"], 126 | ) 127 | 128 | def doc_to_target(self, doc): 129 | # True = entailment 130 | # False = contradiction 131 | # Neither = neutral 132 | return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]]) 133 | 134 | def construct_requests(self, doc, ctx): 135 | ll_true, _ = rf.loglikelihood(ctx, " True") 136 | ll_false, _ = rf.loglikelihood(ctx, " False") 137 | ll_neither, _ = rf.loglikelihood(ctx, " Neither") 138 | 139 | return ll_true, ll_false, ll_neither 140 | 141 | def process_results(self, doc, results): 142 | gold = doc["label"] 143 | pred = np.argmax(results) 144 | acc = 1.0 if pred == gold else 0.0 145 | 146 | return {"acc": acc, "f1": (pred, gold)} 147 | 148 | def higher_is_better(self): 149 | return {"acc": True, "f1": True} 150 | 151 | @classmethod 152 | def cb_multi_fi(cls, items): 153 | preds, golds = zip(*items) 154 | preds = np.array(preds) 155 | golds = np.array(golds) 156 | f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0) 157 | f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1) 158 | f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2) 159 | avg_f1 = mean([f11, f12, f13]) 160 | return avg_f1 161 | 162 | def aggregation(self): 163 | return { 164 | "acc": mean, 165 | "f1": self.cb_multi_fi, 166 | } 167 | 168 | 169 | class Copa(Task): 170 | VERSION = 0 171 | DATASET_PATH = "super_glue" 172 | DATASET_NAME = "copa" 173 | 174 | def has_training_docs(self): 175 | return True 176 | 177 | def has_validation_docs(self): 178 | return True 179 | 180 | def has_test_docs(self): 181 | return False 182 | 183 | def training_docs(self): 184 | if self._training_docs is None: 185 | self._training_docs = list(self.dataset["train"]) 186 | return self._training_docs 187 | 188 | def validation_docs(self): 189 | return self.dataset["validation"] 190 | 191 | def doc_to_text(self, doc): 192 | # Drop the period 193 | connector = { 194 | "cause": "because", 195 | "effect": "therefore", 196 | }[doc["question"]] 197 | return doc["premise"].strip()[:-1] + f" {connector}" 198 | 199 | def doc_to_target(self, doc): 200 | correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"] 201 | # Connect the sentences 202 | return " " + self.convert_choice(correct_choice) 203 | 204 | def construct_requests(self, doc, ctx): 205 | choice1 = " " + self.convert_choice(doc["choice1"]) 206 | choice2 = " " + self.convert_choice(doc["choice2"]) 207 | 208 | ll_choice1, _ = rf.loglikelihood(ctx, choice1) 209 | ll_choice2, _ = rf.loglikelihood(ctx, choice2) 210 | 211 | return ll_choice1, ll_choice2 212 | 213 | def process_results(self, doc, results): 214 | gold = doc["label"] 215 | pred = np.argmax(results) 216 | acc = 1.0 if pred == gold else 0.0 217 | 218 | return {"acc": acc} 219 | 220 | def higher_is_better(self): 221 | return {"acc": True} 222 | 223 | def aggregation(self): 224 | return {"acc": mean} 225 | 226 | @staticmethod 227 | def convert_choice(choice): 228 | return choice[0].lower() + choice[1:] 229 | 230 | 231 | class MultiRC(Task): 232 | VERSION = 1 233 | DATASET_PATH = "super_glue" 234 | DATASET_NAME = "multirc" 235 | 236 | def has_training_docs(self): 237 | return True 238 | 239 | def has_validation_docs(self): 240 | return True 241 | 242 | def has_test_docs(self): 243 | return False 244 | 245 | def training_docs(self): 246 | if self._training_docs is None: 247 | self._training_docs = list(self.dataset["train"]) 248 | return self._training_docs 249 | 250 | def validation_docs(self): 251 | return self.dataset["validation"] 252 | 253 | def doc_to_text(self, doc): 254 | return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" 255 | 256 | def doc_to_target(self, doc): 257 | return " " + self.format_answer(answer=doc["answer"], label=doc["label"]) 258 | 259 | @staticmethod 260 | def format_answer(answer, label): 261 | label_str = "yes" if label else "no" 262 | return f"{answer}\nIs the answer correct? {label_str}" 263 | 264 | def construct_requests(self, doc, ctx): 265 | true_choice = self.format_answer(answer=doc["answer"], label=True) 266 | false_choice = self.format_answer(answer=doc["answer"], label=False) 267 | 268 | ll_true_choice, _ = rf.loglikelihood(ctx, f" {true_choice}") 269 | ll_false_choice, _ = rf.loglikelihood(ctx, f" {false_choice}") 270 | 271 | return ll_true_choice, ll_false_choice 272 | 273 | def process_results(self, doc, results): 274 | ll_true_choice, ll_false_choice = results 275 | pred = ll_true_choice > ll_false_choice 276 | return {"acc": (pred, doc)} 277 | 278 | def higher_is_better(self): 279 | return {"acc": True} 280 | 281 | def aggregation(self): 282 | return {"acc": acc_all} 283 | 284 | 285 | class ReCoRD(Task): 286 | VERSION = 0 287 | DATASET_PATH = "super_glue" 288 | DATASET_NAME = "record" 289 | 290 | def has_training_docs(self): 291 | return True 292 | 293 | def has_validation_docs(self): 294 | return True 295 | 296 | def has_test_docs(self): 297 | return False 298 | 299 | def training_docs(self): 300 | # In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing. 301 | # Each doc consists of multiple answer candidates, each of which is scored yes/no. 302 | if self._training_docs is None: 303 | self._training_docs = [] 304 | for doc in self.dataset["train"]: 305 | self._training_docs.append(self._process_doc(doc)) 306 | return self._training_docs 307 | 308 | def validation_docs(self): 309 | # See: training_docs 310 | for doc in self.dataset["validation"]: 311 | yield self._process_doc(doc) 312 | 313 | @classmethod 314 | def _process_doc(cls, doc): 315 | return { 316 | "passage": doc["passage"], 317 | "query": doc["query"], 318 | "entities": sorted(list(set(doc["entities"]))), 319 | "answers": sorted(list(set(doc["answers"]))), 320 | } 321 | 322 | def doc_to_text(self, doc): 323 | initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n") 324 | text = initial_text + "\n\n" 325 | for highlight in highlights: 326 | text += f" - {highlight}.\n" 327 | return text 328 | 329 | @classmethod 330 | def format_answer(cls, query, entity): 331 | return f" - {query}".replace("@placeholder", entity) 332 | 333 | def doc_to_target(self, doc): 334 | # We only output the first correct entity in a doc 335 | return self.format_answer(query=doc["query"], entity=doc["answers"][0]) 336 | 337 | def construct_requests(self, doc, ctx): 338 | requests = [ 339 | rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity)) 340 | for entity in doc["entities"] 341 | ] 342 | return requests 343 | 344 | def process_results(self, doc, results): 345 | # ReCoRD's evaluation is actually deceptively simple: 346 | # - Pick the maximum likelihood prediction entity 347 | # - Evaluate the accuracy and token F1 PER EXAMPLE 348 | # - Average over all examples 349 | max_idx = np.argmax(np.array([result[0] for result in results])) 350 | 351 | prediction = doc["entities"][max_idx] 352 | gold_label_set = doc["answers"] 353 | f1 = metric_max_over_ground_truths( 354 | squad_metrics.compute_f1, prediction, gold_label_set 355 | ) 356 | em = metric_max_over_ground_truths( 357 | squad_metrics.compute_exact, prediction, gold_label_set 358 | ) 359 | 360 | return { 361 | "f1": f1, 362 | "em": em, 363 | } 364 | 365 | def higher_is_better(self): 366 | return { 367 | "f1": True, 368 | "em": True, 369 | } 370 | 371 | def aggregation(self): 372 | return { 373 | "f1": mean, 374 | "em": mean, 375 | } 376 | 377 | 378 | class WordsInContext(Task): 379 | VERSION = 0 380 | DATASET_PATH = "super_glue" 381 | DATASET_NAME = "wic" 382 | 383 | def has_training_docs(self): 384 | return True 385 | 386 | def has_validation_docs(self): 387 | return True 388 | 389 | def has_test_docs(self): 390 | return False 391 | 392 | def training_docs(self): 393 | if self._training_docs is None: 394 | self._training_docs = list(self.dataset["train"]) 395 | return self._training_docs 396 | 397 | def validation_docs(self): 398 | return self.dataset["validation"] 399 | 400 | def doc_to_text(self, doc): 401 | return ( 402 | "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" 403 | " two sentences above?\nAnswer:".format( 404 | doc["sentence1"], 405 | doc["sentence2"], 406 | doc["sentence1"][doc["start1"] : doc["end1"]], 407 | ) 408 | ) 409 | 410 | def doc_to_target(self, doc): 411 | return " {}".format({0: "no", 1: "yes"}[doc["label"]]) 412 | 413 | def construct_requests(self, doc, ctx): 414 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 415 | ll_no, _ = rf.loglikelihood(ctx, " no") 416 | 417 | return ll_yes, ll_no 418 | 419 | def process_results(self, doc, results): 420 | ll_yes, ll_no = results 421 | gold = doc["label"] 422 | 423 | acc = 1.0 if (ll_yes > ll_no) == gold else 0.0 424 | 425 | return {"acc": acc} 426 | 427 | def higher_is_better(self): 428 | return {"acc": True} 429 | 430 | def aggregation(self): 431 | return {"acc": mean} 432 | 433 | 434 | class SGWinogradSchemaChallenge(Task): 435 | VERSION = 0 436 | # Note: This implementation differs from Fig G.32 because this is the SuperGLUE, 437 | # binary version of the task. 438 | DATASET_PATH = "super_glue" 439 | DATASET_NAME = "wsc" 440 | 441 | def has_training_docs(self): 442 | return True 443 | 444 | def has_validation_docs(self): 445 | return True 446 | 447 | def has_test_docs(self): 448 | return False 449 | 450 | def training_docs(self): 451 | if self.has_training_docs(): 452 | if self._training_docs is None: 453 | # GPT-3 Paper's format only uses positive examples for fewshot "training" 454 | self._training_docs = [ 455 | doc for doc in self.dataset["train"] if doc["label"] 456 | ] 457 | return self._training_docs 458 | 459 | def validation_docs(self): 460 | return self.dataset["validation"] 461 | 462 | def doc_to_text(self, doc): 463 | raw_passage = doc["text"] 464 | # NOTE: HuggingFace span indices are word-based not character-based. 465 | pre = " ".join(raw_passage.split()[: doc["span2_index"]]) 466 | post = raw_passage[len(pre) + len(doc["span2_text"]) + 1 :] 467 | passage = general_detokenize(pre + " *{}*".format(doc["span2_text"]) + post) 468 | noun = doc["span1_text"] 469 | pronoun = doc["span2_text"] 470 | text = ( 471 | f"Passage: {passage}\n" 472 | + f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n' 473 | + "Answer:" 474 | ) 475 | return text 476 | 477 | def doc_to_target(self, doc): 478 | return " " + yesno(doc["label"]) 479 | 480 | def construct_requests(self, doc, ctx): 481 | 482 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 483 | ll_no, _ = rf.loglikelihood(ctx, " no") 484 | 485 | return ll_yes, ll_no 486 | 487 | def process_results(self, doc, results): 488 | ll_yes, ll_no = results 489 | gold = doc["label"] 490 | 491 | acc = 1.0 if (ll_yes > ll_no) == gold else 0.0 492 | 493 | return {"acc": acc} 494 | 495 | def higher_is_better(self): 496 | return {"acc": True} 497 | 498 | def aggregation(self): 499 | return {"acc": mean} 500 | -------------------------------------------------------------------------------- /zeroShot/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import fnmatch 3 | import tasks 4 | import inspect 5 | import functools 6 | 7 | 8 | def positional_deprecated(fn): 9 | """ 10 | A decorator to nudge users into passing only keyword args (`kwargs`) to the 11 | wrapped function, `fn`. 12 | """ 13 | 14 | @functools.wraps(fn) 15 | def _wrapper(*args, **kwargs): 16 | if len(args) != 1 if inspect.ismethod(fn) else 0: 17 | print( 18 | f"WARNING: using {fn.__name__} with positional arguments is " 19 | "deprecated and will be disallowed in a future version of " 20 | "lm-evaluation-harness!" 21 | ) 22 | return fn(*args, **kwargs) 23 | 24 | return _wrapper 25 | 26 | 27 | class MultiChoice: 28 | def __init__(self, choices): 29 | self.choices = choices 30 | 31 | # Simple wildcard support (linux filename patterns) 32 | def __contains__(self, values): 33 | for value in values.split(","): 34 | if len(fnmatch.filter(self.choices, value)) == 0: 35 | return False 36 | 37 | return True 38 | 39 | def __iter__(self): 40 | for choice in self.choices: 41 | yield choice 42 | 43 | 44 | # Returns a list containing all values of the source_list that 45 | # match at least one of the patterns 46 | def pattern_match(patterns, source_list): 47 | task_names = set() 48 | for pattern in patterns: 49 | for matching in fnmatch.filter(source_list, pattern): 50 | task_names.add(matching) 51 | return list(task_names) 52 | 53 | 54 | def parse_args(): 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument( 57 | 'model', type=str, 58 | help='For OPT model to load; pass `facebook/opt-X`.\\ BLOOM model to load; pass `bigscience/bloom-X`' 59 | ) 60 | parser.add_argument( 61 | 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], 62 | help='Where to extract calibration data from.' 63 | ) 64 | parser.add_argument( 65 | '--nsamples', type=int, default=128, 66 | help='Number of calibration data samples.' 67 | ) 68 | parser.add_argument( 69 | '--percdamp', type=float, default=.01, 70 | help='Percent of the average Hessian diagonal to use for dampening.' 71 | ) 72 | parser.add_argument( 73 | '--groupsize', type=int, default=-1, 74 | help='Groupsize to use for quantization; default uses full row.' 75 | ) 76 | parser.add_argument( 77 | '--seed', 78 | type=int, default=2, help='Seed for sampling the calibration data.' 79 | ) 80 | parser.add_argument( 81 | '--table_results', action="store_true", help='Print results in a table.' 82 | ) 83 | 84 | parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS)) 85 | parser.add_argument("--num_fewshot", type=int, default=0) 86 | parser.add_argument("--output_path", default=None) 87 | parser.add_argument("--wbits", type=int, default=32) 88 | parser.add_argument("--nearest", action="store_true") 89 | parser.add_argument('--load', type=str, default='') 90 | 91 | args = parser.parse_args() 92 | args.batch_size = 1 # BS=1 is used for zeroShot tasks! 93 | 94 | return args 95 | --------------------------------------------------------------------------------