├── output
└── .gitignore
├── figure
└── deepmcdd.png
├── models
├── __init__.py
├── mlp.py
├── densenet.py
└── resnet.py
├── README.md
├── dataloader_table.py
├── train_deepmcdd_table.py
├── dataloader_image.py
├── utils.py
└── train_deepmcdd_image.py
/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/figure/deepmcdd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/donalee/DeepMCDD/HEAD/figure/deepmcdd.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .mlp import *
2 | from .resnet import *
3 | from .densenet import *
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-class Data Description (Deep-MCDD) for Out-of-distribution Detection
2 |
3 | This is the author code of ["Multi-class Data Description for Out-of-distribution Detection"](https://dl.acm.org/doi/abs/10.1145/3394486.3403189).
4 | Some codes are implemented based on [Deep Mahalanobis Detector](https://github.com/pokaxpoka/deep_Mahalanobis_detector).
5 |
6 | ## Overview
7 |
8 |
9 |
10 |
11 |
12 | ## Downloading tabular datasets
13 |
14 | The four multi-class tabular datsets reported in the paper can be downloaded from the below links.
15 | We provide the preprocessed version of the datasets, which are converted into the numpy array format, to allow users to easily load them.
16 | For more details of the datasets, please refer to the UCI repository.
17 | Place them in the directory `./table_data/`.
18 |
19 | - **GasSensor**: [[Numpy format]](http://di.postech.ac.kr/donalee/deepmcdd/gas_preproc.npy) [[Raw format]](https://archive.ics.uci.edu/ml/datasets/Gas+Sensor+Array+Drift+Dataset#)
20 | - **Shuttle**: [[Numpy format]](http://di.postech.ac.kr/donalee/deepmcdd/shuttle_preproc.npy) [[Raw format]](https://archive.ics.uci.edu/ml/datasets/Statlog+(Shuttle))
21 | - **DriveDiagnosis**: [[Numpy format]](http://di.postech.ac.kr/donalee/deepmcdd/drive_preproc.npy) [[Raw format]](https://archive.ics.uci.edu/ml/datasets/Dataset+for+Sensorless+Drive+Diagnosis)
22 | - **MNIST**: [[Numpy format]](http://di.postech.ac.kr/donalee/deepmcdd/mnist_preproc.npy) [[Raw format]](http://yann.lecun.com/exdb/mnist/)
23 |
24 | ## Downloading image datasets
25 |
26 | The three in-distribution datasets (i.e., **SVHN**, **CIFAR-10**, and **CIFAR-100**) would be automatically downloaded via `torchvision`.
27 | We use the download links of two additional out-of-distributin datasets (i.e., **TinyImageNet** and **LSUN**) from [Deep Mahalanobis Detector](https://github.com/pokaxpoka/deep_Mahalanobis_detector) and [ODIN Detector](https://github.com/ShiyuLiang/odin-pytorch).
28 | Place them in the directory `./image_data/`.
29 |
30 | ## Running the codes
31 |
32 | - python
33 | - torch (GPU version only)
34 |
35 | #### Training MLP with Deep-MCDD for tabular data
36 | ```
37 | python train_deepmcdd_table.py --dataset gas --net_type mlp --oodclass_idx 0
38 | ```
39 |
40 | #### Training CNN with Deep-MCDD for image data
41 | ```
42 | python train_deepmcdd_image.py --dataset cifar10 --net_type resnet
43 | ```
44 |
45 | ## Citation
46 | ```
47 | @inproceedings{lee2020multi,
48 | author = {Lee, Dongha and Yu, Sehun and Yu, Hwanjo},
49 | title = {Multi-Class Data Description for Out-of-Distribution Detection},
50 | year = {2020},
51 | booktitle = {Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining},
52 | pages = {1362–1370}
53 | }
54 | ```
55 |
--------------------------------------------------------------------------------
/dataloader_table.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import DataLoader, TensorDataset
4 | import numpy as np
5 |
6 | def get_table_data(batch_size, data_dir, dataset, oodclass_idx, fold_idx, **kwargs):
7 |
8 | data_path = os.path.join(data_dir, dataset + '_preproc.npy')
9 | features, labels, num_classes = np.load(data_path, allow_pickle=True)
10 | n_data = len(labels)
11 |
12 | id_classes = [c for c in range(num_classes) if c != oodclass_idx]
13 | id_indices = {c: [i for i in range(len(labels)) if labels[i] == c] for c in id_classes}
14 |
15 | np.random.seed(0)
16 | for c in id_classes:
17 | np.random.shuffle(id_indices[c])
18 | test_id_indices = {c: id_indices[c][int(0.2*fold_idx*len(id_indices[c])):int(0.2*(fold_idx+1)*len(id_indices[c]))] for c in id_classes}
19 |
20 | id_indices = np.concatenate(list(id_indices.values()))
21 | test_id_indices = np.concatenate(list(test_id_indices.values()))
22 | train_id_indices = np.array([i for i in id_indices if i not in test_id_indices])
23 | ood_indices = [i for i in range(len(labels)) if labels[i] == oodclass_idx]
24 |
25 | for i in range(len(labels)):
26 | if labels[i] == oodclass_idx:
27 | labels[i] = -1
28 | elif labels[i] > oodclass_idx:
29 | labels[i] -= 1
30 |
31 | train_dataset = TensorDataset(torch.Tensor(features[train_id_indices]),
32 | torch.LongTensor(labels[train_id_indices]))
33 | test_id_dataset = TensorDataset(torch.Tensor(features[test_id_indices]),
34 | torch.LongTensor(labels[test_id_indices]))
35 | test_ood_dataset = TensorDataset(torch.Tensor(features[ood_indices]),
36 | torch.LongTensor(labels[ood_indices]))
37 |
38 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
39 | test_id_loader = DataLoader(dataset=test_id_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
40 | test_ood_loader = DataLoader(dataset=test_ood_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
41 |
42 | return train_loader, test_id_loader, test_ood_loader
43 |
44 | def get_total_table_data(batch_size, data_dir, dataset, fold_idx, **kwargs):
45 |
46 | data_path = os.path.join(data_dir, dataset + '_preproc.npy')
47 | features, labels, num_classes = np.load(data_path, allow_pickle=True)
48 | n_data = len(labels)
49 | id_indices = [i for i in range(len(labels))]
50 |
51 | np.random.seed(0)
52 | np.random.shuffle(id_indices)
53 | test_id_indices = id_indices[int(0.2*fold_idx*len(id_indices)):int(0.2*(fold_idx+1)*len(id_indices))]
54 | train_id_indices = np.array([i for i in id_indices if i not in test_id_indices])
55 |
56 | test_dataset = TensorDataset(torch.Tensor(features[test_id_indices]),
57 | torch.LongTensor(labels[test_id_indices]))
58 | train_dataset = TensorDataset(torch.Tensor(features[train_id_indices]),
59 | torch.LongTensor(labels[train_id_indices]))
60 |
61 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
62 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
63 |
64 | return train_loader, test_loader
65 |
66 |
--------------------------------------------------------------------------------
/train_deepmcdd_table.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import torch
3 | import numpy as np
4 |
5 | import models
6 | from dataloader_table import get_table_data
7 | from utils import compute_confscores, compute_metrics, print_ood_results, print_ood_results_total
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--dataset', required=True, help='gas | shuttle | drive | mnist')
11 | parser.add_argument('--net_type', required=True, help='mlp')
12 | parser.add_argument('--datadir', default='./table_data/', help='path to dataset')
13 | parser.add_argument('--outdir', default='./output/', help='folder to output results')
14 | parser.add_argument('--oodclass_idx', type=int, default=0, help='index of the OOD class')
15 | parser.add_argument('--batch_size', type=int, default=200, help='batch size for data loader')
16 | parser.add_argument('--latent_size', type=int, default=128, help='dimension size for latent representation')
17 | parser.add_argument('--num_layers', type=int, default=3, help='the number of hidden layers in MLP')
18 | parser.add_argument('--num_folds', type=int, default=5, help='the number of cross-validation folds')
19 | parser.add_argument('--num_epochs', type=int, default=10, help='the number of epochs for training sc-layers')
20 | parser.add_argument('--learning_rate', type=float, default=0.001, help='initial learning rate of Adam optimizer')
21 | parser.add_argument('--reg_lambda', type=float, default=1.0, help='regularization coefficient')
22 | parser.add_argument('--gpu', type=int, default=0, help='gpu index')
23 |
24 | args = parser.parse_args()
25 | print(args)
26 |
27 | def main():
28 | outdir = os.path.join(args.outdir, args.net_type + '_' + args.dataset)
29 |
30 | if os.path.isdir(outdir) == False:
31 | os.mkdir(outdir)
32 |
33 | torch.manual_seed(0)
34 | torch.cuda.manual_seed_all(0)
35 | torch.cuda.set_device(args.gpu)
36 |
37 | best_idacc_list, best_oodacc_list = [], []
38 | for fold_idx in range(args.num_folds):
39 |
40 | train_loader, test_id_loader, test_ood_loader = get_table_data(args.batch_size, args.datadir, args.dataset, args.oodclass_idx, fold_idx)
41 |
42 | if args.dataset == 'gas':
43 | num_classes, num_features = 6, 128
44 | elif args.dataset == 'drive':
45 | num_classes, num_features = 11, 48
46 | elif args.dataset == 'shuttle':
47 | num_classes, num_features = 7, 9
48 | elif args.dataset == 'mnist':
49 | num_classes, num_features = 10, 784
50 |
51 | model = models.MLP_DeepMCDD(num_features, args.num_layers*[args.latent_size], num_classes=num_classes-1)
52 | model.cuda()
53 |
54 | ce_loss = torch.nn.CrossEntropyLoss()
55 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
56 |
57 | idacc_list, oodacc_list = [], []
58 | total_step = len(train_loader)
59 | for epoch in range(args.num_epochs):
60 | model.train()
61 | total_loss = 0.0
62 |
63 | for i, (data, labels) in enumerate(train_loader):
64 | data, labels = data.cuda(), labels.cuda()
65 | dists = model(data)
66 | scores = - dists + model.alphas
67 |
68 | label_mask = torch.zeros(labels.size(0), model.num_classes).cuda().scatter_(1, labels.unsqueeze(dim=1), 1)
69 |
70 | pull_loss = torch.mean(torch.sum(torch.mul(label_mask, dists), dim=1))
71 | push_loss = ce_loss(scores, labels)
72 | loss = args.reg_lambda * pull_loss + push_loss
73 |
74 | optimizer.zero_grad()
75 | loss.backward()
76 | optimizer.step()
77 |
78 | total_loss += loss.item()
79 |
80 | model.eval()
81 | with torch.no_grad():
82 | # (1) evaluate ID classification
83 | correct, total = 0, 0
84 | for data, labels in test_id_loader:
85 | data, labels = data.cuda(), labels.cuda()
86 | scores = - model(data) + model.alphas
87 | _, predicted = torch.max(scores, 1)
88 | total += labels.size(0)
89 | correct += (predicted == labels).sum().item()
90 | idacc_list.append(100 * correct / total)
91 |
92 | # (2) evaluate OOD detection
93 | compute_confscores(model, test_id_loader, outdir, True)
94 | compute_confscores(model, test_ood_loader, outdir, False)
95 | oodacc_list.append(compute_metrics(outdir))
96 |
97 | best_idacc = max(idacc_list)
98 | best_oodacc = oodacc_list[idacc_list.index(best_idacc)]
99 |
100 | print('== {fidx:1d}-th fold results =='.format(fidx=fold_idx+1))
101 | print('The best ID accuracy on "{idset:s}" test samples : {val:6.2f}'.format(idset=args.dataset, val=best_idacc))
102 | print('The best OOD accuracy on "{oodset:s}" test samples :'.format(oodset=args.dataset+'_'+str(args.oodclass_idx)))
103 | print_ood_results(best_oodacc)
104 |
105 | best_idacc_list.append(best_idacc)
106 | best_oodacc_list.append(best_oodacc)
107 |
108 | print('== Final results ==')
109 | print('The best ID accuracy on "{idset:s}" test samples : {mean:6.2f} ({std:6.3f})'.format(idset=args.dataset, mean=np.mean(best_idacc_list), std=np.std(best_idacc_list)))
110 | print('The best OOD accuracy on "{oodset:s}" test samples :'.format(oodset='class_'+str(args.oodclass_idx)))
111 | print_ood_results_total(best_oodacc_list)
112 |
113 |
114 | if __name__ == '__main__':
115 | main()
116 |
--------------------------------------------------------------------------------
/dataloader_image.py:
--------------------------------------------------------------------------------
1 | #
2 | # This file is implemented based on the author code of
3 | # Lee et al., "A simple unified framework for detecting out-of-distribution samples and adversarial attacks", in NeurIPS 2018.
4 | #
5 |
6 | import os
7 | import torch
8 | from torchvision import datasets
9 | from torch.utils.data import DataLoader
10 |
11 | def get_svhn(batch_size, train_TF, test_TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
12 | data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
13 | num_workers = kwargs.setdefault('num_workers', 1)
14 | kwargs.pop('input_size', None)
15 |
16 | ds = []
17 | if train:
18 | train_loader = torch.utils.data.DataLoader(
19 | datasets.SVHN(
20 | root=data_root, split='train', download=True,
21 | transform=train_TF,
22 | ),
23 | batch_size=batch_size, shuffle=True, **kwargs)
24 | ds.append(train_loader)
25 |
26 | if val:
27 | test_loader = torch.utils.data.DataLoader(
28 | datasets.SVHN(
29 | root=data_root, split='test', download=True,
30 | transform=test_TF,
31 | ),
32 | batch_size=batch_size, shuffle=False, **kwargs)
33 | ds.append(test_loader)
34 | ds = ds[0] if len(ds) == 1 else ds
35 | return ds
36 |
37 | def get_cifar10(batch_size, train_TF, test_TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
38 | data_root = os.path.expanduser(os.path.join(data_root, 'cifar10-data'))
39 | num_workers = kwargs.setdefault('num_workers', 1)
40 | kwargs.pop('input_size', None)
41 | ds = []
42 |
43 | if train:
44 | train_loader = torch.utils.data.DataLoader(
45 | datasets.CIFAR10(
46 | root=data_root, train=True, download=True,
47 | transform=train_TF),
48 | batch_size=batch_size, shuffle=True, **kwargs)
49 | ds.append(train_loader)
50 |
51 | if val:
52 | test_loader = torch.utils.data.DataLoader(
53 | datasets.CIFAR10(
54 | root=data_root, train=False, download=True,
55 | transform=test_TF),
56 | batch_size=batch_size, shuffle=False, **kwargs)
57 | ds.append(test_loader)
58 |
59 | ds = ds[0] if len(ds) == 1 else ds
60 | return ds
61 |
62 | def get_cifar100(batch_size, train_TF, test_TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
63 | data_root = os.path.expanduser(os.path.join(data_root, 'cifar100-data'))
64 | num_workers = kwargs.setdefault('num_workers', 1)
65 | kwargs.pop('input_size', None)
66 | ds = []
67 |
68 | if train:
69 | train_loader = torch.utils.data.DataLoader(
70 | datasets.CIFAR100(
71 | root=data_root, train=True, download=True,
72 | transform=train_TF),
73 | batch_size=batch_size, shuffle=True, **kwargs)
74 | ds.append(train_loader)
75 |
76 | if val:
77 | test_loader = torch.utils.data.DataLoader(
78 | datasets.CIFAR100(
79 | root=data_root, train=False, download=True,
80 | transform=test_TF),
81 | batch_size=batch_size, shuffle=False, **kwargs)
82 | ds.append(test_loader)
83 |
84 | ds = ds[0] if len(ds) == 1 else ds
85 | return ds
86 |
87 | def get_id_image_data(data_type, batch_size, train_TF, test_TF, dataroot):
88 | if data_type == 'cifar10':
89 | train_loader, test_loader = get_cifar10(batch_size=batch_size, train_TF=train_TF, test_TF=test_TF, data_root=dataroot, num_workers=1)
90 | elif data_type == 'cifar100':
91 | train_loader, test_loader = get_cifar100(batch_size=batch_size, train_TF=train_TF, test_TF=test_TF, data_root=dataroot, num_workers=1)
92 | elif data_type == 'svhn':
93 | train_loader, test_loader = get_svhn(batch_size=batch_size, train_TF=train_TF, test_TF=test_TF, data_root=dataroot, num_workers=1)
94 |
95 | return train_loader, test_loader
96 |
97 | def get_ood_image_data(data_type, batch_size, input_TF, dataroot):
98 | if data_type == 'cifar10':
99 | _, test_loader = get_cifar10(batch_size=batch_size, train_TF=input_TF, test_TF=input_TF, data_root=dataroot, num_workers=1)
100 | elif data_type == 'svhn':
101 | _, test_loader = get_svhn(batch_size=batch_size, train_TF=input_TF, test_TF=input_TF, data_root=dataroot, num_workers=1)
102 | elif data_type == 'cifar100':
103 | _, test_loader = get_cifar100(batch_size=batch_size, train_TF=input_TF, test_TF=input_TF, data_root=dataroot, num_workers=1)
104 | elif data_type == 'imagenet_crop':
105 | dataroot = os.path.expanduser(os.path.join(dataroot, 'Imagenet_crop'))
106 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF)
107 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1)
108 | elif data_type == 'lsun_crop':
109 | dataroot = os.path.expanduser(os.path.join(dataroot, 'LSUN_crop'))
110 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF)
111 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1)
112 | elif data_type == 'imagenet_resize':
113 | dataroot = os.path.expanduser(os.path.join(dataroot, 'Imagenet_resize'))
114 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF)
115 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1)
116 | elif data_type == 'lsun_resize':
117 | dataroot = os.path.expanduser(os.path.join(dataroot, 'LSUN_resize'))
118 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF)
119 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1)
120 |
121 | return test_loader
122 |
123 |
124 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # This file is implemented based on the author code of
3 | # Lee et al., "A simple unified framework for detecting out-of-distribution samples and adversarial attacks", in NeurIPS 2018.
4 | #
5 |
6 | import os
7 | import torch
8 | import numpy as np
9 |
10 | def compute_confscores(model, test_loader, outdir, id_flag):
11 | total = 0
12 | if id_flag == True:
13 | outfile = os.path.join(outdir, 'confscores_id.txt')
14 | else:
15 | outfile = os.path.join(outdir, 'confscores_ood.txt')
16 |
17 | f = open(outfile, 'w')
18 |
19 | for data, _ in test_loader:
20 | dists = model(data.cuda())
21 | confscores, _ = torch.min(dists, dim=1)
22 | total += data.size(0)
23 |
24 | for i in range(data.size(0)):
25 | f.write("{}\n".format(-confscores[i]))
26 |
27 | f.close()
28 |
29 | def get_auroc_curve(indir):
30 | known = np.loadtxt(os.path.join(indir, 'confscores_id.txt'), delimiter='\n')
31 | novel = np.loadtxt(os.path.join(indir, 'confscores_ood.txt'), delimiter='\n')
32 | known.sort()
33 | novel.sort()
34 |
35 | end = np.max([np.max(known), np.max(novel)])
36 | start = np.min([np.min(known),np.min(novel)])
37 |
38 | num_k = known.shape[0]
39 | num_n = novel.shape[0]
40 |
41 | tp = -np.ones([num_k+num_n+1], dtype=int)
42 | fp = -np.ones([num_k+num_n+1], dtype=int)
43 | tp[0], fp[0] = num_k, num_n
44 | k, n = 0, 0
45 | for l in range(num_k+num_n):
46 | if k == num_k:
47 | tp[l+1:] = tp[l]
48 | fp[l+1:] = np.arange(fp[l]-1, -1, -1)
49 | break
50 | elif n == num_n:
51 | tp[l+1:] = np.arange(tp[l]-1, -1, -1)
52 | fp[l+1:] = fp[l]
53 | break
54 | else:
55 | if novel[n] < known[k]:
56 | n += 1
57 | tp[l+1] = tp[l]
58 | fp[l+1] = fp[l] - 1
59 | else:
60 | k += 1
61 | tp[l+1] = tp[l] - 1
62 | fp[l+1] = fp[l]
63 | tpr85_pos = np.abs(tp / num_k - .85).argmin()
64 | tpr95_pos = np.abs(tp / num_k - .95).argmin()
65 | tnr_at_tpr85 = 1. - fp[tpr85_pos] / num_n
66 | tnr_at_tpr95 = 1. - fp[tpr95_pos] / num_n
67 | return tp, fp, tnr_at_tpr85, tnr_at_tpr95
68 |
69 | def compute_metrics(dir_name, verbose=False):
70 | tp, fp, tnr_at_tpr85, tnr_at_tpr95 = get_auroc_curve(dir_name)
71 | results = dict()
72 | mtypes = ['TNR85', 'TNR95', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']
73 | if verbose:
74 | print(' ', end='')
75 | for mtype in mtypes:
76 | print(' {mtype:6s}'.format(mtype=mtype), end='')
77 | print('')
78 |
79 | if verbose:
80 | print('{stype:5s} '.format(stype=stype), end='')
81 | results = dict()
82 |
83 | # TNR85
84 | mtype = 'TNR85'
85 | results[mtype] = tnr_at_tpr85
86 | if verbose:
87 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='')
88 |
89 | # TNR95
90 | mtype = 'TNR95'
91 | results[mtype] = tnr_at_tpr95
92 | if verbose:
93 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='')
94 |
95 | # AUROC
96 | mtype = 'AUROC'
97 | tpr = np.concatenate([[1.], tp/tp[0], [0.]])
98 | fpr = np.concatenate([[1.], fp/fp[0], [0.]])
99 | results[mtype] = -np.trapz(1. - fpr, tpr)
100 | if verbose:
101 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='')
102 |
103 | # DTACC
104 | mtype = 'DTACC'
105 | results[mtype] = .5 * (tp/tp[0] + 1. - fp/fp[0]).max()
106 | if verbose:
107 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='')
108 |
109 | # AUIN
110 | mtype = 'AUIN'
111 | denom = tp + fp
112 | denom[denom == 0.] = -1.
113 | pin_ind = np.concatenate([[True], denom > 0., [True]])
114 | pin = np.concatenate([[.5], tp/denom, [0.]])
115 | results[mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind])
116 | if verbose:
117 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='')
118 |
119 | # AUOUT
120 | mtype = 'AUOUT'
121 | denom = tp[0] - tp + fp[0] - fp
122 | denom[denom == 0.] = -1.
123 | pout_ind = np.concatenate([[True], denom > 0., [True]])
124 | pout = np.concatenate([[0.], (fp[0] - fp)/denom, [.5]])
125 | results[mtype] = np.trapz(pout[pout_ind], 1. - fpr[pout_ind])
126 | if verbose:
127 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='')
128 | print('')
129 |
130 | return results
131 |
132 | def print_ood_results(ood_result):
133 |
134 | for mtype in ['TNR85', 'TNR95', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']:
135 | print(' {mtype:6s}'.format(mtype=mtype), end='')
136 | print('\n{val:6.2f}'.format(val=100.*ood_result['TNR85']), end='')
137 | print(' {val:6.2f}'.format(val=100.*ood_result['TNR95']), end='')
138 | print(' {val:6.2f}'.format(val=100.*ood_result['AUROC']), end='')
139 | print(' {val:6.2f}'.format(val=100.*ood_result['DTACC']), end='')
140 | print(' {val:6.2f}'.format(val=100.*ood_result['AUIN']), end='')
141 | print(' {val:6.2f}\n'.format(val=100.*ood_result['AUOUT']), end='')
142 | print('')
143 |
144 | def print_ood_results_total(ood_result_list):
145 |
146 | TNR85_list = [100.*ood_result['TNR85'] for ood_result in ood_result_list]
147 | TNR95_list = [100.*ood_result['TNR95'] for ood_result in ood_result_list]
148 | AUROC_list = [100.*ood_result['AUROC'] for ood_result in ood_result_list]
149 | DTACC_list = [100.*ood_result['DTACC'] for ood_result in ood_result_list]
150 | AUIN_list = [100.*ood_result['AUIN'] for ood_result in ood_result_list]
151 | AUOUT_list = [100.*ood_result['AUOUT'] for ood_result in ood_result_list]
152 |
153 | for mtype in ['TNR85', 'TNR95', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']:
154 | print(' {mtype:15s}'.format(mtype=mtype), end='')
155 | print('\n{mean:6.2f} ({std:6.3f})'.format(mean=np.mean(TNR85_list), std=np.std(TNR85_list)), end='')
156 | print(' {mean:6.2f} ({std:6.3f})'.format(mean=np.mean(TNR95_list), std=np.std(TNR95_list)), end='')
157 | print(' {mean:6.2f} ({std:6.3f})'.format(mean=np.mean(AUROC_list), std=np.std(AUROC_list)), end='')
158 | print(' {mean:6.2f} ({std:6.3f})'.format(mean=np.mean(DTACC_list), std=np.std(DTACC_list)), end='')
159 | print(' {mean:6.2f} ({std:6.3f})'.format(mean=np.mean(AUIN_list), std=np.std(AUIN_list)), end='')
160 | print(' {mean:6.2f} ({std:6.3f})\n'.format(mean=np.mean(AUOUT_list), std=np.std(AUOUT_list)), end='')
161 | print('')
162 |
163 |
--------------------------------------------------------------------------------
/train_deepmcdd_image.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import torch
3 | import numpy as np
4 | from torchvision import transforms
5 |
6 | import models
7 | from dataloader_image import get_id_image_data, get_ood_image_data
8 | from utils import compute_confscores, compute_metrics, print_ood_results
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--dataset', required=True, help='svhn | cifar10 | cifar100')
12 | parser.add_argument('--net_type', required=True, help='resnet | densenet')
13 | parser.add_argument('--datadir', default='./image_data/', help='path to dataset')
14 | parser.add_argument('--outdir', default='./output/', help='folder to output results')
15 | parser.add_argument('--modeldir', default='./trained_model/', help='folder to trained model')
16 | parser.add_argument('--batch_size', type=int, default=128, help='batch size for data loader')
17 | parser.add_argument('--num_epochs', type=int, default=200, help='the number of epochs for training sc-layers')
18 | parser.add_argument('--learning_rate', type=float, default=0.1, help='initial learning rate of SGD optimizer')
19 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum factor for nesterov momentum of SGD optimizer')
20 | parser.add_argument('--reg_lambda', type=float, default=0.1, help='regularization coefficient')
21 | parser.add_argument('--pretrained', type=bool, default=False, help='initialize the network with pretrained weights or random weights')
22 | parser.add_argument('--gpu', type=int, default=0, help='gpu index')
23 |
24 | args = parser.parse_args()
25 | print(args)
26 |
27 | def main():
28 | # set the path to pre-trained model and output
29 | outdir = os.path.join(args.outdir, args.net_type + '_' + args.dataset)
30 | pretrained_path = os.path.join('./pretrained/', args.net_type + '_' + args.dataset + '.pth')
31 | model_path = os.path.join(args.modeldir, args.net_type + '_' + args.dataset + '.pth')
32 |
33 | if os.path.isdir(outdir) == False:
34 | os.mkdir(outdir)
35 | if os.path.isdir(args.modeldir) == False:
36 | os.mkdir(args.modeldir)
37 |
38 | torch.cuda.manual_seed(0)
39 | torch.cuda.manual_seed_all(0)
40 | torch.cuda.set_device(args.gpu)
41 |
42 | if args.dataset == 'svhn':
43 | num_classes = 10
44 | ood_list = ['cifar10', 'imagenet_crop', 'lsun_crop']
45 | elif args.dataset == 'cifar10':
46 | num_classes = 10
47 | ood_list = ['svhn', 'imagenet_crop', 'lsun_crop']
48 | elif args.dataset == 'cifar100':
49 | num_classes = 100
50 | ood_list = ['svhn', 'imagenet_crop', 'lsun_crop']
51 |
52 | if args.net_type == 'densenet':
53 | model = models.DenseNet_DeepMCDD(num_classes=num_classes)
54 | if args.pretrained == True:
55 | model.load_fe_weights(torch.load(pretrained_path, map_location = "cuda:" + str(args.gpu)))
56 | in_transform_train = transforms.Compose([
57 | transforms.RandomCrop(32, padding=4),
58 | transforms.RandomHorizontalFlip(),
59 | transforms.ToTensor(),
60 | transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)),
61 | ])
62 | in_transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)),])
63 |
64 | elif args.net_type == 'resnet':
65 | model = models.ResNet_DeepMCDD(num_classes=num_classes)
66 | if args.pretrained == True:
67 | model.load_fe_weights(torch.load(pretrained_path, map_location = "cuda:" + str(args.gpu)))
68 | in_transform_train = transforms.Compose([
69 | transforms.RandomCrop(32, padding=4),
70 | transforms.RandomHorizontalFlip(),
71 | transforms.ToTensor(),
72 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
73 | ])
74 | in_transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])
75 |
76 | model.cuda()
77 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=1e-4, nesterov=True)
78 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.num_epochs*0.5), int(args.num_epochs*0.75)], gamma=0.1)
79 |
80 | train_loader, test_id_loader = get_id_image_data(args.dataset, args.batch_size, in_transform_train, in_transform_test, args.datadir)
81 | ce_loss = torch.nn.CrossEntropyLoss()
82 |
83 | for epoch in range(args.num_epochs):
84 | model.train()
85 | total_loss = 0.0
86 |
87 | for i, (images, labels) in enumerate(train_loader):
88 | images, labels = images.cuda(), labels.cuda()
89 | dists = model(images)
90 | scores = - dists + model.alphas
91 |
92 | label_mask = torch.zeros(labels.size(0), num_classes).cuda().scatter_(1, labels.unsqueeze(dim=1), 1)
93 |
94 | pull_loss = torch.mean(torch.sum(torch.mul(label_mask, dists), dim=1))
95 | push_loss = ce_loss(scores, labels)
96 | loss = args.reg_lambda * pull_loss + push_loss
97 |
98 | optimizer.zero_grad()
99 | loss.backward()
100 | optimizer.step()
101 |
102 | total_loss += loss.item()
103 |
104 | scheduler.step()
105 |
106 | model.eval()
107 | with torch.no_grad():
108 | # (1) evaluate ID classification
109 | correct, total = 0, 0
110 | for images, labels in test_id_loader:
111 | images, labels = images.cuda(), labels.cuda()
112 | scores = - model(images) + model.alphas
113 | _, predicted = torch.max(scores, 1)
114 | total += labels.size(0)
115 | correct += (predicted == labels).sum().item()
116 | idacc = 100 * correct / total
117 |
118 | ood_results_list = []
119 | compute_confscores(model, test_id_loader, outdir, True)
120 |
121 | for ood in ood_list:
122 | test_ood_loader = get_ood_image_data(ood, args.batch_size, in_transform_test, args.datadir)
123 | compute_confscores(model, test_ood_loader, outdir, False)
124 | ood_results_list.append(compute_metrics(outdir))
125 |
126 | print('== Epoch [{}/{}], Loss {} =='.format(epoch+1, args.num_epochs, total_loss))
127 | print('ID Accuracy on "{idset:s}" test images : {val:6.2f}\n'.format(idset=args.dataset, val=idacc))
128 | for ood_idx, ood_results in enumerate(ood_results_list):
129 | print('OOD accuracy on "{oodset:s}" test samples :'.format(oodset=ood_list[ood_idx]))
130 | print_ood_results(ood_results)
131 |
132 | torch.save(model.state_dict(), model_path)
133 |
134 | if __name__ == '__main__':
135 | main()
136 |
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | import os, math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 |
8 | # MLP with a Softmax classifier
9 |
10 | class MLP(nn.Module):
11 | def __init__(self, input_size, hidden_sizes, num_classes):
12 | super(MLP, self).__init__()
13 | self.input_size = input_size
14 | self.hidden_sizes = hidden_sizes[:-1]
15 | self.latent_size = hidden_sizes[-1]
16 | self.num_classes = num_classes
17 |
18 | self.build_fe()
19 | self.init_fe_weights()
20 |
21 | def build_fe(self):
22 | layers = []
23 | layer_sizes = [self.input_size] + self.hidden_sizes
24 |
25 | for i in range(len(layer_sizes)-1):
26 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
27 | layers.append(nn.ReLU())
28 | layers.append(nn.Linear(layer_sizes[-1], self.latent_size))
29 | layers.append(nn.Linear(self.latent_size, self.num_classes))
30 | self.layers = nn.ModuleList(layers)
31 |
32 | def init_fe_weights(self):
33 | for m in self.modules():
34 | if isinstance(m, nn.Linear):
35 | nn.init.xavier_uniform_(m.weight)
36 | nn.init.zeros_(m.bias)
37 |
38 | def _forward(self, x):
39 | for i, layer in enumerate(self.layers[:-1]):
40 | x = layer(x)
41 | return x
42 |
43 | def forward(self, x):
44 | out = self._forward(x)
45 | score = self.layers[-1](out)
46 | return score
47 |
48 | def feature_list(self, x):
49 | out = self._forward(x)
50 | score = self.layers[-1](out)
51 | return score, [out]
52 |
53 | def intermediate_forward(self, x, layer_index):
54 | out = self._forward(x)
55 | return out
56 |
57 |
58 | # MLP with a Deep-MCDD classifier
59 |
60 | class MLP_DeepMCDD(nn.Module):
61 | def __init__(self, input_size, hidden_sizes, num_classes):
62 | super(MLP_DeepMCDD, self).__init__()
63 | self.input_size = input_size
64 | self.hidden_sizes = hidden_sizes[:-1]
65 | self.latent_size = hidden_sizes[-1]
66 | self.num_classes = num_classes
67 |
68 | self.centers = torch.nn.Parameter(torch.zeros([num_classes, self.latent_size]), requires_grad=True)
69 | self.alphas = torch.nn.Parameter(torch.zeros(num_classes), requires_grad=True)
70 | self.logsigmas = torch.nn.Parameter(torch.zeros(num_classes), requires_grad=True)
71 |
72 | self.build_fe()
73 | self.init_fe_weights()
74 |
75 | def build_fe(self):
76 | layers = []
77 | layer_sizes = [self.input_size] + self.hidden_sizes
78 |
79 | for i in range(len(layer_sizes)-1):
80 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
81 | layers.append(nn.ReLU())
82 | layers.append(nn.Linear(layer_sizes[-1], self.latent_size))
83 | self.layers = nn.ModuleList(layers)
84 |
85 | def init_fe_weights(self):
86 | for m in self.modules():
87 | if isinstance(m, nn.Linear):
88 | nn.init.xavier_uniform_(m.weight)
89 | nn.init.zeros_(m.bias)
90 | nn.init.xavier_uniform_(self.centers)
91 | nn.init.zeros_(self.alphas)
92 | nn.init.zeros_(self.logsigmas)
93 |
94 | def _forward(self, x):
95 | for i, layer in enumerate(self.layers):
96 | x = layer(x)
97 | return x
98 |
99 | def forward(self, x):
100 | out = self._forward(x)
101 | out = out.unsqueeze(dim=1).repeat([1, self.num_classes, 1])
102 | scores = torch.sum((out - self.centers)**2, dim=2) / 2 / torch.exp(2 * F.relu(self.logsigmas)) + self.latent_size * F.relu(self.logsigmas)
103 | return scores
104 |
105 |
106 | # MLP with a Soft-MCDD classifier
107 |
108 | class MLP_SoftMCDD(nn.Module):
109 | def __init__(self, input_size, hidden_sizes, num_classes, epsilon=0.1):
110 | super(MLP_SoftMCDD, self).__init__()
111 | self.input_size = input_size
112 | self.hidden_sizes = hidden_sizes[:-1]
113 | self.latent_size = hidden_sizes[-1]
114 | self.num_classes = num_classes
115 | self.epsilon = epsilon
116 |
117 | self.centers = torch.zeros([num_classes, self.latent_size]).cuda()
118 | self.radii = torch.ones(num_classes).cuda()
119 |
120 | self.build_fe()
121 | self.init_fe_weights()
122 |
123 | def build_fe(self):
124 | layers = []
125 | layer_sizes = [self.input_size] + self.hidden_sizes
126 |
127 | for i in range(len(layer_sizes)-1):
128 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
129 | layers.append(nn.ReLU())
130 | layers.append(nn.Linear(layer_sizes[-1], self.latent_size))
131 | self.layers = nn.ModuleList(layers)
132 |
133 | def init_fe_weights(self):
134 | for m in self.modules():
135 | if isinstance(m, nn.Linear):
136 | nn.init.xavier_uniform_(m.weight)
137 | nn.init.zeros_(m.bias)
138 | nn.init.xavier_uniform_(self.centers)
139 |
140 | def init_centers(self):
141 | nn.init.xavier_uniform_(self.centers)
142 | self.centers = 10 * self.centers / torch.norm(self.centers, dim=1)
143 |
144 | def update_centers(self, data_loader):
145 | class_outputs = {i : [] for i in range(self.num_classes)}
146 | centers = torch.zeros([self.num_classes, self.latent_size]).cuda()
147 |
148 | with torch.no_grad():
149 | for data in data_loader:
150 | inputs, labels = data
151 | inputs, labels = inputs.cuda(), labels.cuda()
152 | outputs = self._forward(inputs)
153 |
154 | for k in range(self.num_classes):
155 | indices = (labels == k).nonzero().squeeze(dim=1)
156 | class_outputs[k].append(outputs[indices])
157 |
158 | for k in range(self.num_classes):
159 | class_outputs[k] = torch.cat(class_outputs[k], dim=0)
160 | centers[k] = torch.mean(class_outputs[k], dim=0)
161 |
162 | self.centers.data = centers
163 |
164 | def update_radii(self, data_loader):
165 | class_scores = {i : [] for i in range(self.num_classes)}
166 | radii = np.zeros(self.num_classes)
167 |
168 | with torch.no_grad():
169 | for data in data_loader:
170 | inputs, labels = data
171 | inputs, labels = inputs.cuda(), labels.cuda()
172 | scores = self.forward(inputs)
173 |
174 | for k in range(self.num_classes):
175 | indices = (labels == k).nonzero().squeeze(dim=1)
176 | class_scores[k].append(torch.sqrt(scores[indices][:, k]))
177 |
178 | for k in range(self.num_classes):
179 | class_scores[k] = torch.cat(class_scores[k], dim=0)
180 | radii[k] = np.quantile(class_scores[k].cpu().numpy(), 1 - self.epsilon)
181 |
182 | self.radii = torch.Tensor(radii).cuda()
183 |
184 | def update_centers_and_radii(self, data_loader):
185 | class_outputs = {i : [] for i in range(self.num_classes)}
186 | class_scores = {i : [] for i in range(self.num_classes)}
187 | centers = torch.zeros([self.num_classes, self.latent_size]).cuda()
188 | radii = np.zeros(self.num_classes)
189 |
190 | with torch.no_grad():
191 | for data in data_loader:
192 | inputs, labels = data
193 | inputs, labels = inputs.cuda(), labels.cuda()
194 | outputs, scores = self._forward(inputs), self.forward(inputs)
195 |
196 | for k in range(self.num_classes):
197 | indices = (labels == k).nonzero().squeeze(dim=1)
198 | class_outputs[k].append(outputs[indices])
199 | class_scores[k].append(torch.sqrt(scores[indices][:, k]))
200 |
201 | for k in range(self.num_classes):
202 | class_outputs[k] = torch.cat(class_outputs[k], dim=0)
203 | class_scores[k] = torch.cat(class_scores[k], dim=0)
204 | centers[k] = torch.mean(class_outputs[k], dim=0)
205 | radii[k] = np.quantile(class_scores[k].cpu().numpy(), 1 - self.epsilon)
206 |
207 | self.centers.data = centers
208 | self.radii = torch.Tensor(radii).cuda()
209 |
210 | def _forward(self, x):
211 | for i, layer in enumerate(self.layers):
212 | x = layer(x)
213 | return x
214 |
215 | def forward(self, x):
216 | out = self._forward(x)
217 | out = out.unsqueeze(dim=1).repeat([1, self.num_classes, 1])
218 | scores = torch.sum((out - self.centers)**2, dim=2)
219 | return scores
220 |
221 |
--------------------------------------------------------------------------------
/models/densenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
13 | padding=1, bias=False)
14 | self.droprate = dropRate
15 | def forward(self, x):
16 | out = self.conv1(self.relu(self.bn1(x)))
17 | if self.droprate > 0:
18 | out = F.dropout(out, p=self.droprate, training=self.training)
19 | return torch.cat([x, out], 1)
20 |
21 | class BottleneckBlock(nn.Module):
22 | def __init__(self, in_planes, out_planes, dropRate=0.0):
23 | super(BottleneckBlock, self).__init__()
24 | inter_planes = out_planes * 4
25 | self.bn1 = nn.BatchNorm2d(in_planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
28 | padding=0, bias=False)
29 | self.bn2 = nn.BatchNorm2d(inter_planes)
30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
31 | padding=1, bias=False)
32 | self.droprate = dropRate
33 | def forward(self, x):
34 | out = self.conv1(self.relu(self.bn1(x)))
35 | if self.droprate > 0:
36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
37 | out = self.conv2(self.relu(self.bn2(out)))
38 | if self.droprate > 0:
39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
40 | return torch.cat([x, out], 1)
41 |
42 | class TransitionBlock(nn.Module):
43 | def __init__(self, in_planes, out_planes, dropRate=0.0):
44 | super(TransitionBlock, self).__init__()
45 | self.bn1 = nn.BatchNorm2d(in_planes)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
48 | padding=0, bias=False)
49 | self.droprate = dropRate
50 | def forward(self, x):
51 | out = self.conv1(self.relu(self.bn1(x)))
52 | if self.droprate > 0:
53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
54 | return F.avg_pool2d(out, 2)
55 |
56 | class DenseBlock(nn.Module):
57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
58 | super(DenseBlock, self).__init__()
59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
61 | layers = []
62 | for i in range(int(nb_layers)):
63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))
64 | return nn.Sequential(*layers)
65 | def forward(self, x):
66 | return self.layer(x)
67 |
68 | class DenseNet3(nn.Module):
69 | def __init__(self, depth, num_classes, growth_rate=12,
70 | reduction=0.5, bottleneck=True, dropRate=0.0):
71 | super(DenseNet3, self).__init__()
72 | in_planes = 2 * growth_rate
73 | n = (depth - 4) / 3
74 | if bottleneck == True:
75 | n = n/2
76 | block = BottleneckBlock
77 | else:
78 | block = BasicBlock
79 | # 1st conv before any dense block
80 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,
81 | padding=1, bias=False)
82 | # 1st block
83 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
84 | in_planes = int(in_planes+n*growth_rate)
85 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
86 | in_planes = int(math.floor(in_planes*reduction))
87 | # 2nd block
88 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
89 | in_planes = int(in_planes+n*growth_rate)
90 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
91 | in_planes = int(math.floor(in_planes*reduction))
92 | # 3rd block
93 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
94 | in_planes = int(in_planes+n*growth_rate)
95 | # global average pooling and classifier
96 | self.bn1 = nn.BatchNorm2d(in_planes)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.fc = nn.Linear(in_planes, num_classes)
99 | self.in_planes = in_planes
100 |
101 | for m in self.modules():
102 | if isinstance(m, nn.Conv2d):
103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
104 | m.weight.data.normal_(0, math.sqrt(2. / n))
105 | elif isinstance(m, nn.BatchNorm2d):
106 | m.weight.data.fill_(1)
107 | m.bias.data.zero_()
108 | elif isinstance(m, nn.Linear):
109 | m.bias.data.zero_()
110 |
111 | def forward(self, x):
112 | out = self.conv1(x)
113 | out = self.trans1(self.block1(out))
114 | out = self.trans2(self.block2(out))
115 | out = self.block3(out)
116 | out = self.relu(self.bn1(out))
117 | out = F.avg_pool2d(out, 8)
118 | out = out.view(-1, self.in_planes)
119 | return self.fc(out)
120 |
121 | # function to extact the multiple features
122 | def feature_list(self, x):
123 | out_list = []
124 | out = self.conv1(x)
125 | out_list.append(out)
126 | out = self.trans1(self.block1(out))
127 | out_list.append(out)
128 | out = self.trans2(self.block2(out))
129 | out_list.append(out)
130 | out = self.block3(out)
131 | out = self.relu(self.bn1(out))
132 | out_list.append(out)
133 | out = F.avg_pool2d(out, 8)
134 | out = out.view(-1, self.in_planes)
135 |
136 | return self.fc(out), out_list
137 |
138 | def intermediate_forward(self, x, layer_index):
139 | out = self.conv1(x)
140 | if layer_index == 1:
141 | out = self.trans1(self.block1(out))
142 | elif layer_index == 2:
143 | out = self.trans1(self.block1(out))
144 | out = self.trans2(self.block2(out))
145 | elif layer_index == 3:
146 | out = self.trans1(self.block1(out))
147 | out = self.trans2(self.block2(out))
148 | out = self.block3(out)
149 | out = self.relu(self.bn1(out))
150 | return out
151 |
152 | # function to extact the penultimate features
153 | def penultimate_forward(self, x):
154 | out = self.conv1(x)
155 | out = self.trans1(self.block1(out))
156 | out = self.trans2(self.block2(out))
157 | out = self.block3(out)
158 | penultimate = self.relu(self.bn1(out))
159 | out = F.avg_pool2d(penultimate, 8)
160 | out = out.view(-1, self.in_planes)
161 | return self.fc(out), penultimate
162 |
163 | class DenseNet_DeepMCDD(nn.Module):
164 | def __init__(self, latent_size=342, num_classes=10):
165 | super(DenseNet_DeepMCDD, self).__init__()
166 | self.latent_size = latent_size
167 | self.num_classes = num_classes
168 |
169 | self.centers = torch.nn.Parameter(torch.zeros([num_classes, latent_size]), requires_grad=True)
170 | self.alphas = torch.nn.Parameter(torch.ones(num_classes), requires_grad=True)
171 | self.logsigmas = torch.nn.Parameter(torch.zeros(num_classes), requires_grad=True)
172 | self.radii = torch.zeros(num_classes).cuda()
173 |
174 | self.fe = DenseNet3(100, num_classes)
175 | self.init_weights()
176 |
177 | def load_fe_weights(self, weights):
178 | self.fe.load_state_dict(weights)
179 |
180 | def load_weights(self, weights):
181 | self.load_state_dict(weights)
182 |
183 | def init_weights(self):
184 | for m in self.modules():
185 | if isinstance(m, nn.Conv2d):
186 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
187 | m.weight.data.normal_(0, math.sqrt(2. / n))
188 | elif isinstance(m, nn.BatchNorm2d):
189 | m.weight.data.fill_(1)
190 | m.bias.data.zero_()
191 | elif isinstance(m, nn.Linear):
192 | m.bias.data.zero_()
193 | nn.init.xavier_uniform_(self.centers)
194 | nn.init.zeros_(self.alphas)
195 | nn.init.zeros_(self.logsigmas)
196 |
197 | def _forward(self, x):
198 | out = self.fe.conv1(x)
199 | out = self.fe.trans1(self.fe.block1(out))
200 | out = self.fe.trans2(self.fe.block2(out))
201 | out = self.fe.block3(out)
202 | out = self.fe.relu(self.fe.bn1(out))
203 | out = F.avg_pool2d(out, 8)
204 | out = out.view(-1, self.fe.in_planes)
205 | return out
206 |
207 | def forward(self, x):
208 | out = self._forward(x)
209 | out = out.unsqueeze(dim=1).repeat([1, self.num_classes, 1])
210 | scores = torch.sum((out - self.centers)**2, dim=2) / 2 / torch.exp(2 * F.relu(self.logsigmas)) + self.latent_size * F.relu(self.logsigmas)
211 | return scores
212 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | BasicBlock and Bottleneck module is from the original ResNet paper:
3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
5 | PreActBlock and PreActBottleneck module is from the later paper:
6 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
8 | Original code is from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
9 | '''
10 | import os, math
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.nn.parameter import Parameter
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | expansion = 1
22 |
23 | def __init__(self, in_planes, planes, stride=1):
24 | super(BasicBlock, self).__init__()
25 | self.conv1 = conv3x3(in_planes, planes, stride)
26 | self.bn1 = nn.BatchNorm2d(planes)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 |
30 | self.shortcut = nn.Sequential()
31 | if stride != 1 or in_planes != self.expansion*planes:
32 | self.shortcut = nn.Sequential(
33 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
34 | nn.BatchNorm2d(self.expansion*planes)
35 | )
36 |
37 | def forward(self, x):
38 | out = F.relu(self.bn1(self.conv1(x)))
39 | out = self.bn2(self.conv2(out))
40 | out += self.shortcut(x)
41 | out = F.relu(out)
42 | return out
43 |
44 |
45 | class PreActBlock(nn.Module):
46 | '''Pre-activation version of the BasicBlock.'''
47 | expansion = 1
48 |
49 | def __init__(self, in_planes, planes, stride=1):
50 | super(PreActBlock, self).__init__()
51 | self.bn1 = nn.BatchNorm2d(in_planes)
52 | self.conv1 = conv3x3(in_planes, planes, stride)
53 | self.bn2 = nn.BatchNorm2d(planes)
54 | self.conv2 = conv3x3(planes, planes)
55 |
56 | if stride != 1 or in_planes != self.expansion*planes:
57 | self.shortcut = nn.Sequential(
58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
59 | )
60 |
61 | def forward(self, x):
62 | out = F.relu(self.bn1(x))
63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
64 | out = self.conv1(out)
65 | out = self.conv2(F.relu(self.bn2(out)))
66 | out += shortcut
67 | return out
68 |
69 |
70 | class Bottleneck(nn.Module):
71 | expansion = 4
72 |
73 | def __init__(self, in_planes, planes, stride=1):
74 | super(Bottleneck, self).__init__()
75 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
76 | self.bn1 = nn.BatchNorm2d(planes)
77 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
78 | self.bn2 = nn.BatchNorm2d(planes)
79 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
80 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
81 |
82 | self.shortcut = nn.Sequential()
83 | if stride != 1 or in_planes != self.expansion*planes:
84 | self.shortcut = nn.Sequential(
85 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
86 | nn.BatchNorm2d(self.expansion*planes)
87 | )
88 |
89 | def forward(self, x):
90 | out = F.relu(self.bn1(self.conv1(x)))
91 | out = F.relu(self.bn2(self.conv2(out)))
92 | out = self.bn3(self.conv3(out))
93 | out += self.shortcut(x)
94 | out = F.relu(out)
95 | return out
96 |
97 |
98 | class PreActBottleneck(nn.Module):
99 | '''Pre-activation version of the original Bottleneck module.'''
100 | expansion = 4
101 |
102 | def __init__(self, in_planes, planes, stride=1):
103 | super(PreActBottleneck, self).__init__()
104 | self.bn1 = nn.BatchNorm2d(in_planes)
105 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
106 | self.bn2 = nn.BatchNorm2d(planes)
107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
108 | self.bn3 = nn.BatchNorm2d(planes)
109 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
110 |
111 | if stride != 1 or in_planes != self.expansion*planes:
112 | self.shortcut = nn.Sequential(
113 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
114 | )
115 |
116 | def forward(self, x):
117 | out = F.relu(self.bn1(x))
118 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
119 | out = self.conv1(out)
120 | out = self.conv2(F.relu(self.bn2(out)))
121 | out = self.conv3(F.relu(self.bn3(out)))
122 | out += shortcut
123 | return out
124 |
125 |
126 | class ResNet(nn.Module):
127 | def __init__(self, block, num_blocks, num_classes=10):
128 | super(ResNet, self).__init__()
129 | self.in_planes = 64
130 |
131 | self.conv1 = conv3x3(3,64)
132 | self.bn1 = nn.BatchNorm2d(64)
133 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
134 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
135 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
136 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
137 | self.linear = nn.Linear(512*block.expansion, num_classes)
138 |
139 | self.init_weights()
140 |
141 | def init_weights(self):
142 | for m in self.modules():
143 | if isinstance(m, nn.Conv2d):
144 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
145 | m.weight.data.normal_(0, math.sqrt(2. / n))
146 | elif isinstance(m, nn.BatchNorm2d):
147 | m.weight.data.fill_(1)
148 | m.bias.data.zero_()
149 | elif isinstance(m, nn.Linear):
150 | m.bias.data.zero_()
151 |
152 | def _make_layer(self, block, planes, num_blocks, stride):
153 | strides = [stride] + [1]*(num_blocks-1)
154 | layers = []
155 | for stride in strides:
156 | layers.append(block(self.in_planes, planes, stride))
157 | self.in_planes = planes * block.expansion
158 | return nn.Sequential(*layers)
159 |
160 | def forward(self, x):
161 | out = F.relu(self.bn1(self.conv1(x)))
162 | out = self.layer1(out)
163 | out = self.layer2(out)
164 | out = self.layer3(out)
165 | out = self.layer4(out)
166 | out = F.avg_pool2d(out, 4)
167 | out = out.view(out.size(0), -1)
168 | y = self.linear(out)
169 | return y
170 |
171 | # function to extact the multiple features
172 | def feature_list(self, x):
173 | out_list = []
174 | out = F.relu(self.bn1(self.conv1(x)))
175 | out_list.append(out)
176 | out = self.layer1(out)
177 | out_list.append(out)
178 | out = self.layer2(out)
179 | out_list.append(out)
180 | out = self.layer3(out)
181 | out_list.append(out)
182 | out = self.layer4(out)
183 | out_list.append(out)
184 | out = F.avg_pool2d(out, 4)
185 | out = out.view(out.size(0), -1)
186 | y = self.linear(out)
187 | return y, out_list
188 |
189 | # function to extact a specific feature
190 | def intermediate_forward(self, x, layer_index):
191 | out = F.relu(self.bn1(self.conv1(x)))
192 | if layer_index == 1:
193 | out = self.layer1(out)
194 | elif layer_index == 2:
195 | out = self.layer1(out)
196 | out = self.layer2(out)
197 | elif layer_index == 3:
198 | out = self.layer1(out)
199 | out = self.layer2(out)
200 | out = self.layer3(out)
201 | elif layer_index == 4:
202 | out = self.layer1(out)
203 | out = self.layer2(out)
204 | out = self.layer3(out)
205 | out = self.layer4(out)
206 | return out
207 |
208 | # function to extact the penultimate features
209 | def penultimate_forward(self, x):
210 | out = F.relu(self.bn1(self.conv1(x)))
211 | out = self.layer1(out)
212 | out = self.layer2(out)
213 | out = self.layer3(out)
214 | penultimate = self.layer4(out)
215 | out = F.avg_pool2d(penultimate, 4)
216 | out = out.view(out.size(0), -1)
217 | y = self.linear(out)
218 | return y, penultimate
219 |
220 | def ResNet18(num_c):
221 | return ResNet(PreActBlock, [2,2,2,2], num_classes=num_c)
222 |
223 | def ResNet34(num_c):
224 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_c)
225 |
226 | def ResNet50():
227 | return ResNet(Bottleneck, [3,4,6,3])
228 |
229 | def ResNet101():
230 | return ResNet(Bottleneck, [3,4,23,3])
231 |
232 | def ResNet152():
233 | return ResNet(Bottleneck, [3,8,36,3])
234 |
235 | class ResNet_DeepMCDD(nn.Module):
236 | def __init__(self, block_expansion=1, latent_size=512, num_classes=10):
237 | super(ResNet_DeepMCDD, self).__init__()
238 | self.latent_size = latent_size
239 | self.num_classes = num_classes
240 |
241 | self.centers = torch.nn.Parameter(torch.zeros([num_classes, latent_size]), requires_grad=True)
242 | self.alphas = torch.nn.Parameter(torch.ones(num_classes), requires_grad=True)
243 | self.logsigmas = torch.nn.Parameter(torch.zeros(num_classes), requires_grad=True)
244 | self.radii = torch.ones(num_classes).cuda()
245 |
246 | self.fe = ResNet34(num_c=num_classes)
247 | self.init_weights()
248 |
249 | def load_fe_weights(self, weights):
250 | self.fe.load_state_dict(weights)
251 |
252 | def load_weights(self, weights):
253 | self.load_state_dict(weights)
254 |
255 | def init_weights(self):
256 | self.fe.init_weights()
257 | for m in self.modules():
258 | if isinstance(m, nn.Conv2d):
259 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
260 | m.weight.data.normal_(0, math.sqrt(2. / n))
261 | elif isinstance(m, nn.BatchNorm2d):
262 | m.weight.data.fill_(1)
263 | m.bias.data.zero_()
264 | elif isinstance(m, nn.Linear):
265 | m.bias.data.zero_()
266 | nn.init.xavier_uniform_(self.centers)
267 | nn.init.zeros_(self.alphas)
268 | nn.init.zeros_(self.logsigmas)
269 |
270 | def _forward(self, x):
271 | out = F.relu(self.fe.bn1(self.fe.conv1(x)))
272 | out = self.fe.layer1(out)
273 | out = self.fe.layer2(out)
274 | out = self.fe.layer3(out)
275 | out = self.fe.layer4(out)
276 | out = F.avg_pool2d(out, 4)
277 | out = out.view(out.size(0), -1)
278 | return out
279 |
280 | def forward(self, x):
281 | out = self._forward(x)
282 | out = out.unsqueeze(dim=1).repeat([1, self.num_classes, 1])
283 | scores = torch.sum((out - self.centers)**2, dim=2) / 2 / torch.exp(2 * F.relu(self.logsigmas)) + self.latent_size * F.relu(self.logsigmas)
284 | return scores
285 |
--------------------------------------------------------------------------------