├── .gitignore ├── Bayesian ├── BayesianLayers.py └── BayesianModule.py ├── README.md ├── Tools ├── Objective.py └── visualization.py ├── bayes_opt ├── __init__.py ├── bayesian_optimization.py ├── helpers.py └── target_space.py ├── compression.py ├── example.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | figures/ 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | .static_storage/ 59 | .media/ 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ -------------------------------------------------------------------------------- /Bayesian/BayesianLayers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Variational Dropout version of linear and convolutional layers 6 | Karen Ullrich, Christos Louizos, Oct 2017 7 | """ 8 | 9 | import math 10 | 11 | import torch 12 | from torch.nn.parameter import Parameter 13 | import torch.nn.functional as F 14 | from torch import nn 15 | from torch.nn.modules import Module 16 | from torch.autograd import Variable 17 | from torch.nn.modules import utils 18 | 19 | 20 | def reparametrize(mu, logvar, cuda=False, sampling=True): 21 | if sampling: 22 | std = logvar.mul(0.5).exp_() 23 | if cuda: 24 | eps = torch.cuda.FloatTensor(std.size()).normal_() 25 | else: 26 | eps = torch.FloatTensor(std.size()).normal_() 27 | eps = Variable(eps) 28 | return mu + eps * std 29 | else: 30 | return mu 31 | 32 | 33 | class BayesianLayers(nn.Module): 34 | """ 35 | References: 36 | [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). 37 | [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). 38 | [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). 39 | """ 40 | 41 | def __init__(self): 42 | super(BayesianLayers, self).__init__() 43 | self.sigmoid = nn.Sigmoid() 44 | self.softplus = nn.Softplus() 45 | self.epsilon = 1e-8 46 | 47 | def compute_posterior_params(self): 48 | weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() 49 | self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var 50 | self.post_weight_mu = self.weight_mu * self.z_mu 51 | return self.post_weight_mu, self.post_weight_var 52 | 53 | def get_log_dropout_rates(self): 54 | log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) 55 | return log_alpha 56 | 57 | def kl_divergence(self): 58 | # KL(q(z)||p(z)) 59 | # we use the kl divergence approximation given by [2] Eq.(14) 60 | k1, k2, k3 = 0.63576, 1.87320, 1.48695 61 | log_alpha = self.get_log_dropout_rates() 62 | KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) 63 | 64 | # KL(q(w|z)||p(w|z)) 65 | # we use the kl divergence given by [3] Eq.(8) 66 | KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5 67 | KLD += torch.sum(KLD_element) 68 | 69 | # KL bias 70 | KLD_element = -self.bias_logvar + 0.5 * (self.bias_logvar.exp().pow(2) + self.bias_mu.pow(2)) - 0.5 71 | KLD += torch.sum(KLD_element) 72 | 73 | return KLD 74 | 75 | def __repr__(self): 76 | return self.__class__.__name__ 77 | 78 | 79 | # ------------------------------------------------------- 80 | # LINEAR LAYER 81 | # ------------------------------------------------------- 82 | 83 | class LinearGroupNJ(BayesianLayers): 84 | """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout). 85 | 86 | References: 87 | [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). 88 | [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). 89 | [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). 90 | """ 91 | 92 | def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None): 93 | 94 | super(LinearGroupNJ, self).__init__() 95 | self.cuda = cuda 96 | self.in_features = in_features 97 | self.out_features = out_features 98 | self.clip_var = clip_var 99 | self.deterministic = False # flag is used for compressed inference 100 | # trainable params according to Eq.(6) 101 | # dropout params 102 | self.z_mu = Parameter(torch.Tensor(in_features)) 103 | self.z_logvar = Parameter(torch.Tensor(in_features)) # = z_mu^2 * alpha 104 | # weight params 105 | self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) 106 | self.weight_logvar = Parameter(torch.Tensor(out_features, in_features)) 107 | 108 | self.bias_mu = Parameter(torch.Tensor(out_features)) 109 | self.bias_logvar = Parameter(torch.Tensor(out_features)) 110 | 111 | # init params either random or with pretrained net 112 | self.reset_parameters(init_weight, init_bias) 113 | 114 | # activations for kl 115 | self.sigmoid = nn.Sigmoid() 116 | self.softplus = nn.Softplus() 117 | 118 | # numerical stability param 119 | self.epsilon = 1e-8 120 | 121 | def reset_parameters(self, init_weight, init_bias): 122 | # init means 123 | stdv = 1. / math.sqrt(self.weight_mu.size(1)) 124 | 125 | self.z_mu.data.normal_(1, 1e-2) 126 | 127 | if init_weight is not None: 128 | self.weight_mu.data = torch.Tensor(init_weight) 129 | else: 130 | self.weight_mu.data.normal_(0, stdv) 131 | 132 | if init_bias is not None: 133 | self.bias_mu.data = torch.Tensor(init_bias) 134 | else: 135 | self.bias_mu.data.fill_(0) 136 | 137 | # init logvars 138 | self.z_logvar.data.normal_(-9, 1e-2) 139 | self.weight_logvar.data.normal_(-9, 1e-2) 140 | self.bias_logvar.data.normal_(-9, 1e-2) 141 | 142 | def clip_variances(self): 143 | if self.clip_var: 144 | self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) 145 | self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) 146 | 147 | def get_log_dropout_rates(self): 148 | log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) 149 | return log_alpha 150 | 151 | def compute_posterior_params(self): 152 | weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() 153 | self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var 154 | self.post_weight_mu = self.weight_mu * self.z_mu 155 | # print("self.z_mu.pow(2): ", self.z_mu.pow(2).size()) 156 | # print("weight_var: ", weight_var.size()) 157 | # print("z_var: ", z_var.size()) 158 | # print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size()) 159 | # print("weight_var: ", weight_var.size()) 160 | # print("post_weight_mu: ", self.post_weight_mu.size()) 161 | # print("post_weight_var: ", self.post_weight_var.size()) 162 | return self.post_weight_mu, self.post_weight_var 163 | 164 | def forward(self, x): 165 | if self.deterministic: 166 | assert self.training == False, "Flag deterministic is True. This should not be used in training." 167 | return F.linear(x, self.post_weight_mu, self.bias_mu) 168 | 169 | batch_size = x.size()[0] 170 | # compute z 171 | # note that we reparametrise according to [2] Eq. (11) (not [1]) 172 | z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training, 173 | cuda=self.cuda) 174 | 175 | # apply local reparametrisation trick see [1] Eq. (6) 176 | # to the parametrisation given in [3] Eq. (6) 177 | xz = x * z 178 | mu_activations = F.linear(xz, self.weight_mu, self.bias_mu) 179 | var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp()) 180 | 181 | return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda) 182 | 183 | def kl_divergence(self): 184 | # KL(q(z)||p(z)) 185 | # we use the kl divergence approximation given by [2] Eq.(14) 186 | k1, k2, k3 = 0.63576, 1.87320, 1.48695 187 | log_alpha = self.get_log_dropout_rates() 188 | KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) 189 | 190 | # KL(q(w|z)||p(w|z)) 191 | # we use the kl divergence given by [3] Eq.(8) 192 | KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 193 | KLD += torch.sum(KLD_element) 194 | 195 | # KL bias 196 | KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 197 | KLD += torch.sum(KLD_element) 198 | 199 | return KLD 200 | 201 | def __repr__(self): 202 | return self.__class__.__name__ + ' (' \ 203 | + str(self.in_features) + ' -> ' \ 204 | + str(self.out_features) + ')' 205 | 206 | 207 | # ------------------------------------------------------- 208 | # CONVOLUTIONAL LAYER 209 | # ------------------------------------------------------- 210 | 211 | class _ConvNdGroupNJ(BayesianLayers): 212 | """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout). 213 | 214 | References: 215 | [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). 216 | [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). 217 | [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). 218 | """ 219 | 220 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, 221 | groups, bias, init_weight, init_bias, cuda=False, clip_var=None): 222 | super(_ConvNdGroupNJ, self).__init__() 223 | if in_channels % groups != 0: 224 | raise ValueError('in_channels must be divisible by groups') 225 | if out_channels % groups != 0: 226 | raise ValueError('out_channels must be divisible by groups') 227 | self.in_channels = in_channels 228 | self.out_channels = out_channels 229 | self.kernel_size = kernel_size 230 | self.stride = stride 231 | self.padding = padding 232 | self.dilation = dilation 233 | self.transposed = transposed 234 | self.output_padding = output_padding 235 | self.groups = groups 236 | 237 | self.cuda = cuda 238 | self.clip_var = clip_var 239 | self.deterministic = False # flag is used for compressed inference 240 | 241 | if transposed: 242 | self.weight_mu = Parameter(torch.Tensor( 243 | in_channels, out_channels // groups, *kernel_size)) 244 | self.weight_logvar = Parameter(torch.Tensor( 245 | in_channels, out_channels // groups, *kernel_size)) 246 | else: 247 | self.weight_mu = Parameter(torch.Tensor( 248 | out_channels, in_channels // groups, *kernel_size)) 249 | self.weight_logvar = Parameter(torch.Tensor( 250 | out_channels, in_channels // groups, *kernel_size)) 251 | 252 | self.bias_mu = Parameter(torch.Tensor(out_channels)) 253 | self.bias_logvar = Parameter(torch.Tensor(out_channels)) 254 | 255 | self.z_mu = Parameter(torch.Tensor(self.out_channels)) 256 | self.z_logvar = Parameter(torch.Tensor(self.out_channels)) 257 | 258 | self.reset_parameters(init_weight, init_bias) 259 | 260 | # activations for kl 261 | self.sigmoid = nn.Sigmoid() 262 | self.softplus = nn.Softplus() 263 | # numerical stability param 264 | self.epsilon = 1e-8 265 | 266 | def reset_parameters(self, init_weight, init_bias): 267 | # init means 268 | n = self.in_channels 269 | for k in self.kernel_size: 270 | n *= k 271 | stdv = 1. / math.sqrt(n) 272 | 273 | # init means 274 | if init_weight is not None: 275 | self.weight_mu.data = init_weight 276 | else: 277 | self.weight_mu.data.uniform_(-stdv, stdv) 278 | 279 | if init_bias is not None: 280 | self.bias_mu.data = init_bias 281 | else: 282 | self.bias_mu.data.fill_(0) 283 | 284 | # inti z 285 | self.z_mu.data.normal_(1, 1e-2) 286 | 287 | # init logvars 288 | self.z_logvar.data.normal_(-9, 1e-2) 289 | self.weight_logvar.data.normal_(-9, 1e-2) 290 | self.bias_logvar.data.normal_(-9, 1e-2) 291 | 292 | def clip_variances(self): 293 | if self.clip_var: 294 | self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) 295 | self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) 296 | 297 | def get_log_dropout_rates(self): 298 | log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) 299 | return log_alpha 300 | 301 | def compute_posterior_params(self): 302 | weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() 303 | print("self.z_mu.pow(2): ", self.z_mu.pow(2).size()) 304 | print("weight_var: ", weight_var.size()) 305 | print("z_var: ", z_var.size()) 306 | print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size()) 307 | print("weight_var: ", weight_var.size()) 308 | part1 = self.z_mu.pow(2) * weight_var 309 | part2 = z_var * self.weight_mu.pow(2) 310 | part3 = z_var * weight_var 311 | self.post_weight_var = part1 + part2 + part3 312 | self.post_weight_mu = self.weight_mu * self.z_mu 313 | print("post_weight_mu: ", self.post_weight_mu.size()) 314 | print("post_weight_var: ", self.post_weight_var.size()) 315 | return self.post_weight_mu, self.post_weight_var 316 | 317 | def kl_divergence(self): 318 | # KL(q(z)||p(z)) 319 | # we use the kl divergence approximation given by [2] Eq.(14) 320 | k1, k2, k3 = 0.63576, 1.87320, 1.48695 321 | log_alpha = self.get_log_dropout_rates() 322 | KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) 323 | 324 | # KL(q(w|z)||p(w|z)) 325 | # we use the kl divergence given by [3] Eq.(8) 326 | KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5 327 | KLD += torch.sum(KLD_element) 328 | 329 | # KL bias 330 | KLD_element = -self.bias_logvar + 0.5 * (self.bias_logvar.exp().pow(2) + self.bias_mu.pow(2)) - 0.5 331 | KLD += torch.sum(KLD_element) 332 | 333 | return KLD 334 | 335 | def __repr__(self): 336 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 337 | ', stride={stride}') 338 | if self.padding != (0,) * len(self.padding): 339 | s += ', padding={padding}' 340 | if self.dilation != (1,) * len(self.dilation): 341 | s += ', dilation={dilation}' 342 | if self.output_padding != (0,) * len(self.output_padding): 343 | s += ', output_padding={output_padding}' 344 | if self.groups != 1: 345 | s += ', groups={groups}' 346 | if self.bias is None: 347 | s += ', bias=False' 348 | s += ')' 349 | return s.format(name=self.__class__.__name__, **self.__dict__) 350 | 351 | 352 | class Conv1dGroupNJ(_ConvNdGroupNJ): 353 | r""" 354 | """ 355 | 356 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 357 | cuda=False, init_weight=None, init_bias=None, clip_var=None): 358 | kernel_size = utils._single(kernel_size) 359 | stride = utils._single(stride) 360 | padding = utils._single(padding) 361 | dilation = utils._single(dilation) 362 | 363 | super(Conv1dGroupNJ, self).__init__( 364 | in_channels, out_channels, kernel_size, stride, padding, dilation, 365 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var) 366 | 367 | def forward(self, x): 368 | if self.deterministic: 369 | assert self.training == False, "Flag deterministic is True. This should not be used in training." 370 | return F.conv1d(x, self.post_weight_mu, self.bias_mu) 371 | batch_size = x.size()[0] 372 | # apply local reparametrisation trick see [1] Eq. (6) 373 | # to the parametrisation given in [3] Eq. (6) 374 | mu_activations = F.conv1d(x, self.weight_mu, self.bias_mu, self.stride, 375 | self.padding, self.dilation, self.groups) 376 | 377 | var_activations = F.conv1d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride, 378 | self.padding, self.dilation, self.groups) 379 | # compute z 380 | # note that we reparametrise according to [2] Eq. (11) (not [1]) 381 | z = reparametrize(self.z_mu.repeat(batch_size, 1, 1), self.z_logvar.repeat(batch_size, 1, 1), 382 | sampling=self.training, cuda=self.cuda) 383 | z = z[:, :, None] 384 | 385 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training, 386 | cuda=self.cuda) 387 | 388 | def __repr__(self): 389 | return self.__class__.__name__ + ' (' \ 390 | + str(self.in_features) + ' -> ' \ 391 | + str(self.out_features) + ')' 392 | 393 | 394 | class Conv2dGroupNJ(_ConvNdGroupNJ): 395 | r""" 396 | """ 397 | 398 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 399 | cuda=False, init_weight=None, init_bias=None, clip_var=None): 400 | kernel_size = utils._pair(kernel_size) 401 | stride = utils._pair(stride) 402 | padding = utils._pair(padding) 403 | dilation = utils._pair(dilation) 404 | 405 | super(Conv2dGroupNJ, self).__init__( 406 | in_channels, out_channels, kernel_size, stride, padding, dilation, 407 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var) 408 | 409 | def forward(self, x): 410 | if self.deterministic: 411 | assert self.training == False, "Flag deterministic is True. This should not be used in training." 412 | return F.conv2d(x, self.post_weight_mu, self.bias_mu) 413 | batch_size = x.size()[0] 414 | # apply local reparametrisation trick see [1] Eq. (6) 415 | # to the parametrisation given in [3] Eq. (6) 416 | mu_activations = F.conv2d(x, self.weight_mu, self.bias_mu, self.stride, 417 | self.padding, self.dilation, self.groups) 418 | 419 | var_activations = F.conv2d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride, 420 | self.padding, self.dilation, self.groups) 421 | # compute z 422 | # note that we reparametrise according to [2] Eq. (11) (not [1]) 423 | z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), 424 | sampling=self.training, cuda=self.cuda) 425 | z = z[:, :, None, None] 426 | 427 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training, 428 | cuda=self.cuda) 429 | 430 | def __repr__(self): 431 | return self.__class__.__name__ + ' (' \ 432 | + str(self.in_channels) + ' -> ' \ 433 | + str(self.out_channels) + ')' 434 | 435 | 436 | class Conv3dGroupNJ(_ConvNdGroupNJ): 437 | r""" 438 | """ 439 | 440 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 441 | cuda=False, init_weight=None, init_bias=None, clip_var=None): 442 | kernel_size = utils._triple(kernel_size) 443 | stride = utils._triple(stride) 444 | padding = utils._triple(padding) 445 | dilation = utils._triple(dilation) 446 | 447 | super(Conv3dGroupNJ, self).__init__( 448 | in_channels, out_channels, kernel_size, stride, padding, dilation, 449 | False, utils._pair(0), groups, bias, init_weight, init_bias, cuda, clip_var) 450 | 451 | def forward(self, x): 452 | if self.deterministic: 453 | assert self.training == False, "Flag deterministic is True. This should not be used in training." 454 | return F.conv3d(x, self.post_weight_mu, self.bias_mu) 455 | batch_size = x.size()[0] 456 | # apply local reparametrisation trick see [1] Eq. (6) 457 | # to the parametrisation given in [3] Eq. (6) 458 | mu_activations = F.conv3d(x, self.weight_mu, self.bias_mu, self.stride, 459 | self.padding, self.dilation, self.groups) 460 | 461 | var_weights = self.weight_logvar.exp() 462 | var_activations = F.conv3d(x.pow(2), var_weights, self.bias_logvar.exp(), self.stride, 463 | self.padding, self.dilation, self.groups) 464 | # compute z 465 | # note that we reparametrise according to [2] Eq. (11) (not [1]) 466 | z = reparametrize(self.z_mu.repeat(batch_size, 1, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1, 1), 467 | sampling=self.training, cuda=self.cuda) 468 | z = z[:, :, None, None, None] 469 | 470 | return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training, 471 | cuda=self.cuda) 472 | 473 | def __repr__(self): 474 | return self.__class__.__name__ + ' (' \ 475 | + str(self.in_channels) + ' -> ' \ 476 | + str(self.out_channels) + ')' 477 | 478 | 479 | if __name__ == "__main__": 480 | conv = Conv2dGroupNJ(3, 12, 3, stride=1, padding=1, dilation=1, groups=1, bias=True, 481 | cuda=False, init_weight=None, init_bias=None, clip_var=None) 482 | data = torch.Tensor(3, 3, 28, 28) 483 | data = Variable(data) 484 | print(conv(data).size()) 485 | -------------------------------------------------------------------------------- /Bayesian/BayesianModule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | 10 | from Bayesian import BayesianLayers 11 | 12 | 13 | class BayesianModule(nn.Module): 14 | def __init__(self): 15 | self.kl_list = [] 16 | self.layers = [] 17 | super(BayesianModule, self).__init__() 18 | 19 | def __setattr__(self, name, value): 20 | super(BayesianModule, self).__setattr__(name, value) 21 | # simple hack to collect bayesian layer automatically 22 | if isinstance(value, BayesianLayers.BayesianLayers) and not isinstance(value, BayesianLayers._ConvNdGroupNJ): 23 | self.kl_list.append(value) 24 | self.layers.append(value) 25 | 26 | def get_masks(self, thresholds): 27 | weight_masks = [] 28 | mask = None 29 | for i, (layer, threshold) in enumerate(zip(self.kl_list, thresholds)): 30 | # compute dropout mask 31 | if mask is None: 32 | log_alpha = layer.get_log_dropout_rates().cpu().data.numpy() 33 | mask = log_alpha < threshold 34 | else: 35 | mask = np.copy(next_mask) 36 | 37 | try: 38 | log_alpha = self.layers[i + 1].get_log_dropout_rates().cpu().data.numpy() 39 | next_mask = log_alpha < thresholds[i + 1] 40 | except: 41 | # must be the last mask 42 | next_mask = np.ones(10) 43 | 44 | weight_mask = np.expand_dims(mask, axis=0) * np.expand_dims(next_mask, axis=1) 45 | weight_masks.append(weight_mask.astype(np.float)) 46 | return weight_masks 47 | 48 | def kl_divergence(self): 49 | KLD = 0 50 | for layer in self.kl_list: 51 | KLD += layer.kl_divergence() 52 | return KLD 53 | 54 | 55 | class MLP_Cifar(BayesianModule): 56 | def __init__(self, num_classes=10, use_cuda=True): 57 | super(MLP_Cifar, self).__init__() 58 | 59 | self.fc1 = BayesianLayers.LinearGroupNJ(3 * 32 * 32, 300, clip_var=0.04, cuda=use_cuda) 60 | self.fc2 = BayesianLayers.LinearGroupNJ(300, 100, cuda=use_cuda) 61 | self.fc3 = BayesianLayers.LinearGroupNJ(100, num_classes, cuda=use_cuda) 62 | 63 | def forward(self, x): 64 | x = x.view(-1, 3 * 32 * 32) 65 | x = F.relu(self.fc1(x)) 66 | x = F.relu(self.fc2(x)) 67 | x = self.fc3(x) 68 | return x 69 | 70 | 71 | class MLP_MNIST(BayesianModule): 72 | def __init__(self, num_classes=10, use_cuda=True): 73 | super(MLP_MNIST, self).__init__() 74 | 75 | self.fc1 = BayesianLayers.LinearGroupNJ(28 * 28, 300, clip_var=0.04, cuda=use_cuda) 76 | self.fc2 = BayesianLayers.LinearGroupNJ(300, 100, cuda=use_cuda) 77 | self.fc3 = BayesianLayers.LinearGroupNJ(100, num_classes, cuda=use_cuda) 78 | 79 | def forward(self, x): 80 | x = x.view(-1, 28 * 28) 81 | x = F.relu(self.fc1(x)) 82 | x = F.relu(self.fc2(x)) 83 | x = self.fc3(x) 84 | return x 85 | 86 | 87 | class LeNet_Cifar(BayesianModule): 88 | def __init__(self, num_classes=10, use_cuda=True): 89 | super(LeNet_Cifar, self).__init__() 90 | 91 | self.conv1 = BayesianLayers.Conv2dGroupNJ(3, 6, 5, cuda=use_cuda) 92 | self.conv2 = BayesianLayers.Conv2dGroupNJ(6, 16, 5, cuda=use_cuda) 93 | 94 | self.fc1 = BayesianLayers.LinearGroupNJ(16 * 5 * 5, 120, clip_var=0.04, cuda=use_cuda) 95 | self.fc2 = BayesianLayers.LinearGroupNJ(120, 84, cuda=use_cuda) 96 | self.fc3 = BayesianLayers.LinearGroupNJ(84, num_classes, cuda=use_cuda) 97 | 98 | def forward(self, x): 99 | x = F.relu(self.conv1(x)) 100 | x = F.max_pool2d(x, 2) 101 | x = F.relu(self.conv2(x)) 102 | x = F.max_pool2d(x, 2) 103 | out = x.view(x.size(0), -1) 104 | out = F.relu(self.fc1(out)) 105 | out = F.relu(self.fc2(out)) 106 | out = self.fc3(out) 107 | return out 108 | 109 | 110 | class LeNet_MNIST(BayesianModule): 111 | def __init__(self, num_classes=10, use_cuda=True): 112 | super(LeNet_MNIST, self).__init__() 113 | 114 | self.conv1 = BayesianLayers.Conv2dGroupNJ(1, 10, 5, cuda=use_cuda) 115 | self.conv2 = BayesianLayers.Conv2dGroupNJ(10, 20, 5, cuda=use_cuda) 116 | 117 | self.fc1 = BayesianLayers.LinearGroupNJ(320, 50, clip_var=0.04, cuda=use_cuda) 118 | self.fc2 = BayesianLayers.LinearGroupNJ(50, num_classes, cuda=use_cuda) 119 | 120 | def forward(self, x): 121 | x = F.relu(self.conv1(x)) 122 | x = F.max_pool2d(x, 2) 123 | x = F.relu(self.conv2(x)) 124 | x = F.max_pool2d(x, 2) 125 | out = x.view(x.size(0), -1) 126 | out = F.relu(self.fc1(out)) 127 | out = self.fc2(out) 128 | return out 129 | 130 | 131 | if __name__ == "__main__": 132 | net = LeNet_Cifar(use_cuda=False) 133 | data = torch.randn([1, 3, 32, 32]) 134 | data = Variable(data) 135 | print(net(data).size()) 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesian-Compression-for-Deep-Learning 2 | Remplementation of paper [https://arxiv.org/abs/1705.08665](https://arxiv.org/abs/1705.08665). This repo utilizes code from [Bayesian Optimization](https://github.com/fmfn/BayesianOptimization) and [Bayesian Tutorial](https://github.com/KarenUllrich/Tutorial_BayesianCompressionForDL) 3 | 4 | # Results 5 | 6 | Network | Dataset | Epochs | Accuracy(before) | Accuracy(after) | Compression Rate | 7 | --- | --- | --- | --- | --- | --- 8 | 3-Layer MLP | MNIST | 50 | 98.33% | 98.33% | 1.3 9 | 3-Layer MLP | CIFAR 10 | 50 | 56.26% | 54.47% | 3.5 10 | LeNet | MNIST | 50 | 99.24% | 99.26% | 1.4 11 | 12 | # Usage 13 | ```bash 14 | python example.py --dataset mnist --nettype mlp --epochs 50 15 | ``` 16 | 17 | # TODO: 18 | 1. Clean the train-prune code 19 | 2. Modularize Bayesian Layer and Module 20 | 3. Fix bug in Convolution -------------------------------------------------------------------------------- /Tools/Objective.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class objection(object): 4 | def __init__(self, data_loader, use_cuda=True): 5 | self.d_loss = nn.functional.cross_entropy 6 | self.N = 0 7 | if isinstance(data_loader, list): 8 | self.N += len(data_loader) 9 | else: 10 | self.N = len(data_loader) 11 | self.use_cuda = use_cuda 12 | 13 | def __call__(self, output, target, kl_divergence): 14 | d_error = self.d_loss(output, target) 15 | variational_bound = d_error + kl_divergence / self.N # TODO: why divide by N? 16 | if self.use_cuda: 17 | variational_bound = variational_bound.cuda() 18 | return variational_bound -------------------------------------------------------------------------------- /Tools/visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Utilities 6 | 7 | 8 | Karen Ullrich, Oct 2017 9 | """ 10 | 11 | import os 12 | import numpy as np 13 | import imageio 14 | 15 | import matplotlib.pyplot as plt 16 | import seaborn as sns 17 | 18 | sns.set_style("whitegrid") 19 | cmap = sns.diverging_palette(240, 10, sep=100, as_cmap=True) 20 | 21 | # ------------------------------------------------------- 22 | # VISUALISATION TOOLS 23 | # ------------------------------------------------------- 24 | 25 | def visualize_pixel_importance(imgs, log_alpha, epoch="pixel_importance"): 26 | num_imgs = len(imgs) 27 | 28 | f, ax = plt.subplots(1, num_imgs) 29 | plt.title("Epoch:" + epoch) 30 | for i, img in enumerate(imgs): 31 | img = (img / 255.) - 0.5 32 | mask = log_alpha.reshape(img.shape) 33 | mask = 1 - np.clip(np.exp(mask), 0.0, 1) 34 | ax[i].imshow(img * mask, cmap=cmap, interpolation='none', vmin=-0.5, vmax=0.5) 35 | ax[i].grid("off") 36 | ax[i].set_yticks([]) 37 | ax[i].set_xticks([]) 38 | plt.savefig("./.pixel" + epoch + ".png", bbox_inches='tight') 39 | plt.close() 40 | 41 | 42 | def visualise_weights(weight_mus, log_alphas, epoch): 43 | num_layers = len(weight_mus) 44 | 45 | for i in range(num_layers): 46 | f, ax = plt.subplots(1, 1) 47 | weight_mu = np.transpose(weight_mus[i].cpu().data.numpy()) 48 | # alpha 49 | log_alpha_fc1 = log_alphas[i].unsqueeze(1).cpu().data.numpy() 50 | log_alpha_fc1 = log_alpha_fc1 < -3 51 | log_alpha_fc2 = log_alphas[i + 1].unsqueeze(0).cpu().data.numpy() 52 | log_alpha_fc2 = log_alpha_fc2 < -3 53 | mask = log_alpha_fc1 + log_alpha_fc2 54 | # weight 55 | c = np.max(np.abs(weight_mu)) 56 | s = ax.imshow(weight_mu * mask, cmap='seismic', interpolation='none', vmin=-c, vmax=c) 57 | ax.grid("off") 58 | ax.set_yticks([]) 59 | ax.set_xticks([]) 60 | s.set_clim([-c * 0.5, c * 0.5]) 61 | f.colorbar(s) 62 | plt.title("Epoch:" + str(epoch)) 63 | plt.savefig("./.weight" + str(i) + '_e' + str(epoch) + ".png", bbox_inches='tight') 64 | plt.close() 65 | 66 | 67 | def generate_gif(save='tmp', epochs=10): 68 | images = [] 69 | filenames = ["./." + save + "%d.png" % (epoch + 1) for epoch in np.arange(epochs)] 70 | for filename in filenames: 71 | images.append(imageio.imread(filename)) 72 | os.remove(filename) 73 | imageio.mimsave('./figures/' + save + '.gif', images, duration=.5) 74 | -------------------------------------------------------------------------------- /bayes_opt/__init__.py: -------------------------------------------------------------------------------- 1 | from .bayesian_optimization import BayesianOptimization 2 | from .helpers import UtilityFunction 3 | 4 | __all__ = ["BayesianOptimization", "UtilityFunction"] 5 | -------------------------------------------------------------------------------- /bayes_opt/bayesian_optimization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import numpy as np 5 | import warnings 6 | from sklearn.gaussian_process import GaussianProcessRegressor 7 | from sklearn.gaussian_process.kernels import Matern 8 | from .helpers import (UtilityFunction, PrintLog, acq_max, ensure_rng) 9 | from .target_space import TargetSpace 10 | 11 | 12 | class BayesianOptimization(object): 13 | 14 | def __init__(self, f, pbounds, random_state=None, verbose=1): 15 | """ 16 | :param f: 17 | Function to be maximized. 18 | 19 | :param pbounds: 20 | Dictionary with parameters names as keys and a tuple with minimum 21 | and maximum values. 22 | 23 | :param verbose: 24 | Whether or not to print progress. 25 | 26 | """ 27 | # Store the original dictionary 28 | self.pbounds = pbounds 29 | 30 | self.random_state = ensure_rng(random_state) 31 | 32 | # Data structure containing the function to be optimized, the bounds of 33 | # its domain, and a record of the evaluations we have done so far 34 | self.space = TargetSpace(f, pbounds, random_state) 35 | 36 | # Initialization flag 37 | self.initialized = False 38 | 39 | # Initialization lists --- stores starting points before process begins 40 | self.init_points = [] 41 | self.x_init = [] 42 | self.y_init = [] 43 | 44 | # Counter of iterations 45 | self.i = 0 46 | 47 | # Internal GP regressor 48 | self.gp = GaussianProcessRegressor( 49 | kernel=Matern(nu=2.5), 50 | n_restarts_optimizer=25, 51 | random_state=self.random_state 52 | ) 53 | 54 | # Utility Function placeholder 55 | self.util = None 56 | 57 | # PrintLog object 58 | self.plog = PrintLog(self.space.keys) 59 | 60 | # Output dictionary 61 | self.res = {} 62 | # Output dictionary 63 | self.res['max'] = {'max_val': None, 64 | 'max_params': None} 65 | self.res['all'] = {'values': [], 'params': []} 66 | 67 | # non-public config for maximizing the aquisition function 68 | # (used to speedup tests, but generally leave these as is) 69 | self._acqkw = {'n_warmup': 100000, 'n_iter': 250} 70 | 71 | # Verbose 72 | self.verbose = verbose 73 | 74 | def init(self, init_points): 75 | """ 76 | Initialization method to kick start the optimization process. It is a 77 | combination of points passed by the user, and randomly sampled ones. 78 | 79 | :param init_points: 80 | Number of random points to probe. 81 | """ 82 | # Concatenate new random points to possible existing 83 | # points from self.explore method. 84 | rand_points = self.space.random_points(init_points) 85 | self.init_points.extend(rand_points) 86 | 87 | # Evaluate target function at all initialization points 88 | for x in self.init_points: 89 | y = self._observe_point(x) 90 | 91 | # Add the points from `self.initialize` to the observations 92 | if self.x_init: 93 | x_init = np.vstack(self.x_init) 94 | y_init = np.hstack(self.y_init) 95 | for x, y in zip(x_init, y_init): 96 | self.space.add_observation(x, y) 97 | if self.verbose: 98 | self.plog.print_step(x, y) 99 | 100 | # Updates the flag 101 | self.initialized = True 102 | 103 | def _observe_point(self, x): 104 | y = self.space.observe_point(x) 105 | if self.verbose: 106 | self.plog.print_step(x, y) 107 | return y 108 | 109 | def explore(self, points_dict, eager=False): 110 | """Method to explore user defined points. 111 | 112 | :param points_dict: 113 | :param eager: if True, these points are evaulated immediately 114 | """ 115 | if eager: 116 | self.plog.reset_timer() 117 | if self.verbose: 118 | self.plog.print_header(initialization=True) 119 | 120 | points = self.space._dict_to_points(points_dict) 121 | for x in points: 122 | self._observe_point(x) 123 | else: 124 | points = self.space._dict_to_points(points_dict) 125 | self.init_points = points 126 | 127 | def initialize(self, points_dict): 128 | """ 129 | Method to introduce points for which the target function value is known 130 | 131 | :param points_dict: 132 | dictionary with self.keys and 'target' as keys, and list of 133 | corresponding values as values. 134 | 135 | ex: 136 | { 137 | 'target': [-1166.19102, -1142.71370, -1138.68293], 138 | 'alpha': [7.0034, 6.6186, 6.0798], 139 | 'colsample_bytree': [0.6849, 0.7314, 0.9540], 140 | 'gamma': [8.3673, 3.5455, 2.3281], 141 | } 142 | 143 | :return: 144 | """ 145 | 146 | self.y_init.extend(points_dict['target']) 147 | for i in range(len(points_dict['target'])): 148 | all_points = [] 149 | for key in self.space.keys: 150 | all_points.append(points_dict[key][i]) 151 | self.x_init.append(all_points) 152 | 153 | def initialize_df(self, points_df): 154 | """ 155 | Method to introduce point for which the target function 156 | value is known from pandas dataframe file 157 | 158 | :param points_df: 159 | pandas dataframe with columns (target, {list of columns matching 160 | self.keys}) 161 | 162 | ex: 163 | target alpha colsample_bytree gamma 164 | -1166.19102 7.0034 0.6849 8.3673 165 | -1142.71370 6.6186 0.7314 3.5455 166 | -1138.68293 6.0798 0.9540 2.3281 167 | -1146.65974 2.4566 0.9290 0.3456 168 | -1160.32854 1.9821 0.5298 8.7863 169 | 170 | :return: 171 | """ 172 | 173 | for i in points_df.index: 174 | self.y_init.append(points_df.loc[i, 'target']) 175 | 176 | all_points = [] 177 | for key in self.space.keys: 178 | all_points.append(points_df.loc[i, key]) 179 | 180 | self.x_init.append(all_points) 181 | 182 | def set_bounds(self, new_bounds): 183 | """ 184 | A method that allows changing the lower and upper searching bounds 185 | 186 | :param new_bounds: 187 | A dictionary with the parameter name and its new bounds 188 | 189 | """ 190 | # Update the internal object stored dict 191 | self.pbounds.update(new_bounds) 192 | self.space.set_bounds(new_bounds) 193 | 194 | def maximize(self, 195 | init_points=5, 196 | n_iter=25, 197 | acq='ucb', 198 | kappa=2.576, 199 | xi=0.0, 200 | **gp_params): 201 | """ 202 | Main optimization method. 203 | 204 | Parameters 205 | ---------- 206 | :param init_points: 207 | Number of randomly chosen points to sample the 208 | target function before fitting the gp. 209 | 210 | :param n_iter: 211 | Total number of times the process is to repeated. Note that 212 | currently this methods does not have stopping criteria (due to a 213 | number of reasons), therefore the total number of points to be 214 | sampled must be specified. 215 | 216 | :param acq: 217 | Acquisition function to be used, defaults to Upper Confidence Bound. 218 | 219 | :param gp_params: 220 | Parameters to be passed to the Scikit-learn Gaussian Process object 221 | 222 | Returns 223 | ------- 224 | :return: Nothing 225 | 226 | Example: 227 | >>> xs = np.linspace(-2, 10, 10000) 228 | >>> f = np.exp(-(xs - 2)**2) + np.exp(-(xs - 6)**2/10) + 1/ (xs**2 + 1) 229 | >>> bo = BayesianOptimization(f=lambda x: f[int(x)], 230 | >>> pbounds={"x": (0, len(f)-1)}) 231 | >>> bo.maximize(init_points=2, n_iter=25, acq="ucb", kappa=1) 232 | """ 233 | # Reset timer 234 | self.plog.reset_timer() 235 | 236 | # Set acquisition function 237 | self.util = UtilityFunction(kind=acq, kappa=kappa, xi=xi) 238 | 239 | # Initialize x, y and find current y_max 240 | if not self.initialized: 241 | if self.verbose: 242 | self.plog.print_header() 243 | self.init(init_points) 244 | 245 | y_max = self.space.Y.max() 246 | 247 | # Set parameters if any was passed 248 | self.gp.set_params(**gp_params) 249 | 250 | # Find unique rows of X to avoid GP from breaking 251 | self.gp.fit(self.space.X, self.space.Y) 252 | 253 | # Finding argmax of the acquisition function. 254 | x_max = acq_max(ac=self.util.utility, 255 | gp=self.gp, 256 | y_max=y_max, 257 | bounds=self.space.bounds, 258 | random_state=self.random_state, 259 | **self._acqkw) 260 | 261 | # Print new header 262 | if self.verbose: 263 | self.plog.print_header(initialization=False) 264 | # Iterative process of searching for the maximum. At each round the 265 | # most recent x and y values probed are added to the X and Y arrays 266 | # used to train the Gaussian Process. Next the maximum known value 267 | # of the target function is found and passed to the acq_max function. 268 | # The arg_max of the acquisition function is found and this will be 269 | # the next probed value of the target function in the next round. 270 | for i in range(n_iter): 271 | # Test if x_max is repeated, if it is, draw another one at random 272 | # If it is repeated, print a warning 273 | pwarning = False 274 | while x_max in self.space: 275 | x_max = self.space.random_points(1)[0] 276 | pwarning = True 277 | 278 | # Append most recently generated values to X and Y arrays 279 | y = self.space.observe_point(x_max) 280 | if self.verbose: 281 | self.plog.print_step(x_max, y, pwarning) 282 | 283 | # Updating the GP. 284 | self.gp.fit(self.space.X, self.space.Y) 285 | 286 | # Update the best params seen so far 287 | self.res['max'] = self.space.max_point() 288 | self.res['all']['values'].append(y) 289 | self.res['all']['params'].append(dict(zip(self.space.keys, x_max))) 290 | 291 | # Update maximum value to search for next probe point. 292 | if self.space.Y[-1] > y_max: 293 | y_max = self.space.Y[-1] 294 | 295 | # Maximize acquisition function to find next probing point 296 | x_max = acq_max(ac=self.util.utility, 297 | gp=self.gp, 298 | y_max=y_max, 299 | bounds=self.space.bounds, 300 | random_state=self.random_state, 301 | **self._acqkw) 302 | 303 | # Keep track of total number of iterations 304 | self.i += 1 305 | 306 | # Print a final report if verbose active. 307 | if self.verbose: 308 | self.plog.print_summary() 309 | 310 | def points_to_csv(self, file_name): 311 | """ 312 | After training all points for which we know target variable 313 | (both from initialization and optimization) are saved 314 | 315 | :param file_name: name of the file where points will be saved in the csv 316 | format 317 | 318 | :return: None 319 | """ 320 | 321 | points = np.hstack((self.space.X, np.expand_dims(self.space.Y, axis=1))) 322 | header = ', '.join(self.space.keys + ['target']) 323 | np.savetxt(file_name, points, header=header, delimiter=',') 324 | 325 | # --- API compatibility --- 326 | 327 | @property 328 | def X(self): 329 | warnings.warn("use self.space.X instead", DeprecationWarning) 330 | return self.space.X 331 | 332 | @property 333 | def Y(self): 334 | warnings.warn("use self.space.Y instead", DeprecationWarning) 335 | return self.space.Y 336 | 337 | @property 338 | def keys(self): 339 | warnings.warn("use self.space.keys instead", DeprecationWarning) 340 | return self.space.keys 341 | 342 | @property 343 | def f(self): 344 | warnings.warn("use self.space.target_func instead", DeprecationWarning) 345 | return self.space.target_func 346 | 347 | @property 348 | def bounds(self): 349 | warnings.warn("use self.space.dim instead", DeprecationWarning) 350 | return self.space.bounds 351 | 352 | @property 353 | def dim(self): 354 | warnings.warn("use self.space.dim instead", DeprecationWarning) 355 | return self.space.dim 356 | -------------------------------------------------------------------------------- /bayes_opt/helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import numpy as np 4 | from datetime import datetime 5 | from scipy.stats import norm 6 | from scipy.optimize import minimize 7 | 8 | 9 | def acq_max(ac, gp, y_max, bounds, random_state, n_warmup=100000, n_iter=250): 10 | """ 11 | A function to find the maximum of the acquisition function 12 | 13 | It uses a combination of random sampling (cheap) and the 'L-BFGS-B' 14 | optimization method. First by sampling `n_warmup` (1e5) points at random, 15 | and then running L-BFGS-B from `n_iter` (250) random starting points. 16 | 17 | Parameters 18 | ---------- 19 | :param ac: 20 | The acquisition function object that return its point-wise value. 21 | 22 | :param gp: 23 | A gaussian process fitted to the relevant data. 24 | 25 | :param y_max: 26 | The current maximum known value of the target function. 27 | 28 | :param bounds: 29 | The variables bounds to limit the search of the acq max. 30 | 31 | :param random_state: 32 | instance of np.RandomState random number generator 33 | 34 | :param n_warmup: 35 | number of times to randomly sample the aquisition function 36 | 37 | :param n_iter: 38 | number of times to run scipy.minimize 39 | 40 | Returns 41 | ------- 42 | :return: x_max, The arg max of the acquisition function. 43 | """ 44 | 45 | # Warm up with random points 46 | x_tries = random_state.uniform(bounds[:, 0], bounds[:, 1], 47 | size=(n_warmup, bounds.shape[0])) 48 | ys = ac(x_tries, gp=gp, y_max=y_max) 49 | x_max = x_tries[ys.argmax()] 50 | max_acq = ys.max() 51 | 52 | # Explore the parameter space more throughly 53 | x_seeds = random_state.uniform(bounds[:, 0], bounds[:, 1], 54 | size=(n_iter, bounds.shape[0])) 55 | for x_try in x_seeds: 56 | # Find the minimum of minus the acquisition function 57 | res = minimize(lambda x: -ac(x.reshape(1, -1), gp=gp, y_max=y_max), 58 | x_try.reshape(1, -1), 59 | bounds=bounds, 60 | method="L-BFGS-B") 61 | 62 | # Store it if better than previous minimum(maximum). 63 | if max_acq is None or -res.fun[0] >= max_acq: 64 | x_max = res.x 65 | max_acq = -res.fun[0] 66 | 67 | # Clip output to make sure it lies within the bounds. Due to floating 68 | # point technicalities this is not always the case. 69 | return np.clip(x_max, bounds[:, 0], bounds[:, 1]) 70 | 71 | 72 | class UtilityFunction(object): 73 | """ 74 | An object to compute the acquisition functions. 75 | """ 76 | 77 | def __init__(self, kind, kappa, xi): 78 | """ 79 | If UCB is to be used, a constant kappa is needed. 80 | """ 81 | self.kappa = kappa 82 | 83 | self.xi = xi 84 | 85 | if kind not in ['ucb', 'ei', 'poi']: 86 | err = "The utility function " \ 87 | "{} has not been implemented, " \ 88 | "please choose one of ucb, ei, or poi.".format(kind) 89 | raise NotImplementedError(err) 90 | else: 91 | self.kind = kind 92 | 93 | def utility(self, x, gp, y_max): 94 | if self.kind == 'ucb': 95 | return self._ucb(x, gp, self.kappa) 96 | if self.kind == 'ei': 97 | return self._ei(x, gp, y_max, self.xi) 98 | if self.kind == 'poi': 99 | return self._poi(x, gp, y_max, self.xi) 100 | 101 | @staticmethod 102 | def _ucb(x, gp, kappa): 103 | mean, std = gp.predict(x, return_std=True) 104 | return mean + kappa * std 105 | 106 | @staticmethod 107 | def _ei(x, gp, y_max, xi): 108 | mean, std = gp.predict(x, return_std=True) 109 | z = (mean - y_max - xi)/std 110 | return (mean - y_max - xi) * norm.cdf(z) + std * norm.pdf(z) 111 | 112 | @staticmethod 113 | def _poi(x, gp, y_max, xi): 114 | mean, std = gp.predict(x, return_std=True) 115 | z = (mean - y_max - xi)/std 116 | return norm.cdf(z) 117 | 118 | 119 | def unique_rows(a): 120 | """ 121 | A functions to trim repeated rows that may appear when optimizing. 122 | This is necessary to avoid the sklearn GP object from breaking 123 | 124 | :param a: array to trim repeated rows from 125 | 126 | :return: mask of unique rows 127 | """ 128 | if a.size == 0: 129 | return np.empty((0,)) 130 | 131 | # Sort array and kep track of where things should go back to 132 | order = np.lexsort(a.T) 133 | reorder = np.argsort(order) 134 | 135 | a = a[order] 136 | diff = np.diff(a, axis=0) 137 | ui = np.ones(len(a), 'bool') 138 | ui[1:] = (diff != 0).any(axis=1) 139 | 140 | return ui[reorder] 141 | 142 | 143 | def ensure_rng(random_state=None): 144 | """ 145 | Creates a random number generator based on an optional seed. This can be 146 | an integer or another random state for a seeded rng, or None for an 147 | unseeded rng. 148 | """ 149 | if random_state is None: 150 | random_state = np.random.RandomState() 151 | elif isinstance(random_state, int): 152 | random_state = np.random.RandomState(random_state) 153 | else: 154 | assert isinstance(random_state, np.random.RandomState) 155 | return random_state 156 | 157 | 158 | class BColours(object): 159 | BLUE = '\033[94m' 160 | CYAN = '\033[36m' 161 | GREEN = '\033[32m' 162 | MAGENTA = '\033[35m' 163 | RED = '\033[31m' 164 | ENDC = '\033[0m' 165 | 166 | 167 | class PrintLog(object): 168 | 169 | def __init__(self, params): 170 | 171 | self.ymax = None 172 | self.xmax = None 173 | self.params = params 174 | self.ite = 1 175 | 176 | self.start_time = datetime.now() 177 | self.last_round = datetime.now() 178 | 179 | # sizes of parameters name and all 180 | self.sizes = [max(len(ps), 7) for ps in params] 181 | 182 | # Sorted indexes to access parameters 183 | self.sorti = sorted(range(len(self.params)), 184 | key=self.params.__getitem__) 185 | 186 | def reset_timer(self): 187 | self.start_time = datetime.now() 188 | self.last_round = datetime.now() 189 | 190 | def print_header(self, initialization=True): 191 | 192 | if initialization: 193 | print("{}Initialization{}".format(BColours.RED, 194 | BColours.ENDC)) 195 | else: 196 | print("{}Bayesian Optimization{}".format(BColours.RED, 197 | BColours.ENDC)) 198 | 199 | print(BColours.BLUE + "-" * (29 + sum([s + 5 for s in self.sizes])) + 200 | BColours.ENDC) 201 | 202 | print("{0:>{1}}".format("Step", 5), end=" | ") 203 | print("{0:>{1}}".format("Time", 6), end=" | ") 204 | print("{0:>{1}}".format("Value", 10), end=" | ") 205 | 206 | for index in self.sorti: 207 | print("{0:>{1}}".format(self.params[index], 208 | self.sizes[index] + 2), 209 | end=" | ") 210 | print('') 211 | 212 | def print_step(self, x, y, warning=False): 213 | 214 | print("{:>5d}".format(self.ite), end=" | ") 215 | 216 | m, s = divmod((datetime.now() - self.last_round).total_seconds(), 60) 217 | print("{:>02d}m{:>02d}s".format(int(m), int(s)), end=" | ") 218 | 219 | if self.ymax is None or self.ymax < y: 220 | self.ymax = y 221 | self.xmax = x 222 | print("{0}{2: >10.5f}{1}".format(BColours.MAGENTA, 223 | BColours.ENDC, 224 | y), 225 | end=" | ") 226 | 227 | for index in self.sorti: 228 | print("{0}{2: >{3}.{4}f}{1}".format( 229 | BColours.GREEN, BColours.ENDC, 230 | x[index], 231 | self.sizes[index] + 2, 232 | min(self.sizes[index] - 3, 6 - 2) 233 | ), 234 | end=" | ") 235 | else: 236 | print("{: >10.5f}".format(y), end=" | ") 237 | for index in self.sorti: 238 | print("{0: >{1}.{2}f}".format(x[index], 239 | self.sizes[index] + 2, 240 | min(self.sizes[index] - 3, 6 - 2)), 241 | end=" | ") 242 | 243 | if warning: 244 | print("{}Warning: Test point chose at " 245 | "random due to repeated sample.{}".format(BColours.RED, 246 | BColours.ENDC)) 247 | 248 | print() 249 | 250 | self.last_round = datetime.now() 251 | self.ite += 1 252 | 253 | def print_summary(self): 254 | pass 255 | -------------------------------------------------------------------------------- /bayes_opt/target_space.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import numpy as np 3 | from .helpers import ensure_rng, unique_rows 4 | 5 | 6 | def _hashable(x): 7 | """ ensure that an point is hashable by a python dict """ 8 | return tuple(map(float, x)) 9 | 10 | 11 | class TargetSpace(object): 12 | """ 13 | Holds the param-space coordinates (X) and target values (Y) 14 | Allows for constant-time appends while ensuring no duplicates are added 15 | 16 | Example 17 | ------- 18 | >>> def target_func(p1, p2): 19 | >>> return p1 + p2 20 | >>> pbounds = {'p1': (0, 1), 'p2': (1, 100)} 21 | >>> space = TargetSpace(target_func, pbounds, random_state=0) 22 | >>> x = space.random_points(1)[0] 23 | >>> y = space.observe_point(x) 24 | >>> assert self.max_point()['max_val'] == y 25 | """ 26 | def __init__(self, target_func, pbounds, random_state=None): 27 | """ 28 | Parameters 29 | ---------- 30 | target_func : function 31 | Function to be maximized. 32 | 33 | pbounds : dict 34 | Dictionary with parameters names as keys and a tuple with minimum 35 | and maximum values. 36 | 37 | random_state : int, RandomState, or None 38 | optionally specify a seed for a random number generator 39 | """ 40 | 41 | self.random_state = ensure_rng(random_state) 42 | 43 | # Some function to be optimized 44 | self.target_func = target_func 45 | 46 | # Get the name of the parameters 47 | self.keys = list(pbounds.keys()) 48 | # Create an array with parameters bounds 49 | self.bounds = np.array(list(pbounds.values()), dtype=np.float) 50 | # Find number of parameters 51 | self.dim = len(self.keys) 52 | 53 | # preallocated memory for X and Y points 54 | self._Xarr = None 55 | self._Yarr = None 56 | 57 | # Number of observations 58 | self._length = 0 59 | 60 | # Views of the preallocated arrays showing only populated data 61 | self._Xview = None 62 | self._Yview = None 63 | 64 | self._cache = {} # keep track of unique points we have seen so far 65 | 66 | @property 67 | def X(self): 68 | return self._Xview 69 | 70 | @property 71 | def Y(self): 72 | return self._Yview 73 | 74 | def __contains__(self, x): 75 | return _hashable(x) in self._cache 76 | 77 | def __len__(self): 78 | return self._length 79 | 80 | def _dict_to_points(self, points_dict): 81 | """ 82 | Example: 83 | ------- 84 | >>> pbounds = {'p1': (0, 1), 'p2': (1, 100)} 85 | >>> space = TargetSpace(lambda p1, p2: p1 + p2, pbounds) 86 | >>> points_dict = {'p1': [0, .5, 1], 'p2': [0, 1, 2]} 87 | >>> space._dict_to_points(points_dict) 88 | [[0, 0], [1, 0.5], [2, 1]] 89 | """ 90 | # Consistency check 91 | param_tup_lens = [] 92 | 93 | for key in self.keys: 94 | param_tup_lens.append(len(list(points_dict[key]))) 95 | 96 | if all([e == param_tup_lens[0] for e in param_tup_lens]): 97 | pass 98 | else: 99 | raise ValueError('The same number of initialization points ' 100 | 'must be entered for every parameter.') 101 | 102 | # Turn into list of lists 103 | all_points = [] 104 | for key in self.keys: 105 | all_points.append(points_dict[key]) 106 | 107 | # Take transpose of list 108 | points = list(map(list, zip(*all_points))) 109 | return points 110 | 111 | def observe_point(self, x): 112 | """ 113 | Evaulates a single point x, to obtain the value y and then records them 114 | as observations. 115 | 116 | Notes 117 | ----- 118 | If x has been previously seen returns a cached value of y. 119 | 120 | Parameters 121 | ---------- 122 | x : ndarray 123 | a single point, with len(x) == self.dim 124 | 125 | Returns 126 | ------- 127 | y : float 128 | target function value. 129 | """ 130 | x = np.asarray(x).ravel() 131 | assert x.size == self.dim, 'x must have the same dimensions' 132 | 133 | if x in self: 134 | # Lookup previously seen point 135 | y = self._cache[_hashable(x)] 136 | else: 137 | # measure the target function 138 | params = dict(zip(self.keys, x)) 139 | y = self.target_func(**params) 140 | self.add_observation(x, y) 141 | return y 142 | 143 | def add_observation(self, x, y): 144 | """ 145 | Append a point and its target value to the known data. 146 | 147 | Parameters 148 | ---------- 149 | x : ndarray 150 | a single point, with len(x) == self.dim 151 | 152 | y : float 153 | target function value 154 | 155 | Raises 156 | ------ 157 | KeyError: 158 | if the point is not unique 159 | 160 | Notes 161 | ----- 162 | runs in ammortized constant time 163 | 164 | Example 165 | ------- 166 | >>> pbounds = {'p1': (0, 1), 'p2': (1, 100)} 167 | >>> space = TargetSpace(lambda p1, p2: p1 + p2, pbounds) 168 | >>> len(space) 169 | 0 170 | >>> x = np.array([0, 0]) 171 | >>> y = 1 172 | >>> space.add_observation(x, y) 173 | >>> len(space) 174 | 1 175 | """ 176 | if x in self: 177 | raise KeyError('Data point {} is not unique'.format(x)) 178 | 179 | if self._length >= self._n_alloc_rows: 180 | self._allocate((self._length + 1) * 2) 181 | 182 | x = np.asarray(x).ravel() 183 | 184 | # Insert data into unique dictionary 185 | self._cache[_hashable(x)] = y 186 | 187 | # Insert data into preallocated arrays 188 | self._Xarr[self._length] = x 189 | self._Yarr[self._length] = y 190 | # Expand views to encompass the new data point 191 | self._length += 1 192 | 193 | # Create views of the data 194 | self._Xview = self._Xarr[:self._length] 195 | self._Yview = self._Yarr[:self._length] 196 | 197 | def _allocate(self, num): 198 | """ 199 | Allocate enough memory to store `num` points 200 | """ 201 | if num <= self._n_alloc_rows: 202 | raise ValueError('num must be larger than current array length') 203 | 204 | self._assert_internal_invariants() 205 | 206 | # Allocate new memory 207 | _Xnew = np.empty((num, self.bounds.shape[0])) 208 | _Ynew = np.empty(num) 209 | 210 | # Copy the old data into the new 211 | if self._Xarr is not None: 212 | _Xnew[:self._length] = self._Xarr[:self._length] 213 | _Ynew[:self._length] = self._Yarr[:self._length] 214 | self._Xarr = _Xnew 215 | self._Yarr = _Ynew 216 | 217 | # Create views of the data 218 | self._Xview = self._Xarr[:self._length] 219 | self._Yview = self._Yarr[:self._length] 220 | 221 | @property 222 | def _n_alloc_rows(self): 223 | """ Number of allocated rows """ 224 | return 0 if self._Xarr is None else self._Xarr.shape[0] 225 | 226 | def random_points(self, num): 227 | """ 228 | Creates random points within the bounds of the space 229 | 230 | Parameters 231 | ---------- 232 | num : int 233 | Number of random points to create 234 | 235 | Returns 236 | ---------- 237 | data: ndarray 238 | [num x dim] array points with dimensions corresponding to `self.keys` 239 | 240 | Example 241 | ------- 242 | >>> target_func = lambda p1, p2: p1 + p2 243 | >>> pbounds = {'p1': (0, 1), 'p2': (1, 100)} 244 | >>> space = TargetSpace(target_func, pbounds, random_state=0) 245 | >>> space.random_points(3) 246 | array([[ 55.33253689, 0.54488318], 247 | [ 71.80374727, 0.4236548 ], 248 | [ 60.67357423, 0.64589411]]) 249 | """ 250 | # TODO: support integer, category, and basic scipy.optimize constraints 251 | data = np.empty((num, self.dim)) 252 | for col, (lower, upper) in enumerate(self.bounds): 253 | data.T[col] = self.random_state.uniform(lower, upper, size=num) 254 | return data 255 | 256 | def max_point(self): 257 | """ 258 | Return the current parameters that best maximize target function with 259 | that maximum value. 260 | """ 261 | return {'max_val': self.Y.max(), 262 | 'max_params': dict(zip(self.keys, 263 | self.X[self.Y.argmax()]))} 264 | 265 | def set_bounds(self, new_bounds): 266 | """ 267 | A method that allows changing the lower and upper searching bounds 268 | 269 | Parameters 270 | ---------- 271 | new_bounds : dict 272 | A dictionary with the parameter name and its new bounds 273 | """ 274 | # Loop through the all bounds and reset the min-max bound matrix 275 | for row, key in enumerate(self.keys): 276 | if key in new_bounds: 277 | self.bounds[row] = new_bounds[key] 278 | 279 | def _assert_internal_invariants(self, fast=True): 280 | """ 281 | Run internal consistency checks to ensure that data structure 282 | assumptions have not been violated. 283 | """ 284 | if self._Xarr is None: 285 | assert self._Yarr is None 286 | assert self._Xview is None 287 | assert self._Yview is None 288 | else: 289 | assert self._Yarr is not None 290 | assert self._Xview is not None 291 | assert self._Yview is not None 292 | assert len(self._Xview) == self._length 293 | assert len(self._Yview) == self._length 294 | assert len(self._Xarr) == len(self._Yarr) 295 | 296 | if not fast: 297 | # run slower checks 298 | assert np.all(unique_rows(self.X)) 299 | # assert np.may_share_memory(self._Xview, self._Xarr) 300 | # assert np.may_share_memory(self._Yview, self._Yarr) 301 | -------------------------------------------------------------------------------- /compression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Compression Tools 6 | 7 | 8 | Karen Ullrich, Oct 2017 9 | 10 | References: 11 | 12 | [1] Michael T. Heath. 1996. Scientific Computing: An Introductory Survey (2nd ed.). Eric M. Munson (Ed.). McGraw-Hill Higher Education. Chapter 1 13 | """ 14 | 15 | import numpy as np 16 | 17 | # ------------------------------------------------------- 18 | # General tools 19 | # ------------------------------------------------------- 20 | 21 | 22 | def unit_round_off(t=23): 23 | """ 24 | :param t: 25 | number significand bits 26 | :return: 27 | unit round off based on nearest interpolation, for reference see [1] 28 | """ 29 | return 0.5 * 2. ** (1. - t) 30 | 31 | 32 | SIGNIFICANT_BIT_PRECISION = [unit_round_off(t=i + 1) for i in range(23)] 33 | 34 | 35 | def float_precision(x): 36 | 37 | out = np.sum([x < sbp for sbp in SIGNIFICANT_BIT_PRECISION]) 38 | return out 39 | 40 | 41 | def float_precisions(X, dist_fun, layer=1): 42 | 43 | X = X.flatten() 44 | out = [float_precision(2 * x) for x in X] 45 | out = np.ceil(dist_fun(out)) 46 | return out 47 | 48 | 49 | def special_round(input, significant_bit): 50 | delta = unit_round_off(t=significant_bit) 51 | rounded = np.floor(input / delta + 0.5) 52 | rounded = rounded * delta 53 | return rounded 54 | 55 | 56 | def fast_infernce_weights(w, exponent_bit, significant_bit): 57 | 58 | return special_round(w, significant_bit) 59 | 60 | 61 | def compress_matrix(x): 62 | 63 | if len(x.shape) != 2: 64 | A, B, C, D = x.shape 65 | x = x.reshape(A * B, C * D) 66 | # remove non-necessary filters and rows 67 | x = x[:, (x != 0).any(axis=0)] 68 | x = x[(x != 0).any(axis=1), :] 69 | else: 70 | # remove unnecessary rows, columns 71 | x = x[(x != 0).any(axis=1), :] 72 | x = x[:, (x != 0).any(axis=0)] 73 | return x 74 | 75 | 76 | def extract_pruned_params(layers, masks): 77 | 78 | post_weight_mus = [] 79 | post_weight_vars = [] 80 | 81 | for i, (layer, mask) in enumerate(zip(layers, masks)): 82 | # compute posteriors 83 | post_weight_mu, post_weight_var = layer.compute_posterior_params() 84 | post_weight_var = post_weight_var.cpu().data.numpy() 85 | post_weight_mu = post_weight_mu.cpu().data.numpy() 86 | # apply mask to mus and variances 87 | post_weight_mu = post_weight_mu * mask 88 | post_weight_var = post_weight_var * mask 89 | 90 | post_weight_mus.append(post_weight_mu) 91 | post_weight_vars.append(post_weight_var) 92 | 93 | return post_weight_mus, post_weight_vars 94 | 95 | 96 | # ------------------------------------------------------- 97 | # Compression rates (fast inference scenario) 98 | # ------------------------------------------------------- 99 | 100 | 101 | def _compute_compression_rate(vars, in_precision=32., dist_fun=lambda x: np.max(x), overflow=10e38): 102 | # compute in number of bits occupied by the original architecture 103 | sizes = [v.size for v in vars] 104 | nb_weights = float(np.sum(sizes)) 105 | IN_BITS = in_precision * nb_weights 106 | # prune architecture 107 | vars = [compress_matrix(v) for v in vars] 108 | sizes = [v.size for v in vars] 109 | # compute 110 | significant_bits = [float_precisions(v, dist_fun, layer=k + 1) for k, v in enumerate(vars)] 111 | exponent_bit = np.ceil(np.log2(np.log2(overflow) + 1.) + 1.) 112 | total_bits = [1. + exponent_bit + sb for sb in significant_bits] 113 | OUT_BITS = np.sum(np.asarray(sizes) * np.asarray(total_bits)) 114 | return nb_weights / np.sum(sizes), IN_BITS / OUT_BITS, significant_bits, exponent_bit 115 | 116 | 117 | def compute_compression_rate(layers, masks): 118 | # reduce architecture 119 | weight_mus, weight_vars = extract_pruned_params(layers, masks) 120 | # compute overflow level based on maximum weight 121 | overflow = np.max([np.max(np.abs(w)) for w in weight_mus]) 122 | # compute compression rate 123 | CR_architecture, CR_fast_inference, _, _ = _compute_compression_rate(weight_vars, dist_fun=lambda x: np.mean(x), overflow=overflow) 124 | print("Compressing the architecture will degrease the model by a factor of %.1f." % (CR_architecture)) 125 | print("Making use of weight uncertainty can reduce the model by a factor of %.1f." % (CR_fast_inference)) 126 | 127 | 128 | def compute_reduced_weights(layers, masks): 129 | weight_mus, weight_vars = extract_pruned_params(layers, masks) 130 | overflow = np.max([np.max(np.abs(w)) for w in weight_mus]) 131 | _, _, significant_bits, exponent_bits = _compute_compression_rate(weight_vars, dist_fun=lambda x: np.mean(x), overflow=overflow) 132 | weights = [fast_infernce_weights(weight_mu, exponent_bits, significant_bit) for weight_mu, significant_bit in 133 | zip(weight_mus, significant_bits)] 134 | return weights -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from torchvision import datasets, transforms 9 | 10 | from compression import compute_compression_rate, compute_reduced_weights 11 | from Bayesian import BayesianModule 12 | 13 | N = 60000. # number of data points in the training set 14 | 15 | 16 | def main(FLAGS): 17 | # import data 18 | kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {} 19 | 20 | if FLAGS.dataset == "cifar10": 21 | proj_dst = datasets.CIFAR10 22 | num_classes = 10 23 | elif FLAGS.dataset == "cifar100": 24 | proj_dst = datasets.CIFAR100 25 | num_classes = 100 26 | elif FLAGS.dataset == "mnist": 27 | proj_dst = datasets.MNIST 28 | num_classes = 10 29 | 30 | train_loader = torch.utils.data.DataLoader( 31 | datasets.MNIST('../data', train=True, download=True, 32 | transform=transforms.Compose([ 33 | transforms.ToTensor(), 34 | lambda x: 2 * (x - 0.5), 35 | ])), 36 | batch_size=FLAGS.batchsize, shuffle=True, **kwargs) 37 | 38 | test_loader = torch.utils.data.DataLoader( 39 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 40 | transforms.ToTensor(), 41 | lambda x: 2 * (x - 0.5), 42 | ])), 43 | batch_size=FLAGS.batchsize, shuffle=True, **kwargs) 44 | 45 | if FLAGS.dataset.startswith("cifar"): 46 | if FLAGS.nettype == "lenet": 47 | model = BayesianModule.LeNet_Cifar(num_classes) 48 | elif FLAGS.nettype == "mlp": 49 | model = BayesianModule.MLP_Cifar(num_classes) 50 | elif FLAGS.dataset == "mnist": 51 | if FLAGS.nettype == "lenet": 52 | model = BayesianModule.LeNet_MNIST(num_classes) 53 | elif FLAGS.nettype == "mlp": 54 | model = BayesianModule.MLP_MNIST(num_classes) 55 | 56 | print(FLAGS.dataset, FLAGS.nettype) 57 | if FLAGS.cuda: 58 | model.cuda() 59 | 60 | # init optimizer 61 | optimizer = optim.Adam(model.parameters()) 62 | 63 | # we optimize the variational lower bound scaled by the number of data 64 | # points (so we can keep our intuitions about hyper-params such as the learning rate) 65 | discrimination_loss = nn.functional.cross_entropy 66 | 67 | class objection(object): 68 | def __init__(self, N, use_cuda=True): 69 | self.d_loss = nn.functional.cross_entropy 70 | self.N = N 71 | self.use_cuda = use_cuda 72 | 73 | def __call__(self, output, target, kl_divergence): 74 | d_error = self.d_loss(output, target) 75 | variational_bound = d_error + kl_divergence / self.N # TODO: why divide by N? 76 | if self.use_cuda: 77 | variational_bound = variational_bound.cuda() 78 | return variational_bound 79 | 80 | objective = objection(len(train_loader.dataset)) 81 | 82 | from trainer import Trainer 83 | trainer = Trainer(model, train_loader, test_loader, optimizer, objective) 84 | # train the model and save some visualisations on the way 85 | for epoch in range(1, FLAGS.epochs + 1): 86 | trainer.train(epoch) 87 | trainer.test() 88 | 89 | # compute compression rate and new model accuracy 90 | layers = model.layers 91 | thresholds = FLAGS.thresholds 92 | compute_compression_rate(layers, model.get_masks(thresholds)) 93 | 94 | print("Test error after with reduced bit precision:") 95 | 96 | weights = compute_reduced_weights(layers, model.get_masks(thresholds)) 97 | for layer, weight in zip(layers, weights): 98 | if FLAGS.cuda: 99 | layer.post_weight_mu.data = torch.Tensor(weight).cuda() 100 | else: 101 | layer.post_weight_mu.data = torch.Tensor(weight) 102 | 103 | for layer in layers: 104 | layer.deterministic = True 105 | trainer.test() 106 | 107 | 108 | if __name__ == '__main__': 109 | import argparse 110 | 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('--epochs', type=int, default=50) 113 | parser.add_argument('--batchsize', type=int, default=256) 114 | parser.add_argument('--thresholds', type=float, nargs='*', default=[-2.8, -3., -5., ]) 115 | 116 | parser.add_argument('--dataset', type=str, choices=["cifar10", "cifar100", "mnist"], default="mnist") 117 | parser.add_argument('--nettype', type=str, choices=["mlp", "lenet"], default="mlp") 118 | 119 | FLAGS = parser.parse_args() 120 | FLAGS.cuda = torch.cuda.is_available() # check if we can put the net on the GPU 121 | 122 | main(FLAGS) 123 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torchvision import datasets, transforms 9 | from torch.autograd import Variable 10 | 11 | 12 | class Trainer(): 13 | def __init__(self, model, train_loader, test_loader, optimizer, criterion, 14 | use_cuda=True): 15 | self.model = model 16 | self.train_loader = train_loader 17 | self.test_loader = test_loader 18 | self.optimizer = optimizer 19 | self.criterion = criterion 20 | self.use_cuda = use_cuda 21 | 22 | def train(self, epoch): 23 | self.model.train() 24 | for batch_idx, (data, target) in enumerate(self.train_loader): 25 | if self.use_cuda: 26 | data, target = data.cuda(), target.cuda() 27 | data, target = Variable(data), Variable(target) 28 | self.optimizer.zero_grad() 29 | output = self.model(data) 30 | loss = self.criterion(output, target, self.model.kl_divergence()) 31 | loss.backward() 32 | self.optimizer.step() 33 | # clip the variances after each step 34 | for layer in self.model.kl_list: 35 | layer.clip_variances() 36 | print('Epoch: {} \tTrain loss: {:.6f} \t'.format(epoch, loss.data[0])) 37 | 38 | def test(self): 39 | self.model.eval() 40 | test_loss = 0 41 | correct = 0 42 | for data, target in self.test_loader: 43 | if self.use_cuda: 44 | data, target = data.cuda(), target.cuda() 45 | data, target = Variable(data, volatile=True), Variable(target) 46 | output = self.model(data) 47 | test_loss += F.cross_entropy(output, target, size_average=False).data[0] 48 | pred = output.data.max(1, keepdim=True)[1] 49 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 50 | test_loss /= len(self.test_loader.dataset) 51 | print('Test loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 52 | test_loss, correct, len(self.test_loader.dataset), 53 | 100. * correct / len(self.test_loader.dataset))) 54 | 55 | --------------------------------------------------------------------------------