├── .gitignore ├── README.md ├── llama.py ├── llama_model.py ├── llama_tokenizer.py ├── modelutils.py ├── opt.py ├── pythia.py ├── quant.py ├── retrain ├── prune_utils.py └── train.py ├── rwkv.py └── smart_compressors.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | build 3 | dist 4 | *.txt 5 | *.pt 6 | *egg-info* 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse + Quant LLMs [WIP] 2 | 3 | ## Do not use, unless you really know what you are doing. Broken in multiple places. 4 | 5 | 6 | 7 | Based on Sparse GPT, GPTQ, Optimal Bert Surgeon and others. 8 | 9 | run opt.py to get the optimal sparsity for each layer. 10 | 11 | `python3 opt.py facebook/opt-X [--seed 0] [--nsamples 128] [--wbits 16] [--groupsize -1] [--save PATH_TO_SAVE] [--compression_type {quantizeonly, prunemaskonly, prunemaskreconstruction, prunemagnitudemask, quantizeprune, none}] [--amount_prune 0.5]` 12 | 13 | ## Requirements: 14 | 15 | torch == 1.13.1 16 | transformers == 4.21.2 17 | sentencepiece == 0.1.97 18 | 19 | -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | """Compress LLaMa models.""" 2 | import os 3 | import random 4 | import time 5 | import json 6 | import copy 7 | 8 | import torch 9 | import torch.nn as nn 10 | from datasets import load_dataset 11 | import llama_model 12 | import llama_tokenizer 13 | 14 | import smart_compressors 15 | import quant 16 | from pathlib import Path 17 | 18 | DEVICE = torch.device('cpu') 19 | 20 | if torch.cuda.is_available(): 21 | DEVICE = torch.device('cuda:0') # pylint: disable=no-member 22 | 23 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): # pylint: disable=dangerous-default-value 24 | """Find linear and conv layers in a model.""" 25 | if type(module) in layers: 26 | return {name: module} 27 | res = {} 28 | for name1, child in module.named_children(): 29 | res.update(find_layers( 30 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 31 | )) 32 | return res 33 | 34 | def get_wikitext2(nsamples, seed, seqlen, tokenizer_path): 35 | """For now we take nsamples datapoints from wikitext2 and tokenize them.""" 36 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 37 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 38 | 39 | tokenizer = llama_tokenizer.Tokenizer(tokenizer_path) 40 | trainenc = tokenizer.encode("\n\n".join(traindata['text']), bos=False, eos=False) 41 | testenc = tokenizer.encode("\n\n".join(testdata['text']), bos=False, eos=False) 42 | bos_id = tokenizer.sp_model.bos_id() 43 | 44 | random.seed(seed) 45 | trainloader = [] 46 | print('\n\n\n', len(trainenc), '\n\n', len(trainenc) - seqlen - 1) 47 | for _ in range(nsamples): 48 | # print('\n\n\n', '-----', '\n\n\n') 49 | i = random.randint(0, len(trainenc) - seqlen - 1) 50 | j = i + seqlen - 1 51 | inp = [bos_id] + trainenc[i:j] 52 | tar = copy.deepcopy(inp) 53 | tar = [-100] * (len(tar) - 1) + tar[-1:] 54 | trainloader.append((torch.LongTensor([inp]), torch.LongTensor([tar]))) 55 | _testloader = [] 56 | print('\n\n\n', len(testenc), '\n\n', len(testenc) - seqlen - 1) 57 | for _ in range(nsamples): 58 | # print('\n\n\n', '-----', '\n\n\n') 59 | i = random.randint(0, len(testenc) - seqlen - 1) 60 | j = i + seqlen - 1 61 | inp = [bos_id] + testenc[i:j] 62 | tar = copy.deepcopy(inp) 63 | tar = [-100] * (len(tar) - 1) + tar[-1:] 64 | _testloader.append((torch.LongTensor([inp]), torch.LongTensor([tar]))) 65 | return trainloader, _testloader 66 | 67 | def benchmark(model_to_be_benched, _dataloader): 68 | """Benchmark a model.""" 69 | current_device = next(model.parameters()).device 70 | model_to_be_benched = model_to_be_benched.to(DEVICE) 71 | data_iterator = iter(_dataloader) 72 | loss_fn = nn.CrossEntropyLoss() 73 | with torch.no_grad(): 74 | loss = 0.0 75 | for i in range(100): 76 | inputs = next(data_iterator) 77 | # print('adsas:', inputs, 'iter:', _dataloader) 78 | inputs = inputs[0].to(DEVICE) 79 | outputs = model_to_be_benched(inputs[:, :-1]) 80 | loss += loss_fn(outputs.permute(0, 2, 1), inputs[:, 1:]).item() 81 | if i % 10 == 5: 82 | print(i) 83 | model_to_be_benched = model_to_be_benched.to(current_device) 84 | return loss 85 | 86 | def get_model(model_dir): 87 | """Get llama model.""" 88 | def skip(*args, **kwargs): # pylint: disable=unused-argument, redefined-outer-name 89 | pass 90 | torch.nn.init.kaiming_uniform_ = skip 91 | torch.nn.init.uniform_ = skip 92 | torch.nn.init.normal_ = skip 93 | 94 | with open(Path(model_dir) / "params.json", "r") as f: 95 | params = json.loads(f.read()) 96 | 97 | checkpoint = torch.load(model_dir + "/consolidated.00.pth", map_location="cpu") 98 | if "consolidated.01.pth" in os.listdir(model_dir): 99 | checkpoint_2 = torch.load(model_dir + "consolidated.01.pth", map_location="cpu") 100 | # for every key concatenate the tensors in the two checkpoints 101 | i = 0 102 | for key in checkpoint.keys(): 103 | if 'norm.weight' in key: 104 | continue 105 | dim = 1 if any(x in key for x in [ 106 | '.feed_forward.w2.weight', '.attention.wo.weight', 'tok_embeddings.weight']) else 0 107 | checkpoint[key] = torch.cat((checkpoint[key], checkpoint_2[key]), dim=dim) 108 | if i <= 5: 109 | print(key, checkpoint[key].shape, checkpoint_2[key].shape) 110 | i += 1 111 | 112 | model_args = llama_model.ModelArgs(max_seq_len=2048, **params) 113 | model_loaded = llama_model.Transformer(model_args) 114 | model_loaded.seqlen = 2048 # We need this for the dataloader trimming. 115 | print(model_loaded.state_dict().keys()) 116 | print(checkpoint.keys()) 117 | model_loaded.load_state_dict({k: v for k, v in checkpoint.items() if 'rope.freqs' not in k}) 118 | model_loaded = model_loaded.cuda() 119 | 120 | # If device is CPU then we convert from fp16 to fp32 121 | if DEVICE.type == 'cpu': 122 | model_loaded = model_loaded.half().to(torch.float32) 123 | print("existing") 124 | return model_loaded 125 | 126 | @torch.no_grad() 127 | def model_sequential(model, dataloader, device, compressor_class): # pylint: disable=redefined-outer-name 128 | """Optimize model sequentially.""" 129 | print('Starting ...') 130 | device = "cuda" 131 | layers = model.layers 132 | 133 | # Transfer to device 134 | model = model.to(device) 135 | 136 | # Initialize inputs, cache 137 | dtype = next(iter(model.parameters())).dtype 138 | inps = torch.zeros( 139 | (args.nsamples, model.seqlen, model.params.dim), dtype=dtype, device=device 140 | ) 141 | cache = {'i': 0, 'attention_mask': None} 142 | print('\n\n\nseq\n\n\n') 143 | 144 | # Get input and attention mask after layer 0 145 | class Catcher(nn.Module): # pylint: disable=missing-class-docstring 146 | def __init__(self, module): 147 | super().__init__() 148 | self.module = module 149 | def forward(self, inp, mask, **kwargs): 150 | """Forward pass.""" 151 | # print('Catcher input:', inp) 152 | inps[cache['i']] = inp 153 | cache['i'] += 1 154 | cache['attention_mask'] = mask 155 | raise ValueError 156 | layers[0] = Catcher(layers[0]) 157 | for batch in dataloader: 158 | try: 159 | model(batch[0].to(device)) 160 | except ValueError: 161 | pass 162 | layers[0] = layers[0].module 163 | 164 | # Transfer back to CPU 165 | model = model.cpu() 166 | layers[0] = layers[0].cpu() 167 | torch.cuda.empty_cache() 168 | 169 | outs = torch.zeros_like(inps) # Store outputs after each layer # pylint: disable=no-member 170 | attention_mask = cache['attention_mask'] 171 | print('Ready.') 172 | 173 | all_compressors = {} # pylint: disable=redefined-outer-name 174 | for i in range(len(layers)): # pylint: disable=consider-using-enumerate 175 | layer = layers[i].to(device) 176 | 177 | # Find linear layers and initialize quantizer for it 178 | subset = find_layers(layer) 179 | # print(subset) 180 | single_layer_compressor = {} 181 | for name in subset: # pylint: disable=consider-using-dict-items 182 | single_layer_compressor[name] = compressor_class(subset[name], args.amount_prune) 183 | single_layer_compressor[name].quantizer = quant.Quantizer() 184 | single_layer_compressor[name].quantizer.configure( 185 | args.wbits, perchannel=True, sym=False, mse=False 186 | ) 187 | 188 | def add_batch(name): 189 | def tmp(_, inp, out): 190 | single_layer_compressor[name].add_batch(inp[0].data, out.data) 191 | return tmp 192 | handles = [] 193 | for name in subset: # pylint: disable=consider-using-dict-items 194 | handles.append(subset[name].register_forward_hook(add_batch(name))) 195 | for j in range(args.nsamples): 196 | # print('inps:', inps[j], inps[j].shape) 197 | outs[j] = layer(inps[j].unsqueeze(0), mask=attention_mask)[0] 198 | for hhh in handles: 199 | hhh.remove() 200 | 201 | for name in subset: 202 | print(i, name) 203 | print('Quantizing ...') 204 | single_layer_compressor[name].fasterquant( 205 | percdamp=args.percdamp, groupsize=args.groupsize) 206 | # all_compressors[ 207 | # 'model.decoder.layers.%d.%s' % (i, name)] = single_layer_compressor[name] # pylint: disable=consider-using-f-string 208 | single_layer_compressor[name].free() 209 | for j in range(args.nsamples): 210 | outs[j] = layer(inps[j].unsqueeze(0), mask=attention_mask)[0] 211 | 212 | layers[i] = layer.cpu() 213 | del layer 214 | del single_layer_compressor 215 | torch.cuda.empty_cache() 216 | 217 | inps, outs = outs, inps 218 | 219 | return all_compressors 220 | 221 | 222 | if __name__ == '__main__': 223 | import argparse 224 | 225 | # Parse the arguments 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument('model', type=str, help='LLaMa model path to load;') 228 | parser.add_argument('vocab', type=str, help='LLaMa vocab path to load;') 229 | parser.add_argument('--seed', type=int, default=0, 230 | help='Seed for sampling the calibration data.') 231 | parser.add_argument('--nsamples', type=int, default=128, 232 | help='Number of calibration data samples.') 233 | parser.add_argument('--percdamp', type=float, default=.01, 234 | help='Percent of the average Hessian diagonal to use for dampening.') 235 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 16], 236 | help='#bits to use for quantization; use 16 for evaluating base model.') 237 | parser.add_argument('--groupsize', type=int, default=-1, 238 | help='Groupsize to use for quantization/pruning; default uses full row.') 239 | parser.add_argument('--do_save', action='store_true', default=False, 240 | help='Whether to save or not') 241 | parser.add_argument('--savepath', type=str, default='', 242 | help='Save quantized/pruned checkpoint under this name.') 243 | parser.add_argument('--load', type=str, default='', 244 | help='Load quantized/pruned checkpoint under this name.') 245 | parser.add_argument('--compression_type', type=str, required=True, 246 | choices=['quantizeonly', 'prunemaskonly', 'prunemaskreconstruction', 247 | 'prunemagnitudemask', 'quantizeprune', 'none',# 'pruneonly' 248 | ], 249 | help='Type of compression to perform.') 250 | parser.add_argument('--amount_prune', type=float, default=0.5, 251 | help='Amount of pruning to perform.') 252 | args = parser.parse_args() 253 | if args.savepath != '': 254 | args.do_save = True 255 | if args.compression_type == 'none': 256 | args.do_save = None 257 | 258 | # If prune is to be done then args.amount_prune must be between 0 and 1 259 | if args.compression_type in ['pruneonly', 'quantizeprune', 'prunemaskonly', 260 | 'prunemaskreconstruction']: 261 | assert 0 <= args.amount_prune <= 1, 'Amount of pruning must be between 0 and 1' 262 | 263 | # Load model 264 | model = get_model(args.model) 265 | model.eval() 266 | if args.load: 267 | model.load_state_dict(torch.load(args.load)) 268 | if args.compression_type != 'quantizeonly': 269 | args.wbits = 16 270 | 271 | # Load data 272 | dataloader, testloader = get_wikitext2( 273 | nsamples=args.nsamples, seed=args.seed, seqlen=model.seqlen, tokenizer_path=args.vocab) 274 | 275 | # Perform compression 276 | if args.compression_type != None: 277 | compression_class = None # pylint: disable=invalid-name 278 | if args.compression_type == 'quantizeonly': 279 | assert args.wbits < 16, 'Quantize only works with 4-bit quantization' 280 | compression_class = smart_compressors.QuantizeOnly 281 | elif args.compression_type == 'prunemaskonly': 282 | compression_class = smart_compressors.PruneMaskOnly 283 | elif args.compression_type == 'prunemaskreconstruction': 284 | compression_class = smart_compressors.PruneMaskReconstruction 285 | elif args.compression_type == 'prunemagnitudemask': 286 | compression_class = smart_compressors.PruneMagnitudeMask 287 | elif args.compression_type == 'none': 288 | pass 289 | elif args.compression_type == 'quantizeprune': 290 | raise NotImplementedError 291 | else: 292 | raise ValueError('Unknown compression type: %s' % args.compression_type) 293 | 294 | if compression_class is not None: 295 | tick = time.time() 296 | computed_compressors = model_sequential(model, dataloader, DEVICE, compression_class) 297 | print("Total time taken: %.2f seconds" % (time.time() - tick)) # pylint: disable=consider-using-f-string 298 | 299 | savefolder = None 300 | if args.do_save: 301 | if args.savepath == '': 302 | raise ValueError('Must specify savepath if do_save is True') 303 | savefolder = os.path.join( 304 | os.path.dirname(args.savepath), 305 | f'Model-{args.model.replace("/", "_")}_Compression-{args.compression_type}_Prune-{args.amount_prune}_Bits-{args.wbits}_Group-{args.groupsize}.pth') 306 | 307 | # Save 308 | if args.do_save: 309 | torch.save(model.state_dict(), savefolder) 310 | 311 | # Do benchmark 312 | if args.compression_type in ["quantizeonly", "prunemaskonly", "prunemaskreconstruction", 313 | "none"]: 314 | model = model.to(DEVICE) 315 | score = benchmark(model, testloader) 316 | print(score, savefolder) 317 | if args.do_save: 318 | open(savefolder + ".score", 'w+').write(str(score) + '\n') 319 | print(score, savefolder) 320 | 321 | print("Done") 322 | print("\n" * 5) 323 | -------------------------------------------------------------------------------- /llama_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import Optional, Tuple 5 | from dataclasses import dataclass 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | freqs_cis = None 13 | @dataclass 14 | class ModelArgs: 15 | dim: int = 512 16 | n_layers: int = 8 17 | n_heads: int = 8 18 | vocab_size: int = -1 # defined later by tokenizer 19 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 20 | norm_eps: float = 1e-5 21 | 22 | max_seq_len: int = 2048 23 | 24 | 25 | class RMSNorm(torch.nn.Module): 26 | def __init__(self, dim: int, eps: float = 1e-6): 27 | super().__init__() 28 | self.eps = eps 29 | self.weight = nn.Parameter(torch.ones(dim)) 30 | 31 | def _norm(self, x): 32 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 33 | 34 | def forward(self, x): 35 | output = self._norm(x.float()).type_as(x) 36 | return output * self.weight 37 | 38 | 39 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 40 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 41 | t = torch.arange(end, device=freqs.device) # type: ignore 42 | freqs = torch.outer(t, freqs).float() # type: ignore 43 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 44 | return freqs_cis 45 | 46 | 47 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 48 | # print('\n\n\n\n') 49 | # print(freqs_cis.shape, x.shape) 50 | # print('\n\n\n\n') 51 | ndim = x.ndim 52 | assert 0 <= 1 < ndim 53 | if freqs_cis.shape[0] == x.shape[1] + 1: 54 | freqs_cis = freqs_cis[:-1, :] 55 | # print(freqs_cis.shape, x.shape, 'p1') 56 | # else: 57 | # print(freqs_cis.shape, x.shape, 'p2') 58 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (freqs_cis.shape, x.shape) 59 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 60 | return freqs_cis.view(*shape) 61 | 62 | 63 | def apply_rotary_emb( 64 | xq: torch.Tensor, 65 | xk: torch.Tensor, 66 | freqs_cis: torch.Tensor, 67 | ) -> Tuple[torch.Tensor, torch.Tensor]: 68 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 69 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 70 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 71 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 72 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 73 | return xq_out.type_as(xq), xk_out.type_as(xk) 74 | 75 | 76 | class Attention(nn.Module): 77 | def __init__(self, args: ModelArgs): 78 | super().__init__() 79 | 80 | self.n_local_heads = args.n_heads 81 | self.head_dim = args.dim // args.n_heads 82 | 83 | self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) 84 | self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) 85 | self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) 86 | 87 | self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) 88 | 89 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]): 90 | bsz, seqlen, _ = x.shape 91 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 92 | 93 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 94 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) 95 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) 96 | 97 | device = xq.device 98 | freqs_cis_current = freqs_cis.to(device) 99 | 100 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis_current) 101 | 102 | keys = xk 103 | values = xv 104 | 105 | xq = xq.transpose(1, 2) 106 | keys = keys.transpose(1, 2) 107 | values = values.transpose(1, 2) 108 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 109 | if mask is not None: 110 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) 111 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 112 | output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) 113 | output = output.transpose( 114 | 1, 2 115 | ).contiguous().view(bsz, seqlen, -1) 116 | 117 | return self.wo(output) 118 | 119 | 120 | class FeedForward(nn.Module): 121 | def __init__( 122 | self, 123 | dim: int, 124 | hidden_dim: int, 125 | multiple_of: int, 126 | ): 127 | super().__init__() 128 | hidden_dim = int(2 * hidden_dim / 3) 129 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 130 | 131 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 132 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 133 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 134 | 135 | def forward(self, x): 136 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 137 | 138 | 139 | class TransformerBlock(nn.Module): 140 | def __init__(self, layer_id: int, args: ModelArgs): 141 | super().__init__() 142 | self.n_heads = args.n_heads 143 | self.dim = args.dim 144 | self.head_dim = args.dim // args.n_heads 145 | self.attention = Attention(args) 146 | self.feed_forward = FeedForward( 147 | dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of 148 | ) 149 | self.layer_id = layer_id 150 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 151 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 152 | 153 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]): 154 | # print("Inside tformer") 155 | h = x + self.attention.forward(self.attention_norm(x), mask) 156 | out = h + self.feed_forward.forward(self.ffn_norm(h)) 157 | return out 158 | 159 | 160 | class Transformer(nn.Module): 161 | def __init__(self, params: ModelArgs): 162 | super().__init__() 163 | self.params = params 164 | self.vocab_size = params.vocab_size 165 | self.n_layers = params.n_layers 166 | 167 | self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) 168 | 169 | global freqs_cis 170 | freqs_cis = precompute_freqs_cis( 171 | self.params.dim // self.params.n_heads, self.params.max_seq_len 172 | ) 173 | self.layers = torch.nn.ModuleList() 174 | for layer_id in range(params.n_layers): 175 | self.layers.append(TransformerBlock(layer_id, params)) 176 | 177 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 178 | self.output = nn.Linear(params.dim, params.vocab_size, bias=False) 179 | 180 | 181 | # @torch.inference_mode() 182 | def forward(self, tokens: torch.Tensor, start_pos: int = 0): 183 | _bsz, seqlen = tokens.shape 184 | h = self.tok_embeddings(tokens) 185 | # print("tok_emb:", h) 186 | # self.freqs_cis = self.freqs_cis.to(h.device) 187 | # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] 188 | 189 | mask = None 190 | if seqlen > 1: 191 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) 192 | mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) 193 | 194 | i = 0 195 | for layer in self.layers: 196 | # print("\n\n\n------", i, "------\n\n\n\n", layer, '\n\n') 197 | i += 1 198 | h = layer(h, mask) 199 | h = self.norm(h) 200 | output = self.output(h[:, :, :]) # only compute last logits 201 | return output.float() 202 | -------------------------------------------------------------------------------- /llama_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from sentencepiece import SentencePieceProcessor 5 | from logging import getLogger 6 | from typing import List 7 | import os 8 | 9 | logger = getLogger() 10 | 11 | class Tokenizer: 12 | def __init__(self, model_path: str): 13 | # reload tokenizer 14 | assert os.path.isfile(model_path), model_path 15 | self.sp_model = SentencePieceProcessor(model_file=model_path) 16 | logger.info(f"Reloaded SentencePiece model from {model_path}") 17 | 18 | # BOS / EOS token IDs 19 | self.n_words: int = self.sp_model.vocab_size() 20 | self.bos_id: int = self.sp_model.bos_id() 21 | self.eos_id: int = self.sp_model.eos_id() 22 | self.pad_id: int = self.sp_model.pad_id() 23 | logger.info( 24 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 25 | ) 26 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 27 | 28 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 29 | assert type(s) is str 30 | t = self.sp_model.encode(s) 31 | if bos: 32 | t = [self.bos_id] + t 33 | if eos: 34 | t = t + [self.eos_id] 35 | return t 36 | 37 | def decode(self, t: List[int]) -> str: 38 | return self.sp_model.decode(t) 39 | -------------------------------------------------------------------------------- /modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | DEV = torch.device('cuda:0') 6 | 7 | 8 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 9 | if type(module) in layers: 10 | return {name: module} 11 | res = {} 12 | for name1, child in module.named_children(): 13 | res.update(find_layers( 14 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 15 | )) 16 | return res 17 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | """Compress OPT models.""" 2 | import random 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from datasets import load_dataset 8 | from transformers import AutoTokenizer 9 | from transformers import OPTForCausalLM 10 | 11 | import smart_compressors 12 | import quant 13 | 14 | DEVICE = torch.device('cpu') 15 | 16 | if torch.cuda.is_available(): 17 | DEVICE = torch.device('cuda:0') # pylint: disable=no-member 18 | 19 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): # pylint: disable=dangerous-default-value 20 | """Find linear and conv layers in a model.""" 21 | if type(module) in layers: 22 | return {name: module} 23 | res = {} 24 | for name1, child in module.named_children(): 25 | res.update(find_layers( 26 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 27 | )) 28 | return res 29 | 30 | def get_wikitext2(nsamples, seed, seqlen, model_card): 31 | """For now we take nsamples datapoints from wikitext2 and tokenize them.""" 32 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 33 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 34 | 35 | tokenizer = AutoTokenizer.from_pretrained(model_card, use_fast=False) 36 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 37 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 38 | 39 | random.seed(seed) 40 | trainloader = [] 41 | for _ in range(nsamples): 42 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 43 | j = i + seqlen 44 | inp = trainenc.input_ids[:, i:j] 45 | tar = inp.clone() 46 | tar[:, :-1] = -100 47 | trainloader.append((inp, tar)) 48 | return trainloader, testenc 49 | 50 | def benchmark(model_to_be_benched, _dataloader): 51 | """Benchmark a model.""" 52 | current_device = model_to_be_benched.device 53 | model_to_be_benched = model_to_be_benched.to(DEVICE) 54 | data_iterator = iter(_dataloader) 55 | loss_fn = nn.CrossEntropyLoss() 56 | with torch.no_grad(): 57 | loss = 0.0 58 | for i in range(100): 59 | inputs = next(data_iterator)[0].to(DEVICE) 60 | outputs = model_to_be_benched(inputs[:, :-1]) 61 | loss += loss_fn(outputs.logits.permute(0, 2, 1), inputs[:, 1:]).item() 62 | if i % 10 == 5: 63 | print(i) 64 | model_to_be_benched = model_to_be_benched.to(current_device) 65 | return loss 66 | 67 | def get_opt(model_name): 68 | """Get opt model.""" 69 | def skip(*args, **kwargs): # pylint: disable=unused-argument, redefined-outer-name 70 | pass 71 | torch.nn.init.kaiming_uniform_ = skip 72 | torch.nn.init.uniform_ = skip 73 | torch.nn.init.normal_ = skip 74 | model_loaded = OPTForCausalLM.from_pretrained(model_name, torch_dtype='auto') 75 | model_loaded.seqlen = model_loaded.config.max_position_embeddings # We need this for the dataloader trimming. 76 | # If device is CPU then we convert from fp16 to fp32 77 | if DEVICE.type == 'cpu': 78 | model_loaded = model_loaded.half().to(torch.float32) 79 | return model_loaded 80 | 81 | @torch.no_grad() 82 | def opt_sequential(model, dataloader, device, compressor_class): # pylint: disable=redefined-outer-name 83 | """Optimize model sequentially.""" 84 | print('Starting ...') 85 | 86 | use_cache = model.config.use_cache 87 | model.config.use_cache = False 88 | layers = model.model.decoder.layers 89 | 90 | # Transfer to device 91 | model = model.to(device) 92 | 93 | # Initialize inputs, cache 94 | dtype = next(iter(model.parameters())).dtype 95 | inps = torch.zeros( 96 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device 97 | ) 98 | cache = {'i': 0, 'attention_mask': None} 99 | 100 | # Get input and attention mask after layer 0 101 | class Catcher(nn.Module): # pylint: disable=missing-class-docstring 102 | def __init__(self, module): 103 | super().__init__() 104 | self.module = module 105 | def forward(self, inp, **kwargs): 106 | """Forward pass.""" 107 | inps[cache['i']] = inp 108 | cache['i'] += 1 109 | cache['attention_mask'] = kwargs['attention_mask'] 110 | raise ValueError 111 | layers[0] = Catcher(layers[0]) 112 | for batch in dataloader: 113 | try: 114 | model(batch[0].to(device)) 115 | except ValueError: 116 | pass 117 | layers[0] = layers[0].module 118 | 119 | # Transfer back to CPU 120 | model = model.cpu() 121 | layers[0] = layers[0].cpu() 122 | torch.cuda.empty_cache() 123 | 124 | outs = torch.zeros_like(inps) # Store outputs after each layer # pylint: disable=no-member 125 | attention_mask = cache['attention_mask'] 126 | print('Ready.') 127 | 128 | all_compressors = {} # pylint: disable=redefined-outer-name 129 | for i in range(len(layers)): # pylint: disable=consider-using-enumerate 130 | layer = layers[i].to(device) 131 | 132 | # Find linear layers and initialize quantizer for it 133 | subset = find_layers(layer) 134 | single_layer_compressor = {} 135 | for name in subset: # pylint: disable=consider-using-dict-items 136 | single_layer_compressor[name] = compressor_class(subset[name], args.amount_prune) 137 | single_layer_compressor[name].quantizer = quant.Quantizer() 138 | single_layer_compressor[name].quantizer.configure( 139 | args.wbits, perchannel=True, sym=False, mse=False 140 | ) 141 | 142 | def add_batch(name): 143 | def tmp(_, inp, out): 144 | single_layer_compressor[name].add_batch(inp[0].data, out.data) 145 | return tmp 146 | handles = [] 147 | for name in subset: # pylint: disable=consider-using-dict-items 148 | handles.append(subset[name].register_forward_hook(add_batch(name))) 149 | for j in range(args.nsamples): 150 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 151 | for hhh in handles: 152 | hhh.remove() 153 | 154 | for name in subset: 155 | print(i, name) 156 | print('Quantizing ...') 157 | single_layer_compressor[name].fasterquant( 158 | percdamp=args.percdamp, groupsize=args.groupsize) 159 | # all_compressors[ 160 | # 'model.decoder.layers.%d.%s' % (i, name)] = single_layer_compressor[name] # pylint: disable=consider-using-f-string 161 | single_layer_compressor[name].free() 162 | for j in range(args.nsamples): 163 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 164 | 165 | layers[i] = layer.cpu() 166 | del layer 167 | del single_layer_compressor 168 | torch.cuda.empty_cache() 169 | 170 | inps, outs = outs, inps 171 | 172 | model.config.use_cache = use_cache 173 | 174 | return all_compressors 175 | 176 | if __name__ == '__main__': 177 | import argparse 178 | 179 | # Parse the arguments 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument('model', type=str, help='OPT model to load; pass `facebook/opt-X`.') 182 | parser.add_argument('--seed', type=int, default=0, 183 | help='Seed for sampling the calibration data.') 184 | parser.add_argument('--nsamples', type=int, default=128, 185 | help='Number of calibration data samples.') 186 | parser.add_argument('--percdamp', type=float, default=.01, 187 | help='Percent of the average Hessian diagonal to use for dampening.') 188 | parser.add_argument('--wbits', type=int, default=16, choices=[4, 16], 189 | help='#bits to use for quantization; use 16 for evaluating base model.') 190 | parser.add_argument('--groupsize', type=int, default=-1, 191 | help='Groupsize to use for quantization/pruning; default uses full row.') 192 | parser.add_argument('--save', type=str, default='', 193 | help='Save quantized/pruned checkpoint under this name.') 194 | parser.add_argument('--load', type=str, default='', 195 | help='Load quantized/pruned checkpoint under this name.') 196 | parser.add_argument('--compression_type', type=str, required=True, 197 | choices=['quantizeonly', 'prunemaskonly', 'prunemaskreconstruction', 198 | 'prunemagnitudemask', 'quantizeprune', 'none',# 'pruneonly' 199 | ], 200 | help='Type of compression to perform.') 201 | parser.add_argument('--amount_prune', type=float, default=0.5, 202 | help='Amount of pruning to perform.') 203 | args = parser.parse_args() 204 | 205 | # If prune is to be done then args.amount_prune must be between 0 and 1 206 | if args.compression_type in ['pruneonly', 'quantizeprune', 'prunemaskonly', 207 | 'prunemaskreconstruction']: 208 | assert 0 <= args.amount_prune <= 1, 'Amount of pruning must be between 0 and 1' 209 | 210 | # Load model 211 | model = get_opt(args.model) 212 | model.eval() 213 | if args.load: 214 | model.load_state_dict(torch.load(args.load)) 215 | 216 | # Load data 217 | dataloader, testloader = get_wikitext2( 218 | nsamples=args.nsamples, seed=args.seed, seqlen=model.seqlen, model_card=args.model) 219 | 220 | # Perform compression 221 | if args.wbits != 16: 222 | compression_class = None # pylint: disable=invalid-name 223 | if args.compression_type == 'quantizeonly': 224 | compression_class = smart_compressors.QuantizeOnly 225 | elif args.compression_type == 'prunemaskonly': 226 | compression_class = smart_compressors.PruneMaskOnly 227 | elif args.compression_type == 'prunemaskreconstruction': 228 | compression_class = smart_compressors.PruneMaskReconstruction 229 | elif args.compression_type == 'prunemagnitudemask': 230 | compression_class = smart_compressors.PruneMagnitudeMask 231 | elif args.compression_type == 'none': 232 | pass 233 | elif args.compression_type == 'quantizeprune': 234 | raise NotImplementedError 235 | else: 236 | raise ValueError('Unknown compression type: %s' % args.compression_type) 237 | 238 | if compression_class is not None: 239 | tick = time.time() 240 | computed_compressors = opt_sequential(model, dataloader, DEVICE, compression_class) 241 | print("Total time taken: %.2f seconds" % (time.time() - tick)) # pylint: disable=consider-using-f-string 242 | 243 | # Do benchmark 244 | if args.compression_type in ["quantizeonly", "prunemaskonly", "prunemaskreconstruction"]: 245 | model = model.to(DEVICE) 246 | print(benchmark(model, dataloader)) 247 | # elif args.compression_type == "pruneonly": 248 | # layer_names = ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", 249 | # "self_attn.out_proj", "fc1", "fc2"] 250 | # all_compressor_keys = ['model.decoder.layers.%d.%s' % (i, name) # pylint: disable=consider-using-f-string 251 | # for i in range(len(model.model.decoder.layers)) 252 | # for name in layer_names] 253 | # model = model.to(DEVICE) 254 | # # First benchmark with no pruning 255 | # print("Benchmarking with no pruning...") 256 | # print(benchmark(model, dataloader)) 257 | # print("\n\n") 258 | # # Now prune with only mask out and no reconstruction 259 | # for key in all_compressor_keys: 260 | # computed_compressors[key].layer.weight.data *= computed_compressors[ 261 | # key].new_weight_with_mask[1] 262 | # print("Benchmarking with only mask out...") 263 | # print(benchmark(model, dataloader)) 264 | # print("\n\n") 265 | # # Now prune with masking and reconstruction 266 | # for key in all_compressor_keys: 267 | # print(key, torch.sum(computed_compressors[key].new_weight_with_mask[1] == 0).item() / 268 | # computed_compressors[key].new_weight_with_mask[1].numel()) 269 | # # # print percentage of new_weight_with_mask[1] that is 0 270 | # # print(torch.sum(computed_compressors[key].new_weight_with_mask[1] == 0).item() / 271 | # # computed_compressors[key].new_weight_with_mask[1].numel()) 272 | # computed_compressors[key].layer.weight.data = ( 273 | # computed_compressors[key].new_weight_with_mask[0] * computed_compressors[ 274 | # key].new_weight_with_mask[1]).half() 275 | # print("Benchmarking with mask out and reconstruction...") 276 | # print(benchmark(model, dataloader)) 277 | # print("\n\n") 278 | elif args.compression_type == "quantizeprune": 279 | raise NotImplementedError 280 | else: 281 | model = model.to(DEVICE) 282 | print(benchmark(model, dataloader)) 283 | 284 | # Save 285 | if args.save: 286 | torch.save(model.state_dict(), args.save) 287 | print("Done") 288 | print("\n" * 5) 289 | -------------------------------------------------------------------------------- /pythia.py: -------------------------------------------------------------------------------- 1 | """Compress OPT models.""" 2 | import os 3 | import random 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | from datasets import load_dataset 9 | from transformers import AutoTokenizer 10 | from transformers import AutoModelForCausalLM 11 | 12 | import smart_compressors 13 | import quant 14 | 15 | DEVICE = torch.device('cpu') 16 | 17 | if torch.cuda.is_available(): 18 | DEVICE = torch.device('cuda:0') # pylint: disable=no-member 19 | 20 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): # pylint: disable=dangerous-default-value 21 | """Find linear and conv layers in a model.""" 22 | if type(module) in layers: 23 | return {name: module} 24 | res = {} 25 | for name1, child in module.named_children(): 26 | res.update(find_layers( 27 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 28 | )) 29 | return res 30 | 31 | def get_wikitext2(nsamples, seed, seqlen, model_card): 32 | """For now we take nsamples datapoints from wikitext2 and tokenize them.""" 33 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 34 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 35 | 36 | tokenizer = AutoTokenizer.from_pretrained(model_card) 37 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 38 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 39 | 40 | random.seed(seed) 41 | trainloader = [] 42 | for _ in range(nsamples): 43 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 44 | j = i + seqlen 45 | inp = trainenc.input_ids[:, i:j] 46 | tar = inp.clone() 47 | tar[:, :-1] = -100 48 | trainloader.append((inp, tar)) 49 | test_loader = [] 50 | for _ in range(nsamples): 51 | i = random.randint(0, testenc.input_ids.shape[1] - seqlen - 1) 52 | j = i + seqlen 53 | inp = testenc.input_ids[:, i:j] 54 | tar = inp.clone() 55 | tar[:, :-1] = -100 56 | test_loader.append((inp, tar)) 57 | return trainloader, test_loader 58 | 59 | def benchmark(model_to_be_benched, _dataloader): 60 | """Benchmark a model.""" 61 | current_device = model_to_be_benched.device 62 | model_to_be_benched = model_to_be_benched.to(DEVICE) 63 | data_iterator = iter(_dataloader) 64 | loss_fn = nn.CrossEntropyLoss() 65 | with torch.no_grad(): 66 | loss = 0.0 67 | for i in range(100): 68 | inputs = next(data_iterator)[0].to(DEVICE) 69 | outputs = model_to_be_benched(inputs[:, :-1]) 70 | loss += loss_fn(outputs.logits.permute(0, 2, 1), inputs[:, 1:]).item() 71 | if i % 10 == 5: 72 | print(i) 73 | model_to_be_benched = model_to_be_benched.to(current_device) 74 | return loss 75 | 76 | def get_model(model_name): 77 | """Get model.""" 78 | def skip(*args, **kwargs): # pylint: disable=unused-argument, redefined-outer-name 79 | pass 80 | torch.nn.init.kaiming_uniform_ = skip 81 | torch.nn.init.uniform_ = skip 82 | torch.nn.init.normal_ = skip 83 | model_loaded = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype='auto') 84 | model_loaded.seqlen = model_loaded.config.max_position_embeddings # We need this for the dataloader trimming. 85 | # If device is CPU then we convert from fp16 to fp32 86 | if DEVICE.type == 'cpu': 87 | model_loaded = model_loaded.half().to(torch.float32) 88 | return model_loaded 89 | 90 | @torch.no_grad() 91 | def model_sequential(model, dataloader, device, compressor_class): # pylint: disable=redefined-outer-name 92 | """Optimize model sequentially.""" 93 | print('Starting ...') 94 | 95 | use_cache = model.config.use_cache 96 | model.config.use_cache = False 97 | layers = model.gpt_neox.layers 98 | 99 | # Transfer to device 100 | model = model.to(device) 101 | 102 | # Initialize inputs, cache 103 | dtype = next(iter(model.parameters())).dtype 104 | inps = torch.zeros( 105 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device 106 | ) 107 | cache = {'i': 0, 'attention_mask': None} 108 | 109 | # Get input and attention mask after layer 0 110 | class Catcher(nn.Module): # pylint: disable=missing-class-docstring 111 | def __init__(self, module): 112 | super().__init__() 113 | self.module = module 114 | def forward(self, inp, **kwargs): 115 | """Forward pass.""" 116 | inps[cache['i']] = inp 117 | cache['i'] += 1 118 | cache['attention_mask'] = kwargs['attention_mask'] 119 | raise ValueError 120 | layers[0] = Catcher(layers[0]) 121 | for batch in dataloader: 122 | try: 123 | model(batch[0].to(device)) 124 | except ValueError: 125 | pass 126 | layers[0] = layers[0].module 127 | 128 | # Transfer back to CPU 129 | # model = model.cpu() 130 | # layers[0] = layers[0].cpu() 131 | torch.cuda.empty_cache() 132 | 133 | outs = torch.zeros_like(inps) # Store outputs after each layer # pylint: disable=no-member 134 | attention_mask = cache['attention_mask'] 135 | print('Ready.') 136 | 137 | all_compressors = {} # pylint: disable=redefined-outer-name 138 | for i in range(len(layers)): # pylint: disable=consider-using-enumerate 139 | layer = layers[i].to(device) 140 | 141 | # Find linear layers and initialize quantizer for it 142 | subset = find_layers(layer) 143 | single_layer_compressor = {} 144 | for name in subset: # pylint: disable=consider-using-dict-items 145 | single_layer_compressor[name] = compressor_class(subset[name], args.amount_prune) 146 | single_layer_compressor[name].quantizer = quant.Quantizer() 147 | single_layer_compressor[name].quantizer.configure( 148 | args.wbits, perchannel=True, sym=False, mse=False 149 | ) 150 | 151 | def add_batch(name): 152 | def tmp(_, inp, out): 153 | single_layer_compressor[name].add_batch(inp[0].data, out.data) 154 | return tmp 155 | handles = [] 156 | for name in subset: # pylint: disable=consider-using-dict-items 157 | handles.append(subset[name].register_forward_hook(add_batch(name))) 158 | for j in range(args.nsamples): 159 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 160 | for hhh in handles: 161 | hhh.remove() 162 | 163 | for name in subset: 164 | print(i, name) 165 | print('Quantizing ...') 166 | single_layer_compressor[name].fasterquant( 167 | percdamp=args.percdamp, groupsize=args.groupsize) 168 | # all_compressors[ 169 | # 'model.gpt_neox.layers.%d.%s' % (i, name)] = single_layer_compressor[name] # pylint: disable=consider-using-f-string 170 | single_layer_compressor[name].free() 171 | for j in range(args.nsamples): 172 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 173 | 174 | # layers[i] = layer.cpu() 175 | del layer 176 | del single_layer_compressor 177 | torch.cuda.empty_cache() 178 | 179 | inps, outs = outs, inps 180 | 181 | model.config.use_cache = use_cache 182 | 183 | return all_compressors 184 | 185 | if __name__ == '__main__': 186 | import argparse 187 | 188 | # Parse the arguments 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument('model', type=str, help='OPT model to load; pass `facebook/opt-X`.') 191 | parser.add_argument('--seed', type=int, default=0, 192 | help='Seed for sampling the calibration data.') 193 | parser.add_argument('--nsamples', type=int, default=128, 194 | help='Number of calibration data samples.') 195 | parser.add_argument('--percdamp', type=float, default=.01, 196 | help='Percent of the average Hessian diagonal to use for dampening.') 197 | parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 16], 198 | help='#bits to use for quantization; use 16 for evaluating base model.') 199 | parser.add_argument('--groupsize', type=int, default=-1, 200 | help='Groupsize to use for quantization/pruning; default uses full row.') 201 | parser.add_argument('--do_save', action='store_true', default=False, 202 | help='Whether to save or not') 203 | parser.add_argument('--savepath', type=str, default='', 204 | help='Save quantized/pruned checkpoint under this name.') 205 | parser.add_argument('--load', type=str, default='', 206 | help='Load quantized/pruned checkpoint under this name.') 207 | parser.add_argument('--compression_type', type=str, required=True, 208 | choices=['quantizeonly', 'prunemaskonly', 'prunemaskreconstruction', 209 | 'prunemagnitudemask', 'quantizeprune', 'none',# 'pruneonly' 210 | ], 211 | help='Type of compression to perform.') 212 | parser.add_argument('--amount_prune', type=float, default=0.5, 213 | help='Amount of pruning to perform.') 214 | args = parser.parse_args() 215 | if args.savepath != '': 216 | args.do_save = True 217 | 218 | # If prune is to be done then args.amount_prune must be between 0 and 1 219 | if args.compression_type in ['pruneonly', 'quantizeprune', 'prunemaskonly', 220 | 'prunemaskreconstruction']: 221 | assert 0 <= args.amount_prune <= 1, 'Amount of pruning must be between 0 and 1' 222 | 223 | # Load model 224 | model = get_model(args.model) 225 | model.eval() 226 | if args.load: 227 | model.load_state_dict(torch.load(args.load)) 228 | if args.compression_type != 'quantizeonly': 229 | args.wbits = 16 230 | 231 | # Load data 232 | dataloader, testloader = get_wikitext2( 233 | nsamples=args.nsamples, seed=args.seed, seqlen=model.seqlen, model_card=args.model) 234 | 235 | # Perform compression 236 | if args.compression_type != None: 237 | compression_class = None # pylint: disable=invalid-name 238 | if args.compression_type == 'quantizeonly': 239 | assert args.wbits < 16, 'Quantize only works with 4-bit quantization' 240 | compression_class = smart_compressors.QuantizeOnly 241 | elif args.compression_type == 'prunemaskonly': 242 | compression_class = smart_compressors.PruneMaskOnly 243 | elif args.compression_type == 'prunemaskreconstruction': 244 | compression_class = smart_compressors.PruneMaskReconstruction 245 | elif args.compression_type == 'prunemagnitudemask': 246 | compression_class = smart_compressors.PruneMagnitudeMask 247 | elif args.compression_type == 'none': 248 | pass 249 | elif args.compression_type == 'quantizeprune': 250 | raise NotImplementedError 251 | else: 252 | raise ValueError('Unknown compression type: %s' % args.compression_type) 253 | 254 | if compression_class is not None: 255 | tick = time.time() 256 | computed_compressors = model_sequential(model, dataloader, DEVICE, compression_class) 257 | print("Total time taken: %.2f seconds" % (time.time() - tick)) # pylint: disable=consider-using-f-string 258 | 259 | savefolder = None 260 | if args.do_save: 261 | if args.savepath == '': 262 | raise ValueError('Must specify savepath if do_save is True') 263 | savefolder = os.path.join( 264 | os.path.dirname(args.savepath), 265 | f'Model-{args.model.replace("/", "_")}_Compression-{args.compression_type}_Prune-{args.amount_prune}_Bits-{args.wbits}_Group-{args.groupsize}.pth') 266 | 267 | # Do benchmark 268 | if args.compression_type in ["quantizeonly", "prunemaskonly", "prunemaskreconstruction", 269 | "none"]: 270 | model = model.to(DEVICE) 271 | score = benchmark(model, testloader) 272 | print(score, savefolder) 273 | if args.do_save: 274 | open(savefolder + ".score", 'w+').write(str(score) + '\n') 275 | print(score, savefolder) 276 | 277 | # Save 278 | if args.do_save: 279 | torch.save(model.state_dict(), savefolder) 280 | print("Done") 281 | print("\n" * 5) 282 | -------------------------------------------------------------------------------- /quant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def quantize(x, scale, zero, maxq): 7 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 8 | return scale * (q - zero) 9 | 10 | class Quantizer(nn.Module): 11 | 12 | def __init__(self, shape=1): 13 | super(Quantizer, self).__init__() 14 | self.register_buffer('maxq', torch.tensor(0)) 15 | self.register_buffer('scale', torch.zeros(shape)) 16 | self.register_buffer('zero', torch.zeros(shape)) 17 | 18 | def configure( 19 | self, 20 | bits, perchannel=False, sym=True, 21 | mse=False, norm=2.4, grid=100, maxshrink=.8 22 | ): 23 | self.maxq = torch.tensor(2 ** bits - 1) 24 | self.perchannel = perchannel 25 | self.sym = sym 26 | self.mse = mse 27 | self.norm = norm 28 | self.grid = grid 29 | self.maxshrink = maxshrink 30 | 31 | def find_params(self, x, weight=False): 32 | dev = x.device 33 | self.maxq = self.maxq.to(dev) 34 | 35 | shape = x.shape 36 | if self.perchannel: 37 | if weight: 38 | x = x.flatten(1) 39 | else: 40 | if len(shape) == 4: 41 | x = x.permute([1, 0, 2, 3]) 42 | x = x.flatten(1) 43 | if len(shape) == 3: 44 | x = x.reshape((-1, shape[-1])).t() 45 | if len(shape) == 2: 46 | x = x.t() 47 | else: 48 | x = x.flatten().unsqueeze(0) 49 | 50 | tmp = torch.zeros(x.shape[0], device=dev) 51 | xmin = torch.minimum(x.min(1)[0], tmp) 52 | xmax = torch.maximum(x.max(1)[0], tmp) 53 | 54 | if self.sym: 55 | xmax = torch.maximum(torch.abs(xmin), xmax) 56 | tmp = xmin < 0 57 | if torch.any(tmp): 58 | xmin[tmp] = -xmax[tmp] 59 | tmp = (xmin == 0) & (xmax == 0) 60 | xmin[tmp] = -1 61 | xmax[tmp] = +1 62 | 63 | self.scale = (xmax - xmin) / self.maxq 64 | if self.sym: 65 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 66 | else: 67 | self.zero = torch.round(-xmin / self.scale) 68 | 69 | if self.mse: 70 | raise NotImplementedError 71 | best = torch.full([x.shape[0]], float('inf'), device=dev) 72 | for i in range(int(self.maxshrink * self.grid)): 73 | p = 1 - i / self.grid 74 | xmin1 = p * xmin 75 | xmax1 = p * xmax 76 | scale1 = (xmax1 - xmin1) / self.maxq 77 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 78 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 79 | q -= x 80 | q.abs_() 81 | q.pow_(self.norm) 82 | err = torch.sum(q, 1) 83 | tmp = err < best 84 | if torch.any(tmp): 85 | best[tmp] = err[tmp] 86 | self.scale[tmp] = scale1[tmp] 87 | self.zero[tmp] = zero1[tmp] 88 | if not self.perchannel: 89 | if weight: 90 | tmp = shape[0] 91 | else: 92 | tmp = shape[1] if len(shape) != 3 else shape[2] 93 | self.scale = self.scale.repeat(tmp) 94 | self.zero = self.zero.repeat(tmp) 95 | 96 | if weight: 97 | shape = [-1] + [1] * (len(shape) - 1) 98 | self.scale = self.scale.reshape(shape) 99 | self.zero = self.zero.reshape(shape) 100 | return 101 | if len(shape) == 4: 102 | self.scale = self.scale.reshape((1, -1, 1, 1)) 103 | self.zero = self.zero.reshape((1, -1, 1, 1)) 104 | if len(shape) == 3: 105 | self.scale = self.scale.reshape((1, 1, -1)) 106 | self.zero = self.zero.reshape((1, 1, -1)) 107 | if len(shape) == 2: 108 | self.scale = self.scale.unsqueeze(0) 109 | self.zero = self.zero.unsqueeze(0) 110 | 111 | def quantize(self, x): 112 | if self.ready(): 113 | return quantize(x, self.scale, self.zero, self.maxq) 114 | return x 115 | 116 | def enabled(self): 117 | return self.maxq > 0 118 | 119 | def ready(self): 120 | return torch.all(self.scale != 0) 121 | 122 | 123 | try: 124 | import quant_cuda 125 | except: 126 | print('CUDA extension not installed.') 127 | 128 | # Assumes layer is perfectly divisible into 1024 * 1024 blocks 129 | class Quant3Linear(nn.Module): 130 | 131 | def __init__(self, infeatures, outfeatures): 132 | super().__init__() 133 | self.register_buffer('zeros', torch.zeros((outfeatures, 1))) 134 | self.register_buffer('scales', torch.zeros((outfeatures, 1))) 135 | self.register_buffer('bias', torch.zeros(outfeatures)) 136 | self.register_buffer( 137 | 'qweight', torch.zeros((infeatures // 1024 * 96, outfeatures), dtype=torch.int) 138 | ) 139 | 140 | def pack(self, linear, scales, zeros): 141 | self.zeros = zeros * scales 142 | self.scales = scales.clone() 143 | self.bias = linear.bias.clone() 144 | 145 | intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int) 146 | intweight = intweight.t().contiguous() 147 | intweight = intweight.numpy().astype(np.uint32) 148 | qweight = np.zeros( 149 | (intweight.shape[0] // 1024 * 96, intweight.shape[1]), dtype=np.uint32 150 | ) 151 | i = 0 152 | row = 0 153 | while row < qweight.shape[0]: 154 | for j in range(i, i + 10): 155 | qweight[row] |= intweight[j] << (3 * (j - i)) 156 | i += 10 157 | qweight[row] |= intweight[i] << 30 158 | row += 1 159 | qweight[row] |= (intweight[i] >> 2) & 1 160 | i += 1 161 | for j in range(i, i + 10): 162 | qweight[row] |= intweight[j] << (3 * (j - i) + 1) 163 | i += 10 164 | qweight[row] |= intweight[i] << 31 165 | row += 1 166 | qweight[row] |= (intweight[i] >> 1) & 0x3 167 | i += 1 168 | for j in range(i, i + 10): 169 | qweight[row] |= intweight[j] << (3 * (j - i) + 2) 170 | i += 10 171 | row += 1 172 | 173 | qweight = qweight.astype(np.int32) 174 | self.qweight = torch.from_numpy(qweight) 175 | 176 | def forward(self, x): 177 | if x.shape[-1] == x.numel(): 178 | outshape = list(x.shape) 179 | y = self.bias.clone() 180 | outshape[-1] = self.bias.numel() 181 | dtype = x.dtype 182 | x = x.float() 183 | quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros) 184 | y = y.to(dtype) 185 | return y.reshape(outshape) 186 | raise ValueError('Only supports a single token currently.') 187 | 188 | def make_quant3(module, names, name=''): 189 | if isinstance(module, Quant3Linear): 190 | return 191 | for attr in dir(module): 192 | tmp = getattr(module, attr) 193 | name1 = name + '.' + attr if name != '' else attr 194 | if name1 in names: 195 | setattr( 196 | module, attr, Quant3Linear(tmp.in_features, tmp.out_features) 197 | ) 198 | for name1, child in module.named_children(): 199 | make_quant3(child, names, name + '.' + name1 if name != '' else name1) 200 | -------------------------------------------------------------------------------- /retrain/prune_utils.py: -------------------------------------------------------------------------------- 1 | """Iteratively Prune a model based on the magnitude of weights. 2 | 3 | Pytorch only supports x86/ARM for quantization. 4 | 5 | 6 | """ 7 | import collections # pylint: disable=syntax-error 8 | import copy 9 | 10 | from typing import Union, List # pylint: disable=syntax-error 11 | import torch 12 | from torch.nn.utils import prune as torch_prune # pylint: disable=wrong-import-position 13 | 14 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): # pylint: disable=dangerous-default-value 15 | """Find linear and conv layers in a model.""" 16 | if type(module) in layers: 17 | return {name: module} 18 | res = {} 19 | for name1, child in module.named_children(): 20 | res.update(find_layers( 21 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 22 | )) 23 | return res 24 | 25 | def log_prune_statistics(parameters_to_prune): 26 | """Logs the prune statistics.""" 27 | total_num_pruned = 0 28 | total_num_params = 0 29 | for param_layer, _ in parameters_to_prune: 30 | this_layer_pruned = float(torch.sum(param_layer.weight == 0)) # pylint: disable=no-member 31 | this_num_params = float(param_layer.weight.nelement()) 32 | total_num_pruned += this_layer_pruned 33 | total_num_params += this_num_params 34 | 35 | sparsity_percent = round(100. * this_layer_pruned / this_num_params, 3) 36 | print(f"Sparsity in {param_layer}: {sparsity_percent}%") 37 | 38 | print(f"Global sparsity: {round(100. * total_num_pruned / total_num_params, 3)}%") 39 | 40 | class PruneInitialize: 41 | def __init__(self, model) -> None: 42 | self.model = model 43 | self.initialize() 44 | 45 | def initialize(self): 46 | """For each prunable layer we mask out all the weights that are zero.""" 47 | layers = self.model.model.decoder.layers 48 | 49 | for i, layer in enumerate(layers): 50 | prunable_layers = find_layers(layer) 51 | for prunable_layer_name, prunable_layer in prunable_layers.items(): 52 | # Find percentage of weights that are zero 53 | num_zeros = torch.sum(prunable_layer.weight == 0) # pylint: disable=no-member 54 | total_num_weights = prunable_layer.weight.nelement() 55 | percentage_zeros = num_zeros / total_num_weights 56 | print(f"Percentage of zeros in {prunable_layer_name}: {percentage_zeros}") 57 | 58 | torch_prune.random_unstructured(prunable_layer, name="weight", amount=percentage_zeros) 59 | 60 | def remove_prune(self): 61 | layers = self.model.model.decoder.layers 62 | for i, layer in enumerate(layers): 63 | prunable_layers = find_layers(layer) 64 | for prunable_layer_name, prunable_layer in prunable_layers.items(): 65 | torch_prune.remove(prunable_layer, 'weight') 66 | 67 | -------------------------------------------------------------------------------- /retrain/train.py: -------------------------------------------------------------------------------- 1 | """Training the model.""" 2 | 3 | import copy 4 | from datetime import datetime # pylint: disable=syntax-error 5 | import json 6 | import os 7 | import random # pylint: disable=syntax-error 8 | import time 9 | 10 | from easydict import EasyDict as edict # pylint: disable=import-error 11 | from tqdm import tqdm 12 | import wandb # pylint: disable=syntax-error, import-error 13 | import yaml 14 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 15 | 16 | import torch # pylint: disable=wrong-import-position 17 | from torch import nn # pylint: disable=wrong-import-position 18 | import transformers as tf # pylint: disable=wrong-import-position 19 | 20 | import dataloader # pylint: disable=wrong-import-position, no-name-in-module 21 | from prune_utils import PruneInitialize # pylint: disable=wrong-import-position, wrong-import-order 22 | import params 23 | 24 | timeprint = lambda x: print(str(datetime.now()), x) # pylint: disable=C3001 25 | 26 | LOG_FREQUENCY = 50 27 | MODEL_CKPT_FREQUENCY = 2000 28 | SEED = 432 29 | 30 | torch.manual_seed(SEED) 31 | torch.backends.cudnn.enabled = False 32 | torch.backends.cudnn.benchmark = False 33 | torch.backends.cudnn.deterministic = True 34 | random.seed(SEED) 35 | 36 | scaler = torch.cuda.amp.GradScaler() 37 | 38 | 39 | def train_step(model, batch, subbatch_size, loss_fn, optimizer, device, teacher_model): 40 | """A single training step.""" 41 | model.train() 42 | correct, total_tokens, teacher_correct, total_loss = 0, 0, 0, 0 43 | 44 | input_ids = batch["input_ids"] 45 | label_ids = input_ids[:, 1:].contiguous().to(device) 46 | input_ids = input_ids[:, :-1].contiguous().to(device) 47 | attention_mask = batch["attention_mask"][:, 1:].to(device) 48 | 49 | with torch.cuda.amp.autocast(): 50 | # Split into subbatches 51 | for i in range(0, input_ids.shape[0], subbatch_size): 52 | subbatch_input_ids = input_ids[i:i + subbatch_size, :] 53 | subbatch_label_ids = label_ids[i:i + subbatch_size, :] 54 | subbatch_attention_mask = attention_mask[i:i + subbatch_size, :] 55 | 56 | logits = model(input_ids=subbatch_input_ids).logits 57 | if teacher_model is not None: 58 | with torch.no_grad(): 59 | teacher_logits = teacher_model( 60 | input_ids=subbatch_input_ids.to(teacher_model.device)).logits 61 | teacher_logits = teacher_logits.detach().to(logits.device) 62 | loss = loss_fn(logits, subbatch_label_ids, subbatch_attention_mask, teacher_logits) 63 | teacher_predicts = teacher_logits.argmax(-1) 64 | teacher_correct += (( 65 | teacher_predicts == subbatch_label_ids).float() * subbatch_attention_mask).sum() 66 | else: 67 | loss = loss_fn(logits, subbatch_label_ids, subbatch_attention_mask) 68 | 69 | scaler.scale(loss).backward() 70 | total_loss += loss.item() 71 | 72 | # Calculate accuracy 73 | predicts = logits.argmax(-1) 74 | correct += ((predicts == subbatch_label_ids).float() * subbatch_attention_mask).sum() 75 | total_tokens += subbatch_attention_mask.sum() 76 | 77 | scaler.step(optimizer)#.step() 78 | optimizer.zero_grad() 79 | model.zero_grad() 80 | scaler.update() 81 | 82 | accuracy = correct / total_tokens 83 | 84 | if teacher_model is not None: 85 | teacher_accuracy = teacher_correct / total_tokens 86 | return {"loss": total_loss, "accuracy": 100 * accuracy.item(), 87 | "teacher_accuracy": 100 * teacher_accuracy.item()} 88 | 89 | return {"loss": total_loss, "accuracy": 100 * accuracy.item()} 90 | 91 | 92 | def evaluate_model(model, dataiterator, subbatch_size, loss_fn, device, use_wandb): 93 | """Evaluate the model.""" 94 | model.eval() 95 | 96 | total_loss = 0 97 | total_accuracy = 0 98 | total_tokens = 0 99 | 100 | with torch.no_grad(): 101 | for batch in tqdm(dataiterator): 102 | input_ids = batch["input_ids"] 103 | label_ids = input_ids[:, 1:].contiguous().to(device) 104 | input_ids = input_ids[:, :-1].contiguous().to(device) 105 | attention_mask = batch["attention_mask"][:, 1:].to(device) 106 | 107 | # Split into subbatches 108 | for i in range(0, input_ids.shape[0], subbatch_size): 109 | subbatch_input_ids = input_ids[i:i + subbatch_size, :] 110 | subbatch_label_ids = label_ids[i:i + subbatch_size, :] 111 | subbatch_attention_mask = attention_mask[i:i + subbatch_size, :] 112 | 113 | subbatch_logits = model(input_ids=subbatch_input_ids).logits 114 | subbatch_loss = loss_fn( 115 | subbatch_logits, subbatch_label_ids, subbatch_attention_mask) 116 | 117 | # Calculate accuracy 118 | subbatch_predicts = subbatch_logits.argmax(-1) 119 | subbatch_correct = ( 120 | subbatch_predicts == subbatch_label_ids).float() * subbatch_attention_mask 121 | subbatch_accuracy = subbatch_correct.sum() / subbatch_attention_mask.sum() 122 | 123 | total_loss += subbatch_loss.item() * subbatch_attention_mask.sum() 124 | total_accuracy += subbatch_accuracy.item() * subbatch_attention_mask.sum() 125 | total_tokens += subbatch_attention_mask.sum() 126 | 127 | # Print and Log to wandb 128 | if use_wandb: 129 | wandb.log({"eval_loss": total_loss / total_tokens, 130 | "eval_accuracy": 100 * total_accuracy / total_tokens}) 131 | print("Eval Loss: ", total_loss / total_tokens) 132 | print("Eval Accuracy: ", 100 * total_accuracy / total_tokens) 133 | return 134 | 135 | 136 | def save_checkpoint(iter_idx, model, optimizer, path, device): 137 | """Save model checkpoint, while moving them to CPU. Returns time taken to save.""" 138 | start = time.time() 139 | optimizer_to(optimizer, 'cpu') 140 | 141 | if isinstance(model, nn.DataParallel): 142 | torch.save( 143 | { 144 | 'iter_idx': iter_idx, 145 | 'model_state_dict': model.cpu().module.state_dict(), 146 | 'optimizer_state_dict': optimizer.state_dict(), 147 | }, path) 148 | else: 149 | torch.save( 150 | { 151 | 'iter_idx': iter_idx, 152 | 'model_state_dict': model.cpu().state_dict(), 153 | 'optimizer_state_dict': optimizer.state_dict(), 154 | }, path) 155 | 156 | # Move to Device 157 | model = model.to(device) 158 | optimizer_to(optimizer, device) 159 | 160 | timeprint(f"Model saved at epoch {iter_idx}") 161 | return time.time() - start 162 | 163 | 164 | def load_checkpoint(path): 165 | """Load the model checkpoint with highest iterations cnt.""" 166 | if not os.path.isfile(path): 167 | possible_ckpt_paths = [] 168 | for single_filename in os.listdir(path): 169 | if single_filename.startswith('ckpt_'): 170 | numerical_path = single_filename.lstrip('ckpt_') 171 | if all(char in '0123456789' 172 | for char in numerical_path.split('.')[0]): 173 | possible_ckpt_paths.append(single_filename) 174 | print(os.listdir(path)) 175 | if not possible_ckpt_paths: 176 | raise ValueError(f"No saved checkpoint at {path}") 177 | 178 | latest_ckpt_file = sorted( 179 | possible_ckpt_paths, 180 | key=lambda x: int(x.lstrip('ckpt_').split('.')[0]))[-1] 181 | path = os.path.join(path, latest_ckpt_file) 182 | print(path) 183 | checkpoint = torch.load(path) 184 | if ['iter_idx'] in checkpoint: 185 | return (checkpoint['iter_idx'], 186 | checkpoint['model_state_dict'], 187 | checkpoint['optimizer_state_dict']), path 188 | return (0, checkpoint['model_state_dict'], None), path 189 | 190 | 191 | def optimizer_to(optim, device): 192 | """Move Optimizer to CPU/GPU.""" 193 | for param in optim.state.values(): 194 | # Not sure there are any global tensors in the state dict 195 | if isinstance(param, torch.Tensor): 196 | param.data = param.data.to(device) 197 | if param._grad is not None: # pylint: disable=W0212 198 | param._grad.data = param._grad.data.to(device) # pylint: disable=W0212 199 | elif isinstance(param, dict): 200 | for subparam in param.values(): 201 | if isinstance(subparam, torch.Tensor): 202 | subparam.data = subparam.data.to(device) 203 | if subparam._grad is not None: # pylint: disable=W0212 204 | subparam._grad.data = subparam._grad.data.to(device) # pylint: disable=W0212 205 | 206 | 207 | def log_metrics(logs, times, use_wandb, iter_idx, generated_samples): 208 | """Log the metrics and optionally to wandb.""" 209 | print() 210 | timeprint(f"Step {iter_idx}") 211 | 212 | for log_name, log_value in logs.items(): 213 | print(f"{log_name}: {log_value:.4f}", end=" | ") 214 | 215 | total_time = time.time() - times['Start'] 216 | print(f"Time ={round(total_time, 2)}sec", end=" ") 217 | print(f"(Train {round(100 * times['Train_step']/total_time, 1)}", end="% ") 218 | print(f"Save {round(100 * times['Save']/total_time, 1)}%)") 219 | # _ = [print(x) for x in generated_samples] 220 | 221 | if use_wandb: 222 | wandb_dict = {} 223 | for log_name, log_value in logs.items(): 224 | wandb_dict[log_name] = log_value 225 | 226 | wandb_dict["Train Time"] = 100 * times['Train_step'] / total_time 227 | wandb_dict["Save Time"] = 100 * times['Save'] / total_time 228 | wandb_dict["Examples"] = generated_samples 229 | wandb.log(wandb_dict, step=iter_idx) 230 | 231 | for key in logs: 232 | logs[key] = 0 233 | 234 | # print("Generated Samples:", generated_samples) 235 | 236 | 237 | def main(args): 238 | """Initializes Dataloader, Model, Optimizer and then trains the model.""" 239 | # Prepeare save folder and store model config 240 | if not os.path.exists(args.save_dir): 241 | os.mkdir(args.save_dir) 242 | args.save_dir = os.path.join(args.save_dir, str(len(os.listdir(args.save_dir)))) 243 | if not os.path.exists(args.save_dir): 244 | os.mkdir(args.save_dir) 245 | json.dump(vars(args), open(os.path.join(args.save_dir, "params.json"), "w+")) # pylint: disable=unspecified-encoding 246 | timeprint(f"Save directory is {args.save_dir}") 247 | 248 | # Set Model, loss_fn and optimizer. 249 | model = tf.AutoModelForCausalLM.from_pretrained(args.model_card) 250 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 251 | 252 | def loss_fn_eval(logits, label_ids, pad_mask): 253 | losses = ce_loss(logits.permute(0, 2, 1), label_ids) 254 | return (losses * pad_mask).sum() / pad_mask.sum() 255 | 256 | timeprint("Model, Loss Fn and Optimizer classes set.") 257 | 258 | # Load saved model, optimizer if any. 259 | iter_idx, prev_iter_idx = 0, 0 260 | if args.load_dir.strip() != "": 261 | (prev_iter_idx, model_state_dict, optimizer_state_dict) = load_checkpoint( 262 | args.load_dir)[0] 263 | model.load_state_dict(model_state_dict) 264 | if optimizer_state_dict is not None and optimizer_state_dict: 265 | optimizer.load_state_dict(optimizer_state_dict) 266 | iter_idx = prev_iter_idx 267 | print("iter_idx", iter_idx) 268 | 269 | # Move to GPU 270 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # pylint: disable=no-member 271 | model = model.to(device) 272 | if device.type != "cpu": 273 | model = torch.nn.DataParallel(model) 274 | if args.wandb: 275 | wandb.watch(model) 276 | optimizer_to(optimizer, 'cuda') 277 | 278 | # If Knowledge Distillation is used, load the teacher model. 279 | if args.do_knowledge_distillation: 280 | # TODO: Try loss matching at each layer. 281 | if args.kd_teacher_model_path is not None: 282 | teacher_model = tf.AutoModelForCausalLM.from_pretrained(args.model_card) 283 | if args.kd_teacher_model_path.strip() != "": 284 | teacher_model.load_state_dict(load_checkpoint(args.kd_teacher_model_path)[0][3]) 285 | if torch.cuda.device_count() > 1: 286 | teacher_device = f"cuda:{torch.cuda.device_count() - 1}" 287 | else: 288 | teacher_device = "cpu" 289 | teacher_model = teacher_model.to(teacher_device) 290 | else: 291 | teacher_model = copy.deepcopy(model) 292 | teacher_model.eval() 293 | ce_loss = nn.CrossEntropyLoss(reduction='none') 294 | kd_loss = nn.KLDivLoss(reduction='none') 295 | def loss_fn(logits, label_ids, pad_mask, teacher_logits): 296 | # Compute losses againts target labels 297 | losses = ce_loss( 298 | logits.permute(0, 2, 1), label_ids) 299 | losses = (losses * pad_mask).sum() / pad_mask.sum() 300 | 301 | kd_losses = kd_loss( 302 | nn.functional.log_softmax(logits / args.kd_temperature, dim=-1), 303 | nn.functional.softmax(teacher_logits / args.kd_temperature, dim=-1)) 304 | kd_losses = (kd_losses * pad_mask.unsqueeze(-1)).sum() / pad_mask.sum() 305 | 306 | return losses + args.kd_weight * kd_losses 307 | logs = {key: 0 for key in ["Loss", "Accuracy", "Teacher_Accuracy"]} 308 | else: 309 | teacher_model = None 310 | loss_fn = loss_fn_eval 311 | ce_loss = nn.CrossEntropyLoss(reduction='none') 312 | logs = {key: 0 for key in ["Loss", "Accuracy"]} 313 | 314 | # Create the pruner model class 315 | if args.do_prune: 316 | iterative_pruner = prune_utils.Pruner(args.prune_recipe, iter_steps=iter_idx) 317 | # Initialize the model with initial pruning. 318 | iterative_pruner.init_prune(model=model) 319 | else: 320 | iterative_pruner = prune_utils.DummyPruner() 321 | 322 | # Initialize log dicts for training loop. 323 | times = {'Train_step': 0.0, 'Save': 0.0, 'Start': time.time()} 324 | 325 | def update_logs(train_log): 326 | for key in logs: 327 | logs[key] += train_log[key.lower()] 328 | 329 | print("Starting training") 330 | # Create the dataloaders. 331 | train_loader = dataloader.TextFileDataset( 332 | args.train_data_path, args.model_card, args.max_len, eval_mode=False, dummy_mode=args.dummy) 333 | # eval_loader = dataloader.TextFileDataset( 334 | # args.eval_data_path, args.model_card, args.max_len, eval_mode=True, dummy_mode=args.dummy) 335 | if parser_args.dummy: 336 | train_loader.data = train_loader.data[::len(train_loader.data)//500] 337 | 338 | timeprint("Data is loaded") 339 | 340 | train_dataiterator = torch.utils.data.DataLoader( 341 | dataset=train_loader, batch_size=args.batch_size, num_workers=1, shuffle=False, 342 | collate_fn=dataloader.collate) 343 | 344 | # # Run evaluation on the model before training. 345 | # eval_dataloader = torch.utils.data.DataLoader( 346 | # dataset=eval_loader, batch_size=args.batch_size, num_workers=1, 347 | # collate_fn=dataloader.collate) 348 | # evaluate_model(model, eval_dataloader, 2*args.subbatch_size, loss_fn_eval, device, args.wandb) 349 | 350 | for batch in tqdm(train_dataiterator): 351 | if iter_idx and iter_idx % LOG_FREQUENCY == 0: 352 | model.eval() 353 | gen_fn = model.generate if device.type == "cpu" else model.module.generate 354 | batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 355 | for k, v in batch.items()} 356 | predicts = gen_fn( 357 | input_ids=batch['input_ids'][:1, :-1]) 358 | 359 | generated_samples = [x.tolist()[0] for x in [predicts, batch['input_ids']]] 360 | generated_samples = train_loader.tokenizer.batch_decode( 361 | generated_samples, skip_special_tokens=True) 362 | 363 | start = time.time() 364 | train_logs = train_step( 365 | model, batch, args.subbatch_size, loss_fn, optimizer, device, teacher_model) 366 | times["Train_step"] += (time.time() - start) 367 | update_logs(train_logs) 368 | 369 | if iter_idx and iter_idx != prev_iter_idx and iter_idx % LOG_FREQUENCY == 0: 370 | # Normalize logs. 371 | logs = {log_key: log_val/LOG_FREQUENCY for log_key, log_val in logs.items()} 372 | # Log to wandb. 373 | if iter_idx - prev_iter_idx > LOG_FREQUENCY + 1: 374 | log_metrics(logs, times, args.wandb, iter_idx, generated_samples) 375 | logs = {log_key: 0.0 for log_key, _ in logs.items()} 376 | 377 | torch.cuda.empty_cache() 378 | 379 | save_path = os.path.join( 380 | args.save_dir, f"ckpt_{iter_idx}.pt") 381 | iterative_pruner.remove_prune(model=model) 382 | times["Save"] += save_checkpoint( 383 | iter_idx, model, optimizer, save_path, device) 384 | iterative_pruner.init_prune(model=model) 385 | 386 | iter_idx += 1 387 | 388 | 389 | 390 | if __name__ == '__main__': 391 | # Read parameters from command line and then load config. 392 | parser_args = params.parse_arguments() 393 | 394 | # Setup wandb. 395 | if parser_args.wandb: 396 | wandb.init(project="sparsify", name=parser_args.model_card, config=parser_args) 397 | 398 | # Prepare path to datasets. 399 | main(parser_args) 400 | -------------------------------------------------------------------------------- /rwkv.py: -------------------------------------------------------------------------------- 1 | """Compress OPT models.""" 2 | import random 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from datasets import load_dataset 8 | from transformers import AutoTokenizer 9 | from transformers import OPTForCausalLM 10 | import torch 11 | from rwkvstic.load import RWKV 12 | 13 | import smart_compressors 14 | import quant 15 | 16 | DEVICE = torch.device('cpu') 17 | 18 | if torch.cuda.is_available(): 19 | DEVICE = torch.device('cuda:0') # pylint: disable=no-member 20 | 21 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): # pylint: disable=dangerous-default-value 22 | """Find linear and conv layers in a model.""" 23 | all are list of list, we covert into nn module and .eval it 24 | model.model.key 25 | model.model.value 26 | model.model.receptance 27 | model.model.outputvv 28 | model.model.key_ffn 29 | model.model.receptance_ffn 30 | model.model.value_ffn 31 | 32 | def get_wikitext2(nsamples, seed, seqlen, model_card): 33 | """For now we take nsamples datapoints from wikitext2 and tokenize them.""" 34 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 35 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 36 | 37 | tokenizer = AutoTokenizer.from_pretrained(model_card, use_fast=False) 38 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 39 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 40 | 41 | random.seed(seed) 42 | trainloader = [] 43 | for _ in range(nsamples): 44 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 45 | j = i + seqlen 46 | inp = trainenc.input_ids[:, i:j] 47 | tar = inp.clone() 48 | tar[:, :-1] = -100 49 | trainloader.append((inp, tar)) 50 | return trainloader, testenc 51 | 52 | def benchmark(model_to_be_benched, _dataloader): 53 | """Benchmark a model.""" 54 | current_device = model_to_be_benched.device 55 | model_to_be_benched = model_to_be_benched.to(DEVICE) 56 | data_iterator = iter(_dataloader) 57 | loss_fn = nn.CrossEntropyLoss() 58 | with torch.no_grad(): 59 | loss = 0.0 60 | for i in range(100): 61 | inputs = next(data_iterator)[0].to(DEVICE) 62 | outputs = model_to_be_benched(inputs[:, :-1]) 63 | loss += loss_fn(outputs.logits.permute(0, 2, 1), inputs[:, 1:]).item() 64 | if i % 10 == 5: 65 | print(i) 66 | model_to_be_benched = model_to_be_benched.to(current_device) 67 | return loss 68 | 69 | class wrapIntoTorchNNModule(nn.Module): # pylint: disable=missing-class-docstring 70 | """Wrap a model into a torch.nn.Module. In forward we call the model.""" 71 | def __init__(self, model): 72 | super().__init__() 73 | self.model = model 74 | def forward(self, inp, **kwargs): 75 | 76 | 77 | def get_rwkv(path): 78 | """Get opt model.""" 79 | def skip(*args, **kwargs): # pylint: disable=unused-argument, redefined-outer-name 80 | pass 81 | torch.nn.init.kaiming_uniform_ = skip 82 | torch.nn.init.uniform_ = skip 83 | torch.nn.init.normal_ = skip 84 | model_loaded = RWKV(path) 85 | model_loaded.seqlen = 2048 # We need this for the dataloader trimming. 86 | 87 | # If device is CPU then we convert from fp16 to fp32 88 | if DEVICE.type == 'cpu': 89 | model_loaded = model_loaded.half().to(torch.float32) 90 | return model_loaded 91 | 92 | @torch.no_grad() 93 | def opt_sequential(model, dataloader, device, compressor_class): # pylint: disable=redefined-outer-name 94 | """Optimize model sequentially.""" 95 | print('Starting ...') 96 | 97 | use_cache = model.config.use_cache 98 | model.config.use_cache = False 99 | layers = model.model.decoder.layers 100 | 101 | # Transfer to device 102 | model = model.to(device) 103 | 104 | # Initialize inputs, cache 105 | dtype = next(iter(model.parameters())).dtype 106 | inps = torch.zeros( 107 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device 108 | ) 109 | cache = {'i': 0, 'attention_mask': None} 110 | 111 | # Get input and attention mask after layer 0 112 | class Catcher(nn.Module): # pylint: disable=missing-class-docstring 113 | def __init__(self, module): 114 | super().__init__() 115 | self.module = module 116 | def forward(self, inp, **kwargs): 117 | """Forward pass.""" 118 | inps[cache['i']] = inp 119 | cache['i'] += 1 120 | cache['attention_mask'] = kwargs['attention_mask'] 121 | raise ValueError 122 | layers[0] = Catcher(layers[0]) 123 | for batch in dataloader: 124 | try: 125 | model(batch[0].to(device)) 126 | except ValueError: 127 | pass 128 | layers[0] = layers[0].module 129 | 130 | # Transfer back to CPU 131 | model = model.cpu() 132 | layers[0] = layers[0].cpu() 133 | torch.cuda.empty_cache() 134 | 135 | outs = torch.zeros_like(inps) # Store outputs after each layer # pylint: disable=no-member 136 | attention_mask = cache['attention_mask'] 137 | print('Ready.') 138 | 139 | all_compressors = {} # pylint: disable=redefined-outer-name 140 | for i in range(len(layers)): # pylint: disable=consider-using-enumerate 141 | layer = layers[i].to(device) 142 | 143 | # Find linear layers and initialize quantizer for it 144 | subset = find_layers(layer) 145 | single_layer_compressor = {} 146 | for name in subset: # pylint: disable=consider-using-dict-items 147 | single_layer_compressor[name] = compressor_class(subset[name], args.amount_prune) 148 | single_layer_compressor[name].quantizer = quant.Quantizer() 149 | single_layer_compressor[name].quantizer.configure( 150 | args.wbits, perchannel=True, sym=False, mse=False 151 | ) 152 | 153 | def add_batch(name): 154 | def tmp(_, inp, out): 155 | single_layer_compressor[name].add_batch(inp[0].data, out.data) 156 | return tmp 157 | handles = [] 158 | for name in subset: # pylint: disable=consider-using-dict-items 159 | handles.append(subset[name].register_forward_hook(add_batch(name))) 160 | for j in range(args.nsamples): 161 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 162 | for hhh in handles: 163 | hhh.remove() 164 | 165 | for name in subset: 166 | print(i, name) 167 | print('Quantizing ...') 168 | single_layer_compressor[name].fasterquant( 169 | percdamp=args.percdamp, groupsize=args.groupsize) 170 | # all_compressors[ 171 | # 'model.decoder.layers.%d.%s' % (i, name)] = single_layer_compressor[name] # pylint: disable=consider-using-f-string 172 | single_layer_compressor[name].free() 173 | for j in range(args.nsamples): 174 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 175 | 176 | layers[i] = layer.cpu() 177 | del layer 178 | del single_layer_compressor 179 | torch.cuda.empty_cache() 180 | 181 | inps, outs = outs, inps 182 | 183 | model.config.use_cache = use_cache 184 | 185 | return all_compressors 186 | 187 | if __name__ == '__main__': 188 | import argparse 189 | 190 | # Parse the arguments 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument('model', type=str, help='OPT model to load; pass `facebook/opt-X`.') 193 | parser.add_argument('--seed', type=int, default=0, 194 | help='Seed for sampling the calibration data.') 195 | parser.add_argument('--nsamples', type=int, default=128, 196 | help='Number of calibration data samples.') 197 | parser.add_argument('--percdamp', type=float, default=.01, 198 | help='Percent of the average Hessian diagonal to use for dampening.') 199 | parser.add_argument('--wbits', type=int, default=16, choices=[4, 16], 200 | help='#bits to use for quantization; use 16 for evaluating base model.') 201 | parser.add_argument('--groupsize', type=int, default=-1, 202 | help='Groupsize to use for quantization/pruning; default uses full row.') 203 | parser.add_argument('--save', type=str, default='', 204 | help='Save quantized/pruned checkpoint under this name.') 205 | parser.add_argument('--load', type=str, default='', 206 | help='Load quantized/pruned checkpoint under this name.') 207 | parser.add_argument('--compression_type', type=str, required=True, 208 | choices=['quantizeonly', 'prunemaskonly', 'prunemaskreconstruction', 209 | 'prunemagnitudemask', 'quantizeprune', 'none',# 'pruneonly' 210 | ], 211 | help='Type of compression to perform.') 212 | parser.add_argument('--amount_prune', type=float, default=0.5, 213 | help='Amount of pruning to perform.') 214 | args = parser.parse_args() 215 | 216 | # If prune is to be done then args.amount_prune must be between 0 and 1 217 | if args.compression_type in ['pruneonly', 'quantizeprune', 'prunemaskonly', 218 | 'prunemaskreconstruction']: 219 | assert 0 <= args.amount_prune <= 1, 'Amount of pruning must be between 0 and 1' 220 | 221 | # Load model 222 | model = get_opt(args.model) 223 | model.eval() 224 | if args.load: 225 | model.load_state_dict(torch.load(args.load)) 226 | 227 | # Load data 228 | dataloader, testloader = get_wikitext2( 229 | nsamples=args.nsamples, seed=args.seed, seqlen=model.seqlen, model_card=args.model) 230 | 231 | # Perform compression 232 | if args.wbits != 16: 233 | compression_class = None # pylint: disable=invalid-name 234 | if args.compression_type == 'quantizeonly': 235 | compression_class = smart_compressors.QuantizeOnly 236 | elif args.compression_type == 'prunemaskonly': 237 | compression_class = smart_compressors.PruneMaskOnly 238 | elif args.compression_type == 'prunemaskreconstruction': 239 | compression_class = smart_compressors.PruneMaskReconstruction 240 | elif args.compression_type == 'prunemagnitudemask': 241 | compression_class = smart_compressors.PruneMagnitudeMask 242 | elif args.compression_type == 'none': 243 | pass 244 | elif args.compression_type == 'quantizeprune': 245 | raise NotImplementedError 246 | else: 247 | raise ValueError('Unknown compression type: %s' % args.compression_type) 248 | 249 | if compression_class is not None: 250 | tick = time.time() 251 | computed_compressors = opt_sequential(model, dataloader, DEVICE, compression_class) 252 | print("Total time taken: %.2f seconds" % (time.time() - tick)) # pylint: disable=consider-using-f-string 253 | 254 | # Do benchmark 255 | if args.compression_type in ["quantizeonly", "prunemaskonly", "prunemaskreconstruction"]: 256 | model = model.to(DEVICE) 257 | print(benchmark(model, dataloader)) 258 | # elif args.compression_type == "pruneonly": 259 | # layer_names = ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", 260 | # "self_attn.out_proj", "fc1", "fc2"] 261 | # all_compressor_keys = ['model.decoder.layers.%d.%s' % (i, name) # pylint: disable=consider-using-f-string 262 | # for i in range(len(model.model.decoder.layers)) 263 | # for name in layer_names] 264 | # model = model.to(DEVICE) 265 | # # First benchmark with no pruning 266 | # print("Benchmarking with no pruning...") 267 | # print(benchmark(model, dataloader)) 268 | # print("\n\n") 269 | # # Now prune with only mask out and no reconstruction 270 | # for key in all_compressor_keys: 271 | # computed_compressors[key].layer.weight.data *= computed_compressors[ 272 | # key].new_weight_with_mask[1] 273 | # print("Benchmarking with only mask out...") 274 | # print(benchmark(model, dataloader)) 275 | # print("\n\n") 276 | # # Now prune with masking and reconstruction 277 | # for key in all_compressor_keys: 278 | # print(key, torch.sum(computed_compressors[key].new_weight_with_mask[1] == 0).item() / 279 | # computed_compressors[key].new_weight_with_mask[1].numel()) 280 | # # # print percentage of new_weight_with_mask[1] that is 0 281 | # # print(torch.sum(computed_compressors[key].new_weight_with_mask[1] == 0).item() / 282 | # # computed_compressors[key].new_weight_with_mask[1].numel()) 283 | # computed_compressors[key].layer.weight.data = ( 284 | # computed_compressors[key].new_weight_with_mask[0] * computed_compressors[ 285 | # key].new_weight_with_mask[1]).half() 286 | # print("Benchmarking with mask out and reconstruction...") 287 | # print(benchmark(model, dataloader)) 288 | # print("\n\n") 289 | elif args.compression_type == "quantizeprune": 290 | raise NotImplementedError 291 | else: 292 | model = model.to(DEVICE) 293 | print(benchmark(model, dataloader)) 294 | 295 | # Save 296 | if args.save: 297 | torch.save(model.state_dict(), args.save) 298 | print("Done") 299 | print("\n" * 5) 300 | -------------------------------------------------------------------------------- /smart_compressors.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.utils.prune as prune 7 | import transformers 8 | 9 | import quant 10 | 11 | DEBUG = False 12 | 13 | torch.backends.cuda.matmul.allow_tf32 = False 14 | torch.backends.cudnn.allow_tf32 = False 15 | 16 | DEVICE = torch.device('cpu') 17 | if torch.cuda.is_available(): 18 | DEVICE = torch.device('cuda:0') 19 | 20 | class QuantizeOnly: 21 | """Quantize only, no pruning.""" 22 | def __init__(self, layer, amount_prune=0.0): 23 | self.layer = layer 24 | self.dev = self.layer.weight.device 25 | W = layer.weight.data.clone() 26 | if isinstance(self.layer, nn.Conv2d): 27 | W = W.flatten(1) 28 | if isinstance(self.layer, transformers.Conv1D): 29 | W = W.t() 30 | self.rows = W.shape[0] 31 | self.columns = W.shape[1] 32 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 33 | self.nsamples = 0 34 | 35 | def add_batch(self, inp, out): 36 | if DEBUG: 37 | self.inp1 = inp 38 | self.out1 = out 39 | if len(inp.shape) == 2: 40 | inp = inp.unsqueeze(0) 41 | tmp = inp.shape[0] 42 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 43 | if len(inp.shape) == 3: 44 | inp = inp.reshape((-1, inp.shape[-1])) 45 | inp = inp.t() 46 | if isinstance(self.layer, nn.Conv2d): 47 | unfold = nn.Unfold( 48 | self.layer.kernel_size, 49 | dilation=self.layer.dilation, 50 | padding=self.layer.padding, 51 | stride=self.layer.stride 52 | ) 53 | inp = unfold(inp) 54 | inp = inp.permute([1, 0, 2]) 55 | inp = inp.flatten(1) 56 | self.H *= self.nsamples / (self.nsamples + tmp) 57 | self.nsamples += tmp 58 | # inp = inp.float() 59 | inp = math.sqrt(2 / self.nsamples) * inp.float() 60 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 61 | self.H += inp.matmul(inp.t()) 62 | 63 | def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1): 64 | W = self.layer.weight.data.clone() 65 | if isinstance(self.layer, nn.Conv2d): 66 | W = W.flatten(1) 67 | if isinstance(self.layer, transformers.Conv1D): 68 | W = W.t() 69 | W = W.float() 70 | 71 | tick = time.time() 72 | 73 | if not self.quantizer.ready(): 74 | self.quantizer.find_params(W, weight=True) 75 | 76 | H = self.H 77 | del self.H 78 | dead = torch.diag(H) == 0 79 | H[dead, dead] = 1 80 | W[:, dead] = 0 81 | 82 | Losses = torch.zeros_like(W) 83 | Q = torch.zeros_like(W) 84 | 85 | damp = percdamp * torch.mean(torch.diag(H)) 86 | diag = torch.arange(self.columns, device=self.dev) 87 | H[diag, diag] += damp 88 | H = torch.linalg.cholesky(H) 89 | H = torch.cholesky_inverse(H) 90 | H = torch.linalg.cholesky(H, upper=True) 91 | Hinv = H 92 | 93 | for i1 in range(0, self.columns, blocksize): 94 | i2 = min(i1 + blocksize, self.columns) 95 | count = i2 - i1 96 | 97 | W1 = W[:, i1:i2].clone() 98 | Q1 = torch.zeros_like(W1) 99 | Err1 = torch.zeros_like(W1) 100 | Losses1 = torch.zeros_like(W1) 101 | Hinv1 = Hinv[i1:i2, i1:i2] 102 | 103 | for i in range(count): 104 | w = W1[:, i] 105 | d = Hinv1[i, i] 106 | 107 | if groupsize != -1: 108 | if (i1 + i) % groupsize == 0: 109 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) 110 | 111 | q = quant.quantize( 112 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 113 | ).flatten() 114 | Q1[:, i] = q 115 | Losses1[:, i] = (w - q) ** 2 / d ** 2 116 | 117 | err1 = (w - q) / d 118 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 119 | Err1[:, i] = err1 120 | 121 | Q[:, i1:i2] = Q1 122 | Losses[:, i1:i2] = Losses1 / 2 123 | 124 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 125 | 126 | if DEBUG: 127 | self.layer.weight.data[:, :i2] = Q[:, :i2] 128 | self.layer.weight.data[:, i2:] = W[:, i2:] 129 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 130 | print(torch.sum(Losses)) 131 | 132 | if DEVICE.type != 'cpu': 133 | torch.cuda.synchronize() 134 | print('time %.2f' % (time.time() - tick)) 135 | print('error', torch.sum(Losses).item()) 136 | 137 | if isinstance(self.layer, transformers.Conv1D): 138 | Q = Q.t() 139 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 140 | if DEBUG: 141 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 142 | 143 | def free(self): 144 | if DEBUG: 145 | self.inp1 = None 146 | self.out1 = None 147 | self.H = None 148 | self.Losses = None 149 | self.Trace = None 150 | torch.cuda.empty_cache() 151 | 152 | class PruneMaskOnly(QuantizeOnly): 153 | """Prunes `amount_prune` of the weights based on second order info without reconstruction.""" 154 | def __init__(self, layer, amount_prune): 155 | super().__init__(layer) 156 | self.amount_prune = amount_prune 157 | self.new_weight_with_mask = None 158 | 159 | def fasterquant(self, blocksize=256, percdamp=.01, groupsize=-1): 160 | W = self.layer.weight.data.clone() 161 | if isinstance(self.layer, nn.Conv2d): 162 | W = W.flatten(1) 163 | if isinstance(self.layer, transformers.Conv1D): 164 | W = W.t() 165 | W = W.float() # Shape: Out X In 166 | 167 | tick = time.time() 168 | 169 | H = self.H 170 | del self.H 171 | dead = torch.diag(H) == 0 172 | H[dead, dead] = 1 173 | W[:, dead] = 0 174 | 175 | M = torch.zeros_like(W) + 1 176 | E = (torch.zeros_like(W))[:, :blocksize] 177 | 178 | damp = percdamp * torch.mean(torch.diag(H)) 179 | diag = torch.arange(self.columns, device=self.dev) 180 | H[diag, diag] += damp 181 | H = torch.linalg.cholesky(H) 182 | H = torch.cholesky_inverse(H) 183 | H = torch.linalg.cholesky(H, upper=True) 184 | Hinv = H 185 | 186 | print("Starting pruning: ", end="") 187 | 188 | for i in range(0, self.columns, blocksize): 189 | i2 = min(i + blocksize, self.columns) 190 | count = i2 - i 191 | assert count == blocksize, ( 192 | count, blocksize, i, i2, self.layer.weight.data.shape) 193 | 194 | for ji in range(count): 195 | j = ji + i 196 | 197 | if ji == 0: 198 | # Determine the weights for pruning mask selection for the next column 199 | copy_linear = torch.nn.Linear(blocksize, W.shape[0]).to(W.device) 200 | copy_linear.weight.data = W[:, j:j+blocksize] ** 2 201 | copy_linear.weight.data /= torch.diag(Hinv).unsqueeze(0)[:, j:j+blocksize] 202 | 203 | prune.l1_unstructured(copy_linear, name='weight', amount=self.amount_prune) 204 | print(f"{j}:{j+blocksize}", end = ", ") 205 | M[:, j:j+blocksize] = copy_linear.weight_mask 206 | E[:, j-i] = (1 - M[:, j]) * (W[:, j] / Hinv[j, j]) 207 | W[:, j:i+blocksize] -= E[:, j-i].unsqueeze(1) * Hinv[j, j:i+blocksize].unsqueeze(0) 208 | 209 | W[:, i+blocksize:] -= E.matmul(Hinv[i:i+blocksize, i+blocksize:]) # Keep as is 210 | 211 | if DEVICE.type != "cpu": 212 | torch.cuda.synchronize() 213 | print('\nPrune time %.2f' % (time.time() - tick)) 214 | self.new_weight_with_mask = (W, M) 215 | self.layer.weight.data = (self.layer.weight.data * M).to(self.layer.weight.data.dtype) 216 | 217 | class PruneMaskReconstruction(PruneMaskOnly): 218 | """Prunes `amount_prune` of the weights based on second order info with reconstruction.""" 219 | def __init__(self, layer, amount_prune): 220 | super().__init__(layer, amount_prune) 221 | 222 | def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1): 223 | super().fasterquant(blocksize, percdamp, groupsize) 224 | print(self.new_weight_with_mask[0]) 225 | (W, M) = self.new_weight_with_mask 226 | self.layer.weight.data = (W * M.reshape(self.layer.weight.shape)).to( 227 | self.layer.weight.data.dtype) 228 | 229 | class PruneMagnitudeMask(PruneMaskOnly): 230 | """Prunes `amount_prune` of the weights with lower magnitude in given layer.""" 231 | def __init__(self, layer, amount_prune): 232 | super().__init__(layer, amount_prune) 233 | 234 | def add_batch(self, inp, out): 235 | """We don't need hessian for this.""" 236 | return 237 | 238 | def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1): 239 | W = self.layer.weight.data.clone() 240 | copy_linear = torch.nn.Linear(W.shape[1], W.shape[0]).to(W.device) 241 | copy_linear.weight.data = W.clone() 242 | prune.l1_unstructured(copy_linear, name='weight', amount=self.amount_prune) 243 | 244 | self.new_weight_with_mask = (copy_linear.weight.data, copy_linear.weight_mask) 245 | self.layer.weight.data = (copy_linear.weight.data * copy_linear.weight_mask).to( 246 | self.layer.weight.data.dtype) 247 | --------------------------------------------------------------------------------