├── scalablebdl ├── __init__.py ├── low_rank │ ├── __init__.py │ ├── converter.py │ ├── linear.py │ └── conv.py ├── empirical │ ├── __init__.py │ ├── converter.py │ ├── linear.py │ ├── prelu.py │ ├── conv.py │ └── batchnorm.py ├── implicit │ ├── __init__.py │ ├── converter.py │ ├── linear.py │ ├── prelu.py │ ├── conv.py │ └── batchnorm.py ├── mean_field │ ├── __init__.py │ ├── utils.py │ ├── prelu.py │ ├── converter.py │ ├── linear.py │ ├── batchnorm.py │ └── conv.py ├── prior_reg.py └── bnn_utils.py ├── mnist_ckpts └── best.ckpt ├── .gitignore ├── laplace ├── curvature │ ├── __init__.py │ ├── backpack.py │ ├── asdl.py │ └── curvature.py ├── utils │ ├── __init__.py │ ├── swag.py │ ├── feature_extractor.py │ └── utils.py ├── __init__.py ├── laplace.py ├── subnetlaplace.py ├── lllaplace.py └── marglik_training.py ├── pytorch_cifar_models ├── __init__.py ├── vgg.py ├── resnet.py └── shufflenetv2.py ├── README.md └── data.py /scalablebdl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mnist_ckpts/best.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thudzj/ELLA/HEAD/mnist_ckpts/best.ckpt -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | unused 4 | *.pyc 5 | CIFAR-10-C 6 | CIFAR-100-C 7 | logs 8 | log.txt 9 | pdfs 10 | plot* 11 | main_bk.py 12 | -------------------------------------------------------------------------------- /scalablebdl/low_rank/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import BayesLinearLR 2 | from .conv import BayesConv2dLR 3 | from .converter import to_deterministic, to_bayesian 4 | -------------------------------------------------------------------------------- /scalablebdl/empirical/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import BayesLinearEMP 2 | from .batchnorm import BayesBatchNorm2dEMP 3 | from .conv import BayesConv2dEMP 4 | from .prelu import BayesPReLUEMP 5 | from .converter import to_deterministic, to_bayesian 6 | -------------------------------------------------------------------------------- /scalablebdl/implicit/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import BayesLinearIMP 2 | from .batchnorm import BayesBatchNorm2dIMP 3 | from .conv import BayesConv2dIMP 4 | from .prelu import BayesPReLUIMP 5 | from .converter import to_deterministic, to_bayesian 6 | -------------------------------------------------------------------------------- /scalablebdl/mean_field/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import BayesLinearMF 2 | from .batchnorm import BayesBatchNorm2dMF 3 | from .conv import BayesConv2dMF 4 | from .prelu import BayesPReLUMF 5 | from .utils import MulExpAddFunction 6 | from .converter import to_deterministic, to_bayesian 7 | -------------------------------------------------------------------------------- /scalablebdl/mean_field/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MulExpAddFunction(torch.autograd.Function): 4 | @staticmethod 5 | def forward(ctx, input, psi, mu): 6 | ctx.mark_dirty(input) 7 | output = input.mul_(psi.exp()).add_(mu) 8 | ctx.save_for_backward(mu, output) 9 | return output 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): 13 | mu, output = ctx.saved_tensors 14 | grad_psi = (grad_output*(output - mu)).sum(0) 15 | grad_mu = grad_output.sum(0) 16 | return None, grad_psi, grad_mu -------------------------------------------------------------------------------- /laplace/curvature/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from laplace.curvature.curvature import CurvatureInterface, GGNInterface, EFInterface 4 | 5 | try: 6 | from laplace.curvature.backpack import BackPackGGN, BackPackEF, BackPackInterface 7 | except ModuleNotFoundError: 8 | logging.info('Backpack not available.') 9 | 10 | try: 11 | from laplace.curvature.asdl import AsdlHessian, AsdlGGN, AsdlEF, AsdlInterface 12 | except ModuleNotFoundError: 13 | logging.info('asdfghjkl backend not available.') 14 | 15 | __all__ = ['CurvatureInterface', 'GGNInterface', 'EFInterface', 16 | 'BackPackInterface', 'BackPackGGN', 'BackPackEF', 17 | 'AsdlInterface', 'AsdlGGN', 'AsdlEF', 'AsdlHessian'] 18 | -------------------------------------------------------------------------------- /laplace/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from laplace.utils.utils import get_nll, validate, parameters_per_layer, invsqrt_precision, _is_batchnorm, _is_valid_scalar, kron, diagonal_add_scalar, symeig, block_diag, expand_prior_precision, normal_samples 2 | from laplace.utils.feature_extractor import FeatureExtractor 3 | from laplace.utils.matrix import Kron, KronDecomposed 4 | from laplace.utils.swag import fit_diagonal_swag_var 5 | from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask 6 | 7 | 8 | __all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', 9 | 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision', 10 | 'FeatureExtractor', 11 | 'Kron', 'KronDecomposed', 12 | 'fit_diagonal_swag_var', 13 | 'SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', 14 | 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask'] 15 | -------------------------------------------------------------------------------- /laplace/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. include:: ../README.md 3 | 4 | .. include:: ../examples/regression_example.md 5 | .. include:: ../examples/calibration_example.md 6 | """ 7 | REGRESSION = 'regression' 8 | CLASSIFICATION = 'classification' 9 | 10 | from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace 11 | from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace 12 | from laplace.subnetlaplace import SubnetLaplace, FullSubnetLaplace, DiagSubnetLaplace 13 | from laplace.laplace import Laplace 14 | from laplace.marglik_training import marglik_training 15 | 16 | __all__ = ['Laplace', # direct access to all Laplace classes via unified interface 17 | 'BaseLaplace', 'ParametricLaplace', # base-class and its (first-level) subclasses 18 | 'FullLaplace', 'KronLaplace', 'DiagLaplace', 'LowRankLaplace', # all-weights 19 | 'LLLaplace', # base-class last-layer 20 | 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', # last-layer 21 | 'SubnetLaplace', # base-class subnetwork 22 | 'FullSubnetLaplace', 'DiagSubnetLaplace', # subnetwork 23 | 'marglik_training'] # methods 24 | -------------------------------------------------------------------------------- /laplace/laplace.py: -------------------------------------------------------------------------------- 1 | from laplace.baselaplace import ParametricLaplace 2 | from laplace import * 3 | 4 | 5 | def Laplace(model, likelihood, subset_of_weights='last_layer', hessian_structure='kron', 6 | *args, **kwargs): 7 | """Simplified Laplace access using strings instead of different classes. 8 | 9 | Parameters 10 | ---------- 11 | model : torch.nn.Module 12 | likelihood : {'classification', 'regression'} 13 | subset_of_weights : {'last_layer', 'subnetwork', 'all'}, default='last_layer' 14 | subset of weights to consider for inference 15 | hessian_structure : {'diag', 'kron', 'full', 'lowrank'}, default='kron' 16 | structure of the Hessian approximation 17 | 18 | Returns 19 | ------- 20 | laplace : ParametricLaplace 21 | chosen subclass of ParametricLaplace instantiated with additional arguments 22 | """ 23 | if subset_of_weights == 'subnetwork' and hessian_structure not in ['full', 'diag']: 24 | raise ValueError('Subnetwork Laplace requires a full or diagonal Hessian approximation!') 25 | 26 | laplace_map = {subclass._key: subclass for subclass in _all_subclasses(ParametricLaplace) 27 | if hasattr(subclass, '_key')} 28 | laplace_class = laplace_map[(subset_of_weights, hessian_structure)] 29 | return laplace_class(model, likelihood, *args, **kwargs) 30 | 31 | 32 | def _all_subclasses(cls): 33 | return set(cls.__subclasses__()).union( 34 | [s for c in cls.__subclasses__() for s in _all_subclasses(c)]) 35 | -------------------------------------------------------------------------------- /pytorch_cifar_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import cifar10_resnet20 2 | from .resnet import cifar10_resnet32 3 | from .resnet import cifar10_resnet44 4 | from .resnet import cifar10_resnet56 5 | 6 | from .resnet import cifar100_resnet20 7 | from .resnet import cifar100_resnet32 8 | from .resnet import cifar100_resnet44 9 | from .resnet import cifar100_resnet56 10 | 11 | from .vgg import cifar10_vgg11_bn 12 | from .vgg import cifar10_vgg13_bn 13 | from .vgg import cifar10_vgg16_bn 14 | from .vgg import cifar10_vgg19_bn 15 | 16 | from .vgg import cifar100_vgg11_bn 17 | from .vgg import cifar100_vgg13_bn 18 | from .vgg import cifar100_vgg16_bn 19 | from .vgg import cifar100_vgg19_bn 20 | 21 | from .mobilenetv2 import cifar10_mobilenetv2_x0_5 22 | from .mobilenetv2 import cifar10_mobilenetv2_x0_75 23 | from .mobilenetv2 import cifar10_mobilenetv2_x1_0 24 | from .mobilenetv2 import cifar10_mobilenetv2_x1_4 25 | 26 | from .mobilenetv2 import cifar100_mobilenetv2_x0_5 27 | from .mobilenetv2 import cifar100_mobilenetv2_x0_75 28 | from .mobilenetv2 import cifar100_mobilenetv2_x1_0 29 | from .mobilenetv2 import cifar100_mobilenetv2_x1_4 30 | 31 | from .shufflenetv2 import cifar10_shufflenetv2_x0_5 32 | from .shufflenetv2 import cifar10_shufflenetv2_x1_0 33 | from .shufflenetv2 import cifar10_shufflenetv2_x1_5 34 | from .shufflenetv2 import cifar10_shufflenetv2_x2_0 35 | 36 | from .shufflenetv2 import cifar100_shufflenetv2_x0_5 37 | from .shufflenetv2 import cifar100_shufflenetv2_x1_0 38 | from .shufflenetv2 import cifar100_shufflenetv2_x1_5 39 | from .shufflenetv2 import cifar100_shufflenetv2_x2_0 40 | 41 | from .repvgg import cifar10_repvgg_a0 42 | from .repvgg import cifar10_repvgg_a1 43 | from .repvgg import cifar10_repvgg_a2 44 | 45 | from .repvgg import cifar100_repvgg_a0 46 | from .repvgg import cifar100_repvgg_a1 47 | from .repvgg import cifar100_repvgg_a2 48 | 49 | from .vit import cifar10_vit_b16 50 | from .vit import cifar10_vit_b32 51 | from .vit import cifar10_vit_l16 52 | from .vit import cifar10_vit_l32 53 | from .vit import cifar10_vit_h14 54 | 55 | from .vit import cifar100_vit_b16 56 | from .vit import cifar100_vit_b32 57 | from .vit import cifar100_vit_l16 58 | from .vit import cifar100_vit_l32 59 | from .vit import cifar100_vit_h14 60 | 61 | __version__ = "0.1.0-alpha" 62 | -------------------------------------------------------------------------------- /scalablebdl/empirical/converter.py: -------------------------------------------------------------------------------- 1 | # refer to https://github.com/Harry24k/pytorch-custom-utils/blob/master/torchhk/transform.py 2 | import copy 3 | import warnings 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU 8 | 9 | from . import BayesLinearEMP, BayesConv2dEMP, BayesBatchNorm2dEMP, BayesPReLUEMP 10 | 11 | def to_bayesian(input, num_mc_samples=20): 12 | 13 | if isinstance(input, (Linear, Conv2d, BatchNorm2d, PReLU)): 14 | if isinstance(input, (Linear)): 15 | output = BayesLinearEMP(input.in_features, input.out_features, 16 | input.bias, num_mc_samples=num_mc_samples) 17 | elif isinstance(input, (Conv2d)): 18 | output = BayesConv2dEMP(input.in_channels, input.out_channels, 19 | input.kernel_size, input.stride, 20 | input.padding, input.dilation, 21 | input.groups, input.bias, 22 | num_mc_samples=num_mc_samples) 23 | elif isinstance(input, (PReLU)): 24 | output = BayesPReLUEMP(input.num_parameters, num_mc_samples=num_mc_samples) 25 | else: 26 | output = BayesBatchNorm2dEMP(input.num_features, input.eps, 27 | input.momentum, input.affine, 28 | input.track_running_stats, 29 | num_mc_samples=num_mc_samples) 30 | output.running_mean = input.running_mean 31 | output.running_var = input.running_var 32 | output.num_batches_tracked = input.num_batches_tracked 33 | 34 | if input.weight is not None: 35 | output.weights.data = input.weight.unsqueeze(0).repeat( 36 | num_mc_samples, *([1,]*input.weight.dim())).data 37 | if hasattr(input, 'bias') and input.bias is not None: 38 | output.biases.data = input.bias.unsqueeze(0).repeat( 39 | num_mc_samples, *([1,]*input.bias.dim())).data 40 | del input 41 | return output 42 | output = input 43 | for name, module in input.named_children(): 44 | output.add_module(name, to_bayesian(module, num_mc_samples)) 45 | del input 46 | return output 47 | 48 | def to_deterministic(input): 49 | assert False, "Cannot convert an empirical BNN into DNN" 50 | -------------------------------------------------------------------------------- /scalablebdl/low_rank/converter.py: -------------------------------------------------------------------------------- 1 | # refer to https://github.com/Harry24k/pytorch-custom-utils/blob/master/torchhk/transform.py 2 | import copy 3 | import warnings 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d 8 | 9 | from . import BayesLinearLR, BayesConv2dLR 10 | 11 | def to_bayesian(input, rank=1, num_mc_samples=20, pert_init_std=0.2): 12 | 13 | if isinstance(input, (Linear, Conv2d)): 14 | if isinstance(input, (Linear)): 15 | output = BayesLinearLR(input.in_features, input.out_features, 16 | input.bias, num_mc_samples=num_mc_samples, 17 | rank=rank, pert_init_std=pert_init_std) 18 | elif isinstance(input, (Conv2d)): 19 | output = BayesConv2dLR(input.in_channels, input.out_channels, 20 | input.kernel_size, input.stride, 21 | input.padding, input.dilation, 22 | input.groups, input.bias, 23 | num_mc_samples=num_mc_samples, 24 | rank=rank, pert_init_std=pert_init_std) 25 | 26 | if input.weight is not None: 27 | with torch.no_grad(): 28 | output.weight_mu = input.weight 29 | 30 | if hasattr(input, 'bias') and input.bias is not None: 31 | with torch.no_grad(): 32 | output.bias = input.bias 33 | del input 34 | return output 35 | 36 | output = input 37 | for name, module in input.named_children(): 38 | output.add_module(name, to_bayesian(module, rank, num_mc_samples, pert_init_std)) 39 | del input 40 | return output 41 | 42 | def to_deterministic(input): 43 | 44 | if isinstance(input, (BayesLinearLR, BayesConv2dLR)): 45 | if isinstance(input, (BayesLinearLR)): 46 | output = Linear(input.in_features, input.out_features, input.bias) 47 | elif isinstance(input, (BayesConv2dLR)): 48 | output = Conv2d(input.in_channels, input.out_channels, 49 | input.kernel_size, input.stride, 50 | input.padding, input.dilation, 51 | input.groups, input.bias) 52 | 53 | with torch.no_grad(): 54 | output.weight = input.weight_mu 55 | output.bias = input.bias 56 | del input 57 | return output 58 | 59 | output = input 60 | for name, module in input.named_children(): 61 | output.add_module(name, to_deterministic(module)) 62 | del input 63 | return output 64 | -------------------------------------------------------------------------------- /scalablebdl/empirical/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import Module, Parameter 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | class BayesLinearEMP(Module): 9 | r""" 10 | Applies Bayesian Linear 11 | """ 12 | __constants__ = ['bias', 'in_features', 'out_features', 'num_mc_samples'] 13 | 14 | def __init__(self, in_features, out_features, bias=True, num_mc_samples=20): 15 | super(BayesLinearEMP, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.mc_sample_id = None 19 | self.num_mc_samples = num_mc_samples 20 | 21 | self.parallel_eval = False 22 | 23 | self.weights = Parameter(torch.Tensor(num_mc_samples, out_features, in_features)) 24 | 25 | if bias is None or bias is False: 26 | self.bias = False 27 | else: 28 | self.bias = True 29 | 30 | if self.bias: 31 | self.biases = Parameter(torch.Tensor(num_mc_samples, out_features)) 32 | else: 33 | self.register_parameter('biases', None) 34 | 35 | self.reset_parameters() 36 | 37 | def reset_parameters(self): 38 | stdv = 1. / math.sqrt(self.weights.size(2)) 39 | for i in range(self.num_mc_samples): 40 | self.weights[i].data.uniform_(-stdv, stdv) 41 | if self.bias: 42 | self.biases[i].data.uniform_(-stdv, stdv) 43 | 44 | def forward(self, input): 45 | r""" 46 | Overriden. 47 | """ 48 | if self.parallel_eval: 49 | if input.dim() == 2: 50 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1) 51 | out = torch.bmm(self.weights, input.permute(1, 2, 0)).permute(2, 0, 1) 52 | if self.bias: 53 | out = out + self.biases 54 | elif isinstance(self.mc_sample_id, int): 55 | weight = self.weights[self.mc_sample_id % self.num_mc_samples] 56 | bias = self.biases[self.mc_sample_id % self.num_mc_samples] if self.bias else None 57 | out = F.linear(input, weight, bias) 58 | else: 59 | bs = input.shape[0] 60 | idx = torch.tensor(self.mc_sample_id, device=input.device, dtype=torch.long) 61 | weight = self.weights[idx] 62 | out = torch.bmm(weight, input.unsqueeze(2)).squeeze() 63 | if self.bias: 64 | bias = self.biases[idx] 65 | out = out + bias 66 | return out 67 | 68 | def extra_repr(self): 69 | r""" 70 | Overriden. 71 | """ 72 | return 'in_features={}, out_features={}, bias={}, num_mc_samples={}'.format( 73 | self.in_features, self.out_features, self.bias is not None, self.num_mc_samples) 74 | -------------------------------------------------------------------------------- /scalablebdl/prior_reg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class PriorRegularizor: 6 | def __init__(self, model, decay, num_data, num_mc_samples, posterior_type, MOPED): 7 | super(PriorRegularizor, self).__init__() 8 | 9 | self.model = model 10 | self.decay = decay 11 | self.num_data = num_data 12 | self.num_mc_samples = num_mc_samples 13 | self.posterior_type = posterior_type 14 | self.MOPED = MOPED 15 | 16 | if MOPED: 17 | self.init_mus = {} 18 | for name, param in model.named_parameters(): 19 | if '_mu' in name: 20 | self.init_mus[name] = param.data.clone().detach() 21 | 22 | @torch.no_grad() 23 | def get_param_by_name(self, name): 24 | o = self.model 25 | for i in name.split('.'): 26 | o = getattr(o, i) 27 | return o 28 | 29 | @torch.no_grad() 30 | def step(self,): 31 | for name, param in self.model.named_parameters(): 32 | if self.posterior_type == "mf" or self.posterior_type == "mean_field": 33 | if '_psi' in name: 34 | param.grad.data.add_((param*2).exp(), alpha=self.decay).sub_(1./self.num_data) 35 | else: 36 | if self.MOPED: 37 | param.grad.data.add_(param - self.init_mus[name], alpha=self.decay) 38 | else: 39 | param.grad.data.add_(param, alpha=self.decay) 40 | elif self.posterior_type == "emp" or self.posterior_type == "empirical": 41 | if 'weights' in name or 'biases' in name: 42 | param.grad.data.add_(param, alpha=self.decay/self.num_mc_samples) 43 | else: 44 | param.grad.data.add_(param, alpha=self.decay) 45 | elif self.posterior_type == "lr" or self.posterior_type == "low_rank": 46 | if 'weight_mu' in name: 47 | b = self.get_param_by_name(name.replace("weight_mu", "in_perturbations")) 48 | a = self.get_param_by_name(name.replace("weight_mu", "out_perturbations")) 49 | m_ = torch.bmm(a, b) 50 | param.grad.data.add_((m_**2).mean(0).view_as(param) * param, alpha=self.decay) 51 | m_ = m_.mul_((param**2).flatten(1, -1)) 52 | a.grad.data.add_(torch.bmm(m_, b.permute(0, 2, 1)), alpha=self.decay/self.num_mc_samples) 53 | b.grad.data.add_(torch.bmm(a.permute(0, 2, 1), m_), alpha=self.decay/self.num_mc_samples) 54 | elif not 'perturbations' in name: 55 | param.grad.data.add_(param, alpha=self.decay) 56 | elif self.posterior_type is None: 57 | param.grad.data.add_(param, alpha=self.decay) 58 | else: 59 | raise NotImplementedError 60 | -------------------------------------------------------------------------------- /scalablebdl/implicit/converter.py: -------------------------------------------------------------------------------- 1 | # refer to https://github.com/Harry24k/pytorch-custom-utils/blob/master/torchhk/transform.py 2 | import copy 3 | import warnings 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm2d 8 | 9 | from . import BayesLinearIMP, BayesConv2dIMP, BayesBatchNorm2dIMP 10 | 11 | def to_bayesian(input, num_mc_samples=20, is_residual=False): 12 | return _to_bayesian(copy.deepcopy(input), num_mc_samples, is_residual) 13 | 14 | def _to_bayesian(input, num_mc_samples=20, is_residual=False): 15 | 16 | if isinstance(input, (Linear, Conv2d, BatchNorm2d)): 17 | if isinstance(input, (Linear)): 18 | output = BayesLinearIMP(input.in_features, input.out_features, 19 | input.bias, num_mc_samples=num_mc_samples) 20 | elif isinstance(input, (Conv2d)): 21 | output = BayesConv2dIMP(input.in_channels, input.out_channels, 22 | input.kernel_size, input.stride, 23 | input.padding, input.dilation, 24 | input.groups, input.bias, 25 | num_mc_samples=num_mc_samples) 26 | else: 27 | output = BayesBatchNorm2dIMP(input.num_features, input.eps, 28 | input.momentum, input.affine, 29 | input.track_running_stats, 30 | num_mc_samples=num_mc_samples) 31 | setattr(output, 'running_mean', getattr(input, 'running_mean')) 32 | setattr(output, 'running_var', getattr(input, 'running_var')) 33 | setattr(output, 'num_batches_tracked', getattr(input, 'num_batches_tracked')) 34 | 35 | if input.weight is not None: 36 | if is_residual: 37 | if isinstance(input, (Conv2d)): 38 | output.weights.data = torch.eye(output.weight_mu.data.size( 39 | 0)).unsqueeze(2).unsqueeze(3).float().unsqueeze(0).repeat( 40 | num_mc_samples, 1, 1, 1, 1).data 41 | elif isinstance(input, BatchNorm2d): 42 | output.weights.data.fill_(1.) 43 | else: 44 | output.weights.data = input.weight.unsqueeze(0).repeat( 45 | num_mc_samples, *([1,]*input.weight.dim())).data 46 | if input.bias is not None: 47 | if is_residual: 48 | output.biases.data.zero_() 49 | else: 50 | output.biases.data = input.bias.unsqueeze(0).repeat( 51 | num_mc_samples, *([1,]*input.bias.dim())).data 52 | 53 | return output 54 | else: 55 | for name, module in input.named_children(): 56 | setattr(input, name, _to_bayesian(module, num_mc_samples, is_residual)) 57 | return input 58 | 59 | def to_deterministic(input): 60 | assert False, "Cannot convert an empirical BNN into DNN" 61 | -------------------------------------------------------------------------------- /scalablebdl/implicit/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import Module, Parameter 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | class BayesLinearIMP(Module): 9 | r""" 10 | Applies Bayesian Linear 11 | """ 12 | __constants__ = ['bias', 'in_features', 'out_features'] 13 | 14 | def __init__(self, in_features, out_features, bias=True, 15 | deterministic=False): 16 | super(BayesLinearIMP, self).__init__() 17 | self.in_features = in_features 18 | self.out_features = out_features 19 | self.deterministic = deterministic 20 | 21 | self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) 22 | self.weight_psi = Parameter(torch.Tensor(out_features, in_features)) 23 | 24 | if bias is None or bias is False: 25 | self.bias = False 26 | else: 27 | self.bias = True 28 | 29 | if self.bias: 30 | self.bias_mu = Parameter(torch.Tensor(out_features)) 31 | self.bias_psi = Parameter(torch.Tensor(out_features)) 32 | else: 33 | self.register_parameter('bias_mu', None) 34 | self.register_parameter('bias_psi', None) 35 | 36 | self.reset_parameters() 37 | 38 | self.weight_size = list(self.weight_mu.shape) 39 | self.bias_size = list(self.bias_mu.shape) if self.bias else None 40 | self.mul_exp_add = MulExpAddFunction.apply 41 | 42 | def reset_parameters(self): 43 | stdv = 1. / math.sqrt(self.weight_mu.size(1)) 44 | self.weight_mu.data.uniform_(-stdv, stdv) 45 | self.weight_psi.data.uniform_(-6, -5) 46 | if self.bias: 47 | self.bias_mu.data.uniform_(-stdv, stdv) 48 | self.bias_psi.data.uniform_(-6, -5) 49 | 50 | def forward(self, input): 51 | r""" 52 | Overriden. 53 | """ 54 | if self.deterministic: 55 | weight = self.weight_mu 56 | bias = self.bias_mu if self.bias else None 57 | out = F.linear(input, weight, bias) 58 | else: 59 | bs = input.shape[0] 60 | weight = self.mul_exp_add(torch.randn(bs, *self.weight_size, 61 | device=input.device, 62 | dtype=input.dtype), 63 | self.weight_psi, self.weight_mu) 64 | 65 | out = torch.bmm(weight, input.unsqueeze(2)).squeeze() 66 | if self.bias: 67 | bias = self.mul_exp_add(torch.randn(bs, *self.bias_size, 68 | device=input.device, 69 | dtype=input.dtype), 70 | self.bias_psi, self.bias_mu) 71 | out = out + bias 72 | return out 73 | 74 | def extra_repr(self): 75 | r""" 76 | Overriden. 77 | """ 78 | return 'in_features={}, out_features={}, bias={}'.format( 79 | self.in_features, self.out_features, self.bias is not None) 80 | -------------------------------------------------------------------------------- /scalablebdl/implicit/prelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module, Parameter 4 | import torch.nn.functional as F 5 | 6 | class BayesPReLUIMP(Module): 7 | r"""Applies the element-wise function: 8 | 9 | .. math:: 10 | \text{PReLU}(x) = \max(0,x) + a * \min(0,x) 11 | 12 | or 13 | 14 | .. math:: 15 | \text{PReLU}(x) = 16 | \begin{cases} 17 | x, & \text{ if } x \geq 0 \\ 18 | ax, & \text{ otherwise } 19 | \end{cases} 20 | 21 | Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single 22 | parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, 23 | a separate :math:`a` is used for each input channel. 24 | 25 | 26 | .. note:: 27 | weight decay should not be used when learning :math:`a` for good performance. 28 | 29 | .. note:: 30 | Channel dim is the 2nd dim of input. When input has dims < 2, then there is 31 | no channel dim and the number of channels = 1. 32 | 33 | Args: 34 | num_parameters (int): number of :math:`a` to learn. 35 | Although it takes an int as input, there is only two values are legitimate: 36 | 1, or the number of channels at input. Default: 1 37 | init (float): the initial value of :math:`a`. Default: 0.25 38 | 39 | Shape: 40 | - Input: :math:`(N, *)` where `*` means, any number of additional 41 | dimensions 42 | - Output: :math:`(N, *)`, same shape as the input 43 | 44 | Attributes: 45 | weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). 46 | 47 | .. image:: ../scripts/activation_images/PReLU.png 48 | 49 | Examples:: 50 | 51 | >>> m = nn.PReLU() 52 | >>> input = torch.randn(2) 53 | >>> output = m(input) 54 | """ 55 | __constants__ = ['num_parameters', 'num_mc_samples'] 56 | num_parameters: int 57 | num_mc_samples: int 58 | 59 | def __init__(self, num_parameters: int = 1, num_mc_samples: int = 20, init: float = 0.25) -> None: 60 | super(BayesPReLUIMP, self).__init__() 61 | self.num_parameters = num_parameters 62 | self.mc_sample_id = None 63 | self.num_mc_samples = num_mc_samples 64 | self.parallel_eval = False 65 | 66 | self.weights = Parameter(torch.Tensor(num_mc_samples, num_parameters).fill_(init)) 67 | 68 | def forward(self, input: Tensor) -> Tensor: 69 | assert False 70 | if self.parallel_eval: 71 | return torch.maximum(input, 0) + self.weights[None, :, :, None, None] * torch.minimum(input, 0) 72 | elif isinstance(self.mc_sample_id, int): 73 | self.mc_sample_id = self.mc_sample_id % self.num_mc_samples 74 | weight = self.weights[self.mc_sample_id] 75 | return F.prelu(input, weight) 76 | else: 77 | idx = torch.tensor(self.mc_sample_id, device=input.device, dtype=torch.long) 78 | weight = self.weights[idx] 79 | return torch.maximum(input, 0) + weight[:, :, None, None] * torch.minimum(input, 0) 80 | 81 | def extra_repr(self) -> str: 82 | return 'num_parameters={}, num_mc_samples={}'.format(self.num_parameters, self.num_mc_samples) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Env 2 | ``` 3 | Python==3.8.10 4 | torch==1.11.0 5 | ``` 6 | 7 | # MNIST 8 | ``` 9 | python main.py --balanced --batch-size 500 --test-batch-size 500 --sigma 0.0001 \ 10 | --K 20 --M 1000 --dataset mnist --arch mnist_model \ 11 | --pretrained mnist_ckpts/best.ckpt --check --measure-speed 12 | ``` 13 | 14 | # CIFAR-10 15 | 16 | ### ELLA 17 | ``` 18 | python main.py --balanced --batch-size 500 --test-batch-size 500 \ 19 | --sigma2 0.04 --K 20 --M 2000 --search-freq 1000 \ 20 | --arch [cifar10_resnet20, cifar10_resnet32, cifar10_resnet44, cifar10_resnet56] 21 | ``` 22 | 23 | ### LLA* 24 | ``` 25 | python laplace_baseline.py --batch-size 200 --test-batch-size 200 \ 26 | --subset-of-weights last_layer --hessian-structure full --job-id lastl-full \ 27 | --arch [cifar10_resnet20, cifar10_resnet32, cifar10_resnet44, cifar10_resnet56] 28 | ``` 29 | 30 | ### LLA*-KFAC 31 | ``` 32 | python laplace_baseline.py --batch-size 200 --test-batch-size 200 \ 33 | --subset-of-weights last_layer --hessian-structure kron --job-id lastl-kron \ 34 | --arch [cifar10_resnet20, cifar10_resnet32, cifar10_resnet44, cifar10_resnet56] 35 | ``` 36 | 37 | ### LLA-Diag 38 | ``` 39 | python laplace_baseline.py --batch-size 200 --test-batch-size 200 \ 40 | --subset-of-weights all --hessian-structure diag --job-id all-diag \ 41 | --arch [cifar10_resnet20, cifar10_resnet32, cifar10_resnet44, cifar10_resnet56] 42 | ``` 43 | 44 | ### LLA-KFAC 45 | ``` 46 | python laplace_baseline.py --batch-size 200 --test-batch-size 200 \ 47 | --subset-of-weights all --hessian-structure kron --job-id all-kron \ 48 | --arch [cifar10_resnet20, cifar10_resnet32, cifar10_resnet44, cifar10_resnet56] 49 | ``` 50 | 51 | ### MFVI-BF 52 | ``` 53 | python mfvi_baseline.py --batch-size 256 --test-batch-size 256 --dataset cifar10 \ 54 | --epochs 12 --lr 1e-3 --ft_lr 1e-4 --decay 0.0005 \ 55 | --arch [cifar10_resnet20, cifar10_resnet32, cifar10_resnet44, cifar10_resnet56] 56 | ``` 57 | 58 | 59 | # ImageNet 60 | 61 | ### ELLA 62 | ``` 63 | python main.py --balanced --batch-size 100 --test-batch-size 100 \ 64 | --sigma2 0.01 --K 20 --M 2000 --I 100 --search-freq 100 --dataset imagenet \ 65 | --arch [resnet18, resnet34, resnet50] 66 | ``` 67 | 68 | ### MFVI-BF 69 | ``` 70 | python mfvi_baseline.py --batch-size 128 --test-batch-size 256 --dataset imagenet \ 71 | --epochs 4 --lr 1e-3 --ft_lr 1e-4 --decay 0.0001 \ 72 | --arch [resnet18, resnet34, resnet50] 73 | ``` 74 | 75 | ### ELLA on ViT-B 76 | ``` 77 | python main.py --balanced --batch-size 100 --test-batch-size 100 \ 78 | --sigma2 0.00001 --K 20 --M 2000 --I 80 --search-freq 100 --dataset imagenet \ 79 | --arch vit_base_patch16_224 80 | ``` 81 | 82 | # BibTeX 83 | ``` 84 | @inproceedings{ 85 | deng2022accelerated, 86 | title={Accelerated Linearized Laplace Approximation for Bayesian Deep Learning}, 87 | author={Zhijie Deng and Feng Zhou and Jun Zhu}, 88 | booktitle={Thirty-Sixth Conference on Neural Information Processing Systems}, 89 | year={2022}, 90 | url={https://openreview.net/forum?id=jftNpltMgz} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /scalablebdl/empirical/prelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module, Parameter 4 | import torch.nn.functional as F 5 | 6 | class BayesPReLUEMP(Module): 7 | r"""Applies the element-wise function: 8 | 9 | .. math:: 10 | \text{PReLU}(x) = \max(0,x) + a * \min(0,x) 11 | 12 | or 13 | 14 | .. math:: 15 | \text{PReLU}(x) = 16 | \begin{cases} 17 | x, & \text{ if } x \geq 0 \\ 18 | ax, & \text{ otherwise } 19 | \end{cases} 20 | 21 | Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single 22 | parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, 23 | a separate :math:`a` is used for each input channel. 24 | 25 | 26 | .. note:: 27 | weight decay should not be used when learning :math:`a` for good performance. 28 | 29 | .. note:: 30 | Channel dim is the 2nd dim of input. When input has dims < 2, then there is 31 | no channel dim and the number of channels = 1. 32 | 33 | Args: 34 | num_parameters (int): number of :math:`a` to learn. 35 | Although it takes an int as input, there is only two values are legitimate: 36 | 1, or the number of channels at input. Default: 1 37 | init (float): the initial value of :math:`a`. Default: 0.25 38 | 39 | Shape: 40 | - Input: :math:`(N, *)` where `*` means, any number of additional 41 | dimensions 42 | - Output: :math:`(N, *)`, same shape as the input 43 | 44 | Attributes: 45 | weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). 46 | 47 | .. image:: ../scripts/activation_images/PReLU.png 48 | 49 | Examples:: 50 | 51 | >>> m = nn.PReLU() 52 | >>> input = torch.randn(2) 53 | >>> output = m(input) 54 | """ 55 | __constants__ = ['num_parameters', 'num_mc_samples'] 56 | num_parameters: int 57 | num_mc_samples: int 58 | 59 | def __init__(self, num_parameters: int = 1, num_mc_samples: int = 20, init: float = 0.25) -> None: 60 | super(BayesPReLUEMP, self).__init__() 61 | self.num_parameters = num_parameters 62 | self.mc_sample_id = None 63 | self.num_mc_samples = num_mc_samples 64 | self.parallel_eval = False 65 | 66 | self.weights = Parameter(torch.Tensor(num_mc_samples, num_parameters).fill_(init)) 67 | 68 | def forward(self, input: Tensor) -> Tensor: 69 | if self.parallel_eval: 70 | if input.dim() == 4: 71 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1, 1, 1) 72 | return torch.maximum(input, torch.tensor(0., device=input.device)) + self.weights[None, :, :, None, None] * torch.minimum(input, torch.tensor(0., device=input.device)) 73 | elif isinstance(self.mc_sample_id, int): 74 | self.mc_sample_id = self.mc_sample_id % self.num_mc_samples 75 | weight = self.weights[self.mc_sample_id] 76 | return F.prelu(input, weight) 77 | else: 78 | idx = torch.tensor(self.mc_sample_id, device=input.device, dtype=torch.long) 79 | weight = self.weights[idx] 80 | return torch.maximum(input, torch.tensor(0., device=input.device)) + weight[:, :, None, None] * torch.minimum(input, torch.tensor(0., device=input.device)) 81 | 82 | def extra_repr(self) -> str: 83 | return 'num_parameters={}, num_mc_samples={}'.format(self.num_parameters, self.num_mc_samples) 84 | -------------------------------------------------------------------------------- /laplace/utils/swag.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | from torch.nn.utils import parameters_to_vector 5 | 6 | 7 | __all__ = ['fit_diagonal_swag_var'] 8 | 9 | 10 | def _param_vector(model): 11 | return parameters_to_vector(model.parameters()).detach() 12 | 13 | 14 | def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, 15 | lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30): 16 | """ 17 | Fit diagonal SWAG [1], which estimates marginal variances of model parameters by 18 | computing the first and second moment of SGD iterates with a large learning rate. 19 | 20 | Implementation partly adapted from: 21 | - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py 22 | - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py 23 | 24 | References 25 | ---------- 26 | [1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. 27 | [*A Simple Baseline for Bayesian Uncertainty in Deep Learning*](https://arxiv.org/abs/1902.02476). 28 | NeurIPS 2019. 29 | 30 | Parameters 31 | ---------- 32 | model : torch.nn.Module 33 | train_loader : torch.data.utils.DataLoader 34 | training data loader to use for snapshot collection 35 | criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss 36 | loss function to use for snapshot collection 37 | n_snapshots_total : int 38 | total number of model snapshots to collect 39 | snapshot_freq : int 40 | snapshot collection frequency (in epochs) 41 | lr : float 42 | SGD learning rate for collecting snapshots 43 | momentum : float 44 | SGD momentum 45 | weight_decay : float 46 | SGD weight decay 47 | min_var : float 48 | minimum parameter variance to clamp to (for numerical stability) 49 | 50 | Returns 51 | ------- 52 | param_variances : torch.Tensor 53 | vector of marginal variances for each model parameter 54 | """ 55 | 56 | # create a copy of the model to avoid undesired changes to the original model parameters 57 | _model = deepcopy(model) 58 | _model.train() 59 | device = next(_model.parameters()).device 60 | 61 | # initialize running estimates of first and second moment of model parameters 62 | mean = torch.zeros_like(_param_vector(_model)) 63 | sq_mean = torch.zeros_like(_param_vector(_model)) 64 | n_snapshots = 0 65 | 66 | # run SGD to collect model snapshots 67 | optimizer = torch.optim.SGD( 68 | _model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) 69 | n_epochs = snapshot_freq * n_snapshots_total 70 | for epoch in range(n_epochs): 71 | for inputs, targets in train_loader: 72 | inputs, targets = inputs.to(device), targets.to(device) 73 | optimizer.zero_grad() 74 | loss = criterion(_model(inputs), targets) 75 | loss.backward() 76 | optimizer.step() 77 | 78 | if epoch % snapshot_freq == 0: 79 | # update running estimates of first and second moment of model parameters 80 | old_fac, new_fac = n_snapshots / (n_snapshots + 1), 1 / (n_snapshots + 1) 81 | mean = mean * old_fac + _param_vector(_model) * new_fac 82 | sq_mean = sq_mean * old_fac + _param_vector(_model) ** 2 * new_fac 83 | n_snapshots += 1 84 | 85 | # compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2 86 | param_variances = torch.clamp(sq_mean - mean ** 2, min_var) 87 | return param_variances 88 | -------------------------------------------------------------------------------- /scalablebdl/low_rank/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import Module, Parameter 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | class BayesLinearLR(Module): 9 | r""" 10 | Applies Bayesian Linear 11 | """ 12 | __constants__ = ['bias', 'in_features', 'out_features', 'num_mc_samples', 'rank'] 13 | 14 | def __init__(self, in_features, out_features, bias=True, num_mc_samples=20, rank=1, pert_init_std=0.2): 15 | super(BayesLinearLR, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.mc_sample_id = None 19 | self.deterministic = False 20 | self.num_mc_samples = num_mc_samples 21 | self.rank = rank 22 | self.pert_init_std = pert_init_std 23 | 24 | self.parallel_eval = False 25 | 26 | self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) 27 | self.in_perturbations = Parameter(torch.Tensor(num_mc_samples, rank, in_features)) 28 | self.out_perturbations = Parameter(torch.Tensor(num_mc_samples, out_features, rank)) 29 | 30 | if bias is None or bias is False: 31 | self.bias = Parameter(torch.Tensor(out_features)) 32 | else: 33 | self.register_parameter('bias', None) 34 | 35 | self.reset_parameters() 36 | 37 | def reset_parameters(self): 38 | stdv = 1. / math.sqrt(self.weight_mu.size(1)) 39 | self.weight_mu.data.uniform_(-stdv, stdv) 40 | if self.bias is not None: 41 | self.bias.data.uniform_(-stdv, stdv) 42 | m = math.sqrt(1./self.rank) 43 | v = math.sqrt((math.sqrt(self.rank*(self.pert_init_std**2)+1) - 1)/self.rank) 44 | self.in_perturbations.data.normal_(m, v) 45 | self.out_perturbations.data.normal_(m, v) 46 | 47 | def forward(self, input): 48 | r""" 49 | Overriden. 50 | """ 51 | if self.deterministic: 52 | out = F.linear(input, self.weight_mu, self.bias) 53 | elif self.parallel_eval: 54 | if input.dim() == 2: 55 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1) 56 | perturbations = torch.bmm(self.out_perturbations, self.in_perturbations) 57 | weight = perturbations.mul_(self.weight_mu) 58 | out = torch.bmm(weight, input.permute(1, 2, 0)).permute(2, 0, 1) 59 | if self.bias is not None: 60 | out = out + self.bias 61 | elif isinstance(self.mc_sample_id, int): 62 | self.mc_sample_id %= self.num_mc_samples 63 | perturbations = torch.matmul(self.out_perturbations[self.mc_sample_id], 64 | self.in_perturbations[self.mc_sample_id]) 65 | weight = perturbations.mul_(self.weight_mu) 66 | out = F.linear(input, weight, self.bias) 67 | else: 68 | bs = input.shape[0] 69 | idx = torch.tensor(self.mc_sample_id, device=input.device, dtype=torch.long) 70 | perturbations = torch.bmm(self.out_perturbations[idx], self.in_perturbations[idx]) 71 | weight = perturbations.mul_(self.weight_mu) 72 | out = torch.bmm(weight, input.unsqueeze(2)).squeeze() 73 | if self.bias is not None: 74 | out = out + self.bias 75 | return out 76 | 77 | def extra_repr(self): 78 | r""" 79 | Overriden. 80 | """ 81 | return 'in_features={}, out_features={}, bias={}, num_mc_samples={}, rank={}'.format( 82 | self.in_features, self.out_features, self.bias is not None, self.num_mc_samples, self.rank) 83 | -------------------------------------------------------------------------------- /scalablebdl/mean_field/prelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module, Parameter 4 | import torch.nn.functional as F 5 | 6 | class BayesPReLUMF(Module): 7 | r"""Applies the element-wise function: 8 | 9 | .. math:: 10 | \text{PReLU}(x) = \max(0,x) + a * \min(0,x) 11 | 12 | or 13 | 14 | .. math:: 15 | \text{PReLU}(x) = 16 | \begin{cases} 17 | x, & \text{ if } x \geq 0 \\ 18 | ax, & \text{ otherwise } 19 | \end{cases} 20 | 21 | Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single 22 | parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, 23 | a separate :math:`a` is used for each input channel. 24 | 25 | 26 | .. note:: 27 | weight decay should not be used when learning :math:`a` for good performance. 28 | 29 | .. note:: 30 | Channel dim is the 2nd dim of input. When input has dims < 2, then there is 31 | no channel dim and the number of channels = 1. 32 | 33 | Args: 34 | num_parameters (int): number of :math:`a` to learn. 35 | Although it takes an int as input, there is only two values are legitimate: 36 | 1, or the number of channels at input. Default: 1 37 | init (float): the initial value of :math:`a`. Default: 0.25 38 | 39 | Shape: 40 | - Input: :math:`(N, *)` where `*` means, any number of additional 41 | dimensions 42 | - Output: :math:`(N, *)`, same shape as the input 43 | 44 | Attributes: 45 | weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). 46 | 47 | .. image:: ../scripts/activation_images/PReLU.png 48 | 49 | Examples:: 50 | 51 | >>> m = nn.PReLU() 52 | >>> input = torch.randn(2) 53 | >>> output = m(input) 54 | """ 55 | __constants__ = ['num_parameters', 'num_mc_samples'] 56 | num_parameters: int 57 | num_mc_samples: int 58 | 59 | def __init__(self, num_parameters: int = 1, num_mc_samples: int = 20, init: float = 0.25, deterministic: bool = False) -> None: 60 | super(BayesPReLUMF, self).__init__() 61 | self.num_parameters = num_parameters 62 | self.deterministic = deterministic 63 | self.num_mc_samples = num_mc_samples 64 | self.parallel_eval = False 65 | 66 | self.weight_mu = Parameter(torch.Tensor(num_parameters).fill_(init)) 67 | self.weight_psi = Parameter(torch.Tensor(num_parameters).uniform_(-6, -5)) 68 | 69 | def forward(self, input: Tensor) -> Tensor: 70 | if self.deterministic: 71 | return F.prelu(input, self.weight_mu) 72 | elif not self.parallel_eval: 73 | weight = torch.randn(input.size(0), self.num_parameters, device=input.device, dtype=input.dtype) * self.weight_psi.exp() + self.weight_mu 74 | return torch.maximum(input, torch.tensor(0., device=input.device)) + weight[:, :, None, None] * torch.minimum(input, torch.tensor(0., device=input.device)) 75 | else: 76 | if input.dim() == 4: 77 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1, 1, 1) 78 | weight = torch.randn(self.num_mc_samples, self.num_parameters, device=input.device, dtype=input.dtype) * self.weight_psi.exp() + self.weight_mu 79 | return torch.maximum(input, torch.tensor(0., device=input.device)) + weight[None, :, :, None, None] * torch.minimum(input, torch.tensor(0., device=input.device)) 80 | 81 | def extra_repr(self) -> str: 82 | return 'num_parameters={}, num_mc_samples={}'.format(self.num_parameters, self.num_mc_samples) 83 | -------------------------------------------------------------------------------- /scalablebdl/mean_field/converter.py: -------------------------------------------------------------------------------- 1 | # refer to https://github.com/Harry24k/pytorch-custom-utils/blob/master/torchhk/transform.py 2 | import copy 3 | import warnings 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU 8 | 9 | from . import BayesLinearMF, BayesConv2dMF, BayesBatchNorm2dMF, BayesPReLUMF 10 | 11 | def to_bayesian(input, psi_init_range=[-6, -5], num_mc_samples=20): 12 | 13 | if isinstance(input, (Linear, Conv2d, BatchNorm2d, PReLU)): 14 | if isinstance(input, (Linear)): 15 | output = BayesLinearMF(input.in_features, input.out_features, 16 | input.bias, num_mc_samples=num_mc_samples) 17 | elif isinstance(input, (Conv2d)): 18 | output = BayesConv2dMF(input.in_channels, input.out_channels, 19 | input.kernel_size, input.stride, 20 | input.padding, input.dilation, 21 | input.groups, input.bias, 22 | num_mc_samples=num_mc_samples) 23 | elif isinstance(input, (PReLU)): 24 | output = BayesPReLUMF(input.num_parameters, num_mc_samples=num_mc_samples) 25 | else: 26 | output = BayesBatchNorm2dMF(input.num_features, input.eps, 27 | input.momentum, input.affine, 28 | input.track_running_stats, 29 | num_mc_samples=num_mc_samples) 30 | output.running_mean = input.running_mean 31 | output.running_var = input.running_var 32 | output.num_batches_tracked = input.num_batches_tracked 33 | 34 | if input.weight is not None: 35 | with torch.no_grad(): 36 | output.weight_mu = input.weight 37 | 38 | if hasattr(input, 'bias') and input.bias is not None: 39 | with torch.no_grad(): 40 | output.bias_mu = input.bias 41 | 42 | if output.weight_psi is not None: 43 | output.weight_psi.data.uniform_(psi_init_range[0], psi_init_range[1]) 44 | if hasattr(output, 'bias_psi') and output.bias_psi is not None: 45 | output.bias_psi.data.uniform_(psi_init_range[0], psi_init_range[1]) 46 | del input 47 | return output 48 | 49 | output = input 50 | for name, module in input.named_children(): 51 | output.add_module(name, to_bayesian(module, psi_init_range, num_mc_samples)) 52 | del input 53 | return output 54 | 55 | def to_deterministic(input): 56 | 57 | if isinstance(input, (BayesLinearMF, BayesConv2dMF, BayesBatchNorm2dMF)): 58 | if isinstance(input, (BayesLinearMF)): 59 | output = Linear(input.in_features, input.out_features, input.bias) 60 | elif isinstance(input, (BayesConv2dMF)): 61 | output = Conv2d(input.in_channels, input.out_channels, 62 | input.kernel_size, input.stride, 63 | input.padding, input.dilation, 64 | input.groups, input.bias) 65 | elif isinstance(input, (BayesPReLUMF)): 66 | output = PReLU(input.num_parameters) 67 | else: 68 | output = BatchNorm2d(input.num_features, input.eps, 69 | input.momentum, input.affine, 70 | input.track_running_stats) 71 | output.running_mean = input.running_mean 72 | output.running_var = input.running_var 73 | output.num_batches_tracked = input.num_batches_tracked 74 | 75 | with torch.no_grad(): 76 | if input.weight is not None: 77 | output.weight = input.weight_mu 78 | if hasattr(input, 'bias') and input.bias is not None: 79 | output.bias = input.bias_mu 80 | del input 81 | return output 82 | output = input 83 | for name, module in input.named_children(): 84 | output.add_module(name, to_deterministic(module)) 85 | del input 86 | return output 87 | -------------------------------------------------------------------------------- /scalablebdl/bnn_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .mean_field import BayesLinearMF, BayesConv2dMF, BayesBatchNorm2dMF, BayesPReLUMF 4 | from .empirical import BayesLinearEMP, BayesConv2dEMP, BayesBatchNorm2dEMP, BayesPReLUEMP 5 | from .low_rank import BayesLinearLR, BayesConv2dLR 6 | # from .implicit import BayesLinearIMP, BayesConv2dIMP, BayesBatchNorm2dIMP, BayesPReLUIMP 7 | 8 | # freeze and unfreeze work for mean-field and implicit posteriors 9 | def freeze(net): 10 | net.apply(_freeze) 11 | 12 | def _freeze(m): 13 | if isinstance(m, (BayesConv2dMF, BayesLinearMF, BayesBatchNorm2dMF, BayesPReLUMF)) \ 14 | or isinstance(m, (BayesConv2dLR, BayesLinearLR)): 15 | m.deterministic = True 16 | 17 | def unfreeze(net): 18 | net.apply(_unfreeze) 19 | 20 | def _unfreeze(m): 21 | if isinstance(m, (BayesConv2dMF, BayesLinearMF, BayesBatchNorm2dMF, BayesPReLUMF)) \ 22 | or isinstance(m, (BayesConv2dLR, BayesLinearLR)): 23 | m.deterministic = False 24 | 25 | # set_mc_sample_id only works for empirical posterior 26 | def set_mc_sample_id(net, num_mc_samples, mc_sample_id=None, batch_size=None): 27 | if mc_sample_id is None: 28 | mc_sample_id = np.random.randint(0, num_mc_samples, size=batch_size) 29 | else: 30 | if isinstance(mc_sample_id, int): 31 | assert mc_sample_id >= 0 and mc_sample_id < num_mc_samples, \ 32 | "Mc_sample_id must be in [0, num_mc_samples)" 33 | else: 34 | assert isinstance(mc_sample_id, np.ndarray) 35 | for m in net.modules(): 36 | if isinstance(m, (BayesConv2dEMP, BayesLinearEMP, BayesBatchNorm2dEMP, BayesPReLUEMP, 37 | BayesConv2dLR, BayesLinearLR)): 38 | m.mc_sample_id = mc_sample_id 39 | 40 | def disable_dropout(net): 41 | for m in net.modules(): 42 | if m.__class__.__name__.startswith('Dropout'): 43 | m.p = 0 44 | 45 | def parallel_eval(net): 46 | net.apply(_parallel_eval) 47 | 48 | def _parallel_eval(m): 49 | if isinstance(m, (BayesConv2dMF, BayesLinearMF, BayesBatchNorm2dMF, BayesPReLUMF)) \ 50 | or isinstance(m, (BayesConv2dLR, BayesLinearLR)) \ 51 | or isinstance(m, (BayesLinearEMP, BayesConv2dEMP, BayesBatchNorm2dEMP, BayesPReLUEMP)): 52 | m.parallel_eval = True 53 | 54 | def disable_parallel_eval(net): 55 | net.apply(_disable_parallel_eval) 56 | 57 | def _disable_parallel_eval(m): 58 | if isinstance(m, (BayesConv2dMF, BayesLinearMF, BayesBatchNorm2dMF, BayesPReLUMF)) \ 59 | or isinstance(m, (BayesConv2dLR, BayesLinearLR)) \ 60 | or isinstance(m, (BayesLinearEMP, BayesConv2dEMP, BayesBatchNorm2dEMP, BayesPReLUEMP)): 61 | m.parallel_eval = False 62 | 63 | def use_single_eps(net): 64 | net.apply(_use_single_eps) 65 | 66 | def _use_single_eps(m): 67 | if isinstance(m, (BayesConv2dMF, BayesLinearMF)): 68 | m.single_eps = True 69 | 70 | def use_local_reparam(net): 71 | net.apply(_use_local_reparam) 72 | 73 | def _use_local_reparam(m): 74 | if isinstance(m, (BayesConv2dMF, BayesLinearMF)): 75 | m.local_reparam = True 76 | 77 | def use_flipout(net): 78 | net.apply(_use_flipout) 79 | 80 | def _use_flipout(m): 81 | if isinstance(m, (BayesConv2dMF, BayesLinearMF)): 82 | m.flipout = True 83 | 84 | def Bayes_ensemble(loader, model, loss_metric=torch.nn.functional.cross_entropy, 85 | acc_metric=lambda arg1, arg2: (arg1.argmax(-1)==arg2).float().mean()): 86 | model.eval() 87 | parallel_eval(model) 88 | with torch.no_grad(): 89 | total_loss, total_acc = 0, 0 90 | for i, (input, target) in enumerate(loader): 91 | input = input.cuda(non_blocking=True) 92 | target = target.cuda(non_blocking=True) 93 | outputs = model(input).softmax(-1) 94 | output = outputs.mean(-2) 95 | total_loss += loss_metric(output.log(), target).item() 96 | total_acc += acc_metric(output, target).item() 97 | total_loss /= len(loader) 98 | total_acc /= len(loader) 99 | disable_parallel_eval(model) 100 | return total_loss, total_acc 101 | -------------------------------------------------------------------------------- /scalablebdl/implicit/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.init as init 6 | from torch.nn import Module, Parameter 7 | import torch.nn.functional as F 8 | from torch.nn.modules.utils import _pair 9 | 10 | class _BayesConvNdIMP(Module): 11 | r""" 12 | Applies Bayesian Convolution 13 | """ 14 | __constants__ = ['stride', 'padding', 'dilation', 15 | 'groups', 'bias', 'in_channels', 16 | 'out_channels', 'kernel_size'] 17 | 18 | def __init__(self, in_channels, out_channels, kernel_size, stride, 19 | padding, dilation, groups, bias): 20 | super(_BayesConvNdIMP, self).__init__() 21 | if in_channels % groups != 0: 22 | raise ValueError('in_channels must be divisible by groups') 23 | if out_channels % groups != 0: 24 | raise ValueError('out_channels must be divisible by groups') 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.kernel_size = kernel_size 28 | if isinstance(self.kernel_size, int): 29 | self.kernel_size = (self.kernel_size, self.kernel_size) 30 | self.stride = stride 31 | self.padding = padding 32 | self.dilation = dilation 33 | self.groups = groups 34 | 35 | self.weight_mu = Parameter(torch.Tensor( 36 | out_channels, in_channels // groups, *self.kernel_size)) 37 | self.weight_psi = Parameter(torch.Tensor( 38 | out_channels, in_channels // groups, *self.kernel_size)) 39 | 40 | if bias is None or bias is False : 41 | self.bias = False 42 | else: 43 | self.bias = True 44 | 45 | if self.bias: 46 | self.bias_mu = Parameter(torch.Tensor(out_channels)) 47 | self.bias_psi = Parameter(torch.Tensor(out_channels)) 48 | else: 49 | self.register_parameter('bias_mu', None) 50 | self.register_parameter('bias_psi', None) 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | n = self.in_channels 55 | n *= np.prod(list(self.kernel_size)) 56 | stdv = 1.0 / math.sqrt(n) 57 | self.weight_mu.data.uniform_(-stdv, stdv) 58 | self.weight_psi.data.uniform_(-6, -5) 59 | 60 | if self.bias : 61 | self.bias_mu.data.uniform_(-stdv, stdv) 62 | self.bias_psi.data.uniform_(-6, -5) 63 | 64 | def extra_repr(self): 65 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 66 | ', stride={stride}') 67 | s += ', padding={padding}' 68 | s += ', dilation={dilation}' 69 | s += ', groups={groups}' 70 | s += ', bias=False' 71 | return s.format(**self.__dict__) 72 | 73 | def __setstate__(self, state): 74 | super(_BayesConvNdMF, self).__setstate__(state) 75 | 76 | class BayesConv2dIMP(_BayesConvNdIMP): 77 | r""" 78 | Applies Bayesian Convolution for 2D inputs 79 | """ 80 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 81 | padding=0, dilation=1, groups=1, bias=False, deterministic=False): 82 | super(BayesConv2dIMP, self).__init__( 83 | in_channels, out_channels, kernel_size, stride, 84 | padding, dilation, groups, bias) 85 | self.deterministic = deterministic 86 | self.weight_size = list(self.weight_mu.shape) 87 | self.bias_size = list(self.bias_mu.shape) if self.bias else None 88 | self.mul_exp_add = MulExpAddFunction.apply 89 | 90 | def forward(self, input): 91 | r""" 92 | Overriden. 93 | """ 94 | if self.deterministic: 95 | out = F.conv2d(input, weight=self.weight_mu, bias=self.bias_mu, 96 | stride=self.stride, dilation=self.dilation, 97 | groups=self.groups, padding=self.padding) 98 | else: 99 | bs = input.shape[0] 100 | weight = self.mul_exp_add(torch.randn(bs, *self.weight_size, 101 | device=input.device, 102 | dtype=input.dtype), 103 | self.weight_psi, self.weight_mu).view( 104 | bs*self.weight_size[0], *self.weight_size[1:]) 105 | out = F.conv2d(input.view(1, -1, input.shape[2], input.shape[3]), 106 | weight=weight, bias=None, 107 | stride=self.stride, dilation=self.dilation, 108 | groups=self.groups*bs, padding=self.padding) 109 | out = out.view(bs, self.out_channels, out.shape[2], out.shape[3]) 110 | 111 | if self.bias: 112 | bias = self.mul_exp_add(torch.randn(bs, *self.bias_size, 113 | device=input.device, 114 | dtype=input.dtype), 115 | self.bias_psi, self.bias_mu) 116 | out = out + bias[:, :, None, None] 117 | return out 118 | -------------------------------------------------------------------------------- /scalablebdl/empirical/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.init as init 6 | from torch.nn import Module, Parameter 7 | import torch.nn.functional as F 8 | from torch.nn.modules.utils import _pair 9 | 10 | class _BayesConvNdEMP(Module): 11 | r""" 12 | Applies Bayesian Convolution 13 | """ 14 | __constants__ = ['stride', 'padding', 'dilation', 15 | 'groups', 'bias', 'in_channels', 16 | 'out_channels', 'kernel_size', 'num_mc_samples'] 17 | 18 | def __init__(self, in_channels, out_channels, kernel_size, stride, 19 | padding, dilation, groups, bias, num_mc_samples): 20 | super(_BayesConvNdEMP, self).__init__() 21 | if in_channels % groups != 0: 22 | raise ValueError('in_channels must be divisible by groups') 23 | if out_channels % groups != 0: 24 | raise ValueError('out_channels must be divisible by groups') 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.kernel_size = kernel_size 28 | if isinstance(self.kernel_size, int): 29 | self.kernel_size = (self.kernel_size, self.kernel_size) 30 | self.stride = stride 31 | self.padding = padding 32 | self.dilation = dilation 33 | self.groups = groups 34 | self.mc_sample_id = None 35 | self.num_mc_samples = num_mc_samples 36 | 37 | self.weights = Parameter(torch.Tensor( 38 | num_mc_samples, out_channels, in_channels // groups, *self.kernel_size)) 39 | self.weight_size = list(self.weights.shape)[1:] 40 | 41 | if bias is None or bias is False: 42 | self.bias = False 43 | self.register_parameter('biases', None) 44 | else: 45 | self.bias = True 46 | self.biases = Parameter(torch.Tensor(num_mc_samples, out_channels)) 47 | self.reset_parameters() 48 | 49 | def reset_parameters(self): 50 | n = self.in_channels 51 | n *= np.prod(list(self.kernel_size)) 52 | stdv = 1.0 / math.sqrt(n) 53 | for i in range(self.num_mc_samples): 54 | self.weights[i].data.uniform_(-stdv, stdv) 55 | if self.bias : 56 | self.biases[i].data.uniform_(-stdv, stdv) 57 | 58 | def extra_repr(self): 59 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 60 | ', stride={stride}') 61 | s += ', padding={padding}' 62 | s += ', dilation={dilation}' 63 | s += ', groups={groups}' 64 | s += ', bias={bias}' 65 | s += ', num_mc_samples={num_mc_samples}' 66 | return s.format(**self.__dict__) 67 | 68 | def __setstate__(self, state): 69 | super(_BayesConvNdEMP, self).__setstate__(state) 70 | 71 | class BayesConv2dEMP(_BayesConvNdEMP): 72 | r""" 73 | Applies Bayesian Convolution for 2D inputs 74 | """ 75 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 76 | padding=0, dilation=1, groups=1, bias=False, num_mc_samples=20): 77 | super(BayesConv2dEMP, self).__init__( 78 | in_channels, out_channels, kernel_size, stride, 79 | padding, dilation, groups, bias, num_mc_samples) 80 | self.parallel_eval = False 81 | 82 | def forward(self, input): 83 | r""" 84 | Overriden. 85 | """ 86 | if self.parallel_eval: 87 | if input.dim() == 4: 88 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1, 1, 1) 89 | 90 | out = F.conv2d(input.flatten(start_dim=1, end_dim=2), 91 | weight=self.weights.flatten(0, 1), bias=None, 92 | stride=self.stride, dilation=self.dilation, 93 | groups=self.groups*self.num_mc_samples, 94 | padding=self.padding) 95 | out = out.view(out.shape[0], self.num_mc_samples, 96 | self.out_channels, out.shape[2], out.shape[3]) 97 | if self.bias: 98 | out = out + self.biases[None, :, :, None, None] 99 | elif isinstance(self.mc_sample_id, int): 100 | out = F.conv2d(input, weight=self.weights[self.mc_sample_id % self.num_mc_samples], 101 | bias=self.biases[self.mc_sample_id % self.num_mc_samples] if self.bias else None, 102 | stride=self.stride, dilation=self.dilation, 103 | groups=self.groups, padding=self.padding) 104 | else: 105 | bs = input.shape[0] 106 | idx = torch.tensor(self.mc_sample_id, device=input.device, dtype=torch.long) 107 | weight = self.weights[idx].view(bs*self.weight_size[0], *self.weight_size[1:]) 108 | out = F.conv2d(input.view(1, -1, input.shape[2], input.shape[3]), 109 | weight=weight, bias=None, 110 | stride=self.stride, dilation=self.dilation, 111 | groups=self.groups*bs, padding=self.padding) 112 | out = out.view(bs, self.out_channels, out.shape[2], out.shape[3]) 113 | 114 | if self.bias: 115 | bias = self.biases[idx] 116 | out = out + bias[:, :, None, None] 117 | return out 118 | -------------------------------------------------------------------------------- /laplace/utils/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple, Callable, Optional 4 | 5 | 6 | __all__ = ['FeatureExtractor'] 7 | 8 | 9 | class FeatureExtractor(nn.Module): 10 | """Feature extractor for a PyTorch neural network. 11 | A wrapper which can return the output of the penultimate layer in addition to 12 | the output of the last layer for each forward pass. If the name of the last 13 | layer is not known, it can determine it automatically. It assumes that the 14 | last layer is linear and that for every forward pass the last layer is the same. 15 | If the name of the last layer is known, it can be passed as a parameter at 16 | initilization; this is the safest way to use this class. 17 | Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76. 18 | 19 | Parameters 20 | ---------- 21 | model : torch.nn.Module 22 | PyTorch model 23 | last_layer_name : str, default=None 24 | if the name of the last layer is already known, otherwise it will 25 | be determined automatically. 26 | """ 27 | def __init__(self, model: nn.Module, last_layer_name: Optional[str] = None) -> None: 28 | super().__init__() 29 | self.model = model 30 | self._features = dict() 31 | if last_layer_name is None: 32 | self.last_layer = None 33 | else: 34 | self.set_last_layer(last_layer_name) 35 | 36 | def forward(self, x: torch.Tensor) -> torch.Tensor: 37 | """Forward pass. If the last layer is not known yet, it will be 38 | determined when this function is called for the first time. 39 | 40 | Parameters 41 | ---------- 42 | x : torch.Tensor 43 | one batch of data to use as input for the forward pass 44 | """ 45 | if self.last_layer is None: 46 | # if this is the first forward pass and last layer is unknown 47 | out = self.find_last_layer(x) 48 | else: 49 | # if last and penultimate layers are already known 50 | out = self.model(x) 51 | return out 52 | 53 | def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 54 | """Forward pass which returns the output of the penultimate layer along 55 | with the output of the last layer. If the last layer is not known yet, 56 | it will be determined when this function is called for the first time. 57 | 58 | Parameters 59 | ---------- 60 | x : torch.Tensor 61 | one batch of data to use as input for the forward pass 62 | """ 63 | out = self.forward(x) 64 | features = self._features[self._last_layer_name] 65 | return out, features 66 | 67 | def set_last_layer(self, last_layer_name: str) -> None: 68 | """Set the last layer of the model by its name. This sets the forward 69 | hook to get the output of the penultimate layer. 70 | 71 | Parameters 72 | ---------- 73 | last_layer_name : str 74 | the name of the last layer (fixed in `model.named_modules()`). 75 | """ 76 | # set last_layer attributes and check if it is linear 77 | self._last_layer_name = last_layer_name 78 | self.last_layer = dict(self.model.named_modules())[last_layer_name] 79 | if not isinstance(self.last_layer, nn.Linear): 80 | raise ValueError('Use model with a linear last layer.') 81 | 82 | # set forward hook to extract features in future forward passes 83 | self.last_layer.register_forward_hook(self._get_hook(last_layer_name)) 84 | 85 | def _get_hook(self, name: str) -> Callable: 86 | def hook(_, input, __): 87 | # only accepts one input (expects linear layer) 88 | self._features[name] = input[0].detach() 89 | return hook 90 | 91 | def find_last_layer(self, x: torch.Tensor) -> torch.Tensor: 92 | """Automatically determines the last layer of the model with one 93 | forward pass. It assumes that the last layer is the same for every 94 | forward pass and that it is an instance of `torch.nn.Linear`. 95 | Might not work with every architecture, but is tested with all PyTorch 96 | torchvision classification models (besides SqueezeNet, which has no 97 | linear last layer). 98 | 99 | Parameters 100 | ---------- 101 | x : torch.Tensor 102 | one batch of data to use as input for the forward pass 103 | """ 104 | if self.last_layer is not None: 105 | raise ValueError('Last layer is already known.') 106 | 107 | act_out = dict() 108 | def get_act_hook(name): 109 | def act_hook(_, input, __): 110 | # only accepts one input (expects linear layer) 111 | try: 112 | act_out[name] = input[0].detach() 113 | except (IndexError, AttributeError): 114 | act_out[name] = None 115 | # remove hook 116 | handles[name].remove() 117 | return act_hook 118 | 119 | # set hooks for all modules 120 | handles = dict() 121 | for name, module in self.model.named_modules(): 122 | handles[name] = module.register_forward_hook(get_act_hook(name)) 123 | 124 | # check if model has more than one module 125 | # (there might be pathological exceptions) 126 | if len(handles) <= 2: 127 | raise ValueError('The model only has one module.') 128 | 129 | # forward pass to find execution order 130 | out = self.model(x) 131 | 132 | # find the last layer, store features, return output of forward pass 133 | keys = list(act_out.keys()) 134 | for key in reversed(keys): 135 | layer = dict(self.model.named_modules())[key] 136 | if len(list(layer.children())) == 0: 137 | self.set_last_layer(key) 138 | 139 | # save features from first forward pass 140 | self._features[key] = act_out[key] 141 | 142 | return out 143 | 144 | raise ValueError('Something went wrong (all modules have children).') 145 | -------------------------------------------------------------------------------- /scalablebdl/mean_field/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import Module, Parameter 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | from .utils import MulExpAddFunction 9 | 10 | class BayesLinearMF(Module): 11 | r""" 12 | Applies Bayesian Linear 13 | """ 14 | __constants__ = ['bias', 'in_features', 'out_features'] 15 | 16 | def __init__(self, in_features, out_features, bias=True, 17 | deterministic=False, num_mc_samples=None): 18 | super(BayesLinearMF, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.deterministic = deterministic 22 | self.num_mc_samples = num_mc_samples 23 | self.parallel_eval = False 24 | 25 | self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) 26 | self.weight_psi = Parameter(torch.Tensor(out_features, in_features)) 27 | 28 | if bias is None or bias is False: 29 | self.bias = False 30 | else: 31 | self.bias = True 32 | 33 | if self.bias: 34 | self.bias_mu = Parameter(torch.Tensor(out_features)) 35 | self.bias_psi = Parameter(torch.Tensor(out_features)) 36 | else: 37 | self.register_parameter('bias_mu', None) 38 | self.register_parameter('bias_psi', None) 39 | 40 | self.reset_parameters() 41 | 42 | self.weight_size = list(self.weight_mu.shape) 43 | self.bias_size = list(self.bias_mu.shape) if self.bias else None 44 | self.mul_exp_add = MulExpAddFunction.apply 45 | 46 | self.local_reparam = False 47 | self.flipout = False 48 | self.single_eps = False 49 | 50 | def reset_parameters(self): 51 | stdv = 1. / math.sqrt(self.weight_mu.size(1)) 52 | self.weight_mu.data.uniform_(-stdv, stdv) 53 | self.weight_psi.data.uniform_(-6, -5) 54 | if self.bias: 55 | self.bias_mu.data.uniform_(-stdv, stdv) 56 | self.bias_psi.data.uniform_(-6, -5) 57 | 58 | def forward(self, input): 59 | r""" 60 | Overriden. 61 | """ 62 | if self.deterministic: 63 | weight = self.weight_mu 64 | bias = self.bias_mu if self.bias else None 65 | out = F.linear(input, weight, bias) 66 | elif not self.parallel_eval: 67 | if self.single_eps: 68 | weight = torch.randn_like(self.weight_mu).mul_(self.weight_psi.exp()).add_(self.weight_mu) 69 | if self.bias: 70 | bias = torch.randn_like(self.bias_mu).mul_(self.bias_psi.exp()).add_(self.bias_mu) 71 | else: 72 | bias = None 73 | out = F.linear(input, weight, bias) 74 | elif self.local_reparam: 75 | act_mu = F.linear(input, self.weight_mu, None) 76 | act_var = F.linear(input**2, (self.weight_psi*2).exp_(), None) 77 | act_std = act_var.clamp(1e-8).sqrt_() 78 | out = torch.randn_like(act_mu).mul_(act_std).add_(act_mu) 79 | if self.bias: 80 | bias = torch.randn(input.shape[0], *self.bias_size, device=input.device, dtype=input.dtype).mul_(self.bias_psi.exp()).add_(self.bias_mu) 81 | out = out + bias 82 | elif self.flipout: 83 | outputs = F.linear(input, self.weight_mu, self.bias_mu if self.bias else None) 84 | # sampling perturbation signs 85 | sign_input = torch.empty_like(input).uniform_(-1, 1).sign() 86 | sign_output = torch.empty_like(outputs).uniform_(-1, 1).sign() 87 | # gettin perturbation weights 88 | delta_kernel = torch.randn_like(self.weight_psi).mul(self.weight_psi.exp()) 89 | delta_bias = torch.randn_like(self.bias_psi).mul(self.bias_psi.exp()) if self.bias else None 90 | # perturbed feedforward 91 | perturbed_outputs = F.linear(input * sign_input, delta_kernel, delta_bias) 92 | out = outputs + perturbed_outputs * sign_output 93 | else: 94 | bs = input.shape[0] 95 | weight = self.mul_exp_add(torch.empty(bs, *self.weight_size, 96 | device=input.device, 97 | dtype=input.dtype).normal_(0, 1), 98 | self.weight_psi, self.weight_mu) 99 | 100 | out = torch.bmm(weight, input.unsqueeze(2)).squeeze() 101 | if self.bias: 102 | bias = self.mul_exp_add(torch.empty(bs, *self.bias_size, 103 | device=input.device, 104 | dtype=input.dtype).normal_(0, 1), 105 | self.bias_psi, self.bias_mu) 106 | out = out + bias 107 | else: 108 | if input.dim() == 2: 109 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1) 110 | weight = self.mul_exp_add(torch.empty(self.num_mc_samples, 111 | *self.weight_size, 112 | device=input.device, 113 | dtype=input.dtype).normal_(0, 1), 114 | self.weight_psi, self.weight_mu) 115 | out = torch.bmm(weight, input.permute(1, 2, 0)).permute(2, 0, 1) 116 | if self.bias: 117 | bias = self.mul_exp_add(torch.empty(self.num_mc_samples, 118 | *self.bias_size, 119 | device=input.device, 120 | dtype=input.dtype).normal_(0, 1), 121 | self.bias_psi, self.bias_mu) 122 | out = out + bias 123 | return out 124 | 125 | def extra_repr(self): 126 | r""" 127 | Overriden. 128 | """ 129 | return 'in_features={}, out_features={}, bias={}'.format( 130 | self.in_features, self.out_features, self.bias is not None) 131 | -------------------------------------------------------------------------------- /laplace/curvature/backpack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from backpack import backpack, extend, memory_cleanup 4 | from backpack.extensions import DiagGGNExact, DiagGGNMC, KFAC, KFLR, SumGradSquared, BatchGrad 5 | from backpack.context import CTX 6 | 7 | from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface 8 | from laplace.utils import Kron 9 | 10 | 11 | class BackPackInterface(CurvatureInterface): 12 | """Interface for Backpack backend. 13 | """ 14 | def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): 15 | super().__init__(model, likelihood, last_layer, subnetwork_indices) 16 | extend(self._model) 17 | extend(self.lossfunc) 18 | 19 | def jacobians(self, x): 20 | """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\) 21 | using backpack's BatchGrad per output dimension. 22 | 23 | Parameters 24 | ---------- 25 | x : torch.Tensor 26 | input data `(batch, input_shape)` on compatible device with model. 27 | 28 | Returns 29 | ------- 30 | Js : torch.Tensor 31 | Jacobians `(batch, parameters, outputs)` 32 | f : torch.Tensor 33 | output function `(batch, outputs)` 34 | """ 35 | model = extend(self.model) 36 | to_stack = [] 37 | for i in range(model.output_size): 38 | model.zero_grad() 39 | out = model(x) 40 | with backpack(BatchGrad()): 41 | if model.output_size > 1: 42 | out[:, i].sum().backward() 43 | else: 44 | out.sum().backward() 45 | to_cat = [] 46 | for param in model.parameters(): 47 | to_cat.append(param.grad_batch.detach().reshape(x.shape[0], -1)) 48 | delattr(param, 'grad_batch') 49 | Jk = torch.cat(to_cat, dim=1) 50 | if self.subnetwork_indices is not None: 51 | Jk = Jk[:, self.subnetwork_indices] 52 | to_stack.append(Jk) 53 | if i == 0: 54 | f = out.detach() 55 | 56 | model.zero_grad() 57 | CTX.remove_hooks() 58 | _cleanup(model) 59 | if model.output_size > 1: 60 | return torch.stack(to_stack, dim=2).transpose(1, 2), f 61 | else: 62 | return Jk.unsqueeze(-1).transpose(1, 2), f 63 | 64 | def gradients(self, x, y): 65 | """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter 66 | \\(\\theta\\) using Backpack's BatchGrad. 67 | 68 | Parameters 69 | ---------- 70 | x : torch.Tensor 71 | input data `(batch, input_shape)` on compatible device with model. 72 | y : torch.Tensor 73 | 74 | Returns 75 | ------- 76 | loss : torch.Tensor 77 | Gs : torch.Tensor 78 | gradients `(batch, parameters)` 79 | """ 80 | f = self.model(x) 81 | loss = self.lossfunc(f, y) 82 | with backpack(BatchGrad()): 83 | loss.backward() 84 | Gs = torch.cat([p.grad_batch.data.flatten(start_dim=1) 85 | for p in self._model.parameters()], dim=1) 86 | if self.subnetwork_indices is not None: 87 | Gs = Gs[:, self.subnetwork_indices] 88 | return Gs, loss 89 | 90 | 91 | class BackPackGGN(BackPackInterface, GGNInterface): 92 | """Implementation of the `GGNInterface` using Backpack. 93 | """ 94 | def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): 95 | super().__init__(model, likelihood, last_layer, subnetwork_indices) 96 | self.stochastic = stochastic 97 | 98 | def _get_diag_ggn(self): 99 | if self.stochastic: 100 | return torch.cat([p.diag_ggn_mc.data.flatten() for p in self._model.parameters()]) 101 | else: 102 | return torch.cat([p.diag_ggn_exact.data.flatten() for p in self._model.parameters()]) 103 | 104 | def _get_kron_factors(self): 105 | if self.stochastic: 106 | return Kron([p.kfac for p in self._model.parameters()]) 107 | else: 108 | return Kron([p.kflr for p in self._model.parameters()]) 109 | 110 | @staticmethod 111 | def _rescale_kron_factors(kron, M, N): 112 | # Renormalize Kronecker factor to sum up correctly over N data points with batches of M 113 | # for M=N (full-batch) just M/N=1 114 | for F in kron.kfacs: 115 | if len(F) == 2: 116 | F[1] *= M/N 117 | return kron 118 | 119 | def diag(self, X, y, **kwargs): 120 | context = DiagGGNMC if self.stochastic else DiagGGNExact 121 | f = self.model(X) 122 | loss = self.lossfunc(f, y) 123 | with backpack(context()): 124 | loss.backward() 125 | dggn = self._get_diag_ggn() 126 | if self.subnetwork_indices is not None: 127 | dggn = dggn[self.subnetwork_indices] 128 | 129 | return self.factor * loss.detach(), self.factor * dggn 130 | 131 | def kron(self, X, y, N, **kwargs) -> [torch.Tensor, Kron]: 132 | context = KFAC if self.stochastic else KFLR 133 | f = self.model(X) 134 | loss = self.lossfunc(f, y) 135 | with backpack(context()): 136 | loss.backward() 137 | kron = self._get_kron_factors() 138 | kron = self._rescale_kron_factors(kron, len(y), N) 139 | 140 | return self.factor * loss.detach(), self.factor * kron 141 | 142 | 143 | class BackPackEF(BackPackInterface, EFInterface): 144 | """Implementation of `EFInterface` using Backpack. 145 | """ 146 | 147 | def diag(self, X, y, **kwargs): 148 | f = self.model(X) 149 | loss = self.lossfunc(f, y) 150 | with backpack(SumGradSquared()): 151 | loss.backward() 152 | diag_EF = torch.cat([p.sum_grad_squared.data.flatten() 153 | for p in self._model.parameters()]) 154 | if self.subnetwork_indices is not None: 155 | diag_EF = diag_EF[self.subnetwork_indices] 156 | 157 | return self.factor * loss.detach(), self.factor * diag_EF 158 | 159 | def kron(self, X, y, **kwargs): 160 | raise NotImplementedError('Unavailable through Backpack.') 161 | 162 | 163 | def _cleanup(module): 164 | for child in module.children(): 165 | _cleanup(child) 166 | 167 | setattr(module, "_backpack_extend", False) 168 | memory_cleanup(module) 169 | -------------------------------------------------------------------------------- /scalablebdl/implicit/batchnorm.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import Module, Parameter 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | class _BayesBatchNormIMP(Module): 9 | r""" 10 | Applies Bayesian Batch Normalization over a 2D or 3D input 11 | """ 12 | __constants__ = ['track_running_stats', 13 | 'momentum', 'eps', 'weight', 'bias', 14 | 'running_mean', 'running_var', 'num_batches_tracked', 15 | 'num_features', 'affine'] 16 | 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 18 | track_running_stats=True, deterministic=False): 19 | super(_BayesBatchNormIMP, self).__init__() 20 | self.num_features = num_features 21 | self.eps = eps 22 | self.momentum = momentum 23 | self.affine = affine 24 | self.track_running_stats = track_running_stats 25 | self.deterministic = deterministic 26 | if self.affine: 27 | self.weight_mu = Parameter(torch.Tensor(num_features)) 28 | self.weight_psi = Parameter(torch.Tensor(num_features)) 29 | 30 | self.bias_mu = Parameter(torch.Tensor(num_features)) 31 | self.bias_psi = Parameter(torch.Tensor(num_features)) 32 | else: 33 | self.register_parameter('weight_mu', None) 34 | self.register_parameter('weight_psi', None) 35 | self.register_parameter('bias_mu', None) 36 | self.register_parameter('bias_psi', None) 37 | if self.track_running_stats: 38 | self.register_buffer('running_mean', torch.zeros(num_features)) 39 | self.register_buffer('running_var', torch.ones(num_features)) 40 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 41 | else: 42 | self.register_parameter('running_mean', None) 43 | self.register_parameter('running_var', None) 44 | self.register_parameter('num_batches_tracked', None) 45 | self.reset_parameters() 46 | 47 | self.weight_size = list(self.weight_mu.shape) if self.affine else None 48 | self.bias_size = list(self.bias_mu.shape) if self.affine else None 49 | self.mul_exp_add = MulExpAddFunction.apply 50 | 51 | def reset_running_stats(self): 52 | if self.track_running_stats: 53 | self.running_mean.zero_() 54 | self.running_var.fill_(1) 55 | self.num_batches_tracked.zero_() 56 | 57 | def reset_parameters(self): 58 | self.reset_running_stats() 59 | if self.affine: 60 | self.weight_mu.data.fill_(1) 61 | self.weight_psi.data.uniform_(-6, -5) 62 | self.bias_mu.data.zero_() 63 | self.bias_psi.data.uniform_(-6, -5) 64 | 65 | def _check_input_dim(self, input): 66 | raise NotImplementedError 67 | 68 | def forward(self, input): 69 | self._check_input_dim(input) 70 | 71 | if self.momentum is None: 72 | exponential_average_factor = 0.0 73 | else: 74 | exponential_average_factor = self.momentum 75 | 76 | if self.training and self.track_running_stats: 77 | if self.num_batches_tracked is not None: 78 | self.num_batches_tracked += 1 79 | if self.momentum is None: 80 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 81 | else: 82 | exponential_average_factor = self.momentum 83 | 84 | out = F.batch_norm( 85 | input, self.running_mean, self.running_var, None, None, 86 | self.training or not self.track_running_stats, 87 | exponential_average_factor, self.eps) 88 | 89 | if self.affine : 90 | if self.deterministic: 91 | weight = self.weight_mu.unsqueeze(0) 92 | bias = self.bias_mu.unsqueeze(0) 93 | else: 94 | bs = input.shape[0] 95 | weight = self.mul_exp_add(torch.randn(bs, *self.weight_size, 96 | device=input.device, 97 | dtype=input.dtype), 98 | self.weight_psi, self.weight_mu) 99 | 100 | bias = self.mul_exp_add(torch.randn(bs, *self.bias_size, 101 | device=input.device, 102 | dtype=input.dtype), 103 | self.bias_psi, self.bias_mu) 104 | 105 | if out.dim() == 4: 106 | out = torch.addcmul(bias[:, :, None, None], 107 | weight[:, :, None, None], out) 108 | elif out.dim() == 2: 109 | out = torch.addcmul(bias, weight, out) 110 | else: 111 | raise NotImplementedError 112 | return out 113 | 114 | def extra_repr(self): 115 | return '{num_features}, ' \ 116 | 'eps={eps}, momentum={momentum}, affine={affine}, ' \ 117 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 118 | 119 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 120 | missing_keys, unexpected_keys, error_msgs): 121 | version = local_metadata.get('version', None) 122 | 123 | if (version is None or version < 2) and self.track_running_stats: 124 | num_batches_tracked_key = prefix + 'num_batches_tracked' 125 | if num_batches_tracked_key not in state_dict: 126 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 127 | 128 | super(_BayesBatchNormIMP, self)._load_from_state_dict( 129 | state_dict, prefix, local_metadata, strict, 130 | missing_keys, unexpected_keys, error_msgs) 131 | 132 | class BayesBatchNorm2dIMP(_BayesBatchNormIMP): 133 | r""" 134 | Applies Bayesian Batch Normalization over a 2D input 135 | """ 136 | def _check_input_dim(self, input): 137 | if input.dim() != 4: 138 | raise ValueError('expected 4D input (got {}D input)' 139 | .format(input.dim())) 140 | 141 | class BayesBatchNorm1dIMP(_BayesBatchNormIMP): 142 | def _check_input_dim(self, input): 143 | if input.dim() != 2 and input.dim() != 3: 144 | raise ValueError('expected 2D or 3D input (got {}D input)' 145 | .format(input.dim())) 146 | -------------------------------------------------------------------------------- /scalablebdl/low_rank/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.init as init 6 | from torch.nn import Module, Parameter 7 | import torch.nn.functional as F 8 | from torch.nn.modules.utils import _pair 9 | 10 | class _BayesConvNdLR(Module): 11 | r""" 12 | Applies Bayesian Convolution 13 | """ 14 | __constants__ = ['stride', 'padding', 'dilation', 15 | 'groups', 'bias', 'in_channels', 16 | 'out_channels', 'kernel_size', 'num_mc_samples', 'rank'] 17 | 18 | def __init__(self, in_channels, out_channels, kernel_size, stride, 19 | padding, dilation, groups, bias, num_mc_samples, rank, pert_init_std): 20 | super(_BayesConvNdLR, self).__init__() 21 | if in_channels % groups != 0: 22 | raise ValueError('in_channels must be divisible by groups') 23 | if out_channels % groups != 0: 24 | raise ValueError('out_channels must be divisible by groups') 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.kernel_size = kernel_size 28 | if isinstance(self.kernel_size, int): 29 | self.kernel_size = (self.kernel_size, self.kernel_size) 30 | self.stride = stride 31 | self.padding = padding 32 | self.dilation = dilation 33 | self.groups = groups 34 | self.mc_sample_id = None 35 | self.deterministic = False 36 | self.num_mc_samples = num_mc_samples 37 | self.rank = rank 38 | self.pert_init_std = pert_init_std 39 | 40 | self.weight_mu = Parameter(torch.Tensor(out_channels, 41 | in_channels // groups, *self.kernel_size)) 42 | self.in_perturbations = Parameter(torch.Tensor(num_mc_samples, rank, 43 | in_channels//groups*np.prod(list( self.kernel_size)))) 44 | self.out_perturbations = Parameter(torch.Tensor(num_mc_samples, 45 | out_channels, rank)) 46 | self.weight_size = list(self.weight_mu.shape) 47 | 48 | if bias is None or bias is False: 49 | self.register_parameter('bias', None) 50 | else: 51 | self.bias = Parameter(torch.Tensor(out_channels)) 52 | self.reset_parameters() 53 | 54 | def reset_parameters(self): 55 | n = self.in_channels 56 | n *= np.prod(list(self.kernel_size)) 57 | stdv = 1.0 / math.sqrt(n) 58 | self.weight_mu.data.uniform_(-stdv, stdv) 59 | if self.bias is not None: 60 | self.bias.data.uniform_(-stdv, stdv) 61 | m = math.sqrt(1./self.rank) 62 | v = math.sqrt((math.sqrt(self.rank*(self.pert_init_std**2)+1) - 1)/self.rank) 63 | self.in_perturbations.data.normal_(m, v) 64 | self.out_perturbations.data.normal_(m, v) 65 | 66 | def extra_repr(self): 67 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 68 | ', stride={stride}') 69 | s += ', padding={padding}' 70 | s += ', dilation={dilation}' 71 | s += ', groups={groups}' 72 | s += ', num_mc_samples={num_mc_samples}' 73 | s += ', rank={rank}' 74 | if self.bias is None: 75 | s += ', bias=False' 76 | return s.format(**self.__dict__) 77 | 78 | def __setstate__(self, state): 79 | super(_BayesConvNdLR, self).__setstate__(state) 80 | 81 | class BayesConv2dLR(_BayesConvNdLR): 82 | r""" 83 | Applies Bayesian Convolution for 2D inputs 84 | """ 85 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 86 | padding=0, dilation=1, groups=1, bias=False, 87 | num_mc_samples=20, rank=1, pert_init_std=0.2): 88 | super(BayesConv2dLR, self).__init__( 89 | in_channels, out_channels, kernel_size, stride, 90 | padding, dilation, groups, bias, num_mc_samples, rank, pert_init_std) 91 | self.parallel_eval = False 92 | 93 | def forward(self, input): 94 | r""" 95 | Overriden. 96 | """ 97 | if self.deterministic: 98 | out = F.conv2d(input, weight=self.weight_mu, 99 | bias=self.bias, 100 | stride=self.stride, dilation=self.dilation, 101 | groups=self.groups, padding=self.padding) 102 | elif self.parallel_eval: 103 | if input.dim() == 4: 104 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1, 1, 1) 105 | 106 | perturbations = torch.bmm(self.out_perturbations, self.in_perturbations) 107 | weight = perturbations.view(self.num_mc_samples, *self.weight_size).mul_(self.weight_mu) 108 | out = F.conv2d(input.flatten(start_dim=1, end_dim=2), 109 | weight=weight.flatten(0, 1), bias=None, 110 | stride=self.stride, dilation=self.dilation, 111 | groups=self.groups*self.num_mc_samples, 112 | padding=self.padding) 113 | out = out.view(out.shape[0], self.num_mc_samples, 114 | self.out_channels, out.shape[2], out.shape[3]) 115 | if self.bias is not None: 116 | out = out + self.bias[None, None, :, None, None] 117 | elif isinstance(self.mc_sample_id, int): 118 | self.mc_sample_id %= self.num_mc_samples 119 | perturbations = torch.matmul(self.out_perturbations[self.mc_sample_id], 120 | self.in_perturbations[self.mc_sample_id]) 121 | weight = perturbations.view(*self.weight_size).mul_(self.weight_mu) 122 | out = F.conv2d(input, weight=weight, 123 | bias=self.bias, 124 | stride=self.stride, dilation=self.dilation, 125 | groups=self.groups, padding=self.padding) 126 | else: 127 | bs = input.shape[0] 128 | idx = torch.tensor(self.mc_sample_id, device=input.device, dtype=torch.long) 129 | perturbations = torch.bmm(self.out_perturbations[idx], self.in_perturbations[idx]) 130 | weight = perturbations.view(bs, *self.weight_size).mul_(self.weight_mu) 131 | out = F.conv2d(input.view(1, -1, input.shape[2], input.shape[3]), 132 | weight=weight.flatten(0, 1), bias=None, 133 | stride=self.stride, dilation=self.dilation, 134 | groups=self.groups*bs, padding=self.padding) 135 | out = out.view(bs, self.out_channels, out.shape[2], out.shape[3]) 136 | 137 | if self.bias is not None: 138 | out = out + bias[None, :, None, None] 139 | return out 140 | -------------------------------------------------------------------------------- /scalablebdl/empirical/batchnorm.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import Module, Parameter 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | class _BayesBatchNormEMP(Module): 9 | r""" 10 | Applies Bayesian Batch Normalization over a 2D or 3D input 11 | """ 12 | __constants__ = ['track_running_stats', 13 | 'momentum', 'eps', 'weight', 'bias', 14 | 'running_mean', 'running_var', 'num_batches_tracked', 15 | 'num_features', 'affine', 'num_mc_samples'] 16 | 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 18 | track_running_stats=True, num_mc_samples=20): 19 | super(_BayesBatchNormEMP, self).__init__() 20 | self.num_features = num_features 21 | self.eps = eps 22 | self.momentum = momentum 23 | self.affine = affine 24 | self.track_running_stats = track_running_stats 25 | self.mc_sample_id = None 26 | self.num_mc_samples = num_mc_samples 27 | self.parallel_eval = False 28 | if self.affine: 29 | self.weights = Parameter(torch.Tensor(num_mc_samples, num_features)) 30 | self.biases = Parameter(torch.Tensor(num_mc_samples, num_features)) 31 | else: 32 | self.register_parameter('weights', None) 33 | self.register_parameter('biases', None) 34 | if self.track_running_stats: 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 38 | else: 39 | self.register_parameter('running_mean', None) 40 | self.register_parameter('running_var', None) 41 | self.register_parameter('num_batches_tracked', None) 42 | self.reset_parameters() 43 | 44 | def reset_running_stats(self): 45 | if self.track_running_stats: 46 | self.running_mean.zero_() 47 | self.running_var.fill_(1) 48 | self.num_batches_tracked.zero_() 49 | 50 | def reset_parameters(self): 51 | self.reset_running_stats() 52 | if self.affine: 53 | self.weights.data.fill_(1) 54 | self.biases.data.zero_() 55 | 56 | def _check_input_dim(self, input): 57 | raise NotImplementedError 58 | 59 | def forward(self, input): 60 | self._check_input_dim(input) 61 | 62 | if self.momentum is None: 63 | exponential_average_factor = 0.0 64 | else: 65 | exponential_average_factor = self.momentum 66 | 67 | if self.training and self.track_running_stats: 68 | if self.num_batches_tracked is not None: 69 | self.num_batches_tracked += 1 70 | if self.momentum is None: 71 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 72 | else: 73 | exponential_average_factor = self.momentum 74 | 75 | if self.parallel_eval: 76 | if input.dim() == 4: 77 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1, 1, 1) 78 | elif input.dim() == 2: 79 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1) 80 | input = input.flatten(start_dim=0, end_dim=1) 81 | out = F.batch_norm( 82 | input, self.running_mean, self.running_var, None, None, 83 | self.training or not self.track_running_stats, 84 | exponential_average_factor, self.eps) 85 | 86 | if self.affine : 87 | if self.parallel_eval: 88 | if out.dim() == 4: 89 | out = out.view(-1, self.num_mc_samples, out.shape[1], 90 | out.shape[2], out.shape[3]) * self.weights[None, :, :, None, None] \ 91 | + self.biases[None, :, :, None, None] 92 | elif out.dim() == 2: 93 | out = out.view(-1, self.num_mc_samples, out.shape[1]) \ 94 | * self.weights[None, :, :] + self.biases[None, :, :] 95 | else: 96 | raise NotImplementedError 97 | 98 | elif isinstance(self.mc_sample_id, int): 99 | self.mc_sample_id = self.mc_sample_id % self.num_mc_samples 100 | weight = self.weights[self.mc_sample_id:(self.mc_sample_id+1)] 101 | bias = self.biases[self.mc_sample_id:(self.mc_sample_id+1)] 102 | 103 | if out.dim() == 4: 104 | out = torch.addcmul(bias[:, :, None, None], 105 | weight[:, :, None, None], out) 106 | elif out.dim() == 2: 107 | out = torch.addcmul(bias, weight, out) 108 | else: 109 | raise NotImplementedError 110 | else: 111 | bs = input.shape[0] 112 | idx = torch.tensor(self.mc_sample_id, device=input.device, dtype=torch.long) 113 | weight = self.weights[idx] 114 | bias = self.biases[idx] 115 | 116 | if out.dim() == 4: 117 | out = torch.addcmul(bias[:, :, None, None], 118 | weight[:, :, None, None], out) 119 | elif out.dim() == 2: 120 | out = torch.addcmul(bias, weight, out) 121 | else: 122 | raise NotImplementedError 123 | 124 | 125 | return out 126 | 127 | def extra_repr(self): 128 | return '{num_features}, {num_mc_samples},' \ 129 | 'eps={eps}, momentum={momentum}, affine={affine}, ' \ 130 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 131 | 132 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 133 | missing_keys, unexpected_keys, error_msgs): 134 | version = local_metadata.get('version', None) 135 | 136 | if (version is None or version < 2) and self.track_running_stats: 137 | num_batches_tracked_key = prefix + 'num_batches_tracked' 138 | if num_batches_tracked_key not in state_dict: 139 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 140 | 141 | super(_BayesBatchNormEMP, self)._load_from_state_dict( 142 | state_dict, prefix, local_metadata, strict, 143 | missing_keys, unexpected_keys, error_msgs) 144 | 145 | class BayesBatchNorm2dEMP(_BayesBatchNormEMP): 146 | r""" 147 | Applies Bayesian Batch Normalization over a 2D input 148 | """ 149 | def _check_input_dim(self, input): 150 | if input.dim() != 4 and input.dim() != 5: 151 | raise ValueError('expected 4D input (got {}D input)' 152 | .format(input.dim())) 153 | 154 | class BayesBatchNorm1dEMP(_BayesBatchNormEMP): 155 | def _check_input_dim(self, input): 156 | if input.dim() != 2 and input.dim() != 3: 157 | raise ValueError('expected 2D or 3D input (got {}D input)' 158 | .format(input.dim())) 159 | -------------------------------------------------------------------------------- /pytorch_cifar_models/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/vgg.py 3 | 4 | BSD 3-Clause License 5 | 6 | Copyright (c) Soumith Chintala 2016, 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ''' 34 | 35 | import sys 36 | import torch 37 | import torch.nn as nn 38 | try: 39 | from torch.hub import load_state_dict_from_url 40 | except ImportError: 41 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 42 | from functools import partial 43 | from typing import Union, List, Dict, Any, cast 44 | 45 | cifar10_pretrained_weight_urls = { 46 | 'vgg11_bn': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg11_bn-eaeebf42.pt', 47 | 'vgg13_bn': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg13_bn-c01e4a43.pt', 48 | 'vgg16_bn': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg16_bn-6ee7ea24.pt', 49 | 'vgg19_bn': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg19_bn-57191229.pt', 50 | } 51 | 52 | cifar100_pretrained_weight_urls = { 53 | 'vgg11_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg11_bn-57d0759e.pt', 54 | 'vgg13_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg13_bn-5ebe5778.pt', 55 | 'vgg16_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg16_bn-7d8c4031.pt', 56 | 'vgg19_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg19_bn-b98f7bd7.pt', 57 | } 58 | 59 | 60 | class VGG(nn.Module): 61 | 62 | def __init__( 63 | self, 64 | features: nn.Module, 65 | num_classes: int = 10, 66 | init_weights: bool = True 67 | ) -> None: 68 | super(VGG, self).__init__() 69 | self.features = features 70 | self.classifier = nn.Sequential( 71 | nn.Linear(512, 512), 72 | nn.ReLU(True), 73 | nn.Dropout(), 74 | nn.Linear(512, 512), 75 | nn.ReLU(True), 76 | nn.Dropout(), 77 | nn.Linear(512, num_classes), 78 | ) 79 | if init_weights: 80 | self._initialize_weights() 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | x = self.features(x) 84 | x = torch.flatten(x, 1) 85 | x = self.classifier(x) 86 | return x 87 | 88 | def _initialize_weights(self) -> None: 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | if m.bias is not None: 93 | nn.init.constant_(m.bias, 0) 94 | elif isinstance(m, nn.BatchNorm2d): 95 | nn.init.constant_(m.weight, 1) 96 | nn.init.constant_(m.bias, 0) 97 | elif isinstance(m, nn.Linear): 98 | nn.init.normal_(m.weight, 0, 0.01) 99 | nn.init.constant_(m.bias, 0) 100 | 101 | 102 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 103 | layers: List[nn.Module] = [] 104 | in_channels = 3 105 | for v in cfg: 106 | if v == 'M': 107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 108 | else: 109 | v = cast(int, v) 110 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 111 | if batch_norm: 112 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 113 | else: 114 | layers += [conv2d, nn.ReLU(inplace=True)] 115 | in_channels = v 116 | return nn.Sequential(*layers) 117 | 118 | 119 | cfgs: Dict[str, List[Union[str, int]]] = { 120 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 121 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 122 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 123 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 124 | } 125 | 126 | 127 | def _vgg(arch: str, cfg: str, batch_norm: bool, 128 | model_urls: Dict[str, str], 129 | pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 130 | if pretrained: 131 | kwargs['init_weights'] = False 132 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 133 | if pretrained: 134 | state_dict = load_state_dict_from_url(model_urls[arch], 135 | progress=progress) 136 | model.load_state_dict(state_dict) 137 | return model 138 | 139 | 140 | def cifar10_vgg11_bn(*args, **kwargs) -> VGG: pass 141 | def cifar10_vgg13_bn(*args, **kwargs) -> VGG: pass 142 | def cifar10_vgg16_bn(*args, **kwargs) -> VGG: pass 143 | def cifar10_vgg19_bn(*args, **kwargs) -> VGG: pass 144 | 145 | 146 | def cifar100_vgg11_bn(*args, **kwargs) -> VGG: pass 147 | def cifar100_vgg13_bn(*args, **kwargs) -> VGG: pass 148 | def cifar100_vgg16_bn(*args, **kwargs) -> VGG: pass 149 | def cifar100_vgg19_bn(*args, **kwargs) -> VGG: pass 150 | 151 | 152 | thismodule = sys.modules[__name__] 153 | for dataset in ["cifar10", "cifar100"]: 154 | for cfg, model_name in zip(["A", "B", "D", "E"], ["vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]): 155 | method_name = f"{dataset}_{model_name}" 156 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 157 | num_classes = 10 if dataset == "cifar10" else 100 158 | setattr( 159 | thismodule, 160 | method_name, 161 | partial(_vgg, 162 | arch=model_name, 163 | cfg=cfg, 164 | batch_norm=True, 165 | model_urls=model_urls, 166 | num_classes=num_classes) 167 | ) 168 | -------------------------------------------------------------------------------- /pytorch_cifar_models/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/resnet.py 3 | 4 | BSD 3-Clause License 5 | 6 | Copyright (c) Soumith Chintala 2016, 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ''' 34 | import sys 35 | import torch.nn as nn 36 | try: 37 | from torch.hub import load_state_dict_from_url 38 | except ImportError: 39 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 40 | 41 | from functools import partial 42 | from typing import Dict, Type, Any, Callable, Union, List, Optional 43 | 44 | 45 | cifar10_pretrained_weight_urls = { 46 | 'resnet20': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet20-4118986f.pt', 47 | 'resnet32': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet32-ef93fc4d.pt', 48 | 'resnet44': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet44-2a3cabcb.pt', 49 | 'resnet56': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet56-187c023a.pt', 50 | } 51 | 52 | cifar100_pretrained_weight_urls = { 53 | 'resnet20': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet20-23dac2f1.pt', 54 | 'resnet32': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet32-84213ce6.pt', 55 | 'resnet44': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet44-ffe32858.pt', 56 | 'resnet56': 'http://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet56-f2eff4c8.pt', 57 | } 58 | 59 | 60 | def conv3x3(in_planes, out_planes, stride=1): 61 | """3x3 convolution with padding""" 62 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 63 | 64 | 65 | def conv1x1(in_planes, out_planes, stride=1): 66 | """1x1 convolution""" 67 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 68 | 69 | 70 | class BasicBlock(nn.Module): 71 | expansion = 1 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None): 74 | super(BasicBlock, self).__init__() 75 | self.conv1 = conv3x3(inplanes, planes, stride) 76 | self.bn1 = nn.BatchNorm2d(planes) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.conv2 = conv3x3(planes, planes) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | identity = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | 93 | if self.downsample is not None: 94 | identity = self.downsample(x) 95 | 96 | out += identity 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class CifarResNet(nn.Module): 103 | 104 | def __init__(self, block, layers, num_classes=10): 105 | super(CifarResNet, self).__init__() 106 | self.inplanes = 16 107 | self.conv1 = conv3x3(3, 16) 108 | self.bn1 = nn.BatchNorm2d(16) 109 | self.relu = nn.ReLU(inplace=True) 110 | 111 | self.layer1 = self._make_layer(block, 16, layers[0]) 112 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 114 | 115 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 116 | self.fc = nn.Linear(64 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1) 123 | nn.init.constant_(m.bias, 0) 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | conv1x1(self.inplanes, planes * block.expansion, stride), 130 | nn.BatchNorm2d(planes * block.expansion), 131 | ) 132 | 133 | layers = [] 134 | layers.append(block(self.inplanes, planes, stride, downsample)) 135 | self.inplanes = planes * block.expansion 136 | for _ in range(1, blocks): 137 | layers.append(block(self.inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) 145 | 146 | x = self.layer1(x) 147 | x = self.layer2(x) 148 | x = self.layer3(x) 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def _resnet( 158 | arch: str, 159 | layers: List[int], 160 | model_urls: Dict[str, str], 161 | progress: bool = True, 162 | pretrained: bool = False, 163 | **kwargs: Any 164 | ) -> CifarResNet: 165 | model = CifarResNet(BasicBlock, layers, **kwargs) 166 | if pretrained: 167 | state_dict = load_state_dict_from_url(model_urls[arch], 168 | progress=progress) 169 | model.load_state_dict(state_dict) 170 | return model 171 | 172 | 173 | def cifar10_resnet20(*args, **kwargs) -> CifarResNet: pass 174 | def cifar10_resnet32(*args, **kwargs) -> CifarResNet: pass 175 | def cifar10_resnet44(*args, **kwargs) -> CifarResNet: pass 176 | def cifar10_resnet56(*args, **kwargs) -> CifarResNet: pass 177 | 178 | 179 | def cifar100_resnet20(*args, **kwargs) -> CifarResNet: pass 180 | def cifar100_resnet32(*args, **kwargs) -> CifarResNet: pass 181 | def cifar100_resnet44(*args, **kwargs) -> CifarResNet: pass 182 | def cifar100_resnet56(*args, **kwargs) -> CifarResNet: pass 183 | 184 | 185 | thismodule = sys.modules[__name__] 186 | for dataset in ["cifar10", "cifar100"]: 187 | for layers, model_name in zip([[3]*3, [5]*3, [7]*3, [9]*3], 188 | ["resnet20", "resnet32", "resnet44", "resnet56"]): 189 | method_name = f"{dataset}_{model_name}" 190 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 191 | num_classes = 10 if dataset == "cifar10" else 100 192 | setattr( 193 | thismodule, 194 | method_name, 195 | partial(_resnet, 196 | arch=model_name, 197 | layers=layers, 198 | model_urls=model_urls, 199 | num_classes=num_classes) 200 | ) 201 | -------------------------------------------------------------------------------- /laplace/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn.utils import parameters_to_vector 7 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 8 | from torch.distributions.multivariate_normal import _precision_to_scale_tril 9 | 10 | 11 | __all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', 12 | 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision'] 13 | 14 | 15 | def get_nll(out_dist, targets): 16 | return F.nll_loss(torch.log(out_dist), targets) 17 | 18 | 19 | @torch.no_grad() 20 | def validate(laplace, val_loader, pred_type='glm', link_approx='probit', n_samples=100): 21 | laplace.model.eval() 22 | output_means, output_vars = list(), list() 23 | targets = list() 24 | for X, y in val_loader: 25 | X, y = X.to(laplace._device), y.to(laplace._device) 26 | out = laplace( 27 | X, pred_type=pred_type, 28 | link_approx=link_approx, 29 | n_samples=n_samples) 30 | 31 | if type(out) == tuple: 32 | output_means.append(out[0]) 33 | output_vars.append(out[1]) 34 | else: 35 | output_means.append(out) 36 | 37 | targets.append(y) 38 | 39 | if len(output_vars) == 0: 40 | return torch.cat(output_means, dim=0), torch.cat(targets, dim=0) 41 | return ((torch.cat(output_means, dim=0), torch.cat(output_vars, dim=0)), 42 | torch.cat(targets, dim=0)) 43 | 44 | 45 | def parameters_per_layer(model): 46 | """Get number of parameters per layer. 47 | 48 | Parameters 49 | ---------- 50 | model : torch.nn.Module 51 | 52 | Returns 53 | ------- 54 | params_per_layer : list[int] 55 | """ 56 | return [np.prod(p.shape) for p in model.parameters()] 57 | 58 | 59 | def invsqrt_precision(M): 60 | """Compute ``M^{-0.5}`` as a tridiagonal matrix. 61 | 62 | Parameters 63 | ---------- 64 | M : torch.Tensor 65 | 66 | Returns 67 | ------- 68 | M_invsqrt : torch.Tensor 69 | """ 70 | return _precision_to_scale_tril(M) 71 | 72 | 73 | def _is_batchnorm(module): 74 | if isinstance(module, BatchNorm1d) or \ 75 | isinstance(module, BatchNorm2d) or \ 76 | isinstance(module, BatchNorm3d): 77 | return True 78 | return False 79 | 80 | 81 | def _is_valid_scalar(scalar: Union[float, int, torch.Tensor]) -> bool: 82 | if np.isscalar(scalar) and np.isreal(scalar): 83 | return True 84 | elif torch.is_tensor(scalar) and scalar.ndim <= 1: 85 | if scalar.ndim == 1 and len(scalar) != 1: 86 | return False 87 | return True 88 | return False 89 | 90 | 91 | def kron(t1, t2): 92 | """Computes the Kronecker product between two tensors. 93 | 94 | Parameters 95 | ---------- 96 | t1 : torch.Tensor 97 | t2 : torch.Tensor 98 | 99 | Returns 100 | ------- 101 | kron_product : torch.Tensor 102 | """ 103 | t1_height, t1_width = t1.size() 104 | t2_height, t2_width = t2.size() 105 | out_height = t1_height * t2_height 106 | out_width = t1_width * t2_width 107 | 108 | tiled_t2 = t2.repeat(t1_height, t1_width) 109 | expanded_t1 = ( 110 | t1.unsqueeze(2) 111 | .unsqueeze(3) 112 | .repeat(1, t2_height, t2_width, 1) 113 | .view(out_height, out_width) 114 | ) 115 | 116 | return expanded_t1 * tiled_t2 117 | 118 | 119 | def diagonal_add_scalar(X, value): 120 | """Add scalar value `value` to diagonal of `X`. 121 | 122 | Parameters 123 | ---------- 124 | X : torch.Tensor 125 | value : torch.Tensor or float 126 | 127 | Returns 128 | ------- 129 | X_add_scalar : torch.Tensor 130 | """ 131 | if not X.device == torch.device('cpu'): 132 | indices = torch.cuda.LongTensor([[i, i] for i in range(X.shape[0])]) 133 | else: 134 | indices = torch.LongTensor([[i, i] for i in range(X.shape[0])]) 135 | values = X.new_ones(X.shape[0]).mul(value) 136 | return X.index_put(tuple(indices.t()), values, accumulate=True) 137 | 138 | 139 | def symeig(M): 140 | """Symetric eigendecomposition avoiding failure cases by 141 | adding and removing jitter to the diagonal. 142 | 143 | Parameters 144 | ---------- 145 | M : torch.Tensor 146 | 147 | Returns 148 | ------- 149 | L : torch.Tensor 150 | eigenvalues 151 | W : torch.Tensor 152 | eigenvectors 153 | """ 154 | try: 155 | L, W = torch.linalg.eigh(M, UPLO='U') 156 | except RuntimeError: # did not converge 157 | logging.info('SYMEIG: adding jitter, did not converge.') 158 | # use W L W^T + I = W (L + I) W^T 159 | M = M + torch.eye(M.shape[0], device=M.device) 160 | try: 161 | L, W = torch.linalg.eigh(M, UPLO='U') 162 | L -= 1. 163 | except RuntimeError: 164 | stats = f'diag: {M.diagonal()}, max: {M.abs().max()}, ' 165 | stats = stats + f'min: {M.abs().min()}, mean: {M.abs().mean()}' 166 | logging.info(f'SYMEIG: adding jitter failed. Stats: {stats}') 167 | exit() 168 | # eigenvalues of symeig at least 0 169 | L = L.clamp(min=0.0) 170 | L = torch.nan_to_num(L) 171 | W = torch.nan_to_num(W) 172 | return L, W 173 | 174 | 175 | def block_diag(blocks): 176 | """Compose block-diagonal matrix of individual blocks. 177 | 178 | Parameters 179 | ---------- 180 | blocks : list[torch.Tensor] 181 | 182 | Returns 183 | ------- 184 | M : torch.Tensor 185 | """ 186 | P = sum([b.shape[0] for b in blocks]) 187 | M = torch.zeros(P, P) 188 | p_cur = 0 189 | for block in blocks: 190 | p_block = block.shape[0] 191 | M[p_cur:p_cur+p_block, p_cur:p_cur+p_block] = block 192 | p_cur += p_block 193 | return M 194 | 195 | 196 | def expand_prior_precision(prior_prec, model): 197 | """Expand prior precision to match the shape of the model parameters. 198 | 199 | Parameters 200 | ---------- 201 | prior_prec : torch.Tensor 1-dimensional 202 | prior precision 203 | model : torch.nn.Module 204 | torch model with parameters that are regularized by prior_prec 205 | 206 | Returns 207 | ------- 208 | expanded_prior_prec : torch.Tensor 209 | expanded prior precision has the same shape as model parameters 210 | """ 211 | theta = parameters_to_vector(model.parameters()) 212 | device, P = theta.device, len(theta) 213 | assert prior_prec.ndim == 1 214 | if len(prior_prec) == 1: # scalar 215 | return torch.ones(P, device=device) * prior_prec 216 | elif len(prior_prec) == P: # full diagonal 217 | return prior_prec.to(device) 218 | else: 219 | return torch.cat([delta * torch.ones_like(m).flatten() for delta, m 220 | in zip(prior_prec, model.parameters())]) 221 | 222 | 223 | def normal_samples(mean, var, n_samples, generator=None): 224 | """Produce samples from a batch of Normal distributions either parameterized 225 | by a diagonal or full covariance given by `var`. 226 | 227 | Parameters 228 | ---------- 229 | mean : torch.Tensor 230 | `(batch_size, output_dim)` 231 | var : torch.Tensor 232 | (co)variance of the Normal distribution 233 | `(batch_size, output_dim, output_dim)` or `(batch_size, output_dim)` 234 | generator : torch.Generator 235 | random number generator 236 | """ 237 | assert mean.ndim == 2, 'Invalid input shape of mean, should be 2-dimensional.' 238 | _, output_dim = mean.shape 239 | randn_samples = torch.randn((output_dim, n_samples), device=mean.device, generator=generator) 240 | 241 | if mean.shape == var.shape: 242 | # diagonal covariance 243 | scaled_samples = var.sqrt().unsqueeze(-1) * randn_samples.unsqueeze(0) 244 | return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1)) 245 | elif mean.shape == var.shape[:2] and var.shape[-1] == mean.shape[1]: 246 | # full covariance 247 | scale = torch.linalg.cholesky(var) 248 | scaled_samples = torch.matmul(scale, randn_samples.unsqueeze(0)) # expand batch dim 249 | return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1)) 250 | else: 251 | raise ValueError('Invalid input shapes.') 252 | -------------------------------------------------------------------------------- /laplace/curvature/asdl.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import torch 4 | 5 | from asdfghjkl import FISHER_EXACT, FISHER_MC, COV 6 | from asdfghjkl import SHAPE_KRON, SHAPE_DIAG, SHAPE_FULL 7 | from asdfghjkl import fisher_for_cross_entropy 8 | from asdfghjkl.hessian import hessian_eigenvalues, hessian_for_loss 9 | from asdfghjkl.gradient import batch_gradient 10 | 11 | from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface 12 | from laplace.utils import Kron, _is_batchnorm 13 | 14 | EPS = 1e-6 15 | 16 | 17 | class AsdlInterface(CurvatureInterface): 18 | """Interface for asdfghjkl backend. 19 | """ 20 | 21 | def jacobians(self, x): 22 | """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\) 23 | using asdfghjkl's gradient per output dimension. 24 | 25 | Parameters 26 | ---------- 27 | x : torch.Tensor 28 | input data `(batch, input_shape)` on compatible device with model. 29 | 30 | Returns 31 | ------- 32 | Js : torch.Tensor 33 | Jacobians `(batch, parameters, outputs)` 34 | f : torch.Tensor 35 | output function `(batch, outputs)` 36 | """ 37 | Js = list() 38 | for i in range(self.model.output_size): 39 | def loss_fn(outputs, targets): 40 | return outputs[:, i].sum() 41 | 42 | f = batch_gradient(self.model, loss_fn, x, None).detach() 43 | Jk = _get_batch_grad(self.model) 44 | if self.subnetwork_indices is not None: 45 | Jk = Jk[:, self.subnetwork_indices] 46 | Js.append(Jk) 47 | Js = torch.stack(Js, dim=1) 48 | return Js, f 49 | 50 | def gradients(self, x, y): 51 | """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter 52 | \\(\\theta\\) using asdfghjkl's backend. 53 | 54 | Parameters 55 | ---------- 56 | x : torch.Tensor 57 | input data `(batch, input_shape)` on compatible device with model. 58 | y : torch.Tensor 59 | 60 | Returns 61 | ------- 62 | loss : torch.Tensor 63 | Gs : torch.Tensor 64 | gradients `(batch, parameters)` 65 | """ 66 | f = batch_gradient(self.model, self.lossfunc, x, y).detach() 67 | Gs = _get_batch_grad(self._model) 68 | if self.subnetwork_indices is not None: 69 | Gs = Gs[:, self.subnetwork_indices] 70 | loss = self.lossfunc(f, y) 71 | return Gs, loss 72 | 73 | @property 74 | def _ggn_type(self): 75 | raise NotImplementedError 76 | 77 | def _get_kron_factors(self, curv, M): 78 | kfacs = list() 79 | for module in curv._model.modules(): 80 | if _is_batchnorm(module): 81 | warnings.warn('BatchNorm unsupported for Kron, ignore.') 82 | continue 83 | 84 | stats = getattr(module, self._ggn_type, None) 85 | if stats is None: 86 | continue 87 | if hasattr(module, 'bias') and module.bias is not None: 88 | # split up bias and weights 89 | kfacs.append([stats.kron.B, stats.kron.A[:-1, :-1]]) 90 | kfacs.append([stats.kron.B * stats.kron.A[-1, -1] / M]) 91 | elif hasattr(module, 'weight'): 92 | p, q = np.prod(stats.kron.B.shape), np.prod(stats.kron.A.shape) 93 | if p == q == 1: 94 | kfacs.append([stats.kron.B * stats.kron.A]) 95 | else: 96 | kfacs.append([stats.kron.B, stats.kron.A]) 97 | else: 98 | raise ValueError(f'Whats happening with {module}?') 99 | return Kron(kfacs) 100 | 101 | @staticmethod 102 | def _rescale_kron_factors(kron, N): 103 | for F in kron.kfacs: 104 | if len(F) == 2: 105 | F[1] *= 1/N 106 | return kron 107 | 108 | def diag(self, X, y, **kwargs): 109 | with torch.no_grad(): 110 | if self.last_layer: 111 | f, X = self.model.forward_with_features(X) 112 | else: 113 | f = self.model(X) 114 | loss = self.lossfunc(f, y) 115 | curv = fisher_for_cross_entropy(self._model, self._ggn_type, SHAPE_DIAG, 116 | inputs=X, targets=y) 117 | diag_ggn = curv.matrices_to_vector(None) 118 | if self.subnetwork_indices is not None: 119 | diag_ggn = diag_ggn[self.subnetwork_indices] 120 | return self.factor * loss, self.factor * diag_ggn 121 | 122 | def kron(self, X, y, N, **wkwargs): 123 | with torch.no_grad(): 124 | if self.last_layer: 125 | f, X = self.model.forward_with_features(X) 126 | else: 127 | f = self.model(X) 128 | loss = self.lossfunc(f, y) 129 | curv = fisher_for_cross_entropy(self._model, self._ggn_type, SHAPE_KRON, 130 | inputs=X, targets=y) 131 | M = len(y) 132 | kron = self._get_kron_factors(curv, M) 133 | kron = self._rescale_kron_factors(kron, N) 134 | return self.factor * loss, self.factor * kron 135 | 136 | 137 | class AsdlHessian(AsdlInterface): 138 | 139 | def __init__(self, model, likelihood, last_layer=False, low_rank=10): 140 | super().__init__(model, likelihood, last_layer) 141 | self.low_rank = low_rank 142 | 143 | @property 144 | def _ggn_type(self): 145 | raise NotImplementedError() 146 | 147 | def full(self, x, y, **kwargs): 148 | hessian_for_loss(self.model, self.lossfunc, SHAPE_FULL, x, y) 149 | H = self._model.hessian.data 150 | loss = self.lossfunc(self.model(x), y).detach() 151 | return self.factor * loss, self.factor * H 152 | 153 | def eig_lowrank(self, data_loader): 154 | # compute truncated eigendecomposition of the Hessian, only keep eigvals > EPS 155 | eigvals, eigvecs = hessian_eigenvalues(self.model, self.lossfunc, data_loader, 156 | top_n=self.low_rank, max_iters=self.low_rank*10) 157 | eigvals = torch.from_numpy(np.array(eigvals)) 158 | mask = (eigvals > EPS) 159 | eigvecs = torch.stack([torch.cat([p.flatten() for p in params]) 160 | for params in eigvecs], dim=1)[:, mask] 161 | device = eigvecs.device 162 | eigvals = eigvals[mask].to(eigvecs.dtype).to(device) 163 | loss = sum([self.lossfunc(self.model(x.to(device)).detach(), y.to(device)) for x, y in data_loader]) 164 | return eigvecs, self.factor * eigvals, self.factor * loss 165 | 166 | 167 | class AsdlGGN(AsdlInterface, GGNInterface): 168 | """Implementation of the `GGNInterface` using asdfghjkl. 169 | """ 170 | def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): 171 | if likelihood != 'classification': 172 | raise ValueError('This backend only supports classification currently.') 173 | super().__init__(model, likelihood, last_layer, subnetwork_indices) 174 | self.stochastic = stochastic 175 | 176 | @property 177 | def _ggn_type(self): 178 | return FISHER_MC if self.stochastic else FISHER_EXACT 179 | 180 | 181 | class AsdlEF(AsdlInterface, EFInterface): 182 | """Implementation of the `EFInterface` using asdfghjkl. 183 | """ 184 | def __init__(self, model, likelihood, last_layer=False): 185 | if likelihood != 'classification': 186 | raise ValueError('This backend only supports classification currently.') 187 | super().__init__(model, likelihood, last_layer) 188 | 189 | @property 190 | def _ggn_type(self): 191 | return COV 192 | 193 | 194 | def _flatten_after_batch(tensor: torch.Tensor): 195 | if tensor.ndim == 1: 196 | return tensor.unsqueeze(-1) 197 | else: 198 | return tensor.flatten(start_dim=1) 199 | 200 | 201 | def _get_batch_grad(model): 202 | batch_grads = list() 203 | for module in model.modules(): 204 | if hasattr(module, 'op_results'): 205 | res = module.op_results['batch_grads'] 206 | if 'weight' in res: 207 | batch_grads.append(_flatten_after_batch(res['weight'])) 208 | if 'bias' in res: 209 | batch_grads.append(_flatten_after_batch(res['bias'])) 210 | if len(set(res.keys()) - {'weight', 'bias'}) > 0: 211 | raise ValueError(f'Invalid parameter keys {res.keys()}') 212 | return torch.cat(batch_grads, dim=1) 213 | -------------------------------------------------------------------------------- /scalablebdl/mean_field/batchnorm.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import Module, Parameter 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | from .utils import MulExpAddFunction 9 | 10 | class _BayesBatchNormMF(Module): 11 | r""" 12 | Applies Bayesian Batch Normalization over a 2D or 3D input 13 | """ 14 | __constants__ = ['track_running_stats', 15 | 'momentum', 'eps', 'weight', 'bias', 16 | 'running_mean', 'running_var', 'num_batches_tracked', 17 | 'num_features', 'affine', 'num_mc_samples'] 18 | 19 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 20 | track_running_stats=True, deterministic=False, num_mc_samples=20): 21 | super(_BayesBatchNormMF, self).__init__() 22 | self.num_features = num_features 23 | self.eps = eps 24 | self.momentum = momentum 25 | self.affine = affine 26 | self.track_running_stats = track_running_stats 27 | self.deterministic = deterministic 28 | self.num_mc_samples = num_mc_samples 29 | self.parallel_eval = False 30 | if self.affine: 31 | self.weight_mu = Parameter(torch.Tensor(num_features)) 32 | self.weight_psi = Parameter(torch.Tensor(num_features)) 33 | 34 | self.bias_mu = Parameter(torch.Tensor(num_features)) 35 | self.bias_psi = Parameter(torch.Tensor(num_features)) 36 | else: 37 | self.register_parameter('weight_mu', None) 38 | self.register_parameter('weight_psi', None) 39 | self.register_parameter('bias_mu', None) 40 | self.register_parameter('bias_psi', None) 41 | if self.track_running_stats: 42 | self.register_buffer('running_mean', torch.zeros(num_features)) 43 | self.register_buffer('running_var', torch.ones(num_features)) 44 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 45 | else: 46 | self.register_parameter('running_mean', None) 47 | self.register_parameter('running_var', None) 48 | self.register_parameter('num_batches_tracked', None) 49 | self.reset_parameters() 50 | 51 | self.weight_size = list(self.weight_mu.shape) if self.affine else None 52 | self.bias_size = list(self.bias_mu.shape) if self.affine else None 53 | self.mul_exp_add = MulExpAddFunction.apply 54 | 55 | def reset_running_stats(self): 56 | if self.track_running_stats: 57 | self.running_mean.zero_() 58 | self.running_var.fill_(1) 59 | self.num_batches_tracked.zero_() 60 | 61 | def reset_parameters(self): 62 | self.reset_running_stats() 63 | if self.affine: 64 | self.weight_mu.data.fill_(1) 65 | self.weight_psi.data.uniform_(-6, -5) 66 | self.bias_mu.data.zero_() 67 | self.bias_psi.data.uniform_(-6, -5) 68 | 69 | def _check_input_dim(self, input): 70 | raise NotImplementedError 71 | 72 | def forward(self, input): 73 | self._check_input_dim(input) 74 | 75 | if self.momentum is None: 76 | exponential_average_factor = 0.0 77 | else: 78 | exponential_average_factor = self.momentum 79 | 80 | if self.training and self.track_running_stats: 81 | if self.num_batches_tracked is not None: 82 | self.num_batches_tracked += 1 83 | if self.momentum is None: 84 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 85 | else: 86 | exponential_average_factor = self.momentum 87 | 88 | if self.parallel_eval: 89 | if input.dim() == 4: 90 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1, 1, 1) 91 | elif input.dim() == 2: 92 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1) 93 | input = input.flatten(start_dim=0, end_dim=1) 94 | out = F.batch_norm( 95 | input, self.running_mean, self.running_var, None, None, 96 | self.training or not self.track_running_stats, 97 | exponential_average_factor, self.eps) 98 | 99 | if self.affine : 100 | if self.deterministic: 101 | weight = self.weight_mu.unsqueeze(0) 102 | bias = self.bias_mu.unsqueeze(0) 103 | 104 | if out.dim() == 4: 105 | out = torch.addcmul(bias[:, :, None, None], 106 | weight[:, :, None, None], out) 107 | elif out.dim() == 2: 108 | out = torch.addcmul(bias, weight, out) 109 | else: 110 | raise NotImplementedError 111 | elif not self.parallel_eval: 112 | bs = input.shape[0] 113 | weight = self.mul_exp_add(torch.randn(bs, *self.weight_size, 114 | device=input.device, 115 | dtype=input.dtype), 116 | self.weight_psi, self.weight_mu) 117 | 118 | bias = self.mul_exp_add(torch.randn(bs, *self.bias_size, 119 | device=input.device, 120 | dtype=input.dtype), 121 | self.bias_psi, self.bias_mu) 122 | if out.dim() == 4: 123 | out = torch.addcmul(bias[:, :, None, None], 124 | weight[:, :, None, None], out) 125 | elif out.dim() == 2: 126 | out = torch.addcmul(bias, weight, out) 127 | else: 128 | raise NotImplementedError 129 | else: 130 | weight = self.mul_exp_add(torch.randn(self.num_mc_samples, 131 | *self.weight_size, 132 | device=input.device, 133 | dtype=input.dtype), 134 | self.weight_psi, self.weight_mu) 135 | bias = self.mul_exp_add(torch.randn(self.num_mc_samples, 136 | *self.bias_size, 137 | device=input.device, 138 | dtype=input.dtype), 139 | self.bias_psi, self.bias_mu) 140 | if out.dim() == 4: 141 | out = out.view(-1, self.num_mc_samples, out.shape[1], 142 | out.shape[2], out.shape[3]) * weight[None, :, :, None, None] \ 143 | + bias[None, :, :, None, None] 144 | elif out.dim() == 2: 145 | out = out.view(-1, self.num_mc_samples, out.shape[1]) \ 146 | * weight[None, :, :] + bias[None, :, :] 147 | else: 148 | raise NotImplementedError 149 | 150 | return out 151 | 152 | def extra_repr(self): 153 | return '{num_features}, {num_mc_samples}, ' \ 154 | 'eps={eps}, momentum={momentum}, affine={affine}, ' \ 155 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 156 | 157 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 158 | missing_keys, unexpected_keys, error_msgs): 159 | version = local_metadata.get('version', None) 160 | 161 | if (version is None or version < 2) and self.track_running_stats: 162 | num_batches_tracked_key = prefix + 'num_batches_tracked' 163 | if num_batches_tracked_key not in state_dict: 164 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 165 | 166 | super(_BayesBatchNormMF, self)._load_from_state_dict( 167 | state_dict, prefix, local_metadata, strict, 168 | missing_keys, unexpected_keys, error_msgs) 169 | 170 | class BayesBatchNorm2dMF(_BayesBatchNormMF): 171 | r""" 172 | Applies Bayesian Batch Normalization over a 2D input 173 | """ 174 | def _check_input_dim(self, input): 175 | if input.dim() != 4 and input.dim() != 5: 176 | raise ValueError('expected 4D input (got {}D input)' 177 | .format(input.dim())) 178 | 179 | class BayesBatchNorm1dMF(_BayesBatchNormMF): 180 | def _check_input_dim(self, input): 181 | if input.dim() != 2 and input.dim() != 3: 182 | raise ValueError('expected 2D or 3D input (got {}D input)' 183 | .format(input.dim())) 184 | -------------------------------------------------------------------------------- /laplace/subnetlaplace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import MultivariateNormal 3 | 4 | from laplace.baselaplace import ParametricLaplace, FullLaplace, DiagLaplace 5 | 6 | 7 | __all__ = ['SubnetLaplace', 'FullSubnetLaplace', 'DiagSubnetLaplace'] 8 | 9 | 10 | class SubnetLaplace(ParametricLaplace): 11 | """Class for subnetwork Laplace, which computes the Laplace approximation over just a subset 12 | of the model parameters (i.e. a subnetwork within the neural network), as proposed in [1]. 13 | Subnetwork Laplace can only be used with either a full or a diagonal Hessian approximation. 14 | 15 | A Laplace approximation is represented by a MAP which is given by the 16 | `model` parameter and a posterior precision or covariance specifying 17 | a Gaussian distribution \\(\\mathcal{N}(\\theta_{MAP}, P^{-1})\\). 18 | Here, only a subset of the model parameters (i.e. a subnetwork of the 19 | neural network) are treated probabilistically. 20 | The goal of this class is to compute the posterior precision \\(P\\) 21 | which sums as 22 | \\[ 23 | P = \\sum_{n=1}^N \\nabla^2_\\theta \\log p(\\mathcal{D}_n \\mid \\theta) 24 | \\vert_{\\theta_{MAP}} + \\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}}. 25 | \\] 26 | The prior is assumed to be Gaussian and therefore we have a simple form for 27 | \\(\\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}} = P_0 \\). 28 | In particular, we assume a scalar or diagonal prior precision so that in 29 | all cases \\(P_0 = \\textrm{diag}(p_0)\\) and the structure of \\(p_0\\) can be varied. 30 | 31 | The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood 32 | Hessian approximation and hence posterior precision. Based on the chosen `backend` 33 | parameter, the full approximation can be, for example, a generalized Gauss-Newton 34 | matrix. Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\). 35 | See `FullLaplace` and `BaseLaplace` for the full interface. 36 | 37 | References 38 | ---------- 39 | [1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. 40 | [*Bayesian Deep Learning via Subnetwork Inference*](https://arxiv.org/abs/2010.14689). 41 | ICML 2021. 42 | 43 | Parameters 44 | ---------- 45 | model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` 46 | likelihood : {'classification', 'regression'} 47 | determines the log likelihood Hessian approximation 48 | subnetwork_indices : torch.LongTensor 49 | indices of the vectorized model parameters 50 | (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`) 51 | that define the subnetwork to apply the Laplace approximation over 52 | sigma_noise : torch.Tensor or float, default=1 53 | observation noise for the regression setting; must be 1 for classification 54 | prior_precision : torch.Tensor or float, default=1 55 | prior precision of a Gaussian prior (= weight decay); 56 | can be scalar, per-layer, or diagonal in the most general case 57 | prior_mean : torch.Tensor or float, default=0 58 | prior mean of a Gaussian prior, useful for continual learning 59 | temperature : float, default=1 60 | temperature of the likelihood; lower temperature leads to more 61 | concentrated posterior and vice versa. 62 | backend : subclasses of `laplace.curvature.CurvatureInterface` 63 | backend for access to curvature/Hessian approximations 64 | backend_kwargs : dict, default=None 65 | arguments passed to the backend on initialization, for example to 66 | set the number of MC samples for stochastic approximations. 67 | """ 68 | def __init__(self, model, likelihood, subnetwork_indices, sigma_noise=1., prior_precision=1., 69 | prior_mean=0., temperature=1., backend=None, backend_kwargs=None): 70 | self.H = None 71 | super().__init__(model, likelihood, sigma_noise=sigma_noise, 72 | prior_precision=prior_precision, prior_mean=prior_mean, 73 | temperature=temperature, backend=backend, backend_kwargs=backend_kwargs) 74 | # check validity of subnetwork indices and pass them to backend 75 | self._check_subnetwork_indices(subnetwork_indices) 76 | self.backend.subnetwork_indices = subnetwork_indices 77 | self.n_params_subnet = len(subnetwork_indices) 78 | self._init_H() 79 | 80 | def _check_subnetwork_indices(self, subnetwork_indices): 81 | """Check that subnetwork indices are valid indices of the vectorized model parameters 82 | (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`). 83 | """ 84 | if subnetwork_indices is None: 85 | raise ValueError('Subnetwork indices cannot be None.') 86 | elif not ((isinstance(subnetwork_indices, torch.LongTensor) or 87 | isinstance(subnetwork_indices, torch.cuda.LongTensor)) and 88 | subnetwork_indices.numel() > 0 and len(subnetwork_indices.shape) == 1): 89 | raise ValueError('Subnetwork indices must be non-empty 1-dimensional torch.LongTensor.') 90 | elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and 91 | len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0): 92 | raise ValueError(f'Subnetwork indices must lie between 0 and n_params={self.n_params}.') 93 | elif not (len(subnetwork_indices.unique()) == len(subnetwork_indices)): 94 | raise ValueError('Subnetwork indices must not contain duplicate entries.') 95 | 96 | @property 97 | def prior_precision_diag(self): 98 | """Obtain the diagonal prior precision \\(p_0\\) constructed from either 99 | a scalar or diagonal prior precision. 100 | 101 | Returns 102 | ------- 103 | prior_precision_diag : torch.Tensor 104 | """ 105 | if len(self.prior_precision) == 1: # scalar 106 | return self.prior_precision * torch.ones(self.n_params_subnet, device=self._device) 107 | 108 | elif len(self.prior_precision) == self.n_params_subnet: # diagonal 109 | return self.prior_precision 110 | 111 | else: 112 | raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.') 113 | 114 | @property 115 | def mean_subnet(self): 116 | return self.mean[self.backend.subnetwork_indices] 117 | 118 | @property 119 | def scatter(self): 120 | delta = (self.mean_subnet - self.prior_mean) 121 | return (delta * self.prior_precision_diag) @ delta 122 | 123 | def assemble_full_samples(self, subnet_samples): 124 | full_samples = self.mean.repeat(subnet_samples.shape[0], 1) 125 | full_samples[:, self.backend.subnetwork_indices] = subnet_samples 126 | return full_samples 127 | 128 | 129 | class FullSubnetLaplace(SubnetLaplace, FullLaplace): 130 | """Subnetwork Laplace approximation with full, i.e., dense, log likelihood Hessian 131 | approximation and hence posterior precision. Based on the chosen `backend` parameter, 132 | the full approximation can be, for example, a generalized Gauss-Newton matrix. 133 | Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\). 134 | See `FullLaplace`, `SubnetLaplace`, and `BaseLaplace` for the full interface. 135 | """ 136 | # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) 137 | _key = ('subnetwork', 'full') 138 | 139 | def _init_H(self): 140 | self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device) 141 | 142 | def sample(self, n_samples=100): 143 | # sample only subnetwork parameters and set all other parameters to their MAP estimates 144 | dist = MultivariateNormal(loc=self.mean_subnet, scale_tril=self.posterior_scale) 145 | subnet_samples = dist.sample((n_samples,)) 146 | return self.assemble_full_samples(subnet_samples) 147 | 148 | 149 | class DiagSubnetLaplace(SubnetLaplace, DiagLaplace): 150 | """Subnetwork Laplace approximation with diagonal log likelihood Hessian approximation 151 | and hence posterior precision. 152 | Mathematically, we have \\(P \\approx \\textrm{diag}(P)\\). 153 | See `DiagLaplace`, `SubnetLaplace`, and `BaseLaplace` for the full interface. 154 | """ 155 | # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) 156 | _key = ('subnetwork', 'diag') 157 | 158 | def _init_H(self): 159 | self.H = torch.zeros(self.n_params_subnet, device=self._device) 160 | 161 | def _check_jacobians(self, Js): 162 | if not isinstance(Js, torch.Tensor): 163 | raise ValueError('Jacobians have to be torch.Tensor.') 164 | if not Js.device == self._device: 165 | raise ValueError('Jacobians need to be on the same device as Laplace.') 166 | m, k, p = Js.size() 167 | if p != self.n_params_subnet: 168 | raise ValueError('Invalid Jacobians shape for Laplace posterior approx.') 169 | 170 | def sample(self, n_samples=100): 171 | # sample only subnetwork parameters and set all other parameters to their MAP estimates 172 | samples = torch.randn(n_samples, self.n_params_subnet, device=self._device) 173 | samples = samples * self.posterior_scale.reshape(1, self.n_params_subnet) 174 | subnet_samples = self.mean_subnet.reshape(1, self.n_params_subnet) + samples 175 | return self.assemble_full_samples(subnet_samples) 176 | -------------------------------------------------------------------------------- /scalablebdl/mean_field/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.init as init 6 | from torch.nn import Module, Parameter 7 | import torch.nn.functional as F 8 | from torch.nn.modules.utils import _pair 9 | 10 | from .utils import MulExpAddFunction 11 | 12 | class _BayesConvNdMF(Module): 13 | r""" 14 | Applies Bayesian Convolution 15 | """ 16 | __constants__ = ['stride', 'padding', 'dilation', 17 | 'groups', 'bias', 'in_channels', 18 | 'out_channels', 'kernel_size'] 19 | 20 | def __init__(self, in_channels, out_channels, kernel_size, stride, 21 | padding, dilation, groups, bias): 22 | super(_BayesConvNdMF, self).__init__() 23 | if in_channels % groups != 0: 24 | raise ValueError('in_channels must be divisible by groups') 25 | if out_channels % groups != 0: 26 | raise ValueError('out_channels must be divisible by groups') 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.kernel_size = kernel_size 30 | if isinstance(self.kernel_size, int): 31 | self.kernel_size = (self.kernel_size, self.kernel_size) 32 | self.stride = stride 33 | self.padding = padding 34 | self.dilation = dilation 35 | self.groups = groups 36 | 37 | self.weight_mu = Parameter(torch.Tensor( 38 | out_channels, in_channels // groups, *self.kernel_size)) 39 | self.weight_psi = Parameter(torch.Tensor( 40 | out_channels, in_channels // groups, *self.kernel_size)) 41 | 42 | if bias is None or bias is False : 43 | self.bias = False 44 | else: 45 | self.bias = True 46 | 47 | if self.bias: 48 | self.bias_mu = Parameter(torch.Tensor(out_channels)) 49 | self.bias_psi = Parameter(torch.Tensor(out_channels)) 50 | else: 51 | self.register_parameter('bias_mu', None) 52 | self.register_parameter('bias_psi', None) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | n = self.in_channels 57 | n *= np.prod(list(self.kernel_size)) 58 | stdv = 1.0 / math.sqrt(n) 59 | self.weight_mu.data.uniform_(-stdv, stdv) 60 | self.weight_psi.data.uniform_(-6, -5) 61 | 62 | if self.bias : 63 | self.bias_mu.data.uniform_(-stdv, stdv) 64 | self.bias_psi.data.uniform_(-6, -5) 65 | 66 | def extra_repr(self): 67 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 68 | ', stride={stride}') 69 | s += ', padding={padding}' 70 | s += ', dilation={dilation}' 71 | s += ', groups={groups}' 72 | s += ', bias=False' 73 | return s.format(**self.__dict__) 74 | 75 | def __setstate__(self, state): 76 | super(_BayesConvNdMF, self).__setstate__(state) 77 | 78 | class BayesConv2dMF(_BayesConvNdMF): 79 | r""" 80 | Applies Bayesian Convolution for 2D inputs 81 | """ 82 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 83 | padding=0, dilation=1, groups=1, bias=False, 84 | deterministic=False, num_mc_samples=None): 85 | super(BayesConv2dMF, self).__init__( 86 | in_channels, out_channels, kernel_size, stride, 87 | padding, dilation, groups, bias) 88 | self.deterministic = deterministic 89 | self.num_mc_samples = num_mc_samples 90 | self.parallel_eval = False 91 | self.weight_size = list(self.weight_mu.shape) 92 | self.bias_size = list(self.bias_mu.shape) if self.bias else None 93 | self.mul_exp_add = MulExpAddFunction.apply 94 | 95 | self.local_reparam = False 96 | self.flipout = False 97 | self.single_eps = False 98 | 99 | def forward(self, input): 100 | r""" 101 | Overriden. 102 | """ 103 | if self.deterministic: 104 | out = F.conv2d(input, weight=self.weight_mu, bias=self.bias_mu, 105 | stride=self.stride, dilation=self.dilation, 106 | groups=self.groups, padding=self.padding) 107 | elif not self.parallel_eval: 108 | if self.single_eps: 109 | assert not self.bias 110 | weight = torch.randn_like(self.weight_mu).mul_(self.weight_psi.exp()).add_(self.weight_mu) 111 | out = F.conv2d(input, weight=weight, bias=None, 112 | stride=self.stride, dilation=self.dilation, 113 | groups=self.groups, padding=self.padding) 114 | elif self.local_reparam: 115 | assert not self.bias 116 | act_mu = F.conv2d(input, weight=self.weight_mu, bias=None, 117 | stride=self.stride, padding=self.padding, 118 | dilation=self.dilation, groups=self.groups) 119 | act_var = F.conv2d(input**2, 120 | weight=(self.weight_psi*2).exp_(), bias=None, 121 | stride=self.stride, padding=self.padding, 122 | dilation=self.dilation, groups=self.groups) 123 | act_std = act_var.clamp(1e-8).sqrt_() 124 | # print(act_std.data.norm().item()) 125 | out = torch.randn_like(act_mu).mul_(act_std).add_(act_mu) 126 | elif self.flipout: 127 | assert not self.bias 128 | outputs = F.conv2d(input, weight=self.weight_mu, bias=None, 129 | stride=self.stride, padding=self.padding, 130 | dilation=self.dilation, groups=self.groups) 131 | 132 | # sampling perturbation signs 133 | sign_input = torch.empty(input.size(0), input.size(1), 1, 1, device=input.device).uniform_(-1, 1).sign() 134 | sign_output = torch.empty(outputs.size(0), outputs.size(1), 1, 1, device=input.device).uniform_(-1, 1).sign() 135 | 136 | # gettin perturbation weights 137 | delta_kernel = torch.randn_like(self.weight_psi) * torch.exp(self.weight_psi) 138 | 139 | # perturbed feedforward 140 | perturbed_outputs = F.conv2d(input * sign_input, 141 | weight=delta_kernel, 142 | bias=None, 143 | stride=self.stride, 144 | padding=self.padding, 145 | dilation=self.dilation, 146 | groups=self.groups) 147 | out = outputs + perturbed_outputs * sign_output 148 | else: 149 | bs = input.shape[0] 150 | weight = self.mul_exp_add(torch.empty(bs, *self.weight_size, 151 | device=input.device, 152 | dtype=input.dtype).normal_(0, 1), 153 | self.weight_psi, self.weight_mu).view( 154 | bs*self.weight_size[0], *self.weight_size[1:]) 155 | out = F.conv2d(input.view(1, -1, input.shape[2], input.shape[3]), 156 | weight=weight, bias=None, 157 | stride=self.stride, dilation=self.dilation, 158 | groups=self.groups*bs, padding=self.padding) 159 | out = out.view(bs, self.out_channels, out.shape[2], out.shape[3]) 160 | 161 | if self.bias: 162 | bias = self.mul_exp_add(torch.empty(bs, *self.bias_size, 163 | device=input.device, 164 | dtype=input.dtype).normal_(0, 1), 165 | self.bias_psi, self.bias_mu) 166 | out = out + bias[:, :, None, None] 167 | else: 168 | if input.dim() == 4: 169 | input = input.unsqueeze(1).repeat(1, self.num_mc_samples, 1, 1, 1) 170 | 171 | weight = self.mul_exp_add(torch.empty(self.num_mc_samples, 172 | *self.weight_size, 173 | device=input.device, 174 | dtype=input.dtype).normal_(0, 1), 175 | self.weight_psi, self.weight_mu).view( 176 | self.num_mc_samples*self.weight_size[0], *self.weight_size[1:]) 177 | out = F.conv2d(input.flatten(start_dim=1, end_dim=2), 178 | weight=weight, bias=None, 179 | stride=self.stride, dilation=self.dilation, 180 | groups=self.groups*self.num_mc_samples, 181 | padding=self.padding) 182 | out = out.view(out.shape[0], self.num_mc_samples, 183 | self.out_channels, out.shape[2], out.shape[3]) 184 | if self.bias: 185 | bias = self.mul_exp_add(torch.empty(self.num_mc_samples, 186 | *self.bias_size, 187 | device=input.device, 188 | dtype=input.dtype).normal_(0, 1), 189 | self.bias_psi, self.bias_mu) 190 | out = out + bias[None, :, :, None, None] 191 | return out 192 | -------------------------------------------------------------------------------- /laplace/lllaplace.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | from torch.nn.utils import parameters_to_vector, vector_to_parameters 4 | 5 | from laplace.baselaplace import ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace 6 | from laplace.utils import FeatureExtractor, Kron 7 | 8 | 9 | __all__ = ['LLLaplace', 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace'] 10 | 11 | 12 | class LLLaplace(ParametricLaplace): 13 | """Baseclass for all last-layer Laplace approximations in this library. 14 | Subclasses specify the structure of the Hessian approximation. 15 | See `BaseLaplace` for the full interface. 16 | 17 | A Laplace approximation is represented by a MAP which is given by the 18 | `model` parameter and a posterior precision or covariance specifying 19 | a Gaussian distribution \\(\\mathcal{N}(\\theta_{MAP}, P^{-1})\\). 20 | Here, only the parameters of the last layer of the neural network 21 | are treated probabilistically. 22 | The goal of this class is to compute the posterior precision \\(P\\) 23 | which sums as 24 | \\[ 25 | P = \\sum_{n=1}^N \\nabla^2_\\theta \\log p(\\mathcal{D}_n \\mid \\theta) 26 | \\vert_{\\theta_{MAP}} + \\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}}. 27 | \\] 28 | Every subclass implements different approximations to the log likelihood Hessians, 29 | for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have 30 | a simple form for \\(\\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}} = P_0 \\). 31 | In particular, we assume a scalar or diagonal prior precision so that in 32 | all cases \\(P_0 = \\textrm{diag}(p_0)\\) and the structure of \\(p_0\\) can be varied. 33 | 34 | Parameters 35 | ---------- 36 | model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` 37 | likelihood : {'classification', 'regression'} 38 | determines the log likelihood Hessian approximation 39 | sigma_noise : torch.Tensor or float, default=1 40 | observation noise for the regression setting; must be 1 for classification 41 | prior_precision : torch.Tensor or float, default=1 42 | prior precision of a Gaussian prior (= weight decay); 43 | can be scalar, per-layer, or diagonal in the most general case 44 | prior_mean : torch.Tensor or float, default=0 45 | prior mean of a Gaussian prior, useful for continual learning 46 | temperature : float, default=1 47 | temperature of the likelihood; lower temperature leads to more 48 | concentrated posterior and vice versa. 49 | backend : subclasses of `laplace.curvature.CurvatureInterface` 50 | backend for access to curvature/Hessian approximations 51 | last_layer_name: str, default=None 52 | name of the model's last layer, if None it will be determined automatically 53 | backend_kwargs : dict, default=None 54 | arguments passed to the backend on initialization, for example to 55 | set the number of MC samples for stochastic approximations. 56 | """ 57 | def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1., 58 | prior_mean=0., temperature=1., backend=None, last_layer_name=None, 59 | backend_kwargs=None): 60 | self.H = None 61 | super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=1., 62 | prior_mean=0., temperature=temperature, backend=backend, 63 | backend_kwargs=backend_kwargs) 64 | self.model = FeatureExtractor(deepcopy(model), last_layer_name=last_layer_name) 65 | if self.model.last_layer is None: 66 | self.mean = None 67 | self.n_params = None 68 | self.n_layers = None 69 | # ignore checks of prior mean setter temporarily, check on .fit() 70 | self._prior_precision = prior_precision 71 | self._prior_mean = prior_mean 72 | else: 73 | self.n_params = len(parameters_to_vector(self.model.last_layer.parameters())) 74 | self.n_layers = len(list(self.model.last_layer.parameters())) 75 | self.prior_precision = prior_precision 76 | self.prior_mean = prior_mean 77 | self.mean = self.prior_mean 78 | self._init_H() 79 | self._backend_kwargs['last_layer'] = True 80 | 81 | def fit(self, train_loader, override=True): 82 | """Fit the local Laplace approximation at the parameters of the model. 83 | 84 | Parameters 85 | ---------- 86 | train_loader : torch.data.utils.DataLoader 87 | each iterate is a training batch (X, y); 88 | `train_loader.dataset` needs to be set to access \\(N\\), size of the data set 89 | override : bool, default=True 90 | whether to initialize H, loss, and n_data again; setting to False is useful for 91 | online learning settings to accumulate a sequential posterior approximation. 92 | """ 93 | if not override: 94 | raise ValueError('Last-layer Laplace approximations do not support `override=False`.') 95 | 96 | self.model.eval() 97 | 98 | if self.model.last_layer is None: 99 | X, _ = next(iter(train_loader)) 100 | with torch.no_grad(): 101 | try: 102 | self.model.find_last_layer(X[:1].to(self._device)) 103 | except (TypeError, AttributeError): 104 | self.model.find_last_layer(X.to(self._device)) 105 | params = parameters_to_vector(self.model.last_layer.parameters()).detach() 106 | self.n_params = len(params) 107 | self.n_layers = len(list(self.model.last_layer.parameters())) 108 | # here, check the already set prior precision again 109 | self.prior_precision = self._prior_precision 110 | self.prior_mean = self._prior_mean 111 | self._init_H() 112 | 113 | super().fit(train_loader, override=override) 114 | self.mean = parameters_to_vector(self.model.last_layer.parameters()).detach() 115 | 116 | def _glm_predictive_distribution(self, X): 117 | Js, f_mu = self.backend.last_layer_jacobians(X) 118 | f_var = self.functional_variance(Js) 119 | return f_mu.detach(), f_var.detach() 120 | 121 | def _nn_predictive_samples(self, X, n_samples=100): 122 | fs = list() 123 | for sample in self.sample(n_samples): 124 | vector_to_parameters(sample, self.model.last_layer.parameters()) 125 | fs.append(self.model(X.to(self._device)).detach()) 126 | vector_to_parameters(self.mean, self.model.last_layer.parameters()) 127 | fs = torch.stack(fs) 128 | if self.likelihood == 'classification': 129 | fs = torch.softmax(fs, dim=-1) 130 | return fs 131 | 132 | @property 133 | def prior_precision_diag(self): 134 | """Obtain the diagonal prior precision \\(p_0\\) constructed from either 135 | a scalar or diagonal prior precision. 136 | 137 | Returns 138 | ------- 139 | prior_precision_diag : torch.Tensor 140 | """ 141 | if len(self.prior_precision) == 1: # scalar 142 | return self.prior_precision * torch.ones_like(self.mean) 143 | 144 | elif len(self.prior_precision) == self.n_params: # diagonal 145 | return self.prior_precision 146 | 147 | else: 148 | raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.') 149 | 150 | 151 | class FullLLLaplace(LLLaplace, FullLaplace): 152 | """Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation 153 | and hence posterior precision. Based on the chosen `backend` parameter, the full 154 | approximation can be, for example, a generalized Gauss-Newton matrix. 155 | Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\). 156 | See `FullLaplace`, `LLLaplace`, and `BaseLaplace` for the full interface. 157 | """ 158 | # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) 159 | _key = ('last_layer', 'full') 160 | 161 | 162 | class KronLLLaplace(LLLaplace, KronLaplace): 163 | """Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation 164 | and hence posterior precision. 165 | Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, 166 | that \\P\\approx Q \\otimes H\\. 167 | See `KronLaplace`, `LLLaplace`, and `BaseLaplace` for the full interface and see 168 | `laplace.utils.matrix.Kron` and `laplace.utils.matrix.KronDecomposed` for the structure of 169 | the Kronecker factors. `Kron` is used to aggregate factors by summing up and 170 | `KronDecomposed` is used to add the prior, a Hessian factor (e.g. temperature), 171 | and computing posterior covariances, marginal likelihood, etc. 172 | Use of `damping` is possible by initializing or setting `damping=True`. 173 | """ 174 | # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) 175 | _key = ('last_layer', 'kron') 176 | 177 | def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1., 178 | prior_mean=0., temperature=1., backend=None, last_layer_name=None, 179 | damping=False, **backend_kwargs): 180 | self.damping = damping 181 | super().__init__(model, likelihood, sigma_noise, prior_precision, 182 | prior_mean, temperature, backend, last_layer_name, backend_kwargs) 183 | 184 | def _init_H(self): 185 | self.H = Kron.init_from_model(self.model.last_layer, self._device) 186 | 187 | 188 | class DiagLLLaplace(LLLaplace, DiagLaplace): 189 | """Last-layer Laplace approximation with diagonal log likelihood Hessian approximation 190 | and hence posterior precision. 191 | Mathematically, we have \\(P \\approx \\textrm{diag}(P)\\). 192 | See `DiagLaplace`, `LLLaplace`, and `BaseLaplace` for the full interface. 193 | """ 194 | # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) 195 | _key = ('last_layer', 'diag') 196 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import torchvision.transforms as transforms 5 | import torchvision.datasets as datasets 6 | import numpy as np 7 | 8 | from timm.data import create_transform 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | 11 | def subsample(loader, num_classes, subsample_number, balanced, device, verbose=False): 12 | xs, ys = [], [] 13 | cnt = np.zeros(num_classes) 14 | for x_batch, y_batch in loader: 15 | for x, y in zip(x_batch, y_batch): 16 | if np.all(balanced and cnt >= subsample_number//num_classes) or \ 17 | ((not balanced) and cnt.sum() >= subsample_number): 18 | xs = torch.stack(xs) 19 | ys = torch.stack(ys) 20 | if verbose: 21 | print("The frequency of the sampled labels") 22 | print(np.unique(ys.numpy(), return_counts=True)) 23 | return xs.to(device), ys.to(device) 24 | if balanced and cnt[y.item()] >= subsample_number//num_classes: 25 | continue 26 | xs.append(x); ys.append(y) 27 | cnt[y.item()] += 1 28 | 29 | def data_loaders(args, valid_size=None, noaug=None): 30 | if 'cifar' in args.dataset: 31 | return cifar_loaders(args, valid_size, noaug) 32 | elif args.dataset == 'mnist': 33 | return mnist_loaders(args, valid_size, noaug) 34 | else: 35 | return imagenet_loaders(args, valid_size, noaug) 36 | 37 | def mnist_loaders(args, valid_size=None, noaug=None): 38 | dset = datasets.MNIST 39 | T = transforms.Compose([ 40 | transforms.ToTensor()]) 41 | 42 | test_loader = torch.utils.data.DataLoader( 43 | dset(root=args.data_root, train=False, transform=T, download=True), 44 | batch_size=args.test_batch_size, shuffle=False, 45 | num_workers=args.workers, pin_memory=True) 46 | 47 | train_dataset = dset(root=args.data_root, train=True, transform=T, download=True) 48 | if valid_size is not None: 49 | valid_dataset = dset(root=args.data_root, train=True, transform=transforms.Compose([ 50 | transforms.RandomHorizontalFlip(), 51 | transforms.RandomCrop(28, 4), 52 | transforms.ToTensor(), 53 | Cutout(16), 54 | ]), download=True) 55 | num_train = len(train_dataset) 56 | indices = list(range(num_train)) 57 | split = int(np.floor(valid_size * num_train)) 58 | np.random.shuffle(indices) 59 | 60 | train_idx, valid_idx = indices[split:], indices[:split] 61 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx) 62 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_idx) 63 | 64 | train_loader = torch.utils.data.DataLoader( 65 | train_dataset, batch_size=args.batch_size, sampler=train_sampler, 66 | num_workers=args.workers, pin_memory=True) 67 | val_loader = torch.utils.data.DataLoader( 68 | valid_dataset, batch_size=args.batch_size, sampler=valid_sampler, 69 | num_workers=args.workers, pin_memory=True) 70 | return train_loader, val_loader, test_loader 71 | else: 72 | train_loader = torch.utils.data.DataLoader( 73 | train_dataset, batch_size=args.batch_size, shuffle=True, 74 | num_workers=args.workers, pin_memory=True) 75 | return train_loader, test_loader 76 | 77 | def cifar_loaders(args, valid_size=None, noaug=None): 78 | if args.dataset == 'cifar10': 79 | normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 80 | std=[0.2023, 0.1994, 0.201]) 81 | dset = datasets.CIFAR10 82 | elif args.dataset == 'cifar100': 83 | normalize = transforms.Normalize(mean=[0.507, 0.4865, 0.4409], 84 | std=[0.2673, 0.2564, 0.2761]) 85 | dset = datasets.CIFAR100 86 | else: 87 | raise NotImplementedError 88 | 89 | if noaug: 90 | T = transforms.Compose([ 91 | transforms.ToTensor(), 92 | normalize, 93 | ]) 94 | else: 95 | T = transforms.Compose([ 96 | transforms.RandomHorizontalFlip(), 97 | transforms.RandomCrop(32, 4), 98 | transforms.ToTensor(), 99 | normalize, 100 | ]) 101 | 102 | T_val = transforms.Compose([ 103 | transforms.ToTensor(), 104 | normalize, 105 | ]) 106 | 107 | test_loader = torch.utils.data.DataLoader( 108 | dset(root=args.data_root, train=False, transform=T_val, download=True), 109 | batch_size=args.test_batch_size, shuffle=False, 110 | num_workers=args.workers, pin_memory=True) 111 | 112 | train_dataset = dset(root=args.data_root, train=True, transform=T, download=True) 113 | if valid_size is not None: 114 | valid_dataset = dset(root=args.data_root, train=True, transform=transforms.Compose([ 115 | # transforms.RandomHorizontalFlip(), 116 | # transforms.RandomCrop(32, 4), 117 | # transforms.ToTensor(), 118 | # normalize, 119 | # Cutout(16), 120 | transforms.RandomResizedCrop(size=32, scale=(0.6 if args.arch == 'cifar10_resnet44' else 0.5, 1.)), 121 | # transforms.RandomGrayscale(p=0.2), 122 | # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 123 | # transforms.RandomHorizontalFlip(), 124 | transforms.ToTensor(), 125 | normalize, 126 | # Cutout(2), 127 | ]), download=True) 128 | num_train = len(train_dataset) 129 | indices = list(range(num_train)) 130 | split = int(np.floor(valid_size * num_train)) 131 | np.random.shuffle(indices) 132 | 133 | train_idx, valid_idx = indices[split:], indices[:split] 134 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx) 135 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_idx) 136 | 137 | train_loader = torch.utils.data.DataLoader( 138 | train_dataset, batch_size=args.batch_size, sampler=train_sampler, 139 | num_workers=args.workers, pin_memory=False) 140 | val_loader = torch.utils.data.DataLoader( 141 | valid_dataset, batch_size=args.batch_size, sampler=valid_sampler, 142 | num_workers=args.workers, pin_memory=False) 143 | 144 | # remove the randomness 145 | xs, ys = [], [] 146 | for _ in range(1): 147 | for x, y in val_loader: 148 | xs.append(x); ys.append(y) 149 | xs = torch.cat(xs); ys = torch.cat(ys) 150 | valid_dataset = torch.utils.data.TensorDataset(xs, ys) 151 | val_loader = torch.utils.data.DataLoader( 152 | valid_dataset, batch_size=args.batch_size, shuffle=False, 153 | num_workers=args.workers, pin_memory=True) 154 | 155 | return train_loader, val_loader, test_loader 156 | else: 157 | train_loader = torch.utils.data.DataLoader( 158 | train_dataset, batch_size=args.batch_size, shuffle=True, 159 | num_workers=args.workers, pin_memory=False) 160 | return train_loader, test_loader 161 | 162 | class Cutout(object): 163 | def __init__(self, length): 164 | self.length = length 165 | 166 | def __call__(self, img): 167 | h, w = img.size(1), img.size(2) 168 | mask = np.ones((h, w), np.float32) 169 | y = np.random.randint(h) 170 | x = np.random.randint(w) 171 | 172 | y1 = np.clip(y - self.length // 2, 0, h) 173 | y2 = np.clip(y + self.length // 2, 0, h) 174 | x1 = np.clip(x - self.length // 2, 0, w) 175 | x2 = np.clip(x + self.length // 2, 0, w) 176 | 177 | mask[y1: y2, x1: x2] = 0. 178 | mask = torch.from_numpy(mask) 179 | mask = mask.expand_as(img) 180 | img *= mask 181 | return img 182 | 183 | 184 | def imagenet_loaders(args, valid_size=None, noaug=None): 185 | 186 | if 'vit' in args.arch: 187 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 188 | std=[0.5, 0.5, 0.5]) 189 | else: 190 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 191 | std=[0.229, 0.224, 0.225]) 192 | if noaug: 193 | T = transforms.Compose([ 194 | transforms.Resize(224 if 'vit' in args.arch else 256, interpolation=transforms.InterpolationMode.BICUBIC if 'vit' in args.arch else transforms.InterpolationMode.BILINEAR), 195 | transforms.CenterCrop(224), 196 | transforms.ToTensor(), 197 | normalize, 198 | ]) 199 | else: 200 | T = transforms.Compose([ 201 | transforms.RandomResizedCrop(224, interpolation=transforms.InterpolationMode.BICUBIC if 'vit' in args.arch else transforms.InterpolationMode.BILINEAR), 202 | transforms.RandomHorizontalFlip(), 203 | transforms.ToTensor(), 204 | normalize, 205 | ]) 206 | 207 | T_val = transforms.Compose([ 208 | transforms.Resize(224 if 'vit' in args.arch else 256, interpolation=transforms.InterpolationMode.BICUBIC if 'vit' in args.arch else transforms.InterpolationMode.BILINEAR), 209 | transforms.CenterCrop(224), 210 | transforms.ToTensor(), 211 | normalize, 212 | ]) 213 | 214 | test_loader = torch.utils.data.DataLoader( 215 | datasets.ImageFolder(os.path.join(args.data_root, 'val'), transform=T_val), 216 | batch_size=args.test_batch_size, shuffle=False, 217 | num_workers=args.workers, pin_memory=True) 218 | 219 | train_dataset = datasets.ImageFolder(os.path.join(args.data_root, 'train'), transform=T) 220 | if valid_size is not None: 221 | valid_dataset = datasets.ImageFolder(os.path.join(args.data_root, 'train'), 222 | transform=create_transform( 223 | input_size=224, 224 | scale=(0.08, 0.1), 225 | is_training=True, 226 | color_jitter=0.4, 227 | auto_augment=None, #'original', #'v0' #'rand-m9-mstd0.5-inc1', #'v0', 'original' 228 | interpolation='bicubic', 229 | re_prob=0.25, #0.25, 230 | re_mode='pixel', 231 | re_count=1, 232 | mean=IMAGENET_DEFAULT_MEAN, 233 | std=IMAGENET_DEFAULT_STD, 234 | ) 235 | ) 236 | num_train = len(train_dataset) 237 | indices = list(range(num_train)) 238 | split = int(np.floor(valid_size * num_train)) 239 | np.random.shuffle(indices) 240 | 241 | train_idx, valid_idx = indices[split:], indices[:split] 242 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx) 243 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_idx) 244 | 245 | train_loader = torch.utils.data.DataLoader( 246 | train_dataset, batch_size=args.batch_size, sampler=train_sampler, 247 | num_workers=args.workers, pin_memory=False) 248 | val_loader = torch.utils.data.DataLoader( 249 | valid_dataset, batch_size=args.batch_size, sampler=valid_sampler, 250 | num_workers=args.workers, pin_memory=False) 251 | 252 | xs, ys = [], [] 253 | for _ in range(1): 254 | for x, y in val_loader: 255 | xs.append(x); ys.append(y) 256 | xs = torch.cat(xs); ys = torch.cat(ys) 257 | valid_dataset = torch.utils.data.TensorDataset(xs, ys) 258 | val_loader = torch.utils.data.DataLoader( 259 | valid_dataset, batch_size=args.batch_size, shuffle=False, 260 | num_workers=args.workers, pin_memory=True) 261 | 262 | return train_loader, val_loader, test_loader 263 | else: 264 | train_loader = torch.utils.data.DataLoader( 265 | train_dataset, batch_size=args.batch_size, shuffle=True, 266 | num_workers=args.workers, pin_memory=False) 267 | return train_loader, test_loader 268 | -------------------------------------------------------------------------------- /pytorch_cifar_models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/shufflenetv2.py 3 | 4 | BSD 3-Clause License 5 | 6 | Copyright (c) Soumith Chintala 2016, 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ''' 34 | import sys 35 | import torch 36 | import torch.nn as nn 37 | from torch import Tensor 38 | 39 | try: 40 | from torch.hub import load_state_dict_from_url 41 | except ImportError: 42 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 43 | 44 | from functools import partial 45 | from typing import Dict, Type, Any, Callable, Union, List, Optional 46 | 47 | 48 | cifar10_pretrained_weight_urls = { 49 | 'shufflenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x0_5-1308b4e9.pt', 50 | 'shufflenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_0-98807be3.pt', 51 | 'shufflenetv2_x1_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_5-296694dd.pt', 52 | 'shufflenetv2_x2_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x2_0-ec31611c.pt', 53 | } 54 | 55 | cifar100_pretrained_weight_urls = { 56 | 'shufflenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x0_5-1977720f.pt', 57 | 'shufflenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_0-9ae22beb.pt', 58 | 'shufflenetv2_x1_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_5-e2c85ad8.pt', 59 | 'shufflenetv2_x2_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x2_0-e7e584cd.pt', 60 | } 61 | 62 | 63 | def channel_shuffle(x: Tensor, groups: int) -> Tensor: 64 | batchsize, num_channels, height, width = x.size() 65 | channels_per_group = num_channels // groups 66 | 67 | # reshape 68 | x = x.view(batchsize, groups, 69 | channels_per_group, height, width) 70 | 71 | x = torch.transpose(x, 1, 2).contiguous() 72 | 73 | # flatten 74 | x = x.view(batchsize, -1, height, width) 75 | 76 | return x 77 | 78 | 79 | class InvertedResidual(nn.Module): 80 | def __init__( 81 | self, 82 | inp: int, 83 | oup: int, 84 | stride: int 85 | ) -> None: 86 | super(InvertedResidual, self).__init__() 87 | 88 | if not (1 <= stride <= 3): 89 | raise ValueError('illegal stride value') 90 | self.stride = stride 91 | 92 | branch_features = oup // 2 93 | assert (self.stride != 1) or (inp == branch_features << 1) 94 | 95 | if self.stride > 1: 96 | self.branch1 = nn.Sequential( 97 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 98 | nn.BatchNorm2d(inp), 99 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 100 | nn.BatchNorm2d(branch_features), 101 | nn.ReLU(inplace=True), 102 | ) 103 | else: 104 | self.branch1 = nn.Sequential() 105 | 106 | self.branch2 = nn.Sequential( 107 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 108 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 109 | nn.BatchNorm2d(branch_features), 110 | nn.ReLU(inplace=True), 111 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 112 | nn.BatchNorm2d(branch_features), 113 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 114 | nn.BatchNorm2d(branch_features), 115 | nn.ReLU(inplace=True), 116 | ) 117 | 118 | @staticmethod 119 | def depthwise_conv( 120 | i: int, 121 | o: int, 122 | kernel_size: int, 123 | stride: int = 1, 124 | padding: int = 0, 125 | bias: bool = False 126 | ) -> nn.Conv2d: 127 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 128 | 129 | def forward(self, x: Tensor) -> Tensor: 130 | if self.stride == 1: 131 | x1, x2 = x.chunk(2, dim=1) 132 | out = torch.cat((x1, self.branch2(x2)), dim=1) 133 | else: 134 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 135 | 136 | out = channel_shuffle(out, 2) 137 | 138 | return out 139 | 140 | 141 | class ShuffleNetV2(nn.Module): 142 | def __init__( 143 | self, 144 | stages_repeats: List[int], 145 | stages_out_channels: List[int], 146 | num_classes: int = 1000, 147 | inverted_residual: Callable[..., nn.Module] = InvertedResidual 148 | ) -> None: 149 | super(ShuffleNetV2, self).__init__() 150 | 151 | if len(stages_repeats) != 3: 152 | raise ValueError('expected stages_repeats as list of 3 positive ints') 153 | if len(stages_out_channels) != 5: 154 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 155 | self._stage_out_channels = stages_out_channels 156 | 157 | input_channels = 3 158 | output_channels = self._stage_out_channels[0] 159 | self.conv1 = nn.Sequential( 160 | nn.Conv2d(input_channels, output_channels, 3, 1, 1, bias=False), # NOTE: change stride 2 -> 1 for CIFAR10/100 161 | nn.BatchNorm2d(output_channels), 162 | nn.ReLU(inplace=True), 163 | ) 164 | input_channels = output_channels 165 | 166 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) NOTE: remove this maxpool layer for CIFAR10/100 167 | 168 | # Static annotations for mypy 169 | self.stage2: nn.Sequential 170 | self.stage3: nn.Sequential 171 | self.stage4: nn.Sequential 172 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 173 | for name, repeats, output_channels in zip( 174 | stage_names, stages_repeats, self._stage_out_channels[1:]): 175 | seq = [inverted_residual(input_channels, output_channels, 2)] 176 | for i in range(repeats - 1): 177 | seq.append(inverted_residual(output_channels, output_channels, 1)) 178 | setattr(self, name, nn.Sequential(*seq)) 179 | input_channels = output_channels 180 | 181 | output_channels = self._stage_out_channels[-1] 182 | self.conv5 = nn.Sequential( 183 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 184 | nn.BatchNorm2d(output_channels), 185 | nn.ReLU(inplace=True), 186 | ) 187 | 188 | self.fc = nn.Linear(output_channels, num_classes) 189 | 190 | def _forward_impl(self, x: Tensor) -> Tensor: 191 | # See note [TorchScript super()] 192 | x = self.conv1(x) 193 | # x = self.maxpool(x) NOTE: remove this maxpool layer for CIFAR10/100 194 | x = self.stage2(x) 195 | x = self.stage3(x) 196 | x = self.stage4(x) 197 | x = self.conv5(x) 198 | x = x.mean([2, 3]) # globalpool 199 | x = self.fc(x) 200 | return x 201 | 202 | def forward(self, x: Tensor) -> Tensor: 203 | return self._forward_impl(x) 204 | 205 | 206 | def _shufflenet_v2( 207 | arch: str, 208 | stages_repeats: List[int], 209 | stages_out_channels: List[int], 210 | model_urls: Dict[str, str], 211 | progress: bool = True, 212 | pretrained: bool = False, 213 | **kwargs: Any 214 | ) -> ShuffleNetV2: 215 | model = ShuffleNetV2(stages_repeats=stages_repeats, stages_out_channels=stages_out_channels, ** kwargs) 216 | if pretrained: 217 | state_dict = load_state_dict_from_url(model_urls[arch], 218 | progress=progress) 219 | model.load_state_dict(state_dict) 220 | return model 221 | 222 | 223 | def cifar10_shufflenetv2_x0_5(*args, **kwargs) -> ShuffleNetV2: pass 224 | def cifar10_shufflenetv2_x1_0(*args, **kwargs) -> ShuffleNetV2: pass 225 | def cifar10_shufflenetv2_x1_5(*args, **kwargs) -> ShuffleNetV2: pass 226 | def cifar10_shufflenetv2_x2_0(*args, **kwargs) -> ShuffleNetV2: pass 227 | 228 | 229 | def cifar100_shufflenetv2_x0_5(*args, **kwargs) -> ShuffleNetV2: pass 230 | def cifar100_shufflenetv2_x1_0(*args, **kwargs) -> ShuffleNetV2: pass 231 | def cifar100_shufflenetv2_x1_5(*args, **kwargs) -> ShuffleNetV2: pass 232 | def cifar100_shufflenetv2_x2_0(*args, **kwargs) -> ShuffleNetV2: pass 233 | 234 | 235 | thismodule = sys.modules[__name__] 236 | for dataset in ["cifar10", "cifar100"]: 237 | for stages_repeats, stages_out_channels, model_name in \ 238 | zip([[4, 8, 4]]*4, 239 | [[24, 48, 96, 192, 1024], [24, 116, 232, 464, 1024], [24, 176, 352, 704, 1024], [24, 244, 488, 976, 2048]], 240 | ["shufflenetv2_x0_5", "shufflenetv2_x1_0", "shufflenetv2_x1_5", "shufflenetv2_x2_0"]): 241 | method_name = f"{dataset}_{model_name}" 242 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 243 | num_classes = 10 if dataset == "cifar10" else 100 244 | setattr( 245 | thismodule, 246 | method_name, 247 | partial(_shufflenet_v2, 248 | arch=model_name, 249 | stages_repeats=stages_repeats, 250 | stages_out_channels=stages_out_channels, 251 | model_urls=model_urls, 252 | num_classes=num_classes) 253 | ) 254 | -------------------------------------------------------------------------------- /laplace/curvature/curvature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import MSELoss, CrossEntropyLoss 3 | 4 | 5 | class CurvatureInterface: 6 | """Interface to access curvature for a model and corresponding likelihood. 7 | A `CurvatureInterface` must inherit from this baseclass and implement the 8 | necessary functions `jacobians`, `full`, `kron`, and `diag`. 9 | The interface might be extended in the future to account for other curvature 10 | structures, for example, a block-diagonal one. 11 | 12 | Parameters 13 | ---------- 14 | model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` 15 | torch model (neural network) 16 | likelihood : {'classification', 'regression'} 17 | last_layer : bool, default=False 18 | only consider curvature of last layer 19 | subnetwork_indices : torch.Tensor, default=None 20 | indices of the vectorized model parameters that define the subnetwork 21 | to apply the Laplace approximation over 22 | 23 | Attributes 24 | ---------- 25 | lossfunc : torch.nn.MSELoss or torch.nn.CrossEntropyLoss 26 | factor : float 27 | conversion factor between torch losses and base likelihoods 28 | For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss. 29 | """ 30 | def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): 31 | assert likelihood in ['regression', 'classification'] 32 | self.likelihood = likelihood 33 | self.model = model 34 | self.last_layer = last_layer 35 | self.subnetwork_indices = subnetwork_indices 36 | if likelihood == 'regression': 37 | self.lossfunc = MSELoss(reduction='sum') 38 | self.factor = 0.5 39 | else: 40 | self.lossfunc = CrossEntropyLoss(reduction='sum') 41 | self.factor = 1. 42 | 43 | @property 44 | def _model(self): 45 | return self.model.last_layer if self.last_layer else self.model 46 | 47 | def jacobians(self, x): 48 | """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\). 49 | 50 | Parameters 51 | ---------- 52 | x : torch.Tensor 53 | input data `(batch, input_shape)` on compatible device with model. 54 | 55 | Returns 56 | ------- 57 | Js : torch.Tensor 58 | Jacobians `(batch, parameters, outputs)` 59 | f : torch.Tensor 60 | output function `(batch, outputs)` 61 | """ 62 | raise NotImplementedError 63 | 64 | def last_layer_jacobians(self, x): 65 | """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\) 66 | only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\). 67 | 68 | Parameters 69 | ---------- 70 | x : torch.Tensor 71 | 72 | Returns 73 | ------- 74 | Js : torch.Tensor 75 | Jacobians `(batch, last-layer-parameters, outputs)` 76 | f : torch.Tensor 77 | output function `(batch, outputs)` 78 | """ 79 | f, phi = self.model.forward_with_features(x) 80 | bsize = phi.shape[0] 81 | output_size = f.shape[-1] 82 | 83 | # calculate Jacobians using the feature vector 'phi' 84 | identity = torch.eye(output_size, device=x.device).unsqueeze(0).tile(bsize, 1, 1) 85 | # Jacobians are batch x output x params 86 | Js = torch.einsum('kp,kij->kijp', phi, identity).reshape(bsize, output_size, -1) 87 | if self.model.last_layer.bias is not None: 88 | Js = torch.cat([Js, identity], dim=2) 89 | 90 | return Js, f.detach() 91 | 92 | def gradients(self, x, y): 93 | """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter \\(\\theta\\). 94 | 95 | Parameters 96 | ---------- 97 | x : torch.Tensor 98 | input data `(batch, input_shape)` on compatible device with model. 99 | y : torch.Tensor 100 | 101 | Returns 102 | ------- 103 | loss : torch.Tensor 104 | Gs : torch.Tensor 105 | gradients `(batch, parameters)` 106 | """ 107 | raise NotImplementedError 108 | 109 | def full(self, x, y, **kwargs): 110 | """Compute a dense curvature (approximation) in the form of a \\(P \\times P\\) matrix 111 | \\(H\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\). 112 | 113 | Parameters 114 | ---------- 115 | x : torch.Tensor 116 | input data `(batch, input_shape)` 117 | y : torch.Tensor 118 | labels `(batch, label_shape)` 119 | 120 | Returns 121 | ------- 122 | loss : torch.Tensor 123 | H : torch.Tensor 124 | Hessian approximation `(parameters, parameters)` 125 | """ 126 | raise NotImplementedError 127 | 128 | def kron(self, x, y, **kwargs): 129 | """Compute a Kronecker factored curvature approximation (such as KFAC). 130 | The approximation to \\(H\\) takes the form of two Kronecker factors \\(Q, H\\), 131 | i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting 132 | such curvature. 133 | \\(Q\\) is quadratic in the input-dimension of a module \\(p_{in} \\times p_{in}\\) 134 | and \\(H\\) in the output-dimension \\(p_{out} \\times p_{out}\\). 135 | 136 | Parameters 137 | ---------- 138 | x : torch.Tensor 139 | input data `(batch, input_shape)` 140 | y : torch.Tensor 141 | labels `(batch, label_shape)` 142 | 143 | Returns 144 | ------- 145 | loss : torch.Tensor 146 | H : `laplace.utils.matrix.Kron` 147 | Kronecker factored Hessian approximation. 148 | """ 149 | raise NotImplementedError 150 | 151 | def diag(self, x, y, **kwargs): 152 | """Compute a diagonal Hessian approximation to \\(H\\) and is represented as a 153 | vector of the dimensionality of parameters \\(\\theta\\). 154 | 155 | Parameters 156 | ---------- 157 | x : torch.Tensor 158 | input data `(batch, input_shape)` 159 | y : torch.Tensor 160 | labels `(batch, label_shape)` 161 | 162 | Returns 163 | ------- 164 | loss : torch.Tensor 165 | H : torch.Tensor 166 | vector representing the diagonal of H 167 | """ 168 | raise NotImplementedError 169 | 170 | 171 | class GGNInterface(CurvatureInterface): 172 | """Generalized Gauss-Newton or Fisher Curvature Interface. 173 | The GGN is equal to the Fisher information for the available likelihoods. 174 | In addition to `CurvatureInterface`, methods for Jacobians are required by subclasses. 175 | 176 | Parameters 177 | ---------- 178 | model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` 179 | torch model (neural network) 180 | likelihood : {'classification', 'regression'} 181 | last_layer : bool, default=False 182 | only consider curvature of last layer 183 | subnetwork_indices : torch.Tensor, default=None 184 | indices of the vectorized model parameters that define the subnetwork 185 | to apply the Laplace approximation over 186 | stochastic : bool, default=False 187 | Fisher if stochastic else GGN 188 | """ 189 | def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): 190 | self.stochastic = stochastic 191 | super().__init__(model, likelihood, last_layer, subnetwork_indices) 192 | 193 | def _get_full_ggn(self, Js, f, y): 194 | """Compute full GGN from Jacobians. 195 | 196 | Parameters 197 | ---------- 198 | Js : torch.Tensor 199 | Jacobians `(batch, parameters, outputs)` 200 | f : torch.Tensor 201 | functions `(batch, outputs)` 202 | y : torch.Tensor 203 | labels compatible with loss 204 | 205 | Returns 206 | ------- 207 | loss : torch.Tensor 208 | H_ggn : torch.Tensor 209 | full GGN approximation `(parameters, parameters)` 210 | """ 211 | loss = self.factor * self.lossfunc(f, y) 212 | if self.likelihood == 'regression': 213 | H_ggn = torch.einsum('mkp,mkq->pq', Js, Js) 214 | else: 215 | # second derivative of log lik is diag(p) - pp^T 216 | ps = torch.softmax(f, dim=-1) 217 | H_lik = torch.diag_embed(ps) - torch.einsum('mk,mc->mck', ps, ps) 218 | H_ggn = torch.einsum('mcp,mck,mkq->pq', Js, H_lik, Js) 219 | return loss.detach(), H_ggn 220 | 221 | def full(self, x, y, **kwargs): 222 | """Compute the full GGN \\(P \\times P\\) matrix as Hessian approximation 223 | \\(H_{ggn}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\). 224 | For last-layer, reduced to \\(\\theta_{last}\\) 225 | 226 | Parameters 227 | ---------- 228 | x : torch.Tensor 229 | input data `(batch, input_shape)` 230 | y : torch.Tensor 231 | labels `(batch, label_shape)` 232 | 233 | Returns 234 | ------- 235 | loss : torch.Tensor 236 | H_ggn : torch.Tensor 237 | GGN `(parameters, parameters)` 238 | """ 239 | if self.stochastic: 240 | raise ValueError('Stochastic approximation not implemented for full GGN.') 241 | 242 | if self.last_layer: 243 | Js, f = self.last_layer_jacobians(x) 244 | else: 245 | Js, f = self.jacobians(x) 246 | loss, H_ggn = self._get_full_ggn(Js, f, y) 247 | 248 | return loss, H_ggn 249 | 250 | 251 | class EFInterface(CurvatureInterface): 252 | """Interface for Empirical Fisher as Hessian approximation. 253 | In addition to `CurvatureInterface`, methods for gradients are required by subclasses. 254 | 255 | Parameters 256 | ---------- 257 | model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` 258 | torch model (neural network) 259 | likelihood : {'classification', 'regression'} 260 | last_layer : bool, default=False 261 | only consider curvature of last layer 262 | subnetwork_indices : torch.Tensor, default=None 263 | indices of the vectorized model parameters that define the subnetwork 264 | to apply the Laplace approximation over 265 | 266 | Attributes 267 | ---------- 268 | lossfunc : torch.nn.MSELoss or torch.nn.CrossEntropyLoss 269 | factor : float 270 | conversion factor between torch losses and base likelihoods 271 | For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss. 272 | """ 273 | 274 | def full(self, x, y, **kwargs): 275 | """Compute the full EF \\(P \\times P\\) matrix as Hessian approximation 276 | \\(H_{ef}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\). 277 | For last-layer, reduced to \\(\\theta_{last}\\) 278 | 279 | Parameters 280 | ---------- 281 | x : torch.Tensor 282 | input data `(batch, input_shape)` 283 | y : torch.Tensor 284 | labels `(batch, label_shape)` 285 | 286 | Returns 287 | ------- 288 | loss : torch.Tensor 289 | H_ef : torch.Tensor 290 | EF `(parameters, parameters)` 291 | """ 292 | Gs, loss = self.gradients(x, y) 293 | H_ef = Gs.T @ Gs 294 | return self.factor * loss.detach(), self.factor * H_ef 295 | -------------------------------------------------------------------------------- /laplace/marglik_training.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | import torch 4 | from torch.optim import Adam 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | from torch.nn.utils import parameters_to_vector 7 | import warnings 8 | import logging 9 | 10 | from laplace import Laplace 11 | from laplace.curvature import AsdlGGN 12 | from laplace.utils import expand_prior_precision 13 | 14 | 15 | def marglik_training( 16 | model, 17 | train_loader, 18 | likelihood='classification', 19 | hessian_structure='kron', 20 | backend=AsdlGGN, 21 | optimizer_cls=Adam, 22 | optimizer_kwargs=None, 23 | scheduler_cls=None, 24 | scheduler_kwargs=None, 25 | n_epochs=300, 26 | lr_hyp=1e-1, 27 | prior_structure='layerwise', 28 | n_epochs_burnin=0, 29 | n_hypersteps=10, 30 | marglik_frequency=1, 31 | prior_prec_init=1., 32 | sigma_noise_init=1., 33 | temperature=1. 34 | ): 35 | """Marginal-likelihood based training (Algorithm 1 in [1]). 36 | Optimize model parameters and hyperparameters jointly. 37 | Model parameters are optimized to minimize negative log joint (train loss) 38 | while hyperparameters minimize negative log marginal likelihood. 39 | 40 | This method replaces standard neural network training and adds hyperparameter 41 | optimization to the procedure. 42 | 43 | The settings of standard training can be controlled by passing `train_loader`, 44 | `optimizer_cls`, `optimizer_kwargs`, `scheduler_cls`, `scheduler_kwargs`, and `n_epochs`. 45 | The `model` should return logits, i.e., no softmax should be applied. 46 | With `likelihood='classification'` or `'regression'`, one can choose between 47 | categorical likelihood (CrossEntropyLoss) and Gaussian likelihood (MSELoss). 48 | 49 | As in [1], we optimize prior precision and, for regression, observation noise 50 | using the marginal likelihood. The prior precision structure can be chosen 51 | as `'scalar'`, `'layerwise'`, or `'diagonal'`. `'layerwise'` is a good default 52 | and available to all Laplace approximations. `lr_hyp` is the step size of the 53 | Adam hyperparameter optimizer, `n_hypersteps` controls the number of steps 54 | for each estimated marginal likelihood, `n_epochs_burnin` controls how many 55 | epochs to skip marginal likelihood estimation, `marglik_frequency` controls 56 | how often to estimate the marginal likelihood (default of 1 re-estimates 57 | after every epoch, 5 would estimate every 5-th epoch). 58 | 59 | References 60 | ---------- 61 | [1] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. 62 | [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). 63 | ICML 2021. 64 | 65 | Parameters 66 | ---------- 67 | model : torch.nn.Module 68 | torch neural network model (needs to comply with Backend choice) 69 | train_loader : DataLoader 70 | pytorch dataloader that implements `len(train_loader.dataset)` to obtain number of data points 71 | likelihood : str, default='classification' 72 | 'classification' or 'regression' 73 | hessian_structure : {'diag', 'kron', 'full'}, default='kron' 74 | structure of the Hessian approximation 75 | backend : Backend, default=AsdlGGN 76 | Curvature subclass, e.g. AsdlGGN/AsdlEF or BackPackGGN/BackPackEF 77 | optimizer_cls : torch.optim.Optimizer, default=Adam 78 | optimizer to use for optimizing the neural network parameters togeth with `train_loader` 79 | optimizer_kwargs : dict, default=None 80 | keyword arguments for `optimizer_cls`, for example to change learning rate or momentum 81 | scheduler_cls : torch.optim.lr_scheduler._LRScheduler, default=None 82 | optionally, a scheduler to use on the learning rate of the optimizer. 83 | `scheduler.step()` is called after every batch of the standard training. 84 | scheduler_kwargs : dict, default=None 85 | keyword arguments for `scheduler_cls`, e.g. `lr_min` for CosineAnnealingLR 86 | n_epochs : int, default=300 87 | number of epochs to train for 88 | lr_hyp : float, default=0.1 89 | Adam learning rate for hyperparameters 90 | prior_structure : str, default='layerwise' 91 | structure of the prior. one of `['scalar', 'layerwise', 'diagonal']` 92 | n_epochs_burnin : int default=0 93 | how many epochs to train without estimating and differentiating marglik 94 | n_hypersteps : int, default=10 95 | how many steps to take on the hyperparameters when marglik is estimated 96 | marglik_frequency : int 97 | how often to estimate (and differentiate) the marginal likelihood 98 | `marglik_frequency=1` would be every epoch, 99 | `marglik_frequency=5` would be every 5 epochs. 100 | prior_prec_init : float, default=1.0 101 | initial prior precision 102 | sigma_noise_init : float, default=1.0 103 | initial observation noise (for regression only) 104 | temperature : float, default=1.0 105 | factor for the likelihood for 'overcounting' data. Might be required for data augmentation. 106 | 107 | Returns 108 | ------- 109 | lap : laplace 110 | fit Laplace approximation with the best obtained marginal likelihood during training 111 | model : torch.nn.Module 112 | corresponding model with the MAP parameters 113 | margliks : list 114 | list of marginal likelihoods obtained during training (to monitor convergence) 115 | losses : list 116 | list of losses (log joints) obtained during training (to monitor convergence) 117 | """ 118 | if 'weight_decay' in optimizer_kwargs: 119 | warnings.warn('Weight decay is handled and optimized. Will be set to 0.') 120 | optimizer_kwargs['weight_decay'] = 0.0 121 | 122 | # get device, data set size N, number of layers H, number of parameters P 123 | device = parameters_to_vector(model.parameters()).device 124 | N = len(train_loader.dataset) 125 | H = len(list(model.parameters())) 126 | P = len(parameters_to_vector(model.parameters())) 127 | 128 | # differentiable hyperparameters 129 | hyperparameters = list() 130 | # prior precision 131 | log_prior_prec_init = np.log(temperature * prior_prec_init) 132 | if prior_structure == 'scalar': 133 | log_prior_prec = log_prior_prec_init * torch.ones(1, device=device) 134 | elif prior_structure == 'layerwise': 135 | log_prior_prec = log_prior_prec_init * torch.ones(H, device=device) 136 | elif prior_structure == 'diagonal': 137 | log_prior_prec = log_prior_prec_init * torch.ones(P, device=device) 138 | else: 139 | raise ValueError(f'Invalid prior structure {prior_structure}') 140 | log_prior_prec.requires_grad = True 141 | hyperparameters.append(log_prior_prec) 142 | 143 | # set up loss (and observation noise hyperparam) 144 | if likelihood == 'classification': 145 | criterion = CrossEntropyLoss(reduction='mean') 146 | sigma_noise = 1. 147 | elif likelihood == 'regression': 148 | criterion = MSELoss(reduction='mean') 149 | log_sigma_noise_init = np.log(sigma_noise_init) 150 | log_sigma_noise = log_sigma_noise_init * torch.ones(1, device=device) 151 | log_sigma_noise.requires_grad = True 152 | hyperparameters.append(log_sigma_noise) 153 | 154 | # set up model optimizer 155 | if optimizer_kwargs is None: 156 | optimizer_kwargs = dict() 157 | optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) 158 | 159 | # set up learning rate scheduler 160 | if scheduler_cls is not None: 161 | if scheduler_kwargs is None: 162 | scheduler_kwargs = dict() 163 | scheduler = scheduler_cls(optimizer, **scheduler_kwargs) 164 | 165 | # set up hyperparameter optimizer 166 | hyper_optimizer = Adam(hyperparameters, lr=lr_hyp) 167 | 168 | best_marglik = np.inf 169 | best_model_dict = None 170 | best_precision = None 171 | losses = list() 172 | margliks = list() 173 | 174 | for epoch in range(1, n_epochs + 1): 175 | epoch_loss = 0 176 | epoch_perf = 0 177 | 178 | # standard NN training per batch 179 | for X, y in train_loader: 180 | X, y = X.to(device), y.to(device) 181 | optimizer.zero_grad() 182 | if likelihood == 'regression': 183 | sigma_noise = torch.exp(log_sigma_noise).detach() 184 | crit_factor = temperature / (2 * sigma_noise.square()) 185 | else: 186 | crit_factor = temperature 187 | prior_prec = torch.exp(log_prior_prec).detach() 188 | theta = parameters_to_vector(model.parameters()) 189 | delta = expand_prior_precision(prior_prec, model) 190 | f = model(X) 191 | loss = criterion(f, y) + (0.5 * (delta * theta) @ theta) / N / crit_factor 192 | loss.backward() 193 | optimizer.step() 194 | epoch_loss += loss.cpu().item() * len(y) 195 | if likelihood == 'regression': 196 | epoch_perf += (f.detach() - y).square().sum() 197 | else: 198 | epoch_perf += torch.sum(torch.argmax(f.detach(), dim=-1) == y).item() 199 | if scheduler_cls is not None: 200 | scheduler.step() 201 | 202 | losses.append(epoch_loss / N) 203 | 204 | # compute validation error to report during training 205 | logging.info(f'MARGLIK[epoch={epoch}]: network training. Loss={losses[-1]:.3f}.' + 206 | f'Perf={epoch_perf/N:.3f}') 207 | 208 | # only update hyperparameters every marglik_frequency steps after burnin 209 | if (epoch % marglik_frequency) != 0 or epoch < n_epochs_burnin: 210 | continue 211 | 212 | # optimizer hyperparameters by differentiating marglik 213 | # 1. fit laplace approximation 214 | sigma_noise = 1 if likelihood == 'classification' else torch.exp(log_sigma_noise) 215 | prior_prec = torch.exp(log_prior_prec) 216 | lap = Laplace( 217 | model, likelihood, hessian_structure=hessian_structure, sigma_noise=sigma_noise, 218 | prior_precision=prior_prec, temperature=temperature, backend=backend, 219 | subset_of_weights='all' 220 | ) 221 | lap.fit(train_loader) 222 | 223 | # 2. differentiate wrt. hyperparameters for n_hypersteps 224 | for _ in range(n_hypersteps): 225 | hyper_optimizer.zero_grad() 226 | if likelihood == 'classification': 227 | sigma_noise = None 228 | elif likelihood == 'regression': 229 | sigma_noise = torch.exp(log_sigma_noise) 230 | prior_prec = torch.exp(log_prior_prec) 231 | marglik = -lap.log_marginal_likelihood(prior_prec, sigma_noise) 232 | marglik.backward() 233 | hyper_optimizer.step() 234 | margliks.append(marglik.item()) 235 | 236 | # early stopping on marginal likelihood 237 | if margliks[-1] < best_marglik: 238 | best_model_dict = deepcopy(model.state_dict()) 239 | best_precision = deepcopy(prior_prec.detach()) 240 | best_sigma = 1 if likelihood == 'classification' else deepcopy(sigma_noise.detach()) 241 | best_marglik = margliks[-1] 242 | logging.info(f'MARGLIK[epoch={epoch}]: marglik optimization. MargLik={best_marglik:.2f}. ' 243 | + 'Saving new best model.') 244 | else: 245 | logging.info(f'MARGLIK[epoch={epoch}]: marglik optimization. MargLik={margliks[-1]:.2f}.' 246 | + f'No improvement over {best_marglik:.2f}') 247 | 248 | logging.info('MARGLIK: finished training. Recover best model and fit Laplace.') 249 | if best_model_dict is not None: 250 | model.load_state_dict(best_model_dict) 251 | sigma_noise = best_sigma 252 | prior_prec = best_precision 253 | lap = Laplace( 254 | model, likelihood, hessian_structure=hessian_structure, sigma_noise=sigma_noise, 255 | prior_precision=prior_prec, temperature=temperature, backend=backend, 256 | subset_of_weights='all' 257 | ) 258 | lap.fit(train_loader) 259 | return lap, model, margliks, losses 260 | --------------------------------------------------------------------------------