├── BayesianLayers.py ├── LICENSE ├── README.md ├── compression.py ├── environment.yml ├── example.py ├── figures ├── pixel.gif ├── weight0_e.gif └── weight1_e.gif └── utils.py /BayesianLayers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Variational Dropout version of linear and convolutional layers 6 | 7 | 8 | Karen Ullrich, Christos Louizos, Oct 2017 9 | """ 10 | 11 | import math 12 | 13 | import torch 14 | from torch.nn.parameter import Parameter 15 | import torch.nn.functional as F 16 | from torch import nn 17 | from torch.nn.modules import Module 18 | from torch.autograd import Variable 19 | from torch.nn.modules import utils 20 | 21 | 22 | def reparametrize(mu, logvar, cuda=False, sampling=True): 23 | if sampling: 24 | std = logvar.mul(0.5).exp_() 25 | if cuda: 26 | eps = torch.cuda.FloatTensor(std.size()).normal_() 27 | else: 28 | eps = torch.FloatTensor(std.size()).normal_() 29 | eps = Variable(eps) 30 | return mu + eps * std 31 | else: 32 | return mu 33 | 34 | 35 | # ------------------------------------------------------- 36 | # LINEAR LAYER 37 | # ------------------------------------------------------- 38 | 39 | class LinearGroupNJ(Module): 40 | """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout). 41 | 42 | References: 43 | [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). 44 | [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). 45 | [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). 46 | """ 47 | 48 | def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None): 49 | 50 | super(LinearGroupNJ, self).__init__() 51 | self.cuda = cuda 52 | self.in_features = in_features 53 | self.out_features = out_features 54 | self.clip_var = clip_var 55 | self.deterministic = False # flag is used for compressed inference 56 | # trainable params according to Eq.(6) 57 | # dropout params 58 | self.z_mu = Parameter(torch.Tensor(in_features)) 59 | self.z_logvar = Parameter(torch.Tensor(in_features)) # = z_mu^2 * alpha 60 | # weight params 61 | self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) 62 | self.weight_logvar = Parameter(torch.Tensor(out_features, in_features)) 63 | 64 | self.bias_mu = Parameter(torch.Tensor(out_features)) 65 | self.bias_logvar = Parameter(torch.Tensor(out_features)) 66 | 67 | # init params either random or with pretrained net 68 | self.reset_parameters(init_weight, init_bias) 69 | 70 | # activations for kl 71 | self.sigmoid = nn.Sigmoid() 72 | self.softplus = nn.Softplus() 73 | 74 | # numerical stability param 75 | self.epsilon = 1e-8 76 | 77 | def reset_parameters(self, init_weight, init_bias): 78 | # init means 79 | stdv = 1. / math.sqrt(self.weight_mu.size(1)) 80 | 81 | self.z_mu.data.normal_(1, 1e-2) 82 | 83 | if init_weight is not None: 84 | self.weight_mu.data = torch.Tensor(init_weight) 85 | else: 86 | self.weight_mu.data.normal_(0, stdv) 87 | 88 | if init_bias is not None: 89 | self.bias_mu.data = torch.Tensor(init_bias) 90 | else: 91 | self.bias_mu.data.fill_(0) 92 | 93 | # init logvars 94 | self.z_logvar.data.normal_(-9, 1e-2) 95 | self.weight_logvar.data.normal_(-9, 1e-2) 96 | self.bias_logvar.data.normal_(-9, 1e-2) 97 | 98 | def clip_variances(self): 99 | if self.clip_var: 100 | self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) 101 | self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) 102 | 103 | def get_log_dropout_rates(self): 104 | log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) 105 | return log_alpha 106 | 107 | def compute_posterior_params(self): 108 | weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() 109 | self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var 110 | self.post_weight_mu = self.weight_mu * self.z_mu 111 | return self.post_weight_mu, self.post_weight_var 112 | 113 | def forward(self, x): 114 | if self.deterministic: 115 | assert self.training == False, "Flag deterministic is True. This should not be used in training." 116 | return F.linear(x, self.post_weight_mu, self.bias_mu) 117 | 118 | batch_size = x.size()[0] 119 | # compute z 120 | # note that we reparametrise according to [2] Eq. (11) (not [1]) 121 | z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training, 122 | cuda=self.cuda) 123 | 124 | # apply local reparametrisation trick see [1] Eq. (6) 125 | # to the parametrisation given in [3] Eq. (6) 126 | xz = x * z 127 | mu_activations = F.linear(xz, self.weight_mu, self.bias_mu) 128 | var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp()) 129 | 130 | return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda) 131 | 132 | def kl_divergence(self): 133 | # KL(q(z)||p(z)) 134 | # we use the kl divergence approximation given by [2] Eq.(14) 135 | k1, k2, k3 = 0.63576, 1.87320, 1.48695 136 | log_alpha = self.get_log_dropout_rates() 137 | KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) 138 | 139 | # KL(q(w|z)||p(w|z)) 140 | # we use the kl divergence given by [3] Eq.(8) 141 | KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 142 | KLD += torch.sum(KLD_element) 143 | 144 | # KL bias 145 | KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 146 | KLD += torch.sum(KLD_element) 147 | 148 | return KLD 149 | 150 | def __repr__(self): 151 | return self.__class__.__name__ + ' (' \ 152 | + str(self.in_features) + ' -> ' \ 153 | + str(self.out_features) + ')' 154 | 155 | 156 | # ------------------------------------------------------- 157 | # CONVOLUTIONAL LAYER 158 | # ------------------------------------------------------- 159 | 160 | class _ConvNdGroupNJ(Module): 161 | """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout). 162 | 163 | References: 164 | [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). 165 | [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). 166 | [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). 167 | """ 168 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, 169 | groups, bias, init_weight, init_bias, cuda=False, clip_var=None): 170 | super(_ConvNdGroupNJ, self).__init__() 171 | if in_channels % groups != 0: 172 | raise ValueError('in_channels must be divisible by groups') 173 | if out_channels % groups != 0: 174 | raise ValueError('out_channels must be divisible by groups') 175 | self.in_channels = in_channels 176 | self.out_channels = out_channels 177 | self.kernel_size = kernel_size 178 | self.stride = stride 179 | self.padding = padding 180 | self.dilation = dilation 181 | self.transposed = transposed 182 | self.output_padding = output_padding 183 | self.groups = groups 184 | 185 | self.cuda = cuda 186 | self.clip_var = clip_var 187 | self.deterministic = False # flag is used for compressed inference 188 | 189 | if transposed: 190 | self.weight_mu = Parameter(torch.Tensor( 191 | in_channels, out_channels // groups, *kernel_size)) 192 | self.weight_logvar = Parameter(torch.Tensor( 193 | in_channels, out_channels // groups, *kernel_size)) 194 | else: 195 | self.weight_mu = Parameter(torch.Tensor( 196 | out_channels, in_channels // groups, *kernel_size)) 197 | self.weight_logvar = Parameter(torch.Tensor( 198 | out_channels, in_channels // groups, *kernel_size)) 199 | 200 | self.bias_mu = Parameter(torch.Tensor(out_channels)) 201 | self.bias_logvar = Parameter(torch.Tensor(out_channels)) 202 | 203 | self.z_mu = Parameter(torch.Tensor(self.out_channels)) 204 | self.z_logvar = Parameter(torch.Tensor(self.out_channels)) 205 | 206 | self.reset_parameters(init_weight, init_bias) 207 | 208 | # activations for kl 209 | self.sigmoid = nn.Sigmoid() 210 | self.softplus = nn.Softplus() 211 | # numerical stability param 212 | self.epsilon = 1e-8 213 | 214 | def reset_parameters(self, init_weight, init_bias): 215 | # init means 216 | n = self.in_channels 217 | for k in self.kernel_size: 218 | n *= k 219 | stdv = 1. / math.sqrt(n) 220 | 221 | # init means 222 | if init_weight is not None: 223 | self.weight_mu.data = init_weight 224 | else: 225 | self.weight_mu.data.uniform_(-stdv, stdv) 226 | 227 | if init_bias is not None: 228 | self.bias_mu.data = init_bias 229 | else: 230 | self.bias_mu.data.fill_(0) 231 | 232 | # inti z 233 | self.z_mu.data.normal_(1, 1e-2) 234 | 235 | # init logvars 236 | self.z_logvar.data.normal_(-9, 1e-2) 237 | self.weight_logvar.data.normal_(-9, 1e-2) 238 | self.bias_logvar.data.normal_(-9, 1e-2) 239 | 240 | def clip_variances(self): 241 | if self.clip_var: 242 | self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) 243 | self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) 244 | 245 | def get_log_dropout_rates(self): 246 | log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) 247 | return log_alpha 248 | 249 | def compute_posterior_params(self): 250 | weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() 251 | self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var 252 | self.post_weight_mu = self.weight_mu * self.z_mu 253 | return self.post_weight_mu, self.post_weight_var 254 | 255 | def kl_divergence(self): 256 | # KL(q(z)||p(z)) 257 | # we use the kl divergence approximation given by [2] Eq.(14) 258 | k1, k2, k3 = 0.63576, 1.87320, 1.48695 259 | log_alpha = self.get_log_dropout_rates() 260 | KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) 261 | 262 | # KL(q(w|z)||p(w|z)) 263 | # we use the kl divergence given by [3] Eq.(8) 264 | KLD_element = - 0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 265 | KLD += torch.sum(KLD_element) 266 | 267 | # KL bias 268 | KLD_element = - 0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 269 | KLD += torch.sum(KLD_element) 270 | 271 | return KLD 272 | 273 | def __repr__(self): 274 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 275 | ', stride={stride}') 276 | if self.padding != (0,) * len(self.padding): 277 | s += ', padding={padding}' 278 | if self.dilation != (1,) * len(self.dilation): 279 | s += ', dilation={dilation}' 280 | if self.output_padding != (0,) * len(self.output_padding): 281 | s += ', output_padding={output_padding}' 282 | if self.groups != 1: 283 | s += ', groups={groups}' 284 | if self.bias is None: 285 | s += ', bias=False' 286 | s += ')' 287 | return s.format(name=self.__class__.__name__, **self.__dict__) 288 | 289 | 290 | class Conv1dGroupNJ(_ConvNdGroupNJ): 291 | r""" 292 | """ 293 | 294 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 295 | cuda=False, init_weight=None, init_bias=None, clip_var=None): 296 | kernel_size = utils._single(kernel_size) 297 | stride = utils._single(stride) 298 | padding = utils._single(padding) 299 | dilation = utils._single(dilation) 300 | 301 | super(Conv1dGroupNJ, self).__init__( 302 | in_channels, out_channels, kernel_size, stride, padding, dilation, 303 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var) 304 | 305 | def forward(self, x): 306 | if self.deterministic: 307 | assert self.training == False, "Flag deterministic is True. This should not be used in training." 308 | return F.conv1d(x, self.post_weight_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups) 309 | batch_size = x.size()[0] 310 | # apply local reparametrisation trick see [1] Eq. (6) 311 | # to the parametrisation given in [3] Eq. (6) 312 | mu_activations = F.conv1d(x, self.weight_mu, self.bias_mu, self.stride, 313 | self.padding, self.dilation, self.groups) 314 | 315 | var_activations = F.conv1d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride, 316 | self.padding, self.dilation, self.groups) 317 | # compute z 318 | # note that we reparametrise according to [2] Eq. (11) (not [1]) 319 | z = reparametrize(self.z_mu.repeat(batch_size, 1, 1), self.z_logvar.repeat(batch_size, 1, 1), 320 | sampling=self.training, cuda=self.cuda) 321 | z = z[:, :, None] 322 | 323 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training, 324 | cuda=self.cuda) 325 | 326 | def __repr__(self): 327 | return self.__class__.__name__ + ' (' \ 328 | + str(self.in_features) + ' -> ' \ 329 | + str(self.out_features) + ')' 330 | 331 | 332 | class Conv2dGroupNJ(_ConvNdGroupNJ): 333 | r""" 334 | """ 335 | 336 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 337 | cuda=False, init_weight=None, init_bias=None, clip_var=None): 338 | kernel_size = utils._pair(kernel_size) 339 | stride = utils._pair(stride) 340 | padding = utils._pair(padding) 341 | dilation = utils._pair(dilation) 342 | 343 | super(Conv2dGroupNJ, self).__init__( 344 | in_channels, out_channels, kernel_size, stride, padding, dilation, 345 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var) 346 | 347 | def forward(self, x): 348 | if self.deterministic: 349 | assert self.training == False, "Flag deterministic is True. This should not be used in training." 350 | return F.conv2d(x, self.post_weight_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups) 351 | batch_size = x.size()[0] 352 | # apply local reparametrisation trick see [1] Eq. (6) 353 | # to the parametrisation given in [3] Eq. (6) 354 | mu_activations = F.conv2d(x, self.weight_mu, self.bias_mu, self.stride, 355 | self.padding, self.dilation, self.groups) 356 | 357 | var_activations = F.conv2d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride, 358 | self.padding, self.dilation, self.groups) 359 | # compute z 360 | # note that we reparametrise according to [2] Eq. (11) (not [1]) 361 | z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), 362 | sampling=self.training, cuda=self.cuda) 363 | z = z[:, :, None, None] 364 | 365 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training, 366 | cuda=self.cuda) 367 | 368 | def __repr__(self): 369 | return self.__class__.__name__ + ' (' \ 370 | + str(self.in_features) + ' -> ' \ 371 | + str(self.out_features) + ')' 372 | 373 | 374 | class Conv3dGroupNJ(_ConvNdGroupNJ): 375 | r""" 376 | """ 377 | 378 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 379 | cuda=False, init_weight=None, init_bias=None, clip_var=None): 380 | kernel_size = utils._triple(kernel_size) 381 | stride = utils._triple(stride) 382 | padding = utils._triple(padding) 383 | dilation = utils.triple(dilation) 384 | 385 | super(Conv3dGroupNJ, self).__init__( 386 | in_channels, out_channels, kernel_size, stride, padding, dilation, 387 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var) 388 | 389 | def forward(self, x): 390 | if self.deterministic: 391 | assert self.training == False, "Flag deterministic is True. This should not be used in training." 392 | return F.conv3d(x, self.post_weight_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups) 393 | batch_size = x.size()[0] 394 | # apply local reparametrisation trick see [1] Eq. (6) 395 | # to the parametrisation given in [3] Eq. (6) 396 | mu_activations = F.conv3d(x, self.weight_mu, self.bias_mu, self.stride, 397 | self.padding, self.dilation, self.groups) 398 | 399 | var_weights = self.weight_logvar.exp() 400 | var_activations = F.conv3d(x.pow(2), var_weights, self.bias_logvar.exp(), self.stride, 401 | self.padding, self.dilation, self.groups) 402 | # compute z 403 | # note that we reparametrise according to [2] Eq. (11) (not [1]) 404 | z = reparametrize(self.z_mu.repeat(batch_size, 1, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1, 1), 405 | sampling=self.training, cuda=self.cuda) 406 | z = z[:, :, None, None, None] 407 | 408 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training, 409 | cuda=self.cuda) 410 | 411 | def __repr__(self): 412 | return self.__class__.__name__ + ' (' \ 413 | + str(self.in_features) + ' -> ' \ 414 | + str(self.out_features) + ')' 415 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Karen Ullrich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code release for "Bayesian Compression for Deep Learning" 2 | 3 | 4 | In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of neural networks. 5 | By revisiting the connection between the minimum description length principle and variational inference we are 6 | able to achieve up to 700x compression and up to 50x speed up (CPU to sparse GPU) for neural networks. 7 | 8 | We visualize the learning process in the following figures for a dense network with 300 and 100 connections. 9 | White color represents redundancy whereas red and blue represent positive and negative weights respectively. 10 | 11 | |First layer weights |Second Layer weights| 12 | | :------ |:------: | 13 | |![alt text](./figures/weight0_e.gif "First layer weights")|![alt text](./figures/weight1_e.gif "Second Layer weights")| 14 | 15 | For dense networks it is also simple to reconstruct input feature importance. We show this for a mask and 5 randomly chosen digits. 16 | ![alt text](./figures/pixel.gif "Pixel importance") 17 | 18 | 19 | ## Results 20 | 21 | 22 | | Model | Method | Error [%] | Compression
after pruning | Compression after
precision reduction | 23 | | ------ | :------ |:------: | ------: |------: | 24 | |LeNet-5-Caffe |[DC](https://arxiv.org/abs/1510.00149) | 0.7 | 6* | -| 25 | | |[DNS](https://arxiv.org/abs/1608.04493) | 0.9 | 55* | -| 26 | | |[SWS](https://arxiv.org/abs/1702.04008) | 1.0 | 100* | -| 27 | | |[Sparse VD](https://arxiv.org/pdf/1701.05369.pdf) | 1.0 | 63* | 228| 28 | | |BC-GNJ | 1.0 | 108* | 361| 29 | | |BC-GHS | 1.0 | 156* | 419| 30 | | VGG |BC-GNJ | 8.6 | 14* | 56| 31 | | |BC-GHS | 9.0 | 18* | 59| 32 | 33 | ## Usage 34 | We provide an implementation in PyTorch for fully connected and convolutional layers for the group normal-Jeffreys prior (aka Group Variational Dropout) via: 35 | ```python 36 | import BayesianLayers 37 | ``` 38 | The layers can be then straightforwardly included eas follows: 39 | ```python 40 | class Net(nn.Module): 41 | def __init__(self): 42 | super(Net, self).__init__() 43 | # activation 44 | self.relu = nn.ReLU() 45 | # layers 46 | self.fc1 = BayesianLayers.LinearGroupNJ(28 * 28, 300, clip_var=0.04) 47 | self.fc2 = BayesianLayers.LinearGroupNJ(300, 100) 48 | self.fc3 = BayesianLayers.LinearGroupNJ(100, 10) 49 | # layers including kl_divergence 50 | self.kl_list = [self.fc1, self.fc2, self.fc3] 51 | 52 | def forward(self, x): 53 | x = x.view(-1, 28 * 28) 54 | x = self.relu(self.fc1(x)) 55 | x = self.relu(self.fc2(x)) 56 | return self.fc3(x) 57 | 58 | def kl_divergence(self): 59 | KLD = 0 60 | for layer in self.kl_list: 61 | KLD += layer.kl_divergence() 62 | return KLD 63 | ``` 64 | The only additional effort is to include the KL-divergence in the objective. 65 | This is necessary if we want to the optimize the variational lower bound that leads to sparse solutions: 66 | ```python 67 | N = 60000. 68 | discrimination_loss = nn.functional.cross_entropy 69 | 70 | def objective(output, target, kl_divergence): 71 | discrimination_error = discrimination_loss(output, target) 72 | return discrimination_error + kl_divergence / N 73 | ``` 74 | ## Run an example 75 | We provide a simple example, the LeNet-300-100 trained with the group normal-Jeffreys prior: 76 | ```sh 77 | python example.py 78 | ``` 79 | 80 | ## Retraining a regular neural network 81 | Instead of training a network from scratch we often need to compress an already existing network. 82 | In this case we can simply initialize the weights with those of the pretrained network: 83 | ```python 84 | BayesianLayers.LinearGroupNJ(28*28, 300, init_weight=pretrained_weight, init_bias=pretrained_bias) 85 | ``` 86 | ## *Reference* 87 | The paper "Bayesian Compression for Deep Learning" has been accepted to NIPS 2017. Please cite us: 88 | 89 | @article{louizos2017bayesian, 90 | title={Bayesian Compression for Deep Learning}, 91 | author={Louizos, Christos and Ullrich, Karen and Welling, Max}, 92 | journal={Conference on Neural Information Processing Systems (NIPS)}, 93 | year={2017} 94 | } -------------------------------------------------------------------------------- /compression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Compression Tools 6 | 7 | 8 | Karen Ullrich, Oct 2017 9 | 10 | References: 11 | 12 | [1] Michael T. Heath. 1996. Scientific Computing: An Introductory Survey (2nd ed.). Eric M. Munson (Ed.). McGraw-Hill Higher Education. Chapter 1 13 | """ 14 | 15 | import numpy as np 16 | 17 | # ------------------------------------------------------- 18 | # General tools 19 | # ------------------------------------------------------- 20 | 21 | 22 | def unit_round_off(t=23): 23 | """ 24 | :param t: 25 | number significand bits 26 | :return: 27 | unit round off based on nearest interpolation, for reference see [1] 28 | """ 29 | return 0.5 * 2. ** (1. - t) 30 | 31 | 32 | SIGNIFICANT_BIT_PRECISION = [unit_round_off(t=i + 1) for i in range(23)] 33 | 34 | 35 | def float_precision(x): 36 | 37 | out = np.sum([x < sbp for sbp in SIGNIFICANT_BIT_PRECISION]) 38 | return out 39 | 40 | 41 | def float_precisions(X, dist_fun, layer=1): 42 | 43 | X = X.flatten() 44 | out = [float_precision(2 * x) for x in X] 45 | out = np.ceil(dist_fun(out)) 46 | return out 47 | 48 | 49 | def special_round(input, significant_bit): 50 | delta = unit_round_off(t=significant_bit) 51 | rounded = np.floor(input / delta + 0.5) 52 | rounded = rounded * delta 53 | return rounded 54 | 55 | 56 | def fast_infernce_weights(w, exponent_bit, significant_bit): 57 | 58 | return special_round(w, significant_bit) 59 | 60 | 61 | def compress_matrix(x): 62 | 63 | if len(x.shape) != 2: 64 | A, B, C, D = x.shape 65 | x = x.reshape(A * B, C * D) 66 | # remove non-necessary filters and rows 67 | x = x[:, (x != 0).any(axis=0)] 68 | x = x[(x != 0).any(axis=1), :] 69 | else: 70 | # remove unnecessary rows, columns 71 | x = x[(x != 0).any(axis=1), :] 72 | x = x[:, (x != 0).any(axis=0)] 73 | return x 74 | 75 | 76 | def extract_pruned_params(layers, masks): 77 | 78 | post_weight_mus = [] 79 | post_weight_vars = [] 80 | 81 | for i, (layer, mask) in enumerate(zip(layers, masks)): 82 | # compute posteriors 83 | post_weight_mu, post_weight_var = layer.compute_posterior_params() 84 | post_weight_var = post_weight_var.cpu().data.numpy() 85 | post_weight_mu = post_weight_mu.cpu().data.numpy() 86 | # apply mask to mus and variances 87 | post_weight_mu = post_weight_mu * mask 88 | post_weight_var = post_weight_var * mask 89 | 90 | post_weight_mus.append(post_weight_mu) 91 | post_weight_vars.append(post_weight_var) 92 | 93 | return post_weight_mus, post_weight_vars 94 | 95 | 96 | # ------------------------------------------------------- 97 | # Compression rates (fast inference scenario) 98 | # ------------------------------------------------------- 99 | 100 | 101 | def _compute_compression_rate(vars, in_precision=32., dist_fun=lambda x: np.max(x), overflow=10e38): 102 | # compute in number of bits occupied by the original architecture 103 | sizes = [v.size for v in vars] 104 | nb_weights = float(np.sum(sizes)) 105 | IN_BITS = in_precision * nb_weights 106 | # prune architecture 107 | vars = [compress_matrix(v) for v in vars] 108 | sizes = [v.size for v in vars] 109 | # compute 110 | significant_bits = [float_precisions(v, dist_fun, layer=k + 1) for k, v in enumerate(vars)] 111 | exponent_bit = np.ceil(np.log2(np.log2(overflow) + 1.) + 1.) 112 | total_bits = [1. + exponent_bit + sb for sb in significant_bits] 113 | OUT_BITS = np.sum(np.asarray(sizes) * np.asarray(total_bits)) 114 | return nb_weights / np.sum(sizes), IN_BITS / OUT_BITS, significant_bits, exponent_bit 115 | 116 | 117 | def compute_compression_rate(layers, masks): 118 | # reduce architecture 119 | weight_mus, weight_vars = extract_pruned_params(layers, masks) 120 | # compute overflow level based on maximum weight 121 | overflow = np.max([np.max(np.abs(w)) for w in weight_mus]) 122 | # compute compression rate 123 | CR_architecture, CR_fast_inference, _, _ = _compute_compression_rate(weight_vars, dist_fun=lambda x: np.mean(x), overflow=overflow) 124 | print("Compressing the architecture will degrease the model by a factor of %.1f." % (CR_architecture)) 125 | print("Making use of weight uncertainty can reduce the model by a factor of %.1f." % (CR_fast_inference)) 126 | 127 | 128 | def compute_reduced_weights(layers, masks): 129 | weight_mus, weight_vars = extract_pruned_params(layers, masks) 130 | overflow = np.max([np.max(np.abs(w)) for w in weight_mus]) 131 | _, _, significant_bits, exponent_bits = _compute_compression_rate(weight_vars, dist_fun=lambda x: np.mean(x), overflow=overflow) 132 | weights = [fast_infernce_weights(weight_mu, exponent_bits, significant_bit) for weight_mu, significant_bit in 133 | zip(weight_mus, significant_bits)] 134 | return weights -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: BCDL 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - ca-certificates=2018.03.07=0 7 | - certifi=2018.1.18=py27_0 8 | - cffi=1.11.5=py27h9745a5d_0 9 | - cudatoolkit=8.0=3 10 | - freetype=2.8=hab7d2ae_1 11 | - imageio=2.3.0=py27_0 12 | - intel-openmp=2018.0.0=8 13 | - jpeg=9b=h024ee3a_2 14 | - libedit=3.1=heed3624_0 15 | - libffi=3.2.1=hd88cf55_4 16 | - libgcc-ng=7.2.0=hdf63c60_3 17 | - libgfortran-ng=7.2.0=hdf63c60_3 18 | - libpng=1.6.34=hb9fc6fc_0 19 | - libstdcxx-ng=7.2.0=hdf63c60_3 20 | - libtiff=4.0.9=h28f6b97_0 21 | - mkl=2018.0.2=1 22 | - mkl_fft=1.0.1=py27h3010b51_0 23 | - mkl_random=1.0.1=py27h629b387_0 24 | - ncurses=6.0=h9df7e31_2 25 | - numpy=1.14.2=py27hdbf6ddf_1 26 | - olefile=0.45.1=py27_0 27 | - openssl=1.0.2o=h20670df_0 28 | - pillow=5.0.0=py27h3deb7b8_0 29 | - pip=9.0.3=py27_0 30 | - pycparser=2.18=py27hefa08c5_1 31 | - python=2.7.14=h1571d57_31 32 | - readline=7.0=ha6073c6_4 33 | - scipy=1.0.1=py27hfc37229_0 34 | - setuptools=39.0.1=py27_0 35 | - six=1.11.0=py27h5f960f1_1 36 | - sqlite=3.22.0=h1bed415_0 37 | - tk=8.6.7=hc745277_3 38 | - wheel=0.31.0=py27_0 39 | - xz=5.2.3=h55aa19d_2 40 | - zlib=1.2.11=ha838bed_2 41 | - pytorch=0.3.1=py27_cuda8.0.61_cudnn7.1.2_3 42 | - torchvision=0.2.0=py27hfb27419_1 43 | - pip: 44 | - backports-abc==0.4 45 | - backports.functools-lru-cache==1.5 46 | - backports.ssl-match-hostname==3.5.0.1 47 | - cycler==0.10.0 48 | - functools32==3.2.3.post2 49 | - ipython-genutils==0.1.0 50 | - ipywidgets==4.1.1 51 | - jsonschema==2.5.1 52 | - kiwisolver==1.0.1 53 | - matplotlib==2.2.2 54 | - nbformat==4.0.1 55 | - pandas==0.22.0 56 | - path.py==8.1.2 57 | - ptyprocess==0.5.1 58 | - pyparsing==2.2.0 59 | - python-dateutil==2.7.2 60 | - pytz==2018.4 61 | - seaborn==0.8.1 62 | - singledispatch==3.4.0.3 63 | - subprocess32==3.2.7 64 | - terminado==0.6 65 | - torch==0.3.1.post3 66 | - tornado==4.3 67 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Linear Bayesian Model 6 | 7 | 8 | Karen Ullrich, Christos Louizos, Oct 2017 9 | """ 10 | 11 | 12 | # libraries 13 | from __future__ import print_function 14 | import numpy as np 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | from torchvision import datasets, transforms 21 | from torch.autograd import Variable 22 | 23 | import BayesianLayers 24 | from compression import compute_compression_rate, compute_reduced_weights 25 | from utils import visualize_pixel_importance, generate_gif, visualise_weights 26 | 27 | N = 60000. # number of data points in the training set 28 | 29 | 30 | def main(): 31 | # import data 32 | kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {} 33 | 34 | train_loader = torch.utils.data.DataLoader( 35 | datasets.MNIST('./data', train=True, download=True, 36 | transform=transforms.Compose([ 37 | transforms.ToTensor(),lambda x: 2 * (x - 0.5), 38 | ])), 39 | batch_size=FLAGS.batchsize, shuffle=True, **kwargs) 40 | 41 | test_loader = torch.utils.data.DataLoader( 42 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 43 | transforms.ToTensor(), lambda x: 2 * (x - 0.5), 44 | ])), 45 | batch_size=FLAGS.batchsize, shuffle=True, **kwargs) 46 | 47 | # for later analysis we take some sample digits 48 | mask = 255. * (np.ones((1, 28, 28))) 49 | examples = train_loader.sampler.data_source.train_data[0:5].numpy() 50 | images = np.vstack([mask, examples]) 51 | 52 | # build a simple MLP 53 | class Net(nn.Module): 54 | def __init__(self): 55 | super(Net, self).__init__() 56 | # activation 57 | self.relu = nn.ReLU() 58 | # layers 59 | self.fc1 = BayesianLayers.LinearGroupNJ(28 * 28, 300, clip_var=0.04, cuda=FLAGS.cuda) 60 | self.fc2 = BayesianLayers.LinearGroupNJ(300, 100, cuda=FLAGS.cuda) 61 | self.fc3 = BayesianLayers.LinearGroupNJ(100, 10, cuda=FLAGS.cuda) 62 | # layers including kl_divergence 63 | self.kl_list = [self.fc1, self.fc2, self.fc3] 64 | 65 | def forward(self, x): 66 | x = x.view(-1, 28 * 28) 67 | x = self.relu(self.fc1(x)) 68 | x = self.relu(self.fc2(x)) 69 | return self.fc3(x) 70 | 71 | def get_masks(self,thresholds): 72 | weight_masks = [] 73 | mask = None 74 | for i, (layer, threshold) in enumerate(zip(self.kl_list, thresholds)): 75 | # compute dropout mask 76 | if mask is None: 77 | log_alpha = layer.get_log_dropout_rates().cpu().data.numpy() 78 | mask = log_alpha < threshold 79 | else: 80 | mask = np.copy(next_mask) 81 | try: 82 | log_alpha = layers[i + 1].get_log_dropout_rates().cpu().data.numpy() 83 | next_mask = log_alpha < thresholds[i + 1] 84 | except: 85 | # must be the last mask 86 | next_mask = np.ones(10) 87 | 88 | weight_mask = np.expand_dims(mask, axis=0) * np.expand_dims(next_mask, axis=1) 89 | weight_masks.append(weight_mask.astype(np.float)) 90 | return weight_masks 91 | 92 | def kl_divergence(self): 93 | KLD = 0 94 | for layer in self.kl_list: 95 | KLD += layer.kl_divergence() 96 | return KLD 97 | 98 | # init model 99 | model = Net() 100 | if FLAGS.cuda: 101 | model.cuda() 102 | 103 | # init optimizer 104 | optimizer = optim.Adam(model.parameters()) 105 | 106 | # we optimize the variational lower bound scaled by the number of data 107 | # points (so we can keep our intuitions about hyper-params such as the learning rate) 108 | discrimination_loss = nn.functional.cross_entropy 109 | 110 | def objective(output, target, kl_divergence): 111 | discrimination_error = discrimination_loss(output, target) 112 | variational_bound = discrimination_error + kl_divergence / N 113 | if FLAGS.cuda: 114 | variational_bound = variational_bound.cuda() 115 | return variational_bound 116 | 117 | def train(epoch): 118 | model.train() 119 | for batch_idx, (data, target) in enumerate(train_loader): 120 | if FLAGS.cuda: 121 | data, target = data.cuda(), target.cuda() 122 | data, target = Variable(data), Variable(target) 123 | optimizer.zero_grad() 124 | output = model(data) 125 | loss = objective(output, target, model.kl_divergence()) 126 | loss.backward() 127 | optimizer.step() 128 | # clip the variances after each step 129 | for layer in model.kl_list: 130 | layer.clip_variances() 131 | print('Epoch: {} \tTrain loss: {:.6f} \t'.format( 132 | epoch, loss.data[0])) 133 | 134 | def test(): 135 | model.eval() 136 | test_loss = 0 137 | correct = 0 138 | for data, target in test_loader: 139 | if FLAGS.cuda: 140 | data, target = data.cuda(), target.cuda() 141 | data, target = Variable(data, volatile=True), Variable(target) 142 | output = model(data) 143 | test_loss += discrimination_loss(output, target, size_average=False).data[0] 144 | pred = output.data.max(1, keepdim=True)[1] 145 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 146 | test_loss /= len(test_loader.dataset) 147 | print('Test loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 148 | test_loss, correct, len(test_loader.dataset), 149 | 100. * correct / len(test_loader.dataset))) 150 | 151 | # train the model and save some visualisations on the way 152 | for epoch in range(1, FLAGS.epochs + 1): 153 | train(epoch) 154 | test() 155 | # visualizations 156 | weight_mus = [model.fc1.weight_mu, model.fc2.weight_mu] 157 | log_alphas = [model.fc1.get_log_dropout_rates(), model.fc2.get_log_dropout_rates(), 158 | model.fc3.get_log_dropout_rates()] 159 | visualise_weights(weight_mus, log_alphas, epoch=epoch) 160 | log_alpha = model.fc1.get_log_dropout_rates().cpu().data.numpy() 161 | visualize_pixel_importance(images, log_alpha=log_alpha, epoch=str(epoch)) 162 | 163 | generate_gif(save='pixel', epochs=FLAGS.epochs) 164 | generate_gif(save='weight0_e', epochs=FLAGS.epochs) 165 | generate_gif(save='weight1_e', epochs=FLAGS.epochs) 166 | 167 | # compute compression rate and new model accuracy 168 | layers = [model.fc1, model.fc2, model.fc3] 169 | thresholds = FLAGS.thresholds 170 | compute_compression_rate(layers, model.get_masks(thresholds)) 171 | 172 | print("Test error after with reduced bit precision:") 173 | 174 | weights = compute_reduced_weights(layers, model.get_masks(thresholds)) 175 | for layer, weight in zip(layers, weights): 176 | if FLAGS.cuda: 177 | layer.post_weight_mu.data = torch.Tensor(weight).cuda() 178 | else: 179 | layer.post_weight_mu.data = torch.Tensor(weight) 180 | for layer in layers: layer.deterministic = True 181 | test() 182 | 183 | 184 | if __name__ == '__main__': 185 | import argparse 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--epochs', type=int, default=5) 188 | parser.add_argument('--batchsize', type=int, default=128) 189 | parser.add_argument('--thresholds', type=float, nargs='*', default=[-2.8, -3., -5.]) 190 | 191 | FLAGS = parser.parse_args() 192 | FLAGS.cuda = torch.cuda.is_available() # check if we can put the net on the GPU 193 | main() 194 | -------------------------------------------------------------------------------- /figures/pixel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial_BayesianCompressionForDL/f1e7c7910a61d5ce86490089e82cbbfb01119052/figures/pixel.gif -------------------------------------------------------------------------------- /figures/weight0_e.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial_BayesianCompressionForDL/f1e7c7910a61d5ce86490089e82cbbfb01119052/figures/weight0_e.gif -------------------------------------------------------------------------------- /figures/weight1_e.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarenUllrich/Tutorial_BayesianCompressionForDL/f1e7c7910a61d5ce86490089e82cbbfb01119052/figures/weight1_e.gif -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Utilities 6 | 7 | 8 | Karen Ullrich, Oct 2017 9 | """ 10 | 11 | import os 12 | import numpy as np 13 | import imageio 14 | 15 | import matplotlib.pyplot as plt 16 | import seaborn as sns 17 | 18 | sns.set_style("whitegrid") 19 | cmap = sns.diverging_palette(240, 10, sep=100, as_cmap=True) 20 | 21 | # ------------------------------------------------------- 22 | # VISUALISATION TOOLS 23 | # ------------------------------------------------------- 24 | 25 | 26 | def visualize_pixel_importance(imgs, log_alpha, epoch="pixel_importance"): 27 | num_imgs = len(imgs) 28 | 29 | f, ax = plt.subplots(1, num_imgs) 30 | plt.title("Epoch:" + epoch) 31 | for i, img in enumerate(imgs): 32 | img = (img / 255.) - 0.5 33 | mask = log_alpha.reshape(img.shape) 34 | mask = 1 - np.clip(np.exp(mask), 0.0, 1) 35 | ax[i].imshow(img * mask, cmap=cmap, interpolation='none', vmin=-0.5, vmax=0.5) 36 | ax[i].grid("off") 37 | ax[i].set_yticks([]) 38 | ax[i].set_xticks([]) 39 | plt.savefig("./.pixel" + epoch + ".png", bbox_inches='tight') 40 | plt.close() 41 | 42 | 43 | def visualise_weights(weight_mus, log_alphas, epoch): 44 | num_layers = len(weight_mus) 45 | 46 | for i in range(num_layers): 47 | f, ax = plt.subplots(1, 1) 48 | weight_mu = np.transpose(weight_mus[i].cpu().data.numpy()) 49 | # alpha 50 | log_alpha_fc1 = log_alphas[i].unsqueeze(1).cpu().data.numpy() 51 | log_alpha_fc1 = log_alpha_fc1 < -3 52 | log_alpha_fc2 = log_alphas[i + 1].unsqueeze(0).cpu().data.numpy() 53 | log_alpha_fc2 = log_alpha_fc2 < -3 54 | mask = log_alpha_fc1 + log_alpha_fc2 55 | # weight 56 | c = np.max(np.abs(weight_mu)) 57 | s = ax.imshow(weight_mu * mask, cmap='seismic', interpolation='none', vmin=-c, vmax=c) 58 | ax.grid("off") 59 | ax.set_yticks([]) 60 | ax.set_xticks([]) 61 | s.set_clim([-c * 0.5, c * 0.5]) 62 | f.colorbar(s) 63 | plt.title("Epoch:" + str(epoch)) 64 | plt.savefig("./.weight" + str(i) + '_e' + str(epoch) + ".png", bbox_inches='tight') 65 | plt.close() 66 | 67 | 68 | def generate_gif(save='tmp', epochs=10): 69 | images = [] 70 | filenames = ["./." + save + "%d.png" % (epoch + 1) for epoch in np.arange(epochs)] 71 | for filename in filenames: 72 | images.append(imageio.imread(filename)) 73 | os.remove(filename) 74 | imageio.mimsave('./figures/' + save + '.gif', images, duration=.5) 75 | --------------------------------------------------------------------------------