├── LICENSE ├── README.md ├── assets ├── figure1.png ├── figure2.png ├── figure3.png ├── figure4.png └── poster.pdf ├── checkpoint ├── bayesian_svi_avu │ └── bayesian_resnet20_cifar.pth ├── bayesian_svi_avuts │ ├── model_with_temperature_avuts.pth │ └── valid_indices.pth └── deterministic │ └── resnet20_cifar.pth ├── models ├── bayesian │ ├── resnet.py │ ├── resnet_large.py │ └── simple_cnn.py └── deterministic │ ├── resnet.py │ ├── resnet_large.py │ └── simple_cnn.py ├── requirements.txt ├── scripts ├── test_bayesian_cifar.sh ├── test_bayesian_cifar_auavu.sh ├── test_bayesian_cifar_avu.sh ├── test_bayesian_imagenet_avu.sh ├── test_deterministic_cifar.sh ├── train_bayesian_cifar.sh ├── train_bayesian_cifar_auavu.sh ├── train_bayesian_cifar_avu.sh ├── train_bayesian_imagenet_avu.sh └── train_deterministic_cifar.sh ├── src ├── avuc_loss.py ├── main_bayesian_cifar.py ├── main_bayesian_cifar_auavu.py ├── main_bayesian_cifar_avu.py ├── main_bayesian_imagenet_avu.py ├── main_deterministic_cifar.py └── util.py └── variational_layers ├── conv_variational.py └── linear_variational.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Intel Labs 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Accuracy versus Uncertainty Calibration 2 | 3 | > :warning: **DISCONTINUATION OF PROJECT** - 4 | > *This project will no longer be maintained by Intel. 5 | > Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project.* 6 | > **Intel no longer accepts patches to this project.** 7 | > *If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project.* 8 | 9 | 10 | **[Overview](#overview)** 11 | | **[Requirements](#requirements)** 12 | | **[Example usage](#example-usage)** 13 | | **[Results](#results)** 14 | | **[Paper](https://papers.nips.cc/paper/2020/file/d3d9446802a44259755d38e6d163e820-Paper.pdf)** 15 | | **[Citing](#citing)** 16 | 17 | 18 | Code to accompany the paper [Improving model calibration with accuracy versus uncertainty optimization](https://papers.nips.cc/paper/2020/hash/d3d9446802a44259755d38e6d163e820-Abstract.html) [NeurIPS 2020]. 19 | 20 | **Abstract**: Obtaining reliable and accurate quantification of uncertainty estimates from deep neural networks is important in safety critical applications. A well-calibrated model should be accurate when it is certain about its prediction and indicate high uncertainty when it is likely to be inaccurate. Uncertainty calibration is a challenging problem as there is no ground truth available for uncertainty estimates. We propose an optimization method that leverages the relationship between accuracy and uncertainty as an anchor for uncertainty calibration. We introduce a differentiable *accuracy versus uncertainty calibration* (AvUC) loss function as an additional penalty term within loss-calibrated approximate inference framework. AvUC enables a model to learn to provide well-calibrated uncertainties, in addition to improved accuracy. We also demonstrate the same methodology can be extended to post-hoc uncertainty calibration on pretrained models. 21 | 22 | ## Overview 23 | This repository has code for *accuracy vs uncertainty calibration* (AvUC) loss and variational layers (convolutional and linear) to perform mean-field stochastic variational inference (SVI) in Bayesian neural networks. Implementations of SVI and SVI-AvUC methods with ResNet-20 (CIFAR10) and ResNet-50 (ImageNet) architectures. 24 | 25 | We propose an optimization method that leverages the relationship between accuracy and uncertainty as anchor for uncertainty calibration in deep neural network classifiers (Bayesian and non-Bayesian). 26 | We propose differentiable approximation to *accuracy vs uncertainty* (AvU) measure [[Mukhoti & Gal 2018](https://arxiv.org/abs/1811.12709)] and introduce trainable AvUC loss function. 27 | A task-specific utility function is employed in Bayesian decision theory [Berger 1985] to accomplish optimal predictions. 28 | In this work, AvU utility function is optimized during training for obtaining well-calibrated uncertainties along with improved accuracy. 29 | We use AvUC loss as an additional utility-dependent penalty term to accomplish the task of improving uncertainty calibration relying on the theoretically sound loss-calibrated 30 | approximate inference framework [[Lacoste-Julien et al. 2011](http://proceedings.mlr.press/v15/lacoste_julien11a.html), [Cobb et al. 2018](https://arxiv.org/abs/1805.03901)] 31 | rooted in Bayesian decision theory. 32 | ## Requirements 33 | This code has been tested on PyTorch v1.6.0 and torchvision v0.7.0 with python 3.7.7. 34 | 35 | Datasets: 36 | - CIFAR-10 [[Krizhevsky 2009](https://www.cs.toronto.edu/~kriz/cifar.html)] 37 | - ImageNet [[Deng et al. 2009](http://image-net.org/download)] (download dataset to data/imagenet folder) 38 | - CIFAR10 with corruptions [[Hendrycks & Dietterich 2019](https://github.com/hendrycks/robustness)] for dataset shift evaluation (download [CIFAR-10-C](https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1) dataset to 'data/CIFAR-10-C' folder) 39 | - ImageNet with corruptions [[Hendrycks & Dietterich 2019](https://github.com/hendrycks/robustness)] for dataset shift evaluation (download [ImageNet-C](https://zenodo.org/record/2235448#.X6u3NmhKjOg) dataset to 'data/ImageNet-C' folder) 40 | - SVHN [[Netzer et al. 2011](http://ufldl.stanford.edu/housenumbers/)] for out-of-distribution evaluation (download [file](https://zenodo.org/record/4267245/files/svhn-test.npy) to 'data/SVHN' folder) 41 | 42 | Dependencies: 43 | 44 | - Create conda environment with python=3.7 45 | - Install PyTorch and torchvision packages within conda environment following instructions from [PyTorch install guide](https://pytorch.org/get-started/locally/) 46 | - conda install -c conda-forge accimage 47 | - pip install tensorboard 48 | - pip install scikit-learn 49 | 50 | ## Example usage 51 | We provide [example usages](src/) for SVI-AvUC, SVI and Vanilla (deterministic) methods on CIFAR-10 and ImageNet to train/evaluate the [models](models) along with the implementation of [AvUC loss](src/avuc_loss.py) and [Bayesian layers](variational_layers/). 52 | 53 | ### Training 54 | 55 | To train the Bayesian ResNet-20 model on CIFAR10 with SVI-AvUC method, run this script: 56 | ```train_svi_avuc 57 | sh scripts/train_bayesian_cifar_avu.sh 58 | ``` 59 | 60 | 61 | To train the Bayesian ResNet-50 model on ImageNet with SVI-AvUC method, run this script: 62 | ```train_imagenet_svi_avuc 63 | sh scripts/train_bayesian_imagenet_avu.sh 64 | ``` 65 | ### Evaluation 66 | 67 | Our trained models can be downloaded from [here](https://zenodo.org/record/4267245#.X6uNBWhKiUm). Download and untar [SVI-AVUC ImageNet model](https://zenodo.org/record/4267245/files/svi-avuc-imagenet-trained-model.tar.gz?download=1) to 'checkpoint/imagenet/bayesian_svi_avu/' folder. 68 | 69 | 70 | To evaluate SVI-AvUC on CIFAR10, CIFAR10-C and SVHN, run the script below. Results (numpy files) will be saved in logs/cifar/bayesian_svi_avu/preds folder. 71 | 72 | ```eval_cifar 73 | sh scripts/test_bayesian_cifar_avu.sh 74 | ``` 75 | 76 | 77 | To evaluate SVI-AvUC on ImageNet and ImageNet-C, run the script below. Results (numpy files) will be saved in logs/imagenet/bayesian_svi_avu/preds folder. 78 | 79 | ```eval_imagenet 80 | sh scripts/test_bayesian_imagenet_avu.sh 81 | ``` 82 | 83 | ## Results 84 | 85 | 86 | **Model calibration under dataset shift** 87 | 88 | Figure below shows model calibration evaluation with Expected calibration error (ECE↓) and Expected uncertainty calibration error (UCE↓) on ImageNet under dataset shift. At each shift intensity level, the boxplot summarizes the results across 16 different dataset shift types. A well-calibrated model should 89 | provide lower calibration errors even at increased dataset shift. 90 | 91 |

92 | 93 |

94 | 95 | **Model performance with respect to confidence and uncertainty estimates** 96 | 97 | A reliable model should be accurate when it is certain about its prediction and indicate high uncertainty when it is likely to be inaccurate. We evaluate the quality of confidence and predictive uncertainty estimates using accuracy vs confidence and p(uncertain | inaccurate) plots respectively. 98 | The plots below indicate SVI-AvUC is more accurate at higher confidence and more uncertain when making inaccurate predictions under distributional shift (ImageNet corrupted with Gaussian blur), compared to other methods. 99 |

100 | 101 |

102 | 103 | **Model reliability towards out-of-distribution data** 104 | 105 | Out-of-distribution evaluation with SVHN data on the model trained with CIFAR10. SVI-AvUC has lesser number of examples with higher confidence and provides higher predictive uncertainty estimates on out-of-distribution data. 106 |

107 | 108 |

109 | 110 | **Distributional shift detection performance** 111 | 112 | Distributional shift detection using predictive uncertainty estimates. 113 | For dataset shift detection on ImageNet, test data corrupted with Gaussian blur of intensity level 5 is used. 114 | SVHN is used as out-of-distribution (OOD) data for OOD detection on model trained with CIFAR10. 115 | 116 |

117 | 118 |

119 | 120 | Please refer to the paper for more results. The results for Vanilla, Temp scaling, Ensemble and Dropout methods are computed from the model predictions provided in [UQ benchmark](https://console.cloud.google.com/storage/browser/uq-benchmark-2019) [[Ovadia et al. 2019](https://arxiv.org/abs/1906.02530)] 121 | 122 | ## Citing 123 | 124 | If you find this useful in your research, please cite as: 125 | ```sh 126 | @article{krishnan2020improving, 127 | author={Ranganath Krishnan and Omesh Tickoo}, 128 | title={Improving model calibration with accuracy versus uncertainty optimization}, 129 | journal={Advances in Neural Information Processing Systems}, 130 | year={2020} 131 | } 132 | ``` 133 | -------------------------------------------------------------------------------- /assets/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/figure1.png -------------------------------------------------------------------------------- /assets/figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/figure2.png -------------------------------------------------------------------------------- /assets/figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/figure3.png -------------------------------------------------------------------------------- /assets/figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/figure4.png -------------------------------------------------------------------------------- /assets/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/poster.pdf -------------------------------------------------------------------------------- /checkpoint/bayesian_svi_avu/bayesian_resnet20_cifar.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/checkpoint/bayesian_svi_avu/bayesian_resnet20_cifar.pth -------------------------------------------------------------------------------- /checkpoint/bayesian_svi_avuts/model_with_temperature_avuts.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/checkpoint/bayesian_svi_avuts/model_with_temperature_avuts.pth -------------------------------------------------------------------------------- /checkpoint/bayesian_svi_avuts/valid_indices.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/checkpoint/bayesian_svi_avuts/valid_indices.pth -------------------------------------------------------------------------------- /checkpoint/deterministic/resnet20_cifar.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/checkpoint/deterministic/resnet20_cifar.pth -------------------------------------------------------------------------------- /models/bayesian/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Bayesian ResNet for CIFAR10. 3 | 4 | ResNet architecture ref: 5 | https://arxiv.org/abs/1512.03385 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.init as init 12 | from variational_layers.conv_variational import Conv2dVariational 13 | from variational_layers.linear_variational import LinearVariational 14 | 15 | __all__ = [ 16 | 'ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110' 17 | ] 18 | 19 | prior_mu = 0.0 20 | prior_sigma = 1.0 21 | posterior_mu_init = 0.0 22 | posterior_rho_init = -2.0 23 | 24 | 25 | def _weights_init(m): 26 | classname = m.__class__.__name__ 27 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 28 | init.kaiming_normal_(m.weight) 29 | 30 | 31 | class LambdaLayer(nn.Module): 32 | def __init__(self, lambd): 33 | super(LambdaLayer, self).__init__() 34 | self.lambd = lambd 35 | 36 | def forward(self, x): 37 | return self.lambd(x) 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, in_planes, planes, stride=1, option='A'): 44 | super(BasicBlock, self).__init__() 45 | self.conv1 = Conv2dVariational(prior_mu, 46 | prior_sigma, 47 | posterior_mu_init, 48 | posterior_rho_init, 49 | in_planes, 50 | planes, 51 | kernel_size=3, 52 | stride=stride, 53 | padding=1, 54 | bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = Conv2dVariational(prior_mu, 57 | prior_sigma, 58 | posterior_mu_init, 59 | posterior_rho_init, 60 | planes, 61 | planes, 62 | kernel_size=3, 63 | stride=1, 64 | padding=1, 65 | bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | 68 | self.shortcut = nn.Sequential() 69 | if stride != 1 or in_planes != planes: 70 | if option == 'A': 71 | """ 72 | For CIFAR10 ResNet paper uses option A. 73 | """ 74 | self.shortcut = LambdaLayer(lambda x: F.pad( 75 | x[:, :, ::2, ::2], 76 | (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)) 77 | elif option == 'B': 78 | self.shortcut = nn.Sequential( 79 | Conv2dVariational(prior_mu, 80 | prior_sigma, 81 | posterior_mu_init, 82 | posterior_rho_init, 83 | in_planes, 84 | self.expansion * planes, 85 | kernel_size=1, 86 | stride=stride, 87 | bias=False), 88 | nn.BatchNorm2d(self.expansion * planes)) 89 | 90 | def forward(self, x): 91 | kl_sum = 0 92 | out, kl = self.conv1(x) 93 | kl_sum += kl 94 | out = self.bn1(out) 95 | out = F.relu(out) 96 | out, kl = self.conv2(out) 97 | kl_sum += kl 98 | out = self.bn2(out) 99 | out += self.shortcut(x) 100 | out = F.relu(out) 101 | return out, kl_sum 102 | 103 | 104 | class ResNet(nn.Module): 105 | def __init__(self, block, num_blocks, num_classes=10): 106 | super(ResNet, self).__init__() 107 | self.in_planes = 16 108 | 109 | self.conv1 = Conv2dVariational(prior_mu, 110 | prior_sigma, 111 | posterior_mu_init, 112 | posterior_rho_init, 113 | 3, 114 | 16, 115 | kernel_size=3, 116 | stride=1, 117 | padding=1, 118 | bias=False) 119 | self.bn1 = nn.BatchNorm2d(16) 120 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 121 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 122 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 123 | self.linear = LinearVariational(prior_mu, prior_sigma, 124 | posterior_mu_init, posterior_rho_init, 125 | 64, num_classes) 126 | 127 | self.apply(_weights_init) 128 | 129 | def _make_layer(self, block, planes, num_blocks, stride): 130 | strides = [stride] + [1] * (num_blocks - 1) 131 | layers = [] 132 | for stride in strides: 133 | layers.append(block(self.in_planes, planes, stride)) 134 | self.in_planes = planes * block.expansion 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | kl_sum = 0 140 | out, kl = self.conv1(x) 141 | kl_sum += kl 142 | out = self.bn1(out) 143 | out = F.relu(out) 144 | for l in self.layer1: 145 | out, kl = l(out) 146 | kl_sum += kl 147 | for l in self.layer2: 148 | out, kl = l(out) 149 | kl_sum += kl 150 | for l in self.layer3: 151 | out, kl = l(out) 152 | kl_sum += kl 153 | 154 | out = F.avg_pool2d(out, out.size()[3]) 155 | out = out.view(out.size(0), -1) 156 | out, kl = self.linear(out) 157 | kl_sum += kl 158 | return out, kl_sum 159 | 160 | 161 | def resnet20(): 162 | return ResNet(BasicBlock, [3, 3, 3]) 163 | 164 | 165 | def resnet32(): 166 | return ResNet(BasicBlock, [5, 5, 5]) 167 | 168 | 169 | def resnet44(): 170 | return ResNet(BasicBlock, [7, 7, 7]) 171 | 172 | 173 | def resnet56(): 174 | return ResNet(BasicBlock, [9, 9, 9]) 175 | 176 | 177 | def resnet110(): 178 | return ResNet(BasicBlock, [18, 18, 18]) 179 | 180 | 181 | def test(net): 182 | import numpy as np 183 | total_params = 0 184 | 185 | for x in filter(lambda p: p.requires_grad, net.parameters()): 186 | total_params += np.prod(x.data.numpy().shape) 187 | print("Total number of params", total_params) 188 | print( 189 | "Total layers", 190 | len( 191 | list( 192 | filter(lambda p: p.requires_grad and len(p.data.size()) > 1, 193 | net.parameters())))) 194 | 195 | 196 | if __name__ == "__main__": 197 | for net_name in __all__: 198 | if net_name.startswith('resnet'): 199 | print(net_name) 200 | test(globals()[net_name]()) 201 | print() 202 | -------------------------------------------------------------------------------- /models/bayesian/resnet_large.py: -------------------------------------------------------------------------------- 1 | # Bayesian ResNet for ImageNet 2 | # ResNet architecture ref: 3 | # https://arxiv.org/abs/1512.03385 4 | # Code adapted from torchvision package to build Bayesian model from deterministic model 5 | 6 | import torch.nn as nn 7 | import math 8 | import torch.utils.model_zoo as model_zoo 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.init as init 12 | from variational_layers.conv_variational import Conv2dVariational 13 | from variational_layers.linear_variational import LinearVariational 14 | 15 | __all__ = [ 16 | 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' 17 | ] 18 | 19 | prior_mu = 0.0 20 | prior_sigma = 1.0 21 | posterior_mu_init = 0.0 22 | posterior_rho_init = -9.0 23 | 24 | model_urls = { 25 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 26 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 27 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 28 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 29 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 30 | } 31 | 32 | 33 | def conv3x3(in_planes, out_planes, stride=1): 34 | """3x3 convolution with padding""" 35 | return Conv2dVariational(prior_mu, 36 | prior_sigma, 37 | posterior_mu_init, 38 | posterior_rho_init, 39 | in_planes, 40 | out_planes, 41 | kernel_size=3, 42 | stride=stride, 43 | padding=1, 44 | bias=False) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | residual = x 62 | kl_sum = 0 63 | out, kl = self.conv1(x) 64 | kl_sum += kl 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out, kl = self.conv2(out) 69 | kl_sum += kl 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.relu(out) 77 | 78 | return out, kl_sum 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None): 85 | super(Bottleneck, self).__init__() 86 | self.conv1 = Conv2dVariational(prior_mu, 87 | prior_sigma, 88 | posterior_mu_init, 89 | posterior_rho_init, 90 | inplanes, 91 | planes, 92 | kernel_size=1, 93 | bias=False) 94 | self.bn1 = nn.BatchNorm2d(planes) 95 | self.conv2 = Conv2dVariational(prior_mu, 96 | prior_sigma, 97 | posterior_mu_init, 98 | posterior_rho_init, 99 | planes, 100 | planes, 101 | kernel_size=3, 102 | stride=stride, 103 | padding=1, 104 | bias=False) 105 | self.bn2 = nn.BatchNorm2d(planes) 106 | self.conv3 = Conv2dVariational(prior_mu, 107 | prior_sigma, 108 | posterior_mu_init, 109 | posterior_rho_init, 110 | planes, 111 | planes * 4, 112 | kernel_size=1, 113 | bias=False) 114 | self.bn3 = nn.BatchNorm2d(planes * 4) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.downsample = downsample 117 | self.stride = stride 118 | 119 | def forward(self, x): 120 | residual = x 121 | kl_sum = 0 122 | out, kl = self.conv1(x) 123 | kl_sum += kl 124 | out = self.bn1(out) 125 | out = self.relu(out) 126 | 127 | out, kl = self.conv2(out) 128 | kl_sum += kl 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out, kl = self.conv3(out) 133 | kl_sum += kl 134 | out = self.bn3(out) 135 | 136 | if self.downsample is not None: 137 | residual = self.downsample(x) 138 | 139 | out += residual 140 | out = self.relu(out) 141 | 142 | return out, kl_sum 143 | 144 | 145 | class ResNet(nn.Module): 146 | def __init__(self, block, layers, num_classes=1000): 147 | self.inplanes = 64 148 | super(ResNet, self).__init__() 149 | self.conv1 = Conv2dVariational(prior_mu, 150 | prior_sigma, 151 | posterior_mu_init, 152 | posterior_rho_init, 153 | 3, 154 | 64, 155 | kernel_size=7, 156 | stride=2, 157 | padding=3, 158 | bias=False) 159 | self.bn1 = nn.BatchNorm2d(64) 160 | self.relu = nn.ReLU(inplace=True) 161 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 162 | self.layer1 = self._make_layer(block, 64, layers[0]) 163 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 164 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 165 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 166 | self.avgpool = nn.AvgPool2d(7, stride=1) 167 | self.fc = LinearVariational(prior_mu, prior_sigma, posterior_mu_init, 168 | posterior_rho_init, 512 * block.expansion, 169 | num_classes) 170 | 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | elif isinstance(m, nn.BatchNorm2d): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() 178 | 179 | def _make_layer(self, block, planes, blocks, stride=1): 180 | downsample = None 181 | if stride != 1 or self.inplanes != planes * block.expansion: 182 | downsample = nn.Sequential( 183 | nn.Conv2d(self.inplanes, 184 | planes * block.expansion, 185 | kernel_size=1, 186 | stride=stride, 187 | bias=False), 188 | nn.BatchNorm2d(planes * block.expansion), 189 | ) 190 | 191 | layers = [] 192 | layers.append(block(self.inplanes, planes, stride, downsample)) 193 | self.inplanes = planes * block.expansion 194 | for i in range(1, blocks): 195 | layers.append(block(self.inplanes, planes)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def forward(self, x): 200 | kl_sum = 0 201 | x, kl = self.conv1(x) 202 | kl_sum += kl 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | 207 | for layer in self.layer1: 208 | if 'Variational' in str(layer): 209 | x, kl = layer(x) 210 | if kl is None: 211 | kl_sum += kl 212 | else: 213 | x = layer(x) 214 | 215 | for layer in self.layer2: 216 | if 'Variational' in str(layer): 217 | x, kl = layer(x) 218 | if kl is None: 219 | kl_sum += kl 220 | else: 221 | x = layer(x) 222 | 223 | for layer in self.layer3: 224 | if 'Variational' in str(layer): 225 | x, kl = layer(x) 226 | if kl is None: 227 | kl_sum += kl 228 | else: 229 | x = layer(x) 230 | 231 | for layer in self.layer4: 232 | if 'Variational' in str(layer): 233 | x, kl = layer(x) 234 | if kl is None: 235 | kl_sum += kl 236 | else: 237 | x = layer(x) 238 | 239 | x = self.avgpool(x) 240 | x = x.view(x.size(0), -1) 241 | x, kl = self.fc(x) 242 | kl_sum += kl 243 | 244 | return x, kl_sum 245 | 246 | 247 | def resnet18(pretrained=False, **kwargs): 248 | """Constructs a ResNet-18 model. 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | """ 253 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 254 | if pretrained: 255 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 256 | return model 257 | 258 | 259 | def resnet34(pretrained=False, **kwargs): 260 | """Constructs a ResNet-34 model. 261 | 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | """ 265 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 266 | if pretrained: 267 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 268 | return model 269 | 270 | 271 | def resnet50(pretrained=False, **kwargs): 272 | """Constructs a ResNet-50 model. 273 | 274 | Args: 275 | pretrained (bool): If True, returns a model pre-trained on ImageNet 276 | """ 277 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 278 | if pretrained: 279 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 280 | return model 281 | 282 | 283 | def resnet101(pretrained=False, **kwargs): 284 | """Constructs a ResNet-101 model. 285 | 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | """ 289 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 290 | if pretrained: 291 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 292 | return model 293 | 294 | 295 | def resnet152(pretrained=False, **kwargs): 296 | """Constructs a ResNet-152 model. 297 | 298 | Args: 299 | pretrained (bool): If True, returns a model pre-trained on ImageNet 300 | """ 301 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 304 | return model 305 | -------------------------------------------------------------------------------- /models/bayesian/simple_cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from variational_layers.conv_variational import Conv2dVariational 8 | from variational_layers.linear_variational import LinearVariational 9 | 10 | prior_mu = 0.0 11 | prior_sigma = 1.0 12 | posterior_mu_init = 0.0 13 | posterior_rho_init = -3.0 14 | 15 | 16 | class SCNN(nn.Module): 17 | def __init__(self): 18 | super(SCNN, self).__init__() 19 | self.conv1 = Conv2dVariational(prior_mu, prior_sigma, 20 | posterior_mu_init, posterior_rho_init, 21 | 1, 32, 3, 1) 22 | self.conv2 = Conv2dVariational(prior_mu, prior_sigma, 23 | posterior_mu_init, posterior_rho_init, 24 | 32, 64, 3, 1) 25 | self.dropout1 = nn.Dropout2d(0.1) 26 | self.dropout2 = nn.Dropout2d(0.1) 27 | self.fc1 = LinearVariational(prior_mu, prior_sigma, posterior_mu_init, 28 | posterior_rho_init, 9216, 128) 29 | self.fc2 = LinearVariational(prior_mu, prior_sigma, posterior_mu_init, 30 | posterior_rho_init, 128, 10) 31 | 32 | def forward(self, x): 33 | kl_sum = 0 34 | x, kl = self.conv1(x) 35 | kl_sum += kl 36 | x = F.relu(x) 37 | x, kl = self.conv2(x) 38 | kl_sum += kl 39 | x = F.relu(x) 40 | x = F.max_pool2d(x, 2) 41 | x = self.dropout1(x) 42 | x = torch.flatten(x, 1) 43 | x, kl = self.fc1(x) 44 | kl_sum += kl 45 | x = F.relu(x) 46 | x = self.dropout2(x) 47 | x, kl = self.fc2(x) 48 | kl_sum += kl 49 | output = F.log_softmax(x, dim=1) 50 | return output, kl 51 | -------------------------------------------------------------------------------- /models/deterministic/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ResNet for CIFAR10. 3 | 4 | Ref for ResNet architecture: 5 | https://arxiv.org/abs/1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.nn.init as init 11 | 12 | __all__ = [ 13 | 'ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 14 | 'resnet1202' 15 | ] 16 | 17 | 18 | def _weights_init(m): 19 | classname = m.__class__.__name__ 20 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 21 | init.kaiming_normal_(m.weight) 22 | 23 | 24 | class LambdaLayer(nn.Module): 25 | def __init__(self, lambd): 26 | super(LambdaLayer, self).__init__() 27 | self.lambd = lambd 28 | 29 | def forward(self, x): 30 | return self.lambd(x) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, in_planes, planes, stride=1, option='A'): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = nn.Conv2d(in_planes, 39 | planes, 40 | kernel_size=3, 41 | stride=stride, 42 | padding=1, 43 | bias=False) 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d(planes, 46 | planes, 47 | kernel_size=3, 48 | stride=1, 49 | padding=1, 50 | bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != planes: 55 | if option == 'A': 56 | self.shortcut = LambdaLayer(lambda x: F.pad( 57 | x[:, :, ::2, ::2], 58 | (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)) 59 | elif option == 'B': 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, 62 | self.expansion * planes, 63 | kernel_size=1, 64 | stride=stride, 65 | bias=False), 66 | nn.BatchNorm2d(self.expansion * planes)) 67 | 68 | def forward(self, x): 69 | out = F.relu(self.bn1(self.conv1(x))) 70 | out = self.bn2(self.conv2(out)) 71 | out += self.shortcut(x) 72 | out = F.relu(out) 73 | return out 74 | 75 | 76 | class ResNet(nn.Module): 77 | def __init__(self, block, num_blocks, num_classes=10): 78 | super(ResNet, self).__init__() 79 | self.in_planes = 16 80 | 81 | self.conv1 = nn.Conv2d(3, 82 | 16, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1, 86 | bias=False) 87 | self.bn1 = nn.BatchNorm2d(16) 88 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 89 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 90 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 91 | self.linear = nn.Linear(64, num_classes) 92 | 93 | self.apply(_weights_init) 94 | 95 | def _make_layer(self, block, planes, num_blocks, stride): 96 | strides = [stride] + [1] * (num_blocks - 1) 97 | layers = [] 98 | for stride in strides: 99 | layers.append(block(self.in_planes, planes, stride)) 100 | self.in_planes = planes * block.expansion 101 | 102 | return nn.Sequential(*layers) 103 | 104 | def forward(self, x): 105 | out = F.relu(self.bn1(self.conv1(x))) 106 | out = self.layer1(out) 107 | out = self.layer2(out) 108 | out = self.layer3(out) 109 | out = F.avg_pool2d(out, out.size()[3]) 110 | out = out.view(out.size(0), -1) 111 | out = self.linear(out) 112 | return out 113 | 114 | 115 | def resnet20(): 116 | return ResNet(BasicBlock, [3, 3, 3]) 117 | 118 | 119 | def resnet32(): 120 | return ResNet(BasicBlock, [5, 5, 5]) 121 | 122 | 123 | def resnet44(): 124 | return ResNet(BasicBlock, [7, 7, 7]) 125 | 126 | 127 | def resnet56(): 128 | return ResNet(BasicBlock, [9, 9, 9]) 129 | 130 | 131 | def resnet110(): 132 | return ResNet(BasicBlock, [18, 18, 18]) 133 | 134 | 135 | def test(net): 136 | import numpy as np 137 | total_params = 0 138 | 139 | for x in filter(lambda p: p.requires_grad, net.parameters()): 140 | total_params += np.prod(x.data.numpy().shape) 141 | print("Total number of params", total_params) 142 | print( 143 | "Total layers", 144 | len( 145 | list( 146 | filter(lambda p: p.requires_grad and len(p.data.size()) > 1, 147 | net.parameters())))) 148 | 149 | 150 | if __name__ == "__main__": 151 | for net_name in __all__: 152 | if net_name.startswith('resnet'): 153 | print(net_name) 154 | test(globals()[net_name]()) 155 | print() 156 | -------------------------------------------------------------------------------- /models/deterministic/resnet_large.py: -------------------------------------------------------------------------------- 1 | # deterministic model from torchvision package 2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | 4 | import torch.nn as nn 5 | import math 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | __all__ = [ 9 | 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' 10 | ] 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, 24 | out_planes, 25 | kernel_size=3, 26 | stride=stride, 27 | padding=1, 28 | bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None): 35 | super(BasicBlock, self).__init__() 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | expansion = 4 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None): 67 | super(Bottleneck, self).__init__() 68 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.conv2 = nn.Conv2d(planes, 71 | planes, 72 | kernel_size=3, 73 | stride=stride, 74 | padding=1, 75 | bias=False) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(planes * 4) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | def __init__(self, block, layers, num_classes=1000): 108 | self.inplanes = 64 109 | super(ResNet, self).__init__() 110 | self.conv1 = nn.Conv2d(3, 111 | 64, 112 | kernel_size=7, 113 | stride=2, 114 | padding=3, 115 | bias=False) 116 | self.bn1 = nn.BatchNorm2d(64) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 119 | self.layer1 = self._make_layer(block, 64, layers[0]) 120 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 121 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 122 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 123 | self.avgpool = nn.AvgPool2d(7, stride=1) 124 | self.fc = nn.Linear(512 * block.expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, 139 | planes * block.expansion, 140 | kernel_size=1, 141 | stride=stride, 142 | bias=False), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = [] 147 | layers.append(block(self.inplanes, planes, stride, downsample)) 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes)) 151 | 152 | return nn.Sequential(*layers) 153 | 154 | def forward(self, x): 155 | x = self.conv1(x) 156 | x = self.bn1(x) 157 | x = self.relu(x) 158 | x = self.maxpool(x) 159 | 160 | x = self.layer1(x) 161 | x = self.layer2(x) 162 | x = self.layer3(x) 163 | x = self.layer4(x) 164 | 165 | x = self.avgpool(x) 166 | x = x.view(x.size(0), -1) 167 | x = self.fc(x) 168 | 169 | return x 170 | 171 | 172 | def resnet18(pretrained=False, **kwargs): 173 | """Constructs a ResNet-18 model. 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 179 | if pretrained: 180 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 181 | return model 182 | 183 | 184 | def resnet34(pretrained=False, **kwargs): 185 | """Constructs a ResNet-34 model. 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 191 | if pretrained: 192 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 193 | return model 194 | 195 | 196 | def resnet50(pretrained=False, **kwargs): 197 | """Constructs a ResNet-50 model. 198 | 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 203 | if pretrained: 204 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 205 | return model 206 | 207 | 208 | def resnet101(pretrained=False, **kwargs): 209 | """Constructs a ResNet-101 model. 210 | 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | """ 214 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 215 | if pretrained: 216 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 217 | return model 218 | 219 | 220 | def resnet152(pretrained=False, **kwargs): 221 | """Constructs a ResNet-152 model. 222 | 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | """ 226 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 227 | if pretrained: 228 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 229 | return model 230 | -------------------------------------------------------------------------------- /models/deterministic/simple_cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class SCNN(nn.Module): 9 | def __init__(self): 10 | super(SCNN, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 12 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 13 | self.dropout1 = nn.Dropout2d(0.23) 14 | self.dropout2 = nn.Dropout2d(0.23) 15 | self.fc1 = nn.Linear(9216, 128) 16 | self.fc2 = nn.Linear(128, 10) 17 | 18 | def forward(self, x): 19 | x = self.conv1(x) 20 | x = F.relu(x) 21 | x = self.conv2(x) 22 | x = F.relu(x) 23 | x = F.max_pool2d(x, 2) 24 | x = self.dropout1(x) 25 | x = torch.flatten(x, 1) 26 | x = self.fc1(x) 27 | x = F.relu(x) 28 | x = self.dropout2(x) 29 | x = self.fc2(x) 30 | output = F.log_softmax(x, dim=1) 31 | return output 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | scikit-learn 4 | tensorboard 5 | -------------------------------------------------------------------------------- /scripts/test_bayesian_cifar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet20 3 | mode='test' 4 | batch_size=10000 5 | num_monte_carlo=128 6 | 7 | python src/main_bayesian_cifar.py --arch=$model --mode=$mode --batch-size=$batch_size --num_monte_carlo=$num_monte_carlo 8 | -------------------------------------------------------------------------------- /scripts/test_bayesian_cifar_auavu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet20 3 | mode='test' 4 | batch_size=10000 5 | num_monte_carlo=128 6 | 7 | python src/main_bayesian_cifar_auavu.py --arch=$model --mode=$mode --batch-size=$batch_size --num_monte_carlo=$num_monte_carlo 8 | -------------------------------------------------------------------------------- /scripts/test_bayesian_cifar_avu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet20 3 | mode='test' 4 | batch_size=10000 5 | num_monte_carlo=128 6 | 7 | python src/main_bayesian_cifar_avu.py --arch=$model --mode=$mode --batch-size=$batch_size --num_monte_carlo=$num_monte_carlo 8 | -------------------------------------------------------------------------------- /scripts/test_bayesian_imagenet_avu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet50 3 | mode='test' 4 | val_batch_size=2500 5 | num_monte_carlo=128 6 | 7 | python src/main_bayesian_imagenet_avu.py data/imagenet --arch=$model --mode=$mode --val_batch_size=$val_batch_size --num_monte_carlo=$num_monte_carlo 8 | -------------------------------------------------------------------------------- /scripts/test_deterministic_cifar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet20 3 | mode='test' 4 | batch_size=10000 5 | 6 | python src/main_deterministic_cifar.py --arch=$model --mode=$mode --batch-size=$batch_size 7 | -------------------------------------------------------------------------------- /scripts/train_bayesian_cifar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet20 3 | mode='train' 4 | batch_size=107 5 | lr=0.001189 6 | 7 | python src/main_bayesian_cifar.py --lr=$lr --arch=$model --mode=$mode --batch-size=$batch_size 8 | -------------------------------------------------------------------------------- /scripts/train_bayesian_cifar_auavu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet20 3 | mode='train' 4 | batch_size=107 5 | lr=0.001189 6 | moped=False 7 | 8 | python src/main_bayesian_cifar_auavu.py --lr=$lr --arch=$model --mode=$mode --batch-size=$batch_size --moped=$moped 9 | -------------------------------------------------------------------------------- /scripts/train_bayesian_cifar_avu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet20 3 | mode='train' 4 | batch_size=107 5 | lr=0.001189 6 | moped=False 7 | 8 | python src/main_bayesian_cifar_avu.py --lr=$lr --arch=$model --mode=$mode --batch-size=$batch_size --moped=$moped 9 | -------------------------------------------------------------------------------- /scripts/train_bayesian_imagenet_avu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet50 3 | mode='train' 4 | batch_size=384 5 | lr=0.001 6 | moped=True 7 | 8 | python -u src/main_bayesian_imagenet_avu.py data/imagenet --lr=$lr --arch=$model --mode=$mode --batch-size=$batch_size --moped=$moped 9 | -------------------------------------------------------------------------------- /scripts/train_deterministic_cifar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet20 3 | mode='train' 4 | batch_size=512 5 | lr=0.001 6 | 7 | python src/main_deterministic_cifar.py --arch=$model --mode=$mode --batch-size=$batch_size 8 | -------------------------------------------------------------------------------- /src/avuc_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 Intel Corporation 2 | # 3 | # BSD-3-Clause License 4 | # 5 | # Redistribution and use in source and binary forms, with or without modification, 6 | # are permitted provided that the following conditions are met: 7 | # 1. Redistributions of source code must retain the above copyright notice, 8 | # this list of conditions and the following disclaimer. 9 | # 2. Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # 3. Neither the name of the copyright holder nor the names of its contributors 13 | # may be used to endorse or promote products derived from this software 14 | # without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 20 | # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 21 | # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT 22 | # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 23 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 24 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 25 | # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 26 | # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | # AvULoss -> compute accuracy versus uncertainty calibration loss 29 | # AUAvULoss -> compute accuracy versus uncertainty calibration loss 30 | # without uncertainty threshold 31 | # accuracy_versus_uncertainty -> compute AvU metric 32 | # eval_AvU -> get AvU scores at differemt uncertainty thresholds 33 | # predictive_entropy -> compute predictive uncertainty of the model 34 | # mutual_information -> compute model uncertainty of the model 35 | # 36 | # @authors: Ranganath Krishnan 37 | # 38 | # =============================================================================================== 39 | 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | import torch.nn.functional as F 44 | import torch 45 | from torch import nn 46 | import numpy as np 47 | from sklearn.metrics import auc 48 | 49 | 50 | class AUAvULoss(nn.Module): 51 | """ 52 | Calculates Accuracy vs Uncertainty Loss of a model. 53 | The input to this loss is logits from Monte_carlo sampling of the model, true labels, 54 | and the type of uncertainty to be used [0: predictive uncertainty (default); 55 | 1: model uncertainty] 56 | """ 57 | def __init__(self, beta=1): 58 | super(AUAvULoss, self).__init__() 59 | self.beta = beta 60 | self.eps = 1e-10 61 | 62 | def entropy(self, prob): 63 | return -1 * torch.sum(prob * torch.log(prob + self.eps), dim=-1) 64 | 65 | def expected_entropy(self, mc_preds): 66 | return torch.mean(self.entropy(mc_preds), dim=0) 67 | 68 | def predictive_uncertainty(self, mc_preds): 69 | """ 70 | Compute the entropy of the mean of the predictive distribution 71 | obtained from Monte Carlo sampling. 72 | """ 73 | return self.entropy(torch.mean(mc_preds, dim=0)) 74 | 75 | def model_uncertainty(self, mc_preds): 76 | """ 77 | Compute the difference between the entropy of the mean of the 78 | predictive distribution and the mean of the entropy. 79 | """ 80 | return self.entropy(torch.mean( 81 | mc_preds, dim=0)) - self.expected_entropy(mc_preds) 82 | 83 | def auc_avu(self, logits, labels, unc): 84 | """ returns AvU at various uncertainty thresholds""" 85 | th_list = np.linspace(0, 1, 21) 86 | umin = torch.min(unc) 87 | umax = torch.max(unc) 88 | avu_list = [] 89 | unc_list = [] 90 | 91 | probs = F.softmax(logits, dim=1) 92 | confidences, predictions = torch.max(probs, 1) 93 | 94 | auc_avu = torch.ones(1, device=labels.device) 95 | auc_avu.requires_grad_(True) 96 | 97 | for t in th_list: 98 | unc_th = umin + (torch.tensor(t) * (umax - umin)) 99 | n_ac = torch.zeros( 100 | 1, 101 | device=labels.device) # number of samples accurate and certain 102 | n_ic = torch.zeros(1, device=labels.device 103 | ) # number of samples inaccurate and certain 104 | n_au = torch.zeros(1, device=labels.device 105 | ) # number of samples accurate and uncertain 106 | n_iu = torch.zeros(1, device=labels.device 107 | ) # number of samples inaccurate and uncertain 108 | 109 | for i in range(len(labels)): 110 | if ((labels[i].item() == predictions[i].item()) 111 | and unc[i].item() <= unc_th.item()): 112 | """ accurate and certain """ 113 | n_ac += confidences[i] * (1 - torch.tanh(unc[i])) 114 | elif ((labels[i].item() == predictions[i].item()) 115 | and unc[i].item() > unc_th.item()): 116 | """ accurate and uncertain """ 117 | n_au += confidences[i] * torch.tanh(unc[i]) 118 | elif ((labels[i].item() != predictions[i].item()) 119 | and unc[i].item() <= unc_th.item()): 120 | """ inaccurate and certain """ 121 | n_ic += (1 - confidences[i]) * (1 - torch.tanh(unc[i])) 122 | elif ((labels[i].item() != predictions[i].item()) 123 | and unc[i].item() > unc_th.item()): 124 | """ inaccurate and uncertain """ 125 | n_iu += (1 - confidences[i]) * torch.tanh(unc[i]) 126 | 127 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + 1e-10) 128 | avu_list.append(AvU.data.cpu().numpy()) 129 | unc_list.append(unc_th) 130 | 131 | auc_avu = auc(th_list, avu_list) 132 | return auc_avu 133 | 134 | def accuracy_vs_uncertainty(self, prediction, true_label, uncertainty, 135 | optimal_threshold): 136 | n_ac = torch.zeros( 137 | 1, 138 | device=true_label.device) # number of samples accurate and certain 139 | n_ic = torch.zeros(1, device=true_label.device 140 | ) # number of samples inaccurate and certain 141 | n_au = torch.zeros(1, device=true_label.device 142 | ) # number of samples accurate and uncertain 143 | n_iu = torch.zeros(1, device=true_label.device 144 | ) # number of samples inaccurate and uncertain 145 | 146 | avu = torch.ones(1, device=true_label.device) 147 | avu.requires_grad_(True) 148 | 149 | for i in range(len(true_label)): 150 | if ((true_label[i].item() == prediction[i].item()) 151 | and uncertainty[i].item() <= optimal_threshold): 152 | """ accurate and certain """ 153 | n_ac += 1 154 | elif ((true_label[i].item() == prediction[i].item()) 155 | and uncertainty[i].item() > optimal_threshold): 156 | """ accurate and uncertain """ 157 | n_au += 1 158 | elif ((true_label[i].item() != prediction[i].item()) 159 | and uncertainty[i].item() <= optimal_threshold): 160 | """ inaccurate and certain """ 161 | n_ic += 1 162 | elif ((true_label[i].item() != prediction[i].item()) 163 | and uncertainty[i].item() > optimal_threshold): 164 | """ inaccurate and uncertain """ 165 | n_iu += 1 166 | 167 | print('n_ac: ', n_ac, ' ; n_au: ', n_au, ' ; n_ic: ', n_ic, ' ;n_iu: ', 168 | n_iu) 169 | avu = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu) 170 | 171 | return avu 172 | 173 | def forward(self, logits, labels, type=0): 174 | 175 | probs = F.softmax(logits, dim=1) 176 | confidences, predictions = torch.max(probs, 1) 177 | 178 | if type == 0: 179 | unc = self.entropy(probs) 180 | else: 181 | unc = self.model_uncertainty(probs) 182 | 183 | th_list = np.linspace(0, 1, 21) 184 | umin = torch.min(unc) 185 | umax = torch.max(unc) 186 | avu_list = [] 187 | unc_list = [] 188 | 189 | probs = F.softmax(logits, dim=1) 190 | confidences, predictions = torch.max(probs, 1) 191 | 192 | auc_avu = torch.ones(1, device=labels.device) 193 | auc_avu.requires_grad_(True) 194 | 195 | for t in th_list: 196 | unc_th = umin + (torch.tensor(t, device=labels.device) * 197 | (umax - umin)) 198 | n_ac = torch.zeros( 199 | 1, 200 | device=labels.device) # number of samples accurate and certain 201 | n_ic = torch.zeros(1, device=labels.device 202 | ) # number of samples inaccurate and certain 203 | n_au = torch.zeros(1, device=labels.device 204 | ) # number of samples accurate and uncertain 205 | n_iu = torch.zeros(1, device=labels.device 206 | ) # number of samples inaccurate and uncertain 207 | 208 | for i in range(len(labels)): 209 | if ((labels[i].item() == predictions[i].item()) 210 | and unc[i].item() <= unc_th.item()): 211 | """ accurate and certain """ 212 | n_ac += confidences[i] * (1 - torch.tanh(unc[i])) 213 | elif ((labels[i].item() == predictions[i].item()) 214 | and unc[i].item() > unc_th.item()): 215 | """ accurate and uncertain """ 216 | n_au += confidences[i] * torch.tanh(unc[i]) 217 | elif ((labels[i].item() != predictions[i].item()) 218 | and unc[i].item() <= unc_th.item()): 219 | """ inaccurate and certain """ 220 | n_ic += (1 - confidences[i]) * (1 - torch.tanh(unc[i])) 221 | elif ((labels[i].item() != predictions[i].item()) 222 | and unc[i].item() > unc_th.item()): 223 | """ inaccurate and uncertain """ 224 | n_iu += (1 - confidences[i]) * torch.tanh(unc[i]) 225 | 226 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + self.eps) 227 | avu_list.append(AvU) 228 | unc_list.append(unc_th) 229 | 230 | auc_avu = auc(th_list, avu_list) 231 | avu_loss = -1 * self.beta * torch.log(auc_avu + self.eps) 232 | return avu_loss, auc_avu 233 | 234 | 235 | class AvULoss(nn.Module): 236 | """ 237 | Calculates Accuracy vs Uncertainty Loss of a model. 238 | The input to this loss is logits from Monte_carlo sampling of the model, true labels, 239 | and the type of uncertainty to be used [0: predictive uncertainty (default); 240 | 1: model uncertainty] 241 | """ 242 | def __init__(self, beta=1): 243 | super(AvULoss, self).__init__() 244 | self.beta = beta 245 | self.eps = 1e-10 246 | 247 | def entropy(self, prob): 248 | return -1 * torch.sum(prob * torch.log(prob + self.eps), dim=-1) 249 | 250 | def expected_entropy(self, mc_preds): 251 | return torch.mean(self.entropy(mc_preds), dim=0) 252 | 253 | def predictive_uncertainty(self, mc_preds): 254 | """ 255 | Compute the entropy of the mean of the predictive distribution 256 | obtained from Monte Carlo sampling. 257 | """ 258 | return self.entropy(torch.mean(mc_preds, dim=0)) 259 | 260 | def model_uncertainty(self, mc_preds): 261 | """ 262 | Compute the difference between the entropy of the mean of the 263 | predictive distribution and the mean of the entropy. 264 | """ 265 | return self.entropy(torch.mean( 266 | mc_preds, dim=0)) - self.expected_entropy(mc_preds) 267 | 268 | def accuracy_vs_uncertainty(self, prediction, true_label, uncertainty, 269 | optimal_threshold): 270 | # number of samples accurate and certain 271 | n_ac = torch.zeros(1, device=true_label.device) 272 | # number of samples inaccurate and certain 273 | n_ic = torch.zeros(1, device=true_label.device) 274 | # number of samples accurate and uncertain 275 | n_au = torch.zeros(1, device=true_label.device) 276 | # number of samples inaccurate and uncertain 277 | n_iu = torch.zeros(1, device=true_label.device) 278 | 279 | avu = torch.ones(1, device=true_label.device) 280 | avu.requires_grad_(True) 281 | 282 | for i in range(len(true_label)): 283 | if ((true_label[i].item() == prediction[i].item()) 284 | and uncertainty[i].item() <= optimal_threshold): 285 | """ accurate and certain """ 286 | n_ac += 1 287 | elif ((true_label[i].item() == prediction[i].item()) 288 | and uncertainty[i].item() > optimal_threshold): 289 | """ accurate and uncertain """ 290 | n_au += 1 291 | elif ((true_label[i].item() != prediction[i].item()) 292 | and uncertainty[i].item() <= optimal_threshold): 293 | """ inaccurate and certain """ 294 | n_ic += 1 295 | elif ((true_label[i].item() != prediction[i].item()) 296 | and uncertainty[i].item() > optimal_threshold): 297 | """ inaccurate and uncertain """ 298 | n_iu += 1 299 | 300 | print('n_ac: ', n_ac, ' ; n_au: ', n_au, ' ; n_ic: ', n_ic, ' ;n_iu: ', 301 | n_iu) 302 | avu = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu) 303 | 304 | return avu 305 | 306 | def forward(self, logits, labels, optimal_uncertainty_threshold, type=0): 307 | 308 | probs = F.softmax(logits, dim=1) 309 | confidences, predictions = torch.max(probs, 1) 310 | 311 | if type == 0: 312 | unc = self.entropy(probs) 313 | else: 314 | unc = self.model_uncertainty(probs) 315 | 316 | unc_th = torch.tensor(optimal_uncertainty_threshold, 317 | device=logits.device) 318 | 319 | n_ac = torch.zeros( 320 | 1, device=logits.device) # number of samples accurate and certain 321 | n_ic = torch.zeros( 322 | 1, 323 | device=logits.device) # number of samples inaccurate and certain 324 | n_au = torch.zeros( 325 | 1, 326 | device=logits.device) # number of samples accurate and uncertain 327 | n_iu = torch.zeros( 328 | 1, 329 | device=logits.device) # number of samples inaccurate and uncertain 330 | 331 | avu = torch.ones(1, device=logits.device) 332 | avu_loss = torch.zeros(1, device=logits.device) 333 | 334 | for i in range(len(labels)): 335 | if ((labels[i].item() == predictions[i].item()) 336 | and unc[i].item() <= unc_th.item()): 337 | """ accurate and certain """ 338 | n_ac += confidences[i] * (1 - torch.tanh(unc[i])) 339 | elif ((labels[i].item() == predictions[i].item()) 340 | and unc[i].item() > unc_th.item()): 341 | """ accurate and uncertain """ 342 | n_au += confidences[i] * torch.tanh(unc[i]) 343 | elif ((labels[i].item() != predictions[i].item()) 344 | and unc[i].item() <= unc_th.item()): 345 | """ inaccurate and certain """ 346 | n_ic += (1 - confidences[i]) * (1 - torch.tanh(unc[i])) 347 | elif ((labels[i].item() != predictions[i].item()) 348 | and unc[i].item() > unc_th.item()): 349 | """ inaccurate and uncertain """ 350 | n_iu += (1 - confidences[i]) * torch.tanh(unc[i]) 351 | 352 | avu = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + self.eps) 353 | p_ac = (n_ac) / (n_ac + n_ic) 354 | p_ui = (n_iu) / (n_iu + n_ic) 355 | #print('Actual AvU: ', self.accuracy_vs_uncertainty(predictions, labels, uncertainty, optimal_threshold)) 356 | avu_loss = -1 * self.beta * torch.log(avu + self.eps) 357 | return avu_loss 358 | 359 | 360 | def entropy(prob): 361 | return -1 * np.sum(prob * np.log(prob + 1e-15), axis=-1) 362 | 363 | 364 | def predictive_entropy(mc_preds): 365 | """ 366 | Compute the entropy of the mean of the predictive distribution 367 | obtained from Monte Carlo sampling during prediction phase. 368 | """ 369 | return entropy(np.mean(mc_preds, axis=0)) 370 | 371 | 372 | def mutual_information(mc_preds): 373 | """ 374 | Compute the difference between the entropy of the mean of the 375 | predictive distribution and the mean of the entropy. 376 | """ 377 | MI = entropy(np.mean(mc_preds, axis=0)) - np.mean(entropy(mc_preds), 378 | axis=0) 379 | return MI 380 | 381 | 382 | def eval_avu(pred_label, true_label, uncertainty): 383 | """ returns AvU at various uncertainty thresholds""" 384 | t_list = np.linspace(0, 1, 21) 385 | umin = np.amin(uncertainty, axis=0) 386 | umax = np.amax(uncertainty, axis=0) 387 | avu_list = [] 388 | unc_list = [] 389 | for t in t_list: 390 | u_th = umin + (t * (umax - umin)) 391 | n_ac = 0 392 | n_ic = 0 393 | n_au = 0 394 | n_iu = 0 395 | for i in range(len(true_label)): 396 | if ((true_label[i] == pred_label[i]) and uncertainty[i] <= u_th): 397 | n_ac += 1 398 | elif ((true_label[i] == pred_label[i]) and uncertainty[i] > u_th): 399 | n_au += 1 400 | elif ((true_label[i] != pred_label[i]) and uncertainty[i] <= u_th): 401 | n_ic += 1 402 | elif ((true_label[i] != pred_label[i]) and uncertainty[i] > u_th): 403 | n_iu += 1 404 | 405 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + 1e-15) 406 | avu_list.append(AvU) 407 | unc_list.append(u_th) 408 | return np.asarray(avu_list), np.asarray(unc_list) 409 | 410 | 411 | def accuracy_vs_uncertainty(pred_label, true_label, uncertainty, 412 | optimal_threshold): 413 | 414 | n_ac = 0 415 | n_ic = 0 416 | n_au = 0 417 | n_iu = 0 418 | for i in range(len(true_label)): 419 | if ((true_label[i] == pred_label[i]) 420 | and uncertainty[i] <= optimal_threshold): 421 | n_ac += 1 422 | elif ((true_label[i] == pred_label[i]) 423 | and uncertainty[i] > optimal_threshold): 424 | n_au += 1 425 | elif ((true_label[i] != pred_label[i]) 426 | and uncertainty[i] <= optimal_threshold): 427 | n_ic += 1 428 | elif ((true_label[i] != pred_label[i]) 429 | and uncertainty[i] > optimal_threshold): 430 | n_iu += 1 431 | 432 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu) 433 | return AvU 434 | -------------------------------------------------------------------------------- /src/main_bayesian_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | import models.bayesian.resnet as resnet 15 | import numpy as np 16 | import csv 17 | #from utils import calib 18 | 19 | model_names = sorted( 20 | name for name in resnet.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and name.startswith("resnet") and callable(resnet.__dict__[name])) 23 | 24 | print(model_names) 25 | len_trainset = 50000 26 | len_testset = 10000 27 | 28 | parser = argparse.ArgumentParser(description='CIFAR10') 29 | parser.add_argument('--arch', 30 | '-a', 31 | metavar='ARCH', 32 | default='resnet20', 33 | choices=model_names, 34 | help='model architecture: ' + ' | '.join(model_names) + 35 | ' (default: resnet20)') 36 | parser.add_argument('-j', 37 | '--workers', 38 | default=8, 39 | type=int, 40 | metavar='N', 41 | help='number of data loading workers (default: 8)') 42 | parser.add_argument('--epochs', 43 | default=200, 44 | type=int, 45 | metavar='N', 46 | help='number of total epochs to run') 47 | parser.add_argument('--start-epoch', 48 | default=0, 49 | type=int, 50 | metavar='N', 51 | help='manual epoch number (useful on restarts)') 52 | parser.add_argument('-b', 53 | '--batch-size', 54 | default=512, 55 | type=int, 56 | metavar='N', 57 | help='mini-batch size (default: 512)') 58 | parser.add_argument('--lr', 59 | '--learning-rate', 60 | default=0.1, 61 | type=float, 62 | metavar='LR', 63 | help='initial learning rate') 64 | parser.add_argument('--momentum', 65 | default=0.9, 66 | type=float, 67 | metavar='M', 68 | help='momentum') 69 | parser.add_argument('--weight-decay', 70 | '--wd', 71 | default=1e-4, 72 | type=float, 73 | metavar='W', 74 | help='weight decay (default: 5e-4)') 75 | parser.add_argument('--print-freq', 76 | '-p', 77 | default=50, 78 | type=int, 79 | metavar='N', 80 | help='print frequency (default: 20)') 81 | parser.add_argument('--resume', 82 | default='', 83 | type=str, 84 | metavar='PATH', 85 | help='path to latest checkpoint (default: none)') 86 | parser.add_argument('-e', 87 | '--evaluate', 88 | dest='evaluate', 89 | action='store_true', 90 | help='evaluate model on validation set') 91 | parser.add_argument('--pretrained', 92 | dest='pretrained', 93 | action='store_true', 94 | help='use pre-trained model') 95 | parser.add_argument('--half', 96 | dest='half', 97 | action='store_true', 98 | help='use half-precision(16-bit) ') 99 | parser.add_argument('--save-dir', 100 | dest='save_dir', 101 | help='The directory used to save the trained models', 102 | default='./checkpoint/bayesian_svi', 103 | type=str) 104 | parser.add_argument( 105 | '--save-every', 106 | dest='save_every', 107 | help='Saves checkpoints at every specified number of epochs', 108 | type=int, 109 | default=10) 110 | parser.add_argument('--mode', type=str, required=True, help='train | test') 111 | parser.add_argument('--num_monte_carlo', 112 | type=int, 113 | default=20, 114 | metavar='N', 115 | help='number of Monte Carlo samples') 116 | parser.add_argument( 117 | '--tensorboard', 118 | type=bool, 119 | default=True, 120 | metavar='N', 121 | help='use tensorboard for logging and visualization of training progress') 122 | parser.add_argument( 123 | '--log_dir', 124 | type=str, 125 | default='./logs/cifar/bayesian_svi', 126 | metavar='N', 127 | help='use tensorboard for logging and visualization of training progress') 128 | best_prec1 = 0 129 | 130 | 131 | class CorruptDataset(torch.utils.data.Dataset): 132 | def __init__(self, data, target, transform=None): 133 | self.data = data 134 | self.target = target 135 | self.transform = transform 136 | 137 | def __getitem__(self, index): 138 | x = self.data[index] 139 | y = self.target[index] 140 | 141 | if self.transform: 142 | x = self.transform(x) 143 | 144 | return x, y 145 | 146 | def __len__(self): 147 | return len(self.data) 148 | 149 | 150 | class OODDataset(torch.utils.data.Dataset): 151 | def __init__(self, data, target, transform=None): 152 | self.data = data 153 | self.target = target 154 | self.transform = transform 155 | 156 | def __getitem__(self, index): 157 | x = self.data[index] 158 | y = self.target[index] 159 | 160 | if self.transform: 161 | x = self.transform(x) 162 | 163 | return x, y 164 | 165 | def __len__(self): 166 | return len(self.data) 167 | 168 | 169 | def get_ood_dataloader(ood_images, ood_labels): 170 | ood_dataset = OODDataset(ood_images, 171 | ood_labels, 172 | transform=transforms.Compose( 173 | [transforms.ToTensor()])) 174 | ood_data_loader = torch.utils.data.DataLoader(ood_dataset, 175 | batch_size=args.batch_size, 176 | shuffle=False, 177 | num_workers=args.workers, 178 | pin_memory=True) 179 | 180 | return ood_data_loader 181 | 182 | 183 | corruptions = [ 184 | 'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 185 | 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 186 | 'pixelate', 'saturate', 'shot_noise', 'spatter', 'speckle_noise', 187 | 'zoom_blur' 188 | ] 189 | 190 | 191 | def get_corrupt_dataloader(corrupted_images, corrupted_labels, level): 192 | corrupted_images_1 = corrupted_images[0:10000, :, :, :] 193 | corrupted_labels_1 = corrupted_labels[0:10000] 194 | corrupted_images_2 = corrupted_images[10000:20000, :, :, :] 195 | corrupted_labels_2 = corrupted_labels[10000:20000] 196 | corrupted_images_3 = corrupted_images[20000:30000, :, :, :] 197 | corrupted_labels_3 = corrupted_labels[20000:30000] 198 | corrupted_images_4 = corrupted_images[30000:40000, :, :, :] 199 | corrupted_labels_4 = corrupted_labels[30000:40000] 200 | corrupted_images_5 = corrupted_images[40000:50000, :, :, :] 201 | corrupted_labels_5 = corrupted_labels[40000:50000] 202 | if level == 1: 203 | corrupt_val_dataset = CorruptDataset(corrupted_images_1, 204 | corrupted_labels_1, 205 | transform=transforms.Compose( 206 | [transforms.ToTensor()])) 207 | elif level == 2: 208 | corrupt_val_dataset = CorruptDataset(corrupted_images_2, 209 | corrupted_labels_2, 210 | transform=transforms.Compose( 211 | [transforms.ToTensor()])) 212 | elif level == 3: 213 | corrupt_val_dataset = CorruptDataset(corrupted_images_3, 214 | corrupted_labels_3, 215 | transform=transforms.Compose( 216 | [transforms.ToTensor()])) 217 | elif level == 4: 218 | corrupt_val_dataset = CorruptDataset(corrupted_images_4, 219 | corrupted_labels_4, 220 | transform=transforms.Compose( 221 | [transforms.ToTensor()])) 222 | elif level == 5: 223 | corrupt_val_dataset = CorruptDataset(corrupted_images_5, 224 | corrupted_labels_5, 225 | transform=transforms.Compose( 226 | [transforms.ToTensor()])) 227 | 228 | corrupt_val_loader = torch.utils.data.DataLoader( 229 | corrupt_val_dataset, 230 | batch_size=args.batch_size, 231 | shuffle=False, 232 | num_workers=args.workers, 233 | pin_memory=True) 234 | 235 | return corrupt_val_loader 236 | 237 | 238 | def main(): 239 | global args, best_prec1 240 | args = parser.parse_args() 241 | 242 | # Check the save_dir exists or not 243 | if not os.path.exists(args.save_dir): 244 | os.makedirs(args.save_dir) 245 | 246 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) 247 | if torch.cuda.is_available(): 248 | model.cuda() 249 | else: 250 | model.cpu() 251 | 252 | # optionally resume from a checkpoint 253 | if args.resume: 254 | if os.path.isfile(args.resume): 255 | print("=> loading checkpoint '{}'".format(args.resume)) 256 | checkpoint = torch.load(args.resume) 257 | args.start_epoch = checkpoint['epoch'] 258 | best_prec1 = checkpoint['best_prec1'] 259 | model.load_state_dict(checkpoint['state_dict']) 260 | print("=> loaded checkpoint '{}' (epoch {})".format( 261 | args.evaluate, checkpoint['epoch'])) 262 | else: 263 | print("=> no checkpoint found at '{}'".format(args.resume)) 264 | 265 | cudnn.benchmark = True 266 | 267 | tb_writer = None 268 | if args.tensorboard: 269 | logger_dir = os.path.join(args.log_dir, 'tb_logger') 270 | if not os.path.exists(logger_dir): 271 | os.makedirs(logger_dir) 272 | tb_writer = SummaryWriter(logger_dir) 273 | 274 | preds_dir = os.path.join(args.log_dir, 'preds') 275 | if not os.path.exists(preds_dir): 276 | os.makedirs(preds_dir) 277 | results_dir = os.path.join(args.log_dir, 'results') 278 | if not os.path.exists(results_dir): 279 | os.makedirs(results_dir) 280 | 281 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 282 | std=[0.229, 0.224, 0.225]) 283 | 284 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( 285 | root='./data', 286 | train=True, 287 | transform=transforms.Compose([ 288 | transforms.RandomHorizontalFlip(), 289 | transforms.RandomCrop(32, 4), 290 | transforms.ToTensor() 291 | ]), 292 | download=True), 293 | batch_size=args.batch_size, 294 | shuffle=True, 295 | num_workers=args.workers, 296 | pin_memory=True) 297 | 298 | val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( 299 | root='./data', 300 | train=False, 301 | transform=transforms.Compose([transforms.ToTensor()])), 302 | batch_size=args.batch_size, 303 | shuffle=False, 304 | num_workers=args.workers, 305 | pin_memory=True) 306 | 307 | if not os.path.exists(args.save_dir): 308 | os.makedirs(args.save_dir) 309 | 310 | if torch.cuda.is_available(): 311 | criterion = nn.CrossEntropyLoss().cuda() 312 | else: 313 | criterion = nn.CrossEntropyLoss().cpu() 314 | 315 | if args.half: 316 | model.half() 317 | criterion.half() 318 | 319 | if args.arch in ['resnet110']: 320 | for param_group in optimizer.param_groups: 321 | param_group['lr'] = args.lr * 0.1 322 | 323 | if args.evaluate: 324 | validate(val_loader, model, criterion) 325 | return 326 | 327 | if args.mode == 'train': 328 | for epoch in range(args.start_epoch, args.epochs): 329 | 330 | lr = args.lr 331 | if (epoch >= 80 and epoch < 120): 332 | lr = 0.1 * args.lr 333 | elif (epoch >= 120 and epoch < 160): 334 | lr = 0.01 * args.lr 335 | elif (epoch >= 160 and epoch < 180): 336 | lr = 0.001 * args.lr 337 | elif (epoch >= 180): 338 | lr = 0.0005 * args.lr 339 | 340 | optimizer = torch.optim.Adam(model.parameters(), lr) 341 | # train for one epoch 342 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 343 | train(train_loader, model, criterion, optimizer, epoch, tb_writer) 344 | prec1 = validate(val_loader, model, criterion, epoch, tb_writer) 345 | is_best = prec1 > best_prec1 346 | best_prec1 = max(prec1, best_prec1) 347 | 348 | if epoch > 0: 349 | if is_best: 350 | save_checkpoint( 351 | { 352 | 'epoch': epoch + 1, 353 | 'state_dict': model.state_dict(), 354 | 'best_prec1': best_prec1, 355 | }, 356 | is_best, 357 | filename=os.path.join( 358 | args.save_dir, 359 | 'bayesian_{}_cifar.pth'.format(args.arch))) 360 | 361 | elif args.mode == 'test': 362 | checkpoint_file = args.save_dir + '/bayesian_{}_cifar.pth'.format( 363 | args.arch) 364 | if torch.cuda.is_available(): 365 | checkpoint = torch.load(checkpoint_file) 366 | else: 367 | checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) 368 | model.load_state_dict(checkpoint['state_dict']) 369 | 370 | #header = ['corrupt', 'test_acc', 'brier', 'ece'] 371 | header = ['corrupt', 'test_acc'] 372 | 373 | #Evaluate on OOD dataset (SVHN) 374 | ood_images_file = 'data/SVHN/svhn-test.npy' 375 | ood_images = np.load(ood_images_file) 376 | ood_images = ood_images[:10000, :, :, :] 377 | ood_labels = np.arange(len(ood_images)) + 10 #create dummy labels 378 | ood_loader = get_ood_dataloader(ood_images, ood_labels) 379 | ood_acc = evaluate(model, ood_loader, corrupt='ood', level=None) 380 | print('******OOD data***********\n') 381 | #print('ood_acc: ', ood_acc, ' | Brier: ', ood_brier, ' | ECE: ', ood_ece, '\n') 382 | print('ood_acc: ', ood_acc) 383 | ''' 384 | o_file = args.log_dir + '/results/ood_results.csv' 385 | with open(o_file, 'wt') as o_file: 386 | writer = csv.writer(o_file, delimiter=',', lineterminator='\n') 387 | writer.writerow([j for j in header]) 388 | writer.writerow(['ood', ood_acc, ood_brier, ood_ece]) 389 | o_file.close() 390 | ''' 391 | #Evaluate on test dataset 392 | test_acc = evaluate(model, val_loader, corrupt=None, level=None) 393 | print('******Test data***********\n') 394 | #print('test_acc: ', test_acc, ' | Brier: ', brier, ' | ECE: ', ece, '\n') 395 | print('test_acc: ', test_acc) 396 | ''' 397 | t_file = args.log_dir + '/test_results.csv' 398 | with open(t_file, 'wt') as t_file: 399 | writer = csv.writer(t_file, delimiter=',', lineterminator='\n') 400 | writer.writerow([j for j in header]) 401 | writer.writerow(['test', test_acc, brier, ece]) 402 | t_file.close() 403 | ''' 404 | 405 | for level in range(1, 6): 406 | print('******Corruption Level: ', level, ' ***********\n') 407 | results_file = args.log_dir + '/level' + str(level) + '.csv' 408 | with open(results_file, 'wt') as results_file: 409 | writer = csv.writer(results_file, 410 | delimiter=',', 411 | lineterminator='\n') 412 | writer.writerow([j for j in header]) 413 | for c in corruptions: 414 | images_file = 'data/CIFAR-10-C/' + c + '.npy' 415 | labels_file = 'data/CIFAR-10-C/labels.npy' 416 | corrupt_images = np.load(images_file) 417 | corrupt_labels = np.load(labels_file) 418 | val_loader = get_corrupt_dataloader( 419 | corrupt_images, corrupt_labels, level) 420 | test_acc = evaluate(model, 421 | val_loader, 422 | corrupt=c, 423 | level=level) 424 | print('############ Corruption type: ', c, 425 | ' ################') 426 | #print('test_acc: ', test_acc, ' | Brier: ', brier, ' | ECE: ', ece, '\n') 427 | print('test_acc: ', test_acc) 428 | #writer.writerow([c, test_acc, brier, ece]) 429 | writer.writerow([c, test_acc]) 430 | results_file.close() 431 | 432 | 433 | def train(train_loader, model, criterion, optimizer, epoch, tb_writer=None): 434 | batch_time = AverageMeter() 435 | data_time = AverageMeter() 436 | losses = AverageMeter() 437 | top1 = AverageMeter() 438 | 439 | # switch to train mode 440 | model.train() 441 | 442 | end = time.time() 443 | for i, (input, target) in enumerate(train_loader): 444 | 445 | # measure data loading time 446 | data_time.update(time.time() - end) 447 | 448 | if torch.cuda.is_available(): 449 | target = target.cuda() 450 | input_var = input.cuda() 451 | target_var = target.cuda() 452 | else: 453 | target = target.cpu() 454 | input_var = input.cpu() 455 | target_var = target.cpu() 456 | if args.half: 457 | input_var = input_var.half() 458 | 459 | # compute output 460 | output, kl = model(input_var) 461 | cross_entropy_loss = criterion(output, target_var) 462 | scaled_kl = kl.data / (len_trainset) 463 | loss = cross_entropy_loss + scaled_kl 464 | 465 | # compute gradient and do SGD step 466 | optimizer.zero_grad() 467 | loss.backward() 468 | optimizer.step() 469 | 470 | output = output.float() 471 | loss = loss.float() 472 | # measure accuracy and record loss 473 | prec1 = accuracy(output.data, target)[0] 474 | losses.update(loss.item(), input.size(0)) 475 | top1.update(prec1.item(), input.size(0)) 476 | 477 | # measure elapsed time 478 | batch_time.update(time.time() - end) 479 | end = time.time() 480 | 481 | if i % args.print_freq == 0: 482 | print('Epoch: [{0}][{1}/{2}]\t' 483 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 484 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 485 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 486 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 487 | epoch, 488 | i, 489 | len(train_loader), 490 | batch_time=batch_time, 491 | data_time=data_time, 492 | loss=losses, 493 | top1=top1)) 494 | 495 | if tb_writer is not None: 496 | tb_writer.add_scalar('train/cross_entropy_loss', 497 | cross_entropy_loss.item(), epoch) 498 | tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch) 499 | tb_writer.add_scalar('train/elbo_loss', loss.item(), epoch) 500 | tb_writer.add_scalar('train/accuracy', prec1.item(), epoch) 501 | tb_writer.flush() 502 | 503 | 504 | def validate(val_loader, model, criterion, epoch, tb_writer=None): 505 | batch_time = AverageMeter() 506 | losses = AverageMeter() 507 | top1 = AverageMeter() 508 | 509 | # switch to evaluate mode 510 | model.eval() 511 | 512 | end = time.time() 513 | with torch.no_grad(): 514 | for i, (input, target) in enumerate(val_loader): 515 | if torch.cuda.is_available(): 516 | target = target.cuda() 517 | input_var = input.cuda() 518 | target_var = target.cuda() 519 | else: 520 | target = target.cpu() 521 | input_var = input.cpu() 522 | target_var = target.cpu() 523 | 524 | if args.half: 525 | input_var = input_var.half() 526 | 527 | # compute output 528 | output, kl = model(input_var) 529 | cross_entropy_loss = criterion(output, target_var) 530 | scaled_kl = kl.data / (len_trainset) 531 | loss = cross_entropy_loss + scaled_kl 532 | 533 | output = output.float() 534 | loss = loss.float() 535 | 536 | # measure accuracy and record loss 537 | prec1 = accuracy(output.data, target)[0] 538 | losses.update(loss.item(), input.size(0)) 539 | top1.update(prec1.item(), input.size(0)) 540 | 541 | # measure elapsed time 542 | batch_time.update(time.time() - end) 543 | end = time.time() 544 | 545 | if i % args.print_freq == 0: 546 | print('Test: [{0}/{1}]\t' 547 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 548 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 549 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 550 | i, 551 | len(val_loader), 552 | batch_time=batch_time, 553 | loss=losses, 554 | top1=top1)) 555 | 556 | if tb_writer is not None: 557 | tb_writer.add_scalar('val/cross_entropy_loss', 558 | cross_entropy_loss.item(), epoch) 559 | tb_writer.add_scalar('val/kl_div', scaled_kl.item(), epoch) 560 | tb_writer.add_scalar('val/elbo_loss', loss.item(), epoch) 561 | tb_writer.add_scalar('val/accuracy', prec1.item(), epoch) 562 | tb_writer.flush() 563 | 564 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 565 | 566 | return top1.avg 567 | 568 | 569 | def evaluate(model, val_loader, corrupt=None, level=None): 570 | pred_probs_mc = [] 571 | test_loss = 0 572 | correct = 0 573 | with torch.no_grad(): 574 | pred_probs_mc = [] 575 | for batch_idx, (data, target) in enumerate(val_loader): 576 | #print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape)) 577 | if torch.cuda.is_available(): 578 | data, target = data.cuda(), target.cuda() 579 | else: 580 | data, target = data.cpu(), target.cpu() 581 | for mc_run in range(args.num_monte_carlo): 582 | model.eval() 583 | output, _ = model.forward(data) 584 | pred_probs = torch.nn.functional.softmax(output, dim=1) 585 | pred_probs_mc.append(pred_probs.cpu().data.numpy()) 586 | 587 | if corrupt == 'ood': 588 | np.save(args.log_dir + '/preds/svi_ood_probs.npy', pred_probs_mc) 589 | print('saved predictions') 590 | return None 591 | 592 | target_labels = target.cpu().data.numpy() 593 | pred_mean = np.mean(pred_probs_mc, axis=0) 594 | #print(pred_mean) 595 | Y_pred = np.argmax(pred_mean, axis=1) 596 | test_acc = (Y_pred == target_labels).mean() 597 | #brier = np.mean(calib.brier_scores(target_labels, probs=pred_mean)) 598 | #ece = calib.expected_calibration_error_multiclass(pred_mean, target_labels) 599 | print('Test accuracy:', test_acc * 100) 600 | #print('Brier score: ', brier) 601 | #print('ECE: ', ece) 602 | if corrupt is not None: 603 | np.save( 604 | args.log_dir + 605 | '/preds/svi_corrupt-static-{}-{}_probs.npy'.format( 606 | corrupt, level), pred_probs_mc) 607 | np.save( 608 | args.log_dir + 609 | '/preds/svi_corrupt-static-{}-{}_labels.npy'.format( 610 | corrupt, level), target_labels) 611 | print('saved predictions') 612 | else: 613 | np.save(args.log_dir + '/preds/svi_test_probs.npy', pred_probs_mc) 614 | np.save(args.log_dir + '/preds/svi_test_labels.npy', target_labels) 615 | print('saved predictions') 616 | 617 | return test_acc 618 | 619 | 620 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 621 | """ 622 | Save the training model 623 | """ 624 | torch.save(state, filename) 625 | 626 | 627 | class AverageMeter(object): 628 | """Computes and stores the average and current value""" 629 | def __init__(self): 630 | self.reset() 631 | 632 | def reset(self): 633 | self.val = 0 634 | self.avg = 0 635 | self.sum = 0 636 | self.count = 0 637 | 638 | def update(self, val, n=1): 639 | self.val = val 640 | self.sum += val * n 641 | self.count += n 642 | self.avg = self.sum / self.count 643 | 644 | 645 | def accuracy(output, target, topk=(1, )): 646 | """Computes the precision@k for the specified values of k""" 647 | maxk = max(topk) 648 | batch_size = target.size(0) 649 | 650 | _, pred = output.topk(maxk, 1, True, True) 651 | pred = pred.t() 652 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 653 | 654 | res = [] 655 | for k in topk: 656 | correct_k = correct[:k].view(-1).float().sum(0) 657 | res.append(correct_k.mul_(100.0 / batch_size)) 658 | return res 659 | 660 | 661 | if __name__ == '__main__': 662 | main() 663 | -------------------------------------------------------------------------------- /src/main_bayesian_cifar_auavu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.utils.data 12 | from torch.utils.tensorboard import SummaryWriter 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | import models.bayesian.resnet as resnet 16 | import models.deterministic.resnet as det_resnet 17 | import numpy as np 18 | from src import util 19 | #from utils import calib 20 | import csv 21 | from src.util import get_rho 22 | from src.avuc_loss import AvULoss, AUAvULoss 23 | from torchsummary import summary 24 | 25 | model_names = sorted( 26 | name for name in resnet.__dict__ 27 | if name.islower() and not name.startswith("__") 28 | and name.startswith("resnet") and callable(resnet.__dict__[name])) 29 | 30 | print(model_names) 31 | 32 | parser = argparse.ArgumentParser(description='CIFAR10') 33 | parser.add_argument('--arch', 34 | '-a', 35 | metavar='ARCH', 36 | default='resnet20', 37 | choices=model_names, 38 | help='model architecture: ' + ' | '.join(model_names) + 39 | ' (default: resnet20)') 40 | parser.add_argument('-j', 41 | '--workers', 42 | default=8, 43 | type=int, 44 | metavar='N', 45 | help='number of data loading workers (default: 8)') 46 | parser.add_argument('--epochs', 47 | default=200, 48 | type=int, 49 | metavar='N', 50 | help='number of total epochs to run') 51 | parser.add_argument('--start-epoch', 52 | default=0, 53 | type=int, 54 | metavar='N', 55 | help='manual epoch number (useful on restarts)') 56 | parser.add_argument('-b', 57 | '--batch-size', 58 | default=512, 59 | type=int, 60 | metavar='N', 61 | help='mini-batch size (default: 512)') 62 | parser.add_argument('--lr', 63 | '--learning-rate', 64 | default=0.1, 65 | type=float, 66 | metavar='LR', 67 | help='initial learning rate') 68 | parser.add_argument('--momentum', 69 | default=0.9, 70 | type=float, 71 | metavar='M', 72 | help='momentum') 73 | parser.add_argument('--weight-decay', 74 | '--wd', 75 | default=1e-4, 76 | type=float, 77 | metavar='W', 78 | help='weight decay (default: 5e-4)') 79 | parser.add_argument('--print-freq', 80 | '-p', 81 | default=50, 82 | type=int, 83 | metavar='N', 84 | help='print frequency (default: 20)') 85 | parser.add_argument('--resume', 86 | default='', 87 | type=str, 88 | metavar='PATH', 89 | help='path to latest checkpoint (default: none)') 90 | parser.add_argument('-e', 91 | '--evaluate', 92 | dest='evaluate', 93 | action='store_true', 94 | help='evaluate model on validation set') 95 | parser.add_argument('--pretrained', 96 | dest='pretrained', 97 | action='store_true', 98 | help='use pre-trained model') 99 | parser.add_argument('--half', 100 | dest='half', 101 | action='store_true', 102 | help='use half-precision(16-bit) ') 103 | parser.add_argument('--save-dir', 104 | dest='save_dir', 105 | help='The directory used to save the trained models', 106 | default='./checkpoint/bayesian_svi_auavu', 107 | type=str) 108 | parser.add_argument( 109 | '--save-every', 110 | dest='save_every', 111 | help='Saves checkpoints at every specified number of epochs', 112 | type=int, 113 | default=10) 114 | parser.add_argument('--mode', type=str, required=True, help='train | test') 115 | parser.add_argument('--num_monte_carlo', 116 | type=int, 117 | default=20, 118 | metavar='N', 119 | help='number of Monte Carlo samples') 120 | parser.add_argument( 121 | '--tensorboard', 122 | type=bool, 123 | default=True, 124 | metavar='N', 125 | help='use tensorboard for logging and visualization of training progress') 126 | parser.add_argument('--val_batch_size', default=10000, type=int) 127 | parser.add_argument( 128 | '--log_dir', 129 | type=str, 130 | default='./logs/cifar/bayesian_svi_auavu', 131 | metavar='N', 132 | help='use tensorboard for logging and visualization of training progress') 133 | parser.add_argument( 134 | '--moped', 135 | type=bool, 136 | default=False, 137 | help='set prior and initialize approx posterior with Empirical Bayes') 138 | parser.add_argument('--delta', 139 | type=float, 140 | default=1.0, 141 | help='delta value for variance scaling in MOPED') 142 | len_trainset = 50000 143 | len_testset = 10000 144 | beta = 3.0 145 | optimal_threshold = 1.0 146 | opt_th = optimal_threshold 147 | best_prec1 = 0 148 | best_avu = 0 149 | 150 | 151 | class CorruptDataset(torch.utils.data.Dataset): 152 | def __init__(self, data, target, transform=None): 153 | self.data = data 154 | self.target = target 155 | self.transform = transform 156 | 157 | def __getitem__(self, index): 158 | x = self.data[index] 159 | y = self.target[index] 160 | 161 | if self.transform: 162 | x = self.transform(x) 163 | 164 | return x, y 165 | 166 | def __len__(self): 167 | return len(self.data) 168 | 169 | 170 | class OODDataset(torch.utils.data.Dataset): 171 | def __init__(self, data, target, transform=None): 172 | self.data = data 173 | self.target = target 174 | self.transform = transform 175 | 176 | def __getitem__(self, index): 177 | x = self.data[index] 178 | y = self.target[index] 179 | 180 | if self.transform: 181 | x = self.transform(x) 182 | 183 | return x, y 184 | 185 | def __len__(self): 186 | return len(self.data) 187 | 188 | 189 | def get_ood_dataloader(ood_images, ood_labels): 190 | ood_dataset = OODDataset(ood_images, 191 | ood_labels, 192 | transform=transforms.Compose( 193 | [transforms.ToTensor()])) 194 | ood_data_loader = torch.utils.data.DataLoader(ood_dataset, 195 | batch_size=args.batch_size, 196 | shuffle=False, 197 | num_workers=args.workers, 198 | pin_memory=True) 199 | return ood_data_loader 200 | 201 | 202 | corruptions = [ 203 | 'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 204 | 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 205 | 'pixelate', 'saturate', 'shot_noise', 'spatter', 'speckle_noise', 206 | 'zoom_blur' 207 | ] 208 | 209 | 210 | def get_corrupt_dataloader(corrupted_images, corrupted_labels, level): 211 | 212 | corrupted_images_1 = corrupted_images[0:10000, :, :, :] 213 | corrupted_labels_1 = corrupted_labels[0:10000] 214 | corrupted_images_2 = corrupted_images[10000:20000, :, :, :] 215 | corrupted_labels_2 = corrupted_labels[10000:20000] 216 | corrupted_images_3 = corrupted_images[20000:30000, :, :, :] 217 | corrupted_labels_3 = corrupted_labels[20000:30000] 218 | corrupted_images_4 = corrupted_images[30000:40000, :, :, :] 219 | corrupted_labels_4 = corrupted_labels[30000:40000] 220 | corrupted_images_5 = corrupted_images[40000:50000, :, :, :] 221 | corrupted_labels_5 = corrupted_labels[40000:50000] 222 | 223 | if level == 1: 224 | corrupt_val_dataset = CorruptDataset(corrupted_images_1, 225 | corrupted_labels_1, 226 | transform=transforms.Compose( 227 | [transforms.ToTensor()])) 228 | elif level == 2: 229 | corrupt_val_dataset = CorruptDataset(corrupted_images_2, 230 | corrupted_labels_2, 231 | transform=transforms.Compose( 232 | [transforms.ToTensor()])) 233 | elif level == 3: 234 | corrupt_val_dataset = CorruptDataset(corrupted_images_3, 235 | corrupted_labels_3, 236 | transform=transforms.Compose( 237 | [transforms.ToTensor()])) 238 | elif level == 4: 239 | corrupt_val_dataset = CorruptDataset(corrupted_images_4, 240 | corrupted_labels_4, 241 | transform=transforms.Compose( 242 | [transforms.ToTensor()])) 243 | elif level == 5: 244 | corrupt_val_dataset = CorruptDataset(corrupted_images_5, 245 | corrupted_labels_5, 246 | transform=transforms.Compose( 247 | [transforms.ToTensor()])) 248 | 249 | corrupt_val_loader = torch.utils.data.DataLoader( 250 | corrupt_val_dataset, 251 | batch_size=args.val_batch_size, 252 | shuffle=False, 253 | num_workers=args.workers, 254 | pin_memory=True) 255 | 256 | return corrupt_val_loader 257 | 258 | 259 | def MOPED_layer(layer, det_layer, delta): 260 | """ 261 | Set the priors and initialize surrogate posteriors of Bayesian NN with Empirical Bayes 262 | MOPED (Model Priors with Empirical Bayes using Deterministic DNN) 263 | Ref: https://arxiv.org/abs/1906.05323 264 | 'Specifying Weight Priors in Bayesian Deep Neural Networks with Empirical Bayes'. AAAI 2020. 265 | """ 266 | 267 | if (str(layer) == 'Conv2dVariational()'): 268 | #set the priors 269 | print(str(layer)) 270 | layer.prior_weight_mu = det_layer.weight.data 271 | if layer.prior_bias_mu is not None: 272 | layer.prior_bias_mu = det_layer.bias.data 273 | 274 | #initialize surrogate posteriors 275 | layer.mu_kernel.data = det_layer.weight.data 276 | #layer.rho_kernel.data = get_rho(det_layer.weight.data, delta) 277 | if layer.mu_bias is not None: 278 | layer.mu_bias.data = det_layer.bias.data 279 | #layer.rho_bias.data = get_rho(det_layer.bias.data, delta) 280 | 281 | elif (isinstance(layer, nn.Conv2d)): 282 | print(str(layer)) 283 | layer.weight.data = det_layer.weight.data 284 | if layer.bias is not None: 285 | layer.bias.data = det_layer.bias.data2 286 | 287 | elif (str(layer) == 'LinearVariational()'): 288 | print(str(layer)) 289 | layer.prior_weight_mu = det_layer.weight.data 290 | if layer.prior_bias_mu is not None: 291 | layer.prior_bias_mu = det_layer.bias.data 292 | 293 | #initialize the surrogate posteriors 294 | 295 | layer.mu_weight.data = det_layer.weight.data 296 | layer.rho_weight.data = get_rho(det_layer.weight.data, delta) 297 | if layer.mu_bias is not None: 298 | layer.mu_bias.data = det_layer.bias.data 299 | layer.rho_bias.data = get_rho(det_layer.bias.data, delta) 300 | 301 | elif str(layer).startswith('Batch'): 302 | #initialize parameters 303 | print(str(layer)) 304 | layer.weight.data = det_layer.weight.data 305 | if layer.bias is not None: 306 | layer.bias.data = det_layer.bias.data 307 | layer.running_mean.data = det_layer.running_mean.data 308 | layer.running_var.data = det_layer.running_var.data 309 | layer.num_batches_tracked.data = det_layer.num_batches_tracked.data 310 | 311 | 312 | def main(): 313 | global args, best_prec1, best_avu 314 | args = parser.parse_args() 315 | 316 | # Check the save_dir exists or not 317 | if not os.path.exists(args.save_dir): 318 | os.makedirs(args.save_dir) 319 | 320 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) 321 | if torch.cuda.is_available(): 322 | model.cuda() 323 | else: 324 | model.cpu() 325 | 326 | # optionally resume from a checkpoint 327 | if args.resume: 328 | if os.path.isfile(args.resume): 329 | print("=> loading checkpoint '{}'".format(args.resume)) 330 | checkpoint = torch.load(args.resume) 331 | args.start_epoch = checkpoint['epoch'] 332 | best_prec1 = checkpoint['best_prec1'] 333 | model.load_state_dict(checkpoint['state_dict']) 334 | print("=> loaded checkpoint '{}' (epoch {})".format( 335 | args.evaluate, checkpoint['epoch'])) 336 | else: 337 | print("=> no checkpoint found at '{}'".format(args.resume)) 338 | 339 | cudnn.benchmark = True 340 | 341 | tb_writer = None 342 | if args.tensorboard: 343 | logger_dir = os.path.join(args.log_dir, 'tb_logger') 344 | if not os.path.exists(logger_dir): 345 | os.makedirs(logger_dir) 346 | tb_writer = SummaryWriter(logger_dir) 347 | 348 | preds_dir = os.path.join(args.log_dir, 'preds') 349 | if not os.path.exists(preds_dir): 350 | os.makedirs(preds_dir) 351 | results_dir = os.path.join(args.log_dir, 'results') 352 | if not os.path.exists(results_dir): 353 | os.makedirs(results_dir) 354 | 355 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 356 | std=[0.229, 0.224, 0.225]) 357 | 358 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( 359 | root='./data', 360 | train=True, 361 | transform=transforms.Compose([ 362 | transforms.RandomHorizontalFlip(), 363 | transforms.RandomCrop(32, 4), 364 | transforms.ToTensor() 365 | ]), 366 | download=True), 367 | batch_size=args.batch_size, 368 | shuffle=True, 369 | num_workers=args.workers, 370 | pin_memory=True) 371 | 372 | val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( 373 | root='./data', 374 | train=False, 375 | transform=transforms.Compose([transforms.ToTensor()])), 376 | batch_size=args.val_batch_size, 377 | shuffle=False, 378 | num_workers=args.workers, 379 | pin_memory=True) 380 | 381 | if not os.path.exists(args.save_dir): 382 | os.makedirs(args.save_dir) 383 | 384 | if torch.cuda.is_available(): 385 | criterion = nn.CrossEntropyLoss().cuda() 386 | avu_criterion = AUAvULoss().cuda() 387 | else: 388 | criterion = nn.CrossEntropyLoss().cpu() 389 | avu_criterion = AUAvULoss().cpu() 390 | 391 | if args.half: 392 | model.half() 393 | criterion.half() 394 | 395 | if args.arch in ['resnet110']: 396 | for param_group in optimizer.param_groups: 397 | param_group['lr'] = args.lr * 0.1 398 | 399 | if args.evaluate: 400 | validate(val_loader, model, criterion) 401 | return 402 | 403 | if args.mode == 'train': 404 | 405 | if (args.moped): 406 | print("MOPED enabled") 407 | det_model = torch.nn.DataParallel(det_resnet.__dict__[args.arch]()) 408 | if torch.cuda.is_available(): 409 | det_model.cuda() 410 | else: 411 | det_model.cpu() 412 | checkpoint_file = 'checkpoint/deterministic/{}_cifar.pth'.format( 413 | args.arch) 414 | checkpoint = torch.load(checkpoint_file) 415 | det_model.load_state_dict(checkpoint['state_dict']) 416 | 417 | for (idx_1, layer_1), (det_idx_1, det_layer_1) in zip( 418 | enumerate(model.children()), 419 | enumerate(det_model.children())): 420 | MOPED_layer(layer_1, det_layer_1, args.delta) 421 | for (idx_2, layer_2), (det_idx_2, det_layer_2) in zip( 422 | enumerate(layer_1.children()), 423 | enumerate(det_layer_1.children())): 424 | MOPED_layer(layer_2, det_layer_2, args.delta) 425 | for (idx_3, layer_3), (det_idx_3, det_layer_3) in zip( 426 | enumerate(layer_2.children()), 427 | enumerate(det_layer_2.children())): 428 | MOPED_layer(layer_3, det_layer_3, args.delta) 429 | for (idx_4, layer_4), (det_idx_4, det_layer_4) in zip( 430 | enumerate(layer_3.children()), 431 | enumerate(det_layer_3.children())): 432 | MOPED_layer(layer_4, det_layer_4, args.delta) 433 | 434 | model.state_dict() 435 | 436 | for epoch in range(args.start_epoch, args.epochs): 437 | 438 | lr = args.lr 439 | if (epoch >= 80 and epoch < 120): 440 | lr = 0.1 * args.lr 441 | elif (epoch >= 120 and epoch < 160): 442 | lr = 0.01 * args.lr 443 | elif (epoch >= 160 and epoch < 180): 444 | lr = 0.001 * args.lr 445 | elif (epoch >= 180): 446 | lr = 0.0005 * args.lr 447 | 448 | optimizer = torch.optim.Adam(model.parameters(), lr) 449 | # train for one epoch 450 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 451 | train(train_loader, model, criterion, avu_criterion, optimizer, 452 | epoch, tb_writer) 453 | #lr_scheduler.step() 454 | 455 | prec1 = validate(val_loader, model, criterion, avu_criterion, 456 | epoch, tb_writer) 457 | 458 | is_best = prec1 > best_prec1 459 | best_prec1 = max(prec1, best_prec1) 460 | 461 | #if epoch > 0 and epoch % args.save_every == 0: 462 | if epoch > 0: 463 | if is_best: 464 | save_checkpoint( 465 | { 466 | 'epoch': epoch + 1, 467 | 'state_dict': model.state_dict(), 468 | 'best_prec1': best_prec1, 469 | }, 470 | is_best, 471 | filename=os.path.join( 472 | args.save_dir, 473 | 'bayesian_{}_cifar.pth'.format(args.arch))) 474 | 475 | elif args.mode == 'test': 476 | checkpoint_file = args.save_dir + '/bayesian_{}_cifar.pth'.format( 477 | args.arch) 478 | if torch.cuda.is_available(): 479 | checkpoint = torch.load(checkpoint_file) 480 | else: 481 | checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) 482 | print('load checkpoint. epoch :', checkpoint['epoch']) 483 | model.load_state_dict(checkpoint['state_dict']) 484 | 485 | #header = ['corrupt', 'test_acc', 'brier', 'ece'] 486 | header = ['corrupt', 'test_acc'] 487 | 488 | #Evaluate on OOD dataset (SVHN) 489 | ood_images_file = 'data/SVHN/svhn-test.npy' 490 | ood_images = np.load(ood_images_file) 491 | ood_images = ood_images[:10000, :, :, :] 492 | ood_labels = np.arange(len(ood_images)) + 10 #create dummy labels 493 | ood_loader = get_ood_dataloader(ood_images, ood_labels) 494 | ood_acc = evaluate(model, ood_loader, corrupt='ood', level=None) 495 | print('******OOD data***********\n') 496 | print('ood_acc: ', ood_acc) 497 | ''' 498 | o_file = args.log_dir + '/results/ood_results.csv' 499 | with open(o_file, 'wt') as o_file: 500 | writer = csv.writer(o_file, delimiter=',', lineterminator='\n') 501 | writer.writerow([j for j in header]) 502 | writer.writerow(['ood', ood_acc, ood_brier, ood_ece]) 503 | o_file.close() 504 | ''' 505 | 506 | #Evaluate on test dataset 507 | test_acc = evaluate(model, val_loader, corrupt=None, level=None) 508 | print('******Test data***********\n') 509 | print('test_acc: ', test_acc) 510 | ''' 511 | t_file = args.log_dir + '/results/test_results.csv' 512 | with open(t_file, 'wt') as t_file: 513 | writer = csv.writer(t_file, delimiter=',', lineterminator='\n') 514 | writer.writerow([j for j in header]) 515 | writer.writerow(['test', test_acc, brier, ece]) 516 | t_file.close() 517 | ''' 518 | 519 | for level in range(1, 6): 520 | print('******Corruption Level: ', level, ' ***********\n') 521 | results_file = args.log_dir + '/results/level' + str( 522 | level) + '.csv' 523 | with open(results_file, 'wt') as results_file: 524 | writer = csv.writer(results_file, 525 | delimiter=',', 526 | lineterminator='\n') 527 | writer.writerow([j for j in header]) 528 | for c in corruptions: 529 | images_file = 'data/CIFAR-10-C/' + c + '.npy' 530 | labels_file = 'data/CIFAR-10-C/labels.npy' 531 | corrupt_images = np.load(images_file) 532 | corrupt_labels = np.load(labels_file) 533 | 534 | val_loader = get_corrupt_dataloader(corrupt_images, 535 | corrupt_labels, 536 | level=level) 537 | test_acc = evaluate(model, 538 | val_loader, 539 | corrupt=c, 540 | level=level) 541 | print('############ Corruption type: ', c, 542 | ' ################') 543 | print('test_acc: ', test_acc, '\n') 544 | writer.writerow([c, test_acc]) 545 | results_file.close() 546 | 547 | 548 | def train(train_loader, 549 | model, 550 | criterion, 551 | avu_criterion, 552 | optimizer, 553 | epoch, 554 | tb_writer=None): 555 | batch_time = AverageMeter() 556 | data_time = AverageMeter() 557 | losses = AverageMeter() 558 | top1 = AverageMeter() 559 | avg_unc = AverageMeter() 560 | 561 | # switch to train mode 562 | model.train() 563 | 564 | end = time.time() 565 | for i, (input, target) in enumerate(train_loader): 566 | 567 | # measure data loading time 568 | data_time.update(time.time() - end) 569 | 570 | if torch.cuda.is_available(): 571 | target = target.cuda() 572 | input_var = input.cuda() 573 | else: 574 | target = target.cpu() 575 | input_var = input.cpu() 576 | target_var = target 577 | if args.half: 578 | input_var = input_var.half() 579 | 580 | optimizer.zero_grad() 581 | 582 | output, kl = model(input_var) 583 | probs_ = torch.nn.functional.softmax(output, dim=1) 584 | probs = probs_.data.cpu().numpy() 585 | 586 | pred_entropy = util.entropy(probs) 587 | unc = np.mean(pred_entropy, axis=0) 588 | preds = np.argmax(probs, axis=-1) 589 | 590 | cross_entropy_loss = criterion(output, target_var) 591 | scaled_kl = kl.data / len_trainset 592 | elbo_loss = cross_entropy_loss + scaled_kl 593 | avu_loss, auc_avu = avu_criterion(output, target_var, type=0) 594 | avu_loss = beta * avu_loss 595 | loss = cross_entropy_loss + scaled_kl + avu_loss 596 | 597 | # compute gradient and do SGD step 598 | loss.backward() 599 | optimizer.step() 600 | 601 | output = output.float() 602 | loss = loss.float() 603 | # measure accuracy and record loss 604 | prec1 = accuracy(output.data, target)[0] 605 | losses.update(loss.item(), input.size(0)) 606 | top1.update(prec1.item(), input.size(0)) 607 | avg_unc.update(unc, input.size(0)) 608 | 609 | # measure elapsed time 610 | batch_time.update(time.time() - end) 611 | end = time.time() 612 | 613 | if i % args.print_freq == 0: 614 | #print('opt_th: ', opt_th) 615 | print('Epoch: [{0}][{1}/{2}]\t' 616 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 617 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 618 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 619 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 620 | 'Avg_Unc {avg_unc.val:.3f} ({avg_unc.avg:.3f})'.format( 621 | epoch, 622 | i, 623 | len(train_loader), 624 | batch_time=batch_time, 625 | data_time=data_time, 626 | loss=losses, 627 | top1=top1, 628 | avg_unc=avg_unc)) 629 | 630 | if tb_writer is not None: 631 | tb_writer.add_scalar('train/cross_entropy_loss', 632 | cross_entropy_loss.item(), epoch) 633 | tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch) 634 | tb_writer.add_scalar('train/elbo_loss', elbo_loss.item(), epoch) 635 | tb_writer.add_scalar('train/avu_loss', avu_loss, epoch) 636 | tb_writer.add_scalar('train/loss', loss.item(), epoch) 637 | tb_writer.add_scalar('train/AUC-AvU', auc_avu, epoch) 638 | tb_writer.add_scalar('train/accuracy', prec1.item(), epoch) 639 | tb_writer.flush() 640 | 641 | 642 | def validate(val_loader, 643 | model, 644 | criterion, 645 | avu_criterion, 646 | epoch, 647 | tb_writer=None): 648 | batch_time = AverageMeter() 649 | losses = AverageMeter() 650 | top1 = AverageMeter() 651 | avg_unc = AverageMeter() 652 | global opt_th 653 | 654 | # switch to evaluate mode 655 | model.eval() 656 | 657 | end = time.time() 658 | preds_list = [] 659 | labels_list = [] 660 | unc_list = [] 661 | th_list = np.linspace(0, 1, 21) 662 | with torch.no_grad(): 663 | for i, (input, target) in enumerate(val_loader): 664 | if torch.cuda.is_available(): 665 | target = target.cuda() 666 | input_var = input.cuda() 667 | target_var = target.cuda() 668 | else: 669 | target = target.cpu() 670 | input_var = input.cpu() 671 | target_var = target.cpu() 672 | 673 | if args.half: 674 | input_var = input_var.half() 675 | 676 | output, kl = model(input_var) 677 | probs_ = torch.nn.functional.softmax(output, dim=1) 678 | probs = probs_.data.cpu().numpy() 679 | 680 | pred_entropy = util.entropy(probs) 681 | unc = np.mean(pred_entropy, axis=0) 682 | preds = np.argmax(probs, axis=-1) 683 | preds_list.append(preds) 684 | labels_list.append(target.cpu().data.numpy()) 685 | unc_list.append(pred_entropy) 686 | 687 | cross_entropy_loss = criterion(output, target_var) 688 | scaled_kl = kl.data / len_trainset 689 | elbo_loss = cross_entropy_loss + scaled_kl 690 | avu_loss, auc_avu = avu_criterion(output, target_var, type=0) 691 | avu_loss = beta * avu_loss 692 | loss = cross_entropy_loss + scaled_kl + avu_loss 693 | 694 | output = output.float() 695 | loss = loss.float() 696 | 697 | # measure accuracy and record loss 698 | prec1 = accuracy(output.data, target)[0] 699 | losses.update(loss.item(), input.size(0)) 700 | top1.update(prec1.item(), input.size(0)) 701 | avg_unc.update(unc, input.size(0)) 702 | 703 | # measure elapsed time 704 | batch_time.update(time.time() - end) 705 | end = time.time() 706 | 707 | if i % args.print_freq == 0: 708 | print('Test: [{0}/{1}]\t' 709 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 710 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 711 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 712 | 'Avg_Unc {avg_unc.val:.3f} ({avg_unc.avg:.3f})'.format( 713 | i, 714 | len(val_loader), 715 | batch_time=batch_time, 716 | loss=losses, 717 | top1=top1, 718 | avg_unc=avg_unc)) 719 | 720 | if tb_writer is not None: 721 | tb_writer.add_scalar('val/cross_entropy_loss', 722 | cross_entropy_loss.item(), epoch) 723 | tb_writer.add_scalar('val/kl_div', scaled_kl.item(), epoch) 724 | tb_writer.add_scalar('val/elbo_loss', elbo_loss.item(), epoch) 725 | tb_writer.add_scalar('val/avu_loss', avu_loss, epoch) 726 | tb_writer.add_scalar('val/loss', loss.item(), epoch) 727 | tb_writer.add_scalar('val/AUC-AvU', auc_avu, epoch) 728 | tb_writer.add_scalar('val/accuracy', prec1.item(), epoch) 729 | tb_writer.flush() 730 | 731 | preds = np.hstack(np.asarray(preds_list)) 732 | labels = np.hstack(np.asarray(labels_list)) 733 | unc_ = np.hstack(np.asarray(unc_list)) 734 | avu_th, unc_th = util.eval_avu(preds, labels, unc_) 735 | print('max AvU: ', np.amax(avu_th)) 736 | unc_correct = np.take(unc_, np.where(preds == labels)) 737 | unc_incorrect = np.take(unc_, np.where(preds != labels)) 738 | print('avg unc correct preds: ', 739 | np.mean(np.take(unc_, np.where(preds == labels)), axis=1)) 740 | print('avg unc incorrect preds: ', 741 | np.mean(np.take(unc_, np.where(preds != labels)), axis=1)) 742 | ''' 743 | print('unc @max AvU: ', unc_th[np.argmax(avu_th)]) 744 | print('avg unc: ', np.mean(unc_, axis=0)) 745 | print('avg unc: ', np.mean(unc_th, axis=0)) 746 | print('min unc: ', np.amin(unc_)) 747 | print('max unc: ', np.amax(unc_)) 748 | ''' 749 | if epoch <= 5: 750 | opt_th = (np.mean(unc_correct, axis=1) + 751 | np.mean(unc_incorrect, axis=1)) / 2 752 | 753 | print('opt_th: ', opt_th) 754 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 755 | return top1.avg 756 | 757 | 758 | def evaluate(model, val_loader, corrupt=None, level=None): 759 | pred_probs_mc = [] 760 | test_loss = 0 761 | correct = 0 762 | with torch.no_grad(): 763 | pred_probs_mc = [] 764 | for batch_idx, (data, target) in enumerate(val_loader): 765 | if torch.cuda.is_available(): 766 | data, target = data.cuda(), target.cuda() 767 | else: 768 | data, target = data.cpu(), target.cpu() 769 | for mc_run in range(args.num_monte_carlo): 770 | model.eval() 771 | output, _ = model.forward(data) 772 | pred_probs = torch.nn.functional.softmax(output, dim=1) 773 | pred_probs_mc.append(pred_probs.cpu().data.numpy()) 774 | 775 | if corrupt == 'ood': 776 | np.save(args.log_dir + '/preds/svi_avu_ood_probs.npy', 777 | pred_probs_mc) 778 | print('saved predictive probabilities') 779 | return None 780 | 781 | target_labels = target.cpu().data.numpy() 782 | pred_mean = np.mean(pred_probs_mc, axis=0) 783 | #print(pred_mean) 784 | Y_pred = np.argmax(pred_mean, axis=1) 785 | test_acc = (Y_pred == target_labels).mean() 786 | #brier = np.mean(calib.brier_scores(target_labels, probs=pred_mean)) 787 | #ece = calib.expected_calibration_error_multiclass(pred_mean, target_labels) 788 | print('Test accuracy:', test_acc * 100) 789 | #print('Brier score: ', brier) 790 | #print('ECE: ', ece) 791 | if corrupt is not None: 792 | np.save( 793 | args.log_dir + 794 | '/preds/svi_avu_corrupt-static-{}-{}_probs.npy'.format( 795 | corrupt, level), pred_probs_mc) 796 | np.save( 797 | args.log_dir + 798 | '/preds/svi_avu_corrupt-static-{}-{}_labels.npy'.format( 799 | corrupt, level), target_labels) 800 | print('saved predictive probabilities') 801 | elif corrupt == 'test': 802 | np.save(args.log_dir + '/preds/svi_avu_test_probs.npy', 803 | pred_probs_mc) 804 | np.save(args.log_dir + '/preds/svi_avu_test_labels.npy', 805 | target_labels) 806 | print('saved predictive probabilities') 807 | return test_acc 808 | 809 | 810 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 811 | """ 812 | Save the training model 813 | """ 814 | torch.save(state, filename) 815 | 816 | 817 | class AverageMeter(object): 818 | """Computes and stores the average and current value""" 819 | def __init__(self): 820 | self.reset() 821 | 822 | def reset(self): 823 | self.val = 0 824 | self.avg = 0 825 | self.sum = 0 826 | self.count = 0 827 | 828 | def update(self, val, n=1): 829 | self.val = val 830 | self.sum += val * n 831 | self.count += n 832 | self.avg = self.sum / self.count 833 | 834 | 835 | def accuracy(output, target, topk=(1, )): 836 | """Computes the precision@k for the specified values of k""" 837 | maxk = max(topk) 838 | batch_size = target.size(0) 839 | 840 | _, pred = output.topk(maxk, 1, True, True) 841 | pred = pred.t() 842 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 843 | 844 | res = [] 845 | for k in topk: 846 | correct_k = correct[:k].view(-1).float().sum(0) 847 | res.append(correct_k.mul_(100.0 / batch_size)) 848 | return res 849 | 850 | 851 | if __name__ == '__main__': 852 | main() 853 | -------------------------------------------------------------------------------- /src/main_bayesian_imagenet_avu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | from torch.nn import functional as F 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | import models.bayesian.resnet_large as resnet 23 | import models.deterministic.resnet_large as det_resnet 24 | from src import util 25 | #from utils import calib 26 | import csv 27 | import numpy as np 28 | from src.util import get_rho 29 | from src.avuc_loss import AvULoss 30 | from torch.utils.tensorboard import SummaryWriter 31 | 32 | torchvision.set_image_backend('accimage') 33 | 34 | model_names = sorted( 35 | name for name in resnet.__dict__ 36 | if name.islower() and not name.startswith("__") 37 | and name.startswith("resnet") and callable(resnet.__dict__[name])) 38 | 39 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 40 | parser.add_argument('data', 41 | metavar='DIR', 42 | default='data/imagenet', 43 | help='path to dataset') 44 | parser.add_argument('--corrupt_data', 45 | type=str, 46 | default='data/ImageNet-C', 47 | metavar='N', 48 | help='path to corrupt dataset') 49 | parser.add_argument('-a', 50 | '--arch', 51 | metavar='ARCH', 52 | default='resnet50', 53 | choices=model_names, 54 | help='model architecture: ' + ' | '.join(model_names) + 55 | ' (default: resnet50)') 56 | parser.add_argument('-j', 57 | '--workers', 58 | default=8, 59 | type=int, 60 | metavar='N', 61 | help='number of data loading workers (default: 4)') 62 | parser.add_argument('--epochs', 63 | default=90, 64 | type=int, 65 | metavar='N', 66 | help='number of total epochs to run') 67 | parser.add_argument('--start-epoch', 68 | default=0, 69 | type=int, 70 | metavar='N', 71 | help='manual epoch number (useful on restarts)') 72 | parser.add_argument('--val_batch_size', default=1000, type=int) 73 | parser.add_argument('-b', 74 | '--batch-size', 75 | default=32, 76 | type=int, 77 | metavar='N', 78 | help='mini-batch size (default: 256), this is the total ' 79 | 'batch size of all GPUs on the current node when ' 80 | 'using Data Parallel or Distributed Data Parallel') 81 | parser.add_argument('--lr', 82 | '--learning-rate', 83 | default=0.001, 84 | type=float, 85 | metavar='LR', 86 | help='initial learning rate', 87 | dest='lr') 88 | parser.add_argument('--momentum', 89 | default=0.9, 90 | type=float, 91 | metavar='M', 92 | help='momentum') 93 | parser.add_argument('--wd', 94 | '--weight-decay', 95 | default=1e-4, 96 | type=float, 97 | metavar='W', 98 | help='weight decay (default: 1e-4)', 99 | dest='weight_decay') 100 | parser.add_argument('-p', 101 | '--print-freq', 102 | default=10, 103 | type=int, 104 | metavar='N', 105 | help='print frequency (default: 10)') 106 | parser.add_argument('--resume', 107 | default='', 108 | type=str, 109 | metavar='PATH', 110 | help='path to latest checkpoint (default: none)') 111 | parser.add_argument('-e', 112 | '--evaluate', 113 | dest='evaluate', 114 | action='store_true', 115 | help='evaluate model on validation set') 116 | parser.add_argument('--pretrained', 117 | dest='pretrained', 118 | action='store_true', 119 | default=True, 120 | help='use pre-trained model') 121 | parser.add_argument('--world-size', 122 | default=-1, 123 | type=int, 124 | help='number of nodes for distributed training') 125 | parser.add_argument('--rank', 126 | default=-1, 127 | type=int, 128 | help='node rank for distributed training') 129 | parser.add_argument('--dist-url', 130 | default='tcp://224.66.41.62:23456', 131 | type=str, 132 | help='url used to set up distributed training') 133 | parser.add_argument('--dist-backend', 134 | default='nccl', 135 | type=str, 136 | help='distributed backend') 137 | parser.add_argument('--seed', 138 | default=None, 139 | type=int, 140 | help='seed for initializing training. ') 141 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 142 | parser.add_argument('--multiprocessing-distributed', 143 | action='store_true', 144 | help='Use multi-processing distributed training to launch ' 145 | 'N processes per node, which has N GPUs. This is the ' 146 | 'fastest way to use PyTorch for either single node or ' 147 | 'multi node data parallel training') 148 | parser.add_argument('--mode', type=str, required=True, help='train | test') 149 | parser.add_argument('--save-dir', 150 | dest='save_dir', 151 | help='The directory used to save the trained models', 152 | default='./checkpoint/imagenet/bayesian_svi_avu', 153 | type=str) 154 | parser.add_argument( 155 | '--tensorboard', 156 | type=bool, 157 | default=True, 158 | metavar='N', 159 | help='use tensorboard for logging and visualization of training progress') 160 | parser.add_argument( 161 | '--log_dir', 162 | type=str, 163 | default='./logs/imagenet/bayesian_svi_avu', 164 | metavar='N', 165 | help='use tensorboard for logging and visualization of training progress') 166 | parser.add_argument('--num_monte_carlo', 167 | type=int, 168 | default=20, 169 | metavar='N', 170 | help='number of Monte Carlo samples') 171 | parser.add_argument( 172 | '--moped', 173 | type=bool, 174 | default=True, 175 | help='set prior and initialize approx posterior with Empirical Bayes') 176 | parser.add_argument('--delta', 177 | type=float, 178 | default=0.5, 179 | help='delta value for variance scaling in MOPED') 180 | best_acc1 = 0 181 | opt_th = 1.0 182 | len_trainset = 1281167 183 | beta = 3.0 184 | 185 | 186 | def MOPED_layer(layer, det_layer, delta): 187 | """ 188 | Set the priors and initialize surrogate posteriors of Bayesian NN with Empirical Bayes 189 | MOPED (Model Priors with Empirical Bayes using Deterministic DNN) 190 | Ref: https://arxiv.org/abs/1906.05323 191 | 'Specifying Weight Priors in Bayesian Deep Neural Networks with Empirical Bayes'. AAAI 2020. 192 | """ 193 | 194 | if (str(layer) == 'Conv2dVariational()'): 195 | #set the priors 196 | print(str(layer)) 197 | layer.prior_weight_mu = det_layer.weight.data 198 | if layer.prior_bias_mu is not None: 199 | layer.prior_bias_mu = det_layer.bias.data 200 | 201 | #initialize surrogate posteriors 202 | layer.mu_kernel.data = det_layer.weight.data 203 | if layer.mu_bias is not None: 204 | layer.mu_bias.data = det_layer.bias.data 205 | 206 | elif (isinstance(layer, nn.Conv2d)): 207 | print(str(layer)) 208 | layer.weight.data = det_layer.weight.data 209 | if layer.bias is not None: 210 | layer.bias.data = det_layer.bias.data2 211 | 212 | elif (str(layer) == 'LinearVariational()'): 213 | print(str(layer)) 214 | layer.prior_weight_mu = det_layer.weight.data 215 | if layer.prior_bias_mu is not None: 216 | layer.prior_bias_mu = det_layer.bias.data 217 | 218 | #initialize the surrogate posteriors 219 | 220 | layer.mu_weight.data = det_layer.weight.data 221 | layer.rho_weight.data = get_rho(det_layer.weight.data, delta) 222 | if layer.mu_bias is not None: 223 | layer.mu_bias.data = det_layer.bias.data 224 | layer.rho_bias.data = get_rho(det_layer.bias.data, delta) 225 | 226 | elif str(layer).startswith('Batch'): 227 | #initialize parameters 228 | print(str(layer)) 229 | layer.weight.data = det_layer.weight.data 230 | if layer.bias is not None: 231 | layer.bias.data = det_layer.bias.data 232 | layer.running_mean.data = det_layer.running_mean.data 233 | layer.running_var.data = det_layer.running_var.data 234 | layer.num_batches_tracked.data = det_layer.num_batches_tracked.data 235 | 236 | corruptions = [ 237 | 'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 238 | 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 239 | 'pixelate', 'saturate', 'shot_noise', 'spatter', 'speckle_noise', 240 | 'zoom_blur' 241 | ] 242 | 243 | 244 | def get_corrupt_dataloader(args, corrupt_type, level): 245 | 246 | corrupt_dir = os.path.join(args.corrupt_data, str(corrupt_type), 247 | str(level)) 248 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 249 | std=[0.229, 0.224, 0.225]) 250 | corrupt_dataset = datasets.ImageFolder( 251 | corrupt_dir, 252 | transforms.Compose([ 253 | transforms.Resize(256), 254 | transforms.CenterCrop(224), 255 | transforms.ToTensor(), 256 | normalize, 257 | ])) 258 | 259 | corrupt_val_loader = torch.utils.data.DataLoader( 260 | corrupt_dataset, 261 | batch_size=args.val_batch_size, 262 | shuffle=False, 263 | num_workers=args.workers, 264 | pin_memory=True) 265 | 266 | return corrupt_val_loader 267 | 268 | 269 | def main(): 270 | args = parser.parse_args() 271 | 272 | if args.seed is not None: 273 | random.seed(args.seed) 274 | torch.manual_seed(args.seed) 275 | cudnn.deterministic = True 276 | warnings.warn('You have chosen to seed training. ' 277 | 'This will turn on the CUDNN deterministic setting, ' 278 | 'which can slow down your training considerably! ' 279 | 'You may see unexpected behavior when restarting ' 280 | 'from checkpoints.') 281 | 282 | if args.gpu is not None: 283 | warnings.warn('You have chosen a specific GPU. This will completely ' 284 | 'disable data parallelism.') 285 | 286 | if args.dist_url == "env://" and args.world_size == -1: 287 | args.world_size = int(os.environ["WORLD_SIZE"]) 288 | 289 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 290 | 291 | if torch.cuda.is_available(): 292 | ngpus_per_node = torch.cuda.device_count() 293 | if args.multiprocessing_distributed: 294 | # Since we have ngpus_per_node processes per node, the total world_size 295 | # needs to be adjusted accordingly 296 | args.world_size = ngpus_per_node * args.world_size 297 | # Use torch.multiprocessing.spawn to launch distributed processes: the 298 | # main_worker process function 299 | mp.spawn(main_worker, 300 | nprocs=ngpus_per_node, 301 | args=(ngpus_per_node, args)) 302 | else: 303 | # Simply call main_worker function 304 | main_worker(args.gpu, ngpus_per_node, args) 305 | 306 | 307 | def main_worker(gpu, ngpus_per_node, args): 308 | global best_acc1 309 | args.gpu = gpu 310 | 311 | if args.gpu is not None: 312 | print("Use GPU: {} for training".format(args.gpu)) 313 | 314 | if args.distributed: 315 | if args.dist_url == "env://" and args.rank == -1: 316 | args.rank = int(os.environ["RANK"]) 317 | if args.multiprocessing_distributed: 318 | # For multiprocessing distributed training, rank needs to be the 319 | # global rank among all the processes 320 | args.rank = args.rank * ngpus_per_node + gpu 321 | dist.init_process_group(backend=args.dist_backend, 322 | init_method=args.dist_url, 323 | world_size=args.world_size, 324 | rank=args.rank) 325 | 326 | if not os.path.exists(args.save_dir): 327 | os.makedirs(args.save_dir) 328 | 329 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) 330 | if torch.cuda.is_available(): 331 | model.cuda() 332 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 333 | avu_criterion = AvULoss().cuda() 334 | else: 335 | model.cpu() 336 | criterion = nn.CrossEntropyLoss().cpu() 337 | avu_criterion = AvULoss().cpu() 338 | 339 | optimizer = torch.optim.SGD(model.parameters(), 340 | args.lr, 341 | momentum=args.momentum, 342 | weight_decay=args.weight_decay) 343 | 344 | # optionally resume from a checkpoint 345 | if args.resume: 346 | if os.path.isfile(args.resume): 347 | print("=> loading checkpoint '{}'".format(args.resume)) 348 | if args.gpu is None: 349 | checkpoint = torch.load(args.resume) 350 | else: 351 | # Map model to be loaded to specified single gpu. 352 | loc = 'cuda:{}'.format(args.gpu) 353 | checkpoint = torch.load(args.resume, map_location=loc) 354 | args.start_epoch = checkpoint['epoch'] 355 | best_acc1 = checkpoint['best_acc1'] 356 | if args.gpu is not None: 357 | # best_acc1 may be from a checkpoint from a different GPU 358 | best_acc1 = best_acc1.to(args.gpu) 359 | model.load_state_dict(checkpoint['state_dict']) 360 | optimizer.load_state_dict(checkpoint['optimizer']) 361 | print("=> loaded checkpoint '{}' (epoch {})".format( 362 | args.resume, checkpoint['epoch'])) 363 | else: 364 | print("=> no checkpoint found at '{}'".format(args.resume)) 365 | 366 | cudnn.benchmark = True 367 | 368 | tb_writer = None 369 | if args.tensorboard: 370 | logger_dir = os.path.join(args.log_dir, 'tb_logger') 371 | if not os.path.exists(logger_dir): 372 | os.makedirs(logger_dir) 373 | tb_writer = SummaryWriter(logger_dir) 374 | 375 | preds_dir = os.path.join(args.log_dir, 'preds') 376 | if not os.path.exists(preds_dir): 377 | os.makedirs(preds_dir) 378 | results_dir = os.path.join(args.log_dir, 'results') 379 | if not os.path.exists(results_dir): 380 | os.makedirs(results_dir) 381 | 382 | # Data loading code 383 | traindir = os.path.join(args.data, 'train') 384 | valdir = os.path.join(args.data, 'val') 385 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 386 | std=[0.229, 0.224, 0.225]) 387 | 388 | train_dataset = datasets.ImageFolder( 389 | traindir, 390 | transforms.Compose([ 391 | transforms.RandomResizedCrop(224), 392 | transforms.RandomHorizontalFlip(), 393 | transforms.ToTensor(), 394 | normalize, 395 | ])) 396 | 397 | val_dataset = datasets.ImageFolder( 398 | valdir, 399 | transforms.Compose([ 400 | transforms.Resize(256), 401 | transforms.CenterCrop(224), 402 | transforms.ToTensor(), 403 | normalize, 404 | ])) 405 | 406 | print('len trainset: ', len(train_dataset)) 407 | print('len valset: ', len(val_dataset)) 408 | len_trainset = len(train_dataset) 409 | 410 | if args.distributed: 411 | train_sampler = torch.utils.data.distributed.DistributedSampler( 412 | train_dataset) 413 | else: 414 | train_sampler = None 415 | 416 | train_loader = torch.utils.data.DataLoader(train_dataset, 417 | batch_size=args.batch_size, 418 | shuffle=True, 419 | num_workers=args.workers, 420 | pin_memory=True, 421 | drop_last=True) 422 | 423 | val_loader = torch.utils.data.DataLoader(val_dataset, 424 | batch_size=args.val_batch_size, 425 | shuffle=False, 426 | num_workers=args.workers, 427 | pin_memory=True) 428 | 429 | if args.evaluate: 430 | validate(val_loader, model, criterion, args) 431 | return 432 | 433 | if args.mode == 'train': 434 | 435 | if (args.moped): 436 | print("MOPED enabled") 437 | det_model = torch.nn.DataParallel( 438 | det_resnet.__dict__[args.arch](pretrained=True)) 439 | if torch.cuda.is_available(): 440 | det_model.cuda() 441 | else: 442 | det_model.cpu() 443 | 444 | for (idx_1, layer_1), (det_idx_1, det_layer_1) in zip( 445 | enumerate(model.children()), 446 | enumerate(det_model.children())): 447 | MOPED_layer(layer_1, det_layer_1, args.delta) 448 | for (idx_2, layer_2), (det_idx_2, det_layer_2) in zip( 449 | enumerate(layer_1.children()), 450 | enumerate(det_layer_1.children())): 451 | MOPED_layer(layer_2, det_layer_2, args.delta) 452 | for (idx_3, layer_3), (det_idx_3, det_layer_3) in zip( 453 | enumerate(layer_2.children()), 454 | enumerate(det_layer_2.children())): 455 | MOPED_layer(layer_3, det_layer_3, args.delta) 456 | for (idx_4, layer_4), (det_idx_4, det_layer_4) in zip( 457 | enumerate(layer_3.children()), 458 | enumerate(det_layer_3.children())): 459 | MOPED_layer(layer_4, det_layer_4, args.delta) 460 | for (idx_5, 461 | layer_5), (det_idx_5, det_layer_5) in zip( 462 | enumerate(layer_4.children()), 463 | enumerate(det_layer_4.children())): 464 | MOPED_layer(layer_5, det_layer_5, args.delta) 465 | for (idx_6, 466 | layer_6), (det_idx_6, det_layer_6) in zip( 467 | enumerate(layer_5.children()), 468 | enumerate(det_layer_5.children())): 469 | MOPED_layer(layer_6, det_layer_6, 470 | args.delta) 471 | 472 | model.state_dict() 473 | del det_model 474 | 475 | for epoch in range(args.start_epoch, args.epochs): 476 | if args.distributed: 477 | train_sampler.set_epoch(epoch) 478 | adjust_learning_rate(optimizer, epoch, args) 479 | 480 | # train for one epoch 481 | train(train_loader, model, criterion, avu_criterion, optimizer, 482 | epoch, args, tb_writer) 483 | 484 | # evaluate on validation set 485 | acc1 = validate(val_loader, model, criterion, avu_criterion, epoch, 486 | args, tb_writer) 487 | 488 | # remember best acc@1 and save checkpoint 489 | is_best = acc1 > best_acc1 490 | best_acc1 = max(acc1, best_acc1) 491 | 492 | if is_best: 493 | save_checkpoint( 494 | { 495 | 'epoch': epoch + 1, 496 | 'arch': args.arch, 497 | 'state_dict': model.state_dict(), 498 | 'best_acc1': best_acc1, 499 | 'optimizer': optimizer.state_dict(), 500 | }, 501 | is_best, 502 | filename=os.path.join( 503 | args.save_dir, 504 | 'bayesian_{}_imagenet.pth'.format(args.arch))) 505 | 506 | elif args.mode == 'test': 507 | 508 | checkpoint_file = args.save_dir + '/bayesian_{}_imagenet.pth'.format( 509 | args.arch) 510 | if torch.cuda.is_available(): 511 | checkpoint = torch.load(checkpoint_file) 512 | else: 513 | checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) 514 | print('load checkpoint.') 515 | model.load_state_dict(checkpoint['state_dict']) 516 | 517 | #header = ['corrupt', 'test_acc', 'brier', 'ece'] 518 | header = ['corrupt', 'test_acc'] 519 | 520 | #Evaluate on test dataset 521 | #test_acc, brier, ece = evaluate(model, val_loader, args, corrupt=None, level=None) 522 | test_acc = evaluate(model, val_loader, args, corrupt=None, level=None) 523 | print('******Test data***********\n') 524 | #print('test_acc: ', test_acc, ' | Brier: ', brier, ' | ECE: ', ece, '\n') 525 | print('test_acc: ', test_acc) 526 | ''' 527 | t_file = args.log_dir + '/results/test_results.csv' 528 | with open(t_file, 'wt') as t_file: 529 | writer = csv.writer(t_file, delimiter=',', lineterminator='\n') 530 | writer.writerow([j for j in header]) 531 | writer.writerow(['test', test_acc, brier, ece]) 532 | t_file.close() 533 | ''' 534 | 535 | for level in range(1, 6): 536 | print('******Corruption Level: ', level, ' ***********\n') 537 | results_file = args.log_dir + '/results/level' + str( 538 | level) + '.csv' 539 | with open(results_file, 'wt') as results_file: 540 | writer = csv.writer(results_file, 541 | delimiter=',', 542 | lineterminator='\n') 543 | writer.writerow([j for j in header]) 544 | for c in corruptions: 545 | val_loader = get_corrupt_dataloader(args, c, level) 546 | test_acc = evaluate(model, 547 | val_loader, 548 | args, 549 | corrupt=c, 550 | level=level) 551 | print('############ Corruption type: ', c, 552 | ' ################') 553 | print('test_acc: ', test_acc, '\n') 554 | writer.writerow([c, test_acc]) 555 | results_file.close() 556 | 557 | 558 | def train(train_loader, model, criterion, avu_criterion, optimizer, epoch, 559 | args, tb_writer): 560 | batch_time = AverageMeter('Time', ':6.3f') 561 | data_time = AverageMeter('Data', ':6.3f') 562 | losses = AverageMeter('Loss', ':.4e') 563 | top1 = AverageMeter('Acc@1', ':6.2f') 564 | top5 = AverageMeter('Acc@5', ':6.2f') 565 | global opt_th 566 | progress = ProgressMeter(len(train_loader), 567 | [batch_time, data_time, losses, top1, top5], 568 | prefix="Epoch: [{}]".format(epoch)) 569 | 570 | # switch to train mode 571 | model.train() 572 | 573 | end = time.time() 574 | for i, (images, target) in enumerate(train_loader): 575 | # measure data loading time 576 | data_time.update(time.time() - end) 577 | 578 | if torch.cuda.is_available(): 579 | images = images.cuda(non_blocking=True) 580 | target = target.cuda(non_blocking=True) 581 | else: 582 | images = images.cpu(non_blocking=True) 583 | target = target.cpu(non_blocking=True) 584 | 585 | # compute output 586 | output, kl = model(images) 587 | probs_ = torch.nn.functional.softmax(output, dim=1) 588 | probs = probs_.data.cpu().numpy() 589 | 590 | pred_entropy = util.entropy(probs) 591 | preds = np.argmax(probs, axis=-1) 592 | AvU = util.accuracy_vs_uncertainty(np.array(preds), 593 | np.array(target.cpu().data.numpy()), 594 | np.array(pred_entropy), opt_th) 595 | 596 | preds_list.append(preds) 597 | labels_list.append(target.cpu().data.numpy()) 598 | unc_list.append(pred_entropy) 599 | 600 | cross_entropy_loss = criterion(output, target) 601 | scaled_kl = (kl.data[0] / len_trainset) 602 | elbo_loss = cross_entropy_loss + scaled_kl 603 | avu_loss = beta * avu_criterion(output, target, opt_th, type=0) 604 | loss = cross_entropy_loss + scaled_kl + avu_loss 605 | 606 | output = output.float() 607 | loss = loss.float() 608 | 609 | # measure accuracy and record loss 610 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 611 | losses.update(loss.item(), images.size(0)) 612 | top1.update(acc1[0], images.size(0)) 613 | top5.update(acc5[0], images.size(0)) 614 | 615 | # compute gradient and do SGD step 616 | optimizer.zero_grad() 617 | loss.mean().backward() 618 | optimizer.step() 619 | 620 | # measure elapsed time 621 | batch_time.update(time.time() - end) 622 | end = time.time() 623 | 624 | if i % args.print_freq == 0: 625 | progress.display(i) 626 | 627 | if tb_writer is not None: 628 | tb_writer.add_scalar('train/cross_entropy_loss', 629 | cross_entropy_loss.item(), epoch) 630 | tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch) 631 | tb_writer.add_scalar('train/elbo_loss', elbo_loss.item(), epoch) 632 | tb_writer.add_scalar('train/avu_loss', avu_loss.item(), epoch) 633 | tb_writer.add_scalar('train/loss', loss.item(), epoch) 634 | tb_writer.add_scalar('train/AvU', AvU, epoch) 635 | tb_writer.add_scalar('train/accuracy', acc1.item(), epoch) 636 | tb_writer.flush() 637 | 638 | preds = np.hstack(np.asarray(preds_list)) 639 | labels = np.hstack(np.asarray(labels_list)) 640 | unc_ = np.hstack(np.asarray(unc_list)) 641 | unc_correct = np.take(unc_, np.where(preds == labels)) 642 | unc_incorrect = np.take(unc_, np.where(preds != labels)) 643 | #print('avg unc correct preds: ', np.mean(np.take(unc_,np.where(preds == labels)), axis=1)) 644 | #print('avg unc incorrect preds: ', np.mean(np.take(unc_,np.where(preds != labels)), axis=1)) 645 | if epoch <= 1: 646 | opt_th = (np.mean(unc_correct, axis=1) + 647 | np.mean(unc_incorrect, axis=1)) / 2 648 | 649 | print('opt_th: ', opt_th) 650 | 651 | 652 | def validate(val_loader, model, criterion, avu_criterion, epoch, args, 653 | tb_writer): 654 | batch_time = AverageMeter('Time', ':6.3f') 655 | losses = AverageMeter('Loss', ':.4e') 656 | top1 = AverageMeter('Acc@1', ':6.2f') 657 | top5 = AverageMeter('Acc@5', ':6.2f') 658 | progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5], 659 | prefix='Test: ') 660 | 661 | # switch to evaluate mode 662 | model.eval() 663 | 664 | preds_list = [] 665 | labels_list = [] 666 | unc_list = [] 667 | with torch.no_grad(): 668 | end = time.time() 669 | for i, (images, target) in enumerate(val_loader): 670 | if torch.cuda.is_available(): 671 | images = images.cuda(non_blocking=True) 672 | target = target.cuda(non_blocking=True) 673 | else: 674 | images = images.cpu(non_blocking=True) 675 | target = target.cpu(non_blocking=True) 676 | 677 | # compute output 678 | output, kl = model(images) 679 | cross_entropy_loss = criterion(output, target) 680 | scaled_kl = (kl.data[0] / len_trainset) 681 | elbo_loss = cross_entropy_loss + scaled_kl 682 | avu_loss = beta * avu_criterion(output, target, opt_th, type=0) 683 | loss = cross_entropy_loss + scaled_kl + avu_loss 684 | 685 | output = output.float() 686 | loss = loss.float() 687 | 688 | # measure accuracy and record loss 689 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 690 | losses.update(loss.item(), images.size(0)) 691 | top1.update(acc1[0], images.size(0)) 692 | top5.update(acc5[0], images.size(0)) 693 | 694 | # measure elapsed time 695 | batch_time.update(time.time() - end) 696 | end = time.time() 697 | 698 | if i % args.print_freq == 0: 699 | progress.display(i) 700 | 701 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, 702 | top5=top5)) 703 | 704 | return top1.avg 705 | 706 | 707 | def evaluate(model, val_loader, args, corrupt=None, level=None): 708 | pred_probs_mc = [] 709 | test_loss = 0 710 | correct = 0 711 | with torch.no_grad(): 712 | pred_probs_mc = [] 713 | output_list = [] 714 | label_list = [] 715 | for batch_idx, (data, target) in enumerate(val_loader): 716 | #print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape)) 717 | if torch.cuda.is_available(): 718 | data, target = data.cuda(), target.cuda() 719 | else: 720 | data, target = data.cpu(), target.cpu() 721 | output_mc = [] 722 | output_mc_np = [] 723 | for mc_run in range(args.num_monte_carlo): 724 | model.eval() 725 | output, _ = model.forward(data) 726 | #output_mc_np.append(output.data.cpu().numpy()) 727 | pred_probs = torch.nn.functional.softmax(output, dim=1) 728 | output_mc_np.append(pred_probs.cpu().data.numpy()) 729 | 730 | output_mc = torch.from_numpy( 731 | np.mean(np.asarray(output_mc_np), axis=0)) 732 | output_list.append(output_mc) 733 | label_list.append(target) 734 | 735 | if torch.cuda.is_available(): 736 | labels = torch.cat(label_list).cuda() 737 | probs = torch.cat(output_list).cuda() 738 | else: 739 | labels = torch.cat(label_list).cpu() 740 | probs = torch.cat(output_list).cpu() 741 | 742 | target_labels = labels.data.cpu().numpy() 743 | pred_mean = probs.data.cpu().numpy() 744 | Y_pred = np.argmax(pred_mean, axis=1) 745 | test_acc = (Y_pred == target_labels).mean() 746 | #brier = np.mean(calib.brier_scores(target_labels, probs=pred_mean)) 747 | #ece = calib.expected_calibration_error_multiclass(pred_mean, target_labels) 748 | print('Test accuracy:', test_acc * 100) 749 | #print('Brier score: ', brier) 750 | #print('ECE: ', ece) 751 | if corrupt is not None: 752 | np.save( 753 | args.log_dir + 754 | '/preds/svi_avu_corrupt-static-{}-{}_probs.npy'.format( 755 | corrupt, level), pred_mean) 756 | np.save( 757 | args.log_dir + 758 | '/preds/svi_avu_corrupt-static-{}-{}_labels.npy'.format( 759 | corrupt, level), target_labels) 760 | print('saved predictions') 761 | else: 762 | np.save(args.log_dir + '/preds/svi_avu_test_probs.npy', pred_mean) 763 | np.save(args.log_dir + '/preds/svi_avu_test_labels.npy', 764 | target_labels) 765 | print('saved predictions') 766 | 767 | return test_acc 768 | 769 | 770 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 771 | torch.save(state, filename) 772 | if is_best: 773 | shutil.copyfile(filename, 'model_best.pth.tar') 774 | 775 | 776 | class AverageMeter(object): 777 | """Computes and stores the average and current value""" 778 | def __init__(self, name, fmt=':f'): 779 | self.name = name 780 | self.fmt = fmt 781 | self.reset() 782 | 783 | def reset(self): 784 | self.val = 0 785 | self.avg = 0 786 | self.sum = 0 787 | self.count = 0 788 | 789 | def update(self, val, n=1): 790 | self.val = val 791 | self.sum += val * n 792 | self.count += n 793 | self.avg = self.sum / self.count 794 | 795 | def __str__(self): 796 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 797 | return fmtstr.format(**self.__dict__) 798 | 799 | 800 | class ProgressMeter(object): 801 | def __init__(self, num_batches, meters, prefix=""): 802 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 803 | self.meters = meters 804 | self.prefix = prefix 805 | 806 | def display(self, batch): 807 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 808 | entries += [str(meter) for meter in self.meters] 809 | print('\t'.join(entries)) 810 | 811 | def _get_batch_fmtstr(self, num_batches): 812 | num_digits = len(str(num_batches // 1)) 813 | fmt = '{:' + str(num_digits) + 'd}' 814 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 815 | 816 | 817 | def adjust_learning_rate(optimizer, epoch, args): 818 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 819 | lr = args.lr * (0.1**(epoch // 30)) 820 | for param_group in optimizer.param_groups: 821 | param_group['lr'] = lr 822 | 823 | 824 | def accuracy(output, target, topk=(1, )): 825 | """Computes the accuracy over the k top predictions for the specified values of k""" 826 | with torch.no_grad(): 827 | maxk = max(topk) 828 | batch_size = target.size(0) 829 | 830 | _, pred = output.topk(maxk, 1, True, True) 831 | pred = pred.t() 832 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 833 | 834 | res = [] 835 | for k in topk: 836 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 837 | res.append(correct_k.mul_(100.0 / batch_size)) 838 | return res 839 | 840 | 841 | if __name__ == '__main__': 842 | main() 843 | -------------------------------------------------------------------------------- /src/main_deterministic_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import shutil 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torch.utils.data 13 | from torch.utils.tensorboard import SummaryWriter 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import models.deterministic.resnet as resnet 17 | import numpy as np 18 | 19 | model_names = sorted( 20 | name for name in resnet.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and name.startswith("resnet") and callable(resnet.__dict__[name])) 23 | 24 | print(model_names) 25 | 26 | parser = argparse.ArgumentParser(description='CIFAR10') 27 | parser.add_argument('--arch', 28 | '-a', 29 | metavar='ARCH', 30 | default='resnet20', 31 | choices=model_names, 32 | help='model architecture: ' + ' | '.join(model_names) + 33 | ' (default: resnet20)') 34 | parser.add_argument('-j', 35 | '--workers', 36 | default=8, 37 | type=int, 38 | metavar='N', 39 | help='number of data loading workers (default: 8)') 40 | parser.add_argument('--epochs', 41 | default=200, 42 | type=int, 43 | metavar='N', 44 | help='number of total epochs to run') 45 | parser.add_argument('--start-epoch', 46 | default=0, 47 | type=int, 48 | metavar='N', 49 | help='manual epoch number (useful on restarts)') 50 | parser.add_argument('-b', 51 | '--batch-size', 52 | default=512, 53 | type=int, 54 | metavar='N', 55 | help='mini-batch size (default: 512)') 56 | parser.add_argument('--lr', 57 | '--learning-rate', 58 | default=0.1, 59 | type=float, 60 | metavar='LR', 61 | help='initial learning rate') 62 | parser.add_argument('--momentum', 63 | default=0.9, 64 | type=float, 65 | metavar='M', 66 | help='momentum') 67 | parser.add_argument('--weight-decay', 68 | '--wd', 69 | default=1e-4, 70 | type=float, 71 | metavar='W', 72 | help='weight decay (default: 5e-4)') 73 | parser.add_argument('--print-freq', 74 | '-p', 75 | default=50, 76 | type=int, 77 | metavar='N', 78 | help='print frequency (default: 20)') 79 | parser.add_argument('--resume', 80 | default='', 81 | type=str, 82 | metavar='PATH', 83 | help='path to latest checkpoint (default: none)') 84 | parser.add_argument('-e', 85 | '--evaluate', 86 | dest='evaluate', 87 | action='store_true', 88 | help='evaluate model on validation set') 89 | parser.add_argument('--pretrained', 90 | dest='pretrained', 91 | action='store_true', 92 | help='use pre-trained model') 93 | parser.add_argument('--half', 94 | dest='half', 95 | action='store_true', 96 | help='use half-precision(16-bit) ') 97 | parser.add_argument('--save-dir', 98 | dest='save_dir', 99 | help='The directory used to save the trained models', 100 | default='./checkpoint/deterministic', 101 | type=str) 102 | parser.add_argument( 103 | '--save-every', 104 | dest='save_every', 105 | help='Saves checkpoints at every specified number of epochs', 106 | type=int, 107 | default=10) 108 | parser.add_argument('--mode', type=str, required=True, help='train | test') 109 | parser.add_argument( 110 | '--tensorboard', 111 | type=bool, 112 | default=True, 113 | metavar='N', 114 | help='use tensorboard for logging and visualization of training progress') 115 | parser.add_argument( 116 | '--log_dir', 117 | type=str, 118 | default='./logs/cifar/deterministic', 119 | metavar='N', 120 | help='use tensorboard for logging and visualization of training progress') 121 | best_prec1 = 0 122 | 123 | 124 | def main(): 125 | global args, best_prec1 126 | args = parser.parse_args() 127 | 128 | # Check the save_dir exists or not 129 | if not os.path.exists(args.save_dir): 130 | os.makedirs(args.save_dir) 131 | 132 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) 133 | if torch.cuda.is_available(): 134 | model.cuda() 135 | else: 136 | model.cpu() 137 | 138 | # optionally resume from a checkpoint 139 | if args.resume: 140 | if os.path.isfile(args.resume): 141 | print("=> loading checkpoint '{}'".format(args.resume)) 142 | checkpoint = torch.load(args.resume) 143 | args.start_epoch = checkpoint['epoch'] 144 | best_prec1 = checkpoint['best_prec1'] 145 | model.load_state_dict(checkpoint['state_dict']) 146 | print("=> loaded checkpoint '{}' (epoch {})".format( 147 | args.evaluate, checkpoint['epoch'])) 148 | else: 149 | print("=> no checkpoint found at '{}'".format(args.resume)) 150 | 151 | cudnn.benchmark = True 152 | 153 | tb_writer = None 154 | if args.tensorboard: 155 | logger_dir = os.path.join(args.log_dir, 'tb_logger') 156 | if not os.path.exists(logger_dir): 157 | os.makedirs(logger_dir) 158 | tb_writer = SummaryWriter(logger_dir) 159 | 160 | preds_dir = os.path.join(args.log_dir, 'preds') 161 | if not os.path.exists(preds_dir): 162 | os.makedirs(preds_dir) 163 | results_dir = os.path.join(args.log_dir, 'results') 164 | if not os.path.exists(results_dir): 165 | os.makedirs(results_dir) 166 | 167 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 168 | std=[0.229, 0.224, 0.225]) 169 | 170 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( 171 | root='./data', 172 | train=True, 173 | transform=transforms.Compose([ 174 | transforms.RandomHorizontalFlip(), 175 | transforms.RandomCrop(32, 4), 176 | transforms.ToTensor(), 177 | ]), 178 | download=True), 179 | batch_size=args.batch_size, 180 | shuffle=True, 181 | num_workers=args.workers, 182 | pin_memory=True) 183 | 184 | val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( 185 | root='./data', 186 | train=False, 187 | transform=transforms.Compose([ 188 | transforms.ToTensor(), 189 | ])), 190 | batch_size=args.batch_size, 191 | shuffle=False, 192 | num_workers=args.workers, 193 | pin_memory=True) 194 | 195 | if not os.path.exists(args.save_dir): 196 | os.makedirs(args.save_dir) 197 | 198 | if torch.cuda.is_available(): 199 | criterion = nn.CrossEntropyLoss().cuda() 200 | else: 201 | criterion = nn.CrossEntropyLoss().cpu() 202 | 203 | if args.half: 204 | model.half() 205 | criterion.half() 206 | 207 | optimizer = torch.optim.SGD(model.parameters(), 208 | args.lr, 209 | momentum=args.momentum, 210 | weight_decay=args.weight_decay) 211 | 212 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 213 | optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1) 214 | 215 | if args.arch in ['resnet110']: 216 | for param_group in optimizer.param_groups: 217 | param_group['lr'] = args.lr * 0.1 218 | 219 | if args.evaluate: 220 | validate(val_loader, model, criterion) 221 | return 222 | 223 | if args.mode == 'train': 224 | for epoch in range(args.start_epoch, args.epochs): 225 | 226 | # train for one epoch 227 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 228 | train(train_loader, model, criterion, optimizer, epoch, tb_writer) 229 | lr_scheduler.step() 230 | 231 | prec1 = validate(val_loader, model, criterion, epoch, tb_writer) 232 | 233 | is_best = prec1 > best_prec1 234 | best_prec1 = max(prec1, best_prec1) 235 | 236 | if epoch > 0 and epoch % args.save_every == 0: 237 | if is_best: 238 | save_checkpoint( 239 | { 240 | 'epoch': epoch + 1, 241 | 'state_dict': model.state_dict(), 242 | 'best_prec1': best_prec1, 243 | }, 244 | is_best, 245 | filename=os.path.join( 246 | args.save_dir, 247 | 'det_{}_cifar.pth'.format(args.arch))) 248 | 249 | elif args.mode == 'test': 250 | checkpoint_file = args.save_dir + '/det_{}_cifar.pth'.format(args.arch) 251 | if torch.cuda.is_available(): 252 | checkpoint = torch.load(checkpoint_file) 253 | else: 254 | checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) 255 | model.load_state_dict(checkpoint['state_dict']) 256 | evaluate(model, val_loader) 257 | 258 | 259 | def train(train_loader, model, criterion, optimizer, epoch, tb_writer=None): 260 | """ 261 | Run one train epoch 262 | """ 263 | batch_time = AverageMeter() 264 | data_time = AverageMeter() 265 | losses = AverageMeter() 266 | top1 = AverageMeter() 267 | 268 | # switch to train mode 269 | model.train() 270 | 271 | end = time.time() 272 | for i, (input, target) in enumerate(train_loader): 273 | 274 | # measure data loading time 275 | data_time.update(time.time() - end) 276 | 277 | if torch.cuda.is_available(): 278 | target = target.cuda() 279 | input_var = input.cuda() 280 | target_var = target.cuda() 281 | else: 282 | target = target.cpu() 283 | input_var = input.cpu() 284 | target_var = target.cpu() 285 | if args.half: 286 | input_var = input_var.half() 287 | 288 | # compute output 289 | output = model(input_var) 290 | loss = criterion(output, target_var) 291 | 292 | # compute gradient and do SGD step 293 | optimizer.zero_grad() 294 | loss.backward() 295 | optimizer.step() 296 | 297 | output = output.float() 298 | loss = loss.float() 299 | # measure accuracy and record loss 300 | prec1 = accuracy(output.data, target)[0] 301 | losses.update(loss.item(), input.size(0)) 302 | top1.update(prec1.item(), input.size(0)) 303 | 304 | # measure elapsed time 305 | batch_time.update(time.time() - end) 306 | end = time.time() 307 | 308 | if i % args.print_freq == 0: 309 | print('Epoch: [{0}][{1}/{2}]\t' 310 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 311 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 312 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 313 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 314 | epoch, 315 | i, 316 | len(train_loader), 317 | batch_time=batch_time, 318 | data_time=data_time, 319 | loss=losses, 320 | top1=top1)) 321 | 322 | if tb_writer is not None: 323 | tb_writer.add_scalar('train/loss', loss.item(), epoch) 324 | tb_writer.add_scalar('train/accuracy', prec1.item(), epoch) 325 | tb_writer.flush() 326 | 327 | 328 | def validate(val_loader, model, criterion, epoch, tb_writer=None): 329 | """ 330 | Run evaluation 331 | """ 332 | batch_time = AverageMeter() 333 | losses = AverageMeter() 334 | top1 = AverageMeter() 335 | 336 | # switch to evaluate mode 337 | model.eval() 338 | 339 | end = time.time() 340 | with torch.no_grad(): 341 | for i, (input, target) in enumerate(val_loader): 342 | if torch.cuda.is_available(): 343 | target = target.cuda() 344 | input_var = input.cuda() 345 | target_var = target.cuda() 346 | else: 347 | target = target.cpu() 348 | input_var = input.cpu() 349 | target_var = target.cpu() 350 | 351 | if args.half: 352 | input_var = input_var.half() 353 | 354 | # compute output 355 | output = model(input_var) 356 | loss = criterion(output, target_var) 357 | 358 | output = output.float() 359 | loss = loss.float() 360 | 361 | # measure accuracy and record loss 362 | prec1 = accuracy(output.data, target)[0] 363 | losses.update(loss.item(), input.size(0)) 364 | top1.update(prec1.item(), input.size(0)) 365 | 366 | # measure elapsed time 367 | batch_time.update(time.time() - end) 368 | end = time.time() 369 | 370 | if i % args.print_freq == 0: 371 | print('Test: [{0}/{1}]\t' 372 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 373 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 374 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 375 | i, 376 | len(val_loader), 377 | batch_time=batch_time, 378 | loss=losses, 379 | top1=top1)) 380 | 381 | if tb_writer is not None: 382 | tb_writer.add_scalar('val/loss', loss.item(), epoch) 383 | tb_writer.add_scalar('val/accuracy', prec1.item(), epoch) 384 | tb_writer.flush() 385 | 386 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 387 | 388 | return top1.avg 389 | 390 | 391 | def evaluate(model, val_loader): 392 | model.eval() 393 | correct = 0 394 | with torch.no_grad(): 395 | for data, target in val_loader: 396 | if torch.cuda.is_available(): 397 | data, target = data.cuda(), target.cuda() 398 | else: 399 | data, target = data.cpu(), target.cpu() 400 | output = model(data) 401 | output = torch.nn.functional.softmax(output, dim=1) 402 | pred = output.argmax( 403 | dim=1, 404 | keepdim=True) # get the index of the max log-probability 405 | correct += pred.eq(target.view_as(pred)).sum().item() 406 | 407 | print('\nTest set: Accuracy: {:.2f}%\n'.format(100. * correct / 408 | len(val_loader.dataset))) 409 | target_labels = target.cpu().data.numpy() 410 | np.save('./probs_cifar_det.npy', output.cpu().data.numpy()) 411 | np.save('./cifar_test_labels.npy', target_labels) 412 | 413 | 414 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 415 | """ 416 | Save the training model 417 | """ 418 | torch.save(state, filename) 419 | 420 | 421 | class AverageMeter(object): 422 | """Computes and stores the average and current value""" 423 | def __init__(self): 424 | self.reset() 425 | 426 | def reset(self): 427 | self.val = 0 428 | self.avg = 0 429 | self.sum = 0 430 | self.count = 0 431 | 432 | def update(self, val, n=1): 433 | self.val = val 434 | self.sum += val * n 435 | self.count += n 436 | self.avg = self.sum / self.count 437 | 438 | 439 | def accuracy(output, target, topk=(1, )): 440 | """Computes the precision@k for the specified values of k""" 441 | maxk = max(topk) 442 | batch_size = target.size(0) 443 | 444 | _, pred = output.topk(maxk, 1, True, True) 445 | pred = pred.t() 446 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 447 | 448 | res = [] 449 | for k in topk: 450 | correct_k = correct[:k].view(-1).float().sum(0) 451 | res.append(correct_k.mul_(100.0 / batch_size)) 452 | return res 453 | 454 | 455 | if __name__ == '__main__': 456 | main() 457 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 Intel Corporation 2 | # 3 | # BSD-3-Clause License 4 | # 5 | # Redistribution and use in source and binary forms, with or without modification, 6 | # are permitted provided that the following conditions are met: 7 | # 1. Redistributions of source code must retain the above copyright notice, 8 | # this list of conditions and the following disclaimer. 9 | # 2. Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # 3. Neither the name of the copyright holder nor the names of its contributors 13 | # may be used to endorse or promote products derived from this software 14 | # without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 20 | # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 21 | # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT 22 | # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 23 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 24 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 25 | # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 26 | # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | # Utily functions for variational inference in Bayesian deep neural networks 29 | # ELBO_loss -> to compute evidence lower bound loss 30 | # get_rho -> variance (sigma) is represented by softplus function 'sigma = log(1 + exp(rho))' 31 | # to make sure it remains always positive and non-transformed 'rho' gets 32 | # updated during training. 33 | # MOPED -> set the priors and initialize approximate variational posteriors of Bayesian NN 34 | # with Empirical Bayes 35 | # 36 | # @authors: Ranganath Krishnan 37 | # 38 | # =============================================================================================== 39 | 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | import torch.nn.functional as F 44 | import torch 45 | from torch import nn 46 | import numpy as np 47 | from sklearn.metrics import auc 48 | 49 | 50 | def ELBO_loss(out, y, kl_loss, num_data_samples, batch_size): 51 | nll_loss = F.cross_entropy(out, y) 52 | return nll_loss + ((1.0 / num_data_samples) * kl_loss) 53 | 54 | 55 | def get_rho(sigma, delta): 56 | rho = torch.log(torch.expm1(delta * torch.abs(sigma)) + 1e-20) 57 | return rho 58 | 59 | 60 | def entropy(prob): 61 | return -1 * np.sum(prob * np.log(prob + 1e-15), axis=-1) 62 | 63 | 64 | def predictive_entropy(mc_preds): 65 | """ 66 | Compute the entropy of the mean of the predictive distribution 67 | obtained from Monte Carlo sampling during prediction phase. 68 | """ 69 | return entropy(np.mean(mc_preds, axis=0)) 70 | 71 | 72 | def mutual_information(mc_preds): 73 | """ 74 | Compute the difference between the entropy of the mean of the 75 | predictive distribution and the mean of the entropy. 76 | """ 77 | MI = entropy(np.mean(mc_preds, axis=0)) - np.mean(entropy(mc_preds), 78 | axis=0) 79 | return MI 80 | 81 | 82 | def MOPED(model, det_model, det_checkpoint, delta): 83 | """ 84 | Set the priors and initialize surrogate posteriors of Bayesian NN with Empirical Bayes 85 | MOPED (Model Priors with Empirical Bayes using Deterministic DNN) 86 | Ref: https://arxiv.org/abs/1906.05323 87 | 'Specifying Weight Priors in Bayesian Deep Neural Networks with Empirical Bayes'. AAAI 2020. 88 | """ 89 | det_model.load_state_dict(torch.load(det_checkpoint)) 90 | for (idx, layer), (det_idx, 91 | det_layer) in zip(enumerate(model.modules()), 92 | enumerate(det_model.modules())): 93 | if (str(layer) == 'Conv1dVariational()' 94 | or str(layer) == 'Conv2dVariational()' 95 | or str(layer) == 'Conv3dVariational()' 96 | or str(layer) == 'ConvTranspose1dVariational()' 97 | or str(layer) == 'ConvTranspose2dVariational()' 98 | or str(layer) == 'ConvTranspose3dVariational()'): 99 | #set the priors 100 | layer.prior_weight_mu.data = det_layer.weight 101 | layer.prior_bias_mu.data = det_layer.bias 102 | 103 | #initialize surrogate posteriors 104 | layer.mu_kernel.data = det_layer.weight 105 | layer.rho_kernel.data = get_rho(det_layer.weight.data, delta) 106 | layer.mu_bias.data = det_layer.bias 107 | layer.rho_bias.data = get_rho(det_layer.bias.data, delta) 108 | elif (str(layer) == 'LinearVariational()'): 109 | #set the priors 110 | layer.prior_weight_mu.data = det_layer.weight 111 | layer.prior_bias_mu.data = det_layer.bias 112 | 113 | #initialize the surrogate posteriors 114 | layer.mu_weight.data = det_layer.weight 115 | layer.rho_weight.data = get_rho(det_layer.weight.data, delta) 116 | layer.mu_bias.data = det_layer.bias 117 | layer.rho_bias.data = get_rho(det_layer.bias.data, delta) 118 | elif str(layer).startswith('Batch'): 119 | #initialize parameters 120 | layer.weight.data = det_layer.weight 121 | layer.bias.data = det_layer.bias 122 | layer.running_mean.data = det_layer.running_mean 123 | layer.running_var.data = det_layer.running_var 124 | layer.num_batches_tracked.data = det_layer.num_batches_tracked 125 | 126 | model.state_dict() 127 | return model 128 | 129 | 130 | def eval_avu(pred_label, true_label, uncertainty): 131 | """ returns AvU at various uncertainty thresholds""" 132 | t_list = np.linspace(0, 1, 21) 133 | umin = np.amin(uncertainty, axis=0) 134 | umax = np.amax(uncertainty, axis=0) 135 | avu_list = [] 136 | unc_list = [] 137 | for t in t_list: 138 | u_th = umin + (t * (umax - umin)) 139 | n_ac = 0 140 | n_ic = 0 141 | n_au = 0 142 | n_iu = 0 143 | for i in range(len(true_label)): 144 | if ((true_label[i] == pred_label[i]) and uncertainty[i] <= u_th): 145 | n_ac += 1 146 | elif ((true_label[i] == pred_label[i]) and uncertainty[i] > u_th): 147 | n_au += 1 148 | elif ((true_label[i] != pred_label[i]) and uncertainty[i] <= u_th): 149 | n_ic += 1 150 | elif ((true_label[i] != pred_label[i]) and uncertainty[i] > u_th): 151 | n_iu += 1 152 | 153 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + 1e-15) 154 | avu_list.append(AvU) 155 | unc_list.append(u_th) 156 | return np.asarray(avu_list), np.asarray(unc_list) 157 | 158 | 159 | def accuracy_vs_uncertainty(pred_label, true_label, uncertainty, 160 | optimal_threshold): 161 | 162 | n_ac = 0 163 | n_ic = 0 164 | n_au = 0 165 | n_iu = 0 166 | for i in range(len(true_label)): 167 | if ((true_label[i] == pred_label[i]) 168 | and uncertainty[i] <= optimal_threshold): 169 | n_ac += 1 170 | elif ((true_label[i] == pred_label[i]) 171 | and uncertainty[i] > optimal_threshold): 172 | n_au += 1 173 | elif ((true_label[i] != pred_label[i]) 174 | and uncertainty[i] <= optimal_threshold): 175 | n_ic += 1 176 | elif ((true_label[i] != pred_label[i]) 177 | and uncertainty[i] > optimal_threshold): 178 | n_iu += 1 179 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu) 180 | 181 | return AvU 182 | -------------------------------------------------------------------------------- /variational_layers/conv_variational.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 Intel Corporation 2 | # 3 | # BSD-3-Clause License 4 | # 5 | # Redistribution and use in source and binary forms, with or without modification, 6 | # are permitted provided that the following conditions are met: 7 | # 1. Redistributions of source code must retain the above copyright notice, 8 | # this list of conditions and the following disclaimer. 9 | # 2. Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # 3. Neither the name of the copyright holder nor the names of its contributors 13 | # may be used to endorse or promote products derived from this software 14 | # without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 20 | # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 21 | # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT 22 | # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 23 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 24 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 25 | # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 26 | # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | # Convolutional Variational Layers with reparameterization estimator to perform 29 | # mean-field variational inference in Bayesian neural networks. Variational layers 30 | # enables Monte Carlo approximation of the distribution over 'kernel' and 'bias'. 31 | # 32 | # Kullback-Leibler divergence between the surrogate posterior and prior is computed 33 | # and returned along with the tensors of outputs after convolution operation, which is 34 | # required to compute Evidence Lower Bound (ELBO) loss for variational inference. 35 | # 36 | # @authors: Ranganath Krishnan 37 | # 38 | # ====================================================================================== 39 | 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import torch 45 | import torch.nn as nn 46 | import torch.nn.functional as F 47 | from torch.nn import Module, Parameter 48 | import math 49 | 50 | __all__ = [ 51 | 'Conv1dVariational', 52 | 'Conv2dVariational', 53 | 'Conv3dVariational', 54 | 'ConvTranspose1dVariational', 55 | 'ConvTranspose2dVariational', 56 | 'ConvTranspose3dVariational', 57 | ] 58 | 59 | 60 | class Conv1dVariational(Module): 61 | def __init__(self, 62 | prior_mean, 63 | prior_variance, 64 | posterior_mu_init, 65 | posterior_rho_init, 66 | in_channels, 67 | out_channels, 68 | kernel_size, 69 | stride=1, 70 | padding=0, 71 | dilation=1, 72 | groups=1, 73 | bias=True): 74 | 75 | super(Conv1dVariational, self).__init__() 76 | if in_channels % groups != 0: 77 | raise ValueError('invalid in_channels size') 78 | if out_channels % groups != 0: 79 | raise ValueError('invalid in_channels size') 80 | 81 | self.in_channels = in_channels 82 | self.out_channels = out_channels 83 | self.kernel_size = kernel_size 84 | self.stride = stride 85 | self.padding = padding 86 | self.dilation = dilation 87 | self.groups = groups 88 | self.prior_mean = prior_mean 89 | self.prior_variance = prior_variance 90 | self.posterior_mu_init = posterior_mu_init, # mean of weight 91 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) 92 | self.bias = bias 93 | 94 | self.mu_kernel = Parameter( 95 | torch.Tensor(out_channels, in_channels // groups, kernel_size)) 96 | self.rho_kernel = Parameter( 97 | torch.Tensor(out_channels, in_channels // groups, kernel_size)) 98 | self.register_buffer( 99 | 'eps_kernel', 100 | torch.Tensor(out_channels, in_channels // groups, kernel_size)) 101 | self.register_buffer( 102 | 'prior_weight_mu', 103 | torch.Tensor(out_channels, in_channels // groups, kernel_size)) 104 | self.register_buffer( 105 | 'prior_weight_sigma', 106 | torch.Tensor(out_channels, in_channels // groups, kernel_size)) 107 | 108 | if self.bias: 109 | self.mu_bias = Parameter(torch.Tensor(out_channels)) 110 | self.rho_bias = Parameter(torch.Tensor(out_channels)) 111 | self.register_buffer('eps_bias', torch.Tensor(out_channels)) 112 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels)) 113 | self.register_buffer('prior_bias_sigma', 114 | torch.Tensor(out_channels)) 115 | else: 116 | self.register_parameter('mu_bias', None) 117 | self.register_parameter('rho_bias', None) 118 | self.register_buffer('eps_bias', None) 119 | self.register_buffer('prior_bias_mu', None) 120 | self.register_buffer('prior_bias_sigma', None) 121 | 122 | self.init_parameters() 123 | 124 | def init_parameters(self): 125 | self.prior_weight_mu.data.fill_(self.prior_mean) 126 | self.prior_weight_sigma.fill_(self.prior_variance) 127 | 128 | self.mu_kernel.data.normal_(std=0.1) 129 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1) 130 | if self.bias: 131 | self.prior_bias_mu.data.fill_(self.prior_mean) 132 | self.prior_bias_sigma.fill_(self.prior_variance) 133 | 134 | self.mu_bias.data.normal_(std=0.1) 135 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], 136 | std=0.1) 137 | 138 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): 139 | kl = torch.log(sigma_p + 1e-15) - torch.log( 140 | sigma_q + 1e-15) + (sigma_q**2 + 141 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5 142 | return kl.sum() 143 | 144 | def forward(self, input): 145 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) 146 | eps_kernel = self.eps_kernel.normal_() 147 | weight = self.mu_kernel + (sigma_weight * eps_kernel) 148 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight, 149 | self.prior_weight_mu, self.prior_weight_sigma) 150 | bias = None 151 | 152 | if self.bias: 153 | sigma_bias = torch.log1p(torch.exp(self.rho_bias)) 154 | eps_bias = self.eps_bias.normal_() 155 | bias = self.mu_bias + (sigma_bias * eps_bias) 156 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, 157 | self.prior_bias_sigma) 158 | 159 | out = F.conv1d(input, weight, bias, self.stride, self.padding, 160 | self.dilation, self.groups) 161 | if self.bias: 162 | kl = kl_weight + kl_bias 163 | else: 164 | kl = kl_weight 165 | 166 | return out, kl 167 | 168 | 169 | class Conv2dVariational(Module): 170 | def __init__(self, 171 | prior_mean, 172 | prior_variance, 173 | posterior_mu_init, 174 | posterior_rho_init, 175 | in_channels, 176 | out_channels, 177 | kernel_size, 178 | stride=1, 179 | padding=0, 180 | dilation=1, 181 | groups=1, 182 | bias=True): 183 | 184 | super(Conv2dVariational, self).__init__() 185 | if in_channels % groups != 0: 186 | raise ValueError('invalid in_channels size') 187 | if out_channels % groups != 0: 188 | raise ValueError('invalid in_channels size') 189 | 190 | self.in_channels = in_channels 191 | self.out_channels = out_channels 192 | self.kernel_size = kernel_size 193 | self.stride = stride 194 | self.padding = padding 195 | self.dilation = dilation 196 | self.groups = groups 197 | self.prior_mean = prior_mean 198 | self.prior_variance = prior_variance 199 | self.posterior_mu_init = posterior_mu_init, # mean of weight 200 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) 201 | self.bias = bias 202 | 203 | self.mu_kernel = Parameter( 204 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 205 | kernel_size)) 206 | self.rho_kernel = Parameter( 207 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 208 | kernel_size)) 209 | self.register_buffer( 210 | 'eps_kernel', 211 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 212 | kernel_size)) 213 | self.register_buffer( 214 | 'prior_weight_mu', 215 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 216 | kernel_size)) 217 | self.register_buffer( 218 | 'prior_weight_sigma', 219 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 220 | kernel_size)) 221 | 222 | if self.bias: 223 | self.mu_bias = Parameter(torch.Tensor(out_channels)) 224 | self.rho_bias = Parameter(torch.Tensor(out_channels)) 225 | self.register_buffer('eps_bias', torch.Tensor(out_channels)) 226 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels)) 227 | self.register_buffer('prior_bias_sigma', 228 | torch.Tensor(out_channels)) 229 | else: 230 | self.register_parameter('mu_bias', None) 231 | self.register_parameter('rho_bias', None) 232 | self.register_buffer('eps_bias', None) 233 | self.register_buffer('prior_bias_mu', None) 234 | self.register_buffer('prior_bias_sigma', None) 235 | 236 | self.init_parameters() 237 | 238 | def init_parameters(self): 239 | self.prior_weight_mu.fill_(self.prior_mean) 240 | self.prior_weight_sigma.fill_(self.prior_variance) 241 | 242 | self.mu_kernel.data.normal_(std=0.1) 243 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1) 244 | if self.bias: 245 | self.prior_bias_mu.fill_(self.prior_mean) 246 | self.prior_bias_sigma.fill_(self.prior_variance) 247 | 248 | self.mu_bias.data.normal_(std=0.1) 249 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], 250 | std=0.1) 251 | 252 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): 253 | kl = torch.log(sigma_p + 1e-15) - torch.log( 254 | sigma_q + 1e-15) + (sigma_q**2 + 255 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5 256 | return kl.sum() 257 | 258 | def forward(self, input): 259 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) 260 | eps_kernel = self.eps_kernel.normal_() 261 | weight = self.mu_kernel + (sigma_weight * eps_kernel) 262 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight, 263 | self.prior_weight_mu, self.prior_weight_sigma) 264 | bias = None 265 | 266 | if self.bias: 267 | sigma_bias = torch.log1p(torch.exp(self.rho_bias)) 268 | eps_bias = self.eps_bias.normal_() 269 | bias = self.mu_bias + (sigma_bias * eps_bias) 270 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, 271 | self.prior_bias_sigma) 272 | 273 | out = F.conv2d(input, weight, bias, self.stride, self.padding, 274 | self.dilation, self.groups) 275 | if self.bias: 276 | kl = kl_weight + kl_bias 277 | else: 278 | kl = kl_weight 279 | 280 | return out, kl 281 | 282 | 283 | class Conv3dVariational(Module): 284 | def __init__(self, 285 | prior_mean, 286 | prior_variance, 287 | posterior_mu_init, 288 | posterior_rho_init, 289 | in_channels, 290 | out_channels, 291 | kernel_size, 292 | stride=1, 293 | padding=0, 294 | dilation=1, 295 | groups=1, 296 | bias=True): 297 | 298 | super(Conv3dVariational, self).__init__() 299 | if in_channels % groups != 0: 300 | raise ValueError('invalid in_channels size') 301 | if out_channels % groups != 0: 302 | raise ValueError('invalid in_channels size') 303 | 304 | self.in_channels = in_channels 305 | self.out_channels = out_channels 306 | self.kernel_size = kernel_size 307 | self.stride = stride 308 | self.padding = padding 309 | self.dilation = dilation 310 | self.groups = groups 311 | self.prior_mean = prior_mean 312 | self.prior_variance = prior_variance 313 | self.posterior_mu_init = posterior_mu_init, # mean of weight 314 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) 315 | self.bias = bias 316 | 317 | self.mu_kernel = Parameter( 318 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 319 | kernel_size, kernel_size)) 320 | self.rho_kernel = Parameter( 321 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 322 | kernel_size, kernel_size)) 323 | self.register_buffer( 324 | 'eps_kernel', 325 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 326 | kernel_size, kernel_size)) 327 | self.register_buffer( 328 | 'prior_weight_mu', 329 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 330 | kernel_size, kernel_size)) 331 | self.register_buffer( 332 | 'prior_weight_sigma', 333 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 334 | kernel_size, kernel_size)) 335 | 336 | if self.bias: 337 | self.mu_bias = Parameter(torch.Tensor(out_channels)) 338 | self.rho_bias = Parameter(torch.Tensor(out_channels)) 339 | self.register_buffer('eps_bias', torch.Tensor(out_channels)) 340 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels)) 341 | self.register_buffer('prior_bias_sigma', 342 | torch.Tensor(out_channels)) 343 | else: 344 | self.register_parameter('mu_bias', None) 345 | self.register_parameter('rho_bias', None) 346 | self.register_buffer('eps_bias', None) 347 | self.register_buffer('prior_bias_mu', None) 348 | self.register_buffer('prior_bias_sigma', None) 349 | 350 | self.init_parameters() 351 | 352 | def init_parameters(self): 353 | self.prior_weight_mu.fill_(self.prior_mean) 354 | self.prior_weight_sigma.fill_(self.prior_variance) 355 | 356 | self.mu_kernel.data.normal_(std=0.1) 357 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1) 358 | if self.bias: 359 | self.prior_bias_mu.fill_(self.prior_mean) 360 | self.prior_bias_sigma.fill_(self.prior_variance) 361 | 362 | self.mu_bias.data.normal_(std=0.1) 363 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], 364 | std=0.1) 365 | 366 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): 367 | kl = torch.log(sigma_p + 1e-15) - torch.log( 368 | sigma_q + 1e-15) + (sigma_q**2 + 369 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5 370 | return kl.sum() 371 | 372 | def forward(self, input): 373 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) 374 | eps_kernel = self.eps_kernel.normal_() 375 | weight = self.mu_kernel + (sigma_weight * eps_kernel) 376 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight, 377 | self.prior_weight_mu, self.prior_weight_sigma) 378 | bias = None 379 | 380 | if self.bias: 381 | sigma_bias = torch.log1p(torch.exp(self.rho_bias)) 382 | eps_bias = self.eps_bias.normal_() 383 | bias = self.mu_bias + (sigma_bias * eps_bias) 384 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, 385 | self.prior_bias_sigma) 386 | 387 | out = F.conv3d(input, weight, bias, self.stride, self.padding, 388 | self.dilation, self.groups) 389 | if self.bias: 390 | kl = kl_weight + kl_bias 391 | else: 392 | kl = kl_weight 393 | 394 | return out, kl 395 | 396 | 397 | class ConvTranspose1dVariational(Module): 398 | def __init__(self, 399 | prior_mean, 400 | prior_variance, 401 | posterior_mu_init, 402 | posterior_rho_init, 403 | in_channels, 404 | out_channels, 405 | kernel_size, 406 | stride=1, 407 | padding=0, 408 | output_padding=0, 409 | dilation=1, 410 | groups=1, 411 | bias=True): 412 | 413 | super(ConvTranspose1dVariational, self).__init__() 414 | if in_channels % groups != 0: 415 | raise ValueError('invalid in_channels size') 416 | if out_channels % groups != 0: 417 | raise ValueError('invalid in_channels size') 418 | 419 | self.in_channels = in_channels 420 | self.out_channels = out_channels 421 | self.kernel_size = kernel_size 422 | self.stride = stride 423 | self.padding = padding 424 | self.output_padding = output_padding 425 | self.dilation = dilation 426 | self.groups = groups 427 | self.prior_mean = prior_mean 428 | self.prior_variance = prior_variance 429 | self.posterior_mu_init = posterior_mu_init, # mean of weight 430 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) 431 | self.bias = bias 432 | 433 | self.mu_kernel = Parameter( 434 | torch.Tensor(in_channels, out_channels // groups, kernel_size)) 435 | self.rho_kernel = Parameter( 436 | torch.Tensor(in_channels, out_channels // groups, kernel_size)) 437 | self.register_buffer( 438 | 'eps_kernel', 439 | torch.Tensor(in_channels, out_channels // groups, kernel_size)) 440 | self.register_buffer( 441 | 'prior_weight_mu', 442 | torch.Tensor(in_channels, out_channels // groups, kernel_size)) 443 | self.register_buffer( 444 | 'prior_weight_sigma', 445 | torch.Tensor(in_channels, out_channels // groups, kernel_size)) 446 | 447 | if self.bias: 448 | self.mu_bias = Parameter(torch.Tensor(out_channels)) 449 | self.rho_bias = Parameter(torch.Tensor(out_channels)) 450 | self.register_buffer('eps_bias', torch.Tensor(out_channels)) 451 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels)) 452 | self.register_buffer('prior_bias_sigma', 453 | torch.Tensor(out_channels)) 454 | else: 455 | self.register_parameter('mu_bias', None) 456 | self.register_parameter('rho_bias', None) 457 | self.register_buffer('eps_bias', None) 458 | self.register_buffer('prior_bias_mu', None) 459 | self.register_buffer('prior_bias_sigma', None) 460 | 461 | self.init_parameters() 462 | 463 | def init_parameters(self): 464 | self.prior_weight_mu.fill_(self.prior_mean) 465 | self.prior_weight_sigma.fill_(self.prior_variance) 466 | 467 | self.mu_kernel.data.normal_(std=0.1) 468 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1) 469 | if self.bias: 470 | self.prior_bias_mu.fill_(self.prior_mean) 471 | self.prior_bias_sigma.fill_(self.prior_variance) 472 | 473 | self.mu_bias.data.normal_(std=0.1) 474 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], 475 | std=0.1) 476 | 477 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): 478 | kl = torch.log(sigma_p + 1e-15) - torch.log( 479 | sigma_q + 1e-15) + (sigma_q**2 + 480 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5 481 | return kl.sum() 482 | 483 | def forward(self, input): 484 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) 485 | eps_kernel = self.eps_kernel.normal_() 486 | weight = self.mu_kernel + (sigma_weight * eps_kernel) 487 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight, 488 | self.prior_weight_mu, self.prior_weight_sigma) 489 | bias = None 490 | 491 | if self.bias: 492 | sigma_bias = torch.log1p(torch.exp(self.rho_bias)) 493 | eps_bias = self.eps_bias.normal_() 494 | bias = self.mu_bias + (sigma_bias * eps_bias) 495 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, 496 | self.prior_bias_sigma) 497 | 498 | out = F.conv_transpose1d(input, weight, bias, self.stride, 499 | self.padding, self.output_padding, 500 | self.dilation, self.groups) 501 | if self.bias: 502 | kl = kl_weight + kl_bias 503 | else: 504 | kl = kl_weight 505 | 506 | return out, kl 507 | 508 | 509 | class ConvTranspose2dVariational(Module): 510 | def __init__(self, 511 | prior_mean, 512 | prior_variance, 513 | posterior_mu_init, 514 | posterior_rho_init, 515 | in_channels, 516 | out_channels, 517 | kernel_size, 518 | stride=1, 519 | padding=0, 520 | output_padding=0, 521 | dilation=1, 522 | groups=1, 523 | bias=True): 524 | 525 | super(ConvTranspose2dVariational, self).__init__() 526 | if in_channels % groups != 0: 527 | raise ValueError('invalid in_channels size') 528 | if out_channels % groups != 0: 529 | raise ValueError('invalid in_channels size') 530 | 531 | self.in_channels = in_channels 532 | self.out_channels = out_channels 533 | self.kernel_size = kernel_size 534 | self.stride = stride 535 | self.padding = padding 536 | self.output_padding = output_padding 537 | self.dilation = dilation 538 | self.groups = groups 539 | self.prior_mean = prior_mean 540 | self.prior_variance = prior_variance 541 | self.posterior_mu_init = posterior_mu_init, # mean of weight 542 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) 543 | self.bias = bias 544 | 545 | self.mu_kernel = Parameter( 546 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 547 | kernel_size)) 548 | self.rho_kernel = Parameter( 549 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 550 | kernel_size)) 551 | self.register_buffer( 552 | 'eps_kernel', 553 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 554 | kernel_size)) 555 | self.register_buffer( 556 | 'prior_weight_mu', 557 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 558 | kernel_size)) 559 | self.register_buffer( 560 | 'prior_weight_sigma', 561 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 562 | kernel_size)) 563 | 564 | if self.bias: 565 | self.mu_bias = Parameter(torch.Tensor(out_channels)) 566 | self.rho_bias = Parameter(torch.Tensor(out_channels)) 567 | self.register_buffer('eps_bias', torch.Tensor(out_channels)) 568 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels)) 569 | self.register_buffer('prior_bias_sigma', 570 | torch.Tensor(out_channels)) 571 | else: 572 | self.register_parameter('mu_bias', None) 573 | self.register_parameter('rho_bias', None) 574 | self.register_buffer('eps_bias', None) 575 | self.register_buffer('prior_bias_mu', None) 576 | self.register_buffer('prior_bias_sigma', None) 577 | 578 | self.init_parameters() 579 | 580 | def init_parameters(self): 581 | self.prior_weight_mu.fill_(self.prior_mean) 582 | self.prior_weight_sigma.fill_(self.prior_variance) 583 | 584 | self.mu_kernel.data.normal_(std=0.1) 585 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1) 586 | if self.bias: 587 | self.prior_bias_mu.fill_(self.prior_mean) 588 | self.prior_bias_sigma.fill_(self.prior_variance) 589 | 590 | self.mu_bias.data.normal_(std=0.1) 591 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], 592 | std=0.1) 593 | 594 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): 595 | kl = torch.log(sigma_p + 1e-15) - torch.log( 596 | sigma_q + 1e-15) + (sigma_q**2 + 597 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5 598 | return kl.sum() 599 | 600 | def forward(self, input): 601 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) 602 | eps_kernel = self.eps_kernel.normal_() 603 | weight = self.mu_kernel + (sigma_weight * eps_kernel) 604 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight, 605 | self.prior_weight_mu, self.prior_weight_sigma) 606 | bias = None 607 | 608 | if self.bias: 609 | sigma_bias = torch.log1p(torch.exp(self.rho_bias)) 610 | eps_bias = self.eps_bias.normal_() 611 | bias = self.mu_bias + (sigma_bias * eps_bias) 612 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, 613 | self.prior_bias_sigma) 614 | 615 | out = F.conv_transpose2d(input, weight, bias, self.stride, 616 | self.padding, self.output_padding, 617 | self.dilation, self.groups) 618 | if self.bias: 619 | kl = kl_weight + kl_bias 620 | else: 621 | kl = kl_weight 622 | 623 | return out, kl 624 | 625 | 626 | class ConvTranspose3dVariational(Module): 627 | def __init__(self, 628 | prior_mean, 629 | prior_variance, 630 | posterior_mu_init, 631 | posterior_rho_init, 632 | in_channels, 633 | out_channels, 634 | kernel_size, 635 | stride=1, 636 | padding=0, 637 | output_padding=0, 638 | dilation=1, 639 | groups=1, 640 | bias=True): 641 | 642 | super(ConvTranspose3dVariational, self).__init__() 643 | if in_channels % groups != 0: 644 | raise ValueError('invalid in_channels size') 645 | if out_channels % groups != 0: 646 | raise ValueError('invalid in_channels size') 647 | 648 | self.in_channels = in_channels 649 | self.out_channels = out_channels 650 | self.kernel_size = kernel_size 651 | self.stride = stride 652 | self.padding = padding 653 | self.output_padding = output_padding 654 | self.dilation = dilation 655 | self.groups = groups 656 | self.prior_mean = prior_mean 657 | self.prior_variance = prior_variance 658 | self.posterior_mu_init = posterior_mu_init, # mean of weight 659 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) 660 | self.bias = bias 661 | 662 | self.mu_kernel = Parameter( 663 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 664 | kernel_size, kernel_size)) 665 | self.rho_kernel = Parameter( 666 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 667 | kernel_size, kernel_size)) 668 | self.register_buffer( 669 | 'eps_kernel', 670 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 671 | kernel_size, kernel_size)) 672 | self.register_buffer( 673 | 'prior_weight_mu', 674 | torch.Tensor(in_channels, out_channels // groups, kernel_size, 675 | kernel_size, kernel_size)) 676 | self.register_buffer( 677 | 'prior_weight_sigma', 678 | torch.Tensor(out_channels, in_channels // groups, kernel_size, 679 | kernel_size, kernel_size)) 680 | 681 | if self.bias: 682 | self.mu_bias = Parameter(torch.Tensor(out_channels)) 683 | self.rho_bias = Parameter(torch.Tensor(out_channels)) 684 | self.register_buffer('eps_bias', torch.Tensor(out_channels)) 685 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels)) 686 | self.register_buffer('prior_bias_sigma', 687 | torch.Tensor(out_channels)) 688 | else: 689 | self.register_parameter('mu_bias', None) 690 | self.register_parameter('rho_bias', None) 691 | self.register_buffer('eps_bias', None) 692 | self.register_buffer('prior_bias_mu', None) 693 | self.register_buffer('prior_bias_sigma', None) 694 | 695 | self.init_parameters() 696 | 697 | def init_parameters(self): 698 | self.prior_weight_mu.fill_(self.prior_mean) 699 | self.prior_weight_sigma.fill_(self.prior_variance) 700 | 701 | self.mu_kernel.data.normal_(std=0.1) 702 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1) 703 | if self.bias: 704 | self.prior_bias_mu.fill_(self.prior_mean) 705 | self.prior_bias_sigma.fill_(self.prior_variance) 706 | 707 | self.mu_bias.data.normal_(std=0.1) 708 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], 709 | std=0.1) 710 | 711 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): 712 | kl = torch.log(sigma_p + 1e-15) - torch.log( 713 | sigma_q + 1e-15) + (sigma_q**2 + 714 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5 715 | return kl.sum() 716 | 717 | def forward(self, input): 718 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) 719 | eps_kernel = self.eps_kernel.normal_() 720 | weight = self.mu_kernel + (sigma_weight * eps_kernel) 721 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight, 722 | self.prior_weight_mu, self.prior_weight_sigma) 723 | bias = None 724 | 725 | if self.bias: 726 | sigma_bias = torch.log1p(torch.exp(self.rho_bias)) 727 | eps_bias = self.eps_bias.normal_() 728 | bias = self.mu_bias + (sigma_bias * eps_bias) 729 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, 730 | self.prior_bias_sigma) 731 | 732 | out = F.conv_transpose3d(input, weight, bias, self.stride, 733 | self.padding, self.output_padding, 734 | self.dilation, self.groups) 735 | if self.bias: 736 | kl = kl_weight + kl_bias 737 | else: 738 | kl = kl_weight 739 | 740 | return out, kl 741 | -------------------------------------------------------------------------------- /variational_layers/linear_variational.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 Intel Corporation 2 | # 3 | # BSD-3-Clause License 4 | # 5 | # Redistribution and use in source and binary forms, with or without modification, 6 | # are permitted provided that the following conditions are met: 7 | # 1. Redistributions of source code must retain the above copyright notice, 8 | # this list of conditions and the following disclaimer. 9 | # 2. Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # 3. Neither the name of the copyright holder nor the names of its contributors 13 | # may be used to endorse or promote products derived from this software 14 | # without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS 20 | # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 21 | # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT 22 | # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 23 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 24 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 25 | # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 26 | # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | # Linear Variational Layers with reparameterization estimator to perform 29 | # mean-field variational inference in Bayesian neural networks. Variational layers 30 | # enables Monte Carlo approximation of the distribution over 'kernel' and 'bias'. 31 | # 32 | # Kullback-Leibler divergence between the surrogate posterior and prior is computed 33 | # and returned along with the tensors of outputs after linear opertaion, which is 34 | # required to compute Evidence Lower Bound (ELBO) loss for variational inference. 35 | # 36 | # @authors: Ranganath Krishnan 37 | # 38 | # ====================================================================================== 39 | 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import torch 45 | import torch.nn as nn 46 | import torch.nn.functional as F 47 | from torch.nn import Module, Parameter 48 | import math 49 | 50 | 51 | class LinearVariational(Module): 52 | def __init__(self, 53 | prior_mean, 54 | prior_variance, 55 | posterior_mu_init, 56 | posterior_rho_init, 57 | in_features, 58 | out_features, 59 | bias=True): 60 | 61 | super(LinearVariational, self).__init__() 62 | 63 | self.in_features = in_features 64 | self.out_features = out_features 65 | self.prior_mean = prior_mean 66 | self.prior_variance = prior_variance 67 | self.posterior_mu_init = posterior_mu_init, # mean of weight 68 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) 69 | self.bias = bias 70 | 71 | self.mu_weight = Parameter(torch.Tensor(out_features, in_features)) 72 | self.rho_weight = Parameter(torch.Tensor(out_features, in_features)) 73 | self.register_buffer('eps_weight', 74 | torch.Tensor(out_features, in_features)) 75 | self.register_buffer('prior_weight_mu', 76 | torch.Tensor(out_features, in_features)) 77 | if bias: 78 | self.mu_bias = Parameter(torch.Tensor(out_features)) 79 | self.rho_bias = Parameter(torch.Tensor(out_features)) 80 | self.register_buffer('eps_bias', torch.Tensor(out_features)) 81 | self.register_buffer('prior_bias_mu', torch.Tensor(out_features)) 82 | else: 83 | self.register_buffer('prior_bias_mu', None) 84 | self.register_parameter('mu_bias', None) 85 | self.register_parameter('rho_bias', None) 86 | self.register_buffer('eps_bias', None) 87 | 88 | self.init_parameters() 89 | 90 | def init_parameters(self): 91 | self.prior_weight_mu.fill_(self.prior_mean) 92 | 93 | self.mu_weight.data.normal_(std=0.1) 94 | self.rho_weight.data.normal_(mean=self.posterior_rho_init[0], std=0.1) 95 | if self.mu_bias is not None: 96 | self.prior_bias_mu.fill_(self.prior_mean) 97 | self.mu_bias.data.normal_(std=0.1) 98 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0], 99 | std=0.1) 100 | 101 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): 102 | sigma_p = torch.tensor(sigma_p) 103 | kl = torch.log(sigma_p) - torch.log( 104 | sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 * 105 | (sigma_p**2)) - 0.5 106 | return kl.sum() 107 | 108 | def forward(self, input): 109 | sigma_weight = torch.log1p(torch.exp(self.rho_weight)) 110 | weight = self.mu_weight + (sigma_weight * self.eps_weight.normal_()) 111 | kl_weight = self.kl_div(self.mu_weight, sigma_weight, 112 | self.prior_weight_mu, self.prior_variance) 113 | bias = None 114 | 115 | if self.mu_bias is not None: 116 | sigma_bias = torch.log1p(torch.exp(self.rho_bias)) 117 | bias = self.mu_bias + (sigma_bias * self.eps_bias.normal_()) 118 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, 119 | self.prior_variance) 120 | 121 | out = F.linear(input, weight, bias) 122 | if self.mu_bias is not None: 123 | kl = kl_weight + kl_bias 124 | else: 125 | kl = kl_weight 126 | 127 | return out, kl 128 | --------------------------------------------------------------------------------