├── AdaPrune ├── LICENSE ├── create_calib_folder.py ├── data.py ├── evaluate.py ├── main.py ├── models │ ├── __init__.py │ ├── modules │ │ ├── batch_norm.py │ │ ├── birelu.py │ │ ├── bwn.py │ │ ├── checkpoint.py │ │ ├── evolved_modules.py │ │ ├── fixed_proj.py │ │ ├── fixup.py │ │ ├── lp_norm.py │ │ ├── quantize.py │ │ └── se.py │ └── resnet.py ├── preprocess.py ├── requirements.txt ├── scripts │ ├── adaprune_dense_bnt.sh │ └── adaprune_sparse.sh ├── trainer.py └── utils │ ├── LICENSE │ ├── absorb_bn.py │ ├── adaprune.py │ ├── cross_entropy.py │ ├── dataset.py │ ├── functions.py │ ├── log.py │ ├── meters.py │ ├── misc.py │ ├── mixup.py │ ├── optim.py │ ├── param_filter.py │ ├── regime.py │ └── regularization.py ├── README.md ├── common ├── flatten_object.py ├── json_utils.py └── timer.py ├── dynamic_TNM ├── scripts │ ├── clone_and_copy.sh │ ├── run_R18.sh │ └── run_R50.sh ├── src │ ├── configs │ │ ├── config_resnet18_4by8_transpose.yaml │ │ ├── config_resnet50_4by8_transpose.yaml │ │ └── config_resnext50_4by8_transpose.yaml │ ├── dist_utils.py │ ├── resnet.py │ ├── sparse_ops.py │ ├── sparse_ops_init.py │ ├── train_imagenet.py │ ├── train_val.sh │ └── utils.py └── train-20210211_125543.log ├── prune ├── prune.py ├── pruning_method_based_mask.py ├── pruning_method_transposable_block_l1.py ├── pruning_method_transposable_block_l1_graphs.py ├── pruning_method_utils.py └── sparsity_freezer.py ├── static_TNM ├── scripts │ └── prune_pretrained_R50.sh └── src │ └── prune_pretrained_model.py └── vision ├── LICENSE ├── autoaugment.py ├── data.py ├── main.py ├── models ├── __init__.py ├── alexnet.py ├── modules │ ├── activations.py │ ├── checkpoint.py │ └── se.py └── resnet.py ├── preprocess.py ├── trainer.py └── utils ├── LICENSE ├── absorb_bn.py ├── cross_entropy.py ├── dataset.py ├── log.py ├── meters.py ├── misc.py ├── mixup.py ├── optim.py ├── param_filter.py ├── regime.py └── regularization.py /AdaPrune/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /AdaPrune/create_calib_folder.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import shutil 4 | 5 | basepath = '/home/Datasets/imagenet/train/' 6 | basepath_calib = '/home/Datasets/imagenet/calib/' 7 | 8 | directory = os.fsencode(basepath) 9 | os.mkdir(basepath_calib) 10 | for d in os.listdir(directory): 11 | dir_name = os.fsdecode(d) 12 | dir_path = os.path.join(basepath,dir_name) 13 | dir_copy_path = os.path.join(basepath_calib,dir_name) 14 | os.mkdir(dir_copy_path) 15 | for f in os.listdir(dir_path): 16 | file_path = os.path.join(dir_path,f) 17 | copy_file_path = os.path.join(dir_copy_path,f) 18 | shutil.copyfile(file_path, copy_file_path) 19 | break -------------------------------------------------------------------------------- /AdaPrune/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import models 12 | import torch.distributed as dist 13 | from data import DataRegime 14 | from utils.log import setup_logging, ResultsLog, save_checkpoint 15 | from utils.optim import OptimRegime 16 | from utils.cross_entropy import CrossEntropyLoss 17 | from utils.misc import torch_dtypes 18 | from utils.param_filter import FilterModules, is_bn 19 | from datetime import datetime 20 | from ast import literal_eval 21 | from trainer import Trainer 22 | 23 | model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Evaluation') 28 | parser.add_argument('evaluate', type=str, 29 | help='evaluate model FILE on validation set') 30 | parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results', 31 | help='results dir') 32 | parser.add_argument('--save', metavar='SAVE', default='', 33 | help='saved folder') 34 | parser.add_argument('--datasets-dir', metavar='DATASETS_DIR', default='~/Datasets', 35 | help='datasets dir') 36 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 37 | help='dataset name or folder') 38 | parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet', 39 | choices=model_names, 40 | help='model architecture: ' + 41 | ' | '.join(model_names) + 42 | ' (default: alexnet)') 43 | parser.add_argument('--input-size', type=int, default=None, 44 | help='image input size') 45 | parser.add_argument('--model-config', default='', 46 | help='additional architecture configuration') 47 | parser.add_argument('--dtype', default='float', 48 | help='type of tensor: ' + 49 | ' | '.join(torch_dtypes.keys()) + 50 | ' (default: float)') 51 | parser.add_argument('--device', default='cuda', 52 | help='device assignment ("cpu" or "cuda")') 53 | parser.add_argument('--device-ids', default=[0], type=int, nargs='+', 54 | help='device ids assignment (e.g 0 1 2 3') 55 | parser.add_argument('--world-size', default=-1, type=int, 56 | help='number of distributed processes') 57 | parser.add_argument('--local_rank', default=-1, type=int, 58 | help='rank of distributed processes') 59 | parser.add_argument('--dist-init', default='env://', type=str, 60 | help='init used to set up distributed training') 61 | parser.add_argument('--dist-backend', default='nccl', type=str, 62 | help='distributed backend') 63 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 64 | help='number of data loading workers (default: 8)') 65 | parser.add_argument('-b', '--batch-size', default=256, type=int, 66 | metavar='N', help='mini-batch size (default: 256)') 67 | parser.add_argument('--label-smoothing', default=0, type=float, 68 | help='label smoothing coefficient - default 0') 69 | parser.add_argument('--mixup', default=None, type=float, 70 | help='mixup alpha coefficient - default None') 71 | parser.add_argument('--duplicates', default=1, type=int, 72 | help='number of augmentations over singel example') 73 | parser.add_argument('--chunk-batch', default=1, type=int, 74 | help='chunk batch size for multiple passes (training)') 75 | parser.add_argument('--augment', action='store_true', default=False, 76 | help='perform augmentations') 77 | parser.add_argument('--cutout', action='store_true', default=False, 78 | help='cutout augmentations') 79 | parser.add_argument('--autoaugment', action='store_true', default=False, 80 | help='use autoaugment policies') 81 | parser.add_argument('--avg-out', action='store_true', default=False, 82 | help='average outputs') 83 | parser.add_argument('--print-freq', '-p', default=10, type=int, 84 | metavar='N', help='print frequency (default: 10)') 85 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 86 | help='path to latest checkpoint (default: none)') 87 | 88 | parser.add_argument('--seed', default=123, type=int, 89 | help='random seed (default: 123)') 90 | 91 | 92 | def main(): 93 | args = parser.parse_args() 94 | main_worker(args) 95 | 96 | 97 | def main_worker(args): 98 | global best_prec1, dtype 99 | best_prec1 = 0 100 | dtype = torch_dtypes.get(args.dtype) 101 | torch.manual_seed(args.seed) 102 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 103 | if args.evaluate: 104 | args.results_dir = '/tmp' 105 | if args.save is '': 106 | args.save = time_stamp 107 | save_path = os.path.join(args.results_dir, args.save) 108 | 109 | args.distributed = args.local_rank >= 0 or args.world_size > 1 110 | 111 | if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0): 112 | os.makedirs(save_path) 113 | 114 | setup_logging(os.path.join(save_path, 'log.txt'), 115 | resume=args.resume is not '', 116 | dummy=args.distributed and args.local_rank > 0) 117 | 118 | results_path = os.path.join(save_path, 'results') 119 | results = ResultsLog( 120 | results_path, title='Training Results - %s' % args.save) 121 | 122 | if 'cuda' in args.device and torch.cuda.is_available(): 123 | torch.cuda.manual_seed_all(args.seed) 124 | torch.cuda.set_device(args.device_ids[0]) 125 | cudnn.benchmark = True 126 | else: 127 | args.device_ids = None 128 | 129 | if not os.path.isfile(args.evaluate): 130 | parser.error('invalid checkpoint: {}'.format(args.evaluate)) 131 | checkpoint = torch.load(args.evaluate, map_location="cpu") 132 | # Overrride configuration with checkpoint info 133 | args.model = checkpoint.get('model', args.model) 134 | args.model_config = checkpoint.get('config', args.model_config) 135 | 136 | logging.info("saving to %s", save_path) 137 | logging.debug("run arguments: %s", args) 138 | logging.info("creating model %s", args.model) 139 | 140 | # create model 141 | model = models.__dict__[args.model] 142 | model_config = {'dataset': args.dataset} 143 | 144 | if args.model_config is not '': 145 | model_config = dict(model_config, **literal_eval(args.model_config)) 146 | 147 | model = model(**model_config) 148 | logging.info("created model with configuration: %s", model_config) 149 | num_parameters = sum([l.nelement() for l in model.parameters()]) 150 | logging.info("number of parameters: %d", num_parameters) 151 | 152 | # load checkpoint 153 | model.load_state_dict(checkpoint['state_dict']) 154 | logging.info("loaded checkpoint '%s' (epoch %s)", 155 | args.evaluate, checkpoint['epoch']) 156 | 157 | # define loss function (criterion) and optimizer 158 | loss_params = {} 159 | if args.label_smoothing > 0: 160 | loss_params['smooth_eps'] = args.label_smoothing 161 | criterion = getattr(model, 'criterion', nn.NLLLoss)(**loss_params) 162 | criterion.to(args.device, dtype) 163 | model.to(args.device, dtype) 164 | 165 | # Batch-norm should always be done in float 166 | if 'half' in args.dtype: 167 | FilterModules(model, module=is_bn).to(dtype=torch.float) 168 | 169 | trainer = Trainer(model, criterion, 170 | device_ids=args.device_ids, device=args.device, dtype=dtype, 171 | mixup=args.mixup, print_freq=args.print_freq) 172 | 173 | # Evaluation Data loading code 174 | val_data = DataRegime(getattr(model, 'data_eval_regime', None), 175 | defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': args.augment, 176 | 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': False, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 177 | 'cutout': {'holes': 1, 'length': 16} if args.cutout else None, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False}) 178 | 179 | results = trainer.validate(val_data.get_loader(), 180 | duplicates=val_data.get('duplicates'), 181 | average_output=args.avg_out) 182 | logging.info(results) 183 | return results 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /AdaPrune/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /AdaPrune/models/modules/batch_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import BatchNorm1d as _BatchNorm1d 4 | from torch.nn import BatchNorm2d as _BatchNorm2d 5 | from torch.nn import BatchNorm3d as _BatchNorm3d 6 | 7 | """ 8 | BatchNorm variants that can be disabled by removing all parameters and running stats 9 | """ 10 | 11 | 12 | def has_running_stats(m): 13 | return getattr(m, 'running_mean', None) is not None\ 14 | or getattr(m, 'running_var', None) is not None 15 | 16 | 17 | def has_parameters(m): 18 | return getattr(m, 'weight', None) is not None\ 19 | or getattr(m, 'bias', None) is not None 20 | 21 | 22 | class BatchNorm1d(_BatchNorm1d): 23 | def forward(self, inputs): 24 | if not (has_parameters(self) or has_running_stats(self)): 25 | return inputs 26 | return super(BatchNorm1d, self).forward(inputs) 27 | 28 | 29 | class BatchNorm2d(_BatchNorm2d): 30 | def forward(self, inputs): 31 | if not (has_parameters(self) or has_running_stats(self)): 32 | return inputs 33 | return super(BatchNorm2d, self).forward(inputs) 34 | 35 | 36 | class BatchNorm3d(_BatchNorm3d): 37 | def forward(self, inputs): 38 | if not (has_parameters(self) or has_running_stats(self)): 39 | return inputs 40 | return super(BatchNorm3d, self).forward(inputs) 41 | 42 | 43 | class MeanBatchNorm2d(nn.BatchNorm2d): 44 | """BatchNorm with mean-only normalization""" 45 | 46 | def __init__(self, num_features, momentum=0.1, bias=True): 47 | nn.Module.__init__(self) 48 | self.register_buffer('running_mean', torch.zeros(num_features)) 49 | self.momentum = momentum 50 | self.num_features = num_features 51 | if bias: 52 | self.bias = nn.Parameter(torch.zeros(num_features)) 53 | else: 54 | self.register_parameter('bias', None) 55 | 56 | def forward(self, x): 57 | if not (has_parameters(self) or has_running_stats(self)): 58 | return x 59 | if self.training: 60 | numel = x.size(0) * x.size(2) * x.size(3) 61 | mean = x.sum((0, 2, 3)) / numel 62 | with torch.no_grad(): 63 | self.running_mean.mul_(self.momentum)\ 64 | .add_(1 - self.momentum, mean) 65 | else: 66 | mean = self.running_mean 67 | if self.bias is not None: 68 | mean = mean - self.bias 69 | return x - mean.view(1, -1, 1, 1) 70 | 71 | def extra_repr(self): 72 | return '{num_features}, momentum={momentum}, bias={has_bias}'.format( 73 | has_bias=self.bias is not None, **self.__dict__) 74 | -------------------------------------------------------------------------------- /AdaPrune/models/modules/birelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import InplaceFunction 3 | import torch.nn as nn 4 | 5 | 6 | class BiReLUFunction(InplaceFunction): 7 | 8 | @staticmethod 9 | def forward(ctx, input, inplace=False): 10 | if input.size(1) % 2 != 0: 11 | raise RuntimeError("dimension 1 of input must be multiple of 2, " 12 | "but got {}".format(input.size(1))) 13 | ctx.inplace = inplace 14 | 15 | if ctx.inplace: 16 | ctx.mark_dirty(input) 17 | output = input 18 | else: 19 | output = input.clone() 20 | 21 | pos, neg = output.chunk(2, dim=1) 22 | pos.clamp_(min=0) 23 | neg.clamp_(max=0) 24 | ctx.save_for_backward(output) 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | output, = ctx.saved_variables 30 | grad_input = grad_output.masked_fill(output.eq(0), 0) 31 | return grad_input, None 32 | 33 | 34 | def birelu(x, inplace=False): 35 | return BiReLUFunction().apply(x, inplace) 36 | 37 | 38 | class BiReLU(nn.Module): 39 | """docstring for BiReLU.""" 40 | 41 | def __init__(self, inplace=False): 42 | super(BiReLU, self).__init__() 43 | self.inplace = inplace 44 | 45 | def forward(self, inputs): 46 | return birelu(inputs, inplace=self.inplace) 47 | 48 | -------------------------------------------------------------------------------- /AdaPrune/models/modules/bwn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Weight Normalization from https://arxiv.org/abs/1602.07868 3 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py 4 | """ 5 | import torch 6 | from torch.nn.parameter import Parameter 7 | from torch.autograd import Function 8 | import torch.nn as nn 9 | 10 | 11 | def _norm(x, dim, p=2): 12 | """Computes the norm over all dimensions except dim""" 13 | if p == -1: 14 | def func(x, dim): return x.max(dim=dim)[0] - x.min(dim=dim)[0] 15 | elif p == float('inf'): 16 | def func(x, dim): return x.max(dim=dim)[0] 17 | else: 18 | def func(x, dim): return torch.norm(x, dim=dim, p=p) 19 | if dim is None: 20 | return x.norm(p=p) 21 | elif dim == 0: 22 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 23 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 24 | elif dim == x.dim() - 1: 25 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 26 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 27 | else: 28 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 29 | 30 | 31 | def _mean(p, dim): 32 | """Computes the mean over all dimensions except dim""" 33 | if dim is None: 34 | return p.mean() 35 | elif dim == 0: 36 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 37 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 38 | elif dim == p.dim() - 1: 39 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 40 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 41 | else: 42 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 43 | 44 | 45 | class BoundedWeightNorm(object): 46 | 47 | def __init__(self, name, dim, p): 48 | self.name = name 49 | self.dim = dim 50 | 51 | def compute_weight(self, module): 52 | 53 | v = getattr(module, self.name + '_v') 54 | v.data.div_(_norm(v, self.dim)) 55 | init_norm = getattr(module, self.name + '_init_norm') 56 | return v * (init_norm / _norm(v, self.dim)) 57 | 58 | @staticmethod 59 | def apply(module, name, dim, p): 60 | fn = BoundedWeightNorm(name, dim, p) 61 | 62 | weight = getattr(module, name) 63 | 64 | # remove w from parameter list 65 | del module._parameters[name] 66 | module.register_buffer( 67 | name + '_init_norm', torch.Tensor([_norm(weight, dim, p=p).data.mean()])) 68 | module.register_parameter(name + '_v', Parameter(weight.data)) 69 | setattr(module, name, fn.compute_weight(module)) 70 | 71 | # recompute weight before every forward() 72 | module.register_forward_pre_hook(fn) 73 | return fn 74 | 75 | def remove(self, module): 76 | weight = self.compute_weight(module) 77 | delattr(module, self.name) 78 | del module._parameters[self.name + '_v'] 79 | module.register_parameter(self.name, Parameter(weight.data)) 80 | 81 | def __call__(self, module, inputs): 82 | setattr(module, self.name, self.compute_weight(module)) 83 | 84 | 85 | def weight_norm(module, name='weight', dim=0, p=2): 86 | r"""Applies weight normalization to a parameter in the given module. 87 | 88 | .. math:: 89 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 90 | 91 | Weight normalization is a reparameterization that decouples the magnitude 92 | of a weight tensor from its direction. This replaces the parameter specified 93 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude 94 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v"). 95 | Weight normalization is implemented via a hook that recomputes the weight 96 | tensor from the magnitude and direction before every :meth:`~Module.forward` 97 | call. 98 | 99 | By default, with `dim=0`, the norm is computed independently per output 100 | channel/plane. To compute a norm over the entire weight tensor, use 101 | `dim=None`. 102 | 103 | See https://arxiv.org/abs/1602.07868 104 | 105 | Args: 106 | module (nn.Module): containing module 107 | name (str, optional): name of weight parameter 108 | dim (int, optional): dimension over which to compute the norm 109 | 110 | Returns: 111 | The original module with the weight norm hook 112 | 113 | Example:: 114 | 115 | >>> m = weight_norm(nn.Linear(20, 40), name='weight') 116 | Linear (20 -> 40) 117 | >>> m.weight_g.size() 118 | torch.Size([40, 1]) 119 | >>> m.weight_v.size() 120 | torch.Size([40, 20]) 121 | 122 | """ 123 | BoundedWeightNorm.apply(module, name, dim, p) 124 | return module 125 | 126 | 127 | def remove_weight_norm(module, name='weight'): 128 | r"""Removes the weight normalization reparameterization from a module. 129 | 130 | Args: 131 | module (nn.Module): containing module 132 | name (str, optional): name of weight parameter 133 | 134 | Example: 135 | >>> m = weight_norm(nn.Linear(20, 40)) 136 | >>> remove_weight_norm(m) 137 | """ 138 | for k, hook in module._forward_pre_hooks.items(): 139 | if isinstance(hook, BoundedWeightNorm) and hook.name == name: 140 | hook.remove(module) 141 | del module._forward_pre_hooks[k] 142 | return module 143 | 144 | raise ValueError("weight_norm of '{}' not found in {}" 145 | .format(name, module)) 146 | -------------------------------------------------------------------------------- /AdaPrune/models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 4 | 5 | 6 | class CheckpointModule(nn.Module): 7 | def __init__(self, module, num_segments=1): 8 | super(CheckpointModule, self).__init__() 9 | assert num_segments == 1 or isinstance(module, nn.Sequential) 10 | self.module = module 11 | self.num_segments = num_segments 12 | 13 | def forward(self, *inputs): 14 | if self.num_segments > 1: 15 | return checkpoint_sequential(self.module, self.num_segments, *inputs) 16 | else: 17 | return checkpoint(self.module, *inputs) 18 | -------------------------------------------------------------------------------- /AdaPrune/models/modules/evolved_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from https://github.com/quark0/darts 3 | """ 4 | from collections import namedtuple 5 | import torch 6 | import torch.nn as nn 7 | 8 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 9 | 10 | OPS = { 11 | 'avg_pool_3x3': lambda channels, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 12 | 'max_pool_3x3': lambda channels, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 13 | 'skip_connect': lambda channels, stride, affine: Identity() if stride == 1 else FactorizedReduce(channels, channels, affine=affine), 14 | 'sep_conv_3x3': lambda channels, stride, affine: SepConv(channels, channels, 3, stride, 1, affine=affine), 15 | 'sep_conv_5x5': lambda channels, stride, affine: SepConv(channels, channels, 5, stride, 2, affine=affine), 16 | 'sep_conv_7x7': lambda channels, stride, affine: SepConv(channels, channels, 7, stride, 3, affine=affine), 17 | 'dil_conv_3x3': lambda channels, stride, affine: DilConv(channels, channels, 3, stride, 2, 2, affine=affine), 18 | 'dil_conv_5x5': lambda channels, stride, affine: DilConv(channels, channels, 5, stride, 4, 2, affine=affine), 19 | 'conv_7x1_1x7': lambda channels, stride, affine: nn.Sequential( 20 | nn.ReLU(inplace=False), 21 | nn.Conv2d(channels, channels, (1, 7), stride=(1, stride), 22 | padding=(0, 3), bias=False), 23 | nn.Conv2d(channels, channels, (7, 1), stride=(stride, 1), 24 | padding=(3, 0), bias=False), 25 | nn.BatchNorm2d(channels, affine=affine) 26 | ), 27 | } 28 | 29 | 30 | # genotypes 31 | GENOTYPES = dict( 32 | NASNet=Genotype( 33 | normal=[ 34 | ('sep_conv_5x5', 1), 35 | ('sep_conv_3x3', 0), 36 | ('sep_conv_5x5', 0), 37 | ('sep_conv_3x3', 0), 38 | ('avg_pool_3x3', 1), 39 | ('skip_connect', 0), 40 | ('avg_pool_3x3', 0), 41 | ('avg_pool_3x3', 0), 42 | ('sep_conv_3x3', 1), 43 | ('skip_connect', 1), 44 | ], 45 | normal_concat=[2, 3, 4, 5, 6], 46 | reduce=[ 47 | ('sep_conv_5x5', 1), 48 | ('sep_conv_7x7', 0), 49 | ('max_pool_3x3', 1), 50 | ('sep_conv_7x7', 0), 51 | ('avg_pool_3x3', 1), 52 | ('sep_conv_5x5', 0), 53 | ('skip_connect', 3), 54 | ('avg_pool_3x3', 2), 55 | ('sep_conv_3x3', 2), 56 | ('max_pool_3x3', 1), 57 | ], 58 | reduce_concat=[4, 5, 6], 59 | ), 60 | 61 | AmoebaNet=Genotype( 62 | normal=[ 63 | ('avg_pool_3x3', 0), 64 | ('max_pool_3x3', 1), 65 | ('sep_conv_3x3', 0), 66 | ('sep_conv_5x5', 2), 67 | ('sep_conv_3x3', 0), 68 | ('avg_pool_3x3', 3), 69 | ('sep_conv_3x3', 1), 70 | ('skip_connect', 1), 71 | ('skip_connect', 0), 72 | ('avg_pool_3x3', 1), 73 | ], 74 | normal_concat=[4, 5, 6], 75 | reduce=[ 76 | ('avg_pool_3x3', 0), 77 | ('sep_conv_3x3', 1), 78 | ('max_pool_3x3', 0), 79 | ('sep_conv_7x7', 2), 80 | ('sep_conv_7x7', 0), 81 | ('avg_pool_3x3', 1), 82 | ('max_pool_3x3', 0), 83 | ('max_pool_3x3', 1), 84 | ('conv_7x1_1x7', 0), 85 | ('sep_conv_3x3', 5), 86 | ], 87 | reduce_concat=[3, 4, 6] 88 | ), 89 | 90 | DARTS_V1=Genotype( 91 | normal=[ 92 | ('sep_conv_3x3', 1), 93 | ('sep_conv_3x3', 0), 94 | ('skip_connect', 0), 95 | ('sep_conv_3x3', 1), 96 | ('skip_connect', 0), 97 | ('sep_conv_3x3', 1), 98 | ('sep_conv_3x3', 0), 99 | ('skip_connect', 2)], 100 | normal_concat=[2, 3, 4, 5], 101 | reduce=[('max_pool_3x3', 0), 102 | ('max_pool_3x3', 1), 103 | ('skip_connect', 2), 104 | ('max_pool_3x3', 0), 105 | ('max_pool_3x3', 0), 106 | ('skip_connect', 2), 107 | ('skip_connect', 2), 108 | ('avg_pool_3x3', 0)], 109 | reduce_concat=[2, 3, 4, 5]), 110 | DARTS=Genotype(normal=[('sep_conv_3x3', 0), 111 | ('sep_conv_3x3', 1), 112 | ('sep_conv_3x3', 0), 113 | ('sep_conv_3x3', 1), 114 | ('sep_conv_3x3', 1), 115 | ('skip_connect', 0), 116 | ('skip_connect', 0), 117 | ('dil_conv_3x3', 2)], 118 | normal_concat=[2, 3, 4, 5], 119 | reduce=[('max_pool_3x3', 0), 120 | ('max_pool_3x3', 1), 121 | ('skip_connect', 2), 122 | ('max_pool_3x3', 1), 123 | ('max_pool_3x3', 0), 124 | ('skip_connect', 2), 125 | ('skip_connect', 2), 126 | ('max_pool_3x3', 1)], 127 | reduce_concat=[2, 3, 4, 5]), 128 | ) 129 | 130 | 131 | class ReLUConvBN(nn.Module): 132 | 133 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 134 | super(ReLUConvBN, self).__init__() 135 | self.op = nn.Sequential( 136 | nn.ReLU(inplace=False), 137 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, 138 | padding=padding, bias=False), 139 | nn.BatchNorm2d(C_out, affine=affine) 140 | ) 141 | 142 | def forward(self, x): 143 | return self.op(x) 144 | 145 | 146 | class DilConv(nn.Module): 147 | 148 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 149 | super(DilConv, self).__init__() 150 | self.op = nn.Sequential( 151 | nn.ReLU(inplace=False), 152 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, 153 | padding=padding, dilation=dilation, groups=C_in, bias=False), 154 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 155 | nn.BatchNorm2d(C_out, affine=affine), 156 | ) 157 | 158 | def forward(self, x): 159 | return self.op(x) 160 | 161 | 162 | class SepConv(nn.Module): 163 | 164 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 165 | super(SepConv, self).__init__() 166 | self.op = nn.Sequential( 167 | nn.ReLU(inplace=False), 168 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, 169 | padding=padding, groups=C_in, bias=False), 170 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 171 | nn.BatchNorm2d(C_in, affine=affine), 172 | nn.ReLU(inplace=False), 173 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, 174 | padding=padding, groups=C_in, bias=False), 175 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 176 | nn.BatchNorm2d(C_out, affine=affine), 177 | ) 178 | 179 | def forward(self, x): 180 | return self.op(x) 181 | 182 | 183 | class Identity(nn.Module): 184 | 185 | def __init__(self): 186 | super(Identity, self).__init__() 187 | 188 | def forward(self, x): 189 | return x 190 | 191 | 192 | class FactorizedReduce(nn.Module): 193 | 194 | def __init__(self, C_in, C_out, affine=True): 195 | super(FactorizedReduce, self).__init__() 196 | assert C_out % 2 == 0 197 | self.relu = nn.ReLU(inplace=False) 198 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, 199 | stride=2, padding=0, bias=False) 200 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, 201 | stride=2, padding=0, bias=False) 202 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 203 | 204 | def forward(self, x): 205 | x = self.relu(x) 206 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) 207 | out = self.bn(out) 208 | return out 209 | 210 | 211 | def drop_path(x, drop_prob): 212 | if drop_prob > 0.: 213 | keep_prob = 1.-drop_prob 214 | mask = x.new(x.size(0), 1, 1, 1).bernoulli_(keep_prob) 215 | x.div_(keep_prob) 216 | x.mul_(mask) 217 | return x 218 | 219 | 220 | class Cell(nn.Module): 221 | 222 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): 223 | super(Cell, self).__init__() 224 | if reduction_prev: 225 | self.preprocess0 = FactorizedReduce(C_prev_prev, C) 226 | else: 227 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) 228 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) 229 | 230 | if reduction: 231 | op_names, indices = zip(*genotype.reduce) 232 | concat = genotype.reduce_concat 233 | else: 234 | op_names, indices = zip(*genotype.normal) 235 | concat = genotype.normal_concat 236 | self._compile(C, op_names, indices, concat, reduction) 237 | 238 | def _compile(self, C, op_names, indices, concat, reduction): 239 | assert len(op_names) == len(indices) 240 | self._steps = len(op_names) // 2 241 | self._concat = concat 242 | self.multiplier = len(concat) 243 | 244 | self._ops = nn.ModuleList() 245 | for name, index in zip(op_names, indices): 246 | stride = 2 if reduction and index < 2 else 1 247 | op = OPS[name](C, stride, True) 248 | self._ops += [op] 249 | self._indices = indices 250 | 251 | def forward(self, s0, s1, drop_prob): 252 | s0 = self.preprocess0(s0) 253 | s1 = self.preprocess1(s1) 254 | 255 | states = [s0, s1] 256 | for i in range(self._steps): 257 | h1 = states[self._indices[2*i]] 258 | h2 = states[self._indices[2*i+1]] 259 | op1 = self._ops[2*i] 260 | op2 = self._ops[2*i+1] 261 | h1 = op1(h1) 262 | h2 = op2(h2) 263 | if self.training and drop_prob > 0.: 264 | if not isinstance(op1, Identity): 265 | h1 = drop_path(h1, drop_prob) 266 | if not isinstance(op2, Identity): 267 | h2 = drop_path(h2, drop_prob) 268 | s = h1 + h2 269 | states += [s] 270 | return torch.cat([states[i] for i in self._concat], dim=1) 271 | 272 | 273 | class NasNetCell(Cell): 274 | def __init__(self, *kargs, **kwargs): 275 | super(NasNetCell, self).__init__(GENOTYPES['NASNet'], *kargs, **kwargs) 276 | 277 | 278 | class AmoebaNetCell(Cell): 279 | def __init__(self, *kargs, **kwargs): 280 | super(AmoebaNetCell, self).__init__( 281 | GENOTYPES['AmoebaNet'], *kargs, **kwargs) 282 | 283 | 284 | class DARTSCell(Cell): 285 | def __init__(self, *kargs, **kwargs): 286 | super(DARTSCell, self).__init__(GENOTYPES['DARTS'], *kargs, **kwargs) 287 | -------------------------------------------------------------------------------- /AdaPrune/models/modules/fixed_proj.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | from torch.autograd import Variable 5 | from scipy.linalg import hadamard 6 | 7 | class HadamardProj(nn.Module): 8 | 9 | def __init__(self, input_size, output_size, bias=True, fixed_weights=True, fixed_scale=None): 10 | super(HadamardProj, self).__init__() 11 | self.output_size = output_size 12 | self.input_size = input_size 13 | sz = 2 ** int(math.ceil(math.log(max(input_size, output_size), 2))) 14 | mat = torch.from_numpy(hadamard(sz)) 15 | if fixed_weights: 16 | self.proj = Variable(mat, requires_grad=False) 17 | else: 18 | self.proj = nn.Parameter(mat) 19 | 20 | init_scale = 1. / math.sqrt(self.output_size) 21 | 22 | if fixed_scale is not None: 23 | self.scale = Variable(torch.Tensor( 24 | [fixed_scale]), requires_grad=False) 25 | else: 26 | self.scale = nn.Parameter(torch.Tensor([init_scale])) 27 | 28 | if bias: 29 | self.bias = nn.Parameter(torch.Tensor( 30 | output_size).uniform_(-init_scale, init_scale)) 31 | else: 32 | self.register_parameter('bias', None) 33 | 34 | self.eps = 1e-8 35 | 36 | def forward(self, x): 37 | if not isinstance(self.scale, nn.Parameter): 38 | self.scale = self.scale.type_as(x) 39 | x = x / (x.norm(2, -1, keepdim=True) + self.eps) 40 | w = self.proj.type_as(x) 41 | 42 | out = -self.scale * \ 43 | nn.functional.linear(x, w[:self.output_size, :self.input_size]) 44 | if self.bias is not None: 45 | out = out + self.bias.view(1, -1) 46 | return out 47 | 48 | 49 | class Proj(nn.Module): 50 | 51 | def __init__(self, input_size, output_size, bias=True, init_scale=10): 52 | super(Proj, self).__init__() 53 | if init_scale is not None: 54 | self.weight = nn.Parameter(torch.Tensor(1).fill_(init_scale)) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(output_size).fill_(0)) 57 | self.proj = Variable(torch.Tensor( 58 | output_size, input_size), requires_grad=False) 59 | torch.manual_seed(123) 60 | nn.init.orthogonal(self.proj) 61 | 62 | def forward(self, x): 63 | w = self.proj.type_as(x) 64 | x = x / x.norm(2, -1, keepdim=True) 65 | out = nn.functional.linear(x, w) 66 | if hasattr(self, 'weight'): 67 | out = out * self.weight 68 | if hasattr(self, 'bias'): 69 | out = out + self.bias.view(1, -1) 70 | return out 71 | 72 | class LinearFixed(nn.Linear): 73 | 74 | def __init__(self, input_size, output_size, bias=True, init_scale=10): 75 | super(LinearFixed, self).__init__(input_size, output_size, bias) 76 | self.scale = nn.Parameter(torch.Tensor(1).fill_(init_scale)) 77 | 78 | def forward(self, x): 79 | w = self.weight / self.weight.norm(2, -1, keepdim=True) 80 | x = x / x.norm(2, -1, keepdim=True) 81 | out = nn.functional.linear(x, w, self.bias) 82 | return out 83 | -------------------------------------------------------------------------------- /AdaPrune/models/modules/fixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def _sum_tensor_scalar(tensor, scalar, expand_size): 6 | if scalar is not None: 7 | scalar = scalar.expand(expand_size).contiguous() 8 | else: 9 | return tensor 10 | if tensor is None: 11 | return scalar 12 | return tensor + scalar 13 | 14 | 15 | class ZIConv2d(nn.Conv2d): 16 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 17 | padding=0, dilation=1, groups=1, bias=False, 18 | multiplier=False, pre_bias=True, post_bias=True): 19 | super(ZIConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 20 | padding, dilation, groups, bias) 21 | if pre_bias: 22 | self.pre_bias = nn.Parameter(torch.tensor([0.])) 23 | else: 24 | self.register_parameter('pre_bias', None) 25 | if post_bias: 26 | self.post_bias = nn.Parameter(torch.tensor([0.])) 27 | else: 28 | self.register_parameter('post_bias', None) 29 | if multiplier: 30 | self.multiplier = nn.Parameter(torch.tensor([1.])) 31 | else: 32 | self.register_parameter('multiplier', None) 33 | 34 | def forward(self, x): 35 | if self.pre_bias is not None: 36 | x = x + self.pre_bias 37 | weight = self.weight if self.multiplier is None\ 38 | else self.weight * self.multiplier 39 | bias = _sum_tensor_scalar(self.bias, self.post_bias, self.out_channels) 40 | return nn.functional.conv2d(x, weight, bias, self.stride, 41 | self.padding, self.dilation, self.groups) 42 | 43 | 44 | class ZILinear(nn.Linear): 45 | def __init__(self, in_features, out_features, bias=False, 46 | multiplier=False, pre_bias=True, post_bias=True): 47 | super(ZILinear, self).__init__(in_features, out_features, bias) 48 | if pre_bias: 49 | self.pre_bias = nn.Parameter(torch.tensor([0.])) 50 | else: 51 | self.register_parameter('pre_bias', None) 52 | if post_bias: 53 | self.post_bias = nn.Parameter(torch.tensor([0.])) 54 | else: 55 | self.register_parameter('post_bias', None) 56 | if multiplier: 57 | self.multiplier = nn.Parameter(torch.tensor([1.])) 58 | else: 59 | self.register_parameter('multiplier', None) 60 | 61 | def forward(self, x): 62 | if self.pre_bias is not None: 63 | x = x + self.pre_bias 64 | weight = self.weight if self.multiplier is None\ 65 | else self.weight * self.multiplier 66 | bias = _sum_tensor_scalar(self.bias, self.post_bias, self.out_features) 67 | return nn.functional.linear(x, weight, bias) 68 | -------------------------------------------------------------------------------- /AdaPrune/models/modules/se.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SEBlock(nn.Module): 5 | def __init__(self, in_channels, out_channels=None, ratio=16): 6 | super(SEBlock, self).__init__() 7 | self.in_channels = in_channels 8 | if out_channels is None: 9 | out_channels = in_channels 10 | self.ratio = ratio 11 | self.relu = nn.ReLU(True) 12 | self.global_pool = nn.AdaptiveAvgPool2d(1) 13 | self.transform = nn.Sequential( 14 | nn.Linear(in_channels, in_channels // ratio), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(in_channels // ratio, out_channels), 17 | nn.Sigmoid() 18 | ) 19 | 20 | def forward(self, x): 21 | x_avg = self.global_pool(x).view(x.size(0), -1) 22 | mask = self.transform(x_avg) 23 | return x * mask.view(x.size(0), -1, 1, 1) 24 | 25 | -------------------------------------------------------------------------------- /AdaPrune/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.transforms as transforms 4 | import random 5 | import PIL 6 | 7 | 8 | _IMAGENET_STATS = {'mean': [0.485, 0.456, 0.406], 9 | 'std': [0.229, 0.224, 0.225]} 10 | 11 | _IMAGENET_PCA = { 12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 13 | 'eigvec': torch.Tensor([ 14 | [-0.5675, 0.7192, 0.4009], 15 | [-0.5808, -0.0045, -0.8140], 16 | [-0.5836, -0.6948, 0.4203], 17 | ]) 18 | } 19 | 20 | 21 | def scale_crop(input_size, scale_size=None, num_crops=1, normalize=_IMAGENET_STATS): 22 | assert num_crops in [1, 5, 10], "num crops must be in {1,5,10}" 23 | convert_tensor = transforms.Compose([transforms.ToTensor(), 24 | transforms.Normalize(**normalize)]) 25 | if num_crops == 1: 26 | t_list = [ 27 | transforms.CenterCrop(input_size), 28 | convert_tensor 29 | ] 30 | else: 31 | if num_crops == 5: 32 | t_list = [transforms.FiveCrop(input_size)] 33 | elif num_crops == 10: 34 | t_list = [transforms.TenCrop(input_size)] 35 | # returns a 4D tensor 36 | t_list.append(transforms.Lambda(lambda crops: 37 | torch.stack([convert_tensor(crop) for crop in crops]))) 38 | 39 | if scale_size != input_size: 40 | t_list = [transforms.Resize(scale_size)] + t_list 41 | 42 | return transforms.Compose(t_list) 43 | 44 | 45 | def scale_random_crop(input_size, scale_size=None, normalize=_IMAGENET_STATS): 46 | t_list = [ 47 | transforms.RandomCrop(input_size), 48 | transforms.ToTensor(), 49 | transforms.Normalize(**normalize), 50 | ] 51 | if scale_size != input_size: 52 | t_list = [transforms.Resize(scale_size)] + t_list 53 | 54 | transforms.Compose(t_list) 55 | 56 | 57 | def pad_random_crop(input_size, scale_size=None, normalize=_IMAGENET_STATS): 58 | padding = int((scale_size - input_size) / 2) 59 | return transforms.Compose([ 60 | transforms.RandomCrop(input_size, padding=padding), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ToTensor(), 63 | transforms.Normalize(**normalize), 64 | ]) 65 | 66 | 67 | 68 | def inception_preproccess(input_size, normalize=_IMAGENET_STATS): 69 | return transforms.Compose([ 70 | transforms.RandomResizedCrop(input_size), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | transforms.Normalize(**normalize) 74 | ]) 75 | 76 | 77 | def inception_color_preproccess(input_size, normalize=_IMAGENET_STATS): 78 | return transforms.Compose([ 79 | transforms.RandomResizedCrop(input_size), 80 | transforms.RandomHorizontalFlip(), 81 | transforms.ColorJitter( 82 | brightness=0.4, 83 | contrast=0.4, 84 | saturation=0.4, 85 | ), 86 | transforms.ToTensor(), 87 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), 88 | transforms.Normalize(**normalize) 89 | ]) 90 | 91 | 92 | def multi_transform(transform_fn, duplicates=1, dim=0): 93 | """preforms multiple transforms, useful to implement inference time augmentation or 94 | "batch augmentation" from https://openreview.net/forum?id=H1V4QhAqYQ¬eId=BylUSs_3Y7 95 | """ 96 | if duplicates > 1: 97 | return transforms.Lambda(lambda x: torch.stack([transform_fn(x) for _ in range(duplicates)], dim=dim)) 98 | else: 99 | return transform_fn 100 | 101 | 102 | def get_transform(transform_name='imagenet', input_size=None, scale_size=None, 103 | normalize=None, augment=True, cutout=None, autoaugment=False, 104 | duplicates=1, num_crops=1): 105 | normalize = normalize or _IMAGENET_STATS 106 | transform_fn = None 107 | 108 | if 'imagenet' in transform_name: # inception augmentation is default for imagenet 109 | scale_size = scale_size or 256 110 | input_size = input_size or 224 111 | if augment: 112 | transform_fn = inception_preproccess(input_size, 113 | normalize=normalize) 114 | else: 115 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 116 | num_crops=num_crops, normalize=normalize) 117 | elif 'cifar' in transform_name: # resnet augmentation is default for imagenet 118 | input_size = input_size or 32 119 | if augment: 120 | scale_size = scale_size or 40 121 | if autoaugment: 122 | transform_fn = cifar_autoaugment(input_size, scale_size=scale_size, 123 | normalize=normalize) 124 | else: 125 | transform_fn = pad_random_crop(input_size, scale_size=scale_size, 126 | normalize=normalize) 127 | else: 128 | scale_size = scale_size or 32 129 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 130 | num_crops=num_crops, normalize=normalize) 131 | elif transform_name == 'mnist': 132 | normalize = {'mean': [0.5], 'std': [0.5]} 133 | input_size = input_size or 28 134 | if augment: 135 | scale_size = scale_size or 32 136 | transform_fn = pad_random_crop(input_size, scale_size=scale_size, 137 | normalize=normalize) 138 | else: 139 | scale_size = scale_size or 32 140 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 141 | num_crops=num_crops, normalize=normalize) 142 | if cutout is not None: 143 | transform_fn.transforms.append(Cutout(**cutout)) 144 | return multi_transform(transform_fn, duplicates) 145 | 146 | 147 | class Lighting(object): 148 | """Lighting noise(AlexNet - style PCA - based noise)""" 149 | 150 | def __init__(self, alphastd, eigval, eigvec): 151 | self.alphastd = alphastd 152 | self.eigval = eigval 153 | self.eigvec = eigvec 154 | 155 | def __call__(self, img): 156 | if self.alphastd == 0: 157 | return img 158 | 159 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 160 | rgb = self.eigvec.type_as(img).clone()\ 161 | .mul(alpha.view(1, 3).expand(3, 3))\ 162 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 163 | .sum(1).squeeze() 164 | 165 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 166 | 167 | 168 | class Cutout(object): 169 | """ 170 | Randomly mask out one or more patches from an image. 171 | taken from https://github.com/uoguelph-mlrg/Cutout 172 | 173 | 174 | Args: 175 | holes (int): Number of patches to cut out of each image. 176 | length (int): The length (in pixels) of each square patch. 177 | """ 178 | 179 | def __init__(self, holes, length): 180 | self.holes = holes 181 | self.length = length 182 | 183 | def __call__(self, img): 184 | """ 185 | Args: 186 | img (Tensor): Tensor image of size (C, H, W). 187 | Returns: 188 | Tensor: Image with holes of dimension length x length cut out of it. 189 | """ 190 | h = img.size(1) 191 | w = img.size(2) 192 | 193 | mask = np.ones((h, w), np.float32) 194 | 195 | for n in range(self.holes): 196 | y = np.random.randint(h) 197 | x = np.random.randint(w) 198 | 199 | y1 = np.clip(y - self.length // 2, 0, h) 200 | y2 = np.clip(y + self.length // 2, 0, h) 201 | x1 = np.clip(x - self.length // 2, 0, w) 202 | x2 = np.clip(x + self.length // 2, 0, w) 203 | 204 | mask[y1: y2, x1: x2] = 0. 205 | 206 | mask = torch.from_numpy(mask) 207 | mask = mask.expand_as(img) 208 | img = img * mask 209 | 210 | return img 211 | -------------------------------------------------------------------------------- /AdaPrune/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | bokeh 4 | pandas 5 | -------------------------------------------------------------------------------- /AdaPrune/scripts/adaprune_dense_bnt.sh: -------------------------------------------------------------------------------- 1 | export datasets_dir=/home/Datasets 2 | export model=${1:-"resnet"} 3 | export model_vis=${2:-"resnet50"} 4 | export depth=${3:-50} 5 | export adaprune_suffix='' 6 | if [ "$5" = True ]; then 7 | export adaprune_suffix='.adaprune' 8 | fi 9 | export workdir='dense_'${model_vis}$adaprune_suffix 10 | export perC=True 11 | 12 | 13 | echo ./results/$workdir/resnet 14 | #Download and absorb_bn resnet50 and 15 | python main.py --model $model --save $workdir -b 128 -lfv $model_vis --model-config "{'batch_norm': False,'depth':$depth}" --device-id 1 16 | 17 | # Run adaprune to minimize MSE of the output with respect to a perturations in parameters 18 | python main.py --optimize-weights --model $model -b 200 --evaluate results/$workdir/$model.absorb_bn --model-config "{'batch_norm': False,'depth':$depth}" --dataset imagenet_calib --datasets-dir $datasets_dir --adaprune --prune_bs 8 --prune_topk 4 --device-id 0 --keep_first_last #--unstructured --sparsity_level 0.5 19 | python main.py --batch-norn-tuning --model $model -lfv $model_vis -b 200 --evaluate results/$workdir/$model.absorb_bn.adaprune --model-config "{'batch_norm': False,'depth':$depth}" --dataset imagenet_calib --datasets-dir $datasets_dir --device-id 0 20 | 21 | -------------------------------------------------------------------------------- /AdaPrune/scripts/adaprune_sparse.sh: -------------------------------------------------------------------------------- 1 | export datasets_dir=/home/Datasets 2 | export model=${1:-"resnet"} 3 | export model_vis=${2:-"resnet50"} 4 | export depth=${3:-50} 5 | export adaprune_suffix='.adaprune' 6 | 7 | export workdir='sparse_'${model_vis}$adaprune_suffix 8 | mkdir ./results/$workdir 9 | echo ./results/$workdir/resnet 10 | 11 | #copy sparse model to workdir 12 | cp ./results/resnet50/model_best.pth.tar ./results/$workdir/resnet 13 | 14 | # Run adaprune to minimize MSE of the output with respect to a small perturations in parameters 15 | python main.py --optimize-weights --model $model -b 200 --evaluate results/$workdir/$model --model-config "{'batch_norm': True,'depth':$depth}" --dataset imagenet_calib --datasets-dir $datasets_dir --adaprune --prune_bs 4 --prune_topk 2 16 | 17 | -------------------------------------------------------------------------------- /AdaPrune/utils/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /AdaPrune/utils/absorb_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | # from efficientnet_pytorch.utils import Conv2dSamePadding 5 | 6 | def remove_bn_params(bn_module): 7 | bn_module.register_buffer('running_mean', None) 8 | bn_module.register_buffer('running_var', None) 9 | bn_module.register_parameter('weight', None) 10 | bn_module.register_parameter('bias', None) 11 | 12 | 13 | def init_bn_params(bn_module): 14 | bn_module.running_mean.fill_(0) 15 | bn_module.running_var.fill_(1) 16 | if bn_module.affine: 17 | bn_module.weight.fill_(1) 18 | bn_module.bias.fill_(0) 19 | 20 | 21 | def absorb_bn(module, bn_module, remove_bn=True, verbose=False): 22 | with torch.no_grad(): 23 | w = module.weight 24 | if module.bias is None: 25 | zeros = torch.zeros(module.out_channels, 26 | dtype=w.dtype, device=w.device) 27 | bias = nn.Parameter(zeros) 28 | module.register_parameter('bias', bias) 29 | b = module.bias 30 | 31 | if hasattr(bn_module, 'running_mean'): 32 | b.add_(-bn_module.running_mean) 33 | if hasattr(bn_module, 'running_var'): 34 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5) 35 | w.mul_(invstd.view(w.size(0), 1, 1, 1)) 36 | b.mul_(invstd) 37 | if hasattr(module, 'quantize_weight'): 38 | module.quantize_weight.running_range.mul_(invstd.view(w.size(0), 1, 1, 1)) 39 | module.quantize_weight.running_zero_point.mul_(invstd.view(w.size(0), 1, 1, 1)) 40 | 41 | if hasattr(bn_module, 'weight'): 42 | w.mul_(bn_module.weight.view(w.size(0), 1, 1, 1)) 43 | b.mul_(bn_module.weight) 44 | module.register_parameter('gamma', nn.Parameter(bn_module.weight.data.clone())) 45 | if hasattr(module, 'quantize_weight'): 46 | module.quantize_weight.running_range.mul_(bn_module.weight.view(w.size(0), 1, 1, 1)) 47 | module.quantize_weight.running_zero_point.mul_(bn_module.weight.view(w.size(0), 1, 1, 1)) 48 | if hasattr(bn_module, 'bias'): 49 | b.add_(bn_module.bias) 50 | module.register_parameter('beta', nn.Parameter(bn_module.bias.data.clone())) 51 | 52 | if remove_bn: 53 | remove_bn_params(bn_module) 54 | else: 55 | init_bn_params(bn_module) 56 | 57 | if verbose: 58 | logging.info('BN module %s was asborbed into layer %s' % 59 | (bn_module, module)) 60 | 61 | 62 | def is_bn(m): 63 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 64 | 65 | 66 | def is_absorbing(m): 67 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, Conv2dSamePadding) 68 | 69 | 70 | def search_absorbe_bn(model, prev=None, remove_bn=True, verbose=False): 71 | with torch.no_grad(): 72 | for m in model.children(): 73 | if is_bn(m) and is_absorbing(prev): 74 | # print(prev,m) 75 | absorb_bn(prev, m, remove_bn=remove_bn, verbose=verbose) 76 | search_absorbe_bn(m, remove_bn=remove_bn, verbose=verbose) 77 | prev = m 78 | 79 | 80 | def absorb_fake_bn(module, bn_module, verbose=False): 81 | with torch.no_grad(): 82 | w = module.weight 83 | if module.bias is None: 84 | zeros = torch.zeros(module.out_channels, 85 | dtype=w.dtype, device=w.device) 86 | bias = nn.Parameter(zeros) 87 | module.register_parameter('bias', bias) 88 | 89 | if verbose: 90 | logging.info('BN module %s was asborbed into layer %s' % 91 | (bn_module, module)) 92 | 93 | 94 | def is_fake_bn(m): 95 | from models.resnet import Lambda 96 | return isinstance(m, Lambda) 97 | 98 | 99 | def search_absorbe_fake_bn(model, prev=None, remove_bn=True, verbose=False): 100 | with torch.no_grad(): 101 | for m in model.children(): 102 | if is_fake_bn(m) and is_absorbing(prev): 103 | # print(prev,m) 104 | absorb_fake_bn(prev, m, verbose=verbose) 105 | search_absorbe_fake_bn(m, remove_bn=remove_bn, verbose=verbose) 106 | prev = m 107 | 108 | 109 | def add_bn(module, bn_module, verbose=False): 110 | bn = nn.BatchNorm2d(module.out_channels) 111 | 112 | def bn_forward(bn, x): 113 | res = bn(x) 114 | return res 115 | 116 | bn_module.forward_orig = bn_module.forward 117 | bn_module.forward = lambda x: bn_forward(bn, x) 118 | bn.to(module.weight.device) 119 | 120 | bn.register_buffer('running_var', module.gamma**2) 121 | bn.register_buffer('running_mean', module.beta.clone()) 122 | bn.register_parameter('weight', nn.Parameter(torch.sqrt(bn.running_var + bn.eps))) 123 | bn.register_parameter('bias', nn.Parameter(bn.running_mean.clone())) 124 | 125 | bn_module.bn = bn 126 | 127 | 128 | def need_tuning(module): 129 | return hasattr(module, 'num_bits') #and module.groups == 1 130 | 131 | 132 | def search_add_bn(model, prev=None, remove_bn=True, verbose=False): 133 | with torch.no_grad(): 134 | for m in model.children(): 135 | if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev): 136 | # print(prev,m) 137 | add_bn(prev, m, verbose=verbose) 138 | search_add_bn(m, remove_bn=remove_bn, verbose=verbose) 139 | prev = m 140 | 141 | 142 | def search_absorbe_tuning_bn(model, prev=None, remove_bn=True, verbose=False): 143 | with torch.no_grad(): 144 | for m in model.children(): 145 | if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev): 146 | # print(prev,m) 147 | absorb_bn(prev, m.bn, remove_bn=remove_bn, verbose=verbose) 148 | m.forward = m.forward_orig 149 | m.bn = None 150 | search_absorbe_tuning_bn(m, remove_bn=remove_bn, verbose=verbose) 151 | prev = m 152 | 153 | 154 | def copy_bn_params(module, bn_module, remove_bn=True, verbose=False): 155 | with torch.no_grad(): 156 | if hasattr(bn_module, 'weight'): 157 | module.register_parameter('gamma', nn.Parameter(bn_module.weight.data.clone())) 158 | 159 | if hasattr(bn_module, 'bias'): 160 | module.register_parameter('beta', nn.Parameter(bn_module.bias.data.clone())) 161 | 162 | 163 | def search_copy_bn_params(model, prev=None, remove_bn=True, verbose=False): 164 | with torch.no_grad(): 165 | for m in model.children(): 166 | if is_bn(m) and is_absorbing(prev): 167 | # print(prev,m) 168 | copy_bn_params(prev, m, remove_bn=remove_bn, verbose=verbose) 169 | search_copy_bn_params(m, remove_bn=remove_bn, verbose=verbose) 170 | prev = m 171 | 172 | 173 | # def recalibrate_bn(module, bn_module, verbose=False): 174 | # bn = bn_module.bn 175 | # bn.register_parameter('weight', nn.Parameter(torch.sqrt(bn.running_var + bn.eps))) 176 | # bn.register_parameter('bias', nn.Parameter(bn.running_mean.clone())) 177 | # 178 | # 179 | # def search_bn_recalibrate(model, prev=None, remove_bn=True, verbose=False): 180 | # with torch.no_grad(): 181 | # for m in model.children(): 182 | # if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev): 183 | # recalibrate_bn(prev, m, verbose=verbose) 184 | # search_bn_recalibrate(m, remove_bn=remove_bn, verbose=verbose) 185 | # prev = m 186 | -------------------------------------------------------------------------------- /AdaPrune/utils/adaprune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | import scipy.optimize as opt 7 | import math 8 | 9 | 10 | 11 | def adaprune(layer, mask, cached_inps, cached_outs, test_inp, test_out, lr1=1e-4, lr2=1e-2, iters=1000, progress=True, batch_size=50,relu=False,bs=8,no_optimization=False,keep_first_last=True): 12 | print("\nRun adaprune") 13 | test_inp = test_inp.to(layer.weight.device) 14 | test_out = test_out.to(layer.weight.device) 15 | layer.quantize=False 16 | if keep_first_last and (layer.weight.dim()==2 or layer.weight.shape[1]==3): 17 | return 0.1, 0.1 18 | with torch.no_grad(): 19 | layer.weight.data = absorb_mean_to_nz(layer.weight,mask,bs=bs) 20 | layer.weight.mul_(mask.to(layer.weight.device)) 21 | mse_before = F.mse_loss(layer(test_inp), test_out) 22 | if no_optimization: 23 | return mse_before.item(),mse_before.item() 24 | 25 | lr_w = 1e-3 26 | lr_b = 1e-2 27 | 28 | opt_w = torch.optim.Adam([layer.weight], lr=lr_w) 29 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias = torch.optim.Adam([layer.bias], lr=lr_b) 30 | 31 | losses = [] 32 | 33 | for j in (tqdm(range(iters)) if progress else range(iters)): 34 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 35 | 36 | train_inp = cached_inps[idx].to(layer.weight.device) 37 | train_out = cached_outs[idx].to(layer.weight.device) 38 | qout = layer(train_inp) 39 | if relu: 40 | loss = F.mse_loss(F.relu(qout), F.relu(train_out)) 41 | else: 42 | loss = F.mse_loss(qout, train_out) 43 | 44 | losses.append(loss.item()) 45 | opt_w.zero_grad() 46 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias.zero_grad() 47 | loss.backward() 48 | opt_w.step() 49 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias.step() 50 | with torch.no_grad(): 51 | layer.weight.mul_(mask.to(layer.weight.device)) 52 | 53 | mse_after = F.mse_loss(layer(test_inp), test_out) 54 | return mse_before.item(), mse_after.item() 55 | 56 | def absorb_mean_to_nz(weight,mask,bs=8): 57 | """Prunes the weights with smallest magnitude.""" 58 | if weight.dim()>2: 59 | Co,Ci,k1,k2=weight.shape 60 | pad_size=bs-(Ci*k1*k2)%bs if bs>1 else 0 61 | weight_pad = torch.cat((weight.permute(0,2,3,1).contiguous().view(Co,-1),torch.zeros(Co,pad_size).to(weight.data)),1) 62 | mask_pad = torch.cat((mask.permute(0,2,3,1).contiguous().view(Co,-1).float(),torch.ones(Co,pad_size).to(weight.data).float()),1) 63 | else: 64 | Co,Ci=weight.shape 65 | pad_size=bs-Ci%bs if bs>1 else 0 66 | weight_pad = torch.cat((weight.view(Co,-1),torch.zeros(Co,pad_size).to(weight.data)),1) 67 | mask_pad = torch.cat((mask.view(Co,-1).float(),torch.ones(Co,pad_size).to(weight.data).float()),1) 68 | 69 | weight_pad = weight_pad.view(Co,-1,bs)+weight_pad.view(Co,-1,bs).mul(1-mask_pad.view(Co,-1,bs)).sum(2,keepdim=True).div(mask_pad.view(Co,-1,bs).sum(2,keepdim=True)) 70 | weight_pad.mul_(mask_pad.view(Co,-1,bs)) 71 | if weight.dim()>2: 72 | weight_pad = weight_pad.view(Co,-1)[:,:Ci*k1*k2] 73 | weight_pad = weight_pad.view(Co,k1,k2,Ci).permute(0,3,1,2) 74 | else: 75 | weight_pad = weight_pad.view(Co,-1)[:,:Ci] 76 | return weight_pad 77 | 78 | def create_block_magnitude_mask(weight, bs=2, topk=1): 79 | """Prunes the weights with smallest magnitude.""" 80 | if weight.dim()>2: 81 | Co,Ci,k1,k2=weight.shape 82 | pad_size=bs-(Ci*k1*k2)%bs if bs>1 else 0 83 | weight_pad = torch.cat((weight.permute(0,2,3,1).contiguous().view(Co,-1),torch.zeros(Co,pad_size).to(weight.data)),1) 84 | else: 85 | Co,Ci=weight.shape 86 | pad_size=bs-Ci%bs if bs>1 else 0 87 | weight_pad = torch.cat((weight.view(Co,-1),torch.zeros(Co,pad_size).to(weight.data)),1) 88 | 89 | block_weight = weight_pad.data.abs().view(Co,-1,bs).topk(k=topk,dim=2,sorted=False)[1].reshape(Co,-1,topk) 90 | block_masks = torch.zeros_like(weight_pad).reshape(Co, -1, bs).scatter_(2, block_weight, torch.ones(block_weight.shape).to(weight)) 91 | 92 | if weight.dim()>2: 93 | block_masks = block_masks.view(Co,-1)[:,:Ci*k1*k2] 94 | block_masks = block_masks.view(Co,k1,k2,Ci).permute(0,3,1,2) 95 | else: 96 | block_masks = block_masks.view(Co,-1)[:,:Ci] 97 | return block_masks 98 | 99 | def create_global_unstructured_magnitude_mask(param,global_val): 100 | eps = 0.1 if param.shape[1]==3 else 0 101 | return param.abs().gt(global_val-eps) 102 | 103 | def create_unstructured_magnitude_mask(param,sparsity_level,absorb_mean=True): 104 | topk = int(param.numel()*sparsity_level) 105 | val = param.view(-1).abs().topk(topk,sorted=True)[0][-1] 106 | mask = param.abs().gt(val) 107 | if absorb_mean: 108 | with torch.no_grad(): 109 | mean_val=param[~mask].mean() 110 | aa = param+mask*mean_val 111 | param.copy_(aa) 112 | print('unstructured mask created with %f sparsity'%(mask.sum().float()/mask.numel())) 113 | return mask 114 | 115 | def extract_topk(param,bs,global_val,conf_level=0.95): 116 | if global_val is not None: 117 | param = create_global_unstructured_magnitude_mask(param,global_val) 118 | p = (1 - param.ne(0).float().sum() / param.numel()).item() 119 | n = bs 120 | P=[] 121 | B=param.numel()/n 122 | for k in range(n): 123 | S = 0 124 | for i in range(k,n+1): 125 | C = math.factorial(n)/(math.factorial(i)*math.factorial(n-i)) 126 | S = min(S + C*(p**i)*(1-p)**(n-i),1.0) 127 | P.append(S) 128 | RSD = [math.sqrt((1-pp)/(B*pp)) for pp in P] 129 | P_RSD = np.array(P) #- np.array(RSD)*5 130 | aa = [i for i,p in enumerate(P_RSD) if p>conf_level] 131 | if len(aa)>0: 132 | topk = n-[i for i,p in enumerate(P_RSD) if p>conf_level][-1] 133 | else: 134 | topk=n 135 | return topk 136 | 137 | def create_mask(layer,bs=8,topk=4,prune_extract_topk=False,unstructured =True,sparsity_level=0.5,global_val=None,conf_level=0.95): 138 | if unstructured: 139 | if global_val is not None and not prune_extract_topk: 140 | print('Creating unstructured mask for layer %s'%(layer.name)) 141 | return create_global_unstructured_magnitude_mask(layer.weight,global_val,conf_level) 142 | else: 143 | return create_unstructured_magnitude_mask(layer.weight,sparsity_level=sparsity_level) 144 | if prune_extract_topk: topk = extract_topk(layer.weight,bs,global_val,conf_level=conf_level) 145 | print('Creating mask for layer %s with bs %d ,topk %d'%(layer.name,bs,topk)) 146 | return create_block_magnitude_mask(layer.weight,bs=bs,topk=topk) 147 | 148 | def optimize_layer(layer, in_out, optimize_weights=False,bs=4,topk=2,extract_topk=False,unstructured=False,sparsity_level=0.5,global_val=None,conf_level=0.95): 149 | batch_size = 100 150 | 151 | cached_inps = torch.cat([x[0] for x in in_out]) 152 | cached_outs = torch.cat([x[1] for x in in_out]) 153 | 154 | idx = torch.randperm(cached_inps.size(0))[:batch_size] 155 | 156 | test_inp = cached_inps[idx] 157 | test_out = cached_outs[idx] 158 | 159 | if optimize_weights: 160 | mask = create_mask(layer,bs=bs,topk=topk,prune_extract_topk=extract_topk,unstructured=unstructured,sparsity_level=sparsity_level,global_val=global_val,conf_level=conf_level) 161 | if 'conv1' in layer.name or 'conv2' in layer.name: 162 | mse_before, mse_after = adaprune(layer, mask, cached_inps, cached_outs, test_inp, test_out, iters=1000, lr1=1e-5, lr2=1e-4,relu=False,bs=bs) 163 | else: 164 | mse_before, mse_after = adaprune(layer, mask, cached_inps, cached_outs, test_inp, test_out, iters=1000, lr1=1e-5, lr2=1e-4,relu=False,bs=bs) 165 | 166 | mse_before_opt = mse_before 167 | print("MSE before adaprune (opt weight): {}".format(mse_before)) 168 | print("MSE after adaprune (opt weight): {}".format(mse_after)) 169 | torch.cuda.empty_cache() 170 | else: 171 | mse_before, mse_after = optimize_qparams(layer, cached_inps, cached_outs, test_inp, test_out) 172 | mse_before_opt = mse_before 173 | print("MSE before qparams: {}".format(mse_before)) 174 | print("MSE after qparams: {}".format(mse_after)) 175 | 176 | mse_after_opt = mse_after 177 | 178 | with torch.no_grad(): 179 | N = test_out.numel() 180 | snr_before = (1/math.sqrt(N)) * math.sqrt(N * mse_before_opt) / torch.norm(test_out).item() 181 | snr_after = (1/math.sqrt(N)) * math.sqrt(N * mse_after_opt) / torch.norm(test_out).item() 182 | 183 | 184 | kurt_in = kurtosis(test_inp).item() 185 | kurt_w = kurtosis(layer.weight).item() 186 | 187 | del cached_inps 188 | del cached_outs 189 | torch.cuda.empty_cache() 190 | 191 | return mse_before_opt, mse_after_opt, snr_before, snr_after, kurt_in, kurt_w, mask 192 | 193 | 194 | def kurtosis(x): 195 | var = torch.mean((x - x.mean())**2) 196 | return torch.mean((x - x.mean())**4 / var**2) 197 | 198 | 199 | def dump(model_name, layer, in_out): 200 | path = os.path.join("dump", model_name, layer.name) 201 | if os.path.exists(path): 202 | shutil.rmtree(path) 203 | os.makedirs(path) 204 | 205 | if hasattr(layer, 'groups'): 206 | f = open(os.path.join(path, "groups_{}".format(layer.groups)), 'x') 207 | f.close() 208 | 209 | cached_inps = torch.cat([x[0] for x in in_out]) 210 | cached_outs = torch.cat([x[1] for x in in_out]) 211 | torch.save(cached_inps, os.path.join(path, "input.pt")) 212 | torch.save(cached_outs, os.path.join(path, "output.pt")) 213 | torch.save(layer.weight, os.path.join(path, 'weight.pt')) 214 | if layer.bias is not None: 215 | torch.save(layer.bias, os.path.join(path, 'bias.pt')) 216 | 217 | 218 | -------------------------------------------------------------------------------- /AdaPrune/utils/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .misc import onehot 6 | 7 | 8 | def _is_long(x): 9 | if hasattr(x, 'data'): 10 | x = x.data 11 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor) 12 | 13 | 14 | def cross_entropy(logits, target, weight=None, ignore_index=-100, reduction='mean', 15 | smooth_eps=None, smooth_dist=None): 16 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567""" 17 | smooth_eps = smooth_eps or 0 18 | 19 | # ordinary log-liklihood - use cross_entropy from nn 20 | if _is_long(target) and smooth_eps == 0: 21 | return F.cross_entropy(logits, target, weight, ignore_index=ignore_index, reduction=reduction) 22 | 23 | masked_indices = None 24 | num_classes = logits.size(-1) 25 | 26 | if _is_long(target) and ignore_index >= 0: 27 | masked_indices = target.eq(ignore_index) 28 | 29 | if smooth_eps > 0 and smooth_dist is not None: 30 | if _is_long(target): 31 | target = onehot(target, num_classes).type_as(logits) 32 | if smooth_dist.dim() < target.dim(): 33 | smooth_dist = smooth_dist.unsqueeze(0) 34 | target.lerp_(smooth_dist, smooth_eps) 35 | 36 | # log-softmax of logits 37 | lsm = F.log_softmax(logits, dim=-1) 38 | 39 | if weight is not None: 40 | lsm = lsm * weight.unsqueeze(0) 41 | 42 | if _is_long(target): 43 | eps = smooth_eps / (num_classes - 1) 44 | nll = -lsm.gather(dim=-1, index=target.unsqueeze(-1)) 45 | loss = (1. - 2 * eps) * nll - eps * lsm.sum(-1) 46 | else: 47 | loss = -(target * lsm).sum(-1) 48 | 49 | if masked_indices is not None: 50 | loss.masked_fill_(masked_indices, 0) 51 | 52 | if reduction == 'sum': 53 | loss = loss.sum() 54 | elif reduction == 'mean': 55 | if masked_indices is None: 56 | loss = loss.mean() 57 | else: 58 | loss = loss.sum() / float(loss.size(0) - masked_indices.sum()) 59 | 60 | return loss 61 | 62 | 63 | class CrossEntropyLoss(nn.CrossEntropyLoss): 64 | """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing""" 65 | 66 | def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None): 67 | super(CrossEntropyLoss, self).__init__(weight=weight, 68 | ignore_index=ignore_index, reduction=reduction) 69 | self.smooth_eps = smooth_eps 70 | self.smooth_dist = smooth_dist 71 | 72 | def forward(self, input, target, smooth_dist=None): 73 | if smooth_dist is None: 74 | smooth_dist = self.smooth_dist 75 | return cross_entropy(input, target, self.weight, self.ignore_index, self.reduction, self.smooth_eps, smooth_dist) 76 | -------------------------------------------------------------------------------- /AdaPrune/utils/dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import pickle 3 | import PIL 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.utils.data.sampler import Sampler, RandomSampler, BatchSampler, _int_classes 7 | from numpy.random import choice 8 | 9 | class RandomSamplerReplacment(torch.utils.data.sampler.Sampler): 10 | """Samples elements randomly, with replacement. 11 | Arguments: 12 | data_source (Dataset): dataset to sample from 13 | """ 14 | 15 | def __init__(self, data_source): 16 | self.num_samples = len(data_source) 17 | 18 | def __iter__(self): 19 | return iter(torch.from_numpy(choice(self.num_samples, self.num_samples, replace=True))) 20 | 21 | def __len__(self): 22 | return self.num_samples 23 | 24 | 25 | class LimitDataset(Dataset): 26 | 27 | def __init__(self, dset, max_len): 28 | self.dset = dset 29 | self.max_len = max_len 30 | 31 | def __len__(self): 32 | return min(len(self.dset), self.max_len) 33 | 34 | def __getitem__(self, index): 35 | return self.dset[index] 36 | 37 | class ByClassDataset(Dataset): 38 | 39 | def __init__(self, ds): 40 | self.dataset = ds 41 | self.idx_by_class = {} 42 | for idx, (_, c) in enumerate(ds): 43 | self.idx_by_class.setdefault(c, []) 44 | self.idx_by_class[c].append(idx) 45 | 46 | def __len__(self): 47 | return min([len(d) for d in self.idx_by_class.values()]) 48 | 49 | def __getitem__(self, idx): 50 | idx_per_class = [self.idx_by_class[c][idx] 51 | for c in range(len(self.idx_by_class))] 52 | labels = torch.LongTensor([self.dataset[i][1] 53 | for i in idx_per_class]) 54 | items = [self.dataset[i][0] for i in idx_per_class] 55 | if torch.is_tensor(items[0]): 56 | items = torch.stack(items) 57 | 58 | return (items, labels) 59 | 60 | 61 | class IdxDataset(Dataset): 62 | """docstring for IdxDataset.""" 63 | 64 | def __init__(self, dset): 65 | super(IdxDataset, self).__init__() 66 | self.dset = dset 67 | self.idxs = range(len(self.dset)) 68 | 69 | def __getitem__(self, idx): 70 | data, labels = self.dset[self.idxs[idx]] 71 | return (idx, data, labels) 72 | 73 | def __len__(self): 74 | return len(self.idxs) 75 | 76 | 77 | def image_loader(imagebytes): 78 | img = PIL.Image.open(BytesIO(imagebytes)) 79 | return img.convert('RGB') 80 | 81 | 82 | class IndexedFileDataset(Dataset): 83 | """ A dataset that consists of an indexed file (with sample offsets in 84 | another file). For example, a .tar that contains image files. 85 | The dataset does not extract the samples, but works with the indexed 86 | file directly. 87 | NOTE: The index file is assumed to be a pickled list of 3-tuples: 88 | (name, offset, size). 89 | """ 90 | def __init__(self, filename, index_filename=None, extract_target_fn=None, 91 | transform=None, target_transform=None, loader=image_loader): 92 | super(IndexedFileDataset, self).__init__() 93 | 94 | # Defaults 95 | if index_filename is None: 96 | index_filename = filename + '.index' 97 | if extract_target_fn is None: 98 | extract_target_fn = lambda *args: args 99 | 100 | # Read index 101 | with open(index_filename, 'rb') as index_fp: 102 | sample_list = pickle.load(index_fp) 103 | 104 | # Collect unique targets (sorted by name) 105 | targetset = set(extract_target_fn(target) for target, _, _ in sample_list) 106 | targetmap = {target: i for i, target in enumerate(sorted(targetset))} 107 | 108 | self.samples = [(targetmap[extract_target_fn(target)], offset, size) 109 | for target, offset, size in sample_list] 110 | self.filename = filename 111 | 112 | self.loader = loader 113 | self.transform = transform 114 | self.target_transform = target_transform 115 | 116 | def _get_sample(self, fp, idx): 117 | target, offset, size = self.samples[idx] 118 | fp.seek(offset) 119 | sample = self.loader(fp.read(size)) 120 | 121 | if self.transform is not None: 122 | sample = self.transform(sample) 123 | if self.target_transform is not None: 124 | target = self.target_transform(target) 125 | 126 | return sample, target 127 | 128 | def __getitem__(self, index): 129 | with open(self.filename, 'rb') as fp: 130 | # Handle slices 131 | if isinstance(index, slice): 132 | return [self._get_sample(fp, subidx) for subidx in 133 | range(index.start or 0, index.stop or len(self), 134 | index.step or 1)] 135 | 136 | return self._get_sample(fp, index) 137 | 138 | def __len__(self): 139 | return len(self.samples) 140 | 141 | 142 | class DuplicateBatchSampler(Sampler): 143 | def __init__(self, sampler, batch_size, duplicates, drop_last): 144 | if not isinstance(sampler, Sampler): 145 | raise ValueError("sampler should be an instance of " 146 | "torch.utils.data.Sampler, but got sampler={}" 147 | .format(sampler)) 148 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 149 | batch_size <= 0: 150 | raise ValueError("batch_size should be a positive integeral value, " 151 | "but got batch_size={}".format(batch_size)) 152 | if not isinstance(drop_last, bool): 153 | raise ValueError("drop_last should be a boolean value, but got " 154 | "drop_last={}".format(drop_last)) 155 | self.sampler = sampler 156 | self.batch_size = batch_size 157 | self.drop_last = drop_last 158 | self.duplicates = duplicates 159 | 160 | def __iter__(self): 161 | batch = [] 162 | for idx in self.sampler: 163 | batch.append(idx) 164 | if len(batch) == self.batch_size: 165 | yield batch * self.duplicates 166 | batch = [] 167 | if len(batch) > 0 and not self.drop_last: 168 | yield batch * self.duplicates 169 | 170 | def __len__(self): 171 | if self.drop_last: 172 | return len(self.sampler) // self.batch_size 173 | else: 174 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 175 | -------------------------------------------------------------------------------- /AdaPrune/utils/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function 3 | 4 | class ScaleGrad(Function): 5 | 6 | @staticmethod 7 | def forward(ctx, input, scale): 8 | ctx.scale = scale 9 | return input 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): 13 | grad_input = ctx.scale * grad_output 14 | return grad_input, None 15 | 16 | 17 | def scale_grad(x, scale): 18 | return ScaleGrad().apply(x, scale) 19 | 20 | def negate_grad(x): 21 | return scale_grad(x, -1) 22 | -------------------------------------------------------------------------------- /AdaPrune/utils/log.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from itertools import cycle 4 | import torch 5 | import logging.config 6 | from datetime import datetime 7 | import json 8 | 9 | import pandas as pd 10 | from bokeh.io import output_file, save, show 11 | from bokeh.plotting import figure 12 | from bokeh.layouts import column 13 | from bokeh.models import Div 14 | 15 | try: 16 | import hyperdash 17 | HYPERDASH_AVAILABLE = True 18 | except ImportError: 19 | HYPERDASH_AVAILABLE = False 20 | 21 | 22 | def export_args_namespace(args, filename): 23 | """ 24 | args: argparse.Namespace 25 | arguments to save 26 | filename: string 27 | filename to save at 28 | """ 29 | with open(filename, 'w') as fp: 30 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4) 31 | 32 | 33 | def setup_logging(log_file='log.txt', resume=False, dummy=False): 34 | """ 35 | Setup logging configuration 36 | """ 37 | if dummy: 38 | logging.getLogger('dummy') 39 | else: 40 | if os.path.isfile(log_file) and resume: 41 | file_mode = 'a' 42 | else: 43 | file_mode = 'w' 44 | 45 | root_logger = logging.getLogger() 46 | if root_logger.handlers: 47 | root_logger.removeHandler(root_logger.handlers[0]) 48 | logging.basicConfig(level=logging.DEBUG, 49 | format="%(asctime)s - %(levelname)s - %(message)s", 50 | datefmt="%Y-%m-%d %H:%M:%S", 51 | filename=log_file, 52 | filemode=file_mode) 53 | console = logging.StreamHandler() 54 | console.setLevel(logging.INFO) 55 | formatter = logging.Formatter('%(message)s') 56 | console.setFormatter(formatter) 57 | logging.getLogger('').addHandler(console) 58 | 59 | 60 | def plot_figure(data, x, y, title=None, xlabel=None, ylabel=None, legend=None, 61 | x_axis_type='linear', y_axis_type='linear', 62 | width=800, height=400, line_width=2, 63 | colors=['red', 'green', 'blue', 'orange', 64 | 'black', 'purple', 'brown'], 65 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save', 66 | append_figure=None): 67 | """ 68 | creates a new plot figures 69 | example: 70 | plot_figure(x='epoch', y=['train_loss', 'val_loss'], 71 | 'title='Loss', 'ylabel'='loss') 72 | """ 73 | if not isinstance(y, list): 74 | y = [y] 75 | xlabel = xlabel or x 76 | legend = legend or y 77 | assert len(legend) == len(y) 78 | if append_figure is not None: 79 | f = append_figure 80 | else: 81 | f = figure(title=title, tools=tools, 82 | width=width, height=height, 83 | x_axis_label=xlabel or x, 84 | y_axis_label=ylabel or '', 85 | x_axis_type=x_axis_type, 86 | y_axis_type=y_axis_type) 87 | colors = cycle(colors) 88 | for i, yi in enumerate(y): 89 | f.line(data[x], data[yi], 90 | line_width=line_width, 91 | line_color=next(colors), legend=legend[i]) 92 | f.legend.click_policy = "hide" 93 | return f 94 | 95 | 96 | class ResultsLog(object): 97 | 98 | supported_data_formats = ['csv', 'json'] 99 | 100 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'): 101 | """ 102 | Parameters 103 | ---------- 104 | path: string 105 | path to directory to save data files 106 | plot_path: string 107 | path to directory to save plot files 108 | title: string 109 | title of HTML file 110 | params: Namespace 111 | optionally save parameters for results 112 | resume: bool 113 | resume previous logging 114 | data_format: str('csv'|'json') 115 | which file format to use to save the data 116 | """ 117 | if data_format not in ResultsLog.supported_data_formats: 118 | raise ValueError('data_format must of the following: ' + 119 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats])) 120 | 121 | if data_format == 'json': 122 | self.data_path = '{}.json'.format(path) 123 | else: 124 | self.data_path = '{}.csv'.format(path) 125 | if params is not None: 126 | export_args_namespace(params, '{}.json'.format(path)) 127 | self.plot_path = '{}.html'.format(path) 128 | self.results = None 129 | self.clear() 130 | self.first_save = True 131 | if os.path.isfile(self.data_path): 132 | if resume: 133 | self.load(self.data_path) 134 | self.first_save = False 135 | else: 136 | os.remove(self.data_path) 137 | self.results = pd.DataFrame() 138 | else: 139 | self.results = pd.DataFrame() 140 | 141 | self.title = title 142 | self.data_format = data_format 143 | 144 | if HYPERDASH_AVAILABLE: 145 | name = self.title if title != '' else path 146 | self.hd_experiment = hyperdash.Experiment(name) 147 | if params is not None: 148 | for k, v in params._get_kwargs(): 149 | self.hd_experiment.param(k, v, log=False) 150 | 151 | def clear(self): 152 | self.figures = [] 153 | 154 | def add(self, **kwargs): 155 | """Add a new row to the dataframe 156 | example: 157 | resultsLog.add(epoch=epoch_num, train_loss=loss, 158 | test_loss=test_loss) 159 | """ 160 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 161 | self.results = self.results.append(df, ignore_index=True) 162 | if hasattr(self, 'hd_experiment'): 163 | for k, v in kwargs.items(): 164 | self.hd_experiment.metric(k, v, log=False) 165 | 166 | def smooth(self, column_name, window): 167 | """Select an entry to smooth over time""" 168 | # TODO: smooth only new data 169 | smoothed_column = self.results[column_name].rolling( 170 | window=window, center=False).mean() 171 | self.results[column_name + '_smoothed'] = smoothed_column 172 | 173 | def save(self, title=None): 174 | """save the json file. 175 | Parameters 176 | ---------- 177 | title: string 178 | title of the HTML file 179 | """ 180 | title = title or self.title 181 | if len(self.figures) > 0: 182 | if os.path.isfile(self.plot_path): 183 | os.remove(self.plot_path) 184 | if self.first_save: 185 | self.first_save = False 186 | logging.info('Plot file saved at: {}'.format( 187 | os.path.abspath(self.plot_path))) 188 | 189 | output_file(self.plot_path, title=title) 190 | plot = column( 191 | Div(text='

{}

'.format(title)), *self.figures) 192 | save(plot) 193 | self.clear() 194 | 195 | if self.data_format == 'json': 196 | self.results.to_json(self.data_path, orient='records', lines=True) 197 | else: 198 | self.results.to_csv(self.data_path, index=False, index_label=False) 199 | 200 | def load(self, path=None): 201 | """load the data file 202 | Parameters 203 | ---------- 204 | path: 205 | path to load the json|csv file from 206 | """ 207 | path = path or self.data_path 208 | if os.path.isfile(path): 209 | if self.data_format == 'json': 210 | self.results.read_json(path) 211 | else: 212 | self.results.read_csv(path) 213 | else: 214 | raise ValueError('{} isn''t a file'.format(path)) 215 | 216 | def show(self, title=None): 217 | title = title or self.title 218 | if len(self.figures) > 0: 219 | plot = column( 220 | Div(text='

{}

'.format(title)), *self.figures) 221 | show(plot) 222 | 223 | def plot(self, *kargs, **kwargs): 224 | """ 225 | add a new plot to the HTML file 226 | example: 227 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 228 | 'title='Loss', 'ylabel'='loss') 229 | """ 230 | f = plot_figure(self.results, *kargs, **kwargs) 231 | self.figures.append(f) 232 | 233 | def image(self, *kargs, **kwargs): 234 | fig = figure() 235 | fig.image(*kargs, **kwargs) 236 | self.figures.append(fig) 237 | 238 | def end(self): 239 | if hasattr(self, 'hd_experiment'): 240 | self.hd_experiment.end() 241 | 242 | 243 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 244 | filename = os.path.join(path, filename) 245 | torch.save(state, filename) 246 | if is_best: 247 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 248 | if save_all: 249 | shutil.copyfile(filename, os.path.join( 250 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 251 | -------------------------------------------------------------------------------- /AdaPrune/utils/meters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | class OnlineMeter(object): 24 | """Computes and stores the average and variance/std values of tensor""" 25 | 26 | def __init__(self): 27 | self.mean = torch.FloatTensor(1).fill_(-1) 28 | self.M2 = torch.FloatTensor(1).zero_() 29 | self.count = 0. 30 | self.needs_init = True 31 | 32 | def reset(self, x): 33 | self.mean = x.new(x.size()).zero_() 34 | self.M2 = x.new(x.size()).zero_() 35 | self.count = 0. 36 | self.needs_init = False 37 | 38 | def update(self, x): 39 | self.val = x 40 | if self.needs_init: 41 | self.reset(x) 42 | self.count += 1 43 | delta = x - self.mean 44 | self.mean.add_(delta / self.count) 45 | delta2 = x - self.mean 46 | self.M2.add_(delta * delta2) 47 | 48 | @property 49 | def var(self): 50 | if self.count < 2: 51 | return self.M2.clone().zero_() 52 | return self.M2 / (self.count - 1) 53 | 54 | @property 55 | def std(self): 56 | return self.var().sqrt() 57 | 58 | 59 | def accuracy(output, target, topk=(1,)): 60 | """Computes the precision@k for the specified values of k""" 61 | maxk = max(topk) 62 | batch_size = target.size(0) 63 | 64 | _, pred = output.topk(maxk, 1, True, True) 65 | pred = pred.t().type_as(target) 66 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 67 | 68 | res = [] 69 | for k in topk: 70 | correct_k = correct[:k].view(-1).float().sum(0) 71 | res.append(correct_k.mul_(100.0 / batch_size)) 72 | return res 73 | 74 | 75 | class AccuracyMeter(object): 76 | """Computes and stores the average and current topk accuracy""" 77 | 78 | def __init__(self, topk=(1,)): 79 | self.topk = topk 80 | self.reset() 81 | 82 | def reset(self): 83 | self._meters = {} 84 | for k in self.topk: 85 | self._meters[k] = AverageMeter() 86 | 87 | def update(self, output, target): 88 | n = target.nelement() 89 | acc_vals = accuracy(output, target, self.topk) 90 | for i, k in enumerate(self.topk): 91 | self._meters[k].update(acc_vals[i]) 92 | 93 | @property 94 | def val(self): 95 | return {n: meter.val for (n, meter) in self._meters.items()} 96 | 97 | @property 98 | def avg(self): 99 | return {n: meter.avg for (n, meter) in self._meters.items()} 100 | 101 | @property 102 | def avg_error(self): 103 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()} 104 | -------------------------------------------------------------------------------- /AdaPrune/utils/misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 6 | 7 | torch_dtypes = { 8 | 'float': torch.float, 9 | 'float32': torch.float32, 10 | 'float64': torch.float64, 11 | 'double': torch.double, 12 | 'float16': torch.float16, 13 | 'half': torch.half, 14 | 'uint8': torch.uint8, 15 | 'int8': torch.int8, 16 | 'int16': torch.int16, 17 | 'short': torch.short, 18 | 'int32': torch.int32, 19 | 'int': torch.int, 20 | 'int64': torch.int64, 21 | 'long': torch.long 22 | } 23 | 24 | 25 | def onehot(indexes, N=None, ignore_index=None): 26 | """ 27 | Creates a one-representation of indexes with N possible entries 28 | if N is not specified, it will suit the maximum index appearing. 29 | indexes is a long-tensor of indexes 30 | ignore_index will be zero in onehot representation 31 | """ 32 | if N is None: 33 | N = indexes.max() + 1 34 | sz = list(indexes.size()) 35 | output = indexes.new().byte().resize_(*sz, N).zero_() 36 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 37 | if ignore_index is not None and ignore_index >= 0: 38 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 39 | return output 40 | 41 | 42 | def set_global_seeds(i): 43 | try: 44 | import torch 45 | except ImportError: 46 | pass 47 | else: 48 | torch.manual_seed(i) 49 | if torch.cuda.is_available(): 50 | torch.cuda.manual_seed_all(i) 51 | np.random.seed(i) 52 | random.seed(i) 53 | 54 | 55 | class CheckpointModule(nn.Module): 56 | def __init__(self, module, num_segments=1): 57 | super(CheckpointModule, self).__init__() 58 | assert num_segments == 1 or isinstance(module, nn.Sequential) 59 | self.module = module 60 | self.num_segments = num_segments 61 | 62 | def forward(self, x): 63 | if self.num_segments > 1: 64 | return checkpoint_sequential(self.module, self.num_segments, x) 65 | else: 66 | return checkpoint(self.module, x) 67 | 68 | 69 | def normalize_module_name(layer_name): 70 | """Normalize a module's name. 71 | 72 | PyTorch let's you parallelize the computation of a model, by wrapping a model with a 73 | DataParallel module. Unfortunately, this changs the fully-qualified name of a module, 74 | even though the actual functionality of the module doesn't change. 75 | Many time, when we search for modules by name, we are indifferent to the DataParallel 76 | module and want to use the same module name whether the module is parallel or not. 77 | We call this module name normalization, and this is implemented here. 78 | """ 79 | modules = layer_name.split('.') 80 | try: 81 | idx = modules.index('module') 82 | except ValueError: 83 | return layer_name 84 | del modules[idx] 85 | return '.'.join(modules) 86 | -------------------------------------------------------------------------------- /AdaPrune/utils/mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy.random import beta 4 | from .misc import onehot 5 | 6 | 7 | class MixUp(nn.Module): 8 | def __init__(self, batch_dim=0): 9 | super(MixUp, self).__init__() 10 | self.batch_dim = batch_dim 11 | self.reset() 12 | 13 | def reset(self): 14 | self.enabled = False 15 | self.mix_values = None 16 | self.mix_index = None 17 | 18 | def mix(self, x1, x2): 19 | if not torch.is_tensor(self.mix_values): # scalar 20 | return x2.lerp(x1, self.mix_values) 21 | else: 22 | view = [1] * int(x1.dim()) 23 | view[self.batch_dim] = -1 24 | mix_val = self.mix_values.to(device=x1.device).view(*view) 25 | return mix_val * x1 + (1.-mix_val) * x2 26 | 27 | def sample(self, alpha, batch_size, sample_batch=False): 28 | self.mix_index = torch.randperm(batch_size) 29 | if sample_batch: 30 | values = beta(alpha, alpha, size=batch_size) 31 | self.mix_values = torch.tensor(values, dtype=torch.float) 32 | else: 33 | self.mix_values = torch.tensor([beta(alpha, alpha)], 34 | dtype=torch.float) 35 | 36 | def mix_target(self, y, n_class): 37 | if not self.training or \ 38 | self.mix_values is None or\ 39 | self.mix_values is None: 40 | return y 41 | y = onehot(y, n_class).to(dtype=torch.float) 42 | idx = self.mix_index.to(device=y.device) 43 | y_mix = y.index_select(self.batch_dim, idx) 44 | return self.mix(y, y_mix) 45 | 46 | def forward(self, x): 47 | if not self.training or \ 48 | self.mix_values is None or\ 49 | self.mix_values is None: 50 | return x 51 | idx = self.mix_index.to(device=x.device) 52 | x_mix = x.index_select(self.batch_dim, idx) 53 | return self.mix(x, x_mix) 54 | -------------------------------------------------------------------------------- /AdaPrune/utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging.config 3 | from copy import deepcopy 4 | from six import string_types 5 | from .regime import Regime 6 | from .param_filter import FilterParameters 7 | from . import regularization 8 | import torch.nn as nn 9 | 10 | _OPTIMIZERS = {name: func for name, func in torch.optim.__dict__.items()} 11 | 12 | try: 13 | from adabound import AdaBound 14 | _OPTIMIZERS['AdaBound'] = AdaBound 15 | except ImportError: 16 | pass 17 | 18 | 19 | def copy_params(param_target, param_src): 20 | with torch.no_grad(): 21 | for p_src, p_target in zip(param_src, param_target): 22 | p_target.copy_(p_src) 23 | 24 | 25 | def copy_params_grad(param_target, param_src): 26 | for p_src, p_target in zip(param_src, param_target): 27 | if p_target.grad is None: 28 | p_target.backward(p_src.grad.to(dtype=p_target.dtype)) 29 | else: 30 | p_target.grad.detach().copy_(p_src.grad) 31 | 32 | 33 | class ModuleFloatShadow(nn.Module): 34 | def __init__(self, module): 35 | super(ModuleFloatShadow, self).__init__() 36 | self.original_module = module 37 | self.float_module = deepcopy(module) 38 | self.float_module.to(dtype=torch.float) 39 | 40 | def parameters(self, *kargs, **kwargs): 41 | return self.float_module.parameters(*kargs, **kwargs) 42 | 43 | def named_parameters(self, *kargs, **kwargs): 44 | return self.float_module.named_parameters(*kargs, **kwargs) 45 | 46 | def modules(self, *kargs, **kwargs): 47 | return self.float_module.modules(*kargs, **kwargs) 48 | 49 | def named_modules(self, *kargs, **kwargs): 50 | return self.float_module.named_modules(*kargs, **kwargs) 51 | 52 | def original_parameters(self, *kargs, **kwargs): 53 | return self.original_module.parameters(*kargs, **kwargs) 54 | 55 | def original_named_parameters(self, *kargs, **kwargs): 56 | return self.original_module.named_parameters(*kargs, **kwargs) 57 | 58 | def original_modules(self, *kargs, **kwargs): 59 | return self.original_module.modules(*kargs, **kwargs) 60 | 61 | def original_named_modules(self, *kargs, **kwargs): 62 | return self.original_module.named_modules(*kargs, **kwargs) 63 | 64 | 65 | class OptimRegime(Regime): 66 | """ 67 | Reconfigures the optimizer according to setting list. 68 | Exposes optimizer methods - state, step, zero_grad, add_param_group 69 | 70 | Examples for regime: 71 | 72 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3}, 73 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4}, 74 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4}, 75 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5} 76 | ]" 77 | 2) 78 | "[{'step_lambda': 79 | "lambda t: { 80 | 'optimizer': 'Adam', 81 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5), 82 | 'betas': (0.9, 0.98), 'eps':1e-9} 83 | }]" 84 | """ 85 | 86 | def __init__(self, model, regime, defaults={}, filter=None, use_float_copy=False): 87 | super(OptimRegime, self).__init__(regime, defaults) 88 | if filter is not None: 89 | model = FilterParameters(model, **filter) 90 | if use_float_copy: 91 | model = ModuleFloatShadow(model) 92 | self._original_parameters = list(model.original_parameters()) 93 | 94 | self.parameters = list(model.parameters()) 95 | self.optimizer = torch.optim.SGD(self.parameters, lr=0) 96 | self.regularizer = regularization.Regularizer(model) 97 | self.use_float_copy = use_float_copy 98 | 99 | def update(self, epoch=None, train_steps=None): 100 | """adjusts optimizer according to current epoch or steps and training regime. 101 | """ 102 | if super(OptimRegime, self).update(epoch, train_steps): 103 | self.adjust(self.setting) 104 | return True 105 | else: 106 | return False 107 | 108 | def adjust(self, setting): 109 | """adjusts optimizer according to a setting dict. 110 | e.g: setting={optimizer': 'Adam', 'lr': 5e-4} 111 | """ 112 | if 'optimizer' in setting: 113 | optim_method = _OPTIMIZERS[setting['optimizer']] 114 | if not isinstance(self.optimizer, optim_method): 115 | self.optimizer = optim_method(self.optimizer.param_groups) 116 | logging.debug('OPTIMIZER - setting method = %s' % 117 | setting['optimizer']) 118 | for param_group in self.optimizer.param_groups: 119 | for key in param_group.keys(): 120 | if key in setting: 121 | new_val = setting[key] 122 | if new_val != param_group[key]: 123 | logging.debug('OPTIMIZER - setting %s = %s' % 124 | (key, setting[key])) 125 | param_group[key] = setting[key] 126 | # fix for AdaBound 127 | if key == 'lr' and hasattr(self.optimizer, 'base_lrs'): 128 | self.optimizer.base_lrs = list( 129 | map(lambda group: group['lr'], self.optimizer.param_groups)) 130 | 131 | if 'regularizer' in setting: 132 | reg_list = deepcopy(setting['regularizer']) 133 | if not (isinstance(reg_list, list) or isinstance(reg_list, tuple)): 134 | reg_list = (reg_list,) 135 | regularizers = [] 136 | for reg in reg_list: 137 | if isinstance(reg, dict): 138 | logging.debug('OPTIMIZER - Regularization - %s' % reg) 139 | name = reg.pop('name') 140 | regularizers.append((regularization.__dict__[name], reg)) 141 | elif isinstance(reg, regularization.Regularizer): 142 | regularizers.append(reg) 143 | else: # callable on model 144 | regularizers.append(reg(self.regularizer._model)) 145 | self.regularizer = regularization.RegularizerList(self.regularizer._model, 146 | regularizers) 147 | 148 | def __getstate__(self): 149 | return { 150 | 'optimizer_state': self.optimizer.__getstate__(), 151 | 'regime': self.regime, 152 | } 153 | 154 | def __setstate__(self, state): 155 | self.regime = state.get('regime') 156 | self.optimizer.__setstate__(state.get('optimizer_state')) 157 | 158 | def state_dict(self): 159 | """Returns the state of the optimizer as a :class:`dict`. 160 | """ 161 | return { 162 | 'optimizer_state': self.optimizer.state_dict(), 163 | 'regime': self.regime, 164 | } 165 | 166 | def load_state_dict(self, state_dict): 167 | """Loads the optimizer state. 168 | 169 | Arguments: 170 | state_dict (dict): optimizer state. Should be an object returned 171 | from a call to :meth:`state_dict`. 172 | """ 173 | # deepcopy, to be consistent with module API 174 | optimizer_state_dict = state_dict['optimizer_state'] 175 | 176 | self.__setstate__({'optimizer_state': optimizer_state_dict, 177 | 'regime': state_dict['regime']}) 178 | 179 | def zero_grad(self): 180 | """Clears the gradients of all optimized :class:`Variable` s.""" 181 | self.optimizer.zero_grad() 182 | if self.use_float_copy: 183 | for p in self._original_parameters: 184 | if p.grad is not None: 185 | p.grad.detach().zero_() 186 | 187 | def step(self, closure=None): 188 | """Performs a single optimization step (parameter update). 189 | 190 | Arguments: 191 | closure (callable): A closure that reevaluates the model and 192 | returns the loss. Optional for most optimizers. 193 | """ 194 | if self.use_float_copy: 195 | copy_params_grad(self.parameters, self._original_parameters) 196 | self.regularizer.pre_step() 197 | self.optimizer.step(closure) 198 | self.regularizer.post_step() 199 | if self.use_float_copy: 200 | copy_params(self._original_parameters, self.parameters) 201 | 202 | def pre_forward(self): 203 | """ allows modification pre-forward pass - e.g for regularization 204 | """ 205 | self.regularizer.pre_forward() 206 | 207 | def pre_backward(self): 208 | """ allows modification post-forward pass and pre-backward - e.g for regularization 209 | """ 210 | self.regularizer.pre_backward() 211 | 212 | 213 | class MultiOptimRegime(OptimRegime): 214 | 215 | def __init__(self, *optim_regime_list): 216 | self.optim_regime_list = [] 217 | for optim_regime in optim_regime_list: 218 | assert isinstance(optim_regime, OptimRegime) 219 | self.optim_regime_list.append(optim_regime) 220 | 221 | def update(self, epoch=None, train_steps=None): 222 | """adjusts optimizer according to current epoch or steps and training regime. 223 | """ 224 | updated = False 225 | for i, optim in enumerate(self.optim_regime_list): 226 | current_updated = optim.update(epoch, train_steps) 227 | if current_updated: 228 | logging.debug('OPTIMIZER #%s was updated' % i) 229 | updated = updated or current_updated 230 | return updated 231 | 232 | def zero_grad(self): 233 | """Clears the gradients of all optimized :class:`Variable` s.""" 234 | for optim in self.optim_regime_list: 235 | optim.zero_grad() 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step (parameter update). 239 | 240 | Arguments: 241 | closure (callable): A closure that reevaluates the model and 242 | returns the loss. Optional for most optimizers. 243 | """ 244 | for optim in self.optim_regime_list: 245 | optim.step(closure) 246 | -------------------------------------------------------------------------------- /AdaPrune/utils/param_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def is_not_bias(name): 6 | return not name.endswith('bias') 7 | 8 | 9 | def is_bn(module): 10 | return isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) 11 | 12 | 13 | def is_not_bn(module): 14 | return not is_bn(module) 15 | 16 | 17 | def filtered_parameter_info(model, module_fn=None, module_name_fn=None, parameter_name_fn=None, memo=None): 18 | if memo is None: 19 | memo = set() 20 | 21 | for module_name, module in model.named_modules(): 22 | if module_fn is not None and not module_fn(module): 23 | continue 24 | if module_name_fn is not None and not module_name_fn(module_name): 25 | continue 26 | for parameter_name, param in module.named_parameters(prefix=module_name, recurse=False): 27 | if parameter_name_fn is not None and not parameter_name_fn(parameter_name): 28 | continue 29 | if param not in memo: 30 | memo.add(param) 31 | yield {'named_module': (module_name, module), 'named_parameter': (parameter_name, param)} 32 | 33 | 34 | class FilterParameters(object): 35 | def __init__(self, source, module=None, module_name=None, parameter_name=None): 36 | if isinstance(source, FilterParameters): 37 | self._filtered_parameter_info = list(source.filter( 38 | module=module, 39 | module_name=module_name, 40 | parameter_name=parameter_name)) 41 | elif isinstance(source, torch.nn.Module): # source is a model 42 | self._filtered_parameter_info = list(filtered_parameter_info(source, 43 | module_fn=module, 44 | module_name_fn=module_name, 45 | parameter_name_fn=parameter_name)) 46 | 47 | def named_parameters(self): 48 | for p in self._filtered_parameter_info: 49 | yield p['named_parameter'] 50 | 51 | def parameters(self): 52 | for _, p in self.named_parameters(): 53 | yield p 54 | 55 | def filter(self, module=None, module_name=None, parameter_name=None): 56 | for p_info in self._filtered_parameter_info: 57 | if (module is None or module(p_info['named_module'][1]) 58 | and (module_name is None or module_name(p_info['named_module'][0])) 59 | and (parameter_name is None or parameter_name(p_info['named_parameter'][0]))): 60 | yield p_info 61 | 62 | def named_modules(self): 63 | for m in self._filtered_parameter_info: 64 | yield m['named_module'] 65 | 66 | def modules(self): 67 | for _, m in self.named_modules(): 68 | yield m 69 | 70 | def to(self, *kargs, **kwargs): 71 | for m in self.modules(): 72 | m.to(*kargs, **kwargs) 73 | 74 | 75 | class FilterModules(FilterParameters): 76 | pass 77 | 78 | if __name__ == '__main__': 79 | from torchvision.models import resnet50 80 | model = resnet50() 81 | filterd_params = FilterParameters(model, 82 | module=lambda m: isinstance( 83 | m, torch.nn.Linear), 84 | parameter_name=lambda n: 'bias' in n) 85 | -------------------------------------------------------------------------------- /AdaPrune/utils/regime.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | from six import string_types 4 | 5 | 6 | def eval_func(f, x): 7 | if isinstance(f, string_types): 8 | f = eval(f) 9 | return f(x) 10 | 11 | 12 | class Regime(object): 13 | """ 14 | Examples for regime: 15 | 16 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3}, 17 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4}, 18 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4}, 19 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5} 20 | ]" 21 | 2) 22 | "[{'step_lambda': 23 | "lambda t: { 24 | 'optimizer': 'Adam', 25 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5), 26 | 'betas': (0.9, 0.98), 'eps':1e-9} 27 | }]" 28 | """ 29 | 30 | def __init__(self, regime, defaults={}): 31 | self.regime = regime 32 | self.current_regime_phase = None 33 | self.setting = defaults 34 | 35 | def update(self, epoch=None, train_steps=None): 36 | """adjusts according to current epoch or steps and regime. 37 | """ 38 | if self.regime is None: 39 | return False 40 | epoch = -1 if epoch is None else epoch 41 | train_steps = -1 if train_steps is None else train_steps 42 | setting = deepcopy(self.setting) 43 | if self.current_regime_phase is None: 44 | # Find the first entry where the epoch is smallest than current 45 | for regime_phase, regime_setting in enumerate(self.regime): 46 | start_epoch = regime_setting.get('epoch', 0) 47 | start_step = regime_setting.get('step', 0) 48 | if epoch >= start_epoch or train_steps >= start_step: 49 | self.current_regime_phase = regime_phase 50 | break 51 | # each entry is updated from previous 52 | setting.update(regime_setting) 53 | if len(self.regime) > self.current_regime_phase + 1: 54 | next_phase = self.current_regime_phase + 1 55 | # Any more regime steps? 56 | start_epoch = self.regime[next_phase].get('epoch', float('inf')) 57 | start_step = self.regime[next_phase].get('step', float('inf')) 58 | if epoch >= start_epoch or train_steps >= start_step: 59 | self.current_regime_phase = next_phase 60 | setting.update(self.regime[self.current_regime_phase]) 61 | 62 | if 'lr_decay_rate' in setting and 'lr' in setting: 63 | decay_steps = setting.pop('lr_decay_steps', 100) 64 | if train_steps % decay_steps == 0: 65 | decay_rate = setting.pop('lr_decay_rate') 66 | setting['lr'] *= decay_rate ** (train_steps / decay_steps) 67 | elif 'step_lambda' in setting: 68 | setting.update(eval_func(setting.pop('step_lambda'), train_steps)) 69 | elif 'epoch_lambda' in setting: 70 | setting.update(eval_func(setting.pop('epoch_lambda'), epoch)) 71 | 72 | if 'execute' in setting: 73 | setting.pop('execute')() 74 | 75 | if 'execute_once' in setting: 76 | setting.pop('execute_once')() 77 | # remove from regime, so won't happen again 78 | self.regime[self.current_regime_phase].pop('execute_once', None) 79 | 80 | if setting == self.setting: 81 | return False 82 | else: 83 | self.setting = setting 84 | return True 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Accelerated Sparse Neural Training: A Provable and Efficient Method to FindN:M Transposable Masks 2 | Recently, researchers proposed pruning deep neural network weights (DNNs) using an $N:M$ fine-grained block sparsity mask. In this mask, for each block of M weights, we have at least N zeros. In contrast to unstructured sparsity, N:M fine-grained block sparsity allows acceleration in actual modern hardware. Previously suggested solutions enabled DNN acceleration at the inference phase. To also allow such acceleration in the training phase, we suggest a novel transposable-fine-grained sparsity mask where the same mask can be used for both forward and backward passes. Our transposable mask ensures that both the weight matrix and its transpose follow the same sparsity pattern; thus the matrix multiplication required for passing the error backward can also be accelerated. We discuss the transposable constraint and devise a new measure for mask constraints, called mask-diversity (MD), which correlates with their expected accuracy. Lastly, we formulate the problem of finding the optimal transposable mask as a minimum-cost-flow problem and suggest a fast linear approximation that can be used when the masks dynamically change while training. Our experiments suggest 2x speed-up with no accuracy degradation over vision and language models. A reference implementation is available in the supplementary material. 3 | ## Reproducing the results 4 | 5 | This repository is partially based on [convNet.pytorch](https://github.com/eladhoffer/convNet.pytorch) repo. please ensure that you are using pytorch 1.7+. 6 | Reproducing AdaPrune results 7 | ```bash 8 | cd AdaPrune 9 | sh scripts/adaprune_dense_bnt.sh 10 | sh scripts/adaprune_sparse.sh 11 | ``` 12 | Reproducing static NM-transposable starting from dense pre-trained model: 13 | ```bash 14 | cd static_TNM 15 | sh scripts/prune_pretrained_R50.sh 16 | ``` 17 | Reproducing dynamic NM-transposable from scratch: 18 | ```bash 19 | cd dynamic_TNM 20 | sh scripts/clone_and_copy.sh 21 | sh scripts/run_R18.sh 22 | sh scripts/run_R50.sh 23 | ``` 24 | -------------------------------------------------------------------------------- /common/flatten_object.py: -------------------------------------------------------------------------------- 1 | def flatten_object(obj, delimiter='.', prefix=''): 2 | def flatten(x, name=prefix): 3 | if isinstance(x, dict): 4 | for a in x: 5 | flatten(x[a], name + a + delimiter) 6 | elif isinstance(x, list) or isinstance(x, tuple): 7 | for i, a in enumerate(x): 8 | flatten(a, name + str(i) + delimiter) 9 | else: 10 | out[name[:-1]] = x 11 | 12 | out = {} 13 | flatten(obj) 14 | return out 15 | -------------------------------------------------------------------------------- /common/json_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def json_is_serializable(obj): 5 | serializable = True 6 | try: 7 | json.dumps(obj) 8 | except TypeError: 9 | serializable = False 10 | return serializable 11 | 12 | 13 | def json_force_serializable(obj): 14 | if isinstance(obj, dict): 15 | for k, v in obj.items(): 16 | obj[k] = json_force_serializable(v) 17 | elif isinstance(obj, list): 18 | for i, v in enumerate(obj): 19 | obj[i] = json_force_serializable(v) 20 | elif isinstance(obj, tuple): 21 | obj = list(obj) 22 | for i, v in enumerate(obj): 23 | obj[i] = json_force_serializable(v) 24 | obj = tuple(obj) 25 | elif not json_is_serializable(obj): 26 | obj = 'filtered by json' 27 | return obj 28 | -------------------------------------------------------------------------------- /common/timer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | 4 | class Timer: 5 | def __init__(self): 6 | self.start = None 7 | self.final = None 8 | 9 | def __enter__(self): 10 | self.start = datetime.datetime.now() 11 | return self 12 | 13 | def __exit__(self, exception_type, exception_value, traceback): 14 | self.final = datetime.datetime.now() - self.start 15 | 16 | def total(self): 17 | if self.final is None: 18 | raise RuntimeError('Timer total called before exit start={}'.format(self.start)) 19 | return self.final 20 | 21 | def elapsed(self): 22 | if self.start is None: 23 | raise RuntimeError('Timer elapsed called before start') 24 | return datetime.datetime.now() - self.start 25 | -------------------------------------------------------------------------------- /dynamic_TNM/scripts/clone_and_copy.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/NM-sparsity/NM-sparsity.git 2 | cd NM-sparsity 3 | git checkout d8419d99ad84ae47e3581db0125ed375ee416bb3 4 | cd .. 5 | cp src/dist_utils.py NM-sparsity/devkit/core/ 6 | cp src/sparse_ops.py NM-sparsity/devkit/sparse_ops/ 7 | cp src/train_imagenet.py NM-sparsity/classification/train_imagenet.py 8 | cp src/resnet.py NM-sparsity/classification/models/ 9 | cp src/train_val.sh NM-sparsity/classification 10 | cp src/sparse_ops_init.py NM-sparsity/devkit/sparse_ops/__init__.py 11 | cp src/utils.py NM-sparsity/devkit/core/ 12 | -------------------------------------------------------------------------------- /dynamic_TNM/scripts/run_R18.sh: -------------------------------------------------------------------------------- 1 | cd NM-sparsity/classification/ 2 | sh train_val.sh ../../src/configs/config_resnet18_4by8_transpose.yaml 3 | -------------------------------------------------------------------------------- /dynamic_TNM/scripts/run_R50.sh: -------------------------------------------------------------------------------- 1 | cd NM-sparsity/classification/ 2 | sh train_val.sh ../../src/configs/config_resnet50_4by8_transpose.yaml 3 | -------------------------------------------------------------------------------- /dynamic_TNM/src/configs/config_resnet18_4by8_transpose.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | model: resnet18 3 | N: 4 4 | M: 8 5 | sparse_optimizer: 1 6 | load_mask: True 7 | init_mask: False 8 | save_mask: False 9 | mask_path: 'path_to_masks/' 10 | 11 | # gpu: 3 12 | 13 | workers: 3 14 | batch_size: 512 15 | epochs: 120 16 | 17 | lr_mode : cosine 18 | base_lr: 0.2 19 | warmup_epochs: 5 20 | warmup_lr: 0.0 21 | targetlr : 0.0 22 | 23 | momentum: 0.9 24 | weight_decay: 0.00005 25 | 26 | 27 | print_freq: 100 28 | model_dir: checkpoint/resnet18_4by8 29 | 30 | train_root: /path_to/imagenet/train 31 | val_root: /path_to/imagenet/val 32 | 33 | 34 | 35 | 36 | TEST: 37 | checkpoint_path : data/pretrained_model/ 38 | -------------------------------------------------------------------------------- /dynamic_TNM/src/configs/config_resnet50_4by8_transpose.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | model: resnet50 3 | N: 4 4 | M: 8 5 | sparse_optimizer: 1 6 | load_mask: True 7 | init_mask: False 8 | save_mask: False 9 | mask_path: 'path_tomasks/' 10 | 11 | # gpu: 3 12 | 13 | 14 | workers: 3 15 | batch_size: 512 16 | epochs: 120 17 | 18 | lr_mode : cosine 19 | base_lr: 0.2 20 | warmup_epochs: 5 21 | warmup_lr: 0.0 22 | targetlr : 0.0 23 | 24 | momentum: 0.9 25 | weight_decay: 0.00005 26 | 27 | 28 | print_freq: 100 29 | model_dir: checkpoint/resnet50_4by8 30 | 31 | train_root: /path_to/imagenet/train 32 | val_root: /path_to/imagenet/val 33 | 34 | 35 | 36 | 37 | TEST: 38 | checkpoint_path : data/pretrained_model/ 39 | -------------------------------------------------------------------------------- /dynamic_TNM/src/configs/config_resnext50_4by8_transpose.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | model: resnext50_32x4d 3 | N: 4 4 | M: 8 5 | sparse_optimizer: 1 6 | load_mask: False 7 | init_mask: False 8 | save_mask: True 9 | mask_path: 'path_to_masks/' 10 | 11 | # gpu: 3 12 | 13 | 14 | workers: 3 15 | batch_size: 512 16 | epochs: 120 17 | 18 | lr_mode : cosine 19 | base_lr: 0.2 20 | warmup_epochs: 5 21 | warmup_lr: 0.0 22 | targetlr : 0.0 23 | 24 | momentum: 0.9 25 | weight_decay: 0.00005 26 | 27 | 28 | print_freq: 100 29 | model_dir: checkpoint/resnet50_4by8 30 | 31 | train_root: /path_to/imagenet/train 32 | val_root: /path_to/imagenet/val 33 | 34 | 35 | 36 | 37 | TEST: 38 | checkpoint_path : data/pretrained_model/ 39 | -------------------------------------------------------------------------------- /dynamic_TNM/src/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.multiprocessing as mp 4 | import torch.distributed as dist 5 | 6 | __all__ = [ 7 | 'init_dist', 'broadcast_params','average_gradients'] 8 | 9 | def init_dist(backend='nccl', 10 | master_ip='127.0.0.1', 11 | port=29500): 12 | #if mp.get_start_method(allow_none=True) is None: 13 | # mp.set_start_method('spawn') 14 | #os.environ['MASTER_ADDR'] = master_ip 15 | #os.environ['MASTER_PORT'] = str(port) 16 | rank = int(os.environ['RANK']) 17 | world_size = int(os.environ['WORLD_SIZE']) 18 | num_gpus = torch.cuda.device_count() 19 | #import pdb; pdb.set_trace() 20 | #torch.cuda.set_device(rank % num_gpus) 21 | #dist.init_process_group(backend=backend,init_method='tcp://127.0.0.1:6320',world_size=world_size,rank=rank) 22 | print('INIT') 23 | return rank, world_size,num_gpus,backend 24 | 25 | def average_gradients(model): 26 | for param in model.parameters(): 27 | if param.requires_grad and not (param.grad is None): 28 | dist.all_reduce(param.grad.data) 29 | 30 | def broadcast_params(model): 31 | for p in model.state_dict().values(): 32 | dist.broadcast(p, 0) 33 | 34 | -------------------------------------------------------------------------------- /dynamic_TNM/src/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import sys 4 | import os.path as osp 5 | sys.path.append(osp.abspath(osp.join(__file__, '../../../'))) 6 | #from devkit.ops import SyncBatchNorm2d 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import autograd 10 | from torch.nn.modules.utils import _pair as pair 11 | from torch.nn import init 12 | #from devkit.sparse_ops import SparseConv 13 | from devkit.sparse_ops import SparseConvTranspose as SparseConv 14 | 15 | 16 | 17 | __all__ = ['ResNetV1', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 18 | 'resnet152'] 19 | 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, N=2, M=4): 23 | """3x3 convolution with padding""" 24 | return SparseConv(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False, N=N, M=M) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, N=2, M=4): 32 | super(BasicBlock, self).__init__() 33 | 34 | self.conv1 = conv3x3(inplanes, planes, stride, N=N, M=M) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes, N=N, M=M) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None, N=2, M=4): 64 | super(Bottleneck, self).__init__() 65 | 66 | self.conv1 = SparseConv(inplanes, planes, kernel_size=1, bias=False, N=N, M=M) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = SparseConv(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False, N=N, M=M) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = SparseConv(planes, planes * 4, kernel_size=1, bias=False, N=N, M=M) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | class ResNetV1(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000, N=2, M=4): 102 | super(ResNetV1, self).__init__() 103 | 104 | 105 | self.N = N 106 | self.M = M 107 | 108 | self.inplanes = 64 109 | self.conv1 = SparseConv(3, 64, kernel_size=7, stride=2, padding=3, 110 | bias=False, N=self.N, M=self.M) 111 | self.bn1 = nn.BatchNorm2d(64) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 114 | self.layer1 = self._make_layer(block, 64, layers[0], N = self.N, M = self.M) 115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, N = self.N, M = self.M) 116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, N = self.N, M = self.M) 117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, N = self.N, M = self.M) 118 | self.avgpool = nn.AvgPool2d(7, stride=1) 119 | self.fc = nn.Linear(512 * block.expansion, num_classes) 120 | 121 | for m in self.modules(): 122 | if isinstance(m, SparseConv): 123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 124 | m.weight.data.normal_(0, math.sqrt(2. / n)) 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1, N = 2, M = 4): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | SparseConv(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False, N=N, M=M), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample, N=N, M=M)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes, N=N, M=M)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.maxpool(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | 155 | x = self.avgpool(x) 156 | x = x.view(x.size(0), -1) 157 | x = self.fc(x) 158 | 159 | return x 160 | 161 | 162 | def resnet18(**kwargs): 163 | model = ResNetV1(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | return model 165 | 166 | 167 | def resnet34(**kwargs): 168 | model = ResNetV1(BasicBlock, [3, 4, 6, 3], **kwargs) 169 | return model 170 | 171 | 172 | def resnet50(**kwargs): 173 | model = ResNetV1(Bottleneck, [3, 4, 6, 3], **kwargs) 174 | return model 175 | 176 | 177 | def resnet101(**kwargs): 178 | model = ResNetV1(Bottleneck, [3, 4, 23, 3], **kwargs) 179 | return model 180 | 181 | 182 | def resnet152(**kwargs): 183 | model = ResNetV1(Bottleneck, [3, 8, 36, 3], **kwargs) 184 | return model 185 | -------------------------------------------------------------------------------- /dynamic_TNM/src/sparse_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autograd, nn 3 | import torch.nn.functional as F 4 | 5 | from itertools import repeat 6 | from torch._six import container_abcs 7 | import time 8 | from prune.pruning_method_transposable_block_l1 import PruningMethodTransposableBlockL1 9 | 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, container_abcs.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | _single = _ntuple(1) 18 | _pair = _ntuple(2) 19 | _triple = _ntuple(3) 20 | _quadruple = _ntuple(4) 21 | 22 | def update_mask_approx2(data, mask, topk=4,BS=8): 23 | mask.fill_(0) 24 | Co = data.shape[0] 25 | #topk=BS//2 26 | _,idx_sort = data.sort(1,descending=True); #block x 64 27 | for k in range(BS**2): 28 | if k 0 70 | mask[mask_update] = new_mask[mask_update] 71 | if sum(mask_update) == 0: 72 | break 73 | return mask 74 | 75 | class SparseTranspose(autograd.Function): 76 | """" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using STE in the backwards phase""" 77 | 78 | @staticmethod 79 | def forward(ctx, weight, N, M, counter, freq, absorb_mean): 80 | weight.mask = weight.mask.to(weight) 81 | output = weight.clone() 82 | if counter%freq==0: 83 | weight_temp = weight.detach().abs().reshape(-1, M*M) 84 | weight_mask = weight.mask.detach().reshape(-1, M*M) 85 | #weight_mask = update_mask(weight_temp,weight_mask,BS=M) 86 | weight_mask = update_mask_approx2(weight_temp,weight_mask,BS=M) 87 | if absorb_mean: 88 | output = output.reshape(-1, M*M).clone() 89 | output+=output.mul(1-weight_mask).mean(1) 90 | output=output.reshape(weight.shape) 91 | weight.mask=weight_mask.reshape(weight.shape) 92 | return output*weight.mask, weight.mask 93 | 94 | @staticmethod 95 | def backward(ctx, grad_output, _): 96 | return grad_output, None, None, None, None, None 97 | 98 | 99 | class Sparse(autograd.Function): 100 | """" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using STE in the backwards phase""" 101 | 102 | @staticmethod 103 | def forward(ctx, weight, N, M): 104 | 105 | output = weight.clone() 106 | length = weight.numel() 107 | group = int(length/M) 108 | 109 | weight_temp = weight.detach().abs().reshape(group, M) 110 | index = torch.argsort(weight_temp, dim=1)[:, :int(M-N)] 111 | 112 | w_b = torch.ones(weight_temp.shape, device=weight_temp.device) 113 | w_b = w_b.scatter_(dim=1, index=index, value=0).reshape(weight.shape) 114 | 115 | return output*w_b, w_b 116 | 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output, _): 120 | return grad_output, None, None 121 | 122 | class SparseTransposeV2(autograd.Function): 123 | """" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using STE in the backwards phase""" 124 | 125 | @staticmethod 126 | def forward(ctx, weight, N, M, counter): 127 | weight.mask = weight.mask.to(weight) 128 | output = weight.reshape(-1, M*M).clone() 129 | weight_mask = weight.mask.reshape(-1, M*M) 130 | output+=torch.mean(output.mul(1-weight_mask),dim=1,keepdim=True) 131 | weight.mask=weight_mask.reshape(weight.shape) 132 | output=output.reshape(weight.shape) 133 | return output*weight.mask, weight.mask 134 | 135 | @staticmethod 136 | def backward(ctx, grad_output, _): 137 | return grad_output, None, None, None 138 | 139 | class SparseConvTranspose(nn.Conv2d): 140 | 141 | 142 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', N=2, M=4, **kwargs): 143 | self.N = N 144 | self.M = M 145 | self.counter = 0 146 | self.freq = 1 147 | self.absorb_mean = False 148 | super(SparseConvTranspose, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kwargs) 149 | 150 | 151 | def get_sparse_weights(self): 152 | return SparseTranspose.apply(self.weight, self.N, self.M, self.counter, self.freq, self.absorb_mean) 153 | 154 | 155 | 156 | def forward(self, x): 157 | if self.training: 158 | self.counter+=1 159 | self.freq = 40 #min(self.freq+self.counter//100,100) 160 | w, mask = self.get_sparse_weights() 161 | setattr(self.weight, "mask", mask) 162 | else: 163 | w = self.weight * self.weight.mask 164 | x = F.conv2d( 165 | x, w, self.bias, self.stride, self.padding, self.dilation, self.groups 166 | ) 167 | return x 168 | 169 | class SparseConvTransposeV2(nn.Conv2d): 170 | 171 | 172 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', N=2, M=4, **kwargs): 173 | self.N = N 174 | self.M = M 175 | self.counter = 0 176 | self.freq = 1 177 | self.rerun_ip = 0.01 178 | self.ipClass = PruningMethodTransposableBlockL1(block_size=self.M, topk=self.N) 179 | super(SparseConvTransposeV2, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kwargs) 180 | 181 | 182 | def get_sparse_weights(self): 183 | with torch.no_grad(): 184 | weight_temp = self.weight.detach().abs().reshape(-1, self.M*self.M) 185 | weight_mask = self.weight.mask.detach().reshape(-1, self.M*self.M) 186 | num_samples_ip= int(self.rerun_ip*weight_temp.shape[0]) 187 | idx=torch.randperm(weight_temp.shape[0])[:num_samples_ip] 188 | sample_weight = weight_temp[idx] 189 | mask_new = self.ipClass.compute_mask(sample_weight,torch.ones_like(sample_weight)) 190 | weight_mask = weight_mask.to(self.weight.device) 191 | weight_mask[idx]=mask_new.to(self.weight.device) 192 | return SparseTransposeV2.apply(self.weight, self.N, self.M, self.counter) 193 | 194 | def forward(self, x): 195 | # self.counter+=1 196 | # self.freq = min(self.freq+self.counter//100,100) 197 | w, mask = self.get_sparse_weights() 198 | setattr(self.weight, "mask", mask) 199 | x = F.conv2d( 200 | x, w, self.bias, self.stride, self.padding, self.dilation, self.groups 201 | ) 202 | return x 203 | 204 | class SparseConv(nn.Conv2d): 205 | 206 | 207 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', N=2, M=4, **kwargs): 208 | self.N = N 209 | self.M = M 210 | super(SparseConv, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kwargs) 211 | 212 | 213 | def get_sparse_weights(self): 214 | 215 | return Sparse.apply(self.weight, self.N, self.M) 216 | 217 | 218 | 219 | def forward(self, x): 220 | 221 | w, mask = self.get_sparse_weights() 222 | setattr(self.weight, "mask", mask) 223 | x = F.conv2d( 224 | x, w, self.bias, self.stride, self.padding, self.dilation, self.groups 225 | ) 226 | return x 227 | 228 | class SparseLinear(nn.Linear): 229 | def __init__(): 230 | 231 | self.N = N 232 | self.M = M 233 | 234 | 235 | 236 | 237 | class SparseLinearTranspose(nn.Linear): 238 | 239 | def __init__(self, in_channels, out_channels, bias=True, N=2, M=4, **kwargs): 240 | self.N = N 241 | self.M = M 242 | self.counter = 0 243 | self.freq = 10 244 | super(SparseLinearTranspose, self).__init__(in_channels, out_channels, bias,) 245 | 246 | def get_sparse_weights(self): 247 | return SparseTranspose.apply(self.weight, self.N, self.M, self.counter, self.freq, False) 248 | 249 | def forward(self, x): 250 | if self.training: 251 | self.counter += 1 252 | self.freq = 40 # min(self.freq+self.counter//100,100) 253 | w, mask = self.get_sparse_weights() 254 | setattr(self.weight, "mask", mask) 255 | else: 256 | w = self.weight * self.weight.mask 257 | x = F.linear( 258 | x, w, self.bias 259 | ) 260 | return x 261 | -------------------------------------------------------------------------------- /dynamic_TNM/src/sparse_ops_init.py: -------------------------------------------------------------------------------- 1 | from .syncbn_layer import SyncBatchNorm2d 2 | from .sparse_ops import SparseConv, SparseConvTranspose ,SparseLinearTranspose 3 | -------------------------------------------------------------------------------- /dynamic_TNM/src/train_val.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | export RANK=0 3 | export WORLD_SIZE=8 4 | export PYTHONPATH="path_to_TNM_repo" 5 | python train_imagenet.py \ 6 | --config $1 2>&1|tee train-$now.log 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /dynamic_TNM/src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | from devkit.sparse_ops import SparseConvTranspose,SparseLinearTranspose 5 | 6 | 7 | def save_checkpoint(model_dir, state, is_best): 8 | epoch = state['epoch'] 9 | path = os.path.join(model_dir, 'model.pth-' + str(epoch)) 10 | torch.save(state, path) 11 | checkpoint_file = os.path.join(model_dir, 'checkpoint') 12 | checkpoint = open(checkpoint_file, 'w+') 13 | checkpoint.write('model_checkpoint_path:%s\n' % path) 14 | checkpoint.close() 15 | if is_best: 16 | shutil.copyfile(path, os.path.join(model_dir, 'model-best.pth')) 17 | 18 | 19 | def load_state(model_dir, model, optimizer=None): 20 | if not os.path.exists(model_dir + '/checkpoint'): 21 | print("=> no checkpoint found at '{}', train from scratch".format(model_dir)) 22 | return 0, 0 23 | else: 24 | ckpt = open(model_dir + '/checkpoint') 25 | model_path = ckpt.readlines()[0].split(':')[1].strip('\n') 26 | checkpoint = torch.load(model_path,map_location='cuda:{}'.format(torch.cuda.current_device())) 27 | model.load_state_dict(checkpoint['state_dict'], strict=False) 28 | ckpt_keys = set(checkpoint['state_dict'].keys()) 29 | own_keys = set(model.state_dict().keys()) 30 | missing_keys = own_keys - ckpt_keys 31 | for k in missing_keys: 32 | print('missing keys from checkpoint {}: {}'.format(model_dir, k)) 33 | 34 | print("=> loaded model from checkpoint '{}'".format(model_dir)) 35 | if optimizer != None: 36 | best_prec1 = 0 37 | if 'best_prec1' in checkpoint.keys(): 38 | best_prec1 = checkpoint['best_prec1'] 39 | start_epoch = checkpoint['epoch'] 40 | optimizer.load_state_dict(checkpoint['optimizer']) 41 | print("=> also loaded optimizer from checkpoint '{}' (epoch {})" 42 | .format(model_dir, start_epoch)) 43 | return best_prec1, start_epoch 44 | 45 | 46 | def load_state_epoch(model_dir, model, epoch): 47 | model_path = model_dir + '/model.pth-' + str(epoch) 48 | checkpoint = torch.load(model_path,map_location='cuda:{}'.format(torch.cuda.current_device())) 49 | 50 | model.load_state_dict(checkpoint['state_dict'], strict=False) 51 | ckpt_keys = set(checkpoint['state_dict'].keys()) 52 | own_keys = set(model.state_dict().keys()) 53 | missing_keys = own_keys - ckpt_keys 54 | for k in missing_keys: 55 | print('missing keys from checkpoint {}: {}'.format(model_dir, k)) 56 | 57 | print("=> loaded model from checkpoint '{}'".format(model_dir)) 58 | 59 | 60 | def load_state_ckpt(model_path, model): 61 | checkpoint = torch.load(model_path, map_location='cuda:{}'.format(torch.cuda.current_device())) 62 | model.load_state_dict(checkpoint['state_dict'], strict=False) 63 | ckpt_keys = set(checkpoint['state_dict'].keys()) 64 | own_keys = set(model.state_dict().keys()) 65 | missing_keys = own_keys - ckpt_keys 66 | for k in missing_keys: 67 | print('missing keys from checkpoint {}: {}'.format(model_path, k)) 68 | 69 | print("=> loaded model from checkpoint '{}'".format(model_path)) 70 | 71 | def save_masks(model,args): 72 | masks = {} 73 | for n, m in model.named_modules(): 74 | if isinstance(m, SparseConvTranspose) or isinstance(m,SparseLinearTranspose): 75 | masks[n] = m.weight.mask.cpu() 76 | masks['state_dict'] = model.state_dict() 77 | torch.save(masks, args.mask_path + args.model + '_' + str(args.N) + '_' + str(args.M)) 78 | 79 | def load_state_and_masks(model, args): 80 | masks = torch.load(args.mask_path + args.model + '_' + str(args.N) + '_' + str(args.M)) 81 | 82 | #load weights 83 | model.load_state_dict(masks['state_dict'], strict=False) 84 | ckpt_keys = set(masks['state_dict'].keys()) 85 | own_keys = set(model.state_dict().keys()) 86 | missing_keys = own_keys - ckpt_keys 87 | for k in missing_keys: 88 | print('missing keys from checkpoint {}'.format( k)) 89 | 90 | #load_masks 91 | for n, m in model.named_modules(): 92 | if isinstance(m, SparseConvTranspose) or isinstance(m,SparseLinearTranspose): 93 | 94 | # m.maskBuff.data = masks[n] 95 | setattr(m.weight, "mask", masks[n]) 96 | 97 | 98 | -------------------------------------------------------------------------------- /dynamic_TNM/train-20210211_125543.log: -------------------------------------------------------------------------------- 1 | python: can't open file 'train_imagenet.py': [Errno 2] No such file or directory 2 | -------------------------------------------------------------------------------- /prune/prune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from torch import load as torch_load, save as torch_save, ones_like as torch_ones_like 5 | from common.timer import Timer 6 | from prune.pruning_method_utils import permute_to_nhwc, pad_inner_dims 7 | from prune.pruning_method_transposable_block_l1 import PruningMethodTransposableBlockL1 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser(description='Pruner') 12 | parser.add_argument('--checkpoint', required=True, type=str, help='path to checkpoint') 13 | parser.add_argument('--n-workers', default=None, type=int, help='number of processes') 14 | parser.add_argument('--save', default=None, type=str, help='path to pruned checkpoint') 15 | parser.add_argument('--bs', default=8, type=int, help='block size') 16 | parser.add_argument('--topk', default=4, type=int, help='topk') 17 | parser.add_argument('--sd-key', default='state_dict', type=str, help='state dict key in checkpoint') 18 | parser.add_argument('--optimize-transposed', action='store_true', default=False, 19 | help='if true, transposable pruning method will optimize for (block + block.T)') 20 | parser.add_argument('--include', nargs='*', default=None, 21 | help='list of layers that will be included in pruning') 22 | parser.add_argument('--exclude', nargs='*', default=None, 23 | help='list of layers that will be excluded from pruning') 24 | parser.add_argument('--debug-key', default=None, type=str, help='variable key to print first block') 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def load_checkpoint(filename): 30 | if not os.path.isfile(filename): 31 | raise FileNotFoundError('Checkpoint {} not found'.format(filename)) 32 | 33 | checkpoint = torch_load(filename, map_location='cpu') 34 | return checkpoint 35 | 36 | 37 | def load_sd_from_checkpoint(filename, sd_key): 38 | checkpoint = load_checkpoint(filename) 39 | sd = checkpoint[sd_key] if sd_key is not None else checkpoint.copy() 40 | del checkpoint 41 | return sd 42 | 43 | 44 | def load_var_from_checkpoint(filename, name, sd_key): 45 | sd = load_sd_from_checkpoint(filename, sd_key) 46 | if name not in sd: 47 | raise RuntimeError('Variable {} not found in {}'.format(name, filename)) 48 | v = sd[name] 49 | del sd 50 | return v 51 | 52 | 53 | def save_var_to_checkpoint(filename, name, mask, sd_key): 54 | checkpoint = load_checkpoint(filename) 55 | sd = checkpoint[sd_key] if sd_key is not None else checkpoint 56 | if name not in sd: 57 | raise RuntimeError('Variable {} not found in {}'.format(name, filename)) 58 | sd[name] = sd[name] * mask 59 | torch_save(checkpoint, filename) 60 | del checkpoint 61 | 62 | 63 | def prune(checkpoint, save, sd_key, bs=8, topk=4, optimize_transposed=False, 64 | include=None, exclude=None, n_workers=None, debug_key=None): 65 | 66 | with Timer() as t: 67 | sd = load_sd_from_checkpoint(checkpoint, sd_key) 68 | print('Loading checkpoint, elapsed={}'.format(t.total())) 69 | 70 | save = checkpoint + '.pruned' if save is None else save 71 | shutil.copyfile(checkpoint, save) 72 | 73 | prune_method = PruningMethodTransposableBlockL1(block_size=bs, topk=topk, 74 | optimize_transposed=optimize_transposed, 75 | n_workers=n_workers, with_tqdm=True) 76 | 77 | keys = [k for k in sd.keys() if sd[k].dim() > 1 and 'bias' not in k and 'running' not in k] 78 | 79 | if include: 80 | invalid_keys = [k for k in include if k not in keys] 81 | assert not invalid_keys, 'Requested params to include={} not in model'.format(invalid_keys) 82 | print('Including {}'.format(exclude)) 83 | keys = include 84 | 85 | if exclude: 86 | invalid_keys = [k for k in exclude if k not in keys] 87 | assert not invalid_keys, 'Requested params to exclude={} not in model'.format(invalid_keys) 88 | print('Excluding {}'.format(exclude)) 89 | keys = [k for k in keys if k not in exclude] 90 | 91 | del sd 92 | 93 | with Timer() as t: 94 | for key in keys: 95 | v = load_var_from_checkpoint(checkpoint, key, sd_key) 96 | print('Pruning ' + key) 97 | prune_weight_mask = prune_method.compute_mask(v, torch_ones_like(v)) 98 | save_var_to_checkpoint(save, key, prune_weight_mask, sd_key) 99 | print('Total elapsed time: {}'.format(t.total())) 100 | 101 | if debug_key: 102 | bs = bs 103 | sd = load_sd_from_checkpoint(save, sd_key) 104 | v = sd[debug_key] 105 | 106 | # print first block 107 | permuted_mask = permute_to_nhwc(v) 108 | permuted_mask = pad_inner_dims(permuted_mask, bs * bs) 109 | permuted_mask = permuted_mask.reshape(-1, (bs * bs)) 110 | print('first block=\n{}'.format(permuted_mask.numpy()[0, :].reshape(1, -1, bs, bs))) 111 | 112 | 113 | def main(): 114 | args = get_args() 115 | prune(checkpoint=args.checkpoint, save=args.save, sd_key=args.sd_key, bs=args.bs, topk=args.topk, 116 | optimize_transposed=args.optimize_transposed, include=args.include, exclude=args.exclude, 117 | n_workers=args.n_workers, debug_key=args.debug_key) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /prune/pruning_method_based_mask.py: -------------------------------------------------------------------------------- 1 | from prune.pruning_method_utils import * 2 | import torch.nn.utils.prune as prune 3 | 4 | 5 | class PruningMethodBasedMask(prune.BasePruningMethod): 6 | """ Pruning based on fixed mask """ 7 | 8 | PRUNING_TYPE = 'unstructured' # pruning type "structured" refers to channels 9 | 10 | def __init__(self, mask=None): 11 | super(PruningMethodBasedMask, self).__init__() 12 | self.mask = mask 13 | 14 | def compute_mask(self, t, default_mask): 15 | validate_tensor_shape_2d_4d(t) 16 | mask = self.mask.detach().mul_(default_mask) 17 | return mask.byte() 18 | 19 | def apply_like_self(self, module, name, **kwargs): 20 | assert 'mask' in kwargs 21 | cls = self.__class__ 22 | return super(PruningMethodBasedMask, cls).apply(module, name, kwargs['mask']) 23 | -------------------------------------------------------------------------------- /prune/pruning_method_transposable_block_l1.py: -------------------------------------------------------------------------------- 1 | from pulp import * 2 | from tqdm import tqdm 3 | from multiprocessing import Pool 4 | from common.timer import Timer 5 | from prune.pruning_method_utils import * 6 | import numpy as np 7 | import torch.nn.utils.prune as prune 8 | 9 | 10 | class PruningMethodTransposableBlockL1(prune.BasePruningMethod): 11 | 12 | PRUNING_TYPE = 'unstructured' # pruning type "structured" refers to channels 13 | 14 | RUN_SPEED_TEST = False 15 | 16 | def __init__(self, block_size, topk, optimize_transposed=False, n_workers=None, with_tqdm=True): 17 | super(PruningMethodTransposableBlockL1, self).__init__() 18 | assert topk <= block_size 19 | assert n_workers is None or n_workers > 0 20 | self.bs = block_size 21 | self.topk = topk 22 | self.optimize_transposed = optimize_transposed 23 | self.n_workers = n_workers 24 | self.with_tqdm = with_tqdm 25 | # used for multiprocess in order to avoid serialize/deserialize tensors etc. 26 | self.mp_tensor, self.mp_mask = None, None 27 | 28 | def ip_transpose(self, data): 29 | prob = LpProblem('TransposableMask', LpMaximize) 30 | combinations = [] 31 | magnitude_loss = {} 32 | indicators = {} 33 | bs = self.bs 34 | for r in range(bs): 35 | for c in range(bs): 36 | combinations.append('ind' + '_{}r_{}c'.format(r, c)) 37 | magnitude_loss['ind' + '_{}r_{}c'.format(r, c)] = abs(data[r, c]) 38 | indicators['ind' + '_{}r_{}c'.format(r, c)] = \ 39 | LpVariable('ind' + '_{}r_{}c'.format(r, c), 0, 1, LpInteger) 40 | 41 | prob += lpSum([indicators[ind] * magnitude_loss[ind] for ind in magnitude_loss.keys()]) 42 | 43 | for r in range(bs): 44 | prob += lpSum([indicators[key] for key in combinations if '_{}r'.format(r) in key]) == self.topk 45 | for c in range(bs): 46 | prob += lpSum([indicators[key] for key in combinations if '_{}c'.format(c) in key]) == self.topk 47 | 48 | solver = LpSolverDefault 49 | solver.msg = False 50 | prob.solve(solver) 51 | assert prob.status != -1, 'Infeasible' 52 | mask = np.zeros([self.bs, self.bs]) 53 | for v in prob.variables(): 54 | if 'ind' in v.name: 55 | rc = re.findall(r'\d+', v.name) 56 | mask[int(rc[0]), int(rc[1])] = v.varValue 57 | return mask 58 | 59 | def get_mask_iter(self, c): 60 | co, inners = self.mp_tensor.shape 61 | block_numel = self.bs ** 2 62 | n_blocks = inners // block_numel 63 | for j in range(n_blocks): 64 | offset = j * block_numel 65 | w_block = self.mp_tensor[c, offset:offset + block_numel].reshape(self.bs, self.bs) 66 | w_block = w_block + w_block.T if self.optimize_transposed else w_block 67 | mask_block = self.ip_transpose(w_block).reshape(-1) 68 | self.mp_mask[c, offset:offset + block_numel] = torch.from_numpy(mask_block) 69 | 70 | def get_mask(self, t): 71 | self.mp_tensor = t 72 | self.mp_mask = torch.zeros_like(t) 73 | 74 | co, inners = t.shape 75 | n_blocks = inners // (self.bs ** 2) 76 | 77 | if self.RUN_SPEED_TEST: 78 | self.RUN_SPEED_TEST = False 79 | with Timer() as t: 80 | self.get_mask_iter(0) 81 | elapsed = t.total().total_seconds() 82 | print('Single core speed test: blocks={} secs={} block-time={}'.format(n_blocks, elapsed, elapsed/n_blocks)) 83 | 84 | p = Pool(self.n_workers) 85 | n_iterations = co 86 | bar = tqdm(total=n_iterations, ncols=80) if self.with_tqdm else None 87 | bar.set_postfix_str('n_processes={}, blocks/iter={}'.format(p._processes, n_blocks)) if self.with_tqdm else None 88 | block_indexes = range(co) 89 | for _ in p.imap_unordered(self.get_mask_iter, block_indexes): 90 | bar.update(1) if self.with_tqdm else None 91 | bar.close() if self.with_tqdm else None 92 | p.close() 93 | 94 | return self.mp_mask 95 | 96 | def compute_mask(self, t, default_mask): 97 | # permute and pad 98 | validate_tensor_shape_2d_4d(t) 99 | t_masked = t.clone().detach().mul_(default_mask) 100 | t_permuted = permute_to_nhwc(t_masked) 101 | pad_to = self.bs ** 2 102 | t_padded = pad_inner_dims(t_permuted, pad_to) 103 | t = t_padded.data.abs().to(t) 104 | 105 | # compute mask 106 | mask = self.get_mask(t) 107 | 108 | # restore to original shape 109 | block_mask = clip_padding(mask, t_permuted.shape).reshape(t_permuted.shape) 110 | block_mask = permute_to_nchw(block_mask) 111 | return block_mask 112 | -------------------------------------------------------------------------------- /prune/pruning_method_transposable_block_l1_graphs.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from tqdm import tqdm 3 | from multiprocessing import Pool 4 | from common.timer import Timer 5 | from prune.pruning_method_utils import * 6 | import numpy as np 7 | import torch.nn.utils.prune as prune 8 | 9 | 10 | class PruningMethodTransposableBlockL1Graphs(prune.BasePruningMethod): 11 | 12 | PRUNING_TYPE = 'unstructured' # pruning type "structured" refers to channels 13 | 14 | RUN_SPEED_TEST = False 15 | 16 | def __init__(self, block_size, topk, optimize_transposed=False, n_workers=None, with_tqdm=True): 17 | super(PruningMethodTransposableBlockL1Graphs, self).__init__() 18 | assert topk <= block_size 19 | assert n_workers is None or n_workers > 0 20 | self.bs = block_size 21 | self.topk = topk 22 | self.optimize_transposed = optimize_transposed 23 | self.n_workers = n_workers 24 | self.with_tqdm = with_tqdm 25 | # used for multiprocess in order to avoid serialize/deserialize tensors etc. 26 | self.mp_tensor, self.mp_mask = None, None 27 | 28 | def nxGraph(self, data): 29 | bs = data.shape[0] 30 | G = nx.DiGraph() 31 | G.add_node('s', demand=-int(bs ** 2 / 2)) 32 | G.add_node('t', demand=int(bs ** 2 / 2)) 33 | names = [] 34 | for i in range(bs): 35 | G.add_edge('s', 'row' + str(i), capacity=self.topk, weight=0) 36 | G.add_edge('col' + str(i), 't', capacity=self.topk, weight=0) 37 | for j in range(bs): 38 | G.add_edge('row' + str(i), 'col' + str(j), capacity=1, weight=data[i, j].numpy()) 39 | names.append('row' + str(i)) 40 | dictMinFLow = nx.min_cost_flow(G) 41 | mask = [] 42 | for w in names: 43 | mask.append(list(dictMinFLow[w].values())) 44 | return np.array(mask) 45 | 46 | def get_mask_iter(self, c): 47 | co, inners = self.mp_tensor.shape 48 | block_numel = self.bs ** 2 49 | n_blocks = inners // block_numel 50 | for j in range(n_blocks): 51 | offset = j * block_numel 52 | w_block = self.mp_tensor[c, offset:offset + block_numel].reshape(self.bs, self.bs) 53 | w_block = w_block + w_block.T if self.optimize_transposed else w_block 54 | mask_block = self.nxGraph(-1 * w_block).reshape(-1) #max flow to min flow 55 | self.mp_mask[c, offset:offset + block_numel] = torch.from_numpy(mask_block) 56 | 57 | def get_mask(self, t): 58 | self.mp_tensor = t 59 | self.mp_mask = torch.zeros_like(t) 60 | 61 | co, inners = t.shape 62 | n_blocks = inners // (self.bs ** 2) 63 | 64 | if self.RUN_SPEED_TEST: 65 | self.RUN_SPEED_TEST = False 66 | with Timer() as t: 67 | self.get_mask_iter(0) 68 | elapsed = t.total().total_seconds() 69 | print('Single core speed test: blocks={} secs={} block-time={}'.format(n_blocks, elapsed, elapsed/n_blocks)) 70 | 71 | p = Pool(self.n_workers) 72 | n_iterations = co 73 | bar = tqdm(total=n_iterations, ncols=80) if self.with_tqdm else None 74 | bar.set_postfix_str('n_processes={}, blocks/iter={}'.format(p._processes, n_blocks)) if self.with_tqdm else None 75 | block_indexes = range(co) 76 | for _ in p.imap_unordered(self.get_mask_iter, block_indexes): 77 | bar.update(1) if self.with_tqdm else None 78 | bar.close() if self.with_tqdm else None 79 | p.close() 80 | 81 | return self.mp_mask 82 | 83 | def compute_mask(self, t, default_mask): 84 | # permute and pad 85 | validate_tensor_shape_2d_4d(t) 86 | t_masked = t.clone().detach().mul_(default_mask) 87 | t_permuted = permute_to_nhwc(t_masked) 88 | pad_to = self.bs ** 2 89 | t_padded = pad_inner_dims(t_permuted, pad_to) 90 | t = t_padded.data.abs().to(t) 91 | 92 | # compute mask 93 | mask = self.get_mask(t) 94 | 95 | # restore to original shape 96 | block_mask = clip_padding(mask, t_permuted.shape).reshape(t_permuted.shape) 97 | block_mask = permute_to_nchw(block_mask) 98 | return block_mask 99 | -------------------------------------------------------------------------------- /prune/pruning_method_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def validate_tensor_shape_2d_4d(t): 5 | shape = t.shape 6 | if len(shape) not in (2, 4): 7 | raise ValueError( 8 | "Only 2D and 4D tensor shapes are supported. Found " 9 | "Found tensor of shape {} with {} dims".format(shape, len(shape)) 10 | ) 11 | 12 | 13 | def pad_inner_dims(t, pad_to): 14 | """ return padded-to-block tensor """ 15 | inner_flattened = t.view(t.shape[0], -1) 16 | co, inners = inner_flattened.shape 17 | pad_required = pad_to > 1 and inners % pad_to != 0 18 | pad_size = pad_to - inners % pad_to if pad_required else 0 19 | pad = torch.zeros(co, pad_size).to(inner_flattened.data) 20 | t_padded = torch.cat((inner_flattened, pad), 1) 21 | return t_padded 22 | 23 | 24 | def clip_padding(t, orig_shape): 25 | """ return tensor with clipped padding """ 26 | co = orig_shape[0] 27 | inners = 1 28 | for s in orig_shape[1:]: 29 | inners *= s 30 | t_clipped = t.view(co, -1)[:, :inners] 31 | return t_clipped 32 | 33 | 34 | def permute_to_nhwc(t): 35 | """ for 4D tensors, convert data layout from NCHW to NHWC """ 36 | res = t.permute(0, 2, 3, 1).contiguous() if t.dim() == 4 else t 37 | return res 38 | 39 | 40 | def permute_to_nchw(t): 41 | """ for 4D tensors, convert data layout from NHWC to NCHW """ 42 | res = t.permute(0, 3, 1, 2).contiguous() if t.dim() == 4 else t 43 | return res 44 | -------------------------------------------------------------------------------- /prune/sparsity_freezer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .pruning_method_based_mask import PruningMethodBasedMask 3 | 4 | 5 | class SparsityFreezer: 6 | """ Keeps sparsity level in model by applying pruning based on current zeros of parameters """ 7 | @staticmethod 8 | def freeze(model): 9 | with torch.no_grad(): 10 | params = SparsityFreezer._get_model_params(model) 11 | SparsityFreezer._enforce_mask_based_on_zeros(params) 12 | 13 | @staticmethod 14 | def _enforce_mask_based_on_zeros(params): 15 | prune_method = PruningMethodBasedMask() 16 | for param_info in params.values(): 17 | module, name = param_info 18 | param = getattr(module, name) 19 | mask = param.ne(0).float() 20 | prune_method.apply_like_self(module=module, name=name, mask=mask) 21 | 22 | @staticmethod 23 | def _get_model_params(model): 24 | params = {} 25 | for m_info in list(model.named_modules()): 26 | module_name, module = m_info 27 | for p_info in list(module.named_parameters(recurse=False)): 28 | param_name, param = p_info 29 | key = module_name + '.' + param_name 30 | if param.dim() > 1 and 'bias' not in key and 'running' not in key: 31 | params[key] = (module, param_name) 32 | 33 | # a shared parameter will only appear once in model.named_parameters() 34 | # therefore, filter to get only parameters that appear in model.named_parameters() 35 | model_named_params = set([name for name, _ in model.named_parameters()]) 36 | params = {p: v for p, v in params.items() if p in model_named_params} 37 | return params 38 | -------------------------------------------------------------------------------- /static_TNM/scripts/prune_pretrained_R50.sh: -------------------------------------------------------------------------------- 1 | export datasets_dir=/datasets 2 | export dataset=imagenet 3 | export workdir='./results/static_TNM' 4 | 5 | echo $workdir 6 | cd .. 7 | python -m static_TNM.src.prune_pretrained_model -a resnet50 --save $workdir/resnet50-pruned.pth 8 | cp $workdir/resnet50-pruned.pth $workdir/resnet50.pth 9 | python -m vision.main --model resnet --resume $workdir/resnet50.pth --save $workdir --sparsity-freezer -b 256 --device-ids 0 1 2 3 --dataset $dataset --datasets-dir $datasets_dir 10 | -------------------------------------------------------------------------------- /static_TNM/src/prune_pretrained_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import torchvision.models as models 5 | from torch.hub import get_dir 6 | from glob import glob 7 | from prune.prune import prune 8 | 9 | 10 | def main(): 11 | # get supported models 12 | model_names = sorted(name for name in models.__dict__ 13 | if name.islower() and not name.startswith("__") 14 | and callable(models.__dict__[name])) 15 | 16 | # get arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=model_names, 19 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') 20 | parser.add_argument('--save', default=None, type=str, help='pruned checkpoint') 21 | args = parser.parse_args() 22 | 23 | # if required, download pre trained model 24 | models.__dict__[args.arch](pretrained=True) 25 | 26 | # get pre trained checkpoint 27 | checkpoint_path = os.path.join(get_dir(), 'checkpoints') 28 | files = glob(os.path.join(checkpoint_path, '{}-*.pth').format(args.arch)) 29 | assert len(files) == 1 30 | checkpoint_file = files[0] 31 | 32 | # prune and save checkpoint 33 | prune(checkpoint=checkpoint_file, save=args.save, sd_key=None, bs=8, topk=4) 34 | 35 | # add expected fields to checkpoint 36 | sd = torch.load(args.save) 37 | checkpoint = {'state_dict': sd, 'epoch': 0, 'best_prec1': 0} 38 | torch.save(checkpoint, args.save) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /vision/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /vision/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | from torch.utils.data.distributed import DistributedSampler 5 | from torch.utils.data import Subset 6 | from torch._utils import _accumulate 7 | from vision.utils.regime import Regime 8 | from vision.utils.dataset import IndexedFileDataset 9 | from vision.preprocess import get_transform 10 | from itertools import chain 11 | from copy import deepcopy 12 | import warnings 13 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 14 | 15 | 16 | def get_dataset(name, split='train', transform=None, 17 | target_transform=None, download=True, datasets_path='~/Datasets'): 18 | train = (split == 'train') 19 | root = os.path.join(os.path.expanduser(datasets_path), name) 20 | if name == 'cifar10': 21 | return datasets.CIFAR10(root=root, 22 | train=train, 23 | transform=transform, 24 | target_transform=target_transform, 25 | download=download) 26 | elif name == 'cifar100': 27 | return datasets.CIFAR100(root=root, 28 | train=train, 29 | transform=transform, 30 | target_transform=target_transform, 31 | download=download) 32 | elif name == 'mnist': 33 | return datasets.MNIST(root=root, 34 | train=train, 35 | transform=transform, 36 | target_transform=target_transform, 37 | download=download) 38 | elif name == 'stl10': 39 | return datasets.STL10(root=root, 40 | split=split, 41 | transform=transform, 42 | target_transform=target_transform, 43 | download=download) 44 | elif name == 'imagenet': 45 | root = os.path.join(root, split) 46 | return datasets.ImageFolder(root=root, 47 | transform=transform, 48 | target_transform=target_transform) 49 | elif name == 'imagenet_tar': 50 | if train: 51 | root = os.path.join(root, 'imagenet_train.tar') 52 | else: 53 | root = os.path.join(root, 'imagenet_validation.tar') 54 | return IndexedFileDataset(root, extract_target_fn=( 55 | lambda fname: fname.split('/')[0]), 56 | transform=transform, 57 | target_transform=target_transform) 58 | 59 | 60 | _DATA_ARGS = {'name', 'split', 'transform', 61 | 'target_transform', 'download', 'datasets_path'} 62 | _DATALOADER_ARGS = {'batch_size', 'shuffle', 'sampler', 'batch_sampler', 63 | 'num_workers', 'collate_fn', 'pin_memory', 'drop_last', 64 | 'timeout', 'worker_init_fn'} 65 | _TRANSFORM_ARGS = {'transform_name', 'input_size', 'scale_size', 'normalize', 'augment', 66 | 'cutout', 'duplicates', 'num_crops', 'autoaugment'} 67 | _OTHER_ARGS = {'distributed'} 68 | 69 | 70 | class DataRegime(object): 71 | def __init__(self, regime, defaults={}): 72 | self.regime = Regime(regime, deepcopy(defaults)) 73 | self.epoch = 0 74 | self.steps = None 75 | self.get_loader(True) 76 | 77 | def get_setting(self): 78 | setting = self.regime.setting 79 | loader_setting = {k: v for k, 80 | v in setting.items() if k in _DATALOADER_ARGS} 81 | data_setting = {k: v for k, v in setting.items() if k in _DATA_ARGS} 82 | transform_setting = { 83 | k: v for k, v in setting.items() if k in _TRANSFORM_ARGS} 84 | other_setting = {k: v for k, v in setting.items() if k in _OTHER_ARGS} 85 | transform_setting.setdefault('transform_name', data_setting['name']) 86 | return {'data': data_setting, 'loader': loader_setting, 87 | 'transform': transform_setting, 'other': other_setting} 88 | 89 | def get(self, key, default=None): 90 | return self.regime.setting.get(key, default) 91 | 92 | def get_loader(self, force_update=False, override_settings=None, subset_indices=None): 93 | if force_update or self.regime.update(self.epoch, self.steps): 94 | setting = self.get_setting() 95 | if override_settings is not None: 96 | setting.update(override_settings) 97 | self._transform = get_transform(**setting['transform']) 98 | setting['data'].setdefault('transform', self._transform) 99 | self._data = get_dataset(**setting['data']) 100 | if subset_indices is not None: 101 | self._data = Subset(self._data, subset_indices) 102 | if setting['other'].get('distributed', False): 103 | setting['loader']['sampler'] = DistributedSampler(self._data) 104 | setting['loader']['shuffle'] = None 105 | # pin-memory currently broken for distributed 106 | setting['loader']['pin_memory'] = False 107 | self._sampler = setting['loader'].get('sampler', None) 108 | self._loader = torch.utils.data.DataLoader( 109 | self._data, **setting['loader']) 110 | return self._loader 111 | 112 | def set_epoch(self, epoch): 113 | self.epoch = epoch 114 | if self._sampler is not None and hasattr(self._sampler, 'set_epoch'): 115 | self._sampler.set_epoch(epoch) 116 | 117 | def __len__(self): 118 | return len(self._data) 119 | 120 | def __repr__(self): 121 | return str(self.regime) 122 | 123 | 124 | class SampledDataLoader(object): 125 | def __init__(self, dl_list): 126 | self.dl_list = dl_list 127 | self.epoch = 0 128 | 129 | def generate_order(self): 130 | 131 | order = [[idx]*len(dl) for idx, dl in enumerate(self.dl_list)] 132 | order = list(chain(*order)) 133 | g = torch.Generator() 134 | g.manual_seed(self.epoch) 135 | return torch.tensor(order)[torch.randperm(len(order), generator=g)].tolist() 136 | 137 | def __len__(self): 138 | return sum([len(dl) for dl in self.dl_list]) 139 | 140 | def __iter__(self): 141 | order = self.generate_order() 142 | 143 | iterators = [iter(dl) for dl in self.dl_list] 144 | for idx in order: 145 | yield next(iterators[idx]) 146 | return 147 | 148 | 149 | class SampledDataRegime(DataRegime): 150 | def __init__(self, data_regime_list, probs, split_data=True): 151 | self.probs = probs 152 | self.data_regime_list = data_regime_list 153 | self.split_data = split_data 154 | 155 | def get_setting(self): 156 | return [data_regime.get_setting() for data_regime in self.data_regime_list] 157 | 158 | def get(self, key, default=None): 159 | return [data_regime.get(key, default) for data_regime in self.data_regime_list] 160 | 161 | def get_loader(self, force_update=False): 162 | settings = self.get_setting() 163 | if self.split_data: 164 | dset_sizes = [len(get_dataset(**s['data'])) for s in settings] 165 | assert len(set(dset_sizes)) == 1, \ 166 | "all datasets should be same size" 167 | dset_size = dset_sizes[0] 168 | lengths = [int(prob * dset_size) for prob in self.probs] 169 | lengths[-1] = dset_size - sum(lengths[:-1]) 170 | indices = torch.randperm(dset_size).tolist() 171 | indices_split = [indices[offset - length:offset] 172 | for offset, length in zip(_accumulate(lengths), lengths)] 173 | loaders = [data_regime.get_loader(force_update=True, subset_indices=indices_split[i]) 174 | for i, data_regime in enumerate(self.data_regime_list)] 175 | else: 176 | loaders = [data_regime.get_loader( 177 | force_update=force_update) for data_regime in self.data_regime_list] 178 | self._loader = SampledDataLoader(loaders) 179 | self._loader.epoch = self.epoch 180 | 181 | return self._loader 182 | 183 | def set_epoch(self, epoch): 184 | self.epoch = epoch 185 | if hasattr(self, '_loader'): 186 | self._loader.epoch = epoch 187 | for data_regime in self.data_regime_list: 188 | if data_regime._sampler is not None and hasattr(data_regime._sampler, 'set_epoch'): 189 | data_regime._sampler.set_epoch(epoch) 190 | 191 | def __len__(self): 192 | return sum([len(data_regime._data) 193 | for data_regime in self.data_regime_list]) 194 | 195 | def __repr__(self): 196 | print_str = 'Sampled Data Regime:\n' 197 | for p, config in zip(self.probs, self.data_regime_list): 198 | print_str += 'w.p. %s: %s\n' % (p, config) 199 | return print_str 200 | 201 | 202 | if __name__ == '__main__': 203 | reg1 = DataRegime(None, {'name': 'imagenet', 'batch_size': 16}) 204 | reg2 = DataRegime(None, {'name': 'imagenet', 'batch_size': 32}) 205 | reg1.set_epoch(0) 206 | reg2.set_epoch(0) 207 | mreg = SampledDataRegime([reg1, reg2]) 208 | 209 | for x, _ in mreg.get_loader(): 210 | print(x.shape) 211 | -------------------------------------------------------------------------------- /vision/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .resnet import * 3 | -------------------------------------------------------------------------------- /vision/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | 4 | __all__ = ['alexnet'] 5 | 6 | 7 | class AlexNetOWT_BN(nn.Module): 8 | 9 | def __init__(self, num_classes=1000): 10 | super(AlexNetOWT_BN, self).__init__() 11 | self.features = nn.Sequential( 12 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2, 13 | bias=False), 14 | nn.MaxPool2d(kernel_size=3, stride=2), 15 | nn.BatchNorm2d(64), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(64, 192, kernel_size=5, padding=2, bias=False), 18 | nn.MaxPool2d(kernel_size=3, stride=2), 19 | nn.ReLU(inplace=True), 20 | nn.BatchNorm2d(192), 21 | nn.Conv2d(192, 384, kernel_size=3, padding=1, bias=False), 22 | nn.ReLU(inplace=True), 23 | nn.BatchNorm2d(384), 24 | nn.Conv2d(384, 256, kernel_size=3, padding=1, bias=False), 25 | nn.ReLU(inplace=True), 26 | nn.BatchNorm2d(256), 27 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 28 | nn.MaxPool2d(kernel_size=3, stride=2), 29 | nn.ReLU(inplace=True), 30 | nn.BatchNorm2d(256) 31 | ) 32 | self.classifier = nn.Sequential( 33 | nn.Linear(256 * 6 * 6, 4096, bias=False), 34 | nn.BatchNorm1d(4096), 35 | nn.ReLU(inplace=True), 36 | nn.Dropout(0.5), 37 | nn.Linear(4096, 4096, bias=False), 38 | nn.BatchNorm1d(4096), 39 | nn.ReLU(inplace=True), 40 | nn.Dropout(0.5), 41 | nn.Linear(4096, num_classes) 42 | ) 43 | 44 | self.regime = [ 45 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-2, 46 | 'weight_decay': 5e-4, 'momentum': 0.9}, 47 | {'epoch': 10, 'lr': 5e-3}, 48 | {'epoch': 15, 'lr': 1e-3, 'weight_decay': 0}, 49 | {'epoch': 20, 'lr': 5e-4}, 50 | {'epoch': 25, 'lr': 1e-4} 51 | ] 52 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225]) 54 | self.data_regime = [{ 55 | 'transform': transforms.Compose([ 56 | transforms.Resize(256), 57 | transforms.RandomCrop(224), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.ToTensor(), 60 | normalize]) 61 | }] 62 | self.data_eval_regime = [{ 63 | 'transform': transforms.Compose([ 64 | transforms.Resize(256), 65 | transforms.CenterCrop(224), 66 | transforms.ToTensor(), 67 | normalize]) 68 | }] 69 | def forward(self, x): 70 | x = self.features(x) 71 | x = x.view(-1, 256 * 6 * 6) 72 | x = self.classifier(x) 73 | return x 74 | 75 | 76 | def alexnet(**kwargs): 77 | num_classes = getattr(kwargs, 'num_classes', 1000) 78 | return AlexNetOWT_BN(num_classes) 79 | -------------------------------------------------------------------------------- /vision/models/modules/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | @torch.jit.script 7 | def swish(x): 8 | # type: (Tensor) -> Tensor 9 | return x * x.sigmoid() 10 | 11 | 12 | @torch.jit.script 13 | def hard_sigmoid(x): 14 | # type: (Tensor) -> Tensor 15 | return F.relu6(x+3).div_(6) 16 | 17 | 18 | @torch.jit.script 19 | def hard_swish(x): 20 | # type: (Tensor) -> Tensor 21 | return x * hard_sigmoid(x) 22 | 23 | 24 | class Swish(nn.Module): 25 | def __init__(self): 26 | super(Swish, self).__init__() 27 | 28 | def forward(self, x): 29 | return swish(x) 30 | 31 | 32 | class HardSigmoid(nn.Module): 33 | def __init__(self): 34 | super(HardSigmoid, self).__init__() 35 | 36 | def forward(self, x): 37 | return hard_sigmoid(x) 38 | 39 | 40 | class HardSwish(nn.Module): 41 | def __init__(self): 42 | super(HardSwish, self).__init__() 43 | 44 | def forward(self, x): 45 | return hard_swish(x) 46 | -------------------------------------------------------------------------------- /vision/models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 4 | 5 | 6 | class CheckpointModule(nn.Module): 7 | def __init__(self, module, num_segments=1): 8 | super(CheckpointModule, self).__init__() 9 | assert num_segments == 1 or isinstance(module, nn.Sequential) 10 | self.module = module 11 | self.num_segments = num_segments 12 | 13 | def forward(self, *inputs): 14 | if self.num_segments > 1: 15 | return checkpoint_sequential(self.module, self.num_segments, *inputs) 16 | else: 17 | return checkpoint(self.module, *inputs) 18 | -------------------------------------------------------------------------------- /vision/models/modules/se.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .activations import Swish, HardSwish, HardSigmoid 4 | 5 | 6 | class SEBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels=None, ratio=16): 8 | super(SEBlock, self).__init__() 9 | self.in_channels = in_channels 10 | if out_channels is None: 11 | out_channels = in_channels 12 | self.ratio = ratio 13 | self.relu = nn.ReLU(True) 14 | self.global_pool = nn.AdaptiveAvgPool2d(1) 15 | self.transform = nn.Sequential( 16 | nn.Linear(in_channels, in_channels // ratio), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(in_channels // ratio, out_channels), 19 | nn.Sigmoid() 20 | ) 21 | 22 | def forward(self, x): 23 | x_avg = self.global_pool(x).flatten(1, -1) 24 | mask = self.transform(x_avg) 25 | return x * mask.unsqueeze(-1).unsqueeze(-1) 26 | 27 | 28 | class SESwishBlock(nn.Module): 29 | """ squeeze-excite block for MBConv """ 30 | 31 | def __init__(self, in_channels, out_channels=None, interm_channels=None, ratio=None, hard_act=False): 32 | super(SESwishBlock, self).__init__() 33 | assert not (interm_channels is None and ratio is None) 34 | interm_channels = interm_channels or in_channels // ratio 35 | self.in_channels = in_channels 36 | if out_channels is None: 37 | out_channels = in_channels 38 | self.ratio = ratio 39 | self.activation = HardSwish() if hard_act else Swish(), 40 | self.global_pool = nn.AdaptiveAvgPool2d(1) 41 | self.transform = nn.Sequential( 42 | nn.Linear(in_channels, interm_channels), 43 | HardSwish() if hard_act else Swish(), 44 | nn.Linear(interm_channels, out_channels), 45 | HardSigmoid() if hard_act else nn.Sigmoid() 46 | ) 47 | 48 | def forward(self, x): 49 | x_avg = self.global_pool(x).flatten(1, -1) 50 | mask = self.transform(x_avg) 51 | return x * mask.unsqueeze(-1).unsqueeze(-1) 52 | -------------------------------------------------------------------------------- /vision/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.transforms as transforms 4 | from vision.autoaugment import ImageNetPolicy, CIFAR10Policy 5 | 6 | 7 | _IMAGENET_STATS = {'mean': [0.485, 0.456, 0.406], 8 | 'std': [0.229, 0.224, 0.225]} 9 | 10 | 11 | _IMAGENET_PCA = { 12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 13 | 'eigvec': torch.Tensor([ 14 | [-0.5675, 0.7192, 0.4009], 15 | [-0.5808, -0.0045, -0.8140], 16 | [-0.5836, -0.6948, 0.4203], 17 | ]) 18 | } 19 | 20 | 21 | def scale_crop(input_size, scale_size=None, num_crops=1, normalize=_IMAGENET_STATS): 22 | assert num_crops in [1, 5, 10], "num crops must be in {1,5,10}" 23 | convert_tensor = transforms.Compose([transforms.ToTensor(), 24 | transforms.Normalize(**normalize)]) 25 | if num_crops == 1: 26 | t_list = [ 27 | transforms.CenterCrop(input_size), 28 | convert_tensor 29 | ] 30 | else: 31 | if num_crops == 5: 32 | t_list = [transforms.FiveCrop(input_size)] 33 | elif num_crops == 10: 34 | t_list = [transforms.TenCrop(input_size)] 35 | # returns a 4D tensor 36 | t_list.append(transforms.Lambda(lambda crops: 37 | torch.stack([convert_tensor(crop) for crop in crops]))) 38 | 39 | if scale_size != input_size: 40 | t_list = [transforms.Resize(scale_size)] + t_list 41 | 42 | return transforms.Compose(t_list) 43 | 44 | 45 | def random_crop(input_size, scale_size=None, padding=None, normalize=_IMAGENET_STATS): 46 | scale_size = scale_size or input_size 47 | T = transforms.Compose([ 48 | transforms.RandomCrop(scale_size, padding=padding), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize(**normalize), 52 | ]) 53 | if input_size != scale_size: 54 | T.transforms.insert(1, transforms.Resize(input_size)) 55 | return T 56 | 57 | 58 | def pad_random_crop(input_size, scale_size=None, normalize=_IMAGENET_STATS): 59 | padding = int((scale_size - input_size) / 2) 60 | return transforms.Compose([ 61 | transforms.RandomCrop(input_size, padding=padding), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | transforms.Normalize(**normalize), 65 | ]) 66 | 67 | 68 | def cifar_autoaugment(input_size, scale_size=None, padding=None, normalize=_IMAGENET_STATS): 69 | scale_size = scale_size or input_size 70 | T = transforms.Compose([ 71 | transforms.RandomCrop(scale_size, padding=padding), 72 | transforms.RandomHorizontalFlip(), 73 | CIFAR10Policy(fillcolor=(128, 128, 128)), 74 | transforms.ToTensor(), 75 | transforms.Normalize(**normalize), 76 | ]) 77 | if input_size != scale_size: 78 | T.transforms.insert(1, transforms.Resize(input_size)) 79 | return T 80 | 81 | 82 | def inception_preprocess(input_size, normalize=_IMAGENET_STATS): 83 | return transforms.Compose([ 84 | transforms.RandomResizedCrop(input_size), 85 | transforms.RandomHorizontalFlip(), 86 | transforms.ToTensor(), 87 | transforms.Normalize(**normalize) 88 | ]) 89 | 90 | 91 | def inception_autoaugment_preprocess(input_size, normalize=_IMAGENET_STATS): 92 | return transforms.Compose([ 93 | transforms.RandomResizedCrop(input_size), 94 | transforms.RandomHorizontalFlip(), 95 | ImageNetPolicy(fillcolor=(128, 128, 128)), 96 | transforms.ToTensor(), 97 | transforms.Normalize(**normalize) 98 | ]) 99 | 100 | 101 | def inception_color_preprocess(input_size, normalize=_IMAGENET_STATS): 102 | return transforms.Compose([ 103 | transforms.RandomResizedCrop(input_size), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ColorJitter( 106 | brightness=0.4, 107 | contrast=0.4, 108 | saturation=0.4, 109 | ), 110 | transforms.ToTensor(), 111 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), 112 | transforms.Normalize(**normalize) 113 | ]) 114 | 115 | 116 | def multi_transform(transform_fn, duplicates=1, dim=0): 117 | """preforms multiple transforms, useful to implement inference time augmentation or 118 | "batch augmentation" from https://openreview.net/forum?id=H1V4QhAqYQ¬eId=BylUSs_3Y7 119 | """ 120 | if duplicates > 1: 121 | return transforms.Lambda(lambda x: torch.stack([transform_fn(x) for _ in range(duplicates)], dim=dim)) 122 | else: 123 | return transform_fn 124 | 125 | 126 | def get_transform(transform_name='imagenet', input_size=None, scale_size=None, 127 | normalize=None, augment=True, cutout=None, autoaugment=False, 128 | padding=None, duplicates=1, num_crops=1): 129 | normalize = normalize or _IMAGENET_STATS 130 | transform_fn = None 131 | if 'imagenet' in transform_name: # inception augmentation is default for imagenet 132 | input_size = input_size or 224 133 | scale_size = scale_size or int(input_size * 8/7) 134 | if augment: 135 | if autoaugment: 136 | transform_fn = inception_autoaugment_preprocess(input_size, 137 | normalize=normalize) 138 | else: 139 | transform_fn = inception_preprocess(input_size, 140 | normalize=normalize) 141 | else: 142 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 143 | num_crops=num_crops, normalize=normalize) 144 | elif 'cifar' in transform_name: # resnet augmentation is default for imagenet 145 | input_size = input_size or 32 146 | if augment: 147 | scale_size = scale_size or 32 148 | padding = padding or 4 149 | if autoaugment: 150 | transform_fn = cifar_autoaugment(input_size, scale_size=scale_size, 151 | padding=padding, normalize=normalize) 152 | else: 153 | transform_fn = random_crop(input_size, scale_size=scale_size, 154 | padding=padding, normalize=normalize) 155 | else: 156 | scale_size = scale_size or 32 157 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 158 | num_crops=num_crops, normalize=normalize) 159 | elif transform_name == 'mnist': 160 | normalize = {'mean': [0.5], 'std': [0.5]} 161 | input_size = input_size or 28 162 | if augment: 163 | scale_size = scale_size or 32 164 | transform_fn = pad_random_crop(input_size, scale_size=scale_size, 165 | normalize=normalize) 166 | else: 167 | scale_size = scale_size or 32 168 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size, 169 | num_crops=num_crops, normalize=normalize) 170 | if cutout is not None: 171 | transform_fn.transforms.append(Cutout(**cutout)) 172 | return multi_transform(transform_fn, duplicates) 173 | 174 | 175 | class Lighting(object): 176 | """Lighting noise(AlexNet - style PCA - based noise)""" 177 | 178 | def __init__(self, alphastd, eigval, eigvec): 179 | self.alphastd = alphastd 180 | self.eigval = eigval 181 | self.eigvec = eigvec 182 | 183 | def __call__(self, img): 184 | if self.alphastd == 0: 185 | return img 186 | 187 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 188 | rgb = self.eigvec.type_as(img).clone()\ 189 | .mul(alpha.view(1, 3).expand(3, 3))\ 190 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 191 | .sum(1).squeeze() 192 | 193 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 194 | 195 | 196 | class Cutout(object): 197 | """ 198 | Randomly mask out one or more patches from an image. 199 | taken from https://github.com/uoguelph-mlrg/Cutout 200 | 201 | 202 | Args: 203 | holes (int): Number of patches to cut out of each image. 204 | length (int): The length (in pixels) of each square patch. 205 | """ 206 | 207 | def __init__(self, holes, length): 208 | self.holes = holes 209 | self.length = length 210 | 211 | def __call__(self, img): 212 | """ 213 | Args: 214 | img (Tensor): Tensor image of size (C, H, W). 215 | Returns: 216 | Tensor: Image with holes of dimension length x length cut out of it. 217 | """ 218 | h = img.size(1) 219 | w = img.size(2) 220 | 221 | mask = np.ones((h, w), np.float32) 222 | 223 | for n in range(self.holes): 224 | y = np.random.randint(h) 225 | x = np.random.randint(w) 226 | 227 | y1 = np.clip(y - self.length // 2, 0, h) 228 | y2 = np.clip(y + self.length // 2, 0, h) 229 | x1 = np.clip(x - self.length // 2, 0, w) 230 | x2 = np.clip(x + self.length // 2, 0, w) 231 | 232 | mask[y1: y2, x1: x2] = 0. 233 | 234 | mask = torch.from_numpy(mask) 235 | mask = mask.expand_as(img) 236 | img = img * mask 237 | 238 | return img 239 | -------------------------------------------------------------------------------- /vision/utils/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /vision/utils/absorb_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | 5 | 6 | def remove_bn_params(bn_module): 7 | bn_module.register_buffer('running_mean', None) 8 | bn_module.register_buffer('running_var', None) 9 | bn_module.register_parameter('weight', None) 10 | bn_module.register_parameter('bias', None) 11 | 12 | 13 | def init_bn_params(bn_module): 14 | bn_module.running_mean.fill_(0) 15 | bn_module.running_var.fill_(1) 16 | 17 | def absorb_bn(module, bn_module, remove_bn=True, verbose=False): 18 | with torch.no_grad(): 19 | w = module.weight 20 | if module.bias is None: 21 | zeros = torch.zeros(module.out_channels, 22 | dtype=w.dtype, device=w.device) 23 | bias = nn.Parameter(zeros) 24 | module.register_parameter('bias', bias) 25 | b = module.bias 26 | 27 | if hasattr(bn_module, 'running_mean'): 28 | b.add_(-bn_module.running_mean) 29 | if hasattr(bn_module, 'running_var'): 30 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5) 31 | w.mul_(invstd.view(w.size(0), 1, 1, 1)) 32 | b.mul_(invstd) 33 | 34 | if remove_bn: 35 | if hasattr(bn_module, 'weight'): 36 | w.mul_(bn_module.weight.view(w.size(0), 1, 1, 1)) 37 | b.mul_(bn_module.weight) 38 | if hasattr(bn_module, 'bias'): 39 | b.add_(bn_module.bias) 40 | remove_bn_params(bn_module) 41 | else: 42 | init_bn_params(bn_module) 43 | 44 | if verbose: 45 | logging.info('BN module %s was asborbed into layer %s' % 46 | (bn_module, module)) 47 | 48 | 49 | def is_bn(m): 50 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 51 | 52 | 53 | def is_absorbing(m): 54 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) 55 | 56 | 57 | def search_absorb_bn(model, prev=None, remove_bn=True, verbose=False): 58 | with torch.no_grad(): 59 | for m in model.children(): 60 | if is_bn(m) and is_absorbing(prev): 61 | absorb_bn(prev, m, remove_bn=remove_bn, verbose=verbose) 62 | search_absorb_bn(m, remove_bn=remove_bn, verbose=verbose) 63 | prev = m 64 | -------------------------------------------------------------------------------- /vision/utils/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .misc import onehot 5 | 6 | 7 | def _is_long(x): 8 | if hasattr(x, 'data'): 9 | x = x.data 10 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor) 11 | 12 | 13 | def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean', 14 | smooth_eps=None, smooth_dist=None, from_logits=True): 15 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567""" 16 | smooth_eps = smooth_eps or 0 17 | 18 | # ordinary log-liklihood - use cross_entropy from nn 19 | if _is_long(target) and smooth_eps == 0: 20 | if from_logits: 21 | return F.cross_entropy(inputs, target, weight, ignore_index=ignore_index, reduction=reduction) 22 | else: 23 | return F.nll_loss(inputs, target, weight, ignore_index=ignore_index, reduction=reduction) 24 | 25 | if from_logits: 26 | # log-softmax of inputs 27 | lsm = F.log_softmax(inputs, dim=-1) 28 | else: 29 | lsm = inputs 30 | 31 | masked_indices = None 32 | num_classes = inputs.size(-1) 33 | 34 | if _is_long(target) and ignore_index >= 0: 35 | masked_indices = target.eq(ignore_index) 36 | 37 | if smooth_eps > 0 and smooth_dist is not None: 38 | if _is_long(target): 39 | target = onehot(target, num_classes).type_as(inputs) 40 | if smooth_dist.dim() < target.dim(): 41 | smooth_dist = smooth_dist.unsqueeze(0) 42 | target.lerp_(smooth_dist, smooth_eps) 43 | 44 | if weight is not None: 45 | lsm = lsm * weight.unsqueeze(0) 46 | 47 | if _is_long(target): 48 | eps_sum = smooth_eps / num_classes 49 | eps_nll = 1. - eps_sum - smooth_eps 50 | likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) 51 | loss = -(eps_nll * likelihood + eps_sum * lsm.sum(-1)) 52 | else: 53 | loss = -(target * lsm).sum(-1) 54 | 55 | if masked_indices is not None: 56 | loss.masked_fill_(masked_indices, 0) 57 | 58 | if reduction == 'sum': 59 | loss = loss.sum() 60 | elif reduction == 'mean': 61 | if masked_indices is None: 62 | loss = loss.mean() 63 | else: 64 | loss = loss.sum() / float(loss.size(0) - masked_indices.sum()) 65 | 66 | return loss 67 | 68 | 69 | class CrossEntropyLoss(nn.CrossEntropyLoss): 70 | """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing""" 71 | 72 | def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None, from_logits=True): 73 | super(CrossEntropyLoss, self).__init__(weight=weight, 74 | ignore_index=ignore_index, reduction=reduction) 75 | self.smooth_eps = smooth_eps 76 | self.smooth_dist = smooth_dist 77 | self.from_logits = from_logits 78 | 79 | def forward(self, input, target, smooth_dist=None): 80 | if smooth_dist is None: 81 | smooth_dist = self.smooth_dist 82 | return cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index, 83 | reduction=self.reduction, smooth_eps=self.smooth_eps, 84 | smooth_dist=smooth_dist, from_logits=self.from_logits) 85 | 86 | 87 | def binary_cross_entropy(inputs, target, weight=None, reduction='mean', smooth_eps=None, from_logits=False): 88 | """cross entropy loss, with support for label smoothing https://arxiv.org/abs/1512.00567""" 89 | smooth_eps = smooth_eps or 0 90 | if smooth_eps > 0: 91 | target = target.float() 92 | target.add_(smooth_eps).div_(2.) 93 | if from_logits: 94 | return F.binary_cross_entropy_with_logits(inputs, target, weight=weight, reduction=reduction) 95 | else: 96 | return F.binary_cross_entropy(inputs, target, weight=weight, reduction=reduction) 97 | 98 | 99 | def binary_cross_entropy_with_logits(inputs, target, weight=None, reduction='mean', smooth_eps=None, from_logits=True): 100 | return binary_cross_entropy(inputs, target, weight, reduction, smooth_eps, from_logits) 101 | 102 | 103 | class BCELoss(nn.BCELoss): 104 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', smooth_eps=None, from_logits=False): 105 | super(BCELoss, self).__init__(weight, size_average, reduce, reduction) 106 | self.smooth_eps = smooth_eps 107 | self.from_logits = from_logits 108 | 109 | def forward(self, input, target): 110 | return binary_cross_entropy(input, target, 111 | weight=self.weight, reduction=self.reduction, 112 | smooth_eps=self.smooth_eps, from_logits=self.from_logits) 113 | 114 | 115 | class BCEWithLogitsLoss(BCELoss): 116 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', smooth_eps=None, from_logits=True): 117 | super(BCEWithLogitsLoss, self).__init__(weight, size_average, 118 | reduce, reduction, smooth_eps=smooth_eps, from_logits=from_logits) 119 | -------------------------------------------------------------------------------- /vision/utils/log.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from itertools import cycle 4 | import torch 5 | import logging.config 6 | import json 7 | 8 | import pandas as pd 9 | from bokeh.io import output_file, save, show 10 | from bokeh.plotting import figure 11 | from bokeh.layouts import column 12 | from bokeh.models import Div 13 | 14 | try: 15 | import hyperdash 16 | HYPERDASH_AVAILABLE = True 17 | except ImportError: 18 | HYPERDASH_AVAILABLE = False 19 | 20 | 21 | def export_args_namespace(args, filename): 22 | """ 23 | args: argparse.Namespace 24 | arguments to save 25 | filename: string 26 | filename to save at 27 | """ 28 | with open(filename, 'w') as fp: 29 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4) 30 | 31 | 32 | def setup_logging(log_file='log.txt', resume=False, dummy=False): 33 | """ 34 | Setup logging configuration 35 | """ 36 | if dummy: 37 | logging.getLogger('dummy') 38 | else: 39 | if os.path.isfile(log_file) and resume: 40 | file_mode = 'a' 41 | else: 42 | file_mode = 'w' 43 | 44 | root_logger = logging.getLogger() 45 | if root_logger.handlers: 46 | root_logger.removeHandler(root_logger.handlers[0]) 47 | logging.basicConfig(level=logging.INFO, 48 | format="%(asctime)s - %(levelname)s - %(message)s", 49 | datefmt="%Y-%m-%d %H:%M:%S", 50 | filename=log_file, 51 | filemode=file_mode) 52 | console = logging.StreamHandler() 53 | console.setLevel(logging.INFO) 54 | formatter = logging.Formatter('%(message)s') 55 | console.setFormatter(formatter) 56 | logging.getLogger('').addHandler(console) 57 | 58 | 59 | def plot_figure(data, x, y, title=None, xlabel=None, ylabel=None, legend=None, 60 | x_axis_type='linear', y_axis_type='linear', 61 | width=800, height=400, line_width=2, 62 | colors=['red', 'green', 'blue', 'orange', 63 | 'black', 'purple', 'brown'], 64 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save', 65 | append_figure=None): 66 | """ 67 | creates a new plot figures 68 | example: 69 | plot_figure(x='epoch', y=['train_loss', 'val_loss'], 70 | 'title='Loss', 'ylabel'='loss') 71 | """ 72 | if not isinstance(y, list): 73 | y = [y] 74 | xlabel = xlabel or x 75 | legend = legend or y 76 | assert len(legend) == len(y) 77 | if append_figure is not None: 78 | f = append_figure 79 | else: 80 | f = figure(title=title, tools=tools, 81 | width=width, height=height, 82 | x_axis_label=xlabel or x, 83 | y_axis_label=ylabel or '', 84 | x_axis_type=x_axis_type, 85 | y_axis_type=y_axis_type) 86 | colors = cycle(colors) 87 | for i, yi in enumerate(y): 88 | f.line(data[x], data[yi], 89 | line_width=line_width, 90 | line_color=next(colors), legend_label=legend[i]) 91 | f.legend.click_policy = "hide" 92 | return f 93 | 94 | 95 | class ResultsLog(object): 96 | 97 | supported_data_formats = ['csv', 'json'] 98 | 99 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'): 100 | """ 101 | Parameters 102 | ---------- 103 | path: string 104 | path to directory to save data files 105 | plot_path: string 106 | path to directory to save plot files 107 | title: string 108 | title of HTML file 109 | params: Namespace 110 | optionally save parameters for results 111 | resume: bool 112 | resume previous logging 113 | data_format: str('csv'|'json') 114 | which file format to use to save the data 115 | """ 116 | if data_format not in ResultsLog.supported_data_formats: 117 | raise ValueError('data_format must of the following: ' + 118 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats])) 119 | 120 | if data_format == 'json': 121 | self.data_path = '{}.json'.format(path) 122 | else: 123 | self.data_path = '{}.csv'.format(path) 124 | if params is not None: 125 | export_args_namespace(params, '{}.json'.format(path)) 126 | self.plot_path = '{}.html'.format(path) 127 | self.results = None 128 | self.clear() 129 | self.first_save = True 130 | if os.path.isfile(self.data_path): 131 | if resume: 132 | self.load(self.data_path) 133 | self.first_save = False 134 | else: 135 | os.remove(self.data_path) 136 | self.results = pd.DataFrame() 137 | else: 138 | self.results = pd.DataFrame() 139 | 140 | self.title = title 141 | self.data_format = data_format 142 | 143 | if HYPERDASH_AVAILABLE: 144 | name = self.title if title != '' else path 145 | self.hd_experiment = hyperdash.Experiment(name) 146 | if params is not None: 147 | for k, v in params._get_kwargs(): 148 | self.hd_experiment.param(k, v, log=False) 149 | 150 | def clear(self): 151 | self.figures = [] 152 | 153 | def add(self, **kwargs): 154 | """Add a new row to the dataframe 155 | example: 156 | resultsLog.add(epoch=epoch_num, train_loss=loss, 157 | test_loss=test_loss) 158 | """ 159 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 160 | self.results = self.results.append(df, ignore_index=True) 161 | if hasattr(self, 'hd_experiment'): 162 | for k, v in kwargs.items(): 163 | self.hd_experiment.metric(k, v, log=False) 164 | 165 | def smooth(self, column_name, window): 166 | """Select an entry to smooth over time""" 167 | # TODO: smooth only new data 168 | smoothed_column = self.results[column_name].rolling( 169 | window=window, center=False).mean() 170 | self.results[column_name + '_smoothed'] = smoothed_column 171 | 172 | def save(self, title=None): 173 | """save the json file. 174 | Parameters 175 | ---------- 176 | title: string 177 | title of the HTML file 178 | """ 179 | title = title or self.title 180 | if len(self.figures) > 0: 181 | if os.path.isfile(self.plot_path): 182 | os.remove(self.plot_path) 183 | if self.first_save: 184 | self.first_save = False 185 | logging.info('Plot file saved at: {}'.format( 186 | os.path.abspath(self.plot_path))) 187 | 188 | output_file(self.plot_path, title=title) 189 | plot = column( 190 | Div(text='

{}

'.format(title)), *self.figures) 191 | save(plot) 192 | self.clear() 193 | 194 | if self.data_format == 'json': 195 | self.results.to_json(self.data_path, orient='records', lines=True) 196 | else: 197 | self.results.to_csv(self.data_path, index=False, index_label=False) 198 | 199 | def load(self, path=None): 200 | """load the data file 201 | Parameters 202 | ---------- 203 | path: 204 | path to load the json|csv file from 205 | """ 206 | path = path or self.data_path 207 | if os.path.isfile(path): 208 | if self.data_format == 'json': 209 | self.results.read_json(path) 210 | else: 211 | self.results = pd.read_csv(path) 212 | else: 213 | raise ValueError('{} isn''t a file'.format(path)) 214 | 215 | def show(self, title=None): 216 | title = title or self.title 217 | if len(self.figures) > 0: 218 | plot = column( 219 | Div(text='

{}

'.format(title)), *self.figures) 220 | show(plot) 221 | 222 | def plot(self, *kargs, **kwargs): 223 | """ 224 | add a new plot to the HTML file 225 | example: 226 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 227 | 'title='Loss', 'ylabel'='loss') 228 | """ 229 | f = plot_figure(self.results, *kargs, **kwargs) 230 | self.figures.append(f) 231 | 232 | def image(self, *kargs, **kwargs): 233 | fig = figure() 234 | fig.image(*kargs, **kwargs) 235 | self.figures.append(fig) 236 | 237 | def end(self): 238 | if hasattr(self, 'hd_experiment'): 239 | self.hd_experiment.end() 240 | 241 | 242 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 243 | filename = os.path.join(path, filename) 244 | torch.save(state, filename) 245 | if is_best: 246 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 247 | if save_all: 248 | shutil.copyfile(filename, os.path.join( 249 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 250 | -------------------------------------------------------------------------------- /vision/utils/meters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | class OnlineMeter(object): 24 | """Computes and stores the average and variance/std values of tensor""" 25 | 26 | def __init__(self): 27 | self.mean = torch.FloatTensor(1).fill_(-1) 28 | self.M2 = torch.FloatTensor(1).zero_() 29 | self.count = 0. 30 | self.needs_init = True 31 | 32 | def reset(self, x): 33 | self.mean = x.new(x.size()).zero_() 34 | self.M2 = x.new(x.size()).zero_() 35 | self.count = 0. 36 | self.needs_init = False 37 | 38 | def update(self, x): 39 | self.val = x 40 | if self.needs_init: 41 | self.reset(x) 42 | self.count += 1 43 | delta = x - self.mean 44 | self.mean.add_(delta / self.count) 45 | delta2 = x - self.mean 46 | self.M2.add_(delta * delta2) 47 | 48 | @property 49 | def var(self): 50 | if self.count < 2: 51 | return self.M2.clone().zero_() 52 | return self.M2 / (self.count - 1) 53 | 54 | @property 55 | def std(self): 56 | return self.var().sqrt() 57 | 58 | 59 | def accuracy(output, target, topk=(1,)): 60 | """Computes the precision@k for the specified values of k""" 61 | maxk = max(topk) 62 | batch_size = target.size(0) 63 | 64 | _, pred = output.topk(maxk, 1, True, True) 65 | pred = pred.t().type_as(target) 66 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 67 | 68 | res = [] 69 | for k in topk: 70 | correct_k = correct[:k].view(-1).float().sum(0) 71 | res.append(correct_k.mul_(100.0 / batch_size)) 72 | return res 73 | 74 | 75 | class AccuracyMeter(object): 76 | """Computes and stores the average and current topk accuracy""" 77 | 78 | def __init__(self, topk=(1,)): 79 | self.topk = topk 80 | self.reset() 81 | 82 | def reset(self): 83 | self._meters = {} 84 | for k in self.topk: 85 | self._meters[k] = AverageMeter() 86 | 87 | def update(self, output, target): 88 | n = target.nelement() 89 | acc_vals = accuracy(output, target, self.topk) 90 | for i, k in enumerate(self.topk): 91 | self._meters[k].update(acc_vals[i]) 92 | 93 | @property 94 | def val(self): 95 | return {n: meter.val for (n, meter) in self._meters.items()} 96 | 97 | @property 98 | def avg(self): 99 | return {n: meter.avg for (n, meter) in self._meters.items()} 100 | 101 | @property 102 | def avg_error(self): 103 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()} 104 | -------------------------------------------------------------------------------- /vision/utils/misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 6 | 7 | torch_dtypes = { 8 | 'float': torch.float, 9 | 'float32': torch.float32, 10 | 'float64': torch.float64, 11 | 'double': torch.double, 12 | 'float16': torch.float16, 13 | 'half': torch.half, 14 | 'uint8': torch.uint8, 15 | 'int8': torch.int8, 16 | 'int16': torch.int16, 17 | 'short': torch.short, 18 | 'int32': torch.int32, 19 | 'int': torch.int, 20 | 'int64': torch.int64, 21 | 'long': torch.long 22 | } 23 | 24 | 25 | def onehot(indexes, N=None, ignore_index=None): 26 | """ 27 | Creates a one-representation of indexes with N possible entries 28 | if N is not specified, it will suit the maximum index appearing. 29 | indexes is a long-tensor of indexes 30 | ignore_index will be zero in onehot representation 31 | """ 32 | if N is None: 33 | N = indexes.max() + 1 34 | sz = list(indexes.size()) 35 | output = indexes.new().byte().resize_(*sz, N).zero_() 36 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 37 | if ignore_index is not None and ignore_index >= 0: 38 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 39 | return output 40 | 41 | 42 | def set_global_seeds(i): 43 | try: 44 | import torch 45 | except ImportError: 46 | pass 47 | else: 48 | torch.manual_seed(i) 49 | if torch.cuda.is_available(): 50 | torch.cuda.manual_seed_all(i) 51 | np.random.seed(i) 52 | random.seed(i) 53 | 54 | 55 | class CheckpointModule(nn.Module): 56 | def __init__(self, module, num_segments=1): 57 | super(CheckpointModule, self).__init__() 58 | assert num_segments == 1 or isinstance(module, nn.Sequential) 59 | self.module = module 60 | self.num_segments = num_segments 61 | 62 | def forward(self, x): 63 | if self.num_segments > 1: 64 | return checkpoint_sequential(self.module, self.num_segments, x) 65 | else: 66 | return checkpoint(self.module, x) 67 | -------------------------------------------------------------------------------- /vision/utils/mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from numpy.random import beta 5 | from torch.nn.functional import one_hot 6 | 7 | 8 | class MixUp(nn.Module): 9 | def __init__(self, batch_dim=0): 10 | super(MixUp, self).__init__() 11 | self.batch_dim = batch_dim 12 | self.reset() 13 | 14 | def reset(self): 15 | self.enabled = False 16 | self.mix_values = None 17 | self.mix_index = None 18 | 19 | def mix(self, x1, x2): 20 | if not torch.is_tensor(self.mix_values): # scalar 21 | return x2.lerp(x1, self.mix_values) 22 | else: 23 | view = [1] * int(x1.dim()) 24 | view[self.batch_dim] = -1 25 | mix_val = self.mix_values.to(device=x1.device).view(*view) 26 | return mix_val * x1 + (1.-mix_val) * x2 27 | 28 | def sample(self, alpha, batch_size, sample_batch=False): 29 | self.mix_index = torch.randperm(batch_size) 30 | if sample_batch: 31 | values = beta(alpha, alpha, size=batch_size) 32 | self.mix_values = torch.tensor(values, dtype=torch.float) 33 | else: 34 | self.mix_values = torch.tensor([beta(alpha, alpha)], 35 | dtype=torch.float) 36 | 37 | def mix_target(self, y, n_class): 38 | if not self.training or \ 39 | self.mix_values is None or\ 40 | self.mix_values is None: 41 | return y 42 | y = one_hot(y, n_class).to(dtype=torch.float) 43 | idx = self.mix_index.to(device=y.device) 44 | y_mix = y.index_select(self.batch_dim, idx) 45 | return self.mix(y, y_mix) 46 | 47 | def forward(self, x): 48 | if not self.training or \ 49 | self.mix_values is None or\ 50 | self.mix_values is None: 51 | return x 52 | idx = self.mix_index.to(device=x.device) 53 | x_mix = x.index_select(self.batch_dim, idx) 54 | return self.mix(x, x_mix) 55 | 56 | 57 | def rand_bbox(size, lam): 58 | W = size[2] 59 | H = size[3] 60 | cut_rat = np.sqrt(1. - lam) 61 | cut_w = np.int(W * cut_rat) 62 | cut_h = np.int(H * cut_rat) 63 | 64 | # uniform 65 | cx = np.random.randint(W) 66 | cy = np.random.randint(H) 67 | 68 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 69 | bby1 = np.clip(cy - cut_h // 2, 0, H) 70 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 71 | bby2 = np.clip(cy + cut_h // 2, 0, H) 72 | 73 | return bbx1, bby1, bbx2, bby2 74 | 75 | 76 | class CutMix(MixUp): 77 | def __init__(self, batch_dim=0): 78 | super(CutMix, self).__init__(batch_dim) 79 | 80 | def mix_image(self, x1, x2): 81 | assert not torch.is_tensor(self.mix_values) or \ 82 | self.mix_values.nelement() == 1 83 | lam = float(self.mix_values) 84 | bbx1, bby1, bbx2, bby2 = rand_bbox(x1.size(), lam) 85 | x1[:, :, bbx1:bbx2, bby1:bby2] = x2[:, :, bbx1:bbx2, bby1:bby2] 86 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / 87 | (x1.size()[-1] * x1.size()[-2])) 88 | self.mix_values.fill_(lam) 89 | return x1 90 | 91 | def sample(self, alpha, batch_size, sample_batch=False): 92 | assert not sample_batch 93 | super(CutMix, self).sample(alpha, batch_size, sample_batch) 94 | 95 | def forward(self, x): 96 | if not self.training or \ 97 | self.mix_values is None or\ 98 | self.mix_values is None: 99 | return x 100 | idx = self.mix_index.to(device=x.device) 101 | x_mix = x.index_select(self.batch_dim, idx) 102 | return self.mix_image(x, x_mix) 103 | -------------------------------------------------------------------------------- /vision/utils/param_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def is_not_bias(name): 6 | return not name.endswith('bias') 7 | 8 | 9 | def is_bn(module): 10 | return isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) 11 | 12 | 13 | def is_not_bn(module): 14 | return not is_bn(module) 15 | 16 | 17 | def filtered_parameter_info(model, module_fn=None, module_name_fn=None, parameter_name_fn=None, memo=None): 18 | if memo is None: 19 | memo = set() 20 | 21 | for module_name, module in model.named_modules(): 22 | if module_fn is not None and not module_fn(module): 23 | continue 24 | if module_name_fn is not None and not module_name_fn(module_name): 25 | continue 26 | for parameter_name, param in module.named_parameters(prefix=module_name, recurse=False): 27 | if parameter_name_fn is not None and not parameter_name_fn(parameter_name): 28 | continue 29 | if param not in memo: 30 | memo.add(param) 31 | yield {'named_module': (module_name, module), 'named_parameter': (parameter_name, param)} 32 | 33 | 34 | class FilterParameters(object): 35 | def __init__(self, source, module=None, module_name=None, parameter_name=None): 36 | if isinstance(source, FilterParameters): 37 | self._filtered_parameter_info = list(source.filter( 38 | module=module, 39 | module_name=module_name, 40 | parameter_name=parameter_name)) 41 | elif isinstance(source, torch.nn.Module): # source is a model 42 | self._filtered_parameter_info = list(filtered_parameter_info(source, 43 | module_fn=module, 44 | module_name_fn=module_name, 45 | parameter_name_fn=parameter_name)) 46 | 47 | def named_parameters(self): 48 | for p in self._filtered_parameter_info: 49 | yield p['named_parameter'] 50 | 51 | def parameters(self): 52 | for _, p in self.named_parameters(): 53 | yield p 54 | 55 | def filter(self, module=None, module_name=None, parameter_name=None): 56 | for p_info in self._filtered_parameter_info: 57 | if (module is None or module(p_info['named_module'][1]) 58 | and (module_name is None or module_name(p_info['named_module'][0])) 59 | and (parameter_name is None or parameter_name(p_info['named_parameter'][0]))): 60 | yield p_info 61 | 62 | def named_modules(self): 63 | for m in self._filtered_parameter_info: 64 | yield m['named_module'] 65 | 66 | def modules(self): 67 | for _, m in self.named_modules(): 68 | yield m 69 | 70 | def to(self, *kargs, **kwargs): 71 | for m in self.modules(): 72 | m.to(*kargs, **kwargs) 73 | 74 | 75 | class FilterModules(FilterParameters): 76 | pass 77 | 78 | if __name__ == '__main__': 79 | from torchvision.models import resnet50 80 | model = resnet50() 81 | filterd_params = FilterParameters(model, 82 | module=lambda m: isinstance( 83 | m, torch.nn.Linear), 84 | parameter_name=lambda n: 'bias' in n) 85 | -------------------------------------------------------------------------------- /vision/utils/regime.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from six import string_types 3 | 4 | 5 | def eval_func(f, x): 6 | if isinstance(f, string_types): 7 | f = eval(f) 8 | return f(x) 9 | 10 | 11 | class Regime(object): 12 | """ 13 | Examples for regime: 14 | 15 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3}, 16 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4}, 17 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4}, 18 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5} 19 | ]" 20 | 2) 21 | "[{'step_lambda': 22 | "lambda t: { 23 | 'optimizer': 'Adam', 24 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5), 25 | 'betas': (0.9, 0.98), 'eps':1e-9} 26 | }]" 27 | """ 28 | 29 | def __init__(self, regime, defaults={}): 30 | self.regime = regime 31 | self.current_regime_phase = None 32 | self.setting = defaults 33 | 34 | def update(self, epoch=None, train_steps=None): 35 | """adjusts according to current epoch or steps and regime. 36 | """ 37 | if self.regime is None: 38 | return False 39 | epoch = -1 if epoch is None else epoch 40 | train_steps = -1 if train_steps is None else train_steps 41 | setting = deepcopy(self.setting) 42 | if self.current_regime_phase is None: 43 | # Find the first entry where the epoch is smallest than current 44 | for regime_phase, regime_setting in enumerate(self.regime): 45 | start_epoch = regime_setting.get('epoch', 0) 46 | start_step = regime_setting.get('step', 0) 47 | if epoch >= start_epoch or train_steps >= start_step: 48 | self.current_regime_phase = regime_phase 49 | break 50 | # each entry is updated from previous 51 | setting.update(regime_setting) 52 | if len(self.regime) > self.current_regime_phase + 1: 53 | next_phase = self.current_regime_phase + 1 54 | # Any more regime steps? 55 | start_epoch = self.regime[next_phase].get('epoch', float('inf')) 56 | start_step = self.regime[next_phase].get('step', float('inf')) 57 | if epoch >= start_epoch or train_steps >= start_step: 58 | self.current_regime_phase = next_phase 59 | setting.update(self.regime[self.current_regime_phase]) 60 | 61 | if 'lr_decay_rate' in setting and 'lr' in setting: 62 | decay_steps = setting.pop('lr_decay_steps', 100) 63 | if train_steps % decay_steps == 0: 64 | decay_rate = setting.pop('lr_decay_rate') 65 | setting['lr'] *= decay_rate ** (train_steps / decay_steps) 66 | elif 'step_lambda' in setting: 67 | setting.update(eval_func(setting.pop('step_lambda'), train_steps)) 68 | elif 'epoch_lambda' in setting: 69 | setting.update(eval_func(setting.pop('epoch_lambda'), epoch)) 70 | 71 | if 'execute' in setting: 72 | setting.pop('execute')() 73 | 74 | if 'execute_once' in setting: 75 | setting.pop('execute_once')() 76 | # remove from regime, so won't happen again 77 | self.regime[self.current_regime_phase].pop('execute_once', None) 78 | 79 | if setting == self.setting: 80 | return False 81 | else: 82 | self.setting = setting 83 | return True 84 | 85 | def __repr__(self): 86 | return 'Current: %s\n Regime:%s' % (self.setting, self.regime) 87 | --------------------------------------------------------------------------------