├── LICENSE ├── README.md ├── cifar ├── README.md ├── l1-norm-pruning │ ├── README.md │ ├── compute_flops.py │ ├── main.py │ ├── main_B.py │ ├── main_E.py │ ├── main_finetune.py │ ├── models │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── vgg.py │ ├── res110prune.py │ ├── res56prune.py │ └── vggprune.py ├── lottery-ticket │ ├── README.md │ ├── l1-norm-pruning │ │ ├── README.md │ │ ├── lottery_res110prune.py │ │ ├── lottery_resprune.py │ │ ├── lottery_vggprune.py │ │ ├── main.py │ │ ├── main_lottery.py │ │ ├── main_scratch_mask.py │ │ └── models │ │ │ ├── __init__.py │ │ │ ├── resnet.py │ │ │ └── vgg.py │ └── weight-level │ │ ├── README.md │ │ ├── cifar.py │ │ ├── cifar_prune_iterative.py │ │ ├── cifar_scratch_no_longer.py │ │ ├── lottery_ticket.py │ │ ├── models │ │ ├── __init__.py │ │ └── cifar │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── preresnet.py │ │ │ ├── resnet.py │ │ │ ├── resnext.py │ │ │ ├── vgg.py │ │ │ └── wrn.py │ │ └── utils │ │ ├── __init__.py │ │ ├── eval.py │ │ ├── logger.py │ │ ├── misc.py │ │ └── visualize.py ├── network-slimming │ ├── README.md │ ├── compute_flops.py │ ├── denseprune.py │ ├── main.py │ ├── main_B.py │ ├── main_E.py │ ├── main_finetune.py │ ├── models │ │ ├── __init__.py │ │ ├── channel_selection.py │ │ ├── densenet.py │ │ ├── preresnet.py │ │ └── vgg.py │ ├── resprune.py │ └── vggprune.py ├── soft-filter-pruning │ ├── README.md │ ├── compute_flops.py │ ├── pruning_cifar10_pretrain.py │ ├── pruning_cifar10_resnet.py │ ├── pruning_resnet_longer_scratch.py │ ├── pruning_resnet_scratch.py │ └── utils.py └── weight-level │ ├── README.md │ ├── cifar.py │ ├── cifar_B.py │ ├── cifar_E.py │ ├── cifar_finetune.py │ ├── cifar_prune.py │ ├── count_flops.py │ ├── models │ ├── __init__.py │ └── cifar │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── densenet.py │ │ ├── preresnet.py │ │ ├── resnet.py │ │ └── vgg.py │ └── utils │ ├── __init__.py │ ├── eval.py │ ├── logger.py │ ├── misc.py │ └── visualize.py └── imagenet ├── README.md ├── l1-norm-pruning ├── README.md ├── compute_flops.py ├── main_B.py ├── main_E.py ├── main_finetune.py ├── prune.py └── resnet.py ├── network-slimming ├── README.md ├── compute_flops.py ├── main.py ├── main_B.py ├── main_E.py ├── main_finetune.py ├── prune.py └── vgg.py ├── regression-pruning ├── README.md ├── compute_flops.py ├── main_B.py ├── main_E.py └── models │ ├── __init__.py │ ├── channel_selection.py │ ├── filter.pkl │ ├── resnet.py │ ├── resnet_2x.py │ └── vgg_5x.py ├── sparse-structure-selection └── README.md ├── thinet ├── README.md ├── compute_flops.py ├── main_B.py ├── main_E.py └── models │ ├── __init__.py │ ├── thinetconv.py │ ├── thinetresnet.py │ └── thinetvgg.py └── weight-level ├── README.md ├── compute_flops.py ├── main_B.py ├── main_E.py ├── main_finetune.py └── prune.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Mingjie Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rethinking the Value of Network Pruning 2 | This repository contains the code for reproducing the results, and trained ImageNet models, in the following paper: 3 | 4 | Rethinking the Value of Network Pruning. [[arXiv]](https://arxiv.org/abs/1810.05270) [[OpenReview]](https://openreview.net/forum?id=rJlnB3C5Ym) 5 | 6 | [Zhuang Liu](https://liuzhuang13.github.io/)\*, [Mingjie Sun](https://eric-mingjie.github.io/)\*, [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz/), [Gao Huang](http://www.gaohuang.net/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/) (\* equal contribution). 7 | 8 | ICLR 2019. Also [Best Paper Award](https://nips.cc/Conferences/2018/Schedule?showEvent=10941) at NIPS 2018 Workshop on Compact Deep Neural Networks. 9 | 10 | Several pruning methods' implementations contained in this repo can also be readily used for other research purposes. 11 | 12 | ## Paper Summary 13 | 14 |
15 | 16 |
17 | 18 |
19 | Fig 1: A typical three-stage network pruning 20 | pipeline. 21 |
22 | 23 |
24 | 25 | Our paper shows that for **structured** pruning, **training the pruned model from scratch can almost always achieve comparable or higher level of accuracy than the model obtained from the typical "training, pruning and fine-tuning" (Fig. 1) procedure**. We conclude that for those pruning methods: 26 | 27 | 1. Training a large, over-parameterized model is often not necessary to obtain an efficient final model. 28 | 2. Learned “important” weights of the large model are typically not useful for the small pruned model. 29 | 3. The pruned architecture itself, rather than a set of inherited “important” weights, is more crucial to the efficiency in the final model, which suggests that in some cases pruning can be useful as an architecture search paradigm. 30 | 31 | Our results suggest the need for more careful baseline evaluations in future research on structured pruning methods. 32 | 33 |
34 | 35 |
36 | 37 | Fig 2: Difference between predefined and automatically discovered target architectures, in channel pruning. The pruning ratio x is user-specified, while a, b, c, d are determined by the pruning algorithm. Unstructured sparse pruning can also be viewed as automatic. Our finding has different implications for predefined and automatic methods: for a predefined method, it is possible to skip the traditional "training, pruning and fine-tuning" pipeline and directly train the pruned model; for automatic methods, the pruning can be seen as a form of architecture learning. 38 | 39 |
40 | 41 | We also compare with the "[Lottery Ticket Hypothesis](https://arxiv.org/abs/1803.03635)" (Frankle & Carbin 2019), and find that with optimal learning rate, the "winning ticket" initialization as used in Frankle & Carbin (2019) does not bring improvement over random initialization. For more details please refer to our paper. 42 | 43 | ## Implementation 44 | We evaluated the following seven pruning methods. 45 | 46 | 1. [L1-norm based channel pruning](https://arxiv.org/abs/1608.08710) 47 | 2. [ThiNet](https://arxiv.org/abs/1707.06342) 48 | 3. [Regression based feature reconstruction](https://arxiv.org/abs/1707.06168) 49 | 4. [Network Slimming](https://arxiv.org/abs/1708.06519) 50 | 5. [Sparse Structure Selection](https://arxiv.org/abs/1707.01213) 51 | 6. [Soft filter pruning](https://www.ijcai.org/proceedings/2018/0309.pdf) 52 | 7. [Unstructured weight-level pruning](https://arxiv.org/abs/1506.02626) 53 | 54 | The first six is structured while the last one is unstructured (or sparse). For CIFAR, our code is based on [pytorch-classification](https://github.com/bearpaw/pytorch-classification) and [network-slimming](https://github.com/Eric-mingjie/network-slimming). For ImageNet, we use the [official Pytorch ImageNet training code](https://github.com/pytorch/examples/blob/0.3.1/imagenet/main.py). The instructions and models are in each subfolder. 55 | 56 | For experiments on [The Lottery Ticket Hypothesis](https://arxiv.org/abs/1803.03635), please refer to the folder [cifar/lottery-ticket](https://github.com/Eric-mingjie/rethinking-network-pruning/tree/master/cifar/lottery-ticket). 57 | 58 | Our experiment environment is Python 3.6 & PyTorch 0.3.1. 59 | 60 | ## Contact 61 | Feel free to discuss papers/code with us through issues/emails! 62 | 63 | sunmj15 at gmail.com 64 | liuzhuangthu at gmail.com 65 | 66 | ## Citation 67 | If you use our code in your research, please cite: 68 | ``` 69 | @inproceedings{liu2018rethinking, 70 | title={Rethinking the Value of Network Pruning}, 71 | author={Liu, Zhuang and Sun, Mingjie and Zhou, Tinghui and Huang, Gao and Darrell, Trevor}, 72 | booktitle={ICLR}, 73 | year={2019} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /cifar/README.md: -------------------------------------------------------------------------------- 1 | # CIFAR Experiments 2 | This directory contains all the CIFAR experiments in the paper, where there are four pruning methods in total: 3 | 4 | 1. [L1-norm based channel pruning](https://arxiv.org/abs/1608.08710) 5 | 2. [Network Slimming](https://arxiv.org/abs/1708.06519) 6 | 3. [Soft filter pruning](https://www.ijcai.org/proceedings/2018/0309.pdf) 7 | 4. [Non-structured weight-level pruning](https://arxiv.org/abs/1506.02626) 8 | 9 | For each method, we give example commands for baseline training, finetuning, scratch-E training and scratch-B training. 10 | 11 | We also give our implementation for [Lottery Ticket Hypothesis](https://arxiv.org/abs/1803.03635). 12 | 13 | ## Implementation 14 | Our code is based on [network-slimming](https://github.com/Eric-mingjie/network-slimming) and [pytorch-classification](https://github.com/bearpaw/pytorch-classification). 15 | -------------------------------------------------------------------------------- /cifar/l1-norm-pruning/README.md: -------------------------------------------------------------------------------- 1 | # Pruning Filters For Efficient ConvNets 2 | 3 | This directory contains a pytorch re-implementation of all CIFAR experiments of the following paper 4 | [Pruning Filters for Efficient ConvNets](https://arxiv.org/abs/1608.08710) (ICLR 2017). 5 | 6 | ## Dependencies 7 | torch v0.3.1, torchvision v0.2.0 8 | 9 | ## Baseline 10 | 11 | The `dataset` argument specifies which dataset to use: `cifar10` or `cifar100`. The `arch` argument specifies the architecture to use: `vgg` or `resnet`. The depth is chosen to be the same as the networks used in the paper. 12 | ```shell 13 | python main.py --dataset cifar10 --arch vgg --depth 16 14 | python main.py --dataset cifar10 --arch resnet --depth 56 15 | python main.py --dataset cifar10 --arch resnet --depth 110 16 | ``` 17 | 18 | ## Prune 19 | 20 | ```shell 21 | python vggprune.py --dataset cifar10 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 22 | python res56prune.py --dataset cifar10 -v A --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 23 | python res110prune.py --dataset cifar10 -v A --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 24 | ``` 25 | Here in `res56prune.py` and `res110prune.py`, the `-v` argument is `A` or `B`, which refers to the naming of the pruned model in the original paper. The pruned model will be named `pruned.pth.tar`. 26 | 27 | ## Fine-tune 28 | 29 | ```shell 30 | python main_finetune.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 16 31 | python main_finetune.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 56 32 | python main_finetune.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 110 33 | ``` 34 | 35 | ## Scratch-E 36 | ``` 37 | python main_E.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 16 38 | python main_E.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 56 39 | python main_E.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 110 40 | ``` 41 | 42 | ## Scratch-B 43 | ``` 44 | python main_B.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 16 45 | python main_B.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 56 46 | python main_B.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 110 47 | ``` 48 | 49 | -------------------------------------------------------------------------------- /cifar/l1-norm-pruning/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | 4 | import torch 5 | import torchvision 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | 10 | def print_model_param_nums(model=None, multiply_adds=True): 11 | if model == None: 12 | model = torchvision.models.alexnet() 13 | total = sum([param.nelement() for param in model.parameters()]) 14 | print(' + Number of params: %.2fM' % (total / 1e6)) 15 | 16 | def print_model_param_flops(model=None, input_res=224, multiply_adds=True): 17 | 18 | prods = {} 19 | def save_hook(name): 20 | def hook_per(self, input, output): 21 | prods[name] = np.prod(input[0].shape) 22 | return hook_per 23 | 24 | list_1=[] 25 | def simple_hook(self, input, output): 26 | list_1.append(np.prod(input[0].shape)) 27 | list_2={} 28 | def simple_hook2(self, input, output): 29 | list_2['names'] = np.prod(input[0].shape) 30 | 31 | list_conv=[] 32 | def conv_hook(self, input, output): 33 | batch_size, input_channels, input_height, input_width = input[0].size() 34 | output_channels, output_height, output_width = output[0].size() 35 | 36 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 37 | bias_ops = 1 if self.bias is not None else 0 38 | 39 | params = output_channels * (kernel_ops + bias_ops) 40 | flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 41 | 42 | list_conv.append(flops) 43 | 44 | list_linear=[] 45 | def linear_hook(self, input, output): 46 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 47 | 48 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 49 | bias_ops = self.bias.nelement() 50 | 51 | flops = batch_size * (weight_ops + bias_ops) 52 | list_linear.append(flops) 53 | 54 | list_bn=[] 55 | def bn_hook(self, input, output): 56 | list_bn.append(input[0].nelement() * 2) 57 | 58 | list_relu=[] 59 | def relu_hook(self, input, output): 60 | list_relu.append(input[0].nelement()) 61 | 62 | list_pooling=[] 63 | def pooling_hook(self, input, output): 64 | batch_size, input_channels, input_height, input_width = input[0].size() 65 | output_channels, output_height, output_width = output[0].size() 66 | 67 | kernel_ops = self.kernel_size * self.kernel_size 68 | bias_ops = 0 69 | params = 0 70 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 71 | 72 | list_pooling.append(flops) 73 | 74 | list_upsample=[] 75 | # For bilinear upsample 76 | def upsample_hook(self, input, output): 77 | batch_size, input_channels, input_height, input_width = input[0].size() 78 | output_channels, output_height, output_width = output[0].size() 79 | 80 | flops = output_height * output_width * output_channels * batch_size * 12 81 | list_upsample.append(flops) 82 | 83 | def foo(net): 84 | childrens = list(net.children()) 85 | if not childrens: 86 | if isinstance(net, torch.nn.Conv2d): 87 | net.register_forward_hook(conv_hook) 88 | if isinstance(net, torch.nn.Linear): 89 | net.register_forward_hook(linear_hook) 90 | if isinstance(net, torch.nn.BatchNorm2d): 91 | net.register_forward_hook(bn_hook) 92 | if isinstance(net, torch.nn.ReLU): 93 | net.register_forward_hook(relu_hook) 94 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 95 | net.register_forward_hook(pooling_hook) 96 | if isinstance(net, torch.nn.Upsample): 97 | net.register_forward_hook(upsample_hook) 98 | return 99 | for c in childrens: 100 | foo(c) 101 | 102 | if model == None: 103 | model = torchvision.models.alexnet() 104 | foo(model) 105 | input = Variable(torch.rand(3, 3, input_res, input_res), requires_grad = True) 106 | out = model(input) 107 | 108 | 109 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 110 | 111 | print(' + Number of FLOPs: %.5fG' % (total_flops / 3 / 1e9)) 112 | 113 | return total_flops / 3 114 | -------------------------------------------------------------------------------- /cifar/l1-norm-pruning/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .vgg import * 4 | from .resnet import * -------------------------------------------------------------------------------- /cifar/l1-norm-pruning/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from functools import partial 8 | from torch.autograd import Variable 9 | 10 | 11 | __all__ = ['resnet'] 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, cfg, stride=1, downsample=None): 23 | # cfg should be a number in this case 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, cfg, stride) 26 | self.bn1 = nn.BatchNorm2d(cfg) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(cfg, planes) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | def downsample_basic_block(x, planes): 52 | x = nn.AvgPool2d(2,2)(x) 53 | zero_pads = torch.Tensor( 54 | x.size(0), planes - x.size(1), x.size(2), x.size(3)).zero_() 55 | if isinstance(x.data, torch.cuda.FloatTensor): 56 | zero_pads = zero_pads.cuda() 57 | 58 | out = Variable(torch.cat([x.data, zero_pads], dim=1)) 59 | 60 | return out 61 | 62 | class ResNet(nn.Module): 63 | 64 | def __init__(self, depth, dataset='cifar10', cfg=None): 65 | super(ResNet, self).__init__() 66 | # Model type specifies number of layers for CIFAR-10 model 67 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 68 | n = (depth - 2) // 6 69 | 70 | block = BasicBlock 71 | if cfg == None: 72 | cfg = [[16]*n, [32]*n, [64]*n] 73 | cfg = [item for sub_list in cfg for item in sub_list] 74 | 75 | self.cfg = cfg 76 | 77 | self.inplanes = 16 78 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 79 | bias=False) 80 | self.bn1 = nn.BatchNorm2d(16) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.layer1 = self._make_layer(block, 16, n, cfg=cfg[0:n]) 83 | self.layer2 = self._make_layer(block, 32, n, cfg=cfg[n:2*n], stride=2) 84 | self.layer3 = self._make_layer(block, 64, n, cfg=cfg[2*n:3*n], stride=2) 85 | self.avgpool = nn.AvgPool2d(8) 86 | if dataset == 'cifar10': 87 | num_classes = 10 88 | elif dataset == 'cifar100': 89 | num_classes = 100 90 | self.fc = nn.Linear(64 * block.expansion, num_classes) 91 | 92 | for m in self.modules(): 93 | if isinstance(m, nn.Conv2d): 94 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 95 | m.weight.data.normal_(0, math.sqrt(2. / n)) 96 | elif isinstance(m, nn.BatchNorm2d): 97 | m.weight.data.fill_(1) 98 | m.bias.data.zero_() 99 | 100 | def _make_layer(self, block, planes, blocks, cfg, stride=1): 101 | downsample = None 102 | if stride != 1 or self.inplanes != planes * block.expansion: 103 | downsample = partial(downsample_basic_block, planes=planes*block.expansion) 104 | 105 | layers = [] 106 | layers.append(block(self.inplanes, planes, cfg[0], stride, downsample)) 107 | self.inplanes = planes * block.expansion 108 | for i in range(1, blocks): 109 | layers.append(block(self.inplanes, planes, cfg[i])) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | x = self.conv1(x) 115 | x = self.bn1(x) 116 | x = self.relu(x) # 32x32 117 | 118 | x = self.layer1(x) # 32x32 119 | x = self.layer2(x) # 16x16 120 | x = self.layer3(x) # 8x8 121 | 122 | x = self.avgpool(x) 123 | x = x.view(x.size(0), -1) 124 | x = self.fc(x) 125 | 126 | return x 127 | 128 | def resnet(**kwargs): 129 | """ 130 | Constructs a ResNet model. 131 | """ 132 | return ResNet(**kwargs) 133 | 134 | if __name__ == '__main__': 135 | net = resnet(depth=56) 136 | x=Variable(torch.FloatTensor(16, 3, 32, 32)) 137 | y = net(x) 138 | print(y.data.shape) -------------------------------------------------------------------------------- /cifar/l1-norm-pruning/models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | __all__ = ['vgg'] 9 | 10 | defaultcfg = { 11 | 11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 12 | 13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 13 | 16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 14 | 19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 15 | } 16 | 17 | class vgg(nn.Module): 18 | def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None): 19 | super(vgg, self).__init__() 20 | if cfg is None: 21 | cfg = defaultcfg[depth] 22 | 23 | self.cfg = cfg 24 | 25 | self.feature = self.make_layers(cfg, True) 26 | 27 | if dataset == 'cifar10': 28 | num_classes = 10 29 | elif dataset == 'cifar100': 30 | num_classes = 100 31 | self.classifier = nn.Sequential( 32 | nn.Linear(cfg[-1], 512), 33 | nn.BatchNorm1d(512), 34 | nn.ReLU(inplace=True), 35 | nn.Linear(512, num_classes) 36 | ) 37 | if init_weights: 38 | self._initialize_weights() 39 | 40 | def make_layers(self, cfg, batch_norm=False): 41 | layers = [] 42 | in_channels = 3 43 | for v in cfg: 44 | if v == 'M': 45 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 46 | else: 47 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 48 | if batch_norm: 49 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 50 | else: 51 | layers += [conv2d, nn.ReLU(inplace=True)] 52 | in_channels = v 53 | return nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | x = self.feature(x) 57 | x = nn.AvgPool2d(2)(x) 58 | x = x.view(x.size(0), -1) 59 | y = self.classifier(x) 60 | return y 61 | 62 | def _initialize_weights(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 66 | m.weight.data.normal_(0, math.sqrt(2. / n)) 67 | if m.bias is not None: 68 | m.bias.data.zero_() 69 | elif isinstance(m, nn.BatchNorm2d): 70 | m.weight.data.fill_(0.5) 71 | m.bias.data.zero_() 72 | elif isinstance(m, nn.Linear): 73 | m.weight.data.normal_(0, 0.01) 74 | m.bias.data.zero_() 75 | 76 | if __name__ == '__main__': 77 | net = vgg() 78 | x = Variable(torch.FloatTensor(16, 3, 40, 40)) 79 | y = net(x) 80 | print(y.data.shape) -------------------------------------------------------------------------------- /cifar/lottery-ticket/README.md: -------------------------------------------------------------------------------- 1 | # The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks 2 | 3 | This directory contains a pytorch implementation of of [Lottery Ticket Hypothesis](https://arxiv.org/abs/1803.03635) (ICLR 2019) for non-structured weight pruning and l1-norm-pruning. 4 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/l1-norm-pruning/README.md: -------------------------------------------------------------------------------- 1 | # The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks 2 | 3 | This directory contains a pytorch implementation of [Lottery Ticket Hypothesis](https://arxiv.org/abs/1803.03635) for l1-norm based filter pruning introduced in this [paper](https://arxiv.org/abs/1608.08710) (ICLR 2017). 4 | 5 | ## Dependencies 6 | torch v0.3.1, torchvision v0.2.0 7 | 8 | ## Overview 9 | Since lottery ticket hypothesis involves the initialization of baseline model before training, it is easier to implement it using mask implementation (More explanation [here](https://github.com/Eric-mingjie/network-slimming/tree/master/mask-impl#mask-implementation-of-network-slimming)). 10 | 11 | ## Baseline 12 | 13 | ```shell 14 | python main.py --arch vgg --depth 16 --dataset cifar10 \ 15 | --lr 0.1 --save [PATH TO SAVE THE MODEL] 16 | ``` 17 | Note that the initialization is stored in a file called `init.pth.tar`, which will be used when training the lottery ticket. 18 | 19 | ## Iterative Prune 20 | 21 | ```shell 22 | python lottery_vggprune.py --dataset cifar10 --model [PATH TO THE MODEL] --save [DIRECTORY TO SAVE THE MODEL] 23 | python lottery_res56prune.py --dataset cifar10 -v A --model [PATH TO THE MODEL] --save [DIRECTORY TO SAVE THE MODEL] 24 | python lottery_res110prune.py --dataset cifar10 -v A --model [PATH TO THE MODEL] --save [DIRECTORY TO SAVE THE MODEL] 25 | ``` 26 | 27 | ## Lottery Ticket 28 | 29 | ```shell 30 | python main_lottery.py --dataset cifar10 --arch vgg --depth 16 \ 31 | --lr 0.1 --resume [PATH TO THE PRUNED MODEL] \ 32 | --model [PATH TO THE STORED INTIALIZATION] \ 33 | --save [PATH TO SAVE THE MODEL] 34 | ``` 35 | 36 | ## Scratch-E 37 | ``` 38 | python main_scratch_mask.py --dataset cifar10 --arch vgg --depth 16 \ 39 | --lr 0.1 --resume [PATH TO THE PRUNED MODEL] \ 40 | --save [PATH TO SAVE THE MODEL] 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/l1-norm-pruning/lottery_vggprune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torchvision import datasets, transforms 9 | 10 | from models import * 11 | 12 | 13 | # Prune settings 14 | parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune') 15 | parser.add_argument('--dataset', type=str, default='cifar10', 16 | help='training dataset (default: cifar10)') 17 | parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', 18 | help='input batch size for testing (default: 256)') 19 | parser.add_argument('--no-cuda', action='store_true', default=False, 20 | help='disables CUDA training') 21 | parser.add_argument('--depth', type=int, default=16, 22 | help='depth of the vgg') 23 | parser.add_argument('--model', default='', type=str, metavar='PATH', 24 | help='path to the model (default: none)') 25 | parser.add_argument('--save', default='.', type=str, metavar='PATH', 26 | help='path to save pruned model (default: none)') 27 | parser.add_argument('--prune', default='large', type=str, 28 | help='prune method to use') 29 | args = parser.parse_args() 30 | args.cuda = not args.no_cuda and torch.cuda.is_available() 31 | 32 | if not os.path.exists(args.save): 33 | os.makedirs(args.save) 34 | 35 | model = vgg(dataset=args.dataset, depth=args.depth) 36 | 37 | if args.cuda: 38 | model.cuda() 39 | 40 | if args.model: 41 | if os.path.isfile(args.model): 42 | print("=> loading checkpoint '{}'".format(args.model)) 43 | checkpoint = torch.load(args.model) 44 | args.start_epoch = checkpoint['epoch'] 45 | best_prec1 = checkpoint['best_prec1'] 46 | model.load_state_dict(checkpoint['state_dict']) 47 | print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" 48 | .format(args.model, checkpoint['epoch'], best_prec1)) 49 | else: 50 | print("=> no checkpoint found at '{}'".format(args.resume)) 51 | 52 | print('Pre-processing Successful!') 53 | 54 | # simple test model after Pre-processing prune (simple set BN scales to zeros) 55 | def test(model): 56 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 57 | if args.dataset == 'cifar10': 58 | test_loader = torch.utils.data.DataLoader( 59 | datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 62 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 63 | elif args.dataset == 'cifar100': 64 | test_loader = torch.utils.data.DataLoader( 65 | datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 68 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 69 | else: 70 | raise ValueError("No valid dataset is given.") 71 | model.eval() 72 | correct = 0 73 | for data, target in test_loader: 74 | if args.cuda: 75 | data, target = data.cuda(), target.cuda() 76 | data, target = Variable(data, volatile=True), Variable(target) 77 | output = model(data) 78 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 79 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 80 | 81 | print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( 82 | correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) 83 | return correct / float(len(test_loader.dataset)) 84 | 85 | acc = test(model) 86 | cfg = [32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256] 87 | 88 | cfg_mask = [] 89 | layer_id = 0 90 | for m in model.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | out_channels = m.weight.data.shape[0] 93 | if out_channels == cfg[layer_id]: 94 | cfg_mask.append(torch.ones(out_channels)) 95 | layer_id += 1 96 | continue 97 | weight_copy = m.weight.data.abs().clone() 98 | weight_copy = weight_copy.cpu().numpy() 99 | L1_norm = np.sum(weight_copy, axis=(1, 2, 3)) 100 | arg_max = np.argsort(L1_norm) 101 | if args.prune == 'large': 102 | arg_max_rev = arg_max[::-1][:cfg[layer_id]] 103 | elif args.prune == 'small': 104 | arg_max_rev = arg_max[:cfg[layer_id]] 105 | elif args.prune == 'random': 106 | arg_max_rev = np.random.choice(arg_max, cfg[layer_id], replace=False) 107 | assert arg_max_rev.size == cfg[layer_id], "size of arg_max_rev not correct" 108 | mask = torch.zeros(out_channels) 109 | mask[arg_max_rev.tolist()] = 1 110 | 111 | mask_neg = np.ones(out_channels) 112 | mask_neg[arg_max_rev.tolist()] = 0 113 | m.weight.data[mask_neg,:,:,:] = 0 114 | 115 | cfg_mask.append(mask) 116 | layer_id += 1 117 | elif isinstance(m, nn.MaxPool2d): 118 | layer_id += 1 119 | 120 | torch.save({'cfg': cfg, 'state_dict': model.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) 121 | acc = test(model) 122 | 123 | num_parameters = sum([param.nelement() for param in model.parameters()]) 124 | with open(os.path.join(args.save, "prune.txt"), "w") as fp: 125 | fp.write("Number of parameters: \n"+str(num_parameters)+"\n") 126 | fp.write("Test accuracy: \n"+str(acc)+"\n") -------------------------------------------------------------------------------- /cifar/lottery-ticket/l1-norm-pruning/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .vgg import * 4 | from .resnet import * -------------------------------------------------------------------------------- /cifar/lottery-ticket/l1-norm-pruning/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import math 13 | import torch.nn.functional as F 14 | from functools import partial 15 | from torch.autograd import Variable 16 | 17 | 18 | __all__ = ['resnet'] 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, cfg, stride=1, downsample=None): 30 | # cfg should be a number in this case 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, cfg, stride) 33 | self.bn1 = nn.BatchNorm2d(cfg) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(cfg, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | def downsample_basic_block(x, planes): 59 | x = nn.AvgPool2d(2,2)(x) 60 | zero_pads = torch.Tensor( 61 | x.size(0), planes - x.size(1), x.size(2), x.size(3)).zero_() 62 | if isinstance(x.data, torch.cuda.FloatTensor): 63 | zero_pads = zero_pads.cuda() 64 | 65 | out = Variable(torch.cat([x.data, zero_pads], dim=1)) 66 | 67 | return out 68 | 69 | class ResNet(nn.Module): 70 | 71 | def __init__(self, depth, dataset='cifar10', cfg=None): 72 | super(ResNet, self).__init__() 73 | # Model type specifies number of layers for CIFAR-10 model 74 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 75 | n = (depth - 2) // 6 76 | 77 | block = BasicBlock 78 | if cfg == None: 79 | cfg = [[16]*n, [32]*n, [64]*n] 80 | cfg = [item for sub_list in cfg for item in sub_list] 81 | 82 | self.cfg = cfg 83 | 84 | self.inplanes = 16 85 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 86 | bias=False) 87 | self.bn1 = nn.BatchNorm2d(16) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.layer1 = self._make_layer(block, 16, n, cfg=cfg[0:n]) 90 | self.layer2 = self._make_layer(block, 32, n, cfg=cfg[n:2*n], stride=2) 91 | self.layer3 = self._make_layer(block, 64, n, cfg=cfg[2*n:3*n], stride=2) 92 | self.avgpool = nn.AvgPool2d(8) 93 | if dataset == 'cifar10': 94 | num_classes = 10 95 | elif dataset == 'cifar100': 96 | num_classes = 100 97 | self.fc = nn.Linear(64 * block.expansion, num_classes) 98 | 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | m.weight.data.normal_(0, math.sqrt(2. / n)) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | m.weight.data.fill_(1) 105 | m.bias.data.zero_() 106 | 107 | def _make_layer(self, block, planes, blocks, cfg, stride=1): 108 | downsample = None 109 | if stride != 1 or self.inplanes != planes * block.expansion: 110 | downsample = partial(downsample_basic_block, planes=planes*block.expansion) 111 | 112 | layers = [] 113 | layers.append(block(self.inplanes, planes, cfg[0], stride, downsample)) 114 | self.inplanes = planes * block.expansion 115 | for i in range(1, blocks): 116 | layers.append(block(self.inplanes, planes, cfg[i])) 117 | 118 | return nn.Sequential(*layers) 119 | 120 | def forward(self, x): 121 | x = self.conv1(x) 122 | x = self.bn1(x) 123 | x = self.relu(x) # 32x32 124 | 125 | x = self.layer1(x) # 32x32 126 | x = self.layer2(x) # 16x16 127 | x = self.layer3(x) # 8x8 128 | 129 | x = self.avgpool(x) 130 | x = x.view(x.size(0), -1) 131 | x = self.fc(x) 132 | 133 | return x 134 | 135 | def resnet(**kwargs): 136 | """ 137 | Constructs a ResNet model. 138 | """ 139 | return ResNet(**kwargs) 140 | 141 | if __name__ == '__main__': 142 | net = resnet(depth=56) 143 | x=Variable(torch.FloatTensor(16, 3, 32, 32)) 144 | y = net(x) 145 | print(y.data.shape) -------------------------------------------------------------------------------- /cifar/lottery-ticket/l1-norm-pruning/models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | __all__ = ['vgg'] 8 | 9 | defaultcfg = { 10 | 11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 11 | 13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 12 | 16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 13 | 19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 14 | } 15 | 16 | class vgg(nn.Module): 17 | def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None): 18 | super(vgg, self).__init__() 19 | if cfg is None: 20 | cfg = defaultcfg[depth] 21 | 22 | self.cfg = cfg 23 | 24 | self.feature = self.make_layers(cfg, True) 25 | 26 | if dataset == 'cifar10': 27 | num_classes = 10 28 | elif dataset == 'cifar100': 29 | num_classes = 100 30 | self.classifier = nn.Sequential( 31 | nn.Linear(cfg[-1], 512), 32 | nn.BatchNorm1d(512), 33 | nn.ReLU(inplace=True), 34 | nn.Linear(512, num_classes) 35 | ) 36 | if init_weights: 37 | self._initialize_weights() 38 | 39 | def make_layers(self, cfg, batch_norm=False): 40 | layers = [] 41 | in_channels = 3 42 | for v in cfg: 43 | if v == 'M': 44 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 45 | else: 46 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 47 | if batch_norm: 48 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 49 | else: 50 | layers += [conv2d, nn.ReLU(inplace=True)] 51 | in_channels = v 52 | return nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | x = self.feature(x) 56 | x = nn.AvgPool2d(2)(x) 57 | x = x.view(x.size(0), -1) 58 | y = self.classifier(x) 59 | return y 60 | 61 | def _initialize_weights(self): 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 65 | m.weight.data.normal_(0, math.sqrt(2. / n)) 66 | if m.bias is not None: 67 | m.bias.data.zero_() 68 | elif isinstance(m, nn.BatchNorm2d): 69 | m.weight.data.fill_(0.5) 70 | m.bias.data.zero_() 71 | elif isinstance(m, nn.Linear): 72 | m.weight.data.normal_(0, 0.01) 73 | m.bias.data.zero_() 74 | 75 | if __name__ == '__main__': 76 | net = vgg() 77 | x = Variable(torch.FloatTensor(16, 3, 40, 40)) 78 | y = net(x) 79 | print(y.data.shape) -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/README.md: -------------------------------------------------------------------------------- 1 | # The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks 2 | 3 | This directory contains a pytorch implementation of [Lottery Ticket Hypothesis](https://arxiv.org/abs/1803.03635) for non-structured weight pruning introduced in this [paper](https://arxiv.org/abs/1506.02626) (NIPS 2015). 4 | 5 | ## Dependencies 6 | torch v0.3.1, torchvision v0.2.0 7 | 8 | ## Baseline 9 | 10 | ```shell 11 | python cifar.py --dataset cifar10 --arch vgg16_bn --depth 16 \ 12 | --lr 0.1 --save_dir [PATH TO SAVE THE MODEL] 13 | ``` 14 | Note that the initialization is stored in a file called `init.pth.tar`, which will be used when training the lottery ticket. 15 | 16 | ## Iterative Prune 17 | 18 | ```shell 19 | python cifar_prune_iterative.py --dataset cifar10 --arch vgg16_bn --depth 16 \ 20 | --percent RATIO --resume [PATH TO THE MODEL TO BE PRUNED] \ 21 | --save_dir [PATH TO SAVE THE PRUNED MODEL] 22 | ``` 23 | Note that `cifar_prune_iterative` is implemented as pruning all the nonzero element in the model and the ratio in `--percent` refers to the prune ratio respect to the total number of nonzero element. When a model is iteratively pruned, you just need to pass the model to be pruned each iteration to `--resume` and set the ratio to be the prune ratio respectively. 24 | 25 | ## Lottery Ticket 26 | 27 | ```shell 28 | python lottery_ticket.py --dataset cifar10 --arch vgg16_bn --depth 16 \ 29 | --lr 0.1 --resume [PATH TO THE PRUNED MODEL] \ 30 | --model [PATH TO THE STORED INITIALIZATION] \ 31 | --save_dir [PATH TO SAVE THE MODEL] 32 | ``` 33 | 34 | ## Scratch-E 35 | ``` 36 | python cifar_scratch_no_longer.py --dataset cifar10 --arch vgg16_bn --depth 16 \ 37 | --lr 0.1 --resume [PATH TO THE PRUNED MODEL] \ 38 | --save_dir [PATH TO SAVE THE MODEL] 39 | ``` 40 | 41 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eric-mingjie/rethinking-network-pruning/2ac473d70a09810df888e932bb394f225f9ed2d1/cifar/lottery-ticket/weight-level/models/__init__.py -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .alexnet import * 4 | from .vgg import * 5 | from .resnet import * 6 | from .resnext import * 7 | from .wrn import * 8 | from .preresnet import * 9 | from .densenet import * -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/models/cifar/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | __all__ = ['alexnet'] 5 | 6 | 7 | class AlexNet(nn.Module): 8 | 9 | def __init__(self, num_classes=10): 10 | super(AlexNet, self).__init__() 11 | self.features = nn.Sequential( 12 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 13 | nn.ReLU(inplace=True), 14 | nn.MaxPool2d(kernel_size=2, stride=2), 15 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 16 | nn.ReLU(inplace=True), 17 | nn.MaxPool2d(kernel_size=2, stride=2), 18 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=2, stride=2), 25 | ) 26 | self.classifier = nn.Linear(256, num_classes) 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = x.view(x.size(0), -1) 31 | x = self.classifier(x) 32 | return x 33 | 34 | 35 | def alexnet(**kwargs): 36 | r"""AlexNet model architecture from the 37 | `"One weird trick..." `_ paper. 38 | """ 39 | model = AlexNet(**kwargs) 40 | return model 41 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/models/cifar/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | __all__ = ['densenet'] 9 | 10 | 11 | from torch.autograd import Variable 12 | 13 | class Bottleneck(nn.Module): 14 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 15 | super(Bottleneck, self).__init__() 16 | planes = expansion * growthRate 17 | self.bn1 = nn.BatchNorm2d(inplanes) 18 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 21 | padding=1, bias=False) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.dropRate = dropRate 24 | 25 | def forward(self, x): 26 | out = self.bn1(x) 27 | out = self.relu(out) 28 | out = self.conv1(out) 29 | out = self.bn2(out) 30 | out = self.relu(out) 31 | out = self.conv2(out) 32 | if self.dropRate > 0: 33 | out = F.dropout(out, p=self.dropRate, training=self.training) 34 | 35 | out = torch.cat((x, out), 1) 36 | 37 | return out 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 42 | super(BasicBlock, self).__init__() 43 | planes = expansion * growthRate 44 | self.bn1 = nn.BatchNorm2d(inplanes) 45 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 46 | padding=1, bias=False) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.dropRate = dropRate 49 | 50 | def forward(self, x): 51 | out = self.bn1(x) 52 | out = self.relu(out) 53 | out = self.conv1(out) 54 | if self.dropRate > 0: 55 | out = F.dropout(out, p=self.dropRate, training=self.training) 56 | 57 | out = torch.cat((x, out), 1) 58 | 59 | return out 60 | 61 | 62 | class Transition(nn.Module): 63 | def __init__(self, inplanes, outplanes): 64 | super(Transition, self).__init__() 65 | self.bn1 = nn.BatchNorm2d(inplanes) 66 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 67 | bias=False) 68 | self.relu = nn.ReLU(inplace=True) 69 | 70 | def forward(self, x): 71 | out = self.bn1(x) 72 | out = self.relu(out) 73 | out = self.conv1(out) 74 | out = F.avg_pool2d(out, 2) 75 | return out 76 | 77 | 78 | class DenseNet(nn.Module): 79 | 80 | def __init__(self, depth=22, block=BasicBlock, 81 | dropRate=0, num_classes=10, growthRate=12, compressionRate=1): 82 | super(DenseNet, self).__init__() 83 | 84 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 85 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 86 | n = int(n) 87 | 88 | self.growthRate = growthRate 89 | self.dropRate = dropRate 90 | 91 | # self.inplanes is a global variable used across multiple 92 | # helper functions 93 | self.inplanes = growthRate * 2 94 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 95 | bias=False) 96 | self.dense1 = self._make_denseblock(block, n) 97 | self.trans1 = self._make_transition(compressionRate) 98 | self.dense2 = self._make_denseblock(block, n) 99 | self.trans2 = self._make_transition(compressionRate) 100 | self.dense3 = self._make_denseblock(block, n) 101 | self.bn = nn.BatchNorm2d(self.inplanes) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.avgpool = nn.AvgPool2d(8) 104 | self.fc = nn.Linear(self.inplanes, num_classes) 105 | 106 | # Weight initialization 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def _make_denseblock(self, block, blocks): 116 | layers = [] 117 | for i in range(blocks): 118 | # Currently we fix the expansion ratio as the default value 119 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 120 | self.inplanes += self.growthRate 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def _make_transition(self, compressionRate): 125 | inplanes = self.inplanes 126 | outplanes = int(math.floor(self.inplanes // compressionRate)) 127 | self.inplanes = outplanes 128 | return Transition(inplanes, outplanes) 129 | 130 | 131 | def forward(self, x): 132 | x = self.conv1(x) 133 | 134 | x = self.trans1(self.dense1(x)) 135 | x = self.trans2(self.dense2(x)) 136 | x = self.dense3(x) 137 | x = self.bn(x) 138 | x = self.relu(x) 139 | 140 | x = self.avgpool(x) 141 | x = x.view(x.size(0), -1) 142 | x = self.fc(x) 143 | 144 | return x 145 | 146 | 147 | def densenet(**kwargs): 148 | """ 149 | Constructs a ResNet model. 150 | """ 151 | return DenseNet(**kwargs) 152 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/models/cifar/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | 7 | 8 | __all__ = ['preresnet'] 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | "3x3 convolution with padding" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(BasicBlock, self).__init__() 21 | self.bn1 = nn.BatchNorm2d(inplanes) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.bn1(x) 33 | out = self.relu(out) 34 | out = self.conv1(out) 35 | 36 | out = self.bn2(out) 37 | out = self.relu(out) 38 | out = self.conv2(out) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | out += residual 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.bn1 = nn.BatchNorm2d(inplanes) 54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 57 | padding=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(planes) 59 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.bn1(x) 68 | out = self.relu(out) 69 | out = self.conv1(out) 70 | 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | out = self.conv2(out) 74 | 75 | out = self.bn3(out) 76 | out = self.relu(out) 77 | out = self.conv3(out) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | 84 | return out 85 | 86 | 87 | class PreResNet(nn.Module): 88 | 89 | def __init__(self, depth, num_classes=1000): 90 | super(PreResNet, self).__init__() 91 | # Model type specifies number of layers for CIFAR-10 model 92 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 93 | n = (depth - 2) // 6 94 | 95 | block = Bottleneck if depth >=44 else BasicBlock 96 | 97 | self.inplanes = 16 98 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 99 | bias=False) 100 | self.layer1 = self._make_layer(block, 16, n) 101 | self.layer2 = self._make_layer(block, 32, n, stride=2) 102 | self.layer3 = self._make_layer(block, 64, n, stride=2) 103 | self.bn = nn.BatchNorm2d(64 * block.expansion) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.avgpool = nn.AvgPool2d(8) 106 | self.fc = nn.Linear(64 * block.expansion, num_classes) 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | m.weight.data.fill_(1) 114 | m.bias.data.zero_() 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | 135 | x = self.layer1(x) # 32x32 136 | x = self.layer2(x) # 16x16 137 | x = self.layer3(x) # 8x8 138 | x = self.bn(x) 139 | x = self.relu(x) 140 | 141 | x = self.avgpool(x) 142 | x = x.view(x.size(0), -1) 143 | x = self.fc(x) 144 | 145 | return x 146 | 147 | 148 | def preresnet(**kwargs): 149 | """ 150 | Constructs a ResNet model. 151 | """ 152 | return PreResNet(**kwargs) 153 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | 7 | 8 | __all__ = ['resnet'] 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | "3x3 convolution with padding" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * 4) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class ResNet(nn.Module): 88 | 89 | def __init__(self, depth, num_classes=1000): 90 | super(ResNet, self).__init__() 91 | # Model type specifies number of layers for CIFAR-10 model 92 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 93 | n = (depth - 2) // 6 94 | 95 | block = Bottleneck if depth >=54 else BasicBlock 96 | 97 | self.inplanes = 16 98 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 99 | bias=False) 100 | self.bn1 = nn.BatchNorm2d(16) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.layer1 = self._make_layer(block, 16, n) 103 | self.layer2 = self._make_layer(block, 32, n, stride=2) 104 | self.layer3 = self._make_layer(block, 64, n, stride=2) 105 | self.avgpool = nn.AvgPool2d(8) 106 | self.fc = nn.Linear(64 * block.expansion, num_classes) 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | m.weight.data.fill_(1) 114 | m.bias.data.zero_() 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | nn.BatchNorm2d(planes * block.expansion), 123 | ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.bn1(x) 136 | x = self.relu(x) # 32x32 137 | 138 | x = self.layer1(x) # 32x32 139 | x = self.layer2(x) # 16x16 140 | x = self.layer3(x) # 8x8 141 | 142 | x = self.avgpool(x) 143 | x = x.view(x.size(0), -1) 144 | x = self.fc(x) 145 | 146 | return x 147 | 148 | 149 | def resnet(**kwargs): 150 | """ 151 | Constructs a ResNet model. 152 | """ 153 | return ResNet(**kwargs) 154 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 | 'vgg19_bn', 'vgg19', 10 | ] 11 | 12 | 13 | model_urls = { 14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 18 | } 19 | 20 | 21 | class VGG(nn.Module): 22 | 23 | def __init__(self, features, num_classes=1000): 24 | super(VGG, self).__init__() 25 | self.features = features 26 | self.classifier = nn.Linear(512, num_classes) 27 | self._initialize_weights() 28 | 29 | def forward(self, x): 30 | x = self.features(x) 31 | x = nn.AvgPool2d(2)(x) 32 | x = x.view(x.size(0), -1) 33 | x = self.classifier(x) 34 | return x 35 | 36 | def _initialize_weights(self): 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | n = m.kernel_size[0] * m.kernel_size[1] * (m.in_channels) 40 | m.weight.data.normal_(0, math.sqrt(2. / n)) 41 | if m.bias is not None: 42 | m.bias.data.zero_() 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | elif isinstance(m, nn.Linear): 47 | n = m.weight.size(1) 48 | m.weight.data.normal_(0, 0.01) 49 | m.bias.data.zero_() 50 | 51 | 52 | def make_layers(cfg, batch_norm=False): 53 | layers = [] 54 | in_channels = 3 55 | for v in cfg: 56 | if v == 'M': 57 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 58 | else: 59 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 60 | if batch_norm: 61 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 62 | else: 63 | layers += [conv2d, nn.ReLU(inplace=True)] 64 | in_channels = v 65 | return nn.Sequential(*layers) 66 | 67 | 68 | cfg = { 69 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 70 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 71 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 72 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 73 | # 'E': [64, 128, 'M', 128, 256, 'M', 64, 128, 256, 512, 1024, 'M', 64, 128, 256, 512, 1024, 2048,'M',256, 512, 1024, 512,'M'] 74 | } 75 | 76 | 77 | def vgg11(**kwargs): 78 | """VGG 11-layer model (configuration "A") 79 | 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | """ 83 | model = VGG(make_layers(cfg['A']), **kwargs) 84 | return model 85 | 86 | 87 | def vgg11_bn(**kwargs): 88 | """VGG 11-layer model (configuration "A") with batch normalization""" 89 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 90 | return model 91 | 92 | 93 | def vgg13(**kwargs): 94 | """VGG 13-layer model (configuration "B") 95 | 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | """ 99 | model = VGG(make_layers(cfg['B']), **kwargs) 100 | return model 101 | 102 | 103 | def vgg13_bn(**kwargs): 104 | """VGG 13-layer model (configuration "B") with batch normalization""" 105 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 106 | return model 107 | 108 | 109 | def vgg16(**kwargs): 110 | """VGG 16-layer model (configuration "D") 111 | 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | model = VGG(make_layers(cfg['D']), **kwargs) 116 | return model 117 | 118 | 119 | def vgg16_bn(**kwargs): 120 | """VGG 16-layer model (configuration "D") with batch normalization""" 121 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 122 | return model 123 | 124 | 125 | def vgg19(**kwargs): 126 | """VGG 19-layer model (configuration "E") 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | model = VGG(make_layers(cfg['E']), **kwargs) 132 | return model 133 | 134 | 135 | def vgg19_bn(**kwargs): 136 | """VGG 19-layer model (configuration 'E') with batch normalization""" 137 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 138 | return model 139 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/models/cifar/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['wrn'] 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 11 | super(BasicBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.relu1 = nn.ReLU(inplace=True) 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(out_planes) 17 | self.relu2 = nn.ReLU(inplace=True) 18 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 19 | padding=1, bias=False) 20 | self.droprate = dropRate 21 | self.equalInOut = (in_planes == out_planes) 22 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 23 | padding=0, bias=False) or None 24 | def forward(self, x): 25 | if not self.equalInOut: 26 | x = self.relu1(self.bn1(x)) 27 | else: 28 | out = self.relu1(self.bn1(x)) 29 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 30 | if self.droprate > 0: 31 | out = F.dropout(out, p=self.droprate, training=self.training) 32 | out = self.conv2(out) 33 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 40 | layers = [] 41 | for i in range(nb_layers): 42 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 43 | return nn.Sequential(*layers) 44 | def forward(self, x): 45 | return self.layer(x) 46 | 47 | class WideResNet(nn.Module): 48 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 49 | super(WideResNet, self).__init__() 50 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 51 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 52 | n = (depth - 4) // 6 53 | block = BasicBlock 54 | # 1st conv before any network block 55 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 56 | padding=1, bias=False) 57 | # 1st block 58 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 59 | # 2nd block 60 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 61 | # 3rd block 62 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 63 | # global average pooling and classifier 64 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.fc = nn.Linear(nChannels[3], num_classes) 67 | self.nChannels = nChannels[3] 68 | 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | elif isinstance(m, nn.BatchNorm2d): 74 | m.weight.data.fill_(1) 75 | m.bias.data.zero_() 76 | elif isinstance(m, nn.Linear): 77 | m.bias.data.zero_() 78 | 79 | def forward(self, x): 80 | out = self.conv1(x) 81 | out = self.block1(out) 82 | out = self.block2(out) 83 | out = self.block3(out) 84 | out = self.relu(self.bn1(out)) 85 | out = F.avg_pool2d(out, 8) 86 | out = out.view(-1, self.nChannels) 87 | return self.fc(out) 88 | 89 | def wrn(**kwargs): 90 | """ 91 | Constructs a Wide Residual Networks. 92 | """ 93 | model = WideResNet(**kwargs) 94 | return model 95 | -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 9 | 10 | def savefig(fname, dpi=None): 11 | dpi = 150 if dpi == None else dpi 12 | plt.savefig(fname, dpi=dpi) 13 | 14 | def plot_overlap(logger, names=None): 15 | names = logger.names if names == None else names 16 | numbers = logger.numbers 17 | for _, name in enumerate(names): 18 | x = np.arange(len(numbers[name])) 19 | plt.plot(x, np.asarray(numbers[name])) 20 | return [logger.title + '(' + name + ')' for name in names] 21 | 22 | class Logger(object): 23 | '''Save training process to log file with simple plot function.''' 24 | def __init__(self, fpath, title=None, resume=False): 25 | self.file = None 26 | self.resume = resume 27 | self.title = '' if title == None else title 28 | if fpath is not None: 29 | if resume: 30 | self.file = open(fpath, 'r') 31 | name = self.file.readline() 32 | self.names = name.rstrip().split('\t') 33 | self.numbers = {} 34 | for _, name in enumerate(self.names): 35 | self.numbers[name] = [] 36 | 37 | for numbers in self.file: 38 | numbers = numbers.rstrip().split('\t') 39 | for i in range(0, len(numbers)): 40 | self.numbers[self.names[i]].append(numbers[i]) 41 | self.file.close() 42 | self.file = open(fpath, 'a') 43 | else: 44 | self.file = open(fpath, 'w') 45 | 46 | def set_names(self, names): 47 | if self.resume: 48 | pass 49 | # initialize numbers as empty list 50 | self.numbers = {} 51 | self.names = names 52 | for _, name in enumerate(self.names): 53 | self.file.write(name) 54 | self.file.write('\t') 55 | self.numbers[name] = [] 56 | self.file.write('\n') 57 | self.file.flush() 58 | 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | self.file.write("{0:.6f}".format(num)) 64 | self.file.write('\t') 65 | self.numbers[self.names[index]].append(num) 66 | self.file.write('\n') 67 | self.file.flush() 68 | 69 | def plot(self, names=None): 70 | names = self.names if names == None else names 71 | numbers = self.numbers 72 | for _, name in enumerate(names): 73 | x = np.arange(len(numbers[name])) 74 | plt.plot(x, np.asarray(numbers[name])) 75 | plt.legend([self.title + '(' + name + ')' for name in names]) 76 | plt.grid(True) 77 | 78 | def close(self): 79 | if self.file is not None: 80 | self.file.close() 81 | 82 | class LoggerMonitor(object): 83 | '''Load and visualize multiple logs.''' 84 | def __init__ (self, paths): 85 | '''paths is a distionary with {name:filepath} pair''' 86 | self.loggers = [] 87 | for title, path in paths.items(): 88 | logger = Logger(path, title=title, resume=True) 89 | self.loggers.append(logger) 90 | 91 | def plot(self, names=None): 92 | plt.figure() 93 | plt.subplot(121) 94 | legend_text = [] 95 | for logger in self.loggers: 96 | legend_text += plot_overlap(logger, names) 97 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 98 | plt.grid(True) 99 | 100 | if __name__ == '__main__': 101 | # # Example 102 | # logger = Logger('test.txt') 103 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 104 | 105 | # length = 100 106 | # t = np.arange(length) 107 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 108 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | 111 | # for i in range(0, length): 112 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 113 | # logger.plot() 114 | 115 | # Example: logger monitor 116 | paths = { 117 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 118 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 119 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 120 | } 121 | 122 | field = ['Valid Acc.'] 123 | 124 | monitor = LoggerMonitor(paths) 125 | monitor.plot(names=field) 126 | savefig('test.eps') -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import torch 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | from torch.autograd import Variable 16 | 17 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 18 | 19 | 20 | def get_mean_and_std(dataset): 21 | '''Compute the mean and std value of dataset.''' 22 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 23 | 24 | mean = torch.zeros(3) 25 | std = torch.zeros(3) 26 | print('==> Computing mean and std..') 27 | for inputs, targets in dataloader: 28 | for i in range(3): 29 | mean[i] += inputs[:,i,:,:].mean() 30 | std[i] += inputs[:,i,:,:].std() 31 | mean.div_(len(dataset)) 32 | std.div_(len(dataset)) 33 | return mean, std 34 | 35 | def get_conv_zero_param(model): 36 | total = 0 37 | for m in model.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | total += torch.sum(m.weight.data.eq(0)) 40 | return total 41 | 42 | def init_params(net): 43 | '''Init layer parameters.''' 44 | for m in net.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | init.kaiming_normal(m.weight, mode='fan_out') 47 | if m.bias: 48 | init.constant(m.bias, 0) 49 | elif isinstance(m, nn.BatchNorm2d): 50 | init.constant(m.weight, 1) 51 | init.constant(m.bias, 0) 52 | elif isinstance(m, nn.Linear): 53 | init.normal(m.weight, std=1e-3) 54 | if m.bias: 55 | init.constant(m.bias, 0) 56 | 57 | def mkdir_p(path): 58 | '''make dir if not exist''' 59 | try: 60 | os.makedirs(path) 61 | except OSError as exc: # Python >2.5 62 | if exc.errno == errno.EEXIST and os.path.isdir(path): 63 | pass 64 | else: 65 | raise 66 | 67 | class AverageMeter(object): 68 | """Computes and stores the average and current value 69 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 70 | """ 71 | def __init__(self): 72 | self.reset() 73 | 74 | def reset(self): 75 | self.val = 0 76 | self.avg = 0 77 | self.sum = 0 78 | self.count = 0 79 | 80 | def update(self, val, n=1): 81 | self.val = val 82 | self.sum += val * n 83 | self.count += n 84 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /cifar/lottery-ticket/weight-level/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /cifar/network-slimming/README.md: -------------------------------------------------------------------------------- 1 | # Network Slimming 2 | 3 | This directory contains the pytorch implementation for [network slimming](http://openaccess.thecvf.com/content_iccv_2017/html/Liu_Learning_Efficient_Convolutional_ICCV_2017_paper.html) (ICCV 2017). 4 | 5 | ## Channel Selection Layer 6 | We introduce `channel selection` layer to help the pruning of ResNet and DenseNet. This layer is easy to implement. It stores a parameter `indexes` which is initialized to an all-1 vector. During pruning, it will set some places to 0 which correspond to the pruned channels. 7 | 8 | ## Baseline 9 | 10 | The `dataset` argument specifies which dataset to use: `cifar10` or `cifar100`. The `arch` argument specifies the architecture to use: `vgg`,`resnet` or 11 | `densenet`. The depth is chosen to be the same as the networks used in the paper. 12 | ```shell 13 | python main.py --dataset cifar10 --arch vgg --depth 19 14 | python main.py --dataset cifar10 --arch resnet --depth 164 15 | python main.py --dataset cifar10 --arch densenet --depth 40 16 | ``` 17 | 18 | ## Train with Sparsity 19 | 20 | ```shell 21 | python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19 22 | python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164 23 | python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40 24 | ``` 25 | 26 | ## Prune 27 | 28 | ```shell 29 | python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 30 | python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 31 | python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 32 | ``` 33 | The pruned model will be named `pruned.pth.tar`. 34 | 35 | ## Fine-tune 36 | 37 | ```shell 38 | python main_finetune.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 39 | python main_finetune.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 164 40 | python main_finetune.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch densenet --depth 40 41 | ``` 42 | 43 | ## Scratch-E 44 | ``` 45 | python main_E.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 46 | python main_E.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 164 47 | python main_E.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch densenet --depth 40 48 | ``` 49 | 50 | ## Scratch-B 51 | ``` 52 | python main_B.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 53 | python main_B.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch resnet --depth 164 54 | python main_B.py --scratch [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch densenet --depth 40 55 | ``` 56 | 57 | -------------------------------------------------------------------------------- /cifar/network-slimming/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | 4 | import torch 5 | import torchvision 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | 10 | def print_model_param_nums(model=None): 11 | if model == None: 12 | model = torchvision.models.alexnet() 13 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 14 | print(' + Number of params: %.2fM' % (total / 1e6)) 15 | return total 16 | 17 | def print_model_param_flops(model=None, input_res=224, multiply_adds=True): 18 | 19 | prods = {} 20 | def save_hook(name): 21 | def hook_per(self, input, output): 22 | prods[name] = np.prod(input[0].shape) 23 | return hook_per 24 | 25 | list_1=[] 26 | def simple_hook(self, input, output): 27 | list_1.append(np.prod(input[0].shape)) 28 | 29 | list_2={} 30 | def simple_hook2(self, input, output): 31 | list_2['names'] = np.prod(input[0].shape) 32 | 33 | list_conv=[] 34 | def conv_hook(self, input, output): 35 | batch_size, input_channels, input_height, input_width = input[0].size() 36 | output_channels, output_height, output_width = output[0].size() 37 | 38 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 39 | bias_ops = 1 if self.bias is not None else 0 40 | 41 | params = output_channels * (kernel_ops + bias_ops) 42 | flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 43 | 44 | list_conv.append(flops) 45 | 46 | list_linear=[] 47 | def linear_hook(self, input, output): 48 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 49 | 50 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 51 | bias_ops = self.bias.nelement() 52 | 53 | flops = batch_size * (weight_ops + bias_ops) 54 | list_linear.append(flops) 55 | 56 | list_bn=[] 57 | def bn_hook(self, input, output): 58 | list_bn.append(input[0].nelement() * 2) 59 | 60 | list_relu=[] 61 | def relu_hook(self, input, output): 62 | list_relu.append(input[0].nelement()) 63 | 64 | list_pooling=[] 65 | def pooling_hook(self, input, output): 66 | batch_size, input_channels, input_height, input_width = input[0].size() 67 | output_channels, output_height, output_width = output[0].size() 68 | 69 | kernel_ops = self.kernel_size * self.kernel_size 70 | bias_ops = 0 71 | params = 0 72 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 73 | 74 | list_pooling.append(flops) 75 | 76 | list_upsample=[] 77 | # For bilinear upsample 78 | def upsample_hook(self, input, output): 79 | batch_size, input_channels, input_height, input_width = input[0].size() 80 | output_channels, output_height, output_width = output[0].size() 81 | 82 | flops = output_height * output_width * output_channels * batch_size * 12 83 | list_upsample.append(flops) 84 | 85 | def foo(net): 86 | childrens = list(net.children()) 87 | if not childrens: 88 | if isinstance(net, torch.nn.Conv2d): 89 | net.register_forward_hook(conv_hook) 90 | if isinstance(net, torch.nn.Linear): 91 | net.register_forward_hook(linear_hook) 92 | if isinstance(net, torch.nn.BatchNorm2d): 93 | net.register_forward_hook(bn_hook) 94 | if isinstance(net, torch.nn.ReLU): 95 | net.register_forward_hook(relu_hook) 96 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 97 | net.register_forward_hook(pooling_hook) 98 | if isinstance(net, torch.nn.Upsample): 99 | net.register_forward_hook(upsample_hook) 100 | return 101 | for c in childrens: 102 | foo(c) 103 | 104 | if model == None: 105 | model = torchvision.models.alexnet() 106 | foo(model) 107 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 108 | out = model(input) 109 | 110 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 111 | 112 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 113 | 114 | return total_flops -------------------------------------------------------------------------------- /cifar/network-slimming/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .vgg import * 4 | from .preresnet import * 5 | from .densenet import * 6 | from .channel_selection import * -------------------------------------------------------------------------------- /cifar/network-slimming/models/channel_selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class channel_selection(nn.Module): 8 | """ 9 | Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer. 10 | The output shape of this layer is determined by the number of 1 in `self.indexes`. 11 | """ 12 | def __init__(self, num_channels): 13 | """ 14 | Initialize the `indexes` with all one vector with the length same as the number of channels. 15 | During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0. 16 | """ 17 | super(channel_selection, self).__init__() 18 | self.indexes = nn.Parameter(torch.ones(num_channels)) 19 | 20 | def forward(self, input_tensor): 21 | """ 22 | Parameter 23 | --------- 24 | input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer. 25 | """ 26 | selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy())) 27 | if selected_index.size == 1: 28 | selected_index = np.resize(selected_index, (1,)) 29 | output = input_tensor[:, selected_index, :, :] 30 | return output -------------------------------------------------------------------------------- /cifar/network-slimming/models/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | from .channel_selection import channel_selection 9 | 10 | 11 | __all__ = ['densenet'] 12 | 13 | """ 14 | densenet with basic block. 15 | """ 16 | 17 | class BasicBlock(nn.Module): 18 | def __init__(self, inplanes, cfg, expansion=1, growthRate=12, dropRate=0): 19 | super(BasicBlock, self).__init__() 20 | planes = expansion * growthRate 21 | self.bn1 = nn.BatchNorm2d(inplanes) 22 | self.select = channel_selection(inplanes) 23 | self.conv1 = nn.Conv2d(cfg, growthRate, kernel_size=3, 24 | padding=1, bias=False) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.dropRate = dropRate 27 | 28 | def forward(self, x): 29 | out = self.bn1(x) 30 | out = self.select(out) 31 | out = self.relu(out) 32 | out = self.conv1(out) 33 | if self.dropRate > 0: 34 | out = F.dropout(out, p=self.dropRate, training=self.training) 35 | 36 | out = torch.cat((x, out), 1) 37 | 38 | return out 39 | 40 | class Transition(nn.Module): 41 | def __init__(self, inplanes, outplanes, cfg): 42 | super(Transition, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(inplanes) 44 | self.select = channel_selection(inplanes) 45 | self.conv1 = nn.Conv2d(cfg, outplanes, kernel_size=1, 46 | bias=False) 47 | self.relu = nn.ReLU(inplace=True) 48 | 49 | def forward(self, x): 50 | out = self.bn1(x) 51 | out = self.select(out) 52 | out = self.relu(out) 53 | out = self.conv1(out) 54 | out = F.avg_pool2d(out, 2) 55 | return out 56 | 57 | class densenet(nn.Module): 58 | 59 | def __init__(self, depth=40, 60 | dropRate=0, dataset='cifar10', growthRate=12, compressionRate=1, cfg = None): 61 | super(densenet, self).__init__() 62 | 63 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 64 | n = (depth - 4) // 3 65 | block = BasicBlock 66 | 67 | self.growthRate = growthRate 68 | self.dropRate = dropRate 69 | 70 | if cfg == None: 71 | cfg = [] 72 | start = growthRate*2 73 | for i in range(3): 74 | cfg.append([start+12*i for i in range(n+1)]) 75 | start += growthRate*12 76 | cfg = [item for sub_list in cfg for item in sub_list] 77 | 78 | assert len(cfg) == 3*n+3, 'length of config variable cfg should be 3n+3' 79 | 80 | # self.inplanes is a global variable used across multiple 81 | # helper functions 82 | self.inplanes = growthRate * 2 83 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 84 | bias=False) 85 | self.dense1 = self._make_denseblock(block, n, cfg[0:n]) 86 | self.trans1 = self._make_transition(compressionRate, cfg[n]) 87 | self.dense2 = self._make_denseblock(block, n, cfg[n+1:2*n+1]) 88 | self.trans2 = self._make_transition(compressionRate, cfg[2*n+1]) 89 | self.dense3 = self._make_denseblock(block, n, cfg[2*n+2:3*n+2]) 90 | self.bn = nn.BatchNorm2d(self.inplanes) 91 | self.select = channel_selection(self.inplanes) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.avgpool = nn.AvgPool2d(8) 94 | 95 | if dataset == 'cifar10': 96 | self.fc = nn.Linear(cfg[-1], 10) 97 | elif dataset == 'cifar100': 98 | self.fc = nn.Linear(cfg[-1], 100) 99 | 100 | # Weight initialization 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(0.5) 107 | m.bias.data.zero_() 108 | 109 | def _make_denseblock(self, block, blocks, cfg): 110 | layers = [] 111 | assert blocks == len(cfg), 'Length of the cfg parameter is not right.' 112 | for i in range(blocks): 113 | # Currently we fix the expansion ratio as the default value 114 | layers.append(block(self.inplanes, cfg = cfg[i], growthRate=self.growthRate, dropRate=self.dropRate)) 115 | self.inplanes += self.growthRate 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def _make_transition(self, compressionRate, cfg): 120 | # cfg is a number in this case. 121 | inplanes = self.inplanes 122 | outplanes = int(math.floor(self.inplanes // compressionRate)) 123 | self.inplanes = outplanes 124 | return Transition(inplanes, outplanes, cfg) 125 | 126 | def forward(self, x): 127 | x = self.conv1(x) 128 | 129 | x = self.trans1(self.dense1(x)) 130 | x = self.trans2(self.dense2(x)) 131 | x = self.dense3(x) 132 | x = self.bn(x) 133 | x = self.select(x) 134 | x = self.relu(x) 135 | 136 | x = self.avgpool(x) 137 | x = x.view(x.size(0), -1) 138 | x = self.fc(x) 139 | 140 | return x -------------------------------------------------------------------------------- /cifar/network-slimming/models/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | from .channel_selection import channel_selection 7 | 8 | 9 | __all__ = ['resnet'] 10 | 11 | """ 12 | preactivation resnet with bottleneck design. 13 | """ 14 | 15 | class Bottleneck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, cfg, stride=1, downsample=None): 19 | super(Bottleneck, self).__init__() 20 | self.bn1 = nn.BatchNorm2d(inplanes) 21 | self.select = channel_selection(inplanes) 22 | self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(cfg[1]) 24 | self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(cfg[2]) 27 | self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.bn1(x) 36 | out = self.select(out) 37 | out = self.relu(out) 38 | out = self.conv1(out) 39 | 40 | out = self.bn2(out) 41 | out = self.relu(out) 42 | out = self.conv2(out) 43 | 44 | out = self.bn3(out) 45 | out = self.relu(out) 46 | out = self.conv3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | 53 | return out 54 | 55 | class resnet(nn.Module): 56 | def __init__(self, depth=164, dataset='cifar10', cfg=None): 57 | super(resnet, self).__init__() 58 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2' 59 | 60 | n = (depth - 2) // 9 61 | block = Bottleneck 62 | 63 | if cfg is None: 64 | # Construct config variable. 65 | cfg = [[16, 16, 16], [64, 16, 16]*(n-1), [64, 32, 32], [128, 32, 32]*(n-1), [128, 64, 64], [256, 64, 64]*(n-1), [256]] 66 | cfg = [item for sub_list in cfg for item in sub_list] 67 | 68 | self.inplanes = 16 69 | 70 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 71 | bias=False) 72 | self.layer1 = self._make_layer(block, 16, n, cfg = cfg[0:3*n]) 73 | self.layer2 = self._make_layer(block, 32, n, cfg = cfg[3*n:6*n], stride=2) 74 | self.layer3 = self._make_layer(block, 64, n, cfg = cfg[6*n:9*n], stride=2) 75 | self.bn = nn.BatchNorm2d(64 * block.expansion) 76 | self.select = channel_selection(64 * block.expansion) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.avgpool = nn.AvgPool2d(8) 79 | 80 | if dataset == 'cifar10': 81 | self.fc = nn.Linear(cfg[-1], 10) 82 | elif dataset == 'cifar100': 83 | self.fc = nn.Linear(cfg[-1], 100) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(0.5) 91 | m.bias.data.zero_() 92 | 93 | def _make_layer(self, block, planes, blocks, cfg, stride=1): 94 | downsample = None 95 | if stride != 1 or self.inplanes != planes * block.expansion: 96 | downsample = nn.Sequential( 97 | nn.Conv2d(self.inplanes, planes * block.expansion, 98 | kernel_size=1, stride=stride, bias=False), 99 | ) 100 | 101 | layers = [] 102 | layers.append(block(self.inplanes, planes, cfg[0:3], stride, downsample)) 103 | self.inplanes = planes * block.expansion 104 | for i in range(1, blocks): 105 | layers.append(block(self.inplanes, planes, cfg[3*i: 3*(i+1)])) 106 | 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | x = self.conv1(x) 111 | 112 | x = self.layer1(x) # 32x32 113 | x = self.layer2(x) # 16x16 114 | x = self.layer3(x) # 8x8 115 | x = self.bn(x) 116 | x = self.select(x) 117 | x = self.relu(x) 118 | 119 | x = self.avgpool(x) 120 | x = x.view(x.size(0), -1) 121 | x = self.fc(x) 122 | 123 | return x -------------------------------------------------------------------------------- /cifar/network-slimming/models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | __all__ = ['vgg'] 9 | 10 | defaultcfg = { 11 | 11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 12 | 13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 13 | 16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 14 | 19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 15 | } 16 | 17 | class vgg(nn.Module): 18 | def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None): 19 | super(vgg, self).__init__() 20 | if cfg is None: 21 | cfg = defaultcfg[depth] 22 | 23 | self.feature = self.make_layers(cfg, True) 24 | 25 | if dataset == 'cifar10': 26 | num_classes = 10 27 | elif dataset == 'cifar100': 28 | num_classes = 100 29 | self.classifier = nn.Linear(cfg[-1], num_classes) 30 | if init_weights: 31 | self._initialize_weights() 32 | 33 | def make_layers(self, cfg, batch_norm=False): 34 | layers = [] 35 | in_channels = 3 36 | for v in cfg: 37 | if v == 'M': 38 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 39 | else: 40 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 41 | if batch_norm: 42 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 43 | else: 44 | layers += [conv2d, nn.ReLU(inplace=True)] 45 | in_channels = v 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, x): 49 | x = self.feature(x) 50 | x = nn.AvgPool2d(2)(x) 51 | x = x.view(x.size(0), -1) 52 | y = self.classifier(x) 53 | return y 54 | 55 | def _initialize_weights(self): 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 59 | m.weight.data.normal_(0, math.sqrt(2. / n)) 60 | if m.bias is not None: 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.BatchNorm2d): 63 | m.weight.data.fill_(0.5) 64 | m.bias.data.zero_() 65 | elif isinstance(m, nn.Linear): 66 | m.weight.data.normal_(0, 0.01) 67 | m.bias.data.zero_() 68 | 69 | if __name__ == '__main__': 70 | net = vgg() 71 | x = Variable(torch.FloatTensor(16, 3, 40, 40)) 72 | y = net(x) 73 | print(y.data.shape) -------------------------------------------------------------------------------- /cifar/soft-filter-pruning/README.md: -------------------------------------------------------------------------------- 1 | # Soft Filter Pruning for Accelerating Deep Convolutional Neural Networks 2 | 3 | This directory contains the pytorch implementation for [soft filter pruning](https://www.ijcai.org/proceedings/2018/0309.pdf) (IJCAI 2018). 4 | Official implementation: [soft-filter-pruning](https://github.com/he-y/soft-filter-pruning). 5 | 6 | ## Dependencies 7 | - torch v0.3.1, torchvision v0.3.0 8 | 9 | ## Overview 10 | Specify the path to dataset in `DATA`. The argument `--arch` can be [`resnet20`,`resnet32`,`resnet56`,`resnet110`]. 11 | Below shows the choice of the argument `--layer_end` over different architectures: 12 | `resnet20`: 54 `resnet32`: 90 `resnet56`: 162 `resnet110`:324 13 | The hyperparameter settings are the same as those in the original paper. 14 | 15 | ## Baseline 16 | ```shell 17 | python pruning_cifar10_pretrain.py DATA --dataset cifar10 \ 18 | --arch resnet56 --save_path [PATH TO SAVE THE MODEL] \ 19 | --epochs 200 --schedule 1 60 120 160 --gammas 10 0.2 0.2 0.2 \ 20 | --learning_rate 0.01 --decay 0.0005 --batch_size 128 --rate 0.7 \ 21 | --layer_begin 0 --layer_end 162 --layer_inter 3 --epoch_prune 1 22 | ``` 23 | 24 | ## Soft Filter Pruning 25 | By not passing the argument `--resume`, we do not use the pretrained model. To use pretrained models, pass the argument `--resume` with the path to the pretrained model. 26 | ```shell 27 | python pruning_cifar10_resnet.py DATA --dataset cifar10 \ 28 | --arch resnet56 --save_path [PATH TO SAVE THE PRUNED MODEL] \ 29 | --epochs 200 --schedule 1 60 120 160 --gammas 10 0.2 0.2 0.2 \ 30 | --learning_rate 0.001 --decay 0.0005 --batch_size 128 --rate 0.7 \ 31 | --layer_begin 0 --layer_end 162 --layer_inter 3 --epoch_prune 1 32 | ``` 33 | 34 | ## Scratch-E 35 | ```shell 36 | python pruning_resnet_scratch.py DATA --dataset cifar10 \ 37 | --arch resnet56 --resume [PATH TO THE PRUNED MODEL] \ 38 | --save [PATH TO SAVE THE MODEL] \ 39 | --epochs 200 --schedule 1 60 120 160 --gammas 10 0.2 0.2 0.2 \ 40 | --learning_rate 0.01 --decay 0.0005 --batch_size 128 --rate 0.7 41 | --layer_begin 0 --layer_end 162 --layer_inter 3 --epoch_prune 1 42 | ``` 43 | 44 | ## Scratch-B 45 | ```shell 46 | python pruning_resnet_longer_scratch.py DATA --dataset cifar10 \ 47 | --arch resnet56 --resume [PATH TO THE PRUNED MODEL] \ 48 | --save [PATH TO SAVE THE MODEL] \ 49 | --epochs 200 --schedule 1 60 120 160 --gammas 10 0.2 0.2 0.2 \ 50 | --learning_rate 0.01 --decay 0.0005 --batch_size 128 --rate 0.7 51 | --layer_begin 0 --layer_end 162 --layer_inter 3 --epoch_prune 1 52 | ``` -------------------------------------------------------------------------------- /cifar/soft-filter-pruning/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | import models.cifar as models 11 | 12 | 13 | def print_model_param_nums(model=None): 14 | if model == None: 15 | model = torchvision.models.alexnet() 16 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 17 | for m in model.modules(): 18 | if isinstance(m, nn.Conv2d): 19 | total -= (m.weight.data == 0).sum() 20 | print(' + Number of params: %.2fM' % (total / 1e6)) 21 | return total 22 | 23 | def count_model_param_flops(model=None, input_res=224, multiply_adds=True): 24 | 25 | prods = {} 26 | def save_hook(name): 27 | def hook_per(self, input, output): 28 | prods[name] = np.prod(input[0].shape) 29 | return hook_per 30 | 31 | list_1=[] 32 | def simple_hook(self, input, output): 33 | list_1.append(np.prod(input[0].shape)) 34 | list_2={} 35 | def simple_hook2(self, input, output): 36 | list_2['names'] = np.prod(input[0].shape) 37 | 38 | 39 | list_conv=[] 40 | def conv_hook(self, input, output): 41 | batch_size, input_channels, input_height, input_width = input[0].size() 42 | output_channels, output_height, output_width = output[0].size() 43 | 44 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 45 | bias_ops = 1 if self.bias is not None else 0 46 | 47 | params = output_channels * (kernel_ops + bias_ops) 48 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 49 | 50 | num_weight_params = (self.weight.data != 0).float().sum() 51 | assert self.weight.numel() == kernel_ops * output_channels, "Not match" 52 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 53 | 54 | list_conv.append(flops) 55 | 56 | list_linear=[] 57 | def linear_hook(self, input, output): 58 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 59 | 60 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 61 | bias_ops = self.bias.nelement() 62 | 63 | flops = batch_size * (weight_ops + bias_ops) 64 | list_linear.append(flops) 65 | 66 | list_bn=[] 67 | def bn_hook(self, input, output): 68 | list_bn.append(input[0].nelement() * 2) 69 | 70 | list_relu=[] 71 | def relu_hook(self, input, output): 72 | list_relu.append(input[0].nelement()) 73 | 74 | list_pooling=[] 75 | def pooling_hook(self, input, output): 76 | batch_size, input_channels, input_height, input_width = input[0].size() 77 | output_channels, output_height, output_width = output[0].size() 78 | 79 | kernel_ops = self.kernel_size * self.kernel_size 80 | bias_ops = 0 81 | params = 0 82 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 83 | 84 | list_pooling.append(flops) 85 | 86 | list_upsample=[] 87 | # For bilinear upsample 88 | def upsample_hook(self, input, output): 89 | batch_size, input_channels, input_height, input_width = input[0].size() 90 | output_channels, output_height, output_width = output[0].size() 91 | 92 | flops = output_height * output_width * output_channels * batch_size * 12 93 | list_upsample.append(flops) 94 | 95 | def foo(net): 96 | childrens = list(net.children()) 97 | if not childrens: 98 | if isinstance(net, torch.nn.Conv2d): 99 | net.register_forward_hook(conv_hook) 100 | if isinstance(net, torch.nn.Linear): 101 | net.register_forward_hook(linear_hook) 102 | if isinstance(net, torch.nn.BatchNorm2d): 103 | net.register_forward_hook(bn_hook) 104 | if isinstance(net, torch.nn.ReLU): 105 | net.register_forward_hook(relu_hook) 106 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 107 | net.register_forward_hook(pooling_hook) 108 | if isinstance(net, torch.nn.Upsample): 109 | net.register_forward_hook(upsample_hook) 110 | return 111 | for c in childrens: 112 | foo(c) 113 | 114 | if model == None: 115 | model = torchvision.models.alexnet() 116 | foo(model) 117 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 118 | out = model(input) 119 | 120 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 121 | 122 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 123 | 124 | return total_flops -------------------------------------------------------------------------------- /cifar/soft-filter-pruning/utils.py: -------------------------------------------------------------------------------- 1 | import os, random, sys, time 2 | import numpy as np 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | class RecorderMeter(object): 24 | """Computes and stores the minimum loss value and its epoch index""" 25 | def __init__(self, total_epoch): 26 | self.reset(total_epoch) 27 | 28 | def reset(self, total_epoch): 29 | assert total_epoch > 0 30 | self.total_epoch = total_epoch 31 | self.current_epoch = 0 32 | self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] 33 | self.epoch_losses = self.epoch_losses - 1 34 | 35 | self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] 36 | self.epoch_accuracy= self.epoch_accuracy 37 | 38 | def update(self, idx, train_loss, train_acc, val_loss, val_acc): 39 | assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx) 40 | self.epoch_losses [idx, 0] = train_loss 41 | self.epoch_losses [idx, 1] = val_loss 42 | self.epoch_accuracy[idx, 0] = train_acc 43 | self.epoch_accuracy[idx, 1] = val_acc 44 | self.current_epoch = idx + 1 45 | return self.max_accuracy(False) == val_acc 46 | 47 | def max_accuracy(self, istrain): 48 | if self.current_epoch <= 0: return 0 49 | if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max() 50 | else: return self.epoch_accuracy[:self.current_epoch, 1].max() 51 | 52 | def plot_curve(self, save_path): 53 | title = 'the accuracy/loss curve of train/val' 54 | dpi = 80 55 | width, height = 1200, 800 56 | legend_fontsize = 10 57 | scale_distance = 48.8 58 | figsize = width / float(dpi), height / float(dpi) 59 | 60 | fig = plt.figure(figsize=figsize) 61 | x_axis = np.array([i for i in range(self.total_epoch)]) # epochs 62 | y_axis = np.zeros(self.total_epoch) 63 | 64 | plt.xlim(0, self.total_epoch) 65 | plt.ylim(0, 100) 66 | interval_y = 5 67 | interval_x = 5 68 | plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) 69 | plt.yticks(np.arange(0, 100 + interval_y, interval_y)) 70 | plt.grid() 71 | plt.title(title, fontsize=20) 72 | plt.xlabel('the training epoch', fontsize=16) 73 | plt.ylabel('accuracy', fontsize=16) 74 | 75 | y_axis[:] = self.epoch_accuracy[:, 0] 76 | plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2) 77 | plt.legend(loc=4, fontsize=legend_fontsize) 78 | 79 | y_axis[:] = self.epoch_accuracy[:, 1] 80 | plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2) 81 | plt.legend(loc=4, fontsize=legend_fontsize) 82 | 83 | 84 | y_axis[:] = self.epoch_losses[:, 0] 85 | plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2) 86 | plt.legend(loc=4, fontsize=legend_fontsize) 87 | 88 | y_axis[:] = self.epoch_losses[:, 1] 89 | plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2) 90 | plt.legend(loc=4, fontsize=legend_fontsize) 91 | 92 | if save_path is not None: 93 | fig.savefig(save_path, dpi=dpi, bbox_inches='tight') 94 | print ('---- save figure {} into {}'.format(title, save_path)) 95 | plt.close(fig) 96 | 97 | 98 | def time_string(): 99 | ISOTIMEFORMAT='%Y-%m-%d %X' 100 | string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 101 | return string 102 | 103 | def convert_secs2time(epoch_time): 104 | need_hour = int(epoch_time / 3600) 105 | need_mins = int((epoch_time - 3600*need_hour) / 60) 106 | need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) 107 | return need_hour, need_mins, need_secs 108 | 109 | def time_file_str(): 110 | ISOTIMEFORMAT='%Y-%m-%d' 111 | string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 112 | return string + '-{}'.format(random.randint(1, 10000)) 113 | -------------------------------------------------------------------------------- /cifar/weight-level/README.md: -------------------------------------------------------------------------------- 1 | # Non-Structured Pruning/Weight-Level Pruning 2 | 3 | This directory contains a pytorch implementation of the CIFAR experiments of non-structured pruning introduced in this [paper](https://arxiv.org/abs/1506.02626) (NIPS 2015). 4 | 5 | ## Dependencies 6 | progress v1.3, torch v0.3.1, torchvision v0.2.0 7 | 8 | ## Implementation 9 | We prune only the weights in the convolutional layer. We use the mask implementation, where during pruning, we set the weights that are pruned to be 0. During training, we make sure that we don't update those pruned parameters. 10 | 11 | ## Baseline 12 | 13 | The `dataset` argument specifies which dataset to use: `cifar10` or `cifar100`. The `arch` argument specifies the architecture to use: `vgg` or `resnet`. The depth is chosen to be the same as the networks used in the paper. 14 | ```shell 15 | python cifar.py --dataset cifar10 --arch vgg19_bn --depth 19 16 | python cifar.py --dataset cifar10 --arch preresnet --depth 110 17 | python cifar.py --dataset cifar10 --arch densenet --depth 40 18 | python cifar.py --dataset cifar10 --arch densenet --depth 100 --compressionRate 2 19 | ``` 20 | 21 | ## Prune 22 | 23 | ```shell 24 | python cifar_prune.py --arch vgg19_bn --depth 19 --dataset cifar10 --percent 0.3 --resume [PATH TO THE MODEL] --save_dir [DIRECTORY TO STORE RESULT] 25 | python cifar_prune.py --arch preresnet --depth 110 --dataset cifar10 --percent 0.3 --resume [PATH TO THE MODEL] --save_dir [DIRECTORY TO STORE RESULT] 26 | python cifar_prune.py --arch densenet --depth 40 --dataset cifar10 --percent 0.3 --resume [PATH TO THE MODEL] --save_dir [DIRECTORY TO STORE RESULT] 27 | python cifar_prune.py --arch densenet --depth 100 --compressionRate 2 --dataset cifar10 --percent 0.3 --resume [PATH TO THE MODEL] --save_dir [DIRECTORY TO STORE RESULT] 28 | ``` 29 | 30 | 31 | ## Fine-tune 32 | ```shell 33 | python cifar_finetune.py --arch vgg19_bn --depth 19 --dataset cifar10 --resume [PATH TO THE PRUNED MODEL] 34 | python cifar_finetune.py --arch preresnet --depth 110 --dataset cifar10 --resume [PATH TO THE PRUNED MODEL] 35 | python cifar_finetune.py --arch densenet --depth 40 --dataset cifar10 --resume [PATH TO THE PRUNED MODEL] 36 | python cifar_finetune.py --arch densenet --depth 100 --compressionRate 2 --dataset cifar10 --resume [PATH TO THE PRUNED MODEL] 37 | ``` 38 | 39 | ## Scratch-E 40 | ``` 41 | python cifar_E.py --arch vgg19_bn --depth 19 --dataset cifar10 --scratch [PATH TO THE PRUNED MODEL] 42 | python cifar_E.py --arch preresnet --depth 110 --dataset cifar10 --scratch [PATH TO THE PRUNED MODEL] 43 | python cifar_E.py --arch densenet --depth 40 --dataset cifar10 --scratch [PATH TO THE PRUNED MODEL] 44 | python cifar_E.py --arch densenet --depth 100 --compressionRate 2 --dataset cifar10 --scratch [PATH TO THE PRUNED MODEL] 45 | ``` 46 | 47 | ## Scratch-B 48 | ``` 49 | python cifar_B.py--arch vgg19_bn --depth 19 --dataset cifar10 --scratch [PATH TO THE PRUNED MODEL] 50 | python cifar_B.py--arch preresnet --depth 110 --dataset cifar10 --scratch [PATH TO THE PRUNED MODEL] 51 | python cifar_B.py--arch densenet --depth 40 --dataset cifar10 --scratch [PATH TO THE PRUNED MODEL] 52 | python cifar_B.py--arch densenet --depth 100 --dataset cifar10 --scratch [PATH TO THE PRUNED MODEL] 53 | ``` 54 | 55 | -------------------------------------------------------------------------------- /cifar/weight-level/count_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | import models.cifar as models 11 | 12 | 13 | def print_model_param_nums(model=None): 14 | if model == None: 15 | model = torchvision.models.alexnet() 16 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 17 | for m in model.modules(): 18 | if isinstance(m, nn.Conv2d): 19 | total -= (m.weight.data == 0).sum() 20 | print(' + Number of params: %.2fM' % (total / 1e6)) 21 | return total 22 | 23 | def count_model_param_flops(model=None, input_res=224, multiply_adds=True): 24 | 25 | prods = {} 26 | def save_hook(name): 27 | def hook_per(self, input, output): 28 | prods[name] = np.prod(input[0].shape) 29 | return hook_per 30 | 31 | list_1=[] 32 | def simple_hook(self, input, output): 33 | list_1.append(np.prod(input[0].shape)) 34 | list_2={} 35 | def simple_hook2(self, input, output): 36 | list_2['names'] = np.prod(input[0].shape) 37 | 38 | 39 | list_conv=[] 40 | def conv_hook(self, input, output): 41 | batch_size, input_channels, input_height, input_width = input[0].size() 42 | output_channels, output_height, output_width = output[0].size() 43 | 44 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 45 | bias_ops = 1 if self.bias is not None else 0 46 | 47 | params = output_channels * (kernel_ops + bias_ops) 48 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 49 | 50 | num_weight_params = (self.weight.data != 0).float().sum() 51 | assert self.weight.numel() == kernel_ops * output_channels, "Not match" 52 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 53 | 54 | list_conv.append(flops) 55 | 56 | list_linear=[] 57 | def linear_hook(self, input, output): 58 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 59 | 60 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 61 | bias_ops = self.bias.nelement() 62 | 63 | flops = batch_size * (weight_ops + bias_ops) 64 | list_linear.append(flops) 65 | 66 | list_bn=[] 67 | def bn_hook(self, input, output): 68 | list_bn.append(input[0].nelement() * 2) 69 | 70 | list_relu=[] 71 | def relu_hook(self, input, output): 72 | list_relu.append(input[0].nelement()) 73 | 74 | list_pooling=[] 75 | def pooling_hook(self, input, output): 76 | batch_size, input_channels, input_height, input_width = input[0].size() 77 | output_channels, output_height, output_width = output[0].size() 78 | 79 | kernel_ops = self.kernel_size * self.kernel_size 80 | bias_ops = 0 81 | params = 0 82 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 83 | 84 | list_pooling.append(flops) 85 | 86 | list_upsample=[] 87 | # For bilinear upsample 88 | def upsample_hook(self, input, output): 89 | batch_size, input_channels, input_height, input_width = input[0].size() 90 | output_channels, output_height, output_width = output[0].size() 91 | 92 | flops = output_height * output_width * output_channels * batch_size * 12 93 | list_upsample.append(flops) 94 | 95 | def foo(net): 96 | childrens = list(net.children()) 97 | if not childrens: 98 | if isinstance(net, torch.nn.Conv2d): 99 | net.register_forward_hook(conv_hook) 100 | if isinstance(net, torch.nn.Linear): 101 | net.register_forward_hook(linear_hook) 102 | if isinstance(net, torch.nn.BatchNorm2d): 103 | net.register_forward_hook(bn_hook) 104 | if isinstance(net, torch.nn.ReLU): 105 | net.register_forward_hook(relu_hook) 106 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 107 | net.register_forward_hook(pooling_hook) 108 | if isinstance(net, torch.nn.Upsample): 109 | net.register_forward_hook(upsample_hook) 110 | return 111 | for c in childrens: 112 | foo(c) 113 | 114 | if model == None: 115 | model = torchvision.models.alexnet() 116 | foo(model) 117 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 118 | out = model(input) 119 | 120 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 121 | 122 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 123 | 124 | return total_flops -------------------------------------------------------------------------------- /cifar/weight-level/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eric-mingjie/rethinking-network-pruning/2ac473d70a09810df888e932bb394f225f9ed2d1/cifar/weight-level/models/__init__.py -------------------------------------------------------------------------------- /cifar/weight-level/models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | """The models subpackage contains definitions for the following model for CIFAR10/CIFAR100 4 | architectures: 5 | 6 | - `AlexNet`_ 7 | - `VGG`_ 8 | - `ResNet`_ 9 | - `SqueezeNet`_ 10 | - `DenseNet`_ 11 | 12 | You can construct a model with random weights by calling its constructor: 13 | 14 | .. code:: python 15 | 16 | import torchvision.models as models 17 | resnet18 = models.resnet18() 18 | alexnet = models.alexnet() 19 | squeezenet = models.squeezenet1_0() 20 | densenet = models.densenet_161() 21 | 22 | We provide pre-trained models for the ResNet variants and AlexNet, using the 23 | PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing 24 | ``pretrained=True``: 25 | 26 | .. code:: python 27 | 28 | import torchvision.models as models 29 | resnet18 = models.resnet18(pretrained=True) 30 | alexnet = models.alexnet(pretrained=True) 31 | 32 | ImageNet 1-crop error rates (224x224) 33 | 34 | ======================== ============= ============= 35 | Network Top-1 error Top-5 error 36 | ======================== ============= ============= 37 | ResNet-18 30.24 10.92 38 | ResNet-34 26.70 8.58 39 | ResNet-50 23.85 7.13 40 | ResNet-101 22.63 6.44 41 | ResNet-152 21.69 5.94 42 | Inception v3 22.55 6.44 43 | AlexNet 43.45 20.91 44 | VGG-11 30.98 11.37 45 | VGG-13 30.07 10.75 46 | VGG-16 28.41 9.62 47 | VGG-19 27.62 9.12 48 | SqueezeNet 1.0 41.90 19.58 49 | SqueezeNet 1.1 41.81 19.38 50 | Densenet-121 25.35 7.83 51 | Densenet-169 24.00 7.00 52 | Densenet-201 22.80 6.43 53 | Densenet-161 22.35 6.20 54 | ======================== ============= ============= 55 | 56 | 57 | .. _AlexNet: https://arxiv.org/abs/1404.5997 58 | .. _VGG: https://arxiv.org/abs/1409.1556 59 | .. _ResNet: https://arxiv.org/abs/1512.03385 60 | .. _SqueezeNet: https://arxiv.org/abs/1602.07360 61 | .. _DenseNet: https://arxiv.org/abs/1608.06993 62 | """ 63 | 64 | from .alexnet import * 65 | from .vgg import * 66 | from .resnet import * 67 | from .preresnet import * 68 | from .densenet import * 69 | -------------------------------------------------------------------------------- /cifar/weight-level/models/cifar/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | __all__ = ['alexnet'] 5 | 6 | class AlexNet(nn.Module): 7 | 8 | def __init__(self, num_classes=10): 9 | super(AlexNet, self).__init__() 10 | self.features = nn.Sequential( 11 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 12 | nn.ReLU(inplace=True), 13 | nn.MaxPool2d(kernel_size=2, stride=2), 14 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 15 | nn.ReLU(inplace=True), 16 | nn.MaxPool2d(kernel_size=2, stride=2), 17 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=2, stride=2), 24 | ) 25 | self.classifier = nn.Linear(256, num_classes) 26 | 27 | def forward(self, x): 28 | x = self.features(x) 29 | x = x.view(x.size(0), -1) 30 | x = self.classifier(x) 31 | return x 32 | 33 | def alexnet(**kwargs): 34 | r"""AlexNet model architecture from the 35 | `"One weird trick..." `_ paper. 36 | """ 37 | model = AlexNet(**kwargs) 38 | return model -------------------------------------------------------------------------------- /cifar/weight-level/models/cifar/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | __all__ = ['densenet'] 9 | 10 | class Bottleneck(nn.Module): 11 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 12 | super(Bottleneck, self).__init__() 13 | planes = expansion * growthRate 14 | self.bn1 = nn.BatchNorm2d(inplanes) 15 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 18 | padding=1, bias=False) 19 | self.relu = nn.ReLU(inplace=True) 20 | self.dropRate = dropRate 21 | 22 | def forward(self, x): 23 | out = self.bn1(x) 24 | out = self.relu(out) 25 | out = self.conv1(out) 26 | out = self.bn2(out) 27 | out = self.relu(out) 28 | out = self.conv2(out) 29 | if self.dropRate > 0: 30 | out = F.dropout(out, p=self.dropRate, training=self.training) 31 | 32 | out = torch.cat((x, out), 1) 33 | 34 | return out 35 | 36 | class BasicBlock(nn.Module): 37 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 38 | super(BasicBlock, self).__init__() 39 | planes = expansion * growthRate 40 | self.bn1 = nn.BatchNorm2d(inplanes) 41 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 42 | padding=1, bias=False) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.dropRate = dropRate 45 | 46 | def forward(self, x): 47 | out = self.bn1(x) 48 | out = self.relu(out) 49 | out = self.conv1(out) 50 | if self.dropRate > 0: 51 | out = F.dropout(out, p=self.dropRate, training=self.training) 52 | 53 | out = torch.cat((x, out), 1) 54 | 55 | return out 56 | 57 | class Transition(nn.Module): 58 | def __init__(self, inplanes, outplanes): 59 | super(Transition, self).__init__() 60 | self.bn1 = nn.BatchNorm2d(inplanes) 61 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 62 | bias=False) 63 | self.relu = nn.ReLU(inplace=True) 64 | 65 | def forward(self, x): 66 | out = self.bn1(x) 67 | out = self.relu(out) 68 | out = self.conv1(out) 69 | out = F.avg_pool2d(out, 2) 70 | return out 71 | 72 | class DenseNet(nn.Module): 73 | 74 | def __init__(self, depth=22, block=BasicBlock, 75 | dropRate=0, num_classes=10, growthRate=12, compressionRate=1): 76 | super(DenseNet, self).__init__() 77 | 78 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 79 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 80 | n = int(n) 81 | 82 | self.growthRate = growthRate 83 | self.dropRate = dropRate 84 | 85 | # self.inplanes is a global variable used across multiple 86 | # helper functions 87 | self.inplanes = growthRate * 2 88 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 89 | bias=False) 90 | self.dense1 = self._make_denseblock(block, n) 91 | self.trans1 = self._make_transition(compressionRate) 92 | self.dense2 = self._make_denseblock(block, n) 93 | self.trans2 = self._make_transition(compressionRate) 94 | self.dense3 = self._make_denseblock(block, n) 95 | self.bn = nn.BatchNorm2d(self.inplanes) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.avgpool = nn.AvgPool2d(8) 98 | self.fc = nn.Linear(self.inplanes, num_classes) 99 | 100 | # Weight initialization 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | 109 | def _make_denseblock(self, block, blocks): 110 | layers = [] 111 | for i in range(blocks): 112 | # Currently we fix the expansion ratio as the default value 113 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 114 | self.inplanes += self.growthRate 115 | 116 | return nn.Sequential(*layers) 117 | 118 | def _make_transition(self, compressionRate): 119 | inplanes = self.inplanes 120 | outplanes = int(math.floor(self.inplanes // compressionRate)) 121 | self.inplanes = outplanes 122 | return Transition(inplanes, outplanes) 123 | 124 | 125 | def forward(self, x): 126 | x = self.conv1(x) 127 | 128 | x = self.trans1(self.dense1(x)) 129 | x = self.trans2(self.dense2(x)) 130 | x = self.dense3(x) 131 | x = self.bn(x) 132 | x = self.relu(x) 133 | 134 | x = self.avgpool(x) 135 | x = x.view(x.size(0), -1) 136 | x = self.fc(x) 137 | 138 | return x 139 | 140 | def densenet(**kwargs): 141 | """ 142 | Constructs a ResNet model. 143 | """ 144 | return DenseNet(**kwargs) -------------------------------------------------------------------------------- /cifar/weight-level/models/cifar/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | 7 | __all__ = ['preresnet'] 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.bn1 = nn.BatchNorm2d(inplanes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.bn1(x) 31 | out = self.relu(out) 32 | out = self.conv1(out) 33 | 34 | out = self.bn2(out) 35 | out = self.relu(out) 36 | out = self.conv2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | 43 | return out 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(inplanes) 51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 54 | padding=1, bias=False) 55 | self.bn3 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.bn1(x) 65 | out = self.relu(out) 66 | out = self.conv1(out) 67 | 68 | out = self.bn2(out) 69 | out = self.relu(out) 70 | out = self.conv2(out) 71 | 72 | out = self.bn3(out) 73 | out = self.relu(out) 74 | out = self.conv3(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | 81 | return out 82 | 83 | class PreResNet(nn.Module): 84 | 85 | def __init__(self, depth, num_classes=10): 86 | super(PreResNet, self).__init__() 87 | # Model type specifies number of layers for CIFAR-10 model 88 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 89 | n = (depth - 2) // 9 90 | 91 | block = Bottleneck if depth >=44 else BasicBlock 92 | 93 | self.inplanes = 16 94 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 95 | bias=False) 96 | self.layer1 = self._make_layer(block, 16, n) 97 | self.layer2 = self._make_layer(block, 32, n, stride=2) 98 | self.layer3 = self._make_layer(block, 64, n, stride=2) 99 | self.bn = nn.BatchNorm2d(64 * block.expansion) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.avgpool = nn.AvgPool2d(8) 102 | self.fc = nn.Linear(64 * block.expansion, num_classes) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | 112 | def _make_layer(self, block, planes, blocks, stride=1): 113 | downsample = None 114 | if stride != 1 or self.inplanes != planes * block.expansion: 115 | downsample = nn.Sequential( 116 | nn.Conv2d(self.inplanes, planes * block.expansion, 117 | kernel_size=1, stride=stride, bias=False), 118 | ) 119 | 120 | layers = [] 121 | layers.append(block(self.inplanes, planes, stride, downsample)) 122 | self.inplanes = planes * block.expansion 123 | for i in range(1, blocks): 124 | layers.append(block(self.inplanes, planes)) 125 | 126 | return nn.Sequential(*layers) 127 | 128 | def forward(self, x): 129 | x = self.conv1(x) 130 | 131 | x = self.layer1(x) # 32x32 132 | x = self.layer2(x) # 16x16 133 | x = self.layer3(x) # 8x8 134 | x = self.bn(x) 135 | x = self.relu(x) 136 | 137 | x = self.avgpool(x) 138 | x = x.view(x.size(0), -1) 139 | x = self.fc(x) 140 | 141 | return x 142 | 143 | def preresnet(**kwargs): 144 | """ 145 | Constructs a ResNet model. 146 | """ 147 | return PreResNet(**kwargs) 148 | -------------------------------------------------------------------------------- /cifar/weight-level/models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | 7 | __all__ = ['resnet'] 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | class ResNet(nn.Module): 84 | 85 | def __init__(self, depth, num_classes=1000): 86 | super(ResNet, self).__init__() 87 | # Model type specifies number of layers for CIFAR-10 model 88 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 89 | n = (depth - 2) // 6 90 | 91 | block = Bottleneck if depth >=54 else BasicBlock 92 | 93 | self.inplanes = 16 94 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 95 | bias=False) 96 | self.bn1 = nn.BatchNorm2d(16) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.layer1 = self._make_layer(block, 16, n) 99 | self.layer2 = self._make_layer(block, 32, n, stride=2) 100 | self.layer3 = self._make_layer(block, 64, n, stride=2) 101 | self.avgpool = nn.AvgPool2d(8) 102 | self.fc = nn.Linear(64 * block.expansion, num_classes) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | 112 | def _make_layer(self, block, planes, blocks, stride=1): 113 | downsample = None 114 | if stride != 1 or self.inplanes != planes * block.expansion: 115 | downsample = nn.Sequential( 116 | nn.Conv2d(self.inplanes, planes * block.expansion, 117 | kernel_size=1, stride=stride, bias=False), 118 | nn.BatchNorm2d(planes * block.expansion), 119 | ) 120 | 121 | layers = [] 122 | layers.append(block(self.inplanes, planes, stride, downsample)) 123 | self.inplanes = planes * block.expansion 124 | for i in range(1, blocks): 125 | layers.append(block(self.inplanes, planes)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) # 32x32 133 | 134 | x = self.layer1(x) # 32x32 135 | x = self.layer2(x) # 16x16 136 | x = self.layer3(x) # 8x8 137 | 138 | x = self.avgpool(x) 139 | x = x.view(x.size(0), -1) 140 | x = self.fc(x) 141 | 142 | return x 143 | 144 | def resnet(**kwargs): 145 | """ 146 | Constructs a ResNet model. 147 | """ 148 | return ResNet(**kwargs) -------------------------------------------------------------------------------- /cifar/weight-level/models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, features, num_classes=1000): 26 | super(VGG, self).__init__() 27 | self.features = features 28 | self.classifier = nn.Linear(512, num_classes) 29 | self._initialize_weights() 30 | 31 | def forward(self, x): 32 | x = self.features(x) 33 | x = nn.AvgPool2d(2)(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.classifier(x) 36 | return x 37 | 38 | def _initialize_weights(self): 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * (m.in_channels) 42 | m.weight.data.normal_(0, math.sqrt(2. / n)) 43 | if m.bias is not None: 44 | m.bias.data.zero_() 45 | elif isinstance(m, nn.BatchNorm2d): 46 | m.weight.data.fill_(1) 47 | m.bias.data.zero_() 48 | elif isinstance(m, nn.Linear): 49 | n = m.weight.size(1) 50 | m.weight.data.normal_(0, 0.01) 51 | m.bias.data.zero_() 52 | 53 | 54 | def make_layers(cfg, batch_norm=False): 55 | layers = [] 56 | in_channels = 3 57 | for v in cfg: 58 | if v == 'M': 59 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 60 | else: 61 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 62 | if batch_norm: 63 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 64 | else: 65 | layers += [conv2d, nn.ReLU(inplace=True)] 66 | in_channels = v 67 | return nn.Sequential(*layers) 68 | 69 | 70 | cfg = { 71 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 72 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 73 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 74 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 75 | # 'E': [64, 128, 'M', 128, 256, 'M', 64, 128, 256, 512, 1024, 'M', 64, 128, 256, 512, 1024, 2048,'M',256, 512, 1024, 512,'M'] 76 | } 77 | 78 | 79 | def vgg11(**kwargs): 80 | """VGG 11-layer model (configuration "A") 81 | 82 | Args: 83 | pretrained (bool): If True, returns a model pre-trained on ImageNet 84 | """ 85 | model = VGG(make_layers(cfg['A']), **kwargs) 86 | return model 87 | 88 | 89 | def vgg11_bn(**kwargs): 90 | """VGG 11-layer model (configuration "A") with batch normalization""" 91 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 92 | return model 93 | 94 | 95 | def vgg13(**kwargs): 96 | """VGG 13-layer model (configuration "B") 97 | 98 | Args: 99 | pretrained (bool): If True, returns a model pre-trained on ImageNet 100 | """ 101 | model = VGG(make_layers(cfg['B']), **kwargs) 102 | return model 103 | 104 | 105 | def vgg13_bn(**kwargs): 106 | """VGG 13-layer model (configuration "B") with batch normalization""" 107 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 108 | return model 109 | 110 | 111 | def vgg16(**kwargs): 112 | """VGG 16-layer model (configuration "D") 113 | 114 | Args: 115 | pretrained (bool): If True, returns a model pre-trained on ImageNet 116 | """ 117 | model = VGG(make_layers(cfg['D']), **kwargs) 118 | return model 119 | 120 | 121 | def vgg16_bn(**kwargs): 122 | """VGG 16-layer model (configuration "D") with batch normalization""" 123 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 124 | return model 125 | 126 | 127 | def vgg19(**kwargs): 128 | """VGG 19-layer model (configuration "E") 129 | 130 | Args: 131 | pretrained (bool): If True, returns a model pre-trained on ImageNet 132 | """ 133 | model = VGG(make_layers(cfg['E']), **kwargs) 134 | return model 135 | 136 | 137 | def vgg19_bn(**kwargs): 138 | """VGG 19-layer model (configuration 'E') with batch normalization""" 139 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 140 | return model 141 | -------------------------------------------------------------------------------- /cifar/weight-level/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /cifar/weight-level/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | 4 | __all__ = ['accuracy'] 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | """Computes the precision@k for the specified values of k""" 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | res = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(0) 18 | res.append(correct_k.mul_(100.0 / batch_size)) 19 | return res -------------------------------------------------------------------------------- /cifar/weight-level/utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import sys 6 | 7 | 8 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 9 | 10 | def savefig(fname, dpi=None): 11 | dpi = 150 if dpi == None else dpi 12 | plt.savefig(fname, dpi=dpi) 13 | 14 | def plot_overlap(logger, names=None): 15 | names = logger.names if names == None else names 16 | numbers = logger.numbers 17 | for _, name in enumerate(names): 18 | x = np.arange(len(numbers[name])) 19 | plt.plot(x, np.asarray(numbers[name])) 20 | return [logger.title + '(' + name + ')' for name in names] 21 | 22 | class Logger(object): 23 | '''Save training process to log file with simple plot function.''' 24 | def __init__(self, fpath, title=None, resume=False): 25 | self.file = None 26 | self.resume = resume 27 | self.title = '' if title == None else title 28 | if fpath is not None: 29 | if resume: 30 | self.file = open(fpath, 'r') 31 | name = self.file.readline() 32 | self.names = name.rstrip().split('\t') 33 | self.numbers = {} 34 | for _, name in enumerate(self.names): 35 | self.numbers[name] = [] 36 | 37 | for numbers in self.file: 38 | numbers = numbers.rstrip().split('\t') 39 | for i in range(0, len(numbers)): 40 | self.numbers[self.names[i]].append(numbers[i]) 41 | self.file.close() 42 | self.file = open(fpath, 'a') 43 | else: 44 | self.file = open(fpath, 'w') 45 | 46 | def set_names(self, names): 47 | if self.resume: 48 | pass 49 | # initialize numbers as empty list 50 | self.numbers = {} 51 | self.names = names 52 | for _, name in enumerate(self.names): 53 | self.file.write(name) 54 | self.file.write('\t') 55 | self.numbers[name] = [] 56 | self.file.write('\n') 57 | self.file.flush() 58 | 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | self.file.write("{0:.6f}".format(num)) 64 | self.file.write('\t') 65 | self.numbers[self.names[index]].append(num) 66 | self.file.write('\n') 67 | self.file.flush() 68 | 69 | def plot(self, names=None): 70 | names = self.names if names == None else names 71 | numbers = self.numbers 72 | for _, name in enumerate(names): 73 | x = np.arange(len(numbers[name])) 74 | plt.plot(x, np.asarray(numbers[name])) 75 | plt.legend([self.title + '(' + name + ')' for name in names]) 76 | plt.grid(True) 77 | 78 | def close(self): 79 | if self.file is not None: 80 | self.file.close() 81 | 82 | class LoggerMonitor(object): 83 | '''Load and visualize multiple logs.''' 84 | def __init__ (self, paths): 85 | '''paths is a distionary with {name:filepath} pair''' 86 | self.loggers = [] 87 | for title, path in paths.items(): 88 | logger = Logger(path, title=title, resume=True) 89 | self.loggers.append(logger) 90 | 91 | def plot(self, names=None): 92 | plt.figure() 93 | plt.subplot(121) 94 | legend_text = [] 95 | for logger in self.loggers: 96 | legend_text += plot_overlap(logger, names) 97 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 98 | plt.grid(True) 99 | 100 | if __name__ == '__main__': 101 | # # Example 102 | # logger = Logger('test.txt') 103 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 104 | 105 | # length = 100 106 | # t = np.arange(length) 107 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 108 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | 111 | # for i in range(0, length): 112 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 113 | # logger.plot() 114 | 115 | # Example: logger monitor 116 | paths = { 117 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 118 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 119 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 120 | } 121 | 122 | field = ['Valid Acc.'] 123 | 124 | monitor = LoggerMonitor(paths) 125 | monitor.plot(names=field) 126 | savefig('test.eps') -------------------------------------------------------------------------------- /cifar/weight-level/utils/misc.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import math 3 | import os 4 | import sys 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | from torch.autograd import Variable 11 | 12 | 13 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | 19 | mean = torch.zeros(3) 20 | std = torch.zeros(3) 21 | print('==> Computing mean and std..') 22 | for inputs, targets in dataloader: 23 | for i in range(3): 24 | mean[i] += inputs[:,i,:,:].mean() 25 | std[i] += inputs[:,i,:,:].std() 26 | mean.div_(len(dataset)) 27 | std.div_(len(dataset)) 28 | return mean, std 29 | 30 | def get_conv_zero_param(model): 31 | total = 0 32 | for m in model.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | total += torch.sum(m.weight.data.eq(0)) 35 | return total 36 | 37 | def init_params(net): 38 | '''Init layer parameters.''' 39 | for m in net.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | init.kaiming_normal(m.weight, mode='fan_out') 42 | if m.bias: 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.BatchNorm2d): 45 | init.constant(m.weight, 1) 46 | init.constant(m.bias, 0) 47 | elif isinstance(m, nn.Linear): 48 | init.normal(m.weight, std=1e-3) 49 | if m.bias: 50 | init.constant(m.bias, 0) 51 | 52 | def mkdir_p(path): 53 | '''make dir if not exist''' 54 | try: 55 | os.makedirs(path) 56 | except OSError as exc: # Python >2.5 57 | if exc.errno == errno.EEXIST and os.path.isdir(path): 58 | pass 59 | else: 60 | raise 61 | 62 | class AverageMeter(object): 63 | """Computes and stores the average and current value 64 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 65 | """ 66 | def __init__(self): 67 | self.reset() 68 | 69 | def reset(self): 70 | self.val = 0 71 | self.avg = 0 72 | self.sum = 0 73 | self.count = 0 74 | 75 | def update(self, val, n=1): 76 | self.val = val 77 | self.sum += val * n 78 | self.count += n 79 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /cifar/weight-level/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | 9 | from .misc import * 10 | 11 | 12 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 13 | 14 | # functions to show an image 15 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 16 | for i in range(0, 3): 17 | img[i] = img[i] * std[i] + mean[i] # unnormalize 18 | npimg = img.numpy() 19 | return np.transpose(npimg, (1, 2, 0)) 20 | 21 | def gauss(x,a,b,c): 22 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 23 | 24 | def colorize(x): 25 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 26 | if x.dim() == 2: 27 | torch.unsqueeze(x, 0, out=x) 28 | if x.dim() == 3: 29 | cl = torch.zeros([3, x.size(1), x.size(2)]) 30 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 31 | cl[1] = gauss(x,1,.5,.3) 32 | cl[2] = gauss(x,1,.2,.3) 33 | cl[cl.gt(1)] = 1 34 | elif x.dim() == 4: 35 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 36 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 37 | cl[:,1,:,:] = gauss(x,1,.5,.3) 38 | cl[:,2,:,:] = gauss(x,1,.2,.3) 39 | return cl 40 | 41 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 42 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 43 | plt.imshow(images) 44 | plt.show() 45 | 46 | 47 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 48 | im_size = images.size(2) 49 | 50 | # save for adding mask 51 | im_data = images.clone() 52 | for i in range(0, 3): 53 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 54 | 55 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 56 | plt.subplot(2, 1, 1) 57 | plt.imshow(images) 58 | plt.axis('off') 59 | 60 | # for b in range(mask.size(0)): 61 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 62 | mask_size = mask.size(2) 63 | # print('Max %f Min %f' % (mask.max(), mask.min())) 64 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 65 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 66 | # for c in range(3): 67 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 68 | 69 | # print(mask.size()) 70 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 71 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 72 | plt.subplot(2, 1, 2) 73 | plt.imshow(mask) 74 | plt.axis('off') 75 | 76 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 77 | im_size = images.size(2) 78 | 79 | # save for adding mask 80 | im_data = images.clone() 81 | for i in range(0, 3): 82 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 83 | 84 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 85 | plt.subplot(1+len(masklist), 1, 1) 86 | plt.imshow(images) 87 | plt.axis('off') 88 | 89 | for i in range(len(masklist)): 90 | mask = masklist[i].data.cpu() 91 | # for b in range(mask.size(0)): 92 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 93 | mask_size = mask.size(2) 94 | # print('Max %f Min %f' % (mask.max(), mask.min())) 95 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 96 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 97 | # for c in range(3): 98 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 99 | 100 | # print(mask.size()) 101 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 102 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 103 | plt.subplot(1+len(masklist), 1, i+2) 104 | plt.imshow(mask) 105 | plt.axis('off') 106 | 107 | 108 | 109 | # x = torch.zeros(1, 3, 3) 110 | # out = colorize(x) 111 | # out_im = make_image(out) 112 | # plt.imshow(out_im) 113 | # plt.show() -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | # ImageNet 2 | This directory contains a pytorch implementation of the ImageNet experiments for six pruning methods: 3 | 4 | 1. [L1-norm based channel pruning](https://arxiv.org/abs/1608.08710) 5 | 2. [ThiNet](https://arxiv.org/abs/1707.06342) 6 | 3. [Regression based feature reconstruction](https://arxiv.org/abs/1707.06168) 7 | 4. [Network Slimming](https://arxiv.org/abs/1708.06519) 8 | 5. [Sparse Structure Selection](https://arxiv.org/abs/1707.01213) 9 | 6. [Non-structured weight-level pruning](https://arxiv.org/abs/1506.02626) 10 | 11 | ## Implementation 12 | We use the [official Pytorch ImageNet training code](https://github.com/pytorch/examples/blob/0.3.1/imagenet/main.py). 13 | 14 | ## Dependencies 15 | torch v0.3.1, torchvision v0.2.0 16 | 17 | -------------------------------------------------------------------------------- /imagenet/l1-norm-pruning/README.md: -------------------------------------------------------------------------------- 1 | # Pruning Filters for Efficient Convnets 2 | 3 | ## Baseline 4 | We get the ResNet-34 baseline model from Pytorch model zoo. 5 | 6 | ## Prune 7 | ``` 8 | python prune.py -v A --save [PATH TO SAVE RESULTS] [IMAGENET] 9 | python prune.py -v B --save [PATH TO SAVE RESULTS] [IMAGENET] 10 | ``` 11 | Here `-v` specifies the pruned model: ResNet-34-A or ResNet-34-B. 12 | 13 | ## Finetune 14 | ``` 15 | python main_finetune.py --arch resnet34 --refine [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 16 | ``` 17 | 18 | ## Scratch-E 19 | ``` 20 | python main_E.py --arch resnet34 --scratch [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 21 | ``` 22 | 23 | ## Scratch-B 24 | ``` 25 | python main_B.py --arch resnet34 --scratch [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 26 | ``` 27 | 28 | ## Models 29 | Network|Training method|Top-1|Top-5|Download 30 | :---:|:---:|:---:|:---:|:---: 31 | ResNet34-A|finetune| 72.56| 90.99| [pytorch model (154 MB)](https://drive.google.com/open?id=1EmxoTa0kCHSsFBpL8jEZ5CVF6_4bkxGM) 32 | ResNet34-A|scratch-E| 72.77| 91.20| [pytorch model (154 MB)](https://drive.google.com/open?id=1f-x3XHBFpCbUM5Y3cuH1_J9lN4CcbZ_0) 33 | ResNet34-A|scratch-B| 73.08| 91.29| [pytorch model (154 MB)](https://drive.google.com/open?id=1fQT68PATrGk9zt6HXgCzTNr-zsS9aLeq) 34 | ResNet34-B|finetune| 72.29| 90.72| [pytorch model (149 MB)](https://drive.google.com/open?id=1pW05JPHAPd_-bR862CmQPsP_q1_jsRwG) 35 | ResNet34-B|scratch-E| 72.55| 91.07| [pytorch model (149 MB)](https://drive.google.com/open?id=1YPcKrh1ctxUYsn2Yk-D4cuW7DVV1hX4-) 36 | ResNet34-B|scratch-B| 72.84| 91.19| [pytorch model (149 MB)](https://drive.google.com/open?id=1f_Nl-bcxBdhp3R2bY1nby4PIbwHOTaUv) 37 | -------------------------------------------------------------------------------- /imagenet/l1-norm-pruning/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | 11 | def print_model_param_nums(model=None): 12 | if model == None: 13 | model = torchvision.models.alexnet() 14 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 15 | print(' + Number of params: %.4fM' % (total / 1e6)) 16 | 17 | def count_model_param_flops(model=None, input_res=224, multiply_adds=True): 18 | 19 | prods = {} 20 | def save_hook(name): 21 | def hook_per(self, input, output): 22 | prods[name] = np.prod(input[0].shape) 23 | return hook_per 24 | 25 | list_1=[] 26 | def simple_hook(self, input, output): 27 | list_1.append(np.prod(input[0].shape)) 28 | list_2={} 29 | def simple_hook2(self, input, output): 30 | list_2['names'] = np.prod(input[0].shape) 31 | 32 | 33 | list_conv=[] 34 | def conv_hook(self, input, output): 35 | batch_size, input_channels, input_height, input_width = input[0].size() 36 | output_channels, output_height, output_width = output[0].size() 37 | 38 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 39 | bias_ops = 1 if self.bias is not None else 0 40 | 41 | params = output_channels * (kernel_ops + bias_ops) 42 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 43 | 44 | num_weight_params = (self.weight.data != 0).float().sum() 45 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 46 | 47 | list_conv.append(flops) 48 | 49 | list_linear=[] 50 | def linear_hook(self, input, output): 51 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 52 | 53 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 54 | bias_ops = self.bias.nelement() 55 | 56 | flops = batch_size * (weight_ops + bias_ops) 57 | list_linear.append(flops) 58 | 59 | list_bn=[] 60 | def bn_hook(self, input, output): 61 | list_bn.append(input[0].nelement() * 2) 62 | 63 | list_relu=[] 64 | def relu_hook(self, input, output): 65 | list_relu.append(input[0].nelement()) 66 | 67 | list_pooling=[] 68 | def pooling_hook(self, input, output): 69 | batch_size, input_channels, input_height, input_width = input[0].size() 70 | output_channels, output_height, output_width = output[0].size() 71 | 72 | kernel_ops = self.kernel_size * self.kernel_size 73 | bias_ops = 0 74 | params = 0 75 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 76 | 77 | list_pooling.append(flops) 78 | 79 | list_upsample=[] 80 | 81 | # For bilinear upsample 82 | def upsample_hook(self, input, output): 83 | batch_size, input_channels, input_height, input_width = input[0].size() 84 | output_channels, output_height, output_width = output[0].size() 85 | 86 | flops = output_height * output_width * output_channels * batch_size * 12 87 | list_upsample.append(flops) 88 | 89 | def foo(net): 90 | childrens = list(net.children()) 91 | if not childrens: 92 | if isinstance(net, torch.nn.Conv2d): 93 | net.register_forward_hook(conv_hook) 94 | if isinstance(net, torch.nn.Linear): 95 | net.register_forward_hook(linear_hook) 96 | if isinstance(net, torch.nn.BatchNorm2d): 97 | net.register_forward_hook(bn_hook) 98 | if isinstance(net, torch.nn.ReLU): 99 | net.register_forward_hook(relu_hook) 100 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 101 | net.register_forward_hook(pooling_hook) 102 | if isinstance(net, torch.nn.Upsample): 103 | net.register_forward_hook(upsample_hook) 104 | return 105 | for c in childrens: 106 | foo(c) 107 | 108 | if model == None: 109 | model = torchvision.models.alexnet() 110 | foo(model) 111 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 112 | out = model(input) 113 | 114 | 115 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 116 | 117 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 118 | 119 | return total_flops -------------------------------------------------------------------------------- /imagenet/network-slimming/README.md: -------------------------------------------------------------------------------- 1 | # Network Slimming 2 | This directory contains the code for implementing Network Slimming on ImageNet. 3 | 4 | ## Implementation 5 | We use the `mask implementation` for finetuning, where during pruning we set 0 to the channel scaling factor 6 | whose corresponding channels are pruned. When finetuning the pruned model, in each iteration, before we call `optimizer.step()`, we update the gradient of those 0 scaling factors to be 0. This is achieved in `BN_grad_zero` function. 7 | 8 | ## Train with sparsity 9 | ``` 10 | python main.py --arch vgg11_bn --s 0.00001 --save [PATH TO SAVE RESULTS] [IMAGENET] 11 | ``` 12 | 13 | ## Prune 14 | ``` 15 | python prune.py --percent 0.5 --model [PATH TO THE BASE MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 16 | ``` 17 | 18 | ## Finetune 19 | ``` 20 | python main_finetune.py --arch vgg11_bn --refine [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 21 | ``` 22 | 23 | ## Scratch-E 24 | ``` 25 | python main_E.py --arch vgg11_bn --scratch [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 26 | ``` 27 | 28 | ## Scratch-B 29 | ``` 30 | python main_B.py --arch vgg11_bn --scratch [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 31 | ``` 32 | 33 | ## Models 34 | Network|Prune ratio|Training method|Top-1|Top-5|Download 35 | :---:|:---:|:---:|:---:|:---:|:---: 36 | VGG-11(mask-impl)|50%|finetune| 68.62| 88.77| [pytorch model (1014 MB)](https://drive.google.com/open?id=10uscgVM_5ghsxI110y5-sl8T3Kkzki6N) 37 | VGG-11(mask-impl)|50%|scratch-E| 70.00| 89.33| [pytorch model (1014 MB)](https://drive.google.com/open?id=11ITIlGYUu9wZAF-sp06L5h5JoTKYtWsS) 38 | VGG-11|50%|scratch-B| 71.18| 90.08| [pytorch model (282 MB)](https://drive.google.com/open?id=1HjCAETR2kAx2uORe9yxKXZxidxQJboQx) -------------------------------------------------------------------------------- /imagenet/network-slimming/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | 11 | def print_model_param_nums(model=None): 12 | if model == None: 13 | model = torchvision.models.alexnet() 14 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 15 | print(' + Number of params: %.2fM' % (total / 1e6)) 16 | 17 | def count_model_param_flops(model=None, input_res=224, multiply_adds=True): 18 | 19 | prods = {} 20 | def save_hook(name): 21 | def hook_per(self, input, output): 22 | prods[name] = np.prod(input[0].shape) 23 | return hook_per 24 | 25 | list_1=[] 26 | def simple_hook(self, input, output): 27 | list_1.append(np.prod(input[0].shape)) 28 | list_2={} 29 | def simple_hook2(self, input, output): 30 | list_2['names'] = np.prod(input[0].shape) 31 | 32 | 33 | list_conv=[] 34 | def conv_hook(self, input, output): 35 | batch_size, input_channels, input_height, input_width = input[0].size() 36 | output_channels, output_height, output_width = output[0].size() 37 | 38 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 39 | bias_ops = 1 if self.bias is not None else 0 40 | 41 | params = output_channels * (kernel_ops + bias_ops) 42 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 43 | 44 | num_weight_params = (self.weight.data != 0).float().sum() 45 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 46 | 47 | list_conv.append(flops) 48 | 49 | list_linear=[] 50 | def linear_hook(self, input, output): 51 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 52 | 53 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 54 | bias_ops = self.bias.nelement() 55 | 56 | flops = batch_size * (weight_ops + bias_ops) 57 | list_linear.append(flops) 58 | 59 | list_bn=[] 60 | def bn_hook(self, input, output): 61 | list_bn.append(input[0].nelement() * 2) 62 | 63 | list_relu=[] 64 | def relu_hook(self, input, output): 65 | list_relu.append(input[0].nelement()) 66 | 67 | list_pooling=[] 68 | def pooling_hook(self, input, output): 69 | batch_size, input_channels, input_height, input_width = input[0].size() 70 | output_channels, output_height, output_width = output[0].size() 71 | 72 | kernel_ops = self.kernel_size * self.kernel_size 73 | bias_ops = 0 74 | params = 0 75 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 76 | 77 | list_pooling.append(flops) 78 | 79 | list_upsample=[] 80 | # For bilinear upsample 81 | def upsample_hook(self, input, output): 82 | batch_size, input_channels, input_height, input_width = input[0].size() 83 | output_channels, output_height, output_width = output[0].size() 84 | 85 | flops = output_height * output_width * output_channels * batch_size * 12 86 | list_upsample.append(flops) 87 | 88 | def foo(net): 89 | childrens = list(net.children()) 90 | if not childrens: 91 | if isinstance(net, torch.nn.Conv2d): 92 | net.register_forward_hook(conv_hook) 93 | if isinstance(net, torch.nn.Linear): 94 | net.register_forward_hook(linear_hook) 95 | if isinstance(net, torch.nn.BatchNorm2d): 96 | net.register_forward_hook(bn_hook) 97 | if isinstance(net, torch.nn.ReLU): 98 | net.register_forward_hook(relu_hook) 99 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 100 | net.register_forward_hook(pooling_hook) 101 | if isinstance(net, torch.nn.Upsample): 102 | net.register_forward_hook(upsample_hook) 103 | return 104 | for c in childrens: 105 | foo(c) 106 | 107 | if model == None: 108 | model = torchvision.models.alexnet() 109 | foo(model) 110 | # input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 111 | input = Variable(torch.rand(3,3,input_res,input_res), requires_grad = True) 112 | out = model(input) 113 | 114 | 115 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 116 | 117 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 118 | 119 | return total_flops 120 | -------------------------------------------------------------------------------- /imagenet/network-slimming/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = [ 8 | 'slimmingvgg', 9 | ] 10 | 11 | model_urls = { 12 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 13 | } 14 | 15 | class VGG(nn.Module): 16 | 17 | def __init__(self, features, cfg, num_classes=1000, init_weights=True): 18 | super(VGG, self).__init__() 19 | self.features = features 20 | self.classifier = nn.Sequential( 21 | nn.Linear(cfg[0] * 7 * 7, cfg[1]), 22 | nn.BatchNorm1d(cfg[1]), 23 | nn.ReLU(True), 24 | nn.Linear(cfg[1],cfg[2]), 25 | nn.BatchNorm1d(cfg[2]), 26 | nn.ReLU(True), 27 | nn.Linear(cfg[2], num_classes) 28 | ) 29 | if init_weights: 30 | self._initialize_weights() 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.classifier(x) 36 | return x 37 | 38 | def _initialize_weights(self): 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | nn.init.kaiming_normal(m.weight, mode='fan_out')#, nonlinearity='relu') 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | elif isinstance(m, nn.BatchNorm2d): 45 | m.weight.data.fill_(0.5) 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.Linear): 48 | m.weight.data.normal_(0, 0.01) 49 | m.bias.data.zero_() 50 | elif isinstance(m, nn.BatchNorm1d): 51 | m.weight.data.fill_(0.5) 52 | m.bias.data.zero_() 53 | 54 | def make_layers(cfg, batch_norm=False): 55 | layers = [] 56 | in_channels = 3 57 | for v in cfg: 58 | if v == 'M': 59 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 60 | else: 61 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 62 | if batch_norm: 63 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 64 | else: 65 | layers += [conv2d, nn.ReLU(inplace=True)] 66 | in_channels = v 67 | return nn.Sequential(*layers) 68 | 69 | cfg = { 70 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M', 4096, 4096] 71 | } 72 | 73 | def slimmingvgg(pretrained=False, config=None, **kwargs): 74 | """VGG 11-layer model (configuration "A") with batch normalization 75 | 76 | Args: 77 | pretrained (bool): If True, returns a model pre-trained on ImageNet 78 | """ 79 | if pretrained: 80 | kwargs['init_weights'] = False 81 | if config == None: 82 | config = cfg['A'] 83 | config2 = [config[-4],config[-2],config[-1]] 84 | model = VGG(make_layers(config[:-2], batch_norm=True), config2, **kwargs) 85 | if pretrained: 86 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 87 | return model -------------------------------------------------------------------------------- /imagenet/regression-pruning/README.md: -------------------------------------------------------------------------------- 1 | # Channel Pruning for Accelerating Very Deep Neural Networks 2 | This directory contains a pytorch implementation of the ImageNet experiments of this [paper](https://arxiv.org/abs/1707.06168). The authors have released their code and models in this [repository](https://github.com/yihui-he/channel-pruning). 3 | 4 | ## Implementation 5 | For ResNet-2x, we introduce a `channel selection` layer for pruning the first convolutional layer in a residual block. The `indexes` of the `channel selection` layer is stored in `models/filter.pkl`, which is computed from the [official released model](https://github.com/yihui-he/channel-pruning/releases/tag/ResNet-50-2X). In loading the model ResNet-2x, the indexes in `filter.pkl` is automatically loaded into the network. 6 | 7 | ## Finetune 8 | We use the released model from their repository, where they use Caffe. Therefore, we test the models in Caffe and report the accuracy in the paper. 9 | 10 | ## Scratch-E 11 | ``` 12 | python main_E.py --arch vgg16 --model vgg-5x --lr 0.01 --save [PATH TO SAVE MODEL] [IMAGENET] 13 | python main_E.py --arch resnet50 --model resnet-2x --save [PATH TO SAVE MODEL] [IMAGENET] 14 | ``` 15 | 16 | ## Scratch-B 17 | ``` 18 | python main_B.py --arch vgg16 --model vgg-5x --lr 0.01 --save [PATH TO SAVE MODEL] [IMAGNET] 19 | python main_B.py --arch resnet50 --model resnet-2x --save [PATH TO SAVE MODEL] [IMAGENET] 20 | ``` 21 | Here for VGG-2x, the number of epochs for scratch-B training is 180 epochs; for ResNet-2x, the number of epochs for scratch-B training (132 epochs) is computed according to the actual FLOPs reduction ratio. 22 | 23 | ## Models 24 | We test the model using the scheme: resize the shorter edge to 256 and center crop to (224,224). 25 | 26 | Network|Training method|Top-1|Top-5|Download 27 | :---:|:---:|:---:|:---:|:---: 28 | VGG-5x|scratch-E| 68.05| 88.15| [pytorch model (999 MB)](https://drive.google.com/open?id=151ysF8v39GuZHxAK9YjvTWoUqBqiIdJ0) 29 | VGG-5x|scratch-B| 71.00| 89.96| [pytorch model (999 MB)](https://drive.google.com/open?id=1FiTQhRs4L19bp_YKoGXn2M_6BkxWSC_-) 30 | ResNet-2x|scratch-E| 71.26| 90.68| [pytorch model (151 MB)](https://drive.google.com/open?id=1hdbcrB3-3z5n1WnQRJ6VfYeWrlYmDmkL) 31 | ResNet-2x|scratch-B| 74.58| 92.23| [pytorch model (151 MB)](https://drive.google.com/open?id=16rgMbYHMtwl5rXOJHgKqyrTZB7O2sV40) -------------------------------------------------------------------------------- /imagenet/regression-pruning/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | 11 | def print_model_param_nums(model=None): 12 | if model == None: 13 | model = torchvision.models.alexnet() 14 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 15 | print(' + Number of params: %.4fM' % (total / 1e6)) 16 | 17 | def count_model_param_flops(model=None, input_res=224, multiply_adds=True): 18 | 19 | prods = {} 20 | def save_hook(name): 21 | def hook_per(self, input, output): 22 | prods[name] = np.prod(input[0].shape) 23 | return hook_per 24 | 25 | list_1=[] 26 | def simple_hook(self, input, output): 27 | list_1.append(np.prod(input[0].shape)) 28 | list_2={} 29 | def simple_hook2(self, input, output): 30 | list_2['names'] = np.prod(input[0].shape) 31 | 32 | 33 | list_conv=[] 34 | def conv_hook(self, input, output): 35 | batch_size, input_channels, input_height, input_width = input[0].size() 36 | output_channels, output_height, output_width = output[0].size() 37 | 38 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 39 | bias_ops = 1 if self.bias is not None else 0 40 | 41 | params = output_channels * (kernel_ops + bias_ops) 42 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 43 | 44 | num_weight_params = (self.weight.data != 0).float().sum() 45 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 46 | 47 | list_conv.append(flops) 48 | 49 | list_linear=[] 50 | def linear_hook(self, input, output): 51 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 52 | 53 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 54 | bias_ops = self.bias.nelement() 55 | 56 | flops = batch_size * (weight_ops + bias_ops) 57 | list_linear.append(flops) 58 | 59 | list_bn=[] 60 | def bn_hook(self, input, output): 61 | list_bn.append(input[0].nelement() * 2) 62 | 63 | list_relu=[] 64 | def relu_hook(self, input, output): 65 | list_relu.append(input[0].nelement()) 66 | 67 | list_pooling=[] 68 | def pooling_hook(self, input, output): 69 | batch_size, input_channels, input_height, input_width = input[0].size() 70 | output_channels, output_height, output_width = output[0].size() 71 | 72 | kernel_ops = self.kernel_size * self.kernel_size 73 | bias_ops = 0 74 | params = 0 75 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 76 | 77 | list_pooling.append(flops) 78 | 79 | list_upsample=[] 80 | 81 | # For bilinear upsample 82 | def upsample_hook(self, input, output): 83 | batch_size, input_channels, input_height, input_width = input[0].size() 84 | output_channels, output_height, output_width = output[0].size() 85 | 86 | flops = output_height * output_width * output_channels * batch_size * 12 87 | list_upsample.append(flops) 88 | 89 | def foo(net): 90 | childrens = list(net.children()) 91 | if not childrens: 92 | if isinstance(net, torch.nn.Conv2d): 93 | net.register_forward_hook(conv_hook) 94 | if isinstance(net, torch.nn.Linear): 95 | net.register_forward_hook(linear_hook) 96 | if isinstance(net, torch.nn.BatchNorm2d): 97 | net.register_forward_hook(bn_hook) 98 | if isinstance(net, torch.nn.ReLU): 99 | net.register_forward_hook(relu_hook) 100 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 101 | net.register_forward_hook(pooling_hook) 102 | if isinstance(net, torch.nn.Upsample): 103 | net.register_forward_hook(upsample_hook) 104 | return 105 | for c in childrens: 106 | foo(c) 107 | 108 | if model == None: 109 | model = torchvision.models.alexnet() 110 | foo(model) 111 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 112 | out = model(input) 113 | 114 | 115 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 116 | 117 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 118 | 119 | return total_flops -------------------------------------------------------------------------------- /imagenet/regression-pruning/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg_5x import vgg_5x, vgg_official 2 | from .resnet_2x import resnet_2x 3 | from .resnet import resnet50_official -------------------------------------------------------------------------------- /imagenet/regression-pruning/models/channel_selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class channel_selection(nn.Module): 8 | """ 9 | Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer. 10 | The output shape of this layer is determined by the number of 1 in `self.indexes`. 11 | """ 12 | def __init__(self, num_channels, mask): 13 | """ 14 | Initialize the `indexes` with all one vector with the length same as the number of channels. 15 | During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0. 16 | """ 17 | super(channel_selection, self).__init__() 18 | self.indexes = nn.Parameter(torch.ones(num_channels)) 19 | assert len(mask) == num_channels 20 | mask = torch.from_numpy(mask) 21 | self.indexes.data.mul_(mask) 22 | 23 | def forward(self, input_tensor): 24 | """ 25 | Parameter 26 | --------- 27 | input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer. 28 | """ 29 | selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy())) 30 | if selected_index.size == 1: 31 | selected_index = np.resize(selected_index, (1,)) 32 | output = input_tensor[:, selected_index, :, :] 33 | return output -------------------------------------------------------------------------------- /imagenet/regression-pruning/models/filter.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eric-mingjie/rethinking-network-pruning/2ac473d70a09810df888e932bb394f225f9ed2d1/imagenet/regression-pruning/models/filter.pkl -------------------------------------------------------------------------------- /imagenet/regression-pruning/models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from compute_flops import * 7 | 8 | 9 | __all__ = ['resnet50_official'] 10 | 11 | model_urls = { 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | } 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | "3x3 convolution with padding" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | class Bottleneck(nn.Module): 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, cfg, stride=1, downsample=None): 24 | super(Bottleneck, self).__init__() 25 | self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False) 26 | self.bn1 = nn.BatchNorm2d(cfg[1]) 27 | self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride, 28 | padding=1, bias=False) 29 | self.bn2 = nn.BatchNorm2d(cfg[2]) 30 | self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False) 31 | self.bn3 = nn.BatchNorm2d(planes * 4) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv3(out) 48 | out = self.bn3(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class ResNet(nn.Module): 60 | 61 | def __init__(self, block, layers, cfg, num_classes=1000): 62 | self.inplanes = 64 63 | super(ResNet, self).__init__() 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 65 | bias=False) 66 | self.bn1 = nn.BatchNorm2d(64) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 69 | self.layer1 = self._make_layer(block, cfg[0:9], 64, layers[0]) 70 | self.layer2 = self._make_layer(block, cfg[9:21], 128, layers[1], stride=2) 71 | self.layer3 = self._make_layer(block, cfg[21:39], 256, layers[2], stride=2) 72 | self.layer4 = self._make_layer(block, cfg[39:48], 512, layers[3], stride=2) 73 | self.avgpool = nn.AvgPool2d(7, stride=1) 74 | self.fc = nn.Linear(512 * block.expansion, num_classes) 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 79 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 80 | nn.init.kaiming_normal(m.weight, mode='fan_out') 81 | # m.bias.data.zero_() 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | 86 | def _make_layer(self, block, cfg, planes, blocks, stride=1): 87 | downsample = None 88 | if stride != 1 or self.inplanes != planes * block.expansion: 89 | downsample = nn.Sequential( 90 | nn.Conv2d(self.inplanes, planes * block.expansion, 91 | kernel_size=1, stride=stride, bias=True), 92 | nn.BatchNorm2d(planes * block.expansion), 93 | ) 94 | 95 | layers = [] 96 | layers.append(block(self.inplanes, planes, cfg[:3], stride, downsample)) 97 | self.inplanes = planes * block.expansion 98 | for i in range(1, blocks): 99 | layers.append(block(self.inplanes, planes, cfg[3*i:3*(i+1)])) 100 | 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | x = self.conv1(x) 105 | x = self.bn1(x) 106 | x = self.relu(x) 107 | x = self.maxpool(x) 108 | 109 | x = self.layer1(x) 110 | x = self.layer2(x) 111 | x = self.layer3(x) 112 | x = self.layer4(x) 113 | 114 | x = self.avgpool(x) 115 | x = x.view(x.size(0), -1) 116 | x = self.fc(x) 117 | 118 | return x 119 | 120 | cfg_official = [[64, 64, 64], [256, 64, 64] * 2, [256, 128, 128], [512, 128, 128] * 3, 121 | [512, 256, 256], [1024, 256, 256] * 5, [1024, 512, 512], [2048, 512, 512] * 2] 122 | cfg_official = [item for sublist in cfg_official for item in sublist] 123 | assert len(cfg_official) == 48, "Length of cfg_official is not right" 124 | 125 | 126 | def resnet50_official(pretrained=False, **kwargs): 127 | """Constructs a ResNet-50 model. 128 | 129 | Args: 130 | pretrained (bool): If True, returns a model pre-trained on ImageNet 131 | """ 132 | model = ResNet(Bottleneck, [3, 4, 6, 3], cfg_official, **kwargs) 133 | if pretrained: 134 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 135 | return model -------------------------------------------------------------------------------- /imagenet/regression-pruning/models/resnet_2x.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.autograd import Variable 7 | 8 | from channel_selection import * 9 | 10 | 11 | __all__ = ['resnet_2x'] 12 | 13 | model_urls = { 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | } 16 | 17 | with open("models/filter.pkl",'rb') as f: 18 | filter_index = pickle.load(f) 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | class Bottleneck(nn.Module): 26 | expansion = 4 27 | 28 | def __init__(self, inplanes, planes, cfg, mask, stride=1, downsample=None): 29 | super(Bottleneck, self).__init__() 30 | self.selection = channel_selection(inplanes, mask) 31 | self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False) 32 | self.bn1 = nn.BatchNorm2d(cfg[1]) 33 | self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride, 34 | padding=1, bias=False) 35 | self.bn2 = nn.BatchNorm2d(cfg[2]) 36 | self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False) 37 | self.bn3 = nn.BatchNorm2d(planes * 4) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.selection(x) 46 | 47 | out = self.conv1(out) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv3(out) 56 | out = self.bn3(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | 69 | def __init__(self, block, layers, cfg, num_classes=1000): 70 | self.inplanes = 64 71 | super(ResNet, self).__init__() 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 73 | bias=False) 74 | self.bn1 = nn.BatchNorm2d(64) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 77 | self.count = 0 78 | self.layer1 = self._make_layer(block, cfg[0:9], 64, layers[0]) 79 | self.layer2 = self._make_layer(block, cfg[9:21], 128, layers[1], stride=2) 80 | self.layer3 = self._make_layer(block, cfg[21:39], 256, layers[2], stride=2) 81 | self.layer4 = self._make_layer(block, cfg[39:48], 512, layers[3], stride=2) 82 | self.avgpool = nn.AvgPool2d(7, stride=1) 83 | self.fc = nn.Linear(512 * block.expansion, num_classes) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | nn.init.kaiming_normal(m.weight, mode='fan_out') 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | def _make_layer(self, block, cfg, planes, blocks, stride=1): 94 | downsample = None 95 | if stride != 1 or self.inplanes != planes * block.expansion: 96 | downsample = nn.Sequential( 97 | nn.Conv2d(self.inplanes, planes * block.expansion, 98 | kernel_size=1, stride=stride, bias=True), 99 | nn.BatchNorm2d(planes * block.expansion), 100 | ) 101 | 102 | layers = [] 103 | mask = filter_index[self.count] 104 | layers.append(block(self.inplanes, planes, cfg[:3], mask, stride, downsample)) 105 | self.count += 1 106 | self.inplanes = planes * block.expansion 107 | for i in range(1, blocks): 108 | mask = filter_index[self.count] 109 | layers.append(block(self.inplanes, planes, cfg[3*i:3*(i+1)], mask)) 110 | self.count += 1 111 | 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | x = self.conv1(x) 116 | x = self.bn1(x) 117 | x = self.relu(x) 118 | x = self.maxpool(x) 119 | 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | 125 | x = self.avgpool(x) 126 | x = x.view(x.size(0), -1) 127 | x = self.fc(x) 128 | 129 | return x 130 | 131 | cfg_2x = [35, 64, 55, 101, 51, 39, 97, 50, 37, 144, 128, 106, 205, 105, 72, 198, 105, 72, 288, 128, 110, 278, 256, 225, 418, 209, 147, 132 | 407, 204, 158, 423, 212, 155, 412, 211, 148, 595, 256, 213, 606, 512, 433, 1222, 512, 437, 1147, 512, 440] 133 | assert len(cfg_2x) == 48, "Length of cfg variable is not right." 134 | 135 | 136 | def resnet_2x(pretrained=False, **kwargs): 137 | """Constructs a ResNet-50 model. 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | """ 142 | model = ResNet(Bottleneck, [3, 4, 6, 3], cfg_2x, **kwargs) 143 | if pretrained: 144 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 145 | return model -------------------------------------------------------------------------------- /imagenet/regression-pruning/models/vgg_5x.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from torch.autograd import Variable 6 | 7 | 8 | __all__ = [ 9 | 'vgg_5x', 'vgg_official', 10 | ] 11 | 12 | model_urls = { 13 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 14 | } 15 | 16 | class VGG(nn.Module): 17 | 18 | def __init__(self, features, num_classes=1000, init_weights=True): 19 | super(VGG, self).__init__() 20 | self.features = features 21 | self.classifier = nn.Sequential( 22 | nn.Linear(512 * 7 * 7, 4096), 23 | nn.ReLU(True), 24 | nn.Dropout(), 25 | nn.Linear(4096, 4096), 26 | nn.ReLU(True), 27 | nn.Dropout(), 28 | nn.Linear(4096, num_classes), 29 | ) 30 | if init_weights: 31 | self._initialize_weights() 32 | 33 | def forward(self, x): 34 | x = self.features(x) 35 | x = x.view(x.size(0), -1) 36 | x = self.classifier(x) 37 | return x 38 | 39 | def _initialize_weights(self): 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | nn.init.kaiming_normal(m.weight, mode='fan_out')#, nonlinearity='relu') 43 | if m.bias is not None: 44 | m.bias.data.zero_() 45 | elif isinstance(m, nn.Linear): 46 | m.weight.data.normal_(0, 0.01) 47 | m.bias.data.zero_() 48 | elif isinstance(m, nn.BatchNorm2d): 49 | m.weight.data.fill_(1) 50 | m.bias.data.zero_() 51 | 52 | def make_layers(cfg, batch_norm=False): 53 | layers = [] 54 | in_channels = 3 55 | for v in cfg: 56 | if v == 'M': 57 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 58 | else: 59 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 60 | if batch_norm: 61 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 62 | else: 63 | layers += [conv2d, nn.ReLU(inplace=True)] 64 | in_channels = v 65 | return nn.Sequential(*layers) 66 | 67 | cfg_5x = [24, 22, 'M', 41, 51, 'M', 108, 89, 111, 'M', 184, 276, 228, 'M', 512, 512, 512, 'M'] 68 | cfg_official = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 69 | 70 | def vgg_5x(pretrained=False, **kwargs): 71 | """VGG 16-layer model (configuration "D") 72 | 73 | Args: 74 | pretrained (bool): If True, returns a model pre-trained on ImageNet 75 | """ 76 | if pretrained: 77 | kwargs['init_weights'] = False 78 | model = VGG(make_layers(cfg_5x, False), **kwargs) 79 | if pretrained: 80 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 81 | return model 82 | 83 | def vgg_official(pretrained=False, **kwargs): 84 | """VGG 16-layer model (configuration "D") 85 | 86 | Args: 87 | pretrained (bool): If True, returns a model pre-trained on ImageNet 88 | """ 89 | if pretrained: 90 | kwargs['init_weights'] = False 91 | model = VGG(make_layers(cfg_official, False), **kwargs) 92 | if pretrained: 93 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 94 | return model -------------------------------------------------------------------------------- /imagenet/sparse-structure-selection/README.md: -------------------------------------------------------------------------------- 1 | # Sparse Structure Selection 2 | The authors have released the [code](https://github.com/TuSimple/sparse-structure-selection) for [Data Driven Sparse Structure Selection for Deep Neural Networks](http://openaccess.thecvf.com/content_ECCV_2018/papers/Zehao_Huang_Data-Driven_Sparse_Structure_ECCV_2018_paper.pdf). 3 | 4 | In this repository, we describe how we use the released code for our experiments. For the accuracy of the pruned model, we use the results of the original paper. 5 | 6 | 1. Modify the file `config/cfgs.py`, set `sss=False`. 7 | 2. Modify the file `symbol/resnet.py` to support ResNet-41, ResNet-32, ResNet-26 from the original paper. The details are as follows: I add a new parameter in function [residual_unit](https://github.com/TuSimple/sparse-structure-selection/blob/master/symbol/resnet.py#L10) in `symbol/resnet.py` to indicate whether this block's residual connection is pruned. In this way, we can modify the code [here](https://github.com/TuSimple/sparse-structure-selection/blob/master/symbol/resnet.py#L10) to support all pruned model of ResNet. 8 | It would be helpful to create a new `config.units` format as follows: 9 | a. ResNet-41: `[(0,False), (4,True), (6,True),(3,True)]` 10 | b. ResNet-32: `[(1,False), (4,True), (4,True),(1,True)]` 11 | c. ResNet-26: `[(0,False), (2,False), (5,False),(1,True)]` 12 | where (a,b) is for each stage and `a` means that this stage has `a` blocks remaining and `b` is True means that the first block in this stage is not pruned. (The reason why we make a distinction for the first block is that the first block in each stage contains a downsample convolution which is a corner case in the code.) 13 | 3. Training: `python train.py`. Specify the gpu configuration in `config/cfgs.py`. For scratch-E training, use the standard 100 epochs with learning rate decay at 30, 60, 90 epochs. Also, for scratch-B training, modify the `lr_step` in `config/cfgs.py`, where each learning rate stage is expanded with a uniform ratio (FLOPs reduction ratio). 14 | ### Scratch-B training schedule 15 | Network|Epochs|lr step| 16 | :---:|:---:|:---:| 17 | ResNet-41|117| [35, 70, 105] 18 | ResNet-32|145| [43, 86, 129] 19 | ResNet-26|179| [53, 106, 159] 20 | 21 | ## Models 22 | Network|Training method|Top-1|Top-5|Download 23 | :---:|:---:|:---:|:---:|:---: 24 | ResNet-26|scratch-E| 72.31| 90.90| [MXNET model (57 MB)](https://drive.google.com/open?id=1wWPu5EGyT9lzKpYhqG2_1emdE59ctmvb) 25 | ResNet-26|scratch-B| 73.41| 91.39| [MXNET model (57 MB)](https://drive.google.com/open?id=1lC5TXPtz_9Py3yeOA_MGXvXwBiL6wdYS) 26 | ResNet-32|scratch-E| 73.79| 91.80| [MXNET model (55 MB)](https://drive.google.com/open?id=1h8iwPIT8z3h8ETFGP740FeEgYfxteLqU) 27 | ResNet-32|scratch-B| 74.67| 92.22| [MXNET model (55 MB)](https://drive.google.com/open?id=1ud-K1p_g7ltD3MTJqiqkFgpLJeN_yAVB) 28 | ResNet-41|scratch-E| 75.70| 92.74| [MXNET model (130 MB)](https://drive.google.com/open?id=1DgaqzjMqiFZz1vftKSw8yESCrrp6R6QV) 29 | ResNet-41|scratch-B| 76.17| 92.90| [MXNET model (130 MB)](https://drive.google.com/open?id=1DgaqzjMqiFZz1vftKSw8yESCrrp6R6QV) 30 | -------------------------------------------------------------------------------- /imagenet/thinet/README.md: -------------------------------------------------------------------------------- 1 | # ThiNet 2 | This directory contains a pytorch implementation of the ImageNet experiments of [ThiNet](https://arxiv.org/abs/1707.06342). The authors have released their code and models in this [repository](https://github.com/Roll920/ThiNet). 3 | 4 | ## Finetune 5 | We use the released model from their repository, where they use Caffe. Therefore, we test the models in Caffe and report the accuracy in the paper. 6 | 7 | ## Scratch-E 8 | ``` 9 | python main_E.py --arch vgg16 --model thinet-conv --lr 0.01 --save [PATH TO SAVE MODEL] [IMAGENET] 10 | python main_E.py --arch vgg16 --model thinet-gap --lr 0.01 --save [PATH TO SAVE MODEL] [IMAGENET] 11 | python main_E.py --arch vgg16 --model thinet-tiny --lr 0.01 --save [PATH TO SAVE MODEL] [IMAGENET] 12 | python main_E.py --arch resnet50 --model thinet-30 --save [PATH TO SAVE MODEL] [IMAGENET] 13 | python main_E.py --arch resnet50 --model thinet-50 --save [PATH TO SAVE MODEL] [IMAGENET] 14 | python main_E.py --arch resnet50 --model thinet-70 --save [PATH TO SAVE MODEL] [IMAGENET] 15 | ``` 16 | Here, `thinet-conv`, `thinet-gap` , `thinet-tiny` , `thinet-30`, `thinet-50`, `thinet-70` refer to the models in ThiNet. 17 | 18 | ## Scratch-B 19 | 20 | ``` 21 | python main_B.py --arch vgg16 --model thinet-conv --lr 0.01 --save [PATH TO SAVE MODEL] [IMAGENET] 22 | python main_B.py --arch vgg16 --model thinet-gap --lr 0.01 --save [PATH TO SAVE MODEL] [IMAGENET] 23 | python main_B.py --arch vgg16 --model thinet-tiny --lr 0.01 --save [PATH TO SAVE MODEL] [IMAGENET] 24 | python main_B.py --arch resnet50 --model thinet-30 --save [PATH TO SAVE MODEL] [IMAGENET] 25 | python main_B.py --arch resnet50 --model thinet-50 --save [PATH TO SAVE MODEL] [IMAGENET] 26 | python main_B.py --arch resnet50 --model thinet-70 --save [PATH TO SAVE MODEL] [IMAGENET] 27 | ``` 28 | For all networks other than `thinet-70`, the number of epochs for scratch-B training is 180; for `thinet-70`, the number of epochs for scratch-B training is 141. 29 | 30 | ## Models 31 | We test the model using the scheme: resize the shorter edge to 256 and center crop to (224,224). 32 | ### VGG 33 | Network|Training method|Top-1|Top-5|Download 34 | :---:|:---:|:---:|:---:|:---: 35 | VGG-Conv|scratch-E| 68.76| 88.71| [pytorch model (1003 MB)](https://drive.google.com/open?id=1Jr7n5q4BiYEHUVEv1FuzfAn0S26CNFsd) 36 | VGG-Conv|scratch-B| 71.72| 90.34| [pytorch model (1003 MB)](https://drive.google.com/open?id=12DC2hpbQNVUSpS3ojcxjncW69jh1NpbN) 37 | VGG-GAP|scratch-E| 66.85| 87.07| [pytorch model (64 MB)](https://drive.google.com/open?id=1FnPVJGjlL36tOJo1__7nr3Jykk-VhMLv) 38 | VGG-GAP|scratch-B| 68.66| 88.13| [pytorch model (64 MB)](https://drive.google.com/open?id=1YqDnc6JbXQl83E50P1fmUDO7J7SuEdgK) 39 | VGG-Tiny|scratch-E| 57.15| 79.92| [pytorch model (10 MB)](https://drive.google.com/open?id=1J-ydiASraEdKYEwDu-u5kG8FFgdyXpV_) 40 | VGG-Tiny|scratch-B| 59.93| 82.07| [pytorch model (10 MB)](https://drive.google.com/open?id=1J1JRBLd-2AbDNk57621Wst02QlrC4jf4) 41 | 42 | ### ResNet 43 | Network|Training method|Top-1|Top-5|Download 44 | :---:|:---:|:---:|:---:|:---: 45 | ThiNet-30|scratch-E| 70.91| 90.14| [pytorch model (66 MB)](https://drive.google.com/open?id=14cJ_oF4bAatcXiKEhBQHWny5kiC0hMBu) 46 | ThiNet-30|scratch-B| 71.57| 90.49| [pytorch model (66 MB)](https://drive.google.com/open?id=1RkiTHKxFfmP6jYl2vo_gYC2NZzwfIe1M) 47 | ThiNet-50|scratch-E| 73.31| 91.49| [pytorch model (95 MB)](https://drive.google.com/open?id=1E3c_7wvGXeUywXYVWIhzup5rp47TB9Tw) 48 | ThiNet-50|scratch-B| 73.90| 91.98| [pytorch model (95 MB)](https://drive.google.com/open?id=1-0ip4ZDSxpbQx7D_5-VgOs-B8ww2jNC4) 49 | ThiNet-70|scratch-E| 74.42| 92.07| [pytorch model (130 MB)](https://drive.google.com/open?id=1rTdotQKYBVHr03n1kYjYAVLDwxyYJok1) 50 | ThiNet-70|scratch-B| 75.14| 92.34| [pytorch model (130 MB)](https://drive.google.com/open?id=1p2ER072IyFmDZRoAdQrRsedLElnmuct4) 51 | -------------------------------------------------------------------------------- /imagenet/thinet/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | 11 | def print_model_param_nums(model=None): 12 | if model == None: 13 | model = torchvision.models.alexnet() 14 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 15 | print(' + Number of params: %.4fM' % (total / 1e6)) 16 | 17 | def count_model_param_flops(model=None, input_res=224, multiply_adds=True): 18 | 19 | prods = {} 20 | def save_hook(name): 21 | def hook_per(self, input, output): 22 | prods[name] = np.prod(input[0].shape) 23 | return hook_per 24 | 25 | list_1=[] 26 | def simple_hook(self, input, output): 27 | list_1.append(np.prod(input[0].shape)) 28 | list_2={} 29 | def simple_hook2(self, input, output): 30 | list_2['names'] = np.prod(input[0].shape) 31 | 32 | 33 | list_conv=[] 34 | def conv_hook(self, input, output): 35 | batch_size, input_channels, input_height, input_width = input[0].size() 36 | output_channels, output_height, output_width = output[0].size() 37 | 38 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 39 | bias_ops = 1 if self.bias is not None else 0 40 | 41 | params = output_channels * (kernel_ops + bias_ops) 42 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 43 | 44 | num_weight_params = (self.weight.data != 0).float().sum() 45 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 46 | 47 | list_conv.append(flops) 48 | 49 | list_linear=[] 50 | def linear_hook(self, input, output): 51 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 52 | 53 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 54 | bias_ops = self.bias.nelement() 55 | 56 | flops = batch_size * (weight_ops + bias_ops) 57 | list_linear.append(flops) 58 | 59 | list_bn=[] 60 | def bn_hook(self, input, output): 61 | list_bn.append(input[0].nelement() * 2) 62 | 63 | list_relu=[] 64 | def relu_hook(self, input, output): 65 | list_relu.append(input[0].nelement()) 66 | 67 | list_pooling=[] 68 | def pooling_hook(self, input, output): 69 | batch_size, input_channels, input_height, input_width = input[0].size() 70 | output_channels, output_height, output_width = output[0].size() 71 | 72 | kernel_ops = self.kernel_size * self.kernel_size 73 | bias_ops = 0 74 | params = 0 75 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 76 | 77 | list_pooling.append(flops) 78 | 79 | list_upsample=[] 80 | 81 | # For bilinear upsample 82 | def upsample_hook(self, input, output): 83 | batch_size, input_channels, input_height, input_width = input[0].size() 84 | output_channels, output_height, output_width = output[0].size() 85 | 86 | flops = output_height * output_width * output_channels * batch_size * 12 87 | list_upsample.append(flops) 88 | 89 | def foo(net): 90 | childrens = list(net.children()) 91 | if not childrens: 92 | if isinstance(net, torch.nn.Conv2d): 93 | net.register_forward_hook(conv_hook) 94 | if isinstance(net, torch.nn.Linear): 95 | net.register_forward_hook(linear_hook) 96 | if isinstance(net, torch.nn.BatchNorm2d): 97 | net.register_forward_hook(bn_hook) 98 | if isinstance(net, torch.nn.ReLU): 99 | net.register_forward_hook(relu_hook) 100 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 101 | net.register_forward_hook(pooling_hook) 102 | if isinstance(net, torch.nn.Upsample): 103 | net.register_forward_hook(upsample_hook) 104 | return 105 | for c in childrens: 106 | foo(c) 107 | 108 | if model == None: 109 | model = torchvision.models.alexnet() 110 | foo(model) 111 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 112 | out = model(input) 113 | 114 | 115 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 116 | 117 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 118 | 119 | return total_flops 120 | -------------------------------------------------------------------------------- /imagenet/thinet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from thinetconv import thinet_conv, vgg_official 2 | from thinetvgg import thinet_gap, thinet_tiny 3 | from thinetresnet import thinet30, thinet50, thinet70, resnet50_official -------------------------------------------------------------------------------- /imagenet/thinet/models/thinetconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from torch.autograd import Variable 6 | 7 | 8 | __all__ = [ 9 | 'thinet_conv' 10 | ] 11 | 12 | model_urls = { 13 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 14 | } 15 | 16 | class VGG(nn.Module): 17 | 18 | def __init__(self, features, num_classes=1000, init_weights=True): 19 | super(VGG, self).__init__() 20 | self.features = features 21 | self.classifier = nn.Sequential( 22 | nn.Linear(512 * 7 * 7, 4096), 23 | nn.ReLU(True), 24 | nn.Dropout(), 25 | nn.Linear(4096, 4096), 26 | nn.ReLU(True), 27 | nn.Dropout(), 28 | nn.Linear(4096, num_classes), 29 | ) 30 | if init_weights: 31 | self._initialize_weights() 32 | 33 | def forward(self, x): 34 | x = self.features(x) 35 | x = x.view(x.size(0), -1) 36 | x = self.classifier(x) 37 | return x 38 | 39 | def _initialize_weights(self): 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | nn.init.kaiming_normal(m.weight, mode='fan_out')#, nonlinearity='relu') 43 | if m.bias is not None: 44 | m.bias.data.zero_() 45 | elif isinstance(m, nn.Linear): 46 | m.weight.data.normal_(0, 0.01) 47 | m.bias.data.zero_() 48 | elif isinstance(m, nn.BatchNorm2d): 49 | m.weight.data.fill_(1) 50 | m.bias.data.zero_() 51 | 52 | def make_layers(cfg, batch_norm=False): 53 | layers = [] 54 | in_channels = 3 55 | for v in cfg: 56 | if v == 'M': 57 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 58 | else: 59 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 60 | if batch_norm: 61 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 62 | else: 63 | layers += [conv2d, nn.ReLU(inplace=True)] 64 | in_channels = v 65 | return nn.Sequential(*layers) 66 | 67 | conv = [32, 32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M'] 68 | cfg_official = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 69 | 70 | def thinet_conv(pretrained=False, **kwargs): 71 | """VGG 16-layer model (configuration "D") 72 | 73 | Args: 74 | pretrained (bool): If True, returns a model pre-trained on ImageNet 75 | """ 76 | if pretrained: 77 | kwargs['init_weights'] = False 78 | model = VGG(make_layers(conv, False), **kwargs) 79 | if pretrained: 80 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 81 | return model 82 | 83 | def vgg_official(pretrained=False, **kwargs): 84 | """VGG 16-layer model (configuration "D") 85 | Args: 86 | pretrained (bool): If True, returns a model pre-trained on ImageNet 87 | """ 88 | if pretrained: 89 | kwargs['init_weights'] = False 90 | model = VGG(make_layers(cfg_official, False), **kwargs) 91 | if pretrained: 92 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 93 | return model -------------------------------------------------------------------------------- /imagenet/thinet/models/thinetvgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from torch.autograd import Variable 6 | 7 | 8 | __all__ = [ 9 | 'thinet_gap', 'thinet_tiny' 10 | ] 11 | 12 | model_urls = { 13 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 14 | } 15 | 16 | class VGG(nn.Module): 17 | 18 | def __init__(self, features, cfg, num_classes=1000, init_weights=True): 19 | super(VGG, self).__init__() 20 | self.features = features 21 | self.classifier = nn.Sequential( 22 | nn.Linear(cfg, num_classes), 23 | ) 24 | if init_weights: 25 | self._initialize_weights() 26 | 27 | def forward(self, x): 28 | x = self.features(x) 29 | x = nn.AvgPool2d(14,2)(x) 30 | x = x.view(x.size(0), -1) 31 | x = self.classifier(x) 32 | return x 33 | 34 | def _initialize_weights(self): 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | nn.init.kaiming_normal(m.weight, mode='fan_out')#, nonlinearity='relu') 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | elif isinstance(m, nn.Linear): 41 | m.weight.data.normal_(0, 0.01) 42 | m.bias.data.zero_() 43 | 44 | def make_layers(cfg, batch_norm=False): 45 | layers = [] 46 | in_channels = 3 47 | for v in cfg: 48 | if v == 'M': 49 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 50 | else: 51 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 52 | if batch_norm: 53 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 54 | else: 55 | layers += [conv2d, nn.ReLU(inplace=True)] 56 | in_channels = v 57 | return nn.Sequential(*layers) 58 | 59 | gap = [32, 32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] 60 | tiny = [16, 16, 'M', 32, 32, 'M', 64, 64, 64, 'M', 128, 128, 128, 'M', 128, 128, 256] 61 | 62 | def thinet_gap(pretrained=False, **kwargs): 63 | """VGG 16-layer model (configuration "D") 64 | 65 | Args: 66 | pretrained (bool): If True, returns a model pre-trained on ImageNet 67 | """ 68 | if pretrained: 69 | kwargs['init_weights'] = False 70 | model = VGG(make_layers(gap), 512, **kwargs) 71 | if pretrained: 72 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 73 | return model 74 | 75 | def thinet_tiny(pretrained=False, **kwargs): 76 | """VGG 16-layer model (configuration "D") 77 | 78 | Args: 79 | pretrained (bool): If True, returns a model pre-trained on ImageNet 80 | """ 81 | if pretrained: 82 | kwargs['init_weights'] = False 83 | model = VGG(make_layers(tiny), 256, **kwargs) 84 | if pretrained: 85 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 86 | return model 87 | -------------------------------------------------------------------------------- /imagenet/weight-level/README.md: -------------------------------------------------------------------------------- 1 | # Non-Structured Pruning/Weight-Level Pruning 2 | 3 | This directory contains a pytorch implementation of the ImageNet experiments of non-structured pruning. 4 | 5 | ## Implementation 6 | We prune only the weights in the convolutional layer. We use the mask implementation, where during pruning, we set the weights that are pruned to be 0. During training, we make sure that we don't update those pruned parameters. 7 | 8 | ## Baseline 9 | We get the base model of VGG-16 and ResNet-50 from Pytorch [Model Zoo](https://pytorch.org/docs/stable/torchvision/models.html). 10 | 11 | ## Prune 12 | ``` 13 | python prune.py --arch vgg16_bn --pretrained --percent 0.3 --save [PATH TO SAVE RESULTS] [IMAGENET] 14 | python prune.py --arch resnet50 --pretrained --percent 0.3 --save [PATH TO SAVE RESULTS] [IMAGENET] 15 | ``` 16 | 17 | ## Finetune 18 | ``` 19 | python main_finetune.py --arch vgg16_bn --resume [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 20 | python main_finetune.py --arch resnet50 --resume [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 21 | ``` 22 | 23 | ## Scratch-E 24 | ``` 25 | python main_E.py --arch vgg16_bn --resume [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 26 | python main_E.py --arch resnet50 --resume [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 27 | ``` 28 | 29 | ## Scratch-B 30 | ``` 31 | python main_B.py --arch vgg16_bn --resume [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 32 | python main_B.py --arch resnet50 --resume [PATH TO THE PRUNED MODEL] --save [PATH TO SAVE RESULTS] [IMAGENET] 33 | ``` 34 | 35 | ## Models 36 | ### VGG 37 | Network|Prune ratio|Training method|Top-1|Top-5|Download 38 | :---:|:---:|:---:|:---:|:---:|:---: 39 | VGG-16|30%|finetune| 73.68| 91.53| [pytorch model (1024 MB)](https://drive.google.com/open?id=1OWGaJ-tXAlS4Ne5zhZ1M4k-1rk723do0) 40 | VGG-16|30%|scratch-E| 72.75| 91.06| [pytorch model (1024 MB)](https://drive.google.com/open?id=1kgGiBaG1Y6Kh-EK27APWoMzeV7jO_jlL) 41 | VGG-16|30%|scratch-B| 74.02| 91.78| [pytorch model (1024 MB)](https://drive.google.com/open?id=1ADbEpkziEMs_FPKAP-6BytcBHfqshrlg) 42 | VGG-16|60%|finetune| 73.63| 91.54| [pytorch model (1024 MB)](https://drive.google.com/open?id=1xZOFuxKJEdv9AtoHcv5VvZsrWM7-vujY) 43 | VGG-16|60%|scratch-E| 71.50| 90.43| [pytorch model (1024 MB)](https://drive.google.com/open?id=1s4yETDG0WB7ZerHmGudVRo2Z0JWuxZXr) 44 | VGG-16|60%|scratch-B| 73.42| 91.48| [pytorch model (1024 MB)](https://drive.google.com/open?id=1APsXiwxq2VCitKvGoeqfHieEdWEpMk6W) 45 | 46 | ### ResNet 47 | Network|Prune ratio|Training method|Top-1|Top-5|Download 48 | :---:|:---:|:---:|:---:|:---:|:---: 49 | ResNet-50|30%|finetune| 76.06| 92.88| [pytorch model (195 MB)](https://drive.google.com/open?id=17bzfWtHjTkCture96d7MG0afrFiY9xFF) 50 | ResNet-50|30%|scratch-E| 74.77| 92.19| [pytorch model (195 MB)](https://drive.google.com/open?id=1C3VxBlWbOwjtvlFe_5cRZpFFY0H4NJRp) 51 | ResNet-50|30%|scratch-B| 75.60| 92.75| [pytorch model (195 MB)](https://drive.google.com/open?id=1z3ABz6Pk0drVueWJRucG68MeGnAyA6t7) 52 | ResNet-50|60%|finetune| 76.09| 92.91| [pytorch model (195 MB)](https://drive.google.com/open?id=1iTwXpW61OodacsefyuSDljtGPj0FxvUY) 53 | ResNet-50|60%|scratch-E| 73.69| 91.61| [pytorch model (195 MB)](https://drive.google.com/open?id=1LYyCHVypbkkS23RVOgcE8clUs3haRlHA) 54 | ResNet-50|60%|scratch-B| 74.90| 92.28| [pytorch model (195 MB)](https://drive.google.com/open?id=17pqC05Sakt18xoRnpddAf-sYNs_vaKPk) -------------------------------------------------------------------------------- /imagenet/weight-level/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | 9 | import torch 10 | import torchvision 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | import torch.optim 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | 23 | 24 | def print_model_param_nums(model=None): 25 | if model == None: 26 | model = torchvision.models.alexnet() 27 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 28 | print(' + Number of params: %.2fM' % (total / 1e6)) 29 | 30 | def count_model_param_flops(model=None, input_res=224, multiply_adds=True): 31 | 32 | prods = {} 33 | def save_hook(name): 34 | def hook_per(self, input, output): 35 | prods[name] = np.prod(input[0].shape) 36 | return hook_per 37 | 38 | list_1=[] 39 | def simple_hook(self, input, output): 40 | list_1.append(np.prod(input[0].shape)) 41 | list_2={} 42 | def simple_hook2(self, input, output): 43 | list_2['names'] = np.prod(input[0].shape) 44 | 45 | 46 | list_conv=[] 47 | def conv_hook(self, input, output): 48 | batch_size, input_channels, input_height, input_width = input[0].size() 49 | output_channels, output_height, output_width = output[0].size() 50 | 51 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 52 | bias_ops = 1 if self.bias is not None else 0 53 | 54 | params = output_channels * (kernel_ops + bias_ops) 55 | 56 | num_weight_params = (self.weight.data != 0).float().sum() 57 | assert self.weight.numel() == kernel_ops * output_channels, "Not match" 58 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 59 | 60 | list_conv.append(flops) 61 | 62 | list_linear=[] 63 | def linear_hook(self, input, output): 64 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 65 | 66 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 67 | bias_ops = self.bias.nelement() 68 | 69 | flops = batch_size * (weight_ops + bias_ops) 70 | list_linear.append(flops) 71 | 72 | list_bn=[] 73 | def bn_hook(self, input, output): 74 | list_bn.append(input[0].nelement() * 2) 75 | 76 | list_relu=[] 77 | def relu_hook(self, input, output): 78 | list_relu.append(input[0].nelement()) 79 | 80 | list_pooling=[] 81 | def pooling_hook(self, input, output): 82 | batch_size, input_channels, input_height, input_width = input[0].size() 83 | output_channels, output_height, output_width = output[0].size() 84 | 85 | kernel_ops = self.kernel_size * self.kernel_size 86 | bias_ops = 0 87 | params = 0 88 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 89 | 90 | list_pooling.append(flops) 91 | 92 | list_upsample=[] 93 | # For bilinear upsample 94 | def upsample_hook(self, input, output): 95 | batch_size, input_channels, input_height, input_width = input[0].size() 96 | output_channels, output_height, output_width = output[0].size() 97 | 98 | flops = output_height * output_width * output_channels * batch_size * 12 99 | list_upsample.append(flops) 100 | 101 | def foo(net): 102 | childrens = list(net.children()) 103 | if not childrens: 104 | if isinstance(net, torch.nn.Conv2d): 105 | net.register_forward_hook(conv_hook) 106 | if isinstance(net, torch.nn.Linear): 107 | net.register_forward_hook(linear_hook) 108 | if isinstance(net, torch.nn.BatchNorm2d): 109 | net.register_forward_hook(bn_hook) 110 | if isinstance(net, torch.nn.ReLU): 111 | net.register_forward_hook(relu_hook) 112 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 113 | net.register_forward_hook(pooling_hook) 114 | if isinstance(net, torch.nn.Upsample): 115 | net.register_forward_hook(upsample_hook) 116 | return 117 | for c in childrens: 118 | foo(c) 119 | 120 | if model == None: 121 | model = torchvision.models.alexnet() 122 | foo(model) 123 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 124 | out = model(input) 125 | 126 | 127 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 128 | 129 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 130 | 131 | return total_flops --------------------------------------------------------------------------------