├── mask-impl ├── models │ ├── __init__.py │ ├── vgg.py │ ├── preresnet.py │ └── densenet.py ├── README.md ├── prune_mask.py └── main_mask.py ├── models ├── __init__.py ├── channel_selection.py ├── vgg.py ├── preresnet.py └── densenet.py ├── LICENSE ├── README.md ├── vggprune.py ├── denseprune.py ├── main.py └── resprune.py /mask-impl/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .vgg import * 4 | from .preresnet import * 5 | from .densenet import * -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .vgg import * 4 | from .preresnet import * 5 | from .densenet import * 6 | from .channel_selection import * -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Mingjie Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /models/channel_selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class channel_selection(nn.Module): 7 | """ 8 | Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer. 9 | The output shape of this layer is determined by the number of 1 in `self.indexes`. 10 | """ 11 | def __init__(self, num_channels): 12 | """ 13 | Initialize the `indexes` with all one vector with the length same as the number of channels. 14 | During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0. 15 | """ 16 | super(channel_selection, self).__init__() 17 | self.indexes = nn.Parameter(torch.ones(num_channels)) 18 | 19 | def forward(self, input_tensor): 20 | """ 21 | Parameter 22 | --------- 23 | input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer. 24 | """ 25 | selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy())) 26 | if selected_index.size == 1: 27 | selected_index = np.resize(selected_index, (1,)) 28 | output = input_tensor[:, selected_index, :, :] 29 | return output -------------------------------------------------------------------------------- /mask-impl/README.md: -------------------------------------------------------------------------------- 1 | ## Mask Implementation of Network Slimming 2 | During pruning, we set those scaling factors in BN layer which correspond to pruned channels to be 0. 3 | When training the pruned model, in each iteration, before 4 | we call `optimizer.step()`, we update the gradient of those 0 scaling factors to be 0. This is achieved in `BN_grad_zero` function. 5 | ### Pros 6 | - We don't need to introduce channel selection layer which adds to the training time. 7 | - Even if a layer is pruned to zero channels, it won't raise any error. Instead, this layer will simply output an all-0 tensor. 8 | 9 | ### Cons 10 | - Not easy to compute flops and parameters. 11 | 12 | ## Baseline 13 | ```shell 14 | python main_mask.py --dataset cifar100 --arch resnet --depth 164 15 | ``` 16 | 17 | ## Sparsity 18 | ```shell 19 | python main_mask.py --dataset cifar100 --arch resnet --depth 164 -sr --s 0.00001 20 | ``` 21 | 22 | ## Prune 23 | ```shell 24 | python prune_mask.py --dataset cifar100 --arch resnet --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 25 | ``` 26 | 27 | ## Fine-tune 28 | ```shell 29 | python main_mask.py --dataset cifar100 --arch resnet --depth 164 --refine [DIRECTORY TO THE PRUNED MODEL] 30 | ``` 31 | ## Results 32 | | CIFAR100-Resnet-164 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | 33 | | :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: |:--------------------: | :-----------------:| 34 | | Top1 Accuracy (%) | 76.68 | 76.89 | 48.61 | 77.33 | 1.91 | 76.07 | 35 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | __all__ = ['vgg'] 8 | 9 | defaultcfg = { 10 | 11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 11 | 13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 12 | 16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 13 | 19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 14 | } 15 | 16 | class vgg(nn.Module): 17 | def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None): 18 | super(vgg, self).__init__() 19 | if cfg is None: 20 | cfg = defaultcfg[depth] 21 | 22 | self.feature = self.make_layers(cfg, True) 23 | 24 | if dataset == 'cifar10': 25 | num_classes = 10 26 | elif dataset == 'cifar100': 27 | num_classes = 100 28 | self.classifier = nn.Linear(cfg[-1], num_classes) 29 | if init_weights: 30 | self._initialize_weights() 31 | 32 | def make_layers(self, cfg, batch_norm=False): 33 | layers = [] 34 | in_channels = 3 35 | for v in cfg: 36 | if v == 'M': 37 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 38 | else: 39 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 40 | if batch_norm: 41 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 42 | else: 43 | layers += [conv2d, nn.ReLU(inplace=True)] 44 | in_channels = v 45 | return nn.Sequential(*layers) 46 | 47 | def forward(self, x): 48 | x = self.feature(x) 49 | x = nn.AvgPool2d(2)(x) 50 | x = x.view(x.size(0), -1) 51 | y = self.classifier(x) 52 | return y 53 | 54 | def _initialize_weights(self): 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 58 | m.weight.data.normal_(0, math.sqrt(2. / n)) 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | elif isinstance(m, nn.BatchNorm2d): 62 | m.weight.data.fill_(0.5) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.Linear): 65 | m.weight.data.normal_(0, 0.01) 66 | m.bias.data.zero_() 67 | 68 | if __name__ == '__main__': 69 | net = vgg() 70 | x = Variable(torch.FloatTensor(16, 3, 40, 40)) 71 | y = net(x) 72 | print(y.data.shape) -------------------------------------------------------------------------------- /mask-impl/models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | __all__ = ['vgg'] 8 | 9 | defaultcfg = { 10 | 11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 11 | 13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 12 | 16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 13 | 19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 14 | } 15 | 16 | class vgg(nn.Module): 17 | def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None): 18 | super(vgg, self).__init__() 19 | if cfg is None: 20 | cfg = defaultcfg[depth] 21 | 22 | self.feature = self.make_layers(cfg, True) 23 | 24 | if dataset == 'cifar10': 25 | num_classes = 10 26 | elif dataset == 'cifar100': 27 | num_classes = 100 28 | self.classifier = nn.Linear(cfg[-1], num_classes) 29 | if init_weights: 30 | self._initialize_weights() 31 | 32 | def make_layers(self, cfg, batch_norm=False): 33 | layers = [] 34 | in_channels = 3 35 | for v in cfg: 36 | if v == 'M': 37 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 38 | else: 39 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 40 | if batch_norm: 41 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 42 | else: 43 | layers += [conv2d, nn.ReLU(inplace=True)] 44 | in_channels = v 45 | return nn.Sequential(*layers) 46 | 47 | def forward(self, x): 48 | x = self.feature(x) 49 | x = nn.AvgPool2d(2)(x) 50 | x = x.view(x.size(0), -1) 51 | y = self.classifier(x) 52 | return y 53 | 54 | def _initialize_weights(self): 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 58 | m.weight.data.normal_(0, math.sqrt(2. / n)) 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | elif isinstance(m, nn.BatchNorm2d): 62 | m.weight.data.fill_(0.5) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.Linear): 65 | m.weight.data.normal_(0, 0.01) 66 | m.bias.data.zero_() 67 | 68 | if __name__ == '__main__': 69 | net = vgg() 70 | x = Variable(torch.FloatTensor(16, 3, 40, 40)) 71 | y = net(x) 72 | print(y.data.shape) -------------------------------------------------------------------------------- /mask-impl/models/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import math 3 | import torch.nn as nn 4 | 5 | 6 | __all__ = ['resnet'] 7 | 8 | """ 9 | preactivation resnet with bottleneck design. 10 | """ 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, cfg, stride=1, downsample=None): 16 | super(Bottleneck, self).__init__() 17 | self.bn1 = nn.BatchNorm2d(inplanes) 18 | self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(cfg[1]) 20 | self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(cfg[2]) 23 | self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.bn1(x) 32 | out = self.relu(out) 33 | out = self.conv1(out) 34 | 35 | out = self.bn2(out) 36 | out = self.relu(out) 37 | out = self.conv2(out) 38 | 39 | out = self.bn3(out) 40 | out = self.relu(out) 41 | out = self.conv3(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | 48 | return out 49 | 50 | class resnet(nn.Module): 51 | def __init__(self, depth=164, dataset='cifar10', cfg=None): 52 | super(resnet, self).__init__() 53 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2' 54 | 55 | n = (depth - 2) // 9 56 | block = Bottleneck 57 | 58 | if cfg is None: 59 | # Construct config variable. 60 | cfg = [[16, 16, 16], [64, 16, 16]*(n-1), [64, 32, 32], [128, 32, 32]*(n-1), [128, 64, 64], [256, 64, 64]*(n-1), [256]] 61 | cfg = [item for sub_list in cfg for item in sub_list] 62 | 63 | self.inplanes = 16 64 | 65 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 66 | bias=False) 67 | self.layer1 = self._make_layer(block, 16, n, cfg = cfg[0:3*n]) 68 | self.layer2 = self._make_layer(block, 32, n, cfg = cfg[3*n:6*n], stride=2) 69 | self.layer3 = self._make_layer(block, 64, n, cfg = cfg[6*n:9*n], stride=2) 70 | self.bn = nn.BatchNorm2d(64 * block.expansion) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.avgpool = nn.AvgPool2d(8) 73 | 74 | if dataset == 'cifar10': 75 | self.fc = nn.Linear(cfg[-1], 10) 76 | elif dataset == 'cifar100': 77 | self.fc = nn.Linear(cfg[-1], 100) 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 82 | m.weight.data.normal_(0, math.sqrt(2. / n)) 83 | elif isinstance(m, nn.BatchNorm2d): 84 | m.weight.data.fill_(0.5) 85 | m.bias.data.zero_() 86 | 87 | def _make_layer(self, block, planes, blocks, cfg, stride=1): 88 | downsample = None 89 | if stride != 1 or self.inplanes != planes * block.expansion: 90 | downsample = nn.Sequential( 91 | nn.Conv2d(self.inplanes, planes * block.expansion, 92 | kernel_size=1, stride=stride, bias=False), 93 | ) 94 | 95 | layers = [] 96 | layers.append(block(self.inplanes, planes, cfg[0:3], stride, downsample)) 97 | self.inplanes = planes * block.expansion 98 | for i in range(1, blocks): 99 | layers.append(block(self.inplanes, planes, cfg[3*i: 3*(i+1)])) 100 | 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | x = self.conv1(x) 105 | 106 | x = self.layer1(x) # 32x32 107 | x = self.layer2(x) # 16x16 108 | x = self.layer3(x) # 8x8 109 | x = self.bn(x) 110 | x = self.relu(x) 111 | 112 | x = self.avgpool(x) 113 | x = x.view(x.size(0), -1) 114 | x = self.fc(x) 115 | 116 | return x -------------------------------------------------------------------------------- /models/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import math 3 | import torch.nn as nn 4 | from .channel_selection import channel_selection 5 | 6 | 7 | __all__ = ['resnet'] 8 | 9 | """ 10 | preactivation resnet with bottleneck design. 11 | """ 12 | 13 | class Bottleneck(nn.Module): 14 | expansion = 4 15 | 16 | def __init__(self, inplanes, planes, cfg, stride=1, downsample=None): 17 | super(Bottleneck, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.select = channel_selection(inplanes) 20 | self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(cfg[1]) 22 | self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | self.bn3 = nn.BatchNorm2d(cfg[2]) 25 | self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.bn1(x) 34 | out = self.select(out) 35 | out = self.relu(out) 36 | out = self.conv1(out) 37 | 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | out = self.conv2(out) 41 | 42 | out = self.bn3(out) 43 | out = self.relu(out) 44 | out = self.conv3(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | 51 | return out 52 | 53 | class resnet(nn.Module): 54 | def __init__(self, depth=164, dataset='cifar10', cfg=None): 55 | super(resnet, self).__init__() 56 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2' 57 | 58 | n = (depth - 2) // 9 59 | block = Bottleneck 60 | 61 | if cfg is None: 62 | # Construct config variable. 63 | cfg = [[16, 16, 16], [64, 16, 16]*(n-1), [64, 32, 32], [128, 32, 32]*(n-1), [128, 64, 64], [256, 64, 64]*(n-1), [256]] 64 | cfg = [item for sub_list in cfg for item in sub_list] 65 | 66 | self.inplanes = 16 67 | 68 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 69 | bias=False) 70 | self.layer1 = self._make_layer(block, 16, n, cfg = cfg[0:3*n]) 71 | self.layer2 = self._make_layer(block, 32, n, cfg = cfg[3*n:6*n], stride=2) 72 | self.layer3 = self._make_layer(block, 64, n, cfg = cfg[6*n:9*n], stride=2) 73 | self.bn = nn.BatchNorm2d(64 * block.expansion) 74 | self.select = channel_selection(64 * block.expansion) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.avgpool = nn.AvgPool2d(8) 77 | 78 | if dataset == 'cifar10': 79 | self.fc = nn.Linear(cfg[-1], 10) 80 | elif dataset == 'cifar100': 81 | self.fc = nn.Linear(cfg[-1], 100) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | m.weight.data.fill_(0.5) 89 | m.bias.data.zero_() 90 | 91 | def _make_layer(self, block, planes, blocks, cfg, stride=1): 92 | downsample = None 93 | if stride != 1 or self.inplanes != planes * block.expansion: 94 | downsample = nn.Sequential( 95 | nn.Conv2d(self.inplanes, planes * block.expansion, 96 | kernel_size=1, stride=stride, bias=False), 97 | ) 98 | 99 | layers = [] 100 | layers.append(block(self.inplanes, planes, cfg[0:3], stride, downsample)) 101 | self.inplanes = planes * block.expansion 102 | for i in range(1, blocks): 103 | layers.append(block(self.inplanes, planes, cfg[3*i: 3*(i+1)])) 104 | 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | x = self.conv1(x) 109 | 110 | x = self.layer1(x) # 32x32 111 | x = self.layer2(x) # 16x16 112 | x = self.layer3(x) # 8x8 113 | x = self.bn(x) 114 | x = self.select(x) 115 | x = self.relu(x) 116 | 117 | x = self.avgpool(x) 118 | x = x.view(x.size(0), -1) 119 | x = self.fc(x) 120 | 121 | return x -------------------------------------------------------------------------------- /mask-impl/models/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | __all__ = ['densenet'] 9 | 10 | """ 11 | densenet with basic block. 12 | """ 13 | 14 | class BasicBlock(nn.Module): 15 | def __init__(self, inplanes, cfg, expansion=1, growthRate=12, dropRate=0): 16 | super(BasicBlock, self).__init__() 17 | planes = expansion * growthRate 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.conv1 = nn.Conv2d(cfg, growthRate, kernel_size=3, 20 | padding=1, bias=False) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.dropRate = dropRate 23 | 24 | def forward(self, x): 25 | out = self.bn1(x) 26 | out = self.relu(out) 27 | out = self.conv1(out) 28 | if self.dropRate > 0: 29 | out = F.dropout(out, p=self.dropRate, training=self.training) 30 | 31 | out = torch.cat((x, out), 1) 32 | 33 | return out 34 | 35 | class Transition(nn.Module): 36 | def __init__(self, inplanes, outplanes, cfg): 37 | super(Transition, self).__init__() 38 | self.bn1 = nn.BatchNorm2d(inplanes) 39 | self.conv1 = nn.Conv2d(cfg, outplanes, kernel_size=1, 40 | bias=False) 41 | self.relu = nn.ReLU(inplace=True) 42 | 43 | def forward(self, x): 44 | out = self.bn1(x) 45 | out = self.relu(out) 46 | out = self.conv1(out) 47 | out = F.avg_pool2d(out, 2) 48 | return out 49 | 50 | class densenet(nn.Module): 51 | 52 | def __init__(self, depth=40, 53 | dropRate=0, dataset='cifar10', growthRate=12, compressionRate=1, cfg = None): 54 | super(densenet, self).__init__() 55 | 56 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 57 | n = (depth - 4) // 3 58 | block = BasicBlock 59 | 60 | self.growthRate = growthRate 61 | self.dropRate = dropRate 62 | 63 | if cfg == None: 64 | cfg = [] 65 | start = growthRate*2 66 | for i in range(3): 67 | cfg.append([start+ growthRate*i for i in range(n+1)]) 68 | start += growthRate*n 69 | cfg = [item for sub_list in cfg for item in sub_list] 70 | 71 | assert len(cfg) == 3*n+3, 'length of config variable cfg should be 3n+3' 72 | 73 | # self.inplanes is a global variable used across multiple 74 | # helper functions 75 | self.inplanes = growthRate * 2 76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 77 | bias=False) 78 | self.dense1 = self._make_denseblock(block, n, cfg[0:n]) 79 | self.trans1 = self._make_transition(compressionRate, cfg[n]) 80 | self.dense2 = self._make_denseblock(block, n, cfg[n+1:2*n+1]) 81 | self.trans2 = self._make_transition(compressionRate, cfg[2*n+1]) 82 | self.dense3 = self._make_denseblock(block, n, cfg[2*n+2:3*n+2]) 83 | self.bn = nn.BatchNorm2d(self.inplanes) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.avgpool = nn.AvgPool2d(8) 86 | 87 | if dataset == 'cifar10': 88 | self.fc = nn.Linear(cfg[-1], 10) 89 | elif dataset == 'cifar100': 90 | self.fc = nn.Linear(cfg[-1], 100) 91 | 92 | # Weight initialization 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 96 | m.weight.data.normal_(0, math.sqrt(2. / n)) 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.weight.data.fill_(0.5) 99 | m.bias.data.zero_() 100 | 101 | def _make_denseblock(self, block, blocks, cfg): 102 | layers = [] 103 | assert blocks == len(cfg), 'Length of the cfg parameter is not right.' 104 | for i in range(blocks): 105 | # Currently we fix the expansion ratio as the default value 106 | layers.append(block(self.inplanes, cfg = cfg[i], growthRate=self.growthRate, dropRate=self.dropRate)) 107 | self.inplanes += self.growthRate 108 | 109 | return nn.Sequential(*layers) 110 | 111 | def _make_transition(self, compressionRate, cfg): 112 | # cfg is a number in this case. 113 | inplanes = self.inplanes 114 | outplanes = int(math.floor(self.inplanes // compressionRate)) 115 | self.inplanes = outplanes 116 | return Transition(inplanes, outplanes, cfg) 117 | 118 | def forward(self, x): 119 | x = self.conv1(x) 120 | 121 | x = self.trans1(self.dense1(x)) 122 | x = self.trans2(self.dense2(x)) 123 | x = self.dense3(x) 124 | x = self.bn(x) 125 | x = self.relu(x) 126 | 127 | x = self.avgpool(x) 128 | x = x.view(x.size(0), -1) 129 | x = self.fc(x) 130 | 131 | return x -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from .channel_selection import channel_selection 7 | 8 | 9 | __all__ = ['densenet'] 10 | 11 | """ 12 | densenet with basic block. 13 | """ 14 | 15 | class BasicBlock(nn.Module): 16 | def __init__(self, inplanes, cfg, expansion=1, growthRate=12, dropRate=0): 17 | super(BasicBlock, self).__init__() 18 | planes = expansion * growthRate 19 | self.bn1 = nn.BatchNorm2d(inplanes) 20 | self.select = channel_selection(inplanes) 21 | self.conv1 = nn.Conv2d(cfg, growthRate, kernel_size=3, 22 | padding=1, bias=False) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.dropRate = dropRate 25 | 26 | def forward(self, x): 27 | out = self.bn1(x) 28 | out = self.select(out) 29 | out = self.relu(out) 30 | out = self.conv1(out) 31 | if self.dropRate > 0: 32 | out = F.dropout(out, p=self.dropRate, training=self.training) 33 | 34 | out = torch.cat((x, out), 1) 35 | 36 | return out 37 | 38 | class Transition(nn.Module): 39 | def __init__(self, inplanes, outplanes, cfg): 40 | super(Transition, self).__init__() 41 | self.bn1 = nn.BatchNorm2d(inplanes) 42 | self.select = channel_selection(inplanes) 43 | self.conv1 = nn.Conv2d(cfg, outplanes, kernel_size=1, 44 | bias=False) 45 | self.relu = nn.ReLU(inplace=True) 46 | 47 | def forward(self, x): 48 | out = self.bn1(x) 49 | out = self.select(out) 50 | out = self.relu(out) 51 | out = self.conv1(out) 52 | out = F.avg_pool2d(out, 2) 53 | return out 54 | 55 | class densenet(nn.Module): 56 | 57 | def __init__(self, depth=40, 58 | dropRate=0, dataset='cifar10', growthRate=12, compressionRate=1, cfg = None): 59 | super(densenet, self).__init__() 60 | 61 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 62 | n = (depth - 4) // 3 63 | block = BasicBlock 64 | 65 | self.growthRate = growthRate 66 | self.dropRate = dropRate 67 | 68 | if cfg == None: 69 | cfg = [] 70 | start = growthRate*2 71 | for _ in range(3): 72 | cfg.append([start + growthRate*i for i in range(n+1)]) 73 | start += growthRate*n 74 | cfg = [item for sub_list in cfg for item in sub_list] 75 | 76 | assert len(cfg) == 3*n+3, 'length of config variable cfg should be 3n+3' 77 | 78 | # self.inplanes is a global variable used across multiple 79 | # helper functions 80 | self.inplanes = growthRate * 2 81 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 82 | bias=False) 83 | self.dense1 = self._make_denseblock(block, n, cfg[0:n]) 84 | self.trans1 = self._make_transition(compressionRate, cfg[n]) 85 | self.dense2 = self._make_denseblock(block, n, cfg[n+1:2*n+1]) 86 | self.trans2 = self._make_transition(compressionRate, cfg[2*n+1]) 87 | self.dense3 = self._make_denseblock(block, n, cfg[2*n+2:3*n+2]) 88 | self.bn = nn.BatchNorm2d(self.inplanes) 89 | self.select = channel_selection(self.inplanes) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.avgpool = nn.AvgPool2d(8) 92 | 93 | if dataset == 'cifar10': 94 | self.fc = nn.Linear(cfg[-1], 10) 95 | elif dataset == 'cifar100': 96 | self.fc = nn.Linear(cfg[-1], 100) 97 | 98 | # Weight initialization 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | m.weight.data.normal_(0, math.sqrt(2. / n)) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | m.weight.data.fill_(0.5) 105 | m.bias.data.zero_() 106 | 107 | def _make_denseblock(self, block, blocks, cfg): 108 | layers = [] 109 | assert blocks == len(cfg), 'Length of the cfg parameter is not right.' 110 | for i in range(blocks): 111 | # Currently we fix the expansion ratio as the default value 112 | layers.append(block(self.inplanes, cfg = cfg[i], growthRate=self.growthRate, dropRate=self.dropRate)) 113 | self.inplanes += self.growthRate 114 | 115 | return nn.Sequential(*layers) 116 | 117 | def _make_transition(self, compressionRate, cfg): 118 | # cfg is a number in this case. 119 | inplanes = self.inplanes 120 | outplanes = int(math.floor(self.inplanes // compressionRate)) 121 | self.inplanes = outplanes 122 | return Transition(inplanes, outplanes, cfg) 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | 127 | x = self.trans1(self.dense1(x)) 128 | x = self.trans2(self.dense2(x)) 129 | x = self.dense3(x) 130 | x = self.bn(x) 131 | x = self.select(x) 132 | x = self.relu(x) 133 | 134 | x = self.avgpool(x) 135 | x = x.view(x.size(0), -1) 136 | x = self.fc(x) 137 | 138 | return x -------------------------------------------------------------------------------- /mask-impl/prune_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from torchvision import datasets, transforms 8 | import models 9 | 10 | 11 | # Prune settings 12 | parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune') 13 | parser.add_argument('--dataset', type=str, default='cifar100', 14 | help='training dataset (default: cifar100)') 15 | parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', 16 | help='input batch size for testing (default: 100)') 17 | parser.add_argument('--no-cuda', action='store_true', default=False, 18 | help='disables CUDA training') 19 | parser.add_argument('--percent', type=float, default=0.5, 20 | help='scale sparse rate (default: 0.5)') 21 | parser.add_argument('--model', default='', type=str, metavar='PATH', 22 | help='path to raw trained model (default: none)') 23 | parser.add_argument('--save', default='.', type=str, metavar='PATH', 24 | help='path to save prune model (default: none)') 25 | parser.add_argument('--depth', default=19, type=int, 26 | help='depth of resnet and densenet') 27 | parser.add_argument('--arch', default='vgg', type=str, 28 | help='architecture to use') 29 | args = parser.parse_args() 30 | args.cuda = not args.no_cuda and torch.cuda.is_available() 31 | 32 | if not os.path.exists(args.save): 33 | os.makedirs(args.save) 34 | 35 | model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) 36 | 37 | if args.cuda: 38 | model.cuda() 39 | if args.model: 40 | if os.path.isfile(args.model): 41 | print("=> loading checkpoint '{}'".format(args.model)) 42 | checkpoint = torch.load(args.model) 43 | args.start_epoch = checkpoint['epoch'] 44 | best_prec1 = checkpoint['best_prec1'] 45 | model.load_state_dict(checkpoint['state_dict']) 46 | print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" 47 | .format(args.model, checkpoint['epoch'], best_prec1)) 48 | 49 | print(model) 50 | total = 0 51 | for m in model.modules(): 52 | if isinstance(m, nn.BatchNorm2d): 53 | total += m.weight.data.shape[0] 54 | 55 | bn = torch.zeros(total) 56 | index = 0 57 | for m in model.modules(): 58 | if isinstance(m, nn.BatchNorm2d): 59 | size = m.weight.data.shape[0] 60 | bn[index:(index+size)] = m.weight.data.abs().clone() 61 | index += size 62 | 63 | y, i = torch.sort(bn) 64 | thre_index = int(total * args.percent) 65 | thre = y[thre_index] 66 | 67 | pruned = 0 68 | cfg = [] 69 | cfg_mask = [] 70 | for k, m in enumerate(model.modules()): 71 | if isinstance(m, nn.BatchNorm2d): 72 | weight_copy = m.weight.data.abs().clone() 73 | mask = weight_copy.gt(thre).float().cuda() 74 | pruned = pruned + mask.shape[0] - torch.sum(mask) 75 | m.weight.data.mul_(mask) 76 | m.bias.data.mul_(mask) 77 | cfg.append(int(torch.sum(mask))) 78 | cfg_mask.append(mask.clone()) 79 | print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. 80 | format(k, mask.shape[0], int(torch.sum(mask)))) 81 | elif isinstance(m, nn.MaxPool2d): 82 | cfg.append('M') 83 | 84 | torch.save({'cfg': cfg, 'state_dict': model.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) 85 | 86 | pruned_ratio = pruned/total 87 | 88 | print('Pre-processing Successful!') 89 | 90 | 91 | # simple test model after Pre-processing prune (simple set BN scales to zeros) 92 | def test(): 93 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 94 | if args.dataset == 'cifar10': 95 | test_loader = torch.utils.data.DataLoader( 96 | datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 99 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 100 | elif args.dataset == 'cifar100': 101 | test_loader = torch.utils.data.DataLoader( 102 | datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ 103 | transforms.ToTensor(), 104 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 105 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 106 | else: 107 | raise ValueError("No valid dataset is given.") 108 | model.eval() 109 | correct = 0 110 | for data, target in test_loader: 111 | if args.cuda: 112 | data, target = data.cuda(), target.cuda() 113 | data, target = Variable(data, volatile=True), Variable(target) 114 | output = model(data) 115 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 116 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 117 | 118 | print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( 119 | correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) 120 | return correct / float(len(test_loader.dataset)) 121 | 122 | acc = test() 123 | print(cfg) 124 | 125 | savepath = os.path.join(args.save, "prune.txt") 126 | with open(savepath, "w") as fp: 127 | fp.write("Configuration: \n") 128 | fp.write(str(cfg)+"\n") 129 | fp.write("Test accuracy: \n") 130 | fp.write(str(acc)) 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Network Slimming (Pytorch) 2 | 3 | This repository contains an official pytorch implementation for the following paper 4 | [Learning Efficient Convolutional Networks Through Network Slimming](http://openaccess.thecvf.com/content_iccv_2017/html/Liu_Learning_Efficient_Convolutional_ICCV_2017_paper.html) (ICCV 2017). 5 | [Zhuang Liu](https://liuzhuang13.github.io/), [Jianguo Li](https://sites.google.com/site/leeplus/), [Zhiqiang Shen](http://zhiqiangshen.com/), [Gao Huang](http://www.cs.cornell.edu/~gaohuang/), [Shoumeng Yan](https://scholar.google.com/citations?user=f0BtDUQAAAAJ&hl=en), [Changshui Zhang](http://bigeye.au.tsinghua.edu.cn/english/Introduction.html). 6 | 7 | Original implementation: [slimming](https://github.com/liuzhuang13/slimming) in Torch. 8 | The code is based on [pytorch-slimming](https://github.com/foolwood/pytorch-slimming). We add support for ResNet and DenseNet. 9 | 10 | Citation: 11 | ``` 12 | @InProceedings{Liu_2017_ICCV, 13 | author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui}, 14 | title = {Learning Efficient Convolutional Networks Through Network Slimming}, 15 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 16 | month = {Oct}, 17 | year = {2017} 18 | } 19 | ``` 20 | 21 | 22 | ## Dependencies 23 | torch v0.3.1, torchvision v0.2.0 24 | 25 | ## Channel Selection Layer 26 | We introduce `channel selection` layer to help the pruning of ResNet and DenseNet. This layer is easy to implement. It stores a parameter `indexes` which is initialized to an all-1 vector. During pruning, it will set some places to 0 which correspond to the pruned channels. 27 | 28 | ## Baseline 29 | 30 | The `dataset` argument specifies which dataset to use: `cifar10` or `cifar100`. The `arch` argument specifies the architecture to use: `vgg`,`resnet` or 31 | `densenet`. The depth is chosen to be the same as the networks used in the paper. 32 | ```shell 33 | python main.py --dataset cifar10 --arch vgg --depth 19 34 | python main.py --dataset cifar10 --arch resnet --depth 164 35 | python main.py --dataset cifar10 --arch densenet --depth 40 36 | ``` 37 | 38 | ## Train with Sparsity 39 | 40 | ```shell 41 | python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19 42 | python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164 43 | python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40 44 | ``` 45 | 46 | ## Prune 47 | 48 | ```shell 49 | python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 50 | python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 51 | python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 52 | ``` 53 | The pruned model will be named `pruned.pth.tar`. 54 | 55 | ## Fine-tune 56 | 57 | ```shell 58 | python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 --epochs 160 59 | ``` 60 | 61 | ## Results 62 | 63 | The results are fairly close to the original paper, whose results are produced by Torch. Note that due to different random seeds, there might be up to ~0.5%/1.5% fluctation on CIFAR-10/100 datasets in different runs, according to our experiences. 64 | ### CIFAR10 65 | | CIFAR10-Vgg | Baseline | Sparsity (1e-4) | Prune (70%) | Fine-tune-160(70%) | 66 | | :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: | 67 | | Top1 Accuracy (%) | 93.77 | 93.30 | 32.54 | 93.78 | 68 | | Parameters | 20.04M | 20.04M | 2.25M | 2.25M | 69 | 70 | | CIFAR10-Resnet-164 | Baseline | Sparsity (1e-5) | Prune(40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | 71 | | :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: | :----------------:| :--------------------:| 72 | | Top1 Accuracy (%) | 94.75 | 94.76 | 94.58 | 95.05 | 47.73 | 93.81 | 73 | | Parameters | 1.71M | 1.73M | 1.45M | 1.45M | 1.12M | 1.12M | 74 | 75 | | CIFAR10-Densenet-40 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | 76 | | :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: | :--------------------: | :-----------------:| 77 | | Top1 Accuracy (%) | 94.11 | 94.17 | 94.16 | 94.32 | 89.46 | 94.22 | 78 | | Parameters | 1.07M | 1.07M | 0.69M | 0.69M | 0.49M | 0.49M | 79 | 80 | ### CIFAR100 81 | | CIFAR100-Vgg | Baseline | Sparsity (1e-4) | Prune (50%) | Fine-tune-160(50%) | 82 | | :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: | 83 | | Top1 Accuracy (%) | 72.12 | 72.05 | 5.31 | 73.32 | 84 | | Parameters | 20.04M | 20.04M | 4.93M | 4.93M | 85 | 86 | | CIFAR100-Resnet-164 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | 87 | | :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: |:--------------------: | :-----------------:| 88 | | Top1 Accuracy (%) | 76.79 | 76.87 | 48.0 | 77.36 | --- | --- | 89 | | Parameters | 1.73M | 1.73M | 1.49M | 1.49M |--- | --- | 90 | 91 | Note: For results of pruning 60% of the channels for resnet164-cifar100, in this implementation, sometimes some layers are all pruned and there would be error. However, we also provide a [mask implementation](https://github.com/Eric-mingjie/network-slimming/tree/master/mask-impl) where we apply a mask to the scaling factor in BN layer. For mask implementaion, when pruning 60% of the channels in resnet164-cifar100, we can also train the pruned network. 92 | 93 | | CIFAR100-Densenet-40 | Baseline | Sparsity (1e-5) | Prune (40%) | Fine-tune-160(40%) | Prune(60%) | Fine-tune-160(60%) | 94 | | :---------------: | :------: | :--------------------------: | :-----------------: | :-------------------: |:--------------------: | :-----------------:| 95 | | Top1 Accuracy (%) | 73.27 | 73.29 | 67.67 | 73.76 | 19.18 | 73.19 | 96 | | Parameters | 1.10M | 1.10M | 0.71M | 0.71M | 0.50M | 0.50M | 97 | 98 | ## Contact 99 | sunmj15 at gmail.com 100 | liuzhuangthu at gmail.com 101 | -------------------------------------------------------------------------------- /vggprune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from torchvision import datasets, transforms 8 | from models import * 9 | 10 | 11 | # Prune settings 12 | parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune') 13 | parser.add_argument('--dataset', type=str, default='cifar100', 14 | help='training dataset (default: cifar10)') 15 | parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', 16 | help='input batch size for testing (default: 256)') 17 | parser.add_argument('--no-cuda', action='store_true', default=False, 18 | help='disables CUDA training') 19 | parser.add_argument('--depth', type=int, default=19, 20 | help='depth of the vgg') 21 | parser.add_argument('--percent', type=float, default=0.5, 22 | help='scale sparse rate (default: 0.5)') 23 | parser.add_argument('--model', default='', type=str, metavar='PATH', 24 | help='path to the model (default: none)') 25 | parser.add_argument('--save', default='', type=str, metavar='PATH', 26 | help='path to save pruned model (default: none)') 27 | args = parser.parse_args() 28 | args.cuda = not args.no_cuda and torch.cuda.is_available() 29 | 30 | if not os.path.exists(args.save): 31 | os.makedirs(args.save) 32 | 33 | model = vgg(dataset=args.dataset, depth=args.depth) 34 | if args.cuda: 35 | model.cuda() 36 | 37 | if args.model: 38 | if os.path.isfile(args.model): 39 | print("=> loading checkpoint '{}'".format(args.model)) 40 | checkpoint = torch.load(args.model) 41 | args.start_epoch = checkpoint['epoch'] 42 | best_prec1 = checkpoint['best_prec1'] 43 | model.load_state_dict(checkpoint['state_dict']) 44 | print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" 45 | .format(args.model, checkpoint['epoch'], best_prec1)) 46 | else: 47 | print("=> no checkpoint found at '{}'".format(args.model)) 48 | 49 | print(model) 50 | total = 0 51 | for m in model.modules(): 52 | if isinstance(m, nn.BatchNorm2d): 53 | total += m.weight.data.shape[0] 54 | 55 | bn = torch.zeros(total) 56 | index = 0 57 | for m in model.modules(): 58 | if isinstance(m, nn.BatchNorm2d): 59 | size = m.weight.data.shape[0] 60 | bn[index:(index+size)] = m.weight.data.abs().clone() 61 | index += size 62 | 63 | y, i = torch.sort(bn) 64 | thre_index = int(total * args.percent) 65 | thre = y[thre_index] 66 | 67 | pruned = 0 68 | cfg = [] 69 | cfg_mask = [] 70 | for k, m in enumerate(model.modules()): 71 | if isinstance(m, nn.BatchNorm2d): 72 | weight_copy = m.weight.data.abs().clone() 73 | mask = weight_copy.gt(thre).float().cuda() 74 | pruned = pruned + mask.shape[0] - torch.sum(mask) 75 | m.weight.data.mul_(mask) 76 | m.bias.data.mul_(mask) 77 | cfg.append(int(torch.sum(mask))) 78 | cfg_mask.append(mask.clone()) 79 | print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. 80 | format(k, mask.shape[0], int(torch.sum(mask)))) 81 | elif isinstance(m, nn.MaxPool2d): 82 | cfg.append('M') 83 | 84 | pruned_ratio = pruned/total 85 | 86 | print('Pre-processing Successful!') 87 | 88 | # simple test model after Pre-processing prune (simple set BN scales to zeros) 89 | def test(model): 90 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 91 | if args.dataset == 'cifar10': 92 | test_loader = torch.utils.data.DataLoader( 93 | datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 96 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 97 | elif args.dataset == 'cifar100': 98 | test_loader = torch.utils.data.DataLoader( 99 | datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 102 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 103 | else: 104 | raise ValueError("No valid dataset is given.") 105 | model.eval() 106 | correct = 0 107 | for data, target in test_loader: 108 | if args.cuda: 109 | data, target = data.cuda(), target.cuda() 110 | data, target = Variable(data, volatile=True), Variable(target) 111 | output = model(data) 112 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 113 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 114 | 115 | print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( 116 | correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) 117 | return correct / float(len(test_loader.dataset)) 118 | 119 | acc = test(model) 120 | 121 | # Make real prune 122 | print(cfg) 123 | newmodel = vgg(dataset=args.dataset, cfg=cfg) 124 | if args.cuda: 125 | newmodel.cuda() 126 | 127 | num_parameters = sum([param.nelement() for param in newmodel.parameters()]) 128 | savepath = os.path.join(args.save, "prune.txt") 129 | with open(savepath, "w") as fp: 130 | fp.write("Configuration: \n"+str(cfg)+"\n") 131 | fp.write("Number of parameters: \n"+str(num_parameters)+"\n") 132 | fp.write("Test accuracy: \n"+str(acc)) 133 | 134 | layer_id_in_cfg = 0 135 | start_mask = torch.ones(3) 136 | end_mask = cfg_mask[layer_id_in_cfg] 137 | for [m0, m1] in zip(model.modules(), newmodel.modules()): 138 | if isinstance(m0, nn.BatchNorm2d): 139 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 140 | if idx1.size == 1: 141 | idx1 = np.resize(idx1,(1,)) 142 | m1.weight.data = m0.weight.data[idx1.tolist()].clone() 143 | m1.bias.data = m0.bias.data[idx1.tolist()].clone() 144 | m1.running_mean = m0.running_mean[idx1.tolist()].clone() 145 | m1.running_var = m0.running_var[idx1.tolist()].clone() 146 | layer_id_in_cfg += 1 147 | start_mask = end_mask.clone() 148 | if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC 149 | end_mask = cfg_mask[layer_id_in_cfg] 150 | elif isinstance(m0, nn.Conv2d): 151 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 152 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 153 | print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) 154 | if idx0.size == 1: 155 | idx0 = np.resize(idx0, (1,)) 156 | if idx1.size == 1: 157 | idx1 = np.resize(idx1, (1,)) 158 | w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() 159 | w1 = w1[idx1.tolist(), :, :, :].clone() 160 | m1.weight.data = w1.clone() 161 | elif isinstance(m0, nn.Linear): 162 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 163 | if idx0.size == 1: 164 | idx0 = np.resize(idx0, (1,)) 165 | m1.weight.data = m0.weight.data[:, idx0].clone() 166 | m1.bias.data = m0.bias.data.clone() 167 | 168 | torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) 169 | 170 | print(newmodel) 171 | model = newmodel 172 | test(model) 173 | -------------------------------------------------------------------------------- /denseprune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from torchvision import datasets, transforms 8 | from models import * 9 | 10 | 11 | # Prune settings 12 | parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune') 13 | parser.add_argument('--dataset', type=str, default='cifar100', 14 | help='training dataset (default: cifar10)') 15 | parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', 16 | help='input batch size for testing (default: 256)') 17 | parser.add_argument('--no-cuda', action='store_true', default=False, 18 | help='disables CUDA training') 19 | parser.add_argument('--depth', type=int, default=40, 20 | help='depth of the resnet') 21 | parser.add_argument('--percent', type=float, default=0.5, 22 | help='scale sparse rate (default: 0.5)') 23 | parser.add_argument('--model', default='', type=str, metavar='PATH', 24 | help='path to the model (default: none)') 25 | parser.add_argument('--save', default='', type=str, metavar='PATH', 26 | help='path to save pruned model (default: none)') 27 | 28 | args = parser.parse_args() 29 | args.cuda = not args.no_cuda and torch.cuda.is_available() 30 | 31 | if not os.path.exists(args.save): 32 | os.makedirs(args.save) 33 | 34 | model = densenet(depth=args.depth, dataset=args.dataset) 35 | 36 | if args.cuda: 37 | model.cuda() 38 | if args.model: 39 | if os.path.isfile(args.model): 40 | print("=> loading checkpoint '{}'".format(args.model)) 41 | checkpoint = torch.load(args.model) 42 | args.start_epoch = checkpoint['epoch'] 43 | best_prec1 = checkpoint['best_prec1'] 44 | model.load_state_dict(checkpoint['state_dict']) 45 | print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" 46 | .format(args.model, checkpoint['epoch'], best_prec1)) 47 | else: 48 | print("=> no checkpoint found at '{}'".format(args.resume)) 49 | 50 | total = 0 51 | for m in model.modules(): 52 | if isinstance(m, nn.BatchNorm2d): 53 | total += m.weight.data.shape[0] 54 | 55 | bn = torch.zeros(total) 56 | index = 0 57 | for m in model.modules(): 58 | if isinstance(m, nn.BatchNorm2d): 59 | size = m.weight.data.shape[0] 60 | bn[index:(index+size)] = m.weight.data.abs().clone() 61 | index += size 62 | 63 | y, i = torch.sort(bn) 64 | thre_index = int(total * args.percent) 65 | thre = y[thre_index] 66 | 67 | pruned = 0 68 | cfg = [] 69 | cfg_mask = [] 70 | for k, m in enumerate(model.modules()): 71 | if isinstance(m, nn.BatchNorm2d): 72 | weight_copy = m.weight.data.abs().clone() 73 | mask = weight_copy.gt(thre).float().cuda() 74 | pruned = pruned + mask.shape[0] - torch.sum(mask) 75 | m.weight.data.mul_(mask) 76 | m.bias.data.mul_(mask) 77 | cfg.append(int(torch.sum(mask))) 78 | cfg_mask.append(mask.clone()) 79 | print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. 80 | format(k, mask.shape[0], int(torch.sum(mask)))) 81 | elif isinstance(m, nn.MaxPool2d): 82 | cfg.append('M') 83 | 84 | pruned_ratio = pruned/total 85 | 86 | print('Pre-processing Successful!') 87 | 88 | # simple test model after Pre-processing prune (simple set BN scales to zeros) 89 | def test(model): 90 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 91 | if args.dataset == 'cifar10': 92 | test_loader = torch.utils.data.DataLoader( 93 | datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 96 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 97 | elif args.dataset == 'cifar100': 98 | test_loader = torch.utils.data.DataLoader( 99 | datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 102 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 103 | else: 104 | raise ValueError("No valid dataset is given.") 105 | model.eval() 106 | correct = 0 107 | for data, target in test_loader: 108 | if args.cuda: 109 | data, target = data.cuda(), target.cuda() 110 | data, target = Variable(data, volatile=True), Variable(target) 111 | output = model(data) 112 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 113 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 114 | 115 | print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( 116 | correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) 117 | return correct / float(len(test_loader.dataset)) 118 | 119 | acc = test(model) 120 | 121 | print("Cfg:") 122 | print(cfg) 123 | 124 | newmodel = densenet(depth=args.depth, dataset=args.dataset, cfg=cfg) 125 | 126 | if args.cuda: 127 | newmodel.cuda() 128 | 129 | num_parameters = sum([param.nelement() for param in newmodel.parameters()]) 130 | savepath = os.path.join(args.save, "prune.txt") 131 | with open(savepath, "w") as fp: 132 | fp.write("Configuration: \n"+str(cfg)+"\n") 133 | fp.write("Number of parameters: \n"+str(num_parameters)+"\n") 134 | fp.write("Test accuracy: \n"+str(acc)) 135 | 136 | old_modules = list(model.modules()) 137 | new_modules = list(newmodel.modules()) 138 | 139 | layer_id_in_cfg = 0 140 | start_mask = torch.ones(3) 141 | end_mask = cfg_mask[layer_id_in_cfg] 142 | first_conv = True 143 | 144 | for layer_id in range(len(old_modules)): 145 | m0 = old_modules[layer_id] 146 | m1 = new_modules[layer_id] 147 | if isinstance(m0, nn.BatchNorm2d): 148 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 149 | if idx1.size == 1: 150 | idx1 = np.resize(idx1,(1,)) 151 | 152 | if isinstance(old_modules[layer_id + 1], channel_selection): 153 | # If the next layer is the channel selection layer, then the current batch normalization layer won't be pruned. 154 | m1.weight.data = m0.weight.data.clone() 155 | m1.bias.data = m0.bias.data.clone() 156 | m1.running_mean = m0.running_mean.clone() 157 | m1.running_var = m0.running_var.clone() 158 | 159 | # We need to set the mask parameter `indexes` for the channel selection layer. 160 | m2 = new_modules[layer_id + 1] 161 | m2.indexes.data.zero_() 162 | m2.indexes.data[idx1.tolist()] = 1.0 163 | 164 | layer_id_in_cfg += 1 165 | start_mask = end_mask.clone() 166 | if layer_id_in_cfg < len(cfg_mask): 167 | end_mask = cfg_mask[layer_id_in_cfg] 168 | continue 169 | 170 | elif isinstance(m0, nn.Conv2d): 171 | if first_conv: 172 | # We don't change the first convolution layer. 173 | m1.weight.data = m0.weight.data.clone() 174 | first_conv = False 175 | continue 176 | if isinstance(old_modules[layer_id - 1], channel_selection): 177 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 178 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 179 | print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) 180 | if idx0.size == 1: 181 | idx0 = np.resize(idx0, (1,)) 182 | if idx1.size == 1: 183 | idx1 = np.resize(idx1, (1,)) 184 | 185 | # If the last layer is channel selection layer, then we don't change the number of output channels of the current 186 | # convolutional layer. 187 | w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() 188 | m1.weight.data = w1.clone() 189 | continue 190 | 191 | elif isinstance(m0, nn.Linear): 192 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 193 | if idx0.size == 1: 194 | idx0 = np.resize(idx0, (1,)) 195 | 196 | m1.weight.data = m0.weight.data[:, idx0].clone() 197 | m1.bias.data = m0.bias.data.clone() 198 | 199 | torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) 200 | 201 | print(newmodel) 202 | model = newmodel 203 | test(model) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import shutil 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | import models 13 | 14 | 15 | # Training settings 16 | parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR training') 17 | parser.add_argument('--dataset', type=str, default='cifar100', 18 | help='training dataset (default: cifar100)') 19 | parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true', 20 | help='train with channel sparsity regularization') 21 | parser.add_argument('--s', type=float, default=0.0001, 22 | help='scale sparse rate (default: 0.0001)') 23 | parser.add_argument('--refine', default='', type=str, metavar='PATH', 24 | help='path to the pruned model to be fine tuned') 25 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 26 | help='input batch size for training (default: 64)') 27 | parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', 28 | help='input batch size for testing (default: 256)') 29 | parser.add_argument('--epochs', type=int, default=160, metavar='N', 30 | help='number of epochs to train (default: 160)') 31 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 32 | help='manual epoch number (useful on restarts)') 33 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 34 | help='learning rate (default: 0.1)') 35 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 36 | help='SGD momentum (default: 0.9)') 37 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 38 | metavar='W', help='weight decay (default: 1e-4)') 39 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 40 | help='path to latest checkpoint (default: none)') 41 | parser.add_argument('--no-cuda', action='store_true', default=False, 42 | help='disables CUDA training') 43 | parser.add_argument('--seed', type=int, default=1, metavar='S', 44 | help='random seed (default: 1)') 45 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 46 | help='how many batches to wait before logging training status') 47 | parser.add_argument('--save', default='./logs', type=str, metavar='PATH', 48 | help='path to save prune model (default: current directory)') 49 | parser.add_argument('--arch', default='vgg', type=str, 50 | help='architecture to use') 51 | parser.add_argument('--depth', default=19, type=int, 52 | help='depth of the neural network') 53 | 54 | args = parser.parse_args() 55 | args.cuda = not args.no_cuda and torch.cuda.is_available() 56 | 57 | torch.manual_seed(args.seed) 58 | if args.cuda: 59 | torch.cuda.manual_seed(args.seed) 60 | 61 | if not os.path.exists(args.save): 62 | os.makedirs(args.save) 63 | 64 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 65 | if args.dataset == 'cifar10': 66 | train_loader = torch.utils.data.DataLoader( 67 | datasets.CIFAR10('./data.cifar10', train=True, download=True, 68 | transform=transforms.Compose([ 69 | transforms.Pad(4), 70 | transforms.RandomCrop(32), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 74 | ])), 75 | batch_size=args.batch_size, shuffle=True, **kwargs) 76 | test_loader = torch.utils.data.DataLoader( 77 | datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 80 | ])), 81 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 82 | else: 83 | train_loader = torch.utils.data.DataLoader( 84 | datasets.CIFAR100('./data.cifar100', train=True, download=True, 85 | transform=transforms.Compose([ 86 | transforms.Pad(4), 87 | transforms.RandomCrop(32), 88 | transforms.RandomHorizontalFlip(), 89 | transforms.ToTensor(), 90 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 91 | ])), 92 | batch_size=args.batch_size, shuffle=True, **kwargs) 93 | test_loader = torch.utils.data.DataLoader( 94 | datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ 95 | transforms.ToTensor(), 96 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 97 | ])), 98 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 99 | 100 | if args.refine: 101 | checkpoint = torch.load(args.refine) 102 | model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg']) 103 | model.load_state_dict(checkpoint['state_dict']) 104 | else: 105 | model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) 106 | 107 | if args.cuda: 108 | model.cuda() 109 | 110 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 111 | 112 | if args.resume: 113 | if os.path.isfile(args.resume): 114 | print("=> loading checkpoint '{}'".format(args.resume)) 115 | checkpoint = torch.load(args.resume) 116 | args.start_epoch = checkpoint['epoch'] 117 | best_prec1 = checkpoint['best_prec1'] 118 | model.load_state_dict(checkpoint['state_dict']) 119 | optimizer.load_state_dict(checkpoint['optimizer']) 120 | print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" 121 | .format(args.resume, checkpoint['epoch'], best_prec1)) 122 | else: 123 | print("=> no checkpoint found at '{}'".format(args.resume)) 124 | 125 | # additional subgradient descent on the sparsity-induced penalty term 126 | def updateBN(): 127 | for m in model.modules(): 128 | if isinstance(m, nn.BatchNorm2d): 129 | m.weight.grad.data.add_(args.s*torch.sign(m.weight.data)) # L1 130 | 131 | def train(epoch): 132 | model.train() 133 | for batch_idx, (data, target) in enumerate(train_loader): 134 | if args.cuda: 135 | data, target = data.cuda(), target.cuda() 136 | data, target = Variable(data), Variable(target) 137 | optimizer.zero_grad() 138 | output = model(data) 139 | loss = F.cross_entropy(output, target) 140 | pred = output.data.max(1, keepdim=True)[1] 141 | loss.backward() 142 | if args.sr: 143 | updateBN() 144 | optimizer.step() 145 | if batch_idx % args.log_interval == 0: 146 | print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format( 147 | epoch, batch_idx * len(data), len(train_loader.dataset), 148 | 100. * batch_idx / len(train_loader), loss.data[0])) 149 | 150 | def test(): 151 | model.eval() 152 | test_loss = 0 153 | correct = 0 154 | for data, target in test_loader: 155 | if args.cuda: 156 | data, target = data.cuda(), target.cuda() 157 | data, target = Variable(data, volatile=True), Variable(target) 158 | output = model(data) 159 | test_loss += F.cross_entropy(output, target, size_average=False).data[0] # sum up batch loss 160 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 161 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 162 | 163 | test_loss /= len(test_loader.dataset) 164 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format( 165 | test_loss, correct, len(test_loader.dataset), 166 | 100. * correct / len(test_loader.dataset))) 167 | return correct / float(len(test_loader.dataset)) 168 | 169 | def save_checkpoint(state, is_best, filepath): 170 | torch.save(state, os.path.join(filepath, 'checkpoint.pth.tar')) 171 | if is_best: 172 | shutil.copyfile(os.path.join(filepath, 'checkpoint.pth.tar'), os.path.join(filepath, 'model_best.pth.tar')) 173 | 174 | best_prec1 = 0. 175 | for epoch in range(args.start_epoch, args.epochs): 176 | if epoch in [args.epochs*0.5, args.epochs*0.75]: 177 | for param_group in optimizer.param_groups: 178 | param_group['lr'] *= 0.1 179 | train(epoch) 180 | prec1 = test() 181 | is_best = prec1 > best_prec1 182 | best_prec1 = max(prec1, best_prec1) 183 | save_checkpoint({ 184 | 'epoch': epoch + 1, 185 | 'state_dict': model.state_dict(), 186 | 'best_prec1': best_prec1, 187 | 'optimizer': optimizer.state_dict(), 188 | }, is_best, filepath=args.save) 189 | 190 | print("Best accuracy: "+str(best_prec1)) -------------------------------------------------------------------------------- /mask-impl/main_mask.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import math 4 | import os 5 | import shutil 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torchvision import datasets, transforms 12 | from torch.autograd import Variable 13 | import models 14 | 15 | 16 | # Training settings 17 | parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR training') 18 | parser.add_argument('--dataset', type=str, default='cifar100', 19 | help='training dataset (default: cifar100)') 20 | parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true', 21 | help='train with channel sparsity regularization') 22 | parser.add_argument('--s', type=float, default=0.0001, 23 | help='scale sparse rate (default: 0.0001)') 24 | parser.add_argument('--refine', default='', type=str, metavar='PATH', 25 | help='refine from prune model') 26 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 27 | help='input batch size for training (default: 100)') 28 | parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', 29 | help='input batch size for testing (default: 1000)') 30 | parser.add_argument('--epochs', type=int, default=160, metavar='N', 31 | help='number of epochs to train (default: 160)') 32 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 33 | help='manual epoch number (useful on restarts)') 34 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 35 | help='learning rate (default: 0.1)') 36 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 37 | help='SGD momentum (default: 0.9)') 38 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 39 | metavar='W', help='weight decay (default: 1e-4)') 40 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 41 | help='path to latest checkpoint (default: none)') 42 | parser.add_argument('--no-cuda', action='store_true', default=False, 43 | help='disables CUDA training') 44 | parser.add_argument('--seed', type=int, default=1, metavar='S', 45 | help='random seed (default: 1)') 46 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 47 | help='how many batches to wait before logging training status') 48 | parser.add_argument('--save', default='./logs', type=str, metavar='PATH', 49 | help='path to save prune model (default: current directory)') 50 | parser.add_argument('--arch', default='vgg', type=str, 51 | help='architecture to use') 52 | parser.add_argument('--depth', default=19, type=int, 53 | help='depth of the neural network') 54 | 55 | args = parser.parse_args() 56 | args.cuda = not args.no_cuda and torch.cuda.is_available() 57 | 58 | torch.manual_seed(args.seed) 59 | if args.cuda: 60 | torch.cuda.manual_seed(args.seed) 61 | 62 | if not os.path.exists(args.save): 63 | os.makedirs(args.save) 64 | 65 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 66 | if args.dataset == 'cifar10': 67 | train_loader = torch.utils.data.DataLoader( 68 | datasets.CIFAR10('./data.cifar10', train=True, download=True, 69 | transform=transforms.Compose([ 70 | transforms.Pad(4), 71 | transforms.RandomCrop(32), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 75 | ])), 76 | batch_size=args.batch_size, shuffle=True, **kwargs) 77 | test_loader = torch.utils.data.DataLoader( 78 | datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 81 | ])), 82 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 83 | else: 84 | train_loader = torch.utils.data.DataLoader( 85 | datasets.CIFAR100('./data.cifar100', train=True, download=True, 86 | transform=transforms.Compose([ 87 | transforms.Pad(4), 88 | transforms.RandomCrop(32), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 92 | ])), 93 | batch_size=args.batch_size, shuffle=True, **kwargs) 94 | test_loader = torch.utils.data.DataLoader( 95 | datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ 96 | transforms.ToTensor(), 97 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 98 | ])), 99 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 100 | 101 | model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) 102 | 103 | if args.refine: 104 | checkpoint = torch.load(args.refine) 105 | model.load_state_dict(checkpoint['state_dict']) 106 | 107 | if args.cuda: 108 | model.cuda() 109 | 110 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 111 | 112 | if args.resume: 113 | if os.path.isfile(args.resume): 114 | print("=> loading checkpoint '{}'".format(args.resume)) 115 | checkpoint = torch.load(args.resume) 116 | args.start_epoch = checkpoint['epoch'] 117 | best_prec1 = checkpoint['best_prec1'] 118 | model.load_state_dict(checkpoint['state_dict']) 119 | optimizer.load_state_dict(checkpoint['optimizer']) 120 | print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" 121 | .format(args.resume, checkpoint['epoch'], best_prec1)) 122 | else: 123 | print("=> no checkpoint found at '{}'".format(args.resume)) 124 | 125 | # additional subgradient descent on the sparsity-induced penalty term 126 | def updateBN(): 127 | for m in model.modules(): 128 | if isinstance(m, nn.BatchNorm2d): 129 | m.weight.grad.data.add_(args.s*torch.sign(m.weight.data)) # L1 130 | 131 | def BN_grad_zero(): 132 | for m in model.modules(): 133 | if isinstance(m, nn.BatchNorm2d): 134 | mask = (m.weight.data != 0) 135 | mask = mask.float().cuda() 136 | m.weight.grad.data.mul_(mask) 137 | m.bias.grad.data.mul_(mask) 138 | 139 | def train(epoch): 140 | model.train() 141 | for batch_idx, (data, target) in enumerate(train_loader): 142 | if args.cuda: 143 | data, target = data.cuda(), target.cuda() 144 | data, target = Variable(data), Variable(target) 145 | optimizer.zero_grad() 146 | output = model(data) 147 | loss = F.cross_entropy(output, target) 148 | pred = output.data.max(1, keepdim=True)[1] 149 | loss.backward() 150 | if args.sr: 151 | updateBN() 152 | BN_grad_zero() 153 | optimizer.step() 154 | if batch_idx % args.log_interval == 0: 155 | print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format( 156 | epoch, batch_idx * len(data), len(train_loader.dataset), 157 | 100. * batch_idx / len(train_loader), loss.data[0])) 158 | 159 | def test(): 160 | model.eval() 161 | test_loss = 0 162 | correct = 0 163 | for data, target in test_loader: 164 | if args.cuda: 165 | data, target = data.cuda(), target.cuda() 166 | data, target = Variable(data, volatile=True), Variable(target) 167 | output = model(data) 168 | test_loss += F.cross_entropy(output, target, size_average=False).data[0] # sum up batch loss 169 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 170 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 171 | 172 | test_loss /= len(test_loader.dataset) 173 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format( 174 | test_loss, correct, len(test_loader.dataset), 175 | 100. * correct / len(test_loader.dataset))) 176 | return correct / float(len(test_loader.dataset)) 177 | 178 | 179 | def save_checkpoint(state, is_best, filepath): 180 | torch.save(state, os.path.join(filepath, 'checkpoint.pth.tar')) 181 | if is_best: 182 | shutil.copyfile(os.path.join(filepath, 'checkpoint.pth.tar'), os.path.join(filepath, 'model_best.pth.tar')) 183 | 184 | best_prec1 = 0. 185 | for epoch in range(args.start_epoch, args.epochs): 186 | if epoch in [args.epochs*0.5, args.epochs*0.75]: 187 | for param_group in optimizer.param_groups: 188 | param_group['lr'] *= 0.1 189 | train(epoch) 190 | prec1 = test() 191 | is_best = prec1 > best_prec1 192 | best_prec1 = max(prec1, best_prec1) 193 | save_checkpoint({ 194 | 'epoch': epoch + 1, 195 | 'state_dict': model.state_dict(), 196 | 'best_prec1': best_prec1, 197 | 'optimizer': optimizer.state_dict(), 198 | }, is_best, filepath=args.save) 199 | 200 | print("Best accuracy: "+str(best_prec1)) -------------------------------------------------------------------------------- /resprune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from torchvision import datasets, transforms 8 | from models import * 9 | 10 | 11 | # Prune settings 12 | parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune') 13 | parser.add_argument('--dataset', type=str, default='cifar100', 14 | help='training dataset (default: cifar10)') 15 | parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', 16 | help='input batch size for testing (default: 256)') 17 | parser.add_argument('--no-cuda', action='store_true', default=False, 18 | help='disables CUDA training') 19 | parser.add_argument('--depth', type=int, default=164, 20 | help='depth of the resnet') 21 | parser.add_argument('--percent', type=float, default=0.5, 22 | help='scale sparse rate (default: 0.5)') 23 | parser.add_argument('--model', default='', type=str, metavar='PATH', 24 | help='path to the model (default: none)') 25 | parser.add_argument('--save', default='', type=str, metavar='PATH', 26 | help='path to save pruned model (default: none)') 27 | 28 | args = parser.parse_args() 29 | args.cuda = not args.no_cuda and torch.cuda.is_available() 30 | 31 | if not os.path.exists(args.save): 32 | os.makedirs(args.save) 33 | 34 | model = resnet(depth=args.depth, dataset=args.dataset) 35 | 36 | if args.cuda: 37 | model.cuda() 38 | if args.model: 39 | if os.path.isfile(args.model): 40 | print("=> loading checkpoint '{}'".format(args.model)) 41 | checkpoint = torch.load(args.model) 42 | args.start_epoch = checkpoint['epoch'] 43 | best_prec1 = checkpoint['best_prec1'] 44 | model.load_state_dict(checkpoint['state_dict']) 45 | print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" 46 | .format(args.model, checkpoint['epoch'], best_prec1)) 47 | else: 48 | print("=> no checkpoint found at '{}'".format(args.resume)) 49 | 50 | total = 0 51 | 52 | for m in model.modules(): 53 | if isinstance(m, nn.BatchNorm2d): 54 | total += m.weight.data.shape[0] 55 | 56 | bn = torch.zeros(total) 57 | index = 0 58 | for m in model.modules(): 59 | if isinstance(m, nn.BatchNorm2d): 60 | size = m.weight.data.shape[0] 61 | bn[index:(index+size)] = m.weight.data.abs().clone() 62 | index += size 63 | 64 | y, i = torch.sort(bn) 65 | thre_index = int(total * args.percent) 66 | thre = y[thre_index] 67 | 68 | 69 | pruned = 0 70 | cfg = [] 71 | cfg_mask = [] 72 | for k, m in enumerate(model.modules()): 73 | if isinstance(m, nn.BatchNorm2d): 74 | weight_copy = m.weight.data.abs().clone() 75 | mask = weight_copy.gt(thre).float().cuda() 76 | pruned = pruned + mask.shape[0] - torch.sum(mask) 77 | m.weight.data.mul_(mask) 78 | m.bias.data.mul_(mask) 79 | cfg.append(int(torch.sum(mask))) 80 | cfg_mask.append(mask.clone()) 81 | print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. 82 | format(k, mask.shape[0], int(torch.sum(mask)))) 83 | elif isinstance(m, nn.MaxPool2d): 84 | cfg.append('M') 85 | 86 | pruned_ratio = pruned/total 87 | 88 | print('Pre-processing Successful!') 89 | 90 | # simple test model after Pre-processing prune (simple set BN scales to zeros) 91 | def test(model): 92 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 93 | if args.dataset == 'cifar10': 94 | test_loader = torch.utils.data.DataLoader( 95 | datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ 96 | transforms.ToTensor(), 97 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 98 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 99 | elif args.dataset == 'cifar100': 100 | test_loader = torch.utils.data.DataLoader( 101 | datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ 102 | transforms.ToTensor(), 103 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), 104 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 105 | else: 106 | raise ValueError("No valid dataset is given.") 107 | model.eval() 108 | correct = 0 109 | for data, target in test_loader: 110 | if args.cuda: 111 | data, target = data.cuda(), target.cuda() 112 | data, target = Variable(data, volatile=True), Variable(target) 113 | output = model(data) 114 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 115 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 116 | 117 | print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( 118 | correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) 119 | return correct / float(len(test_loader.dataset)) 120 | 121 | acc = test(model) 122 | 123 | print("Cfg:") 124 | print(cfg) 125 | 126 | newmodel = resnet(depth=args.depth, dataset=args.dataset, cfg=cfg) 127 | if args.cuda: 128 | newmodel.cuda() 129 | 130 | num_parameters = sum([param.nelement() for param in newmodel.parameters()]) 131 | savepath = os.path.join(args.save, "prune.txt") 132 | with open(savepath, "w") as fp: 133 | fp.write("Configuration: \n"+str(cfg)+"\n") 134 | fp.write("Number of parameters: \n"+str(num_parameters)+"\n") 135 | fp.write("Test accuracy: \n"+str(acc)) 136 | 137 | old_modules = list(model.modules()) 138 | new_modules = list(newmodel.modules()) 139 | layer_id_in_cfg = 0 140 | start_mask = torch.ones(3) 141 | end_mask = cfg_mask[layer_id_in_cfg] 142 | conv_count = 0 143 | 144 | for layer_id in range(len(old_modules)): 145 | m0 = old_modules[layer_id] 146 | m1 = new_modules[layer_id] 147 | if isinstance(m0, nn.BatchNorm2d): 148 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 149 | if idx1.size == 1: 150 | idx1 = np.resize(idx1,(1,)) 151 | 152 | if isinstance(old_modules[layer_id + 1], channel_selection): 153 | # If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned. 154 | m1.weight.data = m0.weight.data.clone() 155 | m1.bias.data = m0.bias.data.clone() 156 | m1.running_mean = m0.running_mean.clone() 157 | m1.running_var = m0.running_var.clone() 158 | 159 | # We need to set the channel selection layer. 160 | m2 = new_modules[layer_id + 1] 161 | m2.indexes.data.zero_() 162 | m2.indexes.data[idx1.tolist()] = 1.0 163 | 164 | layer_id_in_cfg += 1 165 | start_mask = end_mask.clone() 166 | if layer_id_in_cfg < len(cfg_mask): 167 | end_mask = cfg_mask[layer_id_in_cfg] 168 | else: 169 | m1.weight.data = m0.weight.data[idx1.tolist()].clone() 170 | m1.bias.data = m0.bias.data[idx1.tolist()].clone() 171 | m1.running_mean = m0.running_mean[idx1.tolist()].clone() 172 | m1.running_var = m0.running_var[idx1.tolist()].clone() 173 | layer_id_in_cfg += 1 174 | start_mask = end_mask.clone() 175 | if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC 176 | end_mask = cfg_mask[layer_id_in_cfg] 177 | elif isinstance(m0, nn.Conv2d): 178 | if conv_count == 0: 179 | m1.weight.data = m0.weight.data.clone() 180 | conv_count += 1 181 | continue 182 | if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d): 183 | # This convers the convolutions in the residual block. 184 | # The convolutions are either after the channel selection layer or after the batch normalization layer. 185 | conv_count += 1 186 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 187 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 188 | print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) 189 | if idx0.size == 1: 190 | idx0 = np.resize(idx0, (1,)) 191 | if idx1.size == 1: 192 | idx1 = np.resize(idx1, (1,)) 193 | w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() 194 | 195 | # If the current convolution is not the last convolution in the residual block, then we can change the 196 | # number of output channels. Currently we use `conv_count` to detect whether it is such convolution. 197 | if conv_count % 3 != 1: 198 | w1 = w1[idx1.tolist(), :, :, :].clone() 199 | m1.weight.data = w1.clone() 200 | continue 201 | 202 | # We need to consider the case where there are downsampling convolutions. 203 | # For these convolutions, we just copy the weights. 204 | m1.weight.data = m0.weight.data.clone() 205 | elif isinstance(m0, nn.Linear): 206 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 207 | if idx0.size == 1: 208 | idx0 = np.resize(idx0, (1,)) 209 | 210 | m1.weight.data = m0.weight.data[:, idx0].clone() 211 | m1.bias.data = m0.bias.data.clone() 212 | 213 | torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) 214 | 215 | print(newmodel) 216 | model = newmodel 217 | test(model) --------------------------------------------------------------------------------