├── README.md ├── code ├── cal.py ├── calData.py ├── calMetric.py ├── densenet.py ├── main.py ├── softmax_scores │ ├── .gitignore │ ├── confidence_Base_In.txt │ ├── confidence_Base_Out.txt │ ├── confidence_Our_In.txt │ └── confidence_Our_Out.txt └── wideresnet.py └── figures ├── original_optimal_shade.png └── performance.png /README.md: -------------------------------------------------------------------------------- 1 | # ODIN: Out-of-Distribution Detector for Neural Networks 2 | 3 | 4 | This is a [PyTorch](http://pytorch.org) implementation for detecting out-of-distribution examples in neural networks. The method is described in the paper [Principled Detection of Out-of-Distribution Examples in Neural Networks](https://arxiv.org/abs/1706.02690) by S. Liang, [Yixuan Li](http://www.cs.cornell.edu/~yli) and [R. Srikant](https://sites.google.com/a/illinois.edu/srikant/). The method reduces the false positive rate from the baseline 34.7% to 4.3% on the DenseNet (applied to CIFAR-10) when the true positive rate is 95%. 5 |

6 | 7 |

8 | 9 | 10 | ## Experimental Results 11 | 12 | We used two neural network architectures, [DenseNet-BC](https://arxiv.org/abs/1608.06993) and [Wide ResNet](https://arxiv.org/abs/1605.07146). 13 | The PyTorch implementation of [DenseNet-BC](https://arxiv.org/abs/1608.06993) is provided by [Andreas Veit](https://github.com/andreasveit/densenet-pytorch) and [Brandon Amos](https://github.com/bamos/densenet.pytorch). The PyTorch implementation of [Wide ResNet](https://arxiv.org/abs/1605.07146) is provided by [Sergey Zagoruyko](https://github.com/szagoruyko/wide-residual-networks). 14 | The experimental results are shown as follows. The definition of each metric can be found in the [paper](). 15 | ![performance](./figures/performance.png) 16 | 17 | 18 | 19 | ## Pre-trained Models 20 | 21 | We provide four pre-trained neural networks: (1) two [DenseNet-BC](https://arxiv.org/abs/1608.06993) networks trained on CIFAR-10 and CIFAR-100 respectively; (2) two [Wide ResNet](https://arxiv.org/abs/1605.07146) networks trained on CIFAR-10 and CIFAR-100 respectively. The test error rates are given by: 22 | 23 | Architecture | CIFAR-10 | CIFAR-100 24 | ------------ | --------- | --------- 25 | DenseNet-BC | 4.81 | 22.37 26 | Wide ResNet | 3.71 | 19.86 27 | 28 | 29 | ## Running the code 30 | 31 | ### Dependencies 32 | 33 | * CUDA 8.0 34 | * PyTorch 35 | * Anaconda2 or 3 36 | * At least **three** GPU 37 | 38 | **Note:** Reproducing results of DenseNet-BC only requires **one** GPU, but reproducing results of Wide ResNet requires **three** GPUs. Single GPU version for Wide ResNet will be released soon in the future. 39 | 40 | ### Downloading Out-of-Distribtion Datasets 41 | We provide download links of five out-of-distributin datasets: 42 | 43 | * **[Tiny-ImageNet (crop)](https://www.dropbox.com/s/avgm2u562itwpkl/Imagenet.tar.gz)** 44 | * **[Tiny-ImageNet (resize)](https://www.dropbox.com/s/kp3my3412u5k9rl/Imagenet_resize.tar.gz)** 45 | * **[LSUN (crop)](https://www.dropbox.com/s/fhtsw1m3qxlwj6h/LSUN.tar.gz)** 46 | * **[LSUN (resize)](https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz)** 47 | * **[iSUN](https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz)** 48 | 49 | Here is an example code of downloading Tiny-ImageNet (crop) dataset. In the **root** directory, run 50 | 51 | ``` 52 | mkdir data 53 | cd data 54 | wget https://www.dropbox.com/s/avgm2u562itwpkl/Imagenet.tar.gz 55 | tar -xvzf Imagenet.tar.gz 56 | cd .. 57 | ``` 58 | 59 | ### Downloading Neural Network Models 60 | 61 | We provide download links of four pre-trained models. 62 | 63 | * **[DenseNet-BC trained on CIFAR-10](https://www.dropbox.com/s/wr4kjintq1tmorr/densenet10.pth.tar.gz)** 64 | * **[DenseNet-BC trained on CIFAR-100](https://www.dropbox.com/s/vxuv11jjg8bw2v9/densenet100.pth.tar.gz)** 65 | * **[Wide ResNet trained on CIFAR-10](https://www.dropbox.com/s/uiye5nw0uj6ie53/wideresnet10.pth.tar.gz)** 66 | * **[Wide ResNet trained on CIFAR-100](https://www.dropbox.com/s/uiye5nw0uj6ie53/wideresnet100.pth.tar.gz)** 67 | 68 | Here is an example code of downloading DenseNet-BC trained on CIFAR-10. In the **root** directory, run 69 | 70 | ``` 71 | mkdir models 72 | cd models 73 | wget https://www.dropbox.com/s/wr4kjintq1tmorr/densenet10.pth.tar.gz 74 | tar -xvzf densenet10.pth.tar.gz 75 | cd .. 76 | ``` 77 | 78 | 79 | ### Running 80 | 81 | Here is an example code reproducing the results of DenseNet-BC trained on CIFAR-10 where TinyImageNet (crop) is the out-of-distribution dataset. The temperature is set as 1000, and perturbation magnitude is set as 0.0014. In the **root** directory, run 82 | 83 | ``` 84 | cd code 85 | # model: DenseNet-BC, in-distribution: CIFAR-10, out-distribution: TinyImageNet (crop) 86 | # magnitude: 0.0014, temperature 1000, gpu: 0 87 | python main.py --nn densenet10 --out_dataset Imagenet --magnitude 0.0014 --temperature 1000 --gpu 0 88 | ``` 89 | **Note:** Please choose arguments according to the following. 90 | 91 | #### args 92 | * **args.nn**: the arguments of neural networks are shown as follows 93 | 94 | Nerual Network Models | args.nn 95 | ----------------------|-------- 96 | DenseNet-BC trained on CIFAR-10| densenet10 97 | DenseNet-BC trained on CIFAR-100| densenet100 98 | * **args.out_dataset**: the arguments of out-of-distribution datasets are shown as follows 99 | 100 | Out-of-Distribution Datasets | args.out_dataset 101 | ------------------------------------|----------------- 102 | Tiny-ImageNet (crop) | Imagenet 103 | Tiny-ImageNet (resize) | Imagenet_resize 104 | LSUN (crop) | LSUN 105 | LSUN (resize) | LSUN_resize 106 | iSUN | iSUN 107 | Uniform random noise | Uniform 108 | Gaussian random noise | Gaussian 109 | 110 | * **args.magnitude**: the optimal noise magnitude can be found below. In practice, the optimal choices of noise magnitude are model-specific and need to be tuned accordingly. 111 | 112 | Out-of-Distribution Datasets | densenet10 | densenet100 | wideresnet10 | wideresnet100 113 | ------------------------------------|------------------|------------- | -------------- |-------------- 114 | Tiny-ImageNet (crop) | 0.0014 | 0.0014 | 0.0005 | 0.0028 115 | Tiny-ImageNet (resize) | 0.0014 | 0.0028 | 0.0011 | 0.0028 116 | LSUN (crop) | 0 | 0.0028 | 0 | 0.0048 117 | LSUN (resize) | 0.0014 | 0.0028 | 0.0006 | 0.002 118 | iSUN | 0.0014 | 0.0028 | 0.0008 | 0.0028 119 | Uniform random noise | 0.0014 | 0.0028 | 0.0014 | 0.0028 120 | Gaussian random noise | 0.0014 |0.0028 | 0.0014 | 0.0028 121 | 122 | * **args.temperature**: temperature is set to 1000 in all cases. 123 | * **args.gpu**: make sure you use the following gpu when running the code: 124 | 125 | Neural Network Models | args.gpu 126 | ----------------------|---------- 127 | densenet10 | 0 128 | densenet100 | 0 129 | wideresnet10 | 1 130 | wideresnet100 | 2 131 | 132 | ### Outputs 133 | Here is an example of output. 134 | 135 | ``` 136 | Neural network architecture: DenseNet-BC-100 137 | In-distribution dataset: CIFAR-10 138 | Out-of-distribution dataset: Tiny-ImageNet (crop) 139 | 140 | Baseline Our Method 141 | FPR at TPR 95%: 34.8% 4.3% 142 | Detection error: 9.9% 4.6% 143 | AUROC: 95.3% 99.1% 144 | AUPR In: 96.4% 99.2% 145 | AUPR Out: 93.8% 99.1% 146 | ``` 147 | -------------------------------------------------------------------------------- /code/cal.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Sep 19 20:55:56 2015 4 | 5 | @author: liangshiyu 6 | """ 7 | 8 | from __future__ import print_function 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import torch.optim as optim 15 | import torchvision 16 | import torchvision.transforms as transforms 17 | import numpy as np 18 | import time 19 | from scipy import misc 20 | import calMetric as m 21 | import calData as d 22 | #CUDA_DEVICE = 0 23 | 24 | start = time.time() 25 | #loading data sets 26 | 27 | transform = transforms.Compose([ 28 | transforms.ToTensor(), 29 | transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)), 30 | ]) 31 | 32 | 33 | 34 | 35 | # loading neural network 36 | 37 | # Name of neural networks 38 | # Densenet trained on CIFAR-10: densenet10 39 | # Densenet trained on CIFAR-100: densenet100 40 | # Densenet trained on WideResNet-10: wideresnet10 41 | # Densenet trained on WideResNet-100: wideresnet100 42 | #nnName = "densenet10" 43 | 44 | #imName = "Imagenet" 45 | 46 | 47 | 48 | criterion = nn.CrossEntropyLoss() 49 | 50 | 51 | 52 | def test(nnName, dataName, CUDA_DEVICE, epsilon, temperature): 53 | 54 | net1 = torch.load("../models/{}.pth".format(nnName)) 55 | optimizer1 = optim.SGD(net1.parameters(), lr = 0, momentum = 0) 56 | net1.cuda(CUDA_DEVICE) 57 | 58 | if dataName != "Uniform" and dataName != "Gaussian": 59 | testsetout = torchvision.datasets.ImageFolder("../data/{}".format(dataName), transform=transform) 60 | testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=1, 61 | shuffle=False, num_workers=2) 62 | 63 | if nnName == "densenet10" or nnName == "wideresnet10": 64 | testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform) 65 | testloaderIn = torch.utils.data.DataLoader(testset, batch_size=1, 66 | shuffle=False, num_workers=2) 67 | if nnName == "densenet100" or nnName == "wideresnet100": 68 | testset = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform) 69 | testloaderIn = torch.utils.data.DataLoader(testset, batch_size=1, 70 | shuffle=False, num_workers=2) 71 | 72 | if dataName == "Gaussian": 73 | d.testGaussian(net1, criterion, CUDA_DEVICE, testloaderIn, testloaderIn, nnName, dataName, epsilon, temperature) 74 | m.metric(nnName, dataName) 75 | 76 | elif dataName == "Uniform": 77 | d.testUni(net1, criterion, CUDA_DEVICE, testloaderIn, testloaderIn, nnName, dataName, epsilon, temperature) 78 | m.metric(nnName, dataName) 79 | else: 80 | d.testData(net1, criterion, CUDA_DEVICE, testloaderIn, testloaderOut, nnName, dataName, epsilon, temperature) 81 | m.metric(nnName, dataName) 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /code/calData.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Sep 19 20:55:56 2015 4 | 5 | @author: liangshiyu 6 | """ 7 | 8 | from __future__ import print_function 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import torch.optim as optim 15 | import torchvision 16 | import torchvision.transforms as transforms 17 | import numpy as np 18 | import time 19 | from scipy import misc 20 | 21 | def testData(net1, criterion, CUDA_DEVICE, testloader10, testloader, nnName, dataName, noiseMagnitude1, temper): 22 | t0 = time.time() 23 | f1 = open("./softmax_scores/confidence_Base_In.txt", 'w') 24 | f2 = open("./softmax_scores/confidence_Base_Out.txt", 'w') 25 | g1 = open("./softmax_scores/confidence_Our_In.txt", 'w') 26 | g2 = open("./softmax_scores/confidence_Our_Out.txt", 'w') 27 | N = 10000 28 | if dataName == "iSUN": N = 8925 29 | print("Processing in-distribution images") 30 | ########################################In-distribution########################################### 31 | for j, data in enumerate(testloader10): 32 | if j<1000: continue 33 | images, _ = data 34 | 35 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True) 36 | outputs = net1(inputs) 37 | 38 | 39 | # Calculating the confidence of the output, no perturbation added here, no temperature scaling used 40 | nnOutputs = outputs.data.cpu() 41 | nnOutputs = nnOutputs.numpy() 42 | nnOutputs = nnOutputs[0] 43 | nnOutputs = nnOutputs - np.max(nnOutputs) 44 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 45 | f1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 46 | 47 | # Using temperature scaling 48 | outputs = outputs / temper 49 | 50 | # Calculating the perturbation we need to add, that is, 51 | # the sign of gradient of cross entropy loss w.r.t. input 52 | maxIndexTemp = np.argmax(nnOutputs) 53 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE)) 54 | loss = criterion(outputs, labels) 55 | loss.backward() 56 | 57 | # Normalizing the gradient to binary in {0, 1} 58 | gradient = torch.ge(inputs.grad.data, 0) 59 | gradient = (gradient.float() - 0.5) * 2 60 | # Normalizing the gradient to the same space of image 61 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0) 62 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0) 63 | gradient[0][2] = (gradient[0][2])/(66.7/255.0) 64 | # Adding small perturbations to images 65 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient) 66 | outputs = net1(Variable(tempInputs)) 67 | outputs = outputs / temper 68 | # Calculating the confidence after adding perturbations 69 | nnOutputs = outputs.data.cpu() 70 | nnOutputs = nnOutputs.numpy() 71 | nnOutputs = nnOutputs[0] 72 | nnOutputs = nnOutputs - np.max(nnOutputs) 73 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 74 | g1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 75 | if j % 100 == 99: 76 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0)) 77 | t0 = time.time() 78 | 79 | if j == N - 1: break 80 | 81 | 82 | t0 = time.time() 83 | print("Processing out-of-distribution images") 84 | ###################################Out-of-Distributions##################################### 85 | for j, data in enumerate(testloader): 86 | if j<1000: continue 87 | images, _ = data 88 | 89 | 90 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True) 91 | outputs = net1(inputs) 92 | 93 | 94 | 95 | # Calculating the confidence of the output, no perturbation added here 96 | nnOutputs = outputs.data.cpu() 97 | nnOutputs = nnOutputs.numpy() 98 | nnOutputs = nnOutputs[0] 99 | nnOutputs = nnOutputs - np.max(nnOutputs) 100 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 101 | f2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 102 | 103 | # Using temperature scaling 104 | outputs = outputs / temper 105 | 106 | 107 | # Calculating the perturbation we need to add, that is, 108 | # the sign of gradient of cross entropy loss w.r.t. input 109 | maxIndexTemp = np.argmax(nnOutputs) 110 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE)) 111 | loss = criterion(outputs, labels) 112 | loss.backward() 113 | 114 | # Normalizing the gradient to binary in {0, 1} 115 | gradient = (torch.ge(inputs.grad.data, 0)) 116 | gradient = (gradient.float() - 0.5) * 2 117 | # Normalizing the gradient to the same space of image 118 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0) 119 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0) 120 | gradient[0][2] = (gradient[0][2])/(66.7/255.0) 121 | # Adding small perturbations to images 122 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient) 123 | outputs = net1(Variable(tempInputs)) 124 | outputs = outputs / temper 125 | # Calculating the confidence after adding perturbations 126 | nnOutputs = outputs.data.cpu() 127 | nnOutputs = nnOutputs.numpy() 128 | nnOutputs = nnOutputs[0] 129 | nnOutputs = nnOutputs - np.max(nnOutputs) 130 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 131 | g2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 132 | if j % 100 == 99: 133 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0)) 134 | t0 = time.time() 135 | 136 | if j== N-1: break 137 | 138 | 139 | 140 | 141 | def testGaussian(net1, criterion, CUDA_DEVICE, testloader10, testloader, nnName, dataName, noiseMagnitude1, temper): 142 | t0 = time.time() 143 | f1 = open("./softmax_scores/confidence_Base_In.txt", 'w') 144 | f2 = open("./softmax_scores/confidence_Base_Out.txt", 'w') 145 | g1 = open("./softmax_scores/confidence_Our_In.txt", 'w') 146 | g2 = open("./softmax_scores/confidence_Our_Out.txt", 'w') 147 | ########################################In-Distribution############################################### 148 | N = 10000 149 | print("Processing in-distribution images") 150 | for j, data in enumerate(testloader10): 151 | 152 | if j<1000: continue 153 | images, _ = data 154 | 155 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True) 156 | outputs = net1(inputs) 157 | 158 | 159 | # Calculating the confidence of the output, no perturbation added here 160 | nnOutputs = outputs.data.cpu() 161 | nnOutputs = nnOutputs.numpy() 162 | nnOutputs = nnOutputs[0] 163 | nnOutputs = nnOutputs - np.max(nnOutputs) 164 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 165 | f1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 166 | 167 | # Using temperature scaling 168 | outputs = outputs / temper 169 | 170 | # Calculating the perturbation we need to add, that is, 171 | # the sign of gradient of cross entropy loss w.r.t. input 172 | maxIndexTemp = np.argmax(nnOutputs) 173 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE)) 174 | loss = criterion(outputs, labels) 175 | loss.backward() 176 | 177 | 178 | # Normalizing the gradient to binary in {0, 1} 179 | gradient = (torch.ge(inputs.grad.data, 0)) 180 | gradient = (gradient.float() - 0.5) * 2 181 | # Normalizing the gradient to the same space of image 182 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0) 183 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0) 184 | gradient[0][2] = (gradient[0][2])/(66.7/255.0) 185 | # Adding small perturbations to images 186 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient) 187 | outputs = net1(Variable(tempInputs)) 188 | outputs = outputs / temper 189 | # Calculating the confidence after adding perturbations 190 | nnOutputs = outputs.data.cpu() 191 | nnOutputs = nnOutputs.numpy() 192 | nnOutputs = nnOutputs[0] 193 | nnOutputs = nnOutputs - np.max(nnOutputs) 194 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 195 | 196 | g1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 197 | if j % 100 == 99: 198 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0)) 199 | t0 = time.time() 200 | 201 | 202 | 203 | ########################################Out-of-Distribution###################################### 204 | print("Processing out-of-distribution images") 205 | for j, data in enumerate(testloader): 206 | if j<1000: continue 207 | 208 | images = torch.randn(1,3,32,32) + 0.5 209 | images = torch.clamp(images, 0, 1) 210 | images[0][0] = (images[0][0] - 125.3/255) / (63.0/255) 211 | images[0][1] = (images[0][1] - 123.0/255) / (62.1/255) 212 | images[0][2] = (images[0][2] - 113.9/255) / (66.7/255) 213 | 214 | 215 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True) 216 | outputs = net1(inputs) 217 | 218 | 219 | 220 | # Calculating the confidence of the output, no perturbation added here 221 | nnOutputs = outputs.data.cpu() 222 | nnOutputs = nnOutputs.numpy() 223 | nnOutputs = nnOutputs[0] 224 | nnOutputs = nnOutputs - np.max(nnOutputs) 225 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 226 | f2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 227 | 228 | # Using temperature scaling 229 | outputs = outputs / temper 230 | 231 | # Calculating the perturbation we need to add, that is, 232 | # the sign of gradient of cross entropy loss w.r.t. input 233 | maxIndexTemp = np.argmax(nnOutputs) 234 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE)) 235 | loss = criterion(outputs, labels) 236 | loss.backward() 237 | 238 | # Normalizing the gradient to binary in {0, 1} 239 | gradient = (torch.ge(inputs.grad.data, 0)) 240 | gradient = (gradient.float() - 0.5) * 2 241 | # Normalizing the gradient to the same space of image 242 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0) 243 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0) 244 | gradient[0][2] = (gradient[0][2])/(66.7/255.0) 245 | # Adding small perturbations to images 246 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient) 247 | outputs = net1(Variable(tempInputs)) 248 | outputs = outputs / temper 249 | # Calculating the confidence after adding perturbations 250 | nnOutputs = outputs.data.cpu() 251 | nnOutputs = nnOutputs.numpy() 252 | nnOutputs = nnOutputs[0] 253 | nnOutputs = nnOutputs - np.max(nnOutputs) 254 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 255 | g2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 256 | 257 | if j % 100 == 99: 258 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0)) 259 | t0 = time.time() 260 | 261 | if j== N-1: break 262 | 263 | 264 | 265 | 266 | def testUni(net1, criterion, CUDA_DEVICE, testloader10, testloader, nnName, dataName, noiseMagnitude1, temper): 267 | t0 = time.time() 268 | f1 = open("./softmax_scores/confidence_Base_In.txt", 'w') 269 | f2 = open("./softmax_scores/confidence_Base_Out.txt", 'w') 270 | g1 = open("./softmax_scores/confidence_Our_In.txt", 'w') 271 | g2 = open("./softmax_scores/confidence_Our_Out.txt", 'w') 272 | ########################################In-Distribution############################################### 273 | N = 10000 274 | print("Processing in-distribution images") 275 | for j, data in enumerate(testloader10): 276 | if j<1000: continue 277 | 278 | images, _ = data 279 | 280 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True) 281 | outputs = net1(inputs) 282 | 283 | 284 | # Calculating the confidence of the output, no perturbation added here 285 | nnOutputs = outputs.data.cpu() 286 | nnOutputs = nnOutputs.numpy() 287 | nnOutputs = nnOutputs[0] 288 | nnOutputs = nnOutputs - np.max(nnOutputs) 289 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 290 | f1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 291 | 292 | # Using temperature scaling 293 | outputs = outputs / temper 294 | 295 | # Calculating the perturbation we need to add, that is, 296 | # the sign of gradient of cross entropy loss w.r.t. input 297 | maxIndexTemp = np.argmax(nnOutputs) 298 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE)) 299 | loss = criterion(outputs, labels) 300 | loss.backward() 301 | 302 | 303 | # Normalizing the gradient to binary in {0, 1} 304 | gradient = (torch.ge(inputs.grad.data, 0)) 305 | gradient = (gradient.float() - 0.5) * 2 306 | # Normalizing the gradient to the same space of image 307 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0) 308 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0) 309 | gradient[0][2] = (gradient[0][2])/(66.7/255.0) 310 | # Adding small perturbations to images 311 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient) 312 | outputs = net1(Variable(tempInputs)) 313 | outputs = outputs / temper 314 | # Calculating the confidence after adding perturbations 315 | nnOutputs = outputs.data.cpu() 316 | nnOutputs = nnOutputs.numpy() 317 | nnOutputs = nnOutputs[0] 318 | nnOutputs = nnOutputs - np.max(nnOutputs) 319 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 320 | 321 | g1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 322 | if j % 100 == 99: 323 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0)) 324 | t0 = time.time() 325 | 326 | 327 | 328 | ########################################Out-of-Distribution###################################### 329 | print("Processing out-of-distribution images") 330 | for j, data in enumerate(testloader): 331 | if j<1000: continue 332 | 333 | images = torch.rand(1,3,32,32) 334 | images[0][0] = (images[0][0] - 125.3/255) / (63.0/255) 335 | images[0][1] = (images[0][1] - 123.0/255) / (62.1/255) 336 | images[0][2] = (images[0][2] - 113.9/255) / (66.7/255) 337 | 338 | 339 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True) 340 | outputs = net1(inputs) 341 | 342 | 343 | 344 | # Calculating the confidence of the output, no perturbation added here 345 | nnOutputs = outputs.data.cpu() 346 | nnOutputs = nnOutputs.numpy() 347 | nnOutputs = nnOutputs[0] 348 | nnOutputs = nnOutputs - np.max(nnOutputs) 349 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 350 | f2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 351 | 352 | # Using temperature scaling 353 | outputs = outputs / temper 354 | 355 | # Calculating the perturbation we need to add, that is, 356 | # the sign of gradient of cross entropy loss w.r.t. input 357 | maxIndexTemp = np.argmax(nnOutputs) 358 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE)) 359 | loss = criterion(outputs, labels) 360 | loss.backward() 361 | 362 | # Normalizing the gradient to binary in {0, 1} 363 | gradient = (torch.ge(inputs.grad.data, 0)) 364 | gradient = (gradient.float() - 0.5) * 2 365 | # Normalizing the gradient to the same space of image 366 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0) 367 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0) 368 | gradient[0][2] = (gradient[0][2])/(66.7/255.0) 369 | # Adding small perturbations to images 370 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient) 371 | outputs = net1(Variable(tempInputs)) 372 | outputs = outputs / temper 373 | # Calculating the confidence after adding perturbations 374 | nnOutputs = outputs.data.cpu() 375 | nnOutputs = nnOutputs.numpy() 376 | nnOutputs = nnOutputs[0] 377 | nnOutputs = nnOutputs - np.max(nnOutputs) 378 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs)) 379 | g2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs))) 380 | if j % 100 == 99: 381 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0)) 382 | t0 = time.time() 383 | 384 | if j== N-1: break 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | -------------------------------------------------------------------------------- /code/calMetric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Sep 19 20:55:56 2015 4 | 5 | @author: liangshiyu 6 | """ 7 | 8 | from __future__ import print_function 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import torch.optim as optim 15 | import torchvision 16 | import torchvision.transforms as transforms 17 | #import matplotlib.pyplot as plt 18 | import numpy as np 19 | import time 20 | from scipy import misc 21 | 22 | 23 | def tpr95(name): 24 | #calculate the falsepositive error when tpr is 95% 25 | # calculate baseline 26 | T = 1 27 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',') 28 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',') 29 | if name == "CIFAR-10": 30 | start = 0.1 31 | end = 1 32 | if name == "CIFAR-100": 33 | start = 0.01 34 | end = 1 35 | gap = (end- start)/100000 36 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 37 | Y1 = other[:, 2] 38 | X1 = cifar[:, 2] 39 | total = 0.0 40 | fpr = 0.0 41 | for delta in np.arange(start, end, gap): 42 | tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1)) 43 | error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1)) 44 | if tpr <= 0.9505 and tpr >= 0.9495: 45 | fpr += error2 46 | total += 1 47 | fprBase = fpr/total 48 | 49 | # calculate our algorithm 50 | T = 1000 51 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',') 52 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',') 53 | if name == "CIFAR-10": 54 | start = 0.1 55 | end = 0.12 56 | if name == "CIFAR-100": 57 | start = 0.01 58 | end = 0.0104 59 | gap = (end- start)/100000 60 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 61 | Y1 = other[:, 2] 62 | X1 = cifar[:, 2] 63 | total = 0.0 64 | fpr = 0.0 65 | for delta in np.arange(start, end, gap): 66 | tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1)) 67 | error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1)) 68 | if tpr <= 0.9505 and tpr >= 0.9495: 69 | fpr += error2 70 | total += 1 71 | fprNew = fpr/total 72 | 73 | return fprBase, fprNew 74 | 75 | def auroc(name): 76 | #calculate the AUROC 77 | # calculate baseline 78 | T = 1 79 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',') 80 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',') 81 | if name == "CIFAR-10": 82 | start = 0.1 83 | end = 1 84 | if name == "CIFAR-100": 85 | start = 0.01 86 | end = 1 87 | gap = (end- start)/100000 88 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 89 | Y1 = other[:, 2] 90 | X1 = cifar[:, 2] 91 | aurocBase = 0.0 92 | fprTemp = 1.0 93 | for delta in np.arange(start, end, gap): 94 | tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1)) 95 | fpr = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1)) 96 | aurocBase += (-fpr+fprTemp)*tpr 97 | fprTemp = fpr 98 | aurocBase += fpr * tpr 99 | # calculate our algorithm 100 | T = 1000 101 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',') 102 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',') 103 | if name == "CIFAR-10": 104 | start = 0.1 105 | end = 0.12 106 | if name == "CIFAR-100": 107 | start = 0.01 108 | end = 0.0104 109 | gap = (end- start)/100000 110 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 111 | Y1 = other[:, 2] 112 | X1 = cifar[:, 2] 113 | aurocNew = 0.0 114 | fprTemp = 1.0 115 | for delta in np.arange(start, end, gap): 116 | tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1)) 117 | fpr = np.sum(np.sum(Y1 >= delta)) / np.float(len(Y1)) 118 | aurocNew += (-fpr+fprTemp)*tpr 119 | fprTemp = fpr 120 | aurocNew += fpr * tpr 121 | return aurocBase, aurocNew 122 | 123 | def auprIn(name): 124 | #calculate the AUPR 125 | # calculate baseline 126 | T = 1 127 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',') 128 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',') 129 | if name == "CIFAR-10": 130 | start = 0.1 131 | end = 1 132 | if name == "CIFAR-100": 133 | start = 0.01 134 | end = 1 135 | gap = (end- start)/100000 136 | precisionVec = [] 137 | recallVec = [] 138 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 139 | Y1 = other[:, 2] 140 | X1 = cifar[:, 2] 141 | auprBase = 0.0 142 | recallTemp = 1.0 143 | for delta in np.arange(start, end, gap): 144 | tp = np.sum(np.sum(X1 >= delta)) / np.float(len(X1)) 145 | fp = np.sum(np.sum(Y1 >= delta)) / np.float(len(Y1)) 146 | if tp + fp == 0: continue 147 | precision = tp / (tp + fp) 148 | recall = tp 149 | precisionVec.append(precision) 150 | recallVec.append(recall) 151 | auprBase += (recallTemp-recall)*precision 152 | recallTemp = recall 153 | auprBase += recall * precision 154 | #print(recall, precision) 155 | 156 | # calculate our algorithm 157 | T = 1000 158 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',') 159 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',') 160 | if name == "CIFAR-10": 161 | start = 0.1 162 | end = 0.12 163 | if name == "CIFAR-100": 164 | start = 0.01 165 | end = 0.0104 166 | gap = (end- start)/100000 167 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 168 | Y1 = other[:, 2] 169 | X1 = cifar[:, 2] 170 | auprNew = 0.0 171 | recallTemp = 1.0 172 | for delta in np.arange(start, end, gap): 173 | tp = np.sum(np.sum(X1 >= delta)) / np.float(len(X1)) 174 | fp = np.sum(np.sum(Y1 >= delta)) / np.float(len(Y1)) 175 | if tp + fp == 0: continue 176 | precision = tp / (tp + fp) 177 | recall = tp 178 | #precisionVec.append(precision) 179 | #recallVec.append(recall) 180 | auprNew += (recallTemp-recall)*precision 181 | recallTemp = recall 182 | auprNew += recall * precision 183 | return auprBase, auprNew 184 | 185 | def auprOut(name): 186 | #calculate the AUPR 187 | # calculate baseline 188 | T = 1 189 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',') 190 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',') 191 | if name == "CIFAR-10": 192 | start = 0.1 193 | end = 1 194 | if name == "CIFAR-100": 195 | start = 0.01 196 | end = 1 197 | gap = (end- start)/100000 198 | Y1 = other[:, 2] 199 | X1 = cifar[:, 2] 200 | auprBase = 0.0 201 | recallTemp = 1.0 202 | for delta in np.arange(end, start, -gap): 203 | fp = np.sum(np.sum(X1 < delta)) / np.float(len(X1)) 204 | tp = np.sum(np.sum(Y1 < delta)) / np.float(len(Y1)) 205 | if tp + fp == 0: break 206 | precision = tp / (tp + fp) 207 | recall = tp 208 | auprBase += (recallTemp-recall)*precision 209 | recallTemp = recall 210 | auprBase += recall * precision 211 | 212 | 213 | # calculate our algorithm 214 | T = 1000 215 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',') 216 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',') 217 | if name == "CIFAR-10": 218 | start = 0.1 219 | end = 0.12 220 | if name == "CIFAR-100": 221 | start = 0.01 222 | end = 0.0104 223 | gap = (end- start)/100000 224 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 225 | Y1 = other[:, 2] 226 | X1 = cifar[:, 2] 227 | auprNew = 0.0 228 | recallTemp = 1.0 229 | for delta in np.arange(end, start, -gap): 230 | fp = np.sum(np.sum(X1 < delta)) / np.float(len(X1)) 231 | tp = np.sum(np.sum(Y1 < delta)) / np.float(len(Y1)) 232 | if tp + fp == 0: break 233 | precision = tp / (tp + fp) 234 | recall = tp 235 | auprNew += (recallTemp-recall)*precision 236 | recallTemp = recall 237 | auprNew += recall * precision 238 | return auprBase, auprNew 239 | 240 | 241 | 242 | def detection(name): 243 | #calculate the minimum detection error 244 | # calculate baseline 245 | T = 1 246 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',') 247 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',') 248 | if name == "CIFAR-10": 249 | start = 0.1 250 | end = 1 251 | if name == "CIFAR-100": 252 | start = 0.01 253 | end = 1 254 | gap = (end- start)/100000 255 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 256 | Y1 = other[:, 2] 257 | X1 = cifar[:, 2] 258 | errorBase = 1.0 259 | for delta in np.arange(start, end, gap): 260 | tpr = np.sum(np.sum(X1 < delta)) / np.float(len(X1)) 261 | error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1)) 262 | errorBase = np.minimum(errorBase, (tpr+error2)/2.0) 263 | 264 | # calculate our algorithm 265 | T = 1000 266 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',') 267 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',') 268 | if name == "CIFAR-10": 269 | start = 0.1 270 | end = 0.12 271 | if name == "CIFAR-100": 272 | start = 0.01 273 | end = 0.0104 274 | gap = (end- start)/100000 275 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w') 276 | Y1 = other[:, 2] 277 | X1 = cifar[:, 2] 278 | errorNew = 1.0 279 | for delta in np.arange(start, end, gap): 280 | tpr = np.sum(np.sum(X1 < delta)) / np.float(len(X1)) 281 | error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1)) 282 | errorNew = np.minimum(errorNew, (tpr+error2)/2.0) 283 | 284 | return errorBase, errorNew 285 | 286 | 287 | 288 | 289 | def metric(nn, data): 290 | if nn == "densenet10" or nn == "wideresnet10": indis = "CIFAR-10" 291 | if nn == "densenet100" or nn == "wideresnet100": indis = "CIFAR-100" 292 | if nn == "densenet10" or nn == "densenet100": nnStructure = "DenseNet-BC-100" 293 | if nn == "wideresnet10" or nn == "wideresnet100": nnStructure = "Wide-ResNet-28-10" 294 | 295 | if data == "Imagenet": dataName = "Tiny-ImageNet (crop)" 296 | if data == "Imagenet_resize": dataName = "Tiny-ImageNet (resize)" 297 | if data == "LSUN": dataName = "LSUN (crop)" 298 | if data == "LSUN_resize": dataName = "LSUN (resize)" 299 | if data == "iSUN": dataName = "iSUN" 300 | if data == "Gaussian": dataName = "Gaussian noise" 301 | if data == "Uniform": dataName = "Uniform Noise" 302 | fprBase, fprNew = tpr95(indis) 303 | errorBase, errorNew = detection(indis) 304 | aurocBase, aurocNew = auroc(indis) 305 | auprinBase, auprinNew = auprIn(indis) 306 | auproutBase, auproutNew = auprOut(indis) 307 | print("{:31}{:>22}".format("Neural network architecture:", nnStructure)) 308 | print("{:31}{:>22}".format("In-distribution dataset:", indis)) 309 | print("{:31}{:>22}".format("Out-of-distribution dataset:", dataName)) 310 | print("") 311 | print("{:>34}{:>19}".format("Baseline", "Our Method")) 312 | print("{:20}{:13.1f}%{:>18.1f}% ".format("FPR at TPR 95%:",fprBase*100, fprNew*100)) 313 | print("{:20}{:13.1f}%{:>18.1f}%".format("Detection error:",errorBase*100, errorNew*100)) 314 | print("{:20}{:13.1f}%{:>18.1f}%".format("AUROC:",aurocBase*100, aurocNew*100)) 315 | print("{:20}{:13.1f}%{:>18.1f}%".format("AUPR In:",auprinBase*100, auprinNew*100)) 316 | print("{:20}{:13.1f}%{:>18.1f}%".format("AUPR Out:",auproutBase*100, auproutNew*100)) 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | -------------------------------------------------------------------------------- /code/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 13 | padding=1, bias=False) 14 | self.droprate = dropRate 15 | def forward(self, x): 16 | out = self.conv1(self.relu(self.bn1(x))) 17 | if self.droprate > 0: 18 | out = F.dropout(out, p=self.droprate, training=self.training) 19 | return torch.cat([x, out], 1) 20 | 21 | class BottleneckBlock(nn.Module): 22 | def __init__(self, in_planes, out_planes, dropRate=0.0): 23 | super(BottleneckBlock, self).__init__() 24 | inter_planes = out_planes * 4 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | self.bn2 = nn.BatchNorm2d(inter_planes) 30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 31 | padding=1, bias=False) 32 | self.droprate = dropRate 33 | def forward(self, x): 34 | out = self.conv1(self.relu(self.bn1(x))) 35 | if self.droprate > 0: 36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 37 | out = self.conv2(self.relu(self.bn2(out))) 38 | if self.droprate > 0: 39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 40 | return torch.cat([x, out], 1) 41 | 42 | class TransitionBlock(nn.Module): 43 | def __init__(self, in_planes, out_planes, dropRate=0.0): 44 | super(TransitionBlock, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 48 | padding=0, bias=False) 49 | self.droprate = dropRate 50 | def forward(self, x): 51 | out = self.conv1(self.relu(self.bn1(x))) 52 | if self.droprate > 0: 53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 54 | return F.avg_pool2d(out, 2) 55 | 56 | class DenseBlock(nn.Module): 57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0): 58 | super(DenseBlock, self).__init__() 59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate) 60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate): 61 | layers = [] 62 | for i in range(nb_layers): 63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate)) 64 | return nn.Sequential(*layers) 65 | def forward(self, x): 66 | return self.layer(x) 67 | 68 | class DenseNet3(nn.Module): 69 | def __init__(self, depth, num_classes, growth_rate=12, 70 | reduction=0.5, bottleneck=True, dropRate=0.0): 71 | super(DenseNet3, self).__init__() 72 | in_planes = 2 * growth_rate 73 | n = (depth - 4) / 3 74 | if bottleneck == True: 75 | n = n/2 76 | block = BottleneckBlock 77 | else: 78 | block = BasicBlock 79 | # 1st conv before any dense block 80 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, 81 | padding=1, bias=False) 82 | # 1st block 83 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 84 | in_planes = int(in_planes+n*growth_rate) 85 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 86 | in_planes = int(math.floor(in_planes*reduction)) 87 | # 2nd block 88 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 89 | in_planes = int(in_planes+n*growth_rate) 90 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 91 | in_planes = int(math.floor(in_planes*reduction)) 92 | # 3rd block 93 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 94 | in_planes = int(in_planes+n*growth_rate) 95 | # global average pooling and classifier 96 | self.bn1 = nn.BatchNorm2d(in_planes) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.fc = nn.Linear(in_planes, num_classes) 99 | self.in_planes = in_planes 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | elif isinstance(m, nn.Linear): 109 | m.bias.data.zero_() 110 | def forward(self, x): 111 | out = self.conv1(x) 112 | out = self.trans1(self.block1(out)) 113 | out = self.trans2(self.block2(out)) 114 | out = self.block3(out) 115 | out = self.relu(self.bn1(out)) 116 | out = F.avg_pool2d(out, 8) 117 | out = out.view(-1, self.in_planes) 118 | return self.fc(out) 119 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Sep 19 20:55:56 2015 4 | 5 | @author: liangshiyu 6 | """ 7 | 8 | from __future__ import print_function 9 | import argparse 10 | import os 11 | import torch 12 | from torch.autograd import Variable 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import numpy as np 16 | import torch.optim as optim 17 | import torchvision 18 | import torchvision.transforms as transforms 19 | #import matplotlib.pyplot as plt 20 | import numpy as np 21 | import time 22 | #import lmdb 23 | from scipy import misc 24 | import cal as c 25 | 26 | 27 | parser = argparse.ArgumentParser(description='Pytorch Detecting Out-of-distribution examples in neural networks') 28 | 29 | parser.add_argument('--nn', default="densenet10", type=str, 30 | help='neural network name and training set') 31 | parser.add_argument('--out_dataset', default="Imagenet", type=str, 32 | help='out-of-distribution dataset') 33 | parser.add_argument('--magnitude', default=0.0014, type=float, 34 | help='perturbation magnitude') 35 | parser.add_argument('--temperature', default=1000, type=int, 36 | help='temperature scaling') 37 | parser.add_argument('--gpu', default = 0, type = int, 38 | help='gpu index') 39 | parser.set_defaults(argument=True) 40 | 41 | 42 | 43 | 44 | 45 | # Setting the name of neural networks 46 | 47 | # Densenet trained on CIFAR-10: densenet10 48 | # Densenet trained on CIFAR-100: densenet100 49 | # Wide-ResNet trained on CIFAR-10: wideresnet10 50 | # Wide-ResNet trained on CIFAR-100: wideresnet100 51 | #nnName = "densenet10" 52 | 53 | # Setting the name of the out-of-distribution dataset 54 | 55 | # Tiny-ImageNet (crop): Imagenet 56 | # Tiny-ImageNet (resize): Imagenet_resize 57 | # LSUN (crop): LSUN 58 | # LSUN (resize): LSUN_resize 59 | # iSUN: iSUN 60 | # Gaussian noise: Gaussian 61 | # Uniform noise: Uniform 62 | #dataName = "Imagenet" 63 | 64 | 65 | # Setting the perturbation magnitude 66 | #epsilon = 0.0014 67 | 68 | # Setting the temperature 69 | #temperature = 1000 70 | def main(): 71 | global args 72 | args = parser.parse_args() 73 | c.test(args.nn, args.out_dataset, args.gpu, args.magnitude, args.temperature) 74 | 75 | if __name__ == '__main__': 76 | main() 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /code/softmax_scores/.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | -------------------------------------------------------------------------------- /code/softmax_scores/confidence_Base_In.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/softmax_scores/confidence_Base_Out.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/softmax_scores/confidence_Our_In.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/softmax_scores/confidence_Our_Out.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | def forward(self, x): 23 | if not self.equalInOut: 24 | x = self.relu1(self.bn1(x)) 25 | else: 26 | out = self.relu1(self.bn1(x)) 27 | out = self.conv1(self.equalInOut and out or x) 28 | if self.droprate > 0: 29 | out = F.dropout(out, p=self.droprate, training=self.training) 30 | out = self.conv2(self.relu2(self.bn2(out))) 31 | return torch.add((not self.equalInOut) and self.convShortcut(x) or x, out) 32 | 33 | class NetworkBlock(nn.Module): 34 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 35 | super(NetworkBlock, self).__init__() 36 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 37 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 38 | layers = [] 39 | for i in range(nb_layers): 40 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 41 | return nn.Sequential(*layers) 42 | def forward(self, x): 43 | return self.layer(x) 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 47 | super(WideResNet, self).__init__() 48 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 49 | assert((depth - 4) % 6 == 0) 50 | n = (depth - 4) / 6 51 | block = BasicBlock 52 | # 1st conv before any network block 53 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 54 | padding=1, bias=False) 55 | # 1st block 56 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 57 | # 2nd block 58 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 59 | # 3rd block 60 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 61 | # global average pooling and classifier 62 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.fc = nn.Linear(nChannels[3], num_classes) 65 | self.nChannels = nChannels[3] 66 | 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv2d): 69 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 70 | m.weight.data.normal_(0, math.sqrt(2. / n)) 71 | elif isinstance(m, nn.BatchNorm2d): 72 | m.weight.data.fill_(1) 73 | m.bias.data.zero_() 74 | elif isinstance(m, nn.Linear): 75 | m.bias.data.zero_() 76 | def forward(self, x): 77 | out = self.conv1(x) 78 | out = self.block1(out) 79 | out = self.block2(out) 80 | out = self.block3(out) 81 | out = self.relu(self.bn1(out)) 82 | out = F.avg_pool2d(out, 8) 83 | out = out.view(-1, self.nChannels) 84 | return self.fc(out) 85 | -------------------------------------------------------------------------------- /figures/original_optimal_shade.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiyuLiang/odin-pytorch/34e53f5a982811a0d74baba049538d34efc0732d/figures/original_optimal_shade.png -------------------------------------------------------------------------------- /figures/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiyuLiang/odin-pytorch/34e53f5a982811a0d74baba049538d34efc0732d/figures/performance.png --------------------------------------------------------------------------------