├── README.md ├── dataset.py ├── finetune.py └── prune.py /README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch implementation of [\[1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference\]](https://arxiv.org/abs/1611.06440) ## 2 | 3 | This demonstrates pruning a VGG16 based classifier that classifies a small dog/cat dataset. 4 | 5 | 6 | This was able to reduce the CPU runtime by x3 and the model size by x4. 7 | 8 | For more details you can read the [blog post](https://jacobgil.github.io/deeplearning/pruning-deep-learning). 9 | 10 | At each pruning step 512 filters are removed from the network. 11 | 12 | 13 | Usage 14 | ----- 15 | 16 | This repository uses the PyTorch ImageFolder loader, so it assumes that the images are in a different directory for each category. 17 | 18 | Train 19 | 20 | ......... dogs 21 | 22 | ......... cats 23 | 24 | 25 | Test 26 | 27 | 28 | ......... dogs 29 | 30 | ......... cats 31 | 32 | 33 | The images were taken from [here](https://www.kaggle.com/c/dogs-vs-cats) but you should try training this on your own data and see if it works! 34 | 35 | Training: 36 | `python finetune.py --train` 37 | 38 | Pruning: 39 | `python finetune.py --prune` 40 | 41 | TBD 42 | --- 43 | 44 | - Change the pruning to be done in one pass. Currently each of the 512 filters are pruned sequentually. 45 | ` 46 | for layer_index, filter_index in prune_targets: 47 | model = prune_vgg16_conv_layer(model, layer_index, filter_index) 48 | ` 49 | 50 | 51 | This is inefficient since allocating new layers, especially fully connected layers with lots of parameters, is slow. 52 | 53 | In principle this can be done in a single pass. 54 | 55 | 56 | 57 | - Change prune_vgg16_conv_layer to support additional architectures. 58 | The most immediate one would be VGG with batch norm. 59 | 60 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.optim as optim 7 | import torch.utils.data as data 8 | import torchvision.datasets as datasets 9 | import torchvision.models as models 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | import glob 13 | import os 14 | 15 | def loader(path, batch_size=32, num_workers=4, pin_memory=True): 16 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 17 | return data.DataLoader( 18 | datasets.ImageFolder(path, 19 | transforms.Compose([ 20 | transforms.Scale(256), 21 | transforms.RandomSizedCrop(224), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | normalize, 25 | ])), 26 | batch_size=batch_size, 27 | shuffle=True, 28 | num_workers=num_workers, 29 | pin_memory=pin_memory) 30 | 31 | def test_loader(path, batch_size=32, num_workers=4, pin_memory=True): 32 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 33 | return data.DataLoader( 34 | datasets.ImageFolder(path, 35 | transforms.Compose([ 36 | transforms.Scale(256), 37 | transforms.CenterCrop(224), 38 | transforms.ToTensor(), 39 | normalize, 40 | ])), 41 | batch_size=batch_size, 42 | shuffle=False, 43 | num_workers=num_workers, 44 | pin_memory=pin_memory) -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision import models 4 | import cv2 5 | import sys 6 | import numpy as np 7 | import torchvision 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import dataset 12 | from prune import * 13 | import argparse 14 | from operator import itemgetter 15 | from heapq import nsmallest 16 | import time 17 | 18 | class ModifiedVGG16Model(torch.nn.Module): 19 | def __init__(self): 20 | super(ModifiedVGG16Model, self).__init__() 21 | 22 | model = models.vgg16(pretrained=True) 23 | self.features = model.features 24 | 25 | for param in self.features.parameters(): 26 | param.requires_grad = False 27 | 28 | self.classifier = nn.Sequential( 29 | nn.Dropout(), 30 | nn.Linear(25088, 4096), 31 | nn.ReLU(inplace=True), 32 | nn.Dropout(), 33 | nn.Linear(4096, 4096), 34 | nn.ReLU(inplace=True), 35 | nn.Linear(4096, 2)) 36 | 37 | def forward(self, x): 38 | x = self.features(x) 39 | x = x.view(x.size(0), -1) 40 | x = self.classifier(x) 41 | return x 42 | 43 | class FilterPrunner: 44 | def __init__(self, model): 45 | self.model = model 46 | self.reset() 47 | 48 | def reset(self): 49 | self.filter_ranks = {} 50 | 51 | def forward(self, x): 52 | self.activations = [] 53 | self.gradients = [] 54 | self.grad_index = 0 55 | self.activation_to_layer = {} 56 | 57 | activation_index = 0 58 | for layer, (name, module) in enumerate(self.model.features._modules.items()): 59 | x = module(x) 60 | if isinstance(module, torch.nn.modules.conv.Conv2d): 61 | x.register_hook(self.compute_rank) 62 | self.activations.append(x) 63 | self.activation_to_layer[activation_index] = layer 64 | activation_index += 1 65 | 66 | return self.model.classifier(x.view(x.size(0), -1)) 67 | 68 | def compute_rank(self, grad): 69 | activation_index = len(self.activations) - self.grad_index - 1 70 | activation = self.activations[activation_index] 71 | 72 | taylor = activation * grad 73 | # Get the average value for every filter, 74 | # accross all the other dimensions 75 | taylor = taylor.mean(dim=(0, 2, 3)).data 76 | 77 | 78 | if activation_index not in self.filter_ranks: 79 | self.filter_ranks[activation_index] = \ 80 | torch.FloatTensor(activation.size(1)).zero_() 81 | 82 | if args.use_cuda: 83 | self.filter_ranks[activation_index] = self.filter_ranks[activation_index].cuda() 84 | 85 | self.filter_ranks[activation_index] += taylor 86 | self.grad_index += 1 87 | 88 | def lowest_ranking_filters(self, num): 89 | data = [] 90 | for i in sorted(self.filter_ranks.keys()): 91 | for j in range(self.filter_ranks[i].size(0)): 92 | data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j])) 93 | 94 | return nsmallest(num, data, itemgetter(2)) 95 | 96 | def normalize_ranks_per_layer(self): 97 | for i in self.filter_ranks: 98 | v = torch.abs(self.filter_ranks[i]) 99 | v = v / np.sqrt(torch.sum(v * v)) 100 | self.filter_ranks[i] = v.cpu() 101 | 102 | def get_prunning_plan(self, num_filters_to_prune): 103 | filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune) 104 | 105 | # After each of the k filters are prunned, 106 | # the filter index of the next filters change since the model is smaller. 107 | filters_to_prune_per_layer = {} 108 | for (l, f, _) in filters_to_prune: 109 | if l not in filters_to_prune_per_layer: 110 | filters_to_prune_per_layer[l] = [] 111 | filters_to_prune_per_layer[l].append(f) 112 | 113 | for l in filters_to_prune_per_layer: 114 | filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l]) 115 | for i in range(len(filters_to_prune_per_layer[l])): 116 | filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i 117 | 118 | filters_to_prune = [] 119 | for l in filters_to_prune_per_layer: 120 | for i in filters_to_prune_per_layer[l]: 121 | filters_to_prune.append((l, i)) 122 | 123 | return filters_to_prune 124 | 125 | class PrunningFineTuner_VGG16: 126 | def __init__(self, train_path, test_path, model): 127 | self.train_data_loader = dataset.loader(train_path) 128 | self.test_data_loader = dataset.test_loader(test_path) 129 | 130 | self.model = model 131 | self.criterion = torch.nn.CrossEntropyLoss() 132 | self.prunner = FilterPrunner(self.model) 133 | self.model.train() 134 | 135 | def test(self): 136 | return 137 | self.model.eval() 138 | correct = 0 139 | total = 0 140 | 141 | for i, (batch, label) in enumerate(self.test_data_loader): 142 | if args.use_cuda: 143 | batch = batch.cuda() 144 | output = model(Variable(batch)) 145 | pred = output.data.max(1)[1] 146 | correct += pred.cpu().eq(label).sum() 147 | total += label.size(0) 148 | 149 | print("Accuracy :", float(correct) / total) 150 | 151 | self.model.train() 152 | 153 | def train(self, optimizer = None, epoches=10): 154 | if optimizer is None: 155 | optimizer = optim.SGD(model.classifier.parameters(), lr=0.0001, momentum=0.9) 156 | 157 | for i in range(epoches): 158 | print("Epoch: ", i) 159 | self.train_epoch(optimizer) 160 | self.test() 161 | print("Finished fine tuning.") 162 | 163 | 164 | def train_batch(self, optimizer, batch, label, rank_filters): 165 | 166 | if args.use_cuda: 167 | batch = batch.cuda() 168 | label = label.cuda() 169 | 170 | self.model.zero_grad() 171 | input = Variable(batch) 172 | 173 | if rank_filters: 174 | output = self.prunner.forward(input) 175 | self.criterion(output, Variable(label)).backward() 176 | else: 177 | self.criterion(self.model(input), Variable(label)).backward() 178 | optimizer.step() 179 | 180 | def train_epoch(self, optimizer = None, rank_filters = False): 181 | for i, (batch, label) in enumerate(self.train_data_loader): 182 | self.train_batch(optimizer, batch, label, rank_filters) 183 | 184 | def get_candidates_to_prune(self, num_filters_to_prune): 185 | self.prunner.reset() 186 | self.train_epoch(rank_filters = True) 187 | self.prunner.normalize_ranks_per_layer() 188 | return self.prunner.get_prunning_plan(num_filters_to_prune) 189 | 190 | def total_num_filters(self): 191 | filters = 0 192 | for name, module in self.model.features._modules.items(): 193 | if isinstance(module, torch.nn.modules.conv.Conv2d): 194 | filters = filters + module.out_channels 195 | return filters 196 | 197 | def prune(self): 198 | #Get the accuracy before prunning 199 | self.test() 200 | self.model.train() 201 | 202 | #Make sure all the layers are trainable 203 | for param in self.model.features.parameters(): 204 | param.requires_grad = True 205 | 206 | number_of_filters = self.total_num_filters() 207 | num_filters_to_prune_per_iteration = 512 208 | iterations = int(float(number_of_filters) / num_filters_to_prune_per_iteration) 209 | 210 | iterations = int(iterations * 2.0 / 3) 211 | 212 | print("Number of prunning iterations to reduce 67% filters", iterations) 213 | 214 | for _ in range(iterations): 215 | print("Ranking filters.. ") 216 | prune_targets = self.get_candidates_to_prune(num_filters_to_prune_per_iteration) 217 | layers_prunned = {} 218 | for layer_index, filter_index in prune_targets: 219 | if layer_index not in layers_prunned: 220 | layers_prunned[layer_index] = 0 221 | layers_prunned[layer_index] = layers_prunned[layer_index] + 1 222 | 223 | print("Layers that will be prunned", layers_prunned) 224 | print("Prunning filters.. ") 225 | model = self.model.cpu() 226 | for layer_index, filter_index in prune_targets: 227 | model = prune_vgg16_conv_layer(model, layer_index, filter_index, use_cuda=args.use_cuda) 228 | 229 | self.model = model 230 | if args.use_cuda: 231 | self.model = self.model.cuda() 232 | 233 | message = str(100*float(self.total_num_filters()) / number_of_filters) + "%" 234 | print("Filters prunned", str(message)) 235 | self.test() 236 | print("Fine tuning to recover from prunning iteration.") 237 | optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) 238 | self.train(optimizer, epoches = 10) 239 | 240 | 241 | print("Finished. Going to fine tune the model a bit more") 242 | self.train(optimizer, epoches=15) 243 | torch.save(model.state_dict(), "model_prunned") 244 | 245 | def get_args(): 246 | parser = argparse.ArgumentParser() 247 | parser.add_argument("--train", dest="train", action="store_true") 248 | parser.add_argument("--prune", dest="prune", action="store_true") 249 | parser.add_argument("--train_path", type = str, default = "train") 250 | parser.add_argument("--test_path", type = str, default = "test") 251 | parser.add_argument('--use-cuda', action='store_true', default=False, help='Use NVIDIA GPU acceleration') 252 | parser.set_defaults(train=False) 253 | parser.set_defaults(prune=False) 254 | args = parser.parse_args() 255 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 256 | 257 | return args 258 | 259 | if __name__ == '__main__': 260 | args = get_args() 261 | 262 | if args.train: 263 | model = ModifiedVGG16Model() 264 | elif args.prune: 265 | model = torch.load("model", map_location=lambda storage, loc: storage) 266 | 267 | if args.use_cuda: 268 | model = model.cuda() 269 | 270 | fine_tuner = PrunningFineTuner_VGG16(args.train_path, args.test_path, model) 271 | 272 | if args.train: 273 | fine_tuner.train(epoches=10) 274 | torch.save(model, "model") 275 | 276 | elif args.prune: 277 | fine_tuner.prune() 278 | -------------------------------------------------------------------------------- /prune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision import models 4 | import cv2 5 | import sys 6 | import numpy as np 7 | 8 | def replace_layers(model, i, indexes, layers): 9 | if i in indexes: 10 | return layers[indexes.index(i)] 11 | return model[i] 12 | 13 | def prune_vgg16_conv_layer(model, layer_index, filter_index, use_cuda=False): 14 | _, conv = list(model.features._modules.items())[layer_index] 15 | next_conv = None 16 | offset = 1 17 | 18 | while layer_index + offset < len(model.features._modules.items()): 19 | res = list(model.features._modules.items())[layer_index+offset] 20 | if isinstance(res[1], torch.nn.modules.conv.Conv2d): 21 | next_name, next_conv = res 22 | break 23 | offset = offset + 1 24 | 25 | new_conv = \ 26 | torch.nn.Conv2d(in_channels = conv.in_channels, \ 27 | out_channels = conv.out_channels - 1, 28 | kernel_size = conv.kernel_size, \ 29 | stride = conv.stride, 30 | padding = conv.padding, 31 | dilation = conv.dilation, 32 | groups = conv.groups, 33 | bias = (conv.bias is not None)) 34 | 35 | old_weights = conv.weight.data.cpu().numpy() 36 | new_weights = new_conv.weight.data.cpu().numpy() 37 | 38 | new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :] 39 | new_weights[filter_index : , :, :, :] = old_weights[filter_index + 1 :, :, :, :] 40 | new_conv.weight.data = torch.from_numpy(new_weights) 41 | if use_cuda: 42 | new_conv.weight.data = new_conv.weight.data.cuda() 43 | 44 | bias_numpy = conv.bias.data.cpu().numpy() 45 | 46 | bias = np.zeros(shape = (bias_numpy.shape[0] - 1), dtype = np.float32) 47 | bias[:filter_index] = bias_numpy[:filter_index] 48 | bias[filter_index : ] = bias_numpy[filter_index + 1 :] 49 | new_conv.bias.data = torch.from_numpy(bias) 50 | if use_cuda: 51 | new_conv.bias.data = new_conv.bias.data.cuda() 52 | 53 | if not next_conv is None: 54 | next_new_conv = \ 55 | torch.nn.Conv2d(in_channels = next_conv.in_channels - 1,\ 56 | out_channels = next_conv.out_channels, \ 57 | kernel_size = next_conv.kernel_size, \ 58 | stride = next_conv.stride, 59 | padding = next_conv.padding, 60 | dilation = next_conv.dilation, 61 | groups = next_conv.groups, 62 | bias = (next_conv.bias is not None)) 63 | 64 | old_weights = next_conv.weight.data.cpu().numpy() 65 | new_weights = next_new_conv.weight.data.cpu().numpy() 66 | 67 | new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :] 68 | new_weights[:, filter_index : , :, :] = old_weights[:, filter_index + 1 :, :, :] 69 | next_new_conv.weight.data = torch.from_numpy(new_weights) 70 | if use_cuda: 71 | next_new_conv.weight.data = next_new_conv.weight.data.cuda() 72 | 73 | next_new_conv.bias.data = next_conv.bias.data 74 | 75 | if not next_conv is None: 76 | features = torch.nn.Sequential( 77 | *(replace_layers(model.features, i, [layer_index, layer_index+offset], \ 78 | [new_conv, next_new_conv]) for i, _ in enumerate(model.features))) 79 | del model.features 80 | del conv 81 | 82 | model.features = features 83 | 84 | else: 85 | #Prunning the last conv layer. This affects the first linear layer of the classifier. 86 | model.features = torch.nn.Sequential( 87 | *(replace_layers(model.features, i, [layer_index], \ 88 | [new_conv]) for i, _ in enumerate(model.features))) 89 | layer_index = 0 90 | old_linear_layer = None 91 | for _, module in model.classifier._modules.items(): 92 | if isinstance(module, torch.nn.Linear): 93 | old_linear_layer = module 94 | break 95 | layer_index = layer_index + 1 96 | 97 | if old_linear_layer is None: 98 | raise BaseException("No linear laye found in classifier") 99 | params_per_input_channel = old_linear_layer.in_features // conv.out_channels 100 | 101 | new_linear_layer = \ 102 | torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel, 103 | old_linear_layer.out_features) 104 | 105 | old_weights = old_linear_layer.weight.data.cpu().numpy() 106 | new_weights = new_linear_layer.weight.data.cpu().numpy() 107 | 108 | new_weights[:, : filter_index * params_per_input_channel] = \ 109 | old_weights[:, : filter_index * params_per_input_channel] 110 | new_weights[:, filter_index * params_per_input_channel :] = \ 111 | old_weights[:, (filter_index + 1) * params_per_input_channel :] 112 | 113 | new_linear_layer.bias.data = old_linear_layer.bias.data 114 | 115 | new_linear_layer.weight.data = torch.from_numpy(new_weights) 116 | if use_cuda: 117 | new_linear_layer.weight.data = new_linear_layer.weight.data.cuda() 118 | 119 | classifier = torch.nn.Sequential( 120 | *(replace_layers(model.classifier, i, [layer_index], \ 121 | [new_linear_layer]) for i, _ in enumerate(model.classifier))) 122 | 123 | del model.classifier 124 | del next_conv 125 | del conv 126 | model.classifier = classifier 127 | 128 | return model 129 | 130 | if __name__ == '__main__': 131 | model = models.vgg16(pretrained=True) 132 | model.train() 133 | 134 | t0 = time.time() 135 | model = prune_conv_layer(model, 28, 10) 136 | print("The prunning took", time.time() - t0) --------------------------------------------------------------------------------