├── .gitignore ├── LICENSE ├── README.md ├── bloom.py ├── datautils.py ├── demo.ipynb ├── llama.py ├── modelutils.py ├── opt.py ├── quant.py └── sparsegpt.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparseGPT 2 | 3 | This repository contains code to reproduce the key results of the paper [SparseGPT: Massive Language Models Can be Accurately Pruned in One-shot](https://arxiv.org/abs/2301.00774). 4 | 5 | Specifically, it provides scripts and implementations to: 6 | 7 | * Evaluate baseline and pruned models on raw-WikiText2, PTB and C4-subset. (`datautils.py`, `opt.py`, `bloom.py`) 8 | * Perform unstructured, n:m and sparse + quantized SparseGPT compression on OPT and BLOOM models. (`sparsegpt.py`, `opt.py`, `bloom.py`) 9 | 10 | We note that this SparseGPT implementation is based on our open-source [GPTQ code](https://github.com/IST-DASLab/gptq). 11 | 12 | ## Dependencies 13 | 14 | * `torch`: tested on v1.10.1+cu111 15 | * `transformers`: tested on v4.21.2 16 | * `datasets`: tested on v1.17.0 17 | 18 | ## Usage 19 | 20 | Here are some sample commands to run baselines and sparsification on OPT models, followed by perplexity evaluations on raw-WikiText2, PTB and C4. 21 | See also the CMD-argument documentation. 22 | 23 | ``` 24 | # Run dense baseline 25 | python opt.py facebook/opt-125m c4 26 | 27 | # Run magnitude baseline 28 | python opt.py facebook/opt-125m c4 --sparsity .5 --gmp 29 | 30 | # Prune to 50\% uniform sparsity with SparseGPT 31 | python opt.py facebook/opt-125m c4 --sparsity .5 32 | 33 | # Prune to full 2:4 sparsity with SparseGPT 34 | python opt.py facebook/opt-125m c4 --prunen 2 --prunem 4 35 | 36 | # Prune to 50\% + 4-bit with SparseGPT 37 | python opt.py facebook/opt-125m c4 --sparsity .5 --wbits 4 38 | ``` 39 | 40 | To run on other OPT models, replace "facebook/opt-125m" by the HuggingFace name of the corresponding model. 41 | For the 175B model, access must first be requested from Meta and the checkpoint converted to HuggingFace format, then its location can simply be passed as a name to this script. 42 | 43 | The BLOOM script `bloom.py` has a very similar interface, however some features are currently only available for OPT, e.g.: 44 | 45 | ``` 46 | # Sparsify BLOOM-176B with SparseGPT 47 | python bloom.py bigscience/bloom c4 --sparsity .5 48 | ``` 49 | 50 | We also provide LLaMA pruning script with the very same interface: 51 | 52 | ``` 53 | # Sparsify LLaMa with SparseGPT 54 | python llama.py LLAMA_HF_WEIGHTS_LOCATION c4 --sparsity 0.5 55 | ``` 56 | 57 | In case one would like to save the sparsified model specify path to saved checkpoint via `--save` flag. 58 | 59 | One can optionally log evalution results to W&B with `--log_wandb`. 60 | 61 | ## Demo 62 | 63 | One can try SparseGPT via the colab demo - `demo.ipynb`. 64 | 65 | ## Cite 66 | 67 | If you found this work useful, please consider citing: 68 | 69 | ``` 70 | @article{frantar-sparsegpt, 71 | title={{SparseGPT}: Massive Language Models Can Be Accurately Pruned in One-Shot}, 72 | author={Elias Frantar and Dan Alistarh}, 73 | year={2023}, 74 | journal={arXiv preprint arXiv:2301.00774} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /bloom.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from sparsegpt import * 9 | from modelutils import * 10 | 11 | try: 12 | import wandb 13 | has_wandb = True 14 | except: 15 | has_wandb = False 16 | 17 | 18 | def get_bloom(model): 19 | import torch 20 | def skip(*args, **kwargs): 21 | pass 22 | torch.nn.init.kaiming_uniform_ = skip 23 | torch.nn.init.uniform_ = skip 24 | torch.nn.init.normal_ = skip 25 | from transformers import BloomForCausalLM 26 | model = BloomForCausalLM.from_pretrained(model, torch_dtype='auto') 27 | model.seqlen = 2048 28 | return model 29 | 30 | @torch.no_grad() 31 | def bloom_sequential(model, dataloader, dev, means=None, stds=None): 32 | print('Starting ...') 33 | 34 | use_cache = model.config.use_cache 35 | model.config.use_cache = False 36 | layers = model.transformer.h 37 | 38 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(dev) 39 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(dev) 40 | layers[0] = layers[0].to(dev) 41 | 42 | dtype = next(iter(model.parameters())).dtype 43 | inps = torch.zeros( 44 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 45 | ) 46 | cache = {'i': 0, 'attention_mask': None, 'alibi': None} 47 | 48 | class Catcher(nn.Module): 49 | def __init__(self, module): 50 | super().__init__() 51 | self.module = module 52 | def forward(self, inp, **kwargs): 53 | inps[cache['i']] = inp 54 | cache['i'] += 1 55 | cache['attention_mask'] = kwargs['attention_mask'] 56 | cache['alibi'] = kwargs['alibi'] 57 | raise ValueError 58 | layers[0] = Catcher(layers[0]) 59 | for batch in dataloader: 60 | try: 61 | model(batch[0].to(dev)) 62 | except ValueError: 63 | pass 64 | layers[0] = layers[0].module 65 | 66 | layers[0] = layers[0].cpu() 67 | model.transformer.word_embeddings = model.transformer.word_embeddings.cpu() 68 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.cpu() 69 | torch.cuda.empty_cache() 70 | 71 | outs = torch.zeros_like(inps) 72 | attention_mask = cache['attention_mask'] 73 | alibi = cache['alibi'] 74 | 75 | print('Ready.') 76 | 77 | for i in range(len(layers)): 78 | layer = layers[i].to(dev) 79 | 80 | subset = find_layers(layer) 81 | gpts = {} 82 | for name in subset: 83 | if (not (args.minlayer <= i < args.maxlayer and args.prune_only in name)) == (not args.invert): 84 | continue 85 | gpts[name] = SparseGPT(subset[name]) 86 | 87 | def add_batch(name): 88 | def tmp(_, inp, out): 89 | gpts[name].add_batch(inp[0].data, out.data) 90 | return tmp 91 | handles = [] 92 | for name in gpts: 93 | handles.append(subset[name].register_forward_hook(add_batch(name))) 94 | for j in range(args.nsamples): 95 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, alibi=alibi)[0] 96 | for h in handles: 97 | h.remove() 98 | 99 | for name in gpts: 100 | print(i, name) 101 | print('pruning ...') 102 | gpts[name].fasterprune( 103 | args.sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp 104 | ) 105 | for j in range(args.nsamples): 106 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, alibi=alibi)[0] 107 | 108 | layers[i] = layer.cpu() 109 | del gpts 110 | torch.cuda.empty_cache() 111 | 112 | inps, outs = outs, inps 113 | 114 | model.config.use_cache = use_cache 115 | 116 | @torch.no_grad() 117 | def bloom_eval(model, testenc, dev, dataset: str, log_wandb: bool = False): 118 | print('Evaluation...') 119 | 120 | testenc = testenc.input_ids 121 | nsamples = testenc.numel() // model.seqlen 122 | 123 | use_cache = model.config.use_cache 124 | model.config.use_cache = False 125 | layers = model.transformer.h 126 | 127 | model.transformer.word_embeddings = model.transformer.word_embeddings.to(dev) 128 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(dev) 129 | layers[0] = layers[0].to(dev) 130 | 131 | dtype = next(iter(model.parameters())).dtype 132 | inps = torch.zeros( 133 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 134 | ) 135 | cache = {'i': 0, 'attention_mask': None, 'alibi': None} 136 | 137 | class Catcher(nn.Module): 138 | def __init__(self, module): 139 | super().__init__() 140 | self.module = module 141 | def forward(self, inp, **kwargs): 142 | inps[cache['i']] = inp 143 | cache['i'] += 1 144 | cache['attention_mask'] = kwargs['attention_mask'] 145 | cache['alibi'] = kwargs['alibi'] 146 | raise ValueError 147 | layers[0] = Catcher(layers[0]) 148 | for i in range(nsamples): 149 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 150 | try: 151 | model(batch) 152 | except ValueError: 153 | pass 154 | layers[0] = layers[0].module 155 | 156 | layers[0] = layers[0].cpu() 157 | model.transformer.word_embeddings = model.transformer.word_embeddings.cpu() 158 | model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.cpu() 159 | torch.cuda.empty_cache() 160 | 161 | outs = torch.zeros_like(inps) 162 | attention_mask = cache['attention_mask'] 163 | alibi = cache['alibi'] 164 | 165 | for i in range(len(layers)): 166 | print(i) 167 | layer = layers[i].to(dev) 168 | 169 | if args.gmp: 170 | subset = find_layers(layer) 171 | for name in subset: 172 | W = subset[name].weight.data 173 | thresh = torch.sort(torch.abs(W.flatten()))[0][int(W.numel() * args.sparsity)] 174 | W.data[torch.abs(W.data) <= thresh] = 0 175 | 176 | for j in range(nsamples): 177 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, alibi=alibi)[0] 178 | layers[i] = layer.cpu() 179 | del layer 180 | torch.cuda.empty_cache() 181 | inps, outs = outs, inps 182 | 183 | model.transformer.ln_f = model.transformer.ln_f.to(dev) 184 | model.lm_head = model.lm_head.to(dev) 185 | 186 | testenc = testenc.to(dev) 187 | nlls = [] 188 | for i in range(nsamples): 189 | hidden_states = inps[i].unsqueeze(0) 190 | hidden_states = model.transformer.ln_f(hidden_states) 191 | lm_logits = model.lm_head(hidden_states) 192 | shift_logits = lm_logits[:, :-1, :].contiguous() 193 | shift_labels = testenc[ 194 | :, (i * model.seqlen):((i + 1) * model.seqlen) 195 | ][:, 1:] 196 | loss_fct = nn.CrossEntropyLoss() 197 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 198 | neg_log_likelihood = loss.float() * model.seqlen 199 | nlls.append(neg_log_likelihood) 200 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 201 | print(f"Perplexity: {ppl.item():3f}") 202 | if log_wandb: 203 | wandb.log({f'{dataset}/perplexity': ppl.item()}) 204 | 205 | model.config.use_cache = use_cache 206 | 207 | 208 | if __name__ == '__main__': 209 | import argparse 210 | from datautils import * 211 | 212 | parser = argparse.ArgumentParser() 213 | 214 | parser.add_argument( 215 | 'model', type=str, 216 | help='BLOOM model to load; pass `bigscience/bloom-X`.' 217 | ) 218 | parser.add_argument( 219 | 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], 220 | help='Where to extract calibration data from.' 221 | ) 222 | parser.add_argument( 223 | '--seed', 224 | type=int, default=0, help='Seed for sampling the calibration data.' 225 | ) 226 | parser.add_argument( 227 | '--nsamples', type=int, default=128, 228 | help='Number of calibration data samples.' 229 | ) 230 | parser.add_argument( 231 | '--percdamp', type=float, default=.01, 232 | help='Percent of the average Hessian diagonal to use for dampening.' 233 | ) 234 | parser.add_argument( 235 | '--sparsity', type=float, default=0, 236 | help='Target sparsity' 237 | ) 238 | parser.add_argument( 239 | '--prunen', type=int, default=0, 240 | help='N for N:M pruning.' 241 | ) 242 | parser.add_argument( 243 | '--prunem', type=int, default=0, 244 | help='M for N:M pruning.' 245 | ) 246 | parser.add_argument( 247 | '--gmp', action='store_true', 248 | help='Whether to run the GMP baseline.' 249 | ) 250 | parser.add_argument( 251 | '--minlayer', type=int, default=-1, 252 | help='Prune all layers with id >= this.' 253 | ) 254 | parser.add_argument( 255 | '--maxlayer', type=int, default=1000, 256 | help='Prune all layers with id < this.' 257 | ) 258 | parser.add_argument( 259 | '--prune_only', type=str, default='', 260 | help='Prune only layers that contain this text.' 261 | ) 262 | parser.add_argument( 263 | '--invert', action='store_true', 264 | help='Invert subset.' 265 | ) 266 | parser.add_argument( 267 | '--save', type=str, default='', 268 | help='Path to saved model.' 269 | ) 270 | parser.add_argument( 271 | '--log_wandb', action='store_true', 272 | help='Whether to log to wandb.' 273 | ) 274 | 275 | args = parser.parse_args() 276 | 277 | # init W&B logging 278 | if args.log_wandb: 279 | assert has_wandb, "wandb not installed try `pip install wandb`" 280 | wandb.init(config=args) 281 | 282 | model = get_bloom(args.model) 283 | model.eval() 284 | 285 | dataloader, testloader = get_loaders( 286 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 287 | ) 288 | 289 | if (args.sparsity or args.prunen) and not args.gmp: 290 | tick = time.time() 291 | bloom_sequential(model, dataloader, DEV) 292 | for n, p in model.named_parameters(): 293 | print(n, torch.mean((p == 0).float())) 294 | if 'dense_4h_to_h' in n: 295 | break 296 | print(time.time() - tick) 297 | 298 | for dataset in ['wikitext2', 'ptb', 'c4']: 299 | dataloader, testloader = get_loaders( 300 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 301 | ) 302 | print("Dataset:", dataset) 303 | bloom_eval(model, testloader, DEV, dataset, args.log_wandb) 304 | 305 | if args.save: 306 | model.save_pretrained(args.save) 307 | -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from datasets import load_dataset 6 | from transformers import AutoTokenizer, LlamaTokenizer 7 | 8 | 9 | def set_seed(seed): 10 | np.random.seed(seed) 11 | torch.random.manual_seed(seed) 12 | 13 | def get_tokenizer(model): 14 | if "llama" in model.lower(): 15 | tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False) 16 | # fix for transformer 4.28.0.dev0 compatibility 17 | if tokenizer.bos_token_id != 1 or tokenizer.eos_token_id != 2: 18 | try: 19 | tokenizer.bos_token_id = 1 20 | tokenizer.eos_token_id = 2 21 | except AttributeError: 22 | pass 23 | else: 24 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 25 | return tokenizer 26 | 27 | def get_wikitext2(nsamples, seed, seqlen, model, tokenizer): 28 | 29 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 30 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 31 | 32 | trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') 33 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 34 | 35 | random.seed(seed) 36 | trainloader = [] 37 | for _ in range(nsamples): 38 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 39 | j = i + seqlen 40 | inp = trainenc.input_ids[:, i:j] 41 | tar = inp.clone() 42 | tar[:, :-1] = -100 43 | trainloader.append((inp, tar)) 44 | return trainloader, testenc 45 | 46 | def get_ptb(nsamples, seed, seqlen, model, tokenizer): 47 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 48 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 49 | 50 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 51 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 52 | 53 | random.seed(seed) 54 | trainloader = [] 55 | for _ in range(nsamples): 56 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 57 | j = i + seqlen 58 | inp = trainenc.input_ids[:, i:j] 59 | tar = inp.clone() 60 | tar[:, :-1] = -100 61 | trainloader.append((inp, tar)) 62 | return trainloader, testenc 63 | 64 | def get_c4(nsamples, seed, seqlen, model, tokenizer): 65 | traindata = load_dataset( 66 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' 67 | ) 68 | valdata = load_dataset( 69 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' 70 | ) 71 | 72 | random.seed(seed) 73 | trainloader = [] 74 | for _ in range(nsamples): 75 | while True: 76 | i = random.randint(0, len(traindata) - 1) 77 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 78 | if trainenc.input_ids.shape[1] > seqlen: 79 | break 80 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 81 | j = i + seqlen 82 | inp = trainenc.input_ids[:, i:j] 83 | tar = inp.clone() 84 | tar[:, :-1] = -100 85 | trainloader.append((inp, tar)) 86 | 87 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 88 | valenc = valenc.input_ids[:, :(256 * seqlen)] 89 | 90 | class TokenizerWrapper: 91 | def __init__(self, input_ids): 92 | self.input_ids = input_ids 93 | valenc = TokenizerWrapper(valenc) 94 | 95 | return trainloader, valenc 96 | 97 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''): 98 | tokenizer = get_tokenizer(model) 99 | if 'wikitext2' in name: 100 | return get_wikitext2(nsamples, seed, seqlen, model, tokenizer) 101 | if 'ptb' in name: 102 | return get_ptb(nsamples, seed, seqlen, model, tokenizer) 103 | if 'c4' in name: 104 | return get_c4(nsamples, seed, seqlen, model, tokenizer) 105 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "QbRrWtX0PXVr" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/IST-DASLab/sparsegpt/blob/master/demo.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "mMUp4UrWjp-8" 16 | }, 17 | "source": [ 18 | "Install dependencies" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "colab": { 26 | "base_uri": "https://localhost:8080/" 27 | }, 28 | "id": "VdbD9blm6j_r", 29 | "outputId": "47a5db11-b0a6-441e-e812-69a8e7676ded" 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "!pip install -q datasets\n", 34 | "!pip install -q transformers" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "id": "rhSblKg_jter" 41 | }, 42 | "source": [ 43 | "Clone repository" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "colab": { 51 | "base_uri": "https://localhost:8080/" 52 | }, 53 | "id": "3nCz469NhV3c", 54 | "outputId": "39acb0f6-445b-401e-e854-6f2bce746d55" 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "!git clone https://github.com/IST-DASLab/sparsegpt" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": { 64 | "id": "mbM_bJODjyBg" 65 | }, 66 | "source": [ 67 | "### Pruning example\n", 68 | "---" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "id": "Om0QSLnLj8JN" 75 | }, 76 | "source": [ 77 | "Below we will show an example of SparseGPT applied to OPT model." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": { 84 | "colab": { 85 | "base_uri": "https://localhost:8080/" 86 | }, 87 | "id": "d9NTGmD4iVK7", 88 | "outputId": "8b92ddee-42a3-4d12-b8cb-904353518493" 89 | }, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "/content/sparsegpt\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "%cd sparsegpt" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": { 106 | "id": "HbiDyjx9j61I" 107 | }, 108 | "source": [ 109 | "Crerate directory to store prune model(s)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "id": "pJ-jauI-iyvi" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "!mkdir -p sparse_opt" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": { 126 | "id": "OGLiExo5Ksc4" 127 | }, 128 | "source": [ 129 | "We will use `opt.py` script to prune the model.\n", 130 | "Select one of the following OPT versions to fit into colab (with `bitsandbytes` one should be able to use larger 6.7b and 13b models):\n", 131 | "* facebook/opt-125m\n", 132 | "* facebook/opt-350m\n", 133 | "* facebook/opt-1.3b\n", 134 | "\n", 135 | "To prune the model select dataset for calibration (`c4`, `ptb` or `wikitext`). The SparseGPT paper uses `c4` by default.\n", 136 | "\n", 137 | "One can prune model to uniform sparsity with SparseGPT either with unstructured pruning or semistructured `N:M` pattern.\n", 138 | "\n", 139 | "To apply unstructured pruning specify `--sparsity` - floating point number in `[0, 1]`.\n", 140 | "\n", 141 | "For semitstructured specify `--prunen` and `--prunem` arguments - integer numbers.\n", 142 | "\n", 143 | "To apply magnitude pruning instead of SparseGPT select `--gmp` option.\n", 144 | "\n", 145 | "To apply quantization on top of sparsity specify `--wbits`.\n", 146 | "\n", 147 | "In the example below we prune `facebook/opt-125m` to 0.5 unstructured sparsity via SparseGPT. Try different options.\n" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "colab": { 155 | "base_uri": "https://localhost:8080/" 156 | }, 157 | "id": "BxucjXmCibnI", 158 | "outputId": "7d9dd66e-5308-4ecd-f30e-e97861682f96" 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "!python opt.py facebook/opt-125m c4 --sparsity 0.5 --save sparse_opt/opt-125m" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": { 168 | "id": "5mrOL92aO5xy" 169 | }, 170 | "source": [ 171 | "Code above prints perplexity on `wikitext2`, `ptb` and `c4` benchmarks in the end." 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": { 177 | "id": "AD9Zkgb-O21A" 178 | }, 179 | "source": [ 180 | "### Compare generations\n", 181 | "---" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": { 187 | "id": "nSJIGizLkPm8" 188 | }, 189 | "source": [ 190 | "Let us compare generations produced by the dense and sparse model" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "id": "-GzBUGsXic0o" 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "from transformers import AutoTokenizer, OPTForCausalLM" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": { 208 | "id": "Ub-69himlTpZ" 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "device = 'cuda'" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": { 219 | "id": "mQJtRPbekmXu" 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "# load dense model\n", 224 | "model_dn = OPTForCausalLM.from_pretrained('facebook/opt-125m', torch_dtype='auto').to(device)\n", 225 | "# load sparse model\n", 226 | "model_sp = OPTForCausalLM.from_pretrained('sparse_opt/opt-125m', torch_dtype='auto').to(device)\n", 227 | "# init tokenizer\n", 228 | "tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": { 235 | "id": "Bqskug9-mXtR" 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "input_text = \"It takes a great deal of bravery\"" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": { 246 | "id": "fS7YWAAhnatI" 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to(device)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": { 256 | "id": "w61F2J0QoPTi" 257 | }, 258 | "source": [ 259 | "Completion by dense model:" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": { 266 | "colab": { 267 | "base_uri": "https://localhost:8080/" 268 | }, 269 | "id": "o_xY5fSSnK2I", 270 | "outputId": "8594b1ba-2438-445e-8b4e-264d3dc943dd" 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "output_ids = model_dn.generate(input_ids)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "colab": { 282 | "base_uri": "https://localhost:8080/" 283 | }, 284 | "id": "KRmGPG1tnoci", 285 | "outputId": "f7c1c71d-6141-47c2-90d8-4bb0d785c4a2" 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "print(tokenizer.decode(output_ids[0].cpu(), skip_special_tokens=True))" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": { 295 | "id": "_9Zk6UaQpP9C" 296 | }, 297 | "source": [ 298 | "Completion by sparse model:" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": { 305 | "id": "Ky5U9elZn-pL" 306 | }, 307 | "outputs": [], 308 | "source": [ 309 | "output_ids = model_sp.generate(input_ids)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": { 316 | "colab": { 317 | "base_uri": "https://localhost:8080/" 318 | }, 319 | "id": "azFoDFrxpSJJ", 320 | "outputId": "47cc4a8a-cb17-47b5-b825-c93966074599" 321 | }, 322 | "outputs": [], 323 | "source": [ 324 | "print(tokenizer.decode(output_ids[0].cpu(), skip_special_tokens=True))" 325 | ] 326 | } 327 | ], 328 | "metadata": { 329 | "accelerator": "GPU", 330 | "colab": { 331 | "gpuType": "T4", 332 | "provenance": [] 333 | }, 334 | "kernelspec": { 335 | "display_name": "Python 3", 336 | "name": "python3" 337 | }, 338 | "language_info": { 339 | "name": "python" 340 | } 341 | }, 342 | "nbformat": 4, 343 | "nbformat_minor": 0 344 | } 345 | -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from sparsegpt import * 7 | from modelutils import * 8 | from quant import * 9 | 10 | try: 11 | import wandb 12 | has_wandb = True 13 | except: 14 | has_wandb = False 15 | 16 | 17 | def get_llama(model): 18 | import torch 19 | def skip(*args, **kwargs): 20 | pass 21 | torch.nn.init.kaiming_uniform_ = skip 22 | torch.nn.init.uniform_ = skip 23 | torch.nn.init.normal_ = skip 24 | from transformers import LlamaForCausalLM 25 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') 26 | model.seqlen = 2048 27 | return model 28 | 29 | 30 | @torch.no_grad() 31 | def llama_sequential(model, dataloader, dev): 32 | print("Starting...") 33 | 34 | use_cache = model.config.use_cache 35 | model.config.use_cache = False 36 | layers = model.model.layers 37 | 38 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 39 | model.model.norm = model.model.norm.to(dev) 40 | layers[0] = layers[0].to(dev) 41 | 42 | dtype = next(iter(model.parameters())).dtype 43 | inps = torch.zeros( 44 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 45 | ) 46 | cache = {"i": 0, "attention_mask": None} 47 | 48 | class Catcher(nn.Module): 49 | def __init__(self, module): 50 | super().__init__() 51 | self.module = module 52 | 53 | def forward(self, inp, **kwargs): 54 | inps[cache["i"]] = inp 55 | cache["i"] += 1 56 | cache["attention_mask"] = kwargs["attention_mask"] 57 | raise ValueError 58 | 59 | layers[0] = Catcher(layers[0]) 60 | for batch in dataloader: 61 | try: 62 | model(batch[0].to(dev)) 63 | except ValueError: 64 | pass 65 | layers[0] = layers[0].module 66 | 67 | layers[0] = layers[0].cpu() 68 | model.model.embed_tokens = model.model.embed_tokens.cpu() 69 | model.model.norm = model.model.norm.cpu() 70 | torch.cuda.empty_cache() 71 | 72 | outs = torch.zeros_like(inps) 73 | attention_mask = cache["attention_mask"] 74 | 75 | print("Ready.") 76 | 77 | quantizers = {} 78 | for i in range(len(layers)): 79 | layer = layers[i].to(dev) 80 | full = find_layers(layer) 81 | 82 | if args.true_sequential: 83 | sequential = [ 84 | ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], 85 | ["self_attn.o_proj"], 86 | ["mlp.up_proj", "mlp.gate_proj"], 87 | ["mlp.down_proj"], 88 | ] 89 | else: 90 | sequential = [list(full.keys())] 91 | 92 | for names in sequential: 93 | subset = {n: full[n] for n in names} 94 | 95 | gpts = {} 96 | for name in subset: 97 | if ( 98 | not (args.minlayer <= i < args.maxlayer and args.prune_only in name) 99 | ) == (not args.invert): 100 | continue 101 | gpts[name] = SparseGPT(subset[name]) 102 | if args.wbits < 16: 103 | gpts[name].quantizer = Quantizer() 104 | gpts[name].quantizer.configure( 105 | args.wbits, perchannel=True, sym=False, mse=False 106 | ) 107 | 108 | def add_batch(name): 109 | def tmp(_, inp, out): 110 | gpts[name].add_batch(inp[0].data, out.data) 111 | 112 | return tmp 113 | 114 | handles = [] 115 | for name in subset: 116 | handles.append(subset[name].register_forward_hook(add_batch(name))) 117 | for j in range(args.nsamples): 118 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 119 | for h in handles: 120 | h.remove() 121 | 122 | for name in subset: 123 | print(i, name) 124 | print("Pruning ...") 125 | sparsity = args.sparsity 126 | gpts[name].fasterprune( 127 | sparsity, 128 | prunen=args.prunen, 129 | prunem=args.prunem, 130 | percdamp=args.percdamp, 131 | blocksize=args.blocksize, 132 | ) 133 | gpts[name].free() 134 | 135 | for j in range(args.nsamples): 136 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 137 | 138 | layers[i] = layer.cpu() 139 | del layer 140 | del gpts 141 | torch.cuda.empty_cache() 142 | 143 | inps, outs = outs, inps 144 | 145 | model.config.use_cache = use_cache 146 | 147 | return quantizers 148 | 149 | 150 | @torch.no_grad() 151 | def llama_eval(model, testenc, dev, dataset: str, log_wandb: bool = False): 152 | print("Evaluating ...") 153 | 154 | testenc = testenc.input_ids 155 | nsamples = testenc.numel() // model.seqlen 156 | 157 | use_cache = model.config.use_cache 158 | model.config.use_cache = False 159 | layers = model.model.layers 160 | 161 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 162 | layers[0] = layers[0].to(dev) 163 | 164 | dtype = next(iter(model.parameters())).dtype 165 | inps = torch.zeros( 166 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 167 | ) 168 | cache = {"i": 0, "attention_mask": None} 169 | 170 | class Catcher(nn.Module): 171 | def __init__(self, module): 172 | super().__init__() 173 | self.module = module 174 | 175 | def forward(self, inp, **kwargs): 176 | inps[cache["i"]] = inp 177 | cache["i"] += 1 178 | cache["attention_mask"] = kwargs["attention_mask"] 179 | raise ValueError 180 | 181 | layers[0] = Catcher(layers[0]) 182 | for i in range(nsamples): 183 | batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev) 184 | try: 185 | model(batch) 186 | except ValueError: 187 | pass 188 | layers[0] = layers[0].module 189 | 190 | layers[0] = layers[0].cpu() 191 | model.model.embed_tokens = model.model.embed_tokens.cpu() 192 | torch.cuda.empty_cache() 193 | 194 | outs = torch.zeros_like(inps) 195 | attention_mask = cache["attention_mask"] 196 | 197 | for i in range(len(layers)): 198 | print(i) 199 | layer = layers[i].to(dev) 200 | 201 | if args.gmp: 202 | subset = find_layers(layer) 203 | for name in subset: 204 | W = subset[name].weight.data 205 | thresh = torch.sort(torch.abs(W.flatten()))[0][ 206 | int(W.numel() * args.sparsity) 207 | ] 208 | W.data[torch.abs(W.data) <= thresh] = 0 209 | 210 | for j in range(nsamples): 211 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 212 | layers[i] = layer.cpu() 213 | del layer 214 | torch.cuda.empty_cache() 215 | inps, outs = outs, inps 216 | 217 | if model.model.norm is not None: 218 | model.model.norm = model.model.norm.to(dev) 219 | model.lm_head = model.lm_head.to(dev) 220 | 221 | testenc = testenc.to(dev) 222 | nlls = [] 223 | for i in range(nsamples): 224 | hidden_states = inps[i].unsqueeze(0) 225 | if model.model.norm is not None: 226 | hidden_states = model.model.norm(hidden_states) 227 | lm_logits = model.lm_head(hidden_states) 228 | shift_logits = lm_logits[:, :-1, :].contiguous() 229 | shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] 230 | loss_fct = nn.CrossEntropyLoss() 231 | loss = loss_fct( 232 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) 233 | ) 234 | neg_log_likelihood = loss.float() * model.seqlen 235 | nlls.append(neg_log_likelihood) 236 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 237 | print(f"Perplexity: {ppl.item():3f}") 238 | if log_wandb: 239 | wandb.log({f"{dataset}/perplexity": ppl.item()}) 240 | 241 | model.config.use_cache = use_cache 242 | 243 | 244 | if __name__ == "__main__": 245 | import argparse 246 | from datautils import * 247 | 248 | parser = argparse.ArgumentParser() 249 | 250 | parser.add_argument("model", type=str, help="LlaMA model to load") 251 | parser.add_argument( 252 | "dataset", 253 | type=str, 254 | choices=["wikitext2", "ptb", "c4"], 255 | help="Where to extract calibration data from.", 256 | ) 257 | parser.add_argument( 258 | "--seed", type=int, default=0, help="Seed for sampling the calibration data." 259 | ) 260 | parser.add_argument( 261 | "--nsamples", type=int, default=128, help="Number of calibration data samples." 262 | ) 263 | parser.add_argument( 264 | "--percdamp", 265 | type=float, 266 | default=0.01, 267 | help="Percent of the average Hessian diagonal to use for dampening.", 268 | ) 269 | parser.add_argument("--sparsity", type=float, default=0, help="Target sparsity") 270 | parser.add_argument("--prunen", type=int, default=0, help="N for N:M pruning.") 271 | parser.add_argument("--prunem", type=int, default=0, help="M for N:M pruning.") 272 | parser.add_argument( 273 | "--blocksize", 274 | type=int, 275 | default=128, 276 | help="Blocksize to use for adaptive mask selection.", 277 | ) 278 | parser.add_argument( 279 | "--gmp", action="store_true", help="Whether to run the GMP baseline." 280 | ) 281 | parser.add_argument( 282 | "--wbits", type=int, default=16, help="Whether to quantize as well." 283 | ) 284 | parser.add_argument( 285 | "--minlayer", type=int, default=-1, help="Prune all layers with id >= this." 286 | ) 287 | parser.add_argument( 288 | "--maxlayer", type=int, default=1000, help="Prune all layers with id < this." 289 | ) 290 | parser.add_argument( 291 | "--prune_only", 292 | type=str, 293 | default="", 294 | help="Prune only layers that contain this text.", 295 | ) 296 | parser.add_argument("--invert", action="store_true", help="Invert subset.") 297 | parser.add_argument("--save", type=str, default="", help="Path to saved model.") 298 | parser.add_argument( 299 | "--true-sequential", 300 | action="store_true", 301 | help="Whether to run in true sequential model.", 302 | ) 303 | parser.add_argument( 304 | "--log_wandb", action="store_true", help="Whether to log to wandb." 305 | ) 306 | 307 | args = parser.parse_args() 308 | 309 | # init W&B logging 310 | if args.log_wandb: 311 | assert has_wandb, "wandb not installed try `pip install wandb`" 312 | wandb.init(config=args) 313 | 314 | model = get_llama(args.model) 315 | model.eval() 316 | 317 | dataloader, testloader = get_loaders( 318 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 319 | ) 320 | 321 | if (args.sparsity or args.prunen) and not args.gmp: 322 | tick = time.time() 323 | llama_sequential(model, dataloader, DEV) 324 | for n, p in model.named_parameters(): 325 | print(n, torch.mean((p == 0).float())) 326 | if 'down_proj' in n: 327 | break 328 | print(time.time() - tick) 329 | 330 | for dataset in ["wikitext2", "ptb", "c4"]: 331 | dataloader, testloader = get_loaders( 332 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 333 | ) 334 | print("Dataset:", dataset) 335 | llama_eval(model, testloader, DEV, dataset, args.log_wandb) 336 | 337 | if args.save: 338 | model.save_pretrained(args.save) 339 | -------------------------------------------------------------------------------- /modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | DEV = torch.device('cuda:0') 6 | 7 | 8 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 9 | if type(module) in layers: 10 | return {name: module} 11 | res = {} 12 | for name1, child in module.named_children(): 13 | res.update(find_layers( 14 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 15 | )) 16 | return res 17 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from quant import * 7 | from sparsegpt import * 8 | from modelutils import * 9 | 10 | try: 11 | import wandb 12 | has_wandb = True 13 | except: 14 | has_wandb = False 15 | 16 | 17 | def get_opt(model): 18 | import torch 19 | def skip(*args, **kwargs): 20 | pass 21 | torch.nn.init.kaiming_uniform_ = skip 22 | torch.nn.init.uniform_ = skip 23 | torch.nn.init.normal_ = skip 24 | from transformers import OPTForCausalLM 25 | model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto') 26 | model.seqlen = model.config.max_position_embeddings 27 | return model 28 | 29 | @torch.no_grad() 30 | def opt_sequential(model, dataloader, dev): 31 | print('Starting ...') 32 | 33 | use_cache = model.config.use_cache 34 | model.config.use_cache = False 35 | layers = model.model.decoder.layers 36 | 37 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 38 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 39 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 40 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 41 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 42 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 43 | layers[0] = layers[0].to(dev) 44 | 45 | dtype = next(iter(model.parameters())).dtype 46 | inps = torch.zeros( 47 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 48 | ) 49 | cache = {'i': 0, 'attention_mask': None} 50 | 51 | class Catcher(nn.Module): 52 | def __init__(self, module): 53 | super().__init__() 54 | self.module = module 55 | def forward(self, inp, **kwargs): 56 | inps[cache['i']] = inp 57 | cache['i'] += 1 58 | cache['attention_mask'] = kwargs['attention_mask'] 59 | raise ValueError 60 | layers[0] = Catcher(layers[0]) 61 | for batch in dataloader: 62 | try: 63 | model(batch[0].to(dev)) 64 | except ValueError: 65 | pass 66 | layers[0] = layers[0].module 67 | 68 | layers[0] = layers[0].cpu() 69 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 70 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 71 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 72 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 73 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 74 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 75 | torch.cuda.empty_cache() 76 | 77 | outs = torch.zeros_like(inps) 78 | attention_mask = cache['attention_mask'] 79 | 80 | print('Ready.') 81 | 82 | for i in range(len(layers)): 83 | layer = layers[i].to(dev) 84 | 85 | subset = find_layers(layer) 86 | 87 | gpts = {} 88 | for name in subset: 89 | if (not (args.minlayer <= i < args.maxlayer and args.prune_only in name)) == (not args.invert): 90 | continue 91 | gpts[name] = SparseGPT(subset[name]) 92 | if args.wbits < 16: 93 | gpts[name].quantizer = Quantizer() 94 | gpts[name].quantizer.configure( 95 | args.wbits, perchannel=True, sym=False, mse=False 96 | ) 97 | 98 | def add_batch(name): 99 | def tmp(_, inp, out): 100 | gpts[name].add_batch(inp[0].data, out.data) 101 | return tmp 102 | handles = [] 103 | for name in gpts: 104 | handles.append(subset[name].register_forward_hook(add_batch(name))) 105 | for j in range(args.nsamples): 106 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 107 | for h in handles: 108 | h.remove() 109 | 110 | for name in gpts: 111 | print(i, name) 112 | print('Pruning ...') 113 | sparsity = args.sparsity 114 | gpts[name].fasterprune( 115 | sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize 116 | ) 117 | gpts[name].free() 118 | 119 | for j in range(args.nsamples): 120 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 121 | 122 | layers[i] = layer.cpu() 123 | del layer 124 | torch.cuda.empty_cache() 125 | 126 | inps, outs = outs, inps 127 | 128 | model.config.use_cache = use_cache 129 | 130 | @torch.no_grad() 131 | def opt_eval(model, testenc, dev, dataset: str, log_wandb: bool = False): 132 | print('Evaluating ...') 133 | 134 | testenc = testenc.input_ids 135 | nsamples = testenc.numel() // model.seqlen 136 | 137 | use_cache = model.config.use_cache 138 | model.config.use_cache = False 139 | layers = model.model.decoder.layers 140 | 141 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 142 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) 143 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 144 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 145 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 146 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 147 | layers[0] = layers[0].to(dev) 148 | 149 | dtype = next(iter(model.parameters())).dtype 150 | inps = torch.zeros( 151 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 152 | ) 153 | cache = {'i': 0, 'attention_mask': None} 154 | 155 | class Catcher(nn.Module): 156 | def __init__(self, module): 157 | super().__init__() 158 | self.module = module 159 | def forward(self, inp, **kwargs): 160 | inps[cache['i']] = inp 161 | cache['i'] += 1 162 | cache['attention_mask'] = kwargs['attention_mask'] 163 | raise ValueError 164 | layers[0] = Catcher(layers[0]) 165 | for i in range(nsamples): 166 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 167 | try: 168 | model(batch) 169 | except ValueError: 170 | pass 171 | layers[0] = layers[0].module 172 | 173 | layers[0] = layers[0].cpu() 174 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() 175 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() 176 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: 177 | model.model.decoder.project_out = model.model.decoder.project_out.cpu() 178 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: 179 | model.model.decoder.project_in = model.model.decoder.project_in.cpu() 180 | torch.cuda.empty_cache() 181 | 182 | outs = torch.zeros_like(inps) 183 | attention_mask = cache['attention_mask'] 184 | 185 | for i in range(len(layers)): 186 | print(i) 187 | layer = layers[i].to(dev) 188 | 189 | if args.gmp: 190 | subset = find_layers(layer) 191 | for name in subset: 192 | W = subset[name].weight.data 193 | thresh = torch.sort(torch.abs(W.flatten()))[0][int(W.numel() * args.sparsity)] 194 | W.data[torch.abs(W.data) <= thresh] = 0 195 | 196 | for j in range(nsamples): 197 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 198 | layers[i] = layer.cpu() 199 | del layer 200 | torch.cuda.empty_cache() 201 | inps, outs = outs, inps 202 | 203 | if model.model.decoder.final_layer_norm is not None: 204 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) 205 | if model.model.decoder.project_out is not None: 206 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 207 | model.lm_head = model.lm_head.to(dev) 208 | 209 | testenc = testenc.to(dev) 210 | nlls = [] 211 | for i in range(nsamples): 212 | hidden_states = inps[i].unsqueeze(0) 213 | if model.model.decoder.final_layer_norm is not None: 214 | hidden_states = model.model.decoder.final_layer_norm(hidden_states) 215 | if model.model.decoder.project_out is not None: 216 | hidden_states = model.model.decoder.project_out(hidden_states) 217 | lm_logits = model.lm_head(hidden_states) 218 | shift_logits = lm_logits[:, :-1, :].contiguous() 219 | shift_labels = testenc[ 220 | :, (i * model.seqlen):((i + 1) * model.seqlen) 221 | ][:, 1:] 222 | loss_fct = nn.CrossEntropyLoss() 223 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 224 | neg_log_likelihood = loss.float() * model.seqlen 225 | nlls.append(neg_log_likelihood) 226 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 227 | print(f"Perplexity: {ppl.item():3f}") 228 | if log_wandb: 229 | wandb.log({f'{dataset}/perplexity': ppl.item()}) 230 | 231 | model.config.use_cache = use_cache 232 | 233 | 234 | if __name__ == '__main__': 235 | import argparse 236 | from datautils import * 237 | 238 | parser = argparse.ArgumentParser() 239 | 240 | parser.add_argument( 241 | 'model', type=str, 242 | help='OPT model to load; pass `facebook/opt-X`.' 243 | ) 244 | parser.add_argument( 245 | 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], 246 | help='Where to extract calibration data from.' 247 | ) 248 | parser.add_argument( 249 | '--seed', 250 | type=int, default=0, help='Seed for sampling the calibration data.' 251 | ) 252 | parser.add_argument( 253 | '--nsamples', type=int, default=128, 254 | help='Number of calibration data samples.' 255 | ) 256 | parser.add_argument( 257 | '--percdamp', type=float, default=.01, 258 | help='Percent of the average Hessian diagonal to use for dampening.' 259 | ) 260 | parser.add_argument( 261 | '--sparsity', type=float, default=0, 262 | help='Target sparsity' 263 | ) 264 | parser.add_argument( 265 | '--prunen', type=int, default=0, 266 | help='N for N:M pruning.' 267 | ) 268 | parser.add_argument( 269 | '--prunem', type=int, default=0, 270 | help='M for N:M pruning.' 271 | ) 272 | parser.add_argument( 273 | '--blocksize', type=int, default=128, 274 | help='Blocksize to use for adaptive mask selection.' 275 | ) 276 | parser.add_argument( 277 | '--gmp', action='store_true', 278 | help='Whether to run the GMP baseline.' 279 | ) 280 | parser.add_argument( 281 | '--wbits', type=int, default=16, 282 | help='Whether to quantize as well.' 283 | ) 284 | parser.add_argument( 285 | '--minlayer', type=int, default=-1, 286 | help='Prune all layers with id >= this.' 287 | ) 288 | parser.add_argument( 289 | '--maxlayer', type=int, default=1000, 290 | help='Prune all layers with id < this.' 291 | ) 292 | parser.add_argument( 293 | '--prune_only', type=str, default='', 294 | help='Prune only layers that contain this text.' 295 | ) 296 | parser.add_argument( 297 | '--invert', action='store_true', 298 | help='Invert subset.' 299 | ) 300 | parser.add_argument( 301 | '--save', type=str, default='', 302 | help='Path to saved model.' 303 | ) 304 | parser.add_argument( 305 | '--log_wandb', action='store_true', 306 | help='Whether to log to wandb.' 307 | ) 308 | 309 | args = parser.parse_args() 310 | 311 | # init W&B logging 312 | if args.log_wandb: 313 | assert has_wandb, "wandb not installed try `pip install wandb`" 314 | wandb.init(config=args) 315 | 316 | model = get_opt(args.model) 317 | model.eval() 318 | 319 | dataloader, testloader = get_loaders( 320 | args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 321 | ) 322 | 323 | if (args.sparsity or args.prunen) and not args.gmp: 324 | tick = time.time() 325 | opt_sequential(model, dataloader, DEV) 326 | for n, p in model.named_parameters(): 327 | print(n, torch.mean((p == 0).float())) 328 | if 'fc2' in n: 329 | break 330 | print(time.time() - tick) 331 | 332 | for dataset in ['wikitext2', 'ptb', 'c4']: 333 | dataloader, testloader = get_loaders( 334 | dataset, seed=args.seed, model=args.model, seqlen=model.seqlen 335 | ) 336 | print(dataset) 337 | opt_eval(model, testloader, DEV, dataset, args.log_wandb) 338 | 339 | if args.save: 340 | model.save_pretrained(args.save) 341 | -------------------------------------------------------------------------------- /quant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def quantize(x, scale, zero, maxq): 7 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 8 | return scale * (q - zero) 9 | 10 | class Quantizer(nn.Module): 11 | 12 | def __init__(self, shape=1): 13 | super(Quantizer, self).__init__() 14 | self.register_buffer('maxq', torch.tensor(0)) 15 | self.register_buffer('scale', torch.zeros(shape)) 16 | self.register_buffer('zero', torch.zeros(shape)) 17 | 18 | def configure( 19 | self, 20 | bits, perchannel=False, sym=True, 21 | mse=False, norm=2.4, grid=100, maxshrink=.8, 22 | grouprows=1 23 | ): 24 | self.maxq = torch.tensor(2 ** bits - 1) 25 | self.perchannel = perchannel 26 | self.sym = sym 27 | self.mse = mse 28 | self.norm = norm 29 | self.grid = grid 30 | self.maxshrink = maxshrink 31 | self.grouprows = grouprows 32 | 33 | def find_params(self, x, weight=False): 34 | dev = x.device 35 | self.maxq = self.maxq.to(dev) 36 | 37 | shape = x.shape 38 | if self.perchannel: 39 | if weight: 40 | x = x.flatten(1) 41 | if self.grouprows > 1: 42 | x = x.reshape((x.shape[0] // self.grouprows, -1)) 43 | else: 44 | if len(shape) == 4: 45 | x = x.permute([1, 0, 2, 3]) 46 | x = x.flatten(1) 47 | if len(shape) == 3: 48 | x = x.reshape((-1, shape[-1])).t() 49 | if len(shape) == 2: 50 | x = x.t() 51 | else: 52 | x = x.flatten().unsqueeze(0) 53 | 54 | tmp = torch.zeros(x.shape[0], device=dev) 55 | xmin = torch.minimum(x.min(1)[0], tmp) 56 | xmax = torch.maximum(x.max(1)[0], tmp) 57 | 58 | if self.sym: 59 | xmax = torch.maximum(torch.abs(xmin), xmax) 60 | tmp = xmin < 0 61 | if torch.any(tmp): 62 | xmin[tmp] = -xmax[tmp] 63 | tmp = (xmin == 0) & (xmax == 0) 64 | xmin[tmp] = -1 65 | xmax[tmp] = +1 66 | 67 | self.scale = (xmax - xmin) / self.maxq 68 | if self.sym: 69 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 70 | else: 71 | self.zero = torch.round(-xmin / self.scale) 72 | 73 | if self.mse: 74 | best = torch.full([x.shape[0]], float('inf'), device=dev) 75 | for i in range(int(self.maxshrink * self.grid)): 76 | p = 1 - i / self.grid 77 | xmin1 = p * xmin 78 | xmax1 = p * xmax 79 | scale1 = (xmax1 - xmin1) / self.maxq 80 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 81 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 82 | q -= x 83 | q.abs_() 84 | q.pow_(self.norm) 85 | err = torch.sum(q, 1) 86 | tmp = err < best 87 | if torch.any(tmp): 88 | best[tmp] = err[tmp] 89 | self.scale[tmp] = scale1[tmp] 90 | self.zero[tmp] = zero1[tmp] 91 | if not self.perchannel: 92 | if weight: 93 | tmp = shape[0] 94 | else: 95 | tmp = shape[1] if len(shape) != 3 else shape[2] 96 | self.scale = self.scale.repeat(tmp) 97 | self.zero = self.zero.repeat(tmp) 98 | 99 | if weight: 100 | if self.grouprows > 1: 101 | self.scale = self.scale.unsqueeze(1).repeat(1, self.grouprows) 102 | self.zero = self.zero.unsqueeze(1).repeat(1, self.grouprows) 103 | shape = [-1] + [1] * (len(shape) - 1) 104 | self.scale = self.scale.reshape(shape) 105 | self.zero = self.zero.reshape(shape) 106 | return 107 | if len(shape) == 4: 108 | self.scale = self.scale.reshape((1, -1, 1, 1)) 109 | self.zero = self.zero.reshape((1, -1, 1, 1)) 110 | if len(shape) == 3: 111 | self.scale = self.scale.reshape((1, 1, -1)) 112 | self.zero = self.zero.reshape((1, 1, -1)) 113 | if len(shape) == 2: 114 | self.scale = self.scale.unsqueeze(0) 115 | self.zero = self.zero.unsqueeze(0) 116 | 117 | def quantize(self, x): 118 | if self.ready(): 119 | return quantize(x, self.scale, self.zero, self.maxq) 120 | return x 121 | 122 | def enabled(self): 123 | return self.maxq > 0 124 | 125 | def ready(self): 126 | return torch.all(self.scale != 0) 127 | -------------------------------------------------------------------------------- /sparsegpt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from quant import * 9 | 10 | 11 | DEBUG = False 12 | 13 | torch.backends.cuda.matmul.allow_tf32 = False 14 | torch.backends.cudnn.allow_tf32 = False 15 | 16 | 17 | class SparseGPT: 18 | 19 | def __init__(self, layer): 20 | self.layer = layer 21 | self.dev = self.layer.weight.device 22 | W = layer.weight.data.clone() 23 | if isinstance(self.layer, nn.Conv2d): 24 | W = W.flatten(1) 25 | if isinstance(self.layer, transformers.Conv1D): 26 | W = W.t() 27 | self.rows = W.shape[0] 28 | self.columns = W.shape[1] 29 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 30 | self.nsamples = 0 31 | 32 | def add_batch(self, inp, out, blocksize=1024): 33 | if DEBUG: 34 | self.inp1 = inp 35 | self.out1 = out 36 | if len(inp.shape) == 2: 37 | inp = inp.unsqueeze(0) 38 | tmp = inp.shape[0] 39 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 40 | if len(inp.shape) == 3: 41 | inp = inp.reshape((-1, inp.shape[-1])) 42 | inp = inp.t() 43 | self.H *= self.nsamples / (self.nsamples + tmp) 44 | self.nsamples += tmp 45 | inp = math.sqrt(2 / self.nsamples) * inp.float() 46 | self.H += inp.matmul(inp.t()) 47 | 48 | def fasterprune( 49 | self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01 50 | ): 51 | W = self.layer.weight.data.clone() 52 | if isinstance(self.layer, nn.Conv2d): 53 | W = W.flatten(1) 54 | if isinstance(self.layer, transformers.Conv1D): 55 | W = W.t() 56 | W = W.float() 57 | 58 | if hasattr(self, 'quantizer'): 59 | if not self.quantizer.ready(): 60 | self.quantizer.find_params(W, weight=True) 61 | 62 | tick = time.time() 63 | 64 | H = self.H 65 | del self.H 66 | dead = torch.diag(H) == 0 67 | H[dead, dead] = 1 68 | W[:, dead] = 0 69 | 70 | Losses = torch.zeros(self.rows, device=self.dev) 71 | 72 | damp = percdamp * torch.mean(torch.diag(H)) 73 | diag = torch.arange(self.columns, device=self.dev) 74 | H[diag, diag] += damp 75 | H = torch.linalg.cholesky(H) 76 | H = torch.cholesky_inverse(H) 77 | H = torch.linalg.cholesky(H, upper=True) 78 | Hinv = H 79 | 80 | mask = None 81 | 82 | for i1 in range(0, self.columns, blocksize): 83 | i2 = min(i1 + blocksize, self.columns) 84 | count = i2 - i1 85 | 86 | W1 = W[:, i1:i2].clone() 87 | Q1 = torch.zeros_like(W1) 88 | Err1 = torch.zeros_like(W1) 89 | Losses1 = torch.zeros_like(W1) 90 | Hinv1 = Hinv[i1:i2, i1:i2] 91 | 92 | if prunen == 0: 93 | if mask is not None: 94 | mask1 = mask[:, i1:i2] 95 | else: 96 | tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 97 | thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] 98 | mask1 = tmp <= thresh 99 | else: 100 | mask1 = torch.zeros_like(W1) == 1 101 | 102 | for i in range(count): 103 | w = W1[:, i] 104 | d = Hinv1[i, i] 105 | 106 | if prunen != 0 and i % prunem == 0: 107 | tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2 108 | mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True) 109 | 110 | q = w.clone() 111 | q[mask1[:, i]] = 0 112 | 113 | if hasattr(self, 'quantizer'): 114 | q = quantize( 115 | q.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 116 | ).flatten() 117 | 118 | Q1[:, i] = q 119 | Losses1[:, i] = (w - q) ** 2 / d ** 2 120 | 121 | err1 = (w - q) / d 122 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 123 | Err1[:, i] = err1 124 | 125 | W[:, i1:i2] = Q1 126 | Losses += torch.sum(Losses1, 1) / 2 127 | 128 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 129 | 130 | if DEBUG: 131 | self.layer.weight.data[:, :i2] = W[:, :i2] 132 | self.layer.weight.data[:, i2:] = W[:, i2:] 133 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 134 | print(torch.sum(Losses)) 135 | 136 | torch.cuda.synchronize() 137 | print('time %.2f' % (time.time() - tick)) 138 | print('error', torch.sum(Losses).item()) 139 | 140 | if isinstance(self.layer, transformers.Conv1D): 141 | W = W.t() 142 | self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 143 | if DEBUG: 144 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 145 | 146 | def free(self): 147 | if DEBUG: 148 | self.inp1 = None 149 | self.out1 = None 150 | self.H = None 151 | torch.cuda.empty_cache() 152 | --------------------------------------------------------------------------------