├── README.md
├── code
├── cal.py
├── calData.py
├── calMetric.py
├── densenet.py
├── main.py
├── softmax_scores
│ ├── .gitignore
│ ├── confidence_Base_In.txt
│ ├── confidence_Base_Out.txt
│ ├── confidence_Our_In.txt
│ └── confidence_Our_Out.txt
└── wideresnet.py
└── figures
├── original_optimal_shade.png
└── performance.png
/README.md:
--------------------------------------------------------------------------------
1 | # ODIN: Out-of-Distribution Detector for Neural Networks
2 |
3 |
4 | This is a [PyTorch](http://pytorch.org) implementation for detecting out-of-distribution examples in neural networks. The method is described in the paper [Principled Detection of Out-of-Distribution Examples in Neural Networks](https://arxiv.org/abs/1706.02690) by S. Liang, [Yixuan Li](http://www.cs.cornell.edu/~yli) and [R. Srikant](https://sites.google.com/a/illinois.edu/srikant/). The method reduces the false positive rate from the baseline 34.7% to 4.3% on the DenseNet (applied to CIFAR-10) when the true positive rate is 95%.
5 |
6 |
7 |
8 |
9 |
10 | ## Experimental Results
11 |
12 | We used two neural network architectures, [DenseNet-BC](https://arxiv.org/abs/1608.06993) and [Wide ResNet](https://arxiv.org/abs/1605.07146).
13 | The PyTorch implementation of [DenseNet-BC](https://arxiv.org/abs/1608.06993) is provided by [Andreas Veit](https://github.com/andreasveit/densenet-pytorch) and [Brandon Amos](https://github.com/bamos/densenet.pytorch). The PyTorch implementation of [Wide ResNet](https://arxiv.org/abs/1605.07146) is provided by [Sergey Zagoruyko](https://github.com/szagoruyko/wide-residual-networks).
14 | The experimental results are shown as follows. The definition of each metric can be found in the [paper]().
15 | 
16 |
17 |
18 |
19 | ## Pre-trained Models
20 |
21 | We provide four pre-trained neural networks: (1) two [DenseNet-BC](https://arxiv.org/abs/1608.06993) networks trained on CIFAR-10 and CIFAR-100 respectively; (2) two [Wide ResNet](https://arxiv.org/abs/1605.07146) networks trained on CIFAR-10 and CIFAR-100 respectively. The test error rates are given by:
22 |
23 | Architecture | CIFAR-10 | CIFAR-100
24 | ------------ | --------- | ---------
25 | DenseNet-BC | 4.81 | 22.37
26 | Wide ResNet | 3.71 | 19.86
27 |
28 |
29 | ## Running the code
30 |
31 | ### Dependencies
32 |
33 | * CUDA 8.0
34 | * PyTorch
35 | * Anaconda2 or 3
36 | * At least **three** GPU
37 |
38 | **Note:** Reproducing results of DenseNet-BC only requires **one** GPU, but reproducing results of Wide ResNet requires **three** GPUs. Single GPU version for Wide ResNet will be released soon in the future.
39 |
40 | ### Downloading Out-of-Distribtion Datasets
41 | We provide download links of five out-of-distributin datasets:
42 |
43 | * **[Tiny-ImageNet (crop)](https://www.dropbox.com/s/avgm2u562itwpkl/Imagenet.tar.gz)**
44 | * **[Tiny-ImageNet (resize)](https://www.dropbox.com/s/kp3my3412u5k9rl/Imagenet_resize.tar.gz)**
45 | * **[LSUN (crop)](https://www.dropbox.com/s/fhtsw1m3qxlwj6h/LSUN.tar.gz)**
46 | * **[LSUN (resize)](https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz)**
47 | * **[iSUN](https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz)**
48 |
49 | Here is an example code of downloading Tiny-ImageNet (crop) dataset. In the **root** directory, run
50 |
51 | ```
52 | mkdir data
53 | cd data
54 | wget https://www.dropbox.com/s/avgm2u562itwpkl/Imagenet.tar.gz
55 | tar -xvzf Imagenet.tar.gz
56 | cd ..
57 | ```
58 |
59 | ### Downloading Neural Network Models
60 |
61 | We provide download links of four pre-trained models.
62 |
63 | * **[DenseNet-BC trained on CIFAR-10](https://www.dropbox.com/s/wr4kjintq1tmorr/densenet10.pth.tar.gz)**
64 | * **[DenseNet-BC trained on CIFAR-100](https://www.dropbox.com/s/vxuv11jjg8bw2v9/densenet100.pth.tar.gz)**
65 | * **[Wide ResNet trained on CIFAR-10](https://www.dropbox.com/s/uiye5nw0uj6ie53/wideresnet10.pth.tar.gz)**
66 | * **[Wide ResNet trained on CIFAR-100](https://www.dropbox.com/s/uiye5nw0uj6ie53/wideresnet100.pth.tar.gz)**
67 |
68 | Here is an example code of downloading DenseNet-BC trained on CIFAR-10. In the **root** directory, run
69 |
70 | ```
71 | mkdir models
72 | cd models
73 | wget https://www.dropbox.com/s/wr4kjintq1tmorr/densenet10.pth.tar.gz
74 | tar -xvzf densenet10.pth.tar.gz
75 | cd ..
76 | ```
77 |
78 |
79 | ### Running
80 |
81 | Here is an example code reproducing the results of DenseNet-BC trained on CIFAR-10 where TinyImageNet (crop) is the out-of-distribution dataset. The temperature is set as 1000, and perturbation magnitude is set as 0.0014. In the **root** directory, run
82 |
83 | ```
84 | cd code
85 | # model: DenseNet-BC, in-distribution: CIFAR-10, out-distribution: TinyImageNet (crop)
86 | # magnitude: 0.0014, temperature 1000, gpu: 0
87 | python main.py --nn densenet10 --out_dataset Imagenet --magnitude 0.0014 --temperature 1000 --gpu 0
88 | ```
89 | **Note:** Please choose arguments according to the following.
90 |
91 | #### args
92 | * **args.nn**: the arguments of neural networks are shown as follows
93 |
94 | Nerual Network Models | args.nn
95 | ----------------------|--------
96 | DenseNet-BC trained on CIFAR-10| densenet10
97 | DenseNet-BC trained on CIFAR-100| densenet100
98 | * **args.out_dataset**: the arguments of out-of-distribution datasets are shown as follows
99 |
100 | Out-of-Distribution Datasets | args.out_dataset
101 | ------------------------------------|-----------------
102 | Tiny-ImageNet (crop) | Imagenet
103 | Tiny-ImageNet (resize) | Imagenet_resize
104 | LSUN (crop) | LSUN
105 | LSUN (resize) | LSUN_resize
106 | iSUN | iSUN
107 | Uniform random noise | Uniform
108 | Gaussian random noise | Gaussian
109 |
110 | * **args.magnitude**: the optimal noise magnitude can be found below. In practice, the optimal choices of noise magnitude are model-specific and need to be tuned accordingly.
111 |
112 | Out-of-Distribution Datasets | densenet10 | densenet100 | wideresnet10 | wideresnet100
113 | ------------------------------------|------------------|------------- | -------------- |--------------
114 | Tiny-ImageNet (crop) | 0.0014 | 0.0014 | 0.0005 | 0.0028
115 | Tiny-ImageNet (resize) | 0.0014 | 0.0028 | 0.0011 | 0.0028
116 | LSUN (crop) | 0 | 0.0028 | 0 | 0.0048
117 | LSUN (resize) | 0.0014 | 0.0028 | 0.0006 | 0.002
118 | iSUN | 0.0014 | 0.0028 | 0.0008 | 0.0028
119 | Uniform random noise | 0.0014 | 0.0028 | 0.0014 | 0.0028
120 | Gaussian random noise | 0.0014 |0.0028 | 0.0014 | 0.0028
121 |
122 | * **args.temperature**: temperature is set to 1000 in all cases.
123 | * **args.gpu**: make sure you use the following gpu when running the code:
124 |
125 | Neural Network Models | args.gpu
126 | ----------------------|----------
127 | densenet10 | 0
128 | densenet100 | 0
129 | wideresnet10 | 1
130 | wideresnet100 | 2
131 |
132 | ### Outputs
133 | Here is an example of output.
134 |
135 | ```
136 | Neural network architecture: DenseNet-BC-100
137 | In-distribution dataset: CIFAR-10
138 | Out-of-distribution dataset: Tiny-ImageNet (crop)
139 |
140 | Baseline Our Method
141 | FPR at TPR 95%: 34.8% 4.3%
142 | Detection error: 9.9% 4.6%
143 | AUROC: 95.3% 99.1%
144 | AUPR In: 96.4% 99.2%
145 | AUPR Out: 93.8% 99.1%
146 | ```
147 |
--------------------------------------------------------------------------------
/code/cal.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Sep 19 20:55:56 2015
4 |
5 | @author: liangshiyu
6 | """
7 |
8 | from __future__ import print_function
9 | import torch
10 | from torch.autograd import Variable
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import numpy as np
14 | import torch.optim as optim
15 | import torchvision
16 | import torchvision.transforms as transforms
17 | import numpy as np
18 | import time
19 | from scipy import misc
20 | import calMetric as m
21 | import calData as d
22 | #CUDA_DEVICE = 0
23 |
24 | start = time.time()
25 | #loading data sets
26 |
27 | transform = transforms.Compose([
28 | transforms.ToTensor(),
29 | transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)),
30 | ])
31 |
32 |
33 |
34 |
35 | # loading neural network
36 |
37 | # Name of neural networks
38 | # Densenet trained on CIFAR-10: densenet10
39 | # Densenet trained on CIFAR-100: densenet100
40 | # Densenet trained on WideResNet-10: wideresnet10
41 | # Densenet trained on WideResNet-100: wideresnet100
42 | #nnName = "densenet10"
43 |
44 | #imName = "Imagenet"
45 |
46 |
47 |
48 | criterion = nn.CrossEntropyLoss()
49 |
50 |
51 |
52 | def test(nnName, dataName, CUDA_DEVICE, epsilon, temperature):
53 |
54 | net1 = torch.load("../models/{}.pth".format(nnName))
55 | optimizer1 = optim.SGD(net1.parameters(), lr = 0, momentum = 0)
56 | net1.cuda(CUDA_DEVICE)
57 |
58 | if dataName != "Uniform" and dataName != "Gaussian":
59 | testsetout = torchvision.datasets.ImageFolder("../data/{}".format(dataName), transform=transform)
60 | testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=1,
61 | shuffle=False, num_workers=2)
62 |
63 | if nnName == "densenet10" or nnName == "wideresnet10":
64 | testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)
65 | testloaderIn = torch.utils.data.DataLoader(testset, batch_size=1,
66 | shuffle=False, num_workers=2)
67 | if nnName == "densenet100" or nnName == "wideresnet100":
68 | testset = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform)
69 | testloaderIn = torch.utils.data.DataLoader(testset, batch_size=1,
70 | shuffle=False, num_workers=2)
71 |
72 | if dataName == "Gaussian":
73 | d.testGaussian(net1, criterion, CUDA_DEVICE, testloaderIn, testloaderIn, nnName, dataName, epsilon, temperature)
74 | m.metric(nnName, dataName)
75 |
76 | elif dataName == "Uniform":
77 | d.testUni(net1, criterion, CUDA_DEVICE, testloaderIn, testloaderIn, nnName, dataName, epsilon, temperature)
78 | m.metric(nnName, dataName)
79 | else:
80 | d.testData(net1, criterion, CUDA_DEVICE, testloaderIn, testloaderOut, nnName, dataName, epsilon, temperature)
81 | m.metric(nnName, dataName)
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/code/calData.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Sep 19 20:55:56 2015
4 |
5 | @author: liangshiyu
6 | """
7 |
8 | from __future__ import print_function
9 | import torch
10 | from torch.autograd import Variable
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import numpy as np
14 | import torch.optim as optim
15 | import torchvision
16 | import torchvision.transforms as transforms
17 | import numpy as np
18 | import time
19 | from scipy import misc
20 |
21 | def testData(net1, criterion, CUDA_DEVICE, testloader10, testloader, nnName, dataName, noiseMagnitude1, temper):
22 | t0 = time.time()
23 | f1 = open("./softmax_scores/confidence_Base_In.txt", 'w')
24 | f2 = open("./softmax_scores/confidence_Base_Out.txt", 'w')
25 | g1 = open("./softmax_scores/confidence_Our_In.txt", 'w')
26 | g2 = open("./softmax_scores/confidence_Our_Out.txt", 'w')
27 | N = 10000
28 | if dataName == "iSUN": N = 8925
29 | print("Processing in-distribution images")
30 | ########################################In-distribution###########################################
31 | for j, data in enumerate(testloader10):
32 | if j<1000: continue
33 | images, _ = data
34 |
35 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True)
36 | outputs = net1(inputs)
37 |
38 |
39 | # Calculating the confidence of the output, no perturbation added here, no temperature scaling used
40 | nnOutputs = outputs.data.cpu()
41 | nnOutputs = nnOutputs.numpy()
42 | nnOutputs = nnOutputs[0]
43 | nnOutputs = nnOutputs - np.max(nnOutputs)
44 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
45 | f1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
46 |
47 | # Using temperature scaling
48 | outputs = outputs / temper
49 |
50 | # Calculating the perturbation we need to add, that is,
51 | # the sign of gradient of cross entropy loss w.r.t. input
52 | maxIndexTemp = np.argmax(nnOutputs)
53 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE))
54 | loss = criterion(outputs, labels)
55 | loss.backward()
56 |
57 | # Normalizing the gradient to binary in {0, 1}
58 | gradient = torch.ge(inputs.grad.data, 0)
59 | gradient = (gradient.float() - 0.5) * 2
60 | # Normalizing the gradient to the same space of image
61 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0)
62 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0)
63 | gradient[0][2] = (gradient[0][2])/(66.7/255.0)
64 | # Adding small perturbations to images
65 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient)
66 | outputs = net1(Variable(tempInputs))
67 | outputs = outputs / temper
68 | # Calculating the confidence after adding perturbations
69 | nnOutputs = outputs.data.cpu()
70 | nnOutputs = nnOutputs.numpy()
71 | nnOutputs = nnOutputs[0]
72 | nnOutputs = nnOutputs - np.max(nnOutputs)
73 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
74 | g1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
75 | if j % 100 == 99:
76 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0))
77 | t0 = time.time()
78 |
79 | if j == N - 1: break
80 |
81 |
82 | t0 = time.time()
83 | print("Processing out-of-distribution images")
84 | ###################################Out-of-Distributions#####################################
85 | for j, data in enumerate(testloader):
86 | if j<1000: continue
87 | images, _ = data
88 |
89 |
90 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True)
91 | outputs = net1(inputs)
92 |
93 |
94 |
95 | # Calculating the confidence of the output, no perturbation added here
96 | nnOutputs = outputs.data.cpu()
97 | nnOutputs = nnOutputs.numpy()
98 | nnOutputs = nnOutputs[0]
99 | nnOutputs = nnOutputs - np.max(nnOutputs)
100 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
101 | f2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
102 |
103 | # Using temperature scaling
104 | outputs = outputs / temper
105 |
106 |
107 | # Calculating the perturbation we need to add, that is,
108 | # the sign of gradient of cross entropy loss w.r.t. input
109 | maxIndexTemp = np.argmax(nnOutputs)
110 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE))
111 | loss = criterion(outputs, labels)
112 | loss.backward()
113 |
114 | # Normalizing the gradient to binary in {0, 1}
115 | gradient = (torch.ge(inputs.grad.data, 0))
116 | gradient = (gradient.float() - 0.5) * 2
117 | # Normalizing the gradient to the same space of image
118 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0)
119 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0)
120 | gradient[0][2] = (gradient[0][2])/(66.7/255.0)
121 | # Adding small perturbations to images
122 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient)
123 | outputs = net1(Variable(tempInputs))
124 | outputs = outputs / temper
125 | # Calculating the confidence after adding perturbations
126 | nnOutputs = outputs.data.cpu()
127 | nnOutputs = nnOutputs.numpy()
128 | nnOutputs = nnOutputs[0]
129 | nnOutputs = nnOutputs - np.max(nnOutputs)
130 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
131 | g2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
132 | if j % 100 == 99:
133 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0))
134 | t0 = time.time()
135 |
136 | if j== N-1: break
137 |
138 |
139 |
140 |
141 | def testGaussian(net1, criterion, CUDA_DEVICE, testloader10, testloader, nnName, dataName, noiseMagnitude1, temper):
142 | t0 = time.time()
143 | f1 = open("./softmax_scores/confidence_Base_In.txt", 'w')
144 | f2 = open("./softmax_scores/confidence_Base_Out.txt", 'w')
145 | g1 = open("./softmax_scores/confidence_Our_In.txt", 'w')
146 | g2 = open("./softmax_scores/confidence_Our_Out.txt", 'w')
147 | ########################################In-Distribution###############################################
148 | N = 10000
149 | print("Processing in-distribution images")
150 | for j, data in enumerate(testloader10):
151 |
152 | if j<1000: continue
153 | images, _ = data
154 |
155 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True)
156 | outputs = net1(inputs)
157 |
158 |
159 | # Calculating the confidence of the output, no perturbation added here
160 | nnOutputs = outputs.data.cpu()
161 | nnOutputs = nnOutputs.numpy()
162 | nnOutputs = nnOutputs[0]
163 | nnOutputs = nnOutputs - np.max(nnOutputs)
164 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
165 | f1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
166 |
167 | # Using temperature scaling
168 | outputs = outputs / temper
169 |
170 | # Calculating the perturbation we need to add, that is,
171 | # the sign of gradient of cross entropy loss w.r.t. input
172 | maxIndexTemp = np.argmax(nnOutputs)
173 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE))
174 | loss = criterion(outputs, labels)
175 | loss.backward()
176 |
177 |
178 | # Normalizing the gradient to binary in {0, 1}
179 | gradient = (torch.ge(inputs.grad.data, 0))
180 | gradient = (gradient.float() - 0.5) * 2
181 | # Normalizing the gradient to the same space of image
182 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0)
183 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0)
184 | gradient[0][2] = (gradient[0][2])/(66.7/255.0)
185 | # Adding small perturbations to images
186 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient)
187 | outputs = net1(Variable(tempInputs))
188 | outputs = outputs / temper
189 | # Calculating the confidence after adding perturbations
190 | nnOutputs = outputs.data.cpu()
191 | nnOutputs = nnOutputs.numpy()
192 | nnOutputs = nnOutputs[0]
193 | nnOutputs = nnOutputs - np.max(nnOutputs)
194 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
195 |
196 | g1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
197 | if j % 100 == 99:
198 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0))
199 | t0 = time.time()
200 |
201 |
202 |
203 | ########################################Out-of-Distribution######################################
204 | print("Processing out-of-distribution images")
205 | for j, data in enumerate(testloader):
206 | if j<1000: continue
207 |
208 | images = torch.randn(1,3,32,32) + 0.5
209 | images = torch.clamp(images, 0, 1)
210 | images[0][0] = (images[0][0] - 125.3/255) / (63.0/255)
211 | images[0][1] = (images[0][1] - 123.0/255) / (62.1/255)
212 | images[0][2] = (images[0][2] - 113.9/255) / (66.7/255)
213 |
214 |
215 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True)
216 | outputs = net1(inputs)
217 |
218 |
219 |
220 | # Calculating the confidence of the output, no perturbation added here
221 | nnOutputs = outputs.data.cpu()
222 | nnOutputs = nnOutputs.numpy()
223 | nnOutputs = nnOutputs[0]
224 | nnOutputs = nnOutputs - np.max(nnOutputs)
225 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
226 | f2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
227 |
228 | # Using temperature scaling
229 | outputs = outputs / temper
230 |
231 | # Calculating the perturbation we need to add, that is,
232 | # the sign of gradient of cross entropy loss w.r.t. input
233 | maxIndexTemp = np.argmax(nnOutputs)
234 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE))
235 | loss = criterion(outputs, labels)
236 | loss.backward()
237 |
238 | # Normalizing the gradient to binary in {0, 1}
239 | gradient = (torch.ge(inputs.grad.data, 0))
240 | gradient = (gradient.float() - 0.5) * 2
241 | # Normalizing the gradient to the same space of image
242 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0)
243 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0)
244 | gradient[0][2] = (gradient[0][2])/(66.7/255.0)
245 | # Adding small perturbations to images
246 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient)
247 | outputs = net1(Variable(tempInputs))
248 | outputs = outputs / temper
249 | # Calculating the confidence after adding perturbations
250 | nnOutputs = outputs.data.cpu()
251 | nnOutputs = nnOutputs.numpy()
252 | nnOutputs = nnOutputs[0]
253 | nnOutputs = nnOutputs - np.max(nnOutputs)
254 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
255 | g2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
256 |
257 | if j % 100 == 99:
258 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0))
259 | t0 = time.time()
260 |
261 | if j== N-1: break
262 |
263 |
264 |
265 |
266 | def testUni(net1, criterion, CUDA_DEVICE, testloader10, testloader, nnName, dataName, noiseMagnitude1, temper):
267 | t0 = time.time()
268 | f1 = open("./softmax_scores/confidence_Base_In.txt", 'w')
269 | f2 = open("./softmax_scores/confidence_Base_Out.txt", 'w')
270 | g1 = open("./softmax_scores/confidence_Our_In.txt", 'w')
271 | g2 = open("./softmax_scores/confidence_Our_Out.txt", 'w')
272 | ########################################In-Distribution###############################################
273 | N = 10000
274 | print("Processing in-distribution images")
275 | for j, data in enumerate(testloader10):
276 | if j<1000: continue
277 |
278 | images, _ = data
279 |
280 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True)
281 | outputs = net1(inputs)
282 |
283 |
284 | # Calculating the confidence of the output, no perturbation added here
285 | nnOutputs = outputs.data.cpu()
286 | nnOutputs = nnOutputs.numpy()
287 | nnOutputs = nnOutputs[0]
288 | nnOutputs = nnOutputs - np.max(nnOutputs)
289 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
290 | f1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
291 |
292 | # Using temperature scaling
293 | outputs = outputs / temper
294 |
295 | # Calculating the perturbation we need to add, that is,
296 | # the sign of gradient of cross entropy loss w.r.t. input
297 | maxIndexTemp = np.argmax(nnOutputs)
298 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE))
299 | loss = criterion(outputs, labels)
300 | loss.backward()
301 |
302 |
303 | # Normalizing the gradient to binary in {0, 1}
304 | gradient = (torch.ge(inputs.grad.data, 0))
305 | gradient = (gradient.float() - 0.5) * 2
306 | # Normalizing the gradient to the same space of image
307 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0)
308 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0)
309 | gradient[0][2] = (gradient[0][2])/(66.7/255.0)
310 | # Adding small perturbations to images
311 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient)
312 | outputs = net1(Variable(tempInputs))
313 | outputs = outputs / temper
314 | # Calculating the confidence after adding perturbations
315 | nnOutputs = outputs.data.cpu()
316 | nnOutputs = nnOutputs.numpy()
317 | nnOutputs = nnOutputs[0]
318 | nnOutputs = nnOutputs - np.max(nnOutputs)
319 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
320 |
321 | g1.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
322 | if j % 100 == 99:
323 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0))
324 | t0 = time.time()
325 |
326 |
327 |
328 | ########################################Out-of-Distribution######################################
329 | print("Processing out-of-distribution images")
330 | for j, data in enumerate(testloader):
331 | if j<1000: continue
332 |
333 | images = torch.rand(1,3,32,32)
334 | images[0][0] = (images[0][0] - 125.3/255) / (63.0/255)
335 | images[0][1] = (images[0][1] - 123.0/255) / (62.1/255)
336 | images[0][2] = (images[0][2] - 113.9/255) / (66.7/255)
337 |
338 |
339 | inputs = Variable(images.cuda(CUDA_DEVICE), requires_grad = True)
340 | outputs = net1(inputs)
341 |
342 |
343 |
344 | # Calculating the confidence of the output, no perturbation added here
345 | nnOutputs = outputs.data.cpu()
346 | nnOutputs = nnOutputs.numpy()
347 | nnOutputs = nnOutputs[0]
348 | nnOutputs = nnOutputs - np.max(nnOutputs)
349 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
350 | f2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
351 |
352 | # Using temperature scaling
353 | outputs = outputs / temper
354 |
355 | # Calculating the perturbation we need to add, that is,
356 | # the sign of gradient of cross entropy loss w.r.t. input
357 | maxIndexTemp = np.argmax(nnOutputs)
358 | labels = Variable(torch.LongTensor([maxIndexTemp]).cuda(CUDA_DEVICE))
359 | loss = criterion(outputs, labels)
360 | loss.backward()
361 |
362 | # Normalizing the gradient to binary in {0, 1}
363 | gradient = (torch.ge(inputs.grad.data, 0))
364 | gradient = (gradient.float() - 0.5) * 2
365 | # Normalizing the gradient to the same space of image
366 | gradient[0][0] = (gradient[0][0] )/(63.0/255.0)
367 | gradient[0][1] = (gradient[0][1] )/(62.1/255.0)
368 | gradient[0][2] = (gradient[0][2])/(66.7/255.0)
369 | # Adding small perturbations to images
370 | tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient)
371 | outputs = net1(Variable(tempInputs))
372 | outputs = outputs / temper
373 | # Calculating the confidence after adding perturbations
374 | nnOutputs = outputs.data.cpu()
375 | nnOutputs = nnOutputs.numpy()
376 | nnOutputs = nnOutputs[0]
377 | nnOutputs = nnOutputs - np.max(nnOutputs)
378 | nnOutputs = np.exp(nnOutputs)/np.sum(np.exp(nnOutputs))
379 | g2.write("{}, {}, {}\n".format(temper, noiseMagnitude1, np.max(nnOutputs)))
380 | if j % 100 == 99:
381 | print("{:4}/{:4} images processed, {:.1f} seconds used.".format(j+1-1000, N-1000, time.time()-t0))
382 | t0 = time.time()
383 |
384 | if j== N-1: break
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
--------------------------------------------------------------------------------
/code/calMetric.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Sep 19 20:55:56 2015
4 |
5 | @author: liangshiyu
6 | """
7 |
8 | from __future__ import print_function
9 | import torch
10 | from torch.autograd import Variable
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import numpy as np
14 | import torch.optim as optim
15 | import torchvision
16 | import torchvision.transforms as transforms
17 | #import matplotlib.pyplot as plt
18 | import numpy as np
19 | import time
20 | from scipy import misc
21 |
22 |
23 | def tpr95(name):
24 | #calculate the falsepositive error when tpr is 95%
25 | # calculate baseline
26 | T = 1
27 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',')
28 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',')
29 | if name == "CIFAR-10":
30 | start = 0.1
31 | end = 1
32 | if name == "CIFAR-100":
33 | start = 0.01
34 | end = 1
35 | gap = (end- start)/100000
36 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
37 | Y1 = other[:, 2]
38 | X1 = cifar[:, 2]
39 | total = 0.0
40 | fpr = 0.0
41 | for delta in np.arange(start, end, gap):
42 | tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1))
43 | error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
44 | if tpr <= 0.9505 and tpr >= 0.9495:
45 | fpr += error2
46 | total += 1
47 | fprBase = fpr/total
48 |
49 | # calculate our algorithm
50 | T = 1000
51 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',')
52 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',')
53 | if name == "CIFAR-10":
54 | start = 0.1
55 | end = 0.12
56 | if name == "CIFAR-100":
57 | start = 0.01
58 | end = 0.0104
59 | gap = (end- start)/100000
60 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
61 | Y1 = other[:, 2]
62 | X1 = cifar[:, 2]
63 | total = 0.0
64 | fpr = 0.0
65 | for delta in np.arange(start, end, gap):
66 | tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1))
67 | error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
68 | if tpr <= 0.9505 and tpr >= 0.9495:
69 | fpr += error2
70 | total += 1
71 | fprNew = fpr/total
72 |
73 | return fprBase, fprNew
74 |
75 | def auroc(name):
76 | #calculate the AUROC
77 | # calculate baseline
78 | T = 1
79 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',')
80 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',')
81 | if name == "CIFAR-10":
82 | start = 0.1
83 | end = 1
84 | if name == "CIFAR-100":
85 | start = 0.01
86 | end = 1
87 | gap = (end- start)/100000
88 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
89 | Y1 = other[:, 2]
90 | X1 = cifar[:, 2]
91 | aurocBase = 0.0
92 | fprTemp = 1.0
93 | for delta in np.arange(start, end, gap):
94 | tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1))
95 | fpr = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
96 | aurocBase += (-fpr+fprTemp)*tpr
97 | fprTemp = fpr
98 | aurocBase += fpr * tpr
99 | # calculate our algorithm
100 | T = 1000
101 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',')
102 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',')
103 | if name == "CIFAR-10":
104 | start = 0.1
105 | end = 0.12
106 | if name == "CIFAR-100":
107 | start = 0.01
108 | end = 0.0104
109 | gap = (end- start)/100000
110 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
111 | Y1 = other[:, 2]
112 | X1 = cifar[:, 2]
113 | aurocNew = 0.0
114 | fprTemp = 1.0
115 | for delta in np.arange(start, end, gap):
116 | tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1))
117 | fpr = np.sum(np.sum(Y1 >= delta)) / np.float(len(Y1))
118 | aurocNew += (-fpr+fprTemp)*tpr
119 | fprTemp = fpr
120 | aurocNew += fpr * tpr
121 | return aurocBase, aurocNew
122 |
123 | def auprIn(name):
124 | #calculate the AUPR
125 | # calculate baseline
126 | T = 1
127 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',')
128 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',')
129 | if name == "CIFAR-10":
130 | start = 0.1
131 | end = 1
132 | if name == "CIFAR-100":
133 | start = 0.01
134 | end = 1
135 | gap = (end- start)/100000
136 | precisionVec = []
137 | recallVec = []
138 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
139 | Y1 = other[:, 2]
140 | X1 = cifar[:, 2]
141 | auprBase = 0.0
142 | recallTemp = 1.0
143 | for delta in np.arange(start, end, gap):
144 | tp = np.sum(np.sum(X1 >= delta)) / np.float(len(X1))
145 | fp = np.sum(np.sum(Y1 >= delta)) / np.float(len(Y1))
146 | if tp + fp == 0: continue
147 | precision = tp / (tp + fp)
148 | recall = tp
149 | precisionVec.append(precision)
150 | recallVec.append(recall)
151 | auprBase += (recallTemp-recall)*precision
152 | recallTemp = recall
153 | auprBase += recall * precision
154 | #print(recall, precision)
155 |
156 | # calculate our algorithm
157 | T = 1000
158 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',')
159 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',')
160 | if name == "CIFAR-10":
161 | start = 0.1
162 | end = 0.12
163 | if name == "CIFAR-100":
164 | start = 0.01
165 | end = 0.0104
166 | gap = (end- start)/100000
167 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
168 | Y1 = other[:, 2]
169 | X1 = cifar[:, 2]
170 | auprNew = 0.0
171 | recallTemp = 1.0
172 | for delta in np.arange(start, end, gap):
173 | tp = np.sum(np.sum(X1 >= delta)) / np.float(len(X1))
174 | fp = np.sum(np.sum(Y1 >= delta)) / np.float(len(Y1))
175 | if tp + fp == 0: continue
176 | precision = tp / (tp + fp)
177 | recall = tp
178 | #precisionVec.append(precision)
179 | #recallVec.append(recall)
180 | auprNew += (recallTemp-recall)*precision
181 | recallTemp = recall
182 | auprNew += recall * precision
183 | return auprBase, auprNew
184 |
185 | def auprOut(name):
186 | #calculate the AUPR
187 | # calculate baseline
188 | T = 1
189 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',')
190 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',')
191 | if name == "CIFAR-10":
192 | start = 0.1
193 | end = 1
194 | if name == "CIFAR-100":
195 | start = 0.01
196 | end = 1
197 | gap = (end- start)/100000
198 | Y1 = other[:, 2]
199 | X1 = cifar[:, 2]
200 | auprBase = 0.0
201 | recallTemp = 1.0
202 | for delta in np.arange(end, start, -gap):
203 | fp = np.sum(np.sum(X1 < delta)) / np.float(len(X1))
204 | tp = np.sum(np.sum(Y1 < delta)) / np.float(len(Y1))
205 | if tp + fp == 0: break
206 | precision = tp / (tp + fp)
207 | recall = tp
208 | auprBase += (recallTemp-recall)*precision
209 | recallTemp = recall
210 | auprBase += recall * precision
211 |
212 |
213 | # calculate our algorithm
214 | T = 1000
215 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',')
216 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',')
217 | if name == "CIFAR-10":
218 | start = 0.1
219 | end = 0.12
220 | if name == "CIFAR-100":
221 | start = 0.01
222 | end = 0.0104
223 | gap = (end- start)/100000
224 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
225 | Y1 = other[:, 2]
226 | X1 = cifar[:, 2]
227 | auprNew = 0.0
228 | recallTemp = 1.0
229 | for delta in np.arange(end, start, -gap):
230 | fp = np.sum(np.sum(X1 < delta)) / np.float(len(X1))
231 | tp = np.sum(np.sum(Y1 < delta)) / np.float(len(Y1))
232 | if tp + fp == 0: break
233 | precision = tp / (tp + fp)
234 | recall = tp
235 | auprNew += (recallTemp-recall)*precision
236 | recallTemp = recall
237 | auprNew += recall * precision
238 | return auprBase, auprNew
239 |
240 |
241 |
242 | def detection(name):
243 | #calculate the minimum detection error
244 | # calculate baseline
245 | T = 1
246 | cifar = np.loadtxt('./softmax_scores/confidence_Base_In.txt', delimiter=',')
247 | other = np.loadtxt('./softmax_scores/confidence_Base_Out.txt', delimiter=',')
248 | if name == "CIFAR-10":
249 | start = 0.1
250 | end = 1
251 | if name == "CIFAR-100":
252 | start = 0.01
253 | end = 1
254 | gap = (end- start)/100000
255 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
256 | Y1 = other[:, 2]
257 | X1 = cifar[:, 2]
258 | errorBase = 1.0
259 | for delta in np.arange(start, end, gap):
260 | tpr = np.sum(np.sum(X1 < delta)) / np.float(len(X1))
261 | error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
262 | errorBase = np.minimum(errorBase, (tpr+error2)/2.0)
263 |
264 | # calculate our algorithm
265 | T = 1000
266 | cifar = np.loadtxt('./softmax_scores/confidence_Our_In.txt', delimiter=',')
267 | other = np.loadtxt('./softmax_scores/confidence_Our_Out.txt', delimiter=',')
268 | if name == "CIFAR-10":
269 | start = 0.1
270 | end = 0.12
271 | if name == "CIFAR-100":
272 | start = 0.01
273 | end = 0.0104
274 | gap = (end- start)/100000
275 | #f = open("./{}/{}/T_{}.txt".format(nnName, dataName, T), 'w')
276 | Y1 = other[:, 2]
277 | X1 = cifar[:, 2]
278 | errorNew = 1.0
279 | for delta in np.arange(start, end, gap):
280 | tpr = np.sum(np.sum(X1 < delta)) / np.float(len(X1))
281 | error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
282 | errorNew = np.minimum(errorNew, (tpr+error2)/2.0)
283 |
284 | return errorBase, errorNew
285 |
286 |
287 |
288 |
289 | def metric(nn, data):
290 | if nn == "densenet10" or nn == "wideresnet10": indis = "CIFAR-10"
291 | if nn == "densenet100" or nn == "wideresnet100": indis = "CIFAR-100"
292 | if nn == "densenet10" or nn == "densenet100": nnStructure = "DenseNet-BC-100"
293 | if nn == "wideresnet10" or nn == "wideresnet100": nnStructure = "Wide-ResNet-28-10"
294 |
295 | if data == "Imagenet": dataName = "Tiny-ImageNet (crop)"
296 | if data == "Imagenet_resize": dataName = "Tiny-ImageNet (resize)"
297 | if data == "LSUN": dataName = "LSUN (crop)"
298 | if data == "LSUN_resize": dataName = "LSUN (resize)"
299 | if data == "iSUN": dataName = "iSUN"
300 | if data == "Gaussian": dataName = "Gaussian noise"
301 | if data == "Uniform": dataName = "Uniform Noise"
302 | fprBase, fprNew = tpr95(indis)
303 | errorBase, errorNew = detection(indis)
304 | aurocBase, aurocNew = auroc(indis)
305 | auprinBase, auprinNew = auprIn(indis)
306 | auproutBase, auproutNew = auprOut(indis)
307 | print("{:31}{:>22}".format("Neural network architecture:", nnStructure))
308 | print("{:31}{:>22}".format("In-distribution dataset:", indis))
309 | print("{:31}{:>22}".format("Out-of-distribution dataset:", dataName))
310 | print("")
311 | print("{:>34}{:>19}".format("Baseline", "Our Method"))
312 | print("{:20}{:13.1f}%{:>18.1f}% ".format("FPR at TPR 95%:",fprBase*100, fprNew*100))
313 | print("{:20}{:13.1f}%{:>18.1f}%".format("Detection error:",errorBase*100, errorNew*100))
314 | print("{:20}{:13.1f}%{:>18.1f}%".format("AUROC:",aurocBase*100, aurocNew*100))
315 | print("{:20}{:13.1f}%{:>18.1f}%".format("AUPR In:",auprinBase*100, auprinNew*100))
316 | print("{:20}{:13.1f}%{:>18.1f}%".format("AUPR Out:",auproutBase*100, auproutNew*100))
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
--------------------------------------------------------------------------------
/code/densenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
13 | padding=1, bias=False)
14 | self.droprate = dropRate
15 | def forward(self, x):
16 | out = self.conv1(self.relu(self.bn1(x)))
17 | if self.droprate > 0:
18 | out = F.dropout(out, p=self.droprate, training=self.training)
19 | return torch.cat([x, out], 1)
20 |
21 | class BottleneckBlock(nn.Module):
22 | def __init__(self, in_planes, out_planes, dropRate=0.0):
23 | super(BottleneckBlock, self).__init__()
24 | inter_planes = out_planes * 4
25 | self.bn1 = nn.BatchNorm2d(in_planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
28 | padding=0, bias=False)
29 | self.bn2 = nn.BatchNorm2d(inter_planes)
30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
31 | padding=1, bias=False)
32 | self.droprate = dropRate
33 | def forward(self, x):
34 | out = self.conv1(self.relu(self.bn1(x)))
35 | if self.droprate > 0:
36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
37 | out = self.conv2(self.relu(self.bn2(out)))
38 | if self.droprate > 0:
39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
40 | return torch.cat([x, out], 1)
41 |
42 | class TransitionBlock(nn.Module):
43 | def __init__(self, in_planes, out_planes, dropRate=0.0):
44 | super(TransitionBlock, self).__init__()
45 | self.bn1 = nn.BatchNorm2d(in_planes)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
48 | padding=0, bias=False)
49 | self.droprate = dropRate
50 | def forward(self, x):
51 | out = self.conv1(self.relu(self.bn1(x)))
52 | if self.droprate > 0:
53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
54 | return F.avg_pool2d(out, 2)
55 |
56 | class DenseBlock(nn.Module):
57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
58 | super(DenseBlock, self).__init__()
59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
61 | layers = []
62 | for i in range(nb_layers):
63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))
64 | return nn.Sequential(*layers)
65 | def forward(self, x):
66 | return self.layer(x)
67 |
68 | class DenseNet3(nn.Module):
69 | def __init__(self, depth, num_classes, growth_rate=12,
70 | reduction=0.5, bottleneck=True, dropRate=0.0):
71 | super(DenseNet3, self).__init__()
72 | in_planes = 2 * growth_rate
73 | n = (depth - 4) / 3
74 | if bottleneck == True:
75 | n = n/2
76 | block = BottleneckBlock
77 | else:
78 | block = BasicBlock
79 | # 1st conv before any dense block
80 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,
81 | padding=1, bias=False)
82 | # 1st block
83 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
84 | in_planes = int(in_planes+n*growth_rate)
85 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
86 | in_planes = int(math.floor(in_planes*reduction))
87 | # 2nd block
88 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
89 | in_planes = int(in_planes+n*growth_rate)
90 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
91 | in_planes = int(math.floor(in_planes*reduction))
92 | # 3rd block
93 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
94 | in_planes = int(in_planes+n*growth_rate)
95 | # global average pooling and classifier
96 | self.bn1 = nn.BatchNorm2d(in_planes)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.fc = nn.Linear(in_planes, num_classes)
99 | self.in_planes = in_planes
100 |
101 | for m in self.modules():
102 | if isinstance(m, nn.Conv2d):
103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
104 | m.weight.data.normal_(0, math.sqrt(2. / n))
105 | elif isinstance(m, nn.BatchNorm2d):
106 | m.weight.data.fill_(1)
107 | m.bias.data.zero_()
108 | elif isinstance(m, nn.Linear):
109 | m.bias.data.zero_()
110 | def forward(self, x):
111 | out = self.conv1(x)
112 | out = self.trans1(self.block1(out))
113 | out = self.trans2(self.block2(out))
114 | out = self.block3(out)
115 | out = self.relu(self.bn1(out))
116 | out = F.avg_pool2d(out, 8)
117 | out = out.view(-1, self.in_planes)
118 | return self.fc(out)
119 |
--------------------------------------------------------------------------------
/code/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Sat Sep 19 20:55:56 2015
4 |
5 | @author: liangshiyu
6 | """
7 |
8 | from __future__ import print_function
9 | import argparse
10 | import os
11 | import torch
12 | from torch.autograd import Variable
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import numpy as np
16 | import torch.optim as optim
17 | import torchvision
18 | import torchvision.transforms as transforms
19 | #import matplotlib.pyplot as plt
20 | import numpy as np
21 | import time
22 | #import lmdb
23 | from scipy import misc
24 | import cal as c
25 |
26 |
27 | parser = argparse.ArgumentParser(description='Pytorch Detecting Out-of-distribution examples in neural networks')
28 |
29 | parser.add_argument('--nn', default="densenet10", type=str,
30 | help='neural network name and training set')
31 | parser.add_argument('--out_dataset', default="Imagenet", type=str,
32 | help='out-of-distribution dataset')
33 | parser.add_argument('--magnitude', default=0.0014, type=float,
34 | help='perturbation magnitude')
35 | parser.add_argument('--temperature', default=1000, type=int,
36 | help='temperature scaling')
37 | parser.add_argument('--gpu', default = 0, type = int,
38 | help='gpu index')
39 | parser.set_defaults(argument=True)
40 |
41 |
42 |
43 |
44 |
45 | # Setting the name of neural networks
46 |
47 | # Densenet trained on CIFAR-10: densenet10
48 | # Densenet trained on CIFAR-100: densenet100
49 | # Wide-ResNet trained on CIFAR-10: wideresnet10
50 | # Wide-ResNet trained on CIFAR-100: wideresnet100
51 | #nnName = "densenet10"
52 |
53 | # Setting the name of the out-of-distribution dataset
54 |
55 | # Tiny-ImageNet (crop): Imagenet
56 | # Tiny-ImageNet (resize): Imagenet_resize
57 | # LSUN (crop): LSUN
58 | # LSUN (resize): LSUN_resize
59 | # iSUN: iSUN
60 | # Gaussian noise: Gaussian
61 | # Uniform noise: Uniform
62 | #dataName = "Imagenet"
63 |
64 |
65 | # Setting the perturbation magnitude
66 | #epsilon = 0.0014
67 |
68 | # Setting the temperature
69 | #temperature = 1000
70 | def main():
71 | global args
72 | args = parser.parse_args()
73 | c.test(args.nn, args.out_dataset, args.gpu, args.magnitude, args.temperature)
74 |
75 | if __name__ == '__main__':
76 | main()
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/code/softmax_scores/.gitignore:
--------------------------------------------------------------------------------
1 | *.txt
2 |
--------------------------------------------------------------------------------
/code/softmax_scores/confidence_Base_In.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/softmax_scores/confidence_Base_Out.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/softmax_scores/confidence_Our_In.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/softmax_scores/confidence_Our_Out.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/wideresnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 | def forward(self, x):
23 | if not self.equalInOut:
24 | x = self.relu1(self.bn1(x))
25 | else:
26 | out = self.relu1(self.bn1(x))
27 | out = self.conv1(self.equalInOut and out or x)
28 | if self.droprate > 0:
29 | out = F.dropout(out, p=self.droprate, training=self.training)
30 | out = self.conv2(self.relu2(self.bn2(out)))
31 | return torch.add((not self.equalInOut) and self.convShortcut(x) or x, out)
32 |
33 | class NetworkBlock(nn.Module):
34 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
35 | super(NetworkBlock, self).__init__()
36 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
37 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
38 | layers = []
39 | for i in range(nb_layers):
40 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
41 | return nn.Sequential(*layers)
42 | def forward(self, x):
43 | return self.layer(x)
44 |
45 | class WideResNet(nn.Module):
46 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
47 | super(WideResNet, self).__init__()
48 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
49 | assert((depth - 4) % 6 == 0)
50 | n = (depth - 4) / 6
51 | block = BasicBlock
52 | # 1st conv before any network block
53 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
54 | padding=1, bias=False)
55 | # 1st block
56 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
57 | # 2nd block
58 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
59 | # 3rd block
60 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
61 | # global average pooling and classifier
62 | self.bn1 = nn.BatchNorm2d(nChannels[3])
63 | self.relu = nn.ReLU(inplace=True)
64 | self.fc = nn.Linear(nChannels[3], num_classes)
65 | self.nChannels = nChannels[3]
66 |
67 | for m in self.modules():
68 | if isinstance(m, nn.Conv2d):
69 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
70 | m.weight.data.normal_(0, math.sqrt(2. / n))
71 | elif isinstance(m, nn.BatchNorm2d):
72 | m.weight.data.fill_(1)
73 | m.bias.data.zero_()
74 | elif isinstance(m, nn.Linear):
75 | m.bias.data.zero_()
76 | def forward(self, x):
77 | out = self.conv1(x)
78 | out = self.block1(out)
79 | out = self.block2(out)
80 | out = self.block3(out)
81 | out = self.relu(self.bn1(out))
82 | out = F.avg_pool2d(out, 8)
83 | out = out.view(-1, self.nChannels)
84 | return self.fc(out)
85 |
--------------------------------------------------------------------------------
/figures/original_optimal_shade.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiyuLiang/odin-pytorch/34e53f5a982811a0d74baba049538d34efc0732d/figures/original_optimal_shade.png
--------------------------------------------------------------------------------
/figures/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiyuLiang/odin-pytorch/34e53f5a982811a0d74baba049538d34efc0732d/figures/performance.png
--------------------------------------------------------------------------------