├── models ├── __init__.py ├── densenet.py └── wideresnet.py ├── LICENSE ├── README.md ├── train.py ├── prune.py └── funcs.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wideresnet import * 2 | from .densenet import * 3 | 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 BayesWatch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Closer Look at Structured Pruning for Neural Network Compression 2 | 3 | Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. 4 | 5 | To prune, we fill our networks with custom `MaskBlocks`, which are manipulated using `Pruner` in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can. 6 | ## Setup 7 | This is best done in a clean conda environment: 8 | 9 | ``` 10 | conda create -n prunes python=3.6 11 | conda activate prunes 12 | conda install pytorch torchvision -c pytorch 13 | ``` 14 | 15 | ## Repository layout 16 | -`train.py`: contains all of the code for training large models from scratch and for training pruned models from scratch 17 | -`prune.py`: contains the code for pruning trained models 18 | -`funcs.py`: contains useful pruning functions and any functions we used commonly 19 | 20 | ## CIFAR Experiments 21 | First, you will need some initial models. 22 | 23 | To train a WRN-40-2: 24 | ``` 25 | python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res' 26 | ``` 27 | 28 | The default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters: 29 | 30 | ``` 31 | python train.py --net='dense' --depth=100 --data_loc= --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1 32 | ``` 33 | 34 | These will automatically save checkpoints to the `checkpoints` folder. 35 | 36 | 37 | 38 | ### Pruning 39 | Once training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets) 40 | ``` 41 | python prune.py --net='res' --data_loc= --base_model='res' --save_file='res_fisher' 42 | python prune.py --net='res' --data_loc= --l1_prune=True --base_model='res' --save_file='res_l1' 43 | 44 | python prune.py --net='dense' --depth 100 --data_loc= --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600 45 | python prune.py --net='dense' --depth 100 --data_loc= --l1_prune=True --base_model='dense' --save_file='dense_l1' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600 46 | 47 | ``` 48 | Note that the default is to perform Fisher pruning, so you don't need to pass a flag to use it. 49 | Once finished, we can train the pruned models from scratch, e.g.: 50 | ``` 51 | python train.py --data_loc= --net='res' --base_file='res_fisher__prunes' --deploy --mask=1 --save_file='res_fisher__prunes_scratch' 52 | ``` 53 | 54 | Each model can then be evaluated using: 55 | ``` 56 | python train.py --deploy --eval --data_loc= --net='res' --mask=1 --base_file='res_fisher__prunes' 57 | ``` 58 | 59 | 60 | ### Training Reduced models 61 | 62 | This can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option: 63 | ``` 64 | python train.py --net='res' --depth= --width= --data_loc= --save_file='res_reduced' 65 | ``` 66 | 67 | To add bottlenecks, use the following: 68 | 69 | ``` 70 | python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res_bottle' --bottle --bottle_mult 71 | ``` 72 | 73 | With DenseNets you can modify the `depth` or `growth`, or use `--bottle --bottle_mult ` as above. 74 | 75 | 76 | ### Acknowledgements 77 | 78 | [Jack Turner][jack] wrote the L1 stuff, and some other stuff for that matter. 79 | 80 | Code has been liberally borrowed from many a repo, including, but not limited to: 81 | 82 | ``` 83 | https://github.com/xternalz/WideResNet-pytorch 84 | https://github.com/bamos/densenet.pytorch 85 | https://github.com/kuangliu/pytorch-cifar 86 | https://github.com/ShichenLiu/CondenseNet 87 | ``` 88 | ### Citing this work 89 | 90 | If you would like to cite this work, please use the following bibtex entry: 91 | 92 | ``` 93 | @article{crowley2018pruning, 94 | title={A Closer Look at Structured Pruning for Neural Network Compression}, 95 | author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael}, 96 | journal={arXiv preprint arXiv:1810.04622}, 97 | year={2018}, 98 | } 99 | ``` 100 | [jack]: https://github.com/jack-willturner 101 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import torchvision.datasets as dset 10 | import torchvision.transforms as transforms 11 | from torch.utils.data import DataLoader 12 | 13 | import torchvision.models as models 14 | 15 | import sys 16 | import math 17 | 18 | 19 | class Identity(nn.Module): 20 | def __init__(self): 21 | super(Identity, self).__init__() 22 | 23 | def forward(self, x): 24 | return x 25 | 26 | 27 | class Zero(nn.Module): 28 | def __init__(self): 29 | super(Zero, self).__init__() 30 | 31 | def forward(self, x): 32 | return x * 0 33 | 34 | 35 | class ZeroMake(nn.Module): 36 | def __init__(self, channels, spatial): 37 | super(ZeroMake, self).__init__() 38 | self.spatial = spatial 39 | self.channels = channels 40 | 41 | def forward(self, x): 42 | return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial], 43 | dtype=x.dtype, layout=x.layout, device=x.device) 44 | 45 | 46 | class MaskBlock(nn.Module): 47 | def __init__(self, nChannels, growthRate): 48 | super(MaskBlock, self).__init__() 49 | interChannels = 4 * growthRate 50 | self.bn1 = nn.BatchNorm2d(nChannels) 51 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 52 | bias=False) 53 | self.bn2 = nn.BatchNorm2d(interChannels) 54 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 55 | padding=1, bias=False) 56 | 57 | self.activation = Identity() 58 | self.activation.register_backward_hook(self._fisher) 59 | self.register_buffer('mask', None) 60 | 61 | self.input_shape = None 62 | self.output_shape = None 63 | self.flops = None 64 | self.params = None 65 | self.in_channels = nChannels 66 | self.out_channels = growthRate 67 | self.stride = 1 68 | 69 | # Fisher method is called on backward passes 70 | self.running_fisher = 0 71 | 72 | def forward(self, x): 73 | out = self.conv1(F.relu(self.bn1(x))) 74 | out = F.relu(self.bn2(out)) 75 | if self.mask is not None: 76 | out = out * self.mask[None, :, None, None] 77 | else: 78 | self._create_mask(x, out) 79 | out = self.activation(out) 80 | self.act = out 81 | 82 | out = self.conv2(out) 83 | out = torch.cat([x, out], 1) 84 | return out 85 | 86 | def _create_mask(self, x, out): 87 | """This takes an activation to generate the exact mask required. It also records input and output shapes 88 | for posterity.""" 89 | self.mask = x.new_ones(out.shape[1]) 90 | self.input_shape = x.size() 91 | self.output_shape = out.size() 92 | 93 | def _fisher(self, _, __, grad_output): 94 | act = self.act.detach() 95 | grad = grad_output[0].detach() 96 | 97 | g_nk = (act * grad).sum(-1).sum(-1) 98 | del_k = g_nk.pow(2).mean(0).mul(0.5) 99 | self.running_fisher += del_k 100 | 101 | def reset_fisher(self): 102 | self.running_fisher = 0 * self.running_fisher 103 | 104 | def update(self, previous_mask): 105 | # This is only required for non-modular nets. 106 | return None 107 | 108 | def cost(self): 109 | 110 | in_channels = self.in_channels 111 | out_channels = self.out_channels 112 | middle_channels = int(self.mask.sum().item()) 113 | 114 | conv1_size = self.conv1.weight.size() 115 | conv2_size = self.conv2.weight.size() 116 | 117 | self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \ 118 | conv2_size[2] * conv2_size[3] 119 | 120 | self.params += 2 * in_channels + 2 * middle_channels 121 | 122 | 123 | def compress_weights(self): 124 | middle_dim = int(self.mask.sum().item()) 125 | 126 | if middle_dim is not 0: 127 | conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=1, bias=False) 128 | conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :]) 129 | 130 | # Batch norm 2 changes 131 | bn2 = nn.BatchNorm2d(middle_dim) 132 | bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1]) 133 | bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1]) 134 | bn2.running_mean = self.bn2.running_mean[self.mask == 1] 135 | bn2.running_var = self.bn2.running_var[self.mask == 1] 136 | 137 | conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False) 138 | conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :]) 139 | 140 | if middle_dim is 0: 141 | conv1 = Zero() 142 | bn2 = Zero() 143 | conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride) 144 | 145 | self.conv1 = conv1 146 | self.conv2 = conv2 147 | self.bn2 = bn2 148 | 149 | if middle_dim is not 0: 150 | self.mask = torch.ones(middle_dim) 151 | else: 152 | self.mask = torch.ones(1) 153 | 154 | 155 | class Bottleneck(nn.Module): 156 | def __init__(self, nChannels, growthRate, width=1): 157 | super(Bottleneck, self).__init__() 158 | interChannels = int(4 * growthRate * width) 159 | self.bn1 = nn.BatchNorm2d(nChannels) 160 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 161 | bias=False) 162 | self.bn2 = nn.BatchNorm2d(interChannels) 163 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 164 | padding=1, bias=False) 165 | 166 | def forward(self, x): 167 | out = self.conv1(F.relu(self.bn1(x))) 168 | out = self.conv2(F.relu(self.bn2(out))) 169 | out = torch.cat((x, out), 1) 170 | return out 171 | 172 | 173 | class SingleLayer(nn.Module): 174 | def __init__(self, nChannels, growthRate): 175 | super(SingleLayer, self).__init__() 176 | self.bn1 = nn.BatchNorm2d(nChannels) 177 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 178 | padding=1, bias=False) 179 | 180 | def forward(self, x): 181 | out = self.conv1(F.relu(self.bn1(x))) 182 | out = torch.cat((x, out), 1) 183 | return out 184 | 185 | 186 | class Transition(nn.Module): 187 | def __init__(self, nChannels, nOutChannels): 188 | super(Transition, self).__init__() 189 | self.bn1 = nn.BatchNorm2d(nChannels) 190 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 191 | bias=False) 192 | 193 | def forward(self, x): 194 | out = self.conv1(F.relu(self.bn1(x))) 195 | out = F.avg_pool2d(out, 2) 196 | return out 197 | 198 | 199 | class DenseNet(nn.Module): 200 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck, mask=False, width=1.): 201 | super(DenseNet, self).__init__() 202 | 203 | nDenseBlocks = (depth - 4) // 3 204 | if bottleneck: 205 | nDenseBlocks //= 2 206 | 207 | nChannels = 2 * growthRate 208 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, 209 | bias=False) 210 | 211 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width) 212 | nChannels += nDenseBlocks * growthRate 213 | nOutChannels = int(math.floor(nChannels * reduction)) 214 | self.trans1 = Transition(nChannels, nOutChannels) 215 | 216 | nChannels = nOutChannels 217 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width) 218 | nChannels += nDenseBlocks * growthRate 219 | nOutChannels = int(math.floor(nChannels * reduction)) 220 | self.trans2 = Transition(nChannels, nOutChannels) 221 | 222 | nChannels = nOutChannels 223 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width) 224 | nChannels += nDenseBlocks * growthRate 225 | 226 | self.bn1 = nn.BatchNorm2d(nChannels) 227 | self.fc = nn.Linear(nChannels, nClasses) 228 | 229 | # Count params that don't exist in blocks (conv1, bn1, fc, trans1, trans2, trans3) 230 | self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \ 231 | len(self.fc.weight.view(-1)) + len(self.fc.bias) 232 | self.fixed_params += len(self.trans1.conv1.weight.view(-1)) + 2 * len(self.trans1.bn1.weight) 233 | self.fixed_params += len(self.trans2.conv1.weight.view(-1)) + 2 * len(self.trans2.bn1.weight) 234 | 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, nn.BatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.Linear): 243 | m.bias.data.zero_() 244 | 245 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, mask=False, width=1): 246 | layers = [] 247 | for i in range(int(nDenseBlocks)): 248 | if bottleneck and mask: 249 | layers.append(MaskBlock(nChannels, growthRate)) 250 | elif bottleneck: 251 | layers.append(Bottleneck(nChannels, growthRate, width)) 252 | else: 253 | layers.append(SingleLayer(nChannels, growthRate)) 254 | nChannels += growthRate 255 | return nn.Sequential(*layers) 256 | 257 | def forward(self, x): 258 | out = self.conv1(x) 259 | out = self.trans1(self.dense1(out)) 260 | out = self.trans2(self.dense2(out)) 261 | out = self.dense3(out) 262 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 263 | out = self.fc(out) 264 | return out 265 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """This script just trains models from scratch, to later be pruned""" 2 | 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import torch.optim.lr_scheduler as lr_scheduler 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | from models import * 11 | 12 | from funcs import * 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Pruning') 16 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers') 17 | parser.add_argument('--GPU', default='0', type=str, help='GPU to use') 18 | parser.add_argument('--save_file', default='saveto', type=str, help='save file for checkpoints') 19 | parser.add_argument('--base_file', default='bbb', type=str, help='base file for checkpoints') 20 | parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') 21 | parser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar') 22 | 23 | # Learning specific arguments 24 | parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)') 25 | parser.add_argument('-lr', '--learning_rate', default=.1, type=float, metavar='LR', help='initial learning rate') 26 | parser.add_argument('-epochs', '--no_epochs', default=200, type=int, metavar='epochs', help='no. epochs') 27 | parser.add_argument('--epoch_step', default='[60,120,160]', type=str, help='json list with epochs to drop lr on') 28 | parser.add_argument('--lr_decay_ratio', default=0.2, type=float, help='learning rate decay factor') 29 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 30 | parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay') 31 | parser.add_argument('--eval', '-e', action='store_true', help='resume from checkpoint') 32 | parser.add_argument('--mask', '-m', type=int, help='mask mode', default=0) 33 | parser.add_argument('--deploy', '-de', action='store_true', help='prune and deploy model') 34 | parser.add_argument('--params_left', '-pl', default=0, type=int, help='prune til...') 35 | parser.add_argument('--net', choices=['res', 'dense'], default='res') 36 | 37 | # Net specific 38 | parser.add_argument('--depth', '-d', default=40, type=int, metavar='D', help='depth of wideresnet/densenet') 39 | parser.add_argument('--width', '-w', default=2.0, type=float, metavar='W', help='width of wideresnet') 40 | parser.add_argument('--growth', default=12, type=int, help='growth rate of densenet') 41 | parser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet') 42 | 43 | 44 | # Uniform bottlenecks 45 | parser.add_argument('--bottle', action='store_true', help='Linearly scale bottlenecks') 46 | parser.add_argument('--bottle_mult', default=0.5, type=float, help='bottleneck multiplier') 47 | 48 | 49 | if not os.path.exists('checkpoints/'): 50 | os.makedirs('checkpoints/') 51 | 52 | args = parser.parse_args() 53 | print(args) 54 | os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU 55 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 56 | 57 | if args.net == 'res': 58 | if not args.bottle: 59 | model = WideResNet(args.depth, args.width, mask=args.mask) 60 | else: 61 | model = WideResNetBottle(args.depth, args.width, bottle_mult=args.bottle_mult) 62 | elif args.net == 'dense': 63 | if not args.bottle: 64 | model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask) 65 | else: 66 | model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, width=args.bottle_mult) 67 | 68 | else: 69 | raise ValueError('pick a valid net') 70 | 71 | pruner = Pruner() 72 | 73 | if args.deploy: 74 | # Feed example to activate masks 75 | model(torch.rand(1, 3, 32, 32)) 76 | SD = torch.load('checkpoints/%s.t7' % args.base_file) 77 | 78 | if not args.eval: 79 | 80 | pruner = Pruner() 81 | pruner._get_masks(model) 82 | 83 | for ii in SD['prune_history']: 84 | pruner.fixed_prune(model, ii) 85 | 86 | else: 87 | model.load_state_dict(SD['state_dict']) 88 | 89 | pruner.compress(model) 90 | 91 | get_inf_params(model) 92 | time.sleep(1) 93 | model.to(device) 94 | 95 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 96 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 97 | 98 | print('==> Preparing data..') 99 | num_classes = 10 100 | 101 | transform_train = transforms.Compose([ 102 | transforms.ToTensor(), 103 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 104 | (4, 4, 4, 4), mode='reflect').squeeze()), 105 | transforms.ToPILImage(), 106 | transforms.RandomCrop(32), 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor(), 109 | normalize, 110 | ]) 111 | 112 | transform_val = transforms.Compose([ 113 | transforms.ToTensor(), 114 | normalize, 115 | 116 | ]) 117 | 118 | trainset = torchvision.datasets.CIFAR10(root=args.data_loc, 119 | train=True, download=True, transform=transform_train) 120 | valset = torchvision.datasets.CIFAR10(root=args.data_loc, 121 | train=False, download=True, transform=transform_val) 122 | 123 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 124 | num_workers=args.workers, 125 | pin_memory=False) 126 | valloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False, 127 | num_workers=args.workers, 128 | pin_memory=False) 129 | 130 | error_history = [] 131 | epoch_step = json.loads(args.epoch_step) 132 | 133 | 134 | def train(): 135 | batch_time = AverageMeter() 136 | data_time = AverageMeter() 137 | losses = AverageMeter() 138 | top1 = AverageMeter() 139 | top5 = AverageMeter() 140 | 141 | # switch to train mode 142 | model.train() 143 | 144 | end = time.time() 145 | 146 | for i, (input, target) in enumerate(trainloader): 147 | 148 | # measure data loading time 149 | data_time.update(time.time() - end) 150 | 151 | input, target = input.to(device), target.to(device) 152 | 153 | # compute output 154 | output = model(input) 155 | 156 | loss = criterion(output, target) 157 | 158 | # measure accuracy and record loss 159 | err1, err5 = get_error(output.detach(), target, topk=(1, 5)) 160 | 161 | losses.update(loss.item(), input.size(0)) 162 | top1.update(err1.item(), input.size(0)) 163 | top5.update(err5.item(), input.size(0)) 164 | 165 | # compute gradient and do SGD step 166 | optimizer.zero_grad() 167 | loss.backward() 168 | optimizer.step() 169 | 170 | # measure elapsed time 171 | batch_time.update(time.time() - end) 172 | end = time.time() 173 | 174 | if i % args.print_freq == 0: 175 | print('Epoch: [{0}][{1}/{2}]\t' 176 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 177 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 178 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 179 | 'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t' 180 | 'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 181 | epoch, i, len(trainloader), batch_time=batch_time, 182 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 183 | 184 | 185 | 186 | 187 | def validate(): 188 | global error_history 189 | 190 | batch_time = AverageMeter() 191 | data_time = AverageMeter() 192 | 193 | losses = AverageMeter() 194 | top1 = AverageMeter() 195 | top5 = AverageMeter() 196 | 197 | # switch to evaluate mode 198 | model.eval() 199 | 200 | end = time.time() 201 | 202 | for i, (input, target) in enumerate(valloader): 203 | 204 | # measure data loading time 205 | data_time.update(time.time() - end) 206 | 207 | input, target = input.to(device), target.to(device) 208 | 209 | # compute output 210 | output = model(input) 211 | 212 | loss = criterion(output, target) 213 | 214 | # measure accuracy and record loss 215 | err1, err5 = get_error(output.detach(), target, topk=(1, 5)) 216 | 217 | losses.update(loss.item(), input.size(0)) 218 | top1.update(err1.item(), input.size(0)) 219 | top5.update(err5.item(), input.size(0)) 220 | 221 | # measure elapsed time 222 | batch_time.update(time.time() - end) 223 | end = time.time() 224 | 225 | if i % args.print_freq == 0: 226 | print('Test: [{0}/{1}]\t' 227 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 228 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 229 | 'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t' 230 | 'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 231 | i, len(valloader), batch_time=batch_time, loss=losses, 232 | top1=top1, top5=top5)) 233 | 234 | print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}' 235 | .format(top1=top1, top5=top5)) 236 | 237 | 238 | # Record Top 1 for CIFAR 239 | error_history.append(top1.avg) 240 | 241 | 242 | if __name__ == '__main__': 243 | 244 | filename = 'checkpoints/%s.t7' % args.save_file 245 | criterion = nn.CrossEntropyLoss() 246 | optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad], 247 | lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 248 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=epoch_step, gamma=args.lr_decay_ratio) 249 | 250 | if not args.eval: 251 | 252 | for epoch in range(args.no_epochs): 253 | 254 | print('Epoch %d:' % epoch) 255 | print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0]) 256 | # train for one epoch 257 | train() 258 | scheduler.step() 259 | # # evaluate on validation set 260 | validate() 261 | 262 | save_checkpoint({ 263 | 'epoch': epoch + 1, 264 | 'state_dict': model.state_dict(), 265 | 'error_history': error_history, 266 | }, filename=filename) 267 | 268 | else: 269 | if not args.deploy: 270 | model.load_state_dict(torch.load(filename)['state_dict']) 271 | epoch = 0 272 | validate() 273 | -------------------------------------------------------------------------------- /prune.py: -------------------------------------------------------------------------------- 1 | """Pruning script""" 2 | 3 | import argparse 4 | import os 5 | 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | from funcs import * 9 | from models import * 10 | 11 | 12 | parser = argparse.ArgumentParser(description='Pruning') 13 | parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers') 14 | parser.add_argument('--GPU', default='0', type=str, help='GPU to use') 15 | parser.add_argument('--save_file', default='wrn16_2_p', type=str, help='save file for checkpoints') 16 | parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') 17 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 18 | parser.add_argument('--resume_ckpt', default='checkpoint', type=str, 19 | help='save file for resumed checkpoint') 20 | parser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar', type=str, help='where is the dataset') 21 | 22 | # Learning specific arguments 23 | parser.add_argument('--optimizer', choices=['sgd', 'adam'], default='sgd', type=str, help='optimizer') 24 | parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)') 25 | parser.add_argument('-lr', '--learning_rate', default=8e-4, type=float, metavar='LR', help='initial learning rate') 26 | parser.add_argument('-epochs', '--no_epochs', default=1300, type=int, metavar='epochs', help='no. epochs') 27 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 28 | parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay') 29 | parser.add_argument('--prune_every', default=100, type=int, help='prune every X steps') 30 | parser.add_argument('--save_every', default=100, type=int, help='save model every X EPOCHS') 31 | parser.add_argument('--random', default=False, type=bool, help='Prune at random') 32 | parser.add_argument('--base_model', default='base_model', type=str, help='basemodel') 33 | parser.add_argument('--val_every', default=1, type=int, help='val model every X EPOCHS') 34 | parser.add_argument('--mask', default=1, type=int, help='Mask type') 35 | parser.add_argument('--l1_prune', default=False, type=bool, help='Prune via l1 norm') 36 | parser.add_argument('--net', default='dense', type=str, help='dense, res') 37 | parser.add_argument('--width', default=2.0, type=float, metavar='D') 38 | parser.add_argument('--depth', default=40, type=int, metavar='W') 39 | parser.add_argument('--growth', default=12, type=int, help='growth rate of densenet') 40 | parser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet') 41 | 42 | args = parser.parse_args() 43 | print(args) 44 | os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU 45 | 46 | device = torch.device("cuda:%s" % '0' if torch.cuda.is_available() else "cpu") 47 | 48 | 49 | if args.net == 'res': 50 | model = WideResNet(args.depth, args.width, mask=args.mask) 51 | elif args.net =='dense': 52 | model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask) 53 | 54 | model.load_state_dict(torch.load('checkpoints/%s.t7' % args.base_model, map_location='cpu')['state_dict'], strict=True) 55 | 56 | if args.resume: 57 | state = torch.load('checkpoints/%s.t7' % args.resume_ckpt, map_location='cpu') 58 | 59 | model = resume_from(state, model_type=args.net) 60 | error_history = state['error_history'] 61 | prune_history = state['prune_history'] 62 | flop_history = state['flop_history'] 63 | param_history = state['param_history'] 64 | start_epoch = state['epoch'] 65 | 66 | else: 67 | 68 | error_history = [] 69 | prune_history = [] 70 | param_history = [] 71 | start_epoch = 0 72 | 73 | model.to(device) 74 | 75 | normMean = [0.49139968, 0.48215827, 0.44653124] 76 | normStd = [0.24703233, 0.24348505, 0.26158768] 77 | normTransform = transforms.Normalize(normMean, normStd) 78 | 79 | print('==> Preparing data..') 80 | num_classes = 10 81 | 82 | transform_train = transforms.Compose([ 83 | transforms.RandomCrop(32, padding=4), 84 | transforms.RandomHorizontalFlip(), 85 | transforms.ToTensor(), 86 | normTransform 87 | ]) 88 | 89 | transform_val = transforms.Compose([ 90 | transforms.ToTensor(), 91 | normTransform 92 | 93 | ]) 94 | 95 | trainset = torchvision.datasets.CIFAR10(root=args.data_loc, 96 | train=True, download=True, transform=transform_train) 97 | valset = torchvision.datasets.CIFAR10(root=args.data_loc, 98 | train=False, download=True, transform=transform_val) 99 | 100 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 101 | num_workers=args.workers, 102 | pin_memory=False) 103 | valloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False, 104 | num_workers=args.workers, 105 | pin_memory=False) 106 | 107 | prune_count = 0 108 | pruner = Pruner() 109 | pruner.prune_history = prune_history 110 | 111 | NO_STEPS = args.prune_every 112 | 113 | 114 | def finetune(): 115 | batch_time = AverageMeter() 116 | data_time = AverageMeter() 117 | losses = AverageMeter() 118 | top1 = AverageMeter() 119 | top5 = AverageMeter() 120 | 121 | # switch to train mode 122 | model.train() 123 | 124 | end = time.time() 125 | 126 | dataiter = iter(trainloader) 127 | 128 | for i in range(0, NO_STEPS): 129 | 130 | try: 131 | input, target = dataiter.next() 132 | except StopIteration: 133 | dataiter = iter(trainloader) 134 | input, target = dataiter.next() 135 | 136 | # measure data loading time 137 | data_time.update(time.time() - end) 138 | 139 | input, target = input.to(device), target.to(device) 140 | 141 | # compute output 142 | output = model(input) 143 | 144 | loss = criterion(output, target) 145 | 146 | # measure accuracy and record loss 147 | err1, err5 = get_error(output.detach(), target, topk=(1, 5)) 148 | 149 | losses.update(loss.item(), input.size(0)) 150 | top1.update(err1.item(), input.size(0)) 151 | top5.update(err5.item(), input.size(0)) 152 | 153 | # compute gradient and do SGD step 154 | optimizer.zero_grad() 155 | loss.backward() 156 | optimizer.step() 157 | 158 | # measure elapsed time 159 | batch_time.update(time.time() - end) 160 | end = time.time() 161 | 162 | if i % args.print_freq == 0: 163 | print('Prunepoch: [{0}][{1}/{2}]\t' 164 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 165 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 166 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 167 | 'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t' 168 | 'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 169 | epoch, i, NO_STEPS, batch_time=batch_time, 170 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 171 | 172 | 173 | 174 | 175 | def prune(): 176 | print('Pruning') 177 | if args.random is False: 178 | if args.l1_prune is False: 179 | print('fisher pruning') 180 | pruner.fisher_prune(model, prune_every=args.prune_every) 181 | else: 182 | print('l1 pruning') 183 | pruner.l1_prune(model, prune_every=args.prune_every) 184 | else: 185 | print('random pruning') 186 | pruner.random_prune(model, ) 187 | 188 | 189 | def validate(): 190 | global error_history 191 | 192 | batch_time = AverageMeter() 193 | data_time = AverageMeter() 194 | 195 | losses = AverageMeter() 196 | top1 = AverageMeter() 197 | top5 = AverageMeter() 198 | 199 | # switch to evaluate mode 200 | model.eval() 201 | 202 | end = time.time() 203 | 204 | for i, (input, target) in enumerate(valloader): 205 | 206 | # measure data loading time 207 | data_time.update(time.time() - end) 208 | 209 | input, target = input.to(device), target.to(device) 210 | 211 | # compute output 212 | output = model(input) 213 | 214 | loss = criterion(output, target) 215 | 216 | # measure accuracy and record loss 217 | err1, err5 = get_error(output.detach(), target, topk=(1, 5)) 218 | 219 | losses.update(loss.item(), input.size(0)) 220 | top1.update(err1.item(), input.size(0)) 221 | top5.update(err5.item(), input.size(0)) 222 | 223 | # measure elapsed time 224 | batch_time.update(time.time() - end) 225 | end = time.time() 226 | 227 | if i % args.print_freq == 0: 228 | print('Test: [{0}/{1}]\t' 229 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 230 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 231 | 'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t' 232 | 'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 233 | i, len(valloader), batch_time=batch_time, loss=losses, 234 | top1=top1, top5=top5)) 235 | 236 | print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}' 237 | .format(top1=top1, top5=top5)) 238 | 239 | 240 | 241 | # Record Top 1 for CIFAR 242 | error_history.append(top1.avg) 243 | 244 | 245 | if __name__ == '__main__': 246 | 247 | criterion = nn.CrossEntropyLoss() 248 | 249 | optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad], 250 | lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 251 | 252 | for epoch in range(start_epoch, args.no_epochs): 253 | 254 | print('Epoch %d:' % epoch) 255 | print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0]) 256 | 257 | # finetune for one epoch 258 | finetune() 259 | # # evaluate on validation set 260 | if epoch != 0 and ((epoch % args.val_every == 0) or (epoch + 1 == args.no_epochs)): # Save at last epoch! 261 | validate() 262 | 263 | # Error history is recorded in validate(). Record params here 264 | no_params = pruner.get_cost(model) + model.fixed_params 265 | param_history.append(no_params) 266 | 267 | # Save before pruning 268 | if epoch != 0 and ((epoch % args.save_every == 0) or (epoch + 1 == args.no_epochs)): # 269 | filename = 'checkpoints/%s_%d_prunes.t7' % (args.save_file, epoch) 270 | save_checkpoint({ 271 | 'epoch': epoch + 1, 272 | 'state_dict': model.state_dict(), 273 | 'error_history': error_history, 274 | 'param_history': param_history, 275 | 'prune_history': pruner.prune_history, 276 | }, filename=filename) 277 | 278 | ## Prune 279 | prune() 280 | 281 | -------------------------------------------------------------------------------- /funcs.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torchvision.transforms as transforms 4 | import torchvision 5 | import time 6 | from functools import reduce 7 | from models import * 8 | import random 9 | import time 10 | import operator 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | 14 | from models import * 15 | 16 | 17 | class Pruner: 18 | def __init__(self, module_name='MaskBlock'): 19 | # First get vector of masks 20 | self.module_name = module_name 21 | self.masks = [] 22 | self.prune_history = [] 23 | 24 | def fisher_prune(self, model, prune_every): 25 | 26 | self._get_fisher(model) 27 | tot_loss = self.fisher.div(prune_every) + 1e6 * (1 - self.masks) # dummy value for off masks 28 | print(len(tot_loss)) 29 | min, argmin = torch.min(tot_loss, 0) 30 | self.prune(model, argmin.item()) 31 | self.prune_history.append(argmin.item()) 32 | 33 | def fixed_prune(self, model, ID): 34 | self.prune(model, ID) 35 | self.prune_history.append(ID) 36 | 37 | def random_prune(self, model): 38 | 39 | self._get_fisher(model) 40 | # Do this to update costs. 41 | masks = [] 42 | for m in model.modules(): 43 | if m._get_name() == self.module_name: 44 | masks.append(m.mask.detach()) 45 | 46 | masks = self.concat(masks) 47 | masks_on = [i for i, v in enumerate(masks) if v == 1] 48 | random_pick = random.choice(masks_on) 49 | self.prune(model, random_pick) 50 | self.prune_history.append(random_pick) 51 | 52 | def l1_prune(self, model, prune_every): 53 | masks = [] 54 | l1_norms = [] 55 | 56 | for m in model.modules(): 57 | if m._get_name() == 'MaskBlock': 58 | l1_norm = torch.sum(m.conv1.weight, (1, 2, 3)).detach().cpu().numpy() 59 | masks.append(m.mask.detach()) 60 | l1_norms.append(l1_norm) 61 | 62 | masks = self.concat(masks) 63 | self.masks = masks 64 | l1_norms = np.concatenate(l1_norms) 65 | 66 | l1_norms_on = [] 67 | for m, l in zip(masks, l1_norms): 68 | if m == 1: 69 | l1_norms_on.append(l) 70 | else: 71 | l1_norms_on.append(9999.) # dummy value 72 | 73 | smallest_norm = min(l1_norms_on) 74 | pick = np.where(l1_norms == smallest_norm)[0][0] 75 | 76 | self.prune(model, pick) 77 | self.prune_history.append(pick) 78 | 79 | def prune(self, model, feat_index): 80 | print('Pruned %d out of %d channels so far' % (len(self.prune_history), len(self.masks))) 81 | if len(self.prune_history) > len(self.masks): 82 | raise Exception('Time to stop') 83 | """feat_index refers to the index of a feature map. This function modifies the mask to turn it off.""" 84 | safe = 0 85 | running_index = 0 86 | for m in model.modules(): 87 | if m._get_name() == self.module_name: 88 | mask_indices = range(running_index, running_index + len(m.mask)) 89 | 90 | if feat_index in mask_indices: 91 | print('Pruning channel %d' % feat_index) 92 | local_index = mask_indices.index(feat_index) 93 | m.mask[local_index] = 0 94 | safe = 1 95 | break 96 | else: 97 | running_index += len(m.mask) 98 | # print(running_index) 99 | if not safe: 100 | raise Exception('The provided index doesn''t correspond to any feature maps. This is bad.') 101 | 102 | def compress(self, model): 103 | for m in model.modules(): 104 | if m._get_name() == 'MaskBlock': 105 | m.compress_weights() 106 | 107 | def _get_fisher(self, model): 108 | masks = [] 109 | fisher = [] 110 | 111 | self._update_cost(model) 112 | 113 | for m in model.modules(): 114 | if m._get_name() == self.module_name: 115 | masks.append(m.mask.detach()) 116 | fisher.append(m.running_fisher.detach()) 117 | 118 | # Now clear the fisher cache 119 | m.reset_fisher() 120 | 121 | self.masks = self.concat(masks) 122 | self.fisher = self.concat(fisher) 123 | 124 | def _get_masks(self, model): 125 | masks = [] 126 | 127 | for m in model.modules(): 128 | if m._get_name() == self.module_name: 129 | masks.append(m.mask.detach()) 130 | 131 | self.masks = self.concat(masks) 132 | 133 | def _update_cost(self, model): 134 | for m in model.modules(): 135 | if m._get_name() == self.module_name: 136 | m.cost() 137 | 138 | def get_cost(self, model): 139 | params = 0 140 | for m in model.modules(): 141 | if m._get_name() == self.module_name: 142 | m.cost() 143 | params += m.params 144 | return params 145 | 146 | @staticmethod 147 | def concat(input): 148 | return torch.cat([item for item in input]) 149 | 150 | 151 | def find(input): 152 | # Find as in MATLAB to find indices in a binary vector 153 | return [i for i, j in enumerate(input) if j] 154 | 155 | 156 | def concat(input): 157 | return torch.cat([item for item in input]) 158 | 159 | 160 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 161 | torch.save(state, filename) 162 | 163 | 164 | def get_error(output, target, topk=(1,)): 165 | """Computes the error@k for the specified values of k""" 166 | maxk = max(topk) 167 | batch_size = target.size(0) 168 | 169 | _, pred = output.topk(maxk, 1, True, True) 170 | pred = pred.t() 171 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 172 | 173 | res = [] 174 | for k in topk: 175 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 176 | res.append(100.0 - correct_k.mul_(100.0 / batch_size)) 177 | return res 178 | 179 | 180 | class AverageMeter(object): 181 | """Computes and stores the average and current value""" 182 | 183 | def __init__(self): 184 | self.reset() 185 | 186 | def reset(self): 187 | self.val = 0 188 | self.avg = 0 189 | self.sum = 0 190 | self.count = 0 191 | 192 | def update(self, val, n=1): 193 | self.val = val 194 | self.sum += val * n 195 | self.count += n 196 | self.avg = self.sum / self.count 197 | 198 | 199 | def get_inf_params(net, verbose=True, sd=False): 200 | if sd: 201 | params = net 202 | else: 203 | params = net.state_dict() 204 | tot = 0 205 | conv_tot = 0 206 | for p in params: 207 | no = params[p].view(-1).__len__() 208 | 209 | if ('num_batches_tracked' not in p) and ('running' not in p) and ('mask' not in p): 210 | tot += no 211 | 212 | if verbose: 213 | print('%s has %d params' % (p, no)) 214 | if 'conv' in p: 215 | conv_tot += no 216 | 217 | if verbose: 218 | print('Net has %d conv params' % conv_tot) 219 | print('Net has %d params in total' % tot) 220 | 221 | return tot 222 | 223 | 224 | count_ops = 0 225 | count_params = 0 226 | 227 | 228 | def get_num_gen(gen): 229 | return sum(1 for x in gen) 230 | 231 | 232 | def is_pruned(layer): 233 | try: 234 | layer.mask 235 | return True 236 | except AttributeError: 237 | return False 238 | 239 | 240 | def is_leaf(model): 241 | return get_num_gen(model.children()) == 0 242 | 243 | 244 | def get_layer_info(layer): 245 | layer_str = str(layer) 246 | type_name = layer_str[:layer_str.find('(')].strip() 247 | return type_name 248 | 249 | 250 | def get_layer_param(model): 251 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 252 | 253 | 254 | ### The input batch size should be 1 to call this function 255 | def measure_layer(layer, x): 256 | global count_ops, count_params 257 | delta_ops = 0 258 | delta_params = 0 259 | multi_add = 1 260 | type_name = get_layer_info(layer) 261 | 262 | ### ops_conv 263 | if type_name in ['Conv2d']: 264 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / 265 | layer.stride[0] + 1) 266 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / 267 | layer.stride[1] + 1) 268 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 269 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 270 | delta_params = get_layer_param(layer) 271 | 272 | ### ops_learned_conv 273 | elif type_name in ['LearnedGroupConv']: 274 | measure_layer(layer.relu, x) 275 | measure_layer(layer.norm, x) 276 | conv = layer.conv 277 | out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) / 278 | conv.stride[0] + 1) 279 | out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) / 280 | conv.stride[1] + 1) 281 | delta_ops = conv.in_channels * conv.out_channels * conv.kernel_size[0] * \ 282 | conv.kernel_size[1] * out_h * out_w / layer.condense_factor * multi_add 283 | delta_params = get_layer_param(conv) / layer.condense_factor 284 | 285 | ### ops_nonlinearity 286 | elif type_name in ['ReLU']: 287 | delta_ops = x.numel() 288 | delta_params = get_layer_param(layer) 289 | 290 | ### ops_pooling 291 | elif type_name in ['AvgPool2d']: 292 | in_w = x.size()[2] 293 | kernel_ops = layer.kernel_size * layer.kernel_size 294 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 295 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 296 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops 297 | print(delta_ops) 298 | delta_params = get_layer_param(layer) 299 | 300 | elif type_name in ['AdaptiveAvgPool2d']: 301 | delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] 302 | delta_params = get_layer_param(layer) 303 | 304 | ### ops_linear 305 | elif type_name in ['Linear']: 306 | weight_ops = layer.weight.numel() * multi_add 307 | bias_ops = layer.bias.numel() 308 | delta_ops = x.size()[0] * (weight_ops + bias_ops) 309 | delta_params = get_layer_param(layer) 310 | 311 | ### ops_nothing 312 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']: 313 | delta_params = get_layer_param(layer) 314 | 315 | ### unknown layer type 316 | else: 317 | None 318 | # raise TypeError('unknown layer type: %s' % type_name) 319 | 320 | count_ops += delta_ops 321 | count_params += delta_params 322 | return 323 | 324 | 325 | def measure_model(model, H, W): 326 | global count_ops, count_params 327 | count_ops = 0 328 | count_params = 0 329 | data = Variable(torch.zeros(1, 3, H, W)) 330 | 331 | def should_measure(x): 332 | return is_leaf(x) or is_pruned(x) 333 | 334 | def modify_forward(model): 335 | for child in model.children(): 336 | if should_measure(child): 337 | def new_forward(m): 338 | def lambda_forward(x): 339 | measure_layer(m, x) 340 | return m.old_forward(x) 341 | 342 | return lambda_forward 343 | 344 | child.old_forward = child.forward 345 | child.forward = new_forward(child) 346 | else: 347 | modify_forward(child) 348 | 349 | def restore_forward(model): 350 | for child in model.children(): 351 | # leaf node 352 | if is_leaf(child) and hasattr(child, 'old_forward'): 353 | child.forward = child.old_forward 354 | child.old_forward = None 355 | else: 356 | restore_forward(child) 357 | 358 | modify_forward(model) 359 | model.forward(data) 360 | restore_forward(model) 361 | 362 | return count_ops, count_params 363 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Identity(nn.Module): 8 | def __init__(self): 9 | super(Identity, self).__init__() 10 | 11 | def forward(self, x): 12 | return x 13 | 14 | 15 | class Zero(nn.Module): 16 | def __init__(self): 17 | super(Zero, self).__init__() 18 | 19 | def forward(self, x): 20 | return x * 0 21 | 22 | 23 | class ZeroMake(nn.Module): 24 | def __init__(self, channels, spatial): 25 | super(ZeroMake, self).__init__() 26 | self.spatial = spatial 27 | self.channels = channels 28 | 29 | def forward(self, x): 30 | return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial], 31 | dtype=x.dtype, layout=x.layout, device=x.device) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | def __init__(self, in_channels, out_channels, stride, dropRate=0.0): 36 | super(BasicBlock, self).__init__() 37 | self.bn1 = nn.BatchNorm2d(in_channels) 38 | self.relu1 = nn.ReLU(inplace=True) 39 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, 40 | padding=1, bias=False) 41 | self.bn2 = nn.BatchNorm2d(out_channels) 42 | self.relu2 = nn.ReLU(inplace=True) 43 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, 44 | padding=1, bias=False) 45 | self.droprate = dropRate 46 | self.equalInOut = (in_channels == out_channels) 47 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, 48 | padding=0, bias=False) or None 49 | 50 | def forward(self, x): 51 | if not self.equalInOut: 52 | x = self.relu1(self.bn1(x)) 53 | else: 54 | out = self.relu1(self.bn1(x)) 55 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 56 | if self.droprate > 0: 57 | out = F.dropout(out, p=self.droprate, training=self.training) 58 | out = self.conv2(out) 59 | 60 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 61 | 62 | 63 | class BottleBlock(nn.Module): 64 | def __init__(self, in_channels, out_channels, mid_channels, stride, dropRate=0.0): 65 | super(BottleBlock, self).__init__() 66 | self.bn1 = nn.BatchNorm2d(in_channels) 67 | self.relu1 = nn.ReLU(inplace=True) 68 | self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(mid_channels) 71 | self.relu2 = nn.ReLU(inplace=True) 72 | self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, 73 | padding=1, bias=False) 74 | self.droprate = dropRate 75 | self.equalInOut = (in_channels == out_channels) 76 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, 77 | padding=0, bias=False) or None 78 | 79 | def forward(self, x): 80 | if not self.equalInOut: 81 | x = self.relu1(self.bn1(x)) 82 | else: 83 | out = self.relu1(self.bn1(x)) 84 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 85 | if self.droprate > 0: 86 | out = F.dropout(out, p=self.droprate, training=self.training) 87 | out = self.conv2(out) 88 | 89 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 90 | 91 | 92 | class MaskBlock(nn.Module): 93 | expansion = 1 94 | 95 | def __init__(self, in_channels, out_channels, stride=1, dropRate=0.0): 96 | 97 | super(MaskBlock, self).__init__() 98 | self.bn1 = nn.BatchNorm2d(in_channels) 99 | self.relu1 = nn.ReLU(inplace=True) 100 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 101 | self.bn2 = nn.BatchNorm2d(out_channels) 102 | self.relu2 = nn.ReLU(inplace=True) 103 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 104 | 105 | self.droprate = dropRate 106 | self.equalInOut = (in_channels == out_channels) 107 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, 108 | padding=0, bias=False) or None 109 | 110 | self.activation = Identity() 111 | self.activation.register_backward_hook(self._fisher) 112 | self.register_buffer('mask', None) 113 | 114 | self.input_shape = None 115 | self.output_shape = None 116 | self.flops = None 117 | self.params = None 118 | self.in_channels = in_channels 119 | self.out_channels = out_channels 120 | self.stride = stride 121 | self.got_shapes = False 122 | 123 | # Fisher method is called on backward passes 124 | self.running_fisher = 0 125 | 126 | def forward(self, x): 127 | 128 | if not self.equalInOut: 129 | x = self.relu1(self.bn1(x)) 130 | else: 131 | out = self.relu1(self.bn1(x)) 132 | 133 | out = self.conv1(out if self.equalInOut else x) 134 | 135 | out = self.relu2(self.bn2(out)) 136 | 137 | if self.mask is not None: 138 | out = out * self.mask[None, :, None, None] 139 | 140 | else: 141 | self._create_mask(x, out) 142 | 143 | out = self.activation(out) 144 | self.act = out 145 | 146 | if self.droprate > 0: 147 | out = F.dropout(out, p=self.droprate, training=self.training) 148 | 149 | out = self.conv2(out) 150 | 151 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 152 | 153 | def _create_mask(self, x, out): 154 | 155 | self.mask = x.new_ones(out.shape[1]) 156 | self.input_shape = x.size() 157 | self.output_shape = out.size() 158 | 159 | def _fisher(self, notused1, notused2, grad_output): 160 | act = self.act.detach() 161 | grad = grad_output[0].detach() 162 | 163 | g_nk = (act * grad).sum(-1).sum(-1) 164 | del_k = g_nk.pow(2).mean(0).mul(0.5) 165 | self.running_fisher += del_k 166 | 167 | def reset_fisher(self): 168 | self.running_fisher = 0 * self.running_fisher 169 | 170 | def cost(self): 171 | 172 | in_channels = self.in_channels 173 | out_channels = self.out_channels 174 | middle_channels = int(self.mask.sum().item()) 175 | 176 | conv1_size = self.conv1.weight.size() 177 | conv2_size = self.conv2.weight.size() 178 | 179 | # convs 180 | self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \ 181 | conv2_size[2] * conv2_size[3] 182 | 183 | # batchnorms, assuming running stats are absorbed 184 | self.params += 2 * in_channels + 2 * middle_channels 185 | 186 | # skip 187 | if not self.equalInOut: 188 | self.params += in_channels * out_channels 189 | else: 190 | self.params += 0 191 | 192 | def compress_weights(self): 193 | 194 | middle_dim = int(self.mask.sum().item()) 195 | print(middle_dim) 196 | 197 | if middle_dim is not 0: 198 | conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=self.stride, padding=1, bias=False) 199 | conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :]) 200 | 201 | # Batch norm 2 changes 202 | bn2 = nn.BatchNorm2d(middle_dim) 203 | bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1]) 204 | bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1]) 205 | bn2.running_mean = self.bn2.running_mean[self.mask == 1] 206 | bn2.running_var = self.bn2.running_var[self.mask == 1] 207 | 208 | conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False) 209 | conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :]) 210 | 211 | if middle_dim is 0: 212 | conv1 = Zero() 213 | bn2 = Zero() 214 | conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride) 215 | 216 | self.conv1 = conv1 217 | self.conv2 = conv2 218 | self.bn2 = bn2 219 | 220 | if middle_dim is not 0: 221 | self.mask = torch.ones(middle_dim) 222 | else: 223 | self.mask = torch.ones(1) 224 | 225 | 226 | class NetworkBlock(nn.Module): 227 | def __init__(self, nb_layers, in_channels, out_channels, block, stride, dropRate=0.0): 228 | super(NetworkBlock, self).__init__() 229 | self.layer = self._make_layer(block, in_channels, out_channels, nb_layers, stride, dropRate) 230 | 231 | def _make_layer(self, block, in_channels, out_channels, nb_layers, stride, dropRate): 232 | layers = [] 233 | for i in range(int(nb_layers)): 234 | layers.append(block(i == 0 and in_channels or out_channels, out_channels, i == 0 and stride or 1, dropRate)) 235 | return nn.Sequential(*layers) 236 | 237 | def forward(self, x): 238 | return self.layer(x) 239 | 240 | 241 | class NetworkBlockBottle(nn.Module): 242 | def __init__(self, nb_layers, in_channels, out_channels, mid_channels, block, stride, dropRate=0.0): 243 | super(NetworkBlockBottle, self).__init__() 244 | self.layer = self._make_layer(block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate) 245 | 246 | def _make_layer(self, block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate): 247 | layers = [] 248 | for i in range(int(nb_layers)): 249 | layers.append( 250 | block(i == 0 and in_channels or out_channels, out_channels, mid_channels, i == 0 and stride or 1, 251 | dropRate)) 252 | return nn.Sequential(*layers) 253 | 254 | def forward(self, x): 255 | return self.layer(x) 256 | 257 | 258 | class WideResNet(nn.Module): 259 | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, mask=False): 260 | super(WideResNet, self).__init__() 261 | 262 | nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)] 263 | 264 | assert ((depth - 4) % 6 == 0) 265 | n = (depth - 4) / 6 266 | 267 | if mask == 1: 268 | block = MaskBlock 269 | else: 270 | block = BasicBlock 271 | 272 | # 1st conv before any network block 273 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 274 | padding=1, bias=False) 275 | # 1st block 276 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 277 | # 2nd block 278 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 279 | # 3rd block 280 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 281 | # global average pooling and classifier 282 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 283 | self.relu = nn.ReLU(inplace=True) 284 | self.fc = nn.Linear(nChannels[3], num_classes) 285 | self.nChannels = nChannels[3] 286 | 287 | # Count params that don't exist in blocks (conv1, bn1, fc) 288 | self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \ 289 | len(self.fc.weight.view(-1)) + len(self.fc.bias) 290 | 291 | 292 | for m in self.modules(): 293 | if isinstance(m, nn.Conv2d): 294 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 295 | m.weight.data.normal_(0, math.sqrt(2. / n)) 296 | elif isinstance(m, nn.BatchNorm2d): 297 | m.weight.data.fill_(1) 298 | m.bias.data.zero_() 299 | elif isinstance(m, nn.Linear): 300 | m.bias.data.zero_() 301 | 302 | def forward(self, x): 303 | out = self.conv1(x) 304 | out = self.block1(out) 305 | out = self.block2(out) 306 | out = self.block3(out) 307 | out = self.relu(self.bn1(out)) 308 | out = F.avg_pool2d(out, 8) 309 | out = out.view(-1, self.nChannels) 310 | return self.fc(out) 311 | 312 | 313 | class WideResNetBottle(nn.Module): 314 | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, bottle_mult=0.5): 315 | super(WideResNetBottle, self).__init__() 316 | 317 | nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)] 318 | 319 | assert ((depth - 4) % 6 == 0) 320 | n = (depth - 4) / 6 321 | 322 | block = BottleBlock 323 | 324 | # 1st conv before any network block 325 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 326 | padding=1, bias=False) 327 | # 1st block 328 | self.block1 = NetworkBlockBottle(n, nChannels[0], nChannels[1], int(nChannels[1] * bottle_mult), block, 1, 329 | dropRate) 330 | # 2nd block 331 | self.block2 = NetworkBlockBottle(n, nChannels[1], nChannels[2], int(nChannels[2] * bottle_mult), block, 2, 332 | dropRate) 333 | # 3rd block 334 | self.block3 = NetworkBlockBottle(n, nChannels[2], nChannels[3], int(nChannels[3] * bottle_mult), block, 2, 335 | dropRate) 336 | # global average pooling and classifier 337 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 338 | self.relu = nn.ReLU(inplace=True) 339 | self.fc = nn.Linear(nChannels[3], num_classes) 340 | self.nChannels = nChannels[3] 341 | 342 | for m in self.modules(): 343 | if isinstance(m, nn.Conv2d): 344 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 345 | m.weight.data.normal_(0, math.sqrt(2. / n)) 346 | elif isinstance(m, nn.BatchNorm2d): 347 | m.weight.data.fill_(1) 348 | m.bias.data.zero_() 349 | elif isinstance(m, nn.Linear): 350 | m.bias.data.zero_() 351 | 352 | def forward(self, x): 353 | out = self.conv1(x) 354 | out = self.block1(out) 355 | out = self.block2(out) 356 | out = self.block3(out) 357 | out = self.relu(self.bn1(out)) 358 | out = F.avg_pool2d(out, 8) 359 | out = out.view(-1, self.nChannels) 360 | return self.fc(out) 361 | --------------------------------------------------------------------------------