├── datasets ├── __init__.py └── datasetsHelper.py ├── models ├── conv │ ├── __init__.py │ ├── cnn.py │ └── nets.py ├── fc │ ├── __init__.py │ └── nets.py ├── utils │ ├── __init__.py │ ├── loss_functions.py │ ├── score.py │ └── modules.py ├── vae.py └── main_model.py ├── training ├── __init__.py ├── train_classifier.py ├── train_cvae.py └── fine_tune.py ├── CSI+CMG ├── CSI_model │ ├── __init__.py │ ├── base_model.py │ ├── classifier.py │ ├── resnet.py │ ├── resnet_imagenet.py │ └── transform_layers.py ├── models │ ├── conv │ │ ├── __init__.py │ │ ├── cnn.py │ │ └── nets.py │ ├── fc │ │ ├── __init__.py │ │ └── nets.py │ ├── utils │ │ ├── __init__.py │ │ ├── loss_functions.py │ │ ├── score.py │ │ └── modules.py │ └── vae.py ├── utils.py ├── readme.md └── main.py ├── figure └── model.png ├── .gitignore ├── utils.py ├── train.py ├── readme.md └── eval.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/conv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/fc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CSI+CMG/CSI_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CSI+CMG/models/conv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CSI+CMG/models/fc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CSI+CMG/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figure/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyijia/CMG/HEAD/figure/model.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Development env 2 | .idea 3 | 4 | # Cache files 5 | __pycache__/ 6 | *.pyc -------------------------------------------------------------------------------- /CSI+CMG/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--device', type=str, default='cuda:0', 7 | help='device for training') 8 | parser.add_argument('--params-dict-name', type=str, 9 | help='name of the classifier checkpoint file') 10 | parser.add_argument('--params-dict-name2', type=str, 11 | help='name of the CVAE checkpoint file') 12 | return parser.parse_args() 13 | -------------------------------------------------------------------------------- /models/utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def weighted_average(tensor, weights=None, dim=0): 5 | """Computes weighted average of [tensor] over dimension [dim].""" 6 | 7 | if weights is None: 8 | mean = torch.mean(tensor, dim=dim) 9 | else: 10 | batch_size = tensor.size(dim) if len(tensor.size()) > 0 else 1 11 | assert len(weights) == batch_size 12 | norm_weights = torch.tensor([weight for weight in weights]).to(tensor.device) 13 | mean = torch.mean(norm_weights * tensor, dim=dim) 14 | return mean 15 | -------------------------------------------------------------------------------- /CSI+CMG/models/utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def weighted_average(tensor, weights=None, dim=0): 5 | """Computes weighted average of [tensor] over dimension [dim].""" 6 | if weights is None: 7 | mean = torch.mean(tensor, dim=dim) 8 | else: 9 | batch_size = tensor.size(dim) if len(tensor.size()) > 0 else 1 10 | assert len(weights) == batch_size 11 | norm_weights = torch.tensor([weight for weight in weights]).to(tensor.device) 12 | mean = torch.mean(norm_weights * tensor, dim=dim) 13 | return mean 14 | -------------------------------------------------------------------------------- /models/fc/nets.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | 5 | class MLP(nn.Module): 6 | """ 7 | Module for a multi-layer perceptron (MLP). 8 | 9 | input: <2D tensor> [batch_size] * [input_dim] 10 | output: <2D tensor> [batch_size] * [classes] 11 | 12 | """ 13 | 14 | def __init__(self, input_dim, classes, latent_dim=512): 15 | super(MLP, self).__init__() 16 | self.input_dim = input_dim 17 | self.classes = classes 18 | self.fc1 = nn.Linear(input_dim, latent_dim) 19 | self.fc2 = nn.Linear(latent_dim, classes) 20 | 21 | def forward(self, x): 22 | x = x.view(-1, self.input_dim) 23 | x = F.relu(self.fc1(x)) 24 | result = self.fc2(x) 25 | 26 | return result 27 | -------------------------------------------------------------------------------- /CSI+CMG/models/fc/nets.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | 5 | class MLP(nn.Module): 6 | """ 7 | Module for a multi-layer perceptron (MLP). 8 | 9 | input: <2D tensor> [batch_size] * [input_dim] 10 | output: <2D tensor> [batch_size] * [classes] 11 | 12 | """ 13 | 14 | def __init__(self, input_dim, classes, latent_dim=512): 15 | super(MLP, self).__init__() 16 | self.input_dim = input_dim 17 | self.classes = classes 18 | self.fc1 = nn.Linear(input_dim, latent_dim) 19 | self.fc2 = nn.Linear(latent_dim, classes) 20 | 21 | def forward(self, x): 22 | x = x.view(-1, self.input_dim) 23 | x = F.relu(self.fc1(x)) 24 | result = self.fc2(x) 25 | 26 | return result 27 | -------------------------------------------------------------------------------- /training/train_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def train_classifier(model, train_loader, device, params_dict_name, dataset='mnist'): 5 | """Trains the IND classifier on the given dataset.""" 6 | 7 | if dataset == 'mnist': 8 | n_epochs = 100 9 | elif dataset == 'cifar10': 10 | n_epochs = 200 11 | else: 12 | raise ValueError 13 | LR = 0.001 14 | model.optimizer = torch.optim.Adam(params=model.parameters(), lr=LR) 15 | 16 | for epoch in range(n_epochs): 17 | for data, target in train_loader: 18 | data = data.to(device) 19 | target = target.long().to(device) 20 | train_loss = model.train_a_batch(data, target) 21 | 22 | print('Epoch: {}, loss = {}'.format(epoch, train_loss)) 23 | 24 | torch.save(model.state_dict(), params_dict_name) 25 | -------------------------------------------------------------------------------- /models/utils/score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from sklearn.metrics import roc_auc_score 4 | 5 | 6 | def softmax_result(fx, y): 7 | """Calculates roc_auc using softmax score of the unseen slot. 8 | 9 | Args: 10 | fx: Last layer output of the model, assumes the unseen slot to be the last one. 11 | y: Class Label, assumes the label of unseen data to be -1. 12 | Returns: 13 | roc_auc: Unseen data as positive, seen data as negative. 14 | """ 15 | score = F.softmax(fx, dim=1)[:, -1] 16 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), score.cpu().detach().numpy()) 17 | 18 | return rocauc 19 | 20 | 21 | def energy_result(fx, y): 22 | """Calculates roc_auc using energy score. 23 | 24 | Args: 25 | fx: Last layer output of the model, assumes the unseen slot to be the last one. 26 | y: Class Label, assumes the label of unseen data to be -1. 27 | Returns: 28 | roc_auc: Unseen data as positive, seen data as negative. 29 | """ 30 | energy_score = - torch.logsumexp(fx[:, :-1], dim=1) 31 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), energy_score.cpu().detach().numpy()) 32 | 33 | return rocauc 34 | -------------------------------------------------------------------------------- /CSI+CMG/models/utils/score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from sklearn.metrics import roc_auc_score 4 | 5 | 6 | def softmax_result(fx, y): 7 | """Calculates roc_auc using softmax score of the unseen slot. 8 | 9 | Args: 10 | fx: Last layer output of the model, assumes the unseen slot to be the last one. 11 | y: Class Label, assumes the label of unseen data to be -1. 12 | Returns: 13 | roc_auc: Unseen data as positive, seen data as negative. 14 | """ 15 | score = F.softmax(fx, dim=1)[:, -1] 16 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), score.cpu().detach().numpy()) 17 | 18 | return rocauc 19 | 20 | 21 | def energy_result(fx, y): 22 | """Calculates roc_auc using energy score. 23 | 24 | Args: 25 | fx: Last layer output of the model, assumes the unseen slot to be the last one. 26 | y: Class Label, assumes the label of unseen data to be -1. 27 | Returns: 28 | roc_auc: Unseen data as positive, seen data as negative. 29 | """ 30 | energy_score = - torch.logsumexp(fx[:, :-1], dim=1) 31 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), energy_score.cpu().detach().numpy()) 32 | 33 | return rocauc 34 | -------------------------------------------------------------------------------- /CSI+CMG/readme.md: -------------------------------------------------------------------------------- 1 | # Reproduce CSI+CMG for new SOTA 2 | 3 | ## Requirements 4 | 5 | ### Environments 6 | 7 | The required packages are as follows: 8 | 9 | - python 3.5 10 | - torch 1.2 11 | - torchvision 0.4 12 | - CUDA 10.0 13 | - scikit-learn 0.22 14 | 15 | ### Datasets 16 | 17 | Please download datasets to `./data` and rename the file. Or you may modify the data path in [main.py](main.py). 18 | 19 | ### Checkpoints 20 | 21 | Please download 22 | the [CSI pretrained model](https://drive.google.com/file/d/1rW2-0MJEzPHLb_PAW-LvCivHt-TkDpRO/view?usp=sharing) provided 23 | by [CSI](https://github.com/alinlab/CSI) and save it as `./checkpoint/cifar10_labeled.model`. You can also train your 24 | own model with CSI's code for other settings. 25 | 26 | Also, you need to pretrain a CVAE model with CIFAR10 training data according to CMG stage 1, and save the checkpoint 27 | as `./checkpoint/cvae_10class.pkl`. 28 | 29 | ## Applying CMG and Evaluations 30 | 31 | To perform CMG tuning on CSI models and get the final result on CIFAR10 (OOD Detection on different datasets), run this 32 | command: 33 | 34 | ``` 35 | python -m main \ 36 | --device {the available GPU in your cluser, e.g., cuda:0} \ 37 | --params-dict-name './checkpoint/cifar10_labled.model' \ 38 | --params-dict-name2 './checkpoint/cvae_10class.pkl' 39 | ``` 40 | -------------------------------------------------------------------------------- /CSI+CMG/CSI_model/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import * 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(nn.Module, metaclass=ABCMeta): 7 | def __init__(self, last_dim, num_classes=10, simclr_dim=128): 8 | super(BaseModel, self).__init__() 9 | self.linear = nn.Linear(last_dim, num_classes) 10 | self.simclr_layer = nn.Sequential( 11 | nn.Linear(last_dim, last_dim), 12 | nn.ReLU(), 13 | nn.Linear(last_dim, simclr_dim), 14 | ) 15 | self.shift_cls_layer = nn.Linear(last_dim, 2) 16 | self.joint_distribution_layer = nn.Linear(last_dim, 4 * num_classes) 17 | 18 | @abstractmethod 19 | def penultimate(self, inputs, all_features=False): 20 | pass 21 | 22 | def forward(self, inputs, penultimate=False, simclr=False, shift=False, joint=False): 23 | _aux = {} 24 | _return_aux = False 25 | 26 | features = self.penultimate(inputs) 27 | 28 | output = self.linear(features) 29 | 30 | if penultimate: 31 | _return_aux = True 32 | _aux['penultimate'] = features 33 | 34 | if simclr: 35 | _return_aux = True 36 | _aux['simclr'] = self.simclr_layer(features) 37 | 38 | if shift: 39 | _return_aux = True 40 | _aux['shift'] = self.shift_cls_layer(features) 41 | 42 | if joint: 43 | _return_aux = True 44 | _aux['joint'] = self.joint_distribution_layer(features) 45 | 46 | if _return_aux: 47 | return output, _aux 48 | 49 | return output 50 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--task', type=str, default='same_dataset_mnist', 7 | choices=['same_dataset_mnist', 'same_dataset_cifar10', 'different_dataset'], 8 | help='the current task: same_dataset_mnist/same_dataset_cifar10/different_dataset') 9 | parser.add_argument('--partition', type=str, default='partition1') 10 | parser.add_argument('--command', type=str, default='train_classifier', 11 | choices=['train_classifier', 'train_cvae'], 12 | help='command for CMG stage 1: train_classifier/train_cvae') 13 | parser.add_argument('--ood-dataset', type=str, 14 | choices=['SVHN', 'LSUN', 'LSUN-FIX', 'tinyImageNet', 'ImageNet-FIX', 'CIFAR100'], 15 | help='OOD dataset for setting 2: SVHN/LSUN/LSUN-FIX/tinyImageNet/ImageNet-FIX/CIFAR100') 16 | parser.add_argument('--mode', type=str, default='CMG-energy', choices=['CMG-softmax', 'CMG-energy'], 17 | help="CMG-softmax/CMG-energy") 18 | parser.add_argument('--device', type=str, default='cuda:0', 19 | help='device for training') 20 | parser.add_argument('--params-dict-name', type=str, 21 | help='name of the classifier checkpoint file') 22 | parser.add_argument('--params-dict-name2', type=str, 23 | help='name of the CVAE checkpoint file') 24 | parser.add_argument('--seed', type=int, default=123, help='set random seed') 25 | return parser.parse_args() 26 | -------------------------------------------------------------------------------- /training/train_cvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.utils import loss_functions as lf 4 | 5 | 6 | def train_cvae(model, train_loader, device, params_dict_name, dataset='mnist'): 7 | """Trains CVAE on the given dataset.""" 8 | 9 | if dataset == 'mnist': 10 | n_epochs = 100 11 | elif dataset == 'cifar10': 12 | n_epochs = 200 13 | else: 14 | raise ValueError 15 | LR = 0.001 16 | 17 | optimizer = torch.optim.Adam(params=model.parameters(), lr=LR) 18 | 19 | for epoch in range(n_epochs): 20 | for data, y in train_loader: 21 | optimizer.zero_grad() 22 | data = data.to(device) 23 | y = y.long() 24 | y_onehot = torch.Tensor(y.shape[0], model.class_num) 25 | y_onehot.zero_() 26 | y_onehot.scatter_(1, y.view(-1, 1), 1) 27 | y_onehot = y_onehot.to(device) 28 | mu, logvar, recon = model(data, y_onehot) 29 | 30 | variatL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) 31 | variatL = lf.weighted_average(variatL, weights=None, dim=0) 32 | variatL /= (model.image_channels * model.image_size * model.image_size) 33 | 34 | data_resize = data.reshape(-1, model.image_channels * model.image_size * model.image_size) 35 | recon_resize = recon.reshape(-1, model.image_channels * model.image_size * model.image_size) 36 | reconL = (data_resize - recon_resize) ** 2 37 | reconL = torch.mean(reconL, 1) 38 | reconL = lf.weighted_average(reconL, weights=None, dim=0) 39 | 40 | loss = variatL + reconL 41 | 42 | loss.backward() 43 | optimizer.step() 44 | 45 | print("epoch: {}, loss = {}, reconL = {}, variaL = {}".format(epoch, loss, reconL, variatL)) 46 | 47 | torch.save(model.state_dict(), params_dict_name) 48 | -------------------------------------------------------------------------------- /models/utils/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | 4 | 5 | class Identity(nn.Module): 6 | """A nn-module to simply pass on the input data.""" 7 | 8 | def forward(self, x): 9 | return x 10 | 11 | def __repr__(self): 12 | tmpstr = self.__class__.__name__ + '()' 13 | return tmpstr 14 | 15 | 16 | class Shape(nn.Module): 17 | """A nn-module to shape a tensor of shape [shape].""" 18 | 19 | def __init__(self, shape): 20 | super().__init__() 21 | self.shape = shape 22 | self.dim = len(shape) 23 | 24 | def forward(self, x): 25 | return x.view(*self.shape) 26 | 27 | def __repr__(self): 28 | tmpstr = self.__class__.__name__ + '(shape = {})'.format(self.shape) 29 | return tmpstr 30 | 31 | 32 | class Reshape(nn.Module): 33 | """A nn-module to reshape a tensor(-tuple) to a 4-dim "image"-tensor(-tuple) with [image_channels] channels.""" 34 | 35 | def __init__(self, image_channels): 36 | super().__init__() 37 | self.image_channels = image_channels 38 | 39 | def forward(self, x): 40 | if type(x) == tuple: 41 | batch_size = x[0].size(0) # first dimension should be batch-dimension. 42 | image_size = int(np.sqrt(x[0].nelement() / (batch_size * self.image_channels))) 43 | return (x_item.view(batch_size, self.image_channels, image_size, image_size) for x_item in x) 44 | else: 45 | batch_size = x.size(0) # first dimension should be batch-dimension. 46 | image_size = int(np.sqrt(x.nelement() / (batch_size * self.image_channels))) 47 | return x.view(batch_size, self.image_channels, image_size, image_size) 48 | 49 | def __repr__(self): 50 | tmpstr = self.__class__.__name__ + '(channels = {})'.format(self.image_channels) 51 | return tmpstr 52 | 53 | 54 | class Flatten(nn.Module): 55 | """A nn-module to flatten a multi-dimensional tensor to 2-dim tensor.""" 56 | 57 | def forward(self, x): 58 | batch_size = x.size(0) # first dimension should be batch-dimension. 59 | return x.contiguous().view(batch_size, -1) 60 | 61 | def __repr__(self): 62 | tmpstr = self.__class__.__name__ + '()' 63 | return tmpstr 64 | -------------------------------------------------------------------------------- /CSI+CMG/models/utils/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | 4 | 5 | class Identity(nn.Module): 6 | """A nn-module to simply pass on the input data.""" 7 | 8 | def forward(self, x): 9 | return x 10 | 11 | def __repr__(self): 12 | tmpstr = self.__class__.__name__ + '()' 13 | return tmpstr 14 | 15 | 16 | class Shape(nn.Module): 17 | """A nn-module to shape a tensor of shape [shape].""" 18 | 19 | def __init__(self, shape): 20 | super().__init__() 21 | self.shape = shape 22 | self.dim = len(shape) 23 | 24 | def forward(self, x): 25 | return x.view(*self.shape) 26 | 27 | def __repr__(self): 28 | tmpstr = self.__class__.__name__ + '(shape = {})'.format(self.shape) 29 | return tmpstr 30 | 31 | 32 | class Reshape(nn.Module): 33 | """A nn-module to reshape a tensor(-tuple) to a 4-dim "image"-tensor(-tuple) with [image_channels] channels.""" 34 | 35 | def __init__(self, image_channels): 36 | super().__init__() 37 | self.image_channels = image_channels 38 | 39 | def forward(self, x): 40 | if type(x) == tuple: 41 | batch_size = x[0].size(0) # first dimension should be batch-dimension. 42 | image_size = int(np.sqrt(x[0].nelement() / (batch_size * self.image_channels))) 43 | return (x_item.view(batch_size, self.image_channels, image_size, image_size) for x_item in x) 44 | else: 45 | batch_size = x.size(0) # first dimension should be batch-dimension. 46 | image_size = int(np.sqrt(x.nelement() / (batch_size * self.image_channels))) 47 | return x.view(batch_size, self.image_channels, image_size, image_size) 48 | 49 | def __repr__(self): 50 | tmpstr = self.__class__.__name__ + '(channels = {})'.format(self.image_channels) 51 | return tmpstr 52 | 53 | 54 | class Flatten(nn.Module): 55 | """A nn-module to flatten a multi-dimensional tensor to 2-dim tensor.""" 56 | 57 | def forward(self, x): 58 | batch_size = x.size(0) # first dimension should be batch-dimension. 59 | return x.contiguous().view(batch_size, -1) 60 | 61 | def __repr__(self): 62 | tmpstr = self.__class__.__name__ + '()' 63 | return tmpstr 64 | -------------------------------------------------------------------------------- /CSI+CMG/CSI_model/classifier.py: -------------------------------------------------------------------------------- 1 | import CSI_model.transform_layers as TL 2 | import torch.nn as nn 3 | from CSI_model.resnet import ResNet18, ResNet34, ResNet50 4 | from CSI_model.resnet_imagenet import resnet18, resnet50 5 | 6 | 7 | def get_simclr_augmentation(P, image_size): 8 | # parameter for resizecrop 9 | resize_scale = (P.resize_factor, 1.0) # resize scaling factor 10 | if P.resize_fix: # if resize_fix is True, use same scale 11 | resize_scale = (P.resize_factor, P.resize_factor) 12 | 13 | # Align augmentation 14 | color_jitter = TL.ColorJitterLayer(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8) 15 | color_gray = TL.RandomColorGrayLayer(p=0.2) 16 | resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=image_size) 17 | 18 | # Transform define # 19 | if P.dataset == 'imagenet': # Using RandomResizedCrop at PIL transform 20 | transform = nn.Sequential( 21 | color_jitter, 22 | color_gray, 23 | ) 24 | else: 25 | transform = nn.Sequential( 26 | color_jitter, 27 | color_gray, 28 | resize_crop, 29 | ) 30 | 31 | return transform 32 | 33 | 34 | def get_shift_module(P, eval=False): 35 | if P.shift_trans_type == 'rotation': 36 | shift_transform = TL.Rotation() 37 | K_shift = 4 38 | elif P.shift_trans_type == 'cutperm': 39 | shift_transform = TL.CutPerm() 40 | K_shift = 4 41 | else: 42 | shift_transform = nn.Identity() 43 | K_shift = 1 44 | 45 | if not eval and not ('sup' in P.mode): 46 | assert P.batch_size == int(128 / K_shift) 47 | 48 | return shift_transform, K_shift 49 | 50 | 51 | def get_shift_classifer(model, K_shift): 52 | model.shift_cls_layer = nn.Linear(model.last_dim, K_shift) 53 | 54 | return model 55 | 56 | 57 | def get_classifier(mode, n_classes=10): 58 | if mode == 'resnet18': 59 | classifier = ResNet18(num_classes=n_classes) 60 | elif mode == 'resnet34': 61 | classifier = ResNet34(num_classes=n_classes) 62 | elif mode == 'resnet50': 63 | classifier = ResNet50(num_classes=n_classes) 64 | elif mode == 'resnet18_imagenet': 65 | classifier = resnet18(num_classes=n_classes) 66 | elif mode == 'resnet50_imagenet': 67 | classifier = resnet50(num_classes=n_classes) 68 | else: 69 | raise NotImplementedError() 70 | 71 | return classifier 72 | -------------------------------------------------------------------------------- /CSI+CMG/models/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.conv.nets import ConvLayers, DeconvLayers, ResNet18, DeconvResnet 6 | from models.utils import modules 7 | 8 | 9 | class ConditionalVAE(nn.Module): 10 | def __init__(self, z_dim=32, image_channels=1, image_size=28, class_num=10, dataset='mnist'): 11 | super(ConditionalVAE, self).__init__() 12 | self.class_num = class_num 13 | self.image_size = image_size 14 | self.image_channels = image_channels 15 | 16 | if dataset == 'mnist': 17 | self.convE = ConvLayers(image_channels) 18 | self.convD = DeconvLayers(image_channels) 19 | elif dataset == 'cifar10': 20 | self.convE = ResNet18() 21 | self.convD = DeconvResnet(image_channels) 22 | else: 23 | raise NotImplementedError 24 | self.flatten = modules.Flatten() 25 | self.fcE = nn.Linear(self.convE.out_feature_dim, 1024) 26 | self.z_dim = z_dim 27 | self.fcE_mean = nn.Linear(1024, self.z_dim) 28 | self.fcE_logvar = nn.Linear(1024, self.z_dim) 29 | self.fromZ = nn.Linear(2 * self.z_dim, 1024) 30 | self.fcD = nn.Linear(1024, self.convD.in_feature_dim) 31 | self.to_image = modules.Reshape(image_channels=self.convE.out_channels) 32 | self.device = None 33 | 34 | self.class_embed = nn.Linear(class_num, self.z_dim) 35 | 36 | def encode(self, x): 37 | hidden_x = self.convE(x) 38 | feature = self.flatten(hidden_x) 39 | 40 | hE = F.relu(self.fcE(feature)) 41 | 42 | z_mean = self.fcE_mean(hE) 43 | z_logvar = self.fcE_logvar(hE) 44 | 45 | return z_mean, z_logvar, hE, hidden_x 46 | 47 | def reparameterize(self, mu, logvar): 48 | std = torch.exp(0.5 * logvar).to(self.device) 49 | z = torch.randn(std.size()).to(self.device) * std + mu.to(self.device) 50 | return z 51 | 52 | def decode(self, z, y_embed): 53 | z = torch.cat([z, y_embed], dim=1) # add label information 54 | hD = F.relu(self.fromZ(z)) 55 | feature = self.fcD(hD) 56 | image_recon = self.convD(feature.view(-1, self.convD.in_channel, self.convD.in_size, self.convD.in_size)) 57 | 58 | return image_recon 59 | 60 | def forward(self, x, y_tensor): 61 | mu, logvar, hE, hidden_x = self.encode(x) 62 | z = self.reparameterize(mu, logvar) 63 | y_embed = self.class_embed(y_tensor) 64 | x_recon = self.decode(z, y_embed) 65 | return mu, logvar, x_recon 66 | -------------------------------------------------------------------------------- /models/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.conv.nets import ConvLayers, DeconvLayers, ResNet18, DeconvResnet 6 | from models.utils import modules 7 | 8 | 9 | class ConditionalVAE(nn.Module): 10 | """Variational Auto-Encoder with class conditional information""" 11 | 12 | def __init__(self, z_dim=32, image_channels=1, image_size=28, class_num=10, dataset='mnist'): 13 | super(ConditionalVAE, self).__init__() 14 | self.class_num = class_num 15 | self.image_size = image_size 16 | self.image_channels = image_channels 17 | 18 | if dataset == 'mnist': 19 | self.convE = ConvLayers(image_channels) 20 | self.convD = DeconvLayers(image_channels) 21 | elif dataset == 'cifar10': 22 | self.convE = ResNet18() 23 | self.convD = DeconvResnet(image_channels) 24 | else: 25 | raise NotImplementedError 26 | self.flatten = modules.Flatten() 27 | self.fcE = nn.Linear(self.convE.out_feature_dim, 1024) 28 | self.z_dim = z_dim 29 | self.fcE_mean = nn.Linear(1024, self.z_dim) 30 | self.fcE_logvar = nn.Linear(1024, self.z_dim) 31 | self.fromZ = nn.Linear(2 * self.z_dim, 1024) 32 | self.fcD = nn.Linear(1024, self.convD.in_feature_dim) 33 | self.to_image = modules.Reshape(image_channels=self.convE.out_channels) 34 | self.device = None 35 | 36 | self.class_embed = nn.Linear(class_num, self.z_dim) 37 | 38 | def encode(self, x): 39 | hidden_x = self.convE(x) 40 | feature = self.flatten(hidden_x) 41 | 42 | hE = F.relu(self.fcE(feature)) 43 | 44 | z_mean = self.fcE_mean(hE) 45 | z_logvar = self.fcE_logvar(hE) 46 | 47 | return z_mean, z_logvar, hE, hidden_x 48 | 49 | def reparameterize(self, mu, logvar): 50 | std = torch.exp(0.5 * logvar).to(self.device) 51 | z = torch.randn(std.size()).to(self.device) * std + mu.to(self.device) 52 | return z 53 | 54 | def decode(self, z, y_embed): 55 | z = torch.cat([z, y_embed], dim=1) # add label information 56 | hD = F.relu(self.fromZ(z)) 57 | feature = self.fcD(hD) 58 | image_recon = self.convD(feature.view(-1, self.convD.in_channel, self.convD.in_size, self.convD.in_size)) 59 | 60 | return image_recon 61 | 62 | def forward(self, x, y_tensor): 63 | mu, logvar, hE, hidden_x = self.encode(x) 64 | z = self.reparameterize(mu, logvar) 65 | y_embed = self.class_embed(y_tensor) 66 | x_recon = self.decode(z, y_embed) 67 | return mu, logvar, x_recon 68 | -------------------------------------------------------------------------------- /models/conv/cnn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def weights_init(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | m.weight.data.normal_(0.0, 0.05) 8 | elif classname.find('BatchNorm') != -1: 9 | m.weight.data.normal_(1.0, 0.02) 10 | m.bias.data.fill_(0) 11 | 12 | 13 | class ConvEncoder(nn.Module): 14 | """9-layer CNN encoder (feature extractor)""" 15 | 16 | def __init__(self, image_channels=1): 17 | super(self.__class__, self).__init__() 18 | self.conv1 = nn.Conv2d(image_channels, 64, 3, 1, 1, bias=False) 19 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False) 20 | self.conv3 = nn.Conv2d(64, 128, 3, 2, 1, bias=False) 21 | 22 | self.conv4 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 23 | self.conv5 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 24 | self.conv6 = nn.Conv2d(128, 128, 3, 2, 1, bias=False) 25 | 26 | self.conv7 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 27 | self.conv8 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 28 | self.conv9 = nn.Conv2d(128, 128, 3, 2, 1, bias=False) 29 | 30 | self.bn1 = nn.BatchNorm2d(64) 31 | self.bn2 = nn.BatchNorm2d(64) 32 | self.bn3 = nn.BatchNorm2d(128) 33 | 34 | self.bn4 = nn.BatchNorm2d(128) 35 | self.bn5 = nn.BatchNorm2d(128) 36 | self.bn6 = nn.BatchNorm2d(128) 37 | 38 | self.bn7 = nn.BatchNorm2d(128) 39 | self.bn8 = nn.BatchNorm2d(128) 40 | self.bn9 = nn.BatchNorm2d(128) 41 | 42 | self.dr1 = nn.Dropout2d(0.2) 43 | self.dr2 = nn.Dropout2d(0.2) 44 | self.dr3 = nn.Dropout2d(0.2) 45 | 46 | self.apply(weights_init) 47 | self.out_channels = 128 48 | self.out_feature_dim = 128 * 4 * 4 49 | 50 | def forward(self, x): 51 | x = self.dr1(x) 52 | x = self.conv1(x) 53 | x = self.bn1(x) 54 | x = nn.LeakyReLU(0.2)(x) 55 | x = self.conv2(x) 56 | x = self.bn2(x) 57 | x = nn.LeakyReLU(0.2)(x) 58 | x = self.conv3(x) 59 | x = self.bn3(x) 60 | x = nn.LeakyReLU(0.2)(x) 61 | 62 | x = self.dr2(x) 63 | x = self.conv4(x) 64 | x = self.bn4(x) 65 | x = nn.LeakyReLU(0.2)(x) 66 | x = self.conv5(x) 67 | x = self.bn5(x) 68 | x = nn.LeakyReLU(0.2)(x) 69 | x = self.conv6(x) 70 | x = self.bn6(x) 71 | x = nn.LeakyReLU(0.2)(x) 72 | 73 | x = self.dr3(x) 74 | x = self.conv7(x) 75 | x = self.bn7(x) 76 | x = nn.LeakyReLU(0.2)(x) 77 | x = self.conv8(x) 78 | x = self.bn8(x) 79 | x = nn.LeakyReLU(0.2)(x) 80 | x = self.conv9(x) 81 | x = self.bn9(x) 82 | x = nn.LeakyReLU(0.2)(x) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /CSI+CMG/models/conv/cnn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def weights_init(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | m.weight.data.normal_(0.0, 0.05) 8 | elif classname.find('BatchNorm') != -1: 9 | m.weight.data.normal_(1.0, 0.02) 10 | m.bias.data.fill_(0) 11 | 12 | 13 | class ConvEncoder(nn.Module): 14 | """9-layer CNN encoder (feature extractor)""" 15 | 16 | def __init__(self, image_channels=1): 17 | super(self.__class__, self).__init__() 18 | self.conv1 = nn.Conv2d(image_channels, 64, 3, 1, 1, bias=False) 19 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False) 20 | self.conv3 = nn.Conv2d(64, 128, 3, 2, 1, bias=False) 21 | 22 | self.conv4 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 23 | self.conv5 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 24 | self.conv6 = nn.Conv2d(128, 128, 3, 2, 1, bias=False) 25 | 26 | self.conv7 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 27 | self.conv8 = nn.Conv2d(128, 128, 3, 1, 1, bias=False) 28 | self.conv9 = nn.Conv2d(128, 128, 3, 2, 1, bias=False) 29 | 30 | self.bn1 = nn.BatchNorm2d(64) 31 | self.bn2 = nn.BatchNorm2d(64) 32 | self.bn3 = nn.BatchNorm2d(128) 33 | 34 | self.bn4 = nn.BatchNorm2d(128) 35 | self.bn5 = nn.BatchNorm2d(128) 36 | self.bn6 = nn.BatchNorm2d(128) 37 | 38 | self.bn7 = nn.BatchNorm2d(128) 39 | self.bn8 = nn.BatchNorm2d(128) 40 | self.bn9 = nn.BatchNorm2d(128) 41 | 42 | self.dr1 = nn.Dropout2d(0.2) 43 | self.dr2 = nn.Dropout2d(0.2) 44 | self.dr3 = nn.Dropout2d(0.2) 45 | 46 | self.apply(weights_init) 47 | self.out_channels = 128 48 | self.out_feature_dim = 128 * 4 * 4 49 | 50 | def forward(self, x): 51 | x = self.dr1(x) 52 | x = self.conv1(x) 53 | x = self.bn1(x) 54 | x = nn.LeakyReLU(0.2)(x) 55 | x = self.conv2(x) 56 | x = self.bn2(x) 57 | x = nn.LeakyReLU(0.2)(x) 58 | x = self.conv3(x) 59 | x = self.bn3(x) 60 | x = nn.LeakyReLU(0.2)(x) 61 | 62 | x = self.dr2(x) 63 | x = self.conv4(x) 64 | x = self.bn4(x) 65 | x = nn.LeakyReLU(0.2)(x) 66 | x = self.conv5(x) 67 | x = self.bn5(x) 68 | x = nn.LeakyReLU(0.2)(x) 69 | x = self.conv6(x) 70 | x = self.bn6(x) 71 | x = nn.LeakyReLU(0.2)(x) 72 | 73 | x = self.dr3(x) 74 | x = self.conv7(x) 75 | x = self.bn7(x) 76 | x = nn.LeakyReLU(0.2)(x) 77 | x = self.conv8(x) 78 | x = self.bn8(x) 79 | x = nn.LeakyReLU(0.2)(x) 80 | x = self.conv9(x) 81 | x = self.bn9(x) 82 | x = nn.LeakyReLU(0.2)(x) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /models/main_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from models.conv.cnn import ConvEncoder 5 | from models.conv.nets import ResNet18 6 | from models.fc.nets import MLP 7 | from models.utils import loss_functions as lf 8 | from models.utils import modules 9 | 10 | 11 | class MainModel(nn.Module): 12 | """ 13 | feature_extractor(CNN) -> classifier (MLP) 14 | """ 15 | 16 | def __init__(self, image_size, image_channels, classes, dataset='mnist'): 17 | super(MainModel, self).__init__() 18 | self.image_size = image_size 19 | self.image_channels = image_channels 20 | self.classes = classes 21 | 22 | # for encoder 23 | if dataset == 'mnist': 24 | self.convE = ConvEncoder(image_channels=image_channels) 25 | elif dataset == 'cifar10': 26 | self.convE = ResNet18() 27 | else: 28 | raise NotImplementedError 29 | self.flatten = modules.Flatten() 30 | 31 | # classifier 32 | self.classifier = MLP(self.convE.out_feature_dim, classes) 33 | 34 | self.optimizer = None # needs to be set before training starts 35 | 36 | self.device = None # needs to be set before using the model 37 | 38 | # --------- FROWARD FUNCTIONS ---------# 39 | def encode(self, x): 40 | """ 41 | pass input through feed-forward connections to get [image_features] 42 | """ 43 | # Forward-pass through conv-layers 44 | hidden_x = self.convE(x) 45 | 46 | return hidden_x 47 | 48 | def classify(self, x): 49 | """ 50 | For input [x] (image or extracted "internal“ image features), 51 | return predicted scores (<2D tensor> [batch_size] * [classes]) 52 | """ 53 | result = self.classifier(x) 54 | return result 55 | 56 | def forward(self, x): 57 | """ 58 | Forward function to propagate [x] through the encoder and the classifier. 59 | """ 60 | hidden_x = self.encode(x) 61 | prediction = self.classifier(hidden_x) 62 | return prediction 63 | 64 | # ------------------TRAINING FUNCTIONS----------------------# 65 | def train_a_batch(self, x, y): 66 | """ 67 | Train model for one batch ([x], [y]) 68 | """ 69 | # Set model to training-mode 70 | self.train() 71 | 72 | # Reset optimizer 73 | self.optimizer.zero_grad() 74 | 75 | # Run the model 76 | hidden_x = self.encode(x) 77 | prediction = self.classifier(hidden_x) 78 | predL = F.cross_entropy(prediction, y, reduction='none') 79 | loss = lf.weighted_average(predL, weights=None, dim=0) 80 | 81 | loss.backward() 82 | 83 | self.optimizer.step() 84 | 85 | return loss.item() 86 | -------------------------------------------------------------------------------- /datasets/datasetsHelper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.data.dataset import Subset 4 | from torchvision import datasets, transforms 5 | 6 | root_path = os.path.dirname(__file__) 7 | 8 | 9 | def get_subclass_dataset(dataset, classes): 10 | if not isinstance(classes, list): 11 | classes = [classes] 12 | indices = [] 13 | for idx, data in enumerate(dataset): 14 | if data[1] in classes: 15 | indices.append(idx) 16 | 17 | dataset = Subset(dataset, indices) 18 | return dataset 19 | 20 | 21 | def get_dataset(dataset, train_transform, test_transform, download=False, seen=None): 22 | """Get datasets for setting 1 (OOD Detection on the Same Dataset).""" 23 | 24 | if dataset == 'cifar10': 25 | DATA_PATH = os.path.join(root_path, 'cifar10') 26 | class_idx = [int(num) for num in seen] 27 | for i in range(10): 28 | if i in class_idx: 29 | continue 30 | class_idx.append(i) 31 | train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform, 32 | target_transform=lambda x: class_idx.index(x)) 33 | test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform, 34 | target_transform=lambda x: class_idx.index(x)) 35 | seen_class_idx = [0, 1, 2, 3, 4, 5] 36 | unseen_class_idx = [6, 7, 8, 9] 37 | 38 | train_set = get_subclass_dataset(train_set, seen_class_idx) 39 | test_set_seen = get_subclass_dataset(test_set, seen_class_idx) 40 | test_set_unseen = get_subclass_dataset(test_set, unseen_class_idx) 41 | 42 | elif dataset == 'mnist': 43 | DATA_PATH = os.path.join(root_path, 'mnist') 44 | class_idx = [int(num) for num in seen] 45 | for i in range(10): 46 | if i in class_idx: 47 | continue 48 | class_idx.append(i) 49 | train_set = datasets.MNIST(DATA_PATH, train=True, download=download, transform=train_transform, 50 | target_transform=lambda x: class_idx.index(x)) 51 | test_set = datasets.MNIST(DATA_PATH, train=False, download=download, transform=test_transform, 52 | target_transform=lambda x: class_idx.index(x)) 53 | seen_class_idx = [0, 1, 2, 3, 4, 5] 54 | unseen_class_idx = [6, 7, 8, 9] 55 | 56 | train_set = get_subclass_dataset(train_set, seen_class_idx) 57 | test_set_seen = get_subclass_dataset(test_set, seen_class_idx) 58 | test_set_unseen = get_subclass_dataset(test_set, unseen_class_idx) 59 | 60 | else: 61 | raise NotImplementedError 62 | 63 | return train_set, test_set_seen, test_set_unseen 64 | 65 | 66 | def get_ood_dataset(dataset): 67 | """Get datasets for setting 2 (OOD Detection on Different Datasets).""" 68 | 69 | if dataset == 'SVHN': 70 | dir = os.path.join(root_path, 'svhn') 71 | data = datasets.SVHN(root=dir, split='test', download=True, transform=transforms.ToTensor()) 72 | elif dataset == 'LSUN': 73 | dir = os.path.join(root_path, 'LSUN_resize') 74 | data = datasets.ImageFolder(dir, transform=transforms.ToTensor()) 75 | elif dataset == 'tinyImageNet': 76 | dir = os.path.join(root_path, 'Imagenet_resize') 77 | data = datasets.ImageFolder(dir, transform=transforms.ToTensor()) 78 | elif dataset == 'LSUN-FIX': 79 | dir = os.path.join(root_path, 'LSUN_fix') 80 | data = datasets.ImageFolder(dir, transform=transforms.ToTensor()) 81 | elif dataset == 'ImageNet-FIX': 82 | dir = os.path.join(root_path, 'Imagenet_fix') 83 | data = datasets.ImageFolder(dir, transform=transforms.ToTensor()) 84 | elif dataset == 'CIFAR100': 85 | dir = os.path.join(root_path, 'cifar100') 86 | data = datasets.CIFAR100( 87 | root=dir, train=False, transform=transforms.ToTensor(), download=True) 88 | else: 89 | raise NotImplementedError 90 | 91 | return data 92 | -------------------------------------------------------------------------------- /models/conv/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class ConvLayers(nn.Module): 7 | """Convolutional feature extractor model for (natural) images.""" 8 | 9 | def __init__(self, image_channels): 10 | super(ConvLayers, self).__init__() 11 | self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=(4, 4), stride=2, padding=(15, 15)) 12 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(4, 4), stride=2, padding=(15, 15)) 13 | self.out_channels = 128 14 | self.out_feature_dim = 128 * 28 * 28 15 | 16 | def forward(self, x): 17 | x = F.relu(self.conv1(x)) 18 | feature = F.relu(self.conv2(x)) 19 | 20 | return feature 21 | 22 | 23 | class DeconvLayers(nn.Module): 24 | """'Deconvolutional' feature decoder model for (natural) images.""" 25 | 26 | def __init__(self, image_channels): 27 | super(DeconvLayers, self).__init__() 28 | self.image_channels = image_channels 29 | self.in_channel = 128 30 | self.in_size = 7 31 | self.in_feature_dim = 7 * 7 * 128 32 | self.deconv1 = nn.ConvTranspose2d( 33 | in_channels=128, out_channels=64, kernel_size=4, padding=1, stride=2) 34 | self.deconv2 = nn.ConvTranspose2d( 35 | in_channels=64, out_channels=self.image_channels, kernel_size=4, padding=1, stride=2) 36 | 37 | def forward(self, x): 38 | x = F.relu(self.deconv1(x)) 39 | x = torch.sigmoid(self.deconv2(x)) 40 | 41 | return x 42 | 43 | 44 | # --------------------------------------------------------------------------------------------------- 45 | 46 | 47 | class ResBlock(nn.Module): 48 | """ 49 | Input: [batch_size] x [dim] x [image_size] x [image_size] tensor 50 | Output: [batch_size] x [dim] x [image_size] x [image_size] tensor 51 | """ 52 | 53 | def __init__(self, dim): 54 | super().__init__() 55 | self.block = nn.Sequential( 56 | nn.ReLU(True), 57 | nn.Conv2d(dim, dim, 3, 1, 1), 58 | nn.BatchNorm2d(dim), 59 | nn.ReLU(True), 60 | nn.Conv2d(dim, dim, 1), 61 | nn.BatchNorm2d(dim) 62 | ) 63 | 64 | def forward(self, x): 65 | return x + self.block(x) 66 | 67 | 68 | class BasicBlock(nn.Module): 69 | expansion = 1 70 | 71 | def __init__(self, in_planes, planes, stride=1): 72 | super(BasicBlock, self).__init__() 73 | self.conv1 = nn.Conv2d( 74 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(planes) 76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 77 | stride=1, padding=1, bias=False) 78 | self.bn2 = nn.BatchNorm2d(planes) 79 | 80 | self.shortcut = nn.Sequential() 81 | if stride != 1 or in_planes != self.expansion * planes: 82 | self.shortcut = nn.Sequential( 83 | nn.Conv2d(in_planes, self.expansion * planes, 84 | kernel_size=1, stride=stride, bias=False), 85 | nn.BatchNorm2d(self.expansion * planes) 86 | ) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.bn2(self.conv2(out)) 91 | out += self.shortcut(x) 92 | out = F.relu(out) 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, block, num_blocks, channel_num=3): 98 | super(ResNet, self).__init__() 99 | self.in_planes = 64 100 | 101 | self.conv1 = nn.Conv2d(channel_num, 64, kernel_size=3, 102 | stride=1, padding=1, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 105 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 106 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 107 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 108 | self.out_channels = 512 109 | self.out_feature_dim = 512 * block.expansion 110 | 111 | def _make_layer(self, block, planes, num_blocks, stride): 112 | strides = [stride] + [1] * (num_blocks - 1) 113 | layers = [] 114 | for stride in strides: 115 | layers.append(block(self.in_planes, planes, stride)) 116 | self.in_planes = planes * block.expansion 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | out = F.relu(self.bn1(self.conv1(x))) 121 | out = self.layer1(out) 122 | out = self.layer2(out) 123 | out = self.layer3(out) 124 | out = self.layer4(out) 125 | out = F.avg_pool2d(out, 4) 126 | return out 127 | 128 | 129 | def ResNet18(channel_num=3): 130 | return ResNet(BasicBlock, [2, 2, 2, 2], channel_num=channel_num) 131 | 132 | 133 | class DeconvResnet(nn.Module): 134 | """'Deconvolutional' feature decoder model for (natural) images using ResBlock as the backbone""" 135 | 136 | def __init__(self, channel_num, dim=512): 137 | super(DeconvResnet, self).__init__() 138 | self.image_channels = channel_num 139 | self.in_channel = dim 140 | self.in_size = 4 141 | self.in_feature_dim = dim * 4 * 4 142 | self.decoder = nn.Sequential( 143 | ResBlock(dim), 144 | ResBlock(dim), 145 | 146 | nn.ReLU(True), 147 | nn.ConvTranspose2d(dim, dim, 4, 2, 1), 148 | 149 | nn.BatchNorm2d(dim), 150 | nn.ReLU(True), 151 | nn.ConvTranspose2d(dim, dim, 4, 2, 1), 152 | 153 | nn.BatchNorm2d(dim), 154 | nn.ReLU(True), 155 | nn.ConvTranspose2d(dim, channel_num, 4, 2, 1) 156 | ) 157 | 158 | def forward(self, x): 159 | x = self.decoder(x) 160 | x = torch.sigmoid(x) 161 | 162 | return x 163 | -------------------------------------------------------------------------------- /CSI+CMG/models/conv/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class ConvLayers(nn.Module): 7 | """Convolutional feature extractor model for (natural) images.""" 8 | 9 | def __init__(self, image_channels): 10 | super(ConvLayers, self).__init__() 11 | self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=(4, 4), stride=2, padding=(15, 15)) 12 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(4, 4), stride=2, padding=(15, 15)) 13 | self.out_channels = 128 14 | self.out_feature_dim = 128 * 28 * 28 15 | 16 | def forward(self, x): 17 | x = F.relu(self.conv1(x)) 18 | feature = F.relu(self.conv2(x)) 19 | 20 | return feature 21 | 22 | 23 | class DeconvLayers(nn.Module): 24 | """'Deconvolutional' feature decoder model for (natural) images.""" 25 | 26 | def __init__(self, image_channels): 27 | super(DeconvLayers, self).__init__() 28 | self.image_channels = image_channels 29 | self.in_channel = 128 30 | self.in_size = 7 31 | self.in_feature_dim = 7 * 7 * 128 32 | self.deconv1 = nn.ConvTranspose2d( 33 | in_channels=128, out_channels=64, kernel_size=4, padding=1, stride=2) 34 | self.deconv2 = nn.ConvTranspose2d( 35 | in_channels=64, out_channels=self.image_channels, kernel_size=4, padding=1, stride=2) 36 | 37 | def forward(self, x): 38 | x = F.relu(self.deconv1(x)) 39 | x = torch.sigmoid(self.deconv2(x)) 40 | 41 | return x 42 | 43 | 44 | # --------------------------------------------------------------------------------------------------- 45 | 46 | 47 | class ResBlock(nn.Module): 48 | """ 49 | Input: [batch_size] x [dim] x [image_size] x [image_size] tensor 50 | Output: [batch_size] x [dim] x [image_size] x [image_size] tensor 51 | """ 52 | 53 | def __init__(self, dim): 54 | super().__init__() 55 | self.block = nn.Sequential( 56 | nn.ReLU(True), 57 | nn.Conv2d(dim, dim, 3, 1, 1), 58 | nn.BatchNorm2d(dim), 59 | nn.ReLU(True), 60 | nn.Conv2d(dim, dim, 1), 61 | nn.BatchNorm2d(dim) 62 | ) 63 | 64 | def forward(self, x): 65 | return x + self.block(x) 66 | 67 | 68 | class BasicBlock(nn.Module): 69 | expansion = 1 70 | 71 | def __init__(self, in_planes, planes, stride=1): 72 | super(BasicBlock, self).__init__() 73 | self.conv1 = nn.Conv2d( 74 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(planes) 76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 77 | stride=1, padding=1, bias=False) 78 | self.bn2 = nn.BatchNorm2d(planes) 79 | 80 | self.shortcut = nn.Sequential() 81 | if stride != 1 or in_planes != self.expansion * planes: 82 | self.shortcut = nn.Sequential( 83 | nn.Conv2d(in_planes, self.expansion * planes, 84 | kernel_size=1, stride=stride, bias=False), 85 | nn.BatchNorm2d(self.expansion * planes) 86 | ) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.bn2(self.conv2(out)) 91 | out += self.shortcut(x) 92 | out = F.relu(out) 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, block, num_blocks, channel_num=3): 98 | super(ResNet, self).__init__() 99 | self.in_planes = 64 100 | 101 | self.conv1 = nn.Conv2d(channel_num, 64, kernel_size=3, 102 | stride=1, padding=1, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 105 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 106 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 107 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 108 | self.out_channels = 512 109 | self.out_feature_dim = 512 * block.expansion 110 | 111 | def _make_layer(self, block, planes, num_blocks, stride): 112 | strides = [stride] + [1] * (num_blocks - 1) 113 | layers = [] 114 | for stride in strides: 115 | layers.append(block(self.in_planes, planes, stride)) 116 | self.in_planes = planes * block.expansion 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | out = F.relu(self.bn1(self.conv1(x))) 121 | out = self.layer1(out) 122 | out = self.layer2(out) 123 | out = self.layer3(out) 124 | out = self.layer4(out) 125 | out = F.avg_pool2d(out, 4) 126 | return out 127 | 128 | 129 | def ResNet18(channel_num=3): 130 | return ResNet(BasicBlock, [2, 2, 2, 2], channel_num=channel_num) 131 | 132 | 133 | class DeconvResnet(nn.Module): 134 | """'Deconvolutional' feature decoder model for (natural) images using ResBlock as the backbone""" 135 | 136 | def __init__(self, channel_num, dim=512): 137 | super(DeconvResnet, self).__init__() 138 | self.image_channels = channel_num 139 | self.in_channel = dim 140 | self.in_size = 4 141 | self.in_feature_dim = dim * 4 * 4 142 | self.decoder = nn.Sequential( 143 | ResBlock(dim), 144 | ResBlock(dim), 145 | 146 | nn.ReLU(True), 147 | nn.ConvTranspose2d(dim, dim, 4, 2, 1), 148 | 149 | nn.BatchNorm2d(dim), 150 | nn.ReLU(True), 151 | nn.ConvTranspose2d(dim, dim, 4, 2, 1), 152 | 153 | nn.BatchNorm2d(dim), 154 | nn.ReLU(True), 155 | nn.ConvTranspose2d(dim, channel_num, 4, 2, 1) 156 | ) 157 | 158 | def forward(self, x): 159 | x = self.decoder(x) 160 | x = torch.sigmoid(x) 161 | 162 | return x 163 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """CMG Stage 1: IND classifier building & CVAE training.""" 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torchvision import datasets, transforms 9 | 10 | from datasets.datasetsHelper import get_dataset 11 | from models.main_model import MainModel 12 | from models.vae import ConditionalVAE 13 | from training.train_classifier import train_classifier 14 | from training.train_cvae import train_cvae 15 | from utils import get_args 16 | 17 | args = get_args() 18 | 19 | 20 | def setup_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | setup_seed(args.seed) 29 | 30 | batch_size = 512 31 | 32 | # gpu 33 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 34 | print(device) 35 | 36 | 37 | def main(): 38 | # prepare data 39 | if args.task == 'same_dataset_mnist': 40 | if args.partition == 'partition1': 41 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='012345') 42 | elif args.partition == 'partition2': 43 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='123456') 44 | elif args.partition == 'partition3': 45 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='234567') 46 | elif args.partition == 'partition4': 47 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='345678') 48 | elif args.partition == 'partition5': 49 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='456789') 50 | else: 51 | raise NotImplementedError 52 | elif args.task == 'same_dataset_cifar10': 53 | if args.command == 'train_classifier': 54 | train_transform = transforms.Compose([ 55 | transforms.RandomCrop(32, padding=4), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | ]) 59 | elif args.command == 'train_cvae': 60 | train_transform = transforms.ToTensor() 61 | 62 | if args.partition == 'partition1': 63 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform, 64 | test_transform=transforms.ToTensor(), seen='012345') 65 | elif args.partition == 'partition2': 66 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform, 67 | test_transform=transforms.ToTensor(), seen='123456') 68 | elif args.partition == 'partition3': 69 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform, 70 | test_transform=transforms.ToTensor(), seen='234567') 71 | elif args.partition == 'partition4': 72 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform, 73 | test_transform=transforms.ToTensor(), seen='345678') 74 | elif args.partition == 'partition5': 75 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform, 76 | test_transform=transforms.ToTensor(), seen='456789') 77 | else: 78 | raise NotImplementedError 79 | elif args.task == 'different_dataset': 80 | if args.command == 'train_classifier': 81 | train_transform = transforms.Compose([ 82 | transforms.RandomCrop(32, padding=4), 83 | transforms.RandomHorizontalFlip(), 84 | transforms.ToTensor(), 85 | ]) 86 | elif args.command == 'train_cvae': 87 | train_transform = transforms.ToTensor() 88 | 89 | root_path = os.path.dirname(__file__) 90 | data_path = os.path.join(root_path, 'datasets/cifar10') 91 | train_data = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform) 92 | else: 93 | raise NotImplementedError 94 | 95 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8) 96 | 97 | # CMG stage 1 98 | if args.command == 'train_classifier': 99 | if args.task == 'same_dataset_mnist': 100 | model = MainModel(28, 1, 11, dataset='mnist') 101 | model.to(device) 102 | train_classifier(model, train_loader, device, args.params_dict_name, dataset='mnist') 103 | elif args.task == 'same_dataset_cifar10': 104 | model = MainModel(32, 3, 11, dataset='cifar10') 105 | model.to(device) 106 | train_classifier(model, train_loader, device, args.params_dict_name, dataset='cifar10') 107 | elif args.task == 'different_dataset': 108 | model = MainModel(32, 3, 110, dataset='cifar10') 109 | model.to(device) 110 | train_classifier(model, train_loader, device, args.params_dict_name, dataset='cifar10') 111 | 112 | elif args.command == 'train_cvae': 113 | if args.task == 'same_dataset_mnist': 114 | model = ConditionalVAE(image_channels=1, image_size=28, dataset='mnist') 115 | model.device = device 116 | model.to(device) 117 | train_cvae(model, train_loader, device, args.params_dict_name, dataset='mnist') 118 | elif args.task == 'same_dataset_cifar10': 119 | model = ConditionalVAE(image_channels=3, image_size=32, dataset='cifar10') 120 | model.device = device 121 | model.to(device) 122 | train_cvae(model, train_loader, device, args.params_dict_name, dataset='cifar10') 123 | elif args.task == 'different_dataset': 124 | model = ConditionalVAE(image_channels=3, image_size=32, dataset='cifar10') 125 | model.device = device 126 | model.to(device) 127 | train_cvae(model, train_loader, device, args.params_dict_name, dataset='cifar10') 128 | 129 | else: 130 | raise NotImplementedError 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # CMG: A Class-Mixed Generation Approach to Out-of-Distribution Detection 2 | 3 | This repository contains the code for our ECML'22 paper [CMG: A Class-Mixed Generation Approach to Out-of-Distribution 4 | Detection](https://2022.ecmlpkdd.org/wp-content/uploads/2022/09/sub_531.pdf) by Mengyu Wang*, Yijia Shao*, Haowei Lin, Wenpeng Hu, and Bing Liu. 5 | 6 | ## Overview 7 | 8 | We propose CMG (*Class-Mixed Generation*) for efficient out-of-distribution (OOD) detection. CMG uses CVAE to generate 9 | pseudo-OOD training samples based on class-mixed information in the latent space. The pseudo-OOD data are used to 10 | fine-tune a classifier in a 2-stage manner. By using different loss functions, we propose two versions of the CMG 11 | system, CMG-softmax (CMG-s) and CMG-energy (CMG-e). The figure below illustrates our CMG system. 12 | 13 | model 14 | 15 | ## Requirements 16 | 17 | ### Environments 18 | 19 | The required packages are as follows: 20 | 21 | - python 3.5 22 | - torch 1.2 23 | - torchvision 0.4 24 | - CUDA 10.0 25 | - scikit-learn 0.22 26 | 27 | ### Datasets 28 | 29 | For Setting 1 - OOD Detection on the Same Dataset, this repository supports MNIST and CIFAR10 which can be directly 30 | downloaded through `torchvision`. 31 | 32 | For Setting 2 - OOD Detection on Different Datasets, this repository supports CIFAR10 as IND data and SVHN / LSUN / 33 | tinyImageNet / LSUN-FIX / ImageNet-FIX / CIFAR-100 as OOD data. SVHN and CIFAR 100 can be directly downloaded 34 | through `torchvision`. The remaining data have been processed by the CSI paper and you can download 35 | them [here](https://github.com/alinlab/CSI). 36 | 37 | Please download datasets to [./datasets](./datasets) and rename the file. See [./datasets/datasetsHelper.py](./datasets/datasetsHelper.py) and our paper for more details. 38 | 39 | ## Training CMG 40 | 41 | ### CMG Stage 1 42 | 43 | CMG Stage 1 involves IND classifier building and CVAE training. 44 | 45 | The standard code is running with a single GPU and you can assign a specific GPU in the command line. 46 | 47 | For more details, please view [utils.py](./utils.py). 48 | 49 | #### Train IND classifier 50 | 51 | To train IND classifier on MNIST for Setting 1, run this command: 52 | 53 | ``` 54 | python -m train \ 55 | --task 'same_dataset_mnist' \ 56 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \ 57 | --command 'train_classifier' \ 58 | --device {the available GPU in your cluser, e.g., cuda:0} \ 59 | --params-dict-name {checkpoint name, e.g., './ckpt/main_model_partition1.pkl'} 60 | ``` 61 | 62 | To train IND classifier on CIFAR10 for Setting 1, run this command: 63 | 64 | ``` 65 | python -m train \ 66 | --task 'same_dataset_cifar10' \ 67 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \ 68 | --command 'train_classifier' \ 69 | --device {the available GPU in your cluser, e.g., cuda:0} \ 70 | --params-dict-name {checkpoint name, e.g., './ckpt/main_model_partition1.pkl'} 71 | ``` 72 | 73 | To train IND classifier on CIFAR10 for Setting 2, run this command: 74 | 75 | ``` 76 | python -m train \ 77 | --task 'different_dataset' \ 78 | --command 'train_classifier' \ 79 | --device {the available GPU in your cluser, e.g., cuda:0} \ 80 | --params-dict-name {checkpoint name, e.g., './ckpt/main_model_different_dataset.pkl'} 81 | ``` 82 | 83 | #### Train CVAE 84 | 85 | To train CVAE on MNIST for Setting 1, run this command: 86 | 87 | ``` 88 | python -m train \ 89 | --task 'same_dataset_mnist' \ 90 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \ 91 | --command 'train_cvae' \ 92 | --device {the available GPU in your cluser, e.g., cuda:0} \ 93 | --params-dict-name {checkpoint name, e.g., './ckpt/cvae_partition1.pkl'} 94 | ``` 95 | 96 | To train CVAE on CIFAR10 for Setting 1, run this command: 97 | 98 | ``` 99 | python -m train \ 100 | --task 'same_dataset_cifar10' \ 101 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \ 102 | --command 'train_cvae' \ 103 | --device {the available GPU in your cluser, e.g., cuda:0} \ 104 | --params-dict-name {checkpoint name, e.g., './ckpt/cvae_partition1.pkl'} 105 | ``` 106 | 107 | To train CVAE on CIFAR10 for Setting 2, run this command: 108 | 109 | ``` 110 | python -m train \ 111 | --task 'different_dataset' \ 112 | --command 'train_cvae' \ 113 | --device {the available GPU in your cluser, e.g., cuda:0} \ 114 | --params-dict-name {checkpoint name, e.g., './ckpt/cvae_different_dataset.pkl'} 115 | ``` 116 | 117 | 118 | ### CMG Stage 2 and Evaluation 119 | 120 | To perform CMG Stage 2 and get the final result on MNIST for Setting 1, run this command: 121 | 122 | ``` 123 | python -m eval \ 124 | --task 'same_dataset_mnist' \ 125 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \ 126 | --device {the available GPU in your cluser, e.g., cuda:0} \ 127 | --params-dict-name {main model checkpoint name} \ 128 | --params-dict-name2 {cvae checkpoint name} \ 129 | --mode {'CMG-energy'/'CMG-softmax'} 130 | ``` 131 | 132 | To perform CMG Stage 2 and get the final result on CIFAR10 for Setting 1, run this command: 133 | 134 | ``` 135 | python -m eval \ 136 | --task 'same_dataset_cifar10' \ 137 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \ 138 | --device {the available GPU in your cluser, e.g., cuda:0} \ 139 | --params-dict-name {main model checkpoint name} \ 140 | --params-dict-name2 {cvae checkpoint name} \ 141 | --mode {'CMG-energy'/'CMG-softmax'} 142 | ``` 143 | 144 | To perform CMG Stage 2 and get the final result on CIFAR10 for Setting 2, run this command: 145 | 146 | ``` 147 | python -m eval \ 148 | --task 'different_dataset' \ 149 | --ood-dataset {'SVHN'/'LSUN'/'LSUN-FIX'/'tinyImageNet'/'ImageNet-FIX'/'CIFAR100'} 150 | --device {the available GPU in your cluser, e.g., cuda:0} \ 151 | --params-dict-name {main model checkpoint name} \ 152 | --params-dict-name2 {cvae checkpoint name} \ 153 | --mode {'CMG-energy'/'CMG-softmax'} 154 | ``` 155 | 156 | ## Apply CMG to CSI 157 | 158 | CMG is a training paradigm orthogonal to existing OOD detection models and can enhance existing systems to further 159 | improve their performance (please see our paper for details). 160 | 161 | Here we also provide codes to reproduce the new SOTA by applying CMG to [CSI](https://github.com/alinlab/CSI). 162 | See [./CSI+CMG](./CSI+CMG). 163 | 164 | ## Acknowledgements 165 | 166 | We thank [CSI](https://github.com/alinlab/CSI) for providing downloaded links for their processed data and our "CSI+CMG" 167 | code is also based on their implementation. 168 | 169 | ## Citation 170 | 171 | Please cite our paper if you use this code of parts of it: 172 | ``` 173 | @article{wangcmg, 174 | title={CMG: A Class-Mixed Generation Approach to Out-of-Distribution Detection}, 175 | author={Wang, Mengyu and Shao, Yijia and Lin, Haowei and Hu, Wenpeng and Liu, Bing} 176 | ``` 177 | -------------------------------------------------------------------------------- /CSI+CMG/CSI_model/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | BasicBlock and Bottleneck module is from the original ResNet paper: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | PreActBlock and PreActBottleneck module is from the later paper: 6 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 8 | ''' 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from models_copy.base_model import BaseModel 13 | from models_copy.transform_layers import NormalizeLayer 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, in_planes, planes, stride=1): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(in_planes, planes, stride) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | 30 | self.shortcut = nn.Sequential() 31 | if stride != 1 or in_planes != self.expansion * planes: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(self.expansion * planes) 35 | ) 36 | 37 | def forward(self, x): 38 | out = F.relu(self.bn1(self.conv1(x))) 39 | out = self.bn2(self.conv2(out)) 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | '''Pre-activation version of the BasicBlock.''' 47 | expansion = 1 48 | 49 | def __init__(self, in_planes, planes, stride=1): 50 | super(PreActBlock, self).__init__() 51 | self.conv1 = conv3x3(in_planes, planes, stride) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn1 = nn.BatchNorm2d(in_planes) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion * planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | out += shortcut 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, in_planes, planes, stride=1): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 77 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 78 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.bn2 = nn.BatchNorm2d(planes) 81 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 82 | 83 | self.shortcut = nn.Sequential() 84 | if stride != 1 or in_planes != self.expansion * planes: 85 | self.shortcut = nn.Sequential( 86 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 87 | nn.BatchNorm2d(self.expansion * planes) 88 | ) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = F.relu(self.bn2(self.conv2(out))) 93 | out = self.bn3(self.conv3(out)) 94 | out += self.shortcut(x) 95 | out = F.relu(out) 96 | return out 97 | 98 | 99 | class PreActBottleneck(nn.Module): 100 | '''Pre-activation version of the original Bottleneck module.''' 101 | expansion = 4 102 | 103 | def __init__(self, in_planes, planes, stride=1): 104 | super(PreActBottleneck, self).__init__() 105 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 107 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(in_planes) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | self.bn3 = nn.BatchNorm2d(planes) 111 | 112 | self.shortcut = nn.Sequential() 113 | if stride != 1 or in_planes != self.expansion * planes: 114 | self.shortcut = nn.Sequential( 115 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 116 | ) 117 | 118 | def forward(self, x): 119 | out = F.relu(self.bn1(x)) 120 | shortcut = self.shortcut(out) 121 | out = self.conv1(out) 122 | out = self.conv2(F.relu(self.bn2(out))) 123 | out = self.conv3(F.relu(self.bn3(out))) 124 | out += shortcut 125 | return out 126 | 127 | 128 | class ResNet(BaseModel): 129 | def __init__(self, block, num_blocks, num_classes=10): 130 | last_dim = 512 * block.expansion 131 | super(ResNet, self).__init__(last_dim, num_classes) 132 | 133 | self.in_planes = 64 134 | self.last_dim = last_dim 135 | 136 | self.normalize = NormalizeLayer() 137 | 138 | self.conv1 = conv3x3(3, 64) 139 | self.bn1 = nn.BatchNorm2d(64) 140 | 141 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 142 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 143 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 144 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 145 | 146 | def _make_layer(self, block, planes, num_blocks, stride): 147 | strides = [stride] + [1] * (num_blocks - 1) 148 | layers = [] 149 | for stride in strides: 150 | layers.append(block(self.in_planes, planes, stride)) 151 | self.in_planes = planes * block.expansion 152 | return nn.Sequential(*layers) 153 | 154 | def penultimate(self, x, all_features=False): 155 | out_list = [] 156 | 157 | out = self.normalize(x) 158 | out = self.conv1(out) 159 | out = self.bn1(out) 160 | out = F.relu(out) 161 | out_list.append(out) 162 | 163 | out = self.layer1(out) 164 | out_list.append(out) 165 | out = self.layer2(out) 166 | out_list.append(out) 167 | out = self.layer3(out) 168 | out_list.append(out) 169 | out = self.layer4(out) 170 | out_list.append(out) 171 | 172 | out = F.avg_pool2d(out, 4) 173 | out = out.view(out.size(0), -1) 174 | 175 | if all_features: 176 | return out, out_list 177 | else: 178 | return out 179 | 180 | 181 | def ResNet18(num_classes): 182 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 183 | 184 | 185 | def ResNet34(num_classes): 186 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 187 | 188 | 189 | def ResNet50(num_classes): 190 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 191 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """CMG Stage 2: Fine-tuning the classification head using IND data and pseudo-OOD data generated by the CVAE.""" 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torchvision import datasets, transforms 9 | 10 | from datasets.datasetsHelper import get_dataset, get_ood_dataset 11 | from models.main_model import MainModel 12 | from models.vae import ConditionalVAE 13 | from training.fine_tune import fine_tune_same_dataset, fine_tune_different_dataset 14 | from utils import get_args 15 | 16 | args = get_args() 17 | 18 | 19 | def setup_seed(seed): 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | torch.backends.cudnn.deterministic = True 25 | 26 | 27 | setup_seed(args.seed) 28 | 29 | batch_size = 128 30 | 31 | # gpu 32 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 33 | print(device) 34 | 35 | 36 | def main(): 37 | if args.task == 'same_dataset_mnist': 38 | # prepare data 39 | if args.partition == 'partition1': 40 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(), 41 | transforms.ToTensor(), seen='012345') 42 | elif args.partition == 'partition2': 43 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(), 44 | transforms.ToTensor(), seen='123456') 45 | elif args.partition == 'partition3': 46 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(), 47 | transforms.ToTensor(), seen='234567') 48 | elif args.partition == 'partition4': 49 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(), 50 | transforms.ToTensor(), seen='345678') 51 | elif args.partition == 'partition5': 52 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(), 53 | transforms.ToTensor(), seen='456789') 54 | else: 55 | raise NotImplementedError 56 | 57 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8) 58 | test_loader_seen = DataLoader(test_data_seen, batch_size=512, num_workers=8) 59 | test_loader_unseen = DataLoader(test_data_unseen, batch_size=512, num_workers=8) 60 | 61 | # prepare model 62 | classifier = MainModel(28, 1, 11, dataset='mnist') 63 | classifier.load_state_dict(torch.load(args.params_dict_name, map_location='cpu')) 64 | classifier.to(device) 65 | vae = ConditionalVAE(image_channels=1, image_size=28, dataset='mnist') 66 | vae.load_state_dict(torch.load(args.params_dict_name2, map_location='cpu')) 67 | vae.to(device) 68 | vae.device = device 69 | 70 | # CMG Stage 2 71 | result = fine_tune_same_dataset( 72 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, dataset='mnist', 73 | mode=args.mode) 74 | 75 | print('{}, same dataset mnist {}: max roc auc = {}'.format(args.mode, args.partition, result)) 76 | 77 | elif args.task == 'same_dataset_cifar10': 78 | # prepare data 79 | if args.partition == 'partition1': 80 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(), 81 | transforms.ToTensor(), seen='012345') 82 | elif args.partition == 'partition2': 83 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(), 84 | transforms.ToTensor(), seen='123456') 85 | elif args.partition == 'partition3': 86 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(), 87 | transforms.ToTensor(), seen='234567') 88 | elif args.partition == 'partition4': 89 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(), 90 | transforms.ToTensor(), seen='345678') 91 | elif args.partition == 'partition5': 92 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(), 93 | transforms.ToTensor(), seen='456789') 94 | else: 95 | raise NotImplementedError 96 | 97 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8) 98 | test_loader_seen = DataLoader(test_data_seen, batch_size=512, num_workers=8) 99 | test_loader_unseen = DataLoader(test_data_unseen, batch_size=512, num_workers=8) 100 | 101 | # prepare model 102 | classifier = MainModel(32, 3, 11, dataset='cifar10') 103 | classifier.load_state_dict(torch.load(args.params_dict_name, map_location='cpu')) 104 | classifier.to(device) 105 | vae = ConditionalVAE(image_channels=3, image_size=32, dataset='cifar10') 106 | vae.load_state_dict(torch.load(args.params_dict_name2, map_location='cpu')) 107 | vae.to(device) 108 | vae.device = device 109 | 110 | # CMG Stage 2 111 | result = fine_tune_same_dataset( 112 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, dataset='cifar10', 113 | mode=args.mode) 114 | 115 | print('{}, same dataset cifar10 {}: max roc auc = {}'.format(args.mode, args.partition, result)) 116 | 117 | elif args.task == 'different_dataset': 118 | # prepare data 119 | root_path = os.path.dirname(__file__) 120 | data_path = os.path.join(root_path, 'datasets/cifar10') 121 | train_data = datasets.CIFAR10(root=data_path, train=True, transform=transforms.ToTensor()) 122 | test_data_seen = datasets.CIFAR10(root=data_path, train=False, transform=transforms.ToTensor()) 123 | test_data_unseen = get_ood_dataset(args.ood_dataset) 124 | 125 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8) 126 | test_loader_seen = DataLoader(test_data_seen, batch_size=512, num_workers=8) 127 | test_loader_unseen = DataLoader(test_data_unseen, batch_size=512, num_workers=8) 128 | 129 | # prepare model 130 | classifier = MainModel(32, 3, 110, dataset='cifar10') 131 | classifier.load_state_dict(torch.load(args.params_dict_name, map_location='cpu')) 132 | classifier.to(device) 133 | vae = ConditionalVAE(image_channels=3, image_size=32, dataset='cifar10') 134 | vae.load_state_dict(torch.load(args.params_dict_name2, map_location='cpu')) 135 | vae.to(device) 136 | vae.device = device 137 | 138 | # CMG Stage 2 139 | result = fine_tune_different_dataset( 140 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, mode=args.mode) 141 | 142 | print('{}, different dataset, ood dataset {}: max roc auc = {}'.format(args.mode, args.ood_dataset, result)) 143 | 144 | else: 145 | raise NotImplementedError 146 | 147 | 148 | if __name__ == '__main__': 149 | main() 150 | -------------------------------------------------------------------------------- /CSI+CMG/CSI_model/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models_copy.base_model import BaseModel 5 | from models_copy.transform_layers import NormalizeLayer 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=dilation, groups=groups, bias=False, dilation=dilation) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 23 | base_width=64, dilation=1, norm_layer=None): 24 | super(BasicBlock, self).__init__() 25 | if norm_layer is None: 26 | norm_layer = nn.BatchNorm2d 27 | if groups != 1 or base_width != 64: 28 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 29 | if dilation > 1: 30 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 31 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = norm_layer(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = norm_layer(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 61 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 62 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 63 | # This variant is also known as ResNet V1.5 and improves accuracy according to 64 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 65 | 66 | expansion = 4 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 69 | base_width=64, dilation=1, norm_layer=None): 70 | super(Bottleneck, self).__init__() 71 | if norm_layer is None: 72 | norm_layer = nn.BatchNorm2d 73 | width = int(planes * (base_width / 64.)) * groups 74 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 75 | self.conv1 = conv1x1(inplanes, width) 76 | self.bn1 = norm_layer(width) 77 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 78 | self.bn2 = norm_layer(width) 79 | self.conv3 = conv1x1(width, planes * self.expansion) 80 | self.bn3 = norm_layer(planes * self.expansion) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | identity = x 87 | 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | identity = self.downsample(x) 101 | 102 | out += identity 103 | out = self.relu(out) 104 | 105 | return out 106 | 107 | 108 | class ResNet(BaseModel): 109 | def __init__(self, block, layers, num_classes=10, 110 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, 111 | norm_layer=None): 112 | last_dim = 512 * block.expansion 113 | super(ResNet, self).__init__(last_dim, num_classes) 114 | if norm_layer is None: 115 | norm_layer = nn.BatchNorm2d 116 | self._norm_layer = norm_layer 117 | 118 | self.inplanes = 64 119 | self.dilation = 1 120 | if replace_stride_with_dilation is None: 121 | # each element in the tuple indicates if we should replace 122 | # the 2x2 stride with a dilated convolution instead 123 | replace_stride_with_dilation = [False, False, False] 124 | if len(replace_stride_with_dilation) != 3: 125 | raise ValueError("replace_stride_with_dilation should be None " 126 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 127 | self.groups = groups 128 | self.base_width = width_per_group 129 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 130 | bias=False) 131 | self.bn1 = norm_layer(self.inplanes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | self.layer1 = self._make_layer(block, 64, layers[0]) 135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 136 | dilate=replace_stride_with_dilation[0]) 137 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 138 | dilate=replace_stride_with_dilation[1]) 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 140 | dilate=replace_stride_with_dilation[2]) 141 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 142 | self.normalize = NormalizeLayer() 143 | self.last_dim = 512 * block.expansion 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 148 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 149 | nn.init.constant_(m.weight, 1) 150 | nn.init.constant_(m.bias, 0) 151 | 152 | # Zero-initialize the last BN in each residual branch, 153 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 154 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 155 | if zero_init_residual: 156 | for m in self.modules(): 157 | if isinstance(m, Bottleneck): 158 | nn.init.constant_(m.bn3.weight, 0) 159 | elif isinstance(m, BasicBlock): 160 | nn.init.constant_(m.bn2.weight, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 163 | norm_layer = self._norm_layer 164 | downsample = None 165 | previous_dilation = self.dilation 166 | if dilate: 167 | self.dilation *= stride 168 | stride = 1 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | conv1x1(self.inplanes, planes * block.expansion, stride), 172 | norm_layer(planes * block.expansion), 173 | ) 174 | 175 | layers = [] 176 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 177 | self.base_width, previous_dilation, norm_layer)) 178 | self.inplanes = planes * block.expansion 179 | for _ in range(1, blocks): 180 | layers.append(block(self.inplanes, planes, groups=self.groups, 181 | base_width=self.base_width, dilation=self.dilation, 182 | norm_layer=norm_layer)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def penultimate(self, x, all_features=False): 187 | # See note [TorchScript super()] 188 | out_list = [] 189 | 190 | x = self.normalize(x) 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | x = self.maxpool(x) 195 | out_list.append(x) 196 | 197 | x = self.layer1(x) 198 | out_list.append(x) 199 | x = self.layer2(x) 200 | out_list.append(x) 201 | x = self.layer3(x) 202 | out_list.append(x) 203 | x = self.layer4(x) 204 | out_list.append(x) 205 | 206 | x = self.avgpool(x) 207 | x = torch.flatten(x, 1) 208 | 209 | if all_features: 210 | return x, out_list 211 | else: 212 | return x 213 | 214 | 215 | def _resnet(arch, block, layers, **kwargs): 216 | model = ResNet(block, layers, **kwargs) 217 | return model 218 | 219 | 220 | def resnet18(**kwargs): 221 | r"""ResNet-18 model from 222 | `"Deep Residual Learning for Image Recognition" `_ 223 | """ 224 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs) 225 | 226 | 227 | def resnet50(**kwargs): 228 | r"""ResNet-50 model from 229 | `"Deep Residual Learning for Image Recognition" `_ 230 | """ 231 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs) 232 | -------------------------------------------------------------------------------- /training/fine_tune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from models.utils import loss_functions as lf 5 | from models.utils import score 6 | 7 | 8 | def generate_pseudo_data(vae, class_num, scalar, neg_item_per_batch, device): 9 | """Generates pseudo data for CMG stage 2.""" 10 | 11 | # prepare for class embedding 12 | y1 = torch.Tensor(neg_item_per_batch, vae.class_num) 13 | y1.zero_() 14 | y2 = torch.Tensor(neg_item_per_batch, vae.class_num) 15 | y2.zero_() 16 | ind = torch.randint(0, class_num, (neg_item_per_batch, 1)) 17 | ind2 = torch.randint(0, class_num, (neg_item_per_batch, 1)) 18 | y1.scatter_(1, ind.view(-1, 1), 1) 19 | y2.scatter_(1, ind2.view(-1, 1), 1) 20 | y1 = y1.to(device) 21 | y2 = y2.to(device) 22 | class_embed1 = vae.class_embed(y1) 23 | class_embed2 = vae.class_embed(y2) 24 | rv = torch.randint(0, 2, class_embed1.shape).to(device) 25 | class_embed = torch.where(rv == 0, class_embed1, class_embed2) 26 | 27 | # sample in N(0, sigma^2) 28 | random_z = torch.randn(neg_item_per_batch, vae.z_dim).to(device) * scalar 29 | 30 | x_generate = vae.decode(random_z, class_embed) 31 | 32 | return x_generate 33 | 34 | 35 | def get_m(train_loader, classifier, vae, class_num, scalar, device): 36 | """Gets hyper-parameters for CMG-energy using training data.""" 37 | 38 | print("=====get_m======") 39 | with torch.no_grad(): 40 | Ec_out, Ec_in = None, None 41 | for data, target in train_loader: 42 | classifier.eval() 43 | classifier.classifier.train() 44 | 45 | data = data.to(device) 46 | prediction = classifier(data) 47 | 48 | x_generate = generate_pseudo_data(vae, class_num, scalar, 128, device) 49 | 50 | prediction_generate = classifier(x_generate) 51 | 52 | # calculate energy of the training data and generated negative data 53 | T = 1 54 | if Ec_in is None and Ec_out is None: 55 | Ec_in = -T * torch.logsumexp(prediction / T, dim=1) 56 | Ec_out = -T * torch.logsumexp(prediction_generate / T, dim=1) 57 | else: 58 | Ec_in = torch.cat((Ec_in, (-T * torch.logsumexp(prediction / T, dim=1))), dim=0) 59 | Ec_out = torch.cat((Ec_out, (-T * torch.logsumexp(prediction_generate / T, dim=1))), dim=0) 60 | 61 | Ec_in = Ec_in.sort()[0] 62 | Ec_out = Ec_out.sort()[0] 63 | in_size = Ec_in.size(0) 64 | out_size = Ec_out.size(0) 65 | m_in, m_out = Ec_in[int(in_size * 0.2)], Ec_out[int(out_size * 0.8)] 66 | print("m_in = ", m_in, ",m_out=", m_out) 67 | 68 | return m_in, m_out 69 | 70 | 71 | def fine_tune_same_dataset( 72 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, dataset='mnist', 73 | mode='CMG-energy'): 74 | """ 75 | Performs CMG stage2 by fine-tuning the OOD detector using generated pseudo data on Setting 1. 76 | """ 77 | for p in classifier.convE.parameters(): 78 | p.requires_grad = False 79 | 80 | vae.eval() 81 | for p in vae.parameters(): 82 | p.requires_grad = False 83 | 84 | n_epochs = 15 85 | LR = 0.0001 86 | if dataset == 'mnist': 87 | scalar = 3.0 88 | elif dataset == 'cifar10': 89 | scalar = 5.0 90 | else: 91 | raise NotImplementedError 92 | neg_item_per_batch = 128 93 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr=LR) 94 | 95 | if mode == 'CMG-energy': 96 | # define hyper-parameters for CMG-energy 97 | mu = 0.1 98 | m_in, m_out = get_m(train_loader, classifier, vae, 6, scalar, device) 99 | 100 | index = -1 101 | max_rocauc = 0 102 | for epoch in range(n_epochs): 103 | for data, target in train_loader: 104 | 105 | classifier.eval() 106 | classifier.classifier.train() 107 | optimizer.zero_grad() 108 | 109 | index += 1 110 | data = data.to(device) 111 | target = target.long().to(device) 112 | prediction = classifier(data) 113 | x_generate = generate_pseudo_data(vae, 6, scalar, neg_item_per_batch, device) 114 | prediction_generate = classifier(x_generate) 115 | 116 | if mode == 'CMG-softmax': 117 | loss_input = F.cross_entropy(prediction, target, reduction='none') 118 | loss_input = lf.weighted_average(loss_input, weights=None, dim=0) 119 | 120 | y_generate = torch.ones(neg_item_per_batch) * 6 121 | y_generate = y_generate.long().to(device) 122 | 123 | loss_generate = F.cross_entropy(prediction_generate[:, 0:7], y_generate) 124 | 125 | loss = loss_generate + loss_input 126 | 127 | loss.backward() 128 | optimizer.step() 129 | 130 | elif mode == 'CMG-energy': 131 | loss_ce = F.cross_entropy(prediction, target) 132 | 133 | Ec_in = - torch.logsumexp(prediction, dim=1) 134 | Ec_out = - torch.logsumexp(prediction_generate, dim=1) 135 | loss_energy = torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean() 136 | 137 | loss = loss_ce + mu * loss_energy 138 | 139 | loss.backward() 140 | optimizer.step() 141 | 142 | else: 143 | raise NotImplementedError 144 | 145 | # evaluate 146 | if index % 20 == 0: 147 | classifier.eval() 148 | with torch.no_grad(): 149 | fx = [] 150 | labels = [] 151 | 152 | for x, y in test_loader_seen: 153 | x = x.to(device) 154 | y = y.long().to(device) 155 | output = classifier(x) 156 | 157 | fx.append(output[:, 0:7]) 158 | labels.append(y) 159 | 160 | for x, _ in test_loader_unseen: 161 | x = x.to(device) 162 | output = classifier(x) 163 | 164 | fx.append(output[:, 0:7]) 165 | labels.append(-torch.ones(x.size(0)).long().to(device)) 166 | 167 | labels = torch.cat(labels, 0) 168 | fx = torch.cat(fx, 0) 169 | 170 | if mode == 'CMG-softmax': 171 | roc_auc = score.softmax_result(fx, labels) 172 | elif mode == 'CMG-energy': 173 | roc_auc = score.energy_result(fx, labels) 174 | else: 175 | raise NotImplementedError 176 | 177 | if roc_auc > max_rocauc: 178 | max_rocauc = roc_auc 179 | 180 | if mode == 'CMG-softmax': 181 | print('Epoch:', epoch, 'Index:', index, 'loss input', loss_input.item(), 182 | 'loss neg', loss_generate.item()) 183 | elif mode == 'CMG-energy': 184 | print('Epoch:', epoch, 'Index:', index, 'loss ce', loss_ce.item(), 185 | 'loss energy', loss_energy.item()) 186 | else: 187 | raise NotImplementedError 188 | 189 | print('max roc auc:', max_rocauc) 190 | 191 | return max_rocauc 192 | 193 | 194 | def fine_tune_different_dataset( 195 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, mode='CMG-energy'): 196 | """ 197 | Performs CMG stage2 by fine-tuning the OOD detector using generated pseudo data on Setting 2. 198 | """ 199 | for p in classifier.convE.parameters(): 200 | p.requires_grad = False 201 | 202 | vae.eval() 203 | for p in vae.parameters(): 204 | p.requires_grad = False 205 | 206 | n_epochs = 15 207 | LR = 0.0001 208 | scalar = 5.0 209 | neg_item_per_batch = 128 210 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr=LR) 211 | 212 | if mode == 'CMG-energy': 213 | # define hyper-parameters for CMG-energy 214 | mu = 0.1 215 | m_in, m_out = get_m(train_loader, classifier, vae, 10, scalar, device) 216 | 217 | index = -1 218 | max_rocauc = 0 219 | for epoch in range(n_epochs): 220 | for data, target in train_loader: 221 | 222 | classifier.eval() 223 | classifier.classifier.train() 224 | optimizer.zero_grad() 225 | 226 | index += 1 227 | data = data.to(device) 228 | target = target.long().to(device) 229 | prediction = classifier(data) 230 | x_generate = generate_pseudo_data(vae, 10, scalar, neg_item_per_batch, device) 231 | prediction_generate = classifier(x_generate) 232 | 233 | if mode == 'CMG-softmax': 234 | loss_input = F.cross_entropy(prediction, target, reduction='none') 235 | loss_input = lf.weighted_average(loss_input, weights=None, dim=0) 236 | 237 | y_generate = torch.ones(neg_item_per_batch) * 10 238 | y_generate = y_generate.long().to(device) 239 | 240 | loss_generate = F.cross_entropy(prediction_generate[:, 0:11], y_generate) 241 | 242 | loss = loss_generate + loss_input 243 | 244 | loss.backward() 245 | optimizer.step() 246 | 247 | elif mode == 'CMG-energy': 248 | loss_ce = F.cross_entropy(prediction, target) 249 | 250 | Ec_in = - torch.logsumexp(prediction, dim=1) 251 | Ec_out = - torch.logsumexp(prediction_generate, dim=1) 252 | loss_energy = torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean() 253 | 254 | loss = loss_ce + mu * loss_energy 255 | 256 | loss.backward() 257 | optimizer.step() 258 | 259 | else: 260 | raise NotImplementedError 261 | 262 | # evaluate 263 | if index % 100 == 0: 264 | classifier.eval() 265 | with torch.no_grad(): 266 | fx = [] 267 | labels = [] 268 | 269 | for x, y in test_loader_seen: 270 | x = x.to(device) 271 | y = y.long().to(device) 272 | output = classifier(x) 273 | 274 | fx.append(output[:, 0:11]) 275 | labels.append(y) 276 | 277 | for x, _ in test_loader_unseen: 278 | x = x.to(device) 279 | output = classifier(x) 280 | 281 | fx.append(output[:, 0:11]) 282 | labels.append(-torch.ones(x.size(0)).long().to(device)) 283 | 284 | labels = torch.cat(labels, 0) 285 | fx = torch.cat(fx, 0) 286 | 287 | if mode == 'CMG-softmax': 288 | roc_auc = score.softmax_result(fx, labels) 289 | elif mode == 'CMG-energy': 290 | roc_auc = score.energy_result(fx, labels) 291 | else: 292 | raise NotImplementedError 293 | 294 | if roc_auc > max_rocauc: 295 | max_rocauc = roc_auc 296 | 297 | if mode == 'CMG-softmax': 298 | print('Epoch:', epoch, 'Index:', index, 'loss input', loss_input.item(), 299 | 'loss neg', loss_generate.item()) 300 | elif mode == 'CMG-energy': 301 | print('Epoch:', epoch, 'Index:', index, 'loss ce', loss_ce.item(), 302 | 'loss energy', loss_energy.item()) 303 | else: 304 | raise NotImplementedError 305 | 306 | print('max roc auc:', max_rocauc) 307 | 308 | return max_rocauc 309 | -------------------------------------------------------------------------------- /CSI+CMG/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from sklearn.metrics import roc_auc_score 8 | from torch.utils.data import DataLoader 9 | from torchvision import datasets, transforms 10 | 11 | import CSI_model.classifier as C 12 | from models.vae import ConditionalVAE2 13 | from utils import get_args 14 | 15 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 16 | 17 | 18 | def energy_result(fx, y): 19 | """Calculates roc_auc using energy score. 20 | 21 | Args: 22 | fx: Last layer output of the model. 23 | y: Class Label, assumes the label of unseen data to be -1. 24 | Returns: 25 | roc_auc: Unseen data as positive, seen data as negative. 26 | """ 27 | energy_score = - torch.logsumexp(fx, dim=1) 28 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), energy_score.cpu().detach().numpy()) 29 | 30 | return rocauc 31 | 32 | 33 | # Set up seed ----------------------------------------------------------------- 34 | def setup_seed(seed): 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | np.random.seed(seed) 38 | random.seed(seed) 39 | torch.backends.cudnn.deterministic = True 40 | 41 | 42 | setup_seed(222) 43 | 44 | # Define hyper parameters and model ------------------------------------------- 45 | args = get_args() 46 | batch_size = 128 47 | n_epochs = 15 48 | LR = 0.0001 49 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 50 | print(device) 51 | 52 | classifier = C.get_classifier('resnet18', n_classes=10).to(device) 53 | checkpoint = torch.load(args.params_dict_name) 54 | classifier.load_state_dict(checkpoint, strict=False) 55 | 56 | # freeze the encoder part 57 | for p in classifier.layer1.parameters(): 58 | p.requires_grad = False 59 | for p in classifier.layer2.parameters(): 60 | p.requires_grad = False 61 | for p in classifier.layer3.parameters(): 62 | p.requires_grad = False 63 | for p in classifier.layer4.parameters(): 64 | p.requires_grad = False 65 | for p in classifier.conv1.parameters(): 66 | p.requires_grad = False 67 | 68 | # load the CVAE model 69 | vae = ConditionalVAE2() 70 | vae.load_state_dict(torch.load(args.params_dict_name2, map_location='cpu')) 71 | vae.to(device) 72 | vae.eval() 73 | for p in vae.parameters(): 74 | p.requires_grad = False 75 | 76 | 77 | def generate_pseudo_data(vae): 78 | scalar = 5.0 79 | neg_item_per_batch = 128 80 | 81 | # prepare for class embedding 82 | y1 = torch.Tensor(neg_item_per_batch, vae.class_num) 83 | y1.zero_() 84 | y2 = torch.Tensor(neg_item_per_batch, vae.class_num) 85 | y2.zero_() 86 | ind = torch.randint(0, 10, (neg_item_per_batch, 1)) 87 | ind2 = torch.randint(0, 10, (neg_item_per_batch, 1)) 88 | y1.scatter_(1, ind.view(-1, 1), 1) 89 | y2.scatter_(1, ind2.view(-1, 1), 1) 90 | y1 = y1.to(device) 91 | y2 = y2.to(device) 92 | class_embed1 = vae.class_embed(y1) 93 | class_embed2 = vae.class_embed(y2) 94 | rv = torch.randint(0, 2, class_embed1.shape).to(device) 95 | class_embed = torch.where(rv == 0, class_embed1, class_embed2) 96 | 97 | # sample in N(0, sigma^2) 98 | random_z = torch.randn(neg_item_per_batch, vae.z_dim).to(device) * scalar 99 | 100 | x_generate = vae.decode(random_z, class_embed) 101 | 102 | return x_generate 103 | 104 | 105 | def get_m(train_loader, classifier, vae): 106 | print("=====get_m======") 107 | with torch.no_grad(): 108 | Ec_out, Ec_in = None, None 109 | for data, target in train_loader: 110 | classifier.eval() 111 | classifier.linear.train() 112 | 113 | data = data.to(device) 114 | prediction = classifier(data) 115 | 116 | x_generate = generate_pseudo_data(vae) 117 | 118 | prediction_generate = classifier(x_generate) 119 | 120 | # calculate energy of the training data and generated negative data 121 | T = 1 122 | if Ec_in is None and Ec_out is None: 123 | Ec_in = -T * torch.logsumexp(prediction / T, dim=1) 124 | Ec_out = -T * torch.logsumexp(prediction_generate / T, dim=1) 125 | else: 126 | Ec_in = torch.cat((Ec_in, (-T * torch.logsumexp(prediction / T, dim=1))), dim=0) 127 | Ec_out = torch.cat((Ec_out, (-T * torch.logsumexp(prediction_generate / T, dim=1))), dim=0) 128 | Ec_in = Ec_in.sort()[0] 129 | Ec_out = Ec_out.sort()[0] 130 | in_size = Ec_in.size(0) 131 | out_size = Ec_out.size(0) 132 | m_in, m_out = Ec_in[int(in_size * 0.2)], Ec_out[int(out_size * 0.8)] 133 | print("m_in = ", m_in, ",m_out=", m_out) 134 | return m_in, m_out 135 | 136 | 137 | def tune_main_model(): 138 | """ 139 | seen: cifar10 140 | unseen: SVHN / LSUN / ImagenNet / LSUN(FIX) / ImageNet(FIX) / CIFAR100. 141 | """ 142 | 143 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr=LR) 144 | 145 | # prepare data ------------------------------------------------------------ 146 | transform_train = transforms.ToTensor() 147 | 148 | train_data = datasets.CIFAR10( 149 | root='./data/cifar10', train=True, download=True, 150 | transform=transform_train) 151 | test_data = datasets.CIFAR10( 152 | root='./data/cifar10', train=False, download=True, 153 | transform=transforms.ToTensor()) 154 | 155 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4) 156 | test_loader_seen = DataLoader(test_data, batch_size=batch_size, num_workers=4) 157 | test_dir_svhn = os.path.join('./data', 'svhn') 158 | svhn_data = datasets.SVHN( 159 | root=test_dir_svhn, split='test', download=True, 160 | transform=transforms.ToTensor()) 161 | test_loader_svhn = DataLoader(svhn_data, batch_size=512, num_workers=4) 162 | test_dir_lsun = os.path.join('./data', 'LSUN_resize') 163 | lsun_data = datasets.ImageFolder(test_dir_lsun, transform=transforms.ToTensor()) 164 | test_loader_lsun = DataLoader(lsun_data, batch_size=512, num_workers=4) 165 | test_dir_imagenet = os.path.join('./data', 'Imagenet_resize') 166 | imagenet_data = datasets.ImageFolder(test_dir_imagenet, transform=transforms.ToTensor()) 167 | test_loader_imagenet = DataLoader(imagenet_data, batch_size=512, num_workers=4) 168 | test_dir_lsun_fix = os.path.join('./data', 'LSUN_fix') 169 | lsun_fix_data = datasets.ImageFolder(test_dir_lsun_fix, transform=transforms.ToTensor()) 170 | test_loader_lsun_fix = DataLoader(lsun_fix_data, batch_size=512, num_workers=4) 171 | test_dir_imagenet_fix = os.path.join('./data', 'Imagenet_fix') 172 | imagenet_data = datasets.ImageFolder(test_dir_imagenet_fix, transform=transforms.ToTensor()) 173 | test_loader_imagenet_fix = DataLoader(imagenet_data, batch_size=512, num_workers=4) 174 | test_dir_cifar100 = os.path.join('./data', 'cifar100') 175 | cifar100_data = datasets.CIFAR100( 176 | root=test_dir_cifar100, train=False, transform=transforms.ToTensor(), download=True) 177 | test_loader_cifar100 = DataLoader(cifar100_data, batch_size=512, num_workers=4) 178 | 179 | # hyper params for energy loss 180 | mu = 0.1 181 | m_in, m_out = get_m(train_loader, classifier, vae) 182 | 183 | # fine-tuning the classification head ---------------------------------------------------- 184 | index = -1 185 | max_roc_auc = {'svhn': 0, 'lsun': 0, 'imagenet': 0, 'lsun_fix': 0, 'imagenet_fix': 0, 'cifar100': 0} 186 | 187 | for epoch in range(n_epochs): 188 | 189 | for data, target in train_loader: 190 | 191 | classifier.eval() 192 | classifier.linear.train() 193 | optimizer.zero_grad() 194 | index += 1 195 | 196 | data = data.to(device) 197 | target = target.long().to(device) 198 | prediction = classifier(data) 199 | loss_ce = F.cross_entropy(prediction, target) 200 | Ec_in = -torch.logsumexp(prediction, dim=1) 201 | 202 | x_generate = generate_pseudo_data(vae) 203 | 204 | prediction_generate = classifier(x_generate)[:, 0:10] 205 | 206 | Ec_out = -torch.logsumexp(prediction_generate, dim=1) 207 | 208 | # energy loss 209 | loss_energy = torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean() 210 | 211 | loss = loss_ce + mu * loss_energy 212 | 213 | loss.backward() 214 | optimizer.step() 215 | 216 | # evaluate (every 100 batches) ------------------------------------ 217 | if index % 100 == 0: 218 | classifier.eval() 219 | with torch.no_grad(): 220 | 221 | output_ind = [] 222 | output_svhn, output_lsun, output_imagenet, output_lsun_fix, output_imagenet_fix, output_cifar100, \ 223 | = [], [], [], [], [], [] 224 | 225 | for x, _ in test_loader_seen: 226 | x = x.to(device) 227 | output = classifier(x) 228 | output_ind.append(output) 229 | 230 | for x, _ in test_loader_svhn: 231 | x = x.to(device) 232 | output = classifier(x) 233 | output_svhn.append(output) 234 | 235 | for x, _ in test_loader_lsun: 236 | x = x.to(device) 237 | output = classifier(x) 238 | output_lsun.append(output) 239 | 240 | for x, _ in test_loader_imagenet: 241 | x = x.to(device) 242 | output = classifier(x) 243 | output_imagenet.append(output) 244 | 245 | for x, _ in test_loader_lsun_fix: 246 | x = x.to(device) 247 | output = classifier(x) 248 | output_lsun_fix.append(output) 249 | 250 | for x, _ in test_loader_imagenet_fix: 251 | x = x.to(device) 252 | output = classifier(x) 253 | output_imagenet_fix.append(output) 254 | 255 | for x, _ in test_loader_cifar100: 256 | x = x.to(device) 257 | output = classifier(x) 258 | output_cifar100.append(output) 259 | 260 | output_ind = torch.cat(output_ind, 0) 261 | output_svhn = torch.cat(output_svhn, 0) 262 | output_lsun = torch.cat(output_lsun, 0) 263 | output_imagenet = torch.cat(output_imagenet, 0) 264 | output_lsun_fix = torch.cat(output_lsun_fix, 0) 265 | output_imagenet_fix = torch.cat(output_imagenet_fix, 0) 266 | output_cifar100 = torch.cat(output_cifar100, 0) 267 | 268 | roc_auc_svhn = energy_result(torch.cat([output_ind, output_svhn]), torch.cat( 269 | [torch.ones(output_ind.size(0)), -torch.ones(output_svhn.size(0))]).long().to(device)) 270 | roc_auc_lsun = energy_result(torch.cat([output_ind, output_lsun]), torch.cat( 271 | [torch.ones(output_ind.size(0)), -torch.ones(output_lsun.size(0))]).long().to(device)) 272 | roc_auc_imagenet = energy_result(torch.cat([output_ind, output_imagenet]), torch.cat( 273 | [torch.ones(output_ind.size(0)), -torch.ones(output_imagenet.size(0))]).long().to(device)) 274 | roc_auc_lsun_fix = energy_result(torch.cat([output_ind, output_lsun_fix]), torch.cat( 275 | [torch.ones(output_ind.size(0)), -torch.ones(output_lsun_fix.size(0))]).long().to(device)) 276 | roc_auc_imagenet_fix = energy_result(torch.cat([output_ind, output_imagenet_fix]), torch.cat( 277 | [torch.ones(output_ind.size(0)), -torch.ones(output_imagenet_fix.size(0))]).long().to(device)) 278 | roc_auc_cifar100 = energy_result(torch.cat([output_ind, output_cifar100]), torch.cat( 279 | [torch.ones(output_ind.size(0)), -torch.ones(output_cifar100.size(0))]).long().to(device)) 280 | 281 | max_roc_auc['svhn'] = max(max_roc_auc['svhn'], roc_auc_svhn) 282 | max_roc_auc['lsun'] = max(max_roc_auc['lsun'], roc_auc_lsun) 283 | max_roc_auc['imagenet'] = max(max_roc_auc['imagenet'], roc_auc_imagenet) 284 | max_roc_auc['lsun_fix'] = max(max_roc_auc['lsun_fix'], roc_auc_lsun_fix) 285 | max_roc_auc['imagenet_fix'] = max(max_roc_auc['imagenet_fix'], roc_auc_imagenet_fix) 286 | max_roc_auc['cifar100'] = max(max_roc_auc['cifar100'], roc_auc_cifar100) 287 | 288 | print('Epoch: {} Index: {}'.format(epoch, index)) 289 | print('Max rocauc result') 290 | print(max_roc_auc) 291 | 292 | 293 | if __name__ == '__main__': 294 | tune_main_model() 295 | -------------------------------------------------------------------------------- /CSI+CMG/CSI_model/transform_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Function 9 | 10 | if torch.__version__ >= '1.4.0': 11 | kwargs = {'align_corners': False} 12 | else: 13 | kwargs = {} 14 | 15 | 16 | def rgb2hsv(rgb): 17 | """Convert a 4-d RGB tensor to the HSV counterpart. 18 | 19 | Here, we compute hue using atan2() based on the definition in [1], 20 | instead of using the common lookup table approach as in [2, 3]. 21 | Those values agree when the angle is a multiple of 30°, 22 | otherwise they may differ at most ~1.2°. 23 | 24 | References 25 | [1] https://en.wikipedia.org/wiki/Hue 26 | [2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html 27 | [3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L212 28 | """ 29 | 30 | r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :] 31 | 32 | Cmax = rgb.max(1)[0] 33 | Cmin = rgb.min(1)[0] 34 | delta = Cmax - Cmin 35 | 36 | hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b) 37 | hue = (hue % (2 * math.pi)) / (2 * math.pi) 38 | saturate = delta / Cmax 39 | value = Cmax 40 | hsv = torch.stack([hue, saturate, value], dim=1) 41 | hsv[~torch.isfinite(hsv)] = 0. 42 | return hsv 43 | 44 | 45 | def hsv2rgb(hsv): 46 | """Convert a 4-d HSV tensor to the RGB counterpart. 47 | 48 | >>> %timeit hsv2rgb(hsv) 49 | 2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 50 | >>> %timeit rgb2hsv_fast(rgb) 51 | 298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 52 | >>> torch.allclose(hsv2rgb(hsv), hsv2rgb_fast(hsv), atol=1e-6) 53 | True 54 | 55 | References 56 | [1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative 57 | """ 58 | h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]] 59 | c = v * s 60 | 61 | n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1) 62 | k = (n + h * 6) % 6 63 | t = torch.min(k, 4 - k) 64 | t = torch.clamp(t, 0, 1) 65 | 66 | return v - c * t 67 | 68 | 69 | class RandomResizedCropLayer(nn.Module): 70 | def __init__(self, size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)): 71 | ''' 72 | Inception Crop 73 | size (tuple): size of fowarding image (C, W, H) 74 | scale (tuple): range of size of the origin size cropped 75 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 76 | ''' 77 | super(RandomResizedCropLayer, self).__init__() 78 | 79 | _eye = torch.eye(2, 3) 80 | self.size = size 81 | self.register_buffer('_eye', _eye) 82 | self.scale = scale 83 | self.ratio = ratio 84 | 85 | def forward(self, inputs, whbias=None): 86 | _device = inputs.device 87 | N = inputs.size(0) 88 | _theta = self._eye.repeat(N, 1, 1) 89 | 90 | if whbias is None: 91 | whbias = self._sample_latent(inputs) 92 | 93 | _theta[:, 0, 0] = whbias[:, 0] 94 | _theta[:, 1, 1] = whbias[:, 1] 95 | _theta[:, 0, 2] = whbias[:, 2] 96 | _theta[:, 1, 2] = whbias[:, 3] 97 | 98 | grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device) 99 | output = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs) 100 | 101 | if self.size is not None: 102 | output = F.adaptive_avg_pool2d(output, self.size) 103 | 104 | return output 105 | 106 | def _clamp(self, whbias): 107 | 108 | w = whbias[:, 0] 109 | h = whbias[:, 1] 110 | w_bias = whbias[:, 2] 111 | h_bias = whbias[:, 3] 112 | 113 | # Clamp with scale 114 | w = torch.clamp(w, *self.scale) 115 | h = torch.clamp(h, *self.scale) 116 | 117 | # Clamp with ratio 118 | w = self.ratio[0] * h + torch.relu(w - self.ratio[0] * h) 119 | w = self.ratio[1] * h - torch.relu(self.ratio[1] * h - w) 120 | 121 | # Clamp with bias range: w_bias \in (w - 1, 1 - w), h_bias \in (h - 1, 1 - h) 122 | w_bias = w - 1 + torch.relu(w_bias - w + 1) 123 | w_bias = 1 - w - torch.relu(1 - w - w_bias) 124 | 125 | h_bias = h - 1 + torch.relu(h_bias - h + 1) 126 | h_bias = 1 - h - torch.relu(1 - h - h_bias) 127 | 128 | whbias = torch.stack([w, h, w_bias, h_bias], dim=0).t() 129 | 130 | return whbias 131 | 132 | def _sample_latent(self, inputs): 133 | 134 | _device = inputs.device 135 | N, _, width, height = inputs.shape 136 | 137 | # N * 10 trial 138 | area = width * height 139 | target_area = np.random.uniform(*self.scale, N * 10) * area 140 | log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) 141 | aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10)) 142 | 143 | # If doesn't satisfy ratio condition, then do central crop 144 | w = np.round(np.sqrt(target_area * aspect_ratio)) 145 | h = np.round(np.sqrt(target_area / aspect_ratio)) 146 | cond = (0 < w) * (w <= width) * (0 < h) * (h <= height) 147 | w = w[cond] 148 | h = h[cond] 149 | cond_len = w.shape[0] 150 | if cond_len >= N: 151 | w = w[:N] 152 | h = h[:N] 153 | else: 154 | w = np.concatenate([w, np.ones(N - cond_len) * width]) 155 | h = np.concatenate([h, np.ones(N - cond_len) * height]) 156 | 157 | w_bias = np.random.randint(w - width, width - w + 1) / width 158 | h_bias = np.random.randint(h - height, height - h + 1) / height 159 | w = w / width 160 | h = h / height 161 | 162 | whbias = np.column_stack([w, h, w_bias, h_bias]) 163 | whbias = torch.tensor(whbias, device=_device) 164 | 165 | return whbias 166 | 167 | 168 | class HorizontalFlipRandomCrop(nn.Module): 169 | def __init__(self, max_range): 170 | super(HorizontalFlipRandomCrop, self).__init__() 171 | self.max_range = max_range 172 | _eye = torch.eye(2, 3) 173 | self.register_buffer('_eye', _eye) 174 | 175 | def forward(self, input, sign=None, bias=None, rotation=None): 176 | _device = input.device 177 | N = input.size(0) 178 | _theta = self._eye.repeat(N, 1, 1) 179 | 180 | if sign is None: 181 | sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1 182 | if bias is None: 183 | bias = torch.empty((N, 2), device=_device).uniform_(-self.max_range, self.max_range) 184 | _theta[:, 0, 0] = sign 185 | _theta[:, :, 2] = bias 186 | 187 | if rotation is not None: 188 | _theta[:, 0:2, 0:2] = rotation 189 | 190 | grid = F.affine_grid(_theta, input.size(), **kwargs).to(_device) 191 | output = F.grid_sample(input, grid, padding_mode='reflection', **kwargs) 192 | 193 | return output 194 | 195 | def _sample_latent(self, N, device=None): 196 | sign = torch.bernoulli(torch.ones(N, device=device) * 0.5) * 2 - 1 197 | bias = torch.empty((N, 2), device=device).uniform_(-self.max_range, self.max_range) 198 | return sign, bias 199 | 200 | 201 | class Rotation(nn.Module): 202 | def __init__(self, max_range=4): 203 | super(Rotation, self).__init__() 204 | self.max_range = max_range 205 | self.prob = 0.5 206 | 207 | def forward(self, input, aug_index=None): 208 | _device = input.device 209 | 210 | _, _, H, W = input.size() 211 | 212 | if aug_index is None: 213 | aug_index = np.random.randint(4) 214 | 215 | output = torch.rot90(input, aug_index, (2, 3)) 216 | 217 | _prob = input.new_full((input.size(0),), self.prob) 218 | _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) 219 | output = _mask * input + (1 - _mask) * output 220 | 221 | else: 222 | aug_index = aug_index % self.max_range 223 | output = torch.rot90(input, aug_index, (2, 3)) 224 | 225 | return output 226 | 227 | 228 | class CutPerm(nn.Module): 229 | def __init__(self, max_range=4): 230 | super(CutPerm, self).__init__() 231 | self.max_range = max_range 232 | self.prob = 0.5 233 | 234 | def forward(self, input, aug_index=None): 235 | _device = input.device 236 | 237 | _, _, H, W = input.size() 238 | 239 | if aug_index is None: 240 | aug_index = np.random.randint(4) 241 | 242 | output = self._cutperm(input, aug_index) 243 | 244 | _prob = input.new_full((input.size(0),), self.prob) 245 | _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) 246 | output = _mask * input + (1 - _mask) * output 247 | 248 | else: 249 | aug_index = aug_index % self.max_range 250 | output = self._cutperm(input, aug_index) 251 | 252 | return output 253 | 254 | def _cutperm(self, inputs, aug_index): 255 | 256 | _, _, H, W = inputs.size() 257 | h_mid = int(H / 2) 258 | w_mid = int(W / 2) 259 | 260 | jigsaw_h = aug_index // 2 261 | jigsaw_v = aug_index % 2 262 | 263 | if jigsaw_h == 1: 264 | inputs = torch.cat((inputs[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2) 265 | if jigsaw_v == 1: 266 | inputs = torch.cat((inputs[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3) 267 | 268 | return inputs 269 | 270 | 271 | class HorizontalFlipLayer(nn.Module): 272 | def __init__(self): 273 | """ 274 | img_size : (int, int, int) 275 | Height and width must be powers of 2. E.g. (32, 32, 1) or 276 | (64, 128, 3). Last number indicates number of channels, e.g. 1 for 277 | grayscale or 3 for RGB 278 | """ 279 | super(HorizontalFlipLayer, self).__init__() 280 | 281 | _eye = torch.eye(2, 3) 282 | self.register_buffer('_eye', _eye) 283 | 284 | def forward(self, inputs): 285 | _device = inputs.device 286 | 287 | N = inputs.size(0) 288 | _theta = self._eye.repeat(N, 1, 1) 289 | r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1 290 | _theta[:, 0, 0] = r_sign 291 | grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device) 292 | inputs = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs) 293 | 294 | return inputs 295 | 296 | 297 | class RandomColorGrayLayer(nn.Module): 298 | def __init__(self, p): 299 | super(RandomColorGrayLayer, self).__init__() 300 | self.prob = p 301 | 302 | _weight = torch.tensor([[0.299, 0.587, 0.114]]) 303 | self.register_buffer('_weight', _weight.view(1, 3, 1, 1)) 304 | 305 | def forward(self, inputs, aug_index=None): 306 | 307 | if aug_index == 0: 308 | return inputs 309 | 310 | l = F.conv2d(inputs, self._weight) 311 | gray = torch.cat([l, l, l], dim=1) 312 | 313 | if aug_index is None: 314 | _prob = inputs.new_full((inputs.size(0),), self.prob) 315 | _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) 316 | 317 | gray = inputs * (1 - _mask) + gray * _mask 318 | 319 | return gray 320 | 321 | 322 | class ColorJitterLayer(nn.Module): 323 | def __init__(self, p, brightness, contrast, saturation, hue): 324 | super(ColorJitterLayer, self).__init__() 325 | self.prob = p 326 | self.brightness = self._check_input(brightness, 'brightness') 327 | self.contrast = self._check_input(contrast, 'contrast') 328 | self.saturation = self._check_input(saturation, 'saturation') 329 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 330 | clip_first_on_zero=False) 331 | 332 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 333 | if isinstance(value, numbers.Number): 334 | if value < 0: 335 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 336 | value = [center - value, center + value] 337 | if clip_first_on_zero: 338 | value[0] = max(value[0], 0) 339 | elif isinstance(value, (tuple, list)) and len(value) == 2: 340 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 341 | raise ValueError("{} values should be between {}".format(name, bound)) 342 | else: 343 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 344 | 345 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 346 | # or (0., 0.) for hue, do nothing 347 | if value[0] == value[1] == center: 348 | value = None 349 | return value 350 | 351 | def adjust_contrast(self, x): 352 | if self.contrast: 353 | factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast) 354 | means = torch.mean(x, dim=[2, 3], keepdim=True) 355 | x = (x - means) * factor + means 356 | return torch.clamp(x, 0, 1) 357 | 358 | def adjust_hsv(self, x): 359 | f_h = x.new_zeros(x.size(0), 1, 1) 360 | f_s = x.new_ones(x.size(0), 1, 1) 361 | f_v = x.new_ones(x.size(0), 1, 1) 362 | 363 | if self.hue: 364 | f_h.uniform_(*self.hue) 365 | if self.saturation: 366 | f_s = f_s.uniform_(*self.saturation) 367 | if self.brightness: 368 | f_v = f_v.uniform_(*self.brightness) 369 | 370 | return RandomHSVFunction.apply(x, f_h, f_s, f_v) 371 | 372 | def transform(self, inputs): 373 | # Shuffle transform 374 | if np.random.rand() > 0.5: 375 | transforms = [self.adjust_contrast, self.adjust_hsv] 376 | else: 377 | transforms = [self.adjust_hsv, self.adjust_contrast] 378 | 379 | for t in transforms: 380 | inputs = t(inputs) 381 | 382 | return inputs 383 | 384 | def forward(self, inputs): 385 | _prob = inputs.new_full((inputs.size(0),), self.prob) 386 | _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) 387 | return inputs * (1 - _mask) + self.transform(inputs) * _mask 388 | 389 | 390 | class RandomHSVFunction(Function): 391 | @staticmethod 392 | def forward(ctx, x, f_h, f_s, f_v): 393 | # ctx is a context object that can be used to stash information 394 | # for backward computation 395 | x = rgb2hsv(x) 396 | h = x[:, 0, :, :] 397 | h += (f_h * 255. / 360.) 398 | h = (h % 1) 399 | x[:, 0, :, :] = h 400 | x[:, 1, :, :] = x[:, 1, :, :] * f_s 401 | x[:, 2, :, :] = x[:, 2, :, :] * f_v 402 | x = torch.clamp(x, 0, 1) 403 | x = hsv2rgb(x) 404 | return x 405 | 406 | @staticmethod 407 | def backward(ctx, grad_output): 408 | # We return as many input gradients as there were arguments. 409 | # Gradients of non-Tensor arguments to forward must be None. 410 | grad_input = None 411 | if ctx.needs_input_grad[0]: 412 | grad_input = grad_output.clone() 413 | return grad_input, None, None, None 414 | 415 | 416 | class NormalizeLayer(nn.Module): 417 | """ 418 | In order to certify radii in original coordinates rather than standardized coordinates, we 419 | add the Gaussian noise _before_ standardizing, which is why we have standardization be the first 420 | layer of the classifier rather than as a part of preprocessing as is typical. 421 | """ 422 | 423 | def __init__(self): 424 | super(NormalizeLayer, self).__init__() 425 | 426 | def forward(self, inputs): 427 | return (inputs - 0.5) / 0.5 428 | --------------------------------------------------------------------------------