├── output └── .gitignore ├── figure └── deepmcdd.png ├── models ├── __init__.py ├── mlp.py ├── densenet.py └── resnet.py ├── README.md ├── dataloader_table.py ├── train_deepmcdd_table.py ├── dataloader_image.py ├── utils.py └── train_deepmcdd_image.py /output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /figure/deepmcdd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donalee/DeepMCDD/HEAD/figure/deepmcdd.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import * 2 | from .resnet import * 3 | from .densenet import * 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-class Data Description (Deep-MCDD) for Out-of-distribution Detection 2 | 3 | This is the author code of ["Multi-class Data Description for Out-of-distribution Detection"](https://dl.acm.org/doi/abs/10.1145/3394486.3403189). 4 | Some codes are implemented based on [Deep Mahalanobis Detector](https://github.com/pokaxpoka/deep_Mahalanobis_detector). 5 | 6 | ## Overview 7 | 8 |

9 | 10 |

11 | 12 | ## Downloading tabular datasets 13 | 14 | The four multi-class tabular datsets reported in the paper can be downloaded from the below links. 15 | We provide the preprocessed version of the datasets, which are converted into the numpy array format, to allow users to easily load them. 16 | For more details of the datasets, please refer to the UCI repository. 17 | Place them in the directory `./table_data/`. 18 | 19 | - **GasSensor**: [[Numpy format]](http://di.postech.ac.kr/donalee/deepmcdd/gas_preproc.npy) [[Raw format]](https://archive.ics.uci.edu/ml/datasets/Gas+Sensor+Array+Drift+Dataset#) 20 | - **Shuttle**: [[Numpy format]](http://di.postech.ac.kr/donalee/deepmcdd/shuttle_preproc.npy) [[Raw format]](https://archive.ics.uci.edu/ml/datasets/Statlog+(Shuttle)) 21 | - **DriveDiagnosis**: [[Numpy format]](http://di.postech.ac.kr/donalee/deepmcdd/drive_preproc.npy) [[Raw format]](https://archive.ics.uci.edu/ml/datasets/Dataset+for+Sensorless+Drive+Diagnosis) 22 | - **MNIST**: [[Numpy format]](http://di.postech.ac.kr/donalee/deepmcdd/mnist_preproc.npy) [[Raw format]](http://yann.lecun.com/exdb/mnist/) 23 | 24 | ## Downloading image datasets 25 | 26 | The three in-distribution datasets (i.e., **SVHN**, **CIFAR-10**, and **CIFAR-100**) would be automatically downloaded via `torchvision`. 27 | We use the download links of two additional out-of-distributin datasets (i.e., **TinyImageNet** and **LSUN**) from [Deep Mahalanobis Detector](https://github.com/pokaxpoka/deep_Mahalanobis_detector) and [ODIN Detector](https://github.com/ShiyuLiang/odin-pytorch). 28 | Place them in the directory `./image_data/`. 29 | 30 | ## Running the codes 31 | 32 | - python 33 | - torch (GPU version only) 34 | 35 | #### Training MLP with Deep-MCDD for tabular data 36 | ``` 37 | python train_deepmcdd_table.py --dataset gas --net_type mlp --oodclass_idx 0 38 | ``` 39 | 40 | #### Training CNN with Deep-MCDD for image data 41 | ``` 42 | python train_deepmcdd_image.py --dataset cifar10 --net_type resnet 43 | ``` 44 | 45 | ## Citation 46 | ``` 47 | @inproceedings{lee2020multi, 48 | author = {Lee, Dongha and Yu, Sehun and Yu, Hwanjo}, 49 | title = {Multi-Class Data Description for Out-of-Distribution Detection}, 50 | year = {2020}, 51 | booktitle = {Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining}, 52 | pages = {1362–1370} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /dataloader_table.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader, TensorDataset 4 | import numpy as np 5 | 6 | def get_table_data(batch_size, data_dir, dataset, oodclass_idx, fold_idx, **kwargs): 7 | 8 | data_path = os.path.join(data_dir, dataset + '_preproc.npy') 9 | features, labels, num_classes = np.load(data_path, allow_pickle=True) 10 | n_data = len(labels) 11 | 12 | id_classes = [c for c in range(num_classes) if c != oodclass_idx] 13 | id_indices = {c: [i for i in range(len(labels)) if labels[i] == c] for c in id_classes} 14 | 15 | np.random.seed(0) 16 | for c in id_classes: 17 | np.random.shuffle(id_indices[c]) 18 | test_id_indices = {c: id_indices[c][int(0.2*fold_idx*len(id_indices[c])):int(0.2*(fold_idx+1)*len(id_indices[c]))] for c in id_classes} 19 | 20 | id_indices = np.concatenate(list(id_indices.values())) 21 | test_id_indices = np.concatenate(list(test_id_indices.values())) 22 | train_id_indices = np.array([i for i in id_indices if i not in test_id_indices]) 23 | ood_indices = [i for i in range(len(labels)) if labels[i] == oodclass_idx] 24 | 25 | for i in range(len(labels)): 26 | if labels[i] == oodclass_idx: 27 | labels[i] = -1 28 | elif labels[i] > oodclass_idx: 29 | labels[i] -= 1 30 | 31 | train_dataset = TensorDataset(torch.Tensor(features[train_id_indices]), 32 | torch.LongTensor(labels[train_id_indices])) 33 | test_id_dataset = TensorDataset(torch.Tensor(features[test_id_indices]), 34 | torch.LongTensor(labels[test_id_indices])) 35 | test_ood_dataset = TensorDataset(torch.Tensor(features[ood_indices]), 36 | torch.LongTensor(labels[ood_indices])) 37 | 38 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=1) 39 | test_id_loader = DataLoader(dataset=test_id_dataset, batch_size=batch_size, shuffle=True, num_workers=1) 40 | test_ood_loader = DataLoader(dataset=test_ood_dataset, batch_size=batch_size, shuffle=True, num_workers=1) 41 | 42 | return train_loader, test_id_loader, test_ood_loader 43 | 44 | def get_total_table_data(batch_size, data_dir, dataset, fold_idx, **kwargs): 45 | 46 | data_path = os.path.join(data_dir, dataset + '_preproc.npy') 47 | features, labels, num_classes = np.load(data_path, allow_pickle=True) 48 | n_data = len(labels) 49 | id_indices = [i for i in range(len(labels))] 50 | 51 | np.random.seed(0) 52 | np.random.shuffle(id_indices) 53 | test_id_indices = id_indices[int(0.2*fold_idx*len(id_indices)):int(0.2*(fold_idx+1)*len(id_indices))] 54 | train_id_indices = np.array([i for i in id_indices if i not in test_id_indices]) 55 | 56 | test_dataset = TensorDataset(torch.Tensor(features[test_id_indices]), 57 | torch.LongTensor(labels[test_id_indices])) 58 | train_dataset = TensorDataset(torch.Tensor(features[train_id_indices]), 59 | torch.LongTensor(labels[train_id_indices])) 60 | 61 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=1) 62 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=1) 63 | 64 | return train_loader, test_loader 65 | 66 | -------------------------------------------------------------------------------- /train_deepmcdd_table.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import torch 3 | import numpy as np 4 | 5 | import models 6 | from dataloader_table import get_table_data 7 | from utils import compute_confscores, compute_metrics, print_ood_results, print_ood_results_total 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset', required=True, help='gas | shuttle | drive | mnist') 11 | parser.add_argument('--net_type', required=True, help='mlp') 12 | parser.add_argument('--datadir', default='./table_data/', help='path to dataset') 13 | parser.add_argument('--outdir', default='./output/', help='folder to output results') 14 | parser.add_argument('--oodclass_idx', type=int, default=0, help='index of the OOD class') 15 | parser.add_argument('--batch_size', type=int, default=200, help='batch size for data loader') 16 | parser.add_argument('--latent_size', type=int, default=128, help='dimension size for latent representation') 17 | parser.add_argument('--num_layers', type=int, default=3, help='the number of hidden layers in MLP') 18 | parser.add_argument('--num_folds', type=int, default=5, help='the number of cross-validation folds') 19 | parser.add_argument('--num_epochs', type=int, default=10, help='the number of epochs for training sc-layers') 20 | parser.add_argument('--learning_rate', type=float, default=0.001, help='initial learning rate of Adam optimizer') 21 | parser.add_argument('--reg_lambda', type=float, default=1.0, help='regularization coefficient') 22 | parser.add_argument('--gpu', type=int, default=0, help='gpu index') 23 | 24 | args = parser.parse_args() 25 | print(args) 26 | 27 | def main(): 28 | outdir = os.path.join(args.outdir, args.net_type + '_' + args.dataset) 29 | 30 | if os.path.isdir(outdir) == False: 31 | os.mkdir(outdir) 32 | 33 | torch.manual_seed(0) 34 | torch.cuda.manual_seed_all(0) 35 | torch.cuda.set_device(args.gpu) 36 | 37 | best_idacc_list, best_oodacc_list = [], [] 38 | for fold_idx in range(args.num_folds): 39 | 40 | train_loader, test_id_loader, test_ood_loader = get_table_data(args.batch_size, args.datadir, args.dataset, args.oodclass_idx, fold_idx) 41 | 42 | if args.dataset == 'gas': 43 | num_classes, num_features = 6, 128 44 | elif args.dataset == 'drive': 45 | num_classes, num_features = 11, 48 46 | elif args.dataset == 'shuttle': 47 | num_classes, num_features = 7, 9 48 | elif args.dataset == 'mnist': 49 | num_classes, num_features = 10, 784 50 | 51 | model = models.MLP_DeepMCDD(num_features, args.num_layers*[args.latent_size], num_classes=num_classes-1) 52 | model.cuda() 53 | 54 | ce_loss = torch.nn.CrossEntropyLoss() 55 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 56 | 57 | idacc_list, oodacc_list = [], [] 58 | total_step = len(train_loader) 59 | for epoch in range(args.num_epochs): 60 | model.train() 61 | total_loss = 0.0 62 | 63 | for i, (data, labels) in enumerate(train_loader): 64 | data, labels = data.cuda(), labels.cuda() 65 | dists = model(data) 66 | scores = - dists + model.alphas 67 | 68 | label_mask = torch.zeros(labels.size(0), model.num_classes).cuda().scatter_(1, labels.unsqueeze(dim=1), 1) 69 | 70 | pull_loss = torch.mean(torch.sum(torch.mul(label_mask, dists), dim=1)) 71 | push_loss = ce_loss(scores, labels) 72 | loss = args.reg_lambda * pull_loss + push_loss 73 | 74 | optimizer.zero_grad() 75 | loss.backward() 76 | optimizer.step() 77 | 78 | total_loss += loss.item() 79 | 80 | model.eval() 81 | with torch.no_grad(): 82 | # (1) evaluate ID classification 83 | correct, total = 0, 0 84 | for data, labels in test_id_loader: 85 | data, labels = data.cuda(), labels.cuda() 86 | scores = - model(data) + model.alphas 87 | _, predicted = torch.max(scores, 1) 88 | total += labels.size(0) 89 | correct += (predicted == labels).sum().item() 90 | idacc_list.append(100 * correct / total) 91 | 92 | # (2) evaluate OOD detection 93 | compute_confscores(model, test_id_loader, outdir, True) 94 | compute_confscores(model, test_ood_loader, outdir, False) 95 | oodacc_list.append(compute_metrics(outdir)) 96 | 97 | best_idacc = max(idacc_list) 98 | best_oodacc = oodacc_list[idacc_list.index(best_idacc)] 99 | 100 | print('== {fidx:1d}-th fold results =='.format(fidx=fold_idx+1)) 101 | print('The best ID accuracy on "{idset:s}" test samples : {val:6.2f}'.format(idset=args.dataset, val=best_idacc)) 102 | print('The best OOD accuracy on "{oodset:s}" test samples :'.format(oodset=args.dataset+'_'+str(args.oodclass_idx))) 103 | print_ood_results(best_oodacc) 104 | 105 | best_idacc_list.append(best_idacc) 106 | best_oodacc_list.append(best_oodacc) 107 | 108 | print('== Final results ==') 109 | print('The best ID accuracy on "{idset:s}" test samples : {mean:6.2f} ({std:6.3f})'.format(idset=args.dataset, mean=np.mean(best_idacc_list), std=np.std(best_idacc_list))) 110 | print('The best OOD accuracy on "{oodset:s}" test samples :'.format(oodset='class_'+str(args.oodclass_idx))) 111 | print_ood_results_total(best_oodacc_list) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /dataloader_image.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is implemented based on the author code of 3 | # Lee et al., "A simple unified framework for detecting out-of-distribution samples and adversarial attacks", in NeurIPS 2018. 4 | # 5 | 6 | import os 7 | import torch 8 | from torchvision import datasets 9 | from torch.utils.data import DataLoader 10 | 11 | def get_svhn(batch_size, train_TF, test_TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs): 12 | data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data')) 13 | num_workers = kwargs.setdefault('num_workers', 1) 14 | kwargs.pop('input_size', None) 15 | 16 | ds = [] 17 | if train: 18 | train_loader = torch.utils.data.DataLoader( 19 | datasets.SVHN( 20 | root=data_root, split='train', download=True, 21 | transform=train_TF, 22 | ), 23 | batch_size=batch_size, shuffle=True, **kwargs) 24 | ds.append(train_loader) 25 | 26 | if val: 27 | test_loader = torch.utils.data.DataLoader( 28 | datasets.SVHN( 29 | root=data_root, split='test', download=True, 30 | transform=test_TF, 31 | ), 32 | batch_size=batch_size, shuffle=False, **kwargs) 33 | ds.append(test_loader) 34 | ds = ds[0] if len(ds) == 1 else ds 35 | return ds 36 | 37 | def get_cifar10(batch_size, train_TF, test_TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs): 38 | data_root = os.path.expanduser(os.path.join(data_root, 'cifar10-data')) 39 | num_workers = kwargs.setdefault('num_workers', 1) 40 | kwargs.pop('input_size', None) 41 | ds = [] 42 | 43 | if train: 44 | train_loader = torch.utils.data.DataLoader( 45 | datasets.CIFAR10( 46 | root=data_root, train=True, download=True, 47 | transform=train_TF), 48 | batch_size=batch_size, shuffle=True, **kwargs) 49 | ds.append(train_loader) 50 | 51 | if val: 52 | test_loader = torch.utils.data.DataLoader( 53 | datasets.CIFAR10( 54 | root=data_root, train=False, download=True, 55 | transform=test_TF), 56 | batch_size=batch_size, shuffle=False, **kwargs) 57 | ds.append(test_loader) 58 | 59 | ds = ds[0] if len(ds) == 1 else ds 60 | return ds 61 | 62 | def get_cifar100(batch_size, train_TF, test_TF, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs): 63 | data_root = os.path.expanduser(os.path.join(data_root, 'cifar100-data')) 64 | num_workers = kwargs.setdefault('num_workers', 1) 65 | kwargs.pop('input_size', None) 66 | ds = [] 67 | 68 | if train: 69 | train_loader = torch.utils.data.DataLoader( 70 | datasets.CIFAR100( 71 | root=data_root, train=True, download=True, 72 | transform=train_TF), 73 | batch_size=batch_size, shuffle=True, **kwargs) 74 | ds.append(train_loader) 75 | 76 | if val: 77 | test_loader = torch.utils.data.DataLoader( 78 | datasets.CIFAR100( 79 | root=data_root, train=False, download=True, 80 | transform=test_TF), 81 | batch_size=batch_size, shuffle=False, **kwargs) 82 | ds.append(test_loader) 83 | 84 | ds = ds[0] if len(ds) == 1 else ds 85 | return ds 86 | 87 | def get_id_image_data(data_type, batch_size, train_TF, test_TF, dataroot): 88 | if data_type == 'cifar10': 89 | train_loader, test_loader = get_cifar10(batch_size=batch_size, train_TF=train_TF, test_TF=test_TF, data_root=dataroot, num_workers=1) 90 | elif data_type == 'cifar100': 91 | train_loader, test_loader = get_cifar100(batch_size=batch_size, train_TF=train_TF, test_TF=test_TF, data_root=dataroot, num_workers=1) 92 | elif data_type == 'svhn': 93 | train_loader, test_loader = get_svhn(batch_size=batch_size, train_TF=train_TF, test_TF=test_TF, data_root=dataroot, num_workers=1) 94 | 95 | return train_loader, test_loader 96 | 97 | def get_ood_image_data(data_type, batch_size, input_TF, dataroot): 98 | if data_type == 'cifar10': 99 | _, test_loader = get_cifar10(batch_size=batch_size, train_TF=input_TF, test_TF=input_TF, data_root=dataroot, num_workers=1) 100 | elif data_type == 'svhn': 101 | _, test_loader = get_svhn(batch_size=batch_size, train_TF=input_TF, test_TF=input_TF, data_root=dataroot, num_workers=1) 102 | elif data_type == 'cifar100': 103 | _, test_loader = get_cifar100(batch_size=batch_size, train_TF=input_TF, test_TF=input_TF, data_root=dataroot, num_workers=1) 104 | elif data_type == 'imagenet_crop': 105 | dataroot = os.path.expanduser(os.path.join(dataroot, 'Imagenet_crop')) 106 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF) 107 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1) 108 | elif data_type == 'lsun_crop': 109 | dataroot = os.path.expanduser(os.path.join(dataroot, 'LSUN_crop')) 110 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF) 111 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1) 112 | elif data_type == 'imagenet_resize': 113 | dataroot = os.path.expanduser(os.path.join(dataroot, 'Imagenet_resize')) 114 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF) 115 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1) 116 | elif data_type == 'lsun_resize': 117 | dataroot = os.path.expanduser(os.path.join(dataroot, 'LSUN_resize')) 118 | testsetout = datasets.ImageFolder(dataroot, transform=input_TF) 119 | test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=1) 120 | 121 | return test_loader 122 | 123 | 124 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is implemented based on the author code of 3 | # Lee et al., "A simple unified framework for detecting out-of-distribution samples and adversarial attacks", in NeurIPS 2018. 4 | # 5 | 6 | import os 7 | import torch 8 | import numpy as np 9 | 10 | def compute_confscores(model, test_loader, outdir, id_flag): 11 | total = 0 12 | if id_flag == True: 13 | outfile = os.path.join(outdir, 'confscores_id.txt') 14 | else: 15 | outfile = os.path.join(outdir, 'confscores_ood.txt') 16 | 17 | f = open(outfile, 'w') 18 | 19 | for data, _ in test_loader: 20 | dists = model(data.cuda()) 21 | confscores, _ = torch.min(dists, dim=1) 22 | total += data.size(0) 23 | 24 | for i in range(data.size(0)): 25 | f.write("{}\n".format(-confscores[i])) 26 | 27 | f.close() 28 | 29 | def get_auroc_curve(indir): 30 | known = np.loadtxt(os.path.join(indir, 'confscores_id.txt'), delimiter='\n') 31 | novel = np.loadtxt(os.path.join(indir, 'confscores_ood.txt'), delimiter='\n') 32 | known.sort() 33 | novel.sort() 34 | 35 | end = np.max([np.max(known), np.max(novel)]) 36 | start = np.min([np.min(known),np.min(novel)]) 37 | 38 | num_k = known.shape[0] 39 | num_n = novel.shape[0] 40 | 41 | tp = -np.ones([num_k+num_n+1], dtype=int) 42 | fp = -np.ones([num_k+num_n+1], dtype=int) 43 | tp[0], fp[0] = num_k, num_n 44 | k, n = 0, 0 45 | for l in range(num_k+num_n): 46 | if k == num_k: 47 | tp[l+1:] = tp[l] 48 | fp[l+1:] = np.arange(fp[l]-1, -1, -1) 49 | break 50 | elif n == num_n: 51 | tp[l+1:] = np.arange(tp[l]-1, -1, -1) 52 | fp[l+1:] = fp[l] 53 | break 54 | else: 55 | if novel[n] < known[k]: 56 | n += 1 57 | tp[l+1] = tp[l] 58 | fp[l+1] = fp[l] - 1 59 | else: 60 | k += 1 61 | tp[l+1] = tp[l] - 1 62 | fp[l+1] = fp[l] 63 | tpr85_pos = np.abs(tp / num_k - .85).argmin() 64 | tpr95_pos = np.abs(tp / num_k - .95).argmin() 65 | tnr_at_tpr85 = 1. - fp[tpr85_pos] / num_n 66 | tnr_at_tpr95 = 1. - fp[tpr95_pos] / num_n 67 | return tp, fp, tnr_at_tpr85, tnr_at_tpr95 68 | 69 | def compute_metrics(dir_name, verbose=False): 70 | tp, fp, tnr_at_tpr85, tnr_at_tpr95 = get_auroc_curve(dir_name) 71 | results = dict() 72 | mtypes = ['TNR85', 'TNR95', 'AUROC', 'DTACC', 'AUIN', 'AUOUT'] 73 | if verbose: 74 | print(' ', end='') 75 | for mtype in mtypes: 76 | print(' {mtype:6s}'.format(mtype=mtype), end='') 77 | print('') 78 | 79 | if verbose: 80 | print('{stype:5s} '.format(stype=stype), end='') 81 | results = dict() 82 | 83 | # TNR85 84 | mtype = 'TNR85' 85 | results[mtype] = tnr_at_tpr85 86 | if verbose: 87 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='') 88 | 89 | # TNR95 90 | mtype = 'TNR95' 91 | results[mtype] = tnr_at_tpr95 92 | if verbose: 93 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='') 94 | 95 | # AUROC 96 | mtype = 'AUROC' 97 | tpr = np.concatenate([[1.], tp/tp[0], [0.]]) 98 | fpr = np.concatenate([[1.], fp/fp[0], [0.]]) 99 | results[mtype] = -np.trapz(1. - fpr, tpr) 100 | if verbose: 101 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='') 102 | 103 | # DTACC 104 | mtype = 'DTACC' 105 | results[mtype] = .5 * (tp/tp[0] + 1. - fp/fp[0]).max() 106 | if verbose: 107 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='') 108 | 109 | # AUIN 110 | mtype = 'AUIN' 111 | denom = tp + fp 112 | denom[denom == 0.] = -1. 113 | pin_ind = np.concatenate([[True], denom > 0., [True]]) 114 | pin = np.concatenate([[.5], tp/denom, [0.]]) 115 | results[mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind]) 116 | if verbose: 117 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='') 118 | 119 | # AUOUT 120 | mtype = 'AUOUT' 121 | denom = tp[0] - tp + fp[0] - fp 122 | denom[denom == 0.] = -1. 123 | pout_ind = np.concatenate([[True], denom > 0., [True]]) 124 | pout = np.concatenate([[0.], (fp[0] - fp)/denom, [.5]]) 125 | results[mtype] = np.trapz(pout[pout_ind], 1. - fpr[pout_ind]) 126 | if verbose: 127 | print(' {val:6.3f}'.format(val=100.*results[mtype]), end='') 128 | print('') 129 | 130 | return results 131 | 132 | def print_ood_results(ood_result): 133 | 134 | for mtype in ['TNR85', 'TNR95', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']: 135 | print(' {mtype:6s}'.format(mtype=mtype), end='') 136 | print('\n{val:6.2f}'.format(val=100.*ood_result['TNR85']), end='') 137 | print(' {val:6.2f}'.format(val=100.*ood_result['TNR95']), end='') 138 | print(' {val:6.2f}'.format(val=100.*ood_result['AUROC']), end='') 139 | print(' {val:6.2f}'.format(val=100.*ood_result['DTACC']), end='') 140 | print(' {val:6.2f}'.format(val=100.*ood_result['AUIN']), end='') 141 | print(' {val:6.2f}\n'.format(val=100.*ood_result['AUOUT']), end='') 142 | print('') 143 | 144 | def print_ood_results_total(ood_result_list): 145 | 146 | TNR85_list = [100.*ood_result['TNR85'] for ood_result in ood_result_list] 147 | TNR95_list = [100.*ood_result['TNR95'] for ood_result in ood_result_list] 148 | AUROC_list = [100.*ood_result['AUROC'] for ood_result in ood_result_list] 149 | DTACC_list = [100.*ood_result['DTACC'] for ood_result in ood_result_list] 150 | AUIN_list = [100.*ood_result['AUIN'] for ood_result in ood_result_list] 151 | AUOUT_list = [100.*ood_result['AUOUT'] for ood_result in ood_result_list] 152 | 153 | for mtype in ['TNR85', 'TNR95', 'AUROC', 'DTACC', 'AUIN', 'AUOUT']: 154 | print(' {mtype:15s}'.format(mtype=mtype), end='') 155 | print('\n{mean:6.2f} ({std:6.3f})'.format(mean=np.mean(TNR85_list), std=np.std(TNR85_list)), end='') 156 | print(' {mean:6.2f} ({std:6.3f})'.format(mean=np.mean(TNR95_list), std=np.std(TNR95_list)), end='') 157 | print(' {mean:6.2f} ({std:6.3f})'.format(mean=np.mean(AUROC_list), std=np.std(AUROC_list)), end='') 158 | print(' {mean:6.2f} ({std:6.3f})'.format(mean=np.mean(DTACC_list), std=np.std(DTACC_list)), end='') 159 | print(' {mean:6.2f} ({std:6.3f})'.format(mean=np.mean(AUIN_list), std=np.std(AUIN_list)), end='') 160 | print(' {mean:6.2f} ({std:6.3f})\n'.format(mean=np.mean(AUOUT_list), std=np.std(AUOUT_list)), end='') 161 | print('') 162 | 163 | -------------------------------------------------------------------------------- /train_deepmcdd_image.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import torch 3 | import numpy as np 4 | from torchvision import transforms 5 | 6 | import models 7 | from dataloader_image import get_id_image_data, get_ood_image_data 8 | from utils import compute_confscores, compute_metrics, print_ood_results 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--dataset', required=True, help='svhn | cifar10 | cifar100') 12 | parser.add_argument('--net_type', required=True, help='resnet | densenet') 13 | parser.add_argument('--datadir', default='./image_data/', help='path to dataset') 14 | parser.add_argument('--outdir', default='./output/', help='folder to output results') 15 | parser.add_argument('--modeldir', default='./trained_model/', help='folder to trained model') 16 | parser.add_argument('--batch_size', type=int, default=128, help='batch size for data loader') 17 | parser.add_argument('--num_epochs', type=int, default=200, help='the number of epochs for training sc-layers') 18 | parser.add_argument('--learning_rate', type=float, default=0.1, help='initial learning rate of SGD optimizer') 19 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum factor for nesterov momentum of SGD optimizer') 20 | parser.add_argument('--reg_lambda', type=float, default=0.1, help='regularization coefficient') 21 | parser.add_argument('--pretrained', type=bool, default=False, help='initialize the network with pretrained weights or random weights') 22 | parser.add_argument('--gpu', type=int, default=0, help='gpu index') 23 | 24 | args = parser.parse_args() 25 | print(args) 26 | 27 | def main(): 28 | # set the path to pre-trained model and output 29 | outdir = os.path.join(args.outdir, args.net_type + '_' + args.dataset) 30 | pretrained_path = os.path.join('./pretrained/', args.net_type + '_' + args.dataset + '.pth') 31 | model_path = os.path.join(args.modeldir, args.net_type + '_' + args.dataset + '.pth') 32 | 33 | if os.path.isdir(outdir) == False: 34 | os.mkdir(outdir) 35 | if os.path.isdir(args.modeldir) == False: 36 | os.mkdir(args.modeldir) 37 | 38 | torch.cuda.manual_seed(0) 39 | torch.cuda.manual_seed_all(0) 40 | torch.cuda.set_device(args.gpu) 41 | 42 | if args.dataset == 'svhn': 43 | num_classes = 10 44 | ood_list = ['cifar10', 'imagenet_crop', 'lsun_crop'] 45 | elif args.dataset == 'cifar10': 46 | num_classes = 10 47 | ood_list = ['svhn', 'imagenet_crop', 'lsun_crop'] 48 | elif args.dataset == 'cifar100': 49 | num_classes = 100 50 | ood_list = ['svhn', 'imagenet_crop', 'lsun_crop'] 51 | 52 | if args.net_type == 'densenet': 53 | model = models.DenseNet_DeepMCDD(num_classes=num_classes) 54 | if args.pretrained == True: 55 | model.load_fe_weights(torch.load(pretrained_path, map_location = "cuda:" + str(args.gpu))) 56 | in_transform_train = transforms.Compose([ 57 | transforms.RandomCrop(32, padding=4), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.ToTensor(), 60 | transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)), 61 | ]) 62 | in_transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)),]) 63 | 64 | elif args.net_type == 'resnet': 65 | model = models.ResNet_DeepMCDD(num_classes=num_classes) 66 | if args.pretrained == True: 67 | model.load_fe_weights(torch.load(pretrained_path, map_location = "cuda:" + str(args.gpu))) 68 | in_transform_train = transforms.Compose([ 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 73 | ]) 74 | in_transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 75 | 76 | model.cuda() 77 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=1e-4, nesterov=True) 78 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.num_epochs*0.5), int(args.num_epochs*0.75)], gamma=0.1) 79 | 80 | train_loader, test_id_loader = get_id_image_data(args.dataset, args.batch_size, in_transform_train, in_transform_test, args.datadir) 81 | ce_loss = torch.nn.CrossEntropyLoss() 82 | 83 | for epoch in range(args.num_epochs): 84 | model.train() 85 | total_loss = 0.0 86 | 87 | for i, (images, labels) in enumerate(train_loader): 88 | images, labels = images.cuda(), labels.cuda() 89 | dists = model(images) 90 | scores = - dists + model.alphas 91 | 92 | label_mask = torch.zeros(labels.size(0), num_classes).cuda().scatter_(1, labels.unsqueeze(dim=1), 1) 93 | 94 | pull_loss = torch.mean(torch.sum(torch.mul(label_mask, dists), dim=1)) 95 | push_loss = ce_loss(scores, labels) 96 | loss = args.reg_lambda * pull_loss + push_loss 97 | 98 | optimizer.zero_grad() 99 | loss.backward() 100 | optimizer.step() 101 | 102 | total_loss += loss.item() 103 | 104 | scheduler.step() 105 | 106 | model.eval() 107 | with torch.no_grad(): 108 | # (1) evaluate ID classification 109 | correct, total = 0, 0 110 | for images, labels in test_id_loader: 111 | images, labels = images.cuda(), labels.cuda() 112 | scores = - model(images) + model.alphas 113 | _, predicted = torch.max(scores, 1) 114 | total += labels.size(0) 115 | correct += (predicted == labels).sum().item() 116 | idacc = 100 * correct / total 117 | 118 | ood_results_list = [] 119 | compute_confscores(model, test_id_loader, outdir, True) 120 | 121 | for ood in ood_list: 122 | test_ood_loader = get_ood_image_data(ood, args.batch_size, in_transform_test, args.datadir) 123 | compute_confscores(model, test_ood_loader, outdir, False) 124 | ood_results_list.append(compute_metrics(outdir)) 125 | 126 | print('== Epoch [{}/{}], Loss {} =='.format(epoch+1, args.num_epochs, total_loss)) 127 | print('ID Accuracy on "{idset:s}" test images : {val:6.2f}\n'.format(idset=args.dataset, val=idacc)) 128 | for ood_idx, ood_results in enumerate(ood_results_list): 129 | print('OOD accuracy on "{oodset:s}" test samples :'.format(oodset=ood_list[ood_idx])) 130 | print_ood_results(ood_results) 131 | 132 | torch.save(model.state_dict(), model_path) 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | # MLP with a Softmax classifier 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, input_size, hidden_sizes, num_classes): 12 | super(MLP, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_sizes = hidden_sizes[:-1] 15 | self.latent_size = hidden_sizes[-1] 16 | self.num_classes = num_classes 17 | 18 | self.build_fe() 19 | self.init_fe_weights() 20 | 21 | def build_fe(self): 22 | layers = [] 23 | layer_sizes = [self.input_size] + self.hidden_sizes 24 | 25 | for i in range(len(layer_sizes)-1): 26 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1])) 27 | layers.append(nn.ReLU()) 28 | layers.append(nn.Linear(layer_sizes[-1], self.latent_size)) 29 | layers.append(nn.Linear(self.latent_size, self.num_classes)) 30 | self.layers = nn.ModuleList(layers) 31 | 32 | def init_fe_weights(self): 33 | for m in self.modules(): 34 | if isinstance(m, nn.Linear): 35 | nn.init.xavier_uniform_(m.weight) 36 | nn.init.zeros_(m.bias) 37 | 38 | def _forward(self, x): 39 | for i, layer in enumerate(self.layers[:-1]): 40 | x = layer(x) 41 | return x 42 | 43 | def forward(self, x): 44 | out = self._forward(x) 45 | score = self.layers[-1](out) 46 | return score 47 | 48 | def feature_list(self, x): 49 | out = self._forward(x) 50 | score = self.layers[-1](out) 51 | return score, [out] 52 | 53 | def intermediate_forward(self, x, layer_index): 54 | out = self._forward(x) 55 | return out 56 | 57 | 58 | # MLP with a Deep-MCDD classifier 59 | 60 | class MLP_DeepMCDD(nn.Module): 61 | def __init__(self, input_size, hidden_sizes, num_classes): 62 | super(MLP_DeepMCDD, self).__init__() 63 | self.input_size = input_size 64 | self.hidden_sizes = hidden_sizes[:-1] 65 | self.latent_size = hidden_sizes[-1] 66 | self.num_classes = num_classes 67 | 68 | self.centers = torch.nn.Parameter(torch.zeros([num_classes, self.latent_size]), requires_grad=True) 69 | self.alphas = torch.nn.Parameter(torch.zeros(num_classes), requires_grad=True) 70 | self.logsigmas = torch.nn.Parameter(torch.zeros(num_classes), requires_grad=True) 71 | 72 | self.build_fe() 73 | self.init_fe_weights() 74 | 75 | def build_fe(self): 76 | layers = [] 77 | layer_sizes = [self.input_size] + self.hidden_sizes 78 | 79 | for i in range(len(layer_sizes)-1): 80 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1])) 81 | layers.append(nn.ReLU()) 82 | layers.append(nn.Linear(layer_sizes[-1], self.latent_size)) 83 | self.layers = nn.ModuleList(layers) 84 | 85 | def init_fe_weights(self): 86 | for m in self.modules(): 87 | if isinstance(m, nn.Linear): 88 | nn.init.xavier_uniform_(m.weight) 89 | nn.init.zeros_(m.bias) 90 | nn.init.xavier_uniform_(self.centers) 91 | nn.init.zeros_(self.alphas) 92 | nn.init.zeros_(self.logsigmas) 93 | 94 | def _forward(self, x): 95 | for i, layer in enumerate(self.layers): 96 | x = layer(x) 97 | return x 98 | 99 | def forward(self, x): 100 | out = self._forward(x) 101 | out = out.unsqueeze(dim=1).repeat([1, self.num_classes, 1]) 102 | scores = torch.sum((out - self.centers)**2, dim=2) / 2 / torch.exp(2 * F.relu(self.logsigmas)) + self.latent_size * F.relu(self.logsigmas) 103 | return scores 104 | 105 | 106 | # MLP with a Soft-MCDD classifier 107 | 108 | class MLP_SoftMCDD(nn.Module): 109 | def __init__(self, input_size, hidden_sizes, num_classes, epsilon=0.1): 110 | super(MLP_SoftMCDD, self).__init__() 111 | self.input_size = input_size 112 | self.hidden_sizes = hidden_sizes[:-1] 113 | self.latent_size = hidden_sizes[-1] 114 | self.num_classes = num_classes 115 | self.epsilon = epsilon 116 | 117 | self.centers = torch.zeros([num_classes, self.latent_size]).cuda() 118 | self.radii = torch.ones(num_classes).cuda() 119 | 120 | self.build_fe() 121 | self.init_fe_weights() 122 | 123 | def build_fe(self): 124 | layers = [] 125 | layer_sizes = [self.input_size] + self.hidden_sizes 126 | 127 | for i in range(len(layer_sizes)-1): 128 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1])) 129 | layers.append(nn.ReLU()) 130 | layers.append(nn.Linear(layer_sizes[-1], self.latent_size)) 131 | self.layers = nn.ModuleList(layers) 132 | 133 | def init_fe_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Linear): 136 | nn.init.xavier_uniform_(m.weight) 137 | nn.init.zeros_(m.bias) 138 | nn.init.xavier_uniform_(self.centers) 139 | 140 | def init_centers(self): 141 | nn.init.xavier_uniform_(self.centers) 142 | self.centers = 10 * self.centers / torch.norm(self.centers, dim=1) 143 | 144 | def update_centers(self, data_loader): 145 | class_outputs = {i : [] for i in range(self.num_classes)} 146 | centers = torch.zeros([self.num_classes, self.latent_size]).cuda() 147 | 148 | with torch.no_grad(): 149 | for data in data_loader: 150 | inputs, labels = data 151 | inputs, labels = inputs.cuda(), labels.cuda() 152 | outputs = self._forward(inputs) 153 | 154 | for k in range(self.num_classes): 155 | indices = (labels == k).nonzero().squeeze(dim=1) 156 | class_outputs[k].append(outputs[indices]) 157 | 158 | for k in range(self.num_classes): 159 | class_outputs[k] = torch.cat(class_outputs[k], dim=0) 160 | centers[k] = torch.mean(class_outputs[k], dim=0) 161 | 162 | self.centers.data = centers 163 | 164 | def update_radii(self, data_loader): 165 | class_scores = {i : [] for i in range(self.num_classes)} 166 | radii = np.zeros(self.num_classes) 167 | 168 | with torch.no_grad(): 169 | for data in data_loader: 170 | inputs, labels = data 171 | inputs, labels = inputs.cuda(), labels.cuda() 172 | scores = self.forward(inputs) 173 | 174 | for k in range(self.num_classes): 175 | indices = (labels == k).nonzero().squeeze(dim=1) 176 | class_scores[k].append(torch.sqrt(scores[indices][:, k])) 177 | 178 | for k in range(self.num_classes): 179 | class_scores[k] = torch.cat(class_scores[k], dim=0) 180 | radii[k] = np.quantile(class_scores[k].cpu().numpy(), 1 - self.epsilon) 181 | 182 | self.radii = torch.Tensor(radii).cuda() 183 | 184 | def update_centers_and_radii(self, data_loader): 185 | class_outputs = {i : [] for i in range(self.num_classes)} 186 | class_scores = {i : [] for i in range(self.num_classes)} 187 | centers = torch.zeros([self.num_classes, self.latent_size]).cuda() 188 | radii = np.zeros(self.num_classes) 189 | 190 | with torch.no_grad(): 191 | for data in data_loader: 192 | inputs, labels = data 193 | inputs, labels = inputs.cuda(), labels.cuda() 194 | outputs, scores = self._forward(inputs), self.forward(inputs) 195 | 196 | for k in range(self.num_classes): 197 | indices = (labels == k).nonzero().squeeze(dim=1) 198 | class_outputs[k].append(outputs[indices]) 199 | class_scores[k].append(torch.sqrt(scores[indices][:, k])) 200 | 201 | for k in range(self.num_classes): 202 | class_outputs[k] = torch.cat(class_outputs[k], dim=0) 203 | class_scores[k] = torch.cat(class_scores[k], dim=0) 204 | centers[k] = torch.mean(class_outputs[k], dim=0) 205 | radii[k] = np.quantile(class_scores[k].cpu().numpy(), 1 - self.epsilon) 206 | 207 | self.centers.data = centers 208 | self.radii = torch.Tensor(radii).cuda() 209 | 210 | def _forward(self, x): 211 | for i, layer in enumerate(self.layers): 212 | x = layer(x) 213 | return x 214 | 215 | def forward(self, x): 216 | out = self._forward(x) 217 | out = out.unsqueeze(dim=1).repeat([1, self.num_classes, 1]) 218 | scores = torch.sum((out - self.centers)**2, dim=2) 219 | return scores 220 | 221 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 13 | padding=1, bias=False) 14 | self.droprate = dropRate 15 | def forward(self, x): 16 | out = self.conv1(self.relu(self.bn1(x))) 17 | if self.droprate > 0: 18 | out = F.dropout(out, p=self.droprate, training=self.training) 19 | return torch.cat([x, out], 1) 20 | 21 | class BottleneckBlock(nn.Module): 22 | def __init__(self, in_planes, out_planes, dropRate=0.0): 23 | super(BottleneckBlock, self).__init__() 24 | inter_planes = out_planes * 4 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | self.bn2 = nn.BatchNorm2d(inter_planes) 30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 31 | padding=1, bias=False) 32 | self.droprate = dropRate 33 | def forward(self, x): 34 | out = self.conv1(self.relu(self.bn1(x))) 35 | if self.droprate > 0: 36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 37 | out = self.conv2(self.relu(self.bn2(out))) 38 | if self.droprate > 0: 39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 40 | return torch.cat([x, out], 1) 41 | 42 | class TransitionBlock(nn.Module): 43 | def __init__(self, in_planes, out_planes, dropRate=0.0): 44 | super(TransitionBlock, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 48 | padding=0, bias=False) 49 | self.droprate = dropRate 50 | def forward(self, x): 51 | out = self.conv1(self.relu(self.bn1(x))) 52 | if self.droprate > 0: 53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 54 | return F.avg_pool2d(out, 2) 55 | 56 | class DenseBlock(nn.Module): 57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0): 58 | super(DenseBlock, self).__init__() 59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate) 60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate): 61 | layers = [] 62 | for i in range(int(nb_layers)): 63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate)) 64 | return nn.Sequential(*layers) 65 | def forward(self, x): 66 | return self.layer(x) 67 | 68 | class DenseNet3(nn.Module): 69 | def __init__(self, depth, num_classes, growth_rate=12, 70 | reduction=0.5, bottleneck=True, dropRate=0.0): 71 | super(DenseNet3, self).__init__() 72 | in_planes = 2 * growth_rate 73 | n = (depth - 4) / 3 74 | if bottleneck == True: 75 | n = n/2 76 | block = BottleneckBlock 77 | else: 78 | block = BasicBlock 79 | # 1st conv before any dense block 80 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, 81 | padding=1, bias=False) 82 | # 1st block 83 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 84 | in_planes = int(in_planes+n*growth_rate) 85 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 86 | in_planes = int(math.floor(in_planes*reduction)) 87 | # 2nd block 88 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 89 | in_planes = int(in_planes+n*growth_rate) 90 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 91 | in_planes = int(math.floor(in_planes*reduction)) 92 | # 3rd block 93 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 94 | in_planes = int(in_planes+n*growth_rate) 95 | # global average pooling and classifier 96 | self.bn1 = nn.BatchNorm2d(in_planes) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.fc = nn.Linear(in_planes, num_classes) 99 | self.in_planes = in_planes 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | elif isinstance(m, nn.Linear): 109 | m.bias.data.zero_() 110 | 111 | def forward(self, x): 112 | out = self.conv1(x) 113 | out = self.trans1(self.block1(out)) 114 | out = self.trans2(self.block2(out)) 115 | out = self.block3(out) 116 | out = self.relu(self.bn1(out)) 117 | out = F.avg_pool2d(out, 8) 118 | out = out.view(-1, self.in_planes) 119 | return self.fc(out) 120 | 121 | # function to extact the multiple features 122 | def feature_list(self, x): 123 | out_list = [] 124 | out = self.conv1(x) 125 | out_list.append(out) 126 | out = self.trans1(self.block1(out)) 127 | out_list.append(out) 128 | out = self.trans2(self.block2(out)) 129 | out_list.append(out) 130 | out = self.block3(out) 131 | out = self.relu(self.bn1(out)) 132 | out_list.append(out) 133 | out = F.avg_pool2d(out, 8) 134 | out = out.view(-1, self.in_planes) 135 | 136 | return self.fc(out), out_list 137 | 138 | def intermediate_forward(self, x, layer_index): 139 | out = self.conv1(x) 140 | if layer_index == 1: 141 | out = self.trans1(self.block1(out)) 142 | elif layer_index == 2: 143 | out = self.trans1(self.block1(out)) 144 | out = self.trans2(self.block2(out)) 145 | elif layer_index == 3: 146 | out = self.trans1(self.block1(out)) 147 | out = self.trans2(self.block2(out)) 148 | out = self.block3(out) 149 | out = self.relu(self.bn1(out)) 150 | return out 151 | 152 | # function to extact the penultimate features 153 | def penultimate_forward(self, x): 154 | out = self.conv1(x) 155 | out = self.trans1(self.block1(out)) 156 | out = self.trans2(self.block2(out)) 157 | out = self.block3(out) 158 | penultimate = self.relu(self.bn1(out)) 159 | out = F.avg_pool2d(penultimate, 8) 160 | out = out.view(-1, self.in_planes) 161 | return self.fc(out), penultimate 162 | 163 | class DenseNet_DeepMCDD(nn.Module): 164 | def __init__(self, latent_size=342, num_classes=10): 165 | super(DenseNet_DeepMCDD, self).__init__() 166 | self.latent_size = latent_size 167 | self.num_classes = num_classes 168 | 169 | self.centers = torch.nn.Parameter(torch.zeros([num_classes, latent_size]), requires_grad=True) 170 | self.alphas = torch.nn.Parameter(torch.ones(num_classes), requires_grad=True) 171 | self.logsigmas = torch.nn.Parameter(torch.zeros(num_classes), requires_grad=True) 172 | self.radii = torch.zeros(num_classes).cuda() 173 | 174 | self.fe = DenseNet3(100, num_classes) 175 | self.init_weights() 176 | 177 | def load_fe_weights(self, weights): 178 | self.fe.load_state_dict(weights) 179 | 180 | def load_weights(self, weights): 181 | self.load_state_dict(weights) 182 | 183 | def init_weights(self): 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 187 | m.weight.data.normal_(0, math.sqrt(2. / n)) 188 | elif isinstance(m, nn.BatchNorm2d): 189 | m.weight.data.fill_(1) 190 | m.bias.data.zero_() 191 | elif isinstance(m, nn.Linear): 192 | m.bias.data.zero_() 193 | nn.init.xavier_uniform_(self.centers) 194 | nn.init.zeros_(self.alphas) 195 | nn.init.zeros_(self.logsigmas) 196 | 197 | def _forward(self, x): 198 | out = self.fe.conv1(x) 199 | out = self.fe.trans1(self.fe.block1(out)) 200 | out = self.fe.trans2(self.fe.block2(out)) 201 | out = self.fe.block3(out) 202 | out = self.fe.relu(self.fe.bn1(out)) 203 | out = F.avg_pool2d(out, 8) 204 | out = out.view(-1, self.fe.in_planes) 205 | return out 206 | 207 | def forward(self, x): 208 | out = self._forward(x) 209 | out = out.unsqueeze(dim=1).repeat([1, self.num_classes, 1]) 210 | scores = torch.sum((out - self.centers)**2, dim=2) / 2 / torch.exp(2 * F.relu(self.logsigmas)) + self.latent_size * F.relu(self.logsigmas) 211 | return scores 212 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | BasicBlock and Bottleneck module is from the original ResNet paper: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | PreActBlock and PreActBottleneck module is from the later paper: 6 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 8 | Original code is from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 9 | ''' 10 | import os, math 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.parameter import Parameter 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, in_planes, planes, stride=1): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(in_planes, planes, stride) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | 30 | self.shortcut = nn.Sequential() 31 | if stride != 1 or in_planes != self.expansion*planes: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(self.expansion*planes) 35 | ) 36 | 37 | def forward(self, x): 38 | out = F.relu(self.bn1(self.conv1(x))) 39 | out = self.bn2(self.conv2(out)) 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | '''Pre-activation version of the BasicBlock.''' 47 | expansion = 1 48 | 49 | def __init__(self, in_planes, planes, stride=1): 50 | super(PreActBlock, self).__init__() 51 | self.bn1 = nn.BatchNorm2d(in_planes) 52 | self.conv1 = conv3x3(in_planes, planes, stride) 53 | self.bn2 = nn.BatchNorm2d(planes) 54 | self.conv2 = conv3x3(planes, planes) 55 | 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(x)) 63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 64 | out = self.conv1(out) 65 | out = self.conv2(F.relu(self.bn2(out))) 66 | out += shortcut 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, in_planes, planes, stride=1): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(planes) 77 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 78 | self.bn2 = nn.BatchNorm2d(planes) 79 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 80 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 81 | 82 | self.shortcut = nn.Sequential() 83 | if stride != 1 or in_planes != self.expansion*planes: 84 | self.shortcut = nn.Sequential( 85 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 86 | nn.BatchNorm2d(self.expansion*planes) 87 | ) 88 | 89 | def forward(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = F.relu(self.bn2(self.conv2(out))) 92 | out = self.bn3(self.conv3(out)) 93 | out += self.shortcut(x) 94 | out = F.relu(out) 95 | return out 96 | 97 | 98 | class PreActBottleneck(nn.Module): 99 | '''Pre-activation version of the original Bottleneck module.''' 100 | expansion = 4 101 | 102 | def __init__(self, in_planes, planes, stride=1): 103 | super(PreActBottleneck, self).__init__() 104 | self.bn1 = nn.BatchNorm2d(in_planes) 105 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 106 | self.bn2 = nn.BatchNorm2d(planes) 107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 108 | self.bn3 = nn.BatchNorm2d(planes) 109 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 110 | 111 | if stride != 1 or in_planes != self.expansion*planes: 112 | self.shortcut = nn.Sequential( 113 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 114 | ) 115 | 116 | def forward(self, x): 117 | out = F.relu(self.bn1(x)) 118 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 119 | out = self.conv1(out) 120 | out = self.conv2(F.relu(self.bn2(out))) 121 | out = self.conv3(F.relu(self.bn3(out))) 122 | out += shortcut 123 | return out 124 | 125 | 126 | class ResNet(nn.Module): 127 | def __init__(self, block, num_blocks, num_classes=10): 128 | super(ResNet, self).__init__() 129 | self.in_planes = 64 130 | 131 | self.conv1 = conv3x3(3,64) 132 | self.bn1 = nn.BatchNorm2d(64) 133 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 134 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 135 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 136 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 137 | self.linear = nn.Linear(512*block.expansion, num_classes) 138 | 139 | self.init_weights() 140 | 141 | def init_weights(self): 142 | for m in self.modules(): 143 | if isinstance(m, nn.Conv2d): 144 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 145 | m.weight.data.normal_(0, math.sqrt(2. / n)) 146 | elif isinstance(m, nn.BatchNorm2d): 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | elif isinstance(m, nn.Linear): 150 | m.bias.data.zero_() 151 | 152 | def _make_layer(self, block, planes, num_blocks, stride): 153 | strides = [stride] + [1]*(num_blocks-1) 154 | layers = [] 155 | for stride in strides: 156 | layers.append(block(self.in_planes, planes, stride)) 157 | self.in_planes = planes * block.expansion 158 | return nn.Sequential(*layers) 159 | 160 | def forward(self, x): 161 | out = F.relu(self.bn1(self.conv1(x))) 162 | out = self.layer1(out) 163 | out = self.layer2(out) 164 | out = self.layer3(out) 165 | out = self.layer4(out) 166 | out = F.avg_pool2d(out, 4) 167 | out = out.view(out.size(0), -1) 168 | y = self.linear(out) 169 | return y 170 | 171 | # function to extact the multiple features 172 | def feature_list(self, x): 173 | out_list = [] 174 | out = F.relu(self.bn1(self.conv1(x))) 175 | out_list.append(out) 176 | out = self.layer1(out) 177 | out_list.append(out) 178 | out = self.layer2(out) 179 | out_list.append(out) 180 | out = self.layer3(out) 181 | out_list.append(out) 182 | out = self.layer4(out) 183 | out_list.append(out) 184 | out = F.avg_pool2d(out, 4) 185 | out = out.view(out.size(0), -1) 186 | y = self.linear(out) 187 | return y, out_list 188 | 189 | # function to extact a specific feature 190 | def intermediate_forward(self, x, layer_index): 191 | out = F.relu(self.bn1(self.conv1(x))) 192 | if layer_index == 1: 193 | out = self.layer1(out) 194 | elif layer_index == 2: 195 | out = self.layer1(out) 196 | out = self.layer2(out) 197 | elif layer_index == 3: 198 | out = self.layer1(out) 199 | out = self.layer2(out) 200 | out = self.layer3(out) 201 | elif layer_index == 4: 202 | out = self.layer1(out) 203 | out = self.layer2(out) 204 | out = self.layer3(out) 205 | out = self.layer4(out) 206 | return out 207 | 208 | # function to extact the penultimate features 209 | def penultimate_forward(self, x): 210 | out = F.relu(self.bn1(self.conv1(x))) 211 | out = self.layer1(out) 212 | out = self.layer2(out) 213 | out = self.layer3(out) 214 | penultimate = self.layer4(out) 215 | out = F.avg_pool2d(penultimate, 4) 216 | out = out.view(out.size(0), -1) 217 | y = self.linear(out) 218 | return y, penultimate 219 | 220 | def ResNet18(num_c): 221 | return ResNet(PreActBlock, [2,2,2,2], num_classes=num_c) 222 | 223 | def ResNet34(num_c): 224 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_c) 225 | 226 | def ResNet50(): 227 | return ResNet(Bottleneck, [3,4,6,3]) 228 | 229 | def ResNet101(): 230 | return ResNet(Bottleneck, [3,4,23,3]) 231 | 232 | def ResNet152(): 233 | return ResNet(Bottleneck, [3,8,36,3]) 234 | 235 | class ResNet_DeepMCDD(nn.Module): 236 | def __init__(self, block_expansion=1, latent_size=512, num_classes=10): 237 | super(ResNet_DeepMCDD, self).__init__() 238 | self.latent_size = latent_size 239 | self.num_classes = num_classes 240 | 241 | self.centers = torch.nn.Parameter(torch.zeros([num_classes, latent_size]), requires_grad=True) 242 | self.alphas = torch.nn.Parameter(torch.ones(num_classes), requires_grad=True) 243 | self.logsigmas = torch.nn.Parameter(torch.zeros(num_classes), requires_grad=True) 244 | self.radii = torch.ones(num_classes).cuda() 245 | 246 | self.fe = ResNet34(num_c=num_classes) 247 | self.init_weights() 248 | 249 | def load_fe_weights(self, weights): 250 | self.fe.load_state_dict(weights) 251 | 252 | def load_weights(self, weights): 253 | self.load_state_dict(weights) 254 | 255 | def init_weights(self): 256 | self.fe.init_weights() 257 | for m in self.modules(): 258 | if isinstance(m, nn.Conv2d): 259 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 260 | m.weight.data.normal_(0, math.sqrt(2. / n)) 261 | elif isinstance(m, nn.BatchNorm2d): 262 | m.weight.data.fill_(1) 263 | m.bias.data.zero_() 264 | elif isinstance(m, nn.Linear): 265 | m.bias.data.zero_() 266 | nn.init.xavier_uniform_(self.centers) 267 | nn.init.zeros_(self.alphas) 268 | nn.init.zeros_(self.logsigmas) 269 | 270 | def _forward(self, x): 271 | out = F.relu(self.fe.bn1(self.fe.conv1(x))) 272 | out = self.fe.layer1(out) 273 | out = self.fe.layer2(out) 274 | out = self.fe.layer3(out) 275 | out = self.fe.layer4(out) 276 | out = F.avg_pool2d(out, 4) 277 | out = out.view(out.size(0), -1) 278 | return out 279 | 280 | def forward(self, x): 281 | out = self._forward(x) 282 | out = out.unsqueeze(dim=1).repeat([1, self.num_classes, 1]) 283 | scores = torch.sum((out - self.centers)**2, dim=2) / 2 / torch.exp(2 * F.relu(self.logsigmas)) + self.latent_size * F.relu(self.logsigmas) 284 | return scores 285 | --------------------------------------------------------------------------------