├── OE_results.png
├── detailed_results.png
├── setup.sh
├── README.md
├── densenet.py
├── calculate_log.py
├── ResNet_SVHN.ipynb
├── ResNet_Cifar10.ipynb
├── ResNet_Cifar100.ipynb
├── DenseNet_SVHN.ipynb
├── DenseNet_Cifar100.ipynb
└── DenseNet_Cifar10.ipynb
/OE_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorInstitute/gram-ood-detection/HEAD/OE_results.png
--------------------------------------------------------------------------------
/detailed_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VectorInstitute/gram-ood-detection/HEAD/detailed_results.png
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | wget https://www.dropbox.com/s/avgm2u562itwpkl/Imagenet.tar.gz
2 | tar -xzf Imagenet.tar.gz
3 |
4 | wget https://www.dropbox.com/s/kp3my3412u5k9rl/Imagenet_resize.tar.gz
5 | tar -xzf Imagenet_resize.tar.gz
6 |
7 | wget https://www.dropbox.com/s/fhtsw1m3qxlwj6h/LSUN.tar.gz
8 | tar -xzf LSUN.tar.gz
9 |
10 | wget https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz
11 | tar -xzf LSUN_resize.tar.gz
12 |
13 | wget https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz
14 | tar -xzf iSUN.tar.gz
15 |
16 | rm *.gz
17 |
18 | wget https://www.dropbox.com/s/pnbvr16gnpyr1zg/densenet_cifar10.pth
19 | wget https://www.dropbox.com/s/7ur9qo81u30od36/densenet_cifar100.pth
20 | wget https://www.dropbox.com/s/9ol1h2tb3xjdpp1/densenet_svhn.pth
21 | wget https://www.dropbox.com/s/ynidbn7n7ccadog/resnet_cifar10.pth
22 | wget https://www.dropbox.com/s/yzfzf4bwqe4du6w/resnet_cifar100.pth
23 | wget https://www.dropbox.com/s/uvgpgy9pu7s9ps2/resnet_svhn.pth
24 |
25 |
26 | wget https://raw.githubusercontent.com/hendrycks/outlier-exposure/master/CIFAR/snapshots/oe_scratch/cifar100_wrn_oe_scratch_epoch_99.pt
27 | wget https://raw.githubusercontent.com/hendrycks/outlier-exposure/master/CIFAR/snapshots/oe_scratch/cifar10_wrn_oe_scratch_epoch_99.pt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Detecting Out-of-Distribution Examples with In-distribution Examples and Gram Matrices
2 | ICML 2020: paper, supplementary material and bibtex available at http://proceedings.mlr.press/v119/sastry20a.html
3 |
4 | ## Dependencies
5 | The code is written in Python 3 with Pytorch 1.1.
6 |
7 | ## Results
8 | 
9 |
10 | (Please refer to this [repository](https://github.com/chandramouli-sastry/deep_Mahalanobis_detector) for the results of Baseline/ODIN/Mahalanobis on dataset-pairs not presented in the Mahalanobis paper)
11 |
12 | ## Combining Outlier Exposure (OE) and Ours
13 | 
14 |
15 | ## Downloading Out-of-Distribution Datasets and Pre-trained Models
16 | We used the out-of-distribution datasets presented in [odin-pytorch](https://github.com/facebookresearch/odin)
17 |
18 | We used pre-trained neural networks open-sourced by [Mahalanobis](https://github.com/pokaxpoka/deep_Mahalanobis_detector/) and [odin-pytorch](https://github.com/ShiyuLiang/odin-pytroch). The DenseNets trained on CIFAR-10 and CIFAR-100 are by ODIN; remaining are by Mahalanobis.
19 |
20 | For experiments on OE-trained networks, we used the pre-trained networks open-sourced by [OE](https://github.com/hendrycks/outlier-exposure)
21 |
22 | Running the setup.sh downloads the Out-of-Distribution Datasets and pre-trained models.
23 |
--------------------------------------------------------------------------------
/densenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
13 | padding=1, bias=False)
14 | self.droprate = dropRate
15 | def forward(self, x):
16 | out = self.conv1(self.relu(self.bn1(x)))
17 | if self.droprate > 0:
18 | out = F.dropout(out, p=self.droprate, training=self.training)
19 | return torch.cat([x, out], 1)
20 |
21 | class BottleneckBlock(nn.Module):
22 | def __init__(self, in_planes, out_planes, dropRate=0.0):
23 | super(BottleneckBlock, self).__init__()
24 | inter_planes = out_planes * 4
25 | self.bn1 = nn.BatchNorm2d(in_planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
28 | padding=0, bias=False)
29 | self.bn2 = nn.BatchNorm2d(inter_planes)
30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
31 | padding=1, bias=False)
32 | self.droprate = dropRate
33 | def forward(self, x):
34 | out = self.conv1(self.relu(self.bn1(x)))
35 | if self.droprate > 0:
36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
37 | out = self.conv2(self.relu(self.bn2(out)))
38 | if self.droprate > 0:
39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
40 | return torch.cat([x, out], 1)
41 |
42 | class TransitionBlock(nn.Module):
43 | def __init__(self, in_planes, out_planes, dropRate=0.0):
44 | super(TransitionBlock, self).__init__()
45 | self.bn1 = nn.BatchNorm2d(in_planes)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
48 | padding=0, bias=False)
49 | self.droprate = dropRate
50 | def forward(self, x):
51 | out = self.conv1(self.relu(self.bn1(x)))
52 | if self.droprate > 0:
53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
54 | return F.avg_pool2d(out, 2)
55 |
56 | class DenseBlock(nn.Module):
57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
58 | super(DenseBlock, self).__init__()
59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
61 | layers = []
62 | for i in range(int(nb_layers)):
63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))
64 | return nn.Sequential(*layers)
65 | def forward(self, x):
66 | return self.layer(x)
67 |
68 | class DenseNet3(nn.Module):
69 | def __init__(self, depth, num_classes, growth_rate=12,
70 | reduction=0.5, bottleneck=True, dropRate=0.0):
71 | super(DenseNet3, self).__init__()
72 | in_planes = 2 * growth_rate
73 | n = (depth - 4) / 3
74 | if bottleneck == True:
75 | n = n/2
76 | block = BottleneckBlock
77 | else:
78 | block = BasicBlock
79 | # 1st conv before any dense block
80 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,
81 | padding=1, bias=False)
82 | # 1st block
83 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
84 | in_planes = int(in_planes+n*growth_rate)
85 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
86 | in_planes = int(math.floor(in_planes*reduction))
87 | # 2nd block
88 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
89 | in_planes = int(in_planes+n*growth_rate)
90 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
91 | in_planes = int(math.floor(in_planes*reduction))
92 | # 3rd block
93 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
94 | in_planes = int(in_planes+n*growth_rate)
95 | # global average pooling and classifier
96 | self.bn1 = nn.BatchNorm2d(in_planes)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.fc = nn.Linear(in_planes, num_classes)
99 | self.in_planes = in_planes
100 |
101 | for m in self.modules():
102 | if isinstance(m, nn.Conv2d):
103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
104 | m.weight.data.normal_(0, math.sqrt(2. / n))
105 | elif isinstance(m, nn.BatchNorm2d):
106 | m.weight.data.fill_(1)
107 | m.bias.data.zero_()
108 | elif isinstance(m, nn.Linear):
109 | m.bias.data.zero_()
110 |
111 | def forward(self, x):
112 | out = self.conv1(x)
113 | out = self.trans1(self.block1(out))
114 | out = self.trans2(self.block2(out))
115 | out = self.block3(out)
116 | out = self.relu(self.bn1(out))
117 | out = F.avg_pool2d(out, 8)
118 | out = out.view(-1, self.in_planes)
119 | return self.fc(out)
120 |
121 | # function to extact the multiple features
122 | def feature_list(self, x):
123 | out_list = []
124 | out = self.conv1(x)
125 | out_list.append(out)
126 | out = self.trans1(self.block1(out))
127 | out_list.append(out)
128 | out = self.trans2(self.block2(out))
129 | out_list.append(out)
130 | out = self.block3(out)
131 | out = self.relu(self.bn1(out))
132 | out_list.append(out)
133 | out = F.avg_pool2d(out, 8)
134 | out = out.view(-1, self.in_planes)
135 |
136 | return self.fc(out), out_list
137 |
138 | def intermediate_forward(self, x, layer_index):
139 | out = self.conv1(x)
140 | if layer_index == 1:
141 | out = self.trans1(self.block1(out))
142 | elif layer_index == 2:
143 | out = self.trans1(self.block1(out))
144 | out = self.trans2(self.block2(out))
145 | elif layer_index == 3:
146 | out = self.trans1(self.block1(out))
147 | out = self.trans2(self.block2(out))
148 | out = self.block3(out)
149 | out = self.relu(self.bn1(out))
150 | return out
151 |
152 | # function to extact the penultimate features
153 | def penultimate_forward(self, x):
154 | out = self.conv1(x)
155 | out = self.trans1(self.block1(out))
156 | out = self.trans2(self.block2(out))
157 | out = self.block3(out)
158 | penultimate = self.relu(self.bn1(out))
159 | out = F.avg_pool2d(penultimate, 8)
160 | out = out.view(-1, self.in_planes)
161 | return self.fc(out), penultimate
--------------------------------------------------------------------------------
/calculate_log.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function,division
2 | import torch
3 | from torch.autograd import Variable
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import torch.optim as optim
8 | import torchvision
9 | import torchvision.transforms as transforms
10 | import numpy as np
11 | import time
12 | from scipy import misc
13 |
14 | import matplotlib
15 | # matplotlib.use('Agg')
16 | import matplotlib.pyplot as plt
17 |
18 | def compute_metric(known, novel):
19 | stype = ""
20 |
21 | tp, fp = dict(), dict()
22 | tnr_at_tpr95 = dict()
23 |
24 | known.sort()
25 | novel.sort()
26 | end = np.max([np.max(known), np.max(novel)])
27 | start = np.min([np.min(known),np.min(novel)])
28 | num_k = known.shape[0]
29 | num_n = novel.shape[0]
30 | tp[stype] = -np.ones([num_k+num_n+1], dtype=int)
31 | fp[stype] = -np.ones([num_k+num_n+1], dtype=int)
32 | tp[stype][0], fp[stype][0] = num_k, num_n
33 | k, n = 0, 0
34 | for l in range(num_k+num_n):
35 | if k == num_k:
36 | tp[stype][l+1:] = tp[stype][l]
37 | fp[stype][l+1:] = np.arange(fp[stype][l]-1, -1, -1)
38 | break
39 | elif n == num_n:
40 | tp[stype][l+1:] = np.arange(tp[stype][l]-1, -1, -1)
41 | fp[stype][l+1:] = fp[stype][l]
42 | break
43 | else:
44 | if novel[n] < known[k]:
45 | n += 1
46 | tp[stype][l+1] = tp[stype][l]
47 | fp[stype][l+1] = fp[stype][l] - 1
48 | else:
49 | k += 1
50 | tp[stype][l+1] = tp[stype][l] - 1
51 | fp[stype][l+1] = fp[stype][l]
52 | tpr95_pos = np.abs(tp[stype] / num_k - .95).argmin()
53 | tnr_at_tpr95[stype] = 1. - fp[stype][tpr95_pos] / num_n
54 | mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']
55 | results = dict()
56 | results[stype] = dict()
57 |
58 | # TNR
59 | mtype = 'TNR'
60 | results[stype][mtype] = tnr_at_tpr95[stype]
61 |
62 | # AUROC
63 | mtype = 'AUROC'
64 | tpr = np.concatenate([[1.], tp[stype]/tp[stype][0], [0.]])
65 | fpr = np.concatenate([[1.], fp[stype]/fp[stype][0], [0.]])
66 | results[stype][mtype] = -np.trapz(1.-fpr, tpr)
67 |
68 | # DTACC
69 | mtype = 'DTACC'
70 | results[stype][mtype] = .5 * (tp[stype]/tp[stype][0] + 1.-fp[stype]/fp[stype][0]).max()
71 |
72 | # AUIN
73 | mtype = 'AUIN'
74 | denom = tp[stype]+fp[stype]
75 | denom[denom == 0.] = -1.
76 | pin_ind = np.concatenate([[True], denom > 0., [True]])
77 | pin = np.concatenate([[.5], tp[stype]/denom, [0.]])
78 | results[stype][mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind])
79 |
80 | # AUOUT
81 | mtype = 'AUOUT'
82 | denom = tp[stype][0]-tp[stype]+fp[stype][0]-fp[stype]
83 | denom[denom == 0.] = -1.
84 | pout_ind = np.concatenate([[True], denom > 0., [True]])
85 | pout = np.concatenate([[0.], (fp[stype][0]-fp[stype])/denom, [.5]])
86 | results[stype][mtype] = np.trapz(pout[pout_ind], 1.-fpr[pout_ind])
87 |
88 | return results[stype]
89 |
90 | def print_results(results):
91 | mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']
92 | for mtype in mtypes:
93 | print(' {mtype:6s}'.format(mtype=mtype), end='')
94 | print('')
95 | for mtype in mtypes:
96 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='')
97 | print('')
98 |
99 |
100 | def get_curve(dir_name, stypes = ['Baseline', 'Gaussian_LDA']):
101 | tp, fp = dict(), dict()
102 | tnr_at_tpr95 = dict()
103 | for stype in stypes:
104 | known = np.loadtxt('{}/confidence_{}_In.txt'.format(dir_name, stype), delimiter='\n')
105 | novel = np.loadtxt('{}/confidence_{}_Out.txt'.format(dir_name, stype), delimiter='\n')
106 | known.sort()
107 | novel.sort()
108 | end = np.max([np.max(known), np.max(novel)])
109 | start = np.min([np.min(known),np.min(novel)])
110 | num_k = known.shape[0]
111 | num_n = novel.shape[0]
112 | tp[stype] = -np.ones([num_k+num_n+1], dtype=int)
113 | fp[stype] = -np.ones([num_k+num_n+1], dtype=int)
114 | tp[stype][0], fp[stype][0] = num_k, num_n
115 | k, n = 0, 0
116 | for l in range(num_k+num_n):
117 | if k == num_k:
118 | tp[stype][l+1:] = tp[stype][l]
119 | fp[stype][l+1:] = np.arange(fp[stype][l]-1, -1, -1)
120 | break
121 | elif n == num_n:
122 | tp[stype][l+1:] = np.arange(tp[stype][l]-1, -1, -1)
123 | fp[stype][l+1:] = fp[stype][l]
124 | break
125 | else:
126 | if novel[n] < known[k]:
127 | n += 1
128 | tp[stype][l+1] = tp[stype][l]
129 | fp[stype][l+1] = fp[stype][l] - 1
130 | else:
131 | k += 1
132 | tp[stype][l+1] = tp[stype][l] - 1
133 | fp[stype][l+1] = fp[stype][l]
134 | tpr95_pos = np.abs(tp[stype] / num_k - .95).argmin()
135 | tnr_at_tpr95[stype] = 1. - fp[stype][tpr95_pos] / num_n
136 |
137 | return tp, fp, tnr_at_tpr95
138 |
139 | def metric(dir_name, stypes = ['Bas', 'Gau'], verbose=False):
140 | tp, fp, tnr_at_tpr95 = get_curve(dir_name, stypes)
141 | results = dict()
142 | mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']
143 | if verbose:
144 | print(' ', end='')
145 | for mtype in mtypes:
146 | print(' {mtype:6s}'.format(mtype=mtype), end='')
147 | print('')
148 |
149 | for stype in stypes:
150 | if verbose:
151 | print('{stype:5s} '.format(stype=stype), end='')
152 | results[stype] = dict()
153 |
154 | # TNR
155 | mtype = 'TNR'
156 | results[stype][mtype] = tnr_at_tpr95[stype]
157 | if verbose:
158 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
159 |
160 | # AUROC
161 | mtype = 'AUROC'
162 | tpr = np.concatenate([[1.], tp[stype]/tp[stype][0], [0.]])
163 | fpr = np.concatenate([[1.], fp[stype]/fp[stype][0], [0.]])
164 | results[stype][mtype] = -np.trapz(1.-fpr, tpr)
165 | if verbose:
166 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
167 |
168 | # DTACC
169 | mtype = 'DTACC'
170 | results[stype][mtype] = .5 * (tp[stype]/tp[stype][0] + 1.-fp[stype]/fp[stype][0]).max()
171 | if verbose:
172 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
173 |
174 | # AUIN
175 | mtype = 'AUIN'
176 | denom = tp[stype]+fp[stype]
177 | denom[denom == 0.] = -1.
178 | pin_ind = np.concatenate([[True], denom > 0., [True]])
179 | pin = np.concatenate([[.5], tp[stype]/denom, [0.]])
180 | results[stype][mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind])
181 | if verbose:
182 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
183 |
184 | # AUOUT
185 | mtype = 'AUOUT'
186 | denom = tp[stype][0]-tp[stype]+fp[stype][0]-fp[stype]
187 | denom[denom == 0.] = -1.
188 | pout_ind = np.concatenate([[True], denom > 0., [True]])
189 | pout = np.concatenate([[0.], (fp[stype][0]-fp[stype])/denom, [.5]])
190 | results[stype][mtype] = np.trapz(pout[pout_ind], 1.-fpr[pout_ind])
191 | if verbose:
192 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='')
193 | print('')
194 |
195 | return results
--------------------------------------------------------------------------------
/ResNet_SVHN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
ResNet: SVHN
"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "## Imports"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from __future__ import division,print_function\n",
24 | "\n",
25 | "%matplotlib inline\n",
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "\n",
29 | "import sys\n",
30 | "from tqdm import tqdm_notebook as tqdm\n",
31 | "\n",
32 | "import random\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "import math\n",
35 | "\n",
36 | "import numpy as np\n",
37 | "\n",
38 | "import torch\n",
39 | "import torch.nn as nn\n",
40 | "import torch.nn.functional as F\n",
41 | "import torch.optim as optim\n",
42 | "import torch.nn.init as init\n",
43 | "from torch.autograd import Variable, grad\n",
44 | "from torchvision import datasets, transforms\n",
45 | "from torch.nn.parameter import Parameter\n",
46 | "\n",
47 | "import calculate_log as callog\n",
48 | "\n",
49 | "import warnings\n",
50 | "warnings.filterwarnings('ignore')"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "torch.cuda.set_device(0) #Select the GPU"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "## Model definition"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {
73 | "scrolled": true
74 | },
75 | "outputs": [
76 | {
77 | "name": "stdout",
78 | "output_type": "stream",
79 | "text": [
80 | "Done\n"
81 | ]
82 | }
83 | ],
84 | "source": [
85 | "def conv3x3(in_planes, out_planes, stride=1):\n",
86 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
87 | "\n",
88 | "class BasicBlock(nn.Module):\n",
89 | " expansion = 1\n",
90 | "\n",
91 | " def __init__(self, in_planes, planes, stride=1):\n",
92 | " super(BasicBlock, self).__init__()\n",
93 | " self.conv1 = conv3x3(in_planes, planes, stride)\n",
94 | " self.bn1 = nn.BatchNorm2d(planes)\n",
95 | " self.conv2 = conv3x3(planes, planes)\n",
96 | " self.bn2 = nn.BatchNorm2d(planes)\n",
97 | "\n",
98 | " self.shortcut = nn.Sequential()\n",
99 | " if stride != 1 or in_planes != self.expansion*planes:\n",
100 | " self.shortcut = nn.Sequential(\n",
101 | " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n",
102 | " nn.BatchNorm2d(self.expansion*planes)\n",
103 | " )\n",
104 | " \n",
105 | " def forward(self, x):\n",
106 | " t = self.conv1(x)\n",
107 | " out = F.relu(self.bn1(t))\n",
108 | " torch_model.record(t)\n",
109 | " torch_model.record(out)\n",
110 | " t = self.conv2(out)\n",
111 | " out = self.bn2(self.conv2(out))\n",
112 | " torch_model.record(t)\n",
113 | " torch_model.record(out)\n",
114 | " t = self.shortcut(x)\n",
115 | " out += t\n",
116 | " torch_model.record(t)\n",
117 | " out = F.relu(out)\n",
118 | " torch_model.record(out)\n",
119 | " \n",
120 | " return out#, out_list\n",
121 | "\n",
122 | "class ResNet(nn.Module):\n",
123 | " def __init__(self, block, num_blocks, num_classes=10):\n",
124 | " super(ResNet, self).__init__()\n",
125 | " self.in_planes = 64\n",
126 | "\n",
127 | " self.conv1 = conv3x3(3,64)\n",
128 | " self.bn1 = nn.BatchNorm2d(64)\n",
129 | " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
130 | " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
131 | " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
132 | " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
133 | " self.linear = nn.Linear(512*block.expansion, num_classes)\n",
134 | " \n",
135 | " self.collecting = False\n",
136 | " \n",
137 | " def _make_layer(self, block, planes, num_blocks, stride):\n",
138 | " strides = [stride] + [1]*(num_blocks-1)\n",
139 | " layers = []\n",
140 | " for stride in strides:\n",
141 | " layers.append(block(self.in_planes, planes, stride))\n",
142 | " self.in_planes = planes * block.expansion\n",
143 | " return nn.Sequential(*layers)\n",
144 | " \n",
145 | " def forward(self, x):\n",
146 | " out = F.relu(self.bn1(self.conv1(x)))\n",
147 | " out = self.layer1(out)\n",
148 | " out = self.layer2(out)\n",
149 | " out = self.layer3(out)\n",
150 | " out = self.layer4(out)\n",
151 | " out = F.avg_pool2d(out, 4)\n",
152 | " out = out.view(out.size(0), -1)\n",
153 | " y = self.linear(out)\n",
154 | " return y\n",
155 | " \n",
156 | " def record(self, t):\n",
157 | " if self.collecting:\n",
158 | " self.gram_feats.append(t)\n",
159 | " \n",
160 | " def gram_feature_list(self,x):\n",
161 | " self.collecting = True\n",
162 | " self.gram_feats = []\n",
163 | " self.forward(x)\n",
164 | " self.collecting = False\n",
165 | " temp = self.gram_feats\n",
166 | " self.gram_feats = []\n",
167 | " return temp\n",
168 | " \n",
169 | " def load(self, path=\"resnet_svhn.pth\"):\n",
170 | " tm = torch.load(path,map_location=\"cpu\") \n",
171 | " self.load_state_dict(tm)\n",
172 | " \n",
173 | " def get_min_max(self, data, power):\n",
174 | " mins = []\n",
175 | " maxs = []\n",
176 | " \n",
177 | " for i in range(0,len(data),128):\n",
178 | " batch = data[i:i+128].cuda()\n",
179 | " feat_list = self.gram_feature_list(batch)\n",
180 | " for L,feat_L in enumerate(feat_list):\n",
181 | " if L==len(mins):\n",
182 | " mins.append([None]*len(power))\n",
183 | " maxs.append([None]*len(power))\n",
184 | " \n",
185 | " for p,P in enumerate(power):\n",
186 | " g_p = G_p(feat_L,P)\n",
187 | " \n",
188 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n",
189 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n",
190 | " \n",
191 | " if mins[L][p] is None:\n",
192 | " mins[L][p] = current_min\n",
193 | " maxs[L][p] = current_max\n",
194 | " else:\n",
195 | " mins[L][p] = torch.min(current_min,mins[L][p])\n",
196 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n",
197 | " \n",
198 | " return mins,maxs\n",
199 | " \n",
200 | " def get_deviations(self,data,power,mins,maxs):\n",
201 | " deviations = []\n",
202 | " \n",
203 | " for i in range(0,len(data),128): \n",
204 | " batch = data[i:i+128].cuda()\n",
205 | " feat_list = self.gram_feature_list(batch)\n",
206 | " batch_deviations = []\n",
207 | " for L,feat_L in enumerate(feat_list):\n",
208 | " dev = 0\n",
209 | " for p,P in enumerate(power):\n",
210 | " g_p = G_p(feat_L,P)\n",
211 | " \n",
212 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
213 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
214 | " batch_deviations.append(dev.cpu().detach().numpy())\n",
215 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n",
216 | " deviations.append(batch_deviations)\n",
217 | " deviations = np.concatenate(deviations,axis=0)\n",
218 | " \n",
219 | " return deviations\n",
220 | "\n",
221 | "torch_model = ResNet(BasicBlock, [3,4,6,3], num_classes=10)\n",
222 | "torch_model.load()\n",
223 | "torch_model.cuda()\n",
224 | "torch_model.params = list(torch_model.parameters())\n",
225 | "torch_model.eval()\n",
226 | "print(\"Done\") "
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {},
232 | "source": [
233 | "## Datasets"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {},
239 | "source": [
240 | "In-distribution Datasets"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 4,
246 | "metadata": {},
247 | "outputs": [
248 | {
249 | "name": "stdout",
250 | "output_type": "stream",
251 | "text": [
252 | "Using downloaded and verified file: data/train_32x32.mat\n",
253 | "Using downloaded and verified file: data/test_32x32.mat\n"
254 | ]
255 | }
256 | ],
257 | "source": [
258 | "batch_size = 128\n",
259 | "mean = np.array([[0.4914, 0.4822, 0.4465]]).T\n",
260 | "\n",
261 | "std = np.array([[0.2023, 0.1994, 0.2010]]).T\n",
262 | "normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
263 | "\n",
264 | "transform_train = transforms.Compose([\n",
265 | " transforms.RandomCrop(32, padding=4),\n",
266 | " transforms.RandomHorizontalFlip(),\n",
267 | " transforms.ToTensor(),\n",
268 | " normalize\n",
269 | " \n",
270 | " ])\n",
271 | "transform_test = transforms.Compose([\n",
272 | " transforms.CenterCrop(size=(32, 32)),\n",
273 | " transforms.ToTensor(),\n",
274 | " normalize\n",
275 | " ])\n",
276 | "\n",
277 | "\n",
278 | "train_loader = torch.utils.data.DataLoader(\n",
279 | " datasets.SVHN('data', split=\"train\", download=True,\n",
280 | " transform=transform_train),\n",
281 | " batch_size=batch_size, shuffle=True)\n",
282 | "test_loader = torch.utils.data.DataLoader(\n",
283 | " datasets.SVHN('data', split=\"test\", download=True, transform=transform_test),\n",
284 | " batch_size=batch_size)\n"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": 5,
290 | "metadata": {
291 | "scrolled": true
292 | },
293 | "outputs": [
294 | {
295 | "name": "stdout",
296 | "output_type": "stream",
297 | "text": [
298 | "Using downloaded and verified file: data/train_32x32.mat\n"
299 | ]
300 | }
301 | ],
302 | "source": [
303 | "data_train = list(list(torch.utils.data.DataLoader(\n",
304 | " datasets.SVHN('data', split=\"train\", download=True,\n",
305 | " transform=transform_test),\n",
306 | " batch_size=1, shuffle=True)))"
307 | ]
308 | },
309 | {
310 | "cell_type": "code",
311 | "execution_count": 6,
312 | "metadata": {},
313 | "outputs": [
314 | {
315 | "name": "stdout",
316 | "output_type": "stream",
317 | "text": [
318 | "Using downloaded and verified file: data/test_32x32.mat\n"
319 | ]
320 | }
321 | ],
322 | "source": [
323 | "data = list(list(torch.utils.data.DataLoader(\n",
324 | " datasets.SVHN('data', split=\"test\", download=True,\n",
325 | " transform=transform_test),\n",
326 | " batch_size=1, shuffle=False)))"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": 7,
332 | "metadata": {},
333 | "outputs": [
334 | {
335 | "name": "stdout",
336 | "output_type": "stream",
337 | "text": [
338 | "Accuracy: 0.9668484941610326\n"
339 | ]
340 | }
341 | ],
342 | "source": [
343 | "torch_model.eval()\n",
344 | "correct = 0\n",
345 | "total = 0\n",
346 | "for x,y in test_loader:\n",
347 | " x = x.cuda()\n",
348 | " y = y.numpy()\n",
349 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n",
350 | " total += y.shape[0]\n",
351 | "print(\"Accuracy: \",correct/total)\n"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {},
357 | "source": [
358 | "Out-of-distribution Datasets"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 8,
364 | "metadata": {},
365 | "outputs": [
366 | {
367 | "name": "stdout",
368 | "output_type": "stream",
369 | "text": [
370 | "Files already downloaded and verified\n"
371 | ]
372 | }
373 | ],
374 | "source": [
375 | "cifar10 = list(torch.utils.data.DataLoader(\n",
376 | " datasets.CIFAR10('data', train=False, download=True,\n",
377 | " transform=transform_test),\n",
378 | " batch_size=1, shuffle=True))"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 9,
384 | "metadata": {},
385 | "outputs": [],
386 | "source": [
387 | "isun = list(torch.utils.data.DataLoader(\n",
388 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": 10,
394 | "metadata": {},
395 | "outputs": [],
396 | "source": [
397 | "lsun_c = list(torch.utils.data.DataLoader(\n",
398 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": 11,
404 | "metadata": {},
405 | "outputs": [],
406 | "source": [
407 | "lsun_r = list(torch.utils.data.DataLoader(\n",
408 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
409 | ]
410 | },
411 | {
412 | "cell_type": "code",
413 | "execution_count": 12,
414 | "metadata": {},
415 | "outputs": [],
416 | "source": [
417 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n",
418 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": 13,
424 | "metadata": {},
425 | "outputs": [],
426 | "source": [
427 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n",
428 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
429 | ]
430 | },
431 | {
432 | "cell_type": "markdown",
433 | "metadata": {},
434 | "source": [
435 | "## Code for Detecting OODs"
436 | ]
437 | },
438 | {
439 | "cell_type": "markdown",
440 | "metadata": {},
441 | "source": [
442 | " Extract predictions for train and test data "
443 | ]
444 | },
445 | {
446 | "cell_type": "code",
447 | "execution_count": 14,
448 | "metadata": {
449 | "scrolled": true
450 | },
451 | "outputs": [
452 | {
453 | "name": "stdout",
454 | "output_type": "stream",
455 | "text": [
456 | "Done\n",
457 | "Done\n"
458 | ]
459 | }
460 | ],
461 | "source": [
462 | "train_preds = []\n",
463 | "train_confs = []\n",
464 | "train_logits = []\n",
465 | "for idx in range(0,len(data_train),128):\n",
466 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n",
467 | " \n",
468 | " logits = torch_model(batch)\n",
469 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
470 | " preds = np.argmax(confs,axis=1)\n",
471 | " logits = (logits.cpu().detach().numpy())\n",
472 | "\n",
473 | " train_confs.extend(np.max(confs,axis=1)) \n",
474 | " train_preds.extend(preds)\n",
475 | " train_logits.extend(logits)\n",
476 | "print(\"Done\")\n",
477 | "\n",
478 | "test_preds = []\n",
479 | "test_confs = []\n",
480 | "test_logits = []\n",
481 | "\n",
482 | "for idx in range(0,len(data),128):\n",
483 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n",
484 | " \n",
485 | " logits = torch_model(batch)\n",
486 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
487 | " preds = np.argmax(confs,axis=1)\n",
488 | " logits = (logits.cpu().detach().numpy())\n",
489 | "\n",
490 | " test_confs.extend(np.max(confs,axis=1)) \n",
491 | " test_preds.extend(preds)\n",
492 | " test_logits.extend(logits)\n",
493 | "print(\"Done\")"
494 | ]
495 | },
496 | {
497 | "cell_type": "markdown",
498 | "metadata": {},
499 | "source": [
500 | " Code for detecting OODs by identifying anomalies in correlations "
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": 15,
506 | "metadata": {},
507 | "outputs": [],
508 | "source": [
509 | "import calculate_log as callog\n",
510 | "\n",
511 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n",
512 | " average_results = {}\n",
513 | " for i in range(1,11):\n",
514 | " random.seed(i)\n",
515 | " \n",
516 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n",
517 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n",
518 | "\n",
519 | " validation = all_test_deviations[validation_indices]\n",
520 | " test_deviations = all_test_deviations[test_indices]\n",
521 | "\n",
522 | " t95 = validation.mean(axis=0)+10**-7\n",
523 | " if not normalize:\n",
524 | " t95 = np.ones_like(t95)\n",
525 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
526 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
527 | " \n",
528 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n",
529 | " for m in results:\n",
530 | " average_results[m] = average_results.get(m,0)+results[m]\n",
531 | " \n",
532 | " for m in average_results:\n",
533 | " average_results[m] /= i\n",
534 | " if verbose:\n",
535 | " callog.print_results(average_results)\n",
536 | " return average_results\n",
537 | "\n",
538 | "def cpu(ob):\n",
539 | " for i in range(len(ob)):\n",
540 | " for j in range(len(ob[i])):\n",
541 | " ob[i][j] = ob[i][j].cpu()\n",
542 | " return ob\n",
543 | "\n",
544 | "def cuda(ob):\n",
545 | " for i in range(len(ob)):\n",
546 | " for j in range(len(ob[i])):\n",
547 | " ob[i][j] = ob[i][j].cuda()\n",
548 | " return ob\n",
549 | "\n",
550 | "class Detector:\n",
551 | " def __init__(self):\n",
552 | " self.all_test_deviations = None\n",
553 | " self.mins = {}\n",
554 | " self.maxs = {}\n",
555 | " \n",
556 | " self.classes = range(10)\n",
557 | " \n",
558 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n",
559 | " for PRED in tqdm(self.classes):\n",
560 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n",
561 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n",
562 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n",
563 | " self.mins[PRED] = cpu(mins)\n",
564 | " self.maxs[PRED] = cpu(maxs)\n",
565 | " torch.cuda.empty_cache()\n",
566 | " \n",
567 | " def compute_test_deviations(self,POWERS=[10]):\n",
568 | " all_test_deviations = None\n",
569 | " test_classes = []\n",
570 | " for PRED in tqdm(self.classes):\n",
571 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n",
572 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n",
573 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n",
574 | " \n",
575 | " test_classes.extend([PRED]*len(test_indices))\n",
576 | " \n",
577 | " mins = cuda(self.mins[PRED])\n",
578 | " maxs = cuda(self.maxs[PRED])\n",
579 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n",
580 | " cpu(mins)\n",
581 | " cpu(maxs)\n",
582 | " if all_test_deviations is None:\n",
583 | " all_test_deviations = test_deviations\n",
584 | " else:\n",
585 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n",
586 | " torch.cuda.empty_cache()\n",
587 | " self.all_test_deviations = all_test_deviations\n",
588 | " \n",
589 | " self.test_classes = np.array(test_classes)\n",
590 | " \n",
591 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n",
592 | " ood_preds = []\n",
593 | " ood_confs = []\n",
594 | " \n",
595 | " for idx in range(0,len(ood),128):\n",
596 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n",
597 | " logits = torch_model(batch)\n",
598 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
599 | " preds = np.argmax(confs,axis=1)\n",
600 | " \n",
601 | " ood_confs.extend(np.max(confs,axis=1))\n",
602 | " ood_preds.extend(preds) \n",
603 | " torch.cuda.empty_cache()\n",
604 | " print(\"Done\")\n",
605 | " \n",
606 | " ood_classes = []\n",
607 | " all_ood_deviations = None\n",
608 | " for PRED in tqdm(self.classes):\n",
609 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n",
610 | " if len(ood_indices)==0:\n",
611 | " continue\n",
612 | " ood_classes.extend([PRED]*len(ood_indices))\n",
613 | " \n",
614 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n",
615 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n",
616 | " mins = cuda(self.mins[PRED])\n",
617 | " maxs = cuda(self.maxs[PRED])\n",
618 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n",
619 | " cpu(self.mins[PRED])\n",
620 | " cpu(self.maxs[PRED]) \n",
621 | " if all_ood_deviations is None:\n",
622 | " all_ood_deviations = ood_deviations\n",
623 | " else:\n",
624 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n",
625 | " torch.cuda.empty_cache()\n",
626 | " \n",
627 | " self.ood_classes = np.array(ood_classes)\n",
628 | " \n",
629 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n",
630 | " return average_results, self.all_test_deviations, all_ood_deviations\n"
631 | ]
632 | },
633 | {
634 | "cell_type": "markdown",
635 | "metadata": {},
636 | "source": [
637 | " Results
"
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": 16,
643 | "metadata": {},
644 | "outputs": [
645 | {
646 | "data": {
647 | "application/vnd.jupyter.widget-view+json": {
648 | "model_id": "973bca4a7644474685a068878c21aa61",
649 | "version_major": 2,
650 | "version_minor": 0
651 | },
652 | "text/plain": [
653 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
654 | ]
655 | },
656 | "metadata": {},
657 | "output_type": "display_data"
658 | },
659 | {
660 | "name": "stdout",
661 | "output_type": "stream",
662 | "text": [
663 | "\n"
664 | ]
665 | },
666 | {
667 | "data": {
668 | "application/vnd.jupyter.widget-view+json": {
669 | "model_id": "3c037095b42c4ea5897065610bd9ad1d",
670 | "version_major": 2,
671 | "version_minor": 0
672 | },
673 | "text/plain": [
674 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
675 | ]
676 | },
677 | "metadata": {},
678 | "output_type": "display_data"
679 | },
680 | {
681 | "name": "stdout",
682 | "output_type": "stream",
683 | "text": [
684 | "\n",
685 | "iSUN\n",
686 | "Done\n"
687 | ]
688 | },
689 | {
690 | "data": {
691 | "application/vnd.jupyter.widget-view+json": {
692 | "model_id": "60342b9ed1854718ab89229ffb69b6e8",
693 | "version_major": 2,
694 | "version_minor": 0
695 | },
696 | "text/plain": [
697 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
698 | ]
699 | },
700 | "metadata": {},
701 | "output_type": "display_data"
702 | },
703 | {
704 | "name": "stdout",
705 | "output_type": "stream",
706 | "text": [
707 | "\n",
708 | " TNR AUROC DTACC AUIN AUOUT \n",
709 | " 99.444 99.766 98.113 99.911 99.337\n",
710 | "LSUN (R)\n",
711 | "Done\n"
712 | ]
713 | },
714 | {
715 | "data": {
716 | "application/vnd.jupyter.widget-view+json": {
717 | "model_id": "00fe9a5a01bb4c5ebedb0c66cc732200",
718 | "version_major": 2,
719 | "version_minor": 0
720 | },
721 | "text/plain": [
722 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
723 | ]
724 | },
725 | "metadata": {},
726 | "output_type": "display_data"
727 | },
728 | {
729 | "name": "stdout",
730 | "output_type": "stream",
731 | "text": [
732 | "\n",
733 | " TNR AUROC DTACC AUIN AUOUT \n",
734 | " 99.555 99.823 98.481 99.926 99.538\n",
735 | "LSUN (C)\n",
736 | "Done\n"
737 | ]
738 | },
739 | {
740 | "data": {
741 | "application/vnd.jupyter.widget-view+json": {
742 | "model_id": "7c5b88d59b6841edb2823e57bb5d77e5",
743 | "version_major": 2,
744 | "version_minor": 0
745 | },
746 | "text/plain": [
747 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
748 | ]
749 | },
750 | "metadata": {},
751 | "output_type": "display_data"
752 | },
753 | {
754 | "name": "stdout",
755 | "output_type": "stream",
756 | "text": [
757 | "\n",
758 | " TNR AUROC DTACC AUIN AUOUT \n",
759 | " 94.191 98.739 94.669 99.430 97.383\n",
760 | "TinyImgNet (R)\n",
761 | "Done\n"
762 | ]
763 | },
764 | {
765 | "data": {
766 | "application/vnd.jupyter.widget-view+json": {
767 | "model_id": "a25a7a21142b46688b86f890ddb076f1",
768 | "version_major": 2,
769 | "version_minor": 0
770 | },
771 | "text/plain": [
772 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
773 | ]
774 | },
775 | "metadata": {},
776 | "output_type": "display_data"
777 | },
778 | {
779 | "name": "stdout",
780 | "output_type": "stream",
781 | "text": [
782 | "\n",
783 | " TNR AUROC DTACC AUIN AUOUT \n",
784 | " 99.280 99.725 97.860 99.885 99.297\n",
785 | "TinyImgNet (C)\n",
786 | "Done\n"
787 | ]
788 | },
789 | {
790 | "data": {
791 | "application/vnd.jupyter.widget-view+json": {
792 | "model_id": "7e6afbe421d4410eb4a52313ab2ce3bf",
793 | "version_major": 2,
794 | "version_minor": 0
795 | },
796 | "text/plain": [
797 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
798 | ]
799 | },
800 | "metadata": {},
801 | "output_type": "display_data"
802 | },
803 | {
804 | "name": "stdout",
805 | "output_type": "stream",
806 | "text": [
807 | "\n",
808 | " TNR AUROC DTACC AUIN AUOUT \n",
809 | " 98.392 99.481 96.958 99.744 98.750\n",
810 | "CIFAR-10\n",
811 | "Done\n"
812 | ]
813 | },
814 | {
815 | "data": {
816 | "application/vnd.jupyter.widget-view+json": {
817 | "model_id": "4fb502a55bc044618cbd5647ae07cd8c",
818 | "version_major": 2,
819 | "version_minor": 0
820 | },
821 | "text/plain": [
822 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
823 | ]
824 | },
825 | "metadata": {},
826 | "output_type": "display_data"
827 | },
828 | {
829 | "name": "stdout",
830 | "output_type": "stream",
831 | "text": [
832 | "\n",
833 | " TNR AUROC DTACC AUIN AUOUT \n",
834 | " 85.753 97.305 91.985 98.871 93.185\n"
835 | ]
836 | }
837 | ],
838 | "source": [
839 | "def G_p(ob, p):\n",
840 | " temp = ob.detach()\n",
841 | " \n",
842 | " temp = temp**p\n",
843 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n",
844 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n",
845 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n",
846 | " \n",
847 | " return temp\n",
848 | "\n",
849 | "\n",
850 | "detector = Detector()\n",
851 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n",
852 | "\n",
853 | "detector.compute_test_deviations(POWERS=range(1,11))\n",
854 | "\n",
855 | "print(\"iSUN\")\n",
856 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n",
857 | "print(\"LSUN (R)\")\n",
858 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n",
859 | "print(\"LSUN (C)\")\n",
860 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n",
861 | "print(\"TinyImgNet (R)\")\n",
862 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n",
863 | "print(\"TinyImgNet (C)\")\n",
864 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n",
865 | "print(\"CIFAR-10\")\n",
866 | "c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))"
867 | ]
868 | }
869 | ],
870 | "metadata": {
871 | "kernelspec": {
872 | "display_name": "Python 2",
873 | "language": "python",
874 | "name": "python2"
875 | },
876 | "language_info": {
877 | "codemirror_mode": {
878 | "name": "ipython",
879 | "version": 3
880 | },
881 | "file_extension": ".py",
882 | "mimetype": "text/x-python",
883 | "name": "python",
884 | "nbconvert_exporter": "python",
885 | "pygments_lexer": "ipython3",
886 | "version": "3.6.9"
887 | }
888 | },
889 | "nbformat": 4,
890 | "nbformat_minor": 2
891 | }
892 |
--------------------------------------------------------------------------------
/ResNet_Cifar10.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "ResNet: Cifar10
"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "## Imports"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from __future__ import division,print_function\n",
24 | "\n",
25 | "%matplotlib inline\n",
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "\n",
29 | "import sys\n",
30 | "from tqdm import tqdm_notebook as tqdm\n",
31 | "\n",
32 | "import random\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "import math\n",
35 | "\n",
36 | "import numpy as np\n",
37 | "\n",
38 | "import torch\n",
39 | "import torch.nn as nn\n",
40 | "import torch.nn.functional as F\n",
41 | "import torch.optim as optim\n",
42 | "import torch.nn.init as init\n",
43 | "from torch.autograd import Variable, grad\n",
44 | "from torchvision import datasets, transforms\n",
45 | "from torch.nn.parameter import Parameter\n",
46 | "\n",
47 | "import calculate_log as callog\n",
48 | "\n",
49 | "import warnings\n",
50 | "warnings.filterwarnings('ignore')"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "torch.cuda.set_device(1) #Select the GPU"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "## Model definition"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "name": "stdout",
76 | "output_type": "stream",
77 | "text": [
78 | "Done\n"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "def conv3x3(in_planes, out_planes, stride=1):\n",
84 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
85 | "\n",
86 | "class BasicBlock(nn.Module):\n",
87 | " expansion = 1\n",
88 | "\n",
89 | " def __init__(self, in_planes, planes, stride=1):\n",
90 | " super(BasicBlock, self).__init__()\n",
91 | " self.conv1 = conv3x3(in_planes, planes, stride)\n",
92 | " self.bn1 = nn.BatchNorm2d(planes)\n",
93 | " self.conv2 = conv3x3(planes, planes)\n",
94 | " self.bn2 = nn.BatchNorm2d(planes)\n",
95 | "\n",
96 | " self.shortcut = nn.Sequential()\n",
97 | " if stride != 1 or in_planes != self.expansion*planes:\n",
98 | " self.shortcut = nn.Sequential(\n",
99 | " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n",
100 | " nn.BatchNorm2d(self.expansion*planes)\n",
101 | " )\n",
102 | " \n",
103 | " def forward(self, x):\n",
104 | " t = self.conv1(x)\n",
105 | " out = F.relu(self.bn1(t))\n",
106 | " torch_model.record(t)\n",
107 | " torch_model.record(out)\n",
108 | " t = self.conv2(out)\n",
109 | " out = self.bn2(self.conv2(out))\n",
110 | " torch_model.record(t)\n",
111 | " torch_model.record(out)\n",
112 | " t = self.shortcut(x)\n",
113 | " out += t\n",
114 | " torch_model.record(t)\n",
115 | " out = F.relu(out)\n",
116 | " torch_model.record(out)\n",
117 | " \n",
118 | " return out\n",
119 | "\n",
120 | "class ResNet(nn.Module):\n",
121 | " def __init__(self, block, num_blocks, num_classes=10):\n",
122 | " super(ResNet, self).__init__()\n",
123 | " self.in_planes = 64\n",
124 | "\n",
125 | " self.conv1 = conv3x3(3,64)\n",
126 | " self.bn1 = nn.BatchNorm2d(64)\n",
127 | " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
128 | " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
129 | " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
130 | " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
131 | " self.linear = nn.Linear(512*block.expansion, num_classes)\n",
132 | " \n",
133 | " self.collecting = False\n",
134 | " \n",
135 | " def _make_layer(self, block, planes, num_blocks, stride):\n",
136 | " strides = [stride] + [1]*(num_blocks-1)\n",
137 | " layers = []\n",
138 | " for stride in strides:\n",
139 | " layers.append(block(self.in_planes, planes, stride))\n",
140 | " self.in_planes = planes * block.expansion\n",
141 | " return nn.Sequential(*layers)\n",
142 | " \n",
143 | " def forward(self, x):\n",
144 | " out = F.relu(self.bn1(self.conv1(x)))\n",
145 | " out = self.layer1(out)\n",
146 | " out = self.layer2(out)\n",
147 | " out = self.layer3(out)\n",
148 | " out = self.layer4(out)\n",
149 | " out = F.avg_pool2d(out, 4)\n",
150 | " out = out.view(out.size(0), -1)\n",
151 | " y = self.linear(out)\n",
152 | " return y\n",
153 | " \n",
154 | " def record(self, t):\n",
155 | " if self.collecting:\n",
156 | " self.gram_feats.append(t)\n",
157 | " \n",
158 | " def gram_feature_list(self,x):\n",
159 | " self.collecting = True\n",
160 | " self.gram_feats = []\n",
161 | " self.forward(x)\n",
162 | " self.collecting = False\n",
163 | " temp = self.gram_feats\n",
164 | " self.gram_feats = []\n",
165 | " return temp\n",
166 | " \n",
167 | " def load(self, path=\"resnet_cifar10.pth\"):\n",
168 | " tm = torch.load(path,map_location=\"cpu\") \n",
169 | " self.load_state_dict(tm)\n",
170 | " \n",
171 | " def get_min_max(self, data, power):\n",
172 | " mins = []\n",
173 | " maxs = []\n",
174 | " \n",
175 | " for i in range(0,len(data),128):\n",
176 | " batch = data[i:i+128].cuda()\n",
177 | " feat_list = self.gram_feature_list(batch)\n",
178 | " for L,feat_L in enumerate(feat_list):\n",
179 | " if L==len(mins):\n",
180 | " mins.append([None]*len(power))\n",
181 | " maxs.append([None]*len(power))\n",
182 | " \n",
183 | " for p,P in enumerate(power):\n",
184 | " g_p = G_p(feat_L,P)\n",
185 | " \n",
186 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n",
187 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n",
188 | " \n",
189 | " if mins[L][p] is None:\n",
190 | " mins[L][p] = current_min\n",
191 | " maxs[L][p] = current_max\n",
192 | " else:\n",
193 | " mins[L][p] = torch.min(current_min,mins[L][p])\n",
194 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n",
195 | " \n",
196 | " return mins,maxs\n",
197 | " \n",
198 | " def get_deviations(self,data,power,mins,maxs):\n",
199 | " deviations = []\n",
200 | " \n",
201 | " for i in range(0,len(data),128): \n",
202 | " batch = data[i:i+128].cuda()\n",
203 | " feat_list = self.gram_feature_list(batch)\n",
204 | " batch_deviations = []\n",
205 | " for L,feat_L in enumerate(feat_list):\n",
206 | " dev = 0\n",
207 | " for p,P in enumerate(power):\n",
208 | " g_p = G_p(feat_L,P)\n",
209 | " \n",
210 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
211 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
212 | " batch_deviations.append(dev.cpu().detach().numpy())\n",
213 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n",
214 | " deviations.append(batch_deviations)\n",
215 | " deviations = np.concatenate(deviations,axis=0)\n",
216 | " \n",
217 | " return deviations\n",
218 | "\n",
219 | "\n",
220 | "torch_model = ResNet(BasicBlock, [3,4,6,3], num_classes=10)\n",
221 | "torch_model.load()\n",
222 | "torch_model.cuda()\n",
223 | "torch_model.params = list(torch_model.parameters())\n",
224 | "torch_model.eval()\n",
225 | "print(\"Done\") "
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "metadata": {},
231 | "source": [
232 | "## Datasets"
233 | ]
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "metadata": {},
238 | "source": [
239 | "In-distribution Datasets"
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": 4,
245 | "metadata": {},
246 | "outputs": [
247 | {
248 | "name": "stdout",
249 | "output_type": "stream",
250 | "text": [
251 | "Files already downloaded and verified\n"
252 | ]
253 | }
254 | ],
255 | "source": [
256 | "batch_size = 128\n",
257 | "mean = np.array([[0.4914, 0.4822, 0.4465]]).T\n",
258 | "\n",
259 | "std = np.array([[0.2023, 0.1994, 0.2010]]).T\n",
260 | "normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
261 | "\n",
262 | "transform_train = transforms.Compose([\n",
263 | " transforms.RandomCrop(32, padding=4),\n",
264 | " transforms.RandomHorizontalFlip(),\n",
265 | " transforms.ToTensor(),\n",
266 | " normalize\n",
267 | " \n",
268 | " ])\n",
269 | "transform_test = transforms.Compose([\n",
270 | " transforms.CenterCrop(size=(32, 32)),\n",
271 | " transforms.ToTensor(),\n",
272 | " normalize\n",
273 | " ])\n",
274 | "\n",
275 | "train_loader = torch.utils.data.DataLoader(\n",
276 | " datasets.CIFAR10('data', train=True, download=True,\n",
277 | " transform=transform_train),\n",
278 | " batch_size=batch_size, shuffle=True)\n",
279 | "test_loader = torch.utils.data.DataLoader(\n",
280 | " datasets.CIFAR10('data', train=False, transform=transform_test),\n",
281 | " batch_size=batch_size)\n"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": 5,
287 | "metadata": {
288 | "scrolled": true
289 | },
290 | "outputs": [
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | "Files already downloaded and verified\n"
296 | ]
297 | }
298 | ],
299 | "source": [
300 | "data_train = list(torch.utils.data.DataLoader(\n",
301 | " datasets.CIFAR10('data', train=True, download=True,\n",
302 | " transform=transform_test),\n",
303 | " batch_size=1, shuffle=False))"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": 6,
309 | "metadata": {},
310 | "outputs": [
311 | {
312 | "name": "stdout",
313 | "output_type": "stream",
314 | "text": [
315 | "Files already downloaded and verified\n"
316 | ]
317 | }
318 | ],
319 | "source": [
320 | "data = list(torch.utils.data.DataLoader(\n",
321 | " datasets.CIFAR10('data', train=False, download=True,\n",
322 | " transform=transform_test),\n",
323 | " batch_size=1, shuffle=False))"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 7,
329 | "metadata": {},
330 | "outputs": [
331 | {
332 | "name": "stdout",
333 | "output_type": "stream",
334 | "text": [
335 | "Accuracy: 0.9367\n"
336 | ]
337 | }
338 | ],
339 | "source": [
340 | "torch_model.eval()\n",
341 | "correct = 0\n",
342 | "total = 0\n",
343 | "for x,y in test_loader:\n",
344 | " x = x.cuda()\n",
345 | " y = y.numpy()\n",
346 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n",
347 | " total += y.shape[0]\n",
348 | "print(\"Accuracy: \",correct/total)"
349 | ]
350 | },
351 | {
352 | "cell_type": "markdown",
353 | "metadata": {},
354 | "source": [
355 | "Out-of-distribution Datasets"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "execution_count": 8,
361 | "metadata": {},
362 | "outputs": [
363 | {
364 | "name": "stdout",
365 | "output_type": "stream",
366 | "text": [
367 | "Files already downloaded and verified\n"
368 | ]
369 | }
370 | ],
371 | "source": [
372 | "cifar100 = list(torch.utils.data.DataLoader(\n",
373 | " datasets.CIFAR100('data', train=False, download=True,\n",
374 | " transform=transform_test),\n",
375 | " batch_size=1, shuffle=True))"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": 9,
381 | "metadata": {},
382 | "outputs": [
383 | {
384 | "name": "stdout",
385 | "output_type": "stream",
386 | "text": [
387 | "Using downloaded and verified file: data/test_32x32.mat\n"
388 | ]
389 | }
390 | ],
391 | "source": [
392 | "svhn = list(torch.utils.data.DataLoader(\n",
393 | " datasets.SVHN('data', split=\"test\", download=True,\n",
394 | " transform=transform_test),\n",
395 | " batch_size=1, shuffle=True))"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 10,
401 | "metadata": {},
402 | "outputs": [],
403 | "source": [
404 | "isun = list(torch.utils.data.DataLoader(\n",
405 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))"
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": 11,
411 | "metadata": {},
412 | "outputs": [],
413 | "source": [
414 | "lsun_c = list(torch.utils.data.DataLoader(\n",
415 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))"
416 | ]
417 | },
418 | {
419 | "cell_type": "code",
420 | "execution_count": 12,
421 | "metadata": {},
422 | "outputs": [],
423 | "source": [
424 | "lsun_r = list(torch.utils.data.DataLoader(\n",
425 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": 13,
431 | "metadata": {},
432 | "outputs": [],
433 | "source": [
434 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n",
435 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": 14,
441 | "metadata": {},
442 | "outputs": [],
443 | "source": [
444 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n",
445 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
446 | ]
447 | },
448 | {
449 | "cell_type": "markdown",
450 | "metadata": {},
451 | "source": [
452 | "## Code for Detecting OODs"
453 | ]
454 | },
455 | {
456 | "cell_type": "markdown",
457 | "metadata": {},
458 | "source": [
459 | " Extract predictions for train and test data "
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": 15,
465 | "metadata": {},
466 | "outputs": [
467 | {
468 | "name": "stdout",
469 | "output_type": "stream",
470 | "text": [
471 | "Done\n",
472 | "Done\n"
473 | ]
474 | }
475 | ],
476 | "source": [
477 | "train_preds = []\n",
478 | "train_confs = []\n",
479 | "train_logits = []\n",
480 | "for idx in range(0,len(data_train),128):\n",
481 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n",
482 | " \n",
483 | " logits = torch_model(batch)\n",
484 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
485 | " preds = np.argmax(confs,axis=1)\n",
486 | " logits = (logits.cpu().detach().numpy())\n",
487 | "\n",
488 | " train_confs.extend(np.max(confs,axis=1)) \n",
489 | " train_preds.extend(preds)\n",
490 | " train_logits.extend(logits)\n",
491 | "print(\"Done\")\n",
492 | "\n",
493 | "test_preds = []\n",
494 | "test_confs = []\n",
495 | "test_logits = []\n",
496 | "\n",
497 | "for idx in range(0,len(data),128):\n",
498 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n",
499 | " \n",
500 | " logits = torch_model(batch)\n",
501 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
502 | " preds = np.argmax(confs,axis=1)\n",
503 | " logits = (logits.cpu().detach().numpy())\n",
504 | "\n",
505 | " test_confs.extend(np.max(confs,axis=1)) \n",
506 | " test_preds.extend(preds)\n",
507 | " test_logits.extend(logits)\n",
508 | "print(\"Done\")"
509 | ]
510 | },
511 | {
512 | "cell_type": "markdown",
513 | "metadata": {},
514 | "source": [
515 | " Code for detecting OODs by identifying anomalies in correlations "
516 | ]
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": 16,
521 | "metadata": {},
522 | "outputs": [],
523 | "source": [
524 | "import calculate_log as callog\n",
525 | "\n",
526 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n",
527 | " average_results = {}\n",
528 | " for i in range(1,11):\n",
529 | " random.seed(i)\n",
530 | " \n",
531 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n",
532 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n",
533 | "\n",
534 | " validation = all_test_deviations[validation_indices]\n",
535 | " test_deviations = all_test_deviations[test_indices]\n",
536 | "\n",
537 | " t95 = validation.mean(axis=0)+10**-7\n",
538 | " if not normalize:\n",
539 | " t95 = np.ones_like(t95)\n",
540 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
541 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
542 | " \n",
543 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n",
544 | " for m in results:\n",
545 | " average_results[m] = average_results.get(m,0)+results[m]\n",
546 | " \n",
547 | " for m in average_results:\n",
548 | " average_results[m] /= i\n",
549 | " if verbose:\n",
550 | " callog.print_results(average_results)\n",
551 | " return average_results\n",
552 | "\n",
553 | "def cpu(ob):\n",
554 | " for i in range(len(ob)):\n",
555 | " for j in range(len(ob[i])):\n",
556 | " ob[i][j] = ob[i][j].cpu()\n",
557 | " return ob\n",
558 | "\n",
559 | "def cuda(ob):\n",
560 | " for i in range(len(ob)):\n",
561 | " for j in range(len(ob[i])):\n",
562 | " ob[i][j] = ob[i][j].cuda()\n",
563 | " return ob\n",
564 | "\n",
565 | "class Detector:\n",
566 | " def __init__(self):\n",
567 | " self.all_test_deviations = None\n",
568 | " self.mins = {}\n",
569 | " self.maxs = {}\n",
570 | " \n",
571 | " self.classes = range(10)\n",
572 | " \n",
573 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n",
574 | " for PRED in tqdm(self.classes):\n",
575 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n",
576 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n",
577 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n",
578 | " self.mins[PRED] = cpu(mins)\n",
579 | " self.maxs[PRED] = cpu(maxs)\n",
580 | " torch.cuda.empty_cache()\n",
581 | " \n",
582 | " def compute_test_deviations(self,POWERS=[10]):\n",
583 | " all_test_deviations = None\n",
584 | " test_classes = []\n",
585 | " for PRED in tqdm(self.classes):\n",
586 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n",
587 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n",
588 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n",
589 | " \n",
590 | " test_classes.extend([PRED]*len(test_indices))\n",
591 | " \n",
592 | " mins = cuda(self.mins[PRED])\n",
593 | " maxs = cuda(self.maxs[PRED])\n",
594 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n",
595 | " cpu(mins)\n",
596 | " cpu(maxs)\n",
597 | " if all_test_deviations is None:\n",
598 | " all_test_deviations = test_deviations\n",
599 | " else:\n",
600 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n",
601 | " torch.cuda.empty_cache()\n",
602 | " self.all_test_deviations = all_test_deviations\n",
603 | " \n",
604 | " self.test_classes = np.array(test_classes)\n",
605 | " \n",
606 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n",
607 | " ood_preds = []\n",
608 | " ood_confs = []\n",
609 | " \n",
610 | " for idx in range(0,len(ood),128):\n",
611 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n",
612 | " logits = torch_model(batch)\n",
613 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
614 | " preds = np.argmax(confs,axis=1)\n",
615 | " \n",
616 | " ood_confs.extend(np.max(confs,axis=1))\n",
617 | " ood_preds.extend(preds) \n",
618 | " torch.cuda.empty_cache()\n",
619 | " print(\"Done\")\n",
620 | " \n",
621 | " ood_classes = []\n",
622 | " all_ood_deviations = None\n",
623 | " for PRED in tqdm(self.classes):\n",
624 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n",
625 | " if len(ood_indices)==0:\n",
626 | " continue\n",
627 | " ood_classes.extend([PRED]*len(ood_indices))\n",
628 | " \n",
629 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n",
630 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n",
631 | " mins = cuda(self.mins[PRED])\n",
632 | " maxs = cuda(self.maxs[PRED])\n",
633 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n",
634 | " cpu(self.mins[PRED])\n",
635 | " cpu(self.maxs[PRED]) \n",
636 | " if all_ood_deviations is None:\n",
637 | " all_ood_deviations = ood_deviations\n",
638 | " else:\n",
639 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n",
640 | " torch.cuda.empty_cache()\n",
641 | " \n",
642 | " self.ood_classes = np.array(ood_classes)\n",
643 | " \n",
644 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n",
645 | " return average_results, self.all_test_deviations, all_ood_deviations\n"
646 | ]
647 | },
648 | {
649 | "cell_type": "markdown",
650 | "metadata": {},
651 | "source": [
652 | " Results
"
653 | ]
654 | },
655 | {
656 | "cell_type": "code",
657 | "execution_count": 17,
658 | "metadata": {},
659 | "outputs": [
660 | {
661 | "data": {
662 | "application/vnd.jupyter.widget-view+json": {
663 | "model_id": "c2ff5c67696a455a888f8fe4b19339c9",
664 | "version_major": 2,
665 | "version_minor": 0
666 | },
667 | "text/plain": [
668 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
669 | ]
670 | },
671 | "metadata": {},
672 | "output_type": "display_data"
673 | },
674 | {
675 | "name": "stdout",
676 | "output_type": "stream",
677 | "text": [
678 | "\n"
679 | ]
680 | },
681 | {
682 | "data": {
683 | "application/vnd.jupyter.widget-view+json": {
684 | "model_id": "e816c6c68303437a8a4fdc97091a8853",
685 | "version_major": 2,
686 | "version_minor": 0
687 | },
688 | "text/plain": [
689 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
690 | ]
691 | },
692 | "metadata": {},
693 | "output_type": "display_data"
694 | },
695 | {
696 | "name": "stdout",
697 | "output_type": "stream",
698 | "text": [
699 | "\n",
700 | " TNR AUROC DTACC AUIN AUOUT \n",
701 | " 99.257 99.831 98.077 99.827 99.829\n",
702 | "LSUN (R)\n",
703 | "Done\n"
704 | ]
705 | },
706 | {
707 | "data": {
708 | "application/vnd.jupyter.widget-view+json": {
709 | "model_id": "42829af13dc84e7fadccd7014a655912",
710 | "version_major": 2,
711 | "version_minor": 0
712 | },
713 | "text/plain": [
714 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
715 | ]
716 | },
717 | "metadata": {},
718 | "output_type": "display_data"
719 | },
720 | {
721 | "name": "stdout",
722 | "output_type": "stream",
723 | "text": [
724 | "\n",
725 | " TNR AUROC DTACC AUIN AUOUT \n",
726 | " 99.585 99.889 98.641 99.866 99.899\n",
727 | "LSUN (C)\n",
728 | "Done\n"
729 | ]
730 | },
731 | {
732 | "data": {
733 | "application/vnd.jupyter.widget-view+json": {
734 | "model_id": "da08da96632648d0b841cbc97fe61456",
735 | "version_major": 2,
736 | "version_minor": 0
737 | },
738 | "text/plain": [
739 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
740 | ]
741 | },
742 | "metadata": {},
743 | "output_type": "display_data"
744 | },
745 | {
746 | "name": "stdout",
747 | "output_type": "stream",
748 | "text": [
749 | "\n",
750 | " TNR AUROC DTACC AUIN AUOUT \n",
751 | " 89.798 97.796 92.591 97.433 98.172\n",
752 | "TinyImgNet (R)\n",
753 | "Done\n"
754 | ]
755 | },
756 | {
757 | "data": {
758 | "application/vnd.jupyter.widget-view+json": {
759 | "model_id": "430ab4eefa264916afbf916f35418a05",
760 | "version_major": 2,
761 | "version_minor": 0
762 | },
763 | "text/plain": [
764 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
765 | ]
766 | },
767 | "metadata": {},
768 | "output_type": "display_data"
769 | },
770 | {
771 | "name": "stdout",
772 | "output_type": "stream",
773 | "text": [
774 | "\n",
775 | " TNR AUROC DTACC AUIN AUOUT \n",
776 | " 98.746 99.717 97.797 99.648 99.765\n",
777 | "TinyImgNet (C)\n",
778 | "Done\n"
779 | ]
780 | },
781 | {
782 | "data": {
783 | "application/vnd.jupyter.widget-view+json": {
784 | "model_id": "4168c5ab119d4577a35895922e4d8430",
785 | "version_major": 2,
786 | "version_minor": 0
787 | },
788 | "text/plain": [
789 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
790 | ]
791 | },
792 | "metadata": {},
793 | "output_type": "display_data"
794 | },
795 | {
796 | "name": "stdout",
797 | "output_type": "stream",
798 | "text": [
799 | "\n",
800 | " TNR AUROC DTACC AUIN AUOUT \n",
801 | " 96.666 99.242 96.074 99.082 99.377\n",
802 | "SVHN\n",
803 | "Done\n"
804 | ]
805 | },
806 | {
807 | "data": {
808 | "application/vnd.jupyter.widget-view+json": {
809 | "model_id": "b91cf4a3135f49aaa74884b0d4149131",
810 | "version_major": 2,
811 | "version_minor": 0
812 | },
813 | "text/plain": [
814 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
815 | ]
816 | },
817 | "metadata": {},
818 | "output_type": "display_data"
819 | },
820 | {
821 | "name": "stdout",
822 | "output_type": "stream",
823 | "text": [
824 | "\n",
825 | " TNR AUROC DTACC AUIN AUOUT \n",
826 | " 97.614 99.502 96.708 98.446 99.831\n",
827 | "CIFAR-100\n",
828 | "Done\n"
829 | ]
830 | },
831 | {
832 | "data": {
833 | "application/vnd.jupyter.widget-view+json": {
834 | "model_id": "e51aaecf0a214053a20b00b2e58b1e6e",
835 | "version_major": 2,
836 | "version_minor": 0
837 | },
838 | "text/plain": [
839 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
840 | ]
841 | },
842 | "metadata": {},
843 | "output_type": "display_data"
844 | },
845 | {
846 | "name": "stdout",
847 | "output_type": "stream",
848 | "text": [
849 | "\n",
850 | " TNR AUROC DTACC AUIN AUOUT \n",
851 | " 32.896 79.015 71.711 74.991 80.666\n"
852 | ]
853 | }
854 | ],
855 | "source": [
856 | "def G_p(ob, p):\n",
857 | " temp = ob.detach()\n",
858 | " \n",
859 | " temp = temp**p\n",
860 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n",
861 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n",
862 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n",
863 | " \n",
864 | " return temp\n",
865 | "\n",
866 | "\n",
867 | "detector = Detector()\n",
868 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n",
869 | "\n",
870 | "detector.compute_test_deviations(POWERS=range(1,11))\n",
871 | "\n",
872 | "print(\"iSUN\")\n",
873 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n",
874 | "print(\"LSUN (R)\")\n",
875 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n",
876 | "print(\"LSUN (C)\")\n",
877 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n",
878 | "print(\"TinyImgNet (R)\")\n",
879 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n",
880 | "print(\"TinyImgNet (C)\")\n",
881 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n",
882 | "print(\"SVHN\")\n",
883 | "svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))\n",
884 | "print(\"CIFAR-100\")\n",
885 | "c100_results = detector.compute_ood_deviations(cifar100,POWERS=range(1,11))"
886 | ]
887 | }
888 | ],
889 | "metadata": {
890 | "kernelspec": {
891 | "display_name": "Python 3",
892 | "language": "python",
893 | "name": "python3"
894 | },
895 | "language_info": {
896 | "codemirror_mode": {
897 | "name": "ipython",
898 | "version": 3
899 | },
900 | "file_extension": ".py",
901 | "mimetype": "text/x-python",
902 | "name": "python",
903 | "nbconvert_exporter": "python",
904 | "pygments_lexer": "ipython3",
905 | "version": "3.6.9"
906 | }
907 | },
908 | "nbformat": 4,
909 | "nbformat_minor": 2
910 | }
911 |
--------------------------------------------------------------------------------
/ResNet_Cifar100.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "ResNet: Cifar100
"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "## Imports"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from __future__ import division,print_function\n",
24 | "\n",
25 | "%matplotlib inline\n",
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "\n",
29 | "import sys\n",
30 | "from tqdm import tqdm_notebook as tqdm\n",
31 | "\n",
32 | "import random\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "import math\n",
35 | "\n",
36 | "import numpy as np\n",
37 | "\n",
38 | "import torch\n",
39 | "import torch.nn as nn\n",
40 | "import torch.nn.functional as F\n",
41 | "import torch.optim as optim\n",
42 | "import torch.nn.init as init\n",
43 | "from torch.autograd import Variable, grad\n",
44 | "from torchvision import datasets, transforms\n",
45 | "from torch.nn.parameter import Parameter\n",
46 | "\n",
47 | "import calculate_log as callog\n",
48 | "\n",
49 | "import warnings\n",
50 | "warnings.filterwarnings('ignore')"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "torch.cuda.set_device(1) #Select the GPU"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "## Model definition"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "name": "stdout",
76 | "output_type": "stream",
77 | "text": [
78 | "Done\n"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "def conv3x3(in_planes, out_planes, stride=1):\n",
84 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
85 | "\n",
86 | "class BasicBlock(nn.Module):\n",
87 | " expansion = 1\n",
88 | "\n",
89 | " def __init__(self, in_planes, planes, stride=1):\n",
90 | " super(BasicBlock, self).__init__()\n",
91 | " self.conv1 = conv3x3(in_planes, planes, stride)\n",
92 | " self.bn1 = nn.BatchNorm2d(planes)\n",
93 | " self.conv2 = conv3x3(planes, planes)\n",
94 | " self.bn2 = nn.BatchNorm2d(planes)\n",
95 | "\n",
96 | " self.shortcut = nn.Sequential()\n",
97 | " if stride != 1 or in_planes != self.expansion*planes:\n",
98 | " self.shortcut = nn.Sequential(\n",
99 | " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n",
100 | " nn.BatchNorm2d(self.expansion*planes)\n",
101 | " )\n",
102 | " \n",
103 | " def forward(self, x):\n",
104 | " t = self.conv1(x)\n",
105 | " out = F.relu(self.bn1(t))\n",
106 | " torch_model.record(t)\n",
107 | " torch_model.record(out)\n",
108 | " t = self.conv2(out)\n",
109 | " out = self.bn2(self.conv2(out))\n",
110 | " torch_model.record(t)\n",
111 | " torch_model.record(out)\n",
112 | " t = self.shortcut(x)\n",
113 | " out += t\n",
114 | " torch_model.record(t)\n",
115 | " out = F.relu(out)\n",
116 | " torch_model.record(out)\n",
117 | " \n",
118 | " return out\n",
119 | "\n",
120 | "class ResNet(nn.Module):\n",
121 | " def __init__(self, block, num_blocks, num_classes=10):\n",
122 | " super(ResNet, self).__init__()\n",
123 | " self.in_planes = 64\n",
124 | "\n",
125 | " self.conv1 = conv3x3(3,64)\n",
126 | " self.bn1 = nn.BatchNorm2d(64)\n",
127 | " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
128 | " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
129 | " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
130 | " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
131 | " self.linear = nn.Linear(512*block.expansion, num_classes)\n",
132 | " \n",
133 | " self.collecting = False\n",
134 | " \n",
135 | " def _make_layer(self, block, planes, num_blocks, stride):\n",
136 | " strides = [stride] + [1]*(num_blocks-1)\n",
137 | " layers = []\n",
138 | " for stride in strides:\n",
139 | " layers.append(block(self.in_planes, planes, stride))\n",
140 | " self.in_planes = planes * block.expansion\n",
141 | " return nn.Sequential(*layers)\n",
142 | " \n",
143 | " def forward(self, x):\n",
144 | " out = F.relu(self.bn1(self.conv1(x)))\n",
145 | " out = self.layer1(out)\n",
146 | " out = self.layer2(out)\n",
147 | " out = self.layer3(out)\n",
148 | " out = self.layer4(out)\n",
149 | " out = F.avg_pool2d(out, 4)\n",
150 | " out = out.view(out.size(0), -1)\n",
151 | " y = self.linear(out)\n",
152 | " return y\n",
153 | " \n",
154 | " def record(self, t):\n",
155 | " if self.collecting:\n",
156 | " self.gram_feats.append(t)\n",
157 | " \n",
158 | " def gram_feature_list(self,x):\n",
159 | " self.collecting = True\n",
160 | " self.gram_feats = []\n",
161 | " self.forward(x)\n",
162 | " self.collecting = False\n",
163 | " temp = self.gram_feats\n",
164 | " self.gram_feats = []\n",
165 | " return temp\n",
166 | " \n",
167 | " def load(self, path=\"resnet_cifar100.pth\"):\n",
168 | " tm = torch.load(path,map_location=\"cpu\") \n",
169 | " self.load_state_dict(tm)\n",
170 | " \n",
171 | " def get_min_max(self, data, power):\n",
172 | " mins = []\n",
173 | " maxs = []\n",
174 | " \n",
175 | " for i in range(0,len(data),128):\n",
176 | " batch = data[i:i+128].cuda()\n",
177 | " feat_list = self.gram_feature_list(batch)\n",
178 | " for L,feat_L in enumerate(feat_list):\n",
179 | " if L==len(mins):\n",
180 | " mins.append([None]*len(power))\n",
181 | " maxs.append([None]*len(power))\n",
182 | " \n",
183 | " for p,P in enumerate(power):\n",
184 | " g_p = G_p(feat_L,P)\n",
185 | " \n",
186 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n",
187 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n",
188 | " \n",
189 | " if mins[L][p] is None:\n",
190 | " mins[L][p] = current_min\n",
191 | " maxs[L][p] = current_max\n",
192 | " else:\n",
193 | " mins[L][p] = torch.min(current_min,mins[L][p])\n",
194 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n",
195 | " \n",
196 | " return mins,maxs\n",
197 | " \n",
198 | " def get_deviations(self,data,power,mins,maxs):\n",
199 | " deviations = []\n",
200 | " \n",
201 | " for i in range(0,len(data),128): \n",
202 | " batch = data[i:i+128].cuda()\n",
203 | " feat_list = self.gram_feature_list(batch)\n",
204 | " batch_deviations = []\n",
205 | " for L,feat_L in enumerate(feat_list):\n",
206 | " dev = 0\n",
207 | " for p,P in enumerate(power):\n",
208 | " g_p = G_p(feat_L,P)\n",
209 | " \n",
210 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
211 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
212 | " batch_deviations.append(dev.cpu().detach().numpy())\n",
213 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n",
214 | " deviations.append(batch_deviations)\n",
215 | " deviations = np.concatenate(deviations,axis=0)\n",
216 | " \n",
217 | " return deviations\n",
218 | "\n",
219 | "\n",
220 | "torch_model = ResNet(BasicBlock, [3,4,6,3], num_classes=100)\n",
221 | "torch_model.load()\n",
222 | "torch_model.cuda()\n",
223 | "torch_model.params = list(torch_model.parameters())\n",
224 | "torch_model.eval()\n",
225 | "print(\"Done\") "
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "metadata": {},
231 | "source": [
232 | "## Datasets"
233 | ]
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "metadata": {},
238 | "source": [
239 | "In-distribution Datasets"
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": 4,
245 | "metadata": {},
246 | "outputs": [
247 | {
248 | "name": "stdout",
249 | "output_type": "stream",
250 | "text": [
251 | "Files already downloaded and verified\n"
252 | ]
253 | }
254 | ],
255 | "source": [
256 | "batch_size = 128\n",
257 | "mean = np.array([[0.4914, 0.4822, 0.4465]]).T\n",
258 | "\n",
259 | "std = np.array([[0.2023, 0.1994, 0.2010]]).T\n",
260 | "normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
261 | "\n",
262 | "transform_train = transforms.Compose([\n",
263 | " transforms.RandomCrop(32, padding=4),\n",
264 | " transforms.RandomHorizontalFlip(),\n",
265 | " transforms.ToTensor(),\n",
266 | " normalize\n",
267 | " \n",
268 | " ])\n",
269 | "\n",
270 | "transform_test = transforms.Compose([\n",
271 | " transforms.CenterCrop(size=(32, 32)),\n",
272 | " transforms.ToTensor(),\n",
273 | " normalize\n",
274 | " ])\n",
275 | "\n",
276 | "train_loader = torch.utils.data.DataLoader(\n",
277 | " datasets.CIFAR100('data', train=True, download=True,\n",
278 | " transform=transform_train),\n",
279 | " batch_size=batch_size, shuffle=True)\n",
280 | "test_loader = torch.utils.data.DataLoader(\n",
281 | " datasets.CIFAR100('data', train=False, transform=transform_test),\n",
282 | " batch_size=batch_size)\n"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": 5,
288 | "metadata": {
289 | "scrolled": true
290 | },
291 | "outputs": [
292 | {
293 | "name": "stdout",
294 | "output_type": "stream",
295 | "text": [
296 | "Files already downloaded and verified\n"
297 | ]
298 | }
299 | ],
300 | "source": [
301 | "data_train = list(torch.utils.data.DataLoader(\n",
302 | " datasets.CIFAR100('data', train=True, download=True,\n",
303 | " transform=transform_test),\n",
304 | " batch_size=1, shuffle=False))"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 6,
310 | "metadata": {
311 | "scrolled": true
312 | },
313 | "outputs": [
314 | {
315 | "name": "stdout",
316 | "output_type": "stream",
317 | "text": [
318 | "Files already downloaded and verified\n"
319 | ]
320 | }
321 | ],
322 | "source": [
323 | "data = list(torch.utils.data.DataLoader(\n",
324 | " datasets.CIFAR100('data', train=False, download=True,\n",
325 | " transform=transform_test),\n",
326 | " batch_size=1, shuffle=False))"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": 7,
332 | "metadata": {},
333 | "outputs": [
334 | {
335 | "name": "stdout",
336 | "output_type": "stream",
337 | "text": [
338 | "Accuracy: 0.7834\n"
339 | ]
340 | }
341 | ],
342 | "source": [
343 | "torch_model.eval()\n",
344 | "correct = 0\n",
345 | "total = 0\n",
346 | "for x,y in test_loader:\n",
347 | " x = x.cuda()\n",
348 | " y = y.numpy()\n",
349 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n",
350 | " total += y.shape[0]\n",
351 | "print(\"Accuracy: \",correct/total)\n"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {},
357 | "source": [
358 | "Out-of-distribution Datasets"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 8,
364 | "metadata": {
365 | "scrolled": true
366 | },
367 | "outputs": [
368 | {
369 | "name": "stdout",
370 | "output_type": "stream",
371 | "text": [
372 | "Files already downloaded and verified\n"
373 | ]
374 | }
375 | ],
376 | "source": [
377 | "cifar10 = list(torch.utils.data.DataLoader(\n",
378 | " datasets.CIFAR10('data', train=False, download=True,\n",
379 | " transform=transform_test),\n",
380 | " batch_size=1, shuffle=True))"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": 9,
386 | "metadata": {},
387 | "outputs": [
388 | {
389 | "name": "stdout",
390 | "output_type": "stream",
391 | "text": [
392 | "Using downloaded and verified file: data/test_32x32.mat\n"
393 | ]
394 | }
395 | ],
396 | "source": [
397 | "svhn = list(torch.utils.data.DataLoader(\n",
398 | " datasets.SVHN('data', split=\"test\", download=True,\n",
399 | " transform=transform_test),\n",
400 | " batch_size=1, shuffle=True))"
401 | ]
402 | },
403 | {
404 | "cell_type": "code",
405 | "execution_count": 10,
406 | "metadata": {},
407 | "outputs": [],
408 | "source": [
409 | "isun = list(torch.utils.data.DataLoader(\n",
410 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "execution_count": 11,
416 | "metadata": {},
417 | "outputs": [],
418 | "source": [
419 | "lsun_c = list(torch.utils.data.DataLoader(\n",
420 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": 12,
426 | "metadata": {},
427 | "outputs": [],
428 | "source": [
429 | "lsun_r = list(torch.utils.data.DataLoader(\n",
430 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": 13,
436 | "metadata": {},
437 | "outputs": [],
438 | "source": [
439 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n",
440 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": 14,
446 | "metadata": {},
447 | "outputs": [],
448 | "source": [
449 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n",
450 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
451 | ]
452 | },
453 | {
454 | "cell_type": "markdown",
455 | "metadata": {},
456 | "source": [
457 | "## Code for Detecting OODs"
458 | ]
459 | },
460 | {
461 | "cell_type": "markdown",
462 | "metadata": {},
463 | "source": [
464 | " Extract predictions for train and test data "
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": 15,
470 | "metadata": {},
471 | "outputs": [
472 | {
473 | "name": "stdout",
474 | "output_type": "stream",
475 | "text": [
476 | "Done\n",
477 | "Done\n"
478 | ]
479 | }
480 | ],
481 | "source": [
482 | "train_preds = []\n",
483 | "train_confs = []\n",
484 | "train_logits = []\n",
485 | "for idx in range(0,len(data_train),128):\n",
486 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n",
487 | " \n",
488 | " logits = torch_model(batch)\n",
489 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
490 | " preds = np.argmax(confs,axis=1)\n",
491 | " logits = (logits.cpu().detach().numpy())#**2)#.sum(axis=1)\n",
492 | "\n",
493 | " train_confs.extend(np.max(confs,axis=1)) \n",
494 | " train_preds.extend(preds)\n",
495 | " train_logits.extend(logits)\n",
496 | "print(\"Done\")\n",
497 | "\n",
498 | "test_preds = []\n",
499 | "test_confs = []\n",
500 | "test_logits = []\n",
501 | "\n",
502 | "for idx in range(0,len(data),128):\n",
503 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n",
504 | " \n",
505 | " logits = torch_model(batch)\n",
506 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
507 | " preds = np.argmax(confs,axis=1)\n",
508 | " logits = (logits.cpu().detach().numpy())#**2)#.sum(axis=1)\n",
509 | "\n",
510 | " test_confs.extend(np.max(confs,axis=1)) \n",
511 | " test_preds.extend(preds)\n",
512 | " test_logits.extend(logits)\n",
513 | "print(\"Done\")"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {},
519 | "source": [
520 | " Code for detecting OODs by identifying anomalies in correlations "
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": 16,
526 | "metadata": {},
527 | "outputs": [],
528 | "source": [
529 | "import calculate_log as callog\n",
530 | "\n",
531 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n",
532 | " average_results = {}\n",
533 | " for i in range(1,11):\n",
534 | " random.seed(i)\n",
535 | " \n",
536 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n",
537 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n",
538 | " \n",
539 | " validation = all_test_deviations[validation_indices]\n",
540 | " test_deviations = all_test_deviations[test_indices]\n",
541 | "\n",
542 | " t95 = validation.mean(axis=0)+10**-7\n",
543 | " if not normalize:\n",
544 | " t95 = np.ones_like(t95)\n",
545 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
546 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
547 | " \n",
548 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n",
549 | " for m in results:\n",
550 | " average_results[m] = average_results.get(m,0)+results[m]\n",
551 | " \n",
552 | " for m in average_results:\n",
553 | " average_results[m] /= i\n",
554 | " if verbose:\n",
555 | " callog.print_results(average_results)\n",
556 | " return average_results\n",
557 | "\n",
558 | "\n",
559 | "def cpu(ob):\n",
560 | " for i in range(len(ob)):\n",
561 | " for j in range(len(ob[i])):\n",
562 | " ob[i][j] = ob[i][j].cpu()\n",
563 | " return ob\n",
564 | "\n",
565 | "def cuda(ob):\n",
566 | " for i in range(len(ob)):\n",
567 | " for j in range(len(ob[i])):\n",
568 | " ob[i][j] = ob[i][j].cuda()\n",
569 | " return ob\n",
570 | "\n",
571 | "class Detector:\n",
572 | " def __init__(self):\n",
573 | " self.all_test_deviations = None\n",
574 | " self.mins = {}\n",
575 | " self.maxs = {}\n",
576 | " \n",
577 | " self.classes = range(100)\n",
578 | " \n",
579 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n",
580 | " for PRED in tqdm(self.classes):\n",
581 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n",
582 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n",
583 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n",
584 | " self.mins[PRED] = cpu(mins)\n",
585 | " self.maxs[PRED] = cpu(maxs)\n",
586 | " torch.cuda.empty_cache()\n",
587 | " \n",
588 | " def compute_test_deviations(self,POWERS=[10]):\n",
589 | " all_test_deviations = None\n",
590 | " for PRED in tqdm(self.classes):\n",
591 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n",
592 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n",
593 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n",
594 | " mins = cuda(self.mins[PRED])\n",
595 | " maxs = cuda(self.maxs[PRED])\n",
596 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n",
597 | " cpu(mins)\n",
598 | " cpu(maxs)\n",
599 | " if all_test_deviations is None:\n",
600 | " all_test_deviations = test_deviations\n",
601 | " else:\n",
602 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n",
603 | " torch.cuda.empty_cache()\n",
604 | " self.all_test_deviations = all_test_deviations\n",
605 | " \n",
606 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n",
607 | " ood_preds = []\n",
608 | " ood_confs = []\n",
609 | " \n",
610 | " for idx in range(0,len(ood),128):\n",
611 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n",
612 | " logits = torch_model(batch)\n",
613 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
614 | " preds = np.argmax(confs,axis=1)\n",
615 | " \n",
616 | " ood_confs.extend(np.max(confs,axis=1))\n",
617 | " ood_preds.extend(preds) \n",
618 | " torch.cuda.empty_cache()\n",
619 | " print(\"Done\")\n",
620 | " \n",
621 | " all_ood_deviations = None\n",
622 | " for PRED in tqdm(self.classes):\n",
623 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n",
624 | " if len(ood_indices)==0:\n",
625 | " continue\n",
626 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n",
627 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n",
628 | " mins = cuda(self.mins[PRED])\n",
629 | " maxs = cuda(self.maxs[PRED])\n",
630 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n",
631 | " cpu(self.mins[PRED])\n",
632 | " cpu(self.maxs[PRED]) \n",
633 | " if all_ood_deviations is None:\n",
634 | " all_ood_deviations = ood_deviations\n",
635 | " else:\n",
636 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n",
637 | " torch.cuda.empty_cache()\n",
638 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n",
639 | " return average_results, self.all_test_deviations, all_ood_deviations\n"
640 | ]
641 | },
642 | {
643 | "cell_type": "markdown",
644 | "metadata": {},
645 | "source": [
646 | " Results
"
647 | ]
648 | },
649 | {
650 | "cell_type": "code",
651 | "execution_count": 17,
652 | "metadata": {
653 | "scrolled": false
654 | },
655 | "outputs": [
656 | {
657 | "data": {
658 | "application/vnd.jupyter.widget-view+json": {
659 | "model_id": "b356f59735b74cccbf45df07cd7a9a50",
660 | "version_major": 2,
661 | "version_minor": 0
662 | },
663 | "text/plain": [
664 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
665 | ]
666 | },
667 | "metadata": {},
668 | "output_type": "display_data"
669 | },
670 | {
671 | "name": "stdout",
672 | "output_type": "stream",
673 | "text": [
674 | "\n"
675 | ]
676 | },
677 | {
678 | "data": {
679 | "application/vnd.jupyter.widget-view+json": {
680 | "model_id": "16234970aabe4b4daccf0dc1e74a0f79",
681 | "version_major": 2,
682 | "version_minor": 0
683 | },
684 | "text/plain": [
685 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
686 | ]
687 | },
688 | "metadata": {},
689 | "output_type": "display_data"
690 | },
691 | {
692 | "name": "stdout",
693 | "output_type": "stream",
694 | "text": [
695 | "\n",
696 | "iSUN\n",
697 | "Done\n"
698 | ]
699 | },
700 | {
701 | "data": {
702 | "application/vnd.jupyter.widget-view+json": {
703 | "model_id": "326d25cb5b24408e870cd8d64513b538",
704 | "version_major": 2,
705 | "version_minor": 0
706 | },
707 | "text/plain": [
708 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
709 | ]
710 | },
711 | "metadata": {},
712 | "output_type": "display_data"
713 | },
714 | {
715 | "name": "stdout",
716 | "output_type": "stream",
717 | "text": [
718 | "\n",
719 | " TNR AUROC DTACC AUIN AUOUT \n",
720 | " 94.756 98.814 94.976 98.755 98.803\n",
721 | "LSUN (R)\n",
722 | "Done\n"
723 | ]
724 | },
725 | {
726 | "data": {
727 | "application/vnd.jupyter.widget-view+json": {
728 | "model_id": "3cc2523b5d4b4d3cbecaa37e04ad9b2a",
729 | "version_major": 2,
730 | "version_minor": 0
731 | },
732 | "text/plain": [
733 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
734 | ]
735 | },
736 | "metadata": {},
737 | "output_type": "display_data"
738 | },
739 | {
740 | "name": "stdout",
741 | "output_type": "stream",
742 | "text": [
743 | "\n",
744 | " TNR AUROC DTACC AUIN AUOUT \n",
745 | " 96.606 99.202 95.969 99.171 99.189\n",
746 | "LSUN (C)\n",
747 | "Done\n"
748 | ]
749 | },
750 | {
751 | "data": {
752 | "application/vnd.jupyter.widget-view+json": {
753 | "model_id": "f5b3b69bc8a44dbebb3dcf663abb0fd0",
754 | "version_major": 2,
755 | "version_minor": 0
756 | },
757 | "text/plain": [
758 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
759 | ]
760 | },
761 | "metadata": {},
762 | "output_type": "display_data"
763 | },
764 | {
765 | "name": "stdout",
766 | "output_type": "stream",
767 | "text": [
768 | "\n",
769 | " TNR AUROC DTACC AUIN AUOUT \n",
770 | " 64.800 92.132 84.186 90.995 93.004\n",
771 | "TinyImgNet (R)\n",
772 | "Done\n"
773 | ]
774 | },
775 | {
776 | "data": {
777 | "application/vnd.jupyter.widget-view+json": {
778 | "model_id": "837ccf6cbff94839b97117303d56ceb2",
779 | "version_major": 2,
780 | "version_minor": 0
781 | },
782 | "text/plain": [
783 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
784 | ]
785 | },
786 | "metadata": {},
787 | "output_type": "display_data"
788 | },
789 | {
790 | "name": "stdout",
791 | "output_type": "stream",
792 | "text": [
793 | "\n",
794 | " TNR AUROC DTACC AUIN AUOUT \n",
795 | " 94.789 98.898 94.960 98.771 98.956\n",
796 | "TinyImgNet (C)\n",
797 | "Done\n"
798 | ]
799 | },
800 | {
801 | "data": {
802 | "application/vnd.jupyter.widget-view+json": {
803 | "model_id": "5afe4ee2584e4c17bb24768b1419df3d",
804 | "version_major": 2,
805 | "version_minor": 0
806 | },
807 | "text/plain": [
808 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
809 | ]
810 | },
811 | "metadata": {},
812 | "output_type": "display_data"
813 | },
814 | {
815 | "name": "stdout",
816 | "output_type": "stream",
817 | "text": [
818 | "\n",
819 | " TNR AUROC DTACC AUIN AUOUT \n",
820 | " 88.478 97.661 92.167 97.392 97.891\n",
821 | "SVHN\n",
822 | "Done\n"
823 | ]
824 | },
825 | {
826 | "data": {
827 | "application/vnd.jupyter.widget-view+json": {
828 | "model_id": "77a2d847101143f187a8376fb47f8242",
829 | "version_major": 2,
830 | "version_minor": 0
831 | },
832 | "text/plain": [
833 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
834 | ]
835 | },
836 | "metadata": {},
837 | "output_type": "display_data"
838 | },
839 | {
840 | "name": "stdout",
841 | "output_type": "stream",
842 | "text": [
843 | "\n",
844 | " TNR AUROC DTACC AUIN AUOUT \n",
845 | " 80.752 96.049 89.610 90.480 98.487\n",
846 | "CIFAR-10\n",
847 | "Done\n"
848 | ]
849 | },
850 | {
851 | "data": {
852 | "application/vnd.jupyter.widget-view+json": {
853 | "model_id": "d7e39e3b74c44ca4b80409baddbe0207",
854 | "version_major": 2,
855 | "version_minor": 0
856 | },
857 | "text/plain": [
858 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
859 | ]
860 | },
861 | "metadata": {},
862 | "output_type": "display_data"
863 | },
864 | {
865 | "name": "stdout",
866 | "output_type": "stream",
867 | "text": [
868 | "\n",
869 | " TNR AUROC DTACC AUIN AUOUT \n",
870 | " 12.196 67.951 63.475 66.102 66.957\n"
871 | ]
872 | }
873 | ],
874 | "source": [
875 | "def G_p(ob, p):\n",
876 | " temp = ob.detach()\n",
877 | " \n",
878 | " temp = temp**p\n",
879 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n",
880 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n",
881 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n",
882 | " \n",
883 | " return temp\n",
884 | "\n",
885 | "\n",
886 | "detector = Detector()\n",
887 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n",
888 | "\n",
889 | "detector.compute_test_deviations(POWERS=range(1,11))\n",
890 | "\n",
891 | "print(\"iSUN\")\n",
892 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n",
893 | "print(\"LSUN (R)\")\n",
894 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n",
895 | "print(\"LSUN (C)\")\n",
896 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n",
897 | "print(\"TinyImgNet (R)\")\n",
898 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n",
899 | "print(\"TinyImgNet (C)\")\n",
900 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n",
901 | "print(\"SVHN\")\n",
902 | "svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))\n",
903 | "print(\"CIFAR-10\")\n",
904 | "c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))"
905 | ]
906 | },
907 | {
908 | "cell_type": "code",
909 | "execution_count": null,
910 | "metadata": {},
911 | "outputs": [],
912 | "source": []
913 | }
914 | ],
915 | "metadata": {
916 | "kernelspec": {
917 | "display_name": "Python 3",
918 | "language": "python",
919 | "name": "python3"
920 | },
921 | "language_info": {
922 | "codemirror_mode": {
923 | "name": "ipython",
924 | "version": 3
925 | },
926 | "file_extension": ".py",
927 | "mimetype": "text/x-python",
928 | "name": "python",
929 | "nbconvert_exporter": "python",
930 | "pygments_lexer": "ipython3",
931 | "version": "3.6.9"
932 | }
933 | },
934 | "nbformat": 4,
935 | "nbformat_minor": 2
936 | }
937 |
--------------------------------------------------------------------------------
/DenseNet_SVHN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "DenseNet: SVHN
"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "## Imports"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from __future__ import division,print_function\n",
24 | "\n",
25 | "%matplotlib inline\n",
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "\n",
29 | "import sys\n",
30 | "from tqdm import tqdm_notebook as tqdm\n",
31 | "\n",
32 | "import random\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "import math\n",
35 | "\n",
36 | "import numpy as np\n",
37 | "\n",
38 | "import torch\n",
39 | "import torch.nn as nn\n",
40 | "import torch.nn.functional as F\n",
41 | "import torch.optim as optim\n",
42 | "import torch.nn.init as init\n",
43 | "from torch.autograd import Variable, grad\n",
44 | "from torchvision import datasets, transforms\n",
45 | "from torch.nn.parameter import Parameter\n",
46 | "\n",
47 | "import calculate_log as callog\n",
48 | "\n",
49 | "import warnings\n",
50 | "warnings.filterwarnings('ignore')"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "torch.cuda.set_device(2) #Select the GPU"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "## Model definition"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "name": "stdout",
76 | "output_type": "stream",
77 | "text": [
78 | "Done\n"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "def conv3x3(in_planes, out_planes, stride=1):\n",
84 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
85 | "\n",
86 | "class BottleneckBlock(nn.Module):\n",
87 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n",
88 | " super(BottleneckBlock, self).__init__()\n",
89 | " inter_planes = out_planes * 4\n",
90 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
91 | " self.relu = nn.ReLU(inplace=True)\n",
92 | " self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,\n",
93 | " padding=0, bias=False)\n",
94 | " self.bn2 = nn.BatchNorm2d(inter_planes)\n",
95 | " self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,\n",
96 | " padding=1, bias=False)\n",
97 | " self.droprate = dropRate\n",
98 | " \n",
99 | " def forward(self, x):\n",
100 | " \n",
101 | " out = self.conv1(self.relu(self.bn1(x)))\n",
102 | " \n",
103 | " torch_model.record(out)\n",
104 | " \n",
105 | " if self.droprate > 0:\n",
106 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
107 | " \n",
108 | " out = self.conv2(self.relu(self.bn2(out)))\n",
109 | " torch_model.record(out)\n",
110 | " \n",
111 | " if self.droprate > 0:\n",
112 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
113 | " return torch.cat([x, out], 1)\n",
114 | "\n",
115 | "class TransitionBlock(nn.Module):\n",
116 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n",
117 | " super(TransitionBlock, self).__init__()\n",
118 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
119 | " self.relu = nn.ReLU(inplace=True)\n",
120 | " self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n",
121 | " padding=0, bias=False)\n",
122 | " self.droprate = dropRate\n",
123 | " \n",
124 | " def forward(self, x):\n",
125 | " t=self.relu(self.bn1(x))\n",
126 | " out = self.conv1(t)\n",
127 | " \n",
128 | " torch_model.record(t)\n",
129 | " torch_model.record(out)\n",
130 | " \n",
131 | " if self.droprate > 0:\n",
132 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
133 | " return F.avg_pool2d(out, 2)\n",
134 | "\n",
135 | "class DenseBlock(nn.Module):\n",
136 | " def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):\n",
137 | " super(DenseBlock, self).__init__()\n",
138 | " self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)\n",
139 | " \n",
140 | " def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):\n",
141 | " layers = []\n",
142 | " for i in range(int(nb_layers)):\n",
143 | " layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))\n",
144 | " return nn.Sequential(*layers)\n",
145 | " \n",
146 | " def forward(self, x):\n",
147 | " t = self.layer(x)\n",
148 | " torch_model.record(t)\n",
149 | " return t\n",
150 | "\n",
151 | "\n",
152 | "class DenseNet3(nn.Module):\n",
153 | " def __init__(self, depth, num_classes, growth_rate=12,\n",
154 | " reduction=0.5, bottleneck=True, dropRate=0.0):\n",
155 | " super(DenseNet3, self).__init__()\n",
156 | " \n",
157 | " self.collecting = False\n",
158 | " \n",
159 | " \n",
160 | " \n",
161 | " in_planes = 2 * growth_rate\n",
162 | " n = (depth - 4) / 3\n",
163 | " if bottleneck == True:\n",
164 | " n = n/2\n",
165 | " block = BottleneckBlock\n",
166 | " else:\n",
167 | " block = BasicBlock\n",
168 | " # 1st conv before any dense block\n",
169 | " self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,\n",
170 | " padding=1, bias=False)\n",
171 | " # 1st block\n",
172 | " self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
173 | " in_planes = int(in_planes+n*growth_rate)\n",
174 | " self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n",
175 | " in_planes = int(math.floor(in_planes*reduction))\n",
176 | " # 2nd block\n",
177 | " self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
178 | " in_planes = int(in_planes+n*growth_rate)\n",
179 | " self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n",
180 | " in_planes = int(math.floor(in_planes*reduction))\n",
181 | " # 3rd block\n",
182 | " self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
183 | " in_planes = int(in_planes+n*growth_rate)\n",
184 | " # global average pooling and classifier\n",
185 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
186 | " self.relu = nn.ReLU(inplace=True)\n",
187 | " self.fc = nn.Linear(in_planes, num_classes)\n",
188 | " self.in_planes = in_planes\n",
189 | "\n",
190 | " for m in self.modules():\n",
191 | " if isinstance(m, nn.Conv2d):\n",
192 | " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
193 | " m.weight.data.normal_(0, math.sqrt(2. / n))\n",
194 | " elif isinstance(m, nn.BatchNorm2d):\n",
195 | " m.weight.data.fill_(1)\n",
196 | " m.bias.data.zero_()\n",
197 | " elif isinstance(m, nn.Linear):\n",
198 | " m.bias.data.zero_()\n",
199 | " \n",
200 | " def forward(self, x):\n",
201 | " out = self.conv1(x)\n",
202 | " self.record(out)\n",
203 | " out = self.trans1(self.block1(out))\n",
204 | " out = self.trans2(self.block2(out))\n",
205 | " out = self.block3(out)\n",
206 | " out = self.relu(self.bn1(out))\n",
207 | " self.record(out)\n",
208 | " out = F.avg_pool2d(out, 8)\n",
209 | " out = out.view(-1, self.in_planes)\n",
210 | " return self.fc(out)\n",
211 | " \n",
212 | " def load(self, path=\"densenet_svhn.pth\"):\n",
213 | " tm = torch.load(path,map_location=\"cpu\")\n",
214 | " self.load_state_dict(tm,strict=False)\n",
215 | " \n",
216 | " def record(self, t):\n",
217 | " if self.collecting:\n",
218 | " self.gram_feats.append(t)\n",
219 | " \n",
220 | " def gram_feature_list(self,x):\n",
221 | " self.collecting = True\n",
222 | " self.gram_feats = []\n",
223 | " self.forward(x)\n",
224 | " self.collecting = False\n",
225 | " temp = self.gram_feats\n",
226 | " self.gram_feats = []\n",
227 | " return temp\n",
228 | " \n",
229 | " def get_min_max(self, data, power):\n",
230 | " mins = []\n",
231 | " maxs = []\n",
232 | " \n",
233 | " for i in range(0,len(data),64):\n",
234 | " batch = data[i:i+64].cuda()\n",
235 | " feat_list = self.gram_feature_list(batch)\n",
236 | " for L,feat_L in enumerate(feat_list):\n",
237 | " if L==len(mins):\n",
238 | " mins.append([None]*len(power))\n",
239 | " maxs.append([None]*len(power))\n",
240 | " \n",
241 | " for p,P in enumerate(power):\n",
242 | " g_p = G_p(feat_L,P)\n",
243 | " \n",
244 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n",
245 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n",
246 | " \n",
247 | " if mins[L][p] is None:\n",
248 | " mins[L][p] = current_min\n",
249 | " maxs[L][p] = current_max\n",
250 | " else:\n",
251 | " mins[L][p] = torch.min(current_min,mins[L][p])\n",
252 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n",
253 | " \n",
254 | " return mins,maxs\n",
255 | " \n",
256 | " def get_deviations(self,data,power,mins,maxs):\n",
257 | " deviations = []\n",
258 | " \n",
259 | " for i in range(0,len(data),64): \n",
260 | " batch = data[i:i+64].cuda()\n",
261 | " feat_list = self.gram_feature_list(batch)\n",
262 | " batch_deviations = []\n",
263 | " for L,feat_L in enumerate(feat_list):\n",
264 | " dev = 0\n",
265 | " for p,P in enumerate(power):\n",
266 | " g_p = G_p(feat_L,P)\n",
267 | " \n",
268 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
269 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
270 | " batch_deviations.append(dev.cpu().detach().numpy())\n",
271 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n",
272 | " deviations.append(batch_deviations)\n",
273 | " deviations = np.concatenate(deviations,axis=0)\n",
274 | " \n",
275 | " return deviations\n",
276 | "\n",
277 | "torch_model = DenseNet3(100, num_classes=10)\n",
278 | "torch_model.load()\n",
279 | "torch_model.cuda()\n",
280 | "torch_model.params = list(torch_model.parameters())\n",
281 | "torch_model.eval()\n",
282 | "print(\"Done\") "
283 | ]
284 | },
285 | {
286 | "cell_type": "markdown",
287 | "metadata": {},
288 | "source": [
289 | "## Datasets"
290 | ]
291 | },
292 | {
293 | "cell_type": "markdown",
294 | "metadata": {},
295 | "source": [
296 | "In-distribution Datasets"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 4,
302 | "metadata": {},
303 | "outputs": [
304 | {
305 | "name": "stdout",
306 | "output_type": "stream",
307 | "text": [
308 | "Using downloaded and verified file: data/train_32x32.mat\n"
309 | ]
310 | }
311 | ],
312 | "source": [
313 | "batch_size = 128\n",
314 | "mean = np.array([[125.3/255, 123.0/255, 113.9/255]]).T\n",
315 | "\n",
316 | "std = np.array([[63.0/255, 62.1/255.0, 66.7/255.0]]).T\n",
317 | "normalize = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))\n",
318 | "\n",
319 | "transform_train = transforms.Compose([\n",
320 | " transforms.RandomCrop(32, padding=4),\n",
321 | " transforms.RandomHorizontalFlip(),\n",
322 | " transforms.ToTensor(),\n",
323 | " normalize\n",
324 | " \n",
325 | " ])\n",
326 | "transform_test = transforms.Compose([\n",
327 | " transforms.CenterCrop(size=(32, 32)),\n",
328 | " transforms.ToTensor(),\n",
329 | " normalize\n",
330 | " ])\n",
331 | "\n",
332 | "train_loader = torch.utils.data.DataLoader(\n",
333 | " datasets.SVHN('data', split=\"train\", download=True,\n",
334 | " transform=transform_train),\n",
335 | " batch_size=batch_size, shuffle=True)\n",
336 | "test_loader = torch.utils.data.DataLoader(\n",
337 | " datasets.SVHN('data', split=\"test\", transform=transform_test),\n",
338 | " batch_size=batch_size)\n"
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": 5,
344 | "metadata": {
345 | "scrolled": true
346 | },
347 | "outputs": [
348 | {
349 | "name": "stdout",
350 | "output_type": "stream",
351 | "text": [
352 | "Using downloaded and verified file: data/train_32x32.mat\n"
353 | ]
354 | }
355 | ],
356 | "source": [
357 | "data_train = list(list(torch.utils.data.DataLoader(\n",
358 | " datasets.SVHN('data', split=\"train\", download=True,\n",
359 | " transform=transform_test),\n",
360 | " batch_size=1, shuffle=True)))"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": 6,
366 | "metadata": {
367 | "scrolled": true
368 | },
369 | "outputs": [
370 | {
371 | "name": "stdout",
372 | "output_type": "stream",
373 | "text": [
374 | "Using downloaded and verified file: data/test_32x32.mat\n"
375 | ]
376 | }
377 | ],
378 | "source": [
379 | "data = list(list(torch.utils.data.DataLoader(\n",
380 | " datasets.SVHN('data', split=\"test\", download=True,\n",
381 | " transform=transform_test),\n",
382 | " batch_size=1, shuffle=False)))"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": 7,
388 | "metadata": {},
389 | "outputs": [
390 | {
391 | "name": "stdout",
392 | "output_type": "stream",
393 | "text": [
394 | "Accuracy: 0.9637753534111863\n"
395 | ]
396 | }
397 | ],
398 | "source": [
399 | "torch_model.eval()\n",
400 | "correct = 0\n",
401 | "total = 0\n",
402 | "for x,y in test_loader:\n",
403 | " x = x.cuda()\n",
404 | " y = y.numpy()\n",
405 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n",
406 | " total += y.shape[0]\n",
407 | "print(\"Accuracy: \",correct/total)\n"
408 | ]
409 | },
410 | {
411 | "cell_type": "markdown",
412 | "metadata": {},
413 | "source": [
414 | "Out-of-distribution Datasets"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "execution_count": 8,
420 | "metadata": {},
421 | "outputs": [
422 | {
423 | "name": "stdout",
424 | "output_type": "stream",
425 | "text": [
426 | "Files already downloaded and verified\n"
427 | ]
428 | }
429 | ],
430 | "source": [
431 | "cifar10 = list(torch.utils.data.DataLoader(\n",
432 | " datasets.CIFAR10('data', train=False, download=True,\n",
433 | " transform=transform_test),\n",
434 | " batch_size=1, shuffle=False))"
435 | ]
436 | },
437 | {
438 | "cell_type": "code",
439 | "execution_count": 9,
440 | "metadata": {},
441 | "outputs": [],
442 | "source": [
443 | "isun = list(torch.utils.data.DataLoader(\n",
444 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))"
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "execution_count": 10,
450 | "metadata": {},
451 | "outputs": [],
452 | "source": [
453 | "lsun_c = list(torch.utils.data.DataLoader(\n",
454 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))"
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": 11,
460 | "metadata": {},
461 | "outputs": [],
462 | "source": [
463 | "lsun_r = list(torch.utils.data.DataLoader(\n",
464 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": 12,
470 | "metadata": {},
471 | "outputs": [],
472 | "source": [
473 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n",
474 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": 13,
480 | "metadata": {},
481 | "outputs": [],
482 | "source": [
483 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n",
484 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {},
490 | "source": [
491 | "## Code for Detecting OODs"
492 | ]
493 | },
494 | {
495 | "cell_type": "markdown",
496 | "metadata": {},
497 | "source": [
498 | " Extract predictions for train and test data "
499 | ]
500 | },
501 | {
502 | "cell_type": "code",
503 | "execution_count": 14,
504 | "metadata": {},
505 | "outputs": [
506 | {
507 | "name": "stdout",
508 | "output_type": "stream",
509 | "text": [
510 | "Done\n",
511 | "Done\n"
512 | ]
513 | }
514 | ],
515 | "source": [
516 | "train_preds = []\n",
517 | "train_confs = []\n",
518 | "train_logits = []\n",
519 | "for idx in range(0,len(data_train),128):\n",
520 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n",
521 | " \n",
522 | " logits = torch_model(batch)\n",
523 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
524 | " preds = np.argmax(confs,axis=1)\n",
525 | " logits = (logits.cpu().detach().numpy())\n",
526 | "\n",
527 | " train_confs.extend(np.max(confs,axis=1)) \n",
528 | " train_preds.extend(preds)\n",
529 | " train_logits.extend(logits)\n",
530 | "print(\"Done\")\n",
531 | "\n",
532 | "test_preds = []\n",
533 | "test_confs = []\n",
534 | "test_logits = []\n",
535 | "\n",
536 | "for idx in range(0,len(data),128):\n",
537 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n",
538 | " \n",
539 | " logits = torch_model(batch)\n",
540 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
541 | " preds = np.argmax(confs,axis=1)\n",
542 | " logits = (logits.cpu().detach().numpy())\n",
543 | "\n",
544 | " test_confs.extend(np.max(confs,axis=1)) \n",
545 | " test_preds.extend(preds)\n",
546 | " test_logits.extend(logits)\n",
547 | "print(\"Done\")"
548 | ]
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "metadata": {},
553 | "source": [
554 | " Code for detecting OODs by identifying anomalies in correlations "
555 | ]
556 | },
557 | {
558 | "cell_type": "code",
559 | "execution_count": 15,
560 | "metadata": {},
561 | "outputs": [],
562 | "source": [
563 | "import calculate_log as callog\n",
564 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n",
565 | " average_results = {}\n",
566 | " for i in range(1,11):\n",
567 | " random.seed(i)\n",
568 | " \n",
569 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n",
570 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n",
571 | "\n",
572 | " validation = all_test_deviations[validation_indices]\n",
573 | " test_deviations = all_test_deviations[test_indices]\n",
574 | "\n",
575 | " t95 = validation.mean(axis=0)+10**-7\n",
576 | " if not normalize:\n",
577 | " t95 = np.ones_like(t95)\n",
578 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
579 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
580 | " \n",
581 | " results = callog.compute_metric(ood_deviations,test_deviations)\n",
582 | " for m in results:\n",
583 | " average_results[m] = average_results.get(m,0)+results[m]\n",
584 | " \n",
585 | " for m in average_results:\n",
586 | " average_results[m] /= i\n",
587 | " if verbose:\n",
588 | " callog.print_results(average_results)\n",
589 | " return average_results\n",
590 | "\n",
591 | "\n",
592 | "def cpu(ob):\n",
593 | " for i in range(len(ob)):\n",
594 | " for j in range(len(ob[i])):\n",
595 | " ob[i][j] = ob[i][j].cpu()\n",
596 | " return ob\n",
597 | " \n",
598 | "def cuda(ob):\n",
599 | " for i in range(len(ob)):\n",
600 | " for j in range(len(ob[i])):\n",
601 | " ob[i][j] = ob[i][j].cuda()\n",
602 | " return ob\n",
603 | "\n",
604 | "class Detector:\n",
605 | " def __init__(self):\n",
606 | " self.all_test_deviations = None\n",
607 | " self.mins = {}\n",
608 | " self.maxs = {}\n",
609 | " self.classes = range(10)\n",
610 | " \n",
611 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n",
612 | " for PRED in tqdm(self.classes):\n",
613 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n",
614 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n",
615 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n",
616 | " self.mins[PRED] = cpu(mins)\n",
617 | " self.maxs[PRED] = cpu(maxs)\n",
618 | " torch.cuda.empty_cache()\n",
619 | " \n",
620 | " def compute_test_deviations(self,POWERS=[10]):\n",
621 | " all_test_deviations = None\n",
622 | " for PRED in tqdm(self.classes):\n",
623 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n",
624 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n",
625 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n",
626 | " mins = cuda(self.mins[PRED])\n",
627 | " maxs = cuda(self.maxs[PRED])\n",
628 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n",
629 | " cpu(mins)\n",
630 | " cpu(maxs)\n",
631 | " if all_test_deviations is None:\n",
632 | " all_test_deviations = test_deviations\n",
633 | " else:\n",
634 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n",
635 | " torch.cuda.empty_cache()\n",
636 | " self.all_test_deviations = all_test_deviations\n",
637 | " \n",
638 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n",
639 | " ood_preds = []\n",
640 | " ood_confs = []\n",
641 | " \n",
642 | " for idx in range(0,len(ood),128):\n",
643 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n",
644 | " logits = torch_model(batch)\n",
645 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
646 | " preds = np.argmax(confs,axis=1)\n",
647 | " \n",
648 | " ood_confs.extend(np.max(confs,axis=1))\n",
649 | " ood_preds.extend(preds) \n",
650 | " torch.cuda.empty_cache()\n",
651 | " print(\"Done\")\n",
652 | " \n",
653 | " all_ood_deviations = None\n",
654 | " for PRED in tqdm(self.classes):\n",
655 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n",
656 | " if len(ood_indices)==0:\n",
657 | " continue\n",
658 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n",
659 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n",
660 | " mins = cuda(self.mins[PRED])\n",
661 | " maxs = cuda(self.maxs[PRED])\n",
662 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n",
663 | " cpu(self.mins[PRED])\n",
664 | " cpu(self.maxs[PRED]) \n",
665 | " if all_ood_deviations is None:\n",
666 | " all_ood_deviations = ood_deviations\n",
667 | " else:\n",
668 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n",
669 | " torch.cuda.empty_cache()\n",
670 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n",
671 | " return average_results, self.all_test_deviations, all_ood_deviations\n",
672 | " "
673 | ]
674 | },
675 | {
676 | "cell_type": "markdown",
677 | "metadata": {},
678 | "source": [
679 | " Results
"
680 | ]
681 | },
682 | {
683 | "cell_type": "code",
684 | "execution_count": 16,
685 | "metadata": {
686 | "scrolled": false
687 | },
688 | "outputs": [
689 | {
690 | "data": {
691 | "application/vnd.jupyter.widget-view+json": {
692 | "model_id": "b04957cc14ae4240957cf9c9dbf194be",
693 | "version_major": 2,
694 | "version_minor": 0
695 | },
696 | "text/plain": [
697 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
698 | ]
699 | },
700 | "metadata": {},
701 | "output_type": "display_data"
702 | },
703 | {
704 | "name": "stdout",
705 | "output_type": "stream",
706 | "text": [
707 | "\n"
708 | ]
709 | },
710 | {
711 | "data": {
712 | "application/vnd.jupyter.widget-view+json": {
713 | "model_id": "9a717142b49e48e4abaec21798c62771",
714 | "version_major": 2,
715 | "version_minor": 0
716 | },
717 | "text/plain": [
718 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
719 | ]
720 | },
721 | "metadata": {},
722 | "output_type": "display_data"
723 | },
724 | {
725 | "name": "stdout",
726 | "output_type": "stream",
727 | "text": [
728 | "\n",
729 | "iSUN\n",
730 | "Done\n"
731 | ]
732 | },
733 | {
734 | "data": {
735 | "application/vnd.jupyter.widget-view+json": {
736 | "model_id": "b6ae6ea2d1fd41e4ade8fb012369ccc1",
737 | "version_major": 2,
738 | "version_minor": 0
739 | },
740 | "text/plain": [
741 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
742 | ]
743 | },
744 | "metadata": {},
745 | "output_type": "display_data"
746 | },
747 | {
748 | "name": "stdout",
749 | "output_type": "stream",
750 | "text": [
751 | "\n",
752 | " TNR AUROC DTACC AUIN AUOUT \n",
753 | " 99.348 99.796 98.329 99.312 99.925\n",
754 | "LSUN (R)\n",
755 | "Done\n"
756 | ]
757 | },
758 | {
759 | "data": {
760 | "application/vnd.jupyter.widget-view+json": {
761 | "model_id": "89195f94b7544783b00e0c1b60c065d2",
762 | "version_major": 2,
763 | "version_minor": 0
764 | },
765 | "text/plain": [
766 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
767 | ]
768 | },
769 | "metadata": {},
770 | "output_type": "display_data"
771 | },
772 | {
773 | "name": "stdout",
774 | "output_type": "stream",
775 | "text": [
776 | "\n",
777 | " TNR AUROC DTACC AUIN AUOUT \n",
778 | " 99.504 99.844 98.581 99.500 99.937\n",
779 | "LSUN (C)\n",
780 | "Done\n"
781 | ]
782 | },
783 | {
784 | "data": {
785 | "application/vnd.jupyter.widget-view+json": {
786 | "model_id": "8bf5d901b38a45d59559203a50e86fec",
787 | "version_major": 2,
788 | "version_minor": 0
789 | },
790 | "text/plain": [
791 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
792 | ]
793 | },
794 | "metadata": {},
795 | "output_type": "display_data"
796 | },
797 | {
798 | "name": "stdout",
799 | "output_type": "stream",
800 | "text": [
801 | "\n",
802 | " TNR AUROC DTACC AUIN AUOUT \n",
803 | " 93.325 98.585 94.326 97.113 99.114\n",
804 | "TinyImgNet (R)\n",
805 | "Done\n"
806 | ]
807 | },
808 | {
809 | "data": {
810 | "application/vnd.jupyter.widget-view+json": {
811 | "model_id": "109d4db4dfbe4b999a3e48b7d8d6b329",
812 | "version_major": 2,
813 | "version_minor": 0
814 | },
815 | "text/plain": [
816 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
817 | ]
818 | },
819 | "metadata": {},
820 | "output_type": "display_data"
821 | },
822 | {
823 | "name": "stdout",
824 | "output_type": "stream",
825 | "text": [
826 | "\n",
827 | " TNR AUROC DTACC AUIN AUOUT \n",
828 | " 99.095 99.736 97.940 99.251 99.891\n",
829 | "TinyImgNet (C)\n",
830 | "Done\n"
831 | ]
832 | },
833 | {
834 | "data": {
835 | "application/vnd.jupyter.widget-view+json": {
836 | "model_id": "8295e559036c46c3a27c988a4ace9853",
837 | "version_major": 2,
838 | "version_minor": 0
839 | },
840 | "text/plain": [
841 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
842 | ]
843 | },
844 | "metadata": {},
845 | "output_type": "display_data"
846 | },
847 | {
848 | "name": "stdout",
849 | "output_type": "stream",
850 | "text": [
851 | "\n",
852 | " TNR AUROC DTACC AUIN AUOUT \n",
853 | " 97.881 99.455 96.796 98.636 99.733\n",
854 | "CIFAR-10\n",
855 | "Done\n"
856 | ]
857 | },
858 | {
859 | "data": {
860 | "application/vnd.jupyter.widget-view+json": {
861 | "model_id": "228d21c857644610a9942393d8539d1e",
862 | "version_major": 2,
863 | "version_minor": 0
864 | },
865 | "text/plain": [
866 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
867 | ]
868 | },
869 | "metadata": {},
870 | "output_type": "display_data"
871 | },
872 | {
873 | "name": "stdout",
874 | "output_type": "stream",
875 | "text": [
876 | "\n",
877 | " TNR AUROC DTACC AUIN AUOUT \n",
878 | " 80.409 95.533 89.051 89.599 97.752\n"
879 | ]
880 | }
881 | ],
882 | "source": [
883 | "def G_p(ob, p):\n",
884 | " temp = ob.detach()\n",
885 | " \n",
886 | " temp = temp**p\n",
887 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n",
888 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n",
889 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n",
890 | " \n",
891 | " return temp\n",
892 | "\n",
893 | "detector = Detector()\n",
894 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n",
895 | "detector.compute_test_deviations(POWERS=range(1,11))\n",
896 | "\n",
897 | "print(\"iSUN\")\n",
898 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n",
899 | "print(\"LSUN (R)\")\n",
900 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n",
901 | "print(\"LSUN (C)\")\n",
902 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n",
903 | "print(\"TinyImgNet (R)\")\n",
904 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n",
905 | "print(\"TinyImgNet (C)\")\n",
906 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n",
907 | "print(\"CIFAR-10\")\n",
908 | "c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))"
909 | ]
910 | }
911 | ],
912 | "metadata": {
913 | "kernelspec": {
914 | "display_name": "Python 2",
915 | "language": "python",
916 | "name": "python2"
917 | },
918 | "language_info": {
919 | "codemirror_mode": {
920 | "name": "ipython",
921 | "version": 3
922 | },
923 | "file_extension": ".py",
924 | "mimetype": "text/x-python",
925 | "name": "python",
926 | "nbconvert_exporter": "python",
927 | "pygments_lexer": "ipython3",
928 | "version": "3.6.9"
929 | }
930 | },
931 | "nbformat": 4,
932 | "nbformat_minor": 2
933 | }
934 |
--------------------------------------------------------------------------------
/DenseNet_Cifar100.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "DenseNet: Cifar100
"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "## Imports"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from __future__ import division,print_function\n",
24 | "\n",
25 | "%matplotlib inline\n",
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "\n",
29 | "import sys\n",
30 | "from tqdm import tqdm_notebook as tqdm\n",
31 | "\n",
32 | "import random\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "import math\n",
35 | "\n",
36 | "import numpy as np\n",
37 | "\n",
38 | "import torch\n",
39 | "import torch.nn as nn\n",
40 | "import torch.nn.functional as F\n",
41 | "import torch.optim as optim\n",
42 | "import torch.nn.init as init\n",
43 | "from torch.autograd import Variable, grad\n",
44 | "from torchvision import datasets, transforms\n",
45 | "from torch.nn.parameter import Parameter\n",
46 | "\n",
47 | "import calculate_log as callog\n",
48 | "\n",
49 | "import warnings\n",
50 | "warnings.filterwarnings('ignore')"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "torch.cuda.set_device(2) #Select the GPU"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "## Model definition"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "name": "stdout",
76 | "output_type": "stream",
77 | "text": [
78 | "Done\n"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "def conv3x3(in_planes, out_planes, stride=1):\n",
84 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
85 | "\n",
86 | "\n",
87 | "class BottleneckBlock(nn.Module):\n",
88 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n",
89 | " super(BottleneckBlock, self).__init__()\n",
90 | " inter_planes = out_planes * 4\n",
91 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
92 | " self.relu = nn.ReLU(inplace=True)\n",
93 | " self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,\n",
94 | " padding=0, bias=False)\n",
95 | " self.bn2 = nn.BatchNorm2d(inter_planes)\n",
96 | " self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,\n",
97 | " padding=1, bias=False)\n",
98 | " self.droprate = dropRate\n",
99 | " \n",
100 | " def forward(self, x):\n",
101 | " \n",
102 | " out = self.conv1(self.relu(self.bn1(x)))\n",
103 | " \n",
104 | " torch_model.record(out)\n",
105 | " \n",
106 | " if self.droprate > 0:\n",
107 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
108 | " \n",
109 | " out = self.conv2(self.relu(self.bn2(out)))\n",
110 | " torch_model.record(out)\n",
111 | " \n",
112 | " if self.droprate > 0:\n",
113 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
114 | " return torch.cat([x, out], 1)\n",
115 | "\n",
116 | "class TransitionBlock(nn.Module):\n",
117 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n",
118 | " super(TransitionBlock, self).__init__()\n",
119 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
120 | " self.relu = nn.ReLU(inplace=True)\n",
121 | " self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n",
122 | " padding=0, bias=False)\n",
123 | " self.droprate = dropRate\n",
124 | " \n",
125 | " def forward(self, x):\n",
126 | " out = self.conv1(self.relu(self.bn1(x)))\n",
127 | " torch_model.record(out)\n",
128 | " \n",
129 | " if self.droprate > 0:\n",
130 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
131 | " return F.avg_pool2d(out, 2)\n",
132 | "\n",
133 | "class DenseBlock(nn.Module):\n",
134 | " def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):\n",
135 | " super(DenseBlock, self).__init__()\n",
136 | " self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)\n",
137 | " \n",
138 | " def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):\n",
139 | " layers = []\n",
140 | " for i in range(int(nb_layers)):\n",
141 | " layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))\n",
142 | " return nn.Sequential(*layers)\n",
143 | " \n",
144 | " def forward(self, x):\n",
145 | " t = self.layer(x)\n",
146 | " torch_model.record(t)\n",
147 | " return t\n",
148 | "\n",
149 | "\n",
150 | "class DenseNet3(nn.Module):\n",
151 | " def __init__(self, depth, num_classes, growth_rate=12,\n",
152 | " reduction=0.5, bottleneck=True, dropRate=0.0):\n",
153 | " super(DenseNet3, self).__init__()\n",
154 | " \n",
155 | " self.collecting = False\n",
156 | " \n",
157 | " \n",
158 | " \n",
159 | " in_planes = 2 * growth_rate\n",
160 | " n = (depth - 4) / 3\n",
161 | " if bottleneck == True:\n",
162 | " n = n/2\n",
163 | " block = BottleneckBlock\n",
164 | " else:\n",
165 | " block = BasicBlock\n",
166 | " # 1st conv before any dense block\n",
167 | " self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,\n",
168 | " padding=1, bias=False)\n",
169 | " # 1st block\n",
170 | " self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
171 | " in_planes = int(in_planes+n*growth_rate)\n",
172 | " self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n",
173 | " in_planes = int(math.floor(in_planes*reduction))\n",
174 | " # 2nd block\n",
175 | " self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
176 | " in_planes = int(in_planes+n*growth_rate)\n",
177 | " self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n",
178 | " in_planes = int(math.floor(in_planes*reduction))\n",
179 | " # 3rd block\n",
180 | " self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
181 | " in_planes = int(in_planes+n*growth_rate)\n",
182 | " # global average pooling and classifier\n",
183 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
184 | " self.relu = nn.ReLU(inplace=True)\n",
185 | " self.fc = nn.Linear(in_planes, num_classes)\n",
186 | " self.in_planes = in_planes\n",
187 | "\n",
188 | " for m in self.modules():\n",
189 | " if isinstance(m, nn.Conv2d):\n",
190 | " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
191 | " m.weight.data.normal_(0, math.sqrt(2. / n))\n",
192 | " elif isinstance(m, nn.BatchNorm2d):\n",
193 | " m.weight.data.fill_(1)\n",
194 | " m.bias.data.zero_()\n",
195 | " elif isinstance(m, nn.Linear):\n",
196 | " m.bias.data.zero_()\n",
197 | " \n",
198 | " def forward(self, x):\n",
199 | " out = self.conv1(x)\n",
200 | " out = self.trans1(self.block1(out))\n",
201 | " out = self.trans2(self.block2(out))\n",
202 | " out = self.block3(out)\n",
203 | " out = self.relu(self.bn1(out))\n",
204 | " out = F.avg_pool2d(out, 8)\n",
205 | " out = out.view(-1, self.in_planes)\n",
206 | " return self.fc(out)\n",
207 | " \n",
208 | " def load(self, path=\"densenet_cifar100.pth\"):\n",
209 | " tm = torch.load(path,map_location=\"cpu\")\n",
210 | " self.load_state_dict(tm.state_dict(),strict=False)\n",
211 | " \n",
212 | " def record(self, t):\n",
213 | " if self.collecting:\n",
214 | " self.gram_feats.append(t)\n",
215 | " \n",
216 | " def gram_feature_list(self,x):\n",
217 | " self.collecting = True\n",
218 | " self.gram_feats = []\n",
219 | " self.forward(x)\n",
220 | " self.collecting = False\n",
221 | " temp = self.gram_feats\n",
222 | " self.gram_feats = []\n",
223 | " return temp\n",
224 | " \n",
225 | " def get_min_max(self, data, power):\n",
226 | " mins = []\n",
227 | " maxs = []\n",
228 | " \n",
229 | " for i in range(0,len(data),64):\n",
230 | " batch = data[i:i+64].cuda()\n",
231 | " feat_list = self.gram_feature_list(batch)\n",
232 | " for L,feat_L in enumerate(feat_list):\n",
233 | " if L==len(mins):\n",
234 | " mins.append([None]*len(power))\n",
235 | " maxs.append([None]*len(power))\n",
236 | " \n",
237 | " for p,P in enumerate(power):\n",
238 | " g_p = G_p(feat_L,P)\n",
239 | " \n",
240 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n",
241 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n",
242 | " \n",
243 | " if mins[L][p] is None:\n",
244 | " mins[L][p] = current_min\n",
245 | " maxs[L][p] = current_max\n",
246 | " else:\n",
247 | " mins[L][p] = torch.min(current_min,mins[L][p])\n",
248 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n",
249 | " \n",
250 | " return mins,maxs\n",
251 | " \n",
252 | " def get_deviations(self,data,power,mins,maxs):\n",
253 | " deviations = []\n",
254 | " \n",
255 | " for i in range(0,len(data),64): \n",
256 | " batch = data[i:i+64].cuda()\n",
257 | " feat_list = self.gram_feature_list(batch)\n",
258 | " batch_deviations = []\n",
259 | " for L,feat_L in enumerate(feat_list):\n",
260 | " dev = 0\n",
261 | " for p,P in enumerate(power):\n",
262 | " g_p = G_p(feat_L,P)\n",
263 | " \n",
264 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
265 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
266 | " batch_deviations.append(dev.cpu().detach().numpy())\n",
267 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n",
268 | " deviations.append(batch_deviations)\n",
269 | " deviations = np.concatenate(deviations,axis=0)\n",
270 | " \n",
271 | " return deviations\n",
272 | "\n",
273 | "torch_model = DenseNet3(100, num_classes=100)\n",
274 | "torch_model.load()\n",
275 | "torch_model.cuda()\n",
276 | "torch_model.params = list(torch_model.parameters())\n",
277 | "torch_model.eval()\n",
278 | "print(\"Done\") "
279 | ]
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "metadata": {},
284 | "source": [
285 | "## Datasets"
286 | ]
287 | },
288 | {
289 | "cell_type": "markdown",
290 | "metadata": {},
291 | "source": [
292 | "In-distribution Datasets"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": 4,
298 | "metadata": {},
299 | "outputs": [
300 | {
301 | "name": "stdout",
302 | "output_type": "stream",
303 | "text": [
304 | "Files already downloaded and verified\n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "batch_size = 128\n",
310 | "mean = np.array([[125.3/255, 123.0/255, 113.9/255]]).T\n",
311 | "\n",
312 | "std = np.array([[63.0/255, 62.1/255.0, 66.7/255.0]]).T\n",
313 | "normalize = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))\n",
314 | "\n",
315 | "transform_train = transforms.Compose([\n",
316 | " transforms.RandomCrop(32, padding=4),\n",
317 | " transforms.RandomHorizontalFlip(),\n",
318 | " transforms.ToTensor(),\n",
319 | " normalize\n",
320 | " \n",
321 | " ])\n",
322 | "transform_test = transforms.Compose([\n",
323 | " transforms.CenterCrop(size=(32, 32)),\n",
324 | " transforms.ToTensor(),\n",
325 | " normalize\n",
326 | " ])\n",
327 | "\n",
328 | "train_loader = torch.utils.data.DataLoader(\n",
329 | " datasets.CIFAR100('data', train=True, download=True,\n",
330 | " transform=transform_train),\n",
331 | " batch_size=batch_size, shuffle=True)\n",
332 | "test_loader = torch.utils.data.DataLoader(\n",
333 | " datasets.CIFAR100('data', train=False, transform=transform_test),\n",
334 | " batch_size=batch_size)\n"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 5,
340 | "metadata": {
341 | "scrolled": true
342 | },
343 | "outputs": [
344 | {
345 | "name": "stdout",
346 | "output_type": "stream",
347 | "text": [
348 | "Files already downloaded and verified\n"
349 | ]
350 | }
351 | ],
352 | "source": [
353 | "data_train = list(torch.utils.data.DataLoader(\n",
354 | " datasets.CIFAR100('data', train=True, download=True,\n",
355 | " transform=transform_test),\n",
356 | " batch_size=1, shuffle=False))"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": 6,
362 | "metadata": {},
363 | "outputs": [
364 | {
365 | "name": "stdout",
366 | "output_type": "stream",
367 | "text": [
368 | "Files already downloaded and verified\n"
369 | ]
370 | }
371 | ],
372 | "source": [
373 | "data = list(torch.utils.data.DataLoader(\n",
374 | " datasets.CIFAR100('data', train=False, download=True,\n",
375 | " transform=transform_test),\n",
376 | " batch_size=1, shuffle=False))"
377 | ]
378 | },
379 | {
380 | "cell_type": "code",
381 | "execution_count": 7,
382 | "metadata": {},
383 | "outputs": [
384 | {
385 | "name": "stdout",
386 | "output_type": "stream",
387 | "text": [
388 | "Accuracy: 0.7763\n"
389 | ]
390 | }
391 | ],
392 | "source": [
393 | "torch_model.eval()\n",
394 | "correct = 0\n",
395 | "total = 0\n",
396 | "for x,y in test_loader:\n",
397 | " x = x.cuda()\n",
398 | " y = y.numpy()\n",
399 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n",
400 | " total += y.shape[0]\n",
401 | "print(\"Accuracy: \",correct/total)\n"
402 | ]
403 | },
404 | {
405 | "cell_type": "markdown",
406 | "metadata": {},
407 | "source": [
408 | "Out-of-distribution Datasets"
409 | ]
410 | },
411 | {
412 | "cell_type": "code",
413 | "execution_count": 8,
414 | "metadata": {},
415 | "outputs": [
416 | {
417 | "name": "stdout",
418 | "output_type": "stream",
419 | "text": [
420 | "Files already downloaded and verified\n"
421 | ]
422 | }
423 | ],
424 | "source": [
425 | "cifar10 = list(torch.utils.data.DataLoader(\n",
426 | " datasets.CIFAR10('data', train=False, download=True,\n",
427 | " transform=transform_test),\n",
428 | " batch_size=1, shuffle=True))"
429 | ]
430 | },
431 | {
432 | "cell_type": "code",
433 | "execution_count": 9,
434 | "metadata": {},
435 | "outputs": [
436 | {
437 | "name": "stdout",
438 | "output_type": "stream",
439 | "text": [
440 | "Using downloaded and verified file: data/test_32x32.mat\n"
441 | ]
442 | }
443 | ],
444 | "source": [
445 | "svhn = list(torch.utils.data.DataLoader(\n",
446 | " datasets.SVHN('data', split=\"test\", download=True,\n",
447 | " transform=transform_test),\n",
448 | " batch_size=1, shuffle=True))"
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": 10,
454 | "metadata": {},
455 | "outputs": [],
456 | "source": [
457 | "isun = list(torch.utils.data.DataLoader(\n",
458 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": 11,
464 | "metadata": {},
465 | "outputs": [],
466 | "source": [
467 | "lsun_c = list(torch.utils.data.DataLoader(\n",
468 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))"
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": 12,
474 | "metadata": {},
475 | "outputs": [],
476 | "source": [
477 | "lsun_r = list(torch.utils.data.DataLoader(\n",
478 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 13,
484 | "metadata": {},
485 | "outputs": [],
486 | "source": [
487 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n",
488 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))"
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": 14,
494 | "metadata": {},
495 | "outputs": [],
496 | "source": [
497 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n",
498 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
499 | ]
500 | },
501 | {
502 | "cell_type": "markdown",
503 | "metadata": {},
504 | "source": [
505 | "## Code for Detecting OODs"
506 | ]
507 | },
508 | {
509 | "cell_type": "markdown",
510 | "metadata": {},
511 | "source": [
512 | " Extract predictions for train and test data "
513 | ]
514 | },
515 | {
516 | "cell_type": "code",
517 | "execution_count": 15,
518 | "metadata": {},
519 | "outputs": [
520 | {
521 | "name": "stdout",
522 | "output_type": "stream",
523 | "text": [
524 | "Done\n",
525 | "Done\n"
526 | ]
527 | }
528 | ],
529 | "source": [
530 | "train_preds = []\n",
531 | "train_confs = []\n",
532 | "train_logits = []\n",
533 | "for idx in range(0,len(data_train),128):\n",
534 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n",
535 | " \n",
536 | " logits = torch_model(batch)\n",
537 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
538 | " preds = np.argmax(confs,axis=1)\n",
539 | " logits = (logits.cpu().detach().numpy())\n",
540 | " \n",
541 | " train_confs.extend(np.max(confs,axis=1)) \n",
542 | " train_preds.extend(preds)\n",
543 | " train_logits.extend(logits)\n",
544 | "print(\"Done\")\n",
545 | "\n",
546 | "test_preds = []\n",
547 | "test_confs = []\n",
548 | "test_logits = []\n",
549 | "\n",
550 | "for idx in range(0,len(data),128):\n",
551 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n",
552 | " \n",
553 | " logits = torch_model(batch)\n",
554 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
555 | " preds = np.argmax(confs,axis=1)\n",
556 | " logits = (logits.cpu().detach().numpy())\n",
557 | "\n",
558 | " test_confs.extend(np.max(confs,axis=1)) \n",
559 | " test_preds.extend(preds)\n",
560 | " test_logits.extend(logits)\n",
561 | "print(\"Done\")"
562 | ]
563 | },
564 | {
565 | "cell_type": "markdown",
566 | "metadata": {},
567 | "source": [
568 | " Code for detecting OODs by identifying anomalies in correlations "
569 | ]
570 | },
571 | {
572 | "cell_type": "code",
573 | "execution_count": 16,
574 | "metadata": {},
575 | "outputs": [],
576 | "source": [
577 | "import calculate_log as callog\n",
578 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n",
579 | " average_results = {}\n",
580 | " for i in range(1,11):\n",
581 | " random.seed(i)\n",
582 | " \n",
583 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n",
584 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n",
585 | "\n",
586 | " validation = all_test_deviations[validation_indices]\n",
587 | " test_deviations = all_test_deviations[test_indices]\n",
588 | "\n",
589 | " t95 = validation.mean(axis=0)+10**-7\n",
590 | " if not normalize:\n",
591 | " t95 = np.ones_like(t95)\n",
592 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
593 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
594 | " \n",
595 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n",
596 | " for m in results:\n",
597 | " average_results[m] = average_results.get(m,0)+results[m]\n",
598 | " \n",
599 | " for m in average_results:\n",
600 | " average_results[m] /= i\n",
601 | " if verbose:\n",
602 | " callog.print_results(average_results)\n",
603 | " return average_results\n",
604 | "\n",
605 | "\n",
606 | "def cpu(ob):\n",
607 | " for i in range(len(ob)):\n",
608 | " for j in range(len(ob[i])):\n",
609 | " ob[i][j] = ob[i][j].cpu()\n",
610 | " return ob\n",
611 | " \n",
612 | "def cuda(ob):\n",
613 | " for i in range(len(ob)):\n",
614 | " for j in range(len(ob[i])):\n",
615 | " ob[i][j] = ob[i][j].cuda()\n",
616 | " return ob\n",
617 | "\n",
618 | "class Detector:\n",
619 | " def __init__(self):\n",
620 | " self.all_test_deviations = None\n",
621 | " self.mins = {}\n",
622 | " self.maxs = {}\n",
623 | " self.classes = range(100)\n",
624 | " \n",
625 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n",
626 | " for PRED in tqdm(self.classes):\n",
627 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n",
628 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n",
629 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n",
630 | " self.mins[PRED] = cpu(mins)\n",
631 | " self.maxs[PRED] = cpu(maxs)\n",
632 | " torch.cuda.empty_cache()\n",
633 | " \n",
634 | " def compute_test_deviations(self,POWERS=[10]):\n",
635 | " all_test_deviations = None\n",
636 | " for PRED in tqdm(self.classes):\n",
637 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n",
638 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n",
639 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n",
640 | " mins = cuda(self.mins[PRED])\n",
641 | " maxs = cuda(self.maxs[PRED])\n",
642 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n",
643 | " cpu(mins)\n",
644 | " cpu(maxs)\n",
645 | " if all_test_deviations is None:\n",
646 | " all_test_deviations = test_deviations\n",
647 | " else:\n",
648 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n",
649 | " torch.cuda.empty_cache()\n",
650 | " self.all_test_deviations = all_test_deviations\n",
651 | " \n",
652 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n",
653 | " ood_preds = []\n",
654 | " ood_confs = []\n",
655 | " \n",
656 | " for idx in range(0,len(ood),128):\n",
657 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n",
658 | " logits = torch_model(batch)\n",
659 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
660 | " preds = np.argmax(confs,axis=1)\n",
661 | " \n",
662 | " ood_confs.extend(np.max(confs,axis=1))\n",
663 | " ood_preds.extend(preds) \n",
664 | " torch.cuda.empty_cache()\n",
665 | " print(\"Done\")\n",
666 | " \n",
667 | " all_ood_deviations = None\n",
668 | " for PRED in tqdm(self.classes):\n",
669 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n",
670 | " if len(ood_indices)==0:\n",
671 | " continue\n",
672 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n",
673 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n",
674 | " mins = cuda(self.mins[PRED])\n",
675 | " maxs = cuda(self.maxs[PRED])\n",
676 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n",
677 | " cpu(self.mins[PRED])\n",
678 | " cpu(self.maxs[PRED]) \n",
679 | " if all_ood_deviations is None:\n",
680 | " all_ood_deviations = ood_deviations\n",
681 | " else:\n",
682 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n",
683 | " torch.cuda.empty_cache()\n",
684 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n",
685 | " return average_results, self.all_test_deviations, all_ood_deviations\n"
686 | ]
687 | },
688 | {
689 | "cell_type": "markdown",
690 | "metadata": {},
691 | "source": [
692 | " Results
"
693 | ]
694 | },
695 | {
696 | "cell_type": "code",
697 | "execution_count": 17,
698 | "metadata": {},
699 | "outputs": [
700 | {
701 | "data": {
702 | "application/vnd.jupyter.widget-view+json": {
703 | "model_id": "6d38db591504460ea13535698dea877c",
704 | "version_major": 2,
705 | "version_minor": 0
706 | },
707 | "text/plain": [
708 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
709 | ]
710 | },
711 | "metadata": {},
712 | "output_type": "display_data"
713 | },
714 | {
715 | "name": "stdout",
716 | "output_type": "stream",
717 | "text": [
718 | "\n"
719 | ]
720 | },
721 | {
722 | "data": {
723 | "application/vnd.jupyter.widget-view+json": {
724 | "model_id": "ffae25d812914b6c95aa3c1c7fd392a6",
725 | "version_major": 2,
726 | "version_minor": 0
727 | },
728 | "text/plain": [
729 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
730 | ]
731 | },
732 | "metadata": {},
733 | "output_type": "display_data"
734 | },
735 | {
736 | "name": "stdout",
737 | "output_type": "stream",
738 | "text": [
739 | "\n",
740 | "iSUN\n",
741 | "Done\n"
742 | ]
743 | },
744 | {
745 | "data": {
746 | "application/vnd.jupyter.widget-view+json": {
747 | "model_id": "716da00211014f8988f1c9937cf1d35c",
748 | "version_major": 2,
749 | "version_minor": 0
750 | },
751 | "text/plain": [
752 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
753 | ]
754 | },
755 | "metadata": {},
756 | "output_type": "display_data"
757 | },
758 | {
759 | "name": "stdout",
760 | "output_type": "stream",
761 | "text": [
762 | "\n",
763 | " TNR AUROC DTACC AUIN AUOUT \n",
764 | " 95.867 99.042 95.632 98.990 99.083\n",
765 | "LSUN (R)\n",
766 | "Done\n"
767 | ]
768 | },
769 | {
770 | "data": {
771 | "application/vnd.jupyter.widget-view+json": {
772 | "model_id": "0397655250134e9d853d8598410ced8b",
773 | "version_major": 2,
774 | "version_minor": 0
775 | },
776 | "text/plain": [
777 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
778 | ]
779 | },
780 | "metadata": {},
781 | "output_type": "display_data"
782 | },
783 | {
784 | "name": "stdout",
785 | "output_type": "stream",
786 | "text": [
787 | "\n",
788 | " TNR AUROC DTACC AUIN AUOUT \n",
789 | " 97.234 99.349 96.393 99.264 99.407\n",
790 | "LSUN (C)\n",
791 | "Done\n"
792 | ]
793 | },
794 | {
795 | "data": {
796 | "application/vnd.jupyter.widget-view+json": {
797 | "model_id": "36483d6fcf034a8887df000863cb187a",
798 | "version_major": 2,
799 | "version_minor": 0
800 | },
801 | "text/plain": [
802 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
803 | ]
804 | },
805 | "metadata": {},
806 | "output_type": "display_data"
807 | },
808 | {
809 | "name": "stdout",
810 | "output_type": "stream",
811 | "text": [
812 | "\n",
813 | " TNR AUROC DTACC AUIN AUOUT \n",
814 | " 65.517 91.391 83.644 89.553 92.749\n",
815 | "TinyImgNet (R)\n",
816 | "Done\n"
817 | ]
818 | },
819 | {
820 | "data": {
821 | "application/vnd.jupyter.widget-view+json": {
822 | "model_id": "9aa5b886321c4ba793127afe191bb73c",
823 | "version_major": 2,
824 | "version_minor": 0
825 | },
826 | "text/plain": [
827 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
828 | ]
829 | },
830 | "metadata": {},
831 | "output_type": "display_data"
832 | },
833 | {
834 | "name": "stdout",
835 | "output_type": "stream",
836 | "text": [
837 | "\n",
838 | " TNR AUROC DTACC AUIN AUOUT \n",
839 | " 95.748 98.973 95.522 98.768 99.126\n",
840 | "TinyImgNet (C)\n",
841 | "Done\n"
842 | ]
843 | },
844 | {
845 | "data": {
846 | "application/vnd.jupyter.widget-view+json": {
847 | "model_id": "1ab175d00f4e44688f226557051e413a",
848 | "version_major": 2,
849 | "version_minor": 0
850 | },
851 | "text/plain": [
852 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
853 | ]
854 | },
855 | "metadata": {},
856 | "output_type": "display_data"
857 | },
858 | {
859 | "name": "stdout",
860 | "output_type": "stream",
861 | "text": [
862 | "\n",
863 | " TNR AUROC DTACC AUIN AUOUT \n",
864 | " 89.013 97.687 92.452 97.303 98.018\n",
865 | "SVHN\n",
866 | "Done\n"
867 | ]
868 | },
869 | {
870 | "data": {
871 | "application/vnd.jupyter.widget-view+json": {
872 | "model_id": "999c654468da4044bae289c360b2de38",
873 | "version_major": 2,
874 | "version_minor": 0
875 | },
876 | "text/plain": [
877 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
878 | ]
879 | },
880 | "metadata": {},
881 | "output_type": "display_data"
882 | },
883 | {
884 | "name": "stdout",
885 | "output_type": "stream",
886 | "text": [
887 | "\n",
888 | " TNR AUROC DTACC AUIN AUOUT \n",
889 | " 89.341 97.322 92.364 91.652 99.080\n",
890 | "CIFAR-10\n",
891 | "Done\n"
892 | ]
893 | },
894 | {
895 | "data": {
896 | "application/vnd.jupyter.widget-view+json": {
897 | "model_id": "ebf38974d3bf4350aa44ea06f629e369",
898 | "version_major": 2,
899 | "version_minor": 0
900 | },
901 | "text/plain": [
902 | "HBox(children=(IntProgress(value=0), HTML(value='')))"
903 | ]
904 | },
905 | "metadata": {},
906 | "output_type": "display_data"
907 | },
908 | {
909 | "name": "stdout",
910 | "output_type": "stream",
911 | "text": [
912 | "\n",
913 | " TNR AUROC DTACC AUIN AUOUT \n",
914 | " 10.596 64.227 60.404 61.350 64.092\n"
915 | ]
916 | }
917 | ],
918 | "source": [
919 | "def G_p(ob, p):\n",
920 | " temp = ob.detach()\n",
921 | " \n",
922 | " temp = temp**p\n",
923 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n",
924 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n",
925 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n",
926 | " \n",
927 | " return temp\n",
928 | "\n",
929 | "detector = Detector()\n",
930 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n",
931 | "\n",
932 | "detector.compute_test_deviations(POWERS=range(1,11))\n",
933 | "\n",
934 | "print(\"iSUN\")\n",
935 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n",
936 | "print(\"LSUN (R)\")\n",
937 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n",
938 | "print(\"LSUN (C)\")\n",
939 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n",
940 | "print(\"TinyImgNet (R)\")\n",
941 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n",
942 | "print(\"TinyImgNet (C)\")\n",
943 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n",
944 | "print(\"SVHN\")\n",
945 | "svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))\n",
946 | "print(\"CIFAR-10\")\n",
947 | "c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))"
948 | ]
949 | }
950 | ],
951 | "metadata": {
952 | "kernelspec": {
953 | "display_name": "Python 2",
954 | "language": "python",
955 | "name": "python2"
956 | },
957 | "language_info": {
958 | "codemirror_mode": {
959 | "name": "ipython",
960 | "version": 3
961 | },
962 | "file_extension": ".py",
963 | "mimetype": "text/x-python",
964 | "name": "python",
965 | "nbconvert_exporter": "python",
966 | "pygments_lexer": "ipython3",
967 | "version": "3.6.9"
968 | }
969 | },
970 | "nbformat": 4,
971 | "nbformat_minor": 2
972 | }
973 |
--------------------------------------------------------------------------------
/DenseNet_Cifar10.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "DenseNet: Cifar10
"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "## Imports"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from __future__ import division,print_function\n",
24 | "\n",
25 | "%matplotlib inline\n",
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "\n",
29 | "import sys\n",
30 | "from tqdm import tqdm_notebook as tqdm\n",
31 | "\n",
32 | "import random\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "import math\n",
35 | "\n",
36 | "import numpy as np\n",
37 | "\n",
38 | "import torch\n",
39 | "import torch.nn as nn\n",
40 | "import torch.nn.functional as F\n",
41 | "import torch.optim as optim\n",
42 | "import torch.nn.init as init\n",
43 | "from torch.autograd import Variable, grad\n",
44 | "from torchvision import datasets, transforms\n",
45 | "from torch.nn.parameter import Parameter\n",
46 | "\n",
47 | "import calculate_log as callog\n",
48 | "\n",
49 | "import warnings\n",
50 | "warnings.filterwarnings('ignore')"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "torch.cuda.set_device(0) #Select the GPU"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "## Model definition"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {
73 | "scrolled": true
74 | },
75 | "outputs": [
76 | {
77 | "name": "stdout",
78 | "output_type": "stream",
79 | "text": [
80 | "Done\n"
81 | ]
82 | }
83 | ],
84 | "source": [
85 | "def conv3x3(in_planes, out_planes, stride=1):\n",
86 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
87 | "\n",
88 | "class BottleneckBlock(nn.Module):\n",
89 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n",
90 | " super(BottleneckBlock, self).__init__()\n",
91 | " inter_planes = out_planes * 4\n",
92 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
93 | " self.relu = nn.ReLU(inplace=True)\n",
94 | " self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,\n",
95 | " padding=0, bias=False)\n",
96 | " self.bn2 = nn.BatchNorm2d(inter_planes)\n",
97 | " self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,\n",
98 | " padding=1, bias=False)\n",
99 | " self.droprate = dropRate\n",
100 | " \n",
101 | " def forward(self, x):\n",
102 | " \n",
103 | " out = self.conv1(self.relu(self.bn1(x)))\n",
104 | " \n",
105 | " torch_model.record(out)\n",
106 | " \n",
107 | " if self.droprate > 0:\n",
108 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
109 | " \n",
110 | " out = self.conv2(self.relu(self.bn2(out)))\n",
111 | " torch_model.record(out)\n",
112 | " \n",
113 | " if self.droprate > 0:\n",
114 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
115 | " return torch.cat([x, out], 1)\n",
116 | "\n",
117 | "class TransitionBlock(nn.Module):\n",
118 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n",
119 | " super(TransitionBlock, self).__init__()\n",
120 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
121 | " self.relu = nn.ReLU(inplace=True)\n",
122 | " self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n",
123 | " padding=0, bias=False)\n",
124 | " self.droprate = dropRate\n",
125 | " \n",
126 | " def forward(self, x):\n",
127 | " out = self.conv1(self.relu(self.bn1(x)))\n",
128 | " torch_model.record(out)\n",
129 | " \n",
130 | " if self.droprate > 0:\n",
131 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n",
132 | " return F.avg_pool2d(out, 2)\n",
133 | "\n",
134 | "class DenseBlock(nn.Module):\n",
135 | " def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):\n",
136 | " super(DenseBlock, self).__init__()\n",
137 | " self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)\n",
138 | " \n",
139 | " def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):\n",
140 | " layers = []\n",
141 | " for i in range(int(nb_layers)):\n",
142 | " layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))\n",
143 | " return nn.Sequential(*layers)\n",
144 | " \n",
145 | " def forward(self, x):\n",
146 | " t = self.layer(x)\n",
147 | " torch_model.record(t)\n",
148 | " return t\n",
149 | "\n",
150 | "\n",
151 | "class DenseNet3(nn.Module):\n",
152 | " def __init__(self, depth, num_classes, growth_rate=12,\n",
153 | " reduction=0.5, bottleneck=True, dropRate=0.0):\n",
154 | " super(DenseNet3, self).__init__()\n",
155 | " \n",
156 | " self.collecting = False\n",
157 | " \n",
158 | " in_planes = 2 * growth_rate\n",
159 | " n = (depth - 4) / 3\n",
160 | " if bottleneck == True:\n",
161 | " n = n/2\n",
162 | " block = BottleneckBlock\n",
163 | " else:\n",
164 | " block = BasicBlock\n",
165 | " # 1st conv before any dense block\n",
166 | " self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,\n",
167 | " padding=1, bias=False)\n",
168 | " # 1st block\n",
169 | " self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
170 | " in_planes = int(in_planes+n*growth_rate)\n",
171 | " self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n",
172 | " in_planes = int(math.floor(in_planes*reduction))\n",
173 | " # 2nd block\n",
174 | " self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
175 | " in_planes = int(in_planes+n*growth_rate)\n",
176 | " self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n",
177 | " in_planes = int(math.floor(in_planes*reduction))\n",
178 | " # 3rd block\n",
179 | " self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n",
180 | " in_planes = int(in_planes+n*growth_rate)\n",
181 | " # global average pooling and classifier\n",
182 | " self.bn1 = nn.BatchNorm2d(in_planes)\n",
183 | " self.relu = nn.ReLU(inplace=True)\n",
184 | " self.fc = nn.Linear(in_planes, num_classes)\n",
185 | " self.in_planes = in_planes\n",
186 | "\n",
187 | " for m in self.modules():\n",
188 | " if isinstance(m, nn.Conv2d):\n",
189 | " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
190 | " m.weight.data.normal_(0, math.sqrt(2. / n))\n",
191 | " elif isinstance(m, nn.BatchNorm2d):\n",
192 | " m.weight.data.fill_(1)\n",
193 | " m.bias.data.zero_()\n",
194 | " elif isinstance(m, nn.Linear):\n",
195 | " m.bias.data.zero_()\n",
196 | " \n",
197 | " def forward(self, x):\n",
198 | " out = self.conv1(x)\n",
199 | " out = self.trans1(self.block1(out))\n",
200 | " out = self.trans2(self.block2(out))\n",
201 | " out = self.block3(out)\n",
202 | " out = self.relu(self.bn1(out))\n",
203 | " out = F.avg_pool2d(out, 8)\n",
204 | " out = out.view(-1, self.in_planes)\n",
205 | " return self.fc(out)\n",
206 | " \n",
207 | " def load(self, path=\"densenet_cifar10.pth\"):\n",
208 | " tm = torch.load(path,map_location=\"cpu\")\n",
209 | " self.load_state_dict(tm.state_dict(),strict=False)\n",
210 | " \n",
211 | " def record(self, t):\n",
212 | " if self.collecting:\n",
213 | " self.gram_feats.append(t)\n",
214 | " \n",
215 | " def gram_feature_list(self,x):\n",
216 | " self.collecting = True\n",
217 | " self.gram_feats = []\n",
218 | " self.forward(x)\n",
219 | " self.collecting = False\n",
220 | " temp = self.gram_feats\n",
221 | " self.gram_feats = []\n",
222 | " return temp\n",
223 | " \n",
224 | " def get_min_max(self, data, power):\n",
225 | " mins = []\n",
226 | " maxs = []\n",
227 | " \n",
228 | " for i in range(0,len(data),64):\n",
229 | " batch = data[i:i+64].cuda()\n",
230 | " feat_list = self.gram_feature_list(batch)\n",
231 | " for L,feat_L in enumerate(feat_list):\n",
232 | " if L==len(mins):\n",
233 | " mins.append([None]*len(power))\n",
234 | " maxs.append([None]*len(power))\n",
235 | " \n",
236 | " for p,P in enumerate(power):\n",
237 | " g_p = G_p(feat_L,P)\n",
238 | " \n",
239 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n",
240 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n",
241 | " \n",
242 | " if mins[L][p] is None:\n",
243 | " mins[L][p] = current_min\n",
244 | " maxs[L][p] = current_max\n",
245 | " else:\n",
246 | " mins[L][p] = torch.min(current_min,mins[L][p])\n",
247 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n",
248 | " \n",
249 | " return mins,maxs\n",
250 | " \n",
251 | " def get_deviations(self,data,power,mins,maxs):\n",
252 | " deviations = []\n",
253 | " \n",
254 | " for i in range(0,len(data),64): \n",
255 | " batch = data[i:i+64].cuda()\n",
256 | " feat_list = self.gram_feature_list(batch)\n",
257 | " batch_deviations = []\n",
258 | " for L,feat_L in enumerate(feat_list):\n",
259 | " dev = 0\n",
260 | " for p,P in enumerate(power):\n",
261 | " g_p = G_p(feat_L,P)\n",
262 | " \n",
263 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
264 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n",
265 | " batch_deviations.append(dev.cpu().detach().numpy())\n",
266 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n",
267 | " deviations.append(batch_deviations)\n",
268 | " deviations = np.concatenate(deviations,axis=0)\n",
269 | " \n",
270 | " return deviations\n",
271 | "\n",
272 | "\n",
273 | "torch_model = DenseNet3(100, num_classes=10)\n",
274 | "torch_model.load()\n",
275 | "torch_model.cuda()\n",
276 | "torch_model.params = list(torch_model.parameters())\n",
277 | "torch_model.eval()\n",
278 | "print(\"Done\") "
279 | ]
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "metadata": {},
284 | "source": [
285 | "## Datasets"
286 | ]
287 | },
288 | {
289 | "cell_type": "markdown",
290 | "metadata": {},
291 | "source": [
292 | "In-distribution Datasets"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": 4,
298 | "metadata": {},
299 | "outputs": [
300 | {
301 | "name": "stdout",
302 | "output_type": "stream",
303 | "text": [
304 | "Files already downloaded and verified\n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "batch_size = 128\n",
310 | "mean = np.array([[125.3/255, 123.0/255, 113.9/255]]).T\n",
311 | "\n",
312 | "std = np.array([[63.0/255, 62.1/255.0, 66.7/255.0]]).T\n",
313 | "normalize = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))\n",
314 | "\n",
315 | "transform_train = transforms.Compose([\n",
316 | " transforms.RandomCrop(32, padding=4),\n",
317 | " transforms.RandomHorizontalFlip(),\n",
318 | " transforms.ToTensor(),\n",
319 | " normalize\n",
320 | " \n",
321 | " ])\n",
322 | "transform_test = transforms.Compose([\n",
323 | " transforms.CenterCrop(size=(32, 32)),\n",
324 | " transforms.ToTensor(),\n",
325 | " normalize\n",
326 | " ])\n",
327 | "\n",
328 | "train_loader = torch.utils.data.DataLoader(\n",
329 | " datasets.CIFAR10('data', train=True, download=True,\n",
330 | " transform=transform_train),\n",
331 | " batch_size=batch_size, shuffle=True)\n",
332 | "test_loader = torch.utils.data.DataLoader(\n",
333 | " datasets.CIFAR10('data', train=False, transform=transform_test),\n",
334 | " batch_size=batch_size)\n"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 5,
340 | "metadata": {
341 | "scrolled": true
342 | },
343 | "outputs": [
344 | {
345 | "name": "stdout",
346 | "output_type": "stream",
347 | "text": [
348 | "Files already downloaded and verified\n"
349 | ]
350 | }
351 | ],
352 | "source": [
353 | "data_train = list(torch.utils.data.DataLoader(\n",
354 | " datasets.CIFAR10('data', train=True, download=True,\n",
355 | " transform=transform_test),\n",
356 | " batch_size=1, shuffle=False))"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": 6,
362 | "metadata": {
363 | "scrolled": true
364 | },
365 | "outputs": [
366 | {
367 | "name": "stdout",
368 | "output_type": "stream",
369 | "text": [
370 | "Files already downloaded and verified\n"
371 | ]
372 | }
373 | ],
374 | "source": [
375 | "data = list(torch.utils.data.DataLoader(\n",
376 | " datasets.CIFAR10('data', train=False, download=True,\n",
377 | " transform=transform_test),\n",
378 | " batch_size=1, shuffle=False))"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 7,
384 | "metadata": {},
385 | "outputs": [
386 | {
387 | "name": "stdout",
388 | "output_type": "stream",
389 | "text": [
390 | "Accuracy: 0.9519\n"
391 | ]
392 | }
393 | ],
394 | "source": [
395 | "torch_model.eval()\n",
396 | "correct = 0\n",
397 | "total = 0\n",
398 | "for x,y in test_loader:\n",
399 | " x = x.cuda()\n",
400 | " y = y.numpy()\n",
401 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n",
402 | " total += y.shape[0]\n",
403 | "print(\"Accuracy: \",correct/total)\n"
404 | ]
405 | },
406 | {
407 | "cell_type": "markdown",
408 | "metadata": {},
409 | "source": [
410 | "Out-of-distribution Datasets"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "execution_count": 8,
416 | "metadata": {},
417 | "outputs": [
418 | {
419 | "name": "stdout",
420 | "output_type": "stream",
421 | "text": [
422 | "Files already downloaded and verified\n"
423 | ]
424 | }
425 | ],
426 | "source": [
427 | "cifar100 = list(torch.utils.data.DataLoader(\n",
428 | " datasets.CIFAR100('data', train=False, download=True,\n",
429 | " transform=transform_test),\n",
430 | " batch_size=1, shuffle=False))"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": 9,
436 | "metadata": {},
437 | "outputs": [
438 | {
439 | "name": "stdout",
440 | "output_type": "stream",
441 | "text": [
442 | "Using downloaded and verified file: data/test_32x32.mat\n"
443 | ]
444 | }
445 | ],
446 | "source": [
447 | "svhn = list(torch.utils.data.DataLoader(\n",
448 | " datasets.SVHN('data', split=\"test\", download=True,\n",
449 | " transform=transform_test),\n",
450 | " batch_size=1, shuffle=True))"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": 10,
456 | "metadata": {},
457 | "outputs": [],
458 | "source": [
459 | "isun = list(torch.utils.data.DataLoader(\n",
460 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))"
461 | ]
462 | },
463 | {
464 | "cell_type": "code",
465 | "execution_count": 11,
466 | "metadata": {},
467 | "outputs": [],
468 | "source": [
469 | "lsun_c = list(torch.utils.data.DataLoader(\n",
470 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))"
471 | ]
472 | },
473 | {
474 | "cell_type": "code",
475 | "execution_count": 12,
476 | "metadata": {},
477 | "outputs": [],
478 | "source": [
479 | "lsun_r = list(torch.utils.data.DataLoader(\n",
480 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
481 | ]
482 | },
483 | {
484 | "cell_type": "code",
485 | "execution_count": 13,
486 | "metadata": {},
487 | "outputs": [],
488 | "source": [
489 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n",
490 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": 14,
496 | "metadata": {},
497 | "outputs": [],
498 | "source": [
499 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n",
500 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))"
501 | ]
502 | },
503 | {
504 | "cell_type": "markdown",
505 | "metadata": {},
506 | "source": [
507 | "## Code for Detecting OODs"
508 | ]
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {},
513 | "source": [
514 | " Extract predictions for train and test data "
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": 15,
520 | "metadata": {},
521 | "outputs": [
522 | {
523 | "name": "stdout",
524 | "output_type": "stream",
525 | "text": [
526 | "Done\n",
527 | "Done\n"
528 | ]
529 | }
530 | ],
531 | "source": [
532 | "train_preds = []\n",
533 | "train_confs = []\n",
534 | "train_logits = []\n",
535 | "for idx in range(0,len(data_train),128):\n",
536 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n",
537 | " \n",
538 | " logits = torch_model(batch)\n",
539 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
540 | " preds = np.argmax(confs,axis=1)\n",
541 | " logits = (logits.cpu().detach().numpy())\n",
542 | "\n",
543 | " train_confs.extend(np.max(confs,axis=1)) \n",
544 | " train_preds.extend(preds)\n",
545 | " train_logits.extend(logits)\n",
546 | "print(\"Done\")\n",
547 | "\n",
548 | "test_preds = []\n",
549 | "test_confs = []\n",
550 | "test_logits = []\n",
551 | "\n",
552 | "for idx in range(0,len(data),128):\n",
553 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n",
554 | " \n",
555 | " logits = torch_model(batch)\n",
556 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
557 | " preds = np.argmax(confs,axis=1)\n",
558 | " logits = (logits.cpu().detach().numpy())\n",
559 | "\n",
560 | " test_confs.extend(np.max(confs,axis=1)) \n",
561 | " test_preds.extend(preds)\n",
562 | " test_logits.extend(logits)\n",
563 | "print(\"Done\")"
564 | ]
565 | },
566 | {
567 | "cell_type": "markdown",
568 | "metadata": {},
569 | "source": [
570 | " Code for detecting OODs by identifying anomalies in correlations "
571 | ]
572 | },
573 | {
574 | "cell_type": "code",
575 | "execution_count": 16,
576 | "metadata": {},
577 | "outputs": [],
578 | "source": [
579 | "import calculate_log as callog\n",
580 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n",
581 | " average_results = {}\n",
582 | " for i in range(1,11):\n",
583 | " random.seed(i)\n",
584 | " \n",
585 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n",
586 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n",
587 | "\n",
588 | " validation = all_test_deviations[validation_indices]\n",
589 | " test_deviations = all_test_deviations[test_indices]\n",
590 | "\n",
591 | " t95 = validation.mean(axis=0)+10**-7\n",
592 | " if not normalize:\n",
593 | " t95 = np.ones_like(t95)\n",
594 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
595 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n",
596 | " \n",
597 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n",
598 | " for m in results:\n",
599 | " average_results[m] = average_results.get(m,0)+results[m]\n",
600 | " \n",
601 | " for m in average_results:\n",
602 | " average_results[m] /= i\n",
603 | " if verbose:\n",
604 | " callog.print_results(average_results)\n",
605 | " return average_results\n",
606 | "\n",
607 | "\n",
608 | "def cpu(ob):\n",
609 | " for i in range(len(ob)):\n",
610 | " for j in range(len(ob[i])):\n",
611 | " ob[i][j] = ob[i][j].cpu()\n",
612 | " return ob\n",
613 | " \n",
614 | "def cuda(ob):\n",
615 | " for i in range(len(ob)):\n",
616 | " for j in range(len(ob[i])):\n",
617 | " ob[i][j] = ob[i][j].cuda()\n",
618 | " return ob\n",
619 | "\n",
620 | "class Detector:\n",
621 | " def __init__(self):\n",
622 | " self.all_test_deviations = None\n",
623 | " self.mins = {}\n",
624 | " self.maxs = {}\n",
625 | " self.classes = range(10)\n",
626 | " \n",
627 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n",
628 | " for PRED in tqdm(self.classes):\n",
629 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n",
630 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n",
631 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n",
632 | " self.mins[PRED] = cpu(mins)\n",
633 | " self.maxs[PRED] = cpu(maxs)\n",
634 | " torch.cuda.empty_cache()\n",
635 | " \n",
636 | " def compute_test_deviations(self,POWERS=[10]):\n",
637 | " all_test_deviations = None\n",
638 | " for PRED in tqdm(self.classes):\n",
639 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n",
640 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n",
641 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n",
642 | " mins = cuda(self.mins[PRED])\n",
643 | " maxs = cuda(self.maxs[PRED])\n",
644 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n",
645 | " cpu(mins)\n",
646 | " cpu(maxs)\n",
647 | " if all_test_deviations is None:\n",
648 | " all_test_deviations = test_deviations\n",
649 | " else:\n",
650 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n",
651 | " torch.cuda.empty_cache()\n",
652 | " self.all_test_deviations = all_test_deviations\n",
653 | " \n",
654 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n",
655 | " ood_preds = []\n",
656 | " ood_confs = []\n",
657 | " \n",
658 | " for idx in range(0,len(ood),128):\n",
659 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n",
660 | " logits = torch_model(batch)\n",
661 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n",
662 | " preds = np.argmax(confs,axis=1)\n",
663 | " \n",
664 | " ood_confs.extend(np.max(confs,axis=1))\n",
665 | " ood_preds.extend(preds) \n",
666 | " torch.cuda.empty_cache()\n",
667 | " print(\"Done\")\n",
668 | " \n",
669 | " all_ood_deviations = None\n",
670 | " for PRED in tqdm(self.classes):\n",
671 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n",
672 | " if len(ood_indices)==0:\n",
673 | " continue\n",
674 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n",
675 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n",
676 | " mins = cuda(self.mins[PRED])\n",
677 | " maxs = cuda(self.maxs[PRED])\n",
678 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n",
679 | " cpu(self.mins[PRED])\n",
680 | " cpu(self.maxs[PRED]) \n",
681 | " if all_ood_deviations is None:\n",
682 | " all_ood_deviations = ood_deviations\n",
683 | " else:\n",
684 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n",
685 | " torch.cuda.empty_cache()\n",
686 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n",
687 | " return average_results, self.all_test_deviations, all_ood_deviations\n"
688 | ]
689 | },
690 | {
691 | "cell_type": "markdown",
692 | "metadata": {},
693 | "source": [
694 | " Results
"
695 | ]
696 | },
697 | {
698 | "cell_type": "code",
699 | "execution_count": 17,
700 | "metadata": {
701 | "scrolled": false
702 | },
703 | "outputs": [
704 | {
705 | "data": {
706 | "application/vnd.jupyter.widget-view+json": {
707 | "model_id": "2b50f073b57840a0bd22fb057602fc78",
708 | "version_major": 2,
709 | "version_minor": 0
710 | },
711 | "text/plain": [
712 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
713 | ]
714 | },
715 | "metadata": {},
716 | "output_type": "display_data"
717 | },
718 | {
719 | "name": "stdout",
720 | "output_type": "stream",
721 | "text": [
722 | "\n"
723 | ]
724 | },
725 | {
726 | "data": {
727 | "application/vnd.jupyter.widget-view+json": {
728 | "model_id": "5fc60ad5623a4ef2b163dbdf8562d051",
729 | "version_major": 2,
730 | "version_minor": 0
731 | },
732 | "text/plain": [
733 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
734 | ]
735 | },
736 | "metadata": {},
737 | "output_type": "display_data"
738 | },
739 | {
740 | "name": "stdout",
741 | "output_type": "stream",
742 | "text": [
743 | "\n",
744 | "iSUN\n",
745 | "Done\n"
746 | ]
747 | },
748 | {
749 | "data": {
750 | "application/vnd.jupyter.widget-view+json": {
751 | "model_id": "59d3443d14a04e2dba5374e894cbdcb3",
752 | "version_major": 2,
753 | "version_minor": 0
754 | },
755 | "text/plain": [
756 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
757 | ]
758 | },
759 | "metadata": {},
760 | "output_type": "display_data"
761 | },
762 | {
763 | "name": "stdout",
764 | "output_type": "stream",
765 | "text": [
766 | "\n",
767 | " TNR AUROC DTACC AUIN AUOUT \n",
768 | " 99.030 99.800 97.935 99.795 99.802\n",
769 | "LSUN (R)\n",
770 | "Done\n"
771 | ]
772 | },
773 | {
774 | "data": {
775 | "application/vnd.jupyter.widget-view+json": {
776 | "model_id": "b704f16c4a6b47fc89194c4952de7f88",
777 | "version_major": 2,
778 | "version_minor": 0
779 | },
780 | "text/plain": [
781 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
782 | ]
783 | },
784 | "metadata": {},
785 | "output_type": "display_data"
786 | },
787 | {
788 | "name": "stdout",
789 | "output_type": "stream",
790 | "text": [
791 | "\n",
792 | " TNR AUROC DTACC AUIN AUOUT \n",
793 | " 99.476 99.881 98.633 99.858 99.894\n",
794 | "LSUN (C)\n",
795 | "Done\n"
796 | ]
797 | },
798 | {
799 | "data": {
800 | "application/vnd.jupyter.widget-view+json": {
801 | "model_id": "d1213c56bf9e43c29a250c7b73578898",
802 | "version_major": 2,
803 | "version_minor": 0
804 | },
805 | "text/plain": [
806 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
807 | ]
808 | },
809 | "metadata": {},
810 | "output_type": "display_data"
811 | },
812 | {
813 | "name": "stdout",
814 | "output_type": "stream",
815 | "text": [
816 | "\n",
817 | " TNR AUROC DTACC AUIN AUOUT \n",
818 | " 88.383 97.446 91.987 96.434 97.914\n",
819 | "TinyImgNet (R)\n",
820 | "Done\n"
821 | ]
822 | },
823 | {
824 | "data": {
825 | "application/vnd.jupyter.widget-view+json": {
826 | "model_id": "ce16d61aec4f4acf945cc17c65f6d79c",
827 | "version_major": 2,
828 | "version_minor": 0
829 | },
830 | "text/plain": [
831 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
832 | ]
833 | },
834 | "metadata": {},
835 | "output_type": "display_data"
836 | },
837 | {
838 | "name": "stdout",
839 | "output_type": "stream",
840 | "text": [
841 | "\n",
842 | " TNR AUROC DTACC AUIN AUOUT \n",
843 | " 98.783 99.714 97.891 99.562 99.769\n",
844 | "TinyImgNet (C)\n",
845 | "Done\n"
846 | ]
847 | },
848 | {
849 | "data": {
850 | "application/vnd.jupyter.widget-view+json": {
851 | "model_id": "352bca808e51495eb060301d59df88db",
852 | "version_major": 2,
853 | "version_minor": 0
854 | },
855 | "text/plain": [
856 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
857 | ]
858 | },
859 | "metadata": {},
860 | "output_type": "display_data"
861 | },
862 | {
863 | "name": "stdout",
864 | "output_type": "stream",
865 | "text": [
866 | "\n",
867 | " TNR AUROC DTACC AUIN AUOUT \n",
868 | " 96.693 99.253 96.137 98.957 99.390\n",
869 | "SVHN\n",
870 | "Done\n"
871 | ]
872 | },
873 | {
874 | "data": {
875 | "application/vnd.jupyter.widget-view+json": {
876 | "model_id": "06cd21e80d4e4992958854774563e5fa",
877 | "version_major": 2,
878 | "version_minor": 0
879 | },
880 | "text/plain": [
881 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
882 | ]
883 | },
884 | "metadata": {},
885 | "output_type": "display_data"
886 | },
887 | {
888 | "name": "stdout",
889 | "output_type": "stream",
890 | "text": [
891 | "\n",
892 | " TNR AUROC DTACC AUIN AUOUT \n",
893 | " 96.072 99.126 95.863 96.782 99.709\n",
894 | "CIFAR-100\n",
895 | "Done\n"
896 | ]
897 | },
898 | {
899 | "data": {
900 | "application/vnd.jupyter.widget-view+json": {
901 | "model_id": "368efd02834c401ca641c579f9f3ab94",
902 | "version_major": 2,
903 | "version_minor": 0
904 | },
905 | "text/plain": [
906 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))"
907 | ]
908 | },
909 | "metadata": {},
910 | "output_type": "display_data"
911 | },
912 | {
913 | "name": "stdout",
914 | "output_type": "stream",
915 | "text": [
916 | "\n",
917 | " TNR AUROC DTACC AUIN AUOUT \n",
918 | " 26.683 72.043 67.226 61.419 75.722\n"
919 | ]
920 | }
921 | ],
922 | "source": [
923 | "def G_p(ob, p):\n",
924 | " temp = ob.detach()\n",
925 | " \n",
926 | " temp = temp**p\n",
927 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n",
928 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n",
929 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n",
930 | " \n",
931 | " return temp\n",
932 | "\n",
933 | "detector = Detector()\n",
934 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n",
935 | "\n",
936 | "detector.compute_test_deviations(POWERS=range(1,11))\n",
937 | "\n",
938 | "print(\"iSUN\")\n",
939 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n",
940 | "print(\"LSUN (R)\")\n",
941 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n",
942 | "print(\"LSUN (C)\")\n",
943 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n",
944 | "print(\"TinyImgNet (R)\")\n",
945 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n",
946 | "print(\"TinyImgNet (C)\")\n",
947 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n",
948 | "print(\"SVHN\")\n",
949 | "svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))\n",
950 | "print(\"CIFAR-100\")\n",
951 | "c100_results = detector.compute_ood_deviations(cifar100,POWERS=range(1,11))"
952 | ]
953 | }
954 | ],
955 | "metadata": {
956 | "kernelspec": {
957 | "display_name": "Python 2",
958 | "language": "python",
959 | "name": "python2"
960 | },
961 | "language_info": {
962 | "codemirror_mode": {
963 | "name": "ipython",
964 | "version": 3
965 | },
966 | "file_extension": ".py",
967 | "mimetype": "text/x-python",
968 | "name": "python",
969 | "nbconvert_exporter": "python",
970 | "pygments_lexer": "ipython3",
971 | "version": "3.6.9"
972 | }
973 | },
974 | "nbformat": 4,
975 | "nbformat_minor": 2
976 | }
977 |
--------------------------------------------------------------------------------