├── .gitignore ├── datautils.py ├── README.md ├── bertsquad ├── trainer_qa.py ├── utils_qa.py └── __init__.py ├── quant.py ├── main_trueobs.py ├── modelutils.py ├── spdy.py ├── postproc.py ├── database.py └── trueobs.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | bertsquad/checkpoint* 3 | bertsquad/tmp 4 | bertsquad/__pycache__ 5 | -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('yolov5') 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from torch.utils.data import Dataset, DataLoader, Subset 9 | import torchvision.datasets as datasets 10 | import torchvision.transforms as transforms 11 | 12 | 13 | def set_seed(seed): 14 | np.random.seed(seed) 15 | torch.random.manual_seed(seed) 16 | 17 | def random_subset(data, nsamples, seed): 18 | set_seed(seed) 19 | idx = np.arange(len(data)) 20 | np.random.shuffle(idx) 21 | return Subset(data, idx[:nsamples]) 22 | 23 | 24 | _IMAGENET_RGB_MEANS = (0.485, 0.456, 0.406) 25 | _IMAGENET_RGB_STDS = (0.229, 0.224, 0.225) 26 | 27 | def get_imagenet(path, noaug=False): 28 | img_size = 224 # standard 29 | train_transform = transforms.Compose([ 30 | transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean=_IMAGENET_RGB_MEANS, std=_IMAGENET_RGB_STDS), 34 | ]) 35 | non_rand_resize_scale = 256.0 / 224.0 # standard 36 | test_transform = transforms.Compose([ 37 | transforms.Resize(round(non_rand_resize_scale * img_size)), 38 | transforms.CenterCrop(img_size), 39 | transforms.ToTensor(), 40 | transforms.Normalize(mean=_IMAGENET_RGB_MEANS, std=_IMAGENET_RGB_STDS), 41 | ]) 42 | 43 | train_dir = os.path.join(os.path.expanduser(path), 'train') 44 | test_dir = os.path.join(os.path.expanduser(path), 'val') 45 | 46 | if noaug: 47 | train_dataset = datasets.ImageFolder(train_dir, test_transform) 48 | else: 49 | train_dataset = datasets.ImageFolder(train_dir, train_transform) 50 | test_dataset = datasets.ImageFolder(test_dir, test_transform) 51 | 52 | return train_dataset, test_dataset 53 | 54 | class YOLOv5Wrapper(Dataset): 55 | def __init__(self, original): 56 | self.original = original 57 | def __len__(self): 58 | return len(self.original) 59 | def __getitem__(self, idx): 60 | tmp = list(self.original[idx]) 61 | tmp[0] = tmp[0].float() / 255 62 | return tmp 63 | 64 | def get_coco(path, batchsize): 65 | from yolov5.utils.datasets import LoadImagesAndLabels 66 | train_data = LoadImagesAndLabels( 67 | os.path.join(path, 'images/calib'), batch_size=batchsize 68 | ) 69 | train_data = YOLOv5Wrapper(train_data) 70 | train_data.collate_fn = LoadImagesAndLabels.collate_fn 71 | test_data = LoadImagesAndLabels( 72 | os.path.join(path, 'images/val2017'), batch_size=batchsize, pad=.5 73 | ) 74 | test_data = YOLOv5Wrapper(test_data) 75 | test_data.collate_fn = LoadImagesAndLabels.collate_fn 76 | return train_data, test_data 77 | 78 | 79 | DEFAULT_PATHS = { 80 | 'imagenet': [ 81 | '../imagenet' 82 | ], 83 | 'coco': [ 84 | '../coco' 85 | ] 86 | } 87 | 88 | def get_loaders( 89 | name, path='', batchsize=-1, workers=8, nsamples=1024, seed=0, 90 | noaug=False 91 | ): 92 | if name == 'squad': 93 | if batchsize == -1: 94 | batchsize = 16 95 | import bertsquad 96 | set_seed(seed) 97 | return bertsquad.get_dataloader(batchsize, nsamples), None 98 | 99 | if not path: 100 | for path in DEFAULT_PATHS[name]: 101 | if os.path.exists(path): 102 | break 103 | 104 | if name == 'imagenet': 105 | if batchsize == -1: 106 | batchsize = 128 107 | train_data, test_data = get_imagenet(path, noaug=noaug) 108 | train_data = random_subset(train_data, nsamples, seed) 109 | if name == 'coco': 110 | if batchsize == -1: 111 | batchsize = 16 112 | train_data, test_data = get_coco(path, batchsize) 113 | 114 | collate_fn = train_data.collate_fn if hasattr(train_data, 'collate_fn') else None 115 | trainloader = DataLoader( 116 | train_data, batch_size=batchsize, num_workers=workers, pin_memory=True, shuffle=True, 117 | collate_fn=collate_fn 118 | ) 119 | collate_fn = test_data.collate_fn if hasattr(test_data, 'collate_fn') else None 120 | testloader = DataLoader( 121 | test_data, batch_size=batchsize, num_workers=workers, pin_memory=True, shuffle=False, 122 | collate_fn=collate_fn 123 | ) 124 | 125 | return trainloader, testloader 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimal Brain Compression 2 | 3 | This repository contains efficient implementations of ExactOBS for quantization, 4 | unstructured-, block- and N:M pruning, introduced in the NeurIPS 2022 paper 5 | "Optimal Brain Compression: A Framework for Accurate Post-Training Quantization 6 | and Pruning". 7 | 8 | ## Files 9 | 10 | * `trueobs.py`: efficient implementations of ExactOBS for all compression types 11 | * `main_trueobs.py`: code to run ExactOBS 12 | * `post_proc.py`: post processing operations like statistics corrections 13 | * `database.py`: generating databases for non-uniform compression 14 | * `spdy.py`: implementation of the DP algorithm for finding non-uniform 15 | compression configurations; adapted from code provided by the authors of SPDY [9] 16 | * `modelutils.py`: model utilities 17 | * `datautils.py`: data utilities 18 | * `quant.py`: quantization utilities 19 | 20 | NOTE: The code as provided here only fully supports torchvision ResNet variants 21 | (the full integration of YOLO and BERT models is omitted due to large amounts 22 | of complex dependencies). 23 | 24 | ## Usage 25 | 26 | First, make sure ImageNet is located/linked to `../imagenet` (alternatively, 27 | you can specifiy the `--datapath` argument for all commands). 28 | 29 | ### Applying OBC 30 | 31 | ``` 32 | # Quantize weights and activations 33 | python main_trueobs.py rn18 imagenet quant --wbits 4 --abits 4 --save rn18_4w4a.pth 34 | 35 | # Prune to the N:M pattern 36 | python main_trueobs.py rn18 imagenet nmprune --prunen 2 --prunem 4 --save rn18_24.pth 37 | 38 | # Generate an unstructured pruning database 39 | mkdir models_unstr 40 | python main_trueobs.py rn18 imagenet unstr --sparse-dir models_unstr 41 | 42 | # Generate a 4-block pruning database 43 | mkdir models_4block 44 | python main_trueobs.py rn18 imagenet blocked --sparse-dir models_blocked 45 | 46 | # Quantize a 2:4 pruned model 47 | python main_trueobs.py rn18 imagenet quant --wbits 4 --abits 4 --load rn18_24.pth --save rn18_24_4w4a.pth 48 | ``` 49 | 50 | # Statistics Corrections 51 | 52 | ``` 53 | # Batchnorm tuning 54 | python postproc.py rn18 imagenet rn18_24.pth --bnt 55 | 56 | # Statistics correction 57 | python postproc.py rn18 imagenet rn18_24.pth --statcorr --statcorr-samples 1024 58 | ``` 59 | 60 | # Non-Uniform Compression 61 | 62 | ``` 63 | mkdir scores 64 | 65 | # Unstructured pruning 66 | 67 | # Setup database 68 | mkdir models_unstr 69 | python main_trueobs.py rn18 imagenet unstr --sparse-dir models_unstr 70 | # Compute corresponding losses 71 | python database.py rn18 imagenet unstr loss 72 | # Run DP algorithm to determine per-layer compression targets 73 | python spdy.py rn18 imagenet 2 unstr --dp 74 | # Stitch profile, apply batchnorm resetting and compute validation accuracy 75 | python postproc.py rn18 imagenet rn18_unstr_200x_dp.txt --database unstr --bnt 76 | 77 | # Mixed quantization + 2:4 pruning 78 | 79 | mkdir models_nm 80 | mkdir models_quant 81 | mkdir models_nm_quant 82 | python main_trueobs.py rn18 imagenet nmprune --save models_nm/rn18_24.pth 83 | python main_trueobs.py rn18 imagenet quant --wbits 8 --abits 8 --save models_quant/rn18_8w8a.pth 84 | python main_trueobs.py rn18 imagenet quant --wbits 4 --abits 4 --save models_quant/rn18_4w4a.pth 85 | python main_trueobs.py rn18 imagenet quant --wbits 8 --abits 8 --load models_nm/rn18_24.pth --save models_nm_quant/rn18_24_8w8a.pth 86 | python main_trueobs.py rn18 imagenet quant --wbits 4 --abits 4 --load models_nm/rn18_24.pth --save models_nm_quant/rn18_24_4w4a.pth 87 | python database.py rn18 imagenet mixed loss 88 | python spdy.py rn18 imagenet 8 mixed --dp 89 | python postproc.py rn18 imagenet rn18_mixed_800x_dp.txt --database mixed --bnt 90 | ``` 91 | 92 | # BERT 93 | 94 | Before using our BERT integration, please download our [pretrained checkpoints](https://seafile.ist.ac.at/d/c155c45712ad4bcb9341/) and move them to the `bertsquad` folder. 95 | Then you should be able to use most features described above by passing `bertsquad` (or `bertsquad6` for smaller variants) as the model name and `squad` as the dataset name. 96 | The code was tested with `transformers==4.21.2` and `datasets==1.17.0`. 97 | 98 | # BibTex 99 | 100 | ``` 101 | @article{frantar2022obc, 102 | title={{Optimal Brain Compression:} A Framework for Accurate Post-Training Quantization and Pruning}, 103 | author={Frantar, Elias and Singh, Sidak Pal and Alistarh, Dan}, 104 | journal={Advances in Neural Information Processing Systems}, 105 | volume={36}, 106 | year={2022} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /bertsquad/trainer_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | A subclass of `Trainer` specific to Question-Answering tasks 17 | """ 18 | 19 | from transformers import Trainer, is_torch_tpu_available 20 | from transformers.trainer_utils import PredictionOutput 21 | 22 | 23 | if is_torch_tpu_available(): 24 | import torch_xla.core.xla_model as xm 25 | import torch_xla.debug.metrics as met 26 | 27 | 28 | class QuestionAnsweringTrainer(Trainer): 29 | def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | self.eval_examples = eval_examples 32 | self.post_process_function = post_process_function 33 | 34 | def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): 35 | eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset 36 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 37 | eval_examples = self.eval_examples if eval_examples is None else eval_examples 38 | 39 | # Temporarily disable metric computation, we will do it in the loop here. 40 | compute_metrics = self.compute_metrics 41 | self.compute_metrics = None 42 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 43 | try: 44 | output = eval_loop( 45 | eval_dataloader, 46 | description="Evaluation", 47 | # No point gathering the predictions if there are no metrics, otherwise we defer to 48 | # self.args.prediction_loss_only 49 | prediction_loss_only=True if compute_metrics is None else None, 50 | ignore_keys=ignore_keys, 51 | ) 52 | finally: 53 | self.compute_metrics = compute_metrics 54 | 55 | if self.post_process_function is not None and self.compute_metrics is not None: 56 | eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) 57 | metrics = self.compute_metrics(eval_preds) 58 | 59 | # Prefix all keys with metric_key_prefix + '_' 60 | for key in list(metrics.keys()): 61 | if not key.startswith(f"{metric_key_prefix}_"): 62 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 63 | 64 | self.log(metrics) 65 | else: 66 | metrics = {} 67 | 68 | if self.args.tpu_metrics_debug or self.args.debug: 69 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 70 | xm.master_print(met.metrics_report()) 71 | 72 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) 73 | return metrics 74 | 75 | def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): 76 | predict_dataloader = self.get_test_dataloader(predict_dataset) 77 | 78 | # Temporarily disable metric computation, we will do it in the loop here. 79 | compute_metrics = self.compute_metrics 80 | self.compute_metrics = None 81 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 82 | try: 83 | output = eval_loop( 84 | predict_dataloader, 85 | description="Prediction", 86 | # No point gathering the predictions if there are no metrics, otherwise we defer to 87 | # self.args.prediction_loss_only 88 | prediction_loss_only=True if compute_metrics is None else None, 89 | ignore_keys=ignore_keys, 90 | ) 91 | finally: 92 | self.compute_metrics = compute_metrics 93 | 94 | if self.post_process_function is None or self.compute_metrics is None: 95 | return output 96 | 97 | predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") 98 | metrics = self.compute_metrics(predictions) 99 | 100 | # Prefix all keys with metric_key_prefix + '_' 101 | for key in list(metrics.keys()): 102 | if not key.startswith(f"{metric_key_prefix}_"): 103 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 104 | 105 | return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) 106 | -------------------------------------------------------------------------------- /quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def quantize(x, scale, zero, maxq): 6 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 7 | return scale * (q - zero) 8 | 9 | class Quantizer(nn.Module): 10 | 11 | def __init__(self, shape=1): 12 | super(Quantizer, self).__init__() 13 | self.register_buffer('maxq', torch.tensor(0)) 14 | self.register_buffer('scale', torch.zeros(shape)) 15 | self.register_buffer('zero', torch.zeros(shape)) 16 | 17 | def configure( 18 | self, 19 | bits, perchannel=False, sym=True, 20 | mse=False, norm=2.4, grid=100, maxshrink=.8 21 | ): 22 | self.maxq = torch.tensor(2 ** bits - 1) 23 | self.perchannel = perchannel 24 | self.sym = sym 25 | self.mse = mse 26 | self.norm = norm 27 | self.grid = grid 28 | self.maxshrink = maxshrink 29 | 30 | def find_params(self, x, weight=False): 31 | dev = x.device 32 | self.maxq = self.maxq.to(dev) 33 | 34 | shape = x.shape 35 | if self.perchannel: 36 | if weight: 37 | x = x.flatten(1) 38 | else: 39 | if len(shape) == 4: 40 | x = x.permute([1, 0, 2, 3]) 41 | x = x.flatten(1) 42 | if len(shape) == 3: 43 | x = x.reshape((-1, shape[-1])).t() 44 | if len(shape) == 2: 45 | x = x.t() 46 | else: 47 | x = x.flatten().unsqueeze(0) 48 | 49 | tmp = torch.zeros(x.shape[0], device=dev) 50 | xmin = torch.minimum(x.min(1)[0], tmp) 51 | xmax = torch.maximum(x.max(1)[0], tmp) 52 | 53 | if self.sym: 54 | xmax = torch.maximum(torch.abs(xmin), xmax) 55 | tmp = xmin < 0 56 | if torch.any(tmp): 57 | xmin[tmp] = -xmax[tmp] 58 | tmp = (xmin == 0) & (xmax == 0) 59 | xmin[tmp] = -1 60 | xmax[tmp] = +1 61 | 62 | self.scale = (xmax - xmin) / self.maxq 63 | if self.sym: 64 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 65 | else: 66 | self.zero = torch.round(-xmin / self.scale) 67 | 68 | if self.mse: 69 | best = torch.full([x.shape[0]], float('inf'), device=dev) 70 | for i in range(int(self.maxshrink * self.grid)): 71 | p = 1 - i / self.grid 72 | xmin1 = p * xmin 73 | xmax1 = p * xmax 74 | scale1 = (xmax1 - xmin1) / self.maxq 75 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 76 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 77 | q -= x 78 | q.abs_() 79 | q.pow_(self.norm) 80 | err = torch.sum(q, 1) 81 | tmp = err < best 82 | if torch.any(tmp): 83 | best[tmp] = err[tmp] 84 | self.scale[tmp] = scale1[tmp] 85 | self.zero[tmp] = zero1[tmp] 86 | if not self.perchannel: 87 | if weight: 88 | tmp = shape[0] 89 | else: 90 | tmp = shape[1] if len(shape) != 3 else shape[2] 91 | self.scale = self.scale.repeat(tmp) 92 | self.zero = self.zero.repeat(tmp) 93 | 94 | if weight: 95 | shape = [-1] + [1] * (len(shape) - 1) 96 | # self.scale = self.scale.unsqueeze(1) 97 | # self.zero = self.zero.unsqueeze(1) 98 | self.scale = self.scale.reshape(shape) 99 | self.zero = self.zero.reshape(shape) 100 | return 101 | if len(shape) == 4: 102 | self.scale = self.scale.reshape((1, -1, 1, 1)) 103 | self.zero = self.zero.reshape((1, -1, 1, 1)) 104 | if len(shape) == 3: 105 | self.scale = self.scale.reshape((1, 1, -1)) 106 | self.zero = self.zero.reshape((1, 1, -1)) 107 | if len(shape) == 2: 108 | self.scale = self.scale.unsqueeze(0) 109 | self.zero = self.zero.unsqueeze(0) 110 | 111 | def quantize(self, x): 112 | if self.ready(): 113 | return quantize(x, self.scale, self.zero, self.maxq) 114 | return x 115 | 116 | def enabled(self): 117 | return self.maxq > 0 118 | 119 | def ready(self): 120 | return torch.all(self.scale != 0) 121 | 122 | class ActQuantWrapper(nn.Module): 123 | 124 | def __init__(self, module): 125 | super(ActQuantWrapper, self).__init__() 126 | self.module = module 127 | shape = [1] * len(self.module.weight.shape) 128 | if len(shape) == 4: 129 | shape[1] = self.module.weight.shape[1] 130 | if len(shape) == 3: 131 | shape[2] = self.module.weight.shape[2] 132 | if len(shape) == 2: 133 | shape[1] = self.module.weight.shape[1] 134 | self.quantizer = Quantizer(shape=shape) 135 | 136 | def forward(self, x): 137 | return self.module(self.quantizer.quantize(x)) 138 | 139 | def add_actquant(module, name='', layers=[nn.Conv2d, nn.Linear]): 140 | if isinstance(module, ActQuantWrapper): 141 | return 142 | for attr in dir(module): 143 | tmp = getattr(module, attr) 144 | if type(tmp) in layers: 145 | setattr(module, attr, ActQuantWrapper(tmp)) 146 | if type(tmp) == nn.Sequential: 147 | replaced = [] 148 | for i, child in enumerate(tmp.children()): 149 | if type(child) in layers: 150 | replaced.append(ActQuantWrapper(child)) 151 | else: 152 | replaced.append(child) 153 | setattr(module, attr, nn.Sequential(*replaced)) 154 | if type(tmp) == torch.nn.ModuleList: 155 | replaced = [] 156 | for i, child in enumerate(tmp.children()): 157 | if type(child) in layers: 158 | replaced.append(ActQuantWrapper(child)) 159 | else: 160 | replaced.append(child) 161 | setattr(module, attr, nn.ModuleList(replaced)) 162 | for name1, child in module.named_children(): 163 | add_actquant(child, name + '.' + name1 if name != '' else name1, layers) 164 | -------------------------------------------------------------------------------- /main_trueobs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from datautils import * 9 | from modelutils import * 10 | from quant import * 11 | from trueobs import * 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('model', type=str) 17 | parser.add_argument('dataset', type=str) 18 | parser.add_argument( 19 | 'compress', type=str, choices=['quant', 'nmprune', 'unstr', 'struct', 'blocked'] 20 | ) 21 | parser.add_argument('--load', type=str, default='') 22 | parser.add_argument('--datapath', type=str, default='') 23 | parser.add_argument('--seed', type=int, default=0) 24 | parser.add_argument('--save', type=str, default='') 25 | 26 | parser.add_argument('--nsamples', type=int, default=1024) 27 | parser.add_argument('--batchsize', type=int, default=-1) 28 | parser.add_argument('--workers', type=int, default=8) 29 | parser.add_argument('--nrounds', type=int, default=-1) 30 | parser.add_argument('--noaug', action='store_true') 31 | 32 | parser.add_argument('--wbits', type=int, default=32) 33 | parser.add_argument('--abits', type=int, default=32) 34 | parser.add_argument('--wperweight', action='store_true') 35 | parser.add_argument('--wasym', action='store_true') 36 | parser.add_argument('--wminmax', action='store_true') 37 | parser.add_argument('--asym', action='store_true') 38 | parser.add_argument('--aminmax', action='store_true') 39 | parser.add_argument('--rel-damp', type=float, default=0) 40 | 41 | parser.add_argument('--prunen', type=int, default=2) 42 | parser.add_argument('--prunem', type=int, default=4) 43 | parser.add_argument('--blocked_size', type=int, default=4) 44 | parser.add_argument('--min-sparsity', type=float, default=0) 45 | parser.add_argument('--max-sparsity', type=float, default=0) 46 | parser.add_argument('--delta-sparse', type=float, default=0) 47 | parser.add_argument('--sparse-dir', type=str, default='') 48 | 49 | args = parser.parse_args() 50 | 51 | dataloader, testloader = get_loaders( 52 | args.dataset, path=args.datapath, 53 | batchsize=args.batchsize, workers=args.workers, 54 | nsamples=args.nsamples, seed=args.seed, 55 | noaug=args.noaug 56 | ) 57 | if args.nrounds == -1: 58 | args.nrounds = 1 if 'yolo' in args.model or 'bert' in args.model else 10 59 | if args.noaug: 60 | args.nrounds = 1 61 | get_model, test, run = get_functions(args.model) 62 | 63 | aquant = args.compress == 'quant' and args.abits < 32 64 | wquant = args.compress == 'quant' and args.wbits < 32 65 | 66 | modelp = get_model() 67 | if args.compress == 'quant' and args.load: 68 | modelp.load_state_dict(torch.load(args.load)) 69 | if aquant: 70 | add_actquant(modelp) 71 | modeld = get_model() 72 | layersp = find_layers(modelp) 73 | layersd = find_layers(modeld) 74 | 75 | SPARSE_DEFAULTS = { 76 | 'unstr': (0, .99, .1), 77 | 'struct': (0, .9, .05), 78 | 'blocked': (0, .95, .1) 79 | } 80 | sparse = args.compress in SPARSE_DEFAULTS 81 | if sparse: 82 | if args.min_sparsity == 0 and args.max_sparsity == 0: 83 | defaults = SPARSE_DEFAULTS[args.compress] 84 | args.min_sparsity, args.max_sparsity, args.delta_sparse = defaults 85 | sparsities = [] 86 | density = 1 - args.min_sparsity 87 | while density > 1 - args.max_sparsity: 88 | sparsities.append(1 - density) 89 | density *= 1 - args.delta_sparse 90 | sparsities.append(args.max_sparsity) 91 | sds = {s: copy.deepcopy(modelp).cpu().state_dict() for s in sparsities} 92 | 93 | trueobs = {} 94 | for name in layersp: 95 | layer = layersp[name] 96 | if isinstance(layer, ActQuantWrapper): 97 | layer = layer.module 98 | trueobs[name] = TrueOBS(layer, rel_damp=args.rel_damp) 99 | if aquant: 100 | layersp[name].quantizer.configure( 101 | args.abits, sym=args.asym, mse=not args.aminmax 102 | ) 103 | if wquant: 104 | trueobs[name].quantizer = Quantizer() 105 | trueobs[name].quantizer.configure( 106 | args.wbits, perchannel=not args.wperweight, sym=not args.wasym, mse=not args.wminmax 107 | ) 108 | 109 | if not (args.compress == 'quant' and not wquant): 110 | cache = {} 111 | def add_batch(name): 112 | def tmp(layer, inp, out): 113 | trueobs[name].add_batch(inp[0].data, out.data) 114 | return tmp 115 | handles = [] 116 | for name in trueobs: 117 | handles.append(layersd[name].register_forward_hook(add_batch(name))) 118 | for i in range(args.nrounds): 119 | for j, batch in enumerate(dataloader): 120 | print(i, j) 121 | with torch.no_grad(): 122 | run(modeld, batch) 123 | for h in handles: 124 | h.remove() 125 | for name in trueobs: 126 | print(name) 127 | if args.compress == 'quant': 128 | print('Quantizing ...') 129 | trueobs[name].quantize() 130 | if args.compress == 'nmprune': 131 | if trueobs[name].columns % args.prunem == 0: 132 | print('N:M pruning ...') 133 | trueobs[name].nmprune(args.prunen, args.prunem) 134 | if sparse: 135 | Ws = None 136 | if args.compress == 'unstr': 137 | print('Unstructured pruning ...') 138 | trueobs[name].prepare_unstr() 139 | Ws = trueobs[name].prune_unstr(sparsities) 140 | if args.compress == 'struct': 141 | if not isinstance(trueobs[name].layer, nn.Conv2d): 142 | size = 1 143 | else: 144 | tmp = trueobs[name].layer.kernel_size 145 | size = tmp[0] * tmp[1] 146 | if trueobs[name].columns / size > 3: 147 | print('Structured pruning ...') 148 | Ws = trueobs[name].prune_struct(sparsities, size=size) 149 | if args.compress == 'blocked': 150 | if trueobs[name].columns % args.blocked_size == 0: 151 | print('Blocked pruning ...') 152 | trueobs[name].prepare_blocked(args.blocked_size) 153 | Ws = trueobs[name].prune_blocked(sparsities) 154 | if Ws: 155 | for sparsity, W in zip(sparsities, Ws): 156 | sds[sparsity][name + '.weight'] = W.reshape(sds[sparsity][name + '.weight'].shape).cpu() 157 | trueobs[name].free() 158 | 159 | if sparse: 160 | if args.sparse_dir: 161 | for sparsity in sparsities: 162 | name = '%s_%04d.pth' % (args.model, int(sparsity * 10000)) 163 | torch.save(sds[sparsity], os.path.join(args.sparse_dir, name)) 164 | exit() 165 | 166 | if aquant: 167 | print('Quantizing activations ...') 168 | def init_actquant(name): 169 | def tmp(layer, inp, out): 170 | layersp[name].quantizer.find_params(inp[0].data) 171 | return tmp 172 | handles = [] 173 | for name in layersd: 174 | handles.append(layersd[name].register_forward_hook(init_actquant(name))) 175 | with torch.no_grad(): 176 | run(modeld, next(iter(dataloader))) 177 | for h in handles: 178 | h.remove() 179 | 180 | if args.save: 181 | torch.save(modelp.state_dict(), args.save) 182 | 183 | test(modelp, testloader) 184 | -------------------------------------------------------------------------------- /modelutils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('yolov5') 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from quant import * 9 | 10 | 11 | DEV = torch.device('cuda:0') 12 | 13 | 14 | def find_layers(module, layers=[nn.Conv2d, nn.Linear, ActQuantWrapper], name=''): 15 | if type(module) in layers: 16 | return {name: module} 17 | res = {} 18 | for name1, child in module.named_children(): 19 | res.update(find_layers( 20 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 21 | )) 22 | return res 23 | 24 | 25 | @torch.no_grad() 26 | def test(model, dataloader): 27 | train = model.training 28 | model.eval() 29 | print('Evaluating ...') 30 | dev = next(iter(model.parameters())).device 31 | preds = [] 32 | ys = [] 33 | for x, y in dataloader: 34 | preds.append(torch.argmax(model(x.to(dev)), 1)) 35 | ys.append(y.to(dev)) 36 | acc = torch.mean((torch.cat(preds) == torch.cat(ys)).float()).item() 37 | acc *= 100 38 | print('%.2f' % acc) 39 | if model.training: 40 | model.train() 41 | 42 | @torch.no_grad() 43 | def test_yolo(model, dataloader): 44 | import json 45 | from pathlib import Path 46 | from pycocotools.coco import COCO 47 | from pycocotools.cocoeval import COCOeval 48 | from yolov5.utils.general import coco80_to_coco91_class, non_max_suppression, scale_coords, xywh2xyxy 49 | from yolov5.utils.metrics import ap_per_class 50 | from yolov5.val import process_batch, save_one_json 51 | 52 | train = model.training 53 | model.eval() 54 | print('Evaluating ...') 55 | dev = next(iter(model.parameters())).device 56 | 57 | conf_thres = .001 58 | iou_thres = .65 59 | 60 | iouv = torch.Tensor([.5]) 61 | niou = iouv.numel() 62 | class_map = coco80_to_coco91_class() 63 | jdict = [] 64 | names = {k: v for k, v in enumerate(model.names)} 65 | 66 | for i, (im, targets, paths, shapes) in enumerate(dataloader): 67 | im = im.to(dev) 68 | targets = targets.to(dev) 69 | out, _ = model(im) 70 | nb, _, height, width = im.shape 71 | targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(dev) 72 | out = non_max_suppression( 73 | out, conf_thres, iou_thres, labels=[], multi_label=True, agnostic=False 74 | ) 75 | for si, pred in enumerate(out): 76 | labels = targets[targets[:, 0] == si, 1:] 77 | nl = len(labels) 78 | tcls = labels[:, 0].tolist() if nl else [] 79 | path, shape = Path(paths[si]), shapes[si][0] 80 | if len(pred) == 0: 81 | continue 82 | predn = pred.clone() 83 | scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) 84 | save_one_json(predn, jdict, path, class_map) 85 | 86 | anno_json = dataloader.dataset.original.path.replace( 87 | 'images/val2017', 'annotations/instances_val2017.json' 88 | ) 89 | import random 90 | pred_json = 'yolo-preds-for-eval-%d.json' % random.randint(0, 10 ** 6) 91 | with open(pred_json, 'w') as f: 92 | json.dump(jdict, f) 93 | 94 | anno = COCO(anno_json) 95 | pred = anno.loadRes(pred_json) 96 | eval = COCOeval(anno, pred, 'bbox') 97 | eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.original.img_files] 98 | eval.evaluate() 99 | eval.accumulate() 100 | eval.summarize() 101 | map, map50 = eval.stats[:2] 102 | print(100 * map50) 103 | os.remove(pred_json) 104 | 105 | if train: 106 | model.train() 107 | 108 | @torch.no_grad() 109 | def test_bertsquad(model, _): 110 | import bertsquad 111 | bertsquad.test(model) 112 | 113 | def get_test(name): 114 | if 'yolo' in name: 115 | return test_yolo 116 | if 'bertsquad' in name: 117 | return test_bertsquad 118 | return test 119 | 120 | 121 | def run(model, batch, loss=False, retmoved=False): 122 | dev = next(iter(model.parameters())).device 123 | if retmoved: 124 | return (batch[0].to(dev), batch[1].to(dev)) 125 | out = model(batch[0].to(dev)) 126 | if loss: 127 | return nn.functional.cross_entropy(out, batch[1].to(dev)).item() * batch[0].shape[0] 128 | return out 129 | 130 | def run_yolo(model, batch, loss=False, retmoved=False): 131 | dev = next(iter(model.parameters())).device 132 | if retmoved: 133 | return (batch[0].to(dev), batch[1].to(dev)) 134 | out = model(batch[0].to(dev)) 135 | if not model.training: 136 | out = out[1] 137 | if loss: 138 | return model.computeloss(out, batch[1].to(dev))[0].item() 139 | return torch.cat([o.flatten() for o in out]) 140 | 141 | def run_bert(model, batch, loss=False, retmoved=False): 142 | dev = next(iter(model.parameters())).device 143 | for k, v in batch.items(): 144 | batch[k] = v.to(DEV) 145 | if retmoved: 146 | return batch 147 | out = model(**batch) 148 | if loss: 149 | return out['loss'].item() * batch[k].shape[0] 150 | return torch.cat([out['start_logits'], out['end_logits']]) 151 | 152 | def get_run(model): 153 | if 'yolo' in model: 154 | return run_yolo 155 | if 'bert' in model: 156 | return run_bert 157 | return run 158 | 159 | 160 | def get_yolo(var): 161 | from yolov5.models.yolo import Model 162 | from yolov5.utils.downloads import attempt_download 163 | weights = attempt_download(var + '.pt') 164 | ckpt = torch.load(weights, map_location=DEV) 165 | model = Model(ckpt['model'].yaml) 166 | csd = ckpt['model'].float().state_dict() 167 | model.load_state_dict(csd, strict=False) 168 | from yolov5.utils.loss import ComputeLoss 169 | model.hyp = { 170 | 'box': .05, 'cls': .5, 'cls_pw': 1., 'obj': 1., 'obj_pw': 1., 'fl_gamma': 0., 'anchor_t': 4 171 | } 172 | model = model.to(DEV) 173 | model.computeloss = ComputeLoss(model) 174 | return model 175 | 176 | def get_bertsquad(layers=12): 177 | import bertsquad 178 | return bertsquad.get_model(layers=layers) 179 | 180 | from torchvision.models import resnet18, resnet34, resnet50, resnet101 181 | 182 | get_models = { 183 | 'rn18': lambda: resnet18(pretrained=True), 184 | 'rn34': lambda: resnet34(pretrained=True), 185 | 'rn50': lambda: resnet50(pretrained=True), 186 | 'yolov5s': lambda: get_yolo('yolov5s'), 187 | 'yolov5m': lambda: get_yolo('yolov5m'), 188 | 'yolov5l': lambda: get_yolo('yolov5l'), 189 | 'bertsquad': lambda: get_bertsquad(), 190 | 'bertsquad6': lambda: get_bertsquad(6), 191 | 'bertsquad3': lambda: get_bertsquad(3) 192 | } 193 | 194 | def get_model(model): 195 | model = get_models[model]() 196 | model = model.to(DEV) 197 | model.eval() 198 | return model 199 | 200 | 201 | def get_functions(model): 202 | return lambda: get_model(model), get_test(model), get_run(model) 203 | 204 | 205 | def firstlast_names(model): 206 | if 'rn' in model: 207 | return ['conv1', 'fc'] 208 | if 'bertsquad' in model: 209 | return [ 210 | 'bert.embeddings.word_embeddings', 211 | 'bert.embeddings.token_type_embeddings', 212 | 'qa_outputs' 213 | ] 214 | if 'yolo' in model: 215 | lastidx = {'n': 24}[model[6]] 216 | return ['model.0.conv'] + ['model.%d.m.%d' % (lastidx, i) for i in range(3)] 217 | -------------------------------------------------------------------------------- /spdy.py: -------------------------------------------------------------------------------- 1 | # Adapted from code provided by the authors of SPDY [9] 2 | 3 | import argparse 4 | import math 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from database import * 12 | from datautils import * 13 | from modelutils import * 14 | from quant import * 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('model') 19 | parser.add_argument('dataset') 20 | parser.add_argument('target', type=float) 21 | parser.add_argument('database', choices=['mixed', '4block', 'unstr', '4block_8w8a']) 22 | parser.add_argument('--errors', choices=['', 'squared', 'loss'], default='') 23 | parser.add_argument('--constr', choices=['', 'bits', 'bops', 'flops', 'timingsq'], default='') 24 | parser.add_argument('--nobatchnorm', action='store_true') 25 | parser.add_argument('--statcorr', action='store_true') 26 | parser.add_argument('--dpbuckets', type=int, default=10000) 27 | parser.add_argument('--dp', action='store_true') 28 | parser.add_argument('--score', type=str, default='') 29 | 30 | parser.add_argument('--prefix', type=str, default='') 31 | parser.add_argument('--datapath', type=str, default='') 32 | parser.add_argument('--seed', type=int, default=0) 33 | parser.add_argument('--nsamples', type=int, default=1024) 34 | parser.add_argument('--batchsize', type=int, default=-1) 35 | parser.add_argument('--workers', type=int, default=8) 36 | parser.add_argument('--nrounds', type=int, default=-1) 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | get_model, test, run = get_functions(args.model) 42 | dataloader, testloader = get_loaders( 43 | args.dataset, path=args.datapath, 44 | batchsize=args.batchsize, workers=args.workers, 45 | nsamples=args.nsamples, seed=args.seed, 46 | noaug=True 47 | ) 48 | 49 | modelp = get_model() 50 | if args.database in ['mixed', '4block_8w8a']: 51 | add_actquant(modelp) 52 | layersp = find_layers(modelp) 53 | 54 | batches = [] 55 | for batch in dataloader: 56 | batches.append(run(modelp, batch, retmoved=True)) 57 | 58 | if args.database == 'mixed': 59 | db = QuantNMDatabase(args.model, prefix=args.prefix) 60 | if args.database in ['4block', 'unstr', '4block_8w8a']: 61 | db = SparsityDatabase(args.database, args.model, prefix=args.prefix) 62 | 63 | DEFAULT_CONSTR = { 64 | 'mixed': 'bops', 65 | '4block': 'flops', 66 | 'unstr': 'flops', 67 | '4block_8w8a': 'timingsq' 68 | } 69 | if not args.constr: 70 | args.constr = DEFAULT_CONSTR[args.database] 71 | if not args.errors: 72 | args.errors = 'loss' if args.dp else 'spdy' 73 | 74 | errors = db.load_errors(args.errors) 75 | baseline_constr = None 76 | if args.constr == 'bits': 77 | constr = db.get_bits(layersp) 78 | if args.constr == 'bops': 79 | constr = db.get_bops(layersp, modelp, batches[0], run) 80 | if args.constr == 'flops': 81 | constr = db.get_flops(layersp, modelp, batches[0], run) 82 | if args.constr == 'timingsq': 83 | baseline_constr, constr = db.get_timingsq() 84 | 85 | 86 | modelp.train() 87 | if args.nobatchnorm or args.statcorr: 88 | batchnorms = find_layers(modelp, [nn.BatchNorm2d]) 89 | for bn in batchnorms.values(): 90 | bn.eval() 91 | if args.statcorr: 92 | batch = batches[0] 93 | batches = [batch] 94 | args.nsamples = args.batchsize 95 | 96 | modeld = get_model() 97 | layersd = find_layers(modeld, layers=[nn.BatchNorm2d, nn.LayerNorm]) 98 | layersp1 = find_layers(modelp, layers=[nn.BatchNorm2d, nn.LayerNorm]) 99 | 100 | REDUCE = { 101 | 2: [0], 102 | 3: [0, 1], 103 | 4: [0, 2, 3] 104 | } 105 | 106 | meansd = {} 107 | stdsd = {} 108 | def hookd(name): 109 | def tmp(layer, inp, out): 110 | red = REDUCE[len(out.shape)] 111 | meansd[name] = torch.mean(out.data, red, keepdim=True) 112 | stdsd[name] = torch.std(out.data, red, keepdim=True) 113 | return tmp 114 | def hookp(name): 115 | def tmp(layer, inp, out): 116 | red = REDUCE[len(out.shape)] 117 | meanp = torch.mean(out.data, red, keepdim=True) 118 | stdp = torch.std(out.data, red, keepdim=True) 119 | out.data -= meanp 120 | out.data *= stdsd[name] / (stdp + 1e-9) 121 | out.data += meansd[name] 122 | return tmp 123 | for name in layersd: 124 | layersd[name].register_forward_hook(hookd(name)) 125 | with torch.no_grad(): 126 | run(modeld, batch) 127 | for name in layersp1: 128 | layersp1[name].register_forward_hook(hookp(name)) 129 | 130 | 131 | layers = list(layersp.keys()) 132 | sparsities = list(errors[layers[0]].keys()) 133 | costs = [[errors[l][s] for s in sparsities] for l in layers] 134 | timings = [[constr[l][s] for s in sparsities] for l in layers] 135 | costs = np.array(costs) 136 | 137 | prunabletime = sum(max(c) for c in timings) 138 | if baseline_constr is None: 139 | baseline_constr = prunabletime 140 | target_constr = baseline_constr / args.target - (baseline_constr - prunabletime) 141 | best = sum(min(c) for c in timings) 142 | print('Max target:', baseline_constr / (best + baseline_constr - prunabletime)) 143 | bucketsize = target_constr / args.dpbuckets 144 | 145 | for row in timings: 146 | for i in range(len(row)): 147 | row[i] = int(round(row[i] / bucketsize)) 148 | 149 | def dp(costs): 150 | DP = np.full((len(layers), args.dpbuckets + 1), float('inf')) 151 | PD = np.full((len(layers), args.dpbuckets + 1), -1) 152 | 153 | for sparsity in range(len(sparsities)): 154 | if costs[0][sparsity] < DP[0][timings[0][sparsity]]: 155 | DP[0][timings[0][sparsity]] = costs[0][sparsity] 156 | PD[0][timings[0][sparsity]] = sparsity 157 | for layer in range(1, len(DP)): 158 | for sparsity in range(len(sparsities)): 159 | timing = timings[layer][sparsity] 160 | if timing == 0 and layer == len(DP) - 1: 161 | DP[layer] = DP[layer - 1] 162 | PD[layer] = 0 163 | continue 164 | if timing == 0 and layer == len(DP) - 1: 165 | DP[layer] = DP[layer - 1] 166 | PD[layer] = 0 167 | continue 168 | if timing < 1 or timing > args.dpbuckets: 169 | continue 170 | score = costs[layer][sparsity] 171 | tmp = DP[layer - 1][:-timing] + score 172 | better = tmp < DP[layer][timing:] 173 | if np.sum(better): 174 | DP[layer][timing:][better] = tmp[better] 175 | PD[layer][timing:][better] = sparsity 176 | 177 | score = np.min(DP[-1, :]) 178 | timing = np.argmin(DP[-1, :]) 179 | 180 | solution = [] 181 | for layer in range(len(DP) - 1, -1, -1): 182 | solution.append(PD[layer][timing]) 183 | timing -= timings[layer][solution[-1]] 184 | solution.reverse() 185 | return solution 186 | 187 | def gen_costs(coefs): 188 | return costs * coefs.reshape((-1, 1)) 189 | 190 | def stitch_model(solution): 191 | config = {n: sparsities[s] for n, s in zip(layers, solution)} 192 | db.stitch(layersp, config) 193 | return modelp 194 | 195 | @torch.no_grad() 196 | def get_loss(model): 197 | loss = 0 198 | for batch in batches: 199 | loss += run(modelp, batch, loss=True) 200 | return loss / args.nsamples 201 | 202 | def get_score(coefs): 203 | costs = gen_costs(coefs) 204 | solution = dp(costs) 205 | model = stitch_model(solution) 206 | return get_loss(model) 207 | 208 | 209 | if args.score: 210 | with open(args.score, 'r') as f: 211 | solution = [] 212 | for l in f.readlines(): 213 | splits = l.split(' ') 214 | sparsity = splits[0] 215 | name = splits[1][:-1] 216 | i = sparsities.index(sparsity) 217 | solution.append(i) 218 | print(baseline_constr / (baseline_constr - prunabletime + sum(t[s] for s, t in zip(solution, timings)) * bucketsize)) 219 | print(get_loss(stitch_model(solution))) 220 | exit() 221 | 222 | def save_profile(coefs, name=''): 223 | solution = dp(gen_costs(coefs)) 224 | if name: 225 | with open(name, 'w') as f: 226 | for s, n in zip(solution, layers): 227 | f.write('%s %s\n' % (sparsities[s], n)) 228 | else: 229 | for s, n in zip(solution, layers): 230 | print('%s %s' % (sparsities[s], n)) 231 | 232 | print('Base:', get_loss(modelp)) 233 | 234 | name = '%s_%s_%dx_spdy' % (args.model, args.database, int(args.target * 100)) 235 | name = os.path.join(args.prefix, name) 236 | 237 | if args.dp: 238 | name = name.replace('spdy', 'dp') 239 | coefs = np.ones(len(layers)) 240 | print(get_score(np.ones(len(layers)))) 241 | save_profile(coefs) 242 | save_profile(coefs, name + '.txt') 243 | exit() 244 | 245 | evals = 0 246 | print('Finding init ...') 247 | coefs = None 248 | score = float('inf') 249 | for _ in range(100): 250 | coefs1 = np.random.uniform(0, 1, size=len(layers)) 251 | score1 = get_score(coefs1) 252 | evals += 1 253 | print(evals) 254 | if score1 < score: 255 | print(score1) 256 | score = score1 257 | coefs = coefs1 258 | print('Running local search ...') 259 | for resamplings in range(int(.1 * len(layers)), 0, -1): 260 | print('Trying %d resamplings ...' % resamplings) 261 | improved = True 262 | while improved: 263 | improved = False 264 | for _ in range(100): 265 | coefs1 = coefs.copy() 266 | for _ in range(resamplings): 267 | coefs1[random.randint(0, len(layers) - 1)] = np.random.uniform(0, 1) 268 | score1 = get_score(coefs1) 269 | evals += 1 270 | print(evals) 271 | if score1 < score: 272 | print(score1) 273 | score = score1 274 | coefs = coefs1 275 | improved = True 276 | break 277 | 278 | print(coefs) 279 | save_profile(coefs) 280 | save_profile(coefs, name + '.txt') 281 | -------------------------------------------------------------------------------- /postproc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from datautils import * 7 | from database import * 8 | from modelutils import * 9 | from quant import * 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('model', type=str) 15 | parser.add_argument('dataset', type=str) 16 | parser.add_argument('load', type=str) 17 | parser.add_argument('--database', choices=['', 'mixed', '4block', 'unstr', '4block_8w8a'], default='') 18 | parser.add_argument('--prefix', type=str, default='') 19 | 20 | parser.add_argument('--datapath', type=str, default='') 21 | parser.add_argument('--seed', type=int, default=0) 22 | 23 | parser.add_argument('--nsamples', type=int, default=1024) 24 | parser.add_argument('--batchsize', type=int, default=-1) 25 | parser.add_argument('--workers', type=int, default=8) 26 | parser.add_argument('--nrounds', type=int, default=-1) 27 | parser.add_argument('--noaug', action='store_true') 28 | 29 | parser.add_argument('--skip-firstlast', action='store_true') 30 | 31 | parser.add_argument('--bnt', action='store_true') 32 | parser.add_argument('--bnt-batches', type=int, default=100) 33 | parser.add_argument('--lintune', action='store_true') 34 | parser.add_argument('--lintune-loss', action='store_true') 35 | parser.add_argument('--lintune-epochs', type=int, default=100) 36 | parser.add_argument('--lintune-lr', type=float, default=1e-4) 37 | parser.add_argument('--gap', action='store_true') 38 | parser.add_argument('--gap-epochs', type=int, default=100) 39 | parser.add_argument('--gap-lr', type=float, default=1e-5) 40 | parser.add_argument('--finetune', action='store_true') 41 | parser.add_argument('--finetune-mse', action='store_true') 42 | parser.add_argument('--finetune-epochs', type=int, default=2) 43 | parser.add_argument('--finetune-lr', type=float, default=1e-5) 44 | parser.add_argument('--statcorr', action='store_true') 45 | parser.add_argument('--statcorr-samples', type=int, default=-1) 46 | parser.add_argument('--save', type=str) 47 | 48 | 49 | args = parser.parse_args() 50 | 51 | dataloader, testloader = get_loaders( 52 | args.dataset, path=args.datapath, 53 | batchsize=args.batchsize, workers=args.workers, 54 | nsamples=args.nsamples, seed=args.seed, 55 | noaug=args.noaug 56 | ) 57 | get_model, test, run = get_functions(args.model) 58 | 59 | modelp = get_model() 60 | if args.load.endswith('.pth'): 61 | tmp = torch.load(args.load) 62 | if any('scale' in k for k in tmp): 63 | add_actquant(modelp) 64 | if args.skip_firstlast: 65 | for l in firstlast_names(args.model): 66 | if any('scale' in k for k in tmp): 67 | tmp[l + '.quantizer.scale'][:] = 0 68 | l += '.module' 69 | l += '.weight' 70 | tmp[l] = modelp.state_dict()[l] 71 | modelp.load_state_dict(tmp) 72 | modelp = modelp.to(DEV) 73 | 74 | if args.database != '': 75 | if args.database == 'mixed': 76 | print('Stitching ...') 77 | db = QuantNMDatabase(args.model, prefix=args.prefix) 78 | if args.database in ['4block', 'unstr', '4block_8w8a']: 79 | db = SparsityDatabase(args.database, args.model, prefix=args.prefix, dev='cpu') 80 | if args.database in ['mixed', '4block_8w8a']: 81 | add_actquant(modelp) 82 | modelp = modelp.to('cpu') 83 | layersp = find_layers(modelp) 84 | with open(args.load, 'r') as f: 85 | config = {} 86 | for l in f.readlines(): 87 | level, name = l.strip().split(' ') 88 | config[name] = level 89 | db.stitch(layersp, config) 90 | modelp = modelp.to(DEV) 91 | layersp = find_layers(modelp) 92 | if args.save: 93 | torch.save(modelp.state_dict(), args.save) 94 | exit() 95 | 96 | 97 | if args.bnt: 98 | print('Batchnorm tuning ...') 99 | 100 | loss = 0 101 | for batch in dataloader: 102 | loss += run(modelp, batch, loss=True) 103 | print(loss / args.nsamples) 104 | 105 | batchnorms = find_layers(modelp, [nn.BatchNorm2d]) 106 | for bn in batchnorms.values(): 107 | bn.reset_running_stats() 108 | bn.momentum = .1 109 | modelp.train() 110 | with torch.no_grad(): 111 | i = 0 112 | while i < args.bnt_batches: 113 | for batch in dataloader: 114 | if i == args.bnt_batches: 115 | break 116 | print('%03d' % i) 117 | run(modelp, batch) 118 | i += 1 119 | modelp.eval() 120 | 121 | loss = 0 122 | for batch in dataloader: 123 | loss += run(modelp, batch, loss=True) 124 | print(loss / args.nsamples) 125 | 126 | if args.lintune: 127 | print('Linear tuning ...') 128 | modeld = get_model() 129 | params = [] 130 | for n, p in modelp.named_parameters(): 131 | if len(p.shape) == 1: 132 | params.append(p) 133 | else: 134 | p.requires_grad = False 135 | optim = torch.optim.Adam(params, lr=args.lintune_lr) 136 | criterion = nn.MSELoss() 137 | for i in range(args.lintune_epochs): 138 | cumloss = 0 139 | for batch in dataloader: 140 | if args.lintune_loss: 141 | loss = run(modelp, batch, loss=True) 142 | else: 143 | with torch.no_grad(): 144 | y = run(modeld, batch) 145 | loss = criterion(run(modelp, batch), y) 146 | cumloss += loss.item() 147 | loss.backward() 148 | optim.step() 149 | optim.zero_grad() 150 | print('%02d %.4f' % (i, cumloss / len(dataloader))) 151 | 152 | if args.gap: 153 | modeld = get_model() 154 | layersp = find_layers(modelp) 155 | layersd = find_layers(modeld) 156 | 157 | masks = {n: l.weight.data == 0 for n, l in layersp.items()} 158 | 159 | def cache_output(name, outputs): 160 | def tmp(layer, inp, out): 161 | outputs[name] = out 162 | return tmp 163 | outputsp = {} 164 | handlesp = [] 165 | for name in layersp: 166 | handlesp.append( 167 | layersp[name].register_forward_hook(cache_output(name, outputsp)) 168 | ) 169 | outputsd = {} 170 | handlesd = [] 171 | for name in layersd: 172 | handlesd.append( 173 | layersd[name].register_forward_hook(cache_output(name, outputsd)) 174 | ) 175 | 176 | criterion = nn.MSELoss(reduction='sum') 177 | optim = torch.optim.Adam(modelp.parameters(), lr=args.gap_lr) 178 | 179 | for i in range(args.gap_epochs): 180 | cumloss = 0 181 | for batch in dataloader: 182 | with torch.no_grad(): 183 | run(modeld, batch) 184 | run(modelp, batch) 185 | loss = 0 186 | for name in outputsd: 187 | norm = torch.norm(outputsd[name].data).item() ** 2 188 | loss += criterion(outputsp[name], outputsd[name].data) / norm 189 | cumloss += loss.item() 190 | loss.backward() 191 | optim.step() 192 | optim.zero_grad() 193 | for name, mask in masks.items(): 194 | layersp[name].weight.data[mask] = 0 195 | print('%05d: %.6f' % (i, cumloss / len(dataloader))) 196 | 197 | for h in handlesp: 198 | h.remove() 199 | for h in handlesd: 200 | h.remove() 201 | 202 | if args.finetune: 203 | print('Finetuning ...') 204 | modeld = get_model() 205 | masks = {n: p == 0 for n, p in modelp.named_parameters()} 206 | optim = torch.optim.Adam(modelp.parameters(), lr=args.finetune_lr) 207 | criterion = nn.MSELoss() 208 | for i in range(args.finetune_epochs): 209 | cumloss = 0 210 | for batch in dataloader: 211 | if args.finetune_mse: 212 | with torch.no_grad(): 213 | y = run(modeld, batch) 214 | loss = criterion(run(modelp, batch), y) 215 | else: 216 | loss = run(modelp, batch, loss=True) 217 | cumloss += loss.item() 218 | loss.backward() 219 | optim.step() 220 | optim.zero_grad() 221 | for n, p in modelp.named_parameters(): 222 | p.data[masks[n]] = 0 223 | print('%02d %.4f' % (i, cumloss / len(dataloader))) 224 | 225 | if args.statcorr: 226 | print('Stat correction ...') 227 | 228 | if args.statcorr_samples == -1: 229 | args.statcorr_samples = args.nsamples 230 | trainloader, testloader = get_loaders( 231 | args.dataset, batchsize=args.statcorr_samples, noaug=True 232 | ) 233 | batch = next(iter(trainloader)) 234 | 235 | modeld = get_model() 236 | layersd = find_layers(modeld, layers=[nn.BatchNorm2d, nn.LayerNorm]) 237 | layersp = find_layers(modelp, layers=[nn.BatchNorm2d, nn.LayerNorm]) 238 | 239 | REDUCE = { 240 | 2: [0], 241 | 3: [0, 1], 242 | 4: [0, 2, 3] 243 | } 244 | 245 | meansd = {} 246 | stdsd = {} 247 | def hookd(name): 248 | def tmp(layer, inp, out): 249 | red = REDUCE[len(out.shape)] 250 | meansd[name] = torch.mean(out.data, red, keepdim=True) 251 | stdsd[name] = torch.std(out.data, red, keepdim=True) 252 | return tmp 253 | meansp = {} 254 | stdsp = {} 255 | def hookp(name): 256 | def tmp(layer, inp, out): 257 | red = REDUCE[len(out.shape)] 258 | meansp[name] = torch.mean(out.data, red, keepdim=True) 259 | stdsp[name] = torch.std(out.data, red, keepdim=True) 260 | out.data -= meansp[name] 261 | out.data *= stdsd[name] / (stdsp[name] + 1e-9) 262 | out.data += meansd[name] 263 | return tmp 264 | handles = [] 265 | for name in layersd: 266 | handles.append(layersd[name].register_forward_hook(hookd(name))) 267 | with torch.no_grad(): 268 | run(modeld, batch) 269 | for h in handles: 270 | h.remove() 271 | handles = [] 272 | for name in layersp: 273 | handles.append(layersp[name].register_forward_hook(hookp(name))) 274 | with torch.no_grad(): 275 | run(modelp, batch) 276 | for h in handles: 277 | h.remove() 278 | 279 | def hook(name): 280 | def tmp(layer, inp, out): 281 | out.data -= meansp[name] 282 | out.data *= stdsd[name] / (stdsp[name] + 1e-9) 283 | out.data += meansd[name] 284 | return tmp 285 | for name in layersp: 286 | layersp[name].register_forward_hook(hook(name)) 287 | 288 | test(modelp, testloader) 289 | -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from datautils import * 5 | from modelutils import * 6 | from quant import * 7 | 8 | 9 | def get_flops(layers, model, sample, run): 10 | flops = {} 11 | def record_flops(name): 12 | def tmp(layer, inp, out): 13 | inp = inp[0] 14 | if isinstance(layer, nn.Conv2d): 15 | flops[name] = inp.shape[2] * inp.shape[3] 16 | flops[name] *= layer.weight.numel() 17 | stride = list(layer.stride) 18 | flops[name] //= stride[0] * stride[1] 19 | if isinstance(layer, nn.Linear): 20 | flops[name] = layer.weight.numel() 21 | return tmp 22 | handles = [] 23 | for name, layer in layers.items(): 24 | if hasattr(layer, 'module'): 25 | layer.module.register_forward_hook(record_flops(name)) 26 | else: 27 | layer.register_forward_hook(record_flops(name)) 28 | with torch.no_grad(): 29 | run(model, sample) 30 | for h in handles: 31 | h.remove() 32 | return flops 33 | 34 | def load_errors(sds, path, norm=False): 35 | errors = {} 36 | with open(path, 'r') as f: 37 | lines = f.readlines() 38 | i = 0 39 | while i < len(lines): 40 | name = lines[i].strip() 41 | errors[name] = {} 42 | i += 1 43 | for _ in range(len(sds)): 44 | err, level = lines[i].strip().split(' ') 45 | errors[name][level] = float(err) 46 | i += 1 47 | if norm: 48 | for name in errors: 49 | norm = max(errors[name].values()) 50 | if norm > 0: 51 | for level in errors[name]: 52 | errors[name][level] /= norm 53 | return errors 54 | 55 | 56 | class SparsityDatabase: 57 | 58 | def __init__(self, sparsetype, model, prefix='', dev=DEV): 59 | self.sds = {} 60 | path = os.path.join(prefix, 'models_' + sparsetype) 61 | for f in os.listdir(path): 62 | if not (f.startswith(model + '_') and f.endswith('.pth')): 63 | continue 64 | sparsity = '0.' + f.split('.')[0].split('_')[1] 65 | self.sds[sparsity] = torch.load(os.path.join(path, f), map_location=dev) 66 | self.sparsetype = sparsetype 67 | self.model = model 68 | self.prefix = prefix 69 | 70 | def load(self, layers, name, config='', sd=None): 71 | if not sd: 72 | sd = self.sds[config] 73 | if '8w8a' in self.sparsetype: 74 | layers[name].module.weight.data = sd[name + '.module.weight'] 75 | layers[name].quantizer.maxq.data = sd[name + '.quantizer.maxq'] 76 | layers[name].quantizer.scale.data = sd[name + '.quantizer.scale'] 77 | layers[name].quantizer.zero.data = sd[name + '.quantizer.zero'] 78 | else: 79 | layers[name].weight.data = sd[name + '.weight'] 80 | 81 | def stitch(self, layers, config): 82 | for name, layer in layers.items(): 83 | self.load(layers, name, config[name]) 84 | 85 | def load_errors(self, name): 86 | path = os.path.join( 87 | self.prefix, 'scores/%s_%s_%s.txt' % (self.model, self.sparsetype, name) 88 | ) 89 | return load_errors(self.sds, path, norm=name == 'squared') 90 | 91 | def get_params(self, layers): 92 | res = {} 93 | for name in layers: 94 | res[name] = {} 95 | for sparsity in self.sds: 96 | res[name][sparsity] = torch.sum( 97 | (self.sds[sparsity][name + '.weight'] != 0).float() 98 | ).item() 99 | return res 100 | 101 | def get_flops(self, layers, model, sample, run): 102 | flops = get_flops(layers, model, sample, run) 103 | res = {} 104 | for name in layers: 105 | res[name] = {} 106 | for sparsity in self.sds: 107 | res[name][sparsity] = flops[name] * torch.mean( 108 | (self.sds[sparsity][name + '.weight'] != 0).float() 109 | ).item() 110 | return res 111 | 112 | def get_timingsq(self): 113 | timings = {} 114 | with open('timings/%sq.txt' % self.model, 'r') as f: 115 | lines = f.readlines() 116 | baselinetime = float(lines[0]) 117 | i = 1 118 | while i < len(lines): 119 | name = lines[i].strip() 120 | timings[name] = {} 121 | i += 1 122 | for _ in range(len(self.sds)): 123 | time, level = lines[i].strip().split(' ') 124 | timings[name][level] = float(time) 125 | i += 1 126 | return baselinetime, timings 127 | 128 | 129 | class QuantNMDatabase: 130 | 131 | def __init__(self, model, prefix=''): 132 | self.sds = {} 133 | for path in ['models_quant', 'models_nm_quant']: 134 | for f in os.listdir(os.path.join(prefix, path)): 135 | if not (f.startswith(model + '_') and f.endswith('.pth')): 136 | continue 137 | config = '_'.join(f.split('.')[0].split('_')[1:]) 138 | self.sds[config] = torch.load(os.path.join(prefix, path, f), map_location=DEV) 139 | self.model = model 140 | self.prefix = prefix 141 | 142 | def load(self, layers, name, config='', sd=None): 143 | if not sd: 144 | sd = self.sds[config] 145 | layers[name].module.weight.data = sd[name + '.module.weight'] 146 | layers[name].quantizer.maxq.data = sd[name + '.quantizer.maxq'] 147 | layers[name].quantizer.scale.data = sd[name + '.quantizer.scale'] 148 | layers[name].quantizer.zero.data = sd[name + '.quantizer.zero'] 149 | 150 | def stitch(self, layers, config): 151 | for name, layer in layers.items(): 152 | self.load(layers, name, config[name]) 153 | 154 | def load_errors(self, name): 155 | path = os.path.join(self.prefix, 'scores/%s_mixed_%s.txt' % (self.model, name)) 156 | return load_errors(self.sds, path, norm=name == 'squared') 157 | 158 | def get_bits(self, layers): 159 | res = {} 160 | for name, layer in layers.items(): 161 | paramcount = layer.module.weight.numel() 162 | res[name] = { 163 | # '24_4w4a': paramcount * 5, 164 | # '24_8w8a': paramcount * 9, 165 | '24_4w4a': paramcount * 4, 166 | '24_8w8a': paramcount * 8, 167 | '4w4a': paramcount * 4, 168 | '8w8a': paramcount * 8 169 | } 170 | return res 171 | 172 | def get_bops(self, layers, model, sample, run): 173 | flops = get_flops(layers, model, sample, run) 174 | res = {} 175 | for name, layer in layers.items(): 176 | res[name] = { 177 | '24_4w4a': flops[name] * 32 // 2 // 8, 178 | '24_8w8a': flops[name] * 32 // 2 // 4, 179 | '4w4a': flops[name] * 32 // 8, 180 | '8w8a': flops[name] * 32 // 4 181 | } 182 | if (layers[name].module.weight.numel() // layers[name].module.weight.shape[0]) % 4 != 0: 183 | res[name]['24_4w4a'] *= 2 184 | res[name]['24_8w8a'] *= 2 185 | return res 186 | 187 | 188 | if __name__ == '__main__': 189 | import argparse 190 | 191 | parser = argparse.ArgumentParser() 192 | 193 | parser.add_argument('model', type=str) 194 | parser.add_argument('dataset', type=str) 195 | parser.add_argument('database', choices=['mixed', '4block', 'unstr', '4block_8w8a']) 196 | parser.add_argument('mode', choices=['loss', 'squared', 'spdy', 'stitch', 'eval']) 197 | parser.add_argument('--prefix', type=str, default='') 198 | parser.add_argument('--profile', type=str, default='') 199 | parser.add_argument('--score_path', type=str, default='scores') 200 | 201 | parser.add_argument('--datapath', type=str, default='') 202 | parser.add_argument('--seed', type=int, default=0) 203 | parser.add_argument('--nsamples', type=int, default=1024) 204 | parser.add_argument('--batchsize', type=int, default=-1) 205 | parser.add_argument('--workers', type=int, default=8) 206 | parser.add_argument('--nrounds', type=int, default=-1) 207 | 208 | args = parser.parse_args() 209 | 210 | get_model, test, run = get_functions(args.model) 211 | dataloader, testloader = get_loaders( 212 | args.dataset, path=args.datapath, 213 | batchsize=args.batchsize, workers=args.workers, 214 | nsamples=args.nsamples, seed=args.seed, 215 | noaug=args.mode == 'loss' 216 | ) 217 | if args.nrounds == -1: 218 | args.nrounds = 1 if 'yolo' in args.model or 'bert' in args.model else 10 219 | if args.mode == 'loss': 220 | args.nrounds = 1 221 | 222 | filepath = os.path.join(args.prefix, args.score_path, '%s_%s_%s.txt' % (args.model, args.database, args.mode)) 223 | 224 | modelp = get_model() 225 | if args.database == 'mixed': 226 | db = QuantNMDatabase(args.model, prefix=args.prefix) 227 | if args.database in ['4block', 'unstr', '4block_8w8a']: 228 | db = SparsityDatabase(args.database, args.model, prefix=args.prefix) 229 | if args.database in ['mixed', '4block_8w8a']: 230 | add_actquant(modelp) 231 | layersp = find_layers(modelp) 232 | 233 | for i in range(layersp['fc'].weight.shape[0]): 234 | print(i) 235 | W = layersp['fc'].weight.data 236 | thresh = torch.sort(torch.abs(W[i, :]), descending=True)[0][9] 237 | W[i, torch.abs(W[i, :]) < thresh] = 0 238 | print(torch.mean((W[i, :] == 0).float())) 239 | test(modelp, testloader) 240 | exit() 241 | 242 | config = {n: '0.0000' for n in layersp} 243 | config['fc'] = '0.9797' # '0.9900' 244 | db.stitch(layersp, config) 245 | with torch.no_grad(): 246 | print(run(modelp, next(iter(dataloader)), loss=True) / args.nsamples) 247 | test(modelp, testloader) 248 | exit() 249 | 250 | if args.mode == 'stitch': 251 | with open(args.profile, 'r') as f: 252 | config = {} 253 | for l in f.readlines(): 254 | level, name = l.strip().split(' ') 255 | config[name] = '24_8w8a' # level 256 | db.stitch(layersp, config) 257 | test(modelp, testloader) 258 | exit() 259 | 260 | if args.mode == 'eval': 261 | for s in sorted(db.sds): 262 | db.stitch(layersp, {n: s for n in layersp}) 263 | print(s) 264 | test(modelp, testloader) 265 | exit() 266 | 267 | if args.mode == 'spdy': 268 | layersp = find_layers(modelp) 269 | tmp = (np.arange(len(db.sds)) / (len(db.sds) - 1)) ** 2 270 | print(len(db.sds)) 271 | print(len(tmp)) 272 | with open(filepath, 'w') as f: 273 | for layer in layersp: 274 | print(layer) 275 | f.write(layer + '\n') 276 | for i, name in enumerate(sorted(db.sds)): 277 | f.write('%.6f %s\n' % (tmp[i], name)) 278 | exit() 279 | 280 | if args.mode == 'squared': 281 | modeld = get_model() 282 | layersd = find_layers(modeld) 283 | 284 | errs = {n: {} for n in layersp} 285 | def accumerrs(name): 286 | def tmp(layer, inp, out): 287 | errs[name]['dense'] = errs[name].get('dense', 0) + torch.sum(out.data ** 2).item() 288 | for config in sorted(db.sds): 289 | db.load(layersp, name, config) 290 | errs[name][config] = errs[name].get(config, 0) + torch.sum((layersp[name](inp[0].data) - out.data) ** 2).item() 291 | return tmp 292 | for name in layersd: 293 | layersd[name].register_forward_hook(accumerrs(name)) 294 | 295 | with torch.no_grad(): 296 | for _ in range(args.nrounds): 297 | for i, batch in enumerate(dataloader): 298 | print(i) 299 | run(modeld, batch) 300 | 301 | with open(filepath, 'w') as f: 302 | for name in errs: 303 | f.write(name + '\n') 304 | for config in sorted(errs[name]): 305 | if config != 'dense': 306 | f.write('%.6f %s\n' % (errs[name][config] / errs[name]['dense'], config)) 307 | exit() 308 | 309 | if args.mode == 'loss': 310 | sd = modelp.state_dict() 311 | errs = {n: {} for n in layersp} 312 | baseloss = 0 313 | 314 | for _ in range(args.nrounds): 315 | for i, batch in enumerate(dataloader): 316 | print(i) 317 | with torch.no_grad(): 318 | baseloss += run(modelp, batch, loss=True) 319 | for name in layersp: 320 | print(name) 321 | for config in sorted(db.sds): 322 | db.load(layersp, name, config) 323 | errs[name][config] = errs[name].get(config, 0) + run(modelp, batch, loss=True) 324 | db.load(layersp, name, sd=sd) 325 | baseloss /= len(dataloader) * args.nrounds 326 | for name in errs: 327 | for config in errs[name]: 328 | errs[name][config] /= len(dataloader) * args.nrounds 329 | 330 | with open(filepath, 'w') as f: 331 | for name in errs: 332 | f.write(name + '\n') 333 | for config in sorted(errs[name]): 334 | f.write('%+.6f %s\n' % (errs[name][config] - baseloss, config)) 335 | exit() 336 | -------------------------------------------------------------------------------- /trueobs.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from quant import * 8 | 9 | 10 | torch.backends.cuda.matmul.allow_tf32 = False 11 | torch.backends.cudnn.allow_tf32 = False 12 | 13 | DEBUG = False 14 | 15 | 16 | class TrueOBS: 17 | 18 | def __init__(self, layer, rel_damp=0): 19 | self.layer = layer 20 | self.dev = self.layer.weight.device 21 | W = layer.weight.data.clone() 22 | if isinstance(self.layer, nn.Conv2d): 23 | W = W.flatten(1) 24 | self.rows = W.shape[0] 25 | self.columns = W.shape[1] 26 | # Accumulate in double precision 27 | self.H = torch.zeros((self.columns, self.columns), device=self.dev, dtype=torch.double) 28 | self.nsamples = 0 29 | self.rel_damp = rel_damp 30 | 31 | def add_batch(self, inp, out): 32 | if DEBUG: 33 | self.inp1 = inp 34 | self.out1 = out 35 | tmp = inp.shape[0] 36 | if isinstance(self.layer, nn.Linear): 37 | if len(inp.shape) == 3: 38 | inp = inp.reshape((-1, inp.shape[-1])) 39 | inp = inp.t() 40 | if isinstance(self.layer, nn.Conv2d): 41 | unfold = nn.Unfold( 42 | self.layer.kernel_size, 43 | dilation=self.layer.dilation, 44 | padding=self.layer.padding, 45 | stride=self.layer.stride 46 | ) 47 | inp = unfold(inp) 48 | inp = inp.permute([1, 0, 2]) 49 | inp = inp.flatten(1) 50 | self.H *= self.nsamples / (self.nsamples + tmp) 51 | self.nsamples += tmp 52 | self.H += 2 / self.nsamples * (inp.matmul(inp.t())).double() 53 | 54 | def invert(self, H): 55 | try: 56 | Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) 57 | except RuntimeError: 58 | print('Hessian not full rank.') 59 | tmp = 1 * torch.eye(self.columns, device=self.dev) 60 | Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H + tmp)) 61 | return Hinv 62 | 63 | def prepare(self, columnslast=False): 64 | if columnslast: 65 | perm = torch.arange(self.columns, device=self.dev) 66 | if len(self.layer.weight.shape) == 4: 67 | perm = perm.reshape(list(self.layer.weight.shape)[1:]) 68 | perm = perm.permute([1, 2, 0]) 69 | perm = perm.flatten() 70 | W = self.layer.weight.data.clone() 71 | if isinstance(self.layer, nn.Conv2d): 72 | W = W.flatten(1) 73 | H = self.H.float() 74 | if self.rel_damp > 0: 75 | damp = self.rel_damp * torch.diag(H).mean() 76 | H += damp * torch.eye(H.shape[0], device=self.dev) 77 | dead = torch.diag(H) == 0 78 | H[dead, dead] = 1 79 | W[:, dead] = 0 80 | if columnslast: 81 | H = H[perm, :][:, perm] 82 | W = W[:, perm] 83 | Hinv = self.invert(H) 84 | Losses = torch.zeros([self.rows, self.columns + 1], device=self.dev) 85 | if columnslast: 86 | return W, H, Hinv, Losses, perm 87 | return W, H, Hinv, Losses 88 | 89 | def prepare_iter(self, i1, parallel, W, Hinv1): 90 | i2 = min(i1 + parallel, self.rows) 91 | count = i2 - i1 92 | w = W[i1:i2, :] 93 | Hinv = Hinv1.unsqueeze(0).repeat((count, 1, 1)) 94 | mask = torch.zeros_like(w).bool() 95 | rangecount = torch.arange(count, device=self.dev) 96 | idxcount = rangecount + i1 97 | return i2, count, w, Hinv, mask, rangecount, idxcount 98 | 99 | def prepare_sparse(self, w, mask, Hinv, H): 100 | start = int(torch.min(torch.sum((w == 0).float(), 1)).item()) + 1 101 | for i in range(w.shape[0]): 102 | tmp = w[i] == 0 103 | H1 = H.clone() 104 | H1[tmp, :] = 0 105 | H1[:, tmp] = 0 106 | H1[tmp, tmp] = 1 107 | Hinv[i] = self.invert(H1) 108 | mask[i, torch.nonzero(tmp, as_tuple=True)[0][:(start - 1)]] = True 109 | return start 110 | 111 | def quantize(self, parallel=32): 112 | W, H, Hinv1, Losses = self.prepare() 113 | 114 | Q = torch.zeros_like(W) 115 | self.quantizer.find_params(W, weight=True) 116 | 117 | for i1 in range(0, self.rows, parallel): 118 | i2, count, w, Hinv, mask, rangecount, idxcount = self.prepare_iter(i1, parallel, W, Hinv1) 119 | start = self.prepare_sparse(w, mask, Hinv, H) 120 | 121 | outlier = .25 * (self.quantizer.scale ** 2)[i1:i2, :] 122 | scale = self.quantizer.scale[i1:i2, :] 123 | zero = self.quantizer.zero[i1:i2, :] 124 | 125 | tick = time.time() 126 | 127 | for quant in range(start, self.columns + 1): 128 | q = quantize(w, scale, zero, self.quantizer.maxq) 129 | err = (w - q) ** 2 130 | diag = torch.diagonal(Hinv, dim1=1, dim2=2) 131 | scores = err / diag 132 | scores[mask] = float('inf') 133 | err[mask] = 0 134 | j = torch.argmin(scores, 1) 135 | sel = torch.any(err > outlier, 1) 136 | sel &= w[rangecount, j] != 0 137 | if torch.any(sel): 138 | j[sel] = torch.argmax(err[sel, :], 1) 139 | Losses[i1:i2, quant] = scores[rangecount, j] 140 | q1 = q[rangecount, j] 141 | Q[idxcount, j] = q1 142 | row = Hinv[rangecount, j, :] 143 | d = diag[rangecount, j] 144 | w -= row * ((w[rangecount, j] - q1) / d).unsqueeze(1) 145 | mask[rangecount, j] = True 146 | if quant == self.columns: 147 | break 148 | row /= torch.sqrt(d).unsqueeze(1) 149 | Hinv -= torch.bmm(row.unsqueeze(2), row.unsqueeze(1)) 150 | Losses[i1:i2, :] /= 2 151 | 152 | torch.cuda.synchronize() 153 | print('%04d %04d time %.2f' % (i1, i2, time.time() - tick)) 154 | 155 | print('error', torch.sum(Losses).item()) 156 | self.layer.weight.data = Q.reshape(self.layer.weight.shape) 157 | if DEBUG: 158 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2) / 128) 159 | 160 | def nmprune(self, n=2, m=4, parallel=32): 161 | W, H, Hinv1, Losses, perm = self.prepare(columnslast=True) 162 | 163 | for i1 in range(0, self.rows, parallel): 164 | i2, count, w, Hinv, mask, rangecount, idxcount = self.prepare_iter(i1, parallel, W, Hinv1) 165 | 166 | buckets = torch.zeros((count, self.columns // m, 1), device=self.dev) 167 | 168 | tick = time.time() 169 | 170 | for zeros in range(1, self.columns + 1): 171 | diag = torch.diagonal(Hinv, dim1=1, dim2=2) 172 | scores = w ** 2 / diag 173 | tmp = (buckets >= n).repeat((1, 1, m)).flatten(1) 174 | scores[mask | tmp] = float('inf') 175 | j = torch.argmin(scores, 1) 176 | Losses[i1:i2, zeros] = scores[rangecount, j] 177 | row = Hinv[rangecount, j, :] 178 | d = diag[rangecount, j] 179 | w -= row * (w[rangecount, j] / d).unsqueeze(1) 180 | mask[rangecount, j] = True 181 | buckets[rangecount, torch.div(j, m, rounding_mode='floor'), :] += 1 182 | if zeros == self.columns * n / m: 183 | break 184 | row /= torch.sqrt(d).unsqueeze(1) 185 | Hinv -= torch.bmm(row.unsqueeze(2), row.unsqueeze(1)) 186 | Losses[i1:i2, :] /= 2 187 | w[mask] = 0 188 | W[i1:i2, :] = w 189 | 190 | torch.cuda.synchronize() 191 | print('%04d %04d time %.2f' % (i1, i2, time.time() - tick)) 192 | 193 | print('error', torch.sum(Losses).item()) 194 | W = W[:, torch.argsort(perm)] 195 | self.layer.weight.data = W.reshape(self.layer.weight.shape) 196 | if DEBUG: 197 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2) / 128) 198 | 199 | def prepare_unstr(self, parallel=32): 200 | W, H, Hinv1, Losses = self.prepare() 201 | 202 | self.Losses = Losses 203 | self.Traces = [] 204 | 205 | for i1 in range(0, self.rows, parallel): 206 | i2, count, w, Hinv, mask, rangecount, idxcount = self.prepare_iter(i1, parallel, W, Hinv1) 207 | start = self.prepare_sparse(w, mask, Hinv, H) 208 | 209 | Trace = torch.zeros((self.columns + 1, count, self.columns), device=self.dev) 210 | Trace[0, :, :] = w 211 | Trace[:start, :, :] = w 212 | 213 | tick = time.time() 214 | 215 | for zeros in range(start, self.columns + 1): 216 | diag = torch.diagonal(Hinv, dim1=1, dim2=2) 217 | scores = (w ** 2) / diag 218 | scores[mask] = float('inf') 219 | j = torch.argmin(scores, 1) 220 | self.Losses[i1:i2, zeros] = scores[rangecount, j] 221 | row = Hinv[rangecount, j, :] 222 | d = diag[rangecount, j] 223 | w -= row * (w[rangecount, j] / d).unsqueeze(1) 224 | mask[rangecount, j] = True 225 | w[mask] = 0 226 | Trace[zeros, :, :] = w 227 | if zeros == self.columns: 228 | break 229 | row /= torch.sqrt(d).unsqueeze(1) 230 | Hinv -= torch.bmm(row.unsqueeze(2), row.unsqueeze(1)) 231 | self.Losses[i1:i2, :] /= 2 232 | self.Traces.append(Trace.cpu()) 233 | 234 | torch.cuda.synchronize() 235 | print('%04d %04d time %.2f' % (i1, i2, time.time() - tick)) 236 | 237 | def prune_unstr(self, sparsities): 238 | return self.prune_blocked(sparsities) 239 | 240 | def prepare_blocked(self, size=4, parallel=32): 241 | W, H, Hinv1, Losses, perm = self.prepare(columnslast=True) 242 | 243 | self.Traces = [] 244 | blockcount = self.columns // size 245 | self.Losses = torch.zeros((self.rows, blockcount + 1), device=self.dev) 246 | rangeblockcount = torch.arange(blockcount, device=self.dev) 247 | rangecolumns = torch.arange(self.columns, device=self.dev) 248 | 249 | for i1 in range(0, self.rows, parallel): 250 | i2, count, w, Hinv, _, rangecount, _ = self.prepare_iter(i1, parallel, W, Hinv1) 251 | 252 | mask = torch.zeros((count, blockcount), device=self.dev).bool() 253 | mask1 = torch.zeros((count, blockcount, size), device=self.dev).bool() 254 | Trace = torch.zeros((blockcount + 1, count, self.columns), device=self.dev) 255 | Trace[0, :, :] = w 256 | rangeblockunroll = torch.arange(count * blockcount, device=self.dev) 257 | blockdiagidx = rangeblockcount.repeat(count) 258 | rangeunroll = torch.arange(self.columns * count, device=self.dev) 259 | diagidx = rangecolumns.repeat(count) 260 | paroffset = blockcount * rangecount 261 | expandrows = torch.arange(size, device=self.dev).unsqueeze(0).repeat(count, 1) 262 | expandrows += self.columns * rangecount.unsqueeze(1) 263 | 264 | tick = time.time() 265 | 266 | for dropped in range(1, blockcount + 1): 267 | blocks = Hinv.reshape(count * blockcount, size, blockcount, size) 268 | blocks = blocks[rangeblockunroll, :, blockdiagidx, :] 269 | invblocks = torch.cholesky_inverse(torch.linalg.cholesky(blocks)) 270 | w1 = w.reshape((count * blockcount, 1, size)) 271 | lambd = torch.bmm(w1, invblocks) 272 | scores = torch.sum(lambd * w1, (1, 2)) 273 | scores = scores.reshape((count, blockcount)) 274 | scores[mask] = float('inf') 275 | j = torch.argmin(scores, 1) 276 | self.Losses[i1:i2, dropped] = scores[rangecount, j] 277 | 278 | tmp = (expandrows + size * j.unsqueeze(1)).flatten() 279 | rows = Hinv.reshape((-1, self.columns))[tmp] 280 | rows = rows.reshape((count, size, self.columns)) 281 | tmp = paroffset + j 282 | d = invblocks[tmp] 283 | 284 | w -= torch.bmm(lambd[tmp], rows).squeeze(1) 285 | mask[rangecount, j] = True 286 | mask1[mask] = True 287 | tmp = mask1.flatten(1) 288 | w[mask1.flatten(1)] = 0 289 | Trace[dropped, :, :] = w 290 | 291 | if dropped == self.columns: 292 | break 293 | Hinv -= torch.bmm(rows.transpose(1, 2), torch.bmm(d, rows)) 294 | Hinv = Hinv.reshape((count * self.columns, self.columns)) 295 | tmp = mask1.flatten() 296 | Hinv[rangeunroll[tmp], diagidx[tmp]] = 1 297 | Hinv = Hinv.reshape((count, self.columns, self.columns)) 298 | self.Losses[i1:i2, :] /= 2 299 | Trace = Trace[:, :, torch.argsort(perm)] 300 | self.Traces.append(Trace.cpu()) 301 | 302 | torch.cuda.synchronize() 303 | print('%04d %04d time %.2f' % (i1, i2, time.time() - tick)) 304 | 305 | def prune_blocked(self, sparsities): 306 | parallel = self.Traces[0].shape[1] 307 | blockcount = self.Traces[0].shape[0] - 1 308 | losses = self.Losses[:, 1:].reshape(-1) 309 | order = torch.argsort(losses) 310 | Ws = [torch.zeros((self.rows, self.columns), device=self.dev) for _ in sparsities] 311 | losses = [0] * len(sparsities) 312 | for i in range(self.rows): 313 | if i % parallel == 0: 314 | Trace = self.Traces[i // parallel].to(self.dev) 315 | for j, sparsity in enumerate(sparsities): 316 | count = int(math.ceil(self.rows * blockcount * sparsity)) 317 | perrow = torch.sum( 318 | torch.div(order[:count], blockcount, rounding_mode='trunc') == i 319 | ).item() 320 | losses[j] += torch.sum(self.Losses[i, :(perrow + 1)]).item() 321 | Ws[j][i, :] = Trace[perrow, i % parallel, :] 322 | for sparsity, loss in zip(sparsities, losses): 323 | print('%.4f error' % sparsity, loss) 324 | if DEBUG: 325 | tmp = self.layer.weight.data.clone() 326 | self.layer.weight.data = Ws[sparsities.index(sparsity)].reshape(self.layer.weight.shape) 327 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2) / 128) 328 | self.layer.weight.data = tmp 329 | return Ws 330 | 331 | def free(self): 332 | if DEBUG: 333 | self.inp1 = None 334 | self.out1 = None 335 | self.H = None 336 | self.Losses = None 337 | self.Trace = None 338 | torch.cuda.empty_cache() 339 | -------------------------------------------------------------------------------- /bertsquad/utils_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Post-processing utilities for question answering. 17 | """ 18 | import collections 19 | import json 20 | import logging 21 | import os 22 | from typing import Optional, Tuple 23 | 24 | import numpy as np 25 | from tqdm.auto import tqdm 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def postprocess_qa_predictions( 32 | examples, 33 | features, 34 | predictions: Tuple[np.ndarray, np.ndarray], 35 | version_2_with_negative: bool = False, 36 | n_best_size: int = 20, 37 | max_answer_length: int = 30, 38 | null_score_diff_threshold: float = 0.0, 39 | output_dir: Optional[str] = None, 40 | prefix: Optional[str] = None, 41 | log_level: Optional[int] = logging.CRITICAL, # logging.WARNING, 42 | ): 43 | """ 44 | Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the 45 | original contexts. This is the base postprocessing functions for models that only return start and end logits. 46 | Args: 47 | examples: The non-preprocessed dataset (see the main script for more information). 48 | features: The processed dataset (see the main script for more information). 49 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): 50 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its 51 | first dimension must match the number of elements of :obj:`features`. 52 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): 53 | Whether or not the underlying dataset contains examples with no answers. 54 | n_best_size (:obj:`int`, `optional`, defaults to 20): 55 | The total number of n-best predictions to generate when looking for an answer. 56 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 57 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 58 | are not conditioned on one another. 59 | null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): 60 | The threshold used to select the null answer: if the best answer has a score that is less than the score of 61 | the null answer minus this threshold, the null answer is selected for this example (note that the score of 62 | the null answer for an example giving several features is the minimum of the scores for the null answer on 63 | each feature: all features must be aligned on the fact they `want` to predict a null answer). 64 | Only useful when :obj:`version_2_with_negative` is :obj:`True`. 65 | output_dir (:obj:`str`, `optional`): 66 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if 67 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null 68 | answers, are saved in `output_dir`. 69 | prefix (:obj:`str`, `optional`): 70 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names. 71 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): 72 | ``logging`` log level (e.g., ``logging.WARNING``) 73 | """ 74 | if len(predictions) != 2: 75 | raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).") 76 | all_start_logits, all_end_logits = predictions 77 | 78 | if len(predictions[0]) != len(features): 79 | raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") 80 | 81 | # Build a map example to its corresponding features. 82 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 83 | features_per_example = collections.defaultdict(list) 84 | for i, feature in enumerate(features): 85 | features_per_example[example_id_to_index[feature["example_id"]]].append(i) 86 | 87 | # The dictionaries we have to fill. 88 | all_predictions = collections.OrderedDict() 89 | all_nbest_json = collections.OrderedDict() 90 | if version_2_with_negative: 91 | scores_diff_json = collections.OrderedDict() 92 | 93 | # Logging. 94 | logger.setLevel(log_level) 95 | logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") 96 | 97 | # Let's loop over all the examples! 98 | for example_index, example in enumerate(tqdm(examples)): 99 | # Those are the indices of the features associated to the current example. 100 | feature_indices = features_per_example[example_index] 101 | 102 | min_null_prediction = None 103 | prelim_predictions = [] 104 | 105 | # Looping through all the features associated to the current example. 106 | for feature_index in feature_indices: 107 | # We grab the predictions of the model for this feature. 108 | start_logits = all_start_logits[feature_index] 109 | end_logits = all_end_logits[feature_index] 110 | # This is what will allow us to map some the positions in our logits to span of texts in the original 111 | # context. 112 | offset_mapping = features[feature_index]["offset_mapping"] 113 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 114 | # available in the current feature. 115 | token_is_max_context = features[feature_index].get("token_is_max_context", None) 116 | 117 | # Update minimum null prediction. 118 | feature_null_score = start_logits[0] + end_logits[0] 119 | if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: 120 | min_null_prediction = { 121 | "offsets": (0, 0), 122 | "score": feature_null_score, 123 | "start_logit": start_logits[0], 124 | "end_logit": end_logits[0], 125 | } 126 | 127 | # Go through all possibilities for the `n_best_size` greater start and end logits. 128 | start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() 129 | end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() 130 | for start_index in start_indexes: 131 | for end_index in end_indexes: 132 | # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond 133 | # to part of the input_ids that are not in the context. 134 | if ( 135 | start_index >= len(offset_mapping) 136 | or end_index >= len(offset_mapping) 137 | or offset_mapping[start_index] is None 138 | or offset_mapping[end_index] is None 139 | ): 140 | continue 141 | # Don't consider answers with a length that is either < 0 or > max_answer_length. 142 | if end_index < start_index or end_index - start_index + 1 > max_answer_length: 143 | continue 144 | # Don't consider answer that don't have the maximum context available (if such information is 145 | # provided). 146 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): 147 | continue 148 | prelim_predictions.append( 149 | { 150 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), 151 | "score": start_logits[start_index] + end_logits[end_index], 152 | "start_logit": start_logits[start_index], 153 | "end_logit": end_logits[end_index], 154 | } 155 | ) 156 | if version_2_with_negative: 157 | # Add the minimum null prediction 158 | prelim_predictions.append(min_null_prediction) 159 | null_score = min_null_prediction["score"] 160 | 161 | # Only keep the best `n_best_size` predictions. 162 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] 163 | 164 | # Add back the minimum null prediction if it was removed because of its low score. 165 | if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): 166 | predictions.append(min_null_prediction) 167 | 168 | # Use the offsets to gather the answer text in the original context. 169 | context = example["context"] 170 | for pred in predictions: 171 | offsets = pred.pop("offsets") 172 | pred["text"] = context[offsets[0] : offsets[1]] 173 | 174 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 175 | # failure. 176 | if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): 177 | predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) 178 | 179 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using 180 | # the LogSumExp trick). 181 | scores = np.array([pred.pop("score") for pred in predictions]) 182 | exp_scores = np.exp(scores - np.max(scores)) 183 | probs = exp_scores / exp_scores.sum() 184 | 185 | # Include the probabilities in our predictions. 186 | for prob, pred in zip(probs, predictions): 187 | pred["probability"] = prob 188 | 189 | # Pick the best prediction. If the null answer is not possible, this is easy. 190 | if not version_2_with_negative: 191 | all_predictions[example["id"]] = predictions[0]["text"] 192 | else: 193 | # Otherwise we first need to find the best non-empty prediction. 194 | i = 0 195 | while predictions[i]["text"] == "": 196 | i += 1 197 | best_non_null_pred = predictions[i] 198 | 199 | # Then we compare to the null prediction using the threshold. 200 | score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] 201 | scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. 202 | if score_diff > null_score_diff_threshold: 203 | all_predictions[example["id"]] = "" 204 | else: 205 | all_predictions[example["id"]] = best_non_null_pred["text"] 206 | 207 | # Make `predictions` JSON-serializable by casting np.float back to float. 208 | all_nbest_json[example["id"]] = [ 209 | {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} 210 | for pred in predictions 211 | ] 212 | 213 | # If we have an output_dir, let's save all those dicts. 214 | if output_dir is not None: 215 | if not os.path.isdir(output_dir): 216 | raise EnvironmentError(f"{output_dir} is not a directory.") 217 | 218 | prediction_file = os.path.join( 219 | output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" 220 | ) 221 | nbest_file = os.path.join( 222 | output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" 223 | ) 224 | if version_2_with_negative: 225 | null_odds_file = os.path.join( 226 | output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" 227 | ) 228 | 229 | logger.info(f"Saving predictions to {prediction_file}.") 230 | with open(prediction_file, "w") as writer: 231 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 232 | logger.info(f"Saving nbest_preds to {nbest_file}.") 233 | with open(nbest_file, "w") as writer: 234 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 235 | if version_2_with_negative: 236 | logger.info(f"Saving null_odds to {null_odds_file}.") 237 | with open(null_odds_file, "w") as writer: 238 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 239 | 240 | return all_predictions 241 | 242 | 243 | def postprocess_qa_predictions_with_beam_search( 244 | examples, 245 | features, 246 | predictions: Tuple[np.ndarray, np.ndarray], 247 | version_2_with_negative: bool = False, 248 | n_best_size: int = 20, 249 | max_answer_length: int = 30, 250 | start_n_top: int = 5, 251 | end_n_top: int = 5, 252 | output_dir: Optional[str] = None, 253 | prefix: Optional[str] = None, 254 | log_level: Optional[int] = logging.WARNING, 255 | ): 256 | """ 257 | Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the 258 | original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as 259 | cls token predictions. 260 | Args: 261 | examples: The non-preprocessed dataset (see the main script for more information). 262 | features: The processed dataset (see the main script for more information). 263 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): 264 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its 265 | first dimension must match the number of elements of :obj:`features`. 266 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): 267 | Whether or not the underlying dataset contains examples with no answers. 268 | n_best_size (:obj:`int`, `optional`, defaults to 20): 269 | The total number of n-best predictions to generate when looking for an answer. 270 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 271 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 272 | are not conditioned on one another. 273 | start_n_top (:obj:`int`, `optional`, defaults to 5): 274 | The number of top start logits too keep when searching for the :obj:`n_best_size` predictions. 275 | end_n_top (:obj:`int`, `optional`, defaults to 5): 276 | The number of top end logits too keep when searching for the :obj:`n_best_size` predictions. 277 | output_dir (:obj:`str`, `optional`): 278 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if 279 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null 280 | answers, are saved in `output_dir`. 281 | prefix (:obj:`str`, `optional`): 282 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names. 283 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): 284 | ``logging`` log level (e.g., ``logging.WARNING``) 285 | """ 286 | if len(predictions) != 5: 287 | raise ValueError("`predictions` should be a tuple with five elements.") 288 | start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions 289 | 290 | if len(predictions[0]) != len(features): 291 | raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") 292 | 293 | # Build a map example to its corresponding features. 294 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 295 | features_per_example = collections.defaultdict(list) 296 | for i, feature in enumerate(features): 297 | features_per_example[example_id_to_index[feature["example_id"]]].append(i) 298 | 299 | # The dictionaries we have to fill. 300 | all_predictions = collections.OrderedDict() 301 | all_nbest_json = collections.OrderedDict() 302 | scores_diff_json = collections.OrderedDict() if version_2_with_negative else None 303 | 304 | # Logging. 305 | logger.setLevel(log_level) 306 | logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") 307 | 308 | # Let's loop over all the examples! 309 | for example_index, example in enumerate(tqdm(examples)): 310 | # Those are the indices of the features associated to the current example. 311 | feature_indices = features_per_example[example_index] 312 | 313 | min_null_score = None 314 | prelim_predictions = [] 315 | 316 | # Looping through all the features associated to the current example. 317 | for feature_index in feature_indices: 318 | # We grab the predictions of the model for this feature. 319 | start_log_prob = start_top_log_probs[feature_index] 320 | start_indexes = start_top_index[feature_index] 321 | end_log_prob = end_top_log_probs[feature_index] 322 | end_indexes = end_top_index[feature_index] 323 | feature_null_score = cls_logits[feature_index] 324 | # This is what will allow us to map some the positions in our logits to span of texts in the original 325 | # context. 326 | offset_mapping = features[feature_index]["offset_mapping"] 327 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 328 | # available in the current feature. 329 | token_is_max_context = features[feature_index].get("token_is_max_context", None) 330 | 331 | # Update minimum null prediction 332 | if min_null_score is None or feature_null_score < min_null_score: 333 | min_null_score = feature_null_score 334 | 335 | # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits. 336 | for i in range(start_n_top): 337 | for j in range(end_n_top): 338 | start_index = int(start_indexes[i]) 339 | j_index = i * end_n_top + j 340 | end_index = int(end_indexes[j_index]) 341 | # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the 342 | # p_mask but let's not take any risk) 343 | if ( 344 | start_index >= len(offset_mapping) 345 | or end_index >= len(offset_mapping) 346 | or offset_mapping[start_index] is None 347 | or offset_mapping[end_index] is None 348 | ): 349 | continue 350 | # Don't consider answers with a length negative or > max_answer_length. 351 | if end_index < start_index or end_index - start_index + 1 > max_answer_length: 352 | continue 353 | # Don't consider answer that don't have the maximum context available (if such information is 354 | # provided). 355 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): 356 | continue 357 | prelim_predictions.append( 358 | { 359 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), 360 | "score": start_log_prob[i] + end_log_prob[j_index], 361 | "start_log_prob": start_log_prob[i], 362 | "end_log_prob": end_log_prob[j_index], 363 | } 364 | ) 365 | 366 | # Only keep the best `n_best_size` predictions. 367 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] 368 | 369 | # Use the offsets to gather the answer text in the original context. 370 | context = example["context"] 371 | for pred in predictions: 372 | offsets = pred.pop("offsets") 373 | pred["text"] = context[offsets[0] : offsets[1]] 374 | 375 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 376 | # failure. 377 | if len(predictions) == 0: 378 | predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6}) 379 | 380 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using 381 | # the LogSumExp trick). 382 | scores = np.array([pred.pop("score") for pred in predictions]) 383 | exp_scores = np.exp(scores - np.max(scores)) 384 | probs = exp_scores / exp_scores.sum() 385 | 386 | # Include the probabilities in our predictions. 387 | for prob, pred in zip(probs, predictions): 388 | pred["probability"] = prob 389 | 390 | # Pick the best prediction and set the probability for the null answer. 391 | all_predictions[example["id"]] = predictions[0]["text"] 392 | if version_2_with_negative: 393 | scores_diff_json[example["id"]] = float(min_null_score) 394 | 395 | # Make `predictions` JSON-serializable by casting np.float back to float. 396 | all_nbest_json[example["id"]] = [ 397 | {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} 398 | for pred in predictions 399 | ] 400 | 401 | # If we have an output_dir, let's save all those dicts. 402 | if output_dir is not None: 403 | if not os.path.isdir(output_dir): 404 | raise EnvironmentError(f"{output_dir} is not a directory.") 405 | 406 | prediction_file = os.path.join( 407 | output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" 408 | ) 409 | nbest_file = os.path.join( 410 | output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" 411 | ) 412 | if version_2_with_negative: 413 | null_odds_file = os.path.join( 414 | output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" 415 | ) 416 | 417 | logger.info(f"Saving predictions to {prediction_file}.") 418 | with open(prediction_file, "w") as writer: 419 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 420 | logger.info(f"Saving nbest_preds to {nbest_file}.") 421 | with open(nbest_file, "w") as writer: 422 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 423 | if version_2_with_negative: 424 | logger.info(f"Saving null_odds to {null_odds_file}.") 425 | with open(null_odds_file, "w") as writer: 426 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 427 | 428 | return all_predictions, scores_diff_json 429 | -------------------------------------------------------------------------------- /bertsquad/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for question answering using a slightly adapted version of the 🤗 Trainer. 18 | """ 19 | # You can also adapt this script on your own question answering task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | from datasets import load_dataset, load_metric 29 | 30 | import transformers 31 | from .trainer_qa import QuestionAnsweringTrainer 32 | from transformers import ( 33 | AutoConfig, 34 | AutoModelForQuestionAnswering, 35 | AutoTokenizer, 36 | DataCollatorWithPadding, 37 | EvalPrediction, 38 | HfArgumentParser, 39 | PreTrainedTokenizerFast, 40 | TrainingArguments, 41 | default_data_collator, 42 | set_seed, 43 | ) 44 | from transformers.trainer_utils import get_last_checkpoint 45 | from transformers.utils import check_min_version 46 | from transformers.utils.versions import require_version 47 | from .utils_qa import postprocess_qa_predictions 48 | 49 | 50 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 51 | # check_min_version("4.16.0.dev0") 52 | 53 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") 54 | 55 | # To control logging level for various modules used in the application: 56 | import re 57 | def set_global_logging_level(level=logging.ERROR, prefices=[""]): 58 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') 59 | for name in logging.root.manager.loggerDict: 60 | if re.match(prefix_re, name): 61 | logging.getLogger(name).setLevel(level) 62 | import tqdm 63 | set_global_logging_level() 64 | def nop(it, *a, **k): 65 | return it 66 | tqdm.tqdm = nop 67 | 68 | datasets.logging.get_verbosity = lambda: logging.NOTSET 69 | logger = logging.getLogger(__name__) 70 | 71 | 72 | @dataclass 73 | class ModelArguments: 74 | """ 75 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 76 | """ 77 | 78 | model_name_or_path: str = field( 79 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 80 | ) 81 | config_name: Optional[str] = field( 82 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 83 | ) 84 | tokenizer_name: Optional[str] = field( 85 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 86 | ) 87 | cache_dir: Optional[str] = field( 88 | default=None, 89 | metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"}, 90 | ) 91 | model_revision: str = field( 92 | default="main", 93 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 94 | ) 95 | use_auth_token: bool = field( 96 | default=False, 97 | metadata={ 98 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 99 | "with private models)." 100 | }, 101 | ) 102 | 103 | 104 | @dataclass 105 | class DataTrainingArguments: 106 | """ 107 | Arguments pertaining to what data we are going to input our model for training and eval. 108 | """ 109 | 110 | dataset_name: Optional[str] = field( 111 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 112 | ) 113 | dataset_config_name: Optional[str] = field( 114 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 115 | ) 116 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 117 | validation_file: Optional[str] = field( 118 | default=None, 119 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 120 | ) 121 | test_file: Optional[str] = field( 122 | default=None, 123 | metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, 124 | ) 125 | overwrite_cache: bool = field( 126 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 127 | ) 128 | preprocessing_num_workers: Optional[int] = field( 129 | default=None, 130 | metadata={"help": "The number of processes to use for the preprocessing."}, 131 | ) 132 | max_seq_length: int = field( 133 | default=384, 134 | metadata={ 135 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 136 | "than this will be truncated, sequences shorter will be padded." 137 | }, 138 | ) 139 | pad_to_max_length: bool = field( 140 | default=True, 141 | metadata={ 142 | "help": "Whether to pad all samples to `max_seq_length`. " 143 | "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " 144 | "be faster on GPU but will be slower on TPU)." 145 | }, 146 | ) 147 | max_train_samples: Optional[int] = field( 148 | default=None, 149 | metadata={ 150 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 151 | "value if set." 152 | }, 153 | ) 154 | max_eval_samples: Optional[int] = field( 155 | default=None, 156 | metadata={ 157 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 158 | "value if set." 159 | }, 160 | ) 161 | max_predict_samples: Optional[int] = field( 162 | default=None, 163 | metadata={ 164 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 165 | "value if set." 166 | }, 167 | ) 168 | version_2_with_negative: bool = field( 169 | default=False, metadata={"help": "If true, some of the examples do not have an answer."} 170 | ) 171 | null_score_diff_threshold: float = field( 172 | default=0.0, 173 | metadata={ 174 | "help": "The threshold used to select the null answer: if the best answer has a score that is less than " 175 | "the score of the null answer minus this threshold, the null answer is selected for this example. " 176 | "Only useful when `version_2_with_negative=True`." 177 | }, 178 | ) 179 | doc_stride: int = field( 180 | default=128, 181 | metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, 182 | ) 183 | n_best_size: int = field( 184 | default=20, 185 | metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, 186 | ) 187 | max_answer_length: int = field( 188 | default=30, 189 | metadata={ 190 | "help": "The maximum length of an answer that can be generated. This is needed because the start " 191 | "and end predictions are not conditioned on one another." 192 | }, 193 | ) 194 | 195 | def __post_init__(self): 196 | if ( 197 | self.dataset_name is None 198 | and self.train_file is None 199 | and self.validation_file is None 200 | and self.test_file is None 201 | ): 202 | raise ValueError("Need either a dataset name or a training/validation file/test_file.") 203 | else: 204 | if self.train_file is not None: 205 | extension = self.train_file.split(".")[-1] 206 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 207 | if self.validation_file is not None: 208 | extension = self.validation_file.split(".")[-1] 209 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 210 | if self.test_file is not None: 211 | extension = self.test_file.split(".")[-1] 212 | assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." 213 | 214 | 215 | # See all possible arguments in src/transformers/training_args.py 216 | # or by passing the --help flag to this script. 217 | # We now keep distinct sets of args, for a cleaner separation of concerns. 218 | 219 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 220 | # BEGIN 221 | model_args, data_args, training_args = parser.parse_json_file(json_file='bertsquad/checkpoint/args.json') 222 | # END 223 | # if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 224 | # If we pass only one argument to the script and it's the path to a json file, 225 | # let's parse it to get our arguments. 226 | # model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 227 | # else: 228 | # model_args, data_args, training_args = parser.parse_args_into_dataclasses() 229 | 230 | # Setup logging 231 | logging.basicConfig( 232 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 233 | datefmt="%m/%d/%Y %H:%M:%S", 234 | handlers=[logging.StreamHandler(sys.stdout)], 235 | ) 236 | 237 | # log_level = training_args.get_process_log_level() 238 | log_level = logging.CRITICAL 239 | logger.setLevel(log_level) 240 | datasets.utils.logging.set_verbosity(log_level) 241 | transformers.utils.logging.set_verbosity(log_level) 242 | transformers.utils.logging.enable_default_handler() 243 | transformers.utils.logging.enable_explicit_format() 244 | 245 | # Log on each process the small summary: 246 | logger.warning( 247 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 248 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 249 | ) 250 | logger.info(f"Training/evaluation parameters {training_args}") 251 | 252 | # Detecting last checkpoint. 253 | last_checkpoint = None 254 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 255 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 256 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 257 | raise ValueError( 258 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 259 | "Use --overwrite_output_dir to overcome." 260 | ) 261 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 262 | logger.info( 263 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 264 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 265 | ) 266 | 267 | # Set seed before initializing model. 268 | set_seed(training_args.seed) 269 | 270 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 271 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 272 | # (the dataset will be downloaded automatically from the datasets Hub). 273 | # 274 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 275 | # 'text' is found. You can easily tweak this behavior (see below). 276 | # 277 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 278 | # download the dataset. 279 | if data_args.dataset_name is not None: 280 | # Downloading and loading a dataset from the hub. 281 | raw_datasets = load_dataset( 282 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 283 | ) 284 | else: 285 | data_files = {} 286 | if data_args.train_file is not None: 287 | data_files["train"] = data_args.train_file 288 | extension = data_args.train_file.split(".")[-1] 289 | 290 | if data_args.validation_file is not None: 291 | data_files["validation"] = data_args.validation_file 292 | extension = data_args.validation_file.split(".")[-1] 293 | if data_args.test_file is not None: 294 | data_files["test"] = data_args.test_file 295 | extension = data_args.test_file.split(".")[-1] 296 | raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) 297 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 298 | # https://huggingface.co/docs/datasets/loading_datasets.html. 299 | 300 | # Load pretrained model and tokenizer 301 | # 302 | # Distributed training: 303 | # The .from_pretrained methods guarantee that only one local process can concurrently 304 | # download model & vocab. 305 | config = AutoConfig.from_pretrained( 306 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 307 | cache_dir=model_args.cache_dir, 308 | revision=model_args.model_revision, 309 | use_auth_token=True if model_args.use_auth_token else None, 310 | ) 311 | tokenizer = AutoTokenizer.from_pretrained( 312 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 313 | cache_dir=model_args.cache_dir, 314 | use_fast=True, 315 | revision=model_args.model_revision, 316 | use_auth_token=True if model_args.use_auth_token else None, 317 | ) 318 | model = AutoModelForQuestionAnswering.from_pretrained( 319 | model_args.model_name_or_path, 320 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 321 | config=config, 322 | cache_dir=model_args.cache_dir, 323 | revision=model_args.model_revision, 324 | use_auth_token=True if model_args.use_auth_token else None, 325 | ) 326 | 327 | # Tokenizer check: this script requires a fast tokenizer. 328 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 329 | raise ValueError( 330 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 331 | "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " 332 | "requirement" 333 | ) 334 | 335 | # Preprocessing the datasets. 336 | # Preprocessing is slighlty different for training and evaluation. 337 | if training_args.do_train: 338 | column_names = raw_datasets["train"].column_names 339 | elif training_args.do_eval: 340 | column_names = raw_datasets["validation"].column_names 341 | else: 342 | column_names = raw_datasets["test"].column_names 343 | question_column_name = "question" if "question" in column_names else column_names[0] 344 | context_column_name = "context" if "context" in column_names else column_names[1] 345 | answer_column_name = "answers" if "answers" in column_names else column_names[2] 346 | 347 | # Padding side determines if we do (question|context) or (context|question). 348 | pad_on_right = tokenizer.padding_side == "right" 349 | 350 | if data_args.max_seq_length > tokenizer.model_max_length: 351 | logger.warning( 352 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 353 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 354 | ) 355 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 356 | 357 | # Training preprocessing 358 | def prepare_train_features(examples): 359 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the 360 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that 361 | # left whitespace 362 | examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] 363 | 364 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results 365 | # in one example possible giving several features when a context is long, each of those features having a 366 | # context that overlaps a bit the context of the previous feature. 367 | tokenized_examples = tokenizer( 368 | examples[question_column_name if pad_on_right else context_column_name], 369 | examples[context_column_name if pad_on_right else question_column_name], 370 | truncation="only_second" if pad_on_right else "only_first", 371 | max_length=max_seq_length, 372 | stride=data_args.doc_stride, 373 | return_overflowing_tokens=True, 374 | return_offsets_mapping=True, 375 | padding="max_length" if data_args.pad_to_max_length else False, 376 | ) 377 | 378 | # Since one example might give us several features if it has a long context, we need a map from a feature to 379 | # its corresponding example. This key gives us just that. 380 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 381 | # The offset mappings will give us a map from token to character position in the original context. This will 382 | # help us compute the start_positions and end_positions. 383 | offset_mapping = tokenized_examples.pop("offset_mapping") 384 | 385 | # Let's label those examples! 386 | tokenized_examples["start_positions"] = [] 387 | tokenized_examples["end_positions"] = [] 388 | 389 | for i, offsets in enumerate(offset_mapping): 390 | # We will label impossible answers with the index of the CLS token. 391 | input_ids = tokenized_examples["input_ids"][i] 392 | cls_index = input_ids.index(tokenizer.cls_token_id) 393 | 394 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 395 | sequence_ids = tokenized_examples.sequence_ids(i) 396 | 397 | # One example can give several spans, this is the index of the example containing this span of text. 398 | sample_index = sample_mapping[i] 399 | answers = examples[answer_column_name][sample_index] 400 | # If no answers are given, set the cls_index as answer. 401 | if len(answers["answer_start"]) == 0: 402 | tokenized_examples["start_positions"].append(cls_index) 403 | tokenized_examples["end_positions"].append(cls_index) 404 | else: 405 | # Start/end character index of the answer in the text. 406 | start_char = answers["answer_start"][0] 407 | end_char = start_char + len(answers["text"][0]) 408 | 409 | # Start token index of the current span in the text. 410 | token_start_index = 0 411 | while sequence_ids[token_start_index] != (1 if pad_on_right else 0): 412 | token_start_index += 1 413 | 414 | # End token index of the current span in the text. 415 | token_end_index = len(input_ids) - 1 416 | while sequence_ids[token_end_index] != (1 if pad_on_right else 0): 417 | token_end_index -= 1 418 | 419 | # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). 420 | if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): 421 | tokenized_examples["start_positions"].append(cls_index) 422 | tokenized_examples["end_positions"].append(cls_index) 423 | else: 424 | # Otherwise move the token_start_index and token_end_index to the two ends of the answer. 425 | # Note: we could go after the last offset if the answer is the last word (edge case). 426 | while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: 427 | token_start_index += 1 428 | tokenized_examples["start_positions"].append(token_start_index - 1) 429 | while offsets[token_end_index][1] >= end_char: 430 | token_end_index -= 1 431 | tokenized_examples["end_positions"].append(token_end_index + 1) 432 | 433 | return tokenized_examples 434 | 435 | if training_args.do_train: 436 | if "train" not in raw_datasets: 437 | raise ValueError("--do_train requires a train dataset") 438 | train_dataset = raw_datasets["train"] 439 | # train_dataset = load_dataset( 440 | # data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, split='train[:1024]' 441 | # ) 442 | if data_args.max_train_samples is not None: 443 | # We will select sample from whole data if argument is specified 444 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 445 | # Create train feature from dataset 446 | with training_args.main_process_first(desc="train dataset map pre-processing"): 447 | train_dataset = train_dataset.map( 448 | prepare_train_features, 449 | batched=True, 450 | num_proc=data_args.preprocessing_num_workers, 451 | remove_columns=column_names, 452 | load_from_cache_file=not data_args.overwrite_cache, 453 | desc="Running tokenizer on train dataset", 454 | ) 455 | if data_args.max_train_samples is not None: 456 | # Number of samples might increase during Feature Creation, We select only specified max samples 457 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 458 | 459 | # Validation preprocessing 460 | def prepare_validation_features(examples): 461 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the 462 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that 463 | # left whitespace 464 | examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] 465 | 466 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results 467 | # in one example possible giving several features when a context is long, each of those features having a 468 | # context that overlaps a bit the context of the previous feature. 469 | tokenized_examples = tokenizer( 470 | examples[question_column_name if pad_on_right else context_column_name], 471 | examples[context_column_name if pad_on_right else question_column_name], 472 | truncation="only_second" if pad_on_right else "only_first", 473 | max_length=max_seq_length, 474 | stride=data_args.doc_stride, 475 | return_overflowing_tokens=True, 476 | return_offsets_mapping=True, 477 | padding="max_length" if data_args.pad_to_max_length else False, 478 | ) 479 | 480 | # Since one example might give us several features if it has a long context, we need a map from a feature to 481 | # its corresponding example. This key gives us just that. 482 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 483 | 484 | # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the 485 | # corresponding example_id and we will store the offset mappings. 486 | tokenized_examples["example_id"] = [] 487 | 488 | for i in range(len(tokenized_examples["input_ids"])): 489 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 490 | sequence_ids = tokenized_examples.sequence_ids(i) 491 | context_index = 1 if pad_on_right else 0 492 | 493 | # One example can give several spans, this is the index of the example containing this span of text. 494 | sample_index = sample_mapping[i] 495 | tokenized_examples["example_id"].append(examples["id"][sample_index]) 496 | 497 | # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token 498 | # position is part of the context or not. 499 | tokenized_examples["offset_mapping"][i] = [ 500 | (o if sequence_ids[k] == context_index else None) 501 | for k, o in enumerate(tokenized_examples["offset_mapping"][i]) 502 | ] 503 | 504 | return tokenized_examples 505 | 506 | if training_args.do_eval: 507 | if "validation" not in raw_datasets: 508 | raise ValueError("--do_eval requires a validation dataset") 509 | eval_examples = raw_datasets["validation"] 510 | if data_args.max_eval_samples is not None: 511 | # We will select sample from whole data 512 | eval_examples = eval_examples.select(range(data_args.max_eval_samples)) 513 | # Validation Feature Creation 514 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 515 | eval_dataset = eval_examples.map( 516 | prepare_validation_features, 517 | batched=True, 518 | num_proc=data_args.preprocessing_num_workers, 519 | remove_columns=column_names, 520 | load_from_cache_file=not data_args.overwrite_cache, 521 | desc="Running tokenizer on validation dataset", 522 | ) 523 | if data_args.max_eval_samples is not None: 524 | # During Feature creation dataset samples might increase, we will select required samples again 525 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 526 | 527 | if training_args.do_predict: 528 | if "test" not in raw_datasets: 529 | raise ValueError("--do_predict requires a test dataset") 530 | predict_examples = raw_datasets["test"] 531 | if data_args.max_predict_samples is not None: 532 | # We will select sample from whole data 533 | predict_examples = predict_examples.select(range(data_args.max_predict_samples)) 534 | # Predict Feature Creation 535 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 536 | predict_dataset = predict_examples.map( 537 | prepare_validation_features, 538 | batched=True, 539 | num_proc=data_args.preprocessing_num_workers, 540 | remove_columns=column_names, 541 | load_from_cache_file=not data_args.overwrite_cache, 542 | desc="Running tokenizer on prediction dataset", 543 | ) 544 | if data_args.max_predict_samples is not None: 545 | # During Feature creation dataset samples might increase, we will select required samples again 546 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 547 | 548 | # Data collator 549 | # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data 550 | # collator. 551 | data_collator = ( 552 | default_data_collator 553 | if data_args.pad_to_max_length 554 | else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) 555 | ) 556 | 557 | # Post-processing: 558 | def post_processing_function(examples, features, predictions, stage="eval"): 559 | # Post-processing: we match the start logits and end logits to answers in the original context. 560 | predictions = postprocess_qa_predictions( 561 | examples=examples, 562 | features=features, 563 | predictions=predictions, 564 | version_2_with_negative=data_args.version_2_with_negative, 565 | n_best_size=data_args.n_best_size, 566 | max_answer_length=data_args.max_answer_length, 567 | null_score_diff_threshold=data_args.null_score_diff_threshold, 568 | output_dir=training_args.output_dir, 569 | log_level=log_level, 570 | prefix=stage, 571 | ) 572 | # Format the result to the format the metric expects. 573 | if data_args.version_2_with_negative: 574 | formatted_predictions = [ 575 | {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() 576 | ] 577 | else: 578 | formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] 579 | 580 | references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] 581 | return EvalPrediction(predictions=formatted_predictions, label_ids=references) 582 | 583 | metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") 584 | 585 | def compute_metrics(p: EvalPrediction): 586 | return metric.compute(predictions=p.predictions, references=p.label_ids) 587 | 588 | def drop_layers(model, layers_to_keep): 589 | import copy 590 | import torch.nn as nn 591 | layer_drop_matching = { 592 | 3: [0,1,2], 593 | 6: [0,1,2,3,4,5], 594 | } 595 | encoder_layers = model.bert.encoder.layer 596 | trimmed_encoder_layers = nn.ModuleList() 597 | for i in layer_drop_matching[layers_to_keep]: 598 | trimmed_encoder_layers.append(encoder_layers[i]) 599 | trimmed_model = copy.deepcopy(model) 600 | trimmed_model.bert.encoder.layer = trimmed_encoder_layers 601 | return trimmed_model 602 | 603 | def get_model(path='', layers=12): 604 | if layers < 12: 605 | path = model_args.model_name_or_path + str(layers) 606 | model = AutoModelForQuestionAnswering.from_pretrained( 607 | model_args.model_name_or_path if not path else path, 608 | from_tf=bool(".ckpt" in (model_args.model_name_or_path if not path else path)), 609 | config=config, 610 | cache_dir=model_args.cache_dir, 611 | revision=model_args.model_revision, 612 | use_auth_token=True if model_args.use_auth_token else None, 613 | ) 614 | if layers < 12: 615 | model = drop_layers(model, layers) 616 | return model 617 | 618 | def test(model): 619 | print('Evaluating ...') 620 | trainer = QuestionAnsweringTrainer( 621 | model=model, 622 | args=training_args, 623 | train_dataset=train_dataset if training_args.do_train else None, 624 | eval_dataset=eval_dataset if training_args.do_eval else None, 625 | eval_examples=eval_examples if training_args.do_eval else None, 626 | tokenizer=tokenizer, 627 | data_collator=data_collator, 628 | post_process_function=post_processing_function, 629 | compute_metrics=compute_metrics, 630 | ) 631 | f1 = trainer.evaluate()['eval_f1'] 632 | print(f1) 633 | return f1 634 | 635 | def get_dataloader(batchsize, nsamples, seed=0): 636 | if nsamples == -1: 637 | train_dataset1 = train_dataset 638 | else: 639 | import numpy as np 640 | np.random.seed(seed) 641 | perm = np.random.permutation(len(train_dataset)) 642 | # perm = np.arange(len(train_dataset)) # TODO: remove 643 | train_dataset1 = train_dataset.select(perm[:nsamples]) 644 | trainer = QuestionAnsweringTrainer( 645 | model=model, 646 | args=training_args, 647 | train_dataset=train_dataset1 if training_args.do_train else None, 648 | eval_dataset=eval_dataset if training_args.do_eval else None, 649 | eval_examples=eval_examples if training_args.do_eval else None, 650 | tokenizer=tokenizer, 651 | data_collator=data_collator, 652 | post_process_function=post_processing_function, 653 | compute_metrics=compute_metrics, 654 | ) 655 | from torch.utils.data import DataLoader 656 | return DataLoader( 657 | trainer.train_dataset, 658 | batch_size=batchsize, 659 | sampler=trainer._get_train_sampler(), 660 | collate_fn=trainer.data_collator, 661 | drop_last=trainer.args.dataloader_drop_last, 662 | num_workers=trainer.args.dataloader_num_workers, 663 | pin_memory=trainer.args.dataloader_pin_memory, 664 | ) 665 | 666 | 667 | if False: 668 | 669 | # Initialize our Trainer 670 | trainer = QuestionAnsweringTrainer( 671 | model=model, 672 | args=training_args, 673 | train_dataset=train_dataset if training_args.do_train else None, 674 | eval_dataset=eval_dataset if training_args.do_eval else None, 675 | eval_examples=eval_examples if training_args.do_eval else None, 676 | tokenizer=tokenizer, 677 | data_collator=data_collator, 678 | post_process_function=post_processing_function, 679 | compute_metrics=compute_metrics, 680 | ) 681 | 682 | # Training 683 | if training_args.do_train: 684 | checkpoint = None 685 | if training_args.resume_from_checkpoint is not None: 686 | checkpoint = training_args.resume_from_checkpoint 687 | elif last_checkpoint is not None: 688 | checkpoint = last_checkpoint 689 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 690 | trainer.save_model() # Saves the tokenizer too for easy upload 691 | 692 | metrics = train_result.metrics 693 | max_train_samples = ( 694 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 695 | ) 696 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 697 | 698 | trainer.log_metrics("train", metrics) 699 | trainer.save_metrics("train", metrics) 700 | trainer.save_state() 701 | 702 | # Evaluation 703 | if training_args.do_eval: 704 | logger.info("*** Evaluate ***") 705 | metrics = trainer.evaluate() 706 | 707 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 708 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 709 | 710 | trainer.log_metrics("eval", metrics) 711 | trainer.save_metrics("eval", metrics) 712 | 713 | # Prediction 714 | if training_args.do_predict: 715 | logger.info("*** Predict ***") 716 | results = trainer.predict(predict_dataset, predict_examples) 717 | metrics = results.metrics 718 | 719 | max_predict_samples = ( 720 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 721 | ) 722 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 723 | 724 | trainer.log_metrics("predict", metrics) 725 | trainer.save_metrics("predict", metrics) 726 | 727 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"} 728 | if data_args.dataset_name is not None: 729 | kwargs["dataset_tags"] = data_args.dataset_name 730 | if data_args.dataset_config_name is not None: 731 | kwargs["dataset_args"] = data_args.dataset_config_name 732 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 733 | else: 734 | kwargs["dataset"] = data_args.dataset_name 735 | 736 | if training_args.push_to_hub: 737 | trainer.push_to_hub(**kwargs) 738 | else: 739 | trainer.create_model_card(**kwargs) 740 | --------------------------------------------------------------------------------