├── LICENSE
├── README.md
├── assets
├── figure1.png
├── figure2.png
├── figure3.png
├── figure4.png
└── poster.pdf
├── checkpoint
├── bayesian_svi_avu
│ └── bayesian_resnet20_cifar.pth
├── bayesian_svi_avuts
│ ├── model_with_temperature_avuts.pth
│ └── valid_indices.pth
└── deterministic
│ └── resnet20_cifar.pth
├── models
├── bayesian
│ ├── resnet.py
│ ├── resnet_large.py
│ └── simple_cnn.py
└── deterministic
│ ├── resnet.py
│ ├── resnet_large.py
│ └── simple_cnn.py
├── requirements.txt
├── scripts
├── test_bayesian_cifar.sh
├── test_bayesian_cifar_auavu.sh
├── test_bayesian_cifar_avu.sh
├── test_bayesian_imagenet_avu.sh
├── test_deterministic_cifar.sh
├── train_bayesian_cifar.sh
├── train_bayesian_cifar_auavu.sh
├── train_bayesian_cifar_avu.sh
├── train_bayesian_imagenet_avu.sh
└── train_deterministic_cifar.sh
├── src
├── avuc_loss.py
├── main_bayesian_cifar.py
├── main_bayesian_cifar_auavu.py
├── main_bayesian_cifar_avu.py
├── main_bayesian_imagenet_avu.py
├── main_deterministic_cifar.py
└── util.py
└── variational_layers
├── conv_variational.py
└── linear_variational.py
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, Intel Labs
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Accuracy versus Uncertainty Calibration
2 |
3 | > :warning: **DISCONTINUATION OF PROJECT** -
4 | > *This project will no longer be maintained by Intel.
5 | > Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project.*
6 | > **Intel no longer accepts patches to this project.**
7 | > *If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project.*
8 |
9 |
10 | **[Overview](#overview)**
11 | | **[Requirements](#requirements)**
12 | | **[Example usage](#example-usage)**
13 | | **[Results](#results)**
14 | | **[Paper](https://papers.nips.cc/paper/2020/file/d3d9446802a44259755d38e6d163e820-Paper.pdf)**
15 | | **[Citing](#citing)**
16 |
17 |
18 | Code to accompany the paper [Improving model calibration with accuracy versus uncertainty optimization](https://papers.nips.cc/paper/2020/hash/d3d9446802a44259755d38e6d163e820-Abstract.html) [NeurIPS 2020].
19 |
20 | **Abstract**: Obtaining reliable and accurate quantification of uncertainty estimates from deep neural networks is important in safety critical applications. A well-calibrated model should be accurate when it is certain about its prediction and indicate high uncertainty when it is likely to be inaccurate. Uncertainty calibration is a challenging problem as there is no ground truth available for uncertainty estimates. We propose an optimization method that leverages the relationship between accuracy and uncertainty as an anchor for uncertainty calibration. We introduce a differentiable *accuracy versus uncertainty calibration* (AvUC) loss function as an additional penalty term within loss-calibrated approximate inference framework. AvUC enables a model to learn to provide well-calibrated uncertainties, in addition to improved accuracy. We also demonstrate the same methodology can be extended to post-hoc uncertainty calibration on pretrained models.
21 |
22 | ## Overview
23 | This repository has code for *accuracy vs uncertainty calibration* (AvUC) loss and variational layers (convolutional and linear) to perform mean-field stochastic variational inference (SVI) in Bayesian neural networks. Implementations of SVI and SVI-AvUC methods with ResNet-20 (CIFAR10) and ResNet-50 (ImageNet) architectures.
24 |
25 | We propose an optimization method that leverages the relationship between accuracy and uncertainty as anchor for uncertainty calibration in deep neural network classifiers (Bayesian and non-Bayesian).
26 | We propose differentiable approximation to *accuracy vs uncertainty* (AvU) measure [[Mukhoti & Gal 2018](https://arxiv.org/abs/1811.12709)] and introduce trainable AvUC loss function.
27 | A task-specific utility function is employed in Bayesian decision theory [Berger 1985] to accomplish optimal predictions.
28 | In this work, AvU utility function is optimized during training for obtaining well-calibrated uncertainties along with improved accuracy.
29 | We use AvUC loss as an additional utility-dependent penalty term to accomplish the task of improving uncertainty calibration relying on the theoretically sound loss-calibrated
30 | approximate inference framework [[Lacoste-Julien et al. 2011](http://proceedings.mlr.press/v15/lacoste_julien11a.html), [Cobb et al. 2018](https://arxiv.org/abs/1805.03901)]
31 | rooted in Bayesian decision theory.
32 | ## Requirements
33 | This code has been tested on PyTorch v1.6.0 and torchvision v0.7.0 with python 3.7.7.
34 |
35 | Datasets:
36 | - CIFAR-10 [[Krizhevsky 2009](https://www.cs.toronto.edu/~kriz/cifar.html)]
37 | - ImageNet [[Deng et al. 2009](http://image-net.org/download)] (download dataset to data/imagenet folder)
38 | - CIFAR10 with corruptions [[Hendrycks & Dietterich 2019](https://github.com/hendrycks/robustness)] for dataset shift evaluation (download [CIFAR-10-C](https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1) dataset to 'data/CIFAR-10-C' folder)
39 | - ImageNet with corruptions [[Hendrycks & Dietterich 2019](https://github.com/hendrycks/robustness)] for dataset shift evaluation (download [ImageNet-C](https://zenodo.org/record/2235448#.X6u3NmhKjOg) dataset to 'data/ImageNet-C' folder)
40 | - SVHN [[Netzer et al. 2011](http://ufldl.stanford.edu/housenumbers/)] for out-of-distribution evaluation (download [file](https://zenodo.org/record/4267245/files/svhn-test.npy) to 'data/SVHN' folder)
41 |
42 | Dependencies:
43 |
44 | - Create conda environment with python=3.7
45 | - Install PyTorch and torchvision packages within conda environment following instructions from [PyTorch install guide](https://pytorch.org/get-started/locally/)
46 | - conda install -c conda-forge accimage
47 | - pip install tensorboard
48 | - pip install scikit-learn
49 |
50 | ## Example usage
51 | We provide [example usages](src/) for SVI-AvUC, SVI and Vanilla (deterministic) methods on CIFAR-10 and ImageNet to train/evaluate the [models](models) along with the implementation of [AvUC loss](src/avuc_loss.py) and [Bayesian layers](variational_layers/).
52 |
53 | ### Training
54 |
55 | To train the Bayesian ResNet-20 model on CIFAR10 with SVI-AvUC method, run this script:
56 | ```train_svi_avuc
57 | sh scripts/train_bayesian_cifar_avu.sh
58 | ```
59 |
60 |
61 | To train the Bayesian ResNet-50 model on ImageNet with SVI-AvUC method, run this script:
62 | ```train_imagenet_svi_avuc
63 | sh scripts/train_bayesian_imagenet_avu.sh
64 | ```
65 | ### Evaluation
66 |
67 | Our trained models can be downloaded from [here](https://zenodo.org/record/4267245#.X6uNBWhKiUm). Download and untar [SVI-AVUC ImageNet model](https://zenodo.org/record/4267245/files/svi-avuc-imagenet-trained-model.tar.gz?download=1) to 'checkpoint/imagenet/bayesian_svi_avu/' folder.
68 |
69 |
70 | To evaluate SVI-AvUC on CIFAR10, CIFAR10-C and SVHN, run the script below. Results (numpy files) will be saved in logs/cifar/bayesian_svi_avu/preds folder.
71 |
72 | ```eval_cifar
73 | sh scripts/test_bayesian_cifar_avu.sh
74 | ```
75 |
76 |
77 | To evaluate SVI-AvUC on ImageNet and ImageNet-C, run the script below. Results (numpy files) will be saved in logs/imagenet/bayesian_svi_avu/preds folder.
78 |
79 | ```eval_imagenet
80 | sh scripts/test_bayesian_imagenet_avu.sh
81 | ```
82 |
83 | ## Results
84 |
85 |
86 | **Model calibration under dataset shift**
87 |
88 | Figure below shows model calibration evaluation with Expected calibration error (ECE↓) and Expected uncertainty calibration error (UCE↓) on ImageNet under dataset shift. At each shift intensity level, the boxplot summarizes the results across 16 different dataset shift types. A well-calibrated model should
89 | provide lower calibration errors even at increased dataset shift.
90 |
91 |
92 |
93 |
94 |
95 | **Model performance with respect to confidence and uncertainty estimates**
96 |
97 | A reliable model should be accurate when it is certain about its prediction and indicate high uncertainty when it is likely to be inaccurate. We evaluate the quality of confidence and predictive uncertainty estimates using accuracy vs confidence and p(uncertain | inaccurate) plots respectively.
98 | The plots below indicate SVI-AvUC is more accurate at higher confidence and more uncertain when making inaccurate predictions under distributional shift (ImageNet corrupted with Gaussian blur), compared to other methods.
99 |
100 |
101 |
102 |
103 | **Model reliability towards out-of-distribution data**
104 |
105 | Out-of-distribution evaluation with SVHN data on the model trained with CIFAR10. SVI-AvUC has lesser number of examples with higher confidence and provides higher predictive uncertainty estimates on out-of-distribution data.
106 |
107 |
108 |
109 |
110 | **Distributional shift detection performance**
111 |
112 | Distributional shift detection using predictive uncertainty estimates.
113 | For dataset shift detection on ImageNet, test data corrupted with Gaussian blur of intensity level 5 is used.
114 | SVHN is used as out-of-distribution (OOD) data for OOD detection on model trained with CIFAR10.
115 |
116 |
117 |
118 |
119 |
120 | Please refer to the paper for more results. The results for Vanilla, Temp scaling, Ensemble and Dropout methods are computed from the model predictions provided in [UQ benchmark](https://console.cloud.google.com/storage/browser/uq-benchmark-2019) [[Ovadia et al. 2019](https://arxiv.org/abs/1906.02530)]
121 |
122 | ## Citing
123 |
124 | If you find this useful in your research, please cite as:
125 | ```sh
126 | @article{krishnan2020improving,
127 | author={Ranganath Krishnan and Omesh Tickoo},
128 | title={Improving model calibration with accuracy versus uncertainty optimization},
129 | journal={Advances in Neural Information Processing Systems},
130 | year={2020}
131 | }
132 | ```
133 |
--------------------------------------------------------------------------------
/assets/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/figure1.png
--------------------------------------------------------------------------------
/assets/figure2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/figure2.png
--------------------------------------------------------------------------------
/assets/figure3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/figure3.png
--------------------------------------------------------------------------------
/assets/figure4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/figure4.png
--------------------------------------------------------------------------------
/assets/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/assets/poster.pdf
--------------------------------------------------------------------------------
/checkpoint/bayesian_svi_avu/bayesian_resnet20_cifar.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/checkpoint/bayesian_svi_avu/bayesian_resnet20_cifar.pth
--------------------------------------------------------------------------------
/checkpoint/bayesian_svi_avuts/model_with_temperature_avuts.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/checkpoint/bayesian_svi_avuts/model_with_temperature_avuts.pth
--------------------------------------------------------------------------------
/checkpoint/bayesian_svi_avuts/valid_indices.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/checkpoint/bayesian_svi_avuts/valid_indices.pth
--------------------------------------------------------------------------------
/checkpoint/deterministic/resnet20_cifar.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/AVUC/326987d604e71eaa3b6db7e057ca1a34530b632c/checkpoint/deterministic/resnet20_cifar.pth
--------------------------------------------------------------------------------
/models/bayesian/resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Bayesian ResNet for CIFAR10.
3 |
4 | ResNet architecture ref:
5 | https://arxiv.org/abs/1512.03385
6 | '''
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.nn.init as init
12 | from variational_layers.conv_variational import Conv2dVariational
13 | from variational_layers.linear_variational import LinearVariational
14 |
15 | __all__ = [
16 | 'ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110'
17 | ]
18 |
19 | prior_mu = 0.0
20 | prior_sigma = 1.0
21 | posterior_mu_init = 0.0
22 | posterior_rho_init = -2.0
23 |
24 |
25 | def _weights_init(m):
26 | classname = m.__class__.__name__
27 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
28 | init.kaiming_normal_(m.weight)
29 |
30 |
31 | class LambdaLayer(nn.Module):
32 | def __init__(self, lambd):
33 | super(LambdaLayer, self).__init__()
34 | self.lambd = lambd
35 |
36 | def forward(self, x):
37 | return self.lambd(x)
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | expansion = 1
42 |
43 | def __init__(self, in_planes, planes, stride=1, option='A'):
44 | super(BasicBlock, self).__init__()
45 | self.conv1 = Conv2dVariational(prior_mu,
46 | prior_sigma,
47 | posterior_mu_init,
48 | posterior_rho_init,
49 | in_planes,
50 | planes,
51 | kernel_size=3,
52 | stride=stride,
53 | padding=1,
54 | bias=False)
55 | self.bn1 = nn.BatchNorm2d(planes)
56 | self.conv2 = Conv2dVariational(prior_mu,
57 | prior_sigma,
58 | posterior_mu_init,
59 | posterior_rho_init,
60 | planes,
61 | planes,
62 | kernel_size=3,
63 | stride=1,
64 | padding=1,
65 | bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 |
68 | self.shortcut = nn.Sequential()
69 | if stride != 1 or in_planes != planes:
70 | if option == 'A':
71 | """
72 | For CIFAR10 ResNet paper uses option A.
73 | """
74 | self.shortcut = LambdaLayer(lambda x: F.pad(
75 | x[:, :, ::2, ::2],
76 | (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
77 | elif option == 'B':
78 | self.shortcut = nn.Sequential(
79 | Conv2dVariational(prior_mu,
80 | prior_sigma,
81 | posterior_mu_init,
82 | posterior_rho_init,
83 | in_planes,
84 | self.expansion * planes,
85 | kernel_size=1,
86 | stride=stride,
87 | bias=False),
88 | nn.BatchNorm2d(self.expansion * planes))
89 |
90 | def forward(self, x):
91 | kl_sum = 0
92 | out, kl = self.conv1(x)
93 | kl_sum += kl
94 | out = self.bn1(out)
95 | out = F.relu(out)
96 | out, kl = self.conv2(out)
97 | kl_sum += kl
98 | out = self.bn2(out)
99 | out += self.shortcut(x)
100 | out = F.relu(out)
101 | return out, kl_sum
102 |
103 |
104 | class ResNet(nn.Module):
105 | def __init__(self, block, num_blocks, num_classes=10):
106 | super(ResNet, self).__init__()
107 | self.in_planes = 16
108 |
109 | self.conv1 = Conv2dVariational(prior_mu,
110 | prior_sigma,
111 | posterior_mu_init,
112 | posterior_rho_init,
113 | 3,
114 | 16,
115 | kernel_size=3,
116 | stride=1,
117 | padding=1,
118 | bias=False)
119 | self.bn1 = nn.BatchNorm2d(16)
120 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
121 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
122 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
123 | self.linear = LinearVariational(prior_mu, prior_sigma,
124 | posterior_mu_init, posterior_rho_init,
125 | 64, num_classes)
126 |
127 | self.apply(_weights_init)
128 |
129 | def _make_layer(self, block, planes, num_blocks, stride):
130 | strides = [stride] + [1] * (num_blocks - 1)
131 | layers = []
132 | for stride in strides:
133 | layers.append(block(self.in_planes, planes, stride))
134 | self.in_planes = planes * block.expansion
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | kl_sum = 0
140 | out, kl = self.conv1(x)
141 | kl_sum += kl
142 | out = self.bn1(out)
143 | out = F.relu(out)
144 | for l in self.layer1:
145 | out, kl = l(out)
146 | kl_sum += kl
147 | for l in self.layer2:
148 | out, kl = l(out)
149 | kl_sum += kl
150 | for l in self.layer3:
151 | out, kl = l(out)
152 | kl_sum += kl
153 |
154 | out = F.avg_pool2d(out, out.size()[3])
155 | out = out.view(out.size(0), -1)
156 | out, kl = self.linear(out)
157 | kl_sum += kl
158 | return out, kl_sum
159 |
160 |
161 | def resnet20():
162 | return ResNet(BasicBlock, [3, 3, 3])
163 |
164 |
165 | def resnet32():
166 | return ResNet(BasicBlock, [5, 5, 5])
167 |
168 |
169 | def resnet44():
170 | return ResNet(BasicBlock, [7, 7, 7])
171 |
172 |
173 | def resnet56():
174 | return ResNet(BasicBlock, [9, 9, 9])
175 |
176 |
177 | def resnet110():
178 | return ResNet(BasicBlock, [18, 18, 18])
179 |
180 |
181 | def test(net):
182 | import numpy as np
183 | total_params = 0
184 |
185 | for x in filter(lambda p: p.requires_grad, net.parameters()):
186 | total_params += np.prod(x.data.numpy().shape)
187 | print("Total number of params", total_params)
188 | print(
189 | "Total layers",
190 | len(
191 | list(
192 | filter(lambda p: p.requires_grad and len(p.data.size()) > 1,
193 | net.parameters()))))
194 |
195 |
196 | if __name__ == "__main__":
197 | for net_name in __all__:
198 | if net_name.startswith('resnet'):
199 | print(net_name)
200 | test(globals()[net_name]())
201 | print()
202 |
--------------------------------------------------------------------------------
/models/bayesian/resnet_large.py:
--------------------------------------------------------------------------------
1 | # Bayesian ResNet for ImageNet
2 | # ResNet architecture ref:
3 | # https://arxiv.org/abs/1512.03385
4 | # Code adapted from torchvision package to build Bayesian model from deterministic model
5 |
6 | import torch.nn as nn
7 | import math
8 | import torch.utils.model_zoo as model_zoo
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.nn.init as init
12 | from variational_layers.conv_variational import Conv2dVariational
13 | from variational_layers.linear_variational import LinearVariational
14 |
15 | __all__ = [
16 | 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
17 | ]
18 |
19 | prior_mu = 0.0
20 | prior_sigma = 1.0
21 | posterior_mu_init = 0.0
22 | posterior_rho_init = -9.0
23 |
24 | model_urls = {
25 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
26 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
27 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
28 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
29 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
30 | }
31 |
32 |
33 | def conv3x3(in_planes, out_planes, stride=1):
34 | """3x3 convolution with padding"""
35 | return Conv2dVariational(prior_mu,
36 | prior_sigma,
37 | posterior_mu_init,
38 | posterior_rho_init,
39 | in_planes,
40 | out_planes,
41 | kernel_size=3,
42 | stride=stride,
43 | padding=1,
44 | bias=False)
45 |
46 |
47 | class BasicBlock(nn.Module):
48 | expansion = 1
49 |
50 | def __init__(self, inplanes, planes, stride=1, downsample=None):
51 | super(BasicBlock, self).__init__()
52 | self.conv1 = conv3x3(inplanes, planes, stride)
53 | self.bn1 = nn.BatchNorm2d(planes)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = conv3x3(planes, planes)
56 | self.bn2 = nn.BatchNorm2d(planes)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x):
61 | residual = x
62 | kl_sum = 0
63 | out, kl = self.conv1(x)
64 | kl_sum += kl
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out, kl = self.conv2(out)
69 | kl_sum += kl
70 | out = self.bn2(out)
71 |
72 | if self.downsample is not None:
73 | residual = self.downsample(x)
74 |
75 | out += residual
76 | out = self.relu(out)
77 |
78 | return out, kl_sum
79 |
80 |
81 | class Bottleneck(nn.Module):
82 | expansion = 4
83 |
84 | def __init__(self, inplanes, planes, stride=1, downsample=None):
85 | super(Bottleneck, self).__init__()
86 | self.conv1 = Conv2dVariational(prior_mu,
87 | prior_sigma,
88 | posterior_mu_init,
89 | posterior_rho_init,
90 | inplanes,
91 | planes,
92 | kernel_size=1,
93 | bias=False)
94 | self.bn1 = nn.BatchNorm2d(planes)
95 | self.conv2 = Conv2dVariational(prior_mu,
96 | prior_sigma,
97 | posterior_mu_init,
98 | posterior_rho_init,
99 | planes,
100 | planes,
101 | kernel_size=3,
102 | stride=stride,
103 | padding=1,
104 | bias=False)
105 | self.bn2 = nn.BatchNorm2d(planes)
106 | self.conv3 = Conv2dVariational(prior_mu,
107 | prior_sigma,
108 | posterior_mu_init,
109 | posterior_rho_init,
110 | planes,
111 | planes * 4,
112 | kernel_size=1,
113 | bias=False)
114 | self.bn3 = nn.BatchNorm2d(planes * 4)
115 | self.relu = nn.ReLU(inplace=True)
116 | self.downsample = downsample
117 | self.stride = stride
118 |
119 | def forward(self, x):
120 | residual = x
121 | kl_sum = 0
122 | out, kl = self.conv1(x)
123 | kl_sum += kl
124 | out = self.bn1(out)
125 | out = self.relu(out)
126 |
127 | out, kl = self.conv2(out)
128 | kl_sum += kl
129 | out = self.bn2(out)
130 | out = self.relu(out)
131 |
132 | out, kl = self.conv3(out)
133 | kl_sum += kl
134 | out = self.bn3(out)
135 |
136 | if self.downsample is not None:
137 | residual = self.downsample(x)
138 |
139 | out += residual
140 | out = self.relu(out)
141 |
142 | return out, kl_sum
143 |
144 |
145 | class ResNet(nn.Module):
146 | def __init__(self, block, layers, num_classes=1000):
147 | self.inplanes = 64
148 | super(ResNet, self).__init__()
149 | self.conv1 = Conv2dVariational(prior_mu,
150 | prior_sigma,
151 | posterior_mu_init,
152 | posterior_rho_init,
153 | 3,
154 | 64,
155 | kernel_size=7,
156 | stride=2,
157 | padding=3,
158 | bias=False)
159 | self.bn1 = nn.BatchNorm2d(64)
160 | self.relu = nn.ReLU(inplace=True)
161 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
162 | self.layer1 = self._make_layer(block, 64, layers[0])
163 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
164 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
165 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
166 | self.avgpool = nn.AvgPool2d(7, stride=1)
167 | self.fc = LinearVariational(prior_mu, prior_sigma, posterior_mu_init,
168 | posterior_rho_init, 512 * block.expansion,
169 | num_classes)
170 |
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d):
173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
174 | m.weight.data.normal_(0, math.sqrt(2. / n))
175 | elif isinstance(m, nn.BatchNorm2d):
176 | m.weight.data.fill_(1)
177 | m.bias.data.zero_()
178 |
179 | def _make_layer(self, block, planes, blocks, stride=1):
180 | downsample = None
181 | if stride != 1 or self.inplanes != planes * block.expansion:
182 | downsample = nn.Sequential(
183 | nn.Conv2d(self.inplanes,
184 | planes * block.expansion,
185 | kernel_size=1,
186 | stride=stride,
187 | bias=False),
188 | nn.BatchNorm2d(planes * block.expansion),
189 | )
190 |
191 | layers = []
192 | layers.append(block(self.inplanes, planes, stride, downsample))
193 | self.inplanes = planes * block.expansion
194 | for i in range(1, blocks):
195 | layers.append(block(self.inplanes, planes))
196 |
197 | return nn.Sequential(*layers)
198 |
199 | def forward(self, x):
200 | kl_sum = 0
201 | x, kl = self.conv1(x)
202 | kl_sum += kl
203 | x = self.bn1(x)
204 | x = self.relu(x)
205 | x = self.maxpool(x)
206 |
207 | for layer in self.layer1:
208 | if 'Variational' in str(layer):
209 | x, kl = layer(x)
210 | if kl is None:
211 | kl_sum += kl
212 | else:
213 | x = layer(x)
214 |
215 | for layer in self.layer2:
216 | if 'Variational' in str(layer):
217 | x, kl = layer(x)
218 | if kl is None:
219 | kl_sum += kl
220 | else:
221 | x = layer(x)
222 |
223 | for layer in self.layer3:
224 | if 'Variational' in str(layer):
225 | x, kl = layer(x)
226 | if kl is None:
227 | kl_sum += kl
228 | else:
229 | x = layer(x)
230 |
231 | for layer in self.layer4:
232 | if 'Variational' in str(layer):
233 | x, kl = layer(x)
234 | if kl is None:
235 | kl_sum += kl
236 | else:
237 | x = layer(x)
238 |
239 | x = self.avgpool(x)
240 | x = x.view(x.size(0), -1)
241 | x, kl = self.fc(x)
242 | kl_sum += kl
243 |
244 | return x, kl_sum
245 |
246 |
247 | def resnet18(pretrained=False, **kwargs):
248 | """Constructs a ResNet-18 model.
249 |
250 | Args:
251 | pretrained (bool): If True, returns a model pre-trained on ImageNet
252 | """
253 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
254 | if pretrained:
255 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
256 | return model
257 |
258 |
259 | def resnet34(pretrained=False, **kwargs):
260 | """Constructs a ResNet-34 model.
261 |
262 | Args:
263 | pretrained (bool): If True, returns a model pre-trained on ImageNet
264 | """
265 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
266 | if pretrained:
267 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
268 | return model
269 |
270 |
271 | def resnet50(pretrained=False, **kwargs):
272 | """Constructs a ResNet-50 model.
273 |
274 | Args:
275 | pretrained (bool): If True, returns a model pre-trained on ImageNet
276 | """
277 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
278 | if pretrained:
279 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
280 | return model
281 |
282 |
283 | def resnet101(pretrained=False, **kwargs):
284 | """Constructs a ResNet-101 model.
285 |
286 | Args:
287 | pretrained (bool): If True, returns a model pre-trained on ImageNet
288 | """
289 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
290 | if pretrained:
291 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
292 | return model
293 |
294 |
295 | def resnet152(pretrained=False, **kwargs):
296 | """Constructs a ResNet-152 model.
297 |
298 | Args:
299 | pretrained (bool): If True, returns a model pre-trained on ImageNet
300 | """
301 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
302 | if pretrained:
303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
304 | return model
305 |
--------------------------------------------------------------------------------
/models/bayesian/simple_cnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from variational_layers.conv_variational import Conv2dVariational
8 | from variational_layers.linear_variational import LinearVariational
9 |
10 | prior_mu = 0.0
11 | prior_sigma = 1.0
12 | posterior_mu_init = 0.0
13 | posterior_rho_init = -3.0
14 |
15 |
16 | class SCNN(nn.Module):
17 | def __init__(self):
18 | super(SCNN, self).__init__()
19 | self.conv1 = Conv2dVariational(prior_mu, prior_sigma,
20 | posterior_mu_init, posterior_rho_init,
21 | 1, 32, 3, 1)
22 | self.conv2 = Conv2dVariational(prior_mu, prior_sigma,
23 | posterior_mu_init, posterior_rho_init,
24 | 32, 64, 3, 1)
25 | self.dropout1 = nn.Dropout2d(0.1)
26 | self.dropout2 = nn.Dropout2d(0.1)
27 | self.fc1 = LinearVariational(prior_mu, prior_sigma, posterior_mu_init,
28 | posterior_rho_init, 9216, 128)
29 | self.fc2 = LinearVariational(prior_mu, prior_sigma, posterior_mu_init,
30 | posterior_rho_init, 128, 10)
31 |
32 | def forward(self, x):
33 | kl_sum = 0
34 | x, kl = self.conv1(x)
35 | kl_sum += kl
36 | x = F.relu(x)
37 | x, kl = self.conv2(x)
38 | kl_sum += kl
39 | x = F.relu(x)
40 | x = F.max_pool2d(x, 2)
41 | x = self.dropout1(x)
42 | x = torch.flatten(x, 1)
43 | x, kl = self.fc1(x)
44 | kl_sum += kl
45 | x = F.relu(x)
46 | x = self.dropout2(x)
47 | x, kl = self.fc2(x)
48 | kl_sum += kl
49 | output = F.log_softmax(x, dim=1)
50 | return output, kl
51 |
--------------------------------------------------------------------------------
/models/deterministic/resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | ResNet for CIFAR10.
3 |
4 | Ref for ResNet architecture:
5 | https://arxiv.org/abs/1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.nn.init as init
11 |
12 | __all__ = [
13 | 'ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
14 | 'resnet1202'
15 | ]
16 |
17 |
18 | def _weights_init(m):
19 | classname = m.__class__.__name__
20 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
21 | init.kaiming_normal_(m.weight)
22 |
23 |
24 | class LambdaLayer(nn.Module):
25 | def __init__(self, lambd):
26 | super(LambdaLayer, self).__init__()
27 | self.lambd = lambd
28 |
29 | def forward(self, x):
30 | return self.lambd(x)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 |
36 | def __init__(self, in_planes, planes, stride=1, option='A'):
37 | super(BasicBlock, self).__init__()
38 | self.conv1 = nn.Conv2d(in_planes,
39 | planes,
40 | kernel_size=3,
41 | stride=stride,
42 | padding=1,
43 | bias=False)
44 | self.bn1 = nn.BatchNorm2d(planes)
45 | self.conv2 = nn.Conv2d(planes,
46 | planes,
47 | kernel_size=3,
48 | stride=1,
49 | padding=1,
50 | bias=False)
51 | self.bn2 = nn.BatchNorm2d(planes)
52 |
53 | self.shortcut = nn.Sequential()
54 | if stride != 1 or in_planes != planes:
55 | if option == 'A':
56 | self.shortcut = LambdaLayer(lambda x: F.pad(
57 | x[:, :, ::2, ::2],
58 | (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
59 | elif option == 'B':
60 | self.shortcut = nn.Sequential(
61 | nn.Conv2d(in_planes,
62 | self.expansion * planes,
63 | kernel_size=1,
64 | stride=stride,
65 | bias=False),
66 | nn.BatchNorm2d(self.expansion * planes))
67 |
68 | def forward(self, x):
69 | out = F.relu(self.bn1(self.conv1(x)))
70 | out = self.bn2(self.conv2(out))
71 | out += self.shortcut(x)
72 | out = F.relu(out)
73 | return out
74 |
75 |
76 | class ResNet(nn.Module):
77 | def __init__(self, block, num_blocks, num_classes=10):
78 | super(ResNet, self).__init__()
79 | self.in_planes = 16
80 |
81 | self.conv1 = nn.Conv2d(3,
82 | 16,
83 | kernel_size=3,
84 | stride=1,
85 | padding=1,
86 | bias=False)
87 | self.bn1 = nn.BatchNorm2d(16)
88 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
89 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
90 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
91 | self.linear = nn.Linear(64, num_classes)
92 |
93 | self.apply(_weights_init)
94 |
95 | def _make_layer(self, block, planes, num_blocks, stride):
96 | strides = [stride] + [1] * (num_blocks - 1)
97 | layers = []
98 | for stride in strides:
99 | layers.append(block(self.in_planes, planes, stride))
100 | self.in_planes = planes * block.expansion
101 |
102 | return nn.Sequential(*layers)
103 |
104 | def forward(self, x):
105 | out = F.relu(self.bn1(self.conv1(x)))
106 | out = self.layer1(out)
107 | out = self.layer2(out)
108 | out = self.layer3(out)
109 | out = F.avg_pool2d(out, out.size()[3])
110 | out = out.view(out.size(0), -1)
111 | out = self.linear(out)
112 | return out
113 |
114 |
115 | def resnet20():
116 | return ResNet(BasicBlock, [3, 3, 3])
117 |
118 |
119 | def resnet32():
120 | return ResNet(BasicBlock, [5, 5, 5])
121 |
122 |
123 | def resnet44():
124 | return ResNet(BasicBlock, [7, 7, 7])
125 |
126 |
127 | def resnet56():
128 | return ResNet(BasicBlock, [9, 9, 9])
129 |
130 |
131 | def resnet110():
132 | return ResNet(BasicBlock, [18, 18, 18])
133 |
134 |
135 | def test(net):
136 | import numpy as np
137 | total_params = 0
138 |
139 | for x in filter(lambda p: p.requires_grad, net.parameters()):
140 | total_params += np.prod(x.data.numpy().shape)
141 | print("Total number of params", total_params)
142 | print(
143 | "Total layers",
144 | len(
145 | list(
146 | filter(lambda p: p.requires_grad and len(p.data.size()) > 1,
147 | net.parameters()))))
148 |
149 |
150 | if __name__ == "__main__":
151 | for net_name in __all__:
152 | if net_name.startswith('resnet'):
153 | print(net_name)
154 | test(globals()[net_name]())
155 | print()
156 |
--------------------------------------------------------------------------------
/models/deterministic/resnet_large.py:
--------------------------------------------------------------------------------
1 | # deterministic model from torchvision package
2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
3 |
4 | import torch.nn as nn
5 | import math
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | __all__ = [
9 | 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
10 | ]
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 |
21 | def conv3x3(in_planes, out_planes, stride=1):
22 | """3x3 convolution with padding"""
23 | return nn.Conv2d(in_planes,
24 | out_planes,
25 | kernel_size=3,
26 | stride=stride,
27 | padding=1,
28 | bias=False)
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | expansion = 1
33 |
34 | def __init__(self, inplanes, planes, stride=1, downsample=None):
35 | super(BasicBlock, self).__init__()
36 | self.conv1 = conv3x3(inplanes, planes, stride)
37 | self.bn1 = nn.BatchNorm2d(planes)
38 | self.relu = nn.ReLU(inplace=True)
39 | self.conv2 = conv3x3(planes, planes)
40 | self.bn2 = nn.BatchNorm2d(planes)
41 | self.downsample = downsample
42 | self.stride = stride
43 |
44 | def forward(self, x):
45 | residual = x
46 |
47 | out = self.conv1(x)
48 | out = self.bn1(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv2(out)
52 | out = self.bn2(out)
53 |
54 | if self.downsample is not None:
55 | residual = self.downsample(x)
56 |
57 | out += residual
58 | out = self.relu(out)
59 |
60 | return out
61 |
62 |
63 | class Bottleneck(nn.Module):
64 | expansion = 4
65 |
66 | def __init__(self, inplanes, planes, stride=1, downsample=None):
67 | super(Bottleneck, self).__init__()
68 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
69 | self.bn1 = nn.BatchNorm2d(planes)
70 | self.conv2 = nn.Conv2d(planes,
71 | planes,
72 | kernel_size=3,
73 | stride=stride,
74 | padding=1,
75 | bias=False)
76 | self.bn2 = nn.BatchNorm2d(planes)
77 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
78 | self.bn3 = nn.BatchNorm2d(planes * 4)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | residual = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv3(out)
95 | out = self.bn3(out)
96 |
97 | if self.downsample is not None:
98 | residual = self.downsample(x)
99 |
100 | out += residual
101 | out = self.relu(out)
102 |
103 | return out
104 |
105 |
106 | class ResNet(nn.Module):
107 | def __init__(self, block, layers, num_classes=1000):
108 | self.inplanes = 64
109 | super(ResNet, self).__init__()
110 | self.conv1 = nn.Conv2d(3,
111 | 64,
112 | kernel_size=7,
113 | stride=2,
114 | padding=3,
115 | bias=False)
116 | self.bn1 = nn.BatchNorm2d(64)
117 | self.relu = nn.ReLU(inplace=True)
118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
119 | self.layer1 = self._make_layer(block, 64, layers[0])
120 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
121 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
122 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
123 | self.avgpool = nn.AvgPool2d(7, stride=1)
124 | self.fc = nn.Linear(512 * block.expansion, num_classes)
125 |
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
129 | m.weight.data.normal_(0, math.sqrt(2. / n))
130 | elif isinstance(m, nn.BatchNorm2d):
131 | m.weight.data.fill_(1)
132 | m.bias.data.zero_()
133 |
134 | def _make_layer(self, block, planes, blocks, stride=1):
135 | downsample = None
136 | if stride != 1 or self.inplanes != planes * block.expansion:
137 | downsample = nn.Sequential(
138 | nn.Conv2d(self.inplanes,
139 | planes * block.expansion,
140 | kernel_size=1,
141 | stride=stride,
142 | bias=False),
143 | nn.BatchNorm2d(planes * block.expansion),
144 | )
145 |
146 | layers = []
147 | layers.append(block(self.inplanes, planes, stride, downsample))
148 | self.inplanes = planes * block.expansion
149 | for i in range(1, blocks):
150 | layers.append(block(self.inplanes, planes))
151 |
152 | return nn.Sequential(*layers)
153 |
154 | def forward(self, x):
155 | x = self.conv1(x)
156 | x = self.bn1(x)
157 | x = self.relu(x)
158 | x = self.maxpool(x)
159 |
160 | x = self.layer1(x)
161 | x = self.layer2(x)
162 | x = self.layer3(x)
163 | x = self.layer4(x)
164 |
165 | x = self.avgpool(x)
166 | x = x.view(x.size(0), -1)
167 | x = self.fc(x)
168 |
169 | return x
170 |
171 |
172 | def resnet18(pretrained=False, **kwargs):
173 | """Constructs a ResNet-18 model.
174 |
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | """
178 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
179 | if pretrained:
180 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
181 | return model
182 |
183 |
184 | def resnet34(pretrained=False, **kwargs):
185 | """Constructs a ResNet-34 model.
186 |
187 | Args:
188 | pretrained (bool): If True, returns a model pre-trained on ImageNet
189 | """
190 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
191 | if pretrained:
192 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
193 | return model
194 |
195 |
196 | def resnet50(pretrained=False, **kwargs):
197 | """Constructs a ResNet-50 model.
198 |
199 | Args:
200 | pretrained (bool): If True, returns a model pre-trained on ImageNet
201 | """
202 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
203 | if pretrained:
204 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
205 | return model
206 |
207 |
208 | def resnet101(pretrained=False, **kwargs):
209 | """Constructs a ResNet-101 model.
210 |
211 | Args:
212 | pretrained (bool): If True, returns a model pre-trained on ImageNet
213 | """
214 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
215 | if pretrained:
216 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
217 | return model
218 |
219 |
220 | def resnet152(pretrained=False, **kwargs):
221 | """Constructs a ResNet-152 model.
222 |
223 | Args:
224 | pretrained (bool): If True, returns a model pre-trained on ImageNet
225 | """
226 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
227 | if pretrained:
228 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
229 | return model
230 |
--------------------------------------------------------------------------------
/models/deterministic/simple_cnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class SCNN(nn.Module):
9 | def __init__(self):
10 | super(SCNN, self).__init__()
11 | self.conv1 = nn.Conv2d(1, 32, 3, 1)
12 | self.conv2 = nn.Conv2d(32, 64, 3, 1)
13 | self.dropout1 = nn.Dropout2d(0.23)
14 | self.dropout2 = nn.Dropout2d(0.23)
15 | self.fc1 = nn.Linear(9216, 128)
16 | self.fc2 = nn.Linear(128, 10)
17 |
18 | def forward(self, x):
19 | x = self.conv1(x)
20 | x = F.relu(x)
21 | x = self.conv2(x)
22 | x = F.relu(x)
23 | x = F.max_pool2d(x, 2)
24 | x = self.dropout1(x)
25 | x = torch.flatten(x, 1)
26 | x = self.fc1(x)
27 | x = F.relu(x)
28 | x = self.dropout2(x)
29 | x = self.fc2(x)
30 | output = F.log_softmax(x, dim=1)
31 | return output
32 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | scikit-learn
4 | tensorboard
5 |
--------------------------------------------------------------------------------
/scripts/test_bayesian_cifar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet20
3 | mode='test'
4 | batch_size=10000
5 | num_monte_carlo=128
6 |
7 | python src/main_bayesian_cifar.py --arch=$model --mode=$mode --batch-size=$batch_size --num_monte_carlo=$num_monte_carlo
8 |
--------------------------------------------------------------------------------
/scripts/test_bayesian_cifar_auavu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet20
3 | mode='test'
4 | batch_size=10000
5 | num_monte_carlo=128
6 |
7 | python src/main_bayesian_cifar_auavu.py --arch=$model --mode=$mode --batch-size=$batch_size --num_monte_carlo=$num_monte_carlo
8 |
--------------------------------------------------------------------------------
/scripts/test_bayesian_cifar_avu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet20
3 | mode='test'
4 | batch_size=10000
5 | num_monte_carlo=128
6 |
7 | python src/main_bayesian_cifar_avu.py --arch=$model --mode=$mode --batch-size=$batch_size --num_monte_carlo=$num_monte_carlo
8 |
--------------------------------------------------------------------------------
/scripts/test_bayesian_imagenet_avu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet50
3 | mode='test'
4 | val_batch_size=2500
5 | num_monte_carlo=128
6 |
7 | python src/main_bayesian_imagenet_avu.py data/imagenet --arch=$model --mode=$mode --val_batch_size=$val_batch_size --num_monte_carlo=$num_monte_carlo
8 |
--------------------------------------------------------------------------------
/scripts/test_deterministic_cifar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet20
3 | mode='test'
4 | batch_size=10000
5 |
6 | python src/main_deterministic_cifar.py --arch=$model --mode=$mode --batch-size=$batch_size
7 |
--------------------------------------------------------------------------------
/scripts/train_bayesian_cifar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet20
3 | mode='train'
4 | batch_size=107
5 | lr=0.001189
6 |
7 | python src/main_bayesian_cifar.py --lr=$lr --arch=$model --mode=$mode --batch-size=$batch_size
8 |
--------------------------------------------------------------------------------
/scripts/train_bayesian_cifar_auavu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet20
3 | mode='train'
4 | batch_size=107
5 | lr=0.001189
6 | moped=False
7 |
8 | python src/main_bayesian_cifar_auavu.py --lr=$lr --arch=$model --mode=$mode --batch-size=$batch_size --moped=$moped
9 |
--------------------------------------------------------------------------------
/scripts/train_bayesian_cifar_avu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet20
3 | mode='train'
4 | batch_size=107
5 | lr=0.001189
6 | moped=False
7 |
8 | python src/main_bayesian_cifar_avu.py --lr=$lr --arch=$model --mode=$mode --batch-size=$batch_size --moped=$moped
9 |
--------------------------------------------------------------------------------
/scripts/train_bayesian_imagenet_avu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet50
3 | mode='train'
4 | batch_size=384
5 | lr=0.001
6 | moped=True
7 |
8 | python -u src/main_bayesian_imagenet_avu.py data/imagenet --lr=$lr --arch=$model --mode=$mode --batch-size=$batch_size --moped=$moped
9 |
--------------------------------------------------------------------------------
/scripts/train_deterministic_cifar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet20
3 | mode='train'
4 | batch_size=512
5 | lr=0.001
6 |
7 | python src/main_deterministic_cifar.py --arch=$model --mode=$mode --batch-size=$batch_size
8 |
--------------------------------------------------------------------------------
/src/avuc_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2020 Intel Corporation
2 | #
3 | # BSD-3-Clause License
4 | #
5 | # Redistribution and use in source and binary forms, with or without modification,
6 | # are permitted provided that the following conditions are met:
7 | # 1. Redistributions of source code must retain the above copyright notice,
8 | # this list of conditions and the following disclaimer.
9 | # 2. Redistributions in binary form must reproduce the above copyright notice,
10 | # this list of conditions and the following disclaimer in the documentation
11 | # and/or other materials provided with the distribution.
12 | # 3. Neither the name of the copyright holder nor the names of its contributors
13 | # may be used to endorse or promote products derived from this software
14 | # without specific prior written permission.
15 | #
16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
20 | # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
21 | # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
22 | # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
23 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
25 | # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
26 | # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 | #
28 | # AvULoss -> compute accuracy versus uncertainty calibration loss
29 | # AUAvULoss -> compute accuracy versus uncertainty calibration loss
30 | # without uncertainty threshold
31 | # accuracy_versus_uncertainty -> compute AvU metric
32 | # eval_AvU -> get AvU scores at differemt uncertainty thresholds
33 | # predictive_entropy -> compute predictive uncertainty of the model
34 | # mutual_information -> compute model uncertainty of the model
35 | #
36 | # @authors: Ranganath Krishnan
37 | #
38 | # ===============================================================================================
39 |
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 | import torch.nn.functional as F
44 | import torch
45 | from torch import nn
46 | import numpy as np
47 | from sklearn.metrics import auc
48 |
49 |
50 | class AUAvULoss(nn.Module):
51 | """
52 | Calculates Accuracy vs Uncertainty Loss of a model.
53 | The input to this loss is logits from Monte_carlo sampling of the model, true labels,
54 | and the type of uncertainty to be used [0: predictive uncertainty (default);
55 | 1: model uncertainty]
56 | """
57 | def __init__(self, beta=1):
58 | super(AUAvULoss, self).__init__()
59 | self.beta = beta
60 | self.eps = 1e-10
61 |
62 | def entropy(self, prob):
63 | return -1 * torch.sum(prob * torch.log(prob + self.eps), dim=-1)
64 |
65 | def expected_entropy(self, mc_preds):
66 | return torch.mean(self.entropy(mc_preds), dim=0)
67 |
68 | def predictive_uncertainty(self, mc_preds):
69 | """
70 | Compute the entropy of the mean of the predictive distribution
71 | obtained from Monte Carlo sampling.
72 | """
73 | return self.entropy(torch.mean(mc_preds, dim=0))
74 |
75 | def model_uncertainty(self, mc_preds):
76 | """
77 | Compute the difference between the entropy of the mean of the
78 | predictive distribution and the mean of the entropy.
79 | """
80 | return self.entropy(torch.mean(
81 | mc_preds, dim=0)) - self.expected_entropy(mc_preds)
82 |
83 | def auc_avu(self, logits, labels, unc):
84 | """ returns AvU at various uncertainty thresholds"""
85 | th_list = np.linspace(0, 1, 21)
86 | umin = torch.min(unc)
87 | umax = torch.max(unc)
88 | avu_list = []
89 | unc_list = []
90 |
91 | probs = F.softmax(logits, dim=1)
92 | confidences, predictions = torch.max(probs, 1)
93 |
94 | auc_avu = torch.ones(1, device=labels.device)
95 | auc_avu.requires_grad_(True)
96 |
97 | for t in th_list:
98 | unc_th = umin + (torch.tensor(t) * (umax - umin))
99 | n_ac = torch.zeros(
100 | 1,
101 | device=labels.device) # number of samples accurate and certain
102 | n_ic = torch.zeros(1, device=labels.device
103 | ) # number of samples inaccurate and certain
104 | n_au = torch.zeros(1, device=labels.device
105 | ) # number of samples accurate and uncertain
106 | n_iu = torch.zeros(1, device=labels.device
107 | ) # number of samples inaccurate and uncertain
108 |
109 | for i in range(len(labels)):
110 | if ((labels[i].item() == predictions[i].item())
111 | and unc[i].item() <= unc_th.item()):
112 | """ accurate and certain """
113 | n_ac += confidences[i] * (1 - torch.tanh(unc[i]))
114 | elif ((labels[i].item() == predictions[i].item())
115 | and unc[i].item() > unc_th.item()):
116 | """ accurate and uncertain """
117 | n_au += confidences[i] * torch.tanh(unc[i])
118 | elif ((labels[i].item() != predictions[i].item())
119 | and unc[i].item() <= unc_th.item()):
120 | """ inaccurate and certain """
121 | n_ic += (1 - confidences[i]) * (1 - torch.tanh(unc[i]))
122 | elif ((labels[i].item() != predictions[i].item())
123 | and unc[i].item() > unc_th.item()):
124 | """ inaccurate and uncertain """
125 | n_iu += (1 - confidences[i]) * torch.tanh(unc[i])
126 |
127 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + 1e-10)
128 | avu_list.append(AvU.data.cpu().numpy())
129 | unc_list.append(unc_th)
130 |
131 | auc_avu = auc(th_list, avu_list)
132 | return auc_avu
133 |
134 | def accuracy_vs_uncertainty(self, prediction, true_label, uncertainty,
135 | optimal_threshold):
136 | n_ac = torch.zeros(
137 | 1,
138 | device=true_label.device) # number of samples accurate and certain
139 | n_ic = torch.zeros(1, device=true_label.device
140 | ) # number of samples inaccurate and certain
141 | n_au = torch.zeros(1, device=true_label.device
142 | ) # number of samples accurate and uncertain
143 | n_iu = torch.zeros(1, device=true_label.device
144 | ) # number of samples inaccurate and uncertain
145 |
146 | avu = torch.ones(1, device=true_label.device)
147 | avu.requires_grad_(True)
148 |
149 | for i in range(len(true_label)):
150 | if ((true_label[i].item() == prediction[i].item())
151 | and uncertainty[i].item() <= optimal_threshold):
152 | """ accurate and certain """
153 | n_ac += 1
154 | elif ((true_label[i].item() == prediction[i].item())
155 | and uncertainty[i].item() > optimal_threshold):
156 | """ accurate and uncertain """
157 | n_au += 1
158 | elif ((true_label[i].item() != prediction[i].item())
159 | and uncertainty[i].item() <= optimal_threshold):
160 | """ inaccurate and certain """
161 | n_ic += 1
162 | elif ((true_label[i].item() != prediction[i].item())
163 | and uncertainty[i].item() > optimal_threshold):
164 | """ inaccurate and uncertain """
165 | n_iu += 1
166 |
167 | print('n_ac: ', n_ac, ' ; n_au: ', n_au, ' ; n_ic: ', n_ic, ' ;n_iu: ',
168 | n_iu)
169 | avu = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu)
170 |
171 | return avu
172 |
173 | def forward(self, logits, labels, type=0):
174 |
175 | probs = F.softmax(logits, dim=1)
176 | confidences, predictions = torch.max(probs, 1)
177 |
178 | if type == 0:
179 | unc = self.entropy(probs)
180 | else:
181 | unc = self.model_uncertainty(probs)
182 |
183 | th_list = np.linspace(0, 1, 21)
184 | umin = torch.min(unc)
185 | umax = torch.max(unc)
186 | avu_list = []
187 | unc_list = []
188 |
189 | probs = F.softmax(logits, dim=1)
190 | confidences, predictions = torch.max(probs, 1)
191 |
192 | auc_avu = torch.ones(1, device=labels.device)
193 | auc_avu.requires_grad_(True)
194 |
195 | for t in th_list:
196 | unc_th = umin + (torch.tensor(t, device=labels.device) *
197 | (umax - umin))
198 | n_ac = torch.zeros(
199 | 1,
200 | device=labels.device) # number of samples accurate and certain
201 | n_ic = torch.zeros(1, device=labels.device
202 | ) # number of samples inaccurate and certain
203 | n_au = torch.zeros(1, device=labels.device
204 | ) # number of samples accurate and uncertain
205 | n_iu = torch.zeros(1, device=labels.device
206 | ) # number of samples inaccurate and uncertain
207 |
208 | for i in range(len(labels)):
209 | if ((labels[i].item() == predictions[i].item())
210 | and unc[i].item() <= unc_th.item()):
211 | """ accurate and certain """
212 | n_ac += confidences[i] * (1 - torch.tanh(unc[i]))
213 | elif ((labels[i].item() == predictions[i].item())
214 | and unc[i].item() > unc_th.item()):
215 | """ accurate and uncertain """
216 | n_au += confidences[i] * torch.tanh(unc[i])
217 | elif ((labels[i].item() != predictions[i].item())
218 | and unc[i].item() <= unc_th.item()):
219 | """ inaccurate and certain """
220 | n_ic += (1 - confidences[i]) * (1 - torch.tanh(unc[i]))
221 | elif ((labels[i].item() != predictions[i].item())
222 | and unc[i].item() > unc_th.item()):
223 | """ inaccurate and uncertain """
224 | n_iu += (1 - confidences[i]) * torch.tanh(unc[i])
225 |
226 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + self.eps)
227 | avu_list.append(AvU)
228 | unc_list.append(unc_th)
229 |
230 | auc_avu = auc(th_list, avu_list)
231 | avu_loss = -1 * self.beta * torch.log(auc_avu + self.eps)
232 | return avu_loss, auc_avu
233 |
234 |
235 | class AvULoss(nn.Module):
236 | """
237 | Calculates Accuracy vs Uncertainty Loss of a model.
238 | The input to this loss is logits from Monte_carlo sampling of the model, true labels,
239 | and the type of uncertainty to be used [0: predictive uncertainty (default);
240 | 1: model uncertainty]
241 | """
242 | def __init__(self, beta=1):
243 | super(AvULoss, self).__init__()
244 | self.beta = beta
245 | self.eps = 1e-10
246 |
247 | def entropy(self, prob):
248 | return -1 * torch.sum(prob * torch.log(prob + self.eps), dim=-1)
249 |
250 | def expected_entropy(self, mc_preds):
251 | return torch.mean(self.entropy(mc_preds), dim=0)
252 |
253 | def predictive_uncertainty(self, mc_preds):
254 | """
255 | Compute the entropy of the mean of the predictive distribution
256 | obtained from Monte Carlo sampling.
257 | """
258 | return self.entropy(torch.mean(mc_preds, dim=0))
259 |
260 | def model_uncertainty(self, mc_preds):
261 | """
262 | Compute the difference between the entropy of the mean of the
263 | predictive distribution and the mean of the entropy.
264 | """
265 | return self.entropy(torch.mean(
266 | mc_preds, dim=0)) - self.expected_entropy(mc_preds)
267 |
268 | def accuracy_vs_uncertainty(self, prediction, true_label, uncertainty,
269 | optimal_threshold):
270 | # number of samples accurate and certain
271 | n_ac = torch.zeros(1, device=true_label.device)
272 | # number of samples inaccurate and certain
273 | n_ic = torch.zeros(1, device=true_label.device)
274 | # number of samples accurate and uncertain
275 | n_au = torch.zeros(1, device=true_label.device)
276 | # number of samples inaccurate and uncertain
277 | n_iu = torch.zeros(1, device=true_label.device)
278 |
279 | avu = torch.ones(1, device=true_label.device)
280 | avu.requires_grad_(True)
281 |
282 | for i in range(len(true_label)):
283 | if ((true_label[i].item() == prediction[i].item())
284 | and uncertainty[i].item() <= optimal_threshold):
285 | """ accurate and certain """
286 | n_ac += 1
287 | elif ((true_label[i].item() == prediction[i].item())
288 | and uncertainty[i].item() > optimal_threshold):
289 | """ accurate and uncertain """
290 | n_au += 1
291 | elif ((true_label[i].item() != prediction[i].item())
292 | and uncertainty[i].item() <= optimal_threshold):
293 | """ inaccurate and certain """
294 | n_ic += 1
295 | elif ((true_label[i].item() != prediction[i].item())
296 | and uncertainty[i].item() > optimal_threshold):
297 | """ inaccurate and uncertain """
298 | n_iu += 1
299 |
300 | print('n_ac: ', n_ac, ' ; n_au: ', n_au, ' ; n_ic: ', n_ic, ' ;n_iu: ',
301 | n_iu)
302 | avu = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu)
303 |
304 | return avu
305 |
306 | def forward(self, logits, labels, optimal_uncertainty_threshold, type=0):
307 |
308 | probs = F.softmax(logits, dim=1)
309 | confidences, predictions = torch.max(probs, 1)
310 |
311 | if type == 0:
312 | unc = self.entropy(probs)
313 | else:
314 | unc = self.model_uncertainty(probs)
315 |
316 | unc_th = torch.tensor(optimal_uncertainty_threshold,
317 | device=logits.device)
318 |
319 | n_ac = torch.zeros(
320 | 1, device=logits.device) # number of samples accurate and certain
321 | n_ic = torch.zeros(
322 | 1,
323 | device=logits.device) # number of samples inaccurate and certain
324 | n_au = torch.zeros(
325 | 1,
326 | device=logits.device) # number of samples accurate and uncertain
327 | n_iu = torch.zeros(
328 | 1,
329 | device=logits.device) # number of samples inaccurate and uncertain
330 |
331 | avu = torch.ones(1, device=logits.device)
332 | avu_loss = torch.zeros(1, device=logits.device)
333 |
334 | for i in range(len(labels)):
335 | if ((labels[i].item() == predictions[i].item())
336 | and unc[i].item() <= unc_th.item()):
337 | """ accurate and certain """
338 | n_ac += confidences[i] * (1 - torch.tanh(unc[i]))
339 | elif ((labels[i].item() == predictions[i].item())
340 | and unc[i].item() > unc_th.item()):
341 | """ accurate and uncertain """
342 | n_au += confidences[i] * torch.tanh(unc[i])
343 | elif ((labels[i].item() != predictions[i].item())
344 | and unc[i].item() <= unc_th.item()):
345 | """ inaccurate and certain """
346 | n_ic += (1 - confidences[i]) * (1 - torch.tanh(unc[i]))
347 | elif ((labels[i].item() != predictions[i].item())
348 | and unc[i].item() > unc_th.item()):
349 | """ inaccurate and uncertain """
350 | n_iu += (1 - confidences[i]) * torch.tanh(unc[i])
351 |
352 | avu = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + self.eps)
353 | p_ac = (n_ac) / (n_ac + n_ic)
354 | p_ui = (n_iu) / (n_iu + n_ic)
355 | #print('Actual AvU: ', self.accuracy_vs_uncertainty(predictions, labels, uncertainty, optimal_threshold))
356 | avu_loss = -1 * self.beta * torch.log(avu + self.eps)
357 | return avu_loss
358 |
359 |
360 | def entropy(prob):
361 | return -1 * np.sum(prob * np.log(prob + 1e-15), axis=-1)
362 |
363 |
364 | def predictive_entropy(mc_preds):
365 | """
366 | Compute the entropy of the mean of the predictive distribution
367 | obtained from Monte Carlo sampling during prediction phase.
368 | """
369 | return entropy(np.mean(mc_preds, axis=0))
370 |
371 |
372 | def mutual_information(mc_preds):
373 | """
374 | Compute the difference between the entropy of the mean of the
375 | predictive distribution and the mean of the entropy.
376 | """
377 | MI = entropy(np.mean(mc_preds, axis=0)) - np.mean(entropy(mc_preds),
378 | axis=0)
379 | return MI
380 |
381 |
382 | def eval_avu(pred_label, true_label, uncertainty):
383 | """ returns AvU at various uncertainty thresholds"""
384 | t_list = np.linspace(0, 1, 21)
385 | umin = np.amin(uncertainty, axis=0)
386 | umax = np.amax(uncertainty, axis=0)
387 | avu_list = []
388 | unc_list = []
389 | for t in t_list:
390 | u_th = umin + (t * (umax - umin))
391 | n_ac = 0
392 | n_ic = 0
393 | n_au = 0
394 | n_iu = 0
395 | for i in range(len(true_label)):
396 | if ((true_label[i] == pred_label[i]) and uncertainty[i] <= u_th):
397 | n_ac += 1
398 | elif ((true_label[i] == pred_label[i]) and uncertainty[i] > u_th):
399 | n_au += 1
400 | elif ((true_label[i] != pred_label[i]) and uncertainty[i] <= u_th):
401 | n_ic += 1
402 | elif ((true_label[i] != pred_label[i]) and uncertainty[i] > u_th):
403 | n_iu += 1
404 |
405 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + 1e-15)
406 | avu_list.append(AvU)
407 | unc_list.append(u_th)
408 | return np.asarray(avu_list), np.asarray(unc_list)
409 |
410 |
411 | def accuracy_vs_uncertainty(pred_label, true_label, uncertainty,
412 | optimal_threshold):
413 |
414 | n_ac = 0
415 | n_ic = 0
416 | n_au = 0
417 | n_iu = 0
418 | for i in range(len(true_label)):
419 | if ((true_label[i] == pred_label[i])
420 | and uncertainty[i] <= optimal_threshold):
421 | n_ac += 1
422 | elif ((true_label[i] == pred_label[i])
423 | and uncertainty[i] > optimal_threshold):
424 | n_au += 1
425 | elif ((true_label[i] != pred_label[i])
426 | and uncertainty[i] <= optimal_threshold):
427 | n_ic += 1
428 | elif ((true_label[i] != pred_label[i])
429 | and uncertainty[i] > optimal_threshold):
430 | n_iu += 1
431 |
432 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu)
433 | return AvU
434 |
--------------------------------------------------------------------------------
/src/main_bayesian_cifar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.optim
10 | import torch.utils.data
11 | from torch.utils.tensorboard import SummaryWriter
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 | import models.bayesian.resnet as resnet
15 | import numpy as np
16 | import csv
17 | #from utils import calib
18 |
19 | model_names = sorted(
20 | name for name in resnet.__dict__
21 | if name.islower() and not name.startswith("__")
22 | and name.startswith("resnet") and callable(resnet.__dict__[name]))
23 |
24 | print(model_names)
25 | len_trainset = 50000
26 | len_testset = 10000
27 |
28 | parser = argparse.ArgumentParser(description='CIFAR10')
29 | parser.add_argument('--arch',
30 | '-a',
31 | metavar='ARCH',
32 | default='resnet20',
33 | choices=model_names,
34 | help='model architecture: ' + ' | '.join(model_names) +
35 | ' (default: resnet20)')
36 | parser.add_argument('-j',
37 | '--workers',
38 | default=8,
39 | type=int,
40 | metavar='N',
41 | help='number of data loading workers (default: 8)')
42 | parser.add_argument('--epochs',
43 | default=200,
44 | type=int,
45 | metavar='N',
46 | help='number of total epochs to run')
47 | parser.add_argument('--start-epoch',
48 | default=0,
49 | type=int,
50 | metavar='N',
51 | help='manual epoch number (useful on restarts)')
52 | parser.add_argument('-b',
53 | '--batch-size',
54 | default=512,
55 | type=int,
56 | metavar='N',
57 | help='mini-batch size (default: 512)')
58 | parser.add_argument('--lr',
59 | '--learning-rate',
60 | default=0.1,
61 | type=float,
62 | metavar='LR',
63 | help='initial learning rate')
64 | parser.add_argument('--momentum',
65 | default=0.9,
66 | type=float,
67 | metavar='M',
68 | help='momentum')
69 | parser.add_argument('--weight-decay',
70 | '--wd',
71 | default=1e-4,
72 | type=float,
73 | metavar='W',
74 | help='weight decay (default: 5e-4)')
75 | parser.add_argument('--print-freq',
76 | '-p',
77 | default=50,
78 | type=int,
79 | metavar='N',
80 | help='print frequency (default: 20)')
81 | parser.add_argument('--resume',
82 | default='',
83 | type=str,
84 | metavar='PATH',
85 | help='path to latest checkpoint (default: none)')
86 | parser.add_argument('-e',
87 | '--evaluate',
88 | dest='evaluate',
89 | action='store_true',
90 | help='evaluate model on validation set')
91 | parser.add_argument('--pretrained',
92 | dest='pretrained',
93 | action='store_true',
94 | help='use pre-trained model')
95 | parser.add_argument('--half',
96 | dest='half',
97 | action='store_true',
98 | help='use half-precision(16-bit) ')
99 | parser.add_argument('--save-dir',
100 | dest='save_dir',
101 | help='The directory used to save the trained models',
102 | default='./checkpoint/bayesian_svi',
103 | type=str)
104 | parser.add_argument(
105 | '--save-every',
106 | dest='save_every',
107 | help='Saves checkpoints at every specified number of epochs',
108 | type=int,
109 | default=10)
110 | parser.add_argument('--mode', type=str, required=True, help='train | test')
111 | parser.add_argument('--num_monte_carlo',
112 | type=int,
113 | default=20,
114 | metavar='N',
115 | help='number of Monte Carlo samples')
116 | parser.add_argument(
117 | '--tensorboard',
118 | type=bool,
119 | default=True,
120 | metavar='N',
121 | help='use tensorboard for logging and visualization of training progress')
122 | parser.add_argument(
123 | '--log_dir',
124 | type=str,
125 | default='./logs/cifar/bayesian_svi',
126 | metavar='N',
127 | help='use tensorboard for logging and visualization of training progress')
128 | best_prec1 = 0
129 |
130 |
131 | class CorruptDataset(torch.utils.data.Dataset):
132 | def __init__(self, data, target, transform=None):
133 | self.data = data
134 | self.target = target
135 | self.transform = transform
136 |
137 | def __getitem__(self, index):
138 | x = self.data[index]
139 | y = self.target[index]
140 |
141 | if self.transform:
142 | x = self.transform(x)
143 |
144 | return x, y
145 |
146 | def __len__(self):
147 | return len(self.data)
148 |
149 |
150 | class OODDataset(torch.utils.data.Dataset):
151 | def __init__(self, data, target, transform=None):
152 | self.data = data
153 | self.target = target
154 | self.transform = transform
155 |
156 | def __getitem__(self, index):
157 | x = self.data[index]
158 | y = self.target[index]
159 |
160 | if self.transform:
161 | x = self.transform(x)
162 |
163 | return x, y
164 |
165 | def __len__(self):
166 | return len(self.data)
167 |
168 |
169 | def get_ood_dataloader(ood_images, ood_labels):
170 | ood_dataset = OODDataset(ood_images,
171 | ood_labels,
172 | transform=transforms.Compose(
173 | [transforms.ToTensor()]))
174 | ood_data_loader = torch.utils.data.DataLoader(ood_dataset,
175 | batch_size=args.batch_size,
176 | shuffle=False,
177 | num_workers=args.workers,
178 | pin_memory=True)
179 |
180 | return ood_data_loader
181 |
182 |
183 | corruptions = [
184 | 'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog',
185 | 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise',
186 | 'pixelate', 'saturate', 'shot_noise', 'spatter', 'speckle_noise',
187 | 'zoom_blur'
188 | ]
189 |
190 |
191 | def get_corrupt_dataloader(corrupted_images, corrupted_labels, level):
192 | corrupted_images_1 = corrupted_images[0:10000, :, :, :]
193 | corrupted_labels_1 = corrupted_labels[0:10000]
194 | corrupted_images_2 = corrupted_images[10000:20000, :, :, :]
195 | corrupted_labels_2 = corrupted_labels[10000:20000]
196 | corrupted_images_3 = corrupted_images[20000:30000, :, :, :]
197 | corrupted_labels_3 = corrupted_labels[20000:30000]
198 | corrupted_images_4 = corrupted_images[30000:40000, :, :, :]
199 | corrupted_labels_4 = corrupted_labels[30000:40000]
200 | corrupted_images_5 = corrupted_images[40000:50000, :, :, :]
201 | corrupted_labels_5 = corrupted_labels[40000:50000]
202 | if level == 1:
203 | corrupt_val_dataset = CorruptDataset(corrupted_images_1,
204 | corrupted_labels_1,
205 | transform=transforms.Compose(
206 | [transforms.ToTensor()]))
207 | elif level == 2:
208 | corrupt_val_dataset = CorruptDataset(corrupted_images_2,
209 | corrupted_labels_2,
210 | transform=transforms.Compose(
211 | [transforms.ToTensor()]))
212 | elif level == 3:
213 | corrupt_val_dataset = CorruptDataset(corrupted_images_3,
214 | corrupted_labels_3,
215 | transform=transforms.Compose(
216 | [transforms.ToTensor()]))
217 | elif level == 4:
218 | corrupt_val_dataset = CorruptDataset(corrupted_images_4,
219 | corrupted_labels_4,
220 | transform=transforms.Compose(
221 | [transforms.ToTensor()]))
222 | elif level == 5:
223 | corrupt_val_dataset = CorruptDataset(corrupted_images_5,
224 | corrupted_labels_5,
225 | transform=transforms.Compose(
226 | [transforms.ToTensor()]))
227 |
228 | corrupt_val_loader = torch.utils.data.DataLoader(
229 | corrupt_val_dataset,
230 | batch_size=args.batch_size,
231 | shuffle=False,
232 | num_workers=args.workers,
233 | pin_memory=True)
234 |
235 | return corrupt_val_loader
236 |
237 |
238 | def main():
239 | global args, best_prec1
240 | args = parser.parse_args()
241 |
242 | # Check the save_dir exists or not
243 | if not os.path.exists(args.save_dir):
244 | os.makedirs(args.save_dir)
245 |
246 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
247 | if torch.cuda.is_available():
248 | model.cuda()
249 | else:
250 | model.cpu()
251 |
252 | # optionally resume from a checkpoint
253 | if args.resume:
254 | if os.path.isfile(args.resume):
255 | print("=> loading checkpoint '{}'".format(args.resume))
256 | checkpoint = torch.load(args.resume)
257 | args.start_epoch = checkpoint['epoch']
258 | best_prec1 = checkpoint['best_prec1']
259 | model.load_state_dict(checkpoint['state_dict'])
260 | print("=> loaded checkpoint '{}' (epoch {})".format(
261 | args.evaluate, checkpoint['epoch']))
262 | else:
263 | print("=> no checkpoint found at '{}'".format(args.resume))
264 |
265 | cudnn.benchmark = True
266 |
267 | tb_writer = None
268 | if args.tensorboard:
269 | logger_dir = os.path.join(args.log_dir, 'tb_logger')
270 | if not os.path.exists(logger_dir):
271 | os.makedirs(logger_dir)
272 | tb_writer = SummaryWriter(logger_dir)
273 |
274 | preds_dir = os.path.join(args.log_dir, 'preds')
275 | if not os.path.exists(preds_dir):
276 | os.makedirs(preds_dir)
277 | results_dir = os.path.join(args.log_dir, 'results')
278 | if not os.path.exists(results_dir):
279 | os.makedirs(results_dir)
280 |
281 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
282 | std=[0.229, 0.224, 0.225])
283 |
284 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
285 | root='./data',
286 | train=True,
287 | transform=transforms.Compose([
288 | transforms.RandomHorizontalFlip(),
289 | transforms.RandomCrop(32, 4),
290 | transforms.ToTensor()
291 | ]),
292 | download=True),
293 | batch_size=args.batch_size,
294 | shuffle=True,
295 | num_workers=args.workers,
296 | pin_memory=True)
297 |
298 | val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
299 | root='./data',
300 | train=False,
301 | transform=transforms.Compose([transforms.ToTensor()])),
302 | batch_size=args.batch_size,
303 | shuffle=False,
304 | num_workers=args.workers,
305 | pin_memory=True)
306 |
307 | if not os.path.exists(args.save_dir):
308 | os.makedirs(args.save_dir)
309 |
310 | if torch.cuda.is_available():
311 | criterion = nn.CrossEntropyLoss().cuda()
312 | else:
313 | criterion = nn.CrossEntropyLoss().cpu()
314 |
315 | if args.half:
316 | model.half()
317 | criterion.half()
318 |
319 | if args.arch in ['resnet110']:
320 | for param_group in optimizer.param_groups:
321 | param_group['lr'] = args.lr * 0.1
322 |
323 | if args.evaluate:
324 | validate(val_loader, model, criterion)
325 | return
326 |
327 | if args.mode == 'train':
328 | for epoch in range(args.start_epoch, args.epochs):
329 |
330 | lr = args.lr
331 | if (epoch >= 80 and epoch < 120):
332 | lr = 0.1 * args.lr
333 | elif (epoch >= 120 and epoch < 160):
334 | lr = 0.01 * args.lr
335 | elif (epoch >= 160 and epoch < 180):
336 | lr = 0.001 * args.lr
337 | elif (epoch >= 180):
338 | lr = 0.0005 * args.lr
339 |
340 | optimizer = torch.optim.Adam(model.parameters(), lr)
341 | # train for one epoch
342 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
343 | train(train_loader, model, criterion, optimizer, epoch, tb_writer)
344 | prec1 = validate(val_loader, model, criterion, epoch, tb_writer)
345 | is_best = prec1 > best_prec1
346 | best_prec1 = max(prec1, best_prec1)
347 |
348 | if epoch > 0:
349 | if is_best:
350 | save_checkpoint(
351 | {
352 | 'epoch': epoch + 1,
353 | 'state_dict': model.state_dict(),
354 | 'best_prec1': best_prec1,
355 | },
356 | is_best,
357 | filename=os.path.join(
358 | args.save_dir,
359 | 'bayesian_{}_cifar.pth'.format(args.arch)))
360 |
361 | elif args.mode == 'test':
362 | checkpoint_file = args.save_dir + '/bayesian_{}_cifar.pth'.format(
363 | args.arch)
364 | if torch.cuda.is_available():
365 | checkpoint = torch.load(checkpoint_file)
366 | else:
367 | checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
368 | model.load_state_dict(checkpoint['state_dict'])
369 |
370 | #header = ['corrupt', 'test_acc', 'brier', 'ece']
371 | header = ['corrupt', 'test_acc']
372 |
373 | #Evaluate on OOD dataset (SVHN)
374 | ood_images_file = 'data/SVHN/svhn-test.npy'
375 | ood_images = np.load(ood_images_file)
376 | ood_images = ood_images[:10000, :, :, :]
377 | ood_labels = np.arange(len(ood_images)) + 10 #create dummy labels
378 | ood_loader = get_ood_dataloader(ood_images, ood_labels)
379 | ood_acc = evaluate(model, ood_loader, corrupt='ood', level=None)
380 | print('******OOD data***********\n')
381 | #print('ood_acc: ', ood_acc, ' | Brier: ', ood_brier, ' | ECE: ', ood_ece, '\n')
382 | print('ood_acc: ', ood_acc)
383 | '''
384 | o_file = args.log_dir + '/results/ood_results.csv'
385 | with open(o_file, 'wt') as o_file:
386 | writer = csv.writer(o_file, delimiter=',', lineterminator='\n')
387 | writer.writerow([j for j in header])
388 | writer.writerow(['ood', ood_acc, ood_brier, ood_ece])
389 | o_file.close()
390 | '''
391 | #Evaluate on test dataset
392 | test_acc = evaluate(model, val_loader, corrupt=None, level=None)
393 | print('******Test data***********\n')
394 | #print('test_acc: ', test_acc, ' | Brier: ', brier, ' | ECE: ', ece, '\n')
395 | print('test_acc: ', test_acc)
396 | '''
397 | t_file = args.log_dir + '/test_results.csv'
398 | with open(t_file, 'wt') as t_file:
399 | writer = csv.writer(t_file, delimiter=',', lineterminator='\n')
400 | writer.writerow([j for j in header])
401 | writer.writerow(['test', test_acc, brier, ece])
402 | t_file.close()
403 | '''
404 |
405 | for level in range(1, 6):
406 | print('******Corruption Level: ', level, ' ***********\n')
407 | results_file = args.log_dir + '/level' + str(level) + '.csv'
408 | with open(results_file, 'wt') as results_file:
409 | writer = csv.writer(results_file,
410 | delimiter=',',
411 | lineterminator='\n')
412 | writer.writerow([j for j in header])
413 | for c in corruptions:
414 | images_file = 'data/CIFAR-10-C/' + c + '.npy'
415 | labels_file = 'data/CIFAR-10-C/labels.npy'
416 | corrupt_images = np.load(images_file)
417 | corrupt_labels = np.load(labels_file)
418 | val_loader = get_corrupt_dataloader(
419 | corrupt_images, corrupt_labels, level)
420 | test_acc = evaluate(model,
421 | val_loader,
422 | corrupt=c,
423 | level=level)
424 | print('############ Corruption type: ', c,
425 | ' ################')
426 | #print('test_acc: ', test_acc, ' | Brier: ', brier, ' | ECE: ', ece, '\n')
427 | print('test_acc: ', test_acc)
428 | #writer.writerow([c, test_acc, brier, ece])
429 | writer.writerow([c, test_acc])
430 | results_file.close()
431 |
432 |
433 | def train(train_loader, model, criterion, optimizer, epoch, tb_writer=None):
434 | batch_time = AverageMeter()
435 | data_time = AverageMeter()
436 | losses = AverageMeter()
437 | top1 = AverageMeter()
438 |
439 | # switch to train mode
440 | model.train()
441 |
442 | end = time.time()
443 | for i, (input, target) in enumerate(train_loader):
444 |
445 | # measure data loading time
446 | data_time.update(time.time() - end)
447 |
448 | if torch.cuda.is_available():
449 | target = target.cuda()
450 | input_var = input.cuda()
451 | target_var = target.cuda()
452 | else:
453 | target = target.cpu()
454 | input_var = input.cpu()
455 | target_var = target.cpu()
456 | if args.half:
457 | input_var = input_var.half()
458 |
459 | # compute output
460 | output, kl = model(input_var)
461 | cross_entropy_loss = criterion(output, target_var)
462 | scaled_kl = kl.data / (len_trainset)
463 | loss = cross_entropy_loss + scaled_kl
464 |
465 | # compute gradient and do SGD step
466 | optimizer.zero_grad()
467 | loss.backward()
468 | optimizer.step()
469 |
470 | output = output.float()
471 | loss = loss.float()
472 | # measure accuracy and record loss
473 | prec1 = accuracy(output.data, target)[0]
474 | losses.update(loss.item(), input.size(0))
475 | top1.update(prec1.item(), input.size(0))
476 |
477 | # measure elapsed time
478 | batch_time.update(time.time() - end)
479 | end = time.time()
480 |
481 | if i % args.print_freq == 0:
482 | print('Epoch: [{0}][{1}/{2}]\t'
483 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
484 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
485 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
486 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
487 | epoch,
488 | i,
489 | len(train_loader),
490 | batch_time=batch_time,
491 | data_time=data_time,
492 | loss=losses,
493 | top1=top1))
494 |
495 | if tb_writer is not None:
496 | tb_writer.add_scalar('train/cross_entropy_loss',
497 | cross_entropy_loss.item(), epoch)
498 | tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch)
499 | tb_writer.add_scalar('train/elbo_loss', loss.item(), epoch)
500 | tb_writer.add_scalar('train/accuracy', prec1.item(), epoch)
501 | tb_writer.flush()
502 |
503 |
504 | def validate(val_loader, model, criterion, epoch, tb_writer=None):
505 | batch_time = AverageMeter()
506 | losses = AverageMeter()
507 | top1 = AverageMeter()
508 |
509 | # switch to evaluate mode
510 | model.eval()
511 |
512 | end = time.time()
513 | with torch.no_grad():
514 | for i, (input, target) in enumerate(val_loader):
515 | if torch.cuda.is_available():
516 | target = target.cuda()
517 | input_var = input.cuda()
518 | target_var = target.cuda()
519 | else:
520 | target = target.cpu()
521 | input_var = input.cpu()
522 | target_var = target.cpu()
523 |
524 | if args.half:
525 | input_var = input_var.half()
526 |
527 | # compute output
528 | output, kl = model(input_var)
529 | cross_entropy_loss = criterion(output, target_var)
530 | scaled_kl = kl.data / (len_trainset)
531 | loss = cross_entropy_loss + scaled_kl
532 |
533 | output = output.float()
534 | loss = loss.float()
535 |
536 | # measure accuracy and record loss
537 | prec1 = accuracy(output.data, target)[0]
538 | losses.update(loss.item(), input.size(0))
539 | top1.update(prec1.item(), input.size(0))
540 |
541 | # measure elapsed time
542 | batch_time.update(time.time() - end)
543 | end = time.time()
544 |
545 | if i % args.print_freq == 0:
546 | print('Test: [{0}/{1}]\t'
547 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
548 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
549 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
550 | i,
551 | len(val_loader),
552 | batch_time=batch_time,
553 | loss=losses,
554 | top1=top1))
555 |
556 | if tb_writer is not None:
557 | tb_writer.add_scalar('val/cross_entropy_loss',
558 | cross_entropy_loss.item(), epoch)
559 | tb_writer.add_scalar('val/kl_div', scaled_kl.item(), epoch)
560 | tb_writer.add_scalar('val/elbo_loss', loss.item(), epoch)
561 | tb_writer.add_scalar('val/accuracy', prec1.item(), epoch)
562 | tb_writer.flush()
563 |
564 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
565 |
566 | return top1.avg
567 |
568 |
569 | def evaluate(model, val_loader, corrupt=None, level=None):
570 | pred_probs_mc = []
571 | test_loss = 0
572 | correct = 0
573 | with torch.no_grad():
574 | pred_probs_mc = []
575 | for batch_idx, (data, target) in enumerate(val_loader):
576 | #print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape))
577 | if torch.cuda.is_available():
578 | data, target = data.cuda(), target.cuda()
579 | else:
580 | data, target = data.cpu(), target.cpu()
581 | for mc_run in range(args.num_monte_carlo):
582 | model.eval()
583 | output, _ = model.forward(data)
584 | pred_probs = torch.nn.functional.softmax(output, dim=1)
585 | pred_probs_mc.append(pred_probs.cpu().data.numpy())
586 |
587 | if corrupt == 'ood':
588 | np.save(args.log_dir + '/preds/svi_ood_probs.npy', pred_probs_mc)
589 | print('saved predictions')
590 | return None
591 |
592 | target_labels = target.cpu().data.numpy()
593 | pred_mean = np.mean(pred_probs_mc, axis=0)
594 | #print(pred_mean)
595 | Y_pred = np.argmax(pred_mean, axis=1)
596 | test_acc = (Y_pred == target_labels).mean()
597 | #brier = np.mean(calib.brier_scores(target_labels, probs=pred_mean))
598 | #ece = calib.expected_calibration_error_multiclass(pred_mean, target_labels)
599 | print('Test accuracy:', test_acc * 100)
600 | #print('Brier score: ', brier)
601 | #print('ECE: ', ece)
602 | if corrupt is not None:
603 | np.save(
604 | args.log_dir +
605 | '/preds/svi_corrupt-static-{}-{}_probs.npy'.format(
606 | corrupt, level), pred_probs_mc)
607 | np.save(
608 | args.log_dir +
609 | '/preds/svi_corrupt-static-{}-{}_labels.npy'.format(
610 | corrupt, level), target_labels)
611 | print('saved predictions')
612 | else:
613 | np.save(args.log_dir + '/preds/svi_test_probs.npy', pred_probs_mc)
614 | np.save(args.log_dir + '/preds/svi_test_labels.npy', target_labels)
615 | print('saved predictions')
616 |
617 | return test_acc
618 |
619 |
620 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
621 | """
622 | Save the training model
623 | """
624 | torch.save(state, filename)
625 |
626 |
627 | class AverageMeter(object):
628 | """Computes and stores the average and current value"""
629 | def __init__(self):
630 | self.reset()
631 |
632 | def reset(self):
633 | self.val = 0
634 | self.avg = 0
635 | self.sum = 0
636 | self.count = 0
637 |
638 | def update(self, val, n=1):
639 | self.val = val
640 | self.sum += val * n
641 | self.count += n
642 | self.avg = self.sum / self.count
643 |
644 |
645 | def accuracy(output, target, topk=(1, )):
646 | """Computes the precision@k for the specified values of k"""
647 | maxk = max(topk)
648 | batch_size = target.size(0)
649 |
650 | _, pred = output.topk(maxk, 1, True, True)
651 | pred = pred.t()
652 | correct = pred.eq(target.view(1, -1).expand_as(pred))
653 |
654 | res = []
655 | for k in topk:
656 | correct_k = correct[:k].view(-1).float().sum(0)
657 | res.append(correct_k.mul_(100.0 / batch_size))
658 | return res
659 |
660 |
661 | if __name__ == '__main__':
662 | main()
663 |
--------------------------------------------------------------------------------
/src/main_bayesian_cifar_auavu.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim
11 | import torch.utils.data
12 | from torch.utils.tensorboard import SummaryWriter
13 | import torchvision.transforms as transforms
14 | import torchvision.datasets as datasets
15 | import models.bayesian.resnet as resnet
16 | import models.deterministic.resnet as det_resnet
17 | import numpy as np
18 | from src import util
19 | #from utils import calib
20 | import csv
21 | from src.util import get_rho
22 | from src.avuc_loss import AvULoss, AUAvULoss
23 | from torchsummary import summary
24 |
25 | model_names = sorted(
26 | name for name in resnet.__dict__
27 | if name.islower() and not name.startswith("__")
28 | and name.startswith("resnet") and callable(resnet.__dict__[name]))
29 |
30 | print(model_names)
31 |
32 | parser = argparse.ArgumentParser(description='CIFAR10')
33 | parser.add_argument('--arch',
34 | '-a',
35 | metavar='ARCH',
36 | default='resnet20',
37 | choices=model_names,
38 | help='model architecture: ' + ' | '.join(model_names) +
39 | ' (default: resnet20)')
40 | parser.add_argument('-j',
41 | '--workers',
42 | default=8,
43 | type=int,
44 | metavar='N',
45 | help='number of data loading workers (default: 8)')
46 | parser.add_argument('--epochs',
47 | default=200,
48 | type=int,
49 | metavar='N',
50 | help='number of total epochs to run')
51 | parser.add_argument('--start-epoch',
52 | default=0,
53 | type=int,
54 | metavar='N',
55 | help='manual epoch number (useful on restarts)')
56 | parser.add_argument('-b',
57 | '--batch-size',
58 | default=512,
59 | type=int,
60 | metavar='N',
61 | help='mini-batch size (default: 512)')
62 | parser.add_argument('--lr',
63 | '--learning-rate',
64 | default=0.1,
65 | type=float,
66 | metavar='LR',
67 | help='initial learning rate')
68 | parser.add_argument('--momentum',
69 | default=0.9,
70 | type=float,
71 | metavar='M',
72 | help='momentum')
73 | parser.add_argument('--weight-decay',
74 | '--wd',
75 | default=1e-4,
76 | type=float,
77 | metavar='W',
78 | help='weight decay (default: 5e-4)')
79 | parser.add_argument('--print-freq',
80 | '-p',
81 | default=50,
82 | type=int,
83 | metavar='N',
84 | help='print frequency (default: 20)')
85 | parser.add_argument('--resume',
86 | default='',
87 | type=str,
88 | metavar='PATH',
89 | help='path to latest checkpoint (default: none)')
90 | parser.add_argument('-e',
91 | '--evaluate',
92 | dest='evaluate',
93 | action='store_true',
94 | help='evaluate model on validation set')
95 | parser.add_argument('--pretrained',
96 | dest='pretrained',
97 | action='store_true',
98 | help='use pre-trained model')
99 | parser.add_argument('--half',
100 | dest='half',
101 | action='store_true',
102 | help='use half-precision(16-bit) ')
103 | parser.add_argument('--save-dir',
104 | dest='save_dir',
105 | help='The directory used to save the trained models',
106 | default='./checkpoint/bayesian_svi_auavu',
107 | type=str)
108 | parser.add_argument(
109 | '--save-every',
110 | dest='save_every',
111 | help='Saves checkpoints at every specified number of epochs',
112 | type=int,
113 | default=10)
114 | parser.add_argument('--mode', type=str, required=True, help='train | test')
115 | parser.add_argument('--num_monte_carlo',
116 | type=int,
117 | default=20,
118 | metavar='N',
119 | help='number of Monte Carlo samples')
120 | parser.add_argument(
121 | '--tensorboard',
122 | type=bool,
123 | default=True,
124 | metavar='N',
125 | help='use tensorboard for logging and visualization of training progress')
126 | parser.add_argument('--val_batch_size', default=10000, type=int)
127 | parser.add_argument(
128 | '--log_dir',
129 | type=str,
130 | default='./logs/cifar/bayesian_svi_auavu',
131 | metavar='N',
132 | help='use tensorboard for logging and visualization of training progress')
133 | parser.add_argument(
134 | '--moped',
135 | type=bool,
136 | default=False,
137 | help='set prior and initialize approx posterior with Empirical Bayes')
138 | parser.add_argument('--delta',
139 | type=float,
140 | default=1.0,
141 | help='delta value for variance scaling in MOPED')
142 | len_trainset = 50000
143 | len_testset = 10000
144 | beta = 3.0
145 | optimal_threshold = 1.0
146 | opt_th = optimal_threshold
147 | best_prec1 = 0
148 | best_avu = 0
149 |
150 |
151 | class CorruptDataset(torch.utils.data.Dataset):
152 | def __init__(self, data, target, transform=None):
153 | self.data = data
154 | self.target = target
155 | self.transform = transform
156 |
157 | def __getitem__(self, index):
158 | x = self.data[index]
159 | y = self.target[index]
160 |
161 | if self.transform:
162 | x = self.transform(x)
163 |
164 | return x, y
165 |
166 | def __len__(self):
167 | return len(self.data)
168 |
169 |
170 | class OODDataset(torch.utils.data.Dataset):
171 | def __init__(self, data, target, transform=None):
172 | self.data = data
173 | self.target = target
174 | self.transform = transform
175 |
176 | def __getitem__(self, index):
177 | x = self.data[index]
178 | y = self.target[index]
179 |
180 | if self.transform:
181 | x = self.transform(x)
182 |
183 | return x, y
184 |
185 | def __len__(self):
186 | return len(self.data)
187 |
188 |
189 | def get_ood_dataloader(ood_images, ood_labels):
190 | ood_dataset = OODDataset(ood_images,
191 | ood_labels,
192 | transform=transforms.Compose(
193 | [transforms.ToTensor()]))
194 | ood_data_loader = torch.utils.data.DataLoader(ood_dataset,
195 | batch_size=args.batch_size,
196 | shuffle=False,
197 | num_workers=args.workers,
198 | pin_memory=True)
199 | return ood_data_loader
200 |
201 |
202 | corruptions = [
203 | 'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog',
204 | 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise',
205 | 'pixelate', 'saturate', 'shot_noise', 'spatter', 'speckle_noise',
206 | 'zoom_blur'
207 | ]
208 |
209 |
210 | def get_corrupt_dataloader(corrupted_images, corrupted_labels, level):
211 |
212 | corrupted_images_1 = corrupted_images[0:10000, :, :, :]
213 | corrupted_labels_1 = corrupted_labels[0:10000]
214 | corrupted_images_2 = corrupted_images[10000:20000, :, :, :]
215 | corrupted_labels_2 = corrupted_labels[10000:20000]
216 | corrupted_images_3 = corrupted_images[20000:30000, :, :, :]
217 | corrupted_labels_3 = corrupted_labels[20000:30000]
218 | corrupted_images_4 = corrupted_images[30000:40000, :, :, :]
219 | corrupted_labels_4 = corrupted_labels[30000:40000]
220 | corrupted_images_5 = corrupted_images[40000:50000, :, :, :]
221 | corrupted_labels_5 = corrupted_labels[40000:50000]
222 |
223 | if level == 1:
224 | corrupt_val_dataset = CorruptDataset(corrupted_images_1,
225 | corrupted_labels_1,
226 | transform=transforms.Compose(
227 | [transforms.ToTensor()]))
228 | elif level == 2:
229 | corrupt_val_dataset = CorruptDataset(corrupted_images_2,
230 | corrupted_labels_2,
231 | transform=transforms.Compose(
232 | [transforms.ToTensor()]))
233 | elif level == 3:
234 | corrupt_val_dataset = CorruptDataset(corrupted_images_3,
235 | corrupted_labels_3,
236 | transform=transforms.Compose(
237 | [transforms.ToTensor()]))
238 | elif level == 4:
239 | corrupt_val_dataset = CorruptDataset(corrupted_images_4,
240 | corrupted_labels_4,
241 | transform=transforms.Compose(
242 | [transforms.ToTensor()]))
243 | elif level == 5:
244 | corrupt_val_dataset = CorruptDataset(corrupted_images_5,
245 | corrupted_labels_5,
246 | transform=transforms.Compose(
247 | [transforms.ToTensor()]))
248 |
249 | corrupt_val_loader = torch.utils.data.DataLoader(
250 | corrupt_val_dataset,
251 | batch_size=args.val_batch_size,
252 | shuffle=False,
253 | num_workers=args.workers,
254 | pin_memory=True)
255 |
256 | return corrupt_val_loader
257 |
258 |
259 | def MOPED_layer(layer, det_layer, delta):
260 | """
261 | Set the priors and initialize surrogate posteriors of Bayesian NN with Empirical Bayes
262 | MOPED (Model Priors with Empirical Bayes using Deterministic DNN)
263 | Ref: https://arxiv.org/abs/1906.05323
264 | 'Specifying Weight Priors in Bayesian Deep Neural Networks with Empirical Bayes'. AAAI 2020.
265 | """
266 |
267 | if (str(layer) == 'Conv2dVariational()'):
268 | #set the priors
269 | print(str(layer))
270 | layer.prior_weight_mu = det_layer.weight.data
271 | if layer.prior_bias_mu is not None:
272 | layer.prior_bias_mu = det_layer.bias.data
273 |
274 | #initialize surrogate posteriors
275 | layer.mu_kernel.data = det_layer.weight.data
276 | #layer.rho_kernel.data = get_rho(det_layer.weight.data, delta)
277 | if layer.mu_bias is not None:
278 | layer.mu_bias.data = det_layer.bias.data
279 | #layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
280 |
281 | elif (isinstance(layer, nn.Conv2d)):
282 | print(str(layer))
283 | layer.weight.data = det_layer.weight.data
284 | if layer.bias is not None:
285 | layer.bias.data = det_layer.bias.data2
286 |
287 | elif (str(layer) == 'LinearVariational()'):
288 | print(str(layer))
289 | layer.prior_weight_mu = det_layer.weight.data
290 | if layer.prior_bias_mu is not None:
291 | layer.prior_bias_mu = det_layer.bias.data
292 |
293 | #initialize the surrogate posteriors
294 |
295 | layer.mu_weight.data = det_layer.weight.data
296 | layer.rho_weight.data = get_rho(det_layer.weight.data, delta)
297 | if layer.mu_bias is not None:
298 | layer.mu_bias.data = det_layer.bias.data
299 | layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
300 |
301 | elif str(layer).startswith('Batch'):
302 | #initialize parameters
303 | print(str(layer))
304 | layer.weight.data = det_layer.weight.data
305 | if layer.bias is not None:
306 | layer.bias.data = det_layer.bias.data
307 | layer.running_mean.data = det_layer.running_mean.data
308 | layer.running_var.data = det_layer.running_var.data
309 | layer.num_batches_tracked.data = det_layer.num_batches_tracked.data
310 |
311 |
312 | def main():
313 | global args, best_prec1, best_avu
314 | args = parser.parse_args()
315 |
316 | # Check the save_dir exists or not
317 | if not os.path.exists(args.save_dir):
318 | os.makedirs(args.save_dir)
319 |
320 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
321 | if torch.cuda.is_available():
322 | model.cuda()
323 | else:
324 | model.cpu()
325 |
326 | # optionally resume from a checkpoint
327 | if args.resume:
328 | if os.path.isfile(args.resume):
329 | print("=> loading checkpoint '{}'".format(args.resume))
330 | checkpoint = torch.load(args.resume)
331 | args.start_epoch = checkpoint['epoch']
332 | best_prec1 = checkpoint['best_prec1']
333 | model.load_state_dict(checkpoint['state_dict'])
334 | print("=> loaded checkpoint '{}' (epoch {})".format(
335 | args.evaluate, checkpoint['epoch']))
336 | else:
337 | print("=> no checkpoint found at '{}'".format(args.resume))
338 |
339 | cudnn.benchmark = True
340 |
341 | tb_writer = None
342 | if args.tensorboard:
343 | logger_dir = os.path.join(args.log_dir, 'tb_logger')
344 | if not os.path.exists(logger_dir):
345 | os.makedirs(logger_dir)
346 | tb_writer = SummaryWriter(logger_dir)
347 |
348 | preds_dir = os.path.join(args.log_dir, 'preds')
349 | if not os.path.exists(preds_dir):
350 | os.makedirs(preds_dir)
351 | results_dir = os.path.join(args.log_dir, 'results')
352 | if not os.path.exists(results_dir):
353 | os.makedirs(results_dir)
354 |
355 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
356 | std=[0.229, 0.224, 0.225])
357 |
358 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
359 | root='./data',
360 | train=True,
361 | transform=transforms.Compose([
362 | transforms.RandomHorizontalFlip(),
363 | transforms.RandomCrop(32, 4),
364 | transforms.ToTensor()
365 | ]),
366 | download=True),
367 | batch_size=args.batch_size,
368 | shuffle=True,
369 | num_workers=args.workers,
370 | pin_memory=True)
371 |
372 | val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
373 | root='./data',
374 | train=False,
375 | transform=transforms.Compose([transforms.ToTensor()])),
376 | batch_size=args.val_batch_size,
377 | shuffle=False,
378 | num_workers=args.workers,
379 | pin_memory=True)
380 |
381 | if not os.path.exists(args.save_dir):
382 | os.makedirs(args.save_dir)
383 |
384 | if torch.cuda.is_available():
385 | criterion = nn.CrossEntropyLoss().cuda()
386 | avu_criterion = AUAvULoss().cuda()
387 | else:
388 | criterion = nn.CrossEntropyLoss().cpu()
389 | avu_criterion = AUAvULoss().cpu()
390 |
391 | if args.half:
392 | model.half()
393 | criterion.half()
394 |
395 | if args.arch in ['resnet110']:
396 | for param_group in optimizer.param_groups:
397 | param_group['lr'] = args.lr * 0.1
398 |
399 | if args.evaluate:
400 | validate(val_loader, model, criterion)
401 | return
402 |
403 | if args.mode == 'train':
404 |
405 | if (args.moped):
406 | print("MOPED enabled")
407 | det_model = torch.nn.DataParallel(det_resnet.__dict__[args.arch]())
408 | if torch.cuda.is_available():
409 | det_model.cuda()
410 | else:
411 | det_model.cpu()
412 | checkpoint_file = 'checkpoint/deterministic/{}_cifar.pth'.format(
413 | args.arch)
414 | checkpoint = torch.load(checkpoint_file)
415 | det_model.load_state_dict(checkpoint['state_dict'])
416 |
417 | for (idx_1, layer_1), (det_idx_1, det_layer_1) in zip(
418 | enumerate(model.children()),
419 | enumerate(det_model.children())):
420 | MOPED_layer(layer_1, det_layer_1, args.delta)
421 | for (idx_2, layer_2), (det_idx_2, det_layer_2) in zip(
422 | enumerate(layer_1.children()),
423 | enumerate(det_layer_1.children())):
424 | MOPED_layer(layer_2, det_layer_2, args.delta)
425 | for (idx_3, layer_3), (det_idx_3, det_layer_3) in zip(
426 | enumerate(layer_2.children()),
427 | enumerate(det_layer_2.children())):
428 | MOPED_layer(layer_3, det_layer_3, args.delta)
429 | for (idx_4, layer_4), (det_idx_4, det_layer_4) in zip(
430 | enumerate(layer_3.children()),
431 | enumerate(det_layer_3.children())):
432 | MOPED_layer(layer_4, det_layer_4, args.delta)
433 |
434 | model.state_dict()
435 |
436 | for epoch in range(args.start_epoch, args.epochs):
437 |
438 | lr = args.lr
439 | if (epoch >= 80 and epoch < 120):
440 | lr = 0.1 * args.lr
441 | elif (epoch >= 120 and epoch < 160):
442 | lr = 0.01 * args.lr
443 | elif (epoch >= 160 and epoch < 180):
444 | lr = 0.001 * args.lr
445 | elif (epoch >= 180):
446 | lr = 0.0005 * args.lr
447 |
448 | optimizer = torch.optim.Adam(model.parameters(), lr)
449 | # train for one epoch
450 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
451 | train(train_loader, model, criterion, avu_criterion, optimizer,
452 | epoch, tb_writer)
453 | #lr_scheduler.step()
454 |
455 | prec1 = validate(val_loader, model, criterion, avu_criterion,
456 | epoch, tb_writer)
457 |
458 | is_best = prec1 > best_prec1
459 | best_prec1 = max(prec1, best_prec1)
460 |
461 | #if epoch > 0 and epoch % args.save_every == 0:
462 | if epoch > 0:
463 | if is_best:
464 | save_checkpoint(
465 | {
466 | 'epoch': epoch + 1,
467 | 'state_dict': model.state_dict(),
468 | 'best_prec1': best_prec1,
469 | },
470 | is_best,
471 | filename=os.path.join(
472 | args.save_dir,
473 | 'bayesian_{}_cifar.pth'.format(args.arch)))
474 |
475 | elif args.mode == 'test':
476 | checkpoint_file = args.save_dir + '/bayesian_{}_cifar.pth'.format(
477 | args.arch)
478 | if torch.cuda.is_available():
479 | checkpoint = torch.load(checkpoint_file)
480 | else:
481 | checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
482 | print('load checkpoint. epoch :', checkpoint['epoch'])
483 | model.load_state_dict(checkpoint['state_dict'])
484 |
485 | #header = ['corrupt', 'test_acc', 'brier', 'ece']
486 | header = ['corrupt', 'test_acc']
487 |
488 | #Evaluate on OOD dataset (SVHN)
489 | ood_images_file = 'data/SVHN/svhn-test.npy'
490 | ood_images = np.load(ood_images_file)
491 | ood_images = ood_images[:10000, :, :, :]
492 | ood_labels = np.arange(len(ood_images)) + 10 #create dummy labels
493 | ood_loader = get_ood_dataloader(ood_images, ood_labels)
494 | ood_acc = evaluate(model, ood_loader, corrupt='ood', level=None)
495 | print('******OOD data***********\n')
496 | print('ood_acc: ', ood_acc)
497 | '''
498 | o_file = args.log_dir + '/results/ood_results.csv'
499 | with open(o_file, 'wt') as o_file:
500 | writer = csv.writer(o_file, delimiter=',', lineterminator='\n')
501 | writer.writerow([j for j in header])
502 | writer.writerow(['ood', ood_acc, ood_brier, ood_ece])
503 | o_file.close()
504 | '''
505 |
506 | #Evaluate on test dataset
507 | test_acc = evaluate(model, val_loader, corrupt=None, level=None)
508 | print('******Test data***********\n')
509 | print('test_acc: ', test_acc)
510 | '''
511 | t_file = args.log_dir + '/results/test_results.csv'
512 | with open(t_file, 'wt') as t_file:
513 | writer = csv.writer(t_file, delimiter=',', lineterminator='\n')
514 | writer.writerow([j for j in header])
515 | writer.writerow(['test', test_acc, brier, ece])
516 | t_file.close()
517 | '''
518 |
519 | for level in range(1, 6):
520 | print('******Corruption Level: ', level, ' ***********\n')
521 | results_file = args.log_dir + '/results/level' + str(
522 | level) + '.csv'
523 | with open(results_file, 'wt') as results_file:
524 | writer = csv.writer(results_file,
525 | delimiter=',',
526 | lineterminator='\n')
527 | writer.writerow([j for j in header])
528 | for c in corruptions:
529 | images_file = 'data/CIFAR-10-C/' + c + '.npy'
530 | labels_file = 'data/CIFAR-10-C/labels.npy'
531 | corrupt_images = np.load(images_file)
532 | corrupt_labels = np.load(labels_file)
533 |
534 | val_loader = get_corrupt_dataloader(corrupt_images,
535 | corrupt_labels,
536 | level=level)
537 | test_acc = evaluate(model,
538 | val_loader,
539 | corrupt=c,
540 | level=level)
541 | print('############ Corruption type: ', c,
542 | ' ################')
543 | print('test_acc: ', test_acc, '\n')
544 | writer.writerow([c, test_acc])
545 | results_file.close()
546 |
547 |
548 | def train(train_loader,
549 | model,
550 | criterion,
551 | avu_criterion,
552 | optimizer,
553 | epoch,
554 | tb_writer=None):
555 | batch_time = AverageMeter()
556 | data_time = AverageMeter()
557 | losses = AverageMeter()
558 | top1 = AverageMeter()
559 | avg_unc = AverageMeter()
560 |
561 | # switch to train mode
562 | model.train()
563 |
564 | end = time.time()
565 | for i, (input, target) in enumerate(train_loader):
566 |
567 | # measure data loading time
568 | data_time.update(time.time() - end)
569 |
570 | if torch.cuda.is_available():
571 | target = target.cuda()
572 | input_var = input.cuda()
573 | else:
574 | target = target.cpu()
575 | input_var = input.cpu()
576 | target_var = target
577 | if args.half:
578 | input_var = input_var.half()
579 |
580 | optimizer.zero_grad()
581 |
582 | output, kl = model(input_var)
583 | probs_ = torch.nn.functional.softmax(output, dim=1)
584 | probs = probs_.data.cpu().numpy()
585 |
586 | pred_entropy = util.entropy(probs)
587 | unc = np.mean(pred_entropy, axis=0)
588 | preds = np.argmax(probs, axis=-1)
589 |
590 | cross_entropy_loss = criterion(output, target_var)
591 | scaled_kl = kl.data / len_trainset
592 | elbo_loss = cross_entropy_loss + scaled_kl
593 | avu_loss, auc_avu = avu_criterion(output, target_var, type=0)
594 | avu_loss = beta * avu_loss
595 | loss = cross_entropy_loss + scaled_kl + avu_loss
596 |
597 | # compute gradient and do SGD step
598 | loss.backward()
599 | optimizer.step()
600 |
601 | output = output.float()
602 | loss = loss.float()
603 | # measure accuracy and record loss
604 | prec1 = accuracy(output.data, target)[0]
605 | losses.update(loss.item(), input.size(0))
606 | top1.update(prec1.item(), input.size(0))
607 | avg_unc.update(unc, input.size(0))
608 |
609 | # measure elapsed time
610 | batch_time.update(time.time() - end)
611 | end = time.time()
612 |
613 | if i % args.print_freq == 0:
614 | #print('opt_th: ', opt_th)
615 | print('Epoch: [{0}][{1}/{2}]\t'
616 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
617 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
618 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
619 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
620 | 'Avg_Unc {avg_unc.val:.3f} ({avg_unc.avg:.3f})'.format(
621 | epoch,
622 | i,
623 | len(train_loader),
624 | batch_time=batch_time,
625 | data_time=data_time,
626 | loss=losses,
627 | top1=top1,
628 | avg_unc=avg_unc))
629 |
630 | if tb_writer is not None:
631 | tb_writer.add_scalar('train/cross_entropy_loss',
632 | cross_entropy_loss.item(), epoch)
633 | tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch)
634 | tb_writer.add_scalar('train/elbo_loss', elbo_loss.item(), epoch)
635 | tb_writer.add_scalar('train/avu_loss', avu_loss, epoch)
636 | tb_writer.add_scalar('train/loss', loss.item(), epoch)
637 | tb_writer.add_scalar('train/AUC-AvU', auc_avu, epoch)
638 | tb_writer.add_scalar('train/accuracy', prec1.item(), epoch)
639 | tb_writer.flush()
640 |
641 |
642 | def validate(val_loader,
643 | model,
644 | criterion,
645 | avu_criterion,
646 | epoch,
647 | tb_writer=None):
648 | batch_time = AverageMeter()
649 | losses = AverageMeter()
650 | top1 = AverageMeter()
651 | avg_unc = AverageMeter()
652 | global opt_th
653 |
654 | # switch to evaluate mode
655 | model.eval()
656 |
657 | end = time.time()
658 | preds_list = []
659 | labels_list = []
660 | unc_list = []
661 | th_list = np.linspace(0, 1, 21)
662 | with torch.no_grad():
663 | for i, (input, target) in enumerate(val_loader):
664 | if torch.cuda.is_available():
665 | target = target.cuda()
666 | input_var = input.cuda()
667 | target_var = target.cuda()
668 | else:
669 | target = target.cpu()
670 | input_var = input.cpu()
671 | target_var = target.cpu()
672 |
673 | if args.half:
674 | input_var = input_var.half()
675 |
676 | output, kl = model(input_var)
677 | probs_ = torch.nn.functional.softmax(output, dim=1)
678 | probs = probs_.data.cpu().numpy()
679 |
680 | pred_entropy = util.entropy(probs)
681 | unc = np.mean(pred_entropy, axis=0)
682 | preds = np.argmax(probs, axis=-1)
683 | preds_list.append(preds)
684 | labels_list.append(target.cpu().data.numpy())
685 | unc_list.append(pred_entropy)
686 |
687 | cross_entropy_loss = criterion(output, target_var)
688 | scaled_kl = kl.data / len_trainset
689 | elbo_loss = cross_entropy_loss + scaled_kl
690 | avu_loss, auc_avu = avu_criterion(output, target_var, type=0)
691 | avu_loss = beta * avu_loss
692 | loss = cross_entropy_loss + scaled_kl + avu_loss
693 |
694 | output = output.float()
695 | loss = loss.float()
696 |
697 | # measure accuracy and record loss
698 | prec1 = accuracy(output.data, target)[0]
699 | losses.update(loss.item(), input.size(0))
700 | top1.update(prec1.item(), input.size(0))
701 | avg_unc.update(unc, input.size(0))
702 |
703 | # measure elapsed time
704 | batch_time.update(time.time() - end)
705 | end = time.time()
706 |
707 | if i % args.print_freq == 0:
708 | print('Test: [{0}/{1}]\t'
709 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
710 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
711 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
712 | 'Avg_Unc {avg_unc.val:.3f} ({avg_unc.avg:.3f})'.format(
713 | i,
714 | len(val_loader),
715 | batch_time=batch_time,
716 | loss=losses,
717 | top1=top1,
718 | avg_unc=avg_unc))
719 |
720 | if tb_writer is not None:
721 | tb_writer.add_scalar('val/cross_entropy_loss',
722 | cross_entropy_loss.item(), epoch)
723 | tb_writer.add_scalar('val/kl_div', scaled_kl.item(), epoch)
724 | tb_writer.add_scalar('val/elbo_loss', elbo_loss.item(), epoch)
725 | tb_writer.add_scalar('val/avu_loss', avu_loss, epoch)
726 | tb_writer.add_scalar('val/loss', loss.item(), epoch)
727 | tb_writer.add_scalar('val/AUC-AvU', auc_avu, epoch)
728 | tb_writer.add_scalar('val/accuracy', prec1.item(), epoch)
729 | tb_writer.flush()
730 |
731 | preds = np.hstack(np.asarray(preds_list))
732 | labels = np.hstack(np.asarray(labels_list))
733 | unc_ = np.hstack(np.asarray(unc_list))
734 | avu_th, unc_th = util.eval_avu(preds, labels, unc_)
735 | print('max AvU: ', np.amax(avu_th))
736 | unc_correct = np.take(unc_, np.where(preds == labels))
737 | unc_incorrect = np.take(unc_, np.where(preds != labels))
738 | print('avg unc correct preds: ',
739 | np.mean(np.take(unc_, np.where(preds == labels)), axis=1))
740 | print('avg unc incorrect preds: ',
741 | np.mean(np.take(unc_, np.where(preds != labels)), axis=1))
742 | '''
743 | print('unc @max AvU: ', unc_th[np.argmax(avu_th)])
744 | print('avg unc: ', np.mean(unc_, axis=0))
745 | print('avg unc: ', np.mean(unc_th, axis=0))
746 | print('min unc: ', np.amin(unc_))
747 | print('max unc: ', np.amax(unc_))
748 | '''
749 | if epoch <= 5:
750 | opt_th = (np.mean(unc_correct, axis=1) +
751 | np.mean(unc_incorrect, axis=1)) / 2
752 |
753 | print('opt_th: ', opt_th)
754 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
755 | return top1.avg
756 |
757 |
758 | def evaluate(model, val_loader, corrupt=None, level=None):
759 | pred_probs_mc = []
760 | test_loss = 0
761 | correct = 0
762 | with torch.no_grad():
763 | pred_probs_mc = []
764 | for batch_idx, (data, target) in enumerate(val_loader):
765 | if torch.cuda.is_available():
766 | data, target = data.cuda(), target.cuda()
767 | else:
768 | data, target = data.cpu(), target.cpu()
769 | for mc_run in range(args.num_monte_carlo):
770 | model.eval()
771 | output, _ = model.forward(data)
772 | pred_probs = torch.nn.functional.softmax(output, dim=1)
773 | pred_probs_mc.append(pred_probs.cpu().data.numpy())
774 |
775 | if corrupt == 'ood':
776 | np.save(args.log_dir + '/preds/svi_avu_ood_probs.npy',
777 | pred_probs_mc)
778 | print('saved predictive probabilities')
779 | return None
780 |
781 | target_labels = target.cpu().data.numpy()
782 | pred_mean = np.mean(pred_probs_mc, axis=0)
783 | #print(pred_mean)
784 | Y_pred = np.argmax(pred_mean, axis=1)
785 | test_acc = (Y_pred == target_labels).mean()
786 | #brier = np.mean(calib.brier_scores(target_labels, probs=pred_mean))
787 | #ece = calib.expected_calibration_error_multiclass(pred_mean, target_labels)
788 | print('Test accuracy:', test_acc * 100)
789 | #print('Brier score: ', brier)
790 | #print('ECE: ', ece)
791 | if corrupt is not None:
792 | np.save(
793 | args.log_dir +
794 | '/preds/svi_avu_corrupt-static-{}-{}_probs.npy'.format(
795 | corrupt, level), pred_probs_mc)
796 | np.save(
797 | args.log_dir +
798 | '/preds/svi_avu_corrupt-static-{}-{}_labels.npy'.format(
799 | corrupt, level), target_labels)
800 | print('saved predictive probabilities')
801 | elif corrupt == 'test':
802 | np.save(args.log_dir + '/preds/svi_avu_test_probs.npy',
803 | pred_probs_mc)
804 | np.save(args.log_dir + '/preds/svi_avu_test_labels.npy',
805 | target_labels)
806 | print('saved predictive probabilities')
807 | return test_acc
808 |
809 |
810 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
811 | """
812 | Save the training model
813 | """
814 | torch.save(state, filename)
815 |
816 |
817 | class AverageMeter(object):
818 | """Computes and stores the average and current value"""
819 | def __init__(self):
820 | self.reset()
821 |
822 | def reset(self):
823 | self.val = 0
824 | self.avg = 0
825 | self.sum = 0
826 | self.count = 0
827 |
828 | def update(self, val, n=1):
829 | self.val = val
830 | self.sum += val * n
831 | self.count += n
832 | self.avg = self.sum / self.count
833 |
834 |
835 | def accuracy(output, target, topk=(1, )):
836 | """Computes the precision@k for the specified values of k"""
837 | maxk = max(topk)
838 | batch_size = target.size(0)
839 |
840 | _, pred = output.topk(maxk, 1, True, True)
841 | pred = pred.t()
842 | correct = pred.eq(target.view(1, -1).expand_as(pred))
843 |
844 | res = []
845 | for k in topk:
846 | correct_k = correct[:k].view(-1).float().sum(0)
847 | res.append(correct_k.mul_(100.0 / batch_size))
848 | return res
849 |
850 |
851 | if __name__ == '__main__':
852 | main()
853 |
--------------------------------------------------------------------------------
/src/main_bayesian_imagenet_avu.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.distributed as dist
13 | import torch.optim
14 | import torch.multiprocessing as mp
15 | import torch.utils.data
16 | import torch.utils.data.distributed
17 | from torch.nn import functional as F
18 | import torchvision
19 | import torchvision.transforms as transforms
20 | import torchvision.datasets as datasets
21 | import torchvision.models as models
22 | import models.bayesian.resnet_large as resnet
23 | import models.deterministic.resnet_large as det_resnet
24 | from src import util
25 | #from utils import calib
26 | import csv
27 | import numpy as np
28 | from src.util import get_rho
29 | from src.avuc_loss import AvULoss
30 | from torch.utils.tensorboard import SummaryWriter
31 |
32 | torchvision.set_image_backend('accimage')
33 |
34 | model_names = sorted(
35 | name for name in resnet.__dict__
36 | if name.islower() and not name.startswith("__")
37 | and name.startswith("resnet") and callable(resnet.__dict__[name]))
38 |
39 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
40 | parser.add_argument('data',
41 | metavar='DIR',
42 | default='data/imagenet',
43 | help='path to dataset')
44 | parser.add_argument('--corrupt_data',
45 | type=str,
46 | default='data/ImageNet-C',
47 | metavar='N',
48 | help='path to corrupt dataset')
49 | parser.add_argument('-a',
50 | '--arch',
51 | metavar='ARCH',
52 | default='resnet50',
53 | choices=model_names,
54 | help='model architecture: ' + ' | '.join(model_names) +
55 | ' (default: resnet50)')
56 | parser.add_argument('-j',
57 | '--workers',
58 | default=8,
59 | type=int,
60 | metavar='N',
61 | help='number of data loading workers (default: 4)')
62 | parser.add_argument('--epochs',
63 | default=90,
64 | type=int,
65 | metavar='N',
66 | help='number of total epochs to run')
67 | parser.add_argument('--start-epoch',
68 | default=0,
69 | type=int,
70 | metavar='N',
71 | help='manual epoch number (useful on restarts)')
72 | parser.add_argument('--val_batch_size', default=1000, type=int)
73 | parser.add_argument('-b',
74 | '--batch-size',
75 | default=32,
76 | type=int,
77 | metavar='N',
78 | help='mini-batch size (default: 256), this is the total '
79 | 'batch size of all GPUs on the current node when '
80 | 'using Data Parallel or Distributed Data Parallel')
81 | parser.add_argument('--lr',
82 | '--learning-rate',
83 | default=0.001,
84 | type=float,
85 | metavar='LR',
86 | help='initial learning rate',
87 | dest='lr')
88 | parser.add_argument('--momentum',
89 | default=0.9,
90 | type=float,
91 | metavar='M',
92 | help='momentum')
93 | parser.add_argument('--wd',
94 | '--weight-decay',
95 | default=1e-4,
96 | type=float,
97 | metavar='W',
98 | help='weight decay (default: 1e-4)',
99 | dest='weight_decay')
100 | parser.add_argument('-p',
101 | '--print-freq',
102 | default=10,
103 | type=int,
104 | metavar='N',
105 | help='print frequency (default: 10)')
106 | parser.add_argument('--resume',
107 | default='',
108 | type=str,
109 | metavar='PATH',
110 | help='path to latest checkpoint (default: none)')
111 | parser.add_argument('-e',
112 | '--evaluate',
113 | dest='evaluate',
114 | action='store_true',
115 | help='evaluate model on validation set')
116 | parser.add_argument('--pretrained',
117 | dest='pretrained',
118 | action='store_true',
119 | default=True,
120 | help='use pre-trained model')
121 | parser.add_argument('--world-size',
122 | default=-1,
123 | type=int,
124 | help='number of nodes for distributed training')
125 | parser.add_argument('--rank',
126 | default=-1,
127 | type=int,
128 | help='node rank for distributed training')
129 | parser.add_argument('--dist-url',
130 | default='tcp://224.66.41.62:23456',
131 | type=str,
132 | help='url used to set up distributed training')
133 | parser.add_argument('--dist-backend',
134 | default='nccl',
135 | type=str,
136 | help='distributed backend')
137 | parser.add_argument('--seed',
138 | default=None,
139 | type=int,
140 | help='seed for initializing training. ')
141 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
142 | parser.add_argument('--multiprocessing-distributed',
143 | action='store_true',
144 | help='Use multi-processing distributed training to launch '
145 | 'N processes per node, which has N GPUs. This is the '
146 | 'fastest way to use PyTorch for either single node or '
147 | 'multi node data parallel training')
148 | parser.add_argument('--mode', type=str, required=True, help='train | test')
149 | parser.add_argument('--save-dir',
150 | dest='save_dir',
151 | help='The directory used to save the trained models',
152 | default='./checkpoint/imagenet/bayesian_svi_avu',
153 | type=str)
154 | parser.add_argument(
155 | '--tensorboard',
156 | type=bool,
157 | default=True,
158 | metavar='N',
159 | help='use tensorboard for logging and visualization of training progress')
160 | parser.add_argument(
161 | '--log_dir',
162 | type=str,
163 | default='./logs/imagenet/bayesian_svi_avu',
164 | metavar='N',
165 | help='use tensorboard for logging and visualization of training progress')
166 | parser.add_argument('--num_monte_carlo',
167 | type=int,
168 | default=20,
169 | metavar='N',
170 | help='number of Monte Carlo samples')
171 | parser.add_argument(
172 | '--moped',
173 | type=bool,
174 | default=True,
175 | help='set prior and initialize approx posterior with Empirical Bayes')
176 | parser.add_argument('--delta',
177 | type=float,
178 | default=0.5,
179 | help='delta value for variance scaling in MOPED')
180 | best_acc1 = 0
181 | opt_th = 1.0
182 | len_trainset = 1281167
183 | beta = 3.0
184 |
185 |
186 | def MOPED_layer(layer, det_layer, delta):
187 | """
188 | Set the priors and initialize surrogate posteriors of Bayesian NN with Empirical Bayes
189 | MOPED (Model Priors with Empirical Bayes using Deterministic DNN)
190 | Ref: https://arxiv.org/abs/1906.05323
191 | 'Specifying Weight Priors in Bayesian Deep Neural Networks with Empirical Bayes'. AAAI 2020.
192 | """
193 |
194 | if (str(layer) == 'Conv2dVariational()'):
195 | #set the priors
196 | print(str(layer))
197 | layer.prior_weight_mu = det_layer.weight.data
198 | if layer.prior_bias_mu is not None:
199 | layer.prior_bias_mu = det_layer.bias.data
200 |
201 | #initialize surrogate posteriors
202 | layer.mu_kernel.data = det_layer.weight.data
203 | if layer.mu_bias is not None:
204 | layer.mu_bias.data = det_layer.bias.data
205 |
206 | elif (isinstance(layer, nn.Conv2d)):
207 | print(str(layer))
208 | layer.weight.data = det_layer.weight.data
209 | if layer.bias is not None:
210 | layer.bias.data = det_layer.bias.data2
211 |
212 | elif (str(layer) == 'LinearVariational()'):
213 | print(str(layer))
214 | layer.prior_weight_mu = det_layer.weight.data
215 | if layer.prior_bias_mu is not None:
216 | layer.prior_bias_mu = det_layer.bias.data
217 |
218 | #initialize the surrogate posteriors
219 |
220 | layer.mu_weight.data = det_layer.weight.data
221 | layer.rho_weight.data = get_rho(det_layer.weight.data, delta)
222 | if layer.mu_bias is not None:
223 | layer.mu_bias.data = det_layer.bias.data
224 | layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
225 |
226 | elif str(layer).startswith('Batch'):
227 | #initialize parameters
228 | print(str(layer))
229 | layer.weight.data = det_layer.weight.data
230 | if layer.bias is not None:
231 | layer.bias.data = det_layer.bias.data
232 | layer.running_mean.data = det_layer.running_mean.data
233 | layer.running_var.data = det_layer.running_var.data
234 | layer.num_batches_tracked.data = det_layer.num_batches_tracked.data
235 |
236 | corruptions = [
237 | 'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog',
238 | 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise',
239 | 'pixelate', 'saturate', 'shot_noise', 'spatter', 'speckle_noise',
240 | 'zoom_blur'
241 | ]
242 |
243 |
244 | def get_corrupt_dataloader(args, corrupt_type, level):
245 |
246 | corrupt_dir = os.path.join(args.corrupt_data, str(corrupt_type),
247 | str(level))
248 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
249 | std=[0.229, 0.224, 0.225])
250 | corrupt_dataset = datasets.ImageFolder(
251 | corrupt_dir,
252 | transforms.Compose([
253 | transforms.Resize(256),
254 | transforms.CenterCrop(224),
255 | transforms.ToTensor(),
256 | normalize,
257 | ]))
258 |
259 | corrupt_val_loader = torch.utils.data.DataLoader(
260 | corrupt_dataset,
261 | batch_size=args.val_batch_size,
262 | shuffle=False,
263 | num_workers=args.workers,
264 | pin_memory=True)
265 |
266 | return corrupt_val_loader
267 |
268 |
269 | def main():
270 | args = parser.parse_args()
271 |
272 | if args.seed is not None:
273 | random.seed(args.seed)
274 | torch.manual_seed(args.seed)
275 | cudnn.deterministic = True
276 | warnings.warn('You have chosen to seed training. '
277 | 'This will turn on the CUDNN deterministic setting, '
278 | 'which can slow down your training considerably! '
279 | 'You may see unexpected behavior when restarting '
280 | 'from checkpoints.')
281 |
282 | if args.gpu is not None:
283 | warnings.warn('You have chosen a specific GPU. This will completely '
284 | 'disable data parallelism.')
285 |
286 | if args.dist_url == "env://" and args.world_size == -1:
287 | args.world_size = int(os.environ["WORLD_SIZE"])
288 |
289 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
290 |
291 | if torch.cuda.is_available():
292 | ngpus_per_node = torch.cuda.device_count()
293 | if args.multiprocessing_distributed:
294 | # Since we have ngpus_per_node processes per node, the total world_size
295 | # needs to be adjusted accordingly
296 | args.world_size = ngpus_per_node * args.world_size
297 | # Use torch.multiprocessing.spawn to launch distributed processes: the
298 | # main_worker process function
299 | mp.spawn(main_worker,
300 | nprocs=ngpus_per_node,
301 | args=(ngpus_per_node, args))
302 | else:
303 | # Simply call main_worker function
304 | main_worker(args.gpu, ngpus_per_node, args)
305 |
306 |
307 | def main_worker(gpu, ngpus_per_node, args):
308 | global best_acc1
309 | args.gpu = gpu
310 |
311 | if args.gpu is not None:
312 | print("Use GPU: {} for training".format(args.gpu))
313 |
314 | if args.distributed:
315 | if args.dist_url == "env://" and args.rank == -1:
316 | args.rank = int(os.environ["RANK"])
317 | if args.multiprocessing_distributed:
318 | # For multiprocessing distributed training, rank needs to be the
319 | # global rank among all the processes
320 | args.rank = args.rank * ngpus_per_node + gpu
321 | dist.init_process_group(backend=args.dist_backend,
322 | init_method=args.dist_url,
323 | world_size=args.world_size,
324 | rank=args.rank)
325 |
326 | if not os.path.exists(args.save_dir):
327 | os.makedirs(args.save_dir)
328 |
329 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
330 | if torch.cuda.is_available():
331 | model.cuda()
332 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
333 | avu_criterion = AvULoss().cuda()
334 | else:
335 | model.cpu()
336 | criterion = nn.CrossEntropyLoss().cpu()
337 | avu_criterion = AvULoss().cpu()
338 |
339 | optimizer = torch.optim.SGD(model.parameters(),
340 | args.lr,
341 | momentum=args.momentum,
342 | weight_decay=args.weight_decay)
343 |
344 | # optionally resume from a checkpoint
345 | if args.resume:
346 | if os.path.isfile(args.resume):
347 | print("=> loading checkpoint '{}'".format(args.resume))
348 | if args.gpu is None:
349 | checkpoint = torch.load(args.resume)
350 | else:
351 | # Map model to be loaded to specified single gpu.
352 | loc = 'cuda:{}'.format(args.gpu)
353 | checkpoint = torch.load(args.resume, map_location=loc)
354 | args.start_epoch = checkpoint['epoch']
355 | best_acc1 = checkpoint['best_acc1']
356 | if args.gpu is not None:
357 | # best_acc1 may be from a checkpoint from a different GPU
358 | best_acc1 = best_acc1.to(args.gpu)
359 | model.load_state_dict(checkpoint['state_dict'])
360 | optimizer.load_state_dict(checkpoint['optimizer'])
361 | print("=> loaded checkpoint '{}' (epoch {})".format(
362 | args.resume, checkpoint['epoch']))
363 | else:
364 | print("=> no checkpoint found at '{}'".format(args.resume))
365 |
366 | cudnn.benchmark = True
367 |
368 | tb_writer = None
369 | if args.tensorboard:
370 | logger_dir = os.path.join(args.log_dir, 'tb_logger')
371 | if not os.path.exists(logger_dir):
372 | os.makedirs(logger_dir)
373 | tb_writer = SummaryWriter(logger_dir)
374 |
375 | preds_dir = os.path.join(args.log_dir, 'preds')
376 | if not os.path.exists(preds_dir):
377 | os.makedirs(preds_dir)
378 | results_dir = os.path.join(args.log_dir, 'results')
379 | if not os.path.exists(results_dir):
380 | os.makedirs(results_dir)
381 |
382 | # Data loading code
383 | traindir = os.path.join(args.data, 'train')
384 | valdir = os.path.join(args.data, 'val')
385 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
386 | std=[0.229, 0.224, 0.225])
387 |
388 | train_dataset = datasets.ImageFolder(
389 | traindir,
390 | transforms.Compose([
391 | transforms.RandomResizedCrop(224),
392 | transforms.RandomHorizontalFlip(),
393 | transforms.ToTensor(),
394 | normalize,
395 | ]))
396 |
397 | val_dataset = datasets.ImageFolder(
398 | valdir,
399 | transforms.Compose([
400 | transforms.Resize(256),
401 | transforms.CenterCrop(224),
402 | transforms.ToTensor(),
403 | normalize,
404 | ]))
405 |
406 | print('len trainset: ', len(train_dataset))
407 | print('len valset: ', len(val_dataset))
408 | len_trainset = len(train_dataset)
409 |
410 | if args.distributed:
411 | train_sampler = torch.utils.data.distributed.DistributedSampler(
412 | train_dataset)
413 | else:
414 | train_sampler = None
415 |
416 | train_loader = torch.utils.data.DataLoader(train_dataset,
417 | batch_size=args.batch_size,
418 | shuffle=True,
419 | num_workers=args.workers,
420 | pin_memory=True,
421 | drop_last=True)
422 |
423 | val_loader = torch.utils.data.DataLoader(val_dataset,
424 | batch_size=args.val_batch_size,
425 | shuffle=False,
426 | num_workers=args.workers,
427 | pin_memory=True)
428 |
429 | if args.evaluate:
430 | validate(val_loader, model, criterion, args)
431 | return
432 |
433 | if args.mode == 'train':
434 |
435 | if (args.moped):
436 | print("MOPED enabled")
437 | det_model = torch.nn.DataParallel(
438 | det_resnet.__dict__[args.arch](pretrained=True))
439 | if torch.cuda.is_available():
440 | det_model.cuda()
441 | else:
442 | det_model.cpu()
443 |
444 | for (idx_1, layer_1), (det_idx_1, det_layer_1) in zip(
445 | enumerate(model.children()),
446 | enumerate(det_model.children())):
447 | MOPED_layer(layer_1, det_layer_1, args.delta)
448 | for (idx_2, layer_2), (det_idx_2, det_layer_2) in zip(
449 | enumerate(layer_1.children()),
450 | enumerate(det_layer_1.children())):
451 | MOPED_layer(layer_2, det_layer_2, args.delta)
452 | for (idx_3, layer_3), (det_idx_3, det_layer_3) in zip(
453 | enumerate(layer_2.children()),
454 | enumerate(det_layer_2.children())):
455 | MOPED_layer(layer_3, det_layer_3, args.delta)
456 | for (idx_4, layer_4), (det_idx_4, det_layer_4) in zip(
457 | enumerate(layer_3.children()),
458 | enumerate(det_layer_3.children())):
459 | MOPED_layer(layer_4, det_layer_4, args.delta)
460 | for (idx_5,
461 | layer_5), (det_idx_5, det_layer_5) in zip(
462 | enumerate(layer_4.children()),
463 | enumerate(det_layer_4.children())):
464 | MOPED_layer(layer_5, det_layer_5, args.delta)
465 | for (idx_6,
466 | layer_6), (det_idx_6, det_layer_6) in zip(
467 | enumerate(layer_5.children()),
468 | enumerate(det_layer_5.children())):
469 | MOPED_layer(layer_6, det_layer_6,
470 | args.delta)
471 |
472 | model.state_dict()
473 | del det_model
474 |
475 | for epoch in range(args.start_epoch, args.epochs):
476 | if args.distributed:
477 | train_sampler.set_epoch(epoch)
478 | adjust_learning_rate(optimizer, epoch, args)
479 |
480 | # train for one epoch
481 | train(train_loader, model, criterion, avu_criterion, optimizer,
482 | epoch, args, tb_writer)
483 |
484 | # evaluate on validation set
485 | acc1 = validate(val_loader, model, criterion, avu_criterion, epoch,
486 | args, tb_writer)
487 |
488 | # remember best acc@1 and save checkpoint
489 | is_best = acc1 > best_acc1
490 | best_acc1 = max(acc1, best_acc1)
491 |
492 | if is_best:
493 | save_checkpoint(
494 | {
495 | 'epoch': epoch + 1,
496 | 'arch': args.arch,
497 | 'state_dict': model.state_dict(),
498 | 'best_acc1': best_acc1,
499 | 'optimizer': optimizer.state_dict(),
500 | },
501 | is_best,
502 | filename=os.path.join(
503 | args.save_dir,
504 | 'bayesian_{}_imagenet.pth'.format(args.arch)))
505 |
506 | elif args.mode == 'test':
507 |
508 | checkpoint_file = args.save_dir + '/bayesian_{}_imagenet.pth'.format(
509 | args.arch)
510 | if torch.cuda.is_available():
511 | checkpoint = torch.load(checkpoint_file)
512 | else:
513 | checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
514 | print('load checkpoint.')
515 | model.load_state_dict(checkpoint['state_dict'])
516 |
517 | #header = ['corrupt', 'test_acc', 'brier', 'ece']
518 | header = ['corrupt', 'test_acc']
519 |
520 | #Evaluate on test dataset
521 | #test_acc, brier, ece = evaluate(model, val_loader, args, corrupt=None, level=None)
522 | test_acc = evaluate(model, val_loader, args, corrupt=None, level=None)
523 | print('******Test data***********\n')
524 | #print('test_acc: ', test_acc, ' | Brier: ', brier, ' | ECE: ', ece, '\n')
525 | print('test_acc: ', test_acc)
526 | '''
527 | t_file = args.log_dir + '/results/test_results.csv'
528 | with open(t_file, 'wt') as t_file:
529 | writer = csv.writer(t_file, delimiter=',', lineterminator='\n')
530 | writer.writerow([j for j in header])
531 | writer.writerow(['test', test_acc, brier, ece])
532 | t_file.close()
533 | '''
534 |
535 | for level in range(1, 6):
536 | print('******Corruption Level: ', level, ' ***********\n')
537 | results_file = args.log_dir + '/results/level' + str(
538 | level) + '.csv'
539 | with open(results_file, 'wt') as results_file:
540 | writer = csv.writer(results_file,
541 | delimiter=',',
542 | lineterminator='\n')
543 | writer.writerow([j for j in header])
544 | for c in corruptions:
545 | val_loader = get_corrupt_dataloader(args, c, level)
546 | test_acc = evaluate(model,
547 | val_loader,
548 | args,
549 | corrupt=c,
550 | level=level)
551 | print('############ Corruption type: ', c,
552 | ' ################')
553 | print('test_acc: ', test_acc, '\n')
554 | writer.writerow([c, test_acc])
555 | results_file.close()
556 |
557 |
558 | def train(train_loader, model, criterion, avu_criterion, optimizer, epoch,
559 | args, tb_writer):
560 | batch_time = AverageMeter('Time', ':6.3f')
561 | data_time = AverageMeter('Data', ':6.3f')
562 | losses = AverageMeter('Loss', ':.4e')
563 | top1 = AverageMeter('Acc@1', ':6.2f')
564 | top5 = AverageMeter('Acc@5', ':6.2f')
565 | global opt_th
566 | progress = ProgressMeter(len(train_loader),
567 | [batch_time, data_time, losses, top1, top5],
568 | prefix="Epoch: [{}]".format(epoch))
569 |
570 | # switch to train mode
571 | model.train()
572 |
573 | end = time.time()
574 | for i, (images, target) in enumerate(train_loader):
575 | # measure data loading time
576 | data_time.update(time.time() - end)
577 |
578 | if torch.cuda.is_available():
579 | images = images.cuda(non_blocking=True)
580 | target = target.cuda(non_blocking=True)
581 | else:
582 | images = images.cpu(non_blocking=True)
583 | target = target.cpu(non_blocking=True)
584 |
585 | # compute output
586 | output, kl = model(images)
587 | probs_ = torch.nn.functional.softmax(output, dim=1)
588 | probs = probs_.data.cpu().numpy()
589 |
590 | pred_entropy = util.entropy(probs)
591 | preds = np.argmax(probs, axis=-1)
592 | AvU = util.accuracy_vs_uncertainty(np.array(preds),
593 | np.array(target.cpu().data.numpy()),
594 | np.array(pred_entropy), opt_th)
595 |
596 | preds_list.append(preds)
597 | labels_list.append(target.cpu().data.numpy())
598 | unc_list.append(pred_entropy)
599 |
600 | cross_entropy_loss = criterion(output, target)
601 | scaled_kl = (kl.data[0] / len_trainset)
602 | elbo_loss = cross_entropy_loss + scaled_kl
603 | avu_loss = beta * avu_criterion(output, target, opt_th, type=0)
604 | loss = cross_entropy_loss + scaled_kl + avu_loss
605 |
606 | output = output.float()
607 | loss = loss.float()
608 |
609 | # measure accuracy and record loss
610 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
611 | losses.update(loss.item(), images.size(0))
612 | top1.update(acc1[0], images.size(0))
613 | top5.update(acc5[0], images.size(0))
614 |
615 | # compute gradient and do SGD step
616 | optimizer.zero_grad()
617 | loss.mean().backward()
618 | optimizer.step()
619 |
620 | # measure elapsed time
621 | batch_time.update(time.time() - end)
622 | end = time.time()
623 |
624 | if i % args.print_freq == 0:
625 | progress.display(i)
626 |
627 | if tb_writer is not None:
628 | tb_writer.add_scalar('train/cross_entropy_loss',
629 | cross_entropy_loss.item(), epoch)
630 | tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch)
631 | tb_writer.add_scalar('train/elbo_loss', elbo_loss.item(), epoch)
632 | tb_writer.add_scalar('train/avu_loss', avu_loss.item(), epoch)
633 | tb_writer.add_scalar('train/loss', loss.item(), epoch)
634 | tb_writer.add_scalar('train/AvU', AvU, epoch)
635 | tb_writer.add_scalar('train/accuracy', acc1.item(), epoch)
636 | tb_writer.flush()
637 |
638 | preds = np.hstack(np.asarray(preds_list))
639 | labels = np.hstack(np.asarray(labels_list))
640 | unc_ = np.hstack(np.asarray(unc_list))
641 | unc_correct = np.take(unc_, np.where(preds == labels))
642 | unc_incorrect = np.take(unc_, np.where(preds != labels))
643 | #print('avg unc correct preds: ', np.mean(np.take(unc_,np.where(preds == labels)), axis=1))
644 | #print('avg unc incorrect preds: ', np.mean(np.take(unc_,np.where(preds != labels)), axis=1))
645 | if epoch <= 1:
646 | opt_th = (np.mean(unc_correct, axis=1) +
647 | np.mean(unc_incorrect, axis=1)) / 2
648 |
649 | print('opt_th: ', opt_th)
650 |
651 |
652 | def validate(val_loader, model, criterion, avu_criterion, epoch, args,
653 | tb_writer):
654 | batch_time = AverageMeter('Time', ':6.3f')
655 | losses = AverageMeter('Loss', ':.4e')
656 | top1 = AverageMeter('Acc@1', ':6.2f')
657 | top5 = AverageMeter('Acc@5', ':6.2f')
658 | progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
659 | prefix='Test: ')
660 |
661 | # switch to evaluate mode
662 | model.eval()
663 |
664 | preds_list = []
665 | labels_list = []
666 | unc_list = []
667 | with torch.no_grad():
668 | end = time.time()
669 | for i, (images, target) in enumerate(val_loader):
670 | if torch.cuda.is_available():
671 | images = images.cuda(non_blocking=True)
672 | target = target.cuda(non_blocking=True)
673 | else:
674 | images = images.cpu(non_blocking=True)
675 | target = target.cpu(non_blocking=True)
676 |
677 | # compute output
678 | output, kl = model(images)
679 | cross_entropy_loss = criterion(output, target)
680 | scaled_kl = (kl.data[0] / len_trainset)
681 | elbo_loss = cross_entropy_loss + scaled_kl
682 | avu_loss = beta * avu_criterion(output, target, opt_th, type=0)
683 | loss = cross_entropy_loss + scaled_kl + avu_loss
684 |
685 | output = output.float()
686 | loss = loss.float()
687 |
688 | # measure accuracy and record loss
689 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
690 | losses.update(loss.item(), images.size(0))
691 | top1.update(acc1[0], images.size(0))
692 | top5.update(acc5[0], images.size(0))
693 |
694 | # measure elapsed time
695 | batch_time.update(time.time() - end)
696 | end = time.time()
697 |
698 | if i % args.print_freq == 0:
699 | progress.display(i)
700 |
701 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
702 | top5=top5))
703 |
704 | return top1.avg
705 |
706 |
707 | def evaluate(model, val_loader, args, corrupt=None, level=None):
708 | pred_probs_mc = []
709 | test_loss = 0
710 | correct = 0
711 | with torch.no_grad():
712 | pred_probs_mc = []
713 | output_list = []
714 | label_list = []
715 | for batch_idx, (data, target) in enumerate(val_loader):
716 | #print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape))
717 | if torch.cuda.is_available():
718 | data, target = data.cuda(), target.cuda()
719 | else:
720 | data, target = data.cpu(), target.cpu()
721 | output_mc = []
722 | output_mc_np = []
723 | for mc_run in range(args.num_monte_carlo):
724 | model.eval()
725 | output, _ = model.forward(data)
726 | #output_mc_np.append(output.data.cpu().numpy())
727 | pred_probs = torch.nn.functional.softmax(output, dim=1)
728 | output_mc_np.append(pred_probs.cpu().data.numpy())
729 |
730 | output_mc = torch.from_numpy(
731 | np.mean(np.asarray(output_mc_np), axis=0))
732 | output_list.append(output_mc)
733 | label_list.append(target)
734 |
735 | if torch.cuda.is_available():
736 | labels = torch.cat(label_list).cuda()
737 | probs = torch.cat(output_list).cuda()
738 | else:
739 | labels = torch.cat(label_list).cpu()
740 | probs = torch.cat(output_list).cpu()
741 |
742 | target_labels = labels.data.cpu().numpy()
743 | pred_mean = probs.data.cpu().numpy()
744 | Y_pred = np.argmax(pred_mean, axis=1)
745 | test_acc = (Y_pred == target_labels).mean()
746 | #brier = np.mean(calib.brier_scores(target_labels, probs=pred_mean))
747 | #ece = calib.expected_calibration_error_multiclass(pred_mean, target_labels)
748 | print('Test accuracy:', test_acc * 100)
749 | #print('Brier score: ', brier)
750 | #print('ECE: ', ece)
751 | if corrupt is not None:
752 | np.save(
753 | args.log_dir +
754 | '/preds/svi_avu_corrupt-static-{}-{}_probs.npy'.format(
755 | corrupt, level), pred_mean)
756 | np.save(
757 | args.log_dir +
758 | '/preds/svi_avu_corrupt-static-{}-{}_labels.npy'.format(
759 | corrupt, level), target_labels)
760 | print('saved predictions')
761 | else:
762 | np.save(args.log_dir + '/preds/svi_avu_test_probs.npy', pred_mean)
763 | np.save(args.log_dir + '/preds/svi_avu_test_labels.npy',
764 | target_labels)
765 | print('saved predictions')
766 |
767 | return test_acc
768 |
769 |
770 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
771 | torch.save(state, filename)
772 | if is_best:
773 | shutil.copyfile(filename, 'model_best.pth.tar')
774 |
775 |
776 | class AverageMeter(object):
777 | """Computes and stores the average and current value"""
778 | def __init__(self, name, fmt=':f'):
779 | self.name = name
780 | self.fmt = fmt
781 | self.reset()
782 |
783 | def reset(self):
784 | self.val = 0
785 | self.avg = 0
786 | self.sum = 0
787 | self.count = 0
788 |
789 | def update(self, val, n=1):
790 | self.val = val
791 | self.sum += val * n
792 | self.count += n
793 | self.avg = self.sum / self.count
794 |
795 | def __str__(self):
796 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
797 | return fmtstr.format(**self.__dict__)
798 |
799 |
800 | class ProgressMeter(object):
801 | def __init__(self, num_batches, meters, prefix=""):
802 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
803 | self.meters = meters
804 | self.prefix = prefix
805 |
806 | def display(self, batch):
807 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
808 | entries += [str(meter) for meter in self.meters]
809 | print('\t'.join(entries))
810 |
811 | def _get_batch_fmtstr(self, num_batches):
812 | num_digits = len(str(num_batches // 1))
813 | fmt = '{:' + str(num_digits) + 'd}'
814 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
815 |
816 |
817 | def adjust_learning_rate(optimizer, epoch, args):
818 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
819 | lr = args.lr * (0.1**(epoch // 30))
820 | for param_group in optimizer.param_groups:
821 | param_group['lr'] = lr
822 |
823 |
824 | def accuracy(output, target, topk=(1, )):
825 | """Computes the accuracy over the k top predictions for the specified values of k"""
826 | with torch.no_grad():
827 | maxk = max(topk)
828 | batch_size = target.size(0)
829 |
830 | _, pred = output.topk(maxk, 1, True, True)
831 | pred = pred.t()
832 | correct = pred.eq(target.view(1, -1).expand_as(pred))
833 |
834 | res = []
835 | for k in topk:
836 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
837 | res.append(correct_k.mul_(100.0 / batch_size))
838 | return res
839 |
840 |
841 | if __name__ == '__main__':
842 | main()
843 |
--------------------------------------------------------------------------------
/src/main_deterministic_cifar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import shutil
5 | import time
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 | import torch.utils.data
13 | from torch.utils.tensorboard import SummaryWriter
14 | import torchvision.transforms as transforms
15 | import torchvision.datasets as datasets
16 | import models.deterministic.resnet as resnet
17 | import numpy as np
18 |
19 | model_names = sorted(
20 | name for name in resnet.__dict__
21 | if name.islower() and not name.startswith("__")
22 | and name.startswith("resnet") and callable(resnet.__dict__[name]))
23 |
24 | print(model_names)
25 |
26 | parser = argparse.ArgumentParser(description='CIFAR10')
27 | parser.add_argument('--arch',
28 | '-a',
29 | metavar='ARCH',
30 | default='resnet20',
31 | choices=model_names,
32 | help='model architecture: ' + ' | '.join(model_names) +
33 | ' (default: resnet20)')
34 | parser.add_argument('-j',
35 | '--workers',
36 | default=8,
37 | type=int,
38 | metavar='N',
39 | help='number of data loading workers (default: 8)')
40 | parser.add_argument('--epochs',
41 | default=200,
42 | type=int,
43 | metavar='N',
44 | help='number of total epochs to run')
45 | parser.add_argument('--start-epoch',
46 | default=0,
47 | type=int,
48 | metavar='N',
49 | help='manual epoch number (useful on restarts)')
50 | parser.add_argument('-b',
51 | '--batch-size',
52 | default=512,
53 | type=int,
54 | metavar='N',
55 | help='mini-batch size (default: 512)')
56 | parser.add_argument('--lr',
57 | '--learning-rate',
58 | default=0.1,
59 | type=float,
60 | metavar='LR',
61 | help='initial learning rate')
62 | parser.add_argument('--momentum',
63 | default=0.9,
64 | type=float,
65 | metavar='M',
66 | help='momentum')
67 | parser.add_argument('--weight-decay',
68 | '--wd',
69 | default=1e-4,
70 | type=float,
71 | metavar='W',
72 | help='weight decay (default: 5e-4)')
73 | parser.add_argument('--print-freq',
74 | '-p',
75 | default=50,
76 | type=int,
77 | metavar='N',
78 | help='print frequency (default: 20)')
79 | parser.add_argument('--resume',
80 | default='',
81 | type=str,
82 | metavar='PATH',
83 | help='path to latest checkpoint (default: none)')
84 | parser.add_argument('-e',
85 | '--evaluate',
86 | dest='evaluate',
87 | action='store_true',
88 | help='evaluate model on validation set')
89 | parser.add_argument('--pretrained',
90 | dest='pretrained',
91 | action='store_true',
92 | help='use pre-trained model')
93 | parser.add_argument('--half',
94 | dest='half',
95 | action='store_true',
96 | help='use half-precision(16-bit) ')
97 | parser.add_argument('--save-dir',
98 | dest='save_dir',
99 | help='The directory used to save the trained models',
100 | default='./checkpoint/deterministic',
101 | type=str)
102 | parser.add_argument(
103 | '--save-every',
104 | dest='save_every',
105 | help='Saves checkpoints at every specified number of epochs',
106 | type=int,
107 | default=10)
108 | parser.add_argument('--mode', type=str, required=True, help='train | test')
109 | parser.add_argument(
110 | '--tensorboard',
111 | type=bool,
112 | default=True,
113 | metavar='N',
114 | help='use tensorboard for logging and visualization of training progress')
115 | parser.add_argument(
116 | '--log_dir',
117 | type=str,
118 | default='./logs/cifar/deterministic',
119 | metavar='N',
120 | help='use tensorboard for logging and visualization of training progress')
121 | best_prec1 = 0
122 |
123 |
124 | def main():
125 | global args, best_prec1
126 | args = parser.parse_args()
127 |
128 | # Check the save_dir exists or not
129 | if not os.path.exists(args.save_dir):
130 | os.makedirs(args.save_dir)
131 |
132 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
133 | if torch.cuda.is_available():
134 | model.cuda()
135 | else:
136 | model.cpu()
137 |
138 | # optionally resume from a checkpoint
139 | if args.resume:
140 | if os.path.isfile(args.resume):
141 | print("=> loading checkpoint '{}'".format(args.resume))
142 | checkpoint = torch.load(args.resume)
143 | args.start_epoch = checkpoint['epoch']
144 | best_prec1 = checkpoint['best_prec1']
145 | model.load_state_dict(checkpoint['state_dict'])
146 | print("=> loaded checkpoint '{}' (epoch {})".format(
147 | args.evaluate, checkpoint['epoch']))
148 | else:
149 | print("=> no checkpoint found at '{}'".format(args.resume))
150 |
151 | cudnn.benchmark = True
152 |
153 | tb_writer = None
154 | if args.tensorboard:
155 | logger_dir = os.path.join(args.log_dir, 'tb_logger')
156 | if not os.path.exists(logger_dir):
157 | os.makedirs(logger_dir)
158 | tb_writer = SummaryWriter(logger_dir)
159 |
160 | preds_dir = os.path.join(args.log_dir, 'preds')
161 | if not os.path.exists(preds_dir):
162 | os.makedirs(preds_dir)
163 | results_dir = os.path.join(args.log_dir, 'results')
164 | if not os.path.exists(results_dir):
165 | os.makedirs(results_dir)
166 |
167 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
168 | std=[0.229, 0.224, 0.225])
169 |
170 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
171 | root='./data',
172 | train=True,
173 | transform=transforms.Compose([
174 | transforms.RandomHorizontalFlip(),
175 | transforms.RandomCrop(32, 4),
176 | transforms.ToTensor(),
177 | ]),
178 | download=True),
179 | batch_size=args.batch_size,
180 | shuffle=True,
181 | num_workers=args.workers,
182 | pin_memory=True)
183 |
184 | val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
185 | root='./data',
186 | train=False,
187 | transform=transforms.Compose([
188 | transforms.ToTensor(),
189 | ])),
190 | batch_size=args.batch_size,
191 | shuffle=False,
192 | num_workers=args.workers,
193 | pin_memory=True)
194 |
195 | if not os.path.exists(args.save_dir):
196 | os.makedirs(args.save_dir)
197 |
198 | if torch.cuda.is_available():
199 | criterion = nn.CrossEntropyLoss().cuda()
200 | else:
201 | criterion = nn.CrossEntropyLoss().cpu()
202 |
203 | if args.half:
204 | model.half()
205 | criterion.half()
206 |
207 | optimizer = torch.optim.SGD(model.parameters(),
208 | args.lr,
209 | momentum=args.momentum,
210 | weight_decay=args.weight_decay)
211 |
212 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
213 | optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)
214 |
215 | if args.arch in ['resnet110']:
216 | for param_group in optimizer.param_groups:
217 | param_group['lr'] = args.lr * 0.1
218 |
219 | if args.evaluate:
220 | validate(val_loader, model, criterion)
221 | return
222 |
223 | if args.mode == 'train':
224 | for epoch in range(args.start_epoch, args.epochs):
225 |
226 | # train for one epoch
227 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
228 | train(train_loader, model, criterion, optimizer, epoch, tb_writer)
229 | lr_scheduler.step()
230 |
231 | prec1 = validate(val_loader, model, criterion, epoch, tb_writer)
232 |
233 | is_best = prec1 > best_prec1
234 | best_prec1 = max(prec1, best_prec1)
235 |
236 | if epoch > 0 and epoch % args.save_every == 0:
237 | if is_best:
238 | save_checkpoint(
239 | {
240 | 'epoch': epoch + 1,
241 | 'state_dict': model.state_dict(),
242 | 'best_prec1': best_prec1,
243 | },
244 | is_best,
245 | filename=os.path.join(
246 | args.save_dir,
247 | 'det_{}_cifar.pth'.format(args.arch)))
248 |
249 | elif args.mode == 'test':
250 | checkpoint_file = args.save_dir + '/det_{}_cifar.pth'.format(args.arch)
251 | if torch.cuda.is_available():
252 | checkpoint = torch.load(checkpoint_file)
253 | else:
254 | checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
255 | model.load_state_dict(checkpoint['state_dict'])
256 | evaluate(model, val_loader)
257 |
258 |
259 | def train(train_loader, model, criterion, optimizer, epoch, tb_writer=None):
260 | """
261 | Run one train epoch
262 | """
263 | batch_time = AverageMeter()
264 | data_time = AverageMeter()
265 | losses = AverageMeter()
266 | top1 = AverageMeter()
267 |
268 | # switch to train mode
269 | model.train()
270 |
271 | end = time.time()
272 | for i, (input, target) in enumerate(train_loader):
273 |
274 | # measure data loading time
275 | data_time.update(time.time() - end)
276 |
277 | if torch.cuda.is_available():
278 | target = target.cuda()
279 | input_var = input.cuda()
280 | target_var = target.cuda()
281 | else:
282 | target = target.cpu()
283 | input_var = input.cpu()
284 | target_var = target.cpu()
285 | if args.half:
286 | input_var = input_var.half()
287 |
288 | # compute output
289 | output = model(input_var)
290 | loss = criterion(output, target_var)
291 |
292 | # compute gradient and do SGD step
293 | optimizer.zero_grad()
294 | loss.backward()
295 | optimizer.step()
296 |
297 | output = output.float()
298 | loss = loss.float()
299 | # measure accuracy and record loss
300 | prec1 = accuracy(output.data, target)[0]
301 | losses.update(loss.item(), input.size(0))
302 | top1.update(prec1.item(), input.size(0))
303 |
304 | # measure elapsed time
305 | batch_time.update(time.time() - end)
306 | end = time.time()
307 |
308 | if i % args.print_freq == 0:
309 | print('Epoch: [{0}][{1}/{2}]\t'
310 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
311 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
312 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
313 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
314 | epoch,
315 | i,
316 | len(train_loader),
317 | batch_time=batch_time,
318 | data_time=data_time,
319 | loss=losses,
320 | top1=top1))
321 |
322 | if tb_writer is not None:
323 | tb_writer.add_scalar('train/loss', loss.item(), epoch)
324 | tb_writer.add_scalar('train/accuracy', prec1.item(), epoch)
325 | tb_writer.flush()
326 |
327 |
328 | def validate(val_loader, model, criterion, epoch, tb_writer=None):
329 | """
330 | Run evaluation
331 | """
332 | batch_time = AverageMeter()
333 | losses = AverageMeter()
334 | top1 = AverageMeter()
335 |
336 | # switch to evaluate mode
337 | model.eval()
338 |
339 | end = time.time()
340 | with torch.no_grad():
341 | for i, (input, target) in enumerate(val_loader):
342 | if torch.cuda.is_available():
343 | target = target.cuda()
344 | input_var = input.cuda()
345 | target_var = target.cuda()
346 | else:
347 | target = target.cpu()
348 | input_var = input.cpu()
349 | target_var = target.cpu()
350 |
351 | if args.half:
352 | input_var = input_var.half()
353 |
354 | # compute output
355 | output = model(input_var)
356 | loss = criterion(output, target_var)
357 |
358 | output = output.float()
359 | loss = loss.float()
360 |
361 | # measure accuracy and record loss
362 | prec1 = accuracy(output.data, target)[0]
363 | losses.update(loss.item(), input.size(0))
364 | top1.update(prec1.item(), input.size(0))
365 |
366 | # measure elapsed time
367 | batch_time.update(time.time() - end)
368 | end = time.time()
369 |
370 | if i % args.print_freq == 0:
371 | print('Test: [{0}/{1}]\t'
372 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
373 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
374 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
375 | i,
376 | len(val_loader),
377 | batch_time=batch_time,
378 | loss=losses,
379 | top1=top1))
380 |
381 | if tb_writer is not None:
382 | tb_writer.add_scalar('val/loss', loss.item(), epoch)
383 | tb_writer.add_scalar('val/accuracy', prec1.item(), epoch)
384 | tb_writer.flush()
385 |
386 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
387 |
388 | return top1.avg
389 |
390 |
391 | def evaluate(model, val_loader):
392 | model.eval()
393 | correct = 0
394 | with torch.no_grad():
395 | for data, target in val_loader:
396 | if torch.cuda.is_available():
397 | data, target = data.cuda(), target.cuda()
398 | else:
399 | data, target = data.cpu(), target.cpu()
400 | output = model(data)
401 | output = torch.nn.functional.softmax(output, dim=1)
402 | pred = output.argmax(
403 | dim=1,
404 | keepdim=True) # get the index of the max log-probability
405 | correct += pred.eq(target.view_as(pred)).sum().item()
406 |
407 | print('\nTest set: Accuracy: {:.2f}%\n'.format(100. * correct /
408 | len(val_loader.dataset)))
409 | target_labels = target.cpu().data.numpy()
410 | np.save('./probs_cifar_det.npy', output.cpu().data.numpy())
411 | np.save('./cifar_test_labels.npy', target_labels)
412 |
413 |
414 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
415 | """
416 | Save the training model
417 | """
418 | torch.save(state, filename)
419 |
420 |
421 | class AverageMeter(object):
422 | """Computes and stores the average and current value"""
423 | def __init__(self):
424 | self.reset()
425 |
426 | def reset(self):
427 | self.val = 0
428 | self.avg = 0
429 | self.sum = 0
430 | self.count = 0
431 |
432 | def update(self, val, n=1):
433 | self.val = val
434 | self.sum += val * n
435 | self.count += n
436 | self.avg = self.sum / self.count
437 |
438 |
439 | def accuracy(output, target, topk=(1, )):
440 | """Computes the precision@k for the specified values of k"""
441 | maxk = max(topk)
442 | batch_size = target.size(0)
443 |
444 | _, pred = output.topk(maxk, 1, True, True)
445 | pred = pred.t()
446 | correct = pred.eq(target.view(1, -1).expand_as(pred))
447 |
448 | res = []
449 | for k in topk:
450 | correct_k = correct[:k].view(-1).float().sum(0)
451 | res.append(correct_k.mul_(100.0 / batch_size))
452 | return res
453 |
454 |
455 | if __name__ == '__main__':
456 | main()
457 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2020 Intel Corporation
2 | #
3 | # BSD-3-Clause License
4 | #
5 | # Redistribution and use in source and binary forms, with or without modification,
6 | # are permitted provided that the following conditions are met:
7 | # 1. Redistributions of source code must retain the above copyright notice,
8 | # this list of conditions and the following disclaimer.
9 | # 2. Redistributions in binary form must reproduce the above copyright notice,
10 | # this list of conditions and the following disclaimer in the documentation
11 | # and/or other materials provided with the distribution.
12 | # 3. Neither the name of the copyright holder nor the names of its contributors
13 | # may be used to endorse or promote products derived from this software
14 | # without specific prior written permission.
15 | #
16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
20 | # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
21 | # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
22 | # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
23 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
25 | # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
26 | # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 | #
28 | # Utily functions for variational inference in Bayesian deep neural networks
29 | # ELBO_loss -> to compute evidence lower bound loss
30 | # get_rho -> variance (sigma) is represented by softplus function 'sigma = log(1 + exp(rho))'
31 | # to make sure it remains always positive and non-transformed 'rho' gets
32 | # updated during training.
33 | # MOPED -> set the priors and initialize approximate variational posteriors of Bayesian NN
34 | # with Empirical Bayes
35 | #
36 | # @authors: Ranganath Krishnan
37 | #
38 | # ===============================================================================================
39 |
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 | import torch.nn.functional as F
44 | import torch
45 | from torch import nn
46 | import numpy as np
47 | from sklearn.metrics import auc
48 |
49 |
50 | def ELBO_loss(out, y, kl_loss, num_data_samples, batch_size):
51 | nll_loss = F.cross_entropy(out, y)
52 | return nll_loss + ((1.0 / num_data_samples) * kl_loss)
53 |
54 |
55 | def get_rho(sigma, delta):
56 | rho = torch.log(torch.expm1(delta * torch.abs(sigma)) + 1e-20)
57 | return rho
58 |
59 |
60 | def entropy(prob):
61 | return -1 * np.sum(prob * np.log(prob + 1e-15), axis=-1)
62 |
63 |
64 | def predictive_entropy(mc_preds):
65 | """
66 | Compute the entropy of the mean of the predictive distribution
67 | obtained from Monte Carlo sampling during prediction phase.
68 | """
69 | return entropy(np.mean(mc_preds, axis=0))
70 |
71 |
72 | def mutual_information(mc_preds):
73 | """
74 | Compute the difference between the entropy of the mean of the
75 | predictive distribution and the mean of the entropy.
76 | """
77 | MI = entropy(np.mean(mc_preds, axis=0)) - np.mean(entropy(mc_preds),
78 | axis=0)
79 | return MI
80 |
81 |
82 | def MOPED(model, det_model, det_checkpoint, delta):
83 | """
84 | Set the priors and initialize surrogate posteriors of Bayesian NN with Empirical Bayes
85 | MOPED (Model Priors with Empirical Bayes using Deterministic DNN)
86 | Ref: https://arxiv.org/abs/1906.05323
87 | 'Specifying Weight Priors in Bayesian Deep Neural Networks with Empirical Bayes'. AAAI 2020.
88 | """
89 | det_model.load_state_dict(torch.load(det_checkpoint))
90 | for (idx, layer), (det_idx,
91 | det_layer) in zip(enumerate(model.modules()),
92 | enumerate(det_model.modules())):
93 | if (str(layer) == 'Conv1dVariational()'
94 | or str(layer) == 'Conv2dVariational()'
95 | or str(layer) == 'Conv3dVariational()'
96 | or str(layer) == 'ConvTranspose1dVariational()'
97 | or str(layer) == 'ConvTranspose2dVariational()'
98 | or str(layer) == 'ConvTranspose3dVariational()'):
99 | #set the priors
100 | layer.prior_weight_mu.data = det_layer.weight
101 | layer.prior_bias_mu.data = det_layer.bias
102 |
103 | #initialize surrogate posteriors
104 | layer.mu_kernel.data = det_layer.weight
105 | layer.rho_kernel.data = get_rho(det_layer.weight.data, delta)
106 | layer.mu_bias.data = det_layer.bias
107 | layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
108 | elif (str(layer) == 'LinearVariational()'):
109 | #set the priors
110 | layer.prior_weight_mu.data = det_layer.weight
111 | layer.prior_bias_mu.data = det_layer.bias
112 |
113 | #initialize the surrogate posteriors
114 | layer.mu_weight.data = det_layer.weight
115 | layer.rho_weight.data = get_rho(det_layer.weight.data, delta)
116 | layer.mu_bias.data = det_layer.bias
117 | layer.rho_bias.data = get_rho(det_layer.bias.data, delta)
118 | elif str(layer).startswith('Batch'):
119 | #initialize parameters
120 | layer.weight.data = det_layer.weight
121 | layer.bias.data = det_layer.bias
122 | layer.running_mean.data = det_layer.running_mean
123 | layer.running_var.data = det_layer.running_var
124 | layer.num_batches_tracked.data = det_layer.num_batches_tracked
125 |
126 | model.state_dict()
127 | return model
128 |
129 |
130 | def eval_avu(pred_label, true_label, uncertainty):
131 | """ returns AvU at various uncertainty thresholds"""
132 | t_list = np.linspace(0, 1, 21)
133 | umin = np.amin(uncertainty, axis=0)
134 | umax = np.amax(uncertainty, axis=0)
135 | avu_list = []
136 | unc_list = []
137 | for t in t_list:
138 | u_th = umin + (t * (umax - umin))
139 | n_ac = 0
140 | n_ic = 0
141 | n_au = 0
142 | n_iu = 0
143 | for i in range(len(true_label)):
144 | if ((true_label[i] == pred_label[i]) and uncertainty[i] <= u_th):
145 | n_ac += 1
146 | elif ((true_label[i] == pred_label[i]) and uncertainty[i] > u_th):
147 | n_au += 1
148 | elif ((true_label[i] != pred_label[i]) and uncertainty[i] <= u_th):
149 | n_ic += 1
150 | elif ((true_label[i] != pred_label[i]) and uncertainty[i] > u_th):
151 | n_iu += 1
152 |
153 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu + 1e-15)
154 | avu_list.append(AvU)
155 | unc_list.append(u_th)
156 | return np.asarray(avu_list), np.asarray(unc_list)
157 |
158 |
159 | def accuracy_vs_uncertainty(pred_label, true_label, uncertainty,
160 | optimal_threshold):
161 |
162 | n_ac = 0
163 | n_ic = 0
164 | n_au = 0
165 | n_iu = 0
166 | for i in range(len(true_label)):
167 | if ((true_label[i] == pred_label[i])
168 | and uncertainty[i] <= optimal_threshold):
169 | n_ac += 1
170 | elif ((true_label[i] == pred_label[i])
171 | and uncertainty[i] > optimal_threshold):
172 | n_au += 1
173 | elif ((true_label[i] != pred_label[i])
174 | and uncertainty[i] <= optimal_threshold):
175 | n_ic += 1
176 | elif ((true_label[i] != pred_label[i])
177 | and uncertainty[i] > optimal_threshold):
178 | n_iu += 1
179 | AvU = (n_ac + n_iu) / (n_ac + n_au + n_ic + n_iu)
180 |
181 | return AvU
182 |
--------------------------------------------------------------------------------
/variational_layers/conv_variational.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2020 Intel Corporation
2 | #
3 | # BSD-3-Clause License
4 | #
5 | # Redistribution and use in source and binary forms, with or without modification,
6 | # are permitted provided that the following conditions are met:
7 | # 1. Redistributions of source code must retain the above copyright notice,
8 | # this list of conditions and the following disclaimer.
9 | # 2. Redistributions in binary form must reproduce the above copyright notice,
10 | # this list of conditions and the following disclaimer in the documentation
11 | # and/or other materials provided with the distribution.
12 | # 3. Neither the name of the copyright holder nor the names of its contributors
13 | # may be used to endorse or promote products derived from this software
14 | # without specific prior written permission.
15 | #
16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
20 | # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
21 | # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
22 | # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
23 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
25 | # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
26 | # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 | #
28 | # Convolutional Variational Layers with reparameterization estimator to perform
29 | # mean-field variational inference in Bayesian neural networks. Variational layers
30 | # enables Monte Carlo approximation of the distribution over 'kernel' and 'bias'.
31 | #
32 | # Kullback-Leibler divergence between the surrogate posterior and prior is computed
33 | # and returned along with the tensors of outputs after convolution operation, which is
34 | # required to compute Evidence Lower Bound (ELBO) loss for variational inference.
35 | #
36 | # @authors: Ranganath Krishnan
37 | #
38 | # ======================================================================================
39 |
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 |
44 | import torch
45 | import torch.nn as nn
46 | import torch.nn.functional as F
47 | from torch.nn import Module, Parameter
48 | import math
49 |
50 | __all__ = [
51 | 'Conv1dVariational',
52 | 'Conv2dVariational',
53 | 'Conv3dVariational',
54 | 'ConvTranspose1dVariational',
55 | 'ConvTranspose2dVariational',
56 | 'ConvTranspose3dVariational',
57 | ]
58 |
59 |
60 | class Conv1dVariational(Module):
61 | def __init__(self,
62 | prior_mean,
63 | prior_variance,
64 | posterior_mu_init,
65 | posterior_rho_init,
66 | in_channels,
67 | out_channels,
68 | kernel_size,
69 | stride=1,
70 | padding=0,
71 | dilation=1,
72 | groups=1,
73 | bias=True):
74 |
75 | super(Conv1dVariational, self).__init__()
76 | if in_channels % groups != 0:
77 | raise ValueError('invalid in_channels size')
78 | if out_channels % groups != 0:
79 | raise ValueError('invalid in_channels size')
80 |
81 | self.in_channels = in_channels
82 | self.out_channels = out_channels
83 | self.kernel_size = kernel_size
84 | self.stride = stride
85 | self.padding = padding
86 | self.dilation = dilation
87 | self.groups = groups
88 | self.prior_mean = prior_mean
89 | self.prior_variance = prior_variance
90 | self.posterior_mu_init = posterior_mu_init, # mean of weight
91 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
92 | self.bias = bias
93 |
94 | self.mu_kernel = Parameter(
95 | torch.Tensor(out_channels, in_channels // groups, kernel_size))
96 | self.rho_kernel = Parameter(
97 | torch.Tensor(out_channels, in_channels // groups, kernel_size))
98 | self.register_buffer(
99 | 'eps_kernel',
100 | torch.Tensor(out_channels, in_channels // groups, kernel_size))
101 | self.register_buffer(
102 | 'prior_weight_mu',
103 | torch.Tensor(out_channels, in_channels // groups, kernel_size))
104 | self.register_buffer(
105 | 'prior_weight_sigma',
106 | torch.Tensor(out_channels, in_channels // groups, kernel_size))
107 |
108 | if self.bias:
109 | self.mu_bias = Parameter(torch.Tensor(out_channels))
110 | self.rho_bias = Parameter(torch.Tensor(out_channels))
111 | self.register_buffer('eps_bias', torch.Tensor(out_channels))
112 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels))
113 | self.register_buffer('prior_bias_sigma',
114 | torch.Tensor(out_channels))
115 | else:
116 | self.register_parameter('mu_bias', None)
117 | self.register_parameter('rho_bias', None)
118 | self.register_buffer('eps_bias', None)
119 | self.register_buffer('prior_bias_mu', None)
120 | self.register_buffer('prior_bias_sigma', None)
121 |
122 | self.init_parameters()
123 |
124 | def init_parameters(self):
125 | self.prior_weight_mu.data.fill_(self.prior_mean)
126 | self.prior_weight_sigma.fill_(self.prior_variance)
127 |
128 | self.mu_kernel.data.normal_(std=0.1)
129 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
130 | if self.bias:
131 | self.prior_bias_mu.data.fill_(self.prior_mean)
132 | self.prior_bias_sigma.fill_(self.prior_variance)
133 |
134 | self.mu_bias.data.normal_(std=0.1)
135 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
136 | std=0.1)
137 |
138 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
139 | kl = torch.log(sigma_p + 1e-15) - torch.log(
140 | sigma_q + 1e-15) + (sigma_q**2 +
141 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5
142 | return kl.sum()
143 |
144 | def forward(self, input):
145 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
146 | eps_kernel = self.eps_kernel.normal_()
147 | weight = self.mu_kernel + (sigma_weight * eps_kernel)
148 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
149 | self.prior_weight_mu, self.prior_weight_sigma)
150 | bias = None
151 |
152 | if self.bias:
153 | sigma_bias = torch.log1p(torch.exp(self.rho_bias))
154 | eps_bias = self.eps_bias.normal_()
155 | bias = self.mu_bias + (sigma_bias * eps_bias)
156 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
157 | self.prior_bias_sigma)
158 |
159 | out = F.conv1d(input, weight, bias, self.stride, self.padding,
160 | self.dilation, self.groups)
161 | if self.bias:
162 | kl = kl_weight + kl_bias
163 | else:
164 | kl = kl_weight
165 |
166 | return out, kl
167 |
168 |
169 | class Conv2dVariational(Module):
170 | def __init__(self,
171 | prior_mean,
172 | prior_variance,
173 | posterior_mu_init,
174 | posterior_rho_init,
175 | in_channels,
176 | out_channels,
177 | kernel_size,
178 | stride=1,
179 | padding=0,
180 | dilation=1,
181 | groups=1,
182 | bias=True):
183 |
184 | super(Conv2dVariational, self).__init__()
185 | if in_channels % groups != 0:
186 | raise ValueError('invalid in_channels size')
187 | if out_channels % groups != 0:
188 | raise ValueError('invalid in_channels size')
189 |
190 | self.in_channels = in_channels
191 | self.out_channels = out_channels
192 | self.kernel_size = kernel_size
193 | self.stride = stride
194 | self.padding = padding
195 | self.dilation = dilation
196 | self.groups = groups
197 | self.prior_mean = prior_mean
198 | self.prior_variance = prior_variance
199 | self.posterior_mu_init = posterior_mu_init, # mean of weight
200 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
201 | self.bias = bias
202 |
203 | self.mu_kernel = Parameter(
204 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
205 | kernel_size))
206 | self.rho_kernel = Parameter(
207 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
208 | kernel_size))
209 | self.register_buffer(
210 | 'eps_kernel',
211 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
212 | kernel_size))
213 | self.register_buffer(
214 | 'prior_weight_mu',
215 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
216 | kernel_size))
217 | self.register_buffer(
218 | 'prior_weight_sigma',
219 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
220 | kernel_size))
221 |
222 | if self.bias:
223 | self.mu_bias = Parameter(torch.Tensor(out_channels))
224 | self.rho_bias = Parameter(torch.Tensor(out_channels))
225 | self.register_buffer('eps_bias', torch.Tensor(out_channels))
226 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels))
227 | self.register_buffer('prior_bias_sigma',
228 | torch.Tensor(out_channels))
229 | else:
230 | self.register_parameter('mu_bias', None)
231 | self.register_parameter('rho_bias', None)
232 | self.register_buffer('eps_bias', None)
233 | self.register_buffer('prior_bias_mu', None)
234 | self.register_buffer('prior_bias_sigma', None)
235 |
236 | self.init_parameters()
237 |
238 | def init_parameters(self):
239 | self.prior_weight_mu.fill_(self.prior_mean)
240 | self.prior_weight_sigma.fill_(self.prior_variance)
241 |
242 | self.mu_kernel.data.normal_(std=0.1)
243 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
244 | if self.bias:
245 | self.prior_bias_mu.fill_(self.prior_mean)
246 | self.prior_bias_sigma.fill_(self.prior_variance)
247 |
248 | self.mu_bias.data.normal_(std=0.1)
249 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
250 | std=0.1)
251 |
252 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
253 | kl = torch.log(sigma_p + 1e-15) - torch.log(
254 | sigma_q + 1e-15) + (sigma_q**2 +
255 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5
256 | return kl.sum()
257 |
258 | def forward(self, input):
259 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
260 | eps_kernel = self.eps_kernel.normal_()
261 | weight = self.mu_kernel + (sigma_weight * eps_kernel)
262 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
263 | self.prior_weight_mu, self.prior_weight_sigma)
264 | bias = None
265 |
266 | if self.bias:
267 | sigma_bias = torch.log1p(torch.exp(self.rho_bias))
268 | eps_bias = self.eps_bias.normal_()
269 | bias = self.mu_bias + (sigma_bias * eps_bias)
270 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
271 | self.prior_bias_sigma)
272 |
273 | out = F.conv2d(input, weight, bias, self.stride, self.padding,
274 | self.dilation, self.groups)
275 | if self.bias:
276 | kl = kl_weight + kl_bias
277 | else:
278 | kl = kl_weight
279 |
280 | return out, kl
281 |
282 |
283 | class Conv3dVariational(Module):
284 | def __init__(self,
285 | prior_mean,
286 | prior_variance,
287 | posterior_mu_init,
288 | posterior_rho_init,
289 | in_channels,
290 | out_channels,
291 | kernel_size,
292 | stride=1,
293 | padding=0,
294 | dilation=1,
295 | groups=1,
296 | bias=True):
297 |
298 | super(Conv3dVariational, self).__init__()
299 | if in_channels % groups != 0:
300 | raise ValueError('invalid in_channels size')
301 | if out_channels % groups != 0:
302 | raise ValueError('invalid in_channels size')
303 |
304 | self.in_channels = in_channels
305 | self.out_channels = out_channels
306 | self.kernel_size = kernel_size
307 | self.stride = stride
308 | self.padding = padding
309 | self.dilation = dilation
310 | self.groups = groups
311 | self.prior_mean = prior_mean
312 | self.prior_variance = prior_variance
313 | self.posterior_mu_init = posterior_mu_init, # mean of weight
314 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
315 | self.bias = bias
316 |
317 | self.mu_kernel = Parameter(
318 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
319 | kernel_size, kernel_size))
320 | self.rho_kernel = Parameter(
321 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
322 | kernel_size, kernel_size))
323 | self.register_buffer(
324 | 'eps_kernel',
325 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
326 | kernel_size, kernel_size))
327 | self.register_buffer(
328 | 'prior_weight_mu',
329 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
330 | kernel_size, kernel_size))
331 | self.register_buffer(
332 | 'prior_weight_sigma',
333 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
334 | kernel_size, kernel_size))
335 |
336 | if self.bias:
337 | self.mu_bias = Parameter(torch.Tensor(out_channels))
338 | self.rho_bias = Parameter(torch.Tensor(out_channels))
339 | self.register_buffer('eps_bias', torch.Tensor(out_channels))
340 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels))
341 | self.register_buffer('prior_bias_sigma',
342 | torch.Tensor(out_channels))
343 | else:
344 | self.register_parameter('mu_bias', None)
345 | self.register_parameter('rho_bias', None)
346 | self.register_buffer('eps_bias', None)
347 | self.register_buffer('prior_bias_mu', None)
348 | self.register_buffer('prior_bias_sigma', None)
349 |
350 | self.init_parameters()
351 |
352 | def init_parameters(self):
353 | self.prior_weight_mu.fill_(self.prior_mean)
354 | self.prior_weight_sigma.fill_(self.prior_variance)
355 |
356 | self.mu_kernel.data.normal_(std=0.1)
357 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
358 | if self.bias:
359 | self.prior_bias_mu.fill_(self.prior_mean)
360 | self.prior_bias_sigma.fill_(self.prior_variance)
361 |
362 | self.mu_bias.data.normal_(std=0.1)
363 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
364 | std=0.1)
365 |
366 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
367 | kl = torch.log(sigma_p + 1e-15) - torch.log(
368 | sigma_q + 1e-15) + (sigma_q**2 +
369 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5
370 | return kl.sum()
371 |
372 | def forward(self, input):
373 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
374 | eps_kernel = self.eps_kernel.normal_()
375 | weight = self.mu_kernel + (sigma_weight * eps_kernel)
376 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
377 | self.prior_weight_mu, self.prior_weight_sigma)
378 | bias = None
379 |
380 | if self.bias:
381 | sigma_bias = torch.log1p(torch.exp(self.rho_bias))
382 | eps_bias = self.eps_bias.normal_()
383 | bias = self.mu_bias + (sigma_bias * eps_bias)
384 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
385 | self.prior_bias_sigma)
386 |
387 | out = F.conv3d(input, weight, bias, self.stride, self.padding,
388 | self.dilation, self.groups)
389 | if self.bias:
390 | kl = kl_weight + kl_bias
391 | else:
392 | kl = kl_weight
393 |
394 | return out, kl
395 |
396 |
397 | class ConvTranspose1dVariational(Module):
398 | def __init__(self,
399 | prior_mean,
400 | prior_variance,
401 | posterior_mu_init,
402 | posterior_rho_init,
403 | in_channels,
404 | out_channels,
405 | kernel_size,
406 | stride=1,
407 | padding=0,
408 | output_padding=0,
409 | dilation=1,
410 | groups=1,
411 | bias=True):
412 |
413 | super(ConvTranspose1dVariational, self).__init__()
414 | if in_channels % groups != 0:
415 | raise ValueError('invalid in_channels size')
416 | if out_channels % groups != 0:
417 | raise ValueError('invalid in_channels size')
418 |
419 | self.in_channels = in_channels
420 | self.out_channels = out_channels
421 | self.kernel_size = kernel_size
422 | self.stride = stride
423 | self.padding = padding
424 | self.output_padding = output_padding
425 | self.dilation = dilation
426 | self.groups = groups
427 | self.prior_mean = prior_mean
428 | self.prior_variance = prior_variance
429 | self.posterior_mu_init = posterior_mu_init, # mean of weight
430 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
431 | self.bias = bias
432 |
433 | self.mu_kernel = Parameter(
434 | torch.Tensor(in_channels, out_channels // groups, kernel_size))
435 | self.rho_kernel = Parameter(
436 | torch.Tensor(in_channels, out_channels // groups, kernel_size))
437 | self.register_buffer(
438 | 'eps_kernel',
439 | torch.Tensor(in_channels, out_channels // groups, kernel_size))
440 | self.register_buffer(
441 | 'prior_weight_mu',
442 | torch.Tensor(in_channels, out_channels // groups, kernel_size))
443 | self.register_buffer(
444 | 'prior_weight_sigma',
445 | torch.Tensor(in_channels, out_channels // groups, kernel_size))
446 |
447 | if self.bias:
448 | self.mu_bias = Parameter(torch.Tensor(out_channels))
449 | self.rho_bias = Parameter(torch.Tensor(out_channels))
450 | self.register_buffer('eps_bias', torch.Tensor(out_channels))
451 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels))
452 | self.register_buffer('prior_bias_sigma',
453 | torch.Tensor(out_channels))
454 | else:
455 | self.register_parameter('mu_bias', None)
456 | self.register_parameter('rho_bias', None)
457 | self.register_buffer('eps_bias', None)
458 | self.register_buffer('prior_bias_mu', None)
459 | self.register_buffer('prior_bias_sigma', None)
460 |
461 | self.init_parameters()
462 |
463 | def init_parameters(self):
464 | self.prior_weight_mu.fill_(self.prior_mean)
465 | self.prior_weight_sigma.fill_(self.prior_variance)
466 |
467 | self.mu_kernel.data.normal_(std=0.1)
468 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
469 | if self.bias:
470 | self.prior_bias_mu.fill_(self.prior_mean)
471 | self.prior_bias_sigma.fill_(self.prior_variance)
472 |
473 | self.mu_bias.data.normal_(std=0.1)
474 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
475 | std=0.1)
476 |
477 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
478 | kl = torch.log(sigma_p + 1e-15) - torch.log(
479 | sigma_q + 1e-15) + (sigma_q**2 +
480 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5
481 | return kl.sum()
482 |
483 | def forward(self, input):
484 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
485 | eps_kernel = self.eps_kernel.normal_()
486 | weight = self.mu_kernel + (sigma_weight * eps_kernel)
487 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
488 | self.prior_weight_mu, self.prior_weight_sigma)
489 | bias = None
490 |
491 | if self.bias:
492 | sigma_bias = torch.log1p(torch.exp(self.rho_bias))
493 | eps_bias = self.eps_bias.normal_()
494 | bias = self.mu_bias + (sigma_bias * eps_bias)
495 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
496 | self.prior_bias_sigma)
497 |
498 | out = F.conv_transpose1d(input, weight, bias, self.stride,
499 | self.padding, self.output_padding,
500 | self.dilation, self.groups)
501 | if self.bias:
502 | kl = kl_weight + kl_bias
503 | else:
504 | kl = kl_weight
505 |
506 | return out, kl
507 |
508 |
509 | class ConvTranspose2dVariational(Module):
510 | def __init__(self,
511 | prior_mean,
512 | prior_variance,
513 | posterior_mu_init,
514 | posterior_rho_init,
515 | in_channels,
516 | out_channels,
517 | kernel_size,
518 | stride=1,
519 | padding=0,
520 | output_padding=0,
521 | dilation=1,
522 | groups=1,
523 | bias=True):
524 |
525 | super(ConvTranspose2dVariational, self).__init__()
526 | if in_channels % groups != 0:
527 | raise ValueError('invalid in_channels size')
528 | if out_channels % groups != 0:
529 | raise ValueError('invalid in_channels size')
530 |
531 | self.in_channels = in_channels
532 | self.out_channels = out_channels
533 | self.kernel_size = kernel_size
534 | self.stride = stride
535 | self.padding = padding
536 | self.output_padding = output_padding
537 | self.dilation = dilation
538 | self.groups = groups
539 | self.prior_mean = prior_mean
540 | self.prior_variance = prior_variance
541 | self.posterior_mu_init = posterior_mu_init, # mean of weight
542 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
543 | self.bias = bias
544 |
545 | self.mu_kernel = Parameter(
546 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
547 | kernel_size))
548 | self.rho_kernel = Parameter(
549 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
550 | kernel_size))
551 | self.register_buffer(
552 | 'eps_kernel',
553 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
554 | kernel_size))
555 | self.register_buffer(
556 | 'prior_weight_mu',
557 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
558 | kernel_size))
559 | self.register_buffer(
560 | 'prior_weight_sigma',
561 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
562 | kernel_size))
563 |
564 | if self.bias:
565 | self.mu_bias = Parameter(torch.Tensor(out_channels))
566 | self.rho_bias = Parameter(torch.Tensor(out_channels))
567 | self.register_buffer('eps_bias', torch.Tensor(out_channels))
568 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels))
569 | self.register_buffer('prior_bias_sigma',
570 | torch.Tensor(out_channels))
571 | else:
572 | self.register_parameter('mu_bias', None)
573 | self.register_parameter('rho_bias', None)
574 | self.register_buffer('eps_bias', None)
575 | self.register_buffer('prior_bias_mu', None)
576 | self.register_buffer('prior_bias_sigma', None)
577 |
578 | self.init_parameters()
579 |
580 | def init_parameters(self):
581 | self.prior_weight_mu.fill_(self.prior_mean)
582 | self.prior_weight_sigma.fill_(self.prior_variance)
583 |
584 | self.mu_kernel.data.normal_(std=0.1)
585 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
586 | if self.bias:
587 | self.prior_bias_mu.fill_(self.prior_mean)
588 | self.prior_bias_sigma.fill_(self.prior_variance)
589 |
590 | self.mu_bias.data.normal_(std=0.1)
591 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
592 | std=0.1)
593 |
594 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
595 | kl = torch.log(sigma_p + 1e-15) - torch.log(
596 | sigma_q + 1e-15) + (sigma_q**2 +
597 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5
598 | return kl.sum()
599 |
600 | def forward(self, input):
601 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
602 | eps_kernel = self.eps_kernel.normal_()
603 | weight = self.mu_kernel + (sigma_weight * eps_kernel)
604 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
605 | self.prior_weight_mu, self.prior_weight_sigma)
606 | bias = None
607 |
608 | if self.bias:
609 | sigma_bias = torch.log1p(torch.exp(self.rho_bias))
610 | eps_bias = self.eps_bias.normal_()
611 | bias = self.mu_bias + (sigma_bias * eps_bias)
612 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
613 | self.prior_bias_sigma)
614 |
615 | out = F.conv_transpose2d(input, weight, bias, self.stride,
616 | self.padding, self.output_padding,
617 | self.dilation, self.groups)
618 | if self.bias:
619 | kl = kl_weight + kl_bias
620 | else:
621 | kl = kl_weight
622 |
623 | return out, kl
624 |
625 |
626 | class ConvTranspose3dVariational(Module):
627 | def __init__(self,
628 | prior_mean,
629 | prior_variance,
630 | posterior_mu_init,
631 | posterior_rho_init,
632 | in_channels,
633 | out_channels,
634 | kernel_size,
635 | stride=1,
636 | padding=0,
637 | output_padding=0,
638 | dilation=1,
639 | groups=1,
640 | bias=True):
641 |
642 | super(ConvTranspose3dVariational, self).__init__()
643 | if in_channels % groups != 0:
644 | raise ValueError('invalid in_channels size')
645 | if out_channels % groups != 0:
646 | raise ValueError('invalid in_channels size')
647 |
648 | self.in_channels = in_channels
649 | self.out_channels = out_channels
650 | self.kernel_size = kernel_size
651 | self.stride = stride
652 | self.padding = padding
653 | self.output_padding = output_padding
654 | self.dilation = dilation
655 | self.groups = groups
656 | self.prior_mean = prior_mean
657 | self.prior_variance = prior_variance
658 | self.posterior_mu_init = posterior_mu_init, # mean of weight
659 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
660 | self.bias = bias
661 |
662 | self.mu_kernel = Parameter(
663 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
664 | kernel_size, kernel_size))
665 | self.rho_kernel = Parameter(
666 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
667 | kernel_size, kernel_size))
668 | self.register_buffer(
669 | 'eps_kernel',
670 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
671 | kernel_size, kernel_size))
672 | self.register_buffer(
673 | 'prior_weight_mu',
674 | torch.Tensor(in_channels, out_channels // groups, kernel_size,
675 | kernel_size, kernel_size))
676 | self.register_buffer(
677 | 'prior_weight_sigma',
678 | torch.Tensor(out_channels, in_channels // groups, kernel_size,
679 | kernel_size, kernel_size))
680 |
681 | if self.bias:
682 | self.mu_bias = Parameter(torch.Tensor(out_channels))
683 | self.rho_bias = Parameter(torch.Tensor(out_channels))
684 | self.register_buffer('eps_bias', torch.Tensor(out_channels))
685 | self.register_buffer('prior_bias_mu', torch.Tensor(out_channels))
686 | self.register_buffer('prior_bias_sigma',
687 | torch.Tensor(out_channels))
688 | else:
689 | self.register_parameter('mu_bias', None)
690 | self.register_parameter('rho_bias', None)
691 | self.register_buffer('eps_bias', None)
692 | self.register_buffer('prior_bias_mu', None)
693 | self.register_buffer('prior_bias_sigma', None)
694 |
695 | self.init_parameters()
696 |
697 | def init_parameters(self):
698 | self.prior_weight_mu.fill_(self.prior_mean)
699 | self.prior_weight_sigma.fill_(self.prior_variance)
700 |
701 | self.mu_kernel.data.normal_(std=0.1)
702 | self.rho_kernel.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
703 | if self.bias:
704 | self.prior_bias_mu.fill_(self.prior_mean)
705 | self.prior_bias_sigma.fill_(self.prior_variance)
706 |
707 | self.mu_bias.data.normal_(std=0.1)
708 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
709 | std=0.1)
710 |
711 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
712 | kl = torch.log(sigma_p + 1e-15) - torch.log(
713 | sigma_q + 1e-15) + (sigma_q**2 +
714 | (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5
715 | return kl.sum()
716 |
717 | def forward(self, input):
718 | sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
719 | eps_kernel = self.eps_kernel.normal_()
720 | weight = self.mu_kernel + (sigma_weight * eps_kernel)
721 | kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
722 | self.prior_weight_mu, self.prior_weight_sigma)
723 | bias = None
724 |
725 | if self.bias:
726 | sigma_bias = torch.log1p(torch.exp(self.rho_bias))
727 | eps_bias = self.eps_bias.normal_()
728 | bias = self.mu_bias + (sigma_bias * eps_bias)
729 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
730 | self.prior_bias_sigma)
731 |
732 | out = F.conv_transpose3d(input, weight, bias, self.stride,
733 | self.padding, self.output_padding,
734 | self.dilation, self.groups)
735 | if self.bias:
736 | kl = kl_weight + kl_bias
737 | else:
738 | kl = kl_weight
739 |
740 | return out, kl
741 |
--------------------------------------------------------------------------------
/variational_layers/linear_variational.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2020 Intel Corporation
2 | #
3 | # BSD-3-Clause License
4 | #
5 | # Redistribution and use in source and binary forms, with or without modification,
6 | # are permitted provided that the following conditions are met:
7 | # 1. Redistributions of source code must retain the above copyright notice,
8 | # this list of conditions and the following disclaimer.
9 | # 2. Redistributions in binary form must reproduce the above copyright notice,
10 | # this list of conditions and the following disclaimer in the documentation
11 | # and/or other materials provided with the distribution.
12 | # 3. Neither the name of the copyright holder nor the names of its contributors
13 | # may be used to endorse or promote products derived from this software
14 | # without specific prior written permission.
15 | #
16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
20 | # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
21 | # OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
22 | # OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
23 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
25 | # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
26 | # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 | #
28 | # Linear Variational Layers with reparameterization estimator to perform
29 | # mean-field variational inference in Bayesian neural networks. Variational layers
30 | # enables Monte Carlo approximation of the distribution over 'kernel' and 'bias'.
31 | #
32 | # Kullback-Leibler divergence between the surrogate posterior and prior is computed
33 | # and returned along with the tensors of outputs after linear opertaion, which is
34 | # required to compute Evidence Lower Bound (ELBO) loss for variational inference.
35 | #
36 | # @authors: Ranganath Krishnan
37 | #
38 | # ======================================================================================
39 |
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 |
44 | import torch
45 | import torch.nn as nn
46 | import torch.nn.functional as F
47 | from torch.nn import Module, Parameter
48 | import math
49 |
50 |
51 | class LinearVariational(Module):
52 | def __init__(self,
53 | prior_mean,
54 | prior_variance,
55 | posterior_mu_init,
56 | posterior_rho_init,
57 | in_features,
58 | out_features,
59 | bias=True):
60 |
61 | super(LinearVariational, self).__init__()
62 |
63 | self.in_features = in_features
64 | self.out_features = out_features
65 | self.prior_mean = prior_mean
66 | self.prior_variance = prior_variance
67 | self.posterior_mu_init = posterior_mu_init, # mean of weight
68 | self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
69 | self.bias = bias
70 |
71 | self.mu_weight = Parameter(torch.Tensor(out_features, in_features))
72 | self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
73 | self.register_buffer('eps_weight',
74 | torch.Tensor(out_features, in_features))
75 | self.register_buffer('prior_weight_mu',
76 | torch.Tensor(out_features, in_features))
77 | if bias:
78 | self.mu_bias = Parameter(torch.Tensor(out_features))
79 | self.rho_bias = Parameter(torch.Tensor(out_features))
80 | self.register_buffer('eps_bias', torch.Tensor(out_features))
81 | self.register_buffer('prior_bias_mu', torch.Tensor(out_features))
82 | else:
83 | self.register_buffer('prior_bias_mu', None)
84 | self.register_parameter('mu_bias', None)
85 | self.register_parameter('rho_bias', None)
86 | self.register_buffer('eps_bias', None)
87 |
88 | self.init_parameters()
89 |
90 | def init_parameters(self):
91 | self.prior_weight_mu.fill_(self.prior_mean)
92 |
93 | self.mu_weight.data.normal_(std=0.1)
94 | self.rho_weight.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
95 | if self.mu_bias is not None:
96 | self.prior_bias_mu.fill_(self.prior_mean)
97 | self.mu_bias.data.normal_(std=0.1)
98 | self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
99 | std=0.1)
100 |
101 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
102 | sigma_p = torch.tensor(sigma_p)
103 | kl = torch.log(sigma_p) - torch.log(
104 | sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 *
105 | (sigma_p**2)) - 0.5
106 | return kl.sum()
107 |
108 | def forward(self, input):
109 | sigma_weight = torch.log1p(torch.exp(self.rho_weight))
110 | weight = self.mu_weight + (sigma_weight * self.eps_weight.normal_())
111 | kl_weight = self.kl_div(self.mu_weight, sigma_weight,
112 | self.prior_weight_mu, self.prior_variance)
113 | bias = None
114 |
115 | if self.mu_bias is not None:
116 | sigma_bias = torch.log1p(torch.exp(self.rho_bias))
117 | bias = self.mu_bias + (sigma_bias * self.eps_bias.normal_())
118 | kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
119 | self.prior_variance)
120 |
121 | out = F.linear(input, weight, bias)
122 | if self.mu_bias is not None:
123 | kl = kl_weight + kl_bias
124 | else:
125 | kl = kl_weight
126 |
127 | return out, kl
128 |
--------------------------------------------------------------------------------