├── BayesianLayers.py
├── LICENSE
├── README.md
├── compression.py
├── environment.yml
├── example.py
├── figures
├── pixel.gif
├── weight0_e.gif
└── weight1_e.gif
└── utils.py
/BayesianLayers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Variational Dropout version of linear and convolutional layers
6 |
7 |
8 | Karen Ullrich, Christos Louizos, Oct 2017
9 | """
10 |
11 | import math
12 |
13 | import torch
14 | from torch.nn.parameter import Parameter
15 | import torch.nn.functional as F
16 | from torch import nn
17 | from torch.nn.modules import Module
18 | from torch.autograd import Variable
19 | from torch.nn.modules import utils
20 |
21 |
22 | def reparametrize(mu, logvar, cuda=False, sampling=True):
23 | if sampling:
24 | std = logvar.mul(0.5).exp_()
25 | if cuda:
26 | eps = torch.cuda.FloatTensor(std.size()).normal_()
27 | else:
28 | eps = torch.FloatTensor(std.size()).normal_()
29 | eps = Variable(eps)
30 | return mu + eps * std
31 | else:
32 | return mu
33 |
34 |
35 | # -------------------------------------------------------
36 | # LINEAR LAYER
37 | # -------------------------------------------------------
38 |
39 | class LinearGroupNJ(Module):
40 | """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).
41 |
42 | References:
43 | [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
44 | [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
45 | [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
46 | """
47 |
48 | def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):
49 |
50 | super(LinearGroupNJ, self).__init__()
51 | self.cuda = cuda
52 | self.in_features = in_features
53 | self.out_features = out_features
54 | self.clip_var = clip_var
55 | self.deterministic = False # flag is used for compressed inference
56 | # trainable params according to Eq.(6)
57 | # dropout params
58 | self.z_mu = Parameter(torch.Tensor(in_features))
59 | self.z_logvar = Parameter(torch.Tensor(in_features)) # = z_mu^2 * alpha
60 | # weight params
61 | self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
62 | self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))
63 |
64 | self.bias_mu = Parameter(torch.Tensor(out_features))
65 | self.bias_logvar = Parameter(torch.Tensor(out_features))
66 |
67 | # init params either random or with pretrained net
68 | self.reset_parameters(init_weight, init_bias)
69 |
70 | # activations for kl
71 | self.sigmoid = nn.Sigmoid()
72 | self.softplus = nn.Softplus()
73 |
74 | # numerical stability param
75 | self.epsilon = 1e-8
76 |
77 | def reset_parameters(self, init_weight, init_bias):
78 | # init means
79 | stdv = 1. / math.sqrt(self.weight_mu.size(1))
80 |
81 | self.z_mu.data.normal_(1, 1e-2)
82 |
83 | if init_weight is not None:
84 | self.weight_mu.data = torch.Tensor(init_weight)
85 | else:
86 | self.weight_mu.data.normal_(0, stdv)
87 |
88 | if init_bias is not None:
89 | self.bias_mu.data = torch.Tensor(init_bias)
90 | else:
91 | self.bias_mu.data.fill_(0)
92 |
93 | # init logvars
94 | self.z_logvar.data.normal_(-9, 1e-2)
95 | self.weight_logvar.data.normal_(-9, 1e-2)
96 | self.bias_logvar.data.normal_(-9, 1e-2)
97 |
98 | def clip_variances(self):
99 | if self.clip_var:
100 | self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
101 | self.bias_logvar.data.clamp_(max=math.log(self.clip_var))
102 |
103 | def get_log_dropout_rates(self):
104 | log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
105 | return log_alpha
106 |
107 | def compute_posterior_params(self):
108 | weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
109 | self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
110 | self.post_weight_mu = self.weight_mu * self.z_mu
111 | return self.post_weight_mu, self.post_weight_var
112 |
113 | def forward(self, x):
114 | if self.deterministic:
115 | assert self.training == False, "Flag deterministic is True. This should not be used in training."
116 | return F.linear(x, self.post_weight_mu, self.bias_mu)
117 |
118 | batch_size = x.size()[0]
119 | # compute z
120 | # note that we reparametrise according to [2] Eq. (11) (not [1])
121 | z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
122 | cuda=self.cuda)
123 |
124 | # apply local reparametrisation trick see [1] Eq. (6)
125 | # to the parametrisation given in [3] Eq. (6)
126 | xz = x * z
127 | mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
128 | var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())
129 |
130 | return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)
131 |
132 | def kl_divergence(self):
133 | # KL(q(z)||p(z))
134 | # we use the kl divergence approximation given by [2] Eq.(14)
135 | k1, k2, k3 = 0.63576, 1.87320, 1.48695
136 | log_alpha = self.get_log_dropout_rates()
137 | KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)
138 |
139 | # KL(q(w|z)||p(w|z))
140 | # we use the kl divergence given by [3] Eq.(8)
141 | KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
142 | KLD += torch.sum(KLD_element)
143 |
144 | # KL bias
145 | KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
146 | KLD += torch.sum(KLD_element)
147 |
148 | return KLD
149 |
150 | def __repr__(self):
151 | return self.__class__.__name__ + ' (' \
152 | + str(self.in_features) + ' -> ' \
153 | + str(self.out_features) + ')'
154 |
155 |
156 | # -------------------------------------------------------
157 | # CONVOLUTIONAL LAYER
158 | # -------------------------------------------------------
159 |
160 | class _ConvNdGroupNJ(Module):
161 | """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout).
162 |
163 | References:
164 | [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
165 | [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
166 | [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
167 | """
168 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding,
169 | groups, bias, init_weight, init_bias, cuda=False, clip_var=None):
170 | super(_ConvNdGroupNJ, self).__init__()
171 | if in_channels % groups != 0:
172 | raise ValueError('in_channels must be divisible by groups')
173 | if out_channels % groups != 0:
174 | raise ValueError('out_channels must be divisible by groups')
175 | self.in_channels = in_channels
176 | self.out_channels = out_channels
177 | self.kernel_size = kernel_size
178 | self.stride = stride
179 | self.padding = padding
180 | self.dilation = dilation
181 | self.transposed = transposed
182 | self.output_padding = output_padding
183 | self.groups = groups
184 |
185 | self.cuda = cuda
186 | self.clip_var = clip_var
187 | self.deterministic = False # flag is used for compressed inference
188 |
189 | if transposed:
190 | self.weight_mu = Parameter(torch.Tensor(
191 | in_channels, out_channels // groups, *kernel_size))
192 | self.weight_logvar = Parameter(torch.Tensor(
193 | in_channels, out_channels // groups, *kernel_size))
194 | else:
195 | self.weight_mu = Parameter(torch.Tensor(
196 | out_channels, in_channels // groups, *kernel_size))
197 | self.weight_logvar = Parameter(torch.Tensor(
198 | out_channels, in_channels // groups, *kernel_size))
199 |
200 | self.bias_mu = Parameter(torch.Tensor(out_channels))
201 | self.bias_logvar = Parameter(torch.Tensor(out_channels))
202 |
203 | self.z_mu = Parameter(torch.Tensor(self.out_channels))
204 | self.z_logvar = Parameter(torch.Tensor(self.out_channels))
205 |
206 | self.reset_parameters(init_weight, init_bias)
207 |
208 | # activations for kl
209 | self.sigmoid = nn.Sigmoid()
210 | self.softplus = nn.Softplus()
211 | # numerical stability param
212 | self.epsilon = 1e-8
213 |
214 | def reset_parameters(self, init_weight, init_bias):
215 | # init means
216 | n = self.in_channels
217 | for k in self.kernel_size:
218 | n *= k
219 | stdv = 1. / math.sqrt(n)
220 |
221 | # init means
222 | if init_weight is not None:
223 | self.weight_mu.data = init_weight
224 | else:
225 | self.weight_mu.data.uniform_(-stdv, stdv)
226 |
227 | if init_bias is not None:
228 | self.bias_mu.data = init_bias
229 | else:
230 | self.bias_mu.data.fill_(0)
231 |
232 | # inti z
233 | self.z_mu.data.normal_(1, 1e-2)
234 |
235 | # init logvars
236 | self.z_logvar.data.normal_(-9, 1e-2)
237 | self.weight_logvar.data.normal_(-9, 1e-2)
238 | self.bias_logvar.data.normal_(-9, 1e-2)
239 |
240 | def clip_variances(self):
241 | if self.clip_var:
242 | self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
243 | self.bias_logvar.data.clamp_(max=math.log(self.clip_var))
244 |
245 | def get_log_dropout_rates(self):
246 | log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
247 | return log_alpha
248 |
249 | def compute_posterior_params(self):
250 | weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
251 | self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
252 | self.post_weight_mu = self.weight_mu * self.z_mu
253 | return self.post_weight_mu, self.post_weight_var
254 |
255 | def kl_divergence(self):
256 | # KL(q(z)||p(z))
257 | # we use the kl divergence approximation given by [2] Eq.(14)
258 | k1, k2, k3 = 0.63576, 1.87320, 1.48695
259 | log_alpha = self.get_log_dropout_rates()
260 | KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)
261 |
262 | # KL(q(w|z)||p(w|z))
263 | # we use the kl divergence given by [3] Eq.(8)
264 | KLD_element = - 0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
265 | KLD += torch.sum(KLD_element)
266 |
267 | # KL bias
268 | KLD_element = - 0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
269 | KLD += torch.sum(KLD_element)
270 |
271 | return KLD
272 |
273 | def __repr__(self):
274 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
275 | ', stride={stride}')
276 | if self.padding != (0,) * len(self.padding):
277 | s += ', padding={padding}'
278 | if self.dilation != (1,) * len(self.dilation):
279 | s += ', dilation={dilation}'
280 | if self.output_padding != (0,) * len(self.output_padding):
281 | s += ', output_padding={output_padding}'
282 | if self.groups != 1:
283 | s += ', groups={groups}'
284 | if self.bias is None:
285 | s += ', bias=False'
286 | s += ')'
287 | return s.format(name=self.__class__.__name__, **self.__dict__)
288 |
289 |
290 | class Conv1dGroupNJ(_ConvNdGroupNJ):
291 | r"""
292 | """
293 |
294 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
295 | cuda=False, init_weight=None, init_bias=None, clip_var=None):
296 | kernel_size = utils._single(kernel_size)
297 | stride = utils._single(stride)
298 | padding = utils._single(padding)
299 | dilation = utils._single(dilation)
300 |
301 | super(Conv1dGroupNJ, self).__init__(
302 | in_channels, out_channels, kernel_size, stride, padding, dilation,
303 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var)
304 |
305 | def forward(self, x):
306 | if self.deterministic:
307 | assert self.training == False, "Flag deterministic is True. This should not be used in training."
308 | return F.conv1d(x, self.post_weight_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups)
309 | batch_size = x.size()[0]
310 | # apply local reparametrisation trick see [1] Eq. (6)
311 | # to the parametrisation given in [3] Eq. (6)
312 | mu_activations = F.conv1d(x, self.weight_mu, self.bias_mu, self.stride,
313 | self.padding, self.dilation, self.groups)
314 |
315 | var_activations = F.conv1d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride,
316 | self.padding, self.dilation, self.groups)
317 | # compute z
318 | # note that we reparametrise according to [2] Eq. (11) (not [1])
319 | z = reparametrize(self.z_mu.repeat(batch_size, 1, 1), self.z_logvar.repeat(batch_size, 1, 1),
320 | sampling=self.training, cuda=self.cuda)
321 | z = z[:, :, None]
322 |
323 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
324 | cuda=self.cuda)
325 |
326 | def __repr__(self):
327 | return self.__class__.__name__ + ' (' \
328 | + str(self.in_features) + ' -> ' \
329 | + str(self.out_features) + ')'
330 |
331 |
332 | class Conv2dGroupNJ(_ConvNdGroupNJ):
333 | r"""
334 | """
335 |
336 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
337 | cuda=False, init_weight=None, init_bias=None, clip_var=None):
338 | kernel_size = utils._pair(kernel_size)
339 | stride = utils._pair(stride)
340 | padding = utils._pair(padding)
341 | dilation = utils._pair(dilation)
342 |
343 | super(Conv2dGroupNJ, self).__init__(
344 | in_channels, out_channels, kernel_size, stride, padding, dilation,
345 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var)
346 |
347 | def forward(self, x):
348 | if self.deterministic:
349 | assert self.training == False, "Flag deterministic is True. This should not be used in training."
350 | return F.conv2d(x, self.post_weight_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups)
351 | batch_size = x.size()[0]
352 | # apply local reparametrisation trick see [1] Eq. (6)
353 | # to the parametrisation given in [3] Eq. (6)
354 | mu_activations = F.conv2d(x, self.weight_mu, self.bias_mu, self.stride,
355 | self.padding, self.dilation, self.groups)
356 |
357 | var_activations = F.conv2d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride,
358 | self.padding, self.dilation, self.groups)
359 | # compute z
360 | # note that we reparametrise according to [2] Eq. (11) (not [1])
361 | z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1),
362 | sampling=self.training, cuda=self.cuda)
363 | z = z[:, :, None, None]
364 |
365 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
366 | cuda=self.cuda)
367 |
368 | def __repr__(self):
369 | return self.__class__.__name__ + ' (' \
370 | + str(self.in_features) + ' -> ' \
371 | + str(self.out_features) + ')'
372 |
373 |
374 | class Conv3dGroupNJ(_ConvNdGroupNJ):
375 | r"""
376 | """
377 |
378 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
379 | cuda=False, init_weight=None, init_bias=None, clip_var=None):
380 | kernel_size = utils._triple(kernel_size)
381 | stride = utils._triple(stride)
382 | padding = utils._triple(padding)
383 | dilation = utils.triple(dilation)
384 |
385 | super(Conv3dGroupNJ, self).__init__(
386 | in_channels, out_channels, kernel_size, stride, padding, dilation,
387 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var)
388 |
389 | def forward(self, x):
390 | if self.deterministic:
391 | assert self.training == False, "Flag deterministic is True. This should not be used in training."
392 | return F.conv3d(x, self.post_weight_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups)
393 | batch_size = x.size()[0]
394 | # apply local reparametrisation trick see [1] Eq. (6)
395 | # to the parametrisation given in [3] Eq. (6)
396 | mu_activations = F.conv3d(x, self.weight_mu, self.bias_mu, self.stride,
397 | self.padding, self.dilation, self.groups)
398 |
399 | var_weights = self.weight_logvar.exp()
400 | var_activations = F.conv3d(x.pow(2), var_weights, self.bias_logvar.exp(), self.stride,
401 | self.padding, self.dilation, self.groups)
402 | # compute z
403 | # note that we reparametrise according to [2] Eq. (11) (not [1])
404 | z = reparametrize(self.z_mu.repeat(batch_size, 1, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1, 1),
405 | sampling=self.training, cuda=self.cuda)
406 | z = z[:, :, None, None, None]
407 |
408 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
409 | cuda=self.cuda)
410 |
411 | def __repr__(self):
412 | return self.__class__.__name__ + ' (' \
413 | + str(self.in_features) + ' -> ' \
414 | + str(self.out_features) + ')'
415 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Karen Ullrich
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Code release for "Bayesian Compression for Deep Learning"
2 |
3 |
4 | In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of neural networks.
5 | By revisiting the connection between the minimum description length principle and variational inference we are
6 | able to achieve up to 700x compression and up to 50x speed up (CPU to sparse GPU) for neural networks.
7 |
8 | We visualize the learning process in the following figures for a dense network with 300 and 100 connections.
9 | White color represents redundancy whereas red and blue represent positive and negative weights respectively.
10 |
11 | |First layer weights |Second Layer weights|
12 | | :------ |:------: |
13 | |||
14 |
15 | For dense networks it is also simple to reconstruct input feature importance. We show this for a mask and 5 randomly chosen digits.
16 | 
17 |
18 |
19 | ## Results
20 |
21 |
22 | | Model | Method | Error [%] | Compression
after pruning | Compression after
precision reduction |
23 | | ------ | :------ |:------: | ------: |------: |
24 | |LeNet-5-Caffe |[DC](https://arxiv.org/abs/1510.00149) | 0.7 | 6* | -|
25 | | |[DNS](https://arxiv.org/abs/1608.04493) | 0.9 | 55* | -|
26 | | |[SWS](https://arxiv.org/abs/1702.04008) | 1.0 | 100* | -|
27 | | |[Sparse VD](https://arxiv.org/pdf/1701.05369.pdf) | 1.0 | 63* | 228|
28 | | |BC-GNJ | 1.0 | 108* | 361|
29 | | |BC-GHS | 1.0 | 156* | 419|
30 | | VGG |BC-GNJ | 8.6 | 14* | 56|
31 | | |BC-GHS | 9.0 | 18* | 59|
32 |
33 | ## Usage
34 | We provide an implementation in PyTorch for fully connected and convolutional layers for the group normal-Jeffreys prior (aka Group Variational Dropout) via:
35 | ```python
36 | import BayesianLayers
37 | ```
38 | The layers can be then straightforwardly included eas follows:
39 | ```python
40 | class Net(nn.Module):
41 | def __init__(self):
42 | super(Net, self).__init__()
43 | # activation
44 | self.relu = nn.ReLU()
45 | # layers
46 | self.fc1 = BayesianLayers.LinearGroupNJ(28 * 28, 300, clip_var=0.04)
47 | self.fc2 = BayesianLayers.LinearGroupNJ(300, 100)
48 | self.fc3 = BayesianLayers.LinearGroupNJ(100, 10)
49 | # layers including kl_divergence
50 | self.kl_list = [self.fc1, self.fc2, self.fc3]
51 |
52 | def forward(self, x):
53 | x = x.view(-1, 28 * 28)
54 | x = self.relu(self.fc1(x))
55 | x = self.relu(self.fc2(x))
56 | return self.fc3(x)
57 |
58 | def kl_divergence(self):
59 | KLD = 0
60 | for layer in self.kl_list:
61 | KLD += layer.kl_divergence()
62 | return KLD
63 | ```
64 | The only additional effort is to include the KL-divergence in the objective.
65 | This is necessary if we want to the optimize the variational lower bound that leads to sparse solutions:
66 | ```python
67 | N = 60000.
68 | discrimination_loss = nn.functional.cross_entropy
69 |
70 | def objective(output, target, kl_divergence):
71 | discrimination_error = discrimination_loss(output, target)
72 | return discrimination_error + kl_divergence / N
73 | ```
74 | ## Run an example
75 | We provide a simple example, the LeNet-300-100 trained with the group normal-Jeffreys prior:
76 | ```sh
77 | python example.py
78 | ```
79 |
80 | ## Retraining a regular neural network
81 | Instead of training a network from scratch we often need to compress an already existing network.
82 | In this case we can simply initialize the weights with those of the pretrained network:
83 | ```python
84 | BayesianLayers.LinearGroupNJ(28*28, 300, init_weight=pretrained_weight, init_bias=pretrained_bias)
85 | ```
86 | ## *Reference*
87 | The paper "Bayesian Compression for Deep Learning" has been accepted to NIPS 2017. Please cite us:
88 |
89 | @article{louizos2017bayesian,
90 | title={Bayesian Compression for Deep Learning},
91 | author={Louizos, Christos and Ullrich, Karen and Welling, Max},
92 | journal={Conference on Neural Information Processing Systems (NIPS)},
93 | year={2017}
94 | }
--------------------------------------------------------------------------------
/compression.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Compression Tools
6 |
7 |
8 | Karen Ullrich, Oct 2017
9 |
10 | References:
11 |
12 | [1] Michael T. Heath. 1996. Scientific Computing: An Introductory Survey (2nd ed.). Eric M. Munson (Ed.). McGraw-Hill Higher Education. Chapter 1
13 | """
14 |
15 | import numpy as np
16 |
17 | # -------------------------------------------------------
18 | # General tools
19 | # -------------------------------------------------------
20 |
21 |
22 | def unit_round_off(t=23):
23 | """
24 | :param t:
25 | number significand bits
26 | :return:
27 | unit round off based on nearest interpolation, for reference see [1]
28 | """
29 | return 0.5 * 2. ** (1. - t)
30 |
31 |
32 | SIGNIFICANT_BIT_PRECISION = [unit_round_off(t=i + 1) for i in range(23)]
33 |
34 |
35 | def float_precision(x):
36 |
37 | out = np.sum([x < sbp for sbp in SIGNIFICANT_BIT_PRECISION])
38 | return out
39 |
40 |
41 | def float_precisions(X, dist_fun, layer=1):
42 |
43 | X = X.flatten()
44 | out = [float_precision(2 * x) for x in X]
45 | out = np.ceil(dist_fun(out))
46 | return out
47 |
48 |
49 | def special_round(input, significant_bit):
50 | delta = unit_round_off(t=significant_bit)
51 | rounded = np.floor(input / delta + 0.5)
52 | rounded = rounded * delta
53 | return rounded
54 |
55 |
56 | def fast_infernce_weights(w, exponent_bit, significant_bit):
57 |
58 | return special_round(w, significant_bit)
59 |
60 |
61 | def compress_matrix(x):
62 |
63 | if len(x.shape) != 2:
64 | A, B, C, D = x.shape
65 | x = x.reshape(A * B, C * D)
66 | # remove non-necessary filters and rows
67 | x = x[:, (x != 0).any(axis=0)]
68 | x = x[(x != 0).any(axis=1), :]
69 | else:
70 | # remove unnecessary rows, columns
71 | x = x[(x != 0).any(axis=1), :]
72 | x = x[:, (x != 0).any(axis=0)]
73 | return x
74 |
75 |
76 | def extract_pruned_params(layers, masks):
77 |
78 | post_weight_mus = []
79 | post_weight_vars = []
80 |
81 | for i, (layer, mask) in enumerate(zip(layers, masks)):
82 | # compute posteriors
83 | post_weight_mu, post_weight_var = layer.compute_posterior_params()
84 | post_weight_var = post_weight_var.cpu().data.numpy()
85 | post_weight_mu = post_weight_mu.cpu().data.numpy()
86 | # apply mask to mus and variances
87 | post_weight_mu = post_weight_mu * mask
88 | post_weight_var = post_weight_var * mask
89 |
90 | post_weight_mus.append(post_weight_mu)
91 | post_weight_vars.append(post_weight_var)
92 |
93 | return post_weight_mus, post_weight_vars
94 |
95 |
96 | # -------------------------------------------------------
97 | # Compression rates (fast inference scenario)
98 | # -------------------------------------------------------
99 |
100 |
101 | def _compute_compression_rate(vars, in_precision=32., dist_fun=lambda x: np.max(x), overflow=10e38):
102 | # compute in number of bits occupied by the original architecture
103 | sizes = [v.size for v in vars]
104 | nb_weights = float(np.sum(sizes))
105 | IN_BITS = in_precision * nb_weights
106 | # prune architecture
107 | vars = [compress_matrix(v) for v in vars]
108 | sizes = [v.size for v in vars]
109 | # compute
110 | significant_bits = [float_precisions(v, dist_fun, layer=k + 1) for k, v in enumerate(vars)]
111 | exponent_bit = np.ceil(np.log2(np.log2(overflow) + 1.) + 1.)
112 | total_bits = [1. + exponent_bit + sb for sb in significant_bits]
113 | OUT_BITS = np.sum(np.asarray(sizes) * np.asarray(total_bits))
114 | return nb_weights / np.sum(sizes), IN_BITS / OUT_BITS, significant_bits, exponent_bit
115 |
116 |
117 | def compute_compression_rate(layers, masks):
118 | # reduce architecture
119 | weight_mus, weight_vars = extract_pruned_params(layers, masks)
120 | # compute overflow level based on maximum weight
121 | overflow = np.max([np.max(np.abs(w)) for w in weight_mus])
122 | # compute compression rate
123 | CR_architecture, CR_fast_inference, _, _ = _compute_compression_rate(weight_vars, dist_fun=lambda x: np.mean(x), overflow=overflow)
124 | print("Compressing the architecture will degrease the model by a factor of %.1f." % (CR_architecture))
125 | print("Making use of weight uncertainty can reduce the model by a factor of %.1f." % (CR_fast_inference))
126 |
127 |
128 | def compute_reduced_weights(layers, masks):
129 | weight_mus, weight_vars = extract_pruned_params(layers, masks)
130 | overflow = np.max([np.max(np.abs(w)) for w in weight_mus])
131 | _, _, significant_bits, exponent_bits = _compute_compression_rate(weight_vars, dist_fun=lambda x: np.mean(x), overflow=overflow)
132 | weights = [fast_infernce_weights(weight_mu, exponent_bits, significant_bit) for weight_mu, significant_bit in
133 | zip(weight_mus, significant_bits)]
134 | return weights
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: BCDL
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - ca-certificates=2018.03.07=0
7 | - certifi=2018.1.18=py27_0
8 | - cffi=1.11.5=py27h9745a5d_0
9 | - cudatoolkit=8.0=3
10 | - freetype=2.8=hab7d2ae_1
11 | - imageio=2.3.0=py27_0
12 | - intel-openmp=2018.0.0=8
13 | - jpeg=9b=h024ee3a_2
14 | - libedit=3.1=heed3624_0
15 | - libffi=3.2.1=hd88cf55_4
16 | - libgcc-ng=7.2.0=hdf63c60_3
17 | - libgfortran-ng=7.2.0=hdf63c60_3
18 | - libpng=1.6.34=hb9fc6fc_0
19 | - libstdcxx-ng=7.2.0=hdf63c60_3
20 | - libtiff=4.0.9=h28f6b97_0
21 | - mkl=2018.0.2=1
22 | - mkl_fft=1.0.1=py27h3010b51_0
23 | - mkl_random=1.0.1=py27h629b387_0
24 | - ncurses=6.0=h9df7e31_2
25 | - numpy=1.14.2=py27hdbf6ddf_1
26 | - olefile=0.45.1=py27_0
27 | - openssl=1.0.2o=h20670df_0
28 | - pillow=5.0.0=py27h3deb7b8_0
29 | - pip=9.0.3=py27_0
30 | - pycparser=2.18=py27hefa08c5_1
31 | - python=2.7.14=h1571d57_31
32 | - readline=7.0=ha6073c6_4
33 | - scipy=1.0.1=py27hfc37229_0
34 | - setuptools=39.0.1=py27_0
35 | - six=1.11.0=py27h5f960f1_1
36 | - sqlite=3.22.0=h1bed415_0
37 | - tk=8.6.7=hc745277_3
38 | - wheel=0.31.0=py27_0
39 | - xz=5.2.3=h55aa19d_2
40 | - zlib=1.2.11=ha838bed_2
41 | - pytorch=0.3.1=py27_cuda8.0.61_cudnn7.1.2_3
42 | - torchvision=0.2.0=py27hfb27419_1
43 | - pip:
44 | - backports-abc==0.4
45 | - backports.functools-lru-cache==1.5
46 | - backports.ssl-match-hostname==3.5.0.1
47 | - cycler==0.10.0
48 | - functools32==3.2.3.post2
49 | - ipython-genutils==0.1.0
50 | - ipywidgets==4.1.1
51 | - jsonschema==2.5.1
52 | - kiwisolver==1.0.1
53 | - matplotlib==2.2.2
54 | - nbformat==4.0.1
55 | - pandas==0.22.0
56 | - path.py==8.1.2
57 | - ptyprocess==0.5.1
58 | - pyparsing==2.2.0
59 | - python-dateutil==2.7.2
60 | - pytz==2018.4
61 | - seaborn==0.8.1
62 | - singledispatch==3.4.0.3
63 | - subprocess32==3.2.7
64 | - terminado==0.6
65 | - torch==0.3.1.post3
66 | - tornado==4.3
67 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Linear Bayesian Model
6 |
7 |
8 | Karen Ullrich, Christos Louizos, Oct 2017
9 | """
10 |
11 |
12 | # libraries
13 | from __future__ import print_function
14 | import numpy as np
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import torch.optim as optim
20 | from torchvision import datasets, transforms
21 | from torch.autograd import Variable
22 |
23 | import BayesianLayers
24 | from compression import compute_compression_rate, compute_reduced_weights
25 | from utils import visualize_pixel_importance, generate_gif, visualise_weights
26 |
27 | N = 60000. # number of data points in the training set
28 |
29 |
30 | def main():
31 | # import data
32 | kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {}
33 |
34 | train_loader = torch.utils.data.DataLoader(
35 | datasets.MNIST('./data', train=True, download=True,
36 | transform=transforms.Compose([
37 | transforms.ToTensor(),lambda x: 2 * (x - 0.5),
38 | ])),
39 | batch_size=FLAGS.batchsize, shuffle=True, **kwargs)
40 |
41 | test_loader = torch.utils.data.DataLoader(
42 | datasets.MNIST('./data', train=False, transform=transforms.Compose([
43 | transforms.ToTensor(), lambda x: 2 * (x - 0.5),
44 | ])),
45 | batch_size=FLAGS.batchsize, shuffle=True, **kwargs)
46 |
47 | # for later analysis we take some sample digits
48 | mask = 255. * (np.ones((1, 28, 28)))
49 | examples = train_loader.sampler.data_source.train_data[0:5].numpy()
50 | images = np.vstack([mask, examples])
51 |
52 | # build a simple MLP
53 | class Net(nn.Module):
54 | def __init__(self):
55 | super(Net, self).__init__()
56 | # activation
57 | self.relu = nn.ReLU()
58 | # layers
59 | self.fc1 = BayesianLayers.LinearGroupNJ(28 * 28, 300, clip_var=0.04, cuda=FLAGS.cuda)
60 | self.fc2 = BayesianLayers.LinearGroupNJ(300, 100, cuda=FLAGS.cuda)
61 | self.fc3 = BayesianLayers.LinearGroupNJ(100, 10, cuda=FLAGS.cuda)
62 | # layers including kl_divergence
63 | self.kl_list = [self.fc1, self.fc2, self.fc3]
64 |
65 | def forward(self, x):
66 | x = x.view(-1, 28 * 28)
67 | x = self.relu(self.fc1(x))
68 | x = self.relu(self.fc2(x))
69 | return self.fc3(x)
70 |
71 | def get_masks(self,thresholds):
72 | weight_masks = []
73 | mask = None
74 | for i, (layer, threshold) in enumerate(zip(self.kl_list, thresholds)):
75 | # compute dropout mask
76 | if mask is None:
77 | log_alpha = layer.get_log_dropout_rates().cpu().data.numpy()
78 | mask = log_alpha < threshold
79 | else:
80 | mask = np.copy(next_mask)
81 | try:
82 | log_alpha = layers[i + 1].get_log_dropout_rates().cpu().data.numpy()
83 | next_mask = log_alpha < thresholds[i + 1]
84 | except:
85 | # must be the last mask
86 | next_mask = np.ones(10)
87 |
88 | weight_mask = np.expand_dims(mask, axis=0) * np.expand_dims(next_mask, axis=1)
89 | weight_masks.append(weight_mask.astype(np.float))
90 | return weight_masks
91 |
92 | def kl_divergence(self):
93 | KLD = 0
94 | for layer in self.kl_list:
95 | KLD += layer.kl_divergence()
96 | return KLD
97 |
98 | # init model
99 | model = Net()
100 | if FLAGS.cuda:
101 | model.cuda()
102 |
103 | # init optimizer
104 | optimizer = optim.Adam(model.parameters())
105 |
106 | # we optimize the variational lower bound scaled by the number of data
107 | # points (so we can keep our intuitions about hyper-params such as the learning rate)
108 | discrimination_loss = nn.functional.cross_entropy
109 |
110 | def objective(output, target, kl_divergence):
111 | discrimination_error = discrimination_loss(output, target)
112 | variational_bound = discrimination_error + kl_divergence / N
113 | if FLAGS.cuda:
114 | variational_bound = variational_bound.cuda()
115 | return variational_bound
116 |
117 | def train(epoch):
118 | model.train()
119 | for batch_idx, (data, target) in enumerate(train_loader):
120 | if FLAGS.cuda:
121 | data, target = data.cuda(), target.cuda()
122 | data, target = Variable(data), Variable(target)
123 | optimizer.zero_grad()
124 | output = model(data)
125 | loss = objective(output, target, model.kl_divergence())
126 | loss.backward()
127 | optimizer.step()
128 | # clip the variances after each step
129 | for layer in model.kl_list:
130 | layer.clip_variances()
131 | print('Epoch: {} \tTrain loss: {:.6f} \t'.format(
132 | epoch, loss.data[0]))
133 |
134 | def test():
135 | model.eval()
136 | test_loss = 0
137 | correct = 0
138 | for data, target in test_loader:
139 | if FLAGS.cuda:
140 | data, target = data.cuda(), target.cuda()
141 | data, target = Variable(data, volatile=True), Variable(target)
142 | output = model(data)
143 | test_loss += discrimination_loss(output, target, size_average=False).data[0]
144 | pred = output.data.max(1, keepdim=True)[1]
145 | correct += pred.eq(target.data.view_as(pred)).cpu().sum()
146 | test_loss /= len(test_loader.dataset)
147 | print('Test loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
148 | test_loss, correct, len(test_loader.dataset),
149 | 100. * correct / len(test_loader.dataset)))
150 |
151 | # train the model and save some visualisations on the way
152 | for epoch in range(1, FLAGS.epochs + 1):
153 | train(epoch)
154 | test()
155 | # visualizations
156 | weight_mus = [model.fc1.weight_mu, model.fc2.weight_mu]
157 | log_alphas = [model.fc1.get_log_dropout_rates(), model.fc2.get_log_dropout_rates(),
158 | model.fc3.get_log_dropout_rates()]
159 | visualise_weights(weight_mus, log_alphas, epoch=epoch)
160 | log_alpha = model.fc1.get_log_dropout_rates().cpu().data.numpy()
161 | visualize_pixel_importance(images, log_alpha=log_alpha, epoch=str(epoch))
162 |
163 | generate_gif(save='pixel', epochs=FLAGS.epochs)
164 | generate_gif(save='weight0_e', epochs=FLAGS.epochs)
165 | generate_gif(save='weight1_e', epochs=FLAGS.epochs)
166 |
167 | # compute compression rate and new model accuracy
168 | layers = [model.fc1, model.fc2, model.fc3]
169 | thresholds = FLAGS.thresholds
170 | compute_compression_rate(layers, model.get_masks(thresholds))
171 |
172 | print("Test error after with reduced bit precision:")
173 |
174 | weights = compute_reduced_weights(layers, model.get_masks(thresholds))
175 | for layer, weight in zip(layers, weights):
176 | if FLAGS.cuda:
177 | layer.post_weight_mu.data = torch.Tensor(weight).cuda()
178 | else:
179 | layer.post_weight_mu.data = torch.Tensor(weight)
180 | for layer in layers: layer.deterministic = True
181 | test()
182 |
183 |
184 | if __name__ == '__main__':
185 | import argparse
186 | parser = argparse.ArgumentParser()
187 | parser.add_argument('--epochs', type=int, default=5)
188 | parser.add_argument('--batchsize', type=int, default=128)
189 | parser.add_argument('--thresholds', type=float, nargs='*', default=[-2.8, -3., -5.])
190 |
191 | FLAGS = parser.parse_args()
192 | FLAGS.cuda = torch.cuda.is_available() # check if we can put the net on the GPU
193 | main()
194 |
--------------------------------------------------------------------------------
/figures/pixel.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KarenUllrich/Tutorial_BayesianCompressionForDL/f1e7c7910a61d5ce86490089e82cbbfb01119052/figures/pixel.gif
--------------------------------------------------------------------------------
/figures/weight0_e.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KarenUllrich/Tutorial_BayesianCompressionForDL/f1e7c7910a61d5ce86490089e82cbbfb01119052/figures/weight0_e.gif
--------------------------------------------------------------------------------
/figures/weight1_e.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KarenUllrich/Tutorial_BayesianCompressionForDL/f1e7c7910a61d5ce86490089e82cbbfb01119052/figures/weight1_e.gif
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Utilities
6 |
7 |
8 | Karen Ullrich, Oct 2017
9 | """
10 |
11 | import os
12 | import numpy as np
13 | import imageio
14 |
15 | import matplotlib.pyplot as plt
16 | import seaborn as sns
17 |
18 | sns.set_style("whitegrid")
19 | cmap = sns.diverging_palette(240, 10, sep=100, as_cmap=True)
20 |
21 | # -------------------------------------------------------
22 | # VISUALISATION TOOLS
23 | # -------------------------------------------------------
24 |
25 |
26 | def visualize_pixel_importance(imgs, log_alpha, epoch="pixel_importance"):
27 | num_imgs = len(imgs)
28 |
29 | f, ax = plt.subplots(1, num_imgs)
30 | plt.title("Epoch:" + epoch)
31 | for i, img in enumerate(imgs):
32 | img = (img / 255.) - 0.5
33 | mask = log_alpha.reshape(img.shape)
34 | mask = 1 - np.clip(np.exp(mask), 0.0, 1)
35 | ax[i].imshow(img * mask, cmap=cmap, interpolation='none', vmin=-0.5, vmax=0.5)
36 | ax[i].grid("off")
37 | ax[i].set_yticks([])
38 | ax[i].set_xticks([])
39 | plt.savefig("./.pixel" + epoch + ".png", bbox_inches='tight')
40 | plt.close()
41 |
42 |
43 | def visualise_weights(weight_mus, log_alphas, epoch):
44 | num_layers = len(weight_mus)
45 |
46 | for i in range(num_layers):
47 | f, ax = plt.subplots(1, 1)
48 | weight_mu = np.transpose(weight_mus[i].cpu().data.numpy())
49 | # alpha
50 | log_alpha_fc1 = log_alphas[i].unsqueeze(1).cpu().data.numpy()
51 | log_alpha_fc1 = log_alpha_fc1 < -3
52 | log_alpha_fc2 = log_alphas[i + 1].unsqueeze(0).cpu().data.numpy()
53 | log_alpha_fc2 = log_alpha_fc2 < -3
54 | mask = log_alpha_fc1 + log_alpha_fc2
55 | # weight
56 | c = np.max(np.abs(weight_mu))
57 | s = ax.imshow(weight_mu * mask, cmap='seismic', interpolation='none', vmin=-c, vmax=c)
58 | ax.grid("off")
59 | ax.set_yticks([])
60 | ax.set_xticks([])
61 | s.set_clim([-c * 0.5, c * 0.5])
62 | f.colorbar(s)
63 | plt.title("Epoch:" + str(epoch))
64 | plt.savefig("./.weight" + str(i) + '_e' + str(epoch) + ".png", bbox_inches='tight')
65 | plt.close()
66 |
67 |
68 | def generate_gif(save='tmp', epochs=10):
69 | images = []
70 | filenames = ["./." + save + "%d.png" % (epoch + 1) for epoch in np.arange(epochs)]
71 | for filename in filenames:
72 | images.append(imageio.imread(filename))
73 | os.remove(filename)
74 | imageio.mimsave('./figures/' + save + '.gif', images, duration=.5)
75 |
--------------------------------------------------------------------------------