├── README.md ├── filter_pruning.py ├── models.py ├── models ├── convnet_pretrained.pkl └── mlp_pretrained.pkl ├── pruning ├── __init__.py ├── layers.py ├── methods.py └── utils.py └── weight_pruning.py /README.md: -------------------------------------------------------------------------------- 1 | # Neural Network Pruning PyTorch Implementation 2 | 3 | Luyu Wang & Gavin Ding 4 | 5 | Borealis AI 6 | 7 | ## Motivation 8 | Neural network pruning has become a trendy research topic, but we haven't found an easy to use PyTorch implementation. We want to take advantage of the power of PyTorch and build pruned networks to study their properties. 9 | 10 | **Note**: this implementation is not aiming at obtaining computational efficiency but to offer convenience for studying properties of pruned networks. Discussions on how to have an efficient implementation is welcome. Thanks! 11 | 12 | ## High-level idea 13 | 1. We write [wrappers](https://github.com/wanglouis49/pytorch-weights_pruning/blob/master/pruning/layers.py) on PyTorch Linear and Conv2d layers. 14 | 2. For each layer, once a binary mask tensor is computed, it is multiplied with the actual weights tensor on the forward pass. 15 | 3. Multiplying the mask is a differentiable operation and the backward pass is handed by automatic differentiation (no explicit coding needed). 16 | 17 | ## Pruning methods 18 | 19 | ### Weight pruning 20 | Han et al propose to compress deep learning models via weights pruning [Han et al, NIPS 2015](http://papers.nips.cc/paper/5784-learning-both-weights-and-connections-for-efficient-neural-network). This repo is an implementation in PyTorch. The pruning method is replaced by the "class-blinded" method mentioned in [See et al, CoNLL 2016](https://arxiv.org/abs/1606.09274), which is much easier to implement and has better performance as well. 21 | 22 | ### Filter pruning 23 | Pruning convolution filters has the advantage that it is more hardware friendly. We also implement the "minimum weight'' approach in [Molchanov et al, ICLR 2017](https://arxiv.org/abs/1611.06440) 24 | 25 | -------------------------------------------------------------------------------- /filter_pruning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pruning a ConvNet by filters iteratively 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.datasets as datasets 8 | import torchvision.transforms as transforms 9 | 10 | from pruning.methods import filter_prune 11 | from pruning.utils import to_var, train, test, prune_rate 12 | from models import ConvNet 13 | 14 | 15 | # Hyper Parameters 16 | param = { 17 | 'pruning_perc': 50., 18 | 'batch_size': 128, 19 | 'test_batch_size': 100, 20 | 'num_epochs': 5, 21 | 'learning_rate': 0.001, 22 | 'weight_decay': 5e-4, 23 | } 24 | 25 | 26 | # Data loaders 27 | train_dataset = datasets.MNIST(root='../data/',train=True, download=True, 28 | transform=transforms.ToTensor()) 29 | loader_train = torch.utils.data.DataLoader(train_dataset, 30 | batch_size=param['batch_size'], shuffle=True) 31 | 32 | test_dataset = datasets.MNIST(root='../data/', train=False, download=True, 33 | transform=transforms.ToTensor()) 34 | loader_test = torch.utils.data.DataLoader(test_dataset, 35 | batch_size=param['test_batch_size'], shuffle=True) 36 | 37 | 38 | # Load the pretrained model 39 | net = ConvNet() 40 | net.load_state_dict(torch.load('models/convnet_pretrained.pkl')) 41 | if torch.cuda.is_available(): 42 | print('CUDA ensabled.') 43 | net.cuda() 44 | print("--- Pretrained network loaded ---") 45 | test(net, loader_test) 46 | 47 | # prune the weights 48 | masks = filter_prune(net, param['pruning_perc']) 49 | net.set_masks(masks) 50 | print("--- {}% parameters pruned ---".format(param['pruning_perc'])) 51 | test(net, loader_test) 52 | 53 | 54 | # Retraining 55 | criterion = nn.CrossEntropyLoss() 56 | optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'], 57 | weight_decay=param['weight_decay']) 58 | 59 | train(net, criterion, optimizer, param, loader_train) 60 | 61 | 62 | # Check accuracy and nonzeros weights in each layer 63 | print("--- After retraining ---") 64 | test(net, loader_test) 65 | prune_rate(net) 66 | 67 | 68 | # Save and load the entire model 69 | torch.save(net.state_dict(), 'models/convnet_pruned.pkl') 70 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pruning.layers import MaskedLinear, MaskedConv2d 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self): 8 | super(MLP, self).__init__() 9 | self.linear1 = MaskedLinear(28*28, 200) 10 | self.relu1 = nn.ReLU(inplace=True) 11 | self.linear2 = MaskedLinear(200, 200) 12 | self.relu2 = nn.ReLU(inplace=True) 13 | self.linear3 = MaskedLinear(200, 10) 14 | 15 | def forward(self, x): 16 | out = x.view(x.size(0), -1) 17 | out = self.relu1(self.linear1(out)) 18 | out = self.relu2(self.linear2(out)) 19 | out = self.linear3(out) 20 | return out 21 | 22 | def set_masks(self, masks): 23 | # Should be a less manual way to set masks 24 | # Leave it for the future 25 | self.linear1.set_mask(masks[0]) 26 | self.linear2.set_mask(masks[1]) 27 | self.linear3.set_mask(masks[2]) 28 | 29 | 30 | class ConvNet(nn.Module): 31 | def __init__(self): 32 | super(ConvNet, self).__init__() 33 | 34 | self.conv1 = MaskedConv2d(1, 32, kernel_size=3, padding=1, stride=1) 35 | self.relu1 = nn.ReLU(inplace=True) 36 | self.maxpool1 = nn.MaxPool2d(2) 37 | 38 | self.conv2 = MaskedConv2d(32, 64, kernel_size=3, padding=1, stride=1) 39 | self.relu2 = nn.ReLU(inplace=True) 40 | self.maxpool2 = nn.MaxPool2d(2) 41 | 42 | self.conv3 = MaskedConv2d(64, 64, kernel_size=3, padding=1, stride=1) 43 | self.relu3 = nn.ReLU(inplace=True) 44 | 45 | self.linear1 = nn.Linear(7*7*64, 10) 46 | 47 | def forward(self, x): 48 | out = self.maxpool1(self.relu1(self.conv1(x))) 49 | out = self.maxpool2(self.relu2(self.conv2(out))) 50 | out = self.relu3(self.conv3(out)) 51 | out = out.view(out.size(0), -1) 52 | out = self.linear1(out) 53 | return out 54 | 55 | def set_masks(self, masks): 56 | # Should be a less manual way to set masks 57 | # Leave it for the future 58 | self.conv1.set_mask(torch.from_numpy(masks[0])) 59 | self.conv2.set_mask(torch.from_numpy(masks[1])) 60 | self.conv3.set_mask(torch.from_numpy(masks[2])) 61 | -------------------------------------------------------------------------------- /models/convnet_pretrained.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglouis49/pytorch-weights_pruning/487cd67a93de5e282b0e400f591b2c2c7d8e8491/models/convnet_pretrained.pkl -------------------------------------------------------------------------------- /models/mlp_pretrained.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglouis49/pytorch-weights_pruning/487cd67a93de5e282b0e400f591b2c2c7d8e8491/models/mlp_pretrained.pkl -------------------------------------------------------------------------------- /pruning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglouis49/pytorch-weights_pruning/487cd67a93de5e282b0e400f591b2c2c7d8e8491/pruning/__init__.py -------------------------------------------------------------------------------- /pruning/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from pruning.utils import to_var 6 | 7 | 8 | class MaskedLinear(nn.Linear): 9 | def __init__(self, in_features, out_features, bias=True): 10 | super(MaskedLinear, self).__init__(in_features, out_features, bias) 11 | self.mask_flag = False 12 | 13 | def set_mask(self, mask): 14 | self.mask = to_var(mask, requires_grad=False) 15 | self.weight.data = self.weight.data*self.mask.data 16 | self.mask_flag = True 17 | 18 | def get_mask(self): 19 | print(self.mask_flag) 20 | return self.mask 21 | 22 | def forward(self, x): 23 | if self.mask_flag == True: 24 | weight = self.weight*self.mask 25 | return F.linear(x, weight, self.bias) 26 | else: 27 | return F.linear(x, self.weight, self.bias) 28 | 29 | 30 | class MaskedConv2d(nn.Conv2d): 31 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 32 | padding=0, dilation=1, groups=1, bias=True): 33 | super(MaskedConv2d, self).__init__(in_channels, out_channels, 34 | kernel_size, stride, padding, dilation, groups, bias) 35 | self.mask_flag = False 36 | 37 | def set_mask(self, mask): 38 | self.mask = to_var(mask, requires_grad=False) 39 | self.weight.data = self.weight.data*self.mask.data 40 | self.mask_flag = True 41 | 42 | def get_mask(self): 43 | print(self.mask_flag) 44 | return self.mask 45 | 46 | def forward(self, x): 47 | if self.mask_flag == True: 48 | weight = self.weight*self.mask 49 | return F.conv2d(x, weight, self.bias, self.stride, 50 | self.padding, self.dilation, self.groups) 51 | else: 52 | return F.conv2d(x, self.weight, self.bias, self.stride, 53 | self.padding, self.dilation, self.groups) 54 | 55 | -------------------------------------------------------------------------------- /pruning/methods.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from pruning.utils import prune_rate, arg_nonzero_min 7 | 8 | 9 | def weight_prune(model, pruning_perc): 10 | ''' 11 | Prune pruning_perc% weights globally (not layer-wise) 12 | arXiv: 1606.09274 13 | ''' 14 | all_weights = [] 15 | for p in model.parameters(): 16 | if len(p.data.size()) != 1: 17 | all_weights += list(p.cpu().data.abs().numpy().flatten()) 18 | threshold = np.percentile(np.array(all_weights), pruning_perc) 19 | 20 | # generate mask 21 | masks = [] 22 | for p in model.parameters(): 23 | if len(p.data.size()) != 1: 24 | pruned_inds = p.data.abs() > threshold 25 | masks.append(pruned_inds.float()) 26 | return masks 27 | 28 | 29 | def prune_one_filter(model, masks): 30 | ''' 31 | Pruning one least ``important'' feature map by the scaled l2norm of 32 | kernel weights 33 | arXiv:1611.06440 34 | ''' 35 | NO_MASKS = False 36 | # construct masks if there is not yet 37 | if not masks: 38 | masks = [] 39 | NO_MASKS = True 40 | 41 | values = [] 42 | for p in model.parameters(): 43 | 44 | if len(p.data.size()) == 4: # nasty way of selecting conv layer 45 | p_np = p.data.cpu().numpy() 46 | 47 | # construct masks if there is not 48 | if NO_MASKS: 49 | masks.append(np.ones(p_np.shape).astype('float32')) 50 | 51 | # find the scaled l2 norm for each filter this layer 52 | value_this_layer = np.square(p_np).sum(axis=1).sum(axis=1)\ 53 | .sum(axis=1)/(p_np.shape[1]*p_np.shape[2]*p_np.shape[3]) 54 | # normalization (important) 55 | value_this_layer = value_this_layer / \ 56 | np.sqrt(np.square(value_this_layer).sum()) 57 | min_value, min_ind = arg_nonzero_min(list(value_this_layer)) 58 | values.append([min_value, min_ind]) 59 | 60 | assert len(masks) == len(values), "something wrong here" 61 | 62 | values = np.array(values) 63 | 64 | # set mask corresponding to the filter to prune 65 | to_prune_layer_ind = np.argmin(values[:, 0]) 66 | to_prune_filter_ind = int(values[to_prune_layer_ind, 1]) 67 | masks[to_prune_layer_ind][to_prune_filter_ind] = 0. 68 | 69 | print('Prune filter #{} in layer #{}'.format( 70 | to_prune_filter_ind, 71 | to_prune_layer_ind)) 72 | 73 | return masks 74 | 75 | 76 | def filter_prune(model, pruning_perc): 77 | ''' 78 | Prune filters one by one until reach pruning_perc 79 | (not iterative pruning) 80 | ''' 81 | masks = [] 82 | current_pruning_perc = 0. 83 | 84 | while current_pruning_perc < pruning_perc: 85 | masks = prune_one_filter(model, masks) 86 | model.set_masks(masks) 87 | current_pruning_perc = prune_rate(model, verbose=False) 88 | print('{:.2f} pruned'.format(current_pruning_perc)) 89 | 90 | return masks 91 | -------------------------------------------------------------------------------- /pruning/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | from torch.utils.data import sampler 6 | 7 | 8 | def to_var(x, requires_grad=False, volatile=False): 9 | """ 10 | Varialbe type that automatically choose cpu or cuda 11 | """ 12 | if torch.cuda.is_available(): 13 | x = x.cuda() 14 | return Variable(x, requires_grad=requires_grad, volatile=volatile) 15 | 16 | 17 | def train(model, loss_fn, optimizer, param, loader_train, loader_val=None): 18 | 19 | model.train() 20 | for epoch in range(param['num_epochs']): 21 | print('Starting epoch %d / %d' % (epoch + 1, param['num_epochs'])) 22 | 23 | for t, (x, y) in enumerate(loader_train): 24 | x_var, y_var = to_var(x), to_var(y.long()) 25 | 26 | scores = model(x_var) 27 | loss = loss_fn(scores, y_var) 28 | 29 | if (t + 1) % 100 == 0: 30 | print('t = %d, loss = %.8f' % (t + 1, loss.data[0])) 31 | 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | 36 | 37 | def test(model, loader): 38 | 39 | model.eval() 40 | 41 | num_correct, num_samples = 0, len(loader.dataset) 42 | for x, y in loader: 43 | x_var = to_var(x, volatile=True) 44 | scores = model(x_var) 45 | _, preds = scores.data.cpu().max(1) 46 | num_correct += (preds == y).sum() 47 | 48 | acc = float(num_correct) / num_samples 49 | 50 | print('Test accuracy: {:.2f}% ({}/{})'.format( 51 | 100.*acc, 52 | num_correct, 53 | num_samples, 54 | )) 55 | 56 | return acc 57 | 58 | 59 | def prune_rate(model, verbose=True): 60 | """ 61 | Print out prune rate for each layer and the whole network 62 | """ 63 | total_nb_param = 0 64 | nb_zero_param = 0 65 | 66 | layer_id = 0 67 | 68 | for parameter in model.parameters(): 69 | 70 | param_this_layer = 1 71 | for dim in parameter.data.size(): 72 | param_this_layer *= dim 73 | total_nb_param += param_this_layer 74 | 75 | # only pruning linear and conv layers 76 | if len(parameter.data.size()) != 1: 77 | layer_id += 1 78 | zero_param_this_layer = \ 79 | np.count_nonzero(parameter.cpu().data.numpy()==0) 80 | nb_zero_param += zero_param_this_layer 81 | 82 | if verbose: 83 | print("Layer {} | {} layer | {:.2f}% parameters pruned" \ 84 | .format( 85 | layer_id, 86 | 'Conv' if len(parameter.data.size()) == 4 \ 87 | else 'Linear', 88 | 100.*zero_param_this_layer/param_this_layer, 89 | )) 90 | pruning_perc = 100.*nb_zero_param/total_nb_param 91 | if verbose: 92 | print("Final pruning rate: {:.2f}%".format(pruning_perc)) 93 | return pruning_perc 94 | 95 | 96 | def arg_nonzero_min(a): 97 | """ 98 | nonzero argmin of a non-negative array 99 | """ 100 | 101 | if not a: 102 | return 103 | 104 | min_ix, min_v = None, None 105 | # find the starting value (should be nonzero) 106 | for i, e in enumerate(a): 107 | if e != 0: 108 | min_ix = i 109 | min_v = e 110 | if not min_ix: 111 | print('Warning: all zero') 112 | return np.inf, np.inf 113 | 114 | # search for the smallest nonzero 115 | for i, e in enumerate(a): 116 | if e < min_v and e != 0: 117 | min_v = e 118 | min_ix = i 119 | 120 | return min_v, min_ix 121 | 122 | -------------------------------------------------------------------------------- /weight_pruning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pruning a MLP by weights with one shot 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.datasets as datasets 8 | import torchvision.transforms as transforms 9 | 10 | from pruning.methods import weight_prune 11 | from pruning.utils import to_var, train, test, prune_rate 12 | from models import MLP 13 | 14 | 15 | # Hyper Parameters 16 | param = { 17 | 'pruning_perc': 90., 18 | 'batch_size': 128, 19 | 'test_batch_size': 100, 20 | 'num_epochs': 5, 21 | 'learning_rate': 0.001, 22 | 'weight_decay': 5e-4, 23 | } 24 | 25 | 26 | # Data loaders 27 | train_dataset = datasets.MNIST(root='../data/',train=True, download=True, 28 | transform=transforms.ToTensor()) 29 | loader_train = torch.utils.data.DataLoader(train_dataset, 30 | batch_size=param['batch_size'], shuffle=True) 31 | 32 | test_dataset = datasets.MNIST(root='../data/', train=False, download=True, 33 | transform=transforms.ToTensor()) 34 | loader_test = torch.utils.data.DataLoader(test_dataset, 35 | batch_size=param['test_batch_size'], shuffle=True) 36 | 37 | 38 | # Load the pretrained model 39 | net = MLP() 40 | net.load_state_dict(torch.load('models/mlp_pretrained.pkl')) 41 | if torch.cuda.is_available(): 42 | print('CUDA ensabled.') 43 | net.cuda() 44 | print("--- Pretrained network loaded ---") 45 | test(net, loader_test) 46 | 47 | # prune the weights 48 | masks = weight_prune(net, param['pruning_perc']) 49 | net.set_masks(masks) 50 | print("--- {}% parameters pruned ---".format(param['pruning_perc'])) 51 | test(net, loader_test) 52 | 53 | 54 | # Retraining 55 | criterion = nn.CrossEntropyLoss() 56 | optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'], 57 | weight_decay=param['weight_decay']) 58 | 59 | train(net, criterion, optimizer, param, loader_train) 60 | 61 | 62 | # Check accuracy and nonzeros weights in each layer 63 | print("--- After retraining ---") 64 | test(net, loader_test) 65 | prune_rate(net) 66 | 67 | 68 | # Save and load the entire model 69 | torch.save(net.state_dict(), 'models/mlp_pruned.pkl') 70 | --------------------------------------------------------------------------------