├── AdaPrune
├── LICENSE
├── create_calib_folder.py
├── data.py
├── evaluate.py
├── main.py
├── models
│ ├── __init__.py
│ ├── modules
│ │ ├── batch_norm.py
│ │ ├── birelu.py
│ │ ├── bwn.py
│ │ ├── checkpoint.py
│ │ ├── evolved_modules.py
│ │ ├── fixed_proj.py
│ │ ├── fixup.py
│ │ ├── lp_norm.py
│ │ ├── quantize.py
│ │ └── se.py
│ └── resnet.py
├── preprocess.py
├── requirements.txt
├── scripts
│ ├── adaprune_dense_bnt.sh
│ └── adaprune_sparse.sh
├── trainer.py
└── utils
│ ├── LICENSE
│ ├── absorb_bn.py
│ ├── adaprune.py
│ ├── cross_entropy.py
│ ├── dataset.py
│ ├── functions.py
│ ├── log.py
│ ├── meters.py
│ ├── misc.py
│ ├── mixup.py
│ ├── optim.py
│ ├── param_filter.py
│ ├── regime.py
│ └── regularization.py
├── README.md
├── common
├── flatten_object.py
├── json_utils.py
└── timer.py
├── dynamic_TNM
├── scripts
│ ├── clone_and_copy.sh
│ ├── run_R18.sh
│ └── run_R50.sh
├── src
│ ├── configs
│ │ ├── config_resnet18_4by8_transpose.yaml
│ │ ├── config_resnet50_4by8_transpose.yaml
│ │ └── config_resnext50_4by8_transpose.yaml
│ ├── dist_utils.py
│ ├── resnet.py
│ ├── sparse_ops.py
│ ├── sparse_ops_init.py
│ ├── train_imagenet.py
│ ├── train_val.sh
│ └── utils.py
└── train-20210211_125543.log
├── prune
├── prune.py
├── pruning_method_based_mask.py
├── pruning_method_transposable_block_l1.py
├── pruning_method_transposable_block_l1_graphs.py
├── pruning_method_utils.py
└── sparsity_freezer.py
├── static_TNM
├── scripts
│ └── prune_pretrained_R50.sh
└── src
│ └── prune_pretrained_model.py
└── vision
├── LICENSE
├── autoaugment.py
├── data.py
├── main.py
├── models
├── __init__.py
├── alexnet.py
├── modules
│ ├── activations.py
│ ├── checkpoint.py
│ └── se.py
└── resnet.py
├── preprocess.py
├── trainer.py
└── utils
├── LICENSE
├── absorb_bn.py
├── cross_entropy.py
├── dataset.py
├── log.py
├── meters.py
├── misc.py
├── mixup.py
├── optim.py
├── param_filter.py
├── regime.py
└── regularization.py
/AdaPrune/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Elad Hoffer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/AdaPrune/create_calib_folder.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import shutil
4 |
5 | basepath = '/home/Datasets/imagenet/train/'
6 | basepath_calib = '/home/Datasets/imagenet/calib/'
7 |
8 | directory = os.fsencode(basepath)
9 | os.mkdir(basepath_calib)
10 | for d in os.listdir(directory):
11 | dir_name = os.fsdecode(d)
12 | dir_path = os.path.join(basepath,dir_name)
13 | dir_copy_path = os.path.join(basepath_calib,dir_name)
14 | os.mkdir(dir_copy_path)
15 | for f in os.listdir(dir_path):
16 | file_path = os.path.join(dir_path,f)
17 | copy_file_path = os.path.join(dir_copy_path,f)
18 | shutil.copyfile(file_path, copy_file_path)
19 | break
--------------------------------------------------------------------------------
/AdaPrune/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import logging
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.optim
10 | import torch.utils.data
11 | import models
12 | import torch.distributed as dist
13 | from data import DataRegime
14 | from utils.log import setup_logging, ResultsLog, save_checkpoint
15 | from utils.optim import OptimRegime
16 | from utils.cross_entropy import CrossEntropyLoss
17 | from utils.misc import torch_dtypes
18 | from utils.param_filter import FilterModules, is_bn
19 | from datetime import datetime
20 | from ast import literal_eval
21 | from trainer import Trainer
22 |
23 | model_names = sorted(name for name in models.__dict__
24 | if name.islower() and not name.startswith("__")
25 | and callable(models.__dict__[name]))
26 |
27 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Evaluation')
28 | parser.add_argument('evaluate', type=str,
29 | help='evaluate model FILE on validation set')
30 | parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results',
31 | help='results dir')
32 | parser.add_argument('--save', metavar='SAVE', default='',
33 | help='saved folder')
34 | parser.add_argument('--datasets-dir', metavar='DATASETS_DIR', default='~/Datasets',
35 | help='datasets dir')
36 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet',
37 | help='dataset name or folder')
38 | parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet',
39 | choices=model_names,
40 | help='model architecture: ' +
41 | ' | '.join(model_names) +
42 | ' (default: alexnet)')
43 | parser.add_argument('--input-size', type=int, default=None,
44 | help='image input size')
45 | parser.add_argument('--model-config', default='',
46 | help='additional architecture configuration')
47 | parser.add_argument('--dtype', default='float',
48 | help='type of tensor: ' +
49 | ' | '.join(torch_dtypes.keys()) +
50 | ' (default: float)')
51 | parser.add_argument('--device', default='cuda',
52 | help='device assignment ("cpu" or "cuda")')
53 | parser.add_argument('--device-ids', default=[0], type=int, nargs='+',
54 | help='device ids assignment (e.g 0 1 2 3')
55 | parser.add_argument('--world-size', default=-1, type=int,
56 | help='number of distributed processes')
57 | parser.add_argument('--local_rank', default=-1, type=int,
58 | help='rank of distributed processes')
59 | parser.add_argument('--dist-init', default='env://', type=str,
60 | help='init used to set up distributed training')
61 | parser.add_argument('--dist-backend', default='nccl', type=str,
62 | help='distributed backend')
63 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
64 | help='number of data loading workers (default: 8)')
65 | parser.add_argument('-b', '--batch-size', default=256, type=int,
66 | metavar='N', help='mini-batch size (default: 256)')
67 | parser.add_argument('--label-smoothing', default=0, type=float,
68 | help='label smoothing coefficient - default 0')
69 | parser.add_argument('--mixup', default=None, type=float,
70 | help='mixup alpha coefficient - default None')
71 | parser.add_argument('--duplicates', default=1, type=int,
72 | help='number of augmentations over singel example')
73 | parser.add_argument('--chunk-batch', default=1, type=int,
74 | help='chunk batch size for multiple passes (training)')
75 | parser.add_argument('--augment', action='store_true', default=False,
76 | help='perform augmentations')
77 | parser.add_argument('--cutout', action='store_true', default=False,
78 | help='cutout augmentations')
79 | parser.add_argument('--autoaugment', action='store_true', default=False,
80 | help='use autoaugment policies')
81 | parser.add_argument('--avg-out', action='store_true', default=False,
82 | help='average outputs')
83 | parser.add_argument('--print-freq', '-p', default=10, type=int,
84 | metavar='N', help='print frequency (default: 10)')
85 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
86 | help='path to latest checkpoint (default: none)')
87 |
88 | parser.add_argument('--seed', default=123, type=int,
89 | help='random seed (default: 123)')
90 |
91 |
92 | def main():
93 | args = parser.parse_args()
94 | main_worker(args)
95 |
96 |
97 | def main_worker(args):
98 | global best_prec1, dtype
99 | best_prec1 = 0
100 | dtype = torch_dtypes.get(args.dtype)
101 | torch.manual_seed(args.seed)
102 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
103 | if args.evaluate:
104 | args.results_dir = '/tmp'
105 | if args.save is '':
106 | args.save = time_stamp
107 | save_path = os.path.join(args.results_dir, args.save)
108 |
109 | args.distributed = args.local_rank >= 0 or args.world_size > 1
110 |
111 | if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0):
112 | os.makedirs(save_path)
113 |
114 | setup_logging(os.path.join(save_path, 'log.txt'),
115 | resume=args.resume is not '',
116 | dummy=args.distributed and args.local_rank > 0)
117 |
118 | results_path = os.path.join(save_path, 'results')
119 | results = ResultsLog(
120 | results_path, title='Training Results - %s' % args.save)
121 |
122 | if 'cuda' in args.device and torch.cuda.is_available():
123 | torch.cuda.manual_seed_all(args.seed)
124 | torch.cuda.set_device(args.device_ids[0])
125 | cudnn.benchmark = True
126 | else:
127 | args.device_ids = None
128 |
129 | if not os.path.isfile(args.evaluate):
130 | parser.error('invalid checkpoint: {}'.format(args.evaluate))
131 | checkpoint = torch.load(args.evaluate, map_location="cpu")
132 | # Overrride configuration with checkpoint info
133 | args.model = checkpoint.get('model', args.model)
134 | args.model_config = checkpoint.get('config', args.model_config)
135 |
136 | logging.info("saving to %s", save_path)
137 | logging.debug("run arguments: %s", args)
138 | logging.info("creating model %s", args.model)
139 |
140 | # create model
141 | model = models.__dict__[args.model]
142 | model_config = {'dataset': args.dataset}
143 |
144 | if args.model_config is not '':
145 | model_config = dict(model_config, **literal_eval(args.model_config))
146 |
147 | model = model(**model_config)
148 | logging.info("created model with configuration: %s", model_config)
149 | num_parameters = sum([l.nelement() for l in model.parameters()])
150 | logging.info("number of parameters: %d", num_parameters)
151 |
152 | # load checkpoint
153 | model.load_state_dict(checkpoint['state_dict'])
154 | logging.info("loaded checkpoint '%s' (epoch %s)",
155 | args.evaluate, checkpoint['epoch'])
156 |
157 | # define loss function (criterion) and optimizer
158 | loss_params = {}
159 | if args.label_smoothing > 0:
160 | loss_params['smooth_eps'] = args.label_smoothing
161 | criterion = getattr(model, 'criterion', nn.NLLLoss)(**loss_params)
162 | criterion.to(args.device, dtype)
163 | model.to(args.device, dtype)
164 |
165 | # Batch-norm should always be done in float
166 | if 'half' in args.dtype:
167 | FilterModules(model, module=is_bn).to(dtype=torch.float)
168 |
169 | trainer = Trainer(model, criterion,
170 | device_ids=args.device_ids, device=args.device, dtype=dtype,
171 | mixup=args.mixup, print_freq=args.print_freq)
172 |
173 | # Evaluation Data loading code
174 | val_data = DataRegime(getattr(model, 'data_eval_regime', None),
175 | defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': args.augment,
176 | 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': False, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment,
177 | 'cutout': {'holes': 1, 'length': 16} if args.cutout else None, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False})
178 |
179 | results = trainer.validate(val_data.get_loader(),
180 | duplicates=val_data.get('duplicates'),
181 | average_output=args.avg_out)
182 | logging.info(results)
183 | return results
184 |
185 |
186 | if __name__ == '__main__':
187 | main()
188 |
--------------------------------------------------------------------------------
/AdaPrune/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 |
--------------------------------------------------------------------------------
/AdaPrune/models/modules/batch_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import BatchNorm1d as _BatchNorm1d
4 | from torch.nn import BatchNorm2d as _BatchNorm2d
5 | from torch.nn import BatchNorm3d as _BatchNorm3d
6 |
7 | """
8 | BatchNorm variants that can be disabled by removing all parameters and running stats
9 | """
10 |
11 |
12 | def has_running_stats(m):
13 | return getattr(m, 'running_mean', None) is not None\
14 | or getattr(m, 'running_var', None) is not None
15 |
16 |
17 | def has_parameters(m):
18 | return getattr(m, 'weight', None) is not None\
19 | or getattr(m, 'bias', None) is not None
20 |
21 |
22 | class BatchNorm1d(_BatchNorm1d):
23 | def forward(self, inputs):
24 | if not (has_parameters(self) or has_running_stats(self)):
25 | return inputs
26 | return super(BatchNorm1d, self).forward(inputs)
27 |
28 |
29 | class BatchNorm2d(_BatchNorm2d):
30 | def forward(self, inputs):
31 | if not (has_parameters(self) or has_running_stats(self)):
32 | return inputs
33 | return super(BatchNorm2d, self).forward(inputs)
34 |
35 |
36 | class BatchNorm3d(_BatchNorm3d):
37 | def forward(self, inputs):
38 | if not (has_parameters(self) or has_running_stats(self)):
39 | return inputs
40 | return super(BatchNorm3d, self).forward(inputs)
41 |
42 |
43 | class MeanBatchNorm2d(nn.BatchNorm2d):
44 | """BatchNorm with mean-only normalization"""
45 |
46 | def __init__(self, num_features, momentum=0.1, bias=True):
47 | nn.Module.__init__(self)
48 | self.register_buffer('running_mean', torch.zeros(num_features))
49 | self.momentum = momentum
50 | self.num_features = num_features
51 | if bias:
52 | self.bias = nn.Parameter(torch.zeros(num_features))
53 | else:
54 | self.register_parameter('bias', None)
55 |
56 | def forward(self, x):
57 | if not (has_parameters(self) or has_running_stats(self)):
58 | return x
59 | if self.training:
60 | numel = x.size(0) * x.size(2) * x.size(3)
61 | mean = x.sum((0, 2, 3)) / numel
62 | with torch.no_grad():
63 | self.running_mean.mul_(self.momentum)\
64 | .add_(1 - self.momentum, mean)
65 | else:
66 | mean = self.running_mean
67 | if self.bias is not None:
68 | mean = mean - self.bias
69 | return x - mean.view(1, -1, 1, 1)
70 |
71 | def extra_repr(self):
72 | return '{num_features}, momentum={momentum}, bias={has_bias}'.format(
73 | has_bias=self.bias is not None, **self.__dict__)
74 |
--------------------------------------------------------------------------------
/AdaPrune/models/modules/birelu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd.function import InplaceFunction
3 | import torch.nn as nn
4 |
5 |
6 | class BiReLUFunction(InplaceFunction):
7 |
8 | @staticmethod
9 | def forward(ctx, input, inplace=False):
10 | if input.size(1) % 2 != 0:
11 | raise RuntimeError("dimension 1 of input must be multiple of 2, "
12 | "but got {}".format(input.size(1)))
13 | ctx.inplace = inplace
14 |
15 | if ctx.inplace:
16 | ctx.mark_dirty(input)
17 | output = input
18 | else:
19 | output = input.clone()
20 |
21 | pos, neg = output.chunk(2, dim=1)
22 | pos.clamp_(min=0)
23 | neg.clamp_(max=0)
24 | ctx.save_for_backward(output)
25 | return output
26 |
27 | @staticmethod
28 | def backward(ctx, grad_output):
29 | output, = ctx.saved_variables
30 | grad_input = grad_output.masked_fill(output.eq(0), 0)
31 | return grad_input, None
32 |
33 |
34 | def birelu(x, inplace=False):
35 | return BiReLUFunction().apply(x, inplace)
36 |
37 |
38 | class BiReLU(nn.Module):
39 | """docstring for BiReLU."""
40 |
41 | def __init__(self, inplace=False):
42 | super(BiReLU, self).__init__()
43 | self.inplace = inplace
44 |
45 | def forward(self, inputs):
46 | return birelu(inputs, inplace=self.inplace)
47 |
48 |
--------------------------------------------------------------------------------
/AdaPrune/models/modules/bwn.py:
--------------------------------------------------------------------------------
1 | """
2 | Weight Normalization from https://arxiv.org/abs/1602.07868
3 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py
4 | """
5 | import torch
6 | from torch.nn.parameter import Parameter
7 | from torch.autograd import Function
8 | import torch.nn as nn
9 |
10 |
11 | def _norm(x, dim, p=2):
12 | """Computes the norm over all dimensions except dim"""
13 | if p == -1:
14 | def func(x, dim): return x.max(dim=dim)[0] - x.min(dim=dim)[0]
15 | elif p == float('inf'):
16 | def func(x, dim): return x.max(dim=dim)[0]
17 | else:
18 | def func(x, dim): return torch.norm(x, dim=dim, p=p)
19 | if dim is None:
20 | return x.norm(p=p)
21 | elif dim == 0:
22 | output_size = (x.size(0),) + (1,) * (x.dim() - 1)
23 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size)
24 | elif dim == x.dim() - 1:
25 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),)
26 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size)
27 | else:
28 | return _norm(x.transpose(0, dim), 0).transpose(0, dim)
29 |
30 |
31 | def _mean(p, dim):
32 | """Computes the mean over all dimensions except dim"""
33 | if dim is None:
34 | return p.mean()
35 | elif dim == 0:
36 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
37 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size)
38 | elif dim == p.dim() - 1:
39 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
40 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size)
41 | else:
42 | return _mean(p.transpose(0, dim), 0).transpose(0, dim)
43 |
44 |
45 | class BoundedWeightNorm(object):
46 |
47 | def __init__(self, name, dim, p):
48 | self.name = name
49 | self.dim = dim
50 |
51 | def compute_weight(self, module):
52 |
53 | v = getattr(module, self.name + '_v')
54 | v.data.div_(_norm(v, self.dim))
55 | init_norm = getattr(module, self.name + '_init_norm')
56 | return v * (init_norm / _norm(v, self.dim))
57 |
58 | @staticmethod
59 | def apply(module, name, dim, p):
60 | fn = BoundedWeightNorm(name, dim, p)
61 |
62 | weight = getattr(module, name)
63 |
64 | # remove w from parameter list
65 | del module._parameters[name]
66 | module.register_buffer(
67 | name + '_init_norm', torch.Tensor([_norm(weight, dim, p=p).data.mean()]))
68 | module.register_parameter(name + '_v', Parameter(weight.data))
69 | setattr(module, name, fn.compute_weight(module))
70 |
71 | # recompute weight before every forward()
72 | module.register_forward_pre_hook(fn)
73 | return fn
74 |
75 | def remove(self, module):
76 | weight = self.compute_weight(module)
77 | delattr(module, self.name)
78 | del module._parameters[self.name + '_v']
79 | module.register_parameter(self.name, Parameter(weight.data))
80 |
81 | def __call__(self, module, inputs):
82 | setattr(module, self.name, self.compute_weight(module))
83 |
84 |
85 | def weight_norm(module, name='weight', dim=0, p=2):
86 | r"""Applies weight normalization to a parameter in the given module.
87 |
88 | .. math::
89 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
90 |
91 | Weight normalization is a reparameterization that decouples the magnitude
92 | of a weight tensor from its direction. This replaces the parameter specified
93 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude
94 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
95 | Weight normalization is implemented via a hook that recomputes the weight
96 | tensor from the magnitude and direction before every :meth:`~Module.forward`
97 | call.
98 |
99 | By default, with `dim=0`, the norm is computed independently per output
100 | channel/plane. To compute a norm over the entire weight tensor, use
101 | `dim=None`.
102 |
103 | See https://arxiv.org/abs/1602.07868
104 |
105 | Args:
106 | module (nn.Module): containing module
107 | name (str, optional): name of weight parameter
108 | dim (int, optional): dimension over which to compute the norm
109 |
110 | Returns:
111 | The original module with the weight norm hook
112 |
113 | Example::
114 |
115 | >>> m = weight_norm(nn.Linear(20, 40), name='weight')
116 | Linear (20 -> 40)
117 | >>> m.weight_g.size()
118 | torch.Size([40, 1])
119 | >>> m.weight_v.size()
120 | torch.Size([40, 20])
121 |
122 | """
123 | BoundedWeightNorm.apply(module, name, dim, p)
124 | return module
125 |
126 |
127 | def remove_weight_norm(module, name='weight'):
128 | r"""Removes the weight normalization reparameterization from a module.
129 |
130 | Args:
131 | module (nn.Module): containing module
132 | name (str, optional): name of weight parameter
133 |
134 | Example:
135 | >>> m = weight_norm(nn.Linear(20, 40))
136 | >>> remove_weight_norm(m)
137 | """
138 | for k, hook in module._forward_pre_hooks.items():
139 | if isinstance(hook, BoundedWeightNorm) and hook.name == name:
140 | hook.remove(module)
141 | del module._forward_pre_hooks[k]
142 | return module
143 |
144 | raise ValueError("weight_norm of '{}' not found in {}"
145 | .format(name, module))
146 |
--------------------------------------------------------------------------------
/AdaPrune/models/modules/checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
4 |
5 |
6 | class CheckpointModule(nn.Module):
7 | def __init__(self, module, num_segments=1):
8 | super(CheckpointModule, self).__init__()
9 | assert num_segments == 1 or isinstance(module, nn.Sequential)
10 | self.module = module
11 | self.num_segments = num_segments
12 |
13 | def forward(self, *inputs):
14 | if self.num_segments > 1:
15 | return checkpoint_sequential(self.module, self.num_segments, *inputs)
16 | else:
17 | return checkpoint(self.module, *inputs)
18 |
--------------------------------------------------------------------------------
/AdaPrune/models/modules/evolved_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | adapted from https://github.com/quark0/darts
3 | """
4 | from collections import namedtuple
5 | import torch
6 | import torch.nn as nn
7 |
8 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
9 |
10 | OPS = {
11 | 'avg_pool_3x3': lambda channels, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
12 | 'max_pool_3x3': lambda channels, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
13 | 'skip_connect': lambda channels, stride, affine: Identity() if stride == 1 else FactorizedReduce(channels, channels, affine=affine),
14 | 'sep_conv_3x3': lambda channels, stride, affine: SepConv(channels, channels, 3, stride, 1, affine=affine),
15 | 'sep_conv_5x5': lambda channels, stride, affine: SepConv(channels, channels, 5, stride, 2, affine=affine),
16 | 'sep_conv_7x7': lambda channels, stride, affine: SepConv(channels, channels, 7, stride, 3, affine=affine),
17 | 'dil_conv_3x3': lambda channels, stride, affine: DilConv(channels, channels, 3, stride, 2, 2, affine=affine),
18 | 'dil_conv_5x5': lambda channels, stride, affine: DilConv(channels, channels, 5, stride, 4, 2, affine=affine),
19 | 'conv_7x1_1x7': lambda channels, stride, affine: nn.Sequential(
20 | nn.ReLU(inplace=False),
21 | nn.Conv2d(channels, channels, (1, 7), stride=(1, stride),
22 | padding=(0, 3), bias=False),
23 | nn.Conv2d(channels, channels, (7, 1), stride=(stride, 1),
24 | padding=(3, 0), bias=False),
25 | nn.BatchNorm2d(channels, affine=affine)
26 | ),
27 | }
28 |
29 |
30 | # genotypes
31 | GENOTYPES = dict(
32 | NASNet=Genotype(
33 | normal=[
34 | ('sep_conv_5x5', 1),
35 | ('sep_conv_3x3', 0),
36 | ('sep_conv_5x5', 0),
37 | ('sep_conv_3x3', 0),
38 | ('avg_pool_3x3', 1),
39 | ('skip_connect', 0),
40 | ('avg_pool_3x3', 0),
41 | ('avg_pool_3x3', 0),
42 | ('sep_conv_3x3', 1),
43 | ('skip_connect', 1),
44 | ],
45 | normal_concat=[2, 3, 4, 5, 6],
46 | reduce=[
47 | ('sep_conv_5x5', 1),
48 | ('sep_conv_7x7', 0),
49 | ('max_pool_3x3', 1),
50 | ('sep_conv_7x7', 0),
51 | ('avg_pool_3x3', 1),
52 | ('sep_conv_5x5', 0),
53 | ('skip_connect', 3),
54 | ('avg_pool_3x3', 2),
55 | ('sep_conv_3x3', 2),
56 | ('max_pool_3x3', 1),
57 | ],
58 | reduce_concat=[4, 5, 6],
59 | ),
60 |
61 | AmoebaNet=Genotype(
62 | normal=[
63 | ('avg_pool_3x3', 0),
64 | ('max_pool_3x3', 1),
65 | ('sep_conv_3x3', 0),
66 | ('sep_conv_5x5', 2),
67 | ('sep_conv_3x3', 0),
68 | ('avg_pool_3x3', 3),
69 | ('sep_conv_3x3', 1),
70 | ('skip_connect', 1),
71 | ('skip_connect', 0),
72 | ('avg_pool_3x3', 1),
73 | ],
74 | normal_concat=[4, 5, 6],
75 | reduce=[
76 | ('avg_pool_3x3', 0),
77 | ('sep_conv_3x3', 1),
78 | ('max_pool_3x3', 0),
79 | ('sep_conv_7x7', 2),
80 | ('sep_conv_7x7', 0),
81 | ('avg_pool_3x3', 1),
82 | ('max_pool_3x3', 0),
83 | ('max_pool_3x3', 1),
84 | ('conv_7x1_1x7', 0),
85 | ('sep_conv_3x3', 5),
86 | ],
87 | reduce_concat=[3, 4, 6]
88 | ),
89 |
90 | DARTS_V1=Genotype(
91 | normal=[
92 | ('sep_conv_3x3', 1),
93 | ('sep_conv_3x3', 0),
94 | ('skip_connect', 0),
95 | ('sep_conv_3x3', 1),
96 | ('skip_connect', 0),
97 | ('sep_conv_3x3', 1),
98 | ('sep_conv_3x3', 0),
99 | ('skip_connect', 2)],
100 | normal_concat=[2, 3, 4, 5],
101 | reduce=[('max_pool_3x3', 0),
102 | ('max_pool_3x3', 1),
103 | ('skip_connect', 2),
104 | ('max_pool_3x3', 0),
105 | ('max_pool_3x3', 0),
106 | ('skip_connect', 2),
107 | ('skip_connect', 2),
108 | ('avg_pool_3x3', 0)],
109 | reduce_concat=[2, 3, 4, 5]),
110 | DARTS=Genotype(normal=[('sep_conv_3x3', 0),
111 | ('sep_conv_3x3', 1),
112 | ('sep_conv_3x3', 0),
113 | ('sep_conv_3x3', 1),
114 | ('sep_conv_3x3', 1),
115 | ('skip_connect', 0),
116 | ('skip_connect', 0),
117 | ('dil_conv_3x3', 2)],
118 | normal_concat=[2, 3, 4, 5],
119 | reduce=[('max_pool_3x3', 0),
120 | ('max_pool_3x3', 1),
121 | ('skip_connect', 2),
122 | ('max_pool_3x3', 1),
123 | ('max_pool_3x3', 0),
124 | ('skip_connect', 2),
125 | ('skip_connect', 2),
126 | ('max_pool_3x3', 1)],
127 | reduce_concat=[2, 3, 4, 5]),
128 | )
129 |
130 |
131 | class ReLUConvBN(nn.Module):
132 |
133 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
134 | super(ReLUConvBN, self).__init__()
135 | self.op = nn.Sequential(
136 | nn.ReLU(inplace=False),
137 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
138 | padding=padding, bias=False),
139 | nn.BatchNorm2d(C_out, affine=affine)
140 | )
141 |
142 | def forward(self, x):
143 | return self.op(x)
144 |
145 |
146 | class DilConv(nn.Module):
147 |
148 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
149 | super(DilConv, self).__init__()
150 | self.op = nn.Sequential(
151 | nn.ReLU(inplace=False),
152 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
153 | padding=padding, dilation=dilation, groups=C_in, bias=False),
154 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
155 | nn.BatchNorm2d(C_out, affine=affine),
156 | )
157 |
158 | def forward(self, x):
159 | return self.op(x)
160 |
161 |
162 | class SepConv(nn.Module):
163 |
164 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
165 | super(SepConv, self).__init__()
166 | self.op = nn.Sequential(
167 | nn.ReLU(inplace=False),
168 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
169 | padding=padding, groups=C_in, bias=False),
170 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
171 | nn.BatchNorm2d(C_in, affine=affine),
172 | nn.ReLU(inplace=False),
173 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1,
174 | padding=padding, groups=C_in, bias=False),
175 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
176 | nn.BatchNorm2d(C_out, affine=affine),
177 | )
178 |
179 | def forward(self, x):
180 | return self.op(x)
181 |
182 |
183 | class Identity(nn.Module):
184 |
185 | def __init__(self):
186 | super(Identity, self).__init__()
187 |
188 | def forward(self, x):
189 | return x
190 |
191 |
192 | class FactorizedReduce(nn.Module):
193 |
194 | def __init__(self, C_in, C_out, affine=True):
195 | super(FactorizedReduce, self).__init__()
196 | assert C_out % 2 == 0
197 | self.relu = nn.ReLU(inplace=False)
198 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1,
199 | stride=2, padding=0, bias=False)
200 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1,
201 | stride=2, padding=0, bias=False)
202 | self.bn = nn.BatchNorm2d(C_out, affine=affine)
203 |
204 | def forward(self, x):
205 | x = self.relu(x)
206 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
207 | out = self.bn(out)
208 | return out
209 |
210 |
211 | def drop_path(x, drop_prob):
212 | if drop_prob > 0.:
213 | keep_prob = 1.-drop_prob
214 | mask = x.new(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
215 | x.div_(keep_prob)
216 | x.mul_(mask)
217 | return x
218 |
219 |
220 | class Cell(nn.Module):
221 |
222 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
223 | super(Cell, self).__init__()
224 | if reduction_prev:
225 | self.preprocess0 = FactorizedReduce(C_prev_prev, C)
226 | else:
227 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
228 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
229 |
230 | if reduction:
231 | op_names, indices = zip(*genotype.reduce)
232 | concat = genotype.reduce_concat
233 | else:
234 | op_names, indices = zip(*genotype.normal)
235 | concat = genotype.normal_concat
236 | self._compile(C, op_names, indices, concat, reduction)
237 |
238 | def _compile(self, C, op_names, indices, concat, reduction):
239 | assert len(op_names) == len(indices)
240 | self._steps = len(op_names) // 2
241 | self._concat = concat
242 | self.multiplier = len(concat)
243 |
244 | self._ops = nn.ModuleList()
245 | for name, index in zip(op_names, indices):
246 | stride = 2 if reduction and index < 2 else 1
247 | op = OPS[name](C, stride, True)
248 | self._ops += [op]
249 | self._indices = indices
250 |
251 | def forward(self, s0, s1, drop_prob):
252 | s0 = self.preprocess0(s0)
253 | s1 = self.preprocess1(s1)
254 |
255 | states = [s0, s1]
256 | for i in range(self._steps):
257 | h1 = states[self._indices[2*i]]
258 | h2 = states[self._indices[2*i+1]]
259 | op1 = self._ops[2*i]
260 | op2 = self._ops[2*i+1]
261 | h1 = op1(h1)
262 | h2 = op2(h2)
263 | if self.training and drop_prob > 0.:
264 | if not isinstance(op1, Identity):
265 | h1 = drop_path(h1, drop_prob)
266 | if not isinstance(op2, Identity):
267 | h2 = drop_path(h2, drop_prob)
268 | s = h1 + h2
269 | states += [s]
270 | return torch.cat([states[i] for i in self._concat], dim=1)
271 |
272 |
273 | class NasNetCell(Cell):
274 | def __init__(self, *kargs, **kwargs):
275 | super(NasNetCell, self).__init__(GENOTYPES['NASNet'], *kargs, **kwargs)
276 |
277 |
278 | class AmoebaNetCell(Cell):
279 | def __init__(self, *kargs, **kwargs):
280 | super(AmoebaNetCell, self).__init__(
281 | GENOTYPES['AmoebaNet'], *kargs, **kwargs)
282 |
283 |
284 | class DARTSCell(Cell):
285 | def __init__(self, *kargs, **kwargs):
286 | super(DARTSCell, self).__init__(GENOTYPES['DARTS'], *kargs, **kwargs)
287 |
--------------------------------------------------------------------------------
/AdaPrune/models/modules/fixed_proj.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch
4 | from torch.autograd import Variable
5 | from scipy.linalg import hadamard
6 |
7 | class HadamardProj(nn.Module):
8 |
9 | def __init__(self, input_size, output_size, bias=True, fixed_weights=True, fixed_scale=None):
10 | super(HadamardProj, self).__init__()
11 | self.output_size = output_size
12 | self.input_size = input_size
13 | sz = 2 ** int(math.ceil(math.log(max(input_size, output_size), 2)))
14 | mat = torch.from_numpy(hadamard(sz))
15 | if fixed_weights:
16 | self.proj = Variable(mat, requires_grad=False)
17 | else:
18 | self.proj = nn.Parameter(mat)
19 |
20 | init_scale = 1. / math.sqrt(self.output_size)
21 |
22 | if fixed_scale is not None:
23 | self.scale = Variable(torch.Tensor(
24 | [fixed_scale]), requires_grad=False)
25 | else:
26 | self.scale = nn.Parameter(torch.Tensor([init_scale]))
27 |
28 | if bias:
29 | self.bias = nn.Parameter(torch.Tensor(
30 | output_size).uniform_(-init_scale, init_scale))
31 | else:
32 | self.register_parameter('bias', None)
33 |
34 | self.eps = 1e-8
35 |
36 | def forward(self, x):
37 | if not isinstance(self.scale, nn.Parameter):
38 | self.scale = self.scale.type_as(x)
39 | x = x / (x.norm(2, -1, keepdim=True) + self.eps)
40 | w = self.proj.type_as(x)
41 |
42 | out = -self.scale * \
43 | nn.functional.linear(x, w[:self.output_size, :self.input_size])
44 | if self.bias is not None:
45 | out = out + self.bias.view(1, -1)
46 | return out
47 |
48 |
49 | class Proj(nn.Module):
50 |
51 | def __init__(self, input_size, output_size, bias=True, init_scale=10):
52 | super(Proj, self).__init__()
53 | if init_scale is not None:
54 | self.weight = nn.Parameter(torch.Tensor(1).fill_(init_scale))
55 | if bias:
56 | self.bias = nn.Parameter(torch.Tensor(output_size).fill_(0))
57 | self.proj = Variable(torch.Tensor(
58 | output_size, input_size), requires_grad=False)
59 | torch.manual_seed(123)
60 | nn.init.orthogonal(self.proj)
61 |
62 | def forward(self, x):
63 | w = self.proj.type_as(x)
64 | x = x / x.norm(2, -1, keepdim=True)
65 | out = nn.functional.linear(x, w)
66 | if hasattr(self, 'weight'):
67 | out = out * self.weight
68 | if hasattr(self, 'bias'):
69 | out = out + self.bias.view(1, -1)
70 | return out
71 |
72 | class LinearFixed(nn.Linear):
73 |
74 | def __init__(self, input_size, output_size, bias=True, init_scale=10):
75 | super(LinearFixed, self).__init__(input_size, output_size, bias)
76 | self.scale = nn.Parameter(torch.Tensor(1).fill_(init_scale))
77 |
78 | def forward(self, x):
79 | w = self.weight / self.weight.norm(2, -1, keepdim=True)
80 | x = x / x.norm(2, -1, keepdim=True)
81 | out = nn.functional.linear(x, w, self.bias)
82 | return out
83 |
--------------------------------------------------------------------------------
/AdaPrune/models/modules/fixup.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def _sum_tensor_scalar(tensor, scalar, expand_size):
6 | if scalar is not None:
7 | scalar = scalar.expand(expand_size).contiguous()
8 | else:
9 | return tensor
10 | if tensor is None:
11 | return scalar
12 | return tensor + scalar
13 |
14 |
15 | class ZIConv2d(nn.Conv2d):
16 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
17 | padding=0, dilation=1, groups=1, bias=False,
18 | multiplier=False, pre_bias=True, post_bias=True):
19 | super(ZIConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
20 | padding, dilation, groups, bias)
21 | if pre_bias:
22 | self.pre_bias = nn.Parameter(torch.tensor([0.]))
23 | else:
24 | self.register_parameter('pre_bias', None)
25 | if post_bias:
26 | self.post_bias = nn.Parameter(torch.tensor([0.]))
27 | else:
28 | self.register_parameter('post_bias', None)
29 | if multiplier:
30 | self.multiplier = nn.Parameter(torch.tensor([1.]))
31 | else:
32 | self.register_parameter('multiplier', None)
33 |
34 | def forward(self, x):
35 | if self.pre_bias is not None:
36 | x = x + self.pre_bias
37 | weight = self.weight if self.multiplier is None\
38 | else self.weight * self.multiplier
39 | bias = _sum_tensor_scalar(self.bias, self.post_bias, self.out_channels)
40 | return nn.functional.conv2d(x, weight, bias, self.stride,
41 | self.padding, self.dilation, self.groups)
42 |
43 |
44 | class ZILinear(nn.Linear):
45 | def __init__(self, in_features, out_features, bias=False,
46 | multiplier=False, pre_bias=True, post_bias=True):
47 | super(ZILinear, self).__init__(in_features, out_features, bias)
48 | if pre_bias:
49 | self.pre_bias = nn.Parameter(torch.tensor([0.]))
50 | else:
51 | self.register_parameter('pre_bias', None)
52 | if post_bias:
53 | self.post_bias = nn.Parameter(torch.tensor([0.]))
54 | else:
55 | self.register_parameter('post_bias', None)
56 | if multiplier:
57 | self.multiplier = nn.Parameter(torch.tensor([1.]))
58 | else:
59 | self.register_parameter('multiplier', None)
60 |
61 | def forward(self, x):
62 | if self.pre_bias is not None:
63 | x = x + self.pre_bias
64 | weight = self.weight if self.multiplier is None\
65 | else self.weight * self.multiplier
66 | bias = _sum_tensor_scalar(self.bias, self.post_bias, self.out_features)
67 | return nn.functional.linear(x, weight, bias)
68 |
--------------------------------------------------------------------------------
/AdaPrune/models/modules/se.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class SEBlock(nn.Module):
5 | def __init__(self, in_channels, out_channels=None, ratio=16):
6 | super(SEBlock, self).__init__()
7 | self.in_channels = in_channels
8 | if out_channels is None:
9 | out_channels = in_channels
10 | self.ratio = ratio
11 | self.relu = nn.ReLU(True)
12 | self.global_pool = nn.AdaptiveAvgPool2d(1)
13 | self.transform = nn.Sequential(
14 | nn.Linear(in_channels, in_channels // ratio),
15 | nn.ReLU(inplace=True),
16 | nn.Linear(in_channels // ratio, out_channels),
17 | nn.Sigmoid()
18 | )
19 |
20 | def forward(self, x):
21 | x_avg = self.global_pool(x).view(x.size(0), -1)
22 | mask = self.transform(x_avg)
23 | return x * mask.view(x.size(0), -1, 1, 1)
24 |
25 |
--------------------------------------------------------------------------------
/AdaPrune/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchvision.transforms as transforms
4 | import random
5 | import PIL
6 |
7 |
8 | _IMAGENET_STATS = {'mean': [0.485, 0.456, 0.406],
9 | 'std': [0.229, 0.224, 0.225]}
10 |
11 | _IMAGENET_PCA = {
12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
13 | 'eigvec': torch.Tensor([
14 | [-0.5675, 0.7192, 0.4009],
15 | [-0.5808, -0.0045, -0.8140],
16 | [-0.5836, -0.6948, 0.4203],
17 | ])
18 | }
19 |
20 |
21 | def scale_crop(input_size, scale_size=None, num_crops=1, normalize=_IMAGENET_STATS):
22 | assert num_crops in [1, 5, 10], "num crops must be in {1,5,10}"
23 | convert_tensor = transforms.Compose([transforms.ToTensor(),
24 | transforms.Normalize(**normalize)])
25 | if num_crops == 1:
26 | t_list = [
27 | transforms.CenterCrop(input_size),
28 | convert_tensor
29 | ]
30 | else:
31 | if num_crops == 5:
32 | t_list = [transforms.FiveCrop(input_size)]
33 | elif num_crops == 10:
34 | t_list = [transforms.TenCrop(input_size)]
35 | # returns a 4D tensor
36 | t_list.append(transforms.Lambda(lambda crops:
37 | torch.stack([convert_tensor(crop) for crop in crops])))
38 |
39 | if scale_size != input_size:
40 | t_list = [transforms.Resize(scale_size)] + t_list
41 |
42 | return transforms.Compose(t_list)
43 |
44 |
45 | def scale_random_crop(input_size, scale_size=None, normalize=_IMAGENET_STATS):
46 | t_list = [
47 | transforms.RandomCrop(input_size),
48 | transforms.ToTensor(),
49 | transforms.Normalize(**normalize),
50 | ]
51 | if scale_size != input_size:
52 | t_list = [transforms.Resize(scale_size)] + t_list
53 |
54 | transforms.Compose(t_list)
55 |
56 |
57 | def pad_random_crop(input_size, scale_size=None, normalize=_IMAGENET_STATS):
58 | padding = int((scale_size - input_size) / 2)
59 | return transforms.Compose([
60 | transforms.RandomCrop(input_size, padding=padding),
61 | transforms.RandomHorizontalFlip(),
62 | transforms.ToTensor(),
63 | transforms.Normalize(**normalize),
64 | ])
65 |
66 |
67 |
68 | def inception_preproccess(input_size, normalize=_IMAGENET_STATS):
69 | return transforms.Compose([
70 | transforms.RandomResizedCrop(input_size),
71 | transforms.RandomHorizontalFlip(),
72 | transforms.ToTensor(),
73 | transforms.Normalize(**normalize)
74 | ])
75 |
76 |
77 | def inception_color_preproccess(input_size, normalize=_IMAGENET_STATS):
78 | return transforms.Compose([
79 | transforms.RandomResizedCrop(input_size),
80 | transforms.RandomHorizontalFlip(),
81 | transforms.ColorJitter(
82 | brightness=0.4,
83 | contrast=0.4,
84 | saturation=0.4,
85 | ),
86 | transforms.ToTensor(),
87 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
88 | transforms.Normalize(**normalize)
89 | ])
90 |
91 |
92 | def multi_transform(transform_fn, duplicates=1, dim=0):
93 | """preforms multiple transforms, useful to implement inference time augmentation or
94 | "batch augmentation" from https://openreview.net/forum?id=H1V4QhAqYQ¬eId=BylUSs_3Y7
95 | """
96 | if duplicates > 1:
97 | return transforms.Lambda(lambda x: torch.stack([transform_fn(x) for _ in range(duplicates)], dim=dim))
98 | else:
99 | return transform_fn
100 |
101 |
102 | def get_transform(transform_name='imagenet', input_size=None, scale_size=None,
103 | normalize=None, augment=True, cutout=None, autoaugment=False,
104 | duplicates=1, num_crops=1):
105 | normalize = normalize or _IMAGENET_STATS
106 | transform_fn = None
107 |
108 | if 'imagenet' in transform_name: # inception augmentation is default for imagenet
109 | scale_size = scale_size or 256
110 | input_size = input_size or 224
111 | if augment:
112 | transform_fn = inception_preproccess(input_size,
113 | normalize=normalize)
114 | else:
115 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size,
116 | num_crops=num_crops, normalize=normalize)
117 | elif 'cifar' in transform_name: # resnet augmentation is default for imagenet
118 | input_size = input_size or 32
119 | if augment:
120 | scale_size = scale_size or 40
121 | if autoaugment:
122 | transform_fn = cifar_autoaugment(input_size, scale_size=scale_size,
123 | normalize=normalize)
124 | else:
125 | transform_fn = pad_random_crop(input_size, scale_size=scale_size,
126 | normalize=normalize)
127 | else:
128 | scale_size = scale_size or 32
129 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size,
130 | num_crops=num_crops, normalize=normalize)
131 | elif transform_name == 'mnist':
132 | normalize = {'mean': [0.5], 'std': [0.5]}
133 | input_size = input_size or 28
134 | if augment:
135 | scale_size = scale_size or 32
136 | transform_fn = pad_random_crop(input_size, scale_size=scale_size,
137 | normalize=normalize)
138 | else:
139 | scale_size = scale_size or 32
140 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size,
141 | num_crops=num_crops, normalize=normalize)
142 | if cutout is not None:
143 | transform_fn.transforms.append(Cutout(**cutout))
144 | return multi_transform(transform_fn, duplicates)
145 |
146 |
147 | class Lighting(object):
148 | """Lighting noise(AlexNet - style PCA - based noise)"""
149 |
150 | def __init__(self, alphastd, eigval, eigvec):
151 | self.alphastd = alphastd
152 | self.eigval = eigval
153 | self.eigvec = eigvec
154 |
155 | def __call__(self, img):
156 | if self.alphastd == 0:
157 | return img
158 |
159 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
160 | rgb = self.eigvec.type_as(img).clone()\
161 | .mul(alpha.view(1, 3).expand(3, 3))\
162 | .mul(self.eigval.view(1, 3).expand(3, 3))\
163 | .sum(1).squeeze()
164 |
165 | return img.add(rgb.view(3, 1, 1).expand_as(img))
166 |
167 |
168 | class Cutout(object):
169 | """
170 | Randomly mask out one or more patches from an image.
171 | taken from https://github.com/uoguelph-mlrg/Cutout
172 |
173 |
174 | Args:
175 | holes (int): Number of patches to cut out of each image.
176 | length (int): The length (in pixels) of each square patch.
177 | """
178 |
179 | def __init__(self, holes, length):
180 | self.holes = holes
181 | self.length = length
182 |
183 | def __call__(self, img):
184 | """
185 | Args:
186 | img (Tensor): Tensor image of size (C, H, W).
187 | Returns:
188 | Tensor: Image with holes of dimension length x length cut out of it.
189 | """
190 | h = img.size(1)
191 | w = img.size(2)
192 |
193 | mask = np.ones((h, w), np.float32)
194 |
195 | for n in range(self.holes):
196 | y = np.random.randint(h)
197 | x = np.random.randint(w)
198 |
199 | y1 = np.clip(y - self.length // 2, 0, h)
200 | y2 = np.clip(y + self.length // 2, 0, h)
201 | x1 = np.clip(x - self.length // 2, 0, w)
202 | x2 = np.clip(x + self.length // 2, 0, w)
203 |
204 | mask[y1: y2, x1: x2] = 0.
205 |
206 | mask = torch.from_numpy(mask)
207 | mask = mask.expand_as(img)
208 | img = img * mask
209 |
210 | return img
211 |
--------------------------------------------------------------------------------
/AdaPrune/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | bokeh
4 | pandas
5 |
--------------------------------------------------------------------------------
/AdaPrune/scripts/adaprune_dense_bnt.sh:
--------------------------------------------------------------------------------
1 | export datasets_dir=/home/Datasets
2 | export model=${1:-"resnet"}
3 | export model_vis=${2:-"resnet50"}
4 | export depth=${3:-50}
5 | export adaprune_suffix=''
6 | if [ "$5" = True ]; then
7 | export adaprune_suffix='.adaprune'
8 | fi
9 | export workdir='dense_'${model_vis}$adaprune_suffix
10 | export perC=True
11 |
12 |
13 | echo ./results/$workdir/resnet
14 | #Download and absorb_bn resnet50 and
15 | python main.py --model $model --save $workdir -b 128 -lfv $model_vis --model-config "{'batch_norm': False,'depth':$depth}" --device-id 1
16 |
17 | # Run adaprune to minimize MSE of the output with respect to a perturations in parameters
18 | python main.py --optimize-weights --model $model -b 200 --evaluate results/$workdir/$model.absorb_bn --model-config "{'batch_norm': False,'depth':$depth}" --dataset imagenet_calib --datasets-dir $datasets_dir --adaprune --prune_bs 8 --prune_topk 4 --device-id 0 --keep_first_last #--unstructured --sparsity_level 0.5
19 | python main.py --batch-norn-tuning --model $model -lfv $model_vis -b 200 --evaluate results/$workdir/$model.absorb_bn.adaprune --model-config "{'batch_norm': False,'depth':$depth}" --dataset imagenet_calib --datasets-dir $datasets_dir --device-id 0
20 |
21 |
--------------------------------------------------------------------------------
/AdaPrune/scripts/adaprune_sparse.sh:
--------------------------------------------------------------------------------
1 | export datasets_dir=/home/Datasets
2 | export model=${1:-"resnet"}
3 | export model_vis=${2:-"resnet50"}
4 | export depth=${3:-50}
5 | export adaprune_suffix='.adaprune'
6 |
7 | export workdir='sparse_'${model_vis}$adaprune_suffix
8 | mkdir ./results/$workdir
9 | echo ./results/$workdir/resnet
10 |
11 | #copy sparse model to workdir
12 | cp ./results/resnet50/model_best.pth.tar ./results/$workdir/resnet
13 |
14 | # Run adaprune to minimize MSE of the output with respect to a small perturations in parameters
15 | python main.py --optimize-weights --model $model -b 200 --evaluate results/$workdir/$model --model-config "{'batch_norm': True,'depth':$depth}" --dataset imagenet_calib --datasets-dir $datasets_dir --adaprune --prune_bs 4 --prune_topk 2
16 |
17 |
--------------------------------------------------------------------------------
/AdaPrune/utils/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Elad Hoffer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/AdaPrune/utils/absorb_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import logging
4 | # from efficientnet_pytorch.utils import Conv2dSamePadding
5 |
6 | def remove_bn_params(bn_module):
7 | bn_module.register_buffer('running_mean', None)
8 | bn_module.register_buffer('running_var', None)
9 | bn_module.register_parameter('weight', None)
10 | bn_module.register_parameter('bias', None)
11 |
12 |
13 | def init_bn_params(bn_module):
14 | bn_module.running_mean.fill_(0)
15 | bn_module.running_var.fill_(1)
16 | if bn_module.affine:
17 | bn_module.weight.fill_(1)
18 | bn_module.bias.fill_(0)
19 |
20 |
21 | def absorb_bn(module, bn_module, remove_bn=True, verbose=False):
22 | with torch.no_grad():
23 | w = module.weight
24 | if module.bias is None:
25 | zeros = torch.zeros(module.out_channels,
26 | dtype=w.dtype, device=w.device)
27 | bias = nn.Parameter(zeros)
28 | module.register_parameter('bias', bias)
29 | b = module.bias
30 |
31 | if hasattr(bn_module, 'running_mean'):
32 | b.add_(-bn_module.running_mean)
33 | if hasattr(bn_module, 'running_var'):
34 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5)
35 | w.mul_(invstd.view(w.size(0), 1, 1, 1))
36 | b.mul_(invstd)
37 | if hasattr(module, 'quantize_weight'):
38 | module.quantize_weight.running_range.mul_(invstd.view(w.size(0), 1, 1, 1))
39 | module.quantize_weight.running_zero_point.mul_(invstd.view(w.size(0), 1, 1, 1))
40 |
41 | if hasattr(bn_module, 'weight'):
42 | w.mul_(bn_module.weight.view(w.size(0), 1, 1, 1))
43 | b.mul_(bn_module.weight)
44 | module.register_parameter('gamma', nn.Parameter(bn_module.weight.data.clone()))
45 | if hasattr(module, 'quantize_weight'):
46 | module.quantize_weight.running_range.mul_(bn_module.weight.view(w.size(0), 1, 1, 1))
47 | module.quantize_weight.running_zero_point.mul_(bn_module.weight.view(w.size(0), 1, 1, 1))
48 | if hasattr(bn_module, 'bias'):
49 | b.add_(bn_module.bias)
50 | module.register_parameter('beta', nn.Parameter(bn_module.bias.data.clone()))
51 |
52 | if remove_bn:
53 | remove_bn_params(bn_module)
54 | else:
55 | init_bn_params(bn_module)
56 |
57 | if verbose:
58 | logging.info('BN module %s was asborbed into layer %s' %
59 | (bn_module, module))
60 |
61 |
62 | def is_bn(m):
63 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)
64 |
65 |
66 | def is_absorbing(m):
67 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, Conv2dSamePadding)
68 |
69 |
70 | def search_absorbe_bn(model, prev=None, remove_bn=True, verbose=False):
71 | with torch.no_grad():
72 | for m in model.children():
73 | if is_bn(m) and is_absorbing(prev):
74 | # print(prev,m)
75 | absorb_bn(prev, m, remove_bn=remove_bn, verbose=verbose)
76 | search_absorbe_bn(m, remove_bn=remove_bn, verbose=verbose)
77 | prev = m
78 |
79 |
80 | def absorb_fake_bn(module, bn_module, verbose=False):
81 | with torch.no_grad():
82 | w = module.weight
83 | if module.bias is None:
84 | zeros = torch.zeros(module.out_channels,
85 | dtype=w.dtype, device=w.device)
86 | bias = nn.Parameter(zeros)
87 | module.register_parameter('bias', bias)
88 |
89 | if verbose:
90 | logging.info('BN module %s was asborbed into layer %s' %
91 | (bn_module, module))
92 |
93 |
94 | def is_fake_bn(m):
95 | from models.resnet import Lambda
96 | return isinstance(m, Lambda)
97 |
98 |
99 | def search_absorbe_fake_bn(model, prev=None, remove_bn=True, verbose=False):
100 | with torch.no_grad():
101 | for m in model.children():
102 | if is_fake_bn(m) and is_absorbing(prev):
103 | # print(prev,m)
104 | absorb_fake_bn(prev, m, verbose=verbose)
105 | search_absorbe_fake_bn(m, remove_bn=remove_bn, verbose=verbose)
106 | prev = m
107 |
108 |
109 | def add_bn(module, bn_module, verbose=False):
110 | bn = nn.BatchNorm2d(module.out_channels)
111 |
112 | def bn_forward(bn, x):
113 | res = bn(x)
114 | return res
115 |
116 | bn_module.forward_orig = bn_module.forward
117 | bn_module.forward = lambda x: bn_forward(bn, x)
118 | bn.to(module.weight.device)
119 |
120 | bn.register_buffer('running_var', module.gamma**2)
121 | bn.register_buffer('running_mean', module.beta.clone())
122 | bn.register_parameter('weight', nn.Parameter(torch.sqrt(bn.running_var + bn.eps)))
123 | bn.register_parameter('bias', nn.Parameter(bn.running_mean.clone()))
124 |
125 | bn_module.bn = bn
126 |
127 |
128 | def need_tuning(module):
129 | return hasattr(module, 'num_bits') #and module.groups == 1
130 |
131 |
132 | def search_add_bn(model, prev=None, remove_bn=True, verbose=False):
133 | with torch.no_grad():
134 | for m in model.children():
135 | if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev):
136 | # print(prev,m)
137 | add_bn(prev, m, verbose=verbose)
138 | search_add_bn(m, remove_bn=remove_bn, verbose=verbose)
139 | prev = m
140 |
141 |
142 | def search_absorbe_tuning_bn(model, prev=None, remove_bn=True, verbose=False):
143 | with torch.no_grad():
144 | for m in model.children():
145 | if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev):
146 | # print(prev,m)
147 | absorb_bn(prev, m.bn, remove_bn=remove_bn, verbose=verbose)
148 | m.forward = m.forward_orig
149 | m.bn = None
150 | search_absorbe_tuning_bn(m, remove_bn=remove_bn, verbose=verbose)
151 | prev = m
152 |
153 |
154 | def copy_bn_params(module, bn_module, remove_bn=True, verbose=False):
155 | with torch.no_grad():
156 | if hasattr(bn_module, 'weight'):
157 | module.register_parameter('gamma', nn.Parameter(bn_module.weight.data.clone()))
158 |
159 | if hasattr(bn_module, 'bias'):
160 | module.register_parameter('beta', nn.Parameter(bn_module.bias.data.clone()))
161 |
162 |
163 | def search_copy_bn_params(model, prev=None, remove_bn=True, verbose=False):
164 | with torch.no_grad():
165 | for m in model.children():
166 | if is_bn(m) and is_absorbing(prev):
167 | # print(prev,m)
168 | copy_bn_params(prev, m, remove_bn=remove_bn, verbose=verbose)
169 | search_copy_bn_params(m, remove_bn=remove_bn, verbose=verbose)
170 | prev = m
171 |
172 |
173 | # def recalibrate_bn(module, bn_module, verbose=False):
174 | # bn = bn_module.bn
175 | # bn.register_parameter('weight', nn.Parameter(torch.sqrt(bn.running_var + bn.eps)))
176 | # bn.register_parameter('bias', nn.Parameter(bn.running_mean.clone()))
177 | #
178 | #
179 | # def search_bn_recalibrate(model, prev=None, remove_bn=True, verbose=False):
180 | # with torch.no_grad():
181 | # for m in model.children():
182 | # if is_fake_bn(m) and is_absorbing(prev) and need_tuning(prev):
183 | # recalibrate_bn(prev, m, verbose=verbose)
184 | # search_bn_recalibrate(m, remove_bn=remove_bn, verbose=verbose)
185 | # prev = m
186 |
--------------------------------------------------------------------------------
/AdaPrune/utils/adaprune.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from tqdm import tqdm
6 | import scipy.optimize as opt
7 | import math
8 |
9 |
10 |
11 | def adaprune(layer, mask, cached_inps, cached_outs, test_inp, test_out, lr1=1e-4, lr2=1e-2, iters=1000, progress=True, batch_size=50,relu=False,bs=8,no_optimization=False,keep_first_last=True):
12 | print("\nRun adaprune")
13 | test_inp = test_inp.to(layer.weight.device)
14 | test_out = test_out.to(layer.weight.device)
15 | layer.quantize=False
16 | if keep_first_last and (layer.weight.dim()==2 or layer.weight.shape[1]==3):
17 | return 0.1, 0.1
18 | with torch.no_grad():
19 | layer.weight.data = absorb_mean_to_nz(layer.weight,mask,bs=bs)
20 | layer.weight.mul_(mask.to(layer.weight.device))
21 | mse_before = F.mse_loss(layer(test_inp), test_out)
22 | if no_optimization:
23 | return mse_before.item(),mse_before.item()
24 |
25 | lr_w = 1e-3
26 | lr_b = 1e-2
27 |
28 | opt_w = torch.optim.Adam([layer.weight], lr=lr_w)
29 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias = torch.optim.Adam([layer.bias], lr=lr_b)
30 |
31 | losses = []
32 |
33 | for j in (tqdm(range(iters)) if progress else range(iters)):
34 | idx = torch.randperm(cached_inps.size(0))[:batch_size]
35 |
36 | train_inp = cached_inps[idx].to(layer.weight.device)
37 | train_out = cached_outs[idx].to(layer.weight.device)
38 | qout = layer(train_inp)
39 | if relu:
40 | loss = F.mse_loss(F.relu(qout), F.relu(train_out))
41 | else:
42 | loss = F.mse_loss(qout, train_out)
43 |
44 | losses.append(loss.item())
45 | opt_w.zero_grad()
46 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias.zero_grad()
47 | loss.backward()
48 | opt_w.step()
49 | if hasattr(layer, 'bias') and layer.bias is not None: opt_bias.step()
50 | with torch.no_grad():
51 | layer.weight.mul_(mask.to(layer.weight.device))
52 |
53 | mse_after = F.mse_loss(layer(test_inp), test_out)
54 | return mse_before.item(), mse_after.item()
55 |
56 | def absorb_mean_to_nz(weight,mask,bs=8):
57 | """Prunes the weights with smallest magnitude."""
58 | if weight.dim()>2:
59 | Co,Ci,k1,k2=weight.shape
60 | pad_size=bs-(Ci*k1*k2)%bs if bs>1 else 0
61 | weight_pad = torch.cat((weight.permute(0,2,3,1).contiguous().view(Co,-1),torch.zeros(Co,pad_size).to(weight.data)),1)
62 | mask_pad = torch.cat((mask.permute(0,2,3,1).contiguous().view(Co,-1).float(),torch.ones(Co,pad_size).to(weight.data).float()),1)
63 | else:
64 | Co,Ci=weight.shape
65 | pad_size=bs-Ci%bs if bs>1 else 0
66 | weight_pad = torch.cat((weight.view(Co,-1),torch.zeros(Co,pad_size).to(weight.data)),1)
67 | mask_pad = torch.cat((mask.view(Co,-1).float(),torch.ones(Co,pad_size).to(weight.data).float()),1)
68 |
69 | weight_pad = weight_pad.view(Co,-1,bs)+weight_pad.view(Co,-1,bs).mul(1-mask_pad.view(Co,-1,bs)).sum(2,keepdim=True).div(mask_pad.view(Co,-1,bs).sum(2,keepdim=True))
70 | weight_pad.mul_(mask_pad.view(Co,-1,bs))
71 | if weight.dim()>2:
72 | weight_pad = weight_pad.view(Co,-1)[:,:Ci*k1*k2]
73 | weight_pad = weight_pad.view(Co,k1,k2,Ci).permute(0,3,1,2)
74 | else:
75 | weight_pad = weight_pad.view(Co,-1)[:,:Ci]
76 | return weight_pad
77 |
78 | def create_block_magnitude_mask(weight, bs=2, topk=1):
79 | """Prunes the weights with smallest magnitude."""
80 | if weight.dim()>2:
81 | Co,Ci,k1,k2=weight.shape
82 | pad_size=bs-(Ci*k1*k2)%bs if bs>1 else 0
83 | weight_pad = torch.cat((weight.permute(0,2,3,1).contiguous().view(Co,-1),torch.zeros(Co,pad_size).to(weight.data)),1)
84 | else:
85 | Co,Ci=weight.shape
86 | pad_size=bs-Ci%bs if bs>1 else 0
87 | weight_pad = torch.cat((weight.view(Co,-1),torch.zeros(Co,pad_size).to(weight.data)),1)
88 |
89 | block_weight = weight_pad.data.abs().view(Co,-1,bs).topk(k=topk,dim=2,sorted=False)[1].reshape(Co,-1,topk)
90 | block_masks = torch.zeros_like(weight_pad).reshape(Co, -1, bs).scatter_(2, block_weight, torch.ones(block_weight.shape).to(weight))
91 |
92 | if weight.dim()>2:
93 | block_masks = block_masks.view(Co,-1)[:,:Ci*k1*k2]
94 | block_masks = block_masks.view(Co,k1,k2,Ci).permute(0,3,1,2)
95 | else:
96 | block_masks = block_masks.view(Co,-1)[:,:Ci]
97 | return block_masks
98 |
99 | def create_global_unstructured_magnitude_mask(param,global_val):
100 | eps = 0.1 if param.shape[1]==3 else 0
101 | return param.abs().gt(global_val-eps)
102 |
103 | def create_unstructured_magnitude_mask(param,sparsity_level,absorb_mean=True):
104 | topk = int(param.numel()*sparsity_level)
105 | val = param.view(-1).abs().topk(topk,sorted=True)[0][-1]
106 | mask = param.abs().gt(val)
107 | if absorb_mean:
108 | with torch.no_grad():
109 | mean_val=param[~mask].mean()
110 | aa = param+mask*mean_val
111 | param.copy_(aa)
112 | print('unstructured mask created with %f sparsity'%(mask.sum().float()/mask.numel()))
113 | return mask
114 |
115 | def extract_topk(param,bs,global_val,conf_level=0.95):
116 | if global_val is not None:
117 | param = create_global_unstructured_magnitude_mask(param,global_val)
118 | p = (1 - param.ne(0).float().sum() / param.numel()).item()
119 | n = bs
120 | P=[]
121 | B=param.numel()/n
122 | for k in range(n):
123 | S = 0
124 | for i in range(k,n+1):
125 | C = math.factorial(n)/(math.factorial(i)*math.factorial(n-i))
126 | S = min(S + C*(p**i)*(1-p)**(n-i),1.0)
127 | P.append(S)
128 | RSD = [math.sqrt((1-pp)/(B*pp)) for pp in P]
129 | P_RSD = np.array(P) #- np.array(RSD)*5
130 | aa = [i for i,p in enumerate(P_RSD) if p>conf_level]
131 | if len(aa)>0:
132 | topk = n-[i for i,p in enumerate(P_RSD) if p>conf_level][-1]
133 | else:
134 | topk=n
135 | return topk
136 |
137 | def create_mask(layer,bs=8,topk=4,prune_extract_topk=False,unstructured =True,sparsity_level=0.5,global_val=None,conf_level=0.95):
138 | if unstructured:
139 | if global_val is not None and not prune_extract_topk:
140 | print('Creating unstructured mask for layer %s'%(layer.name))
141 | return create_global_unstructured_magnitude_mask(layer.weight,global_val,conf_level)
142 | else:
143 | return create_unstructured_magnitude_mask(layer.weight,sparsity_level=sparsity_level)
144 | if prune_extract_topk: topk = extract_topk(layer.weight,bs,global_val,conf_level=conf_level)
145 | print('Creating mask for layer %s with bs %d ,topk %d'%(layer.name,bs,topk))
146 | return create_block_magnitude_mask(layer.weight,bs=bs,topk=topk)
147 |
148 | def optimize_layer(layer, in_out, optimize_weights=False,bs=4,topk=2,extract_topk=False,unstructured=False,sparsity_level=0.5,global_val=None,conf_level=0.95):
149 | batch_size = 100
150 |
151 | cached_inps = torch.cat([x[0] for x in in_out])
152 | cached_outs = torch.cat([x[1] for x in in_out])
153 |
154 | idx = torch.randperm(cached_inps.size(0))[:batch_size]
155 |
156 | test_inp = cached_inps[idx]
157 | test_out = cached_outs[idx]
158 |
159 | if optimize_weights:
160 | mask = create_mask(layer,bs=bs,topk=topk,prune_extract_topk=extract_topk,unstructured=unstructured,sparsity_level=sparsity_level,global_val=global_val,conf_level=conf_level)
161 | if 'conv1' in layer.name or 'conv2' in layer.name:
162 | mse_before, mse_after = adaprune(layer, mask, cached_inps, cached_outs, test_inp, test_out, iters=1000, lr1=1e-5, lr2=1e-4,relu=False,bs=bs)
163 | else:
164 | mse_before, mse_after = adaprune(layer, mask, cached_inps, cached_outs, test_inp, test_out, iters=1000, lr1=1e-5, lr2=1e-4,relu=False,bs=bs)
165 |
166 | mse_before_opt = mse_before
167 | print("MSE before adaprune (opt weight): {}".format(mse_before))
168 | print("MSE after adaprune (opt weight): {}".format(mse_after))
169 | torch.cuda.empty_cache()
170 | else:
171 | mse_before, mse_after = optimize_qparams(layer, cached_inps, cached_outs, test_inp, test_out)
172 | mse_before_opt = mse_before
173 | print("MSE before qparams: {}".format(mse_before))
174 | print("MSE after qparams: {}".format(mse_after))
175 |
176 | mse_after_opt = mse_after
177 |
178 | with torch.no_grad():
179 | N = test_out.numel()
180 | snr_before = (1/math.sqrt(N)) * math.sqrt(N * mse_before_opt) / torch.norm(test_out).item()
181 | snr_after = (1/math.sqrt(N)) * math.sqrt(N * mse_after_opt) / torch.norm(test_out).item()
182 |
183 |
184 | kurt_in = kurtosis(test_inp).item()
185 | kurt_w = kurtosis(layer.weight).item()
186 |
187 | del cached_inps
188 | del cached_outs
189 | torch.cuda.empty_cache()
190 |
191 | return mse_before_opt, mse_after_opt, snr_before, snr_after, kurt_in, kurt_w, mask
192 |
193 |
194 | def kurtosis(x):
195 | var = torch.mean((x - x.mean())**2)
196 | return torch.mean((x - x.mean())**4 / var**2)
197 |
198 |
199 | def dump(model_name, layer, in_out):
200 | path = os.path.join("dump", model_name, layer.name)
201 | if os.path.exists(path):
202 | shutil.rmtree(path)
203 | os.makedirs(path)
204 |
205 | if hasattr(layer, 'groups'):
206 | f = open(os.path.join(path, "groups_{}".format(layer.groups)), 'x')
207 | f.close()
208 |
209 | cached_inps = torch.cat([x[0] for x in in_out])
210 | cached_outs = torch.cat([x[1] for x in in_out])
211 | torch.save(cached_inps, os.path.join(path, "input.pt"))
212 | torch.save(cached_outs, os.path.join(path, "output.pt"))
213 | torch.save(layer.weight, os.path.join(path, 'weight.pt'))
214 | if layer.bias is not None:
215 | torch.save(layer.bias, os.path.join(path, 'bias.pt'))
216 |
217 |
218 |
--------------------------------------------------------------------------------
/AdaPrune/utils/cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .misc import onehot
6 |
7 |
8 | def _is_long(x):
9 | if hasattr(x, 'data'):
10 | x = x.data
11 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor)
12 |
13 |
14 | def cross_entropy(logits, target, weight=None, ignore_index=-100, reduction='mean',
15 | smooth_eps=None, smooth_dist=None):
16 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567"""
17 | smooth_eps = smooth_eps or 0
18 |
19 | # ordinary log-liklihood - use cross_entropy from nn
20 | if _is_long(target) and smooth_eps == 0:
21 | return F.cross_entropy(logits, target, weight, ignore_index=ignore_index, reduction=reduction)
22 |
23 | masked_indices = None
24 | num_classes = logits.size(-1)
25 |
26 | if _is_long(target) and ignore_index >= 0:
27 | masked_indices = target.eq(ignore_index)
28 |
29 | if smooth_eps > 0 and smooth_dist is not None:
30 | if _is_long(target):
31 | target = onehot(target, num_classes).type_as(logits)
32 | if smooth_dist.dim() < target.dim():
33 | smooth_dist = smooth_dist.unsqueeze(0)
34 | target.lerp_(smooth_dist, smooth_eps)
35 |
36 | # log-softmax of logits
37 | lsm = F.log_softmax(logits, dim=-1)
38 |
39 | if weight is not None:
40 | lsm = lsm * weight.unsqueeze(0)
41 |
42 | if _is_long(target):
43 | eps = smooth_eps / (num_classes - 1)
44 | nll = -lsm.gather(dim=-1, index=target.unsqueeze(-1))
45 | loss = (1. - 2 * eps) * nll - eps * lsm.sum(-1)
46 | else:
47 | loss = -(target * lsm).sum(-1)
48 |
49 | if masked_indices is not None:
50 | loss.masked_fill_(masked_indices, 0)
51 |
52 | if reduction == 'sum':
53 | loss = loss.sum()
54 | elif reduction == 'mean':
55 | if masked_indices is None:
56 | loss = loss.mean()
57 | else:
58 | loss = loss.sum() / float(loss.size(0) - masked_indices.sum())
59 |
60 | return loss
61 |
62 |
63 | class CrossEntropyLoss(nn.CrossEntropyLoss):
64 | """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing"""
65 |
66 | def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None):
67 | super(CrossEntropyLoss, self).__init__(weight=weight,
68 | ignore_index=ignore_index, reduction=reduction)
69 | self.smooth_eps = smooth_eps
70 | self.smooth_dist = smooth_dist
71 |
72 | def forward(self, input, target, smooth_dist=None):
73 | if smooth_dist is None:
74 | smooth_dist = self.smooth_dist
75 | return cross_entropy(input, target, self.weight, self.ignore_index, self.reduction, self.smooth_eps, smooth_dist)
76 |
--------------------------------------------------------------------------------
/AdaPrune/utils/dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import pickle
3 | import PIL
4 | import torch
5 | from torch.utils.data import Dataset
6 | from torch.utils.data.sampler import Sampler, RandomSampler, BatchSampler, _int_classes
7 | from numpy.random import choice
8 |
9 | class RandomSamplerReplacment(torch.utils.data.sampler.Sampler):
10 | """Samples elements randomly, with replacement.
11 | Arguments:
12 | data_source (Dataset): dataset to sample from
13 | """
14 |
15 | def __init__(self, data_source):
16 | self.num_samples = len(data_source)
17 |
18 | def __iter__(self):
19 | return iter(torch.from_numpy(choice(self.num_samples, self.num_samples, replace=True)))
20 |
21 | def __len__(self):
22 | return self.num_samples
23 |
24 |
25 | class LimitDataset(Dataset):
26 |
27 | def __init__(self, dset, max_len):
28 | self.dset = dset
29 | self.max_len = max_len
30 |
31 | def __len__(self):
32 | return min(len(self.dset), self.max_len)
33 |
34 | def __getitem__(self, index):
35 | return self.dset[index]
36 |
37 | class ByClassDataset(Dataset):
38 |
39 | def __init__(self, ds):
40 | self.dataset = ds
41 | self.idx_by_class = {}
42 | for idx, (_, c) in enumerate(ds):
43 | self.idx_by_class.setdefault(c, [])
44 | self.idx_by_class[c].append(idx)
45 |
46 | def __len__(self):
47 | return min([len(d) for d in self.idx_by_class.values()])
48 |
49 | def __getitem__(self, idx):
50 | idx_per_class = [self.idx_by_class[c][idx]
51 | for c in range(len(self.idx_by_class))]
52 | labels = torch.LongTensor([self.dataset[i][1]
53 | for i in idx_per_class])
54 | items = [self.dataset[i][0] for i in idx_per_class]
55 | if torch.is_tensor(items[0]):
56 | items = torch.stack(items)
57 |
58 | return (items, labels)
59 |
60 |
61 | class IdxDataset(Dataset):
62 | """docstring for IdxDataset."""
63 |
64 | def __init__(self, dset):
65 | super(IdxDataset, self).__init__()
66 | self.dset = dset
67 | self.idxs = range(len(self.dset))
68 |
69 | def __getitem__(self, idx):
70 | data, labels = self.dset[self.idxs[idx]]
71 | return (idx, data, labels)
72 |
73 | def __len__(self):
74 | return len(self.idxs)
75 |
76 |
77 | def image_loader(imagebytes):
78 | img = PIL.Image.open(BytesIO(imagebytes))
79 | return img.convert('RGB')
80 |
81 |
82 | class IndexedFileDataset(Dataset):
83 | """ A dataset that consists of an indexed file (with sample offsets in
84 | another file). For example, a .tar that contains image files.
85 | The dataset does not extract the samples, but works with the indexed
86 | file directly.
87 | NOTE: The index file is assumed to be a pickled list of 3-tuples:
88 | (name, offset, size).
89 | """
90 | def __init__(self, filename, index_filename=None, extract_target_fn=None,
91 | transform=None, target_transform=None, loader=image_loader):
92 | super(IndexedFileDataset, self).__init__()
93 |
94 | # Defaults
95 | if index_filename is None:
96 | index_filename = filename + '.index'
97 | if extract_target_fn is None:
98 | extract_target_fn = lambda *args: args
99 |
100 | # Read index
101 | with open(index_filename, 'rb') as index_fp:
102 | sample_list = pickle.load(index_fp)
103 |
104 | # Collect unique targets (sorted by name)
105 | targetset = set(extract_target_fn(target) for target, _, _ in sample_list)
106 | targetmap = {target: i for i, target in enumerate(sorted(targetset))}
107 |
108 | self.samples = [(targetmap[extract_target_fn(target)], offset, size)
109 | for target, offset, size in sample_list]
110 | self.filename = filename
111 |
112 | self.loader = loader
113 | self.transform = transform
114 | self.target_transform = target_transform
115 |
116 | def _get_sample(self, fp, idx):
117 | target, offset, size = self.samples[idx]
118 | fp.seek(offset)
119 | sample = self.loader(fp.read(size))
120 |
121 | if self.transform is not None:
122 | sample = self.transform(sample)
123 | if self.target_transform is not None:
124 | target = self.target_transform(target)
125 |
126 | return sample, target
127 |
128 | def __getitem__(self, index):
129 | with open(self.filename, 'rb') as fp:
130 | # Handle slices
131 | if isinstance(index, slice):
132 | return [self._get_sample(fp, subidx) for subidx in
133 | range(index.start or 0, index.stop or len(self),
134 | index.step or 1)]
135 |
136 | return self._get_sample(fp, index)
137 |
138 | def __len__(self):
139 | return len(self.samples)
140 |
141 |
142 | class DuplicateBatchSampler(Sampler):
143 | def __init__(self, sampler, batch_size, duplicates, drop_last):
144 | if not isinstance(sampler, Sampler):
145 | raise ValueError("sampler should be an instance of "
146 | "torch.utils.data.Sampler, but got sampler={}"
147 | .format(sampler))
148 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
149 | batch_size <= 0:
150 | raise ValueError("batch_size should be a positive integeral value, "
151 | "but got batch_size={}".format(batch_size))
152 | if not isinstance(drop_last, bool):
153 | raise ValueError("drop_last should be a boolean value, but got "
154 | "drop_last={}".format(drop_last))
155 | self.sampler = sampler
156 | self.batch_size = batch_size
157 | self.drop_last = drop_last
158 | self.duplicates = duplicates
159 |
160 | def __iter__(self):
161 | batch = []
162 | for idx in self.sampler:
163 | batch.append(idx)
164 | if len(batch) == self.batch_size:
165 | yield batch * self.duplicates
166 | batch = []
167 | if len(batch) > 0 and not self.drop_last:
168 | yield batch * self.duplicates
169 |
170 | def __len__(self):
171 | if self.drop_last:
172 | return len(self.sampler) // self.batch_size
173 | else:
174 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
175 |
--------------------------------------------------------------------------------
/AdaPrune/utils/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd.function import Function
3 |
4 | class ScaleGrad(Function):
5 |
6 | @staticmethod
7 | def forward(ctx, input, scale):
8 | ctx.scale = scale
9 | return input
10 |
11 | @staticmethod
12 | def backward(ctx, grad_output):
13 | grad_input = ctx.scale * grad_output
14 | return grad_input, None
15 |
16 |
17 | def scale_grad(x, scale):
18 | return ScaleGrad().apply(x, scale)
19 |
20 | def negate_grad(x):
21 | return scale_grad(x, -1)
22 |
--------------------------------------------------------------------------------
/AdaPrune/utils/log.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | from itertools import cycle
4 | import torch
5 | import logging.config
6 | from datetime import datetime
7 | import json
8 |
9 | import pandas as pd
10 | from bokeh.io import output_file, save, show
11 | from bokeh.plotting import figure
12 | from bokeh.layouts import column
13 | from bokeh.models import Div
14 |
15 | try:
16 | import hyperdash
17 | HYPERDASH_AVAILABLE = True
18 | except ImportError:
19 | HYPERDASH_AVAILABLE = False
20 |
21 |
22 | def export_args_namespace(args, filename):
23 | """
24 | args: argparse.Namespace
25 | arguments to save
26 | filename: string
27 | filename to save at
28 | """
29 | with open(filename, 'w') as fp:
30 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4)
31 |
32 |
33 | def setup_logging(log_file='log.txt', resume=False, dummy=False):
34 | """
35 | Setup logging configuration
36 | """
37 | if dummy:
38 | logging.getLogger('dummy')
39 | else:
40 | if os.path.isfile(log_file) and resume:
41 | file_mode = 'a'
42 | else:
43 | file_mode = 'w'
44 |
45 | root_logger = logging.getLogger()
46 | if root_logger.handlers:
47 | root_logger.removeHandler(root_logger.handlers[0])
48 | logging.basicConfig(level=logging.DEBUG,
49 | format="%(asctime)s - %(levelname)s - %(message)s",
50 | datefmt="%Y-%m-%d %H:%M:%S",
51 | filename=log_file,
52 | filemode=file_mode)
53 | console = logging.StreamHandler()
54 | console.setLevel(logging.INFO)
55 | formatter = logging.Formatter('%(message)s')
56 | console.setFormatter(formatter)
57 | logging.getLogger('').addHandler(console)
58 |
59 |
60 | def plot_figure(data, x, y, title=None, xlabel=None, ylabel=None, legend=None,
61 | x_axis_type='linear', y_axis_type='linear',
62 | width=800, height=400, line_width=2,
63 | colors=['red', 'green', 'blue', 'orange',
64 | 'black', 'purple', 'brown'],
65 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save',
66 | append_figure=None):
67 | """
68 | creates a new plot figures
69 | example:
70 | plot_figure(x='epoch', y=['train_loss', 'val_loss'],
71 | 'title='Loss', 'ylabel'='loss')
72 | """
73 | if not isinstance(y, list):
74 | y = [y]
75 | xlabel = xlabel or x
76 | legend = legend or y
77 | assert len(legend) == len(y)
78 | if append_figure is not None:
79 | f = append_figure
80 | else:
81 | f = figure(title=title, tools=tools,
82 | width=width, height=height,
83 | x_axis_label=xlabel or x,
84 | y_axis_label=ylabel or '',
85 | x_axis_type=x_axis_type,
86 | y_axis_type=y_axis_type)
87 | colors = cycle(colors)
88 | for i, yi in enumerate(y):
89 | f.line(data[x], data[yi],
90 | line_width=line_width,
91 | line_color=next(colors), legend=legend[i])
92 | f.legend.click_policy = "hide"
93 | return f
94 |
95 |
96 | class ResultsLog(object):
97 |
98 | supported_data_formats = ['csv', 'json']
99 |
100 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'):
101 | """
102 | Parameters
103 | ----------
104 | path: string
105 | path to directory to save data files
106 | plot_path: string
107 | path to directory to save plot files
108 | title: string
109 | title of HTML file
110 | params: Namespace
111 | optionally save parameters for results
112 | resume: bool
113 | resume previous logging
114 | data_format: str('csv'|'json')
115 | which file format to use to save the data
116 | """
117 | if data_format not in ResultsLog.supported_data_formats:
118 | raise ValueError('data_format must of the following: ' +
119 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats]))
120 |
121 | if data_format == 'json':
122 | self.data_path = '{}.json'.format(path)
123 | else:
124 | self.data_path = '{}.csv'.format(path)
125 | if params is not None:
126 | export_args_namespace(params, '{}.json'.format(path))
127 | self.plot_path = '{}.html'.format(path)
128 | self.results = None
129 | self.clear()
130 | self.first_save = True
131 | if os.path.isfile(self.data_path):
132 | if resume:
133 | self.load(self.data_path)
134 | self.first_save = False
135 | else:
136 | os.remove(self.data_path)
137 | self.results = pd.DataFrame()
138 | else:
139 | self.results = pd.DataFrame()
140 |
141 | self.title = title
142 | self.data_format = data_format
143 |
144 | if HYPERDASH_AVAILABLE:
145 | name = self.title if title != '' else path
146 | self.hd_experiment = hyperdash.Experiment(name)
147 | if params is not None:
148 | for k, v in params._get_kwargs():
149 | self.hd_experiment.param(k, v, log=False)
150 |
151 | def clear(self):
152 | self.figures = []
153 |
154 | def add(self, **kwargs):
155 | """Add a new row to the dataframe
156 | example:
157 | resultsLog.add(epoch=epoch_num, train_loss=loss,
158 | test_loss=test_loss)
159 | """
160 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
161 | self.results = self.results.append(df, ignore_index=True)
162 | if hasattr(self, 'hd_experiment'):
163 | for k, v in kwargs.items():
164 | self.hd_experiment.metric(k, v, log=False)
165 |
166 | def smooth(self, column_name, window):
167 | """Select an entry to smooth over time"""
168 | # TODO: smooth only new data
169 | smoothed_column = self.results[column_name].rolling(
170 | window=window, center=False).mean()
171 | self.results[column_name + '_smoothed'] = smoothed_column
172 |
173 | def save(self, title=None):
174 | """save the json file.
175 | Parameters
176 | ----------
177 | title: string
178 | title of the HTML file
179 | """
180 | title = title or self.title
181 | if len(self.figures) > 0:
182 | if os.path.isfile(self.plot_path):
183 | os.remove(self.plot_path)
184 | if self.first_save:
185 | self.first_save = False
186 | logging.info('Plot file saved at: {}'.format(
187 | os.path.abspath(self.plot_path)))
188 |
189 | output_file(self.plot_path, title=title)
190 | plot = column(
191 | Div(text='
{}
'.format(title)), *self.figures)
192 | save(plot)
193 | self.clear()
194 |
195 | if self.data_format == 'json':
196 | self.results.to_json(self.data_path, orient='records', lines=True)
197 | else:
198 | self.results.to_csv(self.data_path, index=False, index_label=False)
199 |
200 | def load(self, path=None):
201 | """load the data file
202 | Parameters
203 | ----------
204 | path:
205 | path to load the json|csv file from
206 | """
207 | path = path or self.data_path
208 | if os.path.isfile(path):
209 | if self.data_format == 'json':
210 | self.results.read_json(path)
211 | else:
212 | self.results.read_csv(path)
213 | else:
214 | raise ValueError('{} isn''t a file'.format(path))
215 |
216 | def show(self, title=None):
217 | title = title or self.title
218 | if len(self.figures) > 0:
219 | plot = column(
220 | Div(text='{}
'.format(title)), *self.figures)
221 | show(plot)
222 |
223 | def plot(self, *kargs, **kwargs):
224 | """
225 | add a new plot to the HTML file
226 | example:
227 | results.plot(x='epoch', y=['train_loss', 'val_loss'],
228 | 'title='Loss', 'ylabel'='loss')
229 | """
230 | f = plot_figure(self.results, *kargs, **kwargs)
231 | self.figures.append(f)
232 |
233 | def image(self, *kargs, **kwargs):
234 | fig = figure()
235 | fig.image(*kargs, **kwargs)
236 | self.figures.append(fig)
237 |
238 | def end(self):
239 | if hasattr(self, 'hd_experiment'):
240 | self.hd_experiment.end()
241 |
242 |
243 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False):
244 | filename = os.path.join(path, filename)
245 | torch.save(state, filename)
246 | if is_best:
247 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))
248 | if save_all:
249 | shutil.copyfile(filename, os.path.join(
250 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch']))
251 |
--------------------------------------------------------------------------------
/AdaPrune/utils/meters.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self):
8 | self.reset()
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
22 |
23 | class OnlineMeter(object):
24 | """Computes and stores the average and variance/std values of tensor"""
25 |
26 | def __init__(self):
27 | self.mean = torch.FloatTensor(1).fill_(-1)
28 | self.M2 = torch.FloatTensor(1).zero_()
29 | self.count = 0.
30 | self.needs_init = True
31 |
32 | def reset(self, x):
33 | self.mean = x.new(x.size()).zero_()
34 | self.M2 = x.new(x.size()).zero_()
35 | self.count = 0.
36 | self.needs_init = False
37 |
38 | def update(self, x):
39 | self.val = x
40 | if self.needs_init:
41 | self.reset(x)
42 | self.count += 1
43 | delta = x - self.mean
44 | self.mean.add_(delta / self.count)
45 | delta2 = x - self.mean
46 | self.M2.add_(delta * delta2)
47 |
48 | @property
49 | def var(self):
50 | if self.count < 2:
51 | return self.M2.clone().zero_()
52 | return self.M2 / (self.count - 1)
53 |
54 | @property
55 | def std(self):
56 | return self.var().sqrt()
57 |
58 |
59 | def accuracy(output, target, topk=(1,)):
60 | """Computes the precision@k for the specified values of k"""
61 | maxk = max(topk)
62 | batch_size = target.size(0)
63 |
64 | _, pred = output.topk(maxk, 1, True, True)
65 | pred = pred.t().type_as(target)
66 | correct = pred.eq(target.view(1, -1).expand_as(pred))
67 |
68 | res = []
69 | for k in topk:
70 | correct_k = correct[:k].view(-1).float().sum(0)
71 | res.append(correct_k.mul_(100.0 / batch_size))
72 | return res
73 |
74 |
75 | class AccuracyMeter(object):
76 | """Computes and stores the average and current topk accuracy"""
77 |
78 | def __init__(self, topk=(1,)):
79 | self.topk = topk
80 | self.reset()
81 |
82 | def reset(self):
83 | self._meters = {}
84 | for k in self.topk:
85 | self._meters[k] = AverageMeter()
86 |
87 | def update(self, output, target):
88 | n = target.nelement()
89 | acc_vals = accuracy(output, target, self.topk)
90 | for i, k in enumerate(self.topk):
91 | self._meters[k].update(acc_vals[i])
92 |
93 | @property
94 | def val(self):
95 | return {n: meter.val for (n, meter) in self._meters.items()}
96 |
97 | @property
98 | def avg(self):
99 | return {n: meter.avg for (n, meter) in self._meters.items()}
100 |
101 | @property
102 | def avg_error(self):
103 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()}
104 |
--------------------------------------------------------------------------------
/AdaPrune/utils/misc.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
6 |
7 | torch_dtypes = {
8 | 'float': torch.float,
9 | 'float32': torch.float32,
10 | 'float64': torch.float64,
11 | 'double': torch.double,
12 | 'float16': torch.float16,
13 | 'half': torch.half,
14 | 'uint8': torch.uint8,
15 | 'int8': torch.int8,
16 | 'int16': torch.int16,
17 | 'short': torch.short,
18 | 'int32': torch.int32,
19 | 'int': torch.int,
20 | 'int64': torch.int64,
21 | 'long': torch.long
22 | }
23 |
24 |
25 | def onehot(indexes, N=None, ignore_index=None):
26 | """
27 | Creates a one-representation of indexes with N possible entries
28 | if N is not specified, it will suit the maximum index appearing.
29 | indexes is a long-tensor of indexes
30 | ignore_index will be zero in onehot representation
31 | """
32 | if N is None:
33 | N = indexes.max() + 1
34 | sz = list(indexes.size())
35 | output = indexes.new().byte().resize_(*sz, N).zero_()
36 | output.scatter_(-1, indexes.unsqueeze(-1), 1)
37 | if ignore_index is not None and ignore_index >= 0:
38 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0)
39 | return output
40 |
41 |
42 | def set_global_seeds(i):
43 | try:
44 | import torch
45 | except ImportError:
46 | pass
47 | else:
48 | torch.manual_seed(i)
49 | if torch.cuda.is_available():
50 | torch.cuda.manual_seed_all(i)
51 | np.random.seed(i)
52 | random.seed(i)
53 |
54 |
55 | class CheckpointModule(nn.Module):
56 | def __init__(self, module, num_segments=1):
57 | super(CheckpointModule, self).__init__()
58 | assert num_segments == 1 or isinstance(module, nn.Sequential)
59 | self.module = module
60 | self.num_segments = num_segments
61 |
62 | def forward(self, x):
63 | if self.num_segments > 1:
64 | return checkpoint_sequential(self.module, self.num_segments, x)
65 | else:
66 | return checkpoint(self.module, x)
67 |
68 |
69 | def normalize_module_name(layer_name):
70 | """Normalize a module's name.
71 |
72 | PyTorch let's you parallelize the computation of a model, by wrapping a model with a
73 | DataParallel module. Unfortunately, this changs the fully-qualified name of a module,
74 | even though the actual functionality of the module doesn't change.
75 | Many time, when we search for modules by name, we are indifferent to the DataParallel
76 | module and want to use the same module name whether the module is parallel or not.
77 | We call this module name normalization, and this is implemented here.
78 | """
79 | modules = layer_name.split('.')
80 | try:
81 | idx = modules.index('module')
82 | except ValueError:
83 | return layer_name
84 | del modules[idx]
85 | return '.'.join(modules)
86 |
--------------------------------------------------------------------------------
/AdaPrune/utils/mixup.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from numpy.random import beta
4 | from .misc import onehot
5 |
6 |
7 | class MixUp(nn.Module):
8 | def __init__(self, batch_dim=0):
9 | super(MixUp, self).__init__()
10 | self.batch_dim = batch_dim
11 | self.reset()
12 |
13 | def reset(self):
14 | self.enabled = False
15 | self.mix_values = None
16 | self.mix_index = None
17 |
18 | def mix(self, x1, x2):
19 | if not torch.is_tensor(self.mix_values): # scalar
20 | return x2.lerp(x1, self.mix_values)
21 | else:
22 | view = [1] * int(x1.dim())
23 | view[self.batch_dim] = -1
24 | mix_val = self.mix_values.to(device=x1.device).view(*view)
25 | return mix_val * x1 + (1.-mix_val) * x2
26 |
27 | def sample(self, alpha, batch_size, sample_batch=False):
28 | self.mix_index = torch.randperm(batch_size)
29 | if sample_batch:
30 | values = beta(alpha, alpha, size=batch_size)
31 | self.mix_values = torch.tensor(values, dtype=torch.float)
32 | else:
33 | self.mix_values = torch.tensor([beta(alpha, alpha)],
34 | dtype=torch.float)
35 |
36 | def mix_target(self, y, n_class):
37 | if not self.training or \
38 | self.mix_values is None or\
39 | self.mix_values is None:
40 | return y
41 | y = onehot(y, n_class).to(dtype=torch.float)
42 | idx = self.mix_index.to(device=y.device)
43 | y_mix = y.index_select(self.batch_dim, idx)
44 | return self.mix(y, y_mix)
45 |
46 | def forward(self, x):
47 | if not self.training or \
48 | self.mix_values is None or\
49 | self.mix_values is None:
50 | return x
51 | idx = self.mix_index.to(device=x.device)
52 | x_mix = x.index_select(self.batch_dim, idx)
53 | return self.mix(x, x_mix)
54 |
--------------------------------------------------------------------------------
/AdaPrune/utils/optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging.config
3 | from copy import deepcopy
4 | from six import string_types
5 | from .regime import Regime
6 | from .param_filter import FilterParameters
7 | from . import regularization
8 | import torch.nn as nn
9 |
10 | _OPTIMIZERS = {name: func for name, func in torch.optim.__dict__.items()}
11 |
12 | try:
13 | from adabound import AdaBound
14 | _OPTIMIZERS['AdaBound'] = AdaBound
15 | except ImportError:
16 | pass
17 |
18 |
19 | def copy_params(param_target, param_src):
20 | with torch.no_grad():
21 | for p_src, p_target in zip(param_src, param_target):
22 | p_target.copy_(p_src)
23 |
24 |
25 | def copy_params_grad(param_target, param_src):
26 | for p_src, p_target in zip(param_src, param_target):
27 | if p_target.grad is None:
28 | p_target.backward(p_src.grad.to(dtype=p_target.dtype))
29 | else:
30 | p_target.grad.detach().copy_(p_src.grad)
31 |
32 |
33 | class ModuleFloatShadow(nn.Module):
34 | def __init__(self, module):
35 | super(ModuleFloatShadow, self).__init__()
36 | self.original_module = module
37 | self.float_module = deepcopy(module)
38 | self.float_module.to(dtype=torch.float)
39 |
40 | def parameters(self, *kargs, **kwargs):
41 | return self.float_module.parameters(*kargs, **kwargs)
42 |
43 | def named_parameters(self, *kargs, **kwargs):
44 | return self.float_module.named_parameters(*kargs, **kwargs)
45 |
46 | def modules(self, *kargs, **kwargs):
47 | return self.float_module.modules(*kargs, **kwargs)
48 |
49 | def named_modules(self, *kargs, **kwargs):
50 | return self.float_module.named_modules(*kargs, **kwargs)
51 |
52 | def original_parameters(self, *kargs, **kwargs):
53 | return self.original_module.parameters(*kargs, **kwargs)
54 |
55 | def original_named_parameters(self, *kargs, **kwargs):
56 | return self.original_module.named_parameters(*kargs, **kwargs)
57 |
58 | def original_modules(self, *kargs, **kwargs):
59 | return self.original_module.modules(*kargs, **kwargs)
60 |
61 | def original_named_modules(self, *kargs, **kwargs):
62 | return self.original_module.named_modules(*kargs, **kwargs)
63 |
64 |
65 | class OptimRegime(Regime):
66 | """
67 | Reconfigures the optimizer according to setting list.
68 | Exposes optimizer methods - state, step, zero_grad, add_param_group
69 |
70 | Examples for regime:
71 |
72 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3},
73 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4},
74 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4},
75 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5}
76 | ]"
77 | 2)
78 | "[{'step_lambda':
79 | "lambda t: {
80 | 'optimizer': 'Adam',
81 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5),
82 | 'betas': (0.9, 0.98), 'eps':1e-9}
83 | }]"
84 | """
85 |
86 | def __init__(self, model, regime, defaults={}, filter=None, use_float_copy=False):
87 | super(OptimRegime, self).__init__(regime, defaults)
88 | if filter is not None:
89 | model = FilterParameters(model, **filter)
90 | if use_float_copy:
91 | model = ModuleFloatShadow(model)
92 | self._original_parameters = list(model.original_parameters())
93 |
94 | self.parameters = list(model.parameters())
95 | self.optimizer = torch.optim.SGD(self.parameters, lr=0)
96 | self.regularizer = regularization.Regularizer(model)
97 | self.use_float_copy = use_float_copy
98 |
99 | def update(self, epoch=None, train_steps=None):
100 | """adjusts optimizer according to current epoch or steps and training regime.
101 | """
102 | if super(OptimRegime, self).update(epoch, train_steps):
103 | self.adjust(self.setting)
104 | return True
105 | else:
106 | return False
107 |
108 | def adjust(self, setting):
109 | """adjusts optimizer according to a setting dict.
110 | e.g: setting={optimizer': 'Adam', 'lr': 5e-4}
111 | """
112 | if 'optimizer' in setting:
113 | optim_method = _OPTIMIZERS[setting['optimizer']]
114 | if not isinstance(self.optimizer, optim_method):
115 | self.optimizer = optim_method(self.optimizer.param_groups)
116 | logging.debug('OPTIMIZER - setting method = %s' %
117 | setting['optimizer'])
118 | for param_group in self.optimizer.param_groups:
119 | for key in param_group.keys():
120 | if key in setting:
121 | new_val = setting[key]
122 | if new_val != param_group[key]:
123 | logging.debug('OPTIMIZER - setting %s = %s' %
124 | (key, setting[key]))
125 | param_group[key] = setting[key]
126 | # fix for AdaBound
127 | if key == 'lr' and hasattr(self.optimizer, 'base_lrs'):
128 | self.optimizer.base_lrs = list(
129 | map(lambda group: group['lr'], self.optimizer.param_groups))
130 |
131 | if 'regularizer' in setting:
132 | reg_list = deepcopy(setting['regularizer'])
133 | if not (isinstance(reg_list, list) or isinstance(reg_list, tuple)):
134 | reg_list = (reg_list,)
135 | regularizers = []
136 | for reg in reg_list:
137 | if isinstance(reg, dict):
138 | logging.debug('OPTIMIZER - Regularization - %s' % reg)
139 | name = reg.pop('name')
140 | regularizers.append((regularization.__dict__[name], reg))
141 | elif isinstance(reg, regularization.Regularizer):
142 | regularizers.append(reg)
143 | else: # callable on model
144 | regularizers.append(reg(self.regularizer._model))
145 | self.regularizer = regularization.RegularizerList(self.regularizer._model,
146 | regularizers)
147 |
148 | def __getstate__(self):
149 | return {
150 | 'optimizer_state': self.optimizer.__getstate__(),
151 | 'regime': self.regime,
152 | }
153 |
154 | def __setstate__(self, state):
155 | self.regime = state.get('regime')
156 | self.optimizer.__setstate__(state.get('optimizer_state'))
157 |
158 | def state_dict(self):
159 | """Returns the state of the optimizer as a :class:`dict`.
160 | """
161 | return {
162 | 'optimizer_state': self.optimizer.state_dict(),
163 | 'regime': self.regime,
164 | }
165 |
166 | def load_state_dict(self, state_dict):
167 | """Loads the optimizer state.
168 |
169 | Arguments:
170 | state_dict (dict): optimizer state. Should be an object returned
171 | from a call to :meth:`state_dict`.
172 | """
173 | # deepcopy, to be consistent with module API
174 | optimizer_state_dict = state_dict['optimizer_state']
175 |
176 | self.__setstate__({'optimizer_state': optimizer_state_dict,
177 | 'regime': state_dict['regime']})
178 |
179 | def zero_grad(self):
180 | """Clears the gradients of all optimized :class:`Variable` s."""
181 | self.optimizer.zero_grad()
182 | if self.use_float_copy:
183 | for p in self._original_parameters:
184 | if p.grad is not None:
185 | p.grad.detach().zero_()
186 |
187 | def step(self, closure=None):
188 | """Performs a single optimization step (parameter update).
189 |
190 | Arguments:
191 | closure (callable): A closure that reevaluates the model and
192 | returns the loss. Optional for most optimizers.
193 | """
194 | if self.use_float_copy:
195 | copy_params_grad(self.parameters, self._original_parameters)
196 | self.regularizer.pre_step()
197 | self.optimizer.step(closure)
198 | self.regularizer.post_step()
199 | if self.use_float_copy:
200 | copy_params(self._original_parameters, self.parameters)
201 |
202 | def pre_forward(self):
203 | """ allows modification pre-forward pass - e.g for regularization
204 | """
205 | self.regularizer.pre_forward()
206 |
207 | def pre_backward(self):
208 | """ allows modification post-forward pass and pre-backward - e.g for regularization
209 | """
210 | self.regularizer.pre_backward()
211 |
212 |
213 | class MultiOptimRegime(OptimRegime):
214 |
215 | def __init__(self, *optim_regime_list):
216 | self.optim_regime_list = []
217 | for optim_regime in optim_regime_list:
218 | assert isinstance(optim_regime, OptimRegime)
219 | self.optim_regime_list.append(optim_regime)
220 |
221 | def update(self, epoch=None, train_steps=None):
222 | """adjusts optimizer according to current epoch or steps and training regime.
223 | """
224 | updated = False
225 | for i, optim in enumerate(self.optim_regime_list):
226 | current_updated = optim.update(epoch, train_steps)
227 | if current_updated:
228 | logging.debug('OPTIMIZER #%s was updated' % i)
229 | updated = updated or current_updated
230 | return updated
231 |
232 | def zero_grad(self):
233 | """Clears the gradients of all optimized :class:`Variable` s."""
234 | for optim in self.optim_regime_list:
235 | optim.zero_grad()
236 |
237 | def step(self, closure=None):
238 | """Performs a single optimization step (parameter update).
239 |
240 | Arguments:
241 | closure (callable): A closure that reevaluates the model and
242 | returns the loss. Optional for most optimizers.
243 | """
244 | for optim in self.optim_regime_list:
245 | optim.step(closure)
246 |
--------------------------------------------------------------------------------
/AdaPrune/utils/param_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def is_not_bias(name):
6 | return not name.endswith('bias')
7 |
8 |
9 | def is_bn(module):
10 | return isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d)
11 |
12 |
13 | def is_not_bn(module):
14 | return not is_bn(module)
15 |
16 |
17 | def filtered_parameter_info(model, module_fn=None, module_name_fn=None, parameter_name_fn=None, memo=None):
18 | if memo is None:
19 | memo = set()
20 |
21 | for module_name, module in model.named_modules():
22 | if module_fn is not None and not module_fn(module):
23 | continue
24 | if module_name_fn is not None and not module_name_fn(module_name):
25 | continue
26 | for parameter_name, param in module.named_parameters(prefix=module_name, recurse=False):
27 | if parameter_name_fn is not None and not parameter_name_fn(parameter_name):
28 | continue
29 | if param not in memo:
30 | memo.add(param)
31 | yield {'named_module': (module_name, module), 'named_parameter': (parameter_name, param)}
32 |
33 |
34 | class FilterParameters(object):
35 | def __init__(self, source, module=None, module_name=None, parameter_name=None):
36 | if isinstance(source, FilterParameters):
37 | self._filtered_parameter_info = list(source.filter(
38 | module=module,
39 | module_name=module_name,
40 | parameter_name=parameter_name))
41 | elif isinstance(source, torch.nn.Module): # source is a model
42 | self._filtered_parameter_info = list(filtered_parameter_info(source,
43 | module_fn=module,
44 | module_name_fn=module_name,
45 | parameter_name_fn=parameter_name))
46 |
47 | def named_parameters(self):
48 | for p in self._filtered_parameter_info:
49 | yield p['named_parameter']
50 |
51 | def parameters(self):
52 | for _, p in self.named_parameters():
53 | yield p
54 |
55 | def filter(self, module=None, module_name=None, parameter_name=None):
56 | for p_info in self._filtered_parameter_info:
57 | if (module is None or module(p_info['named_module'][1])
58 | and (module_name is None or module_name(p_info['named_module'][0]))
59 | and (parameter_name is None or parameter_name(p_info['named_parameter'][0]))):
60 | yield p_info
61 |
62 | def named_modules(self):
63 | for m in self._filtered_parameter_info:
64 | yield m['named_module']
65 |
66 | def modules(self):
67 | for _, m in self.named_modules():
68 | yield m
69 |
70 | def to(self, *kargs, **kwargs):
71 | for m in self.modules():
72 | m.to(*kargs, **kwargs)
73 |
74 |
75 | class FilterModules(FilterParameters):
76 | pass
77 |
78 | if __name__ == '__main__':
79 | from torchvision.models import resnet50
80 | model = resnet50()
81 | filterd_params = FilterParameters(model,
82 | module=lambda m: isinstance(
83 | m, torch.nn.Linear),
84 | parameter_name=lambda n: 'bias' in n)
85 |
--------------------------------------------------------------------------------
/AdaPrune/utils/regime.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 | from six import string_types
4 |
5 |
6 | def eval_func(f, x):
7 | if isinstance(f, string_types):
8 | f = eval(f)
9 | return f(x)
10 |
11 |
12 | class Regime(object):
13 | """
14 | Examples for regime:
15 |
16 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3},
17 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4},
18 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4},
19 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5}
20 | ]"
21 | 2)
22 | "[{'step_lambda':
23 | "lambda t: {
24 | 'optimizer': 'Adam',
25 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5),
26 | 'betas': (0.9, 0.98), 'eps':1e-9}
27 | }]"
28 | """
29 |
30 | def __init__(self, regime, defaults={}):
31 | self.regime = regime
32 | self.current_regime_phase = None
33 | self.setting = defaults
34 |
35 | def update(self, epoch=None, train_steps=None):
36 | """adjusts according to current epoch or steps and regime.
37 | """
38 | if self.regime is None:
39 | return False
40 | epoch = -1 if epoch is None else epoch
41 | train_steps = -1 if train_steps is None else train_steps
42 | setting = deepcopy(self.setting)
43 | if self.current_regime_phase is None:
44 | # Find the first entry where the epoch is smallest than current
45 | for regime_phase, regime_setting in enumerate(self.regime):
46 | start_epoch = regime_setting.get('epoch', 0)
47 | start_step = regime_setting.get('step', 0)
48 | if epoch >= start_epoch or train_steps >= start_step:
49 | self.current_regime_phase = regime_phase
50 | break
51 | # each entry is updated from previous
52 | setting.update(regime_setting)
53 | if len(self.regime) > self.current_regime_phase + 1:
54 | next_phase = self.current_regime_phase + 1
55 | # Any more regime steps?
56 | start_epoch = self.regime[next_phase].get('epoch', float('inf'))
57 | start_step = self.regime[next_phase].get('step', float('inf'))
58 | if epoch >= start_epoch or train_steps >= start_step:
59 | self.current_regime_phase = next_phase
60 | setting.update(self.regime[self.current_regime_phase])
61 |
62 | if 'lr_decay_rate' in setting and 'lr' in setting:
63 | decay_steps = setting.pop('lr_decay_steps', 100)
64 | if train_steps % decay_steps == 0:
65 | decay_rate = setting.pop('lr_decay_rate')
66 | setting['lr'] *= decay_rate ** (train_steps / decay_steps)
67 | elif 'step_lambda' in setting:
68 | setting.update(eval_func(setting.pop('step_lambda'), train_steps))
69 | elif 'epoch_lambda' in setting:
70 | setting.update(eval_func(setting.pop('epoch_lambda'), epoch))
71 |
72 | if 'execute' in setting:
73 | setting.pop('execute')()
74 |
75 | if 'execute_once' in setting:
76 | setting.pop('execute_once')()
77 | # remove from regime, so won't happen again
78 | self.regime[self.current_regime_phase].pop('execute_once', None)
79 |
80 | if setting == self.setting:
81 | return False
82 | else:
83 | self.setting = setting
84 | return True
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Accelerated Sparse Neural Training: A Provable and Efficient Method to FindN:M Transposable Masks
2 | Recently, researchers proposed pruning deep neural network weights (DNNs) using an $N:M$ fine-grained block sparsity mask. In this mask, for each block of M weights, we have at least N zeros. In contrast to unstructured sparsity, N:M fine-grained block sparsity allows acceleration in actual modern hardware. Previously suggested solutions enabled DNN acceleration at the inference phase. To also allow such acceleration in the training phase, we suggest a novel transposable-fine-grained sparsity mask where the same mask can be used for both forward and backward passes. Our transposable mask ensures that both the weight matrix and its transpose follow the same sparsity pattern; thus the matrix multiplication required for passing the error backward can also be accelerated. We discuss the transposable constraint and devise a new measure for mask constraints, called mask-diversity (MD), which correlates with their expected accuracy. Lastly, we formulate the problem of finding the optimal transposable mask as a minimum-cost-flow problem and suggest a fast linear approximation that can be used when the masks dynamically change while training. Our experiments suggest 2x speed-up with no accuracy degradation over vision and language models. A reference implementation is available in the supplementary material.
3 | ## Reproducing the results
4 |
5 | This repository is partially based on [convNet.pytorch](https://github.com/eladhoffer/convNet.pytorch) repo. please ensure that you are using pytorch 1.7+.
6 | Reproducing AdaPrune results
7 | ```bash
8 | cd AdaPrune
9 | sh scripts/adaprune_dense_bnt.sh
10 | sh scripts/adaprune_sparse.sh
11 | ```
12 | Reproducing static NM-transposable starting from dense pre-trained model:
13 | ```bash
14 | cd static_TNM
15 | sh scripts/prune_pretrained_R50.sh
16 | ```
17 | Reproducing dynamic NM-transposable from scratch:
18 | ```bash
19 | cd dynamic_TNM
20 | sh scripts/clone_and_copy.sh
21 | sh scripts/run_R18.sh
22 | sh scripts/run_R50.sh
23 | ```
24 |
--------------------------------------------------------------------------------
/common/flatten_object.py:
--------------------------------------------------------------------------------
1 | def flatten_object(obj, delimiter='.', prefix=''):
2 | def flatten(x, name=prefix):
3 | if isinstance(x, dict):
4 | for a in x:
5 | flatten(x[a], name + a + delimiter)
6 | elif isinstance(x, list) or isinstance(x, tuple):
7 | for i, a in enumerate(x):
8 | flatten(a, name + str(i) + delimiter)
9 | else:
10 | out[name[:-1]] = x
11 |
12 | out = {}
13 | flatten(obj)
14 | return out
15 |
--------------------------------------------------------------------------------
/common/json_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def json_is_serializable(obj):
5 | serializable = True
6 | try:
7 | json.dumps(obj)
8 | except TypeError:
9 | serializable = False
10 | return serializable
11 |
12 |
13 | def json_force_serializable(obj):
14 | if isinstance(obj, dict):
15 | for k, v in obj.items():
16 | obj[k] = json_force_serializable(v)
17 | elif isinstance(obj, list):
18 | for i, v in enumerate(obj):
19 | obj[i] = json_force_serializable(v)
20 | elif isinstance(obj, tuple):
21 | obj = list(obj)
22 | for i, v in enumerate(obj):
23 | obj[i] = json_force_serializable(v)
24 | obj = tuple(obj)
25 | elif not json_is_serializable(obj):
26 | obj = 'filtered by json'
27 | return obj
28 |
--------------------------------------------------------------------------------
/common/timer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 |
4 | class Timer:
5 | def __init__(self):
6 | self.start = None
7 | self.final = None
8 |
9 | def __enter__(self):
10 | self.start = datetime.datetime.now()
11 | return self
12 |
13 | def __exit__(self, exception_type, exception_value, traceback):
14 | self.final = datetime.datetime.now() - self.start
15 |
16 | def total(self):
17 | if self.final is None:
18 | raise RuntimeError('Timer total called before exit start={}'.format(self.start))
19 | return self.final
20 |
21 | def elapsed(self):
22 | if self.start is None:
23 | raise RuntimeError('Timer elapsed called before start')
24 | return datetime.datetime.now() - self.start
25 |
--------------------------------------------------------------------------------
/dynamic_TNM/scripts/clone_and_copy.sh:
--------------------------------------------------------------------------------
1 | git clone https://github.com/NM-sparsity/NM-sparsity.git
2 | cd NM-sparsity
3 | git checkout d8419d99ad84ae47e3581db0125ed375ee416bb3
4 | cd ..
5 | cp src/dist_utils.py NM-sparsity/devkit/core/
6 | cp src/sparse_ops.py NM-sparsity/devkit/sparse_ops/
7 | cp src/train_imagenet.py NM-sparsity/classification/train_imagenet.py
8 | cp src/resnet.py NM-sparsity/classification/models/
9 | cp src/train_val.sh NM-sparsity/classification
10 | cp src/sparse_ops_init.py NM-sparsity/devkit/sparse_ops/__init__.py
11 | cp src/utils.py NM-sparsity/devkit/core/
12 |
--------------------------------------------------------------------------------
/dynamic_TNM/scripts/run_R18.sh:
--------------------------------------------------------------------------------
1 | cd NM-sparsity/classification/
2 | sh train_val.sh ../../src/configs/config_resnet18_4by8_transpose.yaml
3 |
--------------------------------------------------------------------------------
/dynamic_TNM/scripts/run_R50.sh:
--------------------------------------------------------------------------------
1 | cd NM-sparsity/classification/
2 | sh train_val.sh ../../src/configs/config_resnet50_4by8_transpose.yaml
3 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/configs/config_resnet18_4by8_transpose.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | model: resnet18
3 | N: 4
4 | M: 8
5 | sparse_optimizer: 1
6 | load_mask: True
7 | init_mask: False
8 | save_mask: False
9 | mask_path: 'path_to_masks/'
10 |
11 | # gpu: 3
12 |
13 | workers: 3
14 | batch_size: 512
15 | epochs: 120
16 |
17 | lr_mode : cosine
18 | base_lr: 0.2
19 | warmup_epochs: 5
20 | warmup_lr: 0.0
21 | targetlr : 0.0
22 |
23 | momentum: 0.9
24 | weight_decay: 0.00005
25 |
26 |
27 | print_freq: 100
28 | model_dir: checkpoint/resnet18_4by8
29 |
30 | train_root: /path_to/imagenet/train
31 | val_root: /path_to/imagenet/val
32 |
33 |
34 |
35 |
36 | TEST:
37 | checkpoint_path : data/pretrained_model/
38 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/configs/config_resnet50_4by8_transpose.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | model: resnet50
3 | N: 4
4 | M: 8
5 | sparse_optimizer: 1
6 | load_mask: True
7 | init_mask: False
8 | save_mask: False
9 | mask_path: 'path_tomasks/'
10 |
11 | # gpu: 3
12 |
13 |
14 | workers: 3
15 | batch_size: 512
16 | epochs: 120
17 |
18 | lr_mode : cosine
19 | base_lr: 0.2
20 | warmup_epochs: 5
21 | warmup_lr: 0.0
22 | targetlr : 0.0
23 |
24 | momentum: 0.9
25 | weight_decay: 0.00005
26 |
27 |
28 | print_freq: 100
29 | model_dir: checkpoint/resnet50_4by8
30 |
31 | train_root: /path_to/imagenet/train
32 | val_root: /path_to/imagenet/val
33 |
34 |
35 |
36 |
37 | TEST:
38 | checkpoint_path : data/pretrained_model/
39 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/configs/config_resnext50_4by8_transpose.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | model: resnext50_32x4d
3 | N: 4
4 | M: 8
5 | sparse_optimizer: 1
6 | load_mask: False
7 | init_mask: False
8 | save_mask: True
9 | mask_path: 'path_to_masks/'
10 |
11 | # gpu: 3
12 |
13 |
14 | workers: 3
15 | batch_size: 512
16 | epochs: 120
17 |
18 | lr_mode : cosine
19 | base_lr: 0.2
20 | warmup_epochs: 5
21 | warmup_lr: 0.0
22 | targetlr : 0.0
23 |
24 | momentum: 0.9
25 | weight_decay: 0.00005
26 |
27 |
28 | print_freq: 100
29 | model_dir: checkpoint/resnet50_4by8
30 |
31 | train_root: /path_to/imagenet/train
32 | val_root: /path_to/imagenet/val
33 |
34 |
35 |
36 |
37 | TEST:
38 | checkpoint_path : data/pretrained_model/
39 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/dist_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.multiprocessing as mp
4 | import torch.distributed as dist
5 |
6 | __all__ = [
7 | 'init_dist', 'broadcast_params','average_gradients']
8 |
9 | def init_dist(backend='nccl',
10 | master_ip='127.0.0.1',
11 | port=29500):
12 | #if mp.get_start_method(allow_none=True) is None:
13 | # mp.set_start_method('spawn')
14 | #os.environ['MASTER_ADDR'] = master_ip
15 | #os.environ['MASTER_PORT'] = str(port)
16 | rank = int(os.environ['RANK'])
17 | world_size = int(os.environ['WORLD_SIZE'])
18 | num_gpus = torch.cuda.device_count()
19 | #import pdb; pdb.set_trace()
20 | #torch.cuda.set_device(rank % num_gpus)
21 | #dist.init_process_group(backend=backend,init_method='tcp://127.0.0.1:6320',world_size=world_size,rank=rank)
22 | print('INIT')
23 | return rank, world_size,num_gpus,backend
24 |
25 | def average_gradients(model):
26 | for param in model.parameters():
27 | if param.requires_grad and not (param.grad is None):
28 | dist.all_reduce(param.grad.data)
29 |
30 | def broadcast_params(model):
31 | for p in model.state_dict().values():
32 | dist.broadcast(p, 0)
33 |
34 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import sys
4 | import os.path as osp
5 | sys.path.append(osp.abspath(osp.join(__file__, '../../../')))
6 | #from devkit.ops import SyncBatchNorm2d
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import autograd
10 | from torch.nn.modules.utils import _pair as pair
11 | from torch.nn import init
12 | #from devkit.sparse_ops import SparseConv
13 | from devkit.sparse_ops import SparseConvTranspose as SparseConv
14 |
15 |
16 |
17 | __all__ = ['ResNetV1', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
18 | 'resnet152']
19 |
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1, N=2, M=4):
23 | """3x3 convolution with padding"""
24 | return SparseConv(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=1, bias=False, N=N, M=M)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None, N=2, M=4):
32 | super(BasicBlock, self).__init__()
33 |
34 | self.conv1 = conv3x3(inplanes, planes, stride, N=N, M=M)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv3x3(planes, planes, N=N, M=M)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 |
55 | out += residual
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None, N=2, M=4):
64 | super(Bottleneck, self).__init__()
65 |
66 | self.conv1 = SparseConv(inplanes, planes, kernel_size=1, bias=False, N=N, M=M)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = SparseConv(planes, planes, kernel_size=3, stride=stride,
69 | padding=1, bias=False, N=N, M=M)
70 | self.bn2 = nn.BatchNorm2d(planes)
71 | self.conv3 = SparseConv(planes, planes * 4, kernel_size=1, bias=False, N=N, M=M)
72 | self.bn3 = nn.BatchNorm2d(planes * 4)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | residual = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | residual = self.downsample(x)
93 |
94 | out += residual
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 | class ResNetV1(nn.Module):
100 |
101 | def __init__(self, block, layers, num_classes=1000, N=2, M=4):
102 | super(ResNetV1, self).__init__()
103 |
104 |
105 | self.N = N
106 | self.M = M
107 |
108 | self.inplanes = 64
109 | self.conv1 = SparseConv(3, 64, kernel_size=7, stride=2, padding=3,
110 | bias=False, N=self.N, M=self.M)
111 | self.bn1 = nn.BatchNorm2d(64)
112 | self.relu = nn.ReLU(inplace=True)
113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
114 | self.layer1 = self._make_layer(block, 64, layers[0], N = self.N, M = self.M)
115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, N = self.N, M = self.M)
116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, N = self.N, M = self.M)
117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, N = self.N, M = self.M)
118 | self.avgpool = nn.AvgPool2d(7, stride=1)
119 | self.fc = nn.Linear(512 * block.expansion, num_classes)
120 |
121 | for m in self.modules():
122 | if isinstance(m, SparseConv):
123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
124 | m.weight.data.normal_(0, math.sqrt(2. / n))
125 |
126 | def _make_layer(self, block, planes, blocks, stride=1, N = 2, M = 4):
127 | downsample = None
128 | if stride != 1 or self.inplanes != planes * block.expansion:
129 | downsample = nn.Sequential(
130 | SparseConv(self.inplanes, planes * block.expansion,
131 | kernel_size=1, stride=stride, bias=False, N=N, M=M),
132 | nn.BatchNorm2d(planes * block.expansion),
133 | )
134 |
135 | layers = []
136 | layers.append(block(self.inplanes, planes, stride, downsample, N=N, M=M))
137 | self.inplanes = planes * block.expansion
138 | for i in range(1, blocks):
139 | layers.append(block(self.inplanes, planes, N=N, M=M))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 |
145 | x = self.conv1(x)
146 | x = self.bn1(x)
147 | x = self.relu(x)
148 | x = self.maxpool(x)
149 |
150 | x = self.layer1(x)
151 | x = self.layer2(x)
152 | x = self.layer3(x)
153 | x = self.layer4(x)
154 |
155 | x = self.avgpool(x)
156 | x = x.view(x.size(0), -1)
157 | x = self.fc(x)
158 |
159 | return x
160 |
161 |
162 | def resnet18(**kwargs):
163 | model = ResNetV1(BasicBlock, [2, 2, 2, 2], **kwargs)
164 | return model
165 |
166 |
167 | def resnet34(**kwargs):
168 | model = ResNetV1(BasicBlock, [3, 4, 6, 3], **kwargs)
169 | return model
170 |
171 |
172 | def resnet50(**kwargs):
173 | model = ResNetV1(Bottleneck, [3, 4, 6, 3], **kwargs)
174 | return model
175 |
176 |
177 | def resnet101(**kwargs):
178 | model = ResNetV1(Bottleneck, [3, 4, 23, 3], **kwargs)
179 | return model
180 |
181 |
182 | def resnet152(**kwargs):
183 | model = ResNetV1(Bottleneck, [3, 8, 36, 3], **kwargs)
184 | return model
185 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/sparse_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import autograd, nn
3 | import torch.nn.functional as F
4 |
5 | from itertools import repeat
6 | from torch._six import container_abcs
7 | import time
8 | from prune.pruning_method_transposable_block_l1 import PruningMethodTransposableBlockL1
9 |
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, container_abcs.Iterable):
13 | return x
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 | _single = _ntuple(1)
18 | _pair = _ntuple(2)
19 | _triple = _ntuple(3)
20 | _quadruple = _ntuple(4)
21 |
22 | def update_mask_approx2(data, mask, topk=4,BS=8):
23 | mask.fill_(0)
24 | Co = data.shape[0]
25 | #topk=BS//2
26 | _,idx_sort = data.sort(1,descending=True); #block x 64
27 | for k in range(BS**2):
28 | if k 0
70 | mask[mask_update] = new_mask[mask_update]
71 | if sum(mask_update) == 0:
72 | break
73 | return mask
74 |
75 | class SparseTranspose(autograd.Function):
76 | """" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using STE in the backwards phase"""
77 |
78 | @staticmethod
79 | def forward(ctx, weight, N, M, counter, freq, absorb_mean):
80 | weight.mask = weight.mask.to(weight)
81 | output = weight.clone()
82 | if counter%freq==0:
83 | weight_temp = weight.detach().abs().reshape(-1, M*M)
84 | weight_mask = weight.mask.detach().reshape(-1, M*M)
85 | #weight_mask = update_mask(weight_temp,weight_mask,BS=M)
86 | weight_mask = update_mask_approx2(weight_temp,weight_mask,BS=M)
87 | if absorb_mean:
88 | output = output.reshape(-1, M*M).clone()
89 | output+=output.mul(1-weight_mask).mean(1)
90 | output=output.reshape(weight.shape)
91 | weight.mask=weight_mask.reshape(weight.shape)
92 | return output*weight.mask, weight.mask
93 |
94 | @staticmethod
95 | def backward(ctx, grad_output, _):
96 | return grad_output, None, None, None, None, None
97 |
98 |
99 | class Sparse(autograd.Function):
100 | """" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using STE in the backwards phase"""
101 |
102 | @staticmethod
103 | def forward(ctx, weight, N, M):
104 |
105 | output = weight.clone()
106 | length = weight.numel()
107 | group = int(length/M)
108 |
109 | weight_temp = weight.detach().abs().reshape(group, M)
110 | index = torch.argsort(weight_temp, dim=1)[:, :int(M-N)]
111 |
112 | w_b = torch.ones(weight_temp.shape, device=weight_temp.device)
113 | w_b = w_b.scatter_(dim=1, index=index, value=0).reshape(weight.shape)
114 |
115 | return output*w_b, w_b
116 |
117 |
118 | @staticmethod
119 | def backward(ctx, grad_output, _):
120 | return grad_output, None, None
121 |
122 | class SparseTransposeV2(autograd.Function):
123 | """" Prune the unimprotant edges for the forwards phase but pass the gradient to dense weight using STE in the backwards phase"""
124 |
125 | @staticmethod
126 | def forward(ctx, weight, N, M, counter):
127 | weight.mask = weight.mask.to(weight)
128 | output = weight.reshape(-1, M*M).clone()
129 | weight_mask = weight.mask.reshape(-1, M*M)
130 | output+=torch.mean(output.mul(1-weight_mask),dim=1,keepdim=True)
131 | weight.mask=weight_mask.reshape(weight.shape)
132 | output=output.reshape(weight.shape)
133 | return output*weight.mask, weight.mask
134 |
135 | @staticmethod
136 | def backward(ctx, grad_output, _):
137 | return grad_output, None, None, None
138 |
139 | class SparseConvTranspose(nn.Conv2d):
140 |
141 |
142 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', N=2, M=4, **kwargs):
143 | self.N = N
144 | self.M = M
145 | self.counter = 0
146 | self.freq = 1
147 | self.absorb_mean = False
148 | super(SparseConvTranspose, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kwargs)
149 |
150 |
151 | def get_sparse_weights(self):
152 | return SparseTranspose.apply(self.weight, self.N, self.M, self.counter, self.freq, self.absorb_mean)
153 |
154 |
155 |
156 | def forward(self, x):
157 | if self.training:
158 | self.counter+=1
159 | self.freq = 40 #min(self.freq+self.counter//100,100)
160 | w, mask = self.get_sparse_weights()
161 | setattr(self.weight, "mask", mask)
162 | else:
163 | w = self.weight * self.weight.mask
164 | x = F.conv2d(
165 | x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
166 | )
167 | return x
168 |
169 | class SparseConvTransposeV2(nn.Conv2d):
170 |
171 |
172 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', N=2, M=4, **kwargs):
173 | self.N = N
174 | self.M = M
175 | self.counter = 0
176 | self.freq = 1
177 | self.rerun_ip = 0.01
178 | self.ipClass = PruningMethodTransposableBlockL1(block_size=self.M, topk=self.N)
179 | super(SparseConvTransposeV2, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kwargs)
180 |
181 |
182 | def get_sparse_weights(self):
183 | with torch.no_grad():
184 | weight_temp = self.weight.detach().abs().reshape(-1, self.M*self.M)
185 | weight_mask = self.weight.mask.detach().reshape(-1, self.M*self.M)
186 | num_samples_ip= int(self.rerun_ip*weight_temp.shape[0])
187 | idx=torch.randperm(weight_temp.shape[0])[:num_samples_ip]
188 | sample_weight = weight_temp[idx]
189 | mask_new = self.ipClass.compute_mask(sample_weight,torch.ones_like(sample_weight))
190 | weight_mask = weight_mask.to(self.weight.device)
191 | weight_mask[idx]=mask_new.to(self.weight.device)
192 | return SparseTransposeV2.apply(self.weight, self.N, self.M, self.counter)
193 |
194 | def forward(self, x):
195 | # self.counter+=1
196 | # self.freq = min(self.freq+self.counter//100,100)
197 | w, mask = self.get_sparse_weights()
198 | setattr(self.weight, "mask", mask)
199 | x = F.conv2d(
200 | x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
201 | )
202 | return x
203 |
204 | class SparseConv(nn.Conv2d):
205 |
206 |
207 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', N=2, M=4, **kwargs):
208 | self.N = N
209 | self.M = M
210 | super(SparseConv, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kwargs)
211 |
212 |
213 | def get_sparse_weights(self):
214 |
215 | return Sparse.apply(self.weight, self.N, self.M)
216 |
217 |
218 |
219 | def forward(self, x):
220 |
221 | w, mask = self.get_sparse_weights()
222 | setattr(self.weight, "mask", mask)
223 | x = F.conv2d(
224 | x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
225 | )
226 | return x
227 |
228 | class SparseLinear(nn.Linear):
229 | def __init__():
230 |
231 | self.N = N
232 | self.M = M
233 |
234 |
235 |
236 |
237 | class SparseLinearTranspose(nn.Linear):
238 |
239 | def __init__(self, in_channels, out_channels, bias=True, N=2, M=4, **kwargs):
240 | self.N = N
241 | self.M = M
242 | self.counter = 0
243 | self.freq = 10
244 | super(SparseLinearTranspose, self).__init__(in_channels, out_channels, bias,)
245 |
246 | def get_sparse_weights(self):
247 | return SparseTranspose.apply(self.weight, self.N, self.M, self.counter, self.freq, False)
248 |
249 | def forward(self, x):
250 | if self.training:
251 | self.counter += 1
252 | self.freq = 40 # min(self.freq+self.counter//100,100)
253 | w, mask = self.get_sparse_weights()
254 | setattr(self.weight, "mask", mask)
255 | else:
256 | w = self.weight * self.weight.mask
257 | x = F.linear(
258 | x, w, self.bias
259 | )
260 | return x
261 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/sparse_ops_init.py:
--------------------------------------------------------------------------------
1 | from .syncbn_layer import SyncBatchNorm2d
2 | from .sparse_ops import SparseConv, SparseConvTranspose ,SparseLinearTranspose
3 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/train_val.sh:
--------------------------------------------------------------------------------
1 | now=$(date +"%Y%m%d_%H%M%S")
2 | export RANK=0
3 | export WORLD_SIZE=8
4 | export PYTHONPATH="path_to_TNM_repo"
5 | python train_imagenet.py \
6 | --config $1 2>&1|tee train-$now.log
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/dynamic_TNM/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import shutil
4 | from devkit.sparse_ops import SparseConvTranspose,SparseLinearTranspose
5 |
6 |
7 | def save_checkpoint(model_dir, state, is_best):
8 | epoch = state['epoch']
9 | path = os.path.join(model_dir, 'model.pth-' + str(epoch))
10 | torch.save(state, path)
11 | checkpoint_file = os.path.join(model_dir, 'checkpoint')
12 | checkpoint = open(checkpoint_file, 'w+')
13 | checkpoint.write('model_checkpoint_path:%s\n' % path)
14 | checkpoint.close()
15 | if is_best:
16 | shutil.copyfile(path, os.path.join(model_dir, 'model-best.pth'))
17 |
18 |
19 | def load_state(model_dir, model, optimizer=None):
20 | if not os.path.exists(model_dir + '/checkpoint'):
21 | print("=> no checkpoint found at '{}', train from scratch".format(model_dir))
22 | return 0, 0
23 | else:
24 | ckpt = open(model_dir + '/checkpoint')
25 | model_path = ckpt.readlines()[0].split(':')[1].strip('\n')
26 | checkpoint = torch.load(model_path,map_location='cuda:{}'.format(torch.cuda.current_device()))
27 | model.load_state_dict(checkpoint['state_dict'], strict=False)
28 | ckpt_keys = set(checkpoint['state_dict'].keys())
29 | own_keys = set(model.state_dict().keys())
30 | missing_keys = own_keys - ckpt_keys
31 | for k in missing_keys:
32 | print('missing keys from checkpoint {}: {}'.format(model_dir, k))
33 |
34 | print("=> loaded model from checkpoint '{}'".format(model_dir))
35 | if optimizer != None:
36 | best_prec1 = 0
37 | if 'best_prec1' in checkpoint.keys():
38 | best_prec1 = checkpoint['best_prec1']
39 | start_epoch = checkpoint['epoch']
40 | optimizer.load_state_dict(checkpoint['optimizer'])
41 | print("=> also loaded optimizer from checkpoint '{}' (epoch {})"
42 | .format(model_dir, start_epoch))
43 | return best_prec1, start_epoch
44 |
45 |
46 | def load_state_epoch(model_dir, model, epoch):
47 | model_path = model_dir + '/model.pth-' + str(epoch)
48 | checkpoint = torch.load(model_path,map_location='cuda:{}'.format(torch.cuda.current_device()))
49 |
50 | model.load_state_dict(checkpoint['state_dict'], strict=False)
51 | ckpt_keys = set(checkpoint['state_dict'].keys())
52 | own_keys = set(model.state_dict().keys())
53 | missing_keys = own_keys - ckpt_keys
54 | for k in missing_keys:
55 | print('missing keys from checkpoint {}: {}'.format(model_dir, k))
56 |
57 | print("=> loaded model from checkpoint '{}'".format(model_dir))
58 |
59 |
60 | def load_state_ckpt(model_path, model):
61 | checkpoint = torch.load(model_path, map_location='cuda:{}'.format(torch.cuda.current_device()))
62 | model.load_state_dict(checkpoint['state_dict'], strict=False)
63 | ckpt_keys = set(checkpoint['state_dict'].keys())
64 | own_keys = set(model.state_dict().keys())
65 | missing_keys = own_keys - ckpt_keys
66 | for k in missing_keys:
67 | print('missing keys from checkpoint {}: {}'.format(model_path, k))
68 |
69 | print("=> loaded model from checkpoint '{}'".format(model_path))
70 |
71 | def save_masks(model,args):
72 | masks = {}
73 | for n, m in model.named_modules():
74 | if isinstance(m, SparseConvTranspose) or isinstance(m,SparseLinearTranspose):
75 | masks[n] = m.weight.mask.cpu()
76 | masks['state_dict'] = model.state_dict()
77 | torch.save(masks, args.mask_path + args.model + '_' + str(args.N) + '_' + str(args.M))
78 |
79 | def load_state_and_masks(model, args):
80 | masks = torch.load(args.mask_path + args.model + '_' + str(args.N) + '_' + str(args.M))
81 |
82 | #load weights
83 | model.load_state_dict(masks['state_dict'], strict=False)
84 | ckpt_keys = set(masks['state_dict'].keys())
85 | own_keys = set(model.state_dict().keys())
86 | missing_keys = own_keys - ckpt_keys
87 | for k in missing_keys:
88 | print('missing keys from checkpoint {}'.format( k))
89 |
90 | #load_masks
91 | for n, m in model.named_modules():
92 | if isinstance(m, SparseConvTranspose) or isinstance(m,SparseLinearTranspose):
93 |
94 | # m.maskBuff.data = masks[n]
95 | setattr(m.weight, "mask", masks[n])
96 |
97 |
98 |
--------------------------------------------------------------------------------
/dynamic_TNM/train-20210211_125543.log:
--------------------------------------------------------------------------------
1 | python: can't open file 'train_imagenet.py': [Errno 2] No such file or directory
2 |
--------------------------------------------------------------------------------
/prune/prune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | from torch import load as torch_load, save as torch_save, ones_like as torch_ones_like
5 | from common.timer import Timer
6 | from prune.pruning_method_utils import permute_to_nhwc, pad_inner_dims
7 | from prune.pruning_method_transposable_block_l1 import PruningMethodTransposableBlockL1
8 |
9 |
10 | def get_args():
11 | parser = argparse.ArgumentParser(description='Pruner')
12 | parser.add_argument('--checkpoint', required=True, type=str, help='path to checkpoint')
13 | parser.add_argument('--n-workers', default=None, type=int, help='number of processes')
14 | parser.add_argument('--save', default=None, type=str, help='path to pruned checkpoint')
15 | parser.add_argument('--bs', default=8, type=int, help='block size')
16 | parser.add_argument('--topk', default=4, type=int, help='topk')
17 | parser.add_argument('--sd-key', default='state_dict', type=str, help='state dict key in checkpoint')
18 | parser.add_argument('--optimize-transposed', action='store_true', default=False,
19 | help='if true, transposable pruning method will optimize for (block + block.T)')
20 | parser.add_argument('--include', nargs='*', default=None,
21 | help='list of layers that will be included in pruning')
22 | parser.add_argument('--exclude', nargs='*', default=None,
23 | help='list of layers that will be excluded from pruning')
24 | parser.add_argument('--debug-key', default=None, type=str, help='variable key to print first block')
25 | args = parser.parse_args()
26 | return args
27 |
28 |
29 | def load_checkpoint(filename):
30 | if not os.path.isfile(filename):
31 | raise FileNotFoundError('Checkpoint {} not found'.format(filename))
32 |
33 | checkpoint = torch_load(filename, map_location='cpu')
34 | return checkpoint
35 |
36 |
37 | def load_sd_from_checkpoint(filename, sd_key):
38 | checkpoint = load_checkpoint(filename)
39 | sd = checkpoint[sd_key] if sd_key is not None else checkpoint.copy()
40 | del checkpoint
41 | return sd
42 |
43 |
44 | def load_var_from_checkpoint(filename, name, sd_key):
45 | sd = load_sd_from_checkpoint(filename, sd_key)
46 | if name not in sd:
47 | raise RuntimeError('Variable {} not found in {}'.format(name, filename))
48 | v = sd[name]
49 | del sd
50 | return v
51 |
52 |
53 | def save_var_to_checkpoint(filename, name, mask, sd_key):
54 | checkpoint = load_checkpoint(filename)
55 | sd = checkpoint[sd_key] if sd_key is not None else checkpoint
56 | if name not in sd:
57 | raise RuntimeError('Variable {} not found in {}'.format(name, filename))
58 | sd[name] = sd[name] * mask
59 | torch_save(checkpoint, filename)
60 | del checkpoint
61 |
62 |
63 | def prune(checkpoint, save, sd_key, bs=8, topk=4, optimize_transposed=False,
64 | include=None, exclude=None, n_workers=None, debug_key=None):
65 |
66 | with Timer() as t:
67 | sd = load_sd_from_checkpoint(checkpoint, sd_key)
68 | print('Loading checkpoint, elapsed={}'.format(t.total()))
69 |
70 | save = checkpoint + '.pruned' if save is None else save
71 | shutil.copyfile(checkpoint, save)
72 |
73 | prune_method = PruningMethodTransposableBlockL1(block_size=bs, topk=topk,
74 | optimize_transposed=optimize_transposed,
75 | n_workers=n_workers, with_tqdm=True)
76 |
77 | keys = [k for k in sd.keys() if sd[k].dim() > 1 and 'bias' not in k and 'running' not in k]
78 |
79 | if include:
80 | invalid_keys = [k for k in include if k not in keys]
81 | assert not invalid_keys, 'Requested params to include={} not in model'.format(invalid_keys)
82 | print('Including {}'.format(exclude))
83 | keys = include
84 |
85 | if exclude:
86 | invalid_keys = [k for k in exclude if k not in keys]
87 | assert not invalid_keys, 'Requested params to exclude={} not in model'.format(invalid_keys)
88 | print('Excluding {}'.format(exclude))
89 | keys = [k for k in keys if k not in exclude]
90 |
91 | del sd
92 |
93 | with Timer() as t:
94 | for key in keys:
95 | v = load_var_from_checkpoint(checkpoint, key, sd_key)
96 | print('Pruning ' + key)
97 | prune_weight_mask = prune_method.compute_mask(v, torch_ones_like(v))
98 | save_var_to_checkpoint(save, key, prune_weight_mask, sd_key)
99 | print('Total elapsed time: {}'.format(t.total()))
100 |
101 | if debug_key:
102 | bs = bs
103 | sd = load_sd_from_checkpoint(save, sd_key)
104 | v = sd[debug_key]
105 |
106 | # print first block
107 | permuted_mask = permute_to_nhwc(v)
108 | permuted_mask = pad_inner_dims(permuted_mask, bs * bs)
109 | permuted_mask = permuted_mask.reshape(-1, (bs * bs))
110 | print('first block=\n{}'.format(permuted_mask.numpy()[0, :].reshape(1, -1, bs, bs)))
111 |
112 |
113 | def main():
114 | args = get_args()
115 | prune(checkpoint=args.checkpoint, save=args.save, sd_key=args.sd_key, bs=args.bs, topk=args.topk,
116 | optimize_transposed=args.optimize_transposed, include=args.include, exclude=args.exclude,
117 | n_workers=args.n_workers, debug_key=args.debug_key)
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
--------------------------------------------------------------------------------
/prune/pruning_method_based_mask.py:
--------------------------------------------------------------------------------
1 | from prune.pruning_method_utils import *
2 | import torch.nn.utils.prune as prune
3 |
4 |
5 | class PruningMethodBasedMask(prune.BasePruningMethod):
6 | """ Pruning based on fixed mask """
7 |
8 | PRUNING_TYPE = 'unstructured' # pruning type "structured" refers to channels
9 |
10 | def __init__(self, mask=None):
11 | super(PruningMethodBasedMask, self).__init__()
12 | self.mask = mask
13 |
14 | def compute_mask(self, t, default_mask):
15 | validate_tensor_shape_2d_4d(t)
16 | mask = self.mask.detach().mul_(default_mask)
17 | return mask.byte()
18 |
19 | def apply_like_self(self, module, name, **kwargs):
20 | assert 'mask' in kwargs
21 | cls = self.__class__
22 | return super(PruningMethodBasedMask, cls).apply(module, name, kwargs['mask'])
23 |
--------------------------------------------------------------------------------
/prune/pruning_method_transposable_block_l1.py:
--------------------------------------------------------------------------------
1 | from pulp import *
2 | from tqdm import tqdm
3 | from multiprocessing import Pool
4 | from common.timer import Timer
5 | from prune.pruning_method_utils import *
6 | import numpy as np
7 | import torch.nn.utils.prune as prune
8 |
9 |
10 | class PruningMethodTransposableBlockL1(prune.BasePruningMethod):
11 |
12 | PRUNING_TYPE = 'unstructured' # pruning type "structured" refers to channels
13 |
14 | RUN_SPEED_TEST = False
15 |
16 | def __init__(self, block_size, topk, optimize_transposed=False, n_workers=None, with_tqdm=True):
17 | super(PruningMethodTransposableBlockL1, self).__init__()
18 | assert topk <= block_size
19 | assert n_workers is None or n_workers > 0
20 | self.bs = block_size
21 | self.topk = topk
22 | self.optimize_transposed = optimize_transposed
23 | self.n_workers = n_workers
24 | self.with_tqdm = with_tqdm
25 | # used for multiprocess in order to avoid serialize/deserialize tensors etc.
26 | self.mp_tensor, self.mp_mask = None, None
27 |
28 | def ip_transpose(self, data):
29 | prob = LpProblem('TransposableMask', LpMaximize)
30 | combinations = []
31 | magnitude_loss = {}
32 | indicators = {}
33 | bs = self.bs
34 | for r in range(bs):
35 | for c in range(bs):
36 | combinations.append('ind' + '_{}r_{}c'.format(r, c))
37 | magnitude_loss['ind' + '_{}r_{}c'.format(r, c)] = abs(data[r, c])
38 | indicators['ind' + '_{}r_{}c'.format(r, c)] = \
39 | LpVariable('ind' + '_{}r_{}c'.format(r, c), 0, 1, LpInteger)
40 |
41 | prob += lpSum([indicators[ind] * magnitude_loss[ind] for ind in magnitude_loss.keys()])
42 |
43 | for r in range(bs):
44 | prob += lpSum([indicators[key] for key in combinations if '_{}r'.format(r) in key]) == self.topk
45 | for c in range(bs):
46 | prob += lpSum([indicators[key] for key in combinations if '_{}c'.format(c) in key]) == self.topk
47 |
48 | solver = LpSolverDefault
49 | solver.msg = False
50 | prob.solve(solver)
51 | assert prob.status != -1, 'Infeasible'
52 | mask = np.zeros([self.bs, self.bs])
53 | for v in prob.variables():
54 | if 'ind' in v.name:
55 | rc = re.findall(r'\d+', v.name)
56 | mask[int(rc[0]), int(rc[1])] = v.varValue
57 | return mask
58 |
59 | def get_mask_iter(self, c):
60 | co, inners = self.mp_tensor.shape
61 | block_numel = self.bs ** 2
62 | n_blocks = inners // block_numel
63 | for j in range(n_blocks):
64 | offset = j * block_numel
65 | w_block = self.mp_tensor[c, offset:offset + block_numel].reshape(self.bs, self.bs)
66 | w_block = w_block + w_block.T if self.optimize_transposed else w_block
67 | mask_block = self.ip_transpose(w_block).reshape(-1)
68 | self.mp_mask[c, offset:offset + block_numel] = torch.from_numpy(mask_block)
69 |
70 | def get_mask(self, t):
71 | self.mp_tensor = t
72 | self.mp_mask = torch.zeros_like(t)
73 |
74 | co, inners = t.shape
75 | n_blocks = inners // (self.bs ** 2)
76 |
77 | if self.RUN_SPEED_TEST:
78 | self.RUN_SPEED_TEST = False
79 | with Timer() as t:
80 | self.get_mask_iter(0)
81 | elapsed = t.total().total_seconds()
82 | print('Single core speed test: blocks={} secs={} block-time={}'.format(n_blocks, elapsed, elapsed/n_blocks))
83 |
84 | p = Pool(self.n_workers)
85 | n_iterations = co
86 | bar = tqdm(total=n_iterations, ncols=80) if self.with_tqdm else None
87 | bar.set_postfix_str('n_processes={}, blocks/iter={}'.format(p._processes, n_blocks)) if self.with_tqdm else None
88 | block_indexes = range(co)
89 | for _ in p.imap_unordered(self.get_mask_iter, block_indexes):
90 | bar.update(1) if self.with_tqdm else None
91 | bar.close() if self.with_tqdm else None
92 | p.close()
93 |
94 | return self.mp_mask
95 |
96 | def compute_mask(self, t, default_mask):
97 | # permute and pad
98 | validate_tensor_shape_2d_4d(t)
99 | t_masked = t.clone().detach().mul_(default_mask)
100 | t_permuted = permute_to_nhwc(t_masked)
101 | pad_to = self.bs ** 2
102 | t_padded = pad_inner_dims(t_permuted, pad_to)
103 | t = t_padded.data.abs().to(t)
104 |
105 | # compute mask
106 | mask = self.get_mask(t)
107 |
108 | # restore to original shape
109 | block_mask = clip_padding(mask, t_permuted.shape).reshape(t_permuted.shape)
110 | block_mask = permute_to_nchw(block_mask)
111 | return block_mask
112 |
--------------------------------------------------------------------------------
/prune/pruning_method_transposable_block_l1_graphs.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | from tqdm import tqdm
3 | from multiprocessing import Pool
4 | from common.timer import Timer
5 | from prune.pruning_method_utils import *
6 | import numpy as np
7 | import torch.nn.utils.prune as prune
8 |
9 |
10 | class PruningMethodTransposableBlockL1Graphs(prune.BasePruningMethod):
11 |
12 | PRUNING_TYPE = 'unstructured' # pruning type "structured" refers to channels
13 |
14 | RUN_SPEED_TEST = False
15 |
16 | def __init__(self, block_size, topk, optimize_transposed=False, n_workers=None, with_tqdm=True):
17 | super(PruningMethodTransposableBlockL1Graphs, self).__init__()
18 | assert topk <= block_size
19 | assert n_workers is None or n_workers > 0
20 | self.bs = block_size
21 | self.topk = topk
22 | self.optimize_transposed = optimize_transposed
23 | self.n_workers = n_workers
24 | self.with_tqdm = with_tqdm
25 | # used for multiprocess in order to avoid serialize/deserialize tensors etc.
26 | self.mp_tensor, self.mp_mask = None, None
27 |
28 | def nxGraph(self, data):
29 | bs = data.shape[0]
30 | G = nx.DiGraph()
31 | G.add_node('s', demand=-int(bs ** 2 / 2))
32 | G.add_node('t', demand=int(bs ** 2 / 2))
33 | names = []
34 | for i in range(bs):
35 | G.add_edge('s', 'row' + str(i), capacity=self.topk, weight=0)
36 | G.add_edge('col' + str(i), 't', capacity=self.topk, weight=0)
37 | for j in range(bs):
38 | G.add_edge('row' + str(i), 'col' + str(j), capacity=1, weight=data[i, j].numpy())
39 | names.append('row' + str(i))
40 | dictMinFLow = nx.min_cost_flow(G)
41 | mask = []
42 | for w in names:
43 | mask.append(list(dictMinFLow[w].values()))
44 | return np.array(mask)
45 |
46 | def get_mask_iter(self, c):
47 | co, inners = self.mp_tensor.shape
48 | block_numel = self.bs ** 2
49 | n_blocks = inners // block_numel
50 | for j in range(n_blocks):
51 | offset = j * block_numel
52 | w_block = self.mp_tensor[c, offset:offset + block_numel].reshape(self.bs, self.bs)
53 | w_block = w_block + w_block.T if self.optimize_transposed else w_block
54 | mask_block = self.nxGraph(-1 * w_block).reshape(-1) #max flow to min flow
55 | self.mp_mask[c, offset:offset + block_numel] = torch.from_numpy(mask_block)
56 |
57 | def get_mask(self, t):
58 | self.mp_tensor = t
59 | self.mp_mask = torch.zeros_like(t)
60 |
61 | co, inners = t.shape
62 | n_blocks = inners // (self.bs ** 2)
63 |
64 | if self.RUN_SPEED_TEST:
65 | self.RUN_SPEED_TEST = False
66 | with Timer() as t:
67 | self.get_mask_iter(0)
68 | elapsed = t.total().total_seconds()
69 | print('Single core speed test: blocks={} secs={} block-time={}'.format(n_blocks, elapsed, elapsed/n_blocks))
70 |
71 | p = Pool(self.n_workers)
72 | n_iterations = co
73 | bar = tqdm(total=n_iterations, ncols=80) if self.with_tqdm else None
74 | bar.set_postfix_str('n_processes={}, blocks/iter={}'.format(p._processes, n_blocks)) if self.with_tqdm else None
75 | block_indexes = range(co)
76 | for _ in p.imap_unordered(self.get_mask_iter, block_indexes):
77 | bar.update(1) if self.with_tqdm else None
78 | bar.close() if self.with_tqdm else None
79 | p.close()
80 |
81 | return self.mp_mask
82 |
83 | def compute_mask(self, t, default_mask):
84 | # permute and pad
85 | validate_tensor_shape_2d_4d(t)
86 | t_masked = t.clone().detach().mul_(default_mask)
87 | t_permuted = permute_to_nhwc(t_masked)
88 | pad_to = self.bs ** 2
89 | t_padded = pad_inner_dims(t_permuted, pad_to)
90 | t = t_padded.data.abs().to(t)
91 |
92 | # compute mask
93 | mask = self.get_mask(t)
94 |
95 | # restore to original shape
96 | block_mask = clip_padding(mask, t_permuted.shape).reshape(t_permuted.shape)
97 | block_mask = permute_to_nchw(block_mask)
98 | return block_mask
99 |
--------------------------------------------------------------------------------
/prune/pruning_method_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def validate_tensor_shape_2d_4d(t):
5 | shape = t.shape
6 | if len(shape) not in (2, 4):
7 | raise ValueError(
8 | "Only 2D and 4D tensor shapes are supported. Found "
9 | "Found tensor of shape {} with {} dims".format(shape, len(shape))
10 | )
11 |
12 |
13 | def pad_inner_dims(t, pad_to):
14 | """ return padded-to-block tensor """
15 | inner_flattened = t.view(t.shape[0], -1)
16 | co, inners = inner_flattened.shape
17 | pad_required = pad_to > 1 and inners % pad_to != 0
18 | pad_size = pad_to - inners % pad_to if pad_required else 0
19 | pad = torch.zeros(co, pad_size).to(inner_flattened.data)
20 | t_padded = torch.cat((inner_flattened, pad), 1)
21 | return t_padded
22 |
23 |
24 | def clip_padding(t, orig_shape):
25 | """ return tensor with clipped padding """
26 | co = orig_shape[0]
27 | inners = 1
28 | for s in orig_shape[1:]:
29 | inners *= s
30 | t_clipped = t.view(co, -1)[:, :inners]
31 | return t_clipped
32 |
33 |
34 | def permute_to_nhwc(t):
35 | """ for 4D tensors, convert data layout from NCHW to NHWC """
36 | res = t.permute(0, 2, 3, 1).contiguous() if t.dim() == 4 else t
37 | return res
38 |
39 |
40 | def permute_to_nchw(t):
41 | """ for 4D tensors, convert data layout from NHWC to NCHW """
42 | res = t.permute(0, 3, 1, 2).contiguous() if t.dim() == 4 else t
43 | return res
44 |
--------------------------------------------------------------------------------
/prune/sparsity_freezer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .pruning_method_based_mask import PruningMethodBasedMask
3 |
4 |
5 | class SparsityFreezer:
6 | """ Keeps sparsity level in model by applying pruning based on current zeros of parameters """
7 | @staticmethod
8 | def freeze(model):
9 | with torch.no_grad():
10 | params = SparsityFreezer._get_model_params(model)
11 | SparsityFreezer._enforce_mask_based_on_zeros(params)
12 |
13 | @staticmethod
14 | def _enforce_mask_based_on_zeros(params):
15 | prune_method = PruningMethodBasedMask()
16 | for param_info in params.values():
17 | module, name = param_info
18 | param = getattr(module, name)
19 | mask = param.ne(0).float()
20 | prune_method.apply_like_self(module=module, name=name, mask=mask)
21 |
22 | @staticmethod
23 | def _get_model_params(model):
24 | params = {}
25 | for m_info in list(model.named_modules()):
26 | module_name, module = m_info
27 | for p_info in list(module.named_parameters(recurse=False)):
28 | param_name, param = p_info
29 | key = module_name + '.' + param_name
30 | if param.dim() > 1 and 'bias' not in key and 'running' not in key:
31 | params[key] = (module, param_name)
32 |
33 | # a shared parameter will only appear once in model.named_parameters()
34 | # therefore, filter to get only parameters that appear in model.named_parameters()
35 | model_named_params = set([name for name, _ in model.named_parameters()])
36 | params = {p: v for p, v in params.items() if p in model_named_params}
37 | return params
38 |
--------------------------------------------------------------------------------
/static_TNM/scripts/prune_pretrained_R50.sh:
--------------------------------------------------------------------------------
1 | export datasets_dir=/datasets
2 | export dataset=imagenet
3 | export workdir='./results/static_TNM'
4 |
5 | echo $workdir
6 | cd ..
7 | python -m static_TNM.src.prune_pretrained_model -a resnet50 --save $workdir/resnet50-pruned.pth
8 | cp $workdir/resnet50-pruned.pth $workdir/resnet50.pth
9 | python -m vision.main --model resnet --resume $workdir/resnet50.pth --save $workdir --sparsity-freezer -b 256 --device-ids 0 1 2 3 --dataset $dataset --datasets-dir $datasets_dir
10 |
--------------------------------------------------------------------------------
/static_TNM/src/prune_pretrained_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import torchvision.models as models
5 | from torch.hub import get_dir
6 | from glob import glob
7 | from prune.prune import prune
8 |
9 |
10 | def main():
11 | # get supported models
12 | model_names = sorted(name for name in models.__dict__
13 | if name.islower() and not name.startswith("__")
14 | and callable(models.__dict__[name]))
15 |
16 | # get arguments
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=model_names,
19 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)')
20 | parser.add_argument('--save', default=None, type=str, help='pruned checkpoint')
21 | args = parser.parse_args()
22 |
23 | # if required, download pre trained model
24 | models.__dict__[args.arch](pretrained=True)
25 |
26 | # get pre trained checkpoint
27 | checkpoint_path = os.path.join(get_dir(), 'checkpoints')
28 | files = glob(os.path.join(checkpoint_path, '{}-*.pth').format(args.arch))
29 | assert len(files) == 1
30 | checkpoint_file = files[0]
31 |
32 | # prune and save checkpoint
33 | prune(checkpoint=checkpoint_file, save=args.save, sd_key=None, bs=8, topk=4)
34 |
35 | # add expected fields to checkpoint
36 | sd = torch.load(args.save)
37 | checkpoint = {'state_dict': sd, 'epoch': 0, 'best_prec1': 0}
38 | torch.save(checkpoint, args.save)
39 |
40 |
41 | if __name__ == '__main__':
42 | main()
43 |
--------------------------------------------------------------------------------
/vision/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Elad Hoffer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/vision/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 | from torch.utils.data.distributed import DistributedSampler
5 | from torch.utils.data import Subset
6 | from torch._utils import _accumulate
7 | from vision.utils.regime import Regime
8 | from vision.utils.dataset import IndexedFileDataset
9 | from vision.preprocess import get_transform
10 | from itertools import chain
11 | from copy import deepcopy
12 | import warnings
13 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
14 |
15 |
16 | def get_dataset(name, split='train', transform=None,
17 | target_transform=None, download=True, datasets_path='~/Datasets'):
18 | train = (split == 'train')
19 | root = os.path.join(os.path.expanduser(datasets_path), name)
20 | if name == 'cifar10':
21 | return datasets.CIFAR10(root=root,
22 | train=train,
23 | transform=transform,
24 | target_transform=target_transform,
25 | download=download)
26 | elif name == 'cifar100':
27 | return datasets.CIFAR100(root=root,
28 | train=train,
29 | transform=transform,
30 | target_transform=target_transform,
31 | download=download)
32 | elif name == 'mnist':
33 | return datasets.MNIST(root=root,
34 | train=train,
35 | transform=transform,
36 | target_transform=target_transform,
37 | download=download)
38 | elif name == 'stl10':
39 | return datasets.STL10(root=root,
40 | split=split,
41 | transform=transform,
42 | target_transform=target_transform,
43 | download=download)
44 | elif name == 'imagenet':
45 | root = os.path.join(root, split)
46 | return datasets.ImageFolder(root=root,
47 | transform=transform,
48 | target_transform=target_transform)
49 | elif name == 'imagenet_tar':
50 | if train:
51 | root = os.path.join(root, 'imagenet_train.tar')
52 | else:
53 | root = os.path.join(root, 'imagenet_validation.tar')
54 | return IndexedFileDataset(root, extract_target_fn=(
55 | lambda fname: fname.split('/')[0]),
56 | transform=transform,
57 | target_transform=target_transform)
58 |
59 |
60 | _DATA_ARGS = {'name', 'split', 'transform',
61 | 'target_transform', 'download', 'datasets_path'}
62 | _DATALOADER_ARGS = {'batch_size', 'shuffle', 'sampler', 'batch_sampler',
63 | 'num_workers', 'collate_fn', 'pin_memory', 'drop_last',
64 | 'timeout', 'worker_init_fn'}
65 | _TRANSFORM_ARGS = {'transform_name', 'input_size', 'scale_size', 'normalize', 'augment',
66 | 'cutout', 'duplicates', 'num_crops', 'autoaugment'}
67 | _OTHER_ARGS = {'distributed'}
68 |
69 |
70 | class DataRegime(object):
71 | def __init__(self, regime, defaults={}):
72 | self.regime = Regime(regime, deepcopy(defaults))
73 | self.epoch = 0
74 | self.steps = None
75 | self.get_loader(True)
76 |
77 | def get_setting(self):
78 | setting = self.regime.setting
79 | loader_setting = {k: v for k,
80 | v in setting.items() if k in _DATALOADER_ARGS}
81 | data_setting = {k: v for k, v in setting.items() if k in _DATA_ARGS}
82 | transform_setting = {
83 | k: v for k, v in setting.items() if k in _TRANSFORM_ARGS}
84 | other_setting = {k: v for k, v in setting.items() if k in _OTHER_ARGS}
85 | transform_setting.setdefault('transform_name', data_setting['name'])
86 | return {'data': data_setting, 'loader': loader_setting,
87 | 'transform': transform_setting, 'other': other_setting}
88 |
89 | def get(self, key, default=None):
90 | return self.regime.setting.get(key, default)
91 |
92 | def get_loader(self, force_update=False, override_settings=None, subset_indices=None):
93 | if force_update or self.regime.update(self.epoch, self.steps):
94 | setting = self.get_setting()
95 | if override_settings is not None:
96 | setting.update(override_settings)
97 | self._transform = get_transform(**setting['transform'])
98 | setting['data'].setdefault('transform', self._transform)
99 | self._data = get_dataset(**setting['data'])
100 | if subset_indices is not None:
101 | self._data = Subset(self._data, subset_indices)
102 | if setting['other'].get('distributed', False):
103 | setting['loader']['sampler'] = DistributedSampler(self._data)
104 | setting['loader']['shuffle'] = None
105 | # pin-memory currently broken for distributed
106 | setting['loader']['pin_memory'] = False
107 | self._sampler = setting['loader'].get('sampler', None)
108 | self._loader = torch.utils.data.DataLoader(
109 | self._data, **setting['loader'])
110 | return self._loader
111 |
112 | def set_epoch(self, epoch):
113 | self.epoch = epoch
114 | if self._sampler is not None and hasattr(self._sampler, 'set_epoch'):
115 | self._sampler.set_epoch(epoch)
116 |
117 | def __len__(self):
118 | return len(self._data)
119 |
120 | def __repr__(self):
121 | return str(self.regime)
122 |
123 |
124 | class SampledDataLoader(object):
125 | def __init__(self, dl_list):
126 | self.dl_list = dl_list
127 | self.epoch = 0
128 |
129 | def generate_order(self):
130 |
131 | order = [[idx]*len(dl) for idx, dl in enumerate(self.dl_list)]
132 | order = list(chain(*order))
133 | g = torch.Generator()
134 | g.manual_seed(self.epoch)
135 | return torch.tensor(order)[torch.randperm(len(order), generator=g)].tolist()
136 |
137 | def __len__(self):
138 | return sum([len(dl) for dl in self.dl_list])
139 |
140 | def __iter__(self):
141 | order = self.generate_order()
142 |
143 | iterators = [iter(dl) for dl in self.dl_list]
144 | for idx in order:
145 | yield next(iterators[idx])
146 | return
147 |
148 |
149 | class SampledDataRegime(DataRegime):
150 | def __init__(self, data_regime_list, probs, split_data=True):
151 | self.probs = probs
152 | self.data_regime_list = data_regime_list
153 | self.split_data = split_data
154 |
155 | def get_setting(self):
156 | return [data_regime.get_setting() for data_regime in self.data_regime_list]
157 |
158 | def get(self, key, default=None):
159 | return [data_regime.get(key, default) for data_regime in self.data_regime_list]
160 |
161 | def get_loader(self, force_update=False):
162 | settings = self.get_setting()
163 | if self.split_data:
164 | dset_sizes = [len(get_dataset(**s['data'])) for s in settings]
165 | assert len(set(dset_sizes)) == 1, \
166 | "all datasets should be same size"
167 | dset_size = dset_sizes[0]
168 | lengths = [int(prob * dset_size) for prob in self.probs]
169 | lengths[-1] = dset_size - sum(lengths[:-1])
170 | indices = torch.randperm(dset_size).tolist()
171 | indices_split = [indices[offset - length:offset]
172 | for offset, length in zip(_accumulate(lengths), lengths)]
173 | loaders = [data_regime.get_loader(force_update=True, subset_indices=indices_split[i])
174 | for i, data_regime in enumerate(self.data_regime_list)]
175 | else:
176 | loaders = [data_regime.get_loader(
177 | force_update=force_update) for data_regime in self.data_regime_list]
178 | self._loader = SampledDataLoader(loaders)
179 | self._loader.epoch = self.epoch
180 |
181 | return self._loader
182 |
183 | def set_epoch(self, epoch):
184 | self.epoch = epoch
185 | if hasattr(self, '_loader'):
186 | self._loader.epoch = epoch
187 | for data_regime in self.data_regime_list:
188 | if data_regime._sampler is not None and hasattr(data_regime._sampler, 'set_epoch'):
189 | data_regime._sampler.set_epoch(epoch)
190 |
191 | def __len__(self):
192 | return sum([len(data_regime._data)
193 | for data_regime in self.data_regime_list])
194 |
195 | def __repr__(self):
196 | print_str = 'Sampled Data Regime:\n'
197 | for p, config in zip(self.probs, self.data_regime_list):
198 | print_str += 'w.p. %s: %s\n' % (p, config)
199 | return print_str
200 |
201 |
202 | if __name__ == '__main__':
203 | reg1 = DataRegime(None, {'name': 'imagenet', 'batch_size': 16})
204 | reg2 = DataRegime(None, {'name': 'imagenet', 'batch_size': 32})
205 | reg1.set_epoch(0)
206 | reg2.set_epoch(0)
207 | mreg = SampledDataRegime([reg1, reg2])
208 |
209 | for x, _ in mreg.get_loader():
210 | print(x.shape)
211 |
--------------------------------------------------------------------------------
/vision/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .alexnet import *
2 | from .resnet import *
3 |
--------------------------------------------------------------------------------
/vision/models/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torchvision.transforms as transforms
3 |
4 | __all__ = ['alexnet']
5 |
6 |
7 | class AlexNetOWT_BN(nn.Module):
8 |
9 | def __init__(self, num_classes=1000):
10 | super(AlexNetOWT_BN, self).__init__()
11 | self.features = nn.Sequential(
12 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2,
13 | bias=False),
14 | nn.MaxPool2d(kernel_size=3, stride=2),
15 | nn.BatchNorm2d(64),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(64, 192, kernel_size=5, padding=2, bias=False),
18 | nn.MaxPool2d(kernel_size=3, stride=2),
19 | nn.ReLU(inplace=True),
20 | nn.BatchNorm2d(192),
21 | nn.Conv2d(192, 384, kernel_size=3, padding=1, bias=False),
22 | nn.ReLU(inplace=True),
23 | nn.BatchNorm2d(384),
24 | nn.Conv2d(384, 256, kernel_size=3, padding=1, bias=False),
25 | nn.ReLU(inplace=True),
26 | nn.BatchNorm2d(256),
27 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
28 | nn.MaxPool2d(kernel_size=3, stride=2),
29 | nn.ReLU(inplace=True),
30 | nn.BatchNorm2d(256)
31 | )
32 | self.classifier = nn.Sequential(
33 | nn.Linear(256 * 6 * 6, 4096, bias=False),
34 | nn.BatchNorm1d(4096),
35 | nn.ReLU(inplace=True),
36 | nn.Dropout(0.5),
37 | nn.Linear(4096, 4096, bias=False),
38 | nn.BatchNorm1d(4096),
39 | nn.ReLU(inplace=True),
40 | nn.Dropout(0.5),
41 | nn.Linear(4096, num_classes)
42 | )
43 |
44 | self.regime = [
45 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-2,
46 | 'weight_decay': 5e-4, 'momentum': 0.9},
47 | {'epoch': 10, 'lr': 5e-3},
48 | {'epoch': 15, 'lr': 1e-3, 'weight_decay': 0},
49 | {'epoch': 20, 'lr': 5e-4},
50 | {'epoch': 25, 'lr': 1e-4}
51 | ]
52 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
53 | std=[0.229, 0.224, 0.225])
54 | self.data_regime = [{
55 | 'transform': transforms.Compose([
56 | transforms.Resize(256),
57 | transforms.RandomCrop(224),
58 | transforms.RandomHorizontalFlip(),
59 | transforms.ToTensor(),
60 | normalize])
61 | }]
62 | self.data_eval_regime = [{
63 | 'transform': transforms.Compose([
64 | transforms.Resize(256),
65 | transforms.CenterCrop(224),
66 | transforms.ToTensor(),
67 | normalize])
68 | }]
69 | def forward(self, x):
70 | x = self.features(x)
71 | x = x.view(-1, 256 * 6 * 6)
72 | x = self.classifier(x)
73 | return x
74 |
75 |
76 | def alexnet(**kwargs):
77 | num_classes = getattr(kwargs, 'num_classes', 1000)
78 | return AlexNetOWT_BN(num_classes)
79 |
--------------------------------------------------------------------------------
/vision/models/modules/activations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | @torch.jit.script
7 | def swish(x):
8 | # type: (Tensor) -> Tensor
9 | return x * x.sigmoid()
10 |
11 |
12 | @torch.jit.script
13 | def hard_sigmoid(x):
14 | # type: (Tensor) -> Tensor
15 | return F.relu6(x+3).div_(6)
16 |
17 |
18 | @torch.jit.script
19 | def hard_swish(x):
20 | # type: (Tensor) -> Tensor
21 | return x * hard_sigmoid(x)
22 |
23 |
24 | class Swish(nn.Module):
25 | def __init__(self):
26 | super(Swish, self).__init__()
27 |
28 | def forward(self, x):
29 | return swish(x)
30 |
31 |
32 | class HardSigmoid(nn.Module):
33 | def __init__(self):
34 | super(HardSigmoid, self).__init__()
35 |
36 | def forward(self, x):
37 | return hard_sigmoid(x)
38 |
39 |
40 | class HardSwish(nn.Module):
41 | def __init__(self):
42 | super(HardSwish, self).__init__()
43 |
44 | def forward(self, x):
45 | return hard_swish(x)
46 |
--------------------------------------------------------------------------------
/vision/models/modules/checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
4 |
5 |
6 | class CheckpointModule(nn.Module):
7 | def __init__(self, module, num_segments=1):
8 | super(CheckpointModule, self).__init__()
9 | assert num_segments == 1 or isinstance(module, nn.Sequential)
10 | self.module = module
11 | self.num_segments = num_segments
12 |
13 | def forward(self, *inputs):
14 | if self.num_segments > 1:
15 | return checkpoint_sequential(self.module, self.num_segments, *inputs)
16 | else:
17 | return checkpoint(self.module, *inputs)
18 |
--------------------------------------------------------------------------------
/vision/models/modules/se.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .activations import Swish, HardSwish, HardSigmoid
4 |
5 |
6 | class SEBlock(nn.Module):
7 | def __init__(self, in_channels, out_channels=None, ratio=16):
8 | super(SEBlock, self).__init__()
9 | self.in_channels = in_channels
10 | if out_channels is None:
11 | out_channels = in_channels
12 | self.ratio = ratio
13 | self.relu = nn.ReLU(True)
14 | self.global_pool = nn.AdaptiveAvgPool2d(1)
15 | self.transform = nn.Sequential(
16 | nn.Linear(in_channels, in_channels // ratio),
17 | nn.ReLU(inplace=True),
18 | nn.Linear(in_channels // ratio, out_channels),
19 | nn.Sigmoid()
20 | )
21 |
22 | def forward(self, x):
23 | x_avg = self.global_pool(x).flatten(1, -1)
24 | mask = self.transform(x_avg)
25 | return x * mask.unsqueeze(-1).unsqueeze(-1)
26 |
27 |
28 | class SESwishBlock(nn.Module):
29 | """ squeeze-excite block for MBConv """
30 |
31 | def __init__(self, in_channels, out_channels=None, interm_channels=None, ratio=None, hard_act=False):
32 | super(SESwishBlock, self).__init__()
33 | assert not (interm_channels is None and ratio is None)
34 | interm_channels = interm_channels or in_channels // ratio
35 | self.in_channels = in_channels
36 | if out_channels is None:
37 | out_channels = in_channels
38 | self.ratio = ratio
39 | self.activation = HardSwish() if hard_act else Swish(),
40 | self.global_pool = nn.AdaptiveAvgPool2d(1)
41 | self.transform = nn.Sequential(
42 | nn.Linear(in_channels, interm_channels),
43 | HardSwish() if hard_act else Swish(),
44 | nn.Linear(interm_channels, out_channels),
45 | HardSigmoid() if hard_act else nn.Sigmoid()
46 | )
47 |
48 | def forward(self, x):
49 | x_avg = self.global_pool(x).flatten(1, -1)
50 | mask = self.transform(x_avg)
51 | return x * mask.unsqueeze(-1).unsqueeze(-1)
52 |
--------------------------------------------------------------------------------
/vision/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchvision.transforms as transforms
4 | from vision.autoaugment import ImageNetPolicy, CIFAR10Policy
5 |
6 |
7 | _IMAGENET_STATS = {'mean': [0.485, 0.456, 0.406],
8 | 'std': [0.229, 0.224, 0.225]}
9 |
10 |
11 | _IMAGENET_PCA = {
12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
13 | 'eigvec': torch.Tensor([
14 | [-0.5675, 0.7192, 0.4009],
15 | [-0.5808, -0.0045, -0.8140],
16 | [-0.5836, -0.6948, 0.4203],
17 | ])
18 | }
19 |
20 |
21 | def scale_crop(input_size, scale_size=None, num_crops=1, normalize=_IMAGENET_STATS):
22 | assert num_crops in [1, 5, 10], "num crops must be in {1,5,10}"
23 | convert_tensor = transforms.Compose([transforms.ToTensor(),
24 | transforms.Normalize(**normalize)])
25 | if num_crops == 1:
26 | t_list = [
27 | transforms.CenterCrop(input_size),
28 | convert_tensor
29 | ]
30 | else:
31 | if num_crops == 5:
32 | t_list = [transforms.FiveCrop(input_size)]
33 | elif num_crops == 10:
34 | t_list = [transforms.TenCrop(input_size)]
35 | # returns a 4D tensor
36 | t_list.append(transforms.Lambda(lambda crops:
37 | torch.stack([convert_tensor(crop) for crop in crops])))
38 |
39 | if scale_size != input_size:
40 | t_list = [transforms.Resize(scale_size)] + t_list
41 |
42 | return transforms.Compose(t_list)
43 |
44 |
45 | def random_crop(input_size, scale_size=None, padding=None, normalize=_IMAGENET_STATS):
46 | scale_size = scale_size or input_size
47 | T = transforms.Compose([
48 | transforms.RandomCrop(scale_size, padding=padding),
49 | transforms.RandomHorizontalFlip(),
50 | transforms.ToTensor(),
51 | transforms.Normalize(**normalize),
52 | ])
53 | if input_size != scale_size:
54 | T.transforms.insert(1, transforms.Resize(input_size))
55 | return T
56 |
57 |
58 | def pad_random_crop(input_size, scale_size=None, normalize=_IMAGENET_STATS):
59 | padding = int((scale_size - input_size) / 2)
60 | return transforms.Compose([
61 | transforms.RandomCrop(input_size, padding=padding),
62 | transforms.RandomHorizontalFlip(),
63 | transforms.ToTensor(),
64 | transforms.Normalize(**normalize),
65 | ])
66 |
67 |
68 | def cifar_autoaugment(input_size, scale_size=None, padding=None, normalize=_IMAGENET_STATS):
69 | scale_size = scale_size or input_size
70 | T = transforms.Compose([
71 | transforms.RandomCrop(scale_size, padding=padding),
72 | transforms.RandomHorizontalFlip(),
73 | CIFAR10Policy(fillcolor=(128, 128, 128)),
74 | transforms.ToTensor(),
75 | transforms.Normalize(**normalize),
76 | ])
77 | if input_size != scale_size:
78 | T.transforms.insert(1, transforms.Resize(input_size))
79 | return T
80 |
81 |
82 | def inception_preprocess(input_size, normalize=_IMAGENET_STATS):
83 | return transforms.Compose([
84 | transforms.RandomResizedCrop(input_size),
85 | transforms.RandomHorizontalFlip(),
86 | transforms.ToTensor(),
87 | transforms.Normalize(**normalize)
88 | ])
89 |
90 |
91 | def inception_autoaugment_preprocess(input_size, normalize=_IMAGENET_STATS):
92 | return transforms.Compose([
93 | transforms.RandomResizedCrop(input_size),
94 | transforms.RandomHorizontalFlip(),
95 | ImageNetPolicy(fillcolor=(128, 128, 128)),
96 | transforms.ToTensor(),
97 | transforms.Normalize(**normalize)
98 | ])
99 |
100 |
101 | def inception_color_preprocess(input_size, normalize=_IMAGENET_STATS):
102 | return transforms.Compose([
103 | transforms.RandomResizedCrop(input_size),
104 | transforms.RandomHorizontalFlip(),
105 | transforms.ColorJitter(
106 | brightness=0.4,
107 | contrast=0.4,
108 | saturation=0.4,
109 | ),
110 | transforms.ToTensor(),
111 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
112 | transforms.Normalize(**normalize)
113 | ])
114 |
115 |
116 | def multi_transform(transform_fn, duplicates=1, dim=0):
117 | """preforms multiple transforms, useful to implement inference time augmentation or
118 | "batch augmentation" from https://openreview.net/forum?id=H1V4QhAqYQ¬eId=BylUSs_3Y7
119 | """
120 | if duplicates > 1:
121 | return transforms.Lambda(lambda x: torch.stack([transform_fn(x) for _ in range(duplicates)], dim=dim))
122 | else:
123 | return transform_fn
124 |
125 |
126 | def get_transform(transform_name='imagenet', input_size=None, scale_size=None,
127 | normalize=None, augment=True, cutout=None, autoaugment=False,
128 | padding=None, duplicates=1, num_crops=1):
129 | normalize = normalize or _IMAGENET_STATS
130 | transform_fn = None
131 | if 'imagenet' in transform_name: # inception augmentation is default for imagenet
132 | input_size = input_size or 224
133 | scale_size = scale_size or int(input_size * 8/7)
134 | if augment:
135 | if autoaugment:
136 | transform_fn = inception_autoaugment_preprocess(input_size,
137 | normalize=normalize)
138 | else:
139 | transform_fn = inception_preprocess(input_size,
140 | normalize=normalize)
141 | else:
142 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size,
143 | num_crops=num_crops, normalize=normalize)
144 | elif 'cifar' in transform_name: # resnet augmentation is default for imagenet
145 | input_size = input_size or 32
146 | if augment:
147 | scale_size = scale_size or 32
148 | padding = padding or 4
149 | if autoaugment:
150 | transform_fn = cifar_autoaugment(input_size, scale_size=scale_size,
151 | padding=padding, normalize=normalize)
152 | else:
153 | transform_fn = random_crop(input_size, scale_size=scale_size,
154 | padding=padding, normalize=normalize)
155 | else:
156 | scale_size = scale_size or 32
157 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size,
158 | num_crops=num_crops, normalize=normalize)
159 | elif transform_name == 'mnist':
160 | normalize = {'mean': [0.5], 'std': [0.5]}
161 | input_size = input_size or 28
162 | if augment:
163 | scale_size = scale_size or 32
164 | transform_fn = pad_random_crop(input_size, scale_size=scale_size,
165 | normalize=normalize)
166 | else:
167 | scale_size = scale_size or 32
168 | transform_fn = scale_crop(input_size=input_size, scale_size=scale_size,
169 | num_crops=num_crops, normalize=normalize)
170 | if cutout is not None:
171 | transform_fn.transforms.append(Cutout(**cutout))
172 | return multi_transform(transform_fn, duplicates)
173 |
174 |
175 | class Lighting(object):
176 | """Lighting noise(AlexNet - style PCA - based noise)"""
177 |
178 | def __init__(self, alphastd, eigval, eigvec):
179 | self.alphastd = alphastd
180 | self.eigval = eigval
181 | self.eigvec = eigvec
182 |
183 | def __call__(self, img):
184 | if self.alphastd == 0:
185 | return img
186 |
187 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
188 | rgb = self.eigvec.type_as(img).clone()\
189 | .mul(alpha.view(1, 3).expand(3, 3))\
190 | .mul(self.eigval.view(1, 3).expand(3, 3))\
191 | .sum(1).squeeze()
192 |
193 | return img.add(rgb.view(3, 1, 1).expand_as(img))
194 |
195 |
196 | class Cutout(object):
197 | """
198 | Randomly mask out one or more patches from an image.
199 | taken from https://github.com/uoguelph-mlrg/Cutout
200 |
201 |
202 | Args:
203 | holes (int): Number of patches to cut out of each image.
204 | length (int): The length (in pixels) of each square patch.
205 | """
206 |
207 | def __init__(self, holes, length):
208 | self.holes = holes
209 | self.length = length
210 |
211 | def __call__(self, img):
212 | """
213 | Args:
214 | img (Tensor): Tensor image of size (C, H, W).
215 | Returns:
216 | Tensor: Image with holes of dimension length x length cut out of it.
217 | """
218 | h = img.size(1)
219 | w = img.size(2)
220 |
221 | mask = np.ones((h, w), np.float32)
222 |
223 | for n in range(self.holes):
224 | y = np.random.randint(h)
225 | x = np.random.randint(w)
226 |
227 | y1 = np.clip(y - self.length // 2, 0, h)
228 | y2 = np.clip(y + self.length // 2, 0, h)
229 | x1 = np.clip(x - self.length // 2, 0, w)
230 | x2 = np.clip(x + self.length // 2, 0, w)
231 |
232 | mask[y1: y2, x1: x2] = 0.
233 |
234 | mask = torch.from_numpy(mask)
235 | mask = mask.expand_as(img)
236 | img = img * mask
237 |
238 | return img
239 |
--------------------------------------------------------------------------------
/vision/utils/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Elad Hoffer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/vision/utils/absorb_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import logging
4 |
5 |
6 | def remove_bn_params(bn_module):
7 | bn_module.register_buffer('running_mean', None)
8 | bn_module.register_buffer('running_var', None)
9 | bn_module.register_parameter('weight', None)
10 | bn_module.register_parameter('bias', None)
11 |
12 |
13 | def init_bn_params(bn_module):
14 | bn_module.running_mean.fill_(0)
15 | bn_module.running_var.fill_(1)
16 |
17 | def absorb_bn(module, bn_module, remove_bn=True, verbose=False):
18 | with torch.no_grad():
19 | w = module.weight
20 | if module.bias is None:
21 | zeros = torch.zeros(module.out_channels,
22 | dtype=w.dtype, device=w.device)
23 | bias = nn.Parameter(zeros)
24 | module.register_parameter('bias', bias)
25 | b = module.bias
26 |
27 | if hasattr(bn_module, 'running_mean'):
28 | b.add_(-bn_module.running_mean)
29 | if hasattr(bn_module, 'running_var'):
30 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5)
31 | w.mul_(invstd.view(w.size(0), 1, 1, 1))
32 | b.mul_(invstd)
33 |
34 | if remove_bn:
35 | if hasattr(bn_module, 'weight'):
36 | w.mul_(bn_module.weight.view(w.size(0), 1, 1, 1))
37 | b.mul_(bn_module.weight)
38 | if hasattr(bn_module, 'bias'):
39 | b.add_(bn_module.bias)
40 | remove_bn_params(bn_module)
41 | else:
42 | init_bn_params(bn_module)
43 |
44 | if verbose:
45 | logging.info('BN module %s was asborbed into layer %s' %
46 | (bn_module, module))
47 |
48 |
49 | def is_bn(m):
50 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)
51 |
52 |
53 | def is_absorbing(m):
54 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)
55 |
56 |
57 | def search_absorb_bn(model, prev=None, remove_bn=True, verbose=False):
58 | with torch.no_grad():
59 | for m in model.children():
60 | if is_bn(m) and is_absorbing(prev):
61 | absorb_bn(prev, m, remove_bn=remove_bn, verbose=verbose)
62 | search_absorb_bn(m, remove_bn=remove_bn, verbose=verbose)
63 | prev = m
64 |
--------------------------------------------------------------------------------
/vision/utils/cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .misc import onehot
5 |
6 |
7 | def _is_long(x):
8 | if hasattr(x, 'data'):
9 | x = x.data
10 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor)
11 |
12 |
13 | def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean',
14 | smooth_eps=None, smooth_dist=None, from_logits=True):
15 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567"""
16 | smooth_eps = smooth_eps or 0
17 |
18 | # ordinary log-liklihood - use cross_entropy from nn
19 | if _is_long(target) and smooth_eps == 0:
20 | if from_logits:
21 | return F.cross_entropy(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
22 | else:
23 | return F.nll_loss(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
24 |
25 | if from_logits:
26 | # log-softmax of inputs
27 | lsm = F.log_softmax(inputs, dim=-1)
28 | else:
29 | lsm = inputs
30 |
31 | masked_indices = None
32 | num_classes = inputs.size(-1)
33 |
34 | if _is_long(target) and ignore_index >= 0:
35 | masked_indices = target.eq(ignore_index)
36 |
37 | if smooth_eps > 0 and smooth_dist is not None:
38 | if _is_long(target):
39 | target = onehot(target, num_classes).type_as(inputs)
40 | if smooth_dist.dim() < target.dim():
41 | smooth_dist = smooth_dist.unsqueeze(0)
42 | target.lerp_(smooth_dist, smooth_eps)
43 |
44 | if weight is not None:
45 | lsm = lsm * weight.unsqueeze(0)
46 |
47 | if _is_long(target):
48 | eps_sum = smooth_eps / num_classes
49 | eps_nll = 1. - eps_sum - smooth_eps
50 | likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
51 | loss = -(eps_nll * likelihood + eps_sum * lsm.sum(-1))
52 | else:
53 | loss = -(target * lsm).sum(-1)
54 |
55 | if masked_indices is not None:
56 | loss.masked_fill_(masked_indices, 0)
57 |
58 | if reduction == 'sum':
59 | loss = loss.sum()
60 | elif reduction == 'mean':
61 | if masked_indices is None:
62 | loss = loss.mean()
63 | else:
64 | loss = loss.sum() / float(loss.size(0) - masked_indices.sum())
65 |
66 | return loss
67 |
68 |
69 | class CrossEntropyLoss(nn.CrossEntropyLoss):
70 | """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing"""
71 |
72 | def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None, from_logits=True):
73 | super(CrossEntropyLoss, self).__init__(weight=weight,
74 | ignore_index=ignore_index, reduction=reduction)
75 | self.smooth_eps = smooth_eps
76 | self.smooth_dist = smooth_dist
77 | self.from_logits = from_logits
78 |
79 | def forward(self, input, target, smooth_dist=None):
80 | if smooth_dist is None:
81 | smooth_dist = self.smooth_dist
82 | return cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index,
83 | reduction=self.reduction, smooth_eps=self.smooth_eps,
84 | smooth_dist=smooth_dist, from_logits=self.from_logits)
85 |
86 |
87 | def binary_cross_entropy(inputs, target, weight=None, reduction='mean', smooth_eps=None, from_logits=False):
88 | """cross entropy loss, with support for label smoothing https://arxiv.org/abs/1512.00567"""
89 | smooth_eps = smooth_eps or 0
90 | if smooth_eps > 0:
91 | target = target.float()
92 | target.add_(smooth_eps).div_(2.)
93 | if from_logits:
94 | return F.binary_cross_entropy_with_logits(inputs, target, weight=weight, reduction=reduction)
95 | else:
96 | return F.binary_cross_entropy(inputs, target, weight=weight, reduction=reduction)
97 |
98 |
99 | def binary_cross_entropy_with_logits(inputs, target, weight=None, reduction='mean', smooth_eps=None, from_logits=True):
100 | return binary_cross_entropy(inputs, target, weight, reduction, smooth_eps, from_logits)
101 |
102 |
103 | class BCELoss(nn.BCELoss):
104 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', smooth_eps=None, from_logits=False):
105 | super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
106 | self.smooth_eps = smooth_eps
107 | self.from_logits = from_logits
108 |
109 | def forward(self, input, target):
110 | return binary_cross_entropy(input, target,
111 | weight=self.weight, reduction=self.reduction,
112 | smooth_eps=self.smooth_eps, from_logits=self.from_logits)
113 |
114 |
115 | class BCEWithLogitsLoss(BCELoss):
116 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', smooth_eps=None, from_logits=True):
117 | super(BCEWithLogitsLoss, self).__init__(weight, size_average,
118 | reduce, reduction, smooth_eps=smooth_eps, from_logits=from_logits)
119 |
--------------------------------------------------------------------------------
/vision/utils/log.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | from itertools import cycle
4 | import torch
5 | import logging.config
6 | import json
7 |
8 | import pandas as pd
9 | from bokeh.io import output_file, save, show
10 | from bokeh.plotting import figure
11 | from bokeh.layouts import column
12 | from bokeh.models import Div
13 |
14 | try:
15 | import hyperdash
16 | HYPERDASH_AVAILABLE = True
17 | except ImportError:
18 | HYPERDASH_AVAILABLE = False
19 |
20 |
21 | def export_args_namespace(args, filename):
22 | """
23 | args: argparse.Namespace
24 | arguments to save
25 | filename: string
26 | filename to save at
27 | """
28 | with open(filename, 'w') as fp:
29 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4)
30 |
31 |
32 | def setup_logging(log_file='log.txt', resume=False, dummy=False):
33 | """
34 | Setup logging configuration
35 | """
36 | if dummy:
37 | logging.getLogger('dummy')
38 | else:
39 | if os.path.isfile(log_file) and resume:
40 | file_mode = 'a'
41 | else:
42 | file_mode = 'w'
43 |
44 | root_logger = logging.getLogger()
45 | if root_logger.handlers:
46 | root_logger.removeHandler(root_logger.handlers[0])
47 | logging.basicConfig(level=logging.INFO,
48 | format="%(asctime)s - %(levelname)s - %(message)s",
49 | datefmt="%Y-%m-%d %H:%M:%S",
50 | filename=log_file,
51 | filemode=file_mode)
52 | console = logging.StreamHandler()
53 | console.setLevel(logging.INFO)
54 | formatter = logging.Formatter('%(message)s')
55 | console.setFormatter(formatter)
56 | logging.getLogger('').addHandler(console)
57 |
58 |
59 | def plot_figure(data, x, y, title=None, xlabel=None, ylabel=None, legend=None,
60 | x_axis_type='linear', y_axis_type='linear',
61 | width=800, height=400, line_width=2,
62 | colors=['red', 'green', 'blue', 'orange',
63 | 'black', 'purple', 'brown'],
64 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save',
65 | append_figure=None):
66 | """
67 | creates a new plot figures
68 | example:
69 | plot_figure(x='epoch', y=['train_loss', 'val_loss'],
70 | 'title='Loss', 'ylabel'='loss')
71 | """
72 | if not isinstance(y, list):
73 | y = [y]
74 | xlabel = xlabel or x
75 | legend = legend or y
76 | assert len(legend) == len(y)
77 | if append_figure is not None:
78 | f = append_figure
79 | else:
80 | f = figure(title=title, tools=tools,
81 | width=width, height=height,
82 | x_axis_label=xlabel or x,
83 | y_axis_label=ylabel or '',
84 | x_axis_type=x_axis_type,
85 | y_axis_type=y_axis_type)
86 | colors = cycle(colors)
87 | for i, yi in enumerate(y):
88 | f.line(data[x], data[yi],
89 | line_width=line_width,
90 | line_color=next(colors), legend_label=legend[i])
91 | f.legend.click_policy = "hide"
92 | return f
93 |
94 |
95 | class ResultsLog(object):
96 |
97 | supported_data_formats = ['csv', 'json']
98 |
99 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'):
100 | """
101 | Parameters
102 | ----------
103 | path: string
104 | path to directory to save data files
105 | plot_path: string
106 | path to directory to save plot files
107 | title: string
108 | title of HTML file
109 | params: Namespace
110 | optionally save parameters for results
111 | resume: bool
112 | resume previous logging
113 | data_format: str('csv'|'json')
114 | which file format to use to save the data
115 | """
116 | if data_format not in ResultsLog.supported_data_formats:
117 | raise ValueError('data_format must of the following: ' +
118 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats]))
119 |
120 | if data_format == 'json':
121 | self.data_path = '{}.json'.format(path)
122 | else:
123 | self.data_path = '{}.csv'.format(path)
124 | if params is not None:
125 | export_args_namespace(params, '{}.json'.format(path))
126 | self.plot_path = '{}.html'.format(path)
127 | self.results = None
128 | self.clear()
129 | self.first_save = True
130 | if os.path.isfile(self.data_path):
131 | if resume:
132 | self.load(self.data_path)
133 | self.first_save = False
134 | else:
135 | os.remove(self.data_path)
136 | self.results = pd.DataFrame()
137 | else:
138 | self.results = pd.DataFrame()
139 |
140 | self.title = title
141 | self.data_format = data_format
142 |
143 | if HYPERDASH_AVAILABLE:
144 | name = self.title if title != '' else path
145 | self.hd_experiment = hyperdash.Experiment(name)
146 | if params is not None:
147 | for k, v in params._get_kwargs():
148 | self.hd_experiment.param(k, v, log=False)
149 |
150 | def clear(self):
151 | self.figures = []
152 |
153 | def add(self, **kwargs):
154 | """Add a new row to the dataframe
155 | example:
156 | resultsLog.add(epoch=epoch_num, train_loss=loss,
157 | test_loss=test_loss)
158 | """
159 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
160 | self.results = self.results.append(df, ignore_index=True)
161 | if hasattr(self, 'hd_experiment'):
162 | for k, v in kwargs.items():
163 | self.hd_experiment.metric(k, v, log=False)
164 |
165 | def smooth(self, column_name, window):
166 | """Select an entry to smooth over time"""
167 | # TODO: smooth only new data
168 | smoothed_column = self.results[column_name].rolling(
169 | window=window, center=False).mean()
170 | self.results[column_name + '_smoothed'] = smoothed_column
171 |
172 | def save(self, title=None):
173 | """save the json file.
174 | Parameters
175 | ----------
176 | title: string
177 | title of the HTML file
178 | """
179 | title = title or self.title
180 | if len(self.figures) > 0:
181 | if os.path.isfile(self.plot_path):
182 | os.remove(self.plot_path)
183 | if self.first_save:
184 | self.first_save = False
185 | logging.info('Plot file saved at: {}'.format(
186 | os.path.abspath(self.plot_path)))
187 |
188 | output_file(self.plot_path, title=title)
189 | plot = column(
190 | Div(text='{}
'.format(title)), *self.figures)
191 | save(plot)
192 | self.clear()
193 |
194 | if self.data_format == 'json':
195 | self.results.to_json(self.data_path, orient='records', lines=True)
196 | else:
197 | self.results.to_csv(self.data_path, index=False, index_label=False)
198 |
199 | def load(self, path=None):
200 | """load the data file
201 | Parameters
202 | ----------
203 | path:
204 | path to load the json|csv file from
205 | """
206 | path = path or self.data_path
207 | if os.path.isfile(path):
208 | if self.data_format == 'json':
209 | self.results.read_json(path)
210 | else:
211 | self.results = pd.read_csv(path)
212 | else:
213 | raise ValueError('{} isn''t a file'.format(path))
214 |
215 | def show(self, title=None):
216 | title = title or self.title
217 | if len(self.figures) > 0:
218 | plot = column(
219 | Div(text='{}
'.format(title)), *self.figures)
220 | show(plot)
221 |
222 | def plot(self, *kargs, **kwargs):
223 | """
224 | add a new plot to the HTML file
225 | example:
226 | results.plot(x='epoch', y=['train_loss', 'val_loss'],
227 | 'title='Loss', 'ylabel'='loss')
228 | """
229 | f = plot_figure(self.results, *kargs, **kwargs)
230 | self.figures.append(f)
231 |
232 | def image(self, *kargs, **kwargs):
233 | fig = figure()
234 | fig.image(*kargs, **kwargs)
235 | self.figures.append(fig)
236 |
237 | def end(self):
238 | if hasattr(self, 'hd_experiment'):
239 | self.hd_experiment.end()
240 |
241 |
242 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False):
243 | filename = os.path.join(path, filename)
244 | torch.save(state, filename)
245 | if is_best:
246 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))
247 | if save_all:
248 | shutil.copyfile(filename, os.path.join(
249 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch']))
250 |
--------------------------------------------------------------------------------
/vision/utils/meters.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self):
8 | self.reset()
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
22 |
23 | class OnlineMeter(object):
24 | """Computes and stores the average and variance/std values of tensor"""
25 |
26 | def __init__(self):
27 | self.mean = torch.FloatTensor(1).fill_(-1)
28 | self.M2 = torch.FloatTensor(1).zero_()
29 | self.count = 0.
30 | self.needs_init = True
31 |
32 | def reset(self, x):
33 | self.mean = x.new(x.size()).zero_()
34 | self.M2 = x.new(x.size()).zero_()
35 | self.count = 0.
36 | self.needs_init = False
37 |
38 | def update(self, x):
39 | self.val = x
40 | if self.needs_init:
41 | self.reset(x)
42 | self.count += 1
43 | delta = x - self.mean
44 | self.mean.add_(delta / self.count)
45 | delta2 = x - self.mean
46 | self.M2.add_(delta * delta2)
47 |
48 | @property
49 | def var(self):
50 | if self.count < 2:
51 | return self.M2.clone().zero_()
52 | return self.M2 / (self.count - 1)
53 |
54 | @property
55 | def std(self):
56 | return self.var().sqrt()
57 |
58 |
59 | def accuracy(output, target, topk=(1,)):
60 | """Computes the precision@k for the specified values of k"""
61 | maxk = max(topk)
62 | batch_size = target.size(0)
63 |
64 | _, pred = output.topk(maxk, 1, True, True)
65 | pred = pred.t().type_as(target)
66 | correct = pred.eq(target.view(1, -1).expand_as(pred))
67 |
68 | res = []
69 | for k in topk:
70 | correct_k = correct[:k].view(-1).float().sum(0)
71 | res.append(correct_k.mul_(100.0 / batch_size))
72 | return res
73 |
74 |
75 | class AccuracyMeter(object):
76 | """Computes and stores the average and current topk accuracy"""
77 |
78 | def __init__(self, topk=(1,)):
79 | self.topk = topk
80 | self.reset()
81 |
82 | def reset(self):
83 | self._meters = {}
84 | for k in self.topk:
85 | self._meters[k] = AverageMeter()
86 |
87 | def update(self, output, target):
88 | n = target.nelement()
89 | acc_vals = accuracy(output, target, self.topk)
90 | for i, k in enumerate(self.topk):
91 | self._meters[k].update(acc_vals[i])
92 |
93 | @property
94 | def val(self):
95 | return {n: meter.val for (n, meter) in self._meters.items()}
96 |
97 | @property
98 | def avg(self):
99 | return {n: meter.avg for (n, meter) in self._meters.items()}
100 |
101 | @property
102 | def avg_error(self):
103 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()}
104 |
--------------------------------------------------------------------------------
/vision/utils/misc.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
6 |
7 | torch_dtypes = {
8 | 'float': torch.float,
9 | 'float32': torch.float32,
10 | 'float64': torch.float64,
11 | 'double': torch.double,
12 | 'float16': torch.float16,
13 | 'half': torch.half,
14 | 'uint8': torch.uint8,
15 | 'int8': torch.int8,
16 | 'int16': torch.int16,
17 | 'short': torch.short,
18 | 'int32': torch.int32,
19 | 'int': torch.int,
20 | 'int64': torch.int64,
21 | 'long': torch.long
22 | }
23 |
24 |
25 | def onehot(indexes, N=None, ignore_index=None):
26 | """
27 | Creates a one-representation of indexes with N possible entries
28 | if N is not specified, it will suit the maximum index appearing.
29 | indexes is a long-tensor of indexes
30 | ignore_index will be zero in onehot representation
31 | """
32 | if N is None:
33 | N = indexes.max() + 1
34 | sz = list(indexes.size())
35 | output = indexes.new().byte().resize_(*sz, N).zero_()
36 | output.scatter_(-1, indexes.unsqueeze(-1), 1)
37 | if ignore_index is not None and ignore_index >= 0:
38 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0)
39 | return output
40 |
41 |
42 | def set_global_seeds(i):
43 | try:
44 | import torch
45 | except ImportError:
46 | pass
47 | else:
48 | torch.manual_seed(i)
49 | if torch.cuda.is_available():
50 | torch.cuda.manual_seed_all(i)
51 | np.random.seed(i)
52 | random.seed(i)
53 |
54 |
55 | class CheckpointModule(nn.Module):
56 | def __init__(self, module, num_segments=1):
57 | super(CheckpointModule, self).__init__()
58 | assert num_segments == 1 or isinstance(module, nn.Sequential)
59 | self.module = module
60 | self.num_segments = num_segments
61 |
62 | def forward(self, x):
63 | if self.num_segments > 1:
64 | return checkpoint_sequential(self.module, self.num_segments, x)
65 | else:
66 | return checkpoint(self.module, x)
67 |
--------------------------------------------------------------------------------
/vision/utils/mixup.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from numpy.random import beta
5 | from torch.nn.functional import one_hot
6 |
7 |
8 | class MixUp(nn.Module):
9 | def __init__(self, batch_dim=0):
10 | super(MixUp, self).__init__()
11 | self.batch_dim = batch_dim
12 | self.reset()
13 |
14 | def reset(self):
15 | self.enabled = False
16 | self.mix_values = None
17 | self.mix_index = None
18 |
19 | def mix(self, x1, x2):
20 | if not torch.is_tensor(self.mix_values): # scalar
21 | return x2.lerp(x1, self.mix_values)
22 | else:
23 | view = [1] * int(x1.dim())
24 | view[self.batch_dim] = -1
25 | mix_val = self.mix_values.to(device=x1.device).view(*view)
26 | return mix_val * x1 + (1.-mix_val) * x2
27 |
28 | def sample(self, alpha, batch_size, sample_batch=False):
29 | self.mix_index = torch.randperm(batch_size)
30 | if sample_batch:
31 | values = beta(alpha, alpha, size=batch_size)
32 | self.mix_values = torch.tensor(values, dtype=torch.float)
33 | else:
34 | self.mix_values = torch.tensor([beta(alpha, alpha)],
35 | dtype=torch.float)
36 |
37 | def mix_target(self, y, n_class):
38 | if not self.training or \
39 | self.mix_values is None or\
40 | self.mix_values is None:
41 | return y
42 | y = one_hot(y, n_class).to(dtype=torch.float)
43 | idx = self.mix_index.to(device=y.device)
44 | y_mix = y.index_select(self.batch_dim, idx)
45 | return self.mix(y, y_mix)
46 |
47 | def forward(self, x):
48 | if not self.training or \
49 | self.mix_values is None or\
50 | self.mix_values is None:
51 | return x
52 | idx = self.mix_index.to(device=x.device)
53 | x_mix = x.index_select(self.batch_dim, idx)
54 | return self.mix(x, x_mix)
55 |
56 |
57 | def rand_bbox(size, lam):
58 | W = size[2]
59 | H = size[3]
60 | cut_rat = np.sqrt(1. - lam)
61 | cut_w = np.int(W * cut_rat)
62 | cut_h = np.int(H * cut_rat)
63 |
64 | # uniform
65 | cx = np.random.randint(W)
66 | cy = np.random.randint(H)
67 |
68 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
69 | bby1 = np.clip(cy - cut_h // 2, 0, H)
70 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
71 | bby2 = np.clip(cy + cut_h // 2, 0, H)
72 |
73 | return bbx1, bby1, bbx2, bby2
74 |
75 |
76 | class CutMix(MixUp):
77 | def __init__(self, batch_dim=0):
78 | super(CutMix, self).__init__(batch_dim)
79 |
80 | def mix_image(self, x1, x2):
81 | assert not torch.is_tensor(self.mix_values) or \
82 | self.mix_values.nelement() == 1
83 | lam = float(self.mix_values)
84 | bbx1, bby1, bbx2, bby2 = rand_bbox(x1.size(), lam)
85 | x1[:, :, bbx1:bbx2, bby1:bby2] = x2[:, :, bbx1:bbx2, bby1:bby2]
86 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
87 | (x1.size()[-1] * x1.size()[-2]))
88 | self.mix_values.fill_(lam)
89 | return x1
90 |
91 | def sample(self, alpha, batch_size, sample_batch=False):
92 | assert not sample_batch
93 | super(CutMix, self).sample(alpha, batch_size, sample_batch)
94 |
95 | def forward(self, x):
96 | if not self.training or \
97 | self.mix_values is None or\
98 | self.mix_values is None:
99 | return x
100 | idx = self.mix_index.to(device=x.device)
101 | x_mix = x.index_select(self.batch_dim, idx)
102 | return self.mix_image(x, x_mix)
103 |
--------------------------------------------------------------------------------
/vision/utils/param_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def is_not_bias(name):
6 | return not name.endswith('bias')
7 |
8 |
9 | def is_bn(module):
10 | return isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d)
11 |
12 |
13 | def is_not_bn(module):
14 | return not is_bn(module)
15 |
16 |
17 | def filtered_parameter_info(model, module_fn=None, module_name_fn=None, parameter_name_fn=None, memo=None):
18 | if memo is None:
19 | memo = set()
20 |
21 | for module_name, module in model.named_modules():
22 | if module_fn is not None and not module_fn(module):
23 | continue
24 | if module_name_fn is not None and not module_name_fn(module_name):
25 | continue
26 | for parameter_name, param in module.named_parameters(prefix=module_name, recurse=False):
27 | if parameter_name_fn is not None and not parameter_name_fn(parameter_name):
28 | continue
29 | if param not in memo:
30 | memo.add(param)
31 | yield {'named_module': (module_name, module), 'named_parameter': (parameter_name, param)}
32 |
33 |
34 | class FilterParameters(object):
35 | def __init__(self, source, module=None, module_name=None, parameter_name=None):
36 | if isinstance(source, FilterParameters):
37 | self._filtered_parameter_info = list(source.filter(
38 | module=module,
39 | module_name=module_name,
40 | parameter_name=parameter_name))
41 | elif isinstance(source, torch.nn.Module): # source is a model
42 | self._filtered_parameter_info = list(filtered_parameter_info(source,
43 | module_fn=module,
44 | module_name_fn=module_name,
45 | parameter_name_fn=parameter_name))
46 |
47 | def named_parameters(self):
48 | for p in self._filtered_parameter_info:
49 | yield p['named_parameter']
50 |
51 | def parameters(self):
52 | for _, p in self.named_parameters():
53 | yield p
54 |
55 | def filter(self, module=None, module_name=None, parameter_name=None):
56 | for p_info in self._filtered_parameter_info:
57 | if (module is None or module(p_info['named_module'][1])
58 | and (module_name is None or module_name(p_info['named_module'][0]))
59 | and (parameter_name is None or parameter_name(p_info['named_parameter'][0]))):
60 | yield p_info
61 |
62 | def named_modules(self):
63 | for m in self._filtered_parameter_info:
64 | yield m['named_module']
65 |
66 | def modules(self):
67 | for _, m in self.named_modules():
68 | yield m
69 |
70 | def to(self, *kargs, **kwargs):
71 | for m in self.modules():
72 | m.to(*kargs, **kwargs)
73 |
74 |
75 | class FilterModules(FilterParameters):
76 | pass
77 |
78 | if __name__ == '__main__':
79 | from torchvision.models import resnet50
80 | model = resnet50()
81 | filterd_params = FilterParameters(model,
82 | module=lambda m: isinstance(
83 | m, torch.nn.Linear),
84 | parameter_name=lambda n: 'bias' in n)
85 |
--------------------------------------------------------------------------------
/vision/utils/regime.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from six import string_types
3 |
4 |
5 | def eval_func(f, x):
6 | if isinstance(f, string_types):
7 | f = eval(f)
8 | return f(x)
9 |
10 |
11 | class Regime(object):
12 | """
13 | Examples for regime:
14 |
15 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3},
16 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4},
17 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4},
18 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5}
19 | ]"
20 | 2)
21 | "[{'step_lambda':
22 | "lambda t: {
23 | 'optimizer': 'Adam',
24 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5),
25 | 'betas': (0.9, 0.98), 'eps':1e-9}
26 | }]"
27 | """
28 |
29 | def __init__(self, regime, defaults={}):
30 | self.regime = regime
31 | self.current_regime_phase = None
32 | self.setting = defaults
33 |
34 | def update(self, epoch=None, train_steps=None):
35 | """adjusts according to current epoch or steps and regime.
36 | """
37 | if self.regime is None:
38 | return False
39 | epoch = -1 if epoch is None else epoch
40 | train_steps = -1 if train_steps is None else train_steps
41 | setting = deepcopy(self.setting)
42 | if self.current_regime_phase is None:
43 | # Find the first entry where the epoch is smallest than current
44 | for regime_phase, regime_setting in enumerate(self.regime):
45 | start_epoch = regime_setting.get('epoch', 0)
46 | start_step = regime_setting.get('step', 0)
47 | if epoch >= start_epoch or train_steps >= start_step:
48 | self.current_regime_phase = regime_phase
49 | break
50 | # each entry is updated from previous
51 | setting.update(regime_setting)
52 | if len(self.regime) > self.current_regime_phase + 1:
53 | next_phase = self.current_regime_phase + 1
54 | # Any more regime steps?
55 | start_epoch = self.regime[next_phase].get('epoch', float('inf'))
56 | start_step = self.regime[next_phase].get('step', float('inf'))
57 | if epoch >= start_epoch or train_steps >= start_step:
58 | self.current_regime_phase = next_phase
59 | setting.update(self.regime[self.current_regime_phase])
60 |
61 | if 'lr_decay_rate' in setting and 'lr' in setting:
62 | decay_steps = setting.pop('lr_decay_steps', 100)
63 | if train_steps % decay_steps == 0:
64 | decay_rate = setting.pop('lr_decay_rate')
65 | setting['lr'] *= decay_rate ** (train_steps / decay_steps)
66 | elif 'step_lambda' in setting:
67 | setting.update(eval_func(setting.pop('step_lambda'), train_steps))
68 | elif 'epoch_lambda' in setting:
69 | setting.update(eval_func(setting.pop('epoch_lambda'), epoch))
70 |
71 | if 'execute' in setting:
72 | setting.pop('execute')()
73 |
74 | if 'execute_once' in setting:
75 | setting.pop('execute_once')()
76 | # remove from regime, so won't happen again
77 | self.regime[self.current_regime_phase].pop('execute_once', None)
78 |
79 | if setting == self.setting:
80 | return False
81 | else:
82 | self.setting = setting
83 | return True
84 |
85 | def __repr__(self):
86 | return 'Current: %s\n Regime:%s' % (self.setting, self.regime)
87 |
--------------------------------------------------------------------------------