├── tests ├── __init__.py ├── context.py ├── test_scorers.py └── test_utils.py ├── .gitignore ├── setup.py ├── License ├── pytorchpruner ├── __init__.py ├── unitscorers.py ├── modules.py ├── scorers.py ├── pruners.py └── utils.py ├── README.md └── notebooks ├── demo_scorers.ipynb ├── demo_hessian_toynetwork_mnist.ipynb ├── demo_pruner_toyMNISTcnn.ipynb └── demo_utils.ipynb /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | notebooks/data 3 | notebooks/.* 4 | .cache* 5 | build/ 6 | *.mod 7 | -------------------------------------------------------------------------------- /tests/context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | import pytorchpruner 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from distutils.core import setup 4 | 5 | setup(name='pytorchpruner', 6 | version='0.1', 7 | description='Prunner extesion with scoring functions and second order helpers', 8 | author='Utku Evci', 9 | author_email='utkuevci@nyu.edu', 10 | url='https://evcu.github.io/', 11 | packages=['pytorchpruner'], 12 | ) 13 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-2018 evcu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pytorchpruner/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2017-2018 evcu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorchpruner 2 | 3 | Do `python setup.py install` to install. 4 | Check `notebooks/` for samples. 5 | 6 | `pytorchpruner` is a [pytorch](https://pytorch.org/) package for pruning neural networks. It is intended for research and its main objective is not to provide fastest pruning framework, however it is relatively efficient and fast. It uses masking idea to simulate pruning and supports two main pruning strategies. It also implements various second order functions like hessian and hessian-vector product. 7 | 8 | There are X main parts of the library 9 | 10 | 1. **Parameter Pruning (pytorchpruner.scorers)**: Saliency measures that return a same-sized-tensor of scores for each parameter in the provided parameter tensor. 11 | 2. **Unit Pruning (pytorchpruner.unitscorers)**: Saliency measures that return a vector of scores for each unit in the provided parameter tensor. 12 | 3. **Pruners (pytorchpruner.pruners)**: Has two different pruner engine for the two different pruning strategies (parameter vs unit). `remove_empty_filters` function in this file reduces the size of the network by copying the parameters into smaller tensors if possible. 13 | 4. **Auxiliary Modules (pytorchpruner.modules)**: implements `meanOutputReplacer` and `maskedModule`, the two important wrapper for `torch.nn.Module` instances. The first one replaces its output with the mean value, if enabled. And the second one simulates the pruning layers. 14 | 5. **Various first/second-order functionality (pytorchpruner.utils)**: implements hessian calculation, hessian-vector product, search functionality and some other utility functions. 15 | -------------------------------------------------------------------------------- /tests/test_scorers.py: -------------------------------------------------------------------------------- 1 | from .context import pytorchpruner 2 | import pytest 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.nn import Parameter 6 | 7 | TOLERANCE = 1e-5 8 | from pytorchpruner.scorers import magnitudeScorer,gradientScorer,hessianScorer 9 | 10 | def L2d(w): 11 | ''' 12 | A custom loss function 13 | w00 w01 14 | w10 w11 15 | ''' 16 | return (w[1,0]**2)*w[1,1]+4*(w[0,0]**3)*w[1,0] 17 | def GScore(w): 18 | #Jacobian of the L(w) 19 | wd = w.data 20 | return torch.Tensor([ 21 | [12*(wd[0,0]**2)*wd[1,0], 22 | 0], 23 | [2*wd[1,0]*wd[1,1]+4*(wd[0,0]**3), 24 | wd[1,0]**2] 25 | ]) 26 | 27 | 28 | def HScore(w): 29 | #Hessian of the L2(w) 30 | wd=w.data 31 | gw12 = 2*wd[1,0] 32 | gw13 = 12*(wd[0,0]**2) 33 | gw23 = 0 34 | gw11 = 2*wd[1,1] 35 | gw22 = 0 36 | gw33 = 24*wd[0,0]*wd[1,0] 37 | ## x3 0 38 | ## x1 x2 39 | return torch.Tensor([[gw33+0+gw13+gw23, #0,0,: #x3 with others 40 | 0+0+0+0], #0,1,: 41 | [gw13+0+gw11+gw12, #1,0,: #x1 with others 42 | gw23+0+gw12+gw22]]) #1,1,: #x2 with others 43 | 44 | def test_gradientScorer(): 45 | w = Parameter(torch.rand(2,2)) 46 | 47 | loss_val = L2d(w) 48 | grad_score = gradientScorer(w,loss_val) 49 | #emprical score 50 | correct_score = GScore(w).abs() 51 | assert (grad_score-correct_score).abs().sum() < TOLERANCE 52 | 53 | def test_hessianScorer(): 54 | w = Parameter(torch.rand(2,2)) 55 | loss_val = L2d(w) 56 | 57 | hessian_score = hessianScorer(w,loss_val) 58 | #emprical score 59 | correct_score = HScore(w).abs() 60 | assert (hessian_score-correct_score).abs().sum() < TOLERANCE 61 | -------------------------------------------------------------------------------- /pytorchpruner/unitscorers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Each scorer should implement score function, which given parameter tensor returns the 3 | scores in the same shape. High scores mean high salincies. 4 | ''' 5 | 6 | from .modules import meanOutputReplacer 7 | from torch import nn 8 | import torch 9 | #loss is not needer, but adding it here to have a generic set of parameters 10 | def normScorer(layer,p=1,**kwargs): 11 | if isinstance(layer,meanOutputReplacer): 12 | layer = layer.module 13 | if isinstance(layer,(nn.Conv2d,nn.Linear)): 14 | normed_param = layer.weight.data.norm(p=p,dim=1) 15 | while normed_param.dim()>1: 16 | normed_param = normed_param.norm(p=p,dim=1) 17 | return normed_param 18 | else: 19 | raise ValueError("Invalid type, received: %s. should be a nn.Conv2d or nn.Linear" % type(layer)) 20 | 21 | def normScorerL1(layer,**kwargs): 22 | return normScorer(layer,p=1,**kwargs) 23 | 24 | def normScorerL2(layer,**kwargs): 25 | return normScorer(layer,p=2,**kwargs) 26 | 27 | def randomScorer(layer,p=1,**kwargs): 28 | old_state = torch.get_rng_state() 29 | if isinstance(layer,meanOutputReplacer): 30 | layer = layer.module 31 | if isinstance(layer,(nn.Conv2d,nn.Linear)): 32 | res = torch.rand(layer.weight.data.size(0)) 33 | torch.set_rng_state(old_state) 34 | return res 35 | else: 36 | raise ValueError("Invalid type, received: %s. should be a nn.Conv2d or nn.Linear" % type(layer)) 37 | 38 | def mrsScorer(layer): 39 | if isinstance(layer,(meanOutputReplacer)): 40 | # print(f'is_mr:{layer.is_mean_replace},enabled:{layer.enabled}') 41 | return layer.mrss 42 | else: 43 | raise ValueError("Invalid type, received: %s. should be wrapped with meanOutputReplacer" % type(layer)) 44 | 45 | def mrpScorer(layer,loss_diff_f=None): 46 | if loss_diff_f is None: 47 | raise ValueError(f'loss_diff_f:{loss_diff_f} cannot be None') 48 | if isinstance(layer,(meanOutputReplacer)): 49 | n_units = layer.weight.data.size(0) 50 | dat_diffs = torch.zeros(n_units) 51 | old_state = layer.is_mean_replace 52 | layer.is_mean_replace = True 53 | for ui in range(n_units): 54 | layer.unit_id = ui 55 | dat_diffs[ui]= loss_diff_f() 56 | layer.is_mean_replace = old_state 57 | return dat_diffs 58 | else: 59 | raise ValueError("Invalid type, received: %s. should be wrapped with meanOutputReplacer" % type(layer)) 60 | -------------------------------------------------------------------------------- /notebooks/demo_scorers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## How to use Scorers\n", 8 | "This demo includes how to use scorers to get same size scores for each score. " 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": { 15 | "collapsed": true 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import torch\n", 20 | "from torch.autograd import Variable\n", 21 | "from torch.nn import Parameter\n", 22 | "import sys\n", 23 | "sys.path.insert(0,\"../\")\n", 24 | "from pytorchpruner.scorers import magnitudeScorer,gradientScorer,hessianScorer\n", 25 | "\n", 26 | "TOLERANCE = 1e-5" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "Parameter containing:\n", 39 | " 0.3538 0.9319\n", 40 | " 0.1769 0.5440\n", 41 | "[torch.FloatTensor of size 2x2]\n", 42 | "\n", 43 | "\n", 44 | " 0.2656 0.0000\n", 45 | " 0.3695 0.0313\n", 46 | "[torch.FloatTensor of size 2x2]\n", 47 | "\n", 48 | "True\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "# We have 4 random parameters\n", 54 | "w = Parameter(torch.rand(2,2))\n", 55 | "print(w)\n", 56 | "\n", 57 | "def L2d(w):\n", 58 | " '''\n", 59 | " A custom loss function\n", 60 | " w00 w01\n", 61 | " w10 w11\n", 62 | " ''' \n", 63 | " return (w[1,0]**2)*w[1,1]+4*(w[0,0]**3)*w[1,0]\n", 64 | "\n", 65 | "def G2d(w):\n", 66 | " #Jacobian of the L(w)\n", 67 | " wd = w.data\n", 68 | " return torch.Tensor([\n", 69 | " [12*(wd[0,0]**2)*wd[1,0],\n", 70 | " 0],\n", 71 | " [2*wd[1,0]*wd[1,1]+4*(wd[0,0]**3),\n", 72 | " wd[1,0]**2]\n", 73 | " ])\n", 74 | "\n", 75 | "#torch.autograd\n", 76 | "loss_val = L2d(w)\n", 77 | "grad_score = gradientScorer(loss_val, w)\n", 78 | "print(grad_score)\n", 79 | "#emprical gradient\n", 80 | "correct_score = G2d(w)\n", 81 | "# print(correct_score)\n", 82 | "print((grad_score-correct_score).abs().sum() < TOLERANCE)\n" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "\n", 95 | " 3.0034 0.0000\n", 96 | " 2.9434 0.3537\n", 97 | "[torch.FloatTensor of size 2x2]\n", 98 | "\n", 99 | "True\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "hessian_score = hessianScorer(loss_val,w)\n", 105 | "\n", 106 | "print(hessian_score)\n", 107 | "def H2dScore(w):\n", 108 | " #Hessian of the L2(w)\n", 109 | " wd=w.data\n", 110 | " gw12 = 2*wd[1,0]\n", 111 | " gw13 = 12*(wd[0,0]**2)\n", 112 | " gw23 = 0\n", 113 | " gw11 = 2*wd[1,1]\n", 114 | " gw22 = 0\n", 115 | " gw33 = 24*wd[0,0]*wd[1,0]\n", 116 | " ## x3 0\n", 117 | " ## x1 x2\n", 118 | " return torch.Tensor([[gw33+0+gw13+gw23, #0,0,: #x3 with others\n", 119 | " 0+0+0+0], #0,1,:\n", 120 | " [gw13+0+gw11+gw12, #1,0,: #x1 with others\n", 121 | " gw23+0+gw12+gw22]]) #1,1,: #x2 with others\n", 122 | "\n", 123 | "correct_score = H2dScore( w)\n", 124 | "print((correct_score-hessian_score).abs().sum() < TOLERANCE)" 125 | ] 126 | } 127 | ], 128 | "metadata": { 129 | "kernelspec": { 130 | "display_name": "Python 3", 131 | "language": "python", 132 | "name": "python3" 133 | }, 134 | "language_info": { 135 | "codemirror_mode": { 136 | "name": "ipython", 137 | "version": 3 138 | }, 139 | "file_extension": ".py", 140 | "mimetype": "text/x-python", 141 | "name": "python", 142 | "nbconvert_exporter": "python", 143 | "pygments_lexer": "ipython3", 144 | "version": "3.6.3" 145 | } 146 | }, 147 | "nbformat": 4, 148 | "nbformat_minor": 2 149 | } 150 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from .context import pytorchpruner 2 | import pytest 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.nn import Parameter 6 | from torch import nn 7 | 8 | TOLERANCE = 1e-5 9 | from pytorchpruner.utils import hessian_fun,gradient_fun,flatten_params,get_reverse_flatten_params_fun 10 | 11 | def L(w): 12 | #A custom loss function 13 | return (w[0]**2)*w[1]+4*(w[2]**3)*w[0] 14 | def G(w): 15 | #Jacobian of the L(w) 16 | wd = w.data 17 | return torch.Tensor([2*wd[0]*wd[1]+4*(wd[2]**3), 18 | wd[0]**2, 19 | 12*(wd[2]**2)*wd[0]]) 20 | def H(w): 21 | #Hessian of the L(w) 22 | wd=w.data 23 | gw12 = 2*wd[0] 24 | gw13 = 12*(wd[2]**2) 25 | gw23 = 0 26 | gw11 = 2*wd[1] 27 | gw22 = 0 28 | gw33 = 24*wd[2]*wd[0] 29 | return torch.Tensor([[gw11,gw12,gw13], 30 | [gw12,gw22,gw23], 31 | [gw13,gw23,gw33]]) 32 | def L2d(w): 33 | ''' 34 | A custom loss function 35 | w00 w01 36 | w10 w11 37 | ''' 38 | return (w[1,0]**2)*w[1,1]+4*(w[0,0]**3)*w[1,0] 39 | def G2d(w): 40 | #Jacobian of the L(w) 41 | wd = w.data 42 | return torch.Tensor([ 43 | [12*(wd[0,0]**2)*wd[1,0], 44 | 0], 45 | [2*wd[1,0]*wd[1,1]+4*(wd[0,0]**3), 46 | wd[1,0]**2] 47 | ]) 48 | 49 | def H2d(w): 50 | #Hessian of the L2(w) 51 | wd=w.data 52 | gw12 = 2*wd[1,0] 53 | gw13 = 12*(wd[0,0]**2) 54 | gw23 = 0 55 | gw11 = 2*wd[1,1] 56 | gw22 = 0 57 | gw33 = 24*wd[0,0]*wd[1,0] 58 | ## x3 0 59 | ## x1 x2 60 | return torch.Tensor([[[[gw33,0], #0,0,0,: 61 | [gw13,gw23]], #0,0,1,: #x3 with others 62 | [[0,0], #0,1,0,: 63 | [0,0]]], #0,1,1,: 64 | [[[gw13,0], #1,0,0,: 65 | [gw11,gw12]], #1,0,1,: #x1 with others 66 | [[gw23,0], #1,1,0,: 67 | [gw12,gw22]]]]) #1,1,1,: #x2 with others 68 | 69 | def L2(w,w2): 70 | #A custom loss function 71 | return (w[0]**2)*w[1]+4*(w[2]**3)*w[0]+(w2[0]**2)*w2[1]+4*(w2[2]**3)*w2[0] 72 | 73 | def G2(w,w2): 74 | #Jacobian of the L(w) 75 | return (G(w), 76 | G(w2) ) 77 | def H2(w,w2): 78 | #Hessian of the L(w) 79 | return (H(w), 80 | H(w2) ) 81 | 82 | def test_gradient_fun_1d(): 83 | w = Parameter(torch.rand(3)) 84 | 85 | loss_val = L(w) 86 | autograd_grad = gradient_fun(loss_val, w).data 87 | #emprical gradient 88 | correct_grad = G(w) 89 | assert (autograd_grad-correct_grad).abs().sum() < TOLERANCE 90 | 91 | def test_gradient_fun_2d(): 92 | # We have 4 random parameters 93 | w = Parameter(torch.rand(2,2)) 94 | 95 | #torch.autograd 96 | loss_val = L2d(w) 97 | autograd_grad = gradient_fun(loss_val, w).data 98 | #emprical gradient 99 | correct_grad = G2d(w) 100 | assert (autograd_grad-correct_grad).abs().sum() < TOLERANCE 101 | 102 | def test_hessian_fun_1d(): 103 | w = Parameter(torch.rand(3)) 104 | 105 | a = L(w) 106 | hessian = hessian_fun(a,w) 107 | 108 | assert torch.sum(torch.abs(hessian-H(w))) < TOLERANCE 109 | 110 | 111 | def test_hessian_fun_2d(): 112 | 113 | w = Parameter(torch.rand(2,2)) 114 | 115 | a = L2d(w) 116 | hessian = hessian_fun(a,w) 117 | 118 | assert torch.sum(torch.abs(hessian-H2d(w))) < TOLERANCE 119 | 120 | def test_gradient_fun_selection(): 121 | # We have 3 random parameters 122 | genesis = torch.rand(3) 123 | w = Parameter(genesis) 124 | w2 = Parameter(genesis) 125 | 126 | #torch.autograd 127 | loss_val = L2(w,w2) 128 | autograd_grad1 = gradient_fun(loss_val, w,retain_graph=True).data 129 | autograd_grad2 = gradient_fun(loss_val, w2).data 130 | #emprical gradient 131 | correct_grad = G2(w,w2) 132 | assert (autograd_grad2-correct_grad[1]).abs().sum() < TOLERANCE 133 | assert (autograd_grad1-correct_grad[0]).abs().sum() < TOLERANCE 134 | 135 | def test_hessian_fun_selection(): 136 | # We have 3 random parameters 137 | genesis = torch.rand(3) 138 | w = Parameter(genesis) 139 | w2 = Parameter(genesis) 140 | 141 | losses = L2(w,w2) 142 | correct_hessians = H2(w,w2) 143 | hessian1 = hessian_fun(losses,w) 144 | hessian2 = hessian_fun(losses,w2) 145 | assert (hessian1-correct_hessians[0]).abs().sum() < TOLERANCE 146 | assert (hessian2-correct_hessians[1]).abs().sum() < TOLERANCE 147 | 148 | def test_flatten_params_single_parameter(): 149 | conv1 = nn.Conv2d(1, 2, kernel_size=5) 150 | params = list(conv1.parameters()) 151 | param = params[0] 152 | 153 | flatp = flatten_params(param) 154 | rev_f = get_reverse_flatten_params_fun(param) 155 | assert (rev_f(flatp).data-param.data).abs().sum() < TOLERANCE 156 | 157 | 158 | flatps = flatten_params(params) 159 | rev_f = get_reverse_flatten_params_fun(params) 160 | for a,b in zip(rev_f(flatps),params): 161 | assert (a.data-b.data).abs().sum() < TOLERANCE 162 | 163 | def test_flatten_params_network(): 164 | import torch.nn.functional as F 165 | from itertools import tee 166 | class Net(nn.Module): 167 | def __init__(self): 168 | super(Net, self).__init__() 169 | self.conv1 = nn.Conv2d(1, 2, kernel_size=5) 170 | self.conv2 = nn.Conv2d(2, 1, kernel_size=5) 171 | self.fc1 = nn.Linear(16, 2) 172 | self.fc2 = nn.Linear(2, 10) 173 | self.nonlins = {'conv1':('max_relu',(2,2)),'conv2':('max_relu',(2,2)),'fc1':'relu','fc2':'log_softmax'} 174 | 175 | def forward(self, x): 176 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 177 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 178 | x = x.view(-1, 16) 179 | x = F.relu(self.fc1(x)) 180 | x = self.fc2(x) 181 | return F.log_softmax(x,dim=1) 182 | model = Net() 183 | f_params = flatten_params(model.parameters()) 184 | reverse_gen = get_reverse_flatten_params_fun(model.parameters()) 185 | 186 | for a,b in zip(reverse_gen(f_params),model.parameters()): 187 | assert torch.sum(torch.abs(a-b)).data[0]1: 40 | raise ValueError('meanOutputReplacer is not implemented for layyer getting multiple inputs') 41 | if self.enabled: 42 | out = self.module(*inputs, **kwargs) 43 | 44 | out_mean = out.mean(0) 45 | # second dim is the n_outputs 46 | while out_mean.dim()>1: 47 | out_mean = out_mean.mean(1) 48 | self.cy_mean = out_mean.data 49 | # import pdb;pdb.set_trace() 50 | if self.is_mean_replace: 51 | if isinstance(self.unit_id,torch.ByteTensor): 52 | ## CAN DO THIS MORE EFFICIENTLY USING TENSOR OPS 53 | for i in range(len(self.unit_id)): 54 | if self.unit_id[i]==1: 55 | out[:,i] = out_mean[i].expand(out.size(0),*out.size()[2:]) 56 | else: 57 | out[:,self.unit_id] = out_mean[self.unit_id].expand(out.size(0),*out.size()[2:]) 58 | else: 59 | out_mean_expanded = out_mean.data.expand(out.size(0), 60 | *out.size()[2:], 61 | -1) 62 | if out_mean_expanded.dim()>2: 63 | out_mean_expanded= out_mean_expanded.transpose(1,3) 64 | self.cy_zeromean = out.data-out_mean_expanded 65 | else: 66 | out = self.module(*inputs, **kwargs) 67 | return out 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ + '(\n\t' \ 71 | + 'module=' + str(self.module) \ 72 | + '\n\t,is_mean_replace=' + str(self.is_mean_replace) \ 73 | + '\n\t,enabled=' + str(self.enabled) + ')' 74 | 75 | class MaskedModule(Module): 76 | r"""Implements masked module for prunning etc... 77 | it creates a mask for each layer and holds those masks inside a dictionary belongs to 78 | this module 79 | """ 80 | DEFAULT_MASKED_MODULES = (torch.nn.Conv2d,torch.nn.Linear) 81 | def __init__(self, module, masked_modules=DEFAULT_MASKED_MODULES): 82 | super().__init__() 83 | self.module = module 84 | self.def_masked_modules = masked_modules 85 | self._mask_dict = {} # nn.Module->torch.ByteTensor 86 | # self._inp_dict = {} # nn.Module->torch.Tensor 87 | # self._ginp_dict = {} # nn.Module->torch.Tensor 88 | # self.__fhook_dict = {} # nn.Module->fuction 89 | # self.__bhook_dict = {} # nn.Module->fuction 90 | self.initialize_transparent_masks() 91 | 92 | def initialize_transparent_masks(self): 93 | def helper(m): 94 | if isinstance(m, self.def_masked_modules): 95 | self._mask_dict[m] = [torch.zeros(m.weight.data.size()).byte(), 96 | torch.zeros(m.bias.data.size()).byte()] 97 | self.module.apply(helper) 98 | 99 | def forward(self, *inputs, **kwargs): 100 | return self.module.forward(*inputs, **kwargs) 101 | 102 | def cpu(self): 103 | for k,v in self._mask_dict.items(): 104 | self._mask_dict[k]=[v[0].cpu(),v[1].cpu()] 105 | self.module.cpu() 106 | super().cpu() 107 | return self 108 | def cuda(self): 109 | for k,v in self._mask_dict.items(): 110 | self._mask_dict[k]=[v[0].cuda(),v[1].cuda()] 111 | self.module.cuda() 112 | super().cuda() 113 | return self 114 | def apply_mask_on_gradients(self): 115 | def helper(m): 116 | if isinstance(m, self.def_masked_modules): 117 | m.weight.grad.data[self._mask_dict[m][0]]=0 118 | m.bias.grad.data[self._mask_dict[m][1]]=0 119 | self.module.apply(helper) 120 | 121 | def calculateSparsity(self): 122 | sum_zeros = 0 123 | sum_elements = 0 124 | if self._mask_dict: 125 | for mask_w,mask_b in self._mask_dict.values(): 126 | sum_zeros += mask_w.sum()+mask_b.sum() 127 | sum_elements += mask_w.nelement() + mask_b.nelement() 128 | return (sum_zeros/sum_elements) 129 | else: 130 | error('Mask is not initilized, this should not happen currently. 01.04.2018') 131 | def __repr__(self): 132 | return self.__class__.__name__ + '(\n\t' \ 133 | + 'module=' + str(self.module) + ')' 134 | # def initiliaze_forward_hooks(self): 135 | # def f_hook(module,inp,out): 136 | # #TODO check inp is a tuple of size 1. 137 | # self._inp_dict[module]=inp[0] 138 | # def helper(m): 139 | # if isinstance(m, self.DEFAULT_MASKED_MODULES): 140 | # self.__fhook_dict[m] = m.register_forward_hook(f_hook) 141 | # self.module.apply(helper) 142 | # 143 | # def remove_forward_hooks(self): 144 | # if self.__fhook_dict: 145 | # for v in self.__fhook_dict.values(): 146 | # v.remove() 147 | # self._inp_dict = {} 148 | # self.__fhook_dict = {} 149 | # 150 | # def initiliaze_backward_hooks(self): 151 | # def b_hook(module,ginp,gout): 152 | # #TODO check inp is Tensor or tuple 153 | # self._ginp_dict[module]=ginp 154 | # def helper(m): 155 | # if isinstance(m, self.DEFAULT_MASKED_MODULES): 156 | # self.__bhook_dict[m] = m.register_backward_hook(b_hook) 157 | # self.module.apply(helper) 158 | # 159 | # def remove_backward_hooks(self): 160 | # if self.__bhook_dict: 161 | # for v in self.__bhook_dict.values(): 162 | # v.remove() 163 | # self._ginp_dict = {} 164 | # self.__bhook_dict = {} 165 | -------------------------------------------------------------------------------- /pytorchpruner/scorers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Each scorer should implement score function, which given parameter tensor returns the 3 | scores in the same shape. High scores mean high salincies. 4 | ''' 5 | 6 | from .utils import hessian_fun,gradient_fun,get_reverse_flatten_params_fun,hessian_vector_product,flatten_params 7 | from torch import nn 8 | import collections 9 | import torch 10 | from itertools import product 11 | #loss is not needer, but adding it here to have a generic set of parameters 12 | def magnitudeScorer(params,*args,**kwargs): 13 | if isinstance(params,nn.Parameter): 14 | result = params.data.clone().abs() 15 | elif isinstance(params,collections.Iterable): 16 | # Case 2 17 | result = list(map(lambda x:x.data.clone().abs(),params)) 18 | else: 19 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 20 | parameters or a single parameter" % type(params)) 21 | return result 22 | 23 | def taylor1Scorer(params,loss,w_fun=lambda a: -a): 24 | """ 25 | taylor1Scorer 26 | 27 | """ 28 | if isinstance(params,nn.Parameter): 29 | dw = w_fun(params.data) 30 | grad = gradient_fun(loss,params,retain_graph=True).data 31 | result = torch.mul(dw,grad) 32 | elif isinstance(params,collections.Iterable): 33 | params = list(params) 34 | grads = gradient_fun(loss,params,retain_graph=True) 35 | result = [torch.mul(w_fun(w.data),g.data) for w,g in zip(params,grads)] 36 | else: 37 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 38 | parameters or a single parameter" % type(params)) 39 | return result 40 | 41 | def taylor1ScorerAbs(params,loss,w_fun=lambda a:-a): 42 | if isinstance(params,nn.Parameter): 43 | return taylor1Scorer(params,loss,w_fun=w_fun).abs() 44 | elif isinstance(params,collections.Iterable): 45 | return [s.abs() for s in taylor1Scorer(params,loss,w_fun=w_fun)] 46 | else: 47 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 48 | parameters or a single parameter" % type(params)) 49 | 50 | 51 | 52 | def hessianScorer(params,loss,w_fun=lambda a:-a): 53 | """ 54 | hessian Scorer which basically returns the sum of the row of the hessian 55 | using efficient hessian-vector product. 56 | 57 | params: 58 | single nn.Parameter -> trivial 59 | iterator of Parameter's -> In this case we flattened the parameters 60 | returns: 61 | single nn.Parameter -> a single score tensor same size as the input Parameter 62 | iterator of Parameter's -> an iterator of scores each is same size as the input Parameters 63 | 64 | Example: 65 | check `test_hessianScorer` 66 | """ 67 | if isinstance(params,nn.Parameter): 68 | vector = w_fun(params.data.clone()) 69 | hv = hessian_vector_product(loss,params,vector,retain_graph=True) 70 | result = torch.mul(w_fun(params.data),hv) 71 | elif isinstance(params,collections.Iterable): 72 | # Case 2 73 | params = list(params) 74 | rev_f,n_elements = get_reverse_flatten_params_fun(params,get_count=True) 75 | vector = flatten_params((w_fun(p.data.clone()) for p in params)) 76 | flat_hv = hessian_vector_product(loss,params,vector,retain_graph=True,flattened=True) 77 | hv = rev_f(flat_hv) 78 | result = [torch.mul(w_fun(w.data),h) for w,h in zip(params,hv)] 79 | else: 80 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 81 | parameters or a single parameter" % type(params)) 82 | 83 | return result 84 | 85 | def hessianScorerAbs(params,loss,w_fun=lambda a:-a): 86 | if isinstance(params,nn.Parameter): 87 | return hessianScorer(params,loss,w_fun=w_fun).abs() 88 | elif isinstance(params,collections.Iterable): 89 | return [s.abs() for s in hessianScorer(params,loss,w_fun=w_fun)] 90 | else: 91 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 92 | parameters or a single parameter" % type(params)) 93 | 94 | def taylor2Scorer(params,loss,w_fun=lambda a: -a,scale=1): 95 | """ 96 | taylor2Scorer 97 | 98 | """ 99 | if not isinstance(scale, (int, float)): 100 | raise ValueError(f'scale={float} needs tobe a float or int') 101 | if isinstance(params,nn.Parameter): 102 | vector = w_fun(params.data.clone()) 103 | hv = hessian_vector_product(loss,params,vector,retain_graph=True) 104 | grad = gradient_fun(loss,params,retain_graph=True).data 105 | result = torch.mul(w_fun(params.data),scale*hv+grad) 106 | elif isinstance(params,collections.Iterable): 107 | # Case 2 108 | params = list(params) 109 | rev_f,n_elements = get_reverse_flatten_params_fun(params,get_count=True) 110 | vector = flatten_params((w_fun(p.data.clone()) for p in params)) 111 | flat_hv = hessian_vector_product(loss,params,vector,retain_graph=True,flattened=True) 112 | hv = rev_f(flat_hv) 113 | grads = gradient_fun(loss,params,retain_graph=True) 114 | result = [torch.mul(w_fun(w.data),scale*h+g.data) for w,h,g in zip(params,hv,grads)] 115 | else: 116 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 117 | parameters or a single parameter" % type(params)) 118 | return result 119 | 120 | def taylor2ScorerAbs(params,loss,w_fun=lambda a:-a,scale=1): 121 | if isinstance(params,nn.Parameter): 122 | return taylor2Scorer(params,loss,w_fun=w_fun,scale=scale).abs() 123 | elif isinstance(params,collections.Iterable): 124 | return [s.abs() for s in taylor2Scorer(params,loss,w_fun=w_fun,scale=scale)] 125 | else: 126 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 127 | parameters or a single parameter" % type(params)) 128 | 129 | def lossChangeScorer(params,loss,loss_calc_f=None): 130 | if loss_calc_f is None: 131 | raise ValueError(f'loss_calc_f:{loss_calc_f} cannot be None') 132 | 133 | if isinstance(params,nn.Parameter): 134 | scores = params.data.clone() 135 | for idx in product(*map(range,scores.size())): 136 | old_val,params.data[idx] = params.data[idx],0 137 | scores[idx] = loss_calc_f()-loss.data[0] 138 | params.data[idx] = old_val 139 | elif isinstance(params,collections.Iterable): 140 | scores = [lossChangeScorer(p,loss,loss_calc_f=loss_calc_f) for p in params] 141 | else: 142 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 143 | parameters or a single parameter" % type(params)) 144 | 145 | return scores 146 | 147 | def lossChangeScorerAbs(params,loss,**kwargs): 148 | if isinstance(params,nn.Parameter): 149 | return lossChangeScorer(params,loss,**kwargs).abs() 150 | elif isinstance(params,collections.Iterable): 151 | return [s.abs() for s in lossChangeScorer(params,loss,**kwargs)] 152 | else: 153 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 154 | parameters or a single parameter" % type(params)) 155 | 156 | def randomScorer(params,loss): 157 | old_state = torch.get_rng_state() 158 | if isinstance(params,nn.Parameter): 159 | res = torch.rand(params.size()) 160 | elif isinstance(params,collections.Iterable): 161 | res = [randomScorer(p,loss) for p in params] 162 | else: 163 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 164 | parameters or a single parameter" % type(params)) 165 | torch.set_rng_state(old_state) 166 | return res 167 | -------------------------------------------------------------------------------- /notebooks/demo_pruner_toyMNISTcnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## pytorch Pruner Demo\n", 8 | "Purpose:\n", 9 | "- Demonstrate the pruners.BasePruner module with modules.MaskedModule" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 5, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import sys\n", 21 | "sys.path.insert(0,'../')\n", 22 | "from pytorchpruner.modules import MaskedModule\n", 23 | "from pytorchpruner.pruners import BasePruner" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## a Toy CNN\n", 31 | "Let's define a toy CNN" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 21, 37 | "metadata": { 38 | "collapsed": true 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "import torch\n", 43 | "import torch.nn as nn\n", 44 | "import torch.nn.functional as F\n", 45 | "from torch.autograd import Variable\n", 46 | "\n", 47 | "class Net(nn.Module):\n", 48 | " def __init__(self):\n", 49 | " super(Net, self).__init__()\n", 50 | " self.conv1 = nn.Conv2d(1, 8, kernel_size=5)\n", 51 | " self.conv2 = nn.Conv2d(8,16, kernel_size=5)\n", 52 | " self.fc1 = nn.Linear(256, 64)\n", 53 | " self.fc2 = nn.Linear(64, 10)\n", 54 | " self.nonlins = {'conv1':('max_relu',(2,2)),'conv2':('max_relu',(2,2)),'fc1':'relu','fc2':'log_softmax'}\n", 55 | "\n", 56 | " def forward(self, x):\n", 57 | " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", 58 | " x = F.relu(F.max_pool2d(self.conv2(x), 2))\n", 59 | " x = x.view(-1, 256)\n", 60 | " x = F.relu(self.fc1(x))\n", 61 | " x = self.fc2(x)\n", 62 | " return F.log_softmax(x,dim=1)\n", 63 | "\n", 64 | "def weight_init(m):\n", 65 | " if isinstance(m,(torch.nn.Conv2d,torch.nn.Linear)):\n", 66 | " nn.init.xavier_uniform(m.weight)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Masked module and Pruner\n", 74 | "Masked module is needed Pruner to work on the network. This can be done explicitly or implicitly." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 22, 80 | "metadata": { 81 | "collapsed": true 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "model = MaskedModule(Net())\n", 86 | "# model = Net() #implicit MaskedModel(model) is called during initilization\n", 87 | "pruner = BasePruner(model)\n", 88 | "#dummy batch sample\n", 89 | "x=Variable(torch.Tensor(32,1,28,28)) #mnist batch\n", 90 | "y=Variable((torch.ones(32)).long())" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 23, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Variable containing:\n", 103 | "-2.3978\n", 104 | "-2.3298\n", 105 | "-2.1738\n", 106 | "-2.3781\n", 107 | "-2.2667\n", 108 | "-2.4077\n", 109 | "-2.2205\n", 110 | "-2.1774\n", 111 | "-2.3512\n", 112 | "-2.3596\n", 113 | "[torch.FloatTensor of size 10]\n", 114 | "\n", 115 | "Variable containing:\n", 116 | "(0 ,.,.) = \n", 117 | " -0.1216 0.0017 0.0447 0.0141 -0.1600\n", 118 | " 0.0304 -0.1724 0.0179 0.1675 -0.0229\n", 119 | " -0.1923 0.1447 -0.1297 0.0304 0.0879\n", 120 | " -0.1815 0.0625 0.1408 0.0132 -0.0177\n", 121 | " 0.1821 0.0707 0.1053 -0.0330 0.0306\n", 122 | "[torch.FloatTensor of size 1x5x5]\n", 123 | "\n", 124 | "Variable containing:\n", 125 | "-2.3958\n", 126 | "-2.3270\n", 127 | "-2.1862\n", 128 | "-2.3798\n", 129 | "-2.2650\n", 130 | "-2.3988\n", 131 | "-2.2199\n", 132 | "-2.1820\n", 133 | "-2.3522\n", 134 | "-2.3523\n", 135 | "[torch.FloatTensor of size 10]\n", 136 | "\n", 137 | "Variable containing:\n", 138 | "(0 ,.,.) = \n", 139 | " -0.1216 0.0000 0.0000 0.0000 -0.1600\n", 140 | " 0.0000 -0.1724 0.0000 0.1675 0.0000\n", 141 | " -0.1923 0.1447 -0.1297 0.0000 0.0000\n", 142 | " -0.1815 0.0000 0.1408 0.0000 0.0000\n", 143 | " 0.1821 0.0000 0.1053 0.0000 0.0000\n", 144 | "[torch.FloatTensor of size 1x5x5]\n", 145 | "\n", 146 | "\n", 147 | "(0 ,.,.) = \n", 148 | " 1 0 0 0 1\n", 149 | " 0 1 0 1 0\n", 150 | " 1 1 1 0 0\n", 151 | " 1 0 1 0 0\n", 152 | " 1 0 1 0 0\n", 153 | "[torch.ByteTensor of size 1x5x5]\n", 154 | "\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "#Before and after pruning\n", 160 | "print(model(x)[0])\n", 161 | "print(pruner.masked_model.module.conv1.weight[0])\n", 162 | "pruner.prune(0.5)\n", 163 | "print(model(x)[0])\n", 164 | "print(model.module.conv1.weight[0])\n", 165 | "## Printing the mask tensor of the first conv layer\n", 166 | "print(model._mask_dict[model.module.conv1][0])\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 24, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "Variable containing:\n", 179 | "(0 ,.,.) = \n", 180 | " 0 0 0 0 0\n", 181 | " 0 0 0 0 0\n", 182 | " 0 0 0 0 0\n", 183 | " 0 0 0 0 0\n", 184 | " 0 0 0 0 0\n", 185 | "[torch.FloatTensor of size 1x5x5]\n", 186 | "\n", 187 | "Variable containing:\n", 188 | "(0 ,.,.) = \n", 189 | " 0 0 0 0 0\n", 190 | " 0 0 0 0 0\n", 191 | " 0 0 0 0 0\n", 192 | " 0 0 0 0 0\n", 193 | " 0 0 0 0 0\n", 194 | "[torch.FloatTensor of size 1x5x5]\n", 195 | "\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "#applying mask on gradients\n", 201 | "output = model(x)\n", 202 | "loss = F.nll_loss(output, y)\n", 203 | "loss.backward()\n", 204 | "print(model.module.conv1.weight.grad[0])\n", 205 | "model.apply_mask_on_gradients()\n", 206 | "print(model.module.conv1.weight.grad[0])" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "## Saving and loading maskedModule\n" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 10, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "name": "stdout", 223 | "output_type": "stream", 224 | "text": [ 225 | "MaskedModule(\n", 226 | " (module): Net(\n", 227 | " (conv1): Conv2d (1, 8, kernel_size=(5, 5), stride=(1, 1))\n", 228 | " (conv2): Conv2d (8, 16, kernel_size=(5, 5), stride=(1, 1))\n", 229 | " (fc1): Linear(in_features=256, out_features=64)\n", 230 | " (fc2): Linear(in_features=64, out_features=10)\n", 231 | " )\n", 232 | ")\n", 233 | "\n", 234 | "(0 ,.,.) = \n", 235 | " 0 1 1 1 0\n", 236 | " 1 1 0 0 1\n", 237 | " 1 1 1 0 1\n", 238 | " 0 1 1 1 0\n", 239 | " 1 0 0 0 0\n", 240 | "[torch.ByteTensor of size 1x5x5]\n", 241 | "\n" 242 | ] 243 | }, 244 | { 245 | "name": "stderr", 246 | "output_type": "stream", 247 | "text": [ 248 | "/Users/evcu/anaconda3/lib/python3.6/site-packages/torch/serialization.py:158: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.\n", 249 | " \"type \" + obj.__name__ + \". It won't be checked \"\n" 250 | ] 251 | } 252 | ], 253 | "source": [ 254 | "torch.save(model,'test.mod')\n", 255 | "model2 = torch.load('test.mod')\n", 256 | "print(model2)\n", 257 | "print(model2._mask_dict[model2.module.conv1][0])" 258 | ] 259 | } 260 | ], 261 | "metadata": { 262 | "kernelspec": { 263 | "display_name": "Python 3", 264 | "language": "python", 265 | "name": "python3" 266 | }, 267 | "language_info": { 268 | "codemirror_mode": { 269 | "name": "ipython", 270 | "version": 3 271 | }, 272 | "file_extension": ".py", 273 | "mimetype": "text/x-python", 274 | "name": "python", 275 | "nbconvert_exporter": "python", 276 | "pygments_lexer": "ipython3", 277 | "version": "3.6.3" 278 | } 279 | }, 280 | "nbformat": 4, 281 | "nbformat_minor": 2 282 | } 283 | -------------------------------------------------------------------------------- /pytorchpruner/pruners.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules import Module 3 | from .modules import MaskedModule,meanOutputReplacer 4 | from .utils import _find_fan_out_weights,my_select 5 | from torch.nn import Parameter 6 | from torch.autograd import Variable 7 | from torch import nn 8 | from .scorers import magnitudeScorer 9 | from .unitscorers import normScorer 10 | 11 | def get_pruning_mask(scores,f=0.1): 12 | """ 13 | @params scores scorer 14 | @f lowest f fraction of scores elements would have a 1 in the mask. 15 | """ 16 | srted = torch.sort(scores.view(-1))[0] 17 | if (len(srted)*f)>(len(srted)-1): 18 | #In this case we prune all weights 19 | print("Warning! whole tensor is about to be pruned.") 20 | thres = float('inf') 21 | else: 22 | thres = srted[int(len(srted)*f)] 23 | mask = scores.Tensor or list of torch.Tensor 85 | 1d of size n->returns n*n tensor. 86 | 2d of size m*n -> returns m*n*m*n tensor. 87 | 88 | NOTE: retains the graph 89 | TODO:extend to arbitrary params1 params2. such that we can get any part of the big network hessian. 90 | """ 91 | if isinstance(params,Parameter): 92 | # Case 1 93 | pass 94 | elif isinstance(params,collections.Iterable) and flattened: 95 | # Case 2 96 | params = list(params) #this is to prevent grad eating the generator. to use generator multiple times(model.parameters()) 97 | pass 98 | else: 99 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 100 | parameters or a single parameters" % type(params)) 101 | 102 | 103 | jacobian = gradient_fun(loss, params,flattened=flattened, create_graph=True) 104 | hessian_rows = [] 105 | # Iteratively calculate hessian of L(w) by multipliying the hessian with the one-hot vectors 106 | # note that ind can be a tuple or single int 107 | for ind in index_generator(jacobian.size()): 108 | hessian_row_i = hessian_vector_product( None, #when params_grad given loss is not needed 109 | params, 110 | generate_onehot_var(jacobian.size(),ind), 111 | params_grad = jacobian, 112 | retain_graph = True, 113 | flattened = True) 114 | hessian_rows.append(hessian_row_i) 115 | hessian_shape = jacobian.size()*2 116 | result = torch.stack(hessian_rows).view(hessian_shape) 117 | return result 118 | 119 | 120 | def gradient_fun(loss,params,flattened=False,create_graph=False,retain_graph=True): 121 | """ 122 | loss: a scalar Variable 123 | params: Parameter with params.size()=S 124 | 125 | returns: Tensor with size S containing the gradient. 126 | 127 | """ 128 | if create_graph: 129 | #if you are creating it you are reataining it by default. 130 | retain_graph = True 131 | gradient = torch.autograd.grad(loss, 132 | params, 133 | create_graph=create_graph, 134 | retain_graph=retain_graph) 135 | if flattened: 136 | gradient = flatten_params(gradient) 137 | elif isinstance(params,Parameter): 138 | gradient = gradient[0] 139 | 140 | return gradient 141 | 142 | 143 | 144 | def hessian_vector_product(loss,params,vector,params_grad=None,retain_graph=False,flattened=False): 145 | """ 146 | params: Case 1: Parameter 147 | Then the param:vector should be a Tensor with same size. The result is same size as the Parameter. 148 | Case 2: iterator of Parameters 149 | This is allowed only when flattened=True. 150 | loss: needed only params_grad is not provided 151 | vector: Same size as the params_grad. If you are flattened without providing the params_grad note that your vector 152 | match the size of the flattened parameters. 153 | params_grad: is for preventing recalculation and to be able to use in hessian 154 | flattened: if true then the params should be list of parameters. Then the hessian vector product is flattened. 155 | In this setting I am not returning the reverse functon that flatten_params generate since 156 | the only instance where I flatten is during the hessian and I get the same function during grad calcualtion. 157 | Future use cases may require and one can return. 158 | """ 159 | 160 | #We need a Variable, so ensure 161 | if torch.is_tensor(vector): 162 | vector = Variable(vector,requires_grad=False) 163 | elif isinstance(vector,Variable): 164 | pass 165 | else: 166 | raise ValueError("Vector passed is not a Variable or Tensor: {}".format(type(vector))) 167 | 168 | if isinstance(params,Parameter): 169 | # Case 1 170 | pass 171 | elif isinstance(params,collections.Iterable) and flattened: 172 | # Case 2 173 | params = list(params) 174 | pass 175 | else: 176 | raise ValueError("Invalid type, received: %s. either supply iterable of \ 177 | parameters or a single parameters" % type(params)) 178 | 179 | if isinstance(params_grad,Variable): 180 | pass 181 | else: 182 | params_grad = torch.autograd.grad(loss, params, create_graph=True) 183 | if flattened: 184 | params_grad = flatten_params(params_grad) 185 | else: 186 | params_grad = params_grad[0] 187 | if params_grad.is_cuda: vector= vector.cuda() 188 | # import pdb;pdb.set_trace() 189 | grad_vector_dot = torch.sum(params_grad * vector) 190 | hv_params = torch.autograd.grad(grad_vector_dot, params,retain_graph=retain_graph) 191 | if flattened: 192 | hv_params = flatten_params(hv_params) 193 | else: 194 | hv_params = hv_params[0] 195 | 196 | return hv_params.data 197 | 198 | def my_select(tensor,dim,indices): 199 | if dim==0: 200 | return tensor[indices,] 201 | elif dim==1: 202 | return tensor[:,indices,] 203 | elif dim==2: 204 | return tensor[:,:,indices,] 205 | else: 206 | raise ValueError('That is enough') 207 | 208 | def _find_fan_out_weights(defs_conv,defs_fc,layer_name,unit_index): 209 | """ 210 | Arguments: 211 | - unit_index: int or slice 212 | finding outgoing layer and the indices associated with the unit provided. 213 | 1. Find layer, check unit_index 214 | 2. Save consecutive layer 215 | 3. If the layer=conv next_layer=fc we need to update the slice 216 | Example: 217 | 218 | self.defs_conv = [1, 219 | ('conv1',8,5,2), 220 | ('conv2',16,5,2)] 221 | self.defs_fc = [16*16, 222 | ('fc1',64), 223 | ('fc2',10)] 224 | """ 225 | is_found = False 226 | next_layer = False 227 | 228 | #if slice we need better error checking 229 | if isinstance(unit_index,slice): 230 | compare_f = lambda i,n_out: 0<=min(i.start,i.stop) and max(i.start,i.stop)