├── OE_results.png ├── detailed_results.png ├── setup.sh ├── README.md ├── densenet.py ├── calculate_log.py ├── ResNet_SVHN.ipynb ├── ResNet_Cifar10.ipynb ├── ResNet_Cifar100.ipynb ├── DenseNet_SVHN.ipynb ├── DenseNet_Cifar100.ipynb └── DenseNet_Cifar10.ipynb /OE_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/gram-ood-detection/HEAD/OE_results.png -------------------------------------------------------------------------------- /detailed_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VectorInstitute/gram-ood-detection/HEAD/detailed_results.png -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/s/avgm2u562itwpkl/Imagenet.tar.gz 2 | tar -xzf Imagenet.tar.gz 3 | 4 | wget https://www.dropbox.com/s/kp3my3412u5k9rl/Imagenet_resize.tar.gz 5 | tar -xzf Imagenet_resize.tar.gz 6 | 7 | wget https://www.dropbox.com/s/fhtsw1m3qxlwj6h/LSUN.tar.gz 8 | tar -xzf LSUN.tar.gz 9 | 10 | wget https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz 11 | tar -xzf LSUN_resize.tar.gz 12 | 13 | wget https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz 14 | tar -xzf iSUN.tar.gz 15 | 16 | rm *.gz 17 | 18 | wget https://www.dropbox.com/s/pnbvr16gnpyr1zg/densenet_cifar10.pth 19 | wget https://www.dropbox.com/s/7ur9qo81u30od36/densenet_cifar100.pth 20 | wget https://www.dropbox.com/s/9ol1h2tb3xjdpp1/densenet_svhn.pth 21 | wget https://www.dropbox.com/s/ynidbn7n7ccadog/resnet_cifar10.pth 22 | wget https://www.dropbox.com/s/yzfzf4bwqe4du6w/resnet_cifar100.pth 23 | wget https://www.dropbox.com/s/uvgpgy9pu7s9ps2/resnet_svhn.pth 24 | 25 | 26 | wget https://raw.githubusercontent.com/hendrycks/outlier-exposure/master/CIFAR/snapshots/oe_scratch/cifar100_wrn_oe_scratch_epoch_99.pt 27 | wget https://raw.githubusercontent.com/hendrycks/outlier-exposure/master/CIFAR/snapshots/oe_scratch/cifar10_wrn_oe_scratch_epoch_99.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Detecting Out-of-Distribution Examples with In-distribution Examples and Gram Matrices 2 | ICML 2020: paper, supplementary material and bibtex available at http://proceedings.mlr.press/v119/sastry20a.html 3 | 4 | ## Dependencies 5 | The code is written in Python 3 with Pytorch 1.1. 6 | 7 | ## Results 8 | ![results](https://github.com/VectorInstitute/gram-ood-detection/blob/master/detailed_results.png) 9 | 10 | (Please refer to this [repository](https://github.com/chandramouli-sastry/deep_Mahalanobis_detector) for the results of Baseline/ODIN/Mahalanobis on dataset-pairs not presented in the Mahalanobis paper) 11 | 12 | ## Combining Outlier Exposure (OE) and Ours 13 | ![results](https://github.com/VectorInstitute/gram-ood-detection/blob/master/OE_results.png) 14 | 15 | ## Downloading Out-of-Distribution Datasets and Pre-trained Models 16 | We used the out-of-distribution datasets presented in [odin-pytorch](https://github.com/facebookresearch/odin) 17 | 18 | We used pre-trained neural networks open-sourced by [Mahalanobis](https://github.com/pokaxpoka/deep_Mahalanobis_detector/) and [odin-pytorch](https://github.com/ShiyuLiang/odin-pytroch). The DenseNets trained on CIFAR-10 and CIFAR-100 are by ODIN; remaining are by Mahalanobis. 19 | 20 | For experiments on OE-trained networks, we used the pre-trained networks open-sourced by [OE](https://github.com/hendrycks/outlier-exposure) 21 | 22 | Running the setup.sh downloads the Out-of-Distribution Datasets and pre-trained models. 23 | -------------------------------------------------------------------------------- /densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 13 | padding=1, bias=False) 14 | self.droprate = dropRate 15 | def forward(self, x): 16 | out = self.conv1(self.relu(self.bn1(x))) 17 | if self.droprate > 0: 18 | out = F.dropout(out, p=self.droprate, training=self.training) 19 | return torch.cat([x, out], 1) 20 | 21 | class BottleneckBlock(nn.Module): 22 | def __init__(self, in_planes, out_planes, dropRate=0.0): 23 | super(BottleneckBlock, self).__init__() 24 | inter_planes = out_planes * 4 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | self.bn2 = nn.BatchNorm2d(inter_planes) 30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 31 | padding=1, bias=False) 32 | self.droprate = dropRate 33 | def forward(self, x): 34 | out = self.conv1(self.relu(self.bn1(x))) 35 | if self.droprate > 0: 36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 37 | out = self.conv2(self.relu(self.bn2(out))) 38 | if self.droprate > 0: 39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 40 | return torch.cat([x, out], 1) 41 | 42 | class TransitionBlock(nn.Module): 43 | def __init__(self, in_planes, out_planes, dropRate=0.0): 44 | super(TransitionBlock, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 48 | padding=0, bias=False) 49 | self.droprate = dropRate 50 | def forward(self, x): 51 | out = self.conv1(self.relu(self.bn1(x))) 52 | if self.droprate > 0: 53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 54 | return F.avg_pool2d(out, 2) 55 | 56 | class DenseBlock(nn.Module): 57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0): 58 | super(DenseBlock, self).__init__() 59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate) 60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate): 61 | layers = [] 62 | for i in range(int(nb_layers)): 63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate)) 64 | return nn.Sequential(*layers) 65 | def forward(self, x): 66 | return self.layer(x) 67 | 68 | class DenseNet3(nn.Module): 69 | def __init__(self, depth, num_classes, growth_rate=12, 70 | reduction=0.5, bottleneck=True, dropRate=0.0): 71 | super(DenseNet3, self).__init__() 72 | in_planes = 2 * growth_rate 73 | n = (depth - 4) / 3 74 | if bottleneck == True: 75 | n = n/2 76 | block = BottleneckBlock 77 | else: 78 | block = BasicBlock 79 | # 1st conv before any dense block 80 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, 81 | padding=1, bias=False) 82 | # 1st block 83 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 84 | in_planes = int(in_planes+n*growth_rate) 85 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 86 | in_planes = int(math.floor(in_planes*reduction)) 87 | # 2nd block 88 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 89 | in_planes = int(in_planes+n*growth_rate) 90 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 91 | in_planes = int(math.floor(in_planes*reduction)) 92 | # 3rd block 93 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 94 | in_planes = int(in_planes+n*growth_rate) 95 | # global average pooling and classifier 96 | self.bn1 = nn.BatchNorm2d(in_planes) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.fc = nn.Linear(in_planes, num_classes) 99 | self.in_planes = in_planes 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | elif isinstance(m, nn.Linear): 109 | m.bias.data.zero_() 110 | 111 | def forward(self, x): 112 | out = self.conv1(x) 113 | out = self.trans1(self.block1(out)) 114 | out = self.trans2(self.block2(out)) 115 | out = self.block3(out) 116 | out = self.relu(self.bn1(out)) 117 | out = F.avg_pool2d(out, 8) 118 | out = out.view(-1, self.in_planes) 119 | return self.fc(out) 120 | 121 | # function to extact the multiple features 122 | def feature_list(self, x): 123 | out_list = [] 124 | out = self.conv1(x) 125 | out_list.append(out) 126 | out = self.trans1(self.block1(out)) 127 | out_list.append(out) 128 | out = self.trans2(self.block2(out)) 129 | out_list.append(out) 130 | out = self.block3(out) 131 | out = self.relu(self.bn1(out)) 132 | out_list.append(out) 133 | out = F.avg_pool2d(out, 8) 134 | out = out.view(-1, self.in_planes) 135 | 136 | return self.fc(out), out_list 137 | 138 | def intermediate_forward(self, x, layer_index): 139 | out = self.conv1(x) 140 | if layer_index == 1: 141 | out = self.trans1(self.block1(out)) 142 | elif layer_index == 2: 143 | out = self.trans1(self.block1(out)) 144 | out = self.trans2(self.block2(out)) 145 | elif layer_index == 3: 146 | out = self.trans1(self.block1(out)) 147 | out = self.trans2(self.block2(out)) 148 | out = self.block3(out) 149 | out = self.relu(self.bn1(out)) 150 | return out 151 | 152 | # function to extact the penultimate features 153 | def penultimate_forward(self, x): 154 | out = self.conv1(x) 155 | out = self.trans1(self.block1(out)) 156 | out = self.trans2(self.block2(out)) 157 | out = self.block3(out) 158 | penultimate = self.relu(self.bn1(out)) 159 | out = F.avg_pool2d(penultimate, 8) 160 | out = out.view(-1, self.in_planes) 161 | return self.fc(out), penultimate -------------------------------------------------------------------------------- /calculate_log.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function,division 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import torch.optim as optim 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import numpy as np 11 | import time 12 | from scipy import misc 13 | 14 | import matplotlib 15 | # matplotlib.use('Agg') 16 | import matplotlib.pyplot as plt 17 | 18 | def compute_metric(known, novel): 19 | stype = "" 20 | 21 | tp, fp = dict(), dict() 22 | tnr_at_tpr95 = dict() 23 | 24 | known.sort() 25 | novel.sort() 26 | end = np.max([np.max(known), np.max(novel)]) 27 | start = np.min([np.min(known),np.min(novel)]) 28 | num_k = known.shape[0] 29 | num_n = novel.shape[0] 30 | tp[stype] = -np.ones([num_k+num_n+1], dtype=int) 31 | fp[stype] = -np.ones([num_k+num_n+1], dtype=int) 32 | tp[stype][0], fp[stype][0] = num_k, num_n 33 | k, n = 0, 0 34 | for l in range(num_k+num_n): 35 | if k == num_k: 36 | tp[stype][l+1:] = tp[stype][l] 37 | fp[stype][l+1:] = np.arange(fp[stype][l]-1, -1, -1) 38 | break 39 | elif n == num_n: 40 | tp[stype][l+1:] = np.arange(tp[stype][l]-1, -1, -1) 41 | fp[stype][l+1:] = fp[stype][l] 42 | break 43 | else: 44 | if novel[n] < known[k]: 45 | n += 1 46 | tp[stype][l+1] = tp[stype][l] 47 | fp[stype][l+1] = fp[stype][l] - 1 48 | else: 49 | k += 1 50 | tp[stype][l+1] = tp[stype][l] - 1 51 | fp[stype][l+1] = fp[stype][l] 52 | tpr95_pos = np.abs(tp[stype] / num_k - .95).argmin() 53 | tnr_at_tpr95[stype] = 1. - fp[stype][tpr95_pos] / num_n 54 | mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT'] 55 | results = dict() 56 | results[stype] = dict() 57 | 58 | # TNR 59 | mtype = 'TNR' 60 | results[stype][mtype] = tnr_at_tpr95[stype] 61 | 62 | # AUROC 63 | mtype = 'AUROC' 64 | tpr = np.concatenate([[1.], tp[stype]/tp[stype][0], [0.]]) 65 | fpr = np.concatenate([[1.], fp[stype]/fp[stype][0], [0.]]) 66 | results[stype][mtype] = -np.trapz(1.-fpr, tpr) 67 | 68 | # DTACC 69 | mtype = 'DTACC' 70 | results[stype][mtype] = .5 * (tp[stype]/tp[stype][0] + 1.-fp[stype]/fp[stype][0]).max() 71 | 72 | # AUIN 73 | mtype = 'AUIN' 74 | denom = tp[stype]+fp[stype] 75 | denom[denom == 0.] = -1. 76 | pin_ind = np.concatenate([[True], denom > 0., [True]]) 77 | pin = np.concatenate([[.5], tp[stype]/denom, [0.]]) 78 | results[stype][mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind]) 79 | 80 | # AUOUT 81 | mtype = 'AUOUT' 82 | denom = tp[stype][0]-tp[stype]+fp[stype][0]-fp[stype] 83 | denom[denom == 0.] = -1. 84 | pout_ind = np.concatenate([[True], denom > 0., [True]]) 85 | pout = np.concatenate([[0.], (fp[stype][0]-fp[stype])/denom, [.5]]) 86 | results[stype][mtype] = np.trapz(pout[pout_ind], 1.-fpr[pout_ind]) 87 | 88 | return results[stype] 89 | 90 | def print_results(results): 91 | mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT'] 92 | for mtype in mtypes: 93 | print(' {mtype:6s}'.format(mtype=mtype), end='') 94 | print('') 95 | for mtype in mtypes: 96 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='') 97 | print('') 98 | 99 | 100 | def get_curve(dir_name, stypes = ['Baseline', 'Gaussian_LDA']): 101 | tp, fp = dict(), dict() 102 | tnr_at_tpr95 = dict() 103 | for stype in stypes: 104 | known = np.loadtxt('{}/confidence_{}_In.txt'.format(dir_name, stype), delimiter='\n') 105 | novel = np.loadtxt('{}/confidence_{}_Out.txt'.format(dir_name, stype), delimiter='\n') 106 | known.sort() 107 | novel.sort() 108 | end = np.max([np.max(known), np.max(novel)]) 109 | start = np.min([np.min(known),np.min(novel)]) 110 | num_k = known.shape[0] 111 | num_n = novel.shape[0] 112 | tp[stype] = -np.ones([num_k+num_n+1], dtype=int) 113 | fp[stype] = -np.ones([num_k+num_n+1], dtype=int) 114 | tp[stype][0], fp[stype][0] = num_k, num_n 115 | k, n = 0, 0 116 | for l in range(num_k+num_n): 117 | if k == num_k: 118 | tp[stype][l+1:] = tp[stype][l] 119 | fp[stype][l+1:] = np.arange(fp[stype][l]-1, -1, -1) 120 | break 121 | elif n == num_n: 122 | tp[stype][l+1:] = np.arange(tp[stype][l]-1, -1, -1) 123 | fp[stype][l+1:] = fp[stype][l] 124 | break 125 | else: 126 | if novel[n] < known[k]: 127 | n += 1 128 | tp[stype][l+1] = tp[stype][l] 129 | fp[stype][l+1] = fp[stype][l] - 1 130 | else: 131 | k += 1 132 | tp[stype][l+1] = tp[stype][l] - 1 133 | fp[stype][l+1] = fp[stype][l] 134 | tpr95_pos = np.abs(tp[stype] / num_k - .95).argmin() 135 | tnr_at_tpr95[stype] = 1. - fp[stype][tpr95_pos] / num_n 136 | 137 | return tp, fp, tnr_at_tpr95 138 | 139 | def metric(dir_name, stypes = ['Bas', 'Gau'], verbose=False): 140 | tp, fp, tnr_at_tpr95 = get_curve(dir_name, stypes) 141 | results = dict() 142 | mtypes = ['TNR', 'AUROC', 'DTACC', 'AUIN', 'AUOUT'] 143 | if verbose: 144 | print(' ', end='') 145 | for mtype in mtypes: 146 | print(' {mtype:6s}'.format(mtype=mtype), end='') 147 | print('') 148 | 149 | for stype in stypes: 150 | if verbose: 151 | print('{stype:5s} '.format(stype=stype), end='') 152 | results[stype] = dict() 153 | 154 | # TNR 155 | mtype = 'TNR' 156 | results[stype][mtype] = tnr_at_tpr95[stype] 157 | if verbose: 158 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 159 | 160 | # AUROC 161 | mtype = 'AUROC' 162 | tpr = np.concatenate([[1.], tp[stype]/tp[stype][0], [0.]]) 163 | fpr = np.concatenate([[1.], fp[stype]/fp[stype][0], [0.]]) 164 | results[stype][mtype] = -np.trapz(1.-fpr, tpr) 165 | if verbose: 166 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 167 | 168 | # DTACC 169 | mtype = 'DTACC' 170 | results[stype][mtype] = .5 * (tp[stype]/tp[stype][0] + 1.-fp[stype]/fp[stype][0]).max() 171 | if verbose: 172 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 173 | 174 | # AUIN 175 | mtype = 'AUIN' 176 | denom = tp[stype]+fp[stype] 177 | denom[denom == 0.] = -1. 178 | pin_ind = np.concatenate([[True], denom > 0., [True]]) 179 | pin = np.concatenate([[.5], tp[stype]/denom, [0.]]) 180 | results[stype][mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind]) 181 | if verbose: 182 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 183 | 184 | # AUOUT 185 | mtype = 'AUOUT' 186 | denom = tp[stype][0]-tp[stype]+fp[stype][0]-fp[stype] 187 | denom[denom == 0.] = -1. 188 | pout_ind = np.concatenate([[True], denom > 0., [True]]) 189 | pout = np.concatenate([[0.], (fp[stype][0]-fp[stype])/denom, [.5]]) 190 | results[stype][mtype] = np.trapz(pout[pout_ind], 1.-fpr[pout_ind]) 191 | if verbose: 192 | print(' {val:6.3f}'.format(val=100.*results[stype][mtype]), end='') 193 | print('') 194 | 195 | return results -------------------------------------------------------------------------------- /ResNet_SVHN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

ResNet: SVHN

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from __future__ import division,print_function\n", 24 | "\n", 25 | "%matplotlib inline\n", 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "\n", 29 | "import sys\n", 30 | "from tqdm import tqdm_notebook as tqdm\n", 31 | "\n", 32 | "import random\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import math\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "import torch\n", 39 | "import torch.nn as nn\n", 40 | "import torch.nn.functional as F\n", 41 | "import torch.optim as optim\n", 42 | "import torch.nn.init as init\n", 43 | "from torch.autograd import Variable, grad\n", 44 | "from torchvision import datasets, transforms\n", 45 | "from torch.nn.parameter import Parameter\n", 46 | "\n", 47 | "import calculate_log as callog\n", 48 | "\n", 49 | "import warnings\n", 50 | "warnings.filterwarnings('ignore')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "torch.cuda.set_device(0) #Select the GPU" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Model definition" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": { 73 | "scrolled": true 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "Done\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "def conv3x3(in_planes, out_planes, stride=1):\n", 86 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", 87 | "\n", 88 | "class BasicBlock(nn.Module):\n", 89 | " expansion = 1\n", 90 | "\n", 91 | " def __init__(self, in_planes, planes, stride=1):\n", 92 | " super(BasicBlock, self).__init__()\n", 93 | " self.conv1 = conv3x3(in_planes, planes, stride)\n", 94 | " self.bn1 = nn.BatchNorm2d(planes)\n", 95 | " self.conv2 = conv3x3(planes, planes)\n", 96 | " self.bn2 = nn.BatchNorm2d(planes)\n", 97 | "\n", 98 | " self.shortcut = nn.Sequential()\n", 99 | " if stride != 1 or in_planes != self.expansion*planes:\n", 100 | " self.shortcut = nn.Sequential(\n", 101 | " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", 102 | " nn.BatchNorm2d(self.expansion*planes)\n", 103 | " )\n", 104 | " \n", 105 | " def forward(self, x):\n", 106 | " t = self.conv1(x)\n", 107 | " out = F.relu(self.bn1(t))\n", 108 | " torch_model.record(t)\n", 109 | " torch_model.record(out)\n", 110 | " t = self.conv2(out)\n", 111 | " out = self.bn2(self.conv2(out))\n", 112 | " torch_model.record(t)\n", 113 | " torch_model.record(out)\n", 114 | " t = self.shortcut(x)\n", 115 | " out += t\n", 116 | " torch_model.record(t)\n", 117 | " out = F.relu(out)\n", 118 | " torch_model.record(out)\n", 119 | " \n", 120 | " return out#, out_list\n", 121 | "\n", 122 | "class ResNet(nn.Module):\n", 123 | " def __init__(self, block, num_blocks, num_classes=10):\n", 124 | " super(ResNet, self).__init__()\n", 125 | " self.in_planes = 64\n", 126 | "\n", 127 | " self.conv1 = conv3x3(3,64)\n", 128 | " self.bn1 = nn.BatchNorm2d(64)\n", 129 | " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n", 130 | " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n", 131 | " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n", 132 | " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n", 133 | " self.linear = nn.Linear(512*block.expansion, num_classes)\n", 134 | " \n", 135 | " self.collecting = False\n", 136 | " \n", 137 | " def _make_layer(self, block, planes, num_blocks, stride):\n", 138 | " strides = [stride] + [1]*(num_blocks-1)\n", 139 | " layers = []\n", 140 | " for stride in strides:\n", 141 | " layers.append(block(self.in_planes, planes, stride))\n", 142 | " self.in_planes = planes * block.expansion\n", 143 | " return nn.Sequential(*layers)\n", 144 | " \n", 145 | " def forward(self, x):\n", 146 | " out = F.relu(self.bn1(self.conv1(x)))\n", 147 | " out = self.layer1(out)\n", 148 | " out = self.layer2(out)\n", 149 | " out = self.layer3(out)\n", 150 | " out = self.layer4(out)\n", 151 | " out = F.avg_pool2d(out, 4)\n", 152 | " out = out.view(out.size(0), -1)\n", 153 | " y = self.linear(out)\n", 154 | " return y\n", 155 | " \n", 156 | " def record(self, t):\n", 157 | " if self.collecting:\n", 158 | " self.gram_feats.append(t)\n", 159 | " \n", 160 | " def gram_feature_list(self,x):\n", 161 | " self.collecting = True\n", 162 | " self.gram_feats = []\n", 163 | " self.forward(x)\n", 164 | " self.collecting = False\n", 165 | " temp = self.gram_feats\n", 166 | " self.gram_feats = []\n", 167 | " return temp\n", 168 | " \n", 169 | " def load(self, path=\"resnet_svhn.pth\"):\n", 170 | " tm = torch.load(path,map_location=\"cpu\") \n", 171 | " self.load_state_dict(tm)\n", 172 | " \n", 173 | " def get_min_max(self, data, power):\n", 174 | " mins = []\n", 175 | " maxs = []\n", 176 | " \n", 177 | " for i in range(0,len(data),128):\n", 178 | " batch = data[i:i+128].cuda()\n", 179 | " feat_list = self.gram_feature_list(batch)\n", 180 | " for L,feat_L in enumerate(feat_list):\n", 181 | " if L==len(mins):\n", 182 | " mins.append([None]*len(power))\n", 183 | " maxs.append([None]*len(power))\n", 184 | " \n", 185 | " for p,P in enumerate(power):\n", 186 | " g_p = G_p(feat_L,P)\n", 187 | " \n", 188 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n", 189 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n", 190 | " \n", 191 | " if mins[L][p] is None:\n", 192 | " mins[L][p] = current_min\n", 193 | " maxs[L][p] = current_max\n", 194 | " else:\n", 195 | " mins[L][p] = torch.min(current_min,mins[L][p])\n", 196 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n", 197 | " \n", 198 | " return mins,maxs\n", 199 | " \n", 200 | " def get_deviations(self,data,power,mins,maxs):\n", 201 | " deviations = []\n", 202 | " \n", 203 | " for i in range(0,len(data),128): \n", 204 | " batch = data[i:i+128].cuda()\n", 205 | " feat_list = self.gram_feature_list(batch)\n", 206 | " batch_deviations = []\n", 207 | " for L,feat_L in enumerate(feat_list):\n", 208 | " dev = 0\n", 209 | " for p,P in enumerate(power):\n", 210 | " g_p = G_p(feat_L,P)\n", 211 | " \n", 212 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 213 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 214 | " batch_deviations.append(dev.cpu().detach().numpy())\n", 215 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n", 216 | " deviations.append(batch_deviations)\n", 217 | " deviations = np.concatenate(deviations,axis=0)\n", 218 | " \n", 219 | " return deviations\n", 220 | "\n", 221 | "torch_model = ResNet(BasicBlock, [3,4,6,3], num_classes=10)\n", 222 | "torch_model.load()\n", 223 | "torch_model.cuda()\n", 224 | "torch_model.params = list(torch_model.parameters())\n", 225 | "torch_model.eval()\n", 226 | "print(\"Done\") " 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "## Datasets" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "In-distribution Datasets" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 4, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "Using downloaded and verified file: data/train_32x32.mat\n", 253 | "Using downloaded and verified file: data/test_32x32.mat\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "batch_size = 128\n", 259 | "mean = np.array([[0.4914, 0.4822, 0.4465]]).T\n", 260 | "\n", 261 | "std = np.array([[0.2023, 0.1994, 0.2010]]).T\n", 262 | "normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n", 263 | "\n", 264 | "transform_train = transforms.Compose([\n", 265 | " transforms.RandomCrop(32, padding=4),\n", 266 | " transforms.RandomHorizontalFlip(),\n", 267 | " transforms.ToTensor(),\n", 268 | " normalize\n", 269 | " \n", 270 | " ])\n", 271 | "transform_test = transforms.Compose([\n", 272 | " transforms.CenterCrop(size=(32, 32)),\n", 273 | " transforms.ToTensor(),\n", 274 | " normalize\n", 275 | " ])\n", 276 | "\n", 277 | "\n", 278 | "train_loader = torch.utils.data.DataLoader(\n", 279 | " datasets.SVHN('data', split=\"train\", download=True,\n", 280 | " transform=transform_train),\n", 281 | " batch_size=batch_size, shuffle=True)\n", 282 | "test_loader = torch.utils.data.DataLoader(\n", 283 | " datasets.SVHN('data', split=\"test\", download=True, transform=transform_test),\n", 284 | " batch_size=batch_size)\n" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 5, 290 | "metadata": { 291 | "scrolled": true 292 | }, 293 | "outputs": [ 294 | { 295 | "name": "stdout", 296 | "output_type": "stream", 297 | "text": [ 298 | "Using downloaded and verified file: data/train_32x32.mat\n" 299 | ] 300 | } 301 | ], 302 | "source": [ 303 | "data_train = list(list(torch.utils.data.DataLoader(\n", 304 | " datasets.SVHN('data', split=\"train\", download=True,\n", 305 | " transform=transform_test),\n", 306 | " batch_size=1, shuffle=True)))" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 6, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "Using downloaded and verified file: data/test_32x32.mat\n" 319 | ] 320 | } 321 | ], 322 | "source": [ 323 | "data = list(list(torch.utils.data.DataLoader(\n", 324 | " datasets.SVHN('data', split=\"test\", download=True,\n", 325 | " transform=transform_test),\n", 326 | " batch_size=1, shuffle=False)))" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 7, 332 | "metadata": {}, 333 | "outputs": [ 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "Accuracy: 0.9668484941610326\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "torch_model.eval()\n", 344 | "correct = 0\n", 345 | "total = 0\n", 346 | "for x,y in test_loader:\n", 347 | " x = x.cuda()\n", 348 | " y = y.numpy()\n", 349 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n", 350 | " total += y.shape[0]\n", 351 | "print(\"Accuracy: \",correct/total)\n" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "Out-of-distribution Datasets" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 8, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "name": "stdout", 368 | "output_type": "stream", 369 | "text": [ 370 | "Files already downloaded and verified\n" 371 | ] 372 | } 373 | ], 374 | "source": [ 375 | "cifar10 = list(torch.utils.data.DataLoader(\n", 376 | " datasets.CIFAR10('data', train=False, download=True,\n", 377 | " transform=transform_test),\n", 378 | " batch_size=1, shuffle=True))" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 9, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "isun = list(torch.utils.data.DataLoader(\n", 388 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 10, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "lsun_c = list(torch.utils.data.DataLoader(\n", 398 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 11, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "lsun_r = list(torch.utils.data.DataLoader(\n", 408 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 12, 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n", 418 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 13, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n", 428 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "metadata": {}, 434 | "source": [ 435 | "## Code for Detecting OODs" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | " Extract predictions for train and test data " 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 14, 448 | "metadata": { 449 | "scrolled": true 450 | }, 451 | "outputs": [ 452 | { 453 | "name": "stdout", 454 | "output_type": "stream", 455 | "text": [ 456 | "Done\n", 457 | "Done\n" 458 | ] 459 | } 460 | ], 461 | "source": [ 462 | "train_preds = []\n", 463 | "train_confs = []\n", 464 | "train_logits = []\n", 465 | "for idx in range(0,len(data_train),128):\n", 466 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n", 467 | " \n", 468 | " logits = torch_model(batch)\n", 469 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 470 | " preds = np.argmax(confs,axis=1)\n", 471 | " logits = (logits.cpu().detach().numpy())\n", 472 | "\n", 473 | " train_confs.extend(np.max(confs,axis=1)) \n", 474 | " train_preds.extend(preds)\n", 475 | " train_logits.extend(logits)\n", 476 | "print(\"Done\")\n", 477 | "\n", 478 | "test_preds = []\n", 479 | "test_confs = []\n", 480 | "test_logits = []\n", 481 | "\n", 482 | "for idx in range(0,len(data),128):\n", 483 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n", 484 | " \n", 485 | " logits = torch_model(batch)\n", 486 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 487 | " preds = np.argmax(confs,axis=1)\n", 488 | " logits = (logits.cpu().detach().numpy())\n", 489 | "\n", 490 | " test_confs.extend(np.max(confs,axis=1)) \n", 491 | " test_preds.extend(preds)\n", 492 | " test_logits.extend(logits)\n", 493 | "print(\"Done\")" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": {}, 499 | "source": [ 500 | " Code for detecting OODs by identifying anomalies in correlations " 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 15, 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "import calculate_log as callog\n", 510 | "\n", 511 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n", 512 | " average_results = {}\n", 513 | " for i in range(1,11):\n", 514 | " random.seed(i)\n", 515 | " \n", 516 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n", 517 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n", 518 | "\n", 519 | " validation = all_test_deviations[validation_indices]\n", 520 | " test_deviations = all_test_deviations[test_indices]\n", 521 | "\n", 522 | " t95 = validation.mean(axis=0)+10**-7\n", 523 | " if not normalize:\n", 524 | " t95 = np.ones_like(t95)\n", 525 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 526 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 527 | " \n", 528 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n", 529 | " for m in results:\n", 530 | " average_results[m] = average_results.get(m,0)+results[m]\n", 531 | " \n", 532 | " for m in average_results:\n", 533 | " average_results[m] /= i\n", 534 | " if verbose:\n", 535 | " callog.print_results(average_results)\n", 536 | " return average_results\n", 537 | "\n", 538 | "def cpu(ob):\n", 539 | " for i in range(len(ob)):\n", 540 | " for j in range(len(ob[i])):\n", 541 | " ob[i][j] = ob[i][j].cpu()\n", 542 | " return ob\n", 543 | "\n", 544 | "def cuda(ob):\n", 545 | " for i in range(len(ob)):\n", 546 | " for j in range(len(ob[i])):\n", 547 | " ob[i][j] = ob[i][j].cuda()\n", 548 | " return ob\n", 549 | "\n", 550 | "class Detector:\n", 551 | " def __init__(self):\n", 552 | " self.all_test_deviations = None\n", 553 | " self.mins = {}\n", 554 | " self.maxs = {}\n", 555 | " \n", 556 | " self.classes = range(10)\n", 557 | " \n", 558 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n", 559 | " for PRED in tqdm(self.classes):\n", 560 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n", 561 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n", 562 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n", 563 | " self.mins[PRED] = cpu(mins)\n", 564 | " self.maxs[PRED] = cpu(maxs)\n", 565 | " torch.cuda.empty_cache()\n", 566 | " \n", 567 | " def compute_test_deviations(self,POWERS=[10]):\n", 568 | " all_test_deviations = None\n", 569 | " test_classes = []\n", 570 | " for PRED in tqdm(self.classes):\n", 571 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n", 572 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n", 573 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n", 574 | " \n", 575 | " test_classes.extend([PRED]*len(test_indices))\n", 576 | " \n", 577 | " mins = cuda(self.mins[PRED])\n", 578 | " maxs = cuda(self.maxs[PRED])\n", 579 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n", 580 | " cpu(mins)\n", 581 | " cpu(maxs)\n", 582 | " if all_test_deviations is None:\n", 583 | " all_test_deviations = test_deviations\n", 584 | " else:\n", 585 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n", 586 | " torch.cuda.empty_cache()\n", 587 | " self.all_test_deviations = all_test_deviations\n", 588 | " \n", 589 | " self.test_classes = np.array(test_classes)\n", 590 | " \n", 591 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n", 592 | " ood_preds = []\n", 593 | " ood_confs = []\n", 594 | " \n", 595 | " for idx in range(0,len(ood),128):\n", 596 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n", 597 | " logits = torch_model(batch)\n", 598 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 599 | " preds = np.argmax(confs,axis=1)\n", 600 | " \n", 601 | " ood_confs.extend(np.max(confs,axis=1))\n", 602 | " ood_preds.extend(preds) \n", 603 | " torch.cuda.empty_cache()\n", 604 | " print(\"Done\")\n", 605 | " \n", 606 | " ood_classes = []\n", 607 | " all_ood_deviations = None\n", 608 | " for PRED in tqdm(self.classes):\n", 609 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n", 610 | " if len(ood_indices)==0:\n", 611 | " continue\n", 612 | " ood_classes.extend([PRED]*len(ood_indices))\n", 613 | " \n", 614 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n", 615 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n", 616 | " mins = cuda(self.mins[PRED])\n", 617 | " maxs = cuda(self.maxs[PRED])\n", 618 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n", 619 | " cpu(self.mins[PRED])\n", 620 | " cpu(self.maxs[PRED]) \n", 621 | " if all_ood_deviations is None:\n", 622 | " all_ood_deviations = ood_deviations\n", 623 | " else:\n", 624 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n", 625 | " torch.cuda.empty_cache()\n", 626 | " \n", 627 | " self.ood_classes = np.array(ood_classes)\n", 628 | " \n", 629 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n", 630 | " return average_results, self.all_test_deviations, all_ood_deviations\n" 631 | ] 632 | }, 633 | { 634 | "cell_type": "markdown", 635 | "metadata": {}, 636 | "source": [ 637 | "

Results

" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": 16, 643 | "metadata": {}, 644 | "outputs": [ 645 | { 646 | "data": { 647 | "application/vnd.jupyter.widget-view+json": { 648 | "model_id": "973bca4a7644474685a068878c21aa61", 649 | "version_major": 2, 650 | "version_minor": 0 651 | }, 652 | "text/plain": [ 653 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 654 | ] 655 | }, 656 | "metadata": {}, 657 | "output_type": "display_data" 658 | }, 659 | { 660 | "name": "stdout", 661 | "output_type": "stream", 662 | "text": [ 663 | "\n" 664 | ] 665 | }, 666 | { 667 | "data": { 668 | "application/vnd.jupyter.widget-view+json": { 669 | "model_id": "3c037095b42c4ea5897065610bd9ad1d", 670 | "version_major": 2, 671 | "version_minor": 0 672 | }, 673 | "text/plain": [ 674 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 675 | ] 676 | }, 677 | "metadata": {}, 678 | "output_type": "display_data" 679 | }, 680 | { 681 | "name": "stdout", 682 | "output_type": "stream", 683 | "text": [ 684 | "\n", 685 | "iSUN\n", 686 | "Done\n" 687 | ] 688 | }, 689 | { 690 | "data": { 691 | "application/vnd.jupyter.widget-view+json": { 692 | "model_id": "60342b9ed1854718ab89229ffb69b6e8", 693 | "version_major": 2, 694 | "version_minor": 0 695 | }, 696 | "text/plain": [ 697 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 698 | ] 699 | }, 700 | "metadata": {}, 701 | "output_type": "display_data" 702 | }, 703 | { 704 | "name": "stdout", 705 | "output_type": "stream", 706 | "text": [ 707 | "\n", 708 | " TNR AUROC DTACC AUIN AUOUT \n", 709 | " 99.444 99.766 98.113 99.911 99.337\n", 710 | "LSUN (R)\n", 711 | "Done\n" 712 | ] 713 | }, 714 | { 715 | "data": { 716 | "application/vnd.jupyter.widget-view+json": { 717 | "model_id": "00fe9a5a01bb4c5ebedb0c66cc732200", 718 | "version_major": 2, 719 | "version_minor": 0 720 | }, 721 | "text/plain": [ 722 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 723 | ] 724 | }, 725 | "metadata": {}, 726 | "output_type": "display_data" 727 | }, 728 | { 729 | "name": "stdout", 730 | "output_type": "stream", 731 | "text": [ 732 | "\n", 733 | " TNR AUROC DTACC AUIN AUOUT \n", 734 | " 99.555 99.823 98.481 99.926 99.538\n", 735 | "LSUN (C)\n", 736 | "Done\n" 737 | ] 738 | }, 739 | { 740 | "data": { 741 | "application/vnd.jupyter.widget-view+json": { 742 | "model_id": "7c5b88d59b6841edb2823e57bb5d77e5", 743 | "version_major": 2, 744 | "version_minor": 0 745 | }, 746 | "text/plain": [ 747 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 748 | ] 749 | }, 750 | "metadata": {}, 751 | "output_type": "display_data" 752 | }, 753 | { 754 | "name": "stdout", 755 | "output_type": "stream", 756 | "text": [ 757 | "\n", 758 | " TNR AUROC DTACC AUIN AUOUT \n", 759 | " 94.191 98.739 94.669 99.430 97.383\n", 760 | "TinyImgNet (R)\n", 761 | "Done\n" 762 | ] 763 | }, 764 | { 765 | "data": { 766 | "application/vnd.jupyter.widget-view+json": { 767 | "model_id": "a25a7a21142b46688b86f890ddb076f1", 768 | "version_major": 2, 769 | "version_minor": 0 770 | }, 771 | "text/plain": [ 772 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 773 | ] 774 | }, 775 | "metadata": {}, 776 | "output_type": "display_data" 777 | }, 778 | { 779 | "name": "stdout", 780 | "output_type": "stream", 781 | "text": [ 782 | "\n", 783 | " TNR AUROC DTACC AUIN AUOUT \n", 784 | " 99.280 99.725 97.860 99.885 99.297\n", 785 | "TinyImgNet (C)\n", 786 | "Done\n" 787 | ] 788 | }, 789 | { 790 | "data": { 791 | "application/vnd.jupyter.widget-view+json": { 792 | "model_id": "7e6afbe421d4410eb4a52313ab2ce3bf", 793 | "version_major": 2, 794 | "version_minor": 0 795 | }, 796 | "text/plain": [ 797 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 798 | ] 799 | }, 800 | "metadata": {}, 801 | "output_type": "display_data" 802 | }, 803 | { 804 | "name": "stdout", 805 | "output_type": "stream", 806 | "text": [ 807 | "\n", 808 | " TNR AUROC DTACC AUIN AUOUT \n", 809 | " 98.392 99.481 96.958 99.744 98.750\n", 810 | "CIFAR-10\n", 811 | "Done\n" 812 | ] 813 | }, 814 | { 815 | "data": { 816 | "application/vnd.jupyter.widget-view+json": { 817 | "model_id": "4fb502a55bc044618cbd5647ae07cd8c", 818 | "version_major": 2, 819 | "version_minor": 0 820 | }, 821 | "text/plain": [ 822 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 823 | ] 824 | }, 825 | "metadata": {}, 826 | "output_type": "display_data" 827 | }, 828 | { 829 | "name": "stdout", 830 | "output_type": "stream", 831 | "text": [ 832 | "\n", 833 | " TNR AUROC DTACC AUIN AUOUT \n", 834 | " 85.753 97.305 91.985 98.871 93.185\n" 835 | ] 836 | } 837 | ], 838 | "source": [ 839 | "def G_p(ob, p):\n", 840 | " temp = ob.detach()\n", 841 | " \n", 842 | " temp = temp**p\n", 843 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n", 844 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n", 845 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n", 846 | " \n", 847 | " return temp\n", 848 | "\n", 849 | "\n", 850 | "detector = Detector()\n", 851 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n", 852 | "\n", 853 | "detector.compute_test_deviations(POWERS=range(1,11))\n", 854 | "\n", 855 | "print(\"iSUN\")\n", 856 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n", 857 | "print(\"LSUN (R)\")\n", 858 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n", 859 | "print(\"LSUN (C)\")\n", 860 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n", 861 | "print(\"TinyImgNet (R)\")\n", 862 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n", 863 | "print(\"TinyImgNet (C)\")\n", 864 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n", 865 | "print(\"CIFAR-10\")\n", 866 | "c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))" 867 | ] 868 | } 869 | ], 870 | "metadata": { 871 | "kernelspec": { 872 | "display_name": "Python 2", 873 | "language": "python", 874 | "name": "python2" 875 | }, 876 | "language_info": { 877 | "codemirror_mode": { 878 | "name": "ipython", 879 | "version": 3 880 | }, 881 | "file_extension": ".py", 882 | "mimetype": "text/x-python", 883 | "name": "python", 884 | "nbconvert_exporter": "python", 885 | "pygments_lexer": "ipython3", 886 | "version": "3.6.9" 887 | } 888 | }, 889 | "nbformat": 4, 890 | "nbformat_minor": 2 891 | } 892 | -------------------------------------------------------------------------------- /ResNet_Cifar10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

ResNet: Cifar10

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from __future__ import division,print_function\n", 24 | "\n", 25 | "%matplotlib inline\n", 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "\n", 29 | "import sys\n", 30 | "from tqdm import tqdm_notebook as tqdm\n", 31 | "\n", 32 | "import random\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import math\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "import torch\n", 39 | "import torch.nn as nn\n", 40 | "import torch.nn.functional as F\n", 41 | "import torch.optim as optim\n", 42 | "import torch.nn.init as init\n", 43 | "from torch.autograd import Variable, grad\n", 44 | "from torchvision import datasets, transforms\n", 45 | "from torch.nn.parameter import Parameter\n", 46 | "\n", 47 | "import calculate_log as callog\n", 48 | "\n", 49 | "import warnings\n", 50 | "warnings.filterwarnings('ignore')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "torch.cuda.set_device(1) #Select the GPU" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Model definition" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "Done\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "def conv3x3(in_planes, out_planes, stride=1):\n", 84 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", 85 | "\n", 86 | "class BasicBlock(nn.Module):\n", 87 | " expansion = 1\n", 88 | "\n", 89 | " def __init__(self, in_planes, planes, stride=1):\n", 90 | " super(BasicBlock, self).__init__()\n", 91 | " self.conv1 = conv3x3(in_planes, planes, stride)\n", 92 | " self.bn1 = nn.BatchNorm2d(planes)\n", 93 | " self.conv2 = conv3x3(planes, planes)\n", 94 | " self.bn2 = nn.BatchNorm2d(planes)\n", 95 | "\n", 96 | " self.shortcut = nn.Sequential()\n", 97 | " if stride != 1 or in_planes != self.expansion*planes:\n", 98 | " self.shortcut = nn.Sequential(\n", 99 | " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", 100 | " nn.BatchNorm2d(self.expansion*planes)\n", 101 | " )\n", 102 | " \n", 103 | " def forward(self, x):\n", 104 | " t = self.conv1(x)\n", 105 | " out = F.relu(self.bn1(t))\n", 106 | " torch_model.record(t)\n", 107 | " torch_model.record(out)\n", 108 | " t = self.conv2(out)\n", 109 | " out = self.bn2(self.conv2(out))\n", 110 | " torch_model.record(t)\n", 111 | " torch_model.record(out)\n", 112 | " t = self.shortcut(x)\n", 113 | " out += t\n", 114 | " torch_model.record(t)\n", 115 | " out = F.relu(out)\n", 116 | " torch_model.record(out)\n", 117 | " \n", 118 | " return out\n", 119 | "\n", 120 | "class ResNet(nn.Module):\n", 121 | " def __init__(self, block, num_blocks, num_classes=10):\n", 122 | " super(ResNet, self).__init__()\n", 123 | " self.in_planes = 64\n", 124 | "\n", 125 | " self.conv1 = conv3x3(3,64)\n", 126 | " self.bn1 = nn.BatchNorm2d(64)\n", 127 | " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n", 128 | " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n", 129 | " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n", 130 | " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n", 131 | " self.linear = nn.Linear(512*block.expansion, num_classes)\n", 132 | " \n", 133 | " self.collecting = False\n", 134 | " \n", 135 | " def _make_layer(self, block, planes, num_blocks, stride):\n", 136 | " strides = [stride] + [1]*(num_blocks-1)\n", 137 | " layers = []\n", 138 | " for stride in strides:\n", 139 | " layers.append(block(self.in_planes, planes, stride))\n", 140 | " self.in_planes = planes * block.expansion\n", 141 | " return nn.Sequential(*layers)\n", 142 | " \n", 143 | " def forward(self, x):\n", 144 | " out = F.relu(self.bn1(self.conv1(x)))\n", 145 | " out = self.layer1(out)\n", 146 | " out = self.layer2(out)\n", 147 | " out = self.layer3(out)\n", 148 | " out = self.layer4(out)\n", 149 | " out = F.avg_pool2d(out, 4)\n", 150 | " out = out.view(out.size(0), -1)\n", 151 | " y = self.linear(out)\n", 152 | " return y\n", 153 | " \n", 154 | " def record(self, t):\n", 155 | " if self.collecting:\n", 156 | " self.gram_feats.append(t)\n", 157 | " \n", 158 | " def gram_feature_list(self,x):\n", 159 | " self.collecting = True\n", 160 | " self.gram_feats = []\n", 161 | " self.forward(x)\n", 162 | " self.collecting = False\n", 163 | " temp = self.gram_feats\n", 164 | " self.gram_feats = []\n", 165 | " return temp\n", 166 | " \n", 167 | " def load(self, path=\"resnet_cifar10.pth\"):\n", 168 | " tm = torch.load(path,map_location=\"cpu\") \n", 169 | " self.load_state_dict(tm)\n", 170 | " \n", 171 | " def get_min_max(self, data, power):\n", 172 | " mins = []\n", 173 | " maxs = []\n", 174 | " \n", 175 | " for i in range(0,len(data),128):\n", 176 | " batch = data[i:i+128].cuda()\n", 177 | " feat_list = self.gram_feature_list(batch)\n", 178 | " for L,feat_L in enumerate(feat_list):\n", 179 | " if L==len(mins):\n", 180 | " mins.append([None]*len(power))\n", 181 | " maxs.append([None]*len(power))\n", 182 | " \n", 183 | " for p,P in enumerate(power):\n", 184 | " g_p = G_p(feat_L,P)\n", 185 | " \n", 186 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n", 187 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n", 188 | " \n", 189 | " if mins[L][p] is None:\n", 190 | " mins[L][p] = current_min\n", 191 | " maxs[L][p] = current_max\n", 192 | " else:\n", 193 | " mins[L][p] = torch.min(current_min,mins[L][p])\n", 194 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n", 195 | " \n", 196 | " return mins,maxs\n", 197 | " \n", 198 | " def get_deviations(self,data,power,mins,maxs):\n", 199 | " deviations = []\n", 200 | " \n", 201 | " for i in range(0,len(data),128): \n", 202 | " batch = data[i:i+128].cuda()\n", 203 | " feat_list = self.gram_feature_list(batch)\n", 204 | " batch_deviations = []\n", 205 | " for L,feat_L in enumerate(feat_list):\n", 206 | " dev = 0\n", 207 | " for p,P in enumerate(power):\n", 208 | " g_p = G_p(feat_L,P)\n", 209 | " \n", 210 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 211 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 212 | " batch_deviations.append(dev.cpu().detach().numpy())\n", 213 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n", 214 | " deviations.append(batch_deviations)\n", 215 | " deviations = np.concatenate(deviations,axis=0)\n", 216 | " \n", 217 | " return deviations\n", 218 | "\n", 219 | "\n", 220 | "torch_model = ResNet(BasicBlock, [3,4,6,3], num_classes=10)\n", 221 | "torch_model.load()\n", 222 | "torch_model.cuda()\n", 223 | "torch_model.params = list(torch_model.parameters())\n", 224 | "torch_model.eval()\n", 225 | "print(\"Done\") " 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "## Datasets" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "In-distribution Datasets" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 4, 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "Files already downloaded and verified\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "batch_size = 128\n", 257 | "mean = np.array([[0.4914, 0.4822, 0.4465]]).T\n", 258 | "\n", 259 | "std = np.array([[0.2023, 0.1994, 0.2010]]).T\n", 260 | "normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n", 261 | "\n", 262 | "transform_train = transforms.Compose([\n", 263 | " transforms.RandomCrop(32, padding=4),\n", 264 | " transforms.RandomHorizontalFlip(),\n", 265 | " transforms.ToTensor(),\n", 266 | " normalize\n", 267 | " \n", 268 | " ])\n", 269 | "transform_test = transforms.Compose([\n", 270 | " transforms.CenterCrop(size=(32, 32)),\n", 271 | " transforms.ToTensor(),\n", 272 | " normalize\n", 273 | " ])\n", 274 | "\n", 275 | "train_loader = torch.utils.data.DataLoader(\n", 276 | " datasets.CIFAR10('data', train=True, download=True,\n", 277 | " transform=transform_train),\n", 278 | " batch_size=batch_size, shuffle=True)\n", 279 | "test_loader = torch.utils.data.DataLoader(\n", 280 | " datasets.CIFAR10('data', train=False, transform=transform_test),\n", 281 | " batch_size=batch_size)\n" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 5, 287 | "metadata": { 288 | "scrolled": true 289 | }, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "Files already downloaded and verified\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "data_train = list(torch.utils.data.DataLoader(\n", 301 | " datasets.CIFAR10('data', train=True, download=True,\n", 302 | " transform=transform_test),\n", 303 | " batch_size=1, shuffle=False))" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 6, 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "Files already downloaded and verified\n" 316 | ] 317 | } 318 | ], 319 | "source": [ 320 | "data = list(torch.utils.data.DataLoader(\n", 321 | " datasets.CIFAR10('data', train=False, download=True,\n", 322 | " transform=transform_test),\n", 323 | " batch_size=1, shuffle=False))" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 7, 329 | "metadata": {}, 330 | "outputs": [ 331 | { 332 | "name": "stdout", 333 | "output_type": "stream", 334 | "text": [ 335 | "Accuracy: 0.9367\n" 336 | ] 337 | } 338 | ], 339 | "source": [ 340 | "torch_model.eval()\n", 341 | "correct = 0\n", 342 | "total = 0\n", 343 | "for x,y in test_loader:\n", 344 | " x = x.cuda()\n", 345 | " y = y.numpy()\n", 346 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n", 347 | " total += y.shape[0]\n", 348 | "print(\"Accuracy: \",correct/total)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "Out-of-distribution Datasets" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 8, 361 | "metadata": {}, 362 | "outputs": [ 363 | { 364 | "name": "stdout", 365 | "output_type": "stream", 366 | "text": [ 367 | "Files already downloaded and verified\n" 368 | ] 369 | } 370 | ], 371 | "source": [ 372 | "cifar100 = list(torch.utils.data.DataLoader(\n", 373 | " datasets.CIFAR100('data', train=False, download=True,\n", 374 | " transform=transform_test),\n", 375 | " batch_size=1, shuffle=True))" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 9, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "name": "stdout", 385 | "output_type": "stream", 386 | "text": [ 387 | "Using downloaded and verified file: data/test_32x32.mat\n" 388 | ] 389 | } 390 | ], 391 | "source": [ 392 | "svhn = list(torch.utils.data.DataLoader(\n", 393 | " datasets.SVHN('data', split=\"test\", download=True,\n", 394 | " transform=transform_test),\n", 395 | " batch_size=1, shuffle=True))" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 10, 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "isun = list(torch.utils.data.DataLoader(\n", 405 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 11, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "lsun_c = list(torch.utils.data.DataLoader(\n", 415 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 12, 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [ 424 | "lsun_r = list(torch.utils.data.DataLoader(\n", 425 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 13, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n", 435 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 14, 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n", 445 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "metadata": {}, 451 | "source": [ 452 | "## Code for Detecting OODs" 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "metadata": {}, 458 | "source": [ 459 | " Extract predictions for train and test data " 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 15, 465 | "metadata": {}, 466 | "outputs": [ 467 | { 468 | "name": "stdout", 469 | "output_type": "stream", 470 | "text": [ 471 | "Done\n", 472 | "Done\n" 473 | ] 474 | } 475 | ], 476 | "source": [ 477 | "train_preds = []\n", 478 | "train_confs = []\n", 479 | "train_logits = []\n", 480 | "for idx in range(0,len(data_train),128):\n", 481 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n", 482 | " \n", 483 | " logits = torch_model(batch)\n", 484 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 485 | " preds = np.argmax(confs,axis=1)\n", 486 | " logits = (logits.cpu().detach().numpy())\n", 487 | "\n", 488 | " train_confs.extend(np.max(confs,axis=1)) \n", 489 | " train_preds.extend(preds)\n", 490 | " train_logits.extend(logits)\n", 491 | "print(\"Done\")\n", 492 | "\n", 493 | "test_preds = []\n", 494 | "test_confs = []\n", 495 | "test_logits = []\n", 496 | "\n", 497 | "for idx in range(0,len(data),128):\n", 498 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n", 499 | " \n", 500 | " logits = torch_model(batch)\n", 501 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 502 | " preds = np.argmax(confs,axis=1)\n", 503 | " logits = (logits.cpu().detach().numpy())\n", 504 | "\n", 505 | " test_confs.extend(np.max(confs,axis=1)) \n", 506 | " test_preds.extend(preds)\n", 507 | " test_logits.extend(logits)\n", 508 | "print(\"Done\")" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [ 515 | " Code for detecting OODs by identifying anomalies in correlations " 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 16, 521 | "metadata": {}, 522 | "outputs": [], 523 | "source": [ 524 | "import calculate_log as callog\n", 525 | "\n", 526 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n", 527 | " average_results = {}\n", 528 | " for i in range(1,11):\n", 529 | " random.seed(i)\n", 530 | " \n", 531 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n", 532 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n", 533 | "\n", 534 | " validation = all_test_deviations[validation_indices]\n", 535 | " test_deviations = all_test_deviations[test_indices]\n", 536 | "\n", 537 | " t95 = validation.mean(axis=0)+10**-7\n", 538 | " if not normalize:\n", 539 | " t95 = np.ones_like(t95)\n", 540 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 541 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 542 | " \n", 543 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n", 544 | " for m in results:\n", 545 | " average_results[m] = average_results.get(m,0)+results[m]\n", 546 | " \n", 547 | " for m in average_results:\n", 548 | " average_results[m] /= i\n", 549 | " if verbose:\n", 550 | " callog.print_results(average_results)\n", 551 | " return average_results\n", 552 | "\n", 553 | "def cpu(ob):\n", 554 | " for i in range(len(ob)):\n", 555 | " for j in range(len(ob[i])):\n", 556 | " ob[i][j] = ob[i][j].cpu()\n", 557 | " return ob\n", 558 | "\n", 559 | "def cuda(ob):\n", 560 | " for i in range(len(ob)):\n", 561 | " for j in range(len(ob[i])):\n", 562 | " ob[i][j] = ob[i][j].cuda()\n", 563 | " return ob\n", 564 | "\n", 565 | "class Detector:\n", 566 | " def __init__(self):\n", 567 | " self.all_test_deviations = None\n", 568 | " self.mins = {}\n", 569 | " self.maxs = {}\n", 570 | " \n", 571 | " self.classes = range(10)\n", 572 | " \n", 573 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n", 574 | " for PRED in tqdm(self.classes):\n", 575 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n", 576 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n", 577 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n", 578 | " self.mins[PRED] = cpu(mins)\n", 579 | " self.maxs[PRED] = cpu(maxs)\n", 580 | " torch.cuda.empty_cache()\n", 581 | " \n", 582 | " def compute_test_deviations(self,POWERS=[10]):\n", 583 | " all_test_deviations = None\n", 584 | " test_classes = []\n", 585 | " for PRED in tqdm(self.classes):\n", 586 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n", 587 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n", 588 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n", 589 | " \n", 590 | " test_classes.extend([PRED]*len(test_indices))\n", 591 | " \n", 592 | " mins = cuda(self.mins[PRED])\n", 593 | " maxs = cuda(self.maxs[PRED])\n", 594 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n", 595 | " cpu(mins)\n", 596 | " cpu(maxs)\n", 597 | " if all_test_deviations is None:\n", 598 | " all_test_deviations = test_deviations\n", 599 | " else:\n", 600 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n", 601 | " torch.cuda.empty_cache()\n", 602 | " self.all_test_deviations = all_test_deviations\n", 603 | " \n", 604 | " self.test_classes = np.array(test_classes)\n", 605 | " \n", 606 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n", 607 | " ood_preds = []\n", 608 | " ood_confs = []\n", 609 | " \n", 610 | " for idx in range(0,len(ood),128):\n", 611 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n", 612 | " logits = torch_model(batch)\n", 613 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 614 | " preds = np.argmax(confs,axis=1)\n", 615 | " \n", 616 | " ood_confs.extend(np.max(confs,axis=1))\n", 617 | " ood_preds.extend(preds) \n", 618 | " torch.cuda.empty_cache()\n", 619 | " print(\"Done\")\n", 620 | " \n", 621 | " ood_classes = []\n", 622 | " all_ood_deviations = None\n", 623 | " for PRED in tqdm(self.classes):\n", 624 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n", 625 | " if len(ood_indices)==0:\n", 626 | " continue\n", 627 | " ood_classes.extend([PRED]*len(ood_indices))\n", 628 | " \n", 629 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n", 630 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n", 631 | " mins = cuda(self.mins[PRED])\n", 632 | " maxs = cuda(self.maxs[PRED])\n", 633 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n", 634 | " cpu(self.mins[PRED])\n", 635 | " cpu(self.maxs[PRED]) \n", 636 | " if all_ood_deviations is None:\n", 637 | " all_ood_deviations = ood_deviations\n", 638 | " else:\n", 639 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n", 640 | " torch.cuda.empty_cache()\n", 641 | " \n", 642 | " self.ood_classes = np.array(ood_classes)\n", 643 | " \n", 644 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n", 645 | " return average_results, self.all_test_deviations, all_ood_deviations\n" 646 | ] 647 | }, 648 | { 649 | "cell_type": "markdown", 650 | "metadata": {}, 651 | "source": [ 652 | "

Results

" 653 | ] 654 | }, 655 | { 656 | "cell_type": "code", 657 | "execution_count": 17, 658 | "metadata": {}, 659 | "outputs": [ 660 | { 661 | "data": { 662 | "application/vnd.jupyter.widget-view+json": { 663 | "model_id": "c2ff5c67696a455a888f8fe4b19339c9", 664 | "version_major": 2, 665 | "version_minor": 0 666 | }, 667 | "text/plain": [ 668 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 669 | ] 670 | }, 671 | "metadata": {}, 672 | "output_type": "display_data" 673 | }, 674 | { 675 | "name": "stdout", 676 | "output_type": "stream", 677 | "text": [ 678 | "\n" 679 | ] 680 | }, 681 | { 682 | "data": { 683 | "application/vnd.jupyter.widget-view+json": { 684 | "model_id": "e816c6c68303437a8a4fdc97091a8853", 685 | "version_major": 2, 686 | "version_minor": 0 687 | }, 688 | "text/plain": [ 689 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 690 | ] 691 | }, 692 | "metadata": {}, 693 | "output_type": "display_data" 694 | }, 695 | { 696 | "name": "stdout", 697 | "output_type": "stream", 698 | "text": [ 699 | "\n", 700 | " TNR AUROC DTACC AUIN AUOUT \n", 701 | " 99.257 99.831 98.077 99.827 99.829\n", 702 | "LSUN (R)\n", 703 | "Done\n" 704 | ] 705 | }, 706 | { 707 | "data": { 708 | "application/vnd.jupyter.widget-view+json": { 709 | "model_id": "42829af13dc84e7fadccd7014a655912", 710 | "version_major": 2, 711 | "version_minor": 0 712 | }, 713 | "text/plain": [ 714 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 715 | ] 716 | }, 717 | "metadata": {}, 718 | "output_type": "display_data" 719 | }, 720 | { 721 | "name": "stdout", 722 | "output_type": "stream", 723 | "text": [ 724 | "\n", 725 | " TNR AUROC DTACC AUIN AUOUT \n", 726 | " 99.585 99.889 98.641 99.866 99.899\n", 727 | "LSUN (C)\n", 728 | "Done\n" 729 | ] 730 | }, 731 | { 732 | "data": { 733 | "application/vnd.jupyter.widget-view+json": { 734 | "model_id": "da08da96632648d0b841cbc97fe61456", 735 | "version_major": 2, 736 | "version_minor": 0 737 | }, 738 | "text/plain": [ 739 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 740 | ] 741 | }, 742 | "metadata": {}, 743 | "output_type": "display_data" 744 | }, 745 | { 746 | "name": "stdout", 747 | "output_type": "stream", 748 | "text": [ 749 | "\n", 750 | " TNR AUROC DTACC AUIN AUOUT \n", 751 | " 89.798 97.796 92.591 97.433 98.172\n", 752 | "TinyImgNet (R)\n", 753 | "Done\n" 754 | ] 755 | }, 756 | { 757 | "data": { 758 | "application/vnd.jupyter.widget-view+json": { 759 | "model_id": "430ab4eefa264916afbf916f35418a05", 760 | "version_major": 2, 761 | "version_minor": 0 762 | }, 763 | "text/plain": [ 764 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 765 | ] 766 | }, 767 | "metadata": {}, 768 | "output_type": "display_data" 769 | }, 770 | { 771 | "name": "stdout", 772 | "output_type": "stream", 773 | "text": [ 774 | "\n", 775 | " TNR AUROC DTACC AUIN AUOUT \n", 776 | " 98.746 99.717 97.797 99.648 99.765\n", 777 | "TinyImgNet (C)\n", 778 | "Done\n" 779 | ] 780 | }, 781 | { 782 | "data": { 783 | "application/vnd.jupyter.widget-view+json": { 784 | "model_id": "4168c5ab119d4577a35895922e4d8430", 785 | "version_major": 2, 786 | "version_minor": 0 787 | }, 788 | "text/plain": [ 789 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 790 | ] 791 | }, 792 | "metadata": {}, 793 | "output_type": "display_data" 794 | }, 795 | { 796 | "name": "stdout", 797 | "output_type": "stream", 798 | "text": [ 799 | "\n", 800 | " TNR AUROC DTACC AUIN AUOUT \n", 801 | " 96.666 99.242 96.074 99.082 99.377\n", 802 | "SVHN\n", 803 | "Done\n" 804 | ] 805 | }, 806 | { 807 | "data": { 808 | "application/vnd.jupyter.widget-view+json": { 809 | "model_id": "b91cf4a3135f49aaa74884b0d4149131", 810 | "version_major": 2, 811 | "version_minor": 0 812 | }, 813 | "text/plain": [ 814 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 815 | ] 816 | }, 817 | "metadata": {}, 818 | "output_type": "display_data" 819 | }, 820 | { 821 | "name": "stdout", 822 | "output_type": "stream", 823 | "text": [ 824 | "\n", 825 | " TNR AUROC DTACC AUIN AUOUT \n", 826 | " 97.614 99.502 96.708 98.446 99.831\n", 827 | "CIFAR-100\n", 828 | "Done\n" 829 | ] 830 | }, 831 | { 832 | "data": { 833 | "application/vnd.jupyter.widget-view+json": { 834 | "model_id": "e51aaecf0a214053a20b00b2e58b1e6e", 835 | "version_major": 2, 836 | "version_minor": 0 837 | }, 838 | "text/plain": [ 839 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 840 | ] 841 | }, 842 | "metadata": {}, 843 | "output_type": "display_data" 844 | }, 845 | { 846 | "name": "stdout", 847 | "output_type": "stream", 848 | "text": [ 849 | "\n", 850 | " TNR AUROC DTACC AUIN AUOUT \n", 851 | " 32.896 79.015 71.711 74.991 80.666\n" 852 | ] 853 | } 854 | ], 855 | "source": [ 856 | "def G_p(ob, p):\n", 857 | " temp = ob.detach()\n", 858 | " \n", 859 | " temp = temp**p\n", 860 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n", 861 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n", 862 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n", 863 | " \n", 864 | " return temp\n", 865 | "\n", 866 | "\n", 867 | "detector = Detector()\n", 868 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n", 869 | "\n", 870 | "detector.compute_test_deviations(POWERS=range(1,11))\n", 871 | "\n", 872 | "print(\"iSUN\")\n", 873 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n", 874 | "print(\"LSUN (R)\")\n", 875 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n", 876 | "print(\"LSUN (C)\")\n", 877 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n", 878 | "print(\"TinyImgNet (R)\")\n", 879 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n", 880 | "print(\"TinyImgNet (C)\")\n", 881 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n", 882 | "print(\"SVHN\")\n", 883 | "svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))\n", 884 | "print(\"CIFAR-100\")\n", 885 | "c100_results = detector.compute_ood_deviations(cifar100,POWERS=range(1,11))" 886 | ] 887 | } 888 | ], 889 | "metadata": { 890 | "kernelspec": { 891 | "display_name": "Python 3", 892 | "language": "python", 893 | "name": "python3" 894 | }, 895 | "language_info": { 896 | "codemirror_mode": { 897 | "name": "ipython", 898 | "version": 3 899 | }, 900 | "file_extension": ".py", 901 | "mimetype": "text/x-python", 902 | "name": "python", 903 | "nbconvert_exporter": "python", 904 | "pygments_lexer": "ipython3", 905 | "version": "3.6.9" 906 | } 907 | }, 908 | "nbformat": 4, 909 | "nbformat_minor": 2 910 | } 911 | -------------------------------------------------------------------------------- /ResNet_Cifar100.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

ResNet: Cifar100

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from __future__ import division,print_function\n", 24 | "\n", 25 | "%matplotlib inline\n", 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "\n", 29 | "import sys\n", 30 | "from tqdm import tqdm_notebook as tqdm\n", 31 | "\n", 32 | "import random\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import math\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "import torch\n", 39 | "import torch.nn as nn\n", 40 | "import torch.nn.functional as F\n", 41 | "import torch.optim as optim\n", 42 | "import torch.nn.init as init\n", 43 | "from torch.autograd import Variable, grad\n", 44 | "from torchvision import datasets, transforms\n", 45 | "from torch.nn.parameter import Parameter\n", 46 | "\n", 47 | "import calculate_log as callog\n", 48 | "\n", 49 | "import warnings\n", 50 | "warnings.filterwarnings('ignore')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "torch.cuda.set_device(1) #Select the GPU" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Model definition" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "Done\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "def conv3x3(in_planes, out_planes, stride=1):\n", 84 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", 85 | "\n", 86 | "class BasicBlock(nn.Module):\n", 87 | " expansion = 1\n", 88 | "\n", 89 | " def __init__(self, in_planes, planes, stride=1):\n", 90 | " super(BasicBlock, self).__init__()\n", 91 | " self.conv1 = conv3x3(in_planes, planes, stride)\n", 92 | " self.bn1 = nn.BatchNorm2d(planes)\n", 93 | " self.conv2 = conv3x3(planes, planes)\n", 94 | " self.bn2 = nn.BatchNorm2d(planes)\n", 95 | "\n", 96 | " self.shortcut = nn.Sequential()\n", 97 | " if stride != 1 or in_planes != self.expansion*planes:\n", 98 | " self.shortcut = nn.Sequential(\n", 99 | " nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", 100 | " nn.BatchNorm2d(self.expansion*planes)\n", 101 | " )\n", 102 | " \n", 103 | " def forward(self, x):\n", 104 | " t = self.conv1(x)\n", 105 | " out = F.relu(self.bn1(t))\n", 106 | " torch_model.record(t)\n", 107 | " torch_model.record(out)\n", 108 | " t = self.conv2(out)\n", 109 | " out = self.bn2(self.conv2(out))\n", 110 | " torch_model.record(t)\n", 111 | " torch_model.record(out)\n", 112 | " t = self.shortcut(x)\n", 113 | " out += t\n", 114 | " torch_model.record(t)\n", 115 | " out = F.relu(out)\n", 116 | " torch_model.record(out)\n", 117 | " \n", 118 | " return out\n", 119 | "\n", 120 | "class ResNet(nn.Module):\n", 121 | " def __init__(self, block, num_blocks, num_classes=10):\n", 122 | " super(ResNet, self).__init__()\n", 123 | " self.in_planes = 64\n", 124 | "\n", 125 | " self.conv1 = conv3x3(3,64)\n", 126 | " self.bn1 = nn.BatchNorm2d(64)\n", 127 | " self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n", 128 | " self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n", 129 | " self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n", 130 | " self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n", 131 | " self.linear = nn.Linear(512*block.expansion, num_classes)\n", 132 | " \n", 133 | " self.collecting = False\n", 134 | " \n", 135 | " def _make_layer(self, block, planes, num_blocks, stride):\n", 136 | " strides = [stride] + [1]*(num_blocks-1)\n", 137 | " layers = []\n", 138 | " for stride in strides:\n", 139 | " layers.append(block(self.in_planes, planes, stride))\n", 140 | " self.in_planes = planes * block.expansion\n", 141 | " return nn.Sequential(*layers)\n", 142 | " \n", 143 | " def forward(self, x):\n", 144 | " out = F.relu(self.bn1(self.conv1(x)))\n", 145 | " out = self.layer1(out)\n", 146 | " out = self.layer2(out)\n", 147 | " out = self.layer3(out)\n", 148 | " out = self.layer4(out)\n", 149 | " out = F.avg_pool2d(out, 4)\n", 150 | " out = out.view(out.size(0), -1)\n", 151 | " y = self.linear(out)\n", 152 | " return y\n", 153 | " \n", 154 | " def record(self, t):\n", 155 | " if self.collecting:\n", 156 | " self.gram_feats.append(t)\n", 157 | " \n", 158 | " def gram_feature_list(self,x):\n", 159 | " self.collecting = True\n", 160 | " self.gram_feats = []\n", 161 | " self.forward(x)\n", 162 | " self.collecting = False\n", 163 | " temp = self.gram_feats\n", 164 | " self.gram_feats = []\n", 165 | " return temp\n", 166 | " \n", 167 | " def load(self, path=\"resnet_cifar100.pth\"):\n", 168 | " tm = torch.load(path,map_location=\"cpu\") \n", 169 | " self.load_state_dict(tm)\n", 170 | " \n", 171 | " def get_min_max(self, data, power):\n", 172 | " mins = []\n", 173 | " maxs = []\n", 174 | " \n", 175 | " for i in range(0,len(data),128):\n", 176 | " batch = data[i:i+128].cuda()\n", 177 | " feat_list = self.gram_feature_list(batch)\n", 178 | " for L,feat_L in enumerate(feat_list):\n", 179 | " if L==len(mins):\n", 180 | " mins.append([None]*len(power))\n", 181 | " maxs.append([None]*len(power))\n", 182 | " \n", 183 | " for p,P in enumerate(power):\n", 184 | " g_p = G_p(feat_L,P)\n", 185 | " \n", 186 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n", 187 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n", 188 | " \n", 189 | " if mins[L][p] is None:\n", 190 | " mins[L][p] = current_min\n", 191 | " maxs[L][p] = current_max\n", 192 | " else:\n", 193 | " mins[L][p] = torch.min(current_min,mins[L][p])\n", 194 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n", 195 | " \n", 196 | " return mins,maxs\n", 197 | " \n", 198 | " def get_deviations(self,data,power,mins,maxs):\n", 199 | " deviations = []\n", 200 | " \n", 201 | " for i in range(0,len(data),128): \n", 202 | " batch = data[i:i+128].cuda()\n", 203 | " feat_list = self.gram_feature_list(batch)\n", 204 | " batch_deviations = []\n", 205 | " for L,feat_L in enumerate(feat_list):\n", 206 | " dev = 0\n", 207 | " for p,P in enumerate(power):\n", 208 | " g_p = G_p(feat_L,P)\n", 209 | " \n", 210 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 211 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 212 | " batch_deviations.append(dev.cpu().detach().numpy())\n", 213 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n", 214 | " deviations.append(batch_deviations)\n", 215 | " deviations = np.concatenate(deviations,axis=0)\n", 216 | " \n", 217 | " return deviations\n", 218 | "\n", 219 | "\n", 220 | "torch_model = ResNet(BasicBlock, [3,4,6,3], num_classes=100)\n", 221 | "torch_model.load()\n", 222 | "torch_model.cuda()\n", 223 | "torch_model.params = list(torch_model.parameters())\n", 224 | "torch_model.eval()\n", 225 | "print(\"Done\") " 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "## Datasets" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "In-distribution Datasets" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 4, 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "Files already downloaded and verified\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "batch_size = 128\n", 257 | "mean = np.array([[0.4914, 0.4822, 0.4465]]).T\n", 258 | "\n", 259 | "std = np.array([[0.2023, 0.1994, 0.2010]]).T\n", 260 | "normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n", 261 | "\n", 262 | "transform_train = transforms.Compose([\n", 263 | " transforms.RandomCrop(32, padding=4),\n", 264 | " transforms.RandomHorizontalFlip(),\n", 265 | " transforms.ToTensor(),\n", 266 | " normalize\n", 267 | " \n", 268 | " ])\n", 269 | "\n", 270 | "transform_test = transforms.Compose([\n", 271 | " transforms.CenterCrop(size=(32, 32)),\n", 272 | " transforms.ToTensor(),\n", 273 | " normalize\n", 274 | " ])\n", 275 | "\n", 276 | "train_loader = torch.utils.data.DataLoader(\n", 277 | " datasets.CIFAR100('data', train=True, download=True,\n", 278 | " transform=transform_train),\n", 279 | " batch_size=batch_size, shuffle=True)\n", 280 | "test_loader = torch.utils.data.DataLoader(\n", 281 | " datasets.CIFAR100('data', train=False, transform=transform_test),\n", 282 | " batch_size=batch_size)\n" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 5, 288 | "metadata": { 289 | "scrolled": true 290 | }, 291 | "outputs": [ 292 | { 293 | "name": "stdout", 294 | "output_type": "stream", 295 | "text": [ 296 | "Files already downloaded and verified\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "data_train = list(torch.utils.data.DataLoader(\n", 302 | " datasets.CIFAR100('data', train=True, download=True,\n", 303 | " transform=transform_test),\n", 304 | " batch_size=1, shuffle=False))" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 6, 310 | "metadata": { 311 | "scrolled": true 312 | }, 313 | "outputs": [ 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "Files already downloaded and verified\n" 319 | ] 320 | } 321 | ], 322 | "source": [ 323 | "data = list(torch.utils.data.DataLoader(\n", 324 | " datasets.CIFAR100('data', train=False, download=True,\n", 325 | " transform=transform_test),\n", 326 | " batch_size=1, shuffle=False))" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 7, 332 | "metadata": {}, 333 | "outputs": [ 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "Accuracy: 0.7834\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "torch_model.eval()\n", 344 | "correct = 0\n", 345 | "total = 0\n", 346 | "for x,y in test_loader:\n", 347 | " x = x.cuda()\n", 348 | " y = y.numpy()\n", 349 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n", 350 | " total += y.shape[0]\n", 351 | "print(\"Accuracy: \",correct/total)\n" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "Out-of-distribution Datasets" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 8, 364 | "metadata": { 365 | "scrolled": true 366 | }, 367 | "outputs": [ 368 | { 369 | "name": "stdout", 370 | "output_type": "stream", 371 | "text": [ 372 | "Files already downloaded and verified\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "cifar10 = list(torch.utils.data.DataLoader(\n", 378 | " datasets.CIFAR10('data', train=False, download=True,\n", 379 | " transform=transform_test),\n", 380 | " batch_size=1, shuffle=True))" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 9, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "name": "stdout", 390 | "output_type": "stream", 391 | "text": [ 392 | "Using downloaded and verified file: data/test_32x32.mat\n" 393 | ] 394 | } 395 | ], 396 | "source": [ 397 | "svhn = list(torch.utils.data.DataLoader(\n", 398 | " datasets.SVHN('data', split=\"test\", download=True,\n", 399 | " transform=transform_test),\n", 400 | " batch_size=1, shuffle=True))" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 10, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "isun = list(torch.utils.data.DataLoader(\n", 410 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 11, 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "lsun_c = list(torch.utils.data.DataLoader(\n", 420 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 12, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "lsun_r = list(torch.utils.data.DataLoader(\n", 430 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 13, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n", 440 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 14, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n", 450 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "metadata": {}, 456 | "source": [ 457 | "## Code for Detecting OODs" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | " Extract predictions for train and test data " 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 15, 470 | "metadata": {}, 471 | "outputs": [ 472 | { 473 | "name": "stdout", 474 | "output_type": "stream", 475 | "text": [ 476 | "Done\n", 477 | "Done\n" 478 | ] 479 | } 480 | ], 481 | "source": [ 482 | "train_preds = []\n", 483 | "train_confs = []\n", 484 | "train_logits = []\n", 485 | "for idx in range(0,len(data_train),128):\n", 486 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n", 487 | " \n", 488 | " logits = torch_model(batch)\n", 489 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 490 | " preds = np.argmax(confs,axis=1)\n", 491 | " logits = (logits.cpu().detach().numpy())#**2)#.sum(axis=1)\n", 492 | "\n", 493 | " train_confs.extend(np.max(confs,axis=1)) \n", 494 | " train_preds.extend(preds)\n", 495 | " train_logits.extend(logits)\n", 496 | "print(\"Done\")\n", 497 | "\n", 498 | "test_preds = []\n", 499 | "test_confs = []\n", 500 | "test_logits = []\n", 501 | "\n", 502 | "for idx in range(0,len(data),128):\n", 503 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n", 504 | " \n", 505 | " logits = torch_model(batch)\n", 506 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 507 | " preds = np.argmax(confs,axis=1)\n", 508 | " logits = (logits.cpu().detach().numpy())#**2)#.sum(axis=1)\n", 509 | "\n", 510 | " test_confs.extend(np.max(confs,axis=1)) \n", 511 | " test_preds.extend(preds)\n", 512 | " test_logits.extend(logits)\n", 513 | "print(\"Done\")" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | " Code for detecting OODs by identifying anomalies in correlations " 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 16, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "import calculate_log as callog\n", 530 | "\n", 531 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n", 532 | " average_results = {}\n", 533 | " for i in range(1,11):\n", 534 | " random.seed(i)\n", 535 | " \n", 536 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n", 537 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n", 538 | " \n", 539 | " validation = all_test_deviations[validation_indices]\n", 540 | " test_deviations = all_test_deviations[test_indices]\n", 541 | "\n", 542 | " t95 = validation.mean(axis=0)+10**-7\n", 543 | " if not normalize:\n", 544 | " t95 = np.ones_like(t95)\n", 545 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 546 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 547 | " \n", 548 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n", 549 | " for m in results:\n", 550 | " average_results[m] = average_results.get(m,0)+results[m]\n", 551 | " \n", 552 | " for m in average_results:\n", 553 | " average_results[m] /= i\n", 554 | " if verbose:\n", 555 | " callog.print_results(average_results)\n", 556 | " return average_results\n", 557 | "\n", 558 | "\n", 559 | "def cpu(ob):\n", 560 | " for i in range(len(ob)):\n", 561 | " for j in range(len(ob[i])):\n", 562 | " ob[i][j] = ob[i][j].cpu()\n", 563 | " return ob\n", 564 | "\n", 565 | "def cuda(ob):\n", 566 | " for i in range(len(ob)):\n", 567 | " for j in range(len(ob[i])):\n", 568 | " ob[i][j] = ob[i][j].cuda()\n", 569 | " return ob\n", 570 | "\n", 571 | "class Detector:\n", 572 | " def __init__(self):\n", 573 | " self.all_test_deviations = None\n", 574 | " self.mins = {}\n", 575 | " self.maxs = {}\n", 576 | " \n", 577 | " self.classes = range(100)\n", 578 | " \n", 579 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n", 580 | " for PRED in tqdm(self.classes):\n", 581 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n", 582 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n", 583 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n", 584 | " self.mins[PRED] = cpu(mins)\n", 585 | " self.maxs[PRED] = cpu(maxs)\n", 586 | " torch.cuda.empty_cache()\n", 587 | " \n", 588 | " def compute_test_deviations(self,POWERS=[10]):\n", 589 | " all_test_deviations = None\n", 590 | " for PRED in tqdm(self.classes):\n", 591 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n", 592 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n", 593 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n", 594 | " mins = cuda(self.mins[PRED])\n", 595 | " maxs = cuda(self.maxs[PRED])\n", 596 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n", 597 | " cpu(mins)\n", 598 | " cpu(maxs)\n", 599 | " if all_test_deviations is None:\n", 600 | " all_test_deviations = test_deviations\n", 601 | " else:\n", 602 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n", 603 | " torch.cuda.empty_cache()\n", 604 | " self.all_test_deviations = all_test_deviations\n", 605 | " \n", 606 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n", 607 | " ood_preds = []\n", 608 | " ood_confs = []\n", 609 | " \n", 610 | " for idx in range(0,len(ood),128):\n", 611 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n", 612 | " logits = torch_model(batch)\n", 613 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 614 | " preds = np.argmax(confs,axis=1)\n", 615 | " \n", 616 | " ood_confs.extend(np.max(confs,axis=1))\n", 617 | " ood_preds.extend(preds) \n", 618 | " torch.cuda.empty_cache()\n", 619 | " print(\"Done\")\n", 620 | " \n", 621 | " all_ood_deviations = None\n", 622 | " for PRED in tqdm(self.classes):\n", 623 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n", 624 | " if len(ood_indices)==0:\n", 625 | " continue\n", 626 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n", 627 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n", 628 | " mins = cuda(self.mins[PRED])\n", 629 | " maxs = cuda(self.maxs[PRED])\n", 630 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n", 631 | " cpu(self.mins[PRED])\n", 632 | " cpu(self.maxs[PRED]) \n", 633 | " if all_ood_deviations is None:\n", 634 | " all_ood_deviations = ood_deviations\n", 635 | " else:\n", 636 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n", 637 | " torch.cuda.empty_cache()\n", 638 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n", 639 | " return average_results, self.all_test_deviations, all_ood_deviations\n" 640 | ] 641 | }, 642 | { 643 | "cell_type": "markdown", 644 | "metadata": {}, 645 | "source": [ 646 | "

Results

" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": 17, 652 | "metadata": { 653 | "scrolled": false 654 | }, 655 | "outputs": [ 656 | { 657 | "data": { 658 | "application/vnd.jupyter.widget-view+json": { 659 | "model_id": "b356f59735b74cccbf45df07cd7a9a50", 660 | "version_major": 2, 661 | "version_minor": 0 662 | }, 663 | "text/plain": [ 664 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 665 | ] 666 | }, 667 | "metadata": {}, 668 | "output_type": "display_data" 669 | }, 670 | { 671 | "name": "stdout", 672 | "output_type": "stream", 673 | "text": [ 674 | "\n" 675 | ] 676 | }, 677 | { 678 | "data": { 679 | "application/vnd.jupyter.widget-view+json": { 680 | "model_id": "16234970aabe4b4daccf0dc1e74a0f79", 681 | "version_major": 2, 682 | "version_minor": 0 683 | }, 684 | "text/plain": [ 685 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 686 | ] 687 | }, 688 | "metadata": {}, 689 | "output_type": "display_data" 690 | }, 691 | { 692 | "name": "stdout", 693 | "output_type": "stream", 694 | "text": [ 695 | "\n", 696 | "iSUN\n", 697 | "Done\n" 698 | ] 699 | }, 700 | { 701 | "data": { 702 | "application/vnd.jupyter.widget-view+json": { 703 | "model_id": "326d25cb5b24408e870cd8d64513b538", 704 | "version_major": 2, 705 | "version_minor": 0 706 | }, 707 | "text/plain": [ 708 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 709 | ] 710 | }, 711 | "metadata": {}, 712 | "output_type": "display_data" 713 | }, 714 | { 715 | "name": "stdout", 716 | "output_type": "stream", 717 | "text": [ 718 | "\n", 719 | " TNR AUROC DTACC AUIN AUOUT \n", 720 | " 94.756 98.814 94.976 98.755 98.803\n", 721 | "LSUN (R)\n", 722 | "Done\n" 723 | ] 724 | }, 725 | { 726 | "data": { 727 | "application/vnd.jupyter.widget-view+json": { 728 | "model_id": "3cc2523b5d4b4d3cbecaa37e04ad9b2a", 729 | "version_major": 2, 730 | "version_minor": 0 731 | }, 732 | "text/plain": [ 733 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 734 | ] 735 | }, 736 | "metadata": {}, 737 | "output_type": "display_data" 738 | }, 739 | { 740 | "name": "stdout", 741 | "output_type": "stream", 742 | "text": [ 743 | "\n", 744 | " TNR AUROC DTACC AUIN AUOUT \n", 745 | " 96.606 99.202 95.969 99.171 99.189\n", 746 | "LSUN (C)\n", 747 | "Done\n" 748 | ] 749 | }, 750 | { 751 | "data": { 752 | "application/vnd.jupyter.widget-view+json": { 753 | "model_id": "f5b3b69bc8a44dbebb3dcf663abb0fd0", 754 | "version_major": 2, 755 | "version_minor": 0 756 | }, 757 | "text/plain": [ 758 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 759 | ] 760 | }, 761 | "metadata": {}, 762 | "output_type": "display_data" 763 | }, 764 | { 765 | "name": "stdout", 766 | "output_type": "stream", 767 | "text": [ 768 | "\n", 769 | " TNR AUROC DTACC AUIN AUOUT \n", 770 | " 64.800 92.132 84.186 90.995 93.004\n", 771 | "TinyImgNet (R)\n", 772 | "Done\n" 773 | ] 774 | }, 775 | { 776 | "data": { 777 | "application/vnd.jupyter.widget-view+json": { 778 | "model_id": "837ccf6cbff94839b97117303d56ceb2", 779 | "version_major": 2, 780 | "version_minor": 0 781 | }, 782 | "text/plain": [ 783 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 784 | ] 785 | }, 786 | "metadata": {}, 787 | "output_type": "display_data" 788 | }, 789 | { 790 | "name": "stdout", 791 | "output_type": "stream", 792 | "text": [ 793 | "\n", 794 | " TNR AUROC DTACC AUIN AUOUT \n", 795 | " 94.789 98.898 94.960 98.771 98.956\n", 796 | "TinyImgNet (C)\n", 797 | "Done\n" 798 | ] 799 | }, 800 | { 801 | "data": { 802 | "application/vnd.jupyter.widget-view+json": { 803 | "model_id": "5afe4ee2584e4c17bb24768b1419df3d", 804 | "version_major": 2, 805 | "version_minor": 0 806 | }, 807 | "text/plain": [ 808 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 809 | ] 810 | }, 811 | "metadata": {}, 812 | "output_type": "display_data" 813 | }, 814 | { 815 | "name": "stdout", 816 | "output_type": "stream", 817 | "text": [ 818 | "\n", 819 | " TNR AUROC DTACC AUIN AUOUT \n", 820 | " 88.478 97.661 92.167 97.392 97.891\n", 821 | "SVHN\n", 822 | "Done\n" 823 | ] 824 | }, 825 | { 826 | "data": { 827 | "application/vnd.jupyter.widget-view+json": { 828 | "model_id": "77a2d847101143f187a8376fb47f8242", 829 | "version_major": 2, 830 | "version_minor": 0 831 | }, 832 | "text/plain": [ 833 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 834 | ] 835 | }, 836 | "metadata": {}, 837 | "output_type": "display_data" 838 | }, 839 | { 840 | "name": "stdout", 841 | "output_type": "stream", 842 | "text": [ 843 | "\n", 844 | " TNR AUROC DTACC AUIN AUOUT \n", 845 | " 80.752 96.049 89.610 90.480 98.487\n", 846 | "CIFAR-10\n", 847 | "Done\n" 848 | ] 849 | }, 850 | { 851 | "data": { 852 | "application/vnd.jupyter.widget-view+json": { 853 | "model_id": "d7e39e3b74c44ca4b80409baddbe0207", 854 | "version_major": 2, 855 | "version_minor": 0 856 | }, 857 | "text/plain": [ 858 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 859 | ] 860 | }, 861 | "metadata": {}, 862 | "output_type": "display_data" 863 | }, 864 | { 865 | "name": "stdout", 866 | "output_type": "stream", 867 | "text": [ 868 | "\n", 869 | " TNR AUROC DTACC AUIN AUOUT \n", 870 | " 12.196 67.951 63.475 66.102 66.957\n" 871 | ] 872 | } 873 | ], 874 | "source": [ 875 | "def G_p(ob, p):\n", 876 | " temp = ob.detach()\n", 877 | " \n", 878 | " temp = temp**p\n", 879 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n", 880 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n", 881 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n", 882 | " \n", 883 | " return temp\n", 884 | "\n", 885 | "\n", 886 | "detector = Detector()\n", 887 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n", 888 | "\n", 889 | "detector.compute_test_deviations(POWERS=range(1,11))\n", 890 | "\n", 891 | "print(\"iSUN\")\n", 892 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n", 893 | "print(\"LSUN (R)\")\n", 894 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n", 895 | "print(\"LSUN (C)\")\n", 896 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n", 897 | "print(\"TinyImgNet (R)\")\n", 898 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n", 899 | "print(\"TinyImgNet (C)\")\n", 900 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n", 901 | "print(\"SVHN\")\n", 902 | "svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))\n", 903 | "print(\"CIFAR-10\")\n", 904 | "c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))" 905 | ] 906 | }, 907 | { 908 | "cell_type": "code", 909 | "execution_count": null, 910 | "metadata": {}, 911 | "outputs": [], 912 | "source": [] 913 | } 914 | ], 915 | "metadata": { 916 | "kernelspec": { 917 | "display_name": "Python 3", 918 | "language": "python", 919 | "name": "python3" 920 | }, 921 | "language_info": { 922 | "codemirror_mode": { 923 | "name": "ipython", 924 | "version": 3 925 | }, 926 | "file_extension": ".py", 927 | "mimetype": "text/x-python", 928 | "name": "python", 929 | "nbconvert_exporter": "python", 930 | "pygments_lexer": "ipython3", 931 | "version": "3.6.9" 932 | } 933 | }, 934 | "nbformat": 4, 935 | "nbformat_minor": 2 936 | } 937 | -------------------------------------------------------------------------------- /DenseNet_SVHN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

DenseNet: SVHN

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from __future__ import division,print_function\n", 24 | "\n", 25 | "%matplotlib inline\n", 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "\n", 29 | "import sys\n", 30 | "from tqdm import tqdm_notebook as tqdm\n", 31 | "\n", 32 | "import random\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import math\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "import torch\n", 39 | "import torch.nn as nn\n", 40 | "import torch.nn.functional as F\n", 41 | "import torch.optim as optim\n", 42 | "import torch.nn.init as init\n", 43 | "from torch.autograd import Variable, grad\n", 44 | "from torchvision import datasets, transforms\n", 45 | "from torch.nn.parameter import Parameter\n", 46 | "\n", 47 | "import calculate_log as callog\n", 48 | "\n", 49 | "import warnings\n", 50 | "warnings.filterwarnings('ignore')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "torch.cuda.set_device(2) #Select the GPU" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Model definition" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "Done\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "def conv3x3(in_planes, out_planes, stride=1):\n", 84 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", 85 | "\n", 86 | "class BottleneckBlock(nn.Module):\n", 87 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n", 88 | " super(BottleneckBlock, self).__init__()\n", 89 | " inter_planes = out_planes * 4\n", 90 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 91 | " self.relu = nn.ReLU(inplace=True)\n", 92 | " self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,\n", 93 | " padding=0, bias=False)\n", 94 | " self.bn2 = nn.BatchNorm2d(inter_planes)\n", 95 | " self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,\n", 96 | " padding=1, bias=False)\n", 97 | " self.droprate = dropRate\n", 98 | " \n", 99 | " def forward(self, x):\n", 100 | " \n", 101 | " out = self.conv1(self.relu(self.bn1(x)))\n", 102 | " \n", 103 | " torch_model.record(out)\n", 104 | " \n", 105 | " if self.droprate > 0:\n", 106 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 107 | " \n", 108 | " out = self.conv2(self.relu(self.bn2(out)))\n", 109 | " torch_model.record(out)\n", 110 | " \n", 111 | " if self.droprate > 0:\n", 112 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 113 | " return torch.cat([x, out], 1)\n", 114 | "\n", 115 | "class TransitionBlock(nn.Module):\n", 116 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n", 117 | " super(TransitionBlock, self).__init__()\n", 118 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 119 | " self.relu = nn.ReLU(inplace=True)\n", 120 | " self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n", 121 | " padding=0, bias=False)\n", 122 | " self.droprate = dropRate\n", 123 | " \n", 124 | " def forward(self, x):\n", 125 | " t=self.relu(self.bn1(x))\n", 126 | " out = self.conv1(t)\n", 127 | " \n", 128 | " torch_model.record(t)\n", 129 | " torch_model.record(out)\n", 130 | " \n", 131 | " if self.droprate > 0:\n", 132 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 133 | " return F.avg_pool2d(out, 2)\n", 134 | "\n", 135 | "class DenseBlock(nn.Module):\n", 136 | " def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):\n", 137 | " super(DenseBlock, self).__init__()\n", 138 | " self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)\n", 139 | " \n", 140 | " def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):\n", 141 | " layers = []\n", 142 | " for i in range(int(nb_layers)):\n", 143 | " layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))\n", 144 | " return nn.Sequential(*layers)\n", 145 | " \n", 146 | " def forward(self, x):\n", 147 | " t = self.layer(x)\n", 148 | " torch_model.record(t)\n", 149 | " return t\n", 150 | "\n", 151 | "\n", 152 | "class DenseNet3(nn.Module):\n", 153 | " def __init__(self, depth, num_classes, growth_rate=12,\n", 154 | " reduction=0.5, bottleneck=True, dropRate=0.0):\n", 155 | " super(DenseNet3, self).__init__()\n", 156 | " \n", 157 | " self.collecting = False\n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " in_planes = 2 * growth_rate\n", 162 | " n = (depth - 4) / 3\n", 163 | " if bottleneck == True:\n", 164 | " n = n/2\n", 165 | " block = BottleneckBlock\n", 166 | " else:\n", 167 | " block = BasicBlock\n", 168 | " # 1st conv before any dense block\n", 169 | " self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,\n", 170 | " padding=1, bias=False)\n", 171 | " # 1st block\n", 172 | " self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 173 | " in_planes = int(in_planes+n*growth_rate)\n", 174 | " self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n", 175 | " in_planes = int(math.floor(in_planes*reduction))\n", 176 | " # 2nd block\n", 177 | " self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 178 | " in_planes = int(in_planes+n*growth_rate)\n", 179 | " self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n", 180 | " in_planes = int(math.floor(in_planes*reduction))\n", 181 | " # 3rd block\n", 182 | " self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 183 | " in_planes = int(in_planes+n*growth_rate)\n", 184 | " # global average pooling and classifier\n", 185 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 186 | " self.relu = nn.ReLU(inplace=True)\n", 187 | " self.fc = nn.Linear(in_planes, num_classes)\n", 188 | " self.in_planes = in_planes\n", 189 | "\n", 190 | " for m in self.modules():\n", 191 | " if isinstance(m, nn.Conv2d):\n", 192 | " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", 193 | " m.weight.data.normal_(0, math.sqrt(2. / n))\n", 194 | " elif isinstance(m, nn.BatchNorm2d):\n", 195 | " m.weight.data.fill_(1)\n", 196 | " m.bias.data.zero_()\n", 197 | " elif isinstance(m, nn.Linear):\n", 198 | " m.bias.data.zero_()\n", 199 | " \n", 200 | " def forward(self, x):\n", 201 | " out = self.conv1(x)\n", 202 | " self.record(out)\n", 203 | " out = self.trans1(self.block1(out))\n", 204 | " out = self.trans2(self.block2(out))\n", 205 | " out = self.block3(out)\n", 206 | " out = self.relu(self.bn1(out))\n", 207 | " self.record(out)\n", 208 | " out = F.avg_pool2d(out, 8)\n", 209 | " out = out.view(-1, self.in_planes)\n", 210 | " return self.fc(out)\n", 211 | " \n", 212 | " def load(self, path=\"densenet_svhn.pth\"):\n", 213 | " tm = torch.load(path,map_location=\"cpu\")\n", 214 | " self.load_state_dict(tm,strict=False)\n", 215 | " \n", 216 | " def record(self, t):\n", 217 | " if self.collecting:\n", 218 | " self.gram_feats.append(t)\n", 219 | " \n", 220 | " def gram_feature_list(self,x):\n", 221 | " self.collecting = True\n", 222 | " self.gram_feats = []\n", 223 | " self.forward(x)\n", 224 | " self.collecting = False\n", 225 | " temp = self.gram_feats\n", 226 | " self.gram_feats = []\n", 227 | " return temp\n", 228 | " \n", 229 | " def get_min_max(self, data, power):\n", 230 | " mins = []\n", 231 | " maxs = []\n", 232 | " \n", 233 | " for i in range(0,len(data),64):\n", 234 | " batch = data[i:i+64].cuda()\n", 235 | " feat_list = self.gram_feature_list(batch)\n", 236 | " for L,feat_L in enumerate(feat_list):\n", 237 | " if L==len(mins):\n", 238 | " mins.append([None]*len(power))\n", 239 | " maxs.append([None]*len(power))\n", 240 | " \n", 241 | " for p,P in enumerate(power):\n", 242 | " g_p = G_p(feat_L,P)\n", 243 | " \n", 244 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n", 245 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n", 246 | " \n", 247 | " if mins[L][p] is None:\n", 248 | " mins[L][p] = current_min\n", 249 | " maxs[L][p] = current_max\n", 250 | " else:\n", 251 | " mins[L][p] = torch.min(current_min,mins[L][p])\n", 252 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n", 253 | " \n", 254 | " return mins,maxs\n", 255 | " \n", 256 | " def get_deviations(self,data,power,mins,maxs):\n", 257 | " deviations = []\n", 258 | " \n", 259 | " for i in range(0,len(data),64): \n", 260 | " batch = data[i:i+64].cuda()\n", 261 | " feat_list = self.gram_feature_list(batch)\n", 262 | " batch_deviations = []\n", 263 | " for L,feat_L in enumerate(feat_list):\n", 264 | " dev = 0\n", 265 | " for p,P in enumerate(power):\n", 266 | " g_p = G_p(feat_L,P)\n", 267 | " \n", 268 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 269 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 270 | " batch_deviations.append(dev.cpu().detach().numpy())\n", 271 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n", 272 | " deviations.append(batch_deviations)\n", 273 | " deviations = np.concatenate(deviations,axis=0)\n", 274 | " \n", 275 | " return deviations\n", 276 | "\n", 277 | "torch_model = DenseNet3(100, num_classes=10)\n", 278 | "torch_model.load()\n", 279 | "torch_model.cuda()\n", 280 | "torch_model.params = list(torch_model.parameters())\n", 281 | "torch_model.eval()\n", 282 | "print(\"Done\") " 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "## Datasets" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "In-distribution Datasets" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 4, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "name": "stdout", 306 | "output_type": "stream", 307 | "text": [ 308 | "Using downloaded and verified file: data/train_32x32.mat\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "batch_size = 128\n", 314 | "mean = np.array([[125.3/255, 123.0/255, 113.9/255]]).T\n", 315 | "\n", 316 | "std = np.array([[63.0/255, 62.1/255.0, 66.7/255.0]]).T\n", 317 | "normalize = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))\n", 318 | "\n", 319 | "transform_train = transforms.Compose([\n", 320 | " transforms.RandomCrop(32, padding=4),\n", 321 | " transforms.RandomHorizontalFlip(),\n", 322 | " transforms.ToTensor(),\n", 323 | " normalize\n", 324 | " \n", 325 | " ])\n", 326 | "transform_test = transforms.Compose([\n", 327 | " transforms.CenterCrop(size=(32, 32)),\n", 328 | " transforms.ToTensor(),\n", 329 | " normalize\n", 330 | " ])\n", 331 | "\n", 332 | "train_loader = torch.utils.data.DataLoader(\n", 333 | " datasets.SVHN('data', split=\"train\", download=True,\n", 334 | " transform=transform_train),\n", 335 | " batch_size=batch_size, shuffle=True)\n", 336 | "test_loader = torch.utils.data.DataLoader(\n", 337 | " datasets.SVHN('data', split=\"test\", transform=transform_test),\n", 338 | " batch_size=batch_size)\n" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 5, 344 | "metadata": { 345 | "scrolled": true 346 | }, 347 | "outputs": [ 348 | { 349 | "name": "stdout", 350 | "output_type": "stream", 351 | "text": [ 352 | "Using downloaded and verified file: data/train_32x32.mat\n" 353 | ] 354 | } 355 | ], 356 | "source": [ 357 | "data_train = list(list(torch.utils.data.DataLoader(\n", 358 | " datasets.SVHN('data', split=\"train\", download=True,\n", 359 | " transform=transform_test),\n", 360 | " batch_size=1, shuffle=True)))" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 6, 366 | "metadata": { 367 | "scrolled": true 368 | }, 369 | "outputs": [ 370 | { 371 | "name": "stdout", 372 | "output_type": "stream", 373 | "text": [ 374 | "Using downloaded and verified file: data/test_32x32.mat\n" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "data = list(list(torch.utils.data.DataLoader(\n", 380 | " datasets.SVHN('data', split=\"test\", download=True,\n", 381 | " transform=transform_test),\n", 382 | " batch_size=1, shuffle=False)))" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 7, 388 | "metadata": {}, 389 | "outputs": [ 390 | { 391 | "name": "stdout", 392 | "output_type": "stream", 393 | "text": [ 394 | "Accuracy: 0.9637753534111863\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "torch_model.eval()\n", 400 | "correct = 0\n", 401 | "total = 0\n", 402 | "for x,y in test_loader:\n", 403 | " x = x.cuda()\n", 404 | " y = y.numpy()\n", 405 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n", 406 | " total += y.shape[0]\n", 407 | "print(\"Accuracy: \",correct/total)\n" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": {}, 413 | "source": [ 414 | "Out-of-distribution Datasets" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 8, 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | "Files already downloaded and verified\n" 427 | ] 428 | } 429 | ], 430 | "source": [ 431 | "cifar10 = list(torch.utils.data.DataLoader(\n", 432 | " datasets.CIFAR10('data', train=False, download=True,\n", 433 | " transform=transform_test),\n", 434 | " batch_size=1, shuffle=False))" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": 9, 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "isun = list(torch.utils.data.DataLoader(\n", 444 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 10, 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [ 453 | "lsun_c = list(torch.utils.data.DataLoader(\n", 454 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 11, 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [ 463 | "lsun_r = list(torch.utils.data.DataLoader(\n", 464 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 12, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n", 474 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 13, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n", 484 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": {}, 490 | "source": [ 491 | "## Code for Detecting OODs" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | " Extract predictions for train and test data " 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 14, 504 | "metadata": {}, 505 | "outputs": [ 506 | { 507 | "name": "stdout", 508 | "output_type": "stream", 509 | "text": [ 510 | "Done\n", 511 | "Done\n" 512 | ] 513 | } 514 | ], 515 | "source": [ 516 | "train_preds = []\n", 517 | "train_confs = []\n", 518 | "train_logits = []\n", 519 | "for idx in range(0,len(data_train),128):\n", 520 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n", 521 | " \n", 522 | " logits = torch_model(batch)\n", 523 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 524 | " preds = np.argmax(confs,axis=1)\n", 525 | " logits = (logits.cpu().detach().numpy())\n", 526 | "\n", 527 | " train_confs.extend(np.max(confs,axis=1)) \n", 528 | " train_preds.extend(preds)\n", 529 | " train_logits.extend(logits)\n", 530 | "print(\"Done\")\n", 531 | "\n", 532 | "test_preds = []\n", 533 | "test_confs = []\n", 534 | "test_logits = []\n", 535 | "\n", 536 | "for idx in range(0,len(data),128):\n", 537 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n", 538 | " \n", 539 | " logits = torch_model(batch)\n", 540 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 541 | " preds = np.argmax(confs,axis=1)\n", 542 | " logits = (logits.cpu().detach().numpy())\n", 543 | "\n", 544 | " test_confs.extend(np.max(confs,axis=1)) \n", 545 | " test_preds.extend(preds)\n", 546 | " test_logits.extend(logits)\n", 547 | "print(\"Done\")" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": {}, 553 | "source": [ 554 | " Code for detecting OODs by identifying anomalies in correlations " 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 15, 560 | "metadata": {}, 561 | "outputs": [], 562 | "source": [ 563 | "import calculate_log as callog\n", 564 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n", 565 | " average_results = {}\n", 566 | " for i in range(1,11):\n", 567 | " random.seed(i)\n", 568 | " \n", 569 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n", 570 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n", 571 | "\n", 572 | " validation = all_test_deviations[validation_indices]\n", 573 | " test_deviations = all_test_deviations[test_indices]\n", 574 | "\n", 575 | " t95 = validation.mean(axis=0)+10**-7\n", 576 | " if not normalize:\n", 577 | " t95 = np.ones_like(t95)\n", 578 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 579 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 580 | " \n", 581 | " results = callog.compute_metric(ood_deviations,test_deviations)\n", 582 | " for m in results:\n", 583 | " average_results[m] = average_results.get(m,0)+results[m]\n", 584 | " \n", 585 | " for m in average_results:\n", 586 | " average_results[m] /= i\n", 587 | " if verbose:\n", 588 | " callog.print_results(average_results)\n", 589 | " return average_results\n", 590 | "\n", 591 | "\n", 592 | "def cpu(ob):\n", 593 | " for i in range(len(ob)):\n", 594 | " for j in range(len(ob[i])):\n", 595 | " ob[i][j] = ob[i][j].cpu()\n", 596 | " return ob\n", 597 | " \n", 598 | "def cuda(ob):\n", 599 | " for i in range(len(ob)):\n", 600 | " for j in range(len(ob[i])):\n", 601 | " ob[i][j] = ob[i][j].cuda()\n", 602 | " return ob\n", 603 | "\n", 604 | "class Detector:\n", 605 | " def __init__(self):\n", 606 | " self.all_test_deviations = None\n", 607 | " self.mins = {}\n", 608 | " self.maxs = {}\n", 609 | " self.classes = range(10)\n", 610 | " \n", 611 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n", 612 | " for PRED in tqdm(self.classes):\n", 613 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n", 614 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n", 615 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n", 616 | " self.mins[PRED] = cpu(mins)\n", 617 | " self.maxs[PRED] = cpu(maxs)\n", 618 | " torch.cuda.empty_cache()\n", 619 | " \n", 620 | " def compute_test_deviations(self,POWERS=[10]):\n", 621 | " all_test_deviations = None\n", 622 | " for PRED in tqdm(self.classes):\n", 623 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n", 624 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n", 625 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n", 626 | " mins = cuda(self.mins[PRED])\n", 627 | " maxs = cuda(self.maxs[PRED])\n", 628 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n", 629 | " cpu(mins)\n", 630 | " cpu(maxs)\n", 631 | " if all_test_deviations is None:\n", 632 | " all_test_deviations = test_deviations\n", 633 | " else:\n", 634 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n", 635 | " torch.cuda.empty_cache()\n", 636 | " self.all_test_deviations = all_test_deviations\n", 637 | " \n", 638 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n", 639 | " ood_preds = []\n", 640 | " ood_confs = []\n", 641 | " \n", 642 | " for idx in range(0,len(ood),128):\n", 643 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n", 644 | " logits = torch_model(batch)\n", 645 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 646 | " preds = np.argmax(confs,axis=1)\n", 647 | " \n", 648 | " ood_confs.extend(np.max(confs,axis=1))\n", 649 | " ood_preds.extend(preds) \n", 650 | " torch.cuda.empty_cache()\n", 651 | " print(\"Done\")\n", 652 | " \n", 653 | " all_ood_deviations = None\n", 654 | " for PRED in tqdm(self.classes):\n", 655 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n", 656 | " if len(ood_indices)==0:\n", 657 | " continue\n", 658 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n", 659 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n", 660 | " mins = cuda(self.mins[PRED])\n", 661 | " maxs = cuda(self.maxs[PRED])\n", 662 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n", 663 | " cpu(self.mins[PRED])\n", 664 | " cpu(self.maxs[PRED]) \n", 665 | " if all_ood_deviations is None:\n", 666 | " all_ood_deviations = ood_deviations\n", 667 | " else:\n", 668 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n", 669 | " torch.cuda.empty_cache()\n", 670 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n", 671 | " return average_results, self.all_test_deviations, all_ood_deviations\n", 672 | " " 673 | ] 674 | }, 675 | { 676 | "cell_type": "markdown", 677 | "metadata": {}, 678 | "source": [ 679 | "

Results

" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 16, 685 | "metadata": { 686 | "scrolled": false 687 | }, 688 | "outputs": [ 689 | { 690 | "data": { 691 | "application/vnd.jupyter.widget-view+json": { 692 | "model_id": "b04957cc14ae4240957cf9c9dbf194be", 693 | "version_major": 2, 694 | "version_minor": 0 695 | }, 696 | "text/plain": [ 697 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 698 | ] 699 | }, 700 | "metadata": {}, 701 | "output_type": "display_data" 702 | }, 703 | { 704 | "name": "stdout", 705 | "output_type": "stream", 706 | "text": [ 707 | "\n" 708 | ] 709 | }, 710 | { 711 | "data": { 712 | "application/vnd.jupyter.widget-view+json": { 713 | "model_id": "9a717142b49e48e4abaec21798c62771", 714 | "version_major": 2, 715 | "version_minor": 0 716 | }, 717 | "text/plain": [ 718 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 719 | ] 720 | }, 721 | "metadata": {}, 722 | "output_type": "display_data" 723 | }, 724 | { 725 | "name": "stdout", 726 | "output_type": "stream", 727 | "text": [ 728 | "\n", 729 | "iSUN\n", 730 | "Done\n" 731 | ] 732 | }, 733 | { 734 | "data": { 735 | "application/vnd.jupyter.widget-view+json": { 736 | "model_id": "b6ae6ea2d1fd41e4ade8fb012369ccc1", 737 | "version_major": 2, 738 | "version_minor": 0 739 | }, 740 | "text/plain": [ 741 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 742 | ] 743 | }, 744 | "metadata": {}, 745 | "output_type": "display_data" 746 | }, 747 | { 748 | "name": "stdout", 749 | "output_type": "stream", 750 | "text": [ 751 | "\n", 752 | " TNR AUROC DTACC AUIN AUOUT \n", 753 | " 99.348 99.796 98.329 99.312 99.925\n", 754 | "LSUN (R)\n", 755 | "Done\n" 756 | ] 757 | }, 758 | { 759 | "data": { 760 | "application/vnd.jupyter.widget-view+json": { 761 | "model_id": "89195f94b7544783b00e0c1b60c065d2", 762 | "version_major": 2, 763 | "version_minor": 0 764 | }, 765 | "text/plain": [ 766 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 767 | ] 768 | }, 769 | "metadata": {}, 770 | "output_type": "display_data" 771 | }, 772 | { 773 | "name": "stdout", 774 | "output_type": "stream", 775 | "text": [ 776 | "\n", 777 | " TNR AUROC DTACC AUIN AUOUT \n", 778 | " 99.504 99.844 98.581 99.500 99.937\n", 779 | "LSUN (C)\n", 780 | "Done\n" 781 | ] 782 | }, 783 | { 784 | "data": { 785 | "application/vnd.jupyter.widget-view+json": { 786 | "model_id": "8bf5d901b38a45d59559203a50e86fec", 787 | "version_major": 2, 788 | "version_minor": 0 789 | }, 790 | "text/plain": [ 791 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 792 | ] 793 | }, 794 | "metadata": {}, 795 | "output_type": "display_data" 796 | }, 797 | { 798 | "name": "stdout", 799 | "output_type": "stream", 800 | "text": [ 801 | "\n", 802 | " TNR AUROC DTACC AUIN AUOUT \n", 803 | " 93.325 98.585 94.326 97.113 99.114\n", 804 | "TinyImgNet (R)\n", 805 | "Done\n" 806 | ] 807 | }, 808 | { 809 | "data": { 810 | "application/vnd.jupyter.widget-view+json": { 811 | "model_id": "109d4db4dfbe4b999a3e48b7d8d6b329", 812 | "version_major": 2, 813 | "version_minor": 0 814 | }, 815 | "text/plain": [ 816 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 817 | ] 818 | }, 819 | "metadata": {}, 820 | "output_type": "display_data" 821 | }, 822 | { 823 | "name": "stdout", 824 | "output_type": "stream", 825 | "text": [ 826 | "\n", 827 | " TNR AUROC DTACC AUIN AUOUT \n", 828 | " 99.095 99.736 97.940 99.251 99.891\n", 829 | "TinyImgNet (C)\n", 830 | "Done\n" 831 | ] 832 | }, 833 | { 834 | "data": { 835 | "application/vnd.jupyter.widget-view+json": { 836 | "model_id": "8295e559036c46c3a27c988a4ace9853", 837 | "version_major": 2, 838 | "version_minor": 0 839 | }, 840 | "text/plain": [ 841 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 842 | ] 843 | }, 844 | "metadata": {}, 845 | "output_type": "display_data" 846 | }, 847 | { 848 | "name": "stdout", 849 | "output_type": "stream", 850 | "text": [ 851 | "\n", 852 | " TNR AUROC DTACC AUIN AUOUT \n", 853 | " 97.881 99.455 96.796 98.636 99.733\n", 854 | "CIFAR-10\n", 855 | "Done\n" 856 | ] 857 | }, 858 | { 859 | "data": { 860 | "application/vnd.jupyter.widget-view+json": { 861 | "model_id": "228d21c857644610a9942393d8539d1e", 862 | "version_major": 2, 863 | "version_minor": 0 864 | }, 865 | "text/plain": [ 866 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 867 | ] 868 | }, 869 | "metadata": {}, 870 | "output_type": "display_data" 871 | }, 872 | { 873 | "name": "stdout", 874 | "output_type": "stream", 875 | "text": [ 876 | "\n", 877 | " TNR AUROC DTACC AUIN AUOUT \n", 878 | " 80.409 95.533 89.051 89.599 97.752\n" 879 | ] 880 | } 881 | ], 882 | "source": [ 883 | "def G_p(ob, p):\n", 884 | " temp = ob.detach()\n", 885 | " \n", 886 | " temp = temp**p\n", 887 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n", 888 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n", 889 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n", 890 | " \n", 891 | " return temp\n", 892 | "\n", 893 | "detector = Detector()\n", 894 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n", 895 | "detector.compute_test_deviations(POWERS=range(1,11))\n", 896 | "\n", 897 | "print(\"iSUN\")\n", 898 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n", 899 | "print(\"LSUN (R)\")\n", 900 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n", 901 | "print(\"LSUN (C)\")\n", 902 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n", 903 | "print(\"TinyImgNet (R)\")\n", 904 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n", 905 | "print(\"TinyImgNet (C)\")\n", 906 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n", 907 | "print(\"CIFAR-10\")\n", 908 | "c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))" 909 | ] 910 | } 911 | ], 912 | "metadata": { 913 | "kernelspec": { 914 | "display_name": "Python 2", 915 | "language": "python", 916 | "name": "python2" 917 | }, 918 | "language_info": { 919 | "codemirror_mode": { 920 | "name": "ipython", 921 | "version": 3 922 | }, 923 | "file_extension": ".py", 924 | "mimetype": "text/x-python", 925 | "name": "python", 926 | "nbconvert_exporter": "python", 927 | "pygments_lexer": "ipython3", 928 | "version": "3.6.9" 929 | } 930 | }, 931 | "nbformat": 4, 932 | "nbformat_minor": 2 933 | } 934 | -------------------------------------------------------------------------------- /DenseNet_Cifar100.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

DenseNet: Cifar100

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from __future__ import division,print_function\n", 24 | "\n", 25 | "%matplotlib inline\n", 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "\n", 29 | "import sys\n", 30 | "from tqdm import tqdm_notebook as tqdm\n", 31 | "\n", 32 | "import random\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import math\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "import torch\n", 39 | "import torch.nn as nn\n", 40 | "import torch.nn.functional as F\n", 41 | "import torch.optim as optim\n", 42 | "import torch.nn.init as init\n", 43 | "from torch.autograd import Variable, grad\n", 44 | "from torchvision import datasets, transforms\n", 45 | "from torch.nn.parameter import Parameter\n", 46 | "\n", 47 | "import calculate_log as callog\n", 48 | "\n", 49 | "import warnings\n", 50 | "warnings.filterwarnings('ignore')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "torch.cuda.set_device(2) #Select the GPU" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Model definition" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "Done\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "def conv3x3(in_planes, out_planes, stride=1):\n", 84 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", 85 | "\n", 86 | "\n", 87 | "class BottleneckBlock(nn.Module):\n", 88 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n", 89 | " super(BottleneckBlock, self).__init__()\n", 90 | " inter_planes = out_planes * 4\n", 91 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 92 | " self.relu = nn.ReLU(inplace=True)\n", 93 | " self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,\n", 94 | " padding=0, bias=False)\n", 95 | " self.bn2 = nn.BatchNorm2d(inter_planes)\n", 96 | " self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,\n", 97 | " padding=1, bias=False)\n", 98 | " self.droprate = dropRate\n", 99 | " \n", 100 | " def forward(self, x):\n", 101 | " \n", 102 | " out = self.conv1(self.relu(self.bn1(x)))\n", 103 | " \n", 104 | " torch_model.record(out)\n", 105 | " \n", 106 | " if self.droprate > 0:\n", 107 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 108 | " \n", 109 | " out = self.conv2(self.relu(self.bn2(out)))\n", 110 | " torch_model.record(out)\n", 111 | " \n", 112 | " if self.droprate > 0:\n", 113 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 114 | " return torch.cat([x, out], 1)\n", 115 | "\n", 116 | "class TransitionBlock(nn.Module):\n", 117 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n", 118 | " super(TransitionBlock, self).__init__()\n", 119 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 120 | " self.relu = nn.ReLU(inplace=True)\n", 121 | " self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n", 122 | " padding=0, bias=False)\n", 123 | " self.droprate = dropRate\n", 124 | " \n", 125 | " def forward(self, x):\n", 126 | " out = self.conv1(self.relu(self.bn1(x)))\n", 127 | " torch_model.record(out)\n", 128 | " \n", 129 | " if self.droprate > 0:\n", 130 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 131 | " return F.avg_pool2d(out, 2)\n", 132 | "\n", 133 | "class DenseBlock(nn.Module):\n", 134 | " def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):\n", 135 | " super(DenseBlock, self).__init__()\n", 136 | " self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)\n", 137 | " \n", 138 | " def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):\n", 139 | " layers = []\n", 140 | " for i in range(int(nb_layers)):\n", 141 | " layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))\n", 142 | " return nn.Sequential(*layers)\n", 143 | " \n", 144 | " def forward(self, x):\n", 145 | " t = self.layer(x)\n", 146 | " torch_model.record(t)\n", 147 | " return t\n", 148 | "\n", 149 | "\n", 150 | "class DenseNet3(nn.Module):\n", 151 | " def __init__(self, depth, num_classes, growth_rate=12,\n", 152 | " reduction=0.5, bottleneck=True, dropRate=0.0):\n", 153 | " super(DenseNet3, self).__init__()\n", 154 | " \n", 155 | " self.collecting = False\n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " in_planes = 2 * growth_rate\n", 160 | " n = (depth - 4) / 3\n", 161 | " if bottleneck == True:\n", 162 | " n = n/2\n", 163 | " block = BottleneckBlock\n", 164 | " else:\n", 165 | " block = BasicBlock\n", 166 | " # 1st conv before any dense block\n", 167 | " self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,\n", 168 | " padding=1, bias=False)\n", 169 | " # 1st block\n", 170 | " self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 171 | " in_planes = int(in_planes+n*growth_rate)\n", 172 | " self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n", 173 | " in_planes = int(math.floor(in_planes*reduction))\n", 174 | " # 2nd block\n", 175 | " self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 176 | " in_planes = int(in_planes+n*growth_rate)\n", 177 | " self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n", 178 | " in_planes = int(math.floor(in_planes*reduction))\n", 179 | " # 3rd block\n", 180 | " self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 181 | " in_planes = int(in_planes+n*growth_rate)\n", 182 | " # global average pooling and classifier\n", 183 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 184 | " self.relu = nn.ReLU(inplace=True)\n", 185 | " self.fc = nn.Linear(in_planes, num_classes)\n", 186 | " self.in_planes = in_planes\n", 187 | "\n", 188 | " for m in self.modules():\n", 189 | " if isinstance(m, nn.Conv2d):\n", 190 | " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", 191 | " m.weight.data.normal_(0, math.sqrt(2. / n))\n", 192 | " elif isinstance(m, nn.BatchNorm2d):\n", 193 | " m.weight.data.fill_(1)\n", 194 | " m.bias.data.zero_()\n", 195 | " elif isinstance(m, nn.Linear):\n", 196 | " m.bias.data.zero_()\n", 197 | " \n", 198 | " def forward(self, x):\n", 199 | " out = self.conv1(x)\n", 200 | " out = self.trans1(self.block1(out))\n", 201 | " out = self.trans2(self.block2(out))\n", 202 | " out = self.block3(out)\n", 203 | " out = self.relu(self.bn1(out))\n", 204 | " out = F.avg_pool2d(out, 8)\n", 205 | " out = out.view(-1, self.in_planes)\n", 206 | " return self.fc(out)\n", 207 | " \n", 208 | " def load(self, path=\"densenet_cifar100.pth\"):\n", 209 | " tm = torch.load(path,map_location=\"cpu\")\n", 210 | " self.load_state_dict(tm.state_dict(),strict=False)\n", 211 | " \n", 212 | " def record(self, t):\n", 213 | " if self.collecting:\n", 214 | " self.gram_feats.append(t)\n", 215 | " \n", 216 | " def gram_feature_list(self,x):\n", 217 | " self.collecting = True\n", 218 | " self.gram_feats = []\n", 219 | " self.forward(x)\n", 220 | " self.collecting = False\n", 221 | " temp = self.gram_feats\n", 222 | " self.gram_feats = []\n", 223 | " return temp\n", 224 | " \n", 225 | " def get_min_max(self, data, power):\n", 226 | " mins = []\n", 227 | " maxs = []\n", 228 | " \n", 229 | " for i in range(0,len(data),64):\n", 230 | " batch = data[i:i+64].cuda()\n", 231 | " feat_list = self.gram_feature_list(batch)\n", 232 | " for L,feat_L in enumerate(feat_list):\n", 233 | " if L==len(mins):\n", 234 | " mins.append([None]*len(power))\n", 235 | " maxs.append([None]*len(power))\n", 236 | " \n", 237 | " for p,P in enumerate(power):\n", 238 | " g_p = G_p(feat_L,P)\n", 239 | " \n", 240 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n", 241 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n", 242 | " \n", 243 | " if mins[L][p] is None:\n", 244 | " mins[L][p] = current_min\n", 245 | " maxs[L][p] = current_max\n", 246 | " else:\n", 247 | " mins[L][p] = torch.min(current_min,mins[L][p])\n", 248 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n", 249 | " \n", 250 | " return mins,maxs\n", 251 | " \n", 252 | " def get_deviations(self,data,power,mins,maxs):\n", 253 | " deviations = []\n", 254 | " \n", 255 | " for i in range(0,len(data),64): \n", 256 | " batch = data[i:i+64].cuda()\n", 257 | " feat_list = self.gram_feature_list(batch)\n", 258 | " batch_deviations = []\n", 259 | " for L,feat_L in enumerate(feat_list):\n", 260 | " dev = 0\n", 261 | " for p,P in enumerate(power):\n", 262 | " g_p = G_p(feat_L,P)\n", 263 | " \n", 264 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 265 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 266 | " batch_deviations.append(dev.cpu().detach().numpy())\n", 267 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n", 268 | " deviations.append(batch_deviations)\n", 269 | " deviations = np.concatenate(deviations,axis=0)\n", 270 | " \n", 271 | " return deviations\n", 272 | "\n", 273 | "torch_model = DenseNet3(100, num_classes=100)\n", 274 | "torch_model.load()\n", 275 | "torch_model.cuda()\n", 276 | "torch_model.params = list(torch_model.parameters())\n", 277 | "torch_model.eval()\n", 278 | "print(\"Done\") " 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "## Datasets" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "In-distribution Datasets" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 4, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "Files already downloaded and verified\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "batch_size = 128\n", 310 | "mean = np.array([[125.3/255, 123.0/255, 113.9/255]]).T\n", 311 | "\n", 312 | "std = np.array([[63.0/255, 62.1/255.0, 66.7/255.0]]).T\n", 313 | "normalize = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))\n", 314 | "\n", 315 | "transform_train = transforms.Compose([\n", 316 | " transforms.RandomCrop(32, padding=4),\n", 317 | " transforms.RandomHorizontalFlip(),\n", 318 | " transforms.ToTensor(),\n", 319 | " normalize\n", 320 | " \n", 321 | " ])\n", 322 | "transform_test = transforms.Compose([\n", 323 | " transforms.CenterCrop(size=(32, 32)),\n", 324 | " transforms.ToTensor(),\n", 325 | " normalize\n", 326 | " ])\n", 327 | "\n", 328 | "train_loader = torch.utils.data.DataLoader(\n", 329 | " datasets.CIFAR100('data', train=True, download=True,\n", 330 | " transform=transform_train),\n", 331 | " batch_size=batch_size, shuffle=True)\n", 332 | "test_loader = torch.utils.data.DataLoader(\n", 333 | " datasets.CIFAR100('data', train=False, transform=transform_test),\n", 334 | " batch_size=batch_size)\n" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 5, 340 | "metadata": { 341 | "scrolled": true 342 | }, 343 | "outputs": [ 344 | { 345 | "name": "stdout", 346 | "output_type": "stream", 347 | "text": [ 348 | "Files already downloaded and verified\n" 349 | ] 350 | } 351 | ], 352 | "source": [ 353 | "data_train = list(torch.utils.data.DataLoader(\n", 354 | " datasets.CIFAR100('data', train=True, download=True,\n", 355 | " transform=transform_test),\n", 356 | " batch_size=1, shuffle=False))" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 6, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "Files already downloaded and verified\n" 369 | ] 370 | } 371 | ], 372 | "source": [ 373 | "data = list(torch.utils.data.DataLoader(\n", 374 | " datasets.CIFAR100('data', train=False, download=True,\n", 375 | " transform=transform_test),\n", 376 | " batch_size=1, shuffle=False))" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 7, 382 | "metadata": {}, 383 | "outputs": [ 384 | { 385 | "name": "stdout", 386 | "output_type": "stream", 387 | "text": [ 388 | "Accuracy: 0.7763\n" 389 | ] 390 | } 391 | ], 392 | "source": [ 393 | "torch_model.eval()\n", 394 | "correct = 0\n", 395 | "total = 0\n", 396 | "for x,y in test_loader:\n", 397 | " x = x.cuda()\n", 398 | " y = y.numpy()\n", 399 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n", 400 | " total += y.shape[0]\n", 401 | "print(\"Accuracy: \",correct/total)\n" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "Out-of-distribution Datasets" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 8, 414 | "metadata": {}, 415 | "outputs": [ 416 | { 417 | "name": "stdout", 418 | "output_type": "stream", 419 | "text": [ 420 | "Files already downloaded and verified\n" 421 | ] 422 | } 423 | ], 424 | "source": [ 425 | "cifar10 = list(torch.utils.data.DataLoader(\n", 426 | " datasets.CIFAR10('data', train=False, download=True,\n", 427 | " transform=transform_test),\n", 428 | " batch_size=1, shuffle=True))" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 9, 434 | "metadata": {}, 435 | "outputs": [ 436 | { 437 | "name": "stdout", 438 | "output_type": "stream", 439 | "text": [ 440 | "Using downloaded and verified file: data/test_32x32.mat\n" 441 | ] 442 | } 443 | ], 444 | "source": [ 445 | "svhn = list(torch.utils.data.DataLoader(\n", 446 | " datasets.SVHN('data', split=\"test\", download=True,\n", 447 | " transform=transform_test),\n", 448 | " batch_size=1, shuffle=True))" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 10, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "isun = list(torch.utils.data.DataLoader(\n", 458 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 11, 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "lsun_c = list(torch.utils.data.DataLoader(\n", 468 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 12, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "lsun_r = list(torch.utils.data.DataLoader(\n", 478 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 13, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n", 488 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 14, 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [ 497 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n", 498 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 499 | ] 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "metadata": {}, 504 | "source": [ 505 | "## Code for Detecting OODs" 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": {}, 511 | "source": [ 512 | " Extract predictions for train and test data " 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 15, 518 | "metadata": {}, 519 | "outputs": [ 520 | { 521 | "name": "stdout", 522 | "output_type": "stream", 523 | "text": [ 524 | "Done\n", 525 | "Done\n" 526 | ] 527 | } 528 | ], 529 | "source": [ 530 | "train_preds = []\n", 531 | "train_confs = []\n", 532 | "train_logits = []\n", 533 | "for idx in range(0,len(data_train),128):\n", 534 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n", 535 | " \n", 536 | " logits = torch_model(batch)\n", 537 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 538 | " preds = np.argmax(confs,axis=1)\n", 539 | " logits = (logits.cpu().detach().numpy())\n", 540 | " \n", 541 | " train_confs.extend(np.max(confs,axis=1)) \n", 542 | " train_preds.extend(preds)\n", 543 | " train_logits.extend(logits)\n", 544 | "print(\"Done\")\n", 545 | "\n", 546 | "test_preds = []\n", 547 | "test_confs = []\n", 548 | "test_logits = []\n", 549 | "\n", 550 | "for idx in range(0,len(data),128):\n", 551 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n", 552 | " \n", 553 | " logits = torch_model(batch)\n", 554 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 555 | " preds = np.argmax(confs,axis=1)\n", 556 | " logits = (logits.cpu().detach().numpy())\n", 557 | "\n", 558 | " test_confs.extend(np.max(confs,axis=1)) \n", 559 | " test_preds.extend(preds)\n", 560 | " test_logits.extend(logits)\n", 561 | "print(\"Done\")" 562 | ] 563 | }, 564 | { 565 | "cell_type": "markdown", 566 | "metadata": {}, 567 | "source": [ 568 | " Code for detecting OODs by identifying anomalies in correlations " 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 16, 574 | "metadata": {}, 575 | "outputs": [], 576 | "source": [ 577 | "import calculate_log as callog\n", 578 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n", 579 | " average_results = {}\n", 580 | " for i in range(1,11):\n", 581 | " random.seed(i)\n", 582 | " \n", 583 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n", 584 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n", 585 | "\n", 586 | " validation = all_test_deviations[validation_indices]\n", 587 | " test_deviations = all_test_deviations[test_indices]\n", 588 | "\n", 589 | " t95 = validation.mean(axis=0)+10**-7\n", 590 | " if not normalize:\n", 591 | " t95 = np.ones_like(t95)\n", 592 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 593 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 594 | " \n", 595 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n", 596 | " for m in results:\n", 597 | " average_results[m] = average_results.get(m,0)+results[m]\n", 598 | " \n", 599 | " for m in average_results:\n", 600 | " average_results[m] /= i\n", 601 | " if verbose:\n", 602 | " callog.print_results(average_results)\n", 603 | " return average_results\n", 604 | "\n", 605 | "\n", 606 | "def cpu(ob):\n", 607 | " for i in range(len(ob)):\n", 608 | " for j in range(len(ob[i])):\n", 609 | " ob[i][j] = ob[i][j].cpu()\n", 610 | " return ob\n", 611 | " \n", 612 | "def cuda(ob):\n", 613 | " for i in range(len(ob)):\n", 614 | " for j in range(len(ob[i])):\n", 615 | " ob[i][j] = ob[i][j].cuda()\n", 616 | " return ob\n", 617 | "\n", 618 | "class Detector:\n", 619 | " def __init__(self):\n", 620 | " self.all_test_deviations = None\n", 621 | " self.mins = {}\n", 622 | " self.maxs = {}\n", 623 | " self.classes = range(100)\n", 624 | " \n", 625 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n", 626 | " for PRED in tqdm(self.classes):\n", 627 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n", 628 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n", 629 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n", 630 | " self.mins[PRED] = cpu(mins)\n", 631 | " self.maxs[PRED] = cpu(maxs)\n", 632 | " torch.cuda.empty_cache()\n", 633 | " \n", 634 | " def compute_test_deviations(self,POWERS=[10]):\n", 635 | " all_test_deviations = None\n", 636 | " for PRED in tqdm(self.classes):\n", 637 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n", 638 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n", 639 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n", 640 | " mins = cuda(self.mins[PRED])\n", 641 | " maxs = cuda(self.maxs[PRED])\n", 642 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n", 643 | " cpu(mins)\n", 644 | " cpu(maxs)\n", 645 | " if all_test_deviations is None:\n", 646 | " all_test_deviations = test_deviations\n", 647 | " else:\n", 648 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n", 649 | " torch.cuda.empty_cache()\n", 650 | " self.all_test_deviations = all_test_deviations\n", 651 | " \n", 652 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n", 653 | " ood_preds = []\n", 654 | " ood_confs = []\n", 655 | " \n", 656 | " for idx in range(0,len(ood),128):\n", 657 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n", 658 | " logits = torch_model(batch)\n", 659 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 660 | " preds = np.argmax(confs,axis=1)\n", 661 | " \n", 662 | " ood_confs.extend(np.max(confs,axis=1))\n", 663 | " ood_preds.extend(preds) \n", 664 | " torch.cuda.empty_cache()\n", 665 | " print(\"Done\")\n", 666 | " \n", 667 | " all_ood_deviations = None\n", 668 | " for PRED in tqdm(self.classes):\n", 669 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n", 670 | " if len(ood_indices)==0:\n", 671 | " continue\n", 672 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n", 673 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n", 674 | " mins = cuda(self.mins[PRED])\n", 675 | " maxs = cuda(self.maxs[PRED])\n", 676 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n", 677 | " cpu(self.mins[PRED])\n", 678 | " cpu(self.maxs[PRED]) \n", 679 | " if all_ood_deviations is None:\n", 680 | " all_ood_deviations = ood_deviations\n", 681 | " else:\n", 682 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n", 683 | " torch.cuda.empty_cache()\n", 684 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n", 685 | " return average_results, self.all_test_deviations, all_ood_deviations\n" 686 | ] 687 | }, 688 | { 689 | "cell_type": "markdown", 690 | "metadata": {}, 691 | "source": [ 692 | "

Results

" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": 17, 698 | "metadata": {}, 699 | "outputs": [ 700 | { 701 | "data": { 702 | "application/vnd.jupyter.widget-view+json": { 703 | "model_id": "6d38db591504460ea13535698dea877c", 704 | "version_major": 2, 705 | "version_minor": 0 706 | }, 707 | "text/plain": [ 708 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 709 | ] 710 | }, 711 | "metadata": {}, 712 | "output_type": "display_data" 713 | }, 714 | { 715 | "name": "stdout", 716 | "output_type": "stream", 717 | "text": [ 718 | "\n" 719 | ] 720 | }, 721 | { 722 | "data": { 723 | "application/vnd.jupyter.widget-view+json": { 724 | "model_id": "ffae25d812914b6c95aa3c1c7fd392a6", 725 | "version_major": 2, 726 | "version_minor": 0 727 | }, 728 | "text/plain": [ 729 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 730 | ] 731 | }, 732 | "metadata": {}, 733 | "output_type": "display_data" 734 | }, 735 | { 736 | "name": "stdout", 737 | "output_type": "stream", 738 | "text": [ 739 | "\n", 740 | "iSUN\n", 741 | "Done\n" 742 | ] 743 | }, 744 | { 745 | "data": { 746 | "application/vnd.jupyter.widget-view+json": { 747 | "model_id": "716da00211014f8988f1c9937cf1d35c", 748 | "version_major": 2, 749 | "version_minor": 0 750 | }, 751 | "text/plain": [ 752 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 753 | ] 754 | }, 755 | "metadata": {}, 756 | "output_type": "display_data" 757 | }, 758 | { 759 | "name": "stdout", 760 | "output_type": "stream", 761 | "text": [ 762 | "\n", 763 | " TNR AUROC DTACC AUIN AUOUT \n", 764 | " 95.867 99.042 95.632 98.990 99.083\n", 765 | "LSUN (R)\n", 766 | "Done\n" 767 | ] 768 | }, 769 | { 770 | "data": { 771 | "application/vnd.jupyter.widget-view+json": { 772 | "model_id": "0397655250134e9d853d8598410ced8b", 773 | "version_major": 2, 774 | "version_minor": 0 775 | }, 776 | "text/plain": [ 777 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 778 | ] 779 | }, 780 | "metadata": {}, 781 | "output_type": "display_data" 782 | }, 783 | { 784 | "name": "stdout", 785 | "output_type": "stream", 786 | "text": [ 787 | "\n", 788 | " TNR AUROC DTACC AUIN AUOUT \n", 789 | " 97.234 99.349 96.393 99.264 99.407\n", 790 | "LSUN (C)\n", 791 | "Done\n" 792 | ] 793 | }, 794 | { 795 | "data": { 796 | "application/vnd.jupyter.widget-view+json": { 797 | "model_id": "36483d6fcf034a8887df000863cb187a", 798 | "version_major": 2, 799 | "version_minor": 0 800 | }, 801 | "text/plain": [ 802 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 803 | ] 804 | }, 805 | "metadata": {}, 806 | "output_type": "display_data" 807 | }, 808 | { 809 | "name": "stdout", 810 | "output_type": "stream", 811 | "text": [ 812 | "\n", 813 | " TNR AUROC DTACC AUIN AUOUT \n", 814 | " 65.517 91.391 83.644 89.553 92.749\n", 815 | "TinyImgNet (R)\n", 816 | "Done\n" 817 | ] 818 | }, 819 | { 820 | "data": { 821 | "application/vnd.jupyter.widget-view+json": { 822 | "model_id": "9aa5b886321c4ba793127afe191bb73c", 823 | "version_major": 2, 824 | "version_minor": 0 825 | }, 826 | "text/plain": [ 827 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 828 | ] 829 | }, 830 | "metadata": {}, 831 | "output_type": "display_data" 832 | }, 833 | { 834 | "name": "stdout", 835 | "output_type": "stream", 836 | "text": [ 837 | "\n", 838 | " TNR AUROC DTACC AUIN AUOUT \n", 839 | " 95.748 98.973 95.522 98.768 99.126\n", 840 | "TinyImgNet (C)\n", 841 | "Done\n" 842 | ] 843 | }, 844 | { 845 | "data": { 846 | "application/vnd.jupyter.widget-view+json": { 847 | "model_id": "1ab175d00f4e44688f226557051e413a", 848 | "version_major": 2, 849 | "version_minor": 0 850 | }, 851 | "text/plain": [ 852 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 853 | ] 854 | }, 855 | "metadata": {}, 856 | "output_type": "display_data" 857 | }, 858 | { 859 | "name": "stdout", 860 | "output_type": "stream", 861 | "text": [ 862 | "\n", 863 | " TNR AUROC DTACC AUIN AUOUT \n", 864 | " 89.013 97.687 92.452 97.303 98.018\n", 865 | "SVHN\n", 866 | "Done\n" 867 | ] 868 | }, 869 | { 870 | "data": { 871 | "application/vnd.jupyter.widget-view+json": { 872 | "model_id": "999c654468da4044bae289c360b2de38", 873 | "version_major": 2, 874 | "version_minor": 0 875 | }, 876 | "text/plain": [ 877 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 878 | ] 879 | }, 880 | "metadata": {}, 881 | "output_type": "display_data" 882 | }, 883 | { 884 | "name": "stdout", 885 | "output_type": "stream", 886 | "text": [ 887 | "\n", 888 | " TNR AUROC DTACC AUIN AUOUT \n", 889 | " 89.341 97.322 92.364 91.652 99.080\n", 890 | "CIFAR-10\n", 891 | "Done\n" 892 | ] 893 | }, 894 | { 895 | "data": { 896 | "application/vnd.jupyter.widget-view+json": { 897 | "model_id": "ebf38974d3bf4350aa44ea06f629e369", 898 | "version_major": 2, 899 | "version_minor": 0 900 | }, 901 | "text/plain": [ 902 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 903 | ] 904 | }, 905 | "metadata": {}, 906 | "output_type": "display_data" 907 | }, 908 | { 909 | "name": "stdout", 910 | "output_type": "stream", 911 | "text": [ 912 | "\n", 913 | " TNR AUROC DTACC AUIN AUOUT \n", 914 | " 10.596 64.227 60.404 61.350 64.092\n" 915 | ] 916 | } 917 | ], 918 | "source": [ 919 | "def G_p(ob, p):\n", 920 | " temp = ob.detach()\n", 921 | " \n", 922 | " temp = temp**p\n", 923 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n", 924 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n", 925 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n", 926 | " \n", 927 | " return temp\n", 928 | "\n", 929 | "detector = Detector()\n", 930 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n", 931 | "\n", 932 | "detector.compute_test_deviations(POWERS=range(1,11))\n", 933 | "\n", 934 | "print(\"iSUN\")\n", 935 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n", 936 | "print(\"LSUN (R)\")\n", 937 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n", 938 | "print(\"LSUN (C)\")\n", 939 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n", 940 | "print(\"TinyImgNet (R)\")\n", 941 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n", 942 | "print(\"TinyImgNet (C)\")\n", 943 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n", 944 | "print(\"SVHN\")\n", 945 | "svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))\n", 946 | "print(\"CIFAR-10\")\n", 947 | "c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))" 948 | ] 949 | } 950 | ], 951 | "metadata": { 952 | "kernelspec": { 953 | "display_name": "Python 2", 954 | "language": "python", 955 | "name": "python2" 956 | }, 957 | "language_info": { 958 | "codemirror_mode": { 959 | "name": "ipython", 960 | "version": 3 961 | }, 962 | "file_extension": ".py", 963 | "mimetype": "text/x-python", 964 | "name": "python", 965 | "nbconvert_exporter": "python", 966 | "pygments_lexer": "ipython3", 967 | "version": "3.6.9" 968 | } 969 | }, 970 | "nbformat": 4, 971 | "nbformat_minor": 2 972 | } 973 | -------------------------------------------------------------------------------- /DenseNet_Cifar10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

DenseNet: Cifar10

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from __future__ import division,print_function\n", 24 | "\n", 25 | "%matplotlib inline\n", 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "\n", 29 | "import sys\n", 30 | "from tqdm import tqdm_notebook as tqdm\n", 31 | "\n", 32 | "import random\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import math\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "import torch\n", 39 | "import torch.nn as nn\n", 40 | "import torch.nn.functional as F\n", 41 | "import torch.optim as optim\n", 42 | "import torch.nn.init as init\n", 43 | "from torch.autograd import Variable, grad\n", 44 | "from torchvision import datasets, transforms\n", 45 | "from torch.nn.parameter import Parameter\n", 46 | "\n", 47 | "import calculate_log as callog\n", 48 | "\n", 49 | "import warnings\n", 50 | "warnings.filterwarnings('ignore')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "torch.cuda.set_device(0) #Select the GPU" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Model definition" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": { 73 | "scrolled": true 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "Done\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "def conv3x3(in_planes, out_planes, stride=1):\n", 86 | " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", 87 | "\n", 88 | "class BottleneckBlock(nn.Module):\n", 89 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n", 90 | " super(BottleneckBlock, self).__init__()\n", 91 | " inter_planes = out_planes * 4\n", 92 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 93 | " self.relu = nn.ReLU(inplace=True)\n", 94 | " self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,\n", 95 | " padding=0, bias=False)\n", 96 | " self.bn2 = nn.BatchNorm2d(inter_planes)\n", 97 | " self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,\n", 98 | " padding=1, bias=False)\n", 99 | " self.droprate = dropRate\n", 100 | " \n", 101 | " def forward(self, x):\n", 102 | " \n", 103 | " out = self.conv1(self.relu(self.bn1(x)))\n", 104 | " \n", 105 | " torch_model.record(out)\n", 106 | " \n", 107 | " if self.droprate > 0:\n", 108 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 109 | " \n", 110 | " out = self.conv2(self.relu(self.bn2(out)))\n", 111 | " torch_model.record(out)\n", 112 | " \n", 113 | " if self.droprate > 0:\n", 114 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 115 | " return torch.cat([x, out], 1)\n", 116 | "\n", 117 | "class TransitionBlock(nn.Module):\n", 118 | " def __init__(self, in_planes, out_planes, dropRate=0.0):\n", 119 | " super(TransitionBlock, self).__init__()\n", 120 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 121 | " self.relu = nn.ReLU(inplace=True)\n", 122 | " self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n", 123 | " padding=0, bias=False)\n", 124 | " self.droprate = dropRate\n", 125 | " \n", 126 | " def forward(self, x):\n", 127 | " out = self.conv1(self.relu(self.bn1(x)))\n", 128 | " torch_model.record(out)\n", 129 | " \n", 130 | " if self.droprate > 0:\n", 131 | " out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)\n", 132 | " return F.avg_pool2d(out, 2)\n", 133 | "\n", 134 | "class DenseBlock(nn.Module):\n", 135 | " def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):\n", 136 | " super(DenseBlock, self).__init__()\n", 137 | " self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)\n", 138 | " \n", 139 | " def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):\n", 140 | " layers = []\n", 141 | " for i in range(int(nb_layers)):\n", 142 | " layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))\n", 143 | " return nn.Sequential(*layers)\n", 144 | " \n", 145 | " def forward(self, x):\n", 146 | " t = self.layer(x)\n", 147 | " torch_model.record(t)\n", 148 | " return t\n", 149 | "\n", 150 | "\n", 151 | "class DenseNet3(nn.Module):\n", 152 | " def __init__(self, depth, num_classes, growth_rate=12,\n", 153 | " reduction=0.5, bottleneck=True, dropRate=0.0):\n", 154 | " super(DenseNet3, self).__init__()\n", 155 | " \n", 156 | " self.collecting = False\n", 157 | " \n", 158 | " in_planes = 2 * growth_rate\n", 159 | " n = (depth - 4) / 3\n", 160 | " if bottleneck == True:\n", 161 | " n = n/2\n", 162 | " block = BottleneckBlock\n", 163 | " else:\n", 164 | " block = BasicBlock\n", 165 | " # 1st conv before any dense block\n", 166 | " self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,\n", 167 | " padding=1, bias=False)\n", 168 | " # 1st block\n", 169 | " self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 170 | " in_planes = int(in_planes+n*growth_rate)\n", 171 | " self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n", 172 | " in_planes = int(math.floor(in_planes*reduction))\n", 173 | " # 2nd block\n", 174 | " self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 175 | " in_planes = int(in_planes+n*growth_rate)\n", 176 | " self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)\n", 177 | " in_planes = int(math.floor(in_planes*reduction))\n", 178 | " # 3rd block\n", 179 | " self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)\n", 180 | " in_planes = int(in_planes+n*growth_rate)\n", 181 | " # global average pooling and classifier\n", 182 | " self.bn1 = nn.BatchNorm2d(in_planes)\n", 183 | " self.relu = nn.ReLU(inplace=True)\n", 184 | " self.fc = nn.Linear(in_planes, num_classes)\n", 185 | " self.in_planes = in_planes\n", 186 | "\n", 187 | " for m in self.modules():\n", 188 | " if isinstance(m, nn.Conv2d):\n", 189 | " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", 190 | " m.weight.data.normal_(0, math.sqrt(2. / n))\n", 191 | " elif isinstance(m, nn.BatchNorm2d):\n", 192 | " m.weight.data.fill_(1)\n", 193 | " m.bias.data.zero_()\n", 194 | " elif isinstance(m, nn.Linear):\n", 195 | " m.bias.data.zero_()\n", 196 | " \n", 197 | " def forward(self, x):\n", 198 | " out = self.conv1(x)\n", 199 | " out = self.trans1(self.block1(out))\n", 200 | " out = self.trans2(self.block2(out))\n", 201 | " out = self.block3(out)\n", 202 | " out = self.relu(self.bn1(out))\n", 203 | " out = F.avg_pool2d(out, 8)\n", 204 | " out = out.view(-1, self.in_planes)\n", 205 | " return self.fc(out)\n", 206 | " \n", 207 | " def load(self, path=\"densenet_cifar10.pth\"):\n", 208 | " tm = torch.load(path,map_location=\"cpu\")\n", 209 | " self.load_state_dict(tm.state_dict(),strict=False)\n", 210 | " \n", 211 | " def record(self, t):\n", 212 | " if self.collecting:\n", 213 | " self.gram_feats.append(t)\n", 214 | " \n", 215 | " def gram_feature_list(self,x):\n", 216 | " self.collecting = True\n", 217 | " self.gram_feats = []\n", 218 | " self.forward(x)\n", 219 | " self.collecting = False\n", 220 | " temp = self.gram_feats\n", 221 | " self.gram_feats = []\n", 222 | " return temp\n", 223 | " \n", 224 | " def get_min_max(self, data, power):\n", 225 | " mins = []\n", 226 | " maxs = []\n", 227 | " \n", 228 | " for i in range(0,len(data),64):\n", 229 | " batch = data[i:i+64].cuda()\n", 230 | " feat_list = self.gram_feature_list(batch)\n", 231 | " for L,feat_L in enumerate(feat_list):\n", 232 | " if L==len(mins):\n", 233 | " mins.append([None]*len(power))\n", 234 | " maxs.append([None]*len(power))\n", 235 | " \n", 236 | " for p,P in enumerate(power):\n", 237 | " g_p = G_p(feat_L,P)\n", 238 | " \n", 239 | " current_min = g_p.min(dim=0,keepdim=True)[0]\n", 240 | " current_max = g_p.max(dim=0,keepdim=True)[0]\n", 241 | " \n", 242 | " if mins[L][p] is None:\n", 243 | " mins[L][p] = current_min\n", 244 | " maxs[L][p] = current_max\n", 245 | " else:\n", 246 | " mins[L][p] = torch.min(current_min,mins[L][p])\n", 247 | " maxs[L][p] = torch.max(current_max,maxs[L][p])\n", 248 | " \n", 249 | " return mins,maxs\n", 250 | " \n", 251 | " def get_deviations(self,data,power,mins,maxs):\n", 252 | " deviations = []\n", 253 | " \n", 254 | " for i in range(0,len(data),64): \n", 255 | " batch = data[i:i+64].cuda()\n", 256 | " feat_list = self.gram_feature_list(batch)\n", 257 | " batch_deviations = []\n", 258 | " for L,feat_L in enumerate(feat_list):\n", 259 | " dev = 0\n", 260 | " for p,P in enumerate(power):\n", 261 | " g_p = G_p(feat_L,P)\n", 262 | " \n", 263 | " dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 264 | " dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)\n", 265 | " batch_deviations.append(dev.cpu().detach().numpy())\n", 266 | " batch_deviations = np.concatenate(batch_deviations,axis=1)\n", 267 | " deviations.append(batch_deviations)\n", 268 | " deviations = np.concatenate(deviations,axis=0)\n", 269 | " \n", 270 | " return deviations\n", 271 | "\n", 272 | "\n", 273 | "torch_model = DenseNet3(100, num_classes=10)\n", 274 | "torch_model.load()\n", 275 | "torch_model.cuda()\n", 276 | "torch_model.params = list(torch_model.parameters())\n", 277 | "torch_model.eval()\n", 278 | "print(\"Done\") " 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "## Datasets" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "In-distribution Datasets" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 4, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "Files already downloaded and verified\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "batch_size = 128\n", 310 | "mean = np.array([[125.3/255, 123.0/255, 113.9/255]]).T\n", 311 | "\n", 312 | "std = np.array([[63.0/255, 62.1/255.0, 66.7/255.0]]).T\n", 313 | "normalize = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))\n", 314 | "\n", 315 | "transform_train = transforms.Compose([\n", 316 | " transforms.RandomCrop(32, padding=4),\n", 317 | " transforms.RandomHorizontalFlip(),\n", 318 | " transforms.ToTensor(),\n", 319 | " normalize\n", 320 | " \n", 321 | " ])\n", 322 | "transform_test = transforms.Compose([\n", 323 | " transforms.CenterCrop(size=(32, 32)),\n", 324 | " transforms.ToTensor(),\n", 325 | " normalize\n", 326 | " ])\n", 327 | "\n", 328 | "train_loader = torch.utils.data.DataLoader(\n", 329 | " datasets.CIFAR10('data', train=True, download=True,\n", 330 | " transform=transform_train),\n", 331 | " batch_size=batch_size, shuffle=True)\n", 332 | "test_loader = torch.utils.data.DataLoader(\n", 333 | " datasets.CIFAR10('data', train=False, transform=transform_test),\n", 334 | " batch_size=batch_size)\n" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 5, 340 | "metadata": { 341 | "scrolled": true 342 | }, 343 | "outputs": [ 344 | { 345 | "name": "stdout", 346 | "output_type": "stream", 347 | "text": [ 348 | "Files already downloaded and verified\n" 349 | ] 350 | } 351 | ], 352 | "source": [ 353 | "data_train = list(torch.utils.data.DataLoader(\n", 354 | " datasets.CIFAR10('data', train=True, download=True,\n", 355 | " transform=transform_test),\n", 356 | " batch_size=1, shuffle=False))" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 6, 362 | "metadata": { 363 | "scrolled": true 364 | }, 365 | "outputs": [ 366 | { 367 | "name": "stdout", 368 | "output_type": "stream", 369 | "text": [ 370 | "Files already downloaded and verified\n" 371 | ] 372 | } 373 | ], 374 | "source": [ 375 | "data = list(torch.utils.data.DataLoader(\n", 376 | " datasets.CIFAR10('data', train=False, download=True,\n", 377 | " transform=transform_test),\n", 378 | " batch_size=1, shuffle=False))" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 7, 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "Accuracy: 0.9519\n" 391 | ] 392 | } 393 | ], 394 | "source": [ 395 | "torch_model.eval()\n", 396 | "correct = 0\n", 397 | "total = 0\n", 398 | "for x,y in test_loader:\n", 399 | " x = x.cuda()\n", 400 | " y = y.numpy()\n", 401 | " correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()\n", 402 | " total += y.shape[0]\n", 403 | "print(\"Accuracy: \",correct/total)\n" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": {}, 409 | "source": [ 410 | "Out-of-distribution Datasets" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 8, 416 | "metadata": {}, 417 | "outputs": [ 418 | { 419 | "name": "stdout", 420 | "output_type": "stream", 421 | "text": [ 422 | "Files already downloaded and verified\n" 423 | ] 424 | } 425 | ], 426 | "source": [ 427 | "cifar100 = list(torch.utils.data.DataLoader(\n", 428 | " datasets.CIFAR100('data', train=False, download=True,\n", 429 | " transform=transform_test),\n", 430 | " batch_size=1, shuffle=False))" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 9, 436 | "metadata": {}, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | "Using downloaded and verified file: data/test_32x32.mat\n" 443 | ] 444 | } 445 | ], 446 | "source": [ 447 | "svhn = list(torch.utils.data.DataLoader(\n", 448 | " datasets.SVHN('data', split=\"test\", download=True,\n", 449 | " transform=transform_test),\n", 450 | " batch_size=1, shuffle=True))" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 10, 456 | "metadata": {}, 457 | "outputs": [], 458 | "source": [ 459 | "isun = list(torch.utils.data.DataLoader(\n", 460 | " datasets.ImageFolder(\"iSUN/\",transform=transform_test),batch_size=1,shuffle=False))" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 11, 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "lsun_c = list(torch.utils.data.DataLoader(\n", 470 | " datasets.ImageFolder(\"LSUN/\",transform=transform_test),batch_size=1,shuffle=True))" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 12, 476 | "metadata": {}, 477 | "outputs": [], 478 | "source": [ 479 | "lsun_r = list(torch.utils.data.DataLoader(\n", 480 | " datasets.ImageFolder(\"LSUN_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 13, 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "tinyimagenet_c = list(torch.utils.data.DataLoader(\n", 490 | " datasets.ImageFolder(\"Imagenet/\",transform=transform_test),batch_size=1,shuffle=True))" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 14, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [ 499 | "tinyimagenet_r = list(torch.utils.data.DataLoader(\n", 500 | " datasets.ImageFolder(\"Imagenet_resize/\",transform=transform_test),batch_size=1,shuffle=True))" 501 | ] 502 | }, 503 | { 504 | "cell_type": "markdown", 505 | "metadata": {}, 506 | "source": [ 507 | "## Code for Detecting OODs" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | " Extract predictions for train and test data " 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 15, 520 | "metadata": {}, 521 | "outputs": [ 522 | { 523 | "name": "stdout", 524 | "output_type": "stream", 525 | "text": [ 526 | "Done\n", 527 | "Done\n" 528 | ] 529 | } 530 | ], 531 | "source": [ 532 | "train_preds = []\n", 533 | "train_confs = []\n", 534 | "train_logits = []\n", 535 | "for idx in range(0,len(data_train),128):\n", 536 | " batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()\n", 537 | " \n", 538 | " logits = torch_model(batch)\n", 539 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 540 | " preds = np.argmax(confs,axis=1)\n", 541 | " logits = (logits.cpu().detach().numpy())\n", 542 | "\n", 543 | " train_confs.extend(np.max(confs,axis=1)) \n", 544 | " train_preds.extend(preds)\n", 545 | " train_logits.extend(logits)\n", 546 | "print(\"Done\")\n", 547 | "\n", 548 | "test_preds = []\n", 549 | "test_confs = []\n", 550 | "test_logits = []\n", 551 | "\n", 552 | "for idx in range(0,len(data),128):\n", 553 | " batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()\n", 554 | " \n", 555 | " logits = torch_model(batch)\n", 556 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 557 | " preds = np.argmax(confs,axis=1)\n", 558 | " logits = (logits.cpu().detach().numpy())\n", 559 | "\n", 560 | " test_confs.extend(np.max(confs,axis=1)) \n", 561 | " test_preds.extend(preds)\n", 562 | " test_logits.extend(logits)\n", 563 | "print(\"Done\")" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "metadata": {}, 569 | "source": [ 570 | " Code for detecting OODs by identifying anomalies in correlations " 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 16, 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "import calculate_log as callog\n", 580 | "def detect(all_test_deviations,all_ood_deviations, verbose=True, normalize=True):\n", 581 | " average_results = {}\n", 582 | " for i in range(1,11):\n", 583 | " random.seed(i)\n", 584 | " \n", 585 | " validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))\n", 586 | " test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))\n", 587 | "\n", 588 | " validation = all_test_deviations[validation_indices]\n", 589 | " test_deviations = all_test_deviations[test_indices]\n", 590 | "\n", 591 | " t95 = validation.mean(axis=0)+10**-7\n", 592 | " if not normalize:\n", 593 | " t95 = np.ones_like(t95)\n", 594 | " test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 595 | " ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)\n", 596 | " \n", 597 | " results = callog.compute_metric(-test_deviations,-ood_deviations)\n", 598 | " for m in results:\n", 599 | " average_results[m] = average_results.get(m,0)+results[m]\n", 600 | " \n", 601 | " for m in average_results:\n", 602 | " average_results[m] /= i\n", 603 | " if verbose:\n", 604 | " callog.print_results(average_results)\n", 605 | " return average_results\n", 606 | "\n", 607 | "\n", 608 | "def cpu(ob):\n", 609 | " for i in range(len(ob)):\n", 610 | " for j in range(len(ob[i])):\n", 611 | " ob[i][j] = ob[i][j].cpu()\n", 612 | " return ob\n", 613 | " \n", 614 | "def cuda(ob):\n", 615 | " for i in range(len(ob)):\n", 616 | " for j in range(len(ob[i])):\n", 617 | " ob[i][j] = ob[i][j].cuda()\n", 618 | " return ob\n", 619 | "\n", 620 | "class Detector:\n", 621 | " def __init__(self):\n", 622 | " self.all_test_deviations = None\n", 623 | " self.mins = {}\n", 624 | " self.maxs = {}\n", 625 | " self.classes = range(10)\n", 626 | " \n", 627 | " def compute_minmaxs(self,data_train,POWERS=[10]):\n", 628 | " for PRED in tqdm(self.classes):\n", 629 | " train_indices = np.where(np.array(train_preds)==PRED)[0]\n", 630 | " train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)\n", 631 | " mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)\n", 632 | " self.mins[PRED] = cpu(mins)\n", 633 | " self.maxs[PRED] = cpu(maxs)\n", 634 | " torch.cuda.empty_cache()\n", 635 | " \n", 636 | " def compute_test_deviations(self,POWERS=[10]):\n", 637 | " all_test_deviations = None\n", 638 | " for PRED in tqdm(self.classes):\n", 639 | " test_indices = np.where(np.array(test_preds)==PRED)[0]\n", 640 | " test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)\n", 641 | " test_confs_PRED = np.array([test_confs[i] for i in test_indices])\n", 642 | " mins = cuda(self.mins[PRED])\n", 643 | " maxs = cuda(self.maxs[PRED])\n", 644 | " test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)/test_confs_PRED[:,np.newaxis]\n", 645 | " cpu(mins)\n", 646 | " cpu(maxs)\n", 647 | " if all_test_deviations is None:\n", 648 | " all_test_deviations = test_deviations\n", 649 | " else:\n", 650 | " all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)\n", 651 | " torch.cuda.empty_cache()\n", 652 | " self.all_test_deviations = all_test_deviations\n", 653 | " \n", 654 | " def compute_ood_deviations(self,ood,POWERS=[10]):\n", 655 | " ood_preds = []\n", 656 | " ood_confs = []\n", 657 | " \n", 658 | " for idx in range(0,len(ood),128):\n", 659 | " batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()\n", 660 | " logits = torch_model(batch)\n", 661 | " confs = F.softmax(logits,dim=1).cpu().detach().numpy()\n", 662 | " preds = np.argmax(confs,axis=1)\n", 663 | " \n", 664 | " ood_confs.extend(np.max(confs,axis=1))\n", 665 | " ood_preds.extend(preds) \n", 666 | " torch.cuda.empty_cache()\n", 667 | " print(\"Done\")\n", 668 | " \n", 669 | " all_ood_deviations = None\n", 670 | " for PRED in tqdm(self.classes):\n", 671 | " ood_indices = np.where(np.array(ood_preds)==PRED)[0]\n", 672 | " if len(ood_indices)==0:\n", 673 | " continue\n", 674 | " ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)\n", 675 | " ood_confs_PRED = np.array([ood_confs[i] for i in ood_indices])\n", 676 | " mins = cuda(self.mins[PRED])\n", 677 | " maxs = cuda(self.maxs[PRED])\n", 678 | " ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)/ood_confs_PRED[:,np.newaxis]\n", 679 | " cpu(self.mins[PRED])\n", 680 | " cpu(self.maxs[PRED]) \n", 681 | " if all_ood_deviations is None:\n", 682 | " all_ood_deviations = ood_deviations\n", 683 | " else:\n", 684 | " all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)\n", 685 | " torch.cuda.empty_cache()\n", 686 | " average_results = detect(self.all_test_deviations,all_ood_deviations)\n", 687 | " return average_results, self.all_test_deviations, all_ood_deviations\n" 688 | ] 689 | }, 690 | { 691 | "cell_type": "markdown", 692 | "metadata": {}, 693 | "source": [ 694 | "

Results

" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 17, 700 | "metadata": { 701 | "scrolled": false 702 | }, 703 | "outputs": [ 704 | { 705 | "data": { 706 | "application/vnd.jupyter.widget-view+json": { 707 | "model_id": "2b50f073b57840a0bd22fb057602fc78", 708 | "version_major": 2, 709 | "version_minor": 0 710 | }, 711 | "text/plain": [ 712 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 713 | ] 714 | }, 715 | "metadata": {}, 716 | "output_type": "display_data" 717 | }, 718 | { 719 | "name": "stdout", 720 | "output_type": "stream", 721 | "text": [ 722 | "\n" 723 | ] 724 | }, 725 | { 726 | "data": { 727 | "application/vnd.jupyter.widget-view+json": { 728 | "model_id": "5fc60ad5623a4ef2b163dbdf8562d051", 729 | "version_major": 2, 730 | "version_minor": 0 731 | }, 732 | "text/plain": [ 733 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 734 | ] 735 | }, 736 | "metadata": {}, 737 | "output_type": "display_data" 738 | }, 739 | { 740 | "name": "stdout", 741 | "output_type": "stream", 742 | "text": [ 743 | "\n", 744 | "iSUN\n", 745 | "Done\n" 746 | ] 747 | }, 748 | { 749 | "data": { 750 | "application/vnd.jupyter.widget-view+json": { 751 | "model_id": "59d3443d14a04e2dba5374e894cbdcb3", 752 | "version_major": 2, 753 | "version_minor": 0 754 | }, 755 | "text/plain": [ 756 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 757 | ] 758 | }, 759 | "metadata": {}, 760 | "output_type": "display_data" 761 | }, 762 | { 763 | "name": "stdout", 764 | "output_type": "stream", 765 | "text": [ 766 | "\n", 767 | " TNR AUROC DTACC AUIN AUOUT \n", 768 | " 99.030 99.800 97.935 99.795 99.802\n", 769 | "LSUN (R)\n", 770 | "Done\n" 771 | ] 772 | }, 773 | { 774 | "data": { 775 | "application/vnd.jupyter.widget-view+json": { 776 | "model_id": "b704f16c4a6b47fc89194c4952de7f88", 777 | "version_major": 2, 778 | "version_minor": 0 779 | }, 780 | "text/plain": [ 781 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 782 | ] 783 | }, 784 | "metadata": {}, 785 | "output_type": "display_data" 786 | }, 787 | { 788 | "name": "stdout", 789 | "output_type": "stream", 790 | "text": [ 791 | "\n", 792 | " TNR AUROC DTACC AUIN AUOUT \n", 793 | " 99.476 99.881 98.633 99.858 99.894\n", 794 | "LSUN (C)\n", 795 | "Done\n" 796 | ] 797 | }, 798 | { 799 | "data": { 800 | "application/vnd.jupyter.widget-view+json": { 801 | "model_id": "d1213c56bf9e43c29a250c7b73578898", 802 | "version_major": 2, 803 | "version_minor": 0 804 | }, 805 | "text/plain": [ 806 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 807 | ] 808 | }, 809 | "metadata": {}, 810 | "output_type": "display_data" 811 | }, 812 | { 813 | "name": "stdout", 814 | "output_type": "stream", 815 | "text": [ 816 | "\n", 817 | " TNR AUROC DTACC AUIN AUOUT \n", 818 | " 88.383 97.446 91.987 96.434 97.914\n", 819 | "TinyImgNet (R)\n", 820 | "Done\n" 821 | ] 822 | }, 823 | { 824 | "data": { 825 | "application/vnd.jupyter.widget-view+json": { 826 | "model_id": "ce16d61aec4f4acf945cc17c65f6d79c", 827 | "version_major": 2, 828 | "version_minor": 0 829 | }, 830 | "text/plain": [ 831 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 832 | ] 833 | }, 834 | "metadata": {}, 835 | "output_type": "display_data" 836 | }, 837 | { 838 | "name": "stdout", 839 | "output_type": "stream", 840 | "text": [ 841 | "\n", 842 | " TNR AUROC DTACC AUIN AUOUT \n", 843 | " 98.783 99.714 97.891 99.562 99.769\n", 844 | "TinyImgNet (C)\n", 845 | "Done\n" 846 | ] 847 | }, 848 | { 849 | "data": { 850 | "application/vnd.jupyter.widget-view+json": { 851 | "model_id": "352bca808e51495eb060301d59df88db", 852 | "version_major": 2, 853 | "version_minor": 0 854 | }, 855 | "text/plain": [ 856 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 857 | ] 858 | }, 859 | "metadata": {}, 860 | "output_type": "display_data" 861 | }, 862 | { 863 | "name": "stdout", 864 | "output_type": "stream", 865 | "text": [ 866 | "\n", 867 | " TNR AUROC DTACC AUIN AUOUT \n", 868 | " 96.693 99.253 96.137 98.957 99.390\n", 869 | "SVHN\n", 870 | "Done\n" 871 | ] 872 | }, 873 | { 874 | "data": { 875 | "application/vnd.jupyter.widget-view+json": { 876 | "model_id": "06cd21e80d4e4992958854774563e5fa", 877 | "version_major": 2, 878 | "version_minor": 0 879 | }, 880 | "text/plain": [ 881 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 882 | ] 883 | }, 884 | "metadata": {}, 885 | "output_type": "display_data" 886 | }, 887 | { 888 | "name": "stdout", 889 | "output_type": "stream", 890 | "text": [ 891 | "\n", 892 | " TNR AUROC DTACC AUIN AUOUT \n", 893 | " 96.072 99.126 95.863 96.782 99.709\n", 894 | "CIFAR-100\n", 895 | "Done\n" 896 | ] 897 | }, 898 | { 899 | "data": { 900 | "application/vnd.jupyter.widget-view+json": { 901 | "model_id": "368efd02834c401ca641c579f9f3ab94", 902 | "version_major": 2, 903 | "version_minor": 0 904 | }, 905 | "text/plain": [ 906 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 907 | ] 908 | }, 909 | "metadata": {}, 910 | "output_type": "display_data" 911 | }, 912 | { 913 | "name": "stdout", 914 | "output_type": "stream", 915 | "text": [ 916 | "\n", 917 | " TNR AUROC DTACC AUIN AUOUT \n", 918 | " 26.683 72.043 67.226 61.419 75.722\n" 919 | ] 920 | } 921 | ], 922 | "source": [ 923 | "def G_p(ob, p):\n", 924 | " temp = ob.detach()\n", 925 | " \n", 926 | " temp = temp**p\n", 927 | " temp = temp.reshape(temp.shape[0],temp.shape[1],-1)\n", 928 | " temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) \n", 929 | " temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)\n", 930 | " \n", 931 | " return temp\n", 932 | "\n", 933 | "detector = Detector()\n", 934 | "detector.compute_minmaxs(data_train,POWERS=range(1,11))\n", 935 | "\n", 936 | "detector.compute_test_deviations(POWERS=range(1,11))\n", 937 | "\n", 938 | "print(\"iSUN\")\n", 939 | "isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))\n", 940 | "print(\"LSUN (R)\")\n", 941 | "lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))\n", 942 | "print(\"LSUN (C)\")\n", 943 | "lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))\n", 944 | "print(\"TinyImgNet (R)\")\n", 945 | "timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))\n", 946 | "print(\"TinyImgNet (C)\")\n", 947 | "timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))\n", 948 | "print(\"SVHN\")\n", 949 | "svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))\n", 950 | "print(\"CIFAR-100\")\n", 951 | "c100_results = detector.compute_ood_deviations(cifar100,POWERS=range(1,11))" 952 | ] 953 | } 954 | ], 955 | "metadata": { 956 | "kernelspec": { 957 | "display_name": "Python 2", 958 | "language": "python", 959 | "name": "python2" 960 | }, 961 | "language_info": { 962 | "codemirror_mode": { 963 | "name": "ipython", 964 | "version": 3 965 | }, 966 | "file_extension": ".py", 967 | "mimetype": "text/x-python", 968 | "name": "python", 969 | "nbconvert_exporter": "python", 970 | "pygments_lexer": "ipython3", 971 | "version": "3.6.9" 972 | } 973 | }, 974 | "nbformat": 4, 975 | "nbformat_minor": 2 976 | } 977 | --------------------------------------------------------------------------------