├── .gitignore ├── README.md ├── dataset.py ├── viExperiment.py └── viModel.py /.gitignore: -------------------------------------------------------------------------------- 1 | #tmp files 2 | checkpoints/* 3 | data/* 4 | 5 | #compiled files 6 | *.pyc 7 | 8 | #backup and cache files 9 | *~ 10 | __pycache__ 11 | 12 | #project files 13 | .spyproject 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesian MNIST 2 | 3 | Bayesian MNIST is a companion toy example for our tutorial ["Hands-on Bayesian Neural Networks - A Tutorial for Deep Learning Users"](https://doi.org/10.1109/MCI.2022.3155327). It is just a hello world project showing how a BNN can be implemented to perform classification on MNIST. 4 | 5 | ## Dependancies 6 | 7 | The code depends on: 8 | 9 | - numpy (tested with version 1.19.2), 10 | - pytorch (tested with version 1.8.1), 11 | - torchvision (tested with version 0.9.1), 12 | - matplotlib (tested with version 3.1.1), 13 | 14 | and two libraries from the base python distribution: argparse and os. 15 | 16 | It has been tested with python 3.6.9. 17 | 18 | ## Usage 19 | 20 | The project is split into multiple files: 21 | 22 | - dataset.py implement a few routines to filter out the mnist dataset, allowing us to train the model without one digit, as it will be presented later to the model to see how it reacts. 23 | - viModel.py implement the variational inference layers and model we are using. 24 | - viExperiment.py is the script running the actual experiment. It can be called with the -h option to get a contextual help message: 25 | 26 | python viExperiment.py -h 27 | 28 | ## Citation 29 | 30 | If you use our code in your project please cite our tutorial: 31 | 32 | @ARTICLE{9756596, 33 | author={Jospin, Laurent Valentin and Laga, Hamid and Boussaid, Farid and Buntine, Wray and Bennamoun, Mohammed}, 34 | journal={IEEE Computational Intelligence Magazine}, 35 | title={Hands-On Bayesian Neural Networks—A Tutorial for Deep Learning Users}, 36 | year={2022}, 37 | volume={17}, 38 | number={2}, 39 | pages={29-48}, 40 | doi={10.1109/MCI.2022.3155327} 41 | } 42 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon May 17 11:05:05 2021 5 | 6 | @author: laurent 7 | """ 8 | 9 | import torch 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | import torchvision 12 | 13 | def getSets(filteredClass = None, removeFiltered = True) : 14 | """ 15 | Return a torch dataset 16 | """ 17 | 18 | train = torchvision.datasets.MNIST('./data/', train=True, download=True, 19 | transform=torchvision.transforms.Compose([ 20 | torchvision.transforms.ToTensor(), 21 | torchvision.transforms.Normalize((0.1307,), (0.3081,)) 22 | ])) 23 | 24 | test = torchvision.datasets.MNIST('./data/', train=False, download=True, 25 | transform=torchvision.transforms.Compose([ 26 | torchvision.transforms.ToTensor(), 27 | torchvision.transforms.Normalize((0.1307,), (0.3081,)) 28 | ])) 29 | 30 | if filteredClass is not None : 31 | 32 | train_loader = torch.utils.data.DataLoader(train, batch_size=len(train)) 33 | 34 | train_labels = next(iter(train_loader))[1].squeeze() 35 | 36 | test_loader = torch.utils.data.DataLoader(test, batch_size=len(test)) 37 | 38 | test_labels = next(iter(test_loader))[1].squeeze() 39 | 40 | if removeFiltered : 41 | trainIndices = torch.nonzero(train_labels != filteredClass).squeeze() 42 | testIndices = torch.nonzero(test_labels != filteredClass).squeeze() 43 | else : 44 | trainIndices = torch.nonzero(train_labels == filteredClass).squeeze() 45 | testIndices = torch.nonzero(test_labels == filteredClass).squeeze() 46 | 47 | train = torch.utils.data.Subset(train, trainIndices) 48 | test = torch.utils.data.Subset(test, testIndices) 49 | 50 | return train, test 51 | 52 | if __name__ == "__main__" : 53 | 54 | #test getSets function 55 | train, test = getSets(filteredClass = 3, removeFiltered = False) 56 | 57 | test_loader = torch.utils.data.DataLoader(test, batch_size=len(test)) 58 | 59 | images, labels = next(iter(test_loader)) 60 | 61 | print(images.shape) 62 | print(torch.unique(labels.squeeze())) -------------------------------------------------------------------------------- /viExperiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon May 17 16:04:42 2021 5 | 6 | @author: laurent 7 | """ 8 | 9 | from dataset import getSets 10 | from viModel import BayesianMnistNet 11 | 12 | import numpy as np 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | import torch 17 | from torch.optim import Adam 18 | from torch.utils.data import DataLoader 19 | 20 | import os 21 | 22 | import argparse as args 23 | 24 | def saveModels(models, savedir) : 25 | 26 | for i, m in enumerate(models) : 27 | 28 | saveFileName = os.path.join(savedir, "model{}.pth".format(i)) 29 | 30 | torch.save({"model_state_dict": m.state_dict()}, os.path.abspath(saveFileName)) 31 | 32 | def loadModels(savedir) : 33 | 34 | models = [] 35 | 36 | for f in os.listdir(savedir) : 37 | 38 | model = BayesianMnistNet(p_mc_dropout=None) 39 | model.load_state_dict(torch.load(os.path.abspath(os.path.join(savedir, f)))["model_state_dict"]) 40 | models.append(model) 41 | 42 | return models 43 | 44 | if __name__ == "__main__" : 45 | 46 | parser = args.ArgumentParser(description='Train a BNN on Mnist') 47 | 48 | parser.add_argument('--filteredclass', type=int, default = 5, choices = [x for x in range(10)], help="The class to ignore during training") 49 | parser.add_argument('--testclass', type=int, default = 4, choices = [x for x in range(10)], help="The class to test against that is not the filtered class") 50 | 51 | 52 | parser.add_argument('--savedir', default = None, help="Directory where the models can be saved or loaded from") 53 | parser.add_argument('--notrain', action = "store_true", help="Load the models directly instead of training") 54 | 55 | parser.add_argument('--nepochs', type=int, default = 10, help="The number of epochs to train for") 56 | parser.add_argument('--nbatch', type=int, default = 64, help="Batch size used for training") 57 | parser.add_argument('--nruntests', type=int, default = 50, help="The number of pass to use at test time for monte-carlo uncertainty estimation") 58 | parser.add_argument('--learningrate', type=float, default = 5e-3, help="The learning rate of the optimizer") 59 | parser.add_argument('--numnetworks', type=int, default = 10, help="The number of networks to train to make an ensemble") 60 | 61 | args = parser.parse_args() 62 | plt.rcParams["font.family"] = "serif" 63 | 64 | 65 | train, test = getSets(filteredClass = args.filteredclass) 66 | train_filtered, test_filtered = getSets(filteredClass = args.filteredclass, removeFiltered = False) 67 | 68 | N = len(train) 69 | 70 | train_loader = torch.utils.data.DataLoader(train, batch_size=args.nbatch) 71 | test_loader = torch.utils.data.DataLoader(test, batch_size=args.nbatch) 72 | 73 | batchLen = len(train_loader) 74 | digitsBatchLen = len(str(batchLen)) 75 | 76 | models = [] 77 | 78 | # Training or Loading 79 | if args.notrain : 80 | 81 | models = loadModels(args.savedir) 82 | 83 | else : 84 | 85 | for i in np.arange(args.numnetworks) : 86 | print("Training model {}/{}:".format(i+1, args.numnetworks)) 87 | 88 | #Initialize the model 89 | model = BayesianMnistNet(p_mc_dropout=None) #p_mc_dropout=None will disable MC-Dropout for this bnn, as we found out it makes learning much much slower. 90 | loss = torch.nn.NLLLoss(reduction='mean') #negative log likelihood will be part of the ELBO 91 | 92 | optimizer = Adam(model.parameters(), lr=args.learningrate) 93 | optimizer.zero_grad() 94 | 95 | for n in np.arange(args.nepochs) : 96 | 97 | for batch_id, sampl in enumerate(train_loader) : 98 | 99 | images, labels = sampl 100 | 101 | pred = model(images, stochastic=True) 102 | 103 | logprob = loss(pred, labels) 104 | l = N*logprob 105 | 106 | modelloss = model.evalAllLosses() 107 | l += modelloss 108 | 109 | optimizer.zero_grad() 110 | l.backward() 111 | 112 | optimizer.step() 113 | 114 | print("\r", ("\tEpoch {}/{}: Train step {"+(":0{}d".format(digitsBatchLen))+"}/{} prob = {:.4f} model = {:.4f} loss = {:.4f} ").format( 115 | n+1, args.nepochs, 116 | batch_id+1, 117 | batchLen, 118 | torch.exp(-logprob.detach().cpu()).item(), 119 | modelloss.detach().cpu().item(), 120 | l.detach().cpu().item()), end="") 121 | print("") 122 | 123 | models.append(model) 124 | 125 | if args.savedir is not None : 126 | saveModels(models, args.savedir) 127 | 128 | 129 | # Testing 130 | if args.testclass != args.filteredclass : 131 | 132 | train_filtered_seen, test_filtered_seen = getSets(filteredClass = args.testclass, removeFiltered = False) 133 | 134 | print("") 135 | print("Testing against seen class:") 136 | 137 | with torch.no_grad() : 138 | 139 | samples = torch.zeros((args.nruntests, len(test_filtered_seen), 10)) 140 | 141 | test_loader = DataLoader(test_filtered_seen, batch_size=len(test_filtered_seen)) 142 | images, labels = next(iter(test_loader)) 143 | 144 | for i in np.arange(args.nruntests) : 145 | print("\r", "\tTest run {}/{}".format(i+1, args.nruntests), end="") 146 | model = np.random.randint(args.numnetworks) 147 | model = models[model] 148 | 149 | samples[i,:,:] = torch.exp(model(images)) 150 | 151 | print("") 152 | 153 | withinSampleMean = torch.mean(samples, dim=0) 154 | samplesMean = torch.mean(samples, dim=(0,1)) 155 | 156 | withinSampleStd = torch.sqrt(torch.mean(torch.var(samples, dim=0), dim=0)) 157 | acrossSamplesStd = torch.std(withinSampleMean, dim=0) 158 | 159 | print("") 160 | print("Class prediction analysis:") 161 | print("\tMean class probabilities:") 162 | print(samplesMean) 163 | print("\tPrediction standard deviation per sample:") 164 | print(withinSampleStd) 165 | print("\tPrediction standard deviation across samples:") 166 | print(acrossSamplesStd) 167 | 168 | plt.figure("Seen class probabilities") 169 | plt.bar(np.arange(10), samplesMean.numpy()) 170 | plt.xlabel('digits') 171 | plt.ylabel('digit prob') 172 | plt.ylim([0,1]) 173 | plt.xticks(np.arange(10)) 174 | 175 | plt.figure("Seen inner and outter sample std") 176 | plt.bar(np.arange(10)-0.2, withinSampleStd.numpy(), width = 0.4, label="Within sample") 177 | plt.bar(np.arange(10)+0.2, acrossSamplesStd.numpy(), width = 0.4, label="Across samples") 178 | plt.legend() 179 | plt.xlabel('digits') 180 | plt.ylabel('std digit prob') 181 | plt.xticks(np.arange(10)) 182 | 183 | 184 | 185 | 186 | 187 | print("") 188 | print("Testing against unseen class:") 189 | 190 | with torch.no_grad() : 191 | 192 | samples = torch.zeros((args.nruntests, len(test_filtered), 10)) 193 | 194 | test_loader = DataLoader(test_filtered, batch_size=len(test_filtered)) 195 | images, labels = next(iter(test_loader)) 196 | 197 | for i in np.arange(args.nruntests) : 198 | print("\r", "\tTest run {}/{}".format(i+1, args.nruntests), end="") 199 | model = np.random.randint(args.numnetworks) 200 | model = models[model] 201 | 202 | samples[i,:,:] = torch.exp(model(images)) 203 | 204 | print("") 205 | 206 | withinSampleMean = torch.mean(samples, dim=0) 207 | samplesMean = torch.mean(samples, dim=(0,1)) 208 | 209 | withinSampleStd = torch.sqrt(torch.mean(torch.var(samples, dim=0), dim=0)) 210 | acrossSamplesStd = torch.std(withinSampleMean, dim=0) 211 | 212 | print("") 213 | print("Class prediction analysis:") 214 | print("\tMean class probabilities:") 215 | print(samplesMean) 216 | print("\tPrediction standard deviation per sample:") 217 | print(withinSampleStd) 218 | print("\tPrediction standard deviation across samples:") 219 | print(acrossSamplesStd) 220 | 221 | plt.figure("Unseen class probabilities") 222 | plt.bar(np.arange(10), samplesMean.numpy()) 223 | plt.xlabel('digits') 224 | plt.ylabel('digit prob') 225 | plt.ylim([0,1]) 226 | plt.xticks(np.arange(10)) 227 | 228 | plt.figure("Unseen inner and outter sample std") 229 | plt.bar(np.arange(10)-0.2, withinSampleStd.numpy(), width = 0.4, label="Within sample") 230 | plt.bar(np.arange(10)+0.2, acrossSamplesStd.numpy(), width = 0.4, label="Across samples") 231 | plt.legend() 232 | plt.xlabel('digits') 233 | plt.ylabel('std digit prob') 234 | plt.xticks(np.arange(10)) 235 | 236 | 237 | 238 | 239 | 240 | print("") 241 | print("Testing against pure white noise:") 242 | 243 | with torch.no_grad() : 244 | 245 | l = 1000 246 | 247 | samples = torch.zeros((args.nruntests, l, 10)) 248 | 249 | random = torch.rand((l,1,28,28)) 250 | 251 | for i in np.arange(args.nruntests) : 252 | print("\r", "\tTest run {}/{}".format(i+1, args.nruntests), end="") 253 | model = np.random.randint(args.numnetworks) 254 | model = models[model] 255 | 256 | samples[i,:,:] = torch.exp(model(random)) 257 | 258 | print("") 259 | 260 | withinSampleMean = torch.mean(samples, dim=0) 261 | samplesMean = torch.mean(samples, dim=(0,1)) 262 | 263 | withinSampleStd = torch.sqrt(torch.mean(torch.var(samples, dim=0), dim=0)) 264 | acrossSamplesStd = torch.std(withinSampleMean, dim=0) 265 | 266 | print("") 267 | print("Class prediction analysis:") 268 | print("\tMean class probabilities:") 269 | print(samplesMean) 270 | print("\tPrediction standard deviation per sample:") 271 | print(withinSampleStd) 272 | print("\tPrediction standard deviation across samples:") 273 | print(acrossSamplesStd) 274 | 275 | plt.figure("White noise class probabilities") 276 | plt.bar(np.arange(10), samplesMean.numpy()) 277 | plt.xlabel('digits') 278 | plt.ylabel('digit prob') 279 | plt.ylim([0,1]) 280 | plt.xticks(np.arange(10)) 281 | 282 | plt.figure("White noise inner and outter sample std") 283 | plt.bar(np.arange(10)-0.2, withinSampleStd.numpy(), width = 0.4, label="Within sample") 284 | plt.bar(np.arange(10)+0.2, acrossSamplesStd.numpy(), width = 0.4, label="Across samples") 285 | plt.legend() 286 | plt.xlabel('digits') 287 | plt.ylabel('std digit prob') 288 | plt.xticks(np.arange(10)) 289 | 290 | plt.show() -------------------------------------------------------------------------------- /viModel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon May 17 13:05:55 2021 5 | 6 | @author: laurent 7 | """ 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn.parameter import Parameter 14 | 15 | from torch.distributions.normal import Normal 16 | 17 | class VIModule(nn.Module) : 18 | """ 19 | A mixin class to attach loss functions to layer. This is usefull when doing variational inference with deep learning. 20 | """ 21 | 22 | def __init__(self, *args, **kwargs) : 23 | super().__init__(*args, **kwargs) 24 | 25 | self._internalLosses = [] 26 | self.lossScaleFactor = 1 27 | 28 | def addLoss(self, func) : 29 | self._internalLosses.append(func) 30 | 31 | def evalLosses(self) : 32 | t_loss = 0 33 | 34 | for l in self._internalLosses : 35 | t_loss = t_loss + l(self) 36 | 37 | return t_loss 38 | 39 | def evalAllLosses(self) : 40 | 41 | t_loss = self.evalLosses()*self.lossScaleFactor 42 | 43 | for m in self.children() : 44 | if isinstance(m, VIModule) : 45 | t_loss = t_loss + m.evalAllLosses()*self.lossScaleFactor 46 | 47 | return t_loss 48 | 49 | 50 | class MeanFieldGaussianFeedForward(VIModule) : 51 | """ 52 | A feed forward layer with a Gaussian prior distribution and a Gaussian variational posterior. 53 | """ 54 | 55 | def __init__(self, 56 | in_features, 57 | out_features, 58 | bias = True, 59 | groups=1, 60 | weightPriorMean = 0, 61 | weightPriorSigma = 1., 62 | biasPriorMean = 0, 63 | biasPriorSigma = 1., 64 | initMeanZero = False, 65 | initBiasMeanZero = False, 66 | initPriorSigmaScale = 0.01) : 67 | 68 | 69 | super(MeanFieldGaussianFeedForward, self).__init__() 70 | 71 | self.samples = {'weights' : None, 'bias' : None, 'wNoiseState' : None, 'bNoiseState' : None} 72 | 73 | self.in_features = in_features 74 | self.out_features = out_features 75 | self.has_bias = bias 76 | 77 | self.weights_mean = Parameter((0. if initMeanZero else 1.)*(torch.rand(out_features, int(in_features/groups))-0.5)) 78 | self.lweights_sigma = Parameter(torch.log(initPriorSigmaScale*weightPriorSigma*torch.ones(out_features, int(in_features/groups)))) 79 | 80 | self.noiseSourceWeights = Normal(torch.zeros(out_features, int(in_features/groups)), 81 | torch.ones(out_features, int(in_features/groups))) 82 | 83 | self.addLoss(lambda s : 0.5*s.getSampledWeights().pow(2).sum()/weightPriorSigma**2) 84 | self.addLoss(lambda s : -self.out_features/2*np.log(2*np.pi) - 0.5*s.samples['wNoiseState'].pow(2).sum() - s.lweights_sigma.sum()) 85 | 86 | if self.has_bias : 87 | self.bias_mean = Parameter((0. if initBiasMeanZero else 1.)*(torch.rand(out_features)-0.5)) 88 | self.lbias_sigma = Parameter(torch.log(initPriorSigmaScale*biasPriorSigma*torch.ones(out_features))) 89 | 90 | self.noiseSourceBias = Normal(torch.zeros(out_features), torch.ones(out_features)) 91 | 92 | self.addLoss(lambda s : 0.5*s.getSampledBias().pow(2).sum()/biasPriorSigma**2) 93 | self.addLoss(lambda s : -self.out_features/2*np.log(2*np.pi) - 0.5*s.samples['bNoiseState'].pow(2).sum() - self.lbias_sigma.sum()) 94 | 95 | 96 | def sampleTransform(self, stochastic=True) : 97 | self.samples['wNoiseState'] = self.noiseSourceWeights.sample().to(device=self.weights_mean.device) 98 | self.samples['weights'] = self.weights_mean + (torch.exp(self.lweights_sigma)*self.samples['wNoiseState'] if stochastic else 0) 99 | 100 | if self.has_bias : 101 | self.samples['bNoiseState'] = self.noiseSourceBias.sample().to(device=self.bias_mean.device) 102 | self.samples['bias'] = self.bias_mean + (torch.exp(self.lbias_sigma)*self.samples['bNoiseState'] if stochastic else 0) 103 | 104 | def getSampledWeights(self) : 105 | return self.samples['weights'] 106 | 107 | def getSampledBias(self) : 108 | return self.samples['bias'] 109 | 110 | def forward(self, x, stochastic=True) : 111 | 112 | self.sampleTransform(stochastic=stochastic) 113 | 114 | return nn.functional.linear(x, self.samples['weights'], bias = self.samples['bias'] if self.has_bias else None) 115 | 116 | 117 | class MeanFieldGaussian2DConvolution(VIModule) : 118 | """ 119 | A Bayesian module that fit a posterior gaussian distribution on a 2D convolution module with normal prior. 120 | """ 121 | 122 | def __init__(self, 123 | in_channels, 124 | out_channels, 125 | kernel_size, 126 | stride=1, 127 | padding=0, 128 | dilation=1, 129 | groups=1, 130 | bias=True, 131 | padding_mode='zeros', 132 | wPriorSigma = 1., 133 | bPriorSigma = 1., 134 | initMeanZero = False, 135 | initBiasMeanZero = False, 136 | initPriorSigmaScale = 0.01) : 137 | 138 | super(MeanFieldGaussian2DConvolution, self).__init__() 139 | 140 | self.samples = {'weights' : None, 'bias' : None, 'wNoiseState' : None, 'bNoiseState' : None} 141 | 142 | self.in_channels = in_channels 143 | self.out_channels = out_channels 144 | self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) 145 | self.stride = stride 146 | self.padding = padding 147 | self.dilation = dilation 148 | self.groups = groups 149 | self.has_bias = bias 150 | self.padding_mode = padding_mode 151 | 152 | 153 | self.weights_mean = Parameter((0. if initMeanZero else 1.)*(torch.rand(out_channels, int(in_channels/groups), self.kernel_size[0], self.kernel_size[1])-0.5)) 154 | self.lweights_sigma = Parameter(torch.log(initPriorSigmaScale*wPriorSigma*torch.ones(out_channels, int(in_channels/groups), self.kernel_size[0], self.kernel_size[1]))) 155 | 156 | self.noiseSourceWeights = Normal(torch.zeros(out_channels, int(in_channels/groups), self.kernel_size[0], self.kernel_size[1]), 157 | torch.ones(out_channels, int(in_channels/groups), self.kernel_size[0], self.kernel_size[1])) 158 | 159 | self.addLoss(lambda s : 0.5*s.getSampledWeights().pow(2).sum()/wPriorSigma**2) 160 | self.addLoss(lambda s : -self.out_channels/2*np.log(2*np.pi) - 0.5*s.samples['wNoiseState'].pow(2).sum() - s.lweights_sigma.sum()) 161 | 162 | 163 | if self.has_bias : 164 | self.bias_mean = Parameter((0. if initBiasMeanZero else 1.)*(torch.rand(out_channels)-0.5)) 165 | self.lbias_sigma = Parameter(torch.log(initPriorSigmaScale*bPriorSigma*torch.ones(out_channels))) 166 | 167 | self.noiseSourceBias = Normal(torch.zeros(out_channels), torch.ones(out_channels)) 168 | 169 | self.addLoss(lambda s : 0.5*s.getSampledBias().pow(2).sum()/bPriorSigma**2) 170 | self.addLoss(lambda s : -self.out_channels/2*np.log(2*np.pi) - 0.5*s.samples['bNoiseState'].pow(2).sum() - self.lbias_sigma.sum()) 171 | 172 | 173 | def sampleTransform(self, stochastic=True) : 174 | self.samples['wNoiseState'] = self.noiseSourceWeights.sample().to(device=self.weights_mean.device) 175 | self.samples['weights'] = self.weights_mean + (torch.exp(self.lweights_sigma)*self.samples['wNoiseState'] if stochastic else 0) 176 | 177 | if self.has_bias : 178 | self.samples['bNoiseState'] = self.noiseSourceBias.sample().to(device=self.bias_mean.device) 179 | self.samples['bias'] = self.bias_mean + (torch.exp(self.lbias_sigma)*self.samples['bNoiseState'] if stochastic else 0) 180 | 181 | def getSampledWeights(self) : 182 | return self.samples['weights'] 183 | 184 | def getSampledBias(self) : 185 | return self.samples['bias'] 186 | 187 | def forward(self, x, stochastic=True) : 188 | 189 | self.sampleTransform(stochastic=stochastic) 190 | 191 | if self.padding != 0 and self.padding != (0,0) : 192 | padkernel = (self.padding, self.padding, self.padding, self.padding) if isinstance(self.padding, int) else (self.padding[1], self.padding[1], self.padding[0], self.padding[0]) 193 | mx = nn.functional.pad(x, padkernel, mode=self.padding_mode, value=0) 194 | else : 195 | mx = x 196 | 197 | return nn.functional.conv2d(mx, 198 | self.samples['weights'], 199 | bias = self.samples['bias'] if self.has_bias else None, 200 | stride= self.stride, 201 | padding=0, 202 | dilation=self.dilation, 203 | groups=self.groups) 204 | 205 | class BayesianMnistNet(VIModule): 206 | def __init__(self, 207 | convWPriorSigma = 1., 208 | convBPriorSigma = 5., 209 | linearWPriorSigma = 1., 210 | linearBPriorSigma = 5., 211 | p_mc_dropout = 0.5) : 212 | 213 | super().__init__() 214 | 215 | self.p_mc_dropout = p_mc_dropout 216 | 217 | self.conv1 = MeanFieldGaussian2DConvolution(1, 16, 218 | wPriorSigma = convWPriorSigma, 219 | bPriorSigma = convBPriorSigma, 220 | kernel_size=5, 221 | initPriorSigmaScale=1e-7) 222 | self.conv2 = MeanFieldGaussian2DConvolution(16, 32, 223 | wPriorSigma = convWPriorSigma, 224 | bPriorSigma = convBPriorSigma, 225 | kernel_size=5, 226 | initPriorSigmaScale=1e-7) 227 | self.linear1 = MeanFieldGaussianFeedForward(512, 128, 228 | weightPriorSigma = linearWPriorSigma, 229 | biasPriorSigma = linearBPriorSigma, 230 | initPriorSigmaScale=1e-7) 231 | self.linear2 = MeanFieldGaussianFeedForward(128, 10, 232 | weightPriorSigma = linearWPriorSigma, 233 | biasPriorSigma = linearBPriorSigma, 234 | initPriorSigmaScale=1e-7) 235 | 236 | def forward(self, x, stochastic=True): 237 | 238 | x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x, stochastic=stochastic), 2)) 239 | x = self.conv2(x, stochastic=stochastic) 240 | 241 | if self.p_mc_dropout is not None : 242 | x = nn.functional.dropout2d(x, p = self.p_mc_dropout, training=stochastic) #MC-Dropout 243 | 244 | x = nn.functional.relu(nn.functional.max_pool2d(x, 2)) 245 | 246 | x = x.view(-1, 512) 247 | 248 | x = nn.functional.relu(self.linear1(x, stochastic=stochastic)) 249 | 250 | if self.p_mc_dropout is not None : 251 | x = nn.functional.dropout(x, p = self.p_mc_dropout, training=stochastic) #MC-Dropout 252 | 253 | x = self.linear2(x, stochastic=stochastic) 254 | return nn.functional.log_softmax(x, dim=-1) --------------------------------------------------------------------------------