├── datasets
├── __init__.py
└── datasetsHelper.py
├── models
├── conv
│ ├── __init__.py
│ ├── cnn.py
│ └── nets.py
├── fc
│ ├── __init__.py
│ └── nets.py
├── utils
│ ├── __init__.py
│ ├── loss_functions.py
│ ├── score.py
│ └── modules.py
├── vae.py
└── main_model.py
├── training
├── __init__.py
├── train_classifier.py
├── train_cvae.py
└── fine_tune.py
├── CSI+CMG
├── CSI_model
│ ├── __init__.py
│ ├── base_model.py
│ ├── classifier.py
│ ├── resnet.py
│ ├── resnet_imagenet.py
│ └── transform_layers.py
├── models
│ ├── conv
│ │ ├── __init__.py
│ │ ├── cnn.py
│ │ └── nets.py
│ ├── fc
│ │ ├── __init__.py
│ │ └── nets.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── loss_functions.py
│ │ ├── score.py
│ │ └── modules.py
│ └── vae.py
├── utils.py
├── readme.md
└── main.py
├── figure
└── model.png
├── .gitignore
├── utils.py
├── train.py
├── readme.md
└── eval.py
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/conv/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/fc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/CSI+CMG/CSI_model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/CSI+CMG/models/conv/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/CSI+CMG/models/fc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/CSI+CMG/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figure/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaoyijia/CMG/HEAD/figure/model.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Development env
2 | .idea
3 |
4 | # Cache files
5 | __pycache__/
6 | *.pyc
--------------------------------------------------------------------------------
/CSI+CMG/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--device', type=str, default='cuda:0',
7 | help='device for training')
8 | parser.add_argument('--params-dict-name', type=str,
9 | help='name of the classifier checkpoint file')
10 | parser.add_argument('--params-dict-name2', type=str,
11 | help='name of the CVAE checkpoint file')
12 | return parser.parse_args()
13 |
--------------------------------------------------------------------------------
/models/utils/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def weighted_average(tensor, weights=None, dim=0):
5 | """Computes weighted average of [tensor] over dimension [dim]."""
6 |
7 | if weights is None:
8 | mean = torch.mean(tensor, dim=dim)
9 | else:
10 | batch_size = tensor.size(dim) if len(tensor.size()) > 0 else 1
11 | assert len(weights) == batch_size
12 | norm_weights = torch.tensor([weight for weight in weights]).to(tensor.device)
13 | mean = torch.mean(norm_weights * tensor, dim=dim)
14 | return mean
15 |
--------------------------------------------------------------------------------
/CSI+CMG/models/utils/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def weighted_average(tensor, weights=None, dim=0):
5 | """Computes weighted average of [tensor] over dimension [dim]."""
6 | if weights is None:
7 | mean = torch.mean(tensor, dim=dim)
8 | else:
9 | batch_size = tensor.size(dim) if len(tensor.size()) > 0 else 1
10 | assert len(weights) == batch_size
11 | norm_weights = torch.tensor([weight for weight in weights]).to(tensor.device)
12 | mean = torch.mean(norm_weights * tensor, dim=dim)
13 | return mean
14 |
--------------------------------------------------------------------------------
/models/fc/nets.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import nn
3 |
4 |
5 | class MLP(nn.Module):
6 | """
7 | Module for a multi-layer perceptron (MLP).
8 |
9 | input: <2D tensor> [batch_size] * [input_dim]
10 | output: <2D tensor> [batch_size] * [classes]
11 |
12 | """
13 |
14 | def __init__(self, input_dim, classes, latent_dim=512):
15 | super(MLP, self).__init__()
16 | self.input_dim = input_dim
17 | self.classes = classes
18 | self.fc1 = nn.Linear(input_dim, latent_dim)
19 | self.fc2 = nn.Linear(latent_dim, classes)
20 |
21 | def forward(self, x):
22 | x = x.view(-1, self.input_dim)
23 | x = F.relu(self.fc1(x))
24 | result = self.fc2(x)
25 |
26 | return result
27 |
--------------------------------------------------------------------------------
/CSI+CMG/models/fc/nets.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import nn
3 |
4 |
5 | class MLP(nn.Module):
6 | """
7 | Module for a multi-layer perceptron (MLP).
8 |
9 | input: <2D tensor> [batch_size] * [input_dim]
10 | output: <2D tensor> [batch_size] * [classes]
11 |
12 | """
13 |
14 | def __init__(self, input_dim, classes, latent_dim=512):
15 | super(MLP, self).__init__()
16 | self.input_dim = input_dim
17 | self.classes = classes
18 | self.fc1 = nn.Linear(input_dim, latent_dim)
19 | self.fc2 = nn.Linear(latent_dim, classes)
20 |
21 | def forward(self, x):
22 | x = x.view(-1, self.input_dim)
23 | x = F.relu(self.fc1(x))
24 | result = self.fc2(x)
25 |
26 | return result
27 |
--------------------------------------------------------------------------------
/training/train_classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def train_classifier(model, train_loader, device, params_dict_name, dataset='mnist'):
5 | """Trains the IND classifier on the given dataset."""
6 |
7 | if dataset == 'mnist':
8 | n_epochs = 100
9 | elif dataset == 'cifar10':
10 | n_epochs = 200
11 | else:
12 | raise ValueError
13 | LR = 0.001
14 | model.optimizer = torch.optim.Adam(params=model.parameters(), lr=LR)
15 |
16 | for epoch in range(n_epochs):
17 | for data, target in train_loader:
18 | data = data.to(device)
19 | target = target.long().to(device)
20 | train_loss = model.train_a_batch(data, target)
21 |
22 | print('Epoch: {}, loss = {}'.format(epoch, train_loss))
23 |
24 | torch.save(model.state_dict(), params_dict_name)
25 |
--------------------------------------------------------------------------------
/models/utils/score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from sklearn.metrics import roc_auc_score
4 |
5 |
6 | def softmax_result(fx, y):
7 | """Calculates roc_auc using softmax score of the unseen slot.
8 |
9 | Args:
10 | fx: Last layer output of the model, assumes the unseen slot to be the last one.
11 | y: Class Label, assumes the label of unseen data to be -1.
12 | Returns:
13 | roc_auc: Unseen data as positive, seen data as negative.
14 | """
15 | score = F.softmax(fx, dim=1)[:, -1]
16 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), score.cpu().detach().numpy())
17 |
18 | return rocauc
19 |
20 |
21 | def energy_result(fx, y):
22 | """Calculates roc_auc using energy score.
23 |
24 | Args:
25 | fx: Last layer output of the model, assumes the unseen slot to be the last one.
26 | y: Class Label, assumes the label of unseen data to be -1.
27 | Returns:
28 | roc_auc: Unseen data as positive, seen data as negative.
29 | """
30 | energy_score = - torch.logsumexp(fx[:, :-1], dim=1)
31 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), energy_score.cpu().detach().numpy())
32 |
33 | return rocauc
34 |
--------------------------------------------------------------------------------
/CSI+CMG/models/utils/score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from sklearn.metrics import roc_auc_score
4 |
5 |
6 | def softmax_result(fx, y):
7 | """Calculates roc_auc using softmax score of the unseen slot.
8 |
9 | Args:
10 | fx: Last layer output of the model, assumes the unseen slot to be the last one.
11 | y: Class Label, assumes the label of unseen data to be -1.
12 | Returns:
13 | roc_auc: Unseen data as positive, seen data as negative.
14 | """
15 | score = F.softmax(fx, dim=1)[:, -1]
16 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), score.cpu().detach().numpy())
17 |
18 | return rocauc
19 |
20 |
21 | def energy_result(fx, y):
22 | """Calculates roc_auc using energy score.
23 |
24 | Args:
25 | fx: Last layer output of the model, assumes the unseen slot to be the last one.
26 | y: Class Label, assumes the label of unseen data to be -1.
27 | Returns:
28 | roc_auc: Unseen data as positive, seen data as negative.
29 | """
30 | energy_score = - torch.logsumexp(fx[:, :-1], dim=1)
31 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), energy_score.cpu().detach().numpy())
32 |
33 | return rocauc
34 |
--------------------------------------------------------------------------------
/CSI+CMG/readme.md:
--------------------------------------------------------------------------------
1 | # Reproduce CSI+CMG for new SOTA
2 |
3 | ## Requirements
4 |
5 | ### Environments
6 |
7 | The required packages are as follows:
8 |
9 | - python 3.5
10 | - torch 1.2
11 | - torchvision 0.4
12 | - CUDA 10.0
13 | - scikit-learn 0.22
14 |
15 | ### Datasets
16 |
17 | Please download datasets to `./data` and rename the file. Or you may modify the data path in [main.py](main.py).
18 |
19 | ### Checkpoints
20 |
21 | Please download
22 | the [CSI pretrained model](https://drive.google.com/file/d/1rW2-0MJEzPHLb_PAW-LvCivHt-TkDpRO/view?usp=sharing) provided
23 | by [CSI](https://github.com/alinlab/CSI) and save it as `./checkpoint/cifar10_labeled.model`. You can also train your
24 | own model with CSI's code for other settings.
25 |
26 | Also, you need to pretrain a CVAE model with CIFAR10 training data according to CMG stage 1, and save the checkpoint
27 | as `./checkpoint/cvae_10class.pkl`.
28 |
29 | ## Applying CMG and Evaluations
30 |
31 | To perform CMG tuning on CSI models and get the final result on CIFAR10 (OOD Detection on different datasets), run this
32 | command:
33 |
34 | ```
35 | python -m main \
36 | --device {the available GPU in your cluser, e.g., cuda:0} \
37 | --params-dict-name './checkpoint/cifar10_labled.model' \
38 | --params-dict-name2 './checkpoint/cvae_10class.pkl'
39 | ```
40 |
--------------------------------------------------------------------------------
/CSI+CMG/CSI_model/base_model.py:
--------------------------------------------------------------------------------
1 | from abc import *
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class BaseModel(nn.Module, metaclass=ABCMeta):
7 | def __init__(self, last_dim, num_classes=10, simclr_dim=128):
8 | super(BaseModel, self).__init__()
9 | self.linear = nn.Linear(last_dim, num_classes)
10 | self.simclr_layer = nn.Sequential(
11 | nn.Linear(last_dim, last_dim),
12 | nn.ReLU(),
13 | nn.Linear(last_dim, simclr_dim),
14 | )
15 | self.shift_cls_layer = nn.Linear(last_dim, 2)
16 | self.joint_distribution_layer = nn.Linear(last_dim, 4 * num_classes)
17 |
18 | @abstractmethod
19 | def penultimate(self, inputs, all_features=False):
20 | pass
21 |
22 | def forward(self, inputs, penultimate=False, simclr=False, shift=False, joint=False):
23 | _aux = {}
24 | _return_aux = False
25 |
26 | features = self.penultimate(inputs)
27 |
28 | output = self.linear(features)
29 |
30 | if penultimate:
31 | _return_aux = True
32 | _aux['penultimate'] = features
33 |
34 | if simclr:
35 | _return_aux = True
36 | _aux['simclr'] = self.simclr_layer(features)
37 |
38 | if shift:
39 | _return_aux = True
40 | _aux['shift'] = self.shift_cls_layer(features)
41 |
42 | if joint:
43 | _return_aux = True
44 | _aux['joint'] = self.joint_distribution_layer(features)
45 |
46 | if _return_aux:
47 | return output, _aux
48 |
49 | return output
50 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--task', type=str, default='same_dataset_mnist',
7 | choices=['same_dataset_mnist', 'same_dataset_cifar10', 'different_dataset'],
8 | help='the current task: same_dataset_mnist/same_dataset_cifar10/different_dataset')
9 | parser.add_argument('--partition', type=str, default='partition1')
10 | parser.add_argument('--command', type=str, default='train_classifier',
11 | choices=['train_classifier', 'train_cvae'],
12 | help='command for CMG stage 1: train_classifier/train_cvae')
13 | parser.add_argument('--ood-dataset', type=str,
14 | choices=['SVHN', 'LSUN', 'LSUN-FIX', 'tinyImageNet', 'ImageNet-FIX', 'CIFAR100'],
15 | help='OOD dataset for setting 2: SVHN/LSUN/LSUN-FIX/tinyImageNet/ImageNet-FIX/CIFAR100')
16 | parser.add_argument('--mode', type=str, default='CMG-energy', choices=['CMG-softmax', 'CMG-energy'],
17 | help="CMG-softmax/CMG-energy")
18 | parser.add_argument('--device', type=str, default='cuda:0',
19 | help='device for training')
20 | parser.add_argument('--params-dict-name', type=str,
21 | help='name of the classifier checkpoint file')
22 | parser.add_argument('--params-dict-name2', type=str,
23 | help='name of the CVAE checkpoint file')
24 | parser.add_argument('--seed', type=int, default=123, help='set random seed')
25 | return parser.parse_args()
26 |
--------------------------------------------------------------------------------
/training/train_cvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from models.utils import loss_functions as lf
4 |
5 |
6 | def train_cvae(model, train_loader, device, params_dict_name, dataset='mnist'):
7 | """Trains CVAE on the given dataset."""
8 |
9 | if dataset == 'mnist':
10 | n_epochs = 100
11 | elif dataset == 'cifar10':
12 | n_epochs = 200
13 | else:
14 | raise ValueError
15 | LR = 0.001
16 |
17 | optimizer = torch.optim.Adam(params=model.parameters(), lr=LR)
18 |
19 | for epoch in range(n_epochs):
20 | for data, y in train_loader:
21 | optimizer.zero_grad()
22 | data = data.to(device)
23 | y = y.long()
24 | y_onehot = torch.Tensor(y.shape[0], model.class_num)
25 | y_onehot.zero_()
26 | y_onehot.scatter_(1, y.view(-1, 1), 1)
27 | y_onehot = y_onehot.to(device)
28 | mu, logvar, recon = model(data, y_onehot)
29 |
30 | variatL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
31 | variatL = lf.weighted_average(variatL, weights=None, dim=0)
32 | variatL /= (model.image_channels * model.image_size * model.image_size)
33 |
34 | data_resize = data.reshape(-1, model.image_channels * model.image_size * model.image_size)
35 | recon_resize = recon.reshape(-1, model.image_channels * model.image_size * model.image_size)
36 | reconL = (data_resize - recon_resize) ** 2
37 | reconL = torch.mean(reconL, 1)
38 | reconL = lf.weighted_average(reconL, weights=None, dim=0)
39 |
40 | loss = variatL + reconL
41 |
42 | loss.backward()
43 | optimizer.step()
44 |
45 | print("epoch: {}, loss = {}, reconL = {}, variaL = {}".format(epoch, loss, reconL, variatL))
46 |
47 | torch.save(model.state_dict(), params_dict_name)
48 |
--------------------------------------------------------------------------------
/models/utils/modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch import nn
3 |
4 |
5 | class Identity(nn.Module):
6 | """A nn-module to simply pass on the input data."""
7 |
8 | def forward(self, x):
9 | return x
10 |
11 | def __repr__(self):
12 | tmpstr = self.__class__.__name__ + '()'
13 | return tmpstr
14 |
15 |
16 | class Shape(nn.Module):
17 | """A nn-module to shape a tensor of shape [shape]."""
18 |
19 | def __init__(self, shape):
20 | super().__init__()
21 | self.shape = shape
22 | self.dim = len(shape)
23 |
24 | def forward(self, x):
25 | return x.view(*self.shape)
26 |
27 | def __repr__(self):
28 | tmpstr = self.__class__.__name__ + '(shape = {})'.format(self.shape)
29 | return tmpstr
30 |
31 |
32 | class Reshape(nn.Module):
33 | """A nn-module to reshape a tensor(-tuple) to a 4-dim "image"-tensor(-tuple) with [image_channels] channels."""
34 |
35 | def __init__(self, image_channels):
36 | super().__init__()
37 | self.image_channels = image_channels
38 |
39 | def forward(self, x):
40 | if type(x) == tuple:
41 | batch_size = x[0].size(0) # first dimension should be batch-dimension.
42 | image_size = int(np.sqrt(x[0].nelement() / (batch_size * self.image_channels)))
43 | return (x_item.view(batch_size, self.image_channels, image_size, image_size) for x_item in x)
44 | else:
45 | batch_size = x.size(0) # first dimension should be batch-dimension.
46 | image_size = int(np.sqrt(x.nelement() / (batch_size * self.image_channels)))
47 | return x.view(batch_size, self.image_channels, image_size, image_size)
48 |
49 | def __repr__(self):
50 | tmpstr = self.__class__.__name__ + '(channels = {})'.format(self.image_channels)
51 | return tmpstr
52 |
53 |
54 | class Flatten(nn.Module):
55 | """A nn-module to flatten a multi-dimensional tensor to 2-dim tensor."""
56 |
57 | def forward(self, x):
58 | batch_size = x.size(0) # first dimension should be batch-dimension.
59 | return x.contiguous().view(batch_size, -1)
60 |
61 | def __repr__(self):
62 | tmpstr = self.__class__.__name__ + '()'
63 | return tmpstr
64 |
--------------------------------------------------------------------------------
/CSI+CMG/models/utils/modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch import nn
3 |
4 |
5 | class Identity(nn.Module):
6 | """A nn-module to simply pass on the input data."""
7 |
8 | def forward(self, x):
9 | return x
10 |
11 | def __repr__(self):
12 | tmpstr = self.__class__.__name__ + '()'
13 | return tmpstr
14 |
15 |
16 | class Shape(nn.Module):
17 | """A nn-module to shape a tensor of shape [shape]."""
18 |
19 | def __init__(self, shape):
20 | super().__init__()
21 | self.shape = shape
22 | self.dim = len(shape)
23 |
24 | def forward(self, x):
25 | return x.view(*self.shape)
26 |
27 | def __repr__(self):
28 | tmpstr = self.__class__.__name__ + '(shape = {})'.format(self.shape)
29 | return tmpstr
30 |
31 |
32 | class Reshape(nn.Module):
33 | """A nn-module to reshape a tensor(-tuple) to a 4-dim "image"-tensor(-tuple) with [image_channels] channels."""
34 |
35 | def __init__(self, image_channels):
36 | super().__init__()
37 | self.image_channels = image_channels
38 |
39 | def forward(self, x):
40 | if type(x) == tuple:
41 | batch_size = x[0].size(0) # first dimension should be batch-dimension.
42 | image_size = int(np.sqrt(x[0].nelement() / (batch_size * self.image_channels)))
43 | return (x_item.view(batch_size, self.image_channels, image_size, image_size) for x_item in x)
44 | else:
45 | batch_size = x.size(0) # first dimension should be batch-dimension.
46 | image_size = int(np.sqrt(x.nelement() / (batch_size * self.image_channels)))
47 | return x.view(batch_size, self.image_channels, image_size, image_size)
48 |
49 | def __repr__(self):
50 | tmpstr = self.__class__.__name__ + '(channels = {})'.format(self.image_channels)
51 | return tmpstr
52 |
53 |
54 | class Flatten(nn.Module):
55 | """A nn-module to flatten a multi-dimensional tensor to 2-dim tensor."""
56 |
57 | def forward(self, x):
58 | batch_size = x.size(0) # first dimension should be batch-dimension.
59 | return x.contiguous().view(batch_size, -1)
60 |
61 | def __repr__(self):
62 | tmpstr = self.__class__.__name__ + '()'
63 | return tmpstr
64 |
--------------------------------------------------------------------------------
/CSI+CMG/CSI_model/classifier.py:
--------------------------------------------------------------------------------
1 | import CSI_model.transform_layers as TL
2 | import torch.nn as nn
3 | from CSI_model.resnet import ResNet18, ResNet34, ResNet50
4 | from CSI_model.resnet_imagenet import resnet18, resnet50
5 |
6 |
7 | def get_simclr_augmentation(P, image_size):
8 | # parameter for resizecrop
9 | resize_scale = (P.resize_factor, 1.0) # resize scaling factor
10 | if P.resize_fix: # if resize_fix is True, use same scale
11 | resize_scale = (P.resize_factor, P.resize_factor)
12 |
13 | # Align augmentation
14 | color_jitter = TL.ColorJitterLayer(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8)
15 | color_gray = TL.RandomColorGrayLayer(p=0.2)
16 | resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=image_size)
17 |
18 | # Transform define #
19 | if P.dataset == 'imagenet': # Using RandomResizedCrop at PIL transform
20 | transform = nn.Sequential(
21 | color_jitter,
22 | color_gray,
23 | )
24 | else:
25 | transform = nn.Sequential(
26 | color_jitter,
27 | color_gray,
28 | resize_crop,
29 | )
30 |
31 | return transform
32 |
33 |
34 | def get_shift_module(P, eval=False):
35 | if P.shift_trans_type == 'rotation':
36 | shift_transform = TL.Rotation()
37 | K_shift = 4
38 | elif P.shift_trans_type == 'cutperm':
39 | shift_transform = TL.CutPerm()
40 | K_shift = 4
41 | else:
42 | shift_transform = nn.Identity()
43 | K_shift = 1
44 |
45 | if not eval and not ('sup' in P.mode):
46 | assert P.batch_size == int(128 / K_shift)
47 |
48 | return shift_transform, K_shift
49 |
50 |
51 | def get_shift_classifer(model, K_shift):
52 | model.shift_cls_layer = nn.Linear(model.last_dim, K_shift)
53 |
54 | return model
55 |
56 |
57 | def get_classifier(mode, n_classes=10):
58 | if mode == 'resnet18':
59 | classifier = ResNet18(num_classes=n_classes)
60 | elif mode == 'resnet34':
61 | classifier = ResNet34(num_classes=n_classes)
62 | elif mode == 'resnet50':
63 | classifier = ResNet50(num_classes=n_classes)
64 | elif mode == 'resnet18_imagenet':
65 | classifier = resnet18(num_classes=n_classes)
66 | elif mode == 'resnet50_imagenet':
67 | classifier = resnet50(num_classes=n_classes)
68 | else:
69 | raise NotImplementedError()
70 |
71 | return classifier
72 |
--------------------------------------------------------------------------------
/CSI+CMG/models/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.conv.nets import ConvLayers, DeconvLayers, ResNet18, DeconvResnet
6 | from models.utils import modules
7 |
8 |
9 | class ConditionalVAE(nn.Module):
10 | def __init__(self, z_dim=32, image_channels=1, image_size=28, class_num=10, dataset='mnist'):
11 | super(ConditionalVAE, self).__init__()
12 | self.class_num = class_num
13 | self.image_size = image_size
14 | self.image_channels = image_channels
15 |
16 | if dataset == 'mnist':
17 | self.convE = ConvLayers(image_channels)
18 | self.convD = DeconvLayers(image_channels)
19 | elif dataset == 'cifar10':
20 | self.convE = ResNet18()
21 | self.convD = DeconvResnet(image_channels)
22 | else:
23 | raise NotImplementedError
24 | self.flatten = modules.Flatten()
25 | self.fcE = nn.Linear(self.convE.out_feature_dim, 1024)
26 | self.z_dim = z_dim
27 | self.fcE_mean = nn.Linear(1024, self.z_dim)
28 | self.fcE_logvar = nn.Linear(1024, self.z_dim)
29 | self.fromZ = nn.Linear(2 * self.z_dim, 1024)
30 | self.fcD = nn.Linear(1024, self.convD.in_feature_dim)
31 | self.to_image = modules.Reshape(image_channels=self.convE.out_channels)
32 | self.device = None
33 |
34 | self.class_embed = nn.Linear(class_num, self.z_dim)
35 |
36 | def encode(self, x):
37 | hidden_x = self.convE(x)
38 | feature = self.flatten(hidden_x)
39 |
40 | hE = F.relu(self.fcE(feature))
41 |
42 | z_mean = self.fcE_mean(hE)
43 | z_logvar = self.fcE_logvar(hE)
44 |
45 | return z_mean, z_logvar, hE, hidden_x
46 |
47 | def reparameterize(self, mu, logvar):
48 | std = torch.exp(0.5 * logvar).to(self.device)
49 | z = torch.randn(std.size()).to(self.device) * std + mu.to(self.device)
50 | return z
51 |
52 | def decode(self, z, y_embed):
53 | z = torch.cat([z, y_embed], dim=1) # add label information
54 | hD = F.relu(self.fromZ(z))
55 | feature = self.fcD(hD)
56 | image_recon = self.convD(feature.view(-1, self.convD.in_channel, self.convD.in_size, self.convD.in_size))
57 |
58 | return image_recon
59 |
60 | def forward(self, x, y_tensor):
61 | mu, logvar, hE, hidden_x = self.encode(x)
62 | z = self.reparameterize(mu, logvar)
63 | y_embed = self.class_embed(y_tensor)
64 | x_recon = self.decode(z, y_embed)
65 | return mu, logvar, x_recon
66 |
--------------------------------------------------------------------------------
/models/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.conv.nets import ConvLayers, DeconvLayers, ResNet18, DeconvResnet
6 | from models.utils import modules
7 |
8 |
9 | class ConditionalVAE(nn.Module):
10 | """Variational Auto-Encoder with class conditional information"""
11 |
12 | def __init__(self, z_dim=32, image_channels=1, image_size=28, class_num=10, dataset='mnist'):
13 | super(ConditionalVAE, self).__init__()
14 | self.class_num = class_num
15 | self.image_size = image_size
16 | self.image_channels = image_channels
17 |
18 | if dataset == 'mnist':
19 | self.convE = ConvLayers(image_channels)
20 | self.convD = DeconvLayers(image_channels)
21 | elif dataset == 'cifar10':
22 | self.convE = ResNet18()
23 | self.convD = DeconvResnet(image_channels)
24 | else:
25 | raise NotImplementedError
26 | self.flatten = modules.Flatten()
27 | self.fcE = nn.Linear(self.convE.out_feature_dim, 1024)
28 | self.z_dim = z_dim
29 | self.fcE_mean = nn.Linear(1024, self.z_dim)
30 | self.fcE_logvar = nn.Linear(1024, self.z_dim)
31 | self.fromZ = nn.Linear(2 * self.z_dim, 1024)
32 | self.fcD = nn.Linear(1024, self.convD.in_feature_dim)
33 | self.to_image = modules.Reshape(image_channels=self.convE.out_channels)
34 | self.device = None
35 |
36 | self.class_embed = nn.Linear(class_num, self.z_dim)
37 |
38 | def encode(self, x):
39 | hidden_x = self.convE(x)
40 | feature = self.flatten(hidden_x)
41 |
42 | hE = F.relu(self.fcE(feature))
43 |
44 | z_mean = self.fcE_mean(hE)
45 | z_logvar = self.fcE_logvar(hE)
46 |
47 | return z_mean, z_logvar, hE, hidden_x
48 |
49 | def reparameterize(self, mu, logvar):
50 | std = torch.exp(0.5 * logvar).to(self.device)
51 | z = torch.randn(std.size()).to(self.device) * std + mu.to(self.device)
52 | return z
53 |
54 | def decode(self, z, y_embed):
55 | z = torch.cat([z, y_embed], dim=1) # add label information
56 | hD = F.relu(self.fromZ(z))
57 | feature = self.fcD(hD)
58 | image_recon = self.convD(feature.view(-1, self.convD.in_channel, self.convD.in_size, self.convD.in_size))
59 |
60 | return image_recon
61 |
62 | def forward(self, x, y_tensor):
63 | mu, logvar, hE, hidden_x = self.encode(x)
64 | z = self.reparameterize(mu, logvar)
65 | y_embed = self.class_embed(y_tensor)
66 | x_recon = self.decode(z, y_embed)
67 | return mu, logvar, x_recon
68 |
--------------------------------------------------------------------------------
/models/conv/cnn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | def weights_init(m):
5 | classname = m.__class__.__name__
6 | if classname.find('Conv') != -1:
7 | m.weight.data.normal_(0.0, 0.05)
8 | elif classname.find('BatchNorm') != -1:
9 | m.weight.data.normal_(1.0, 0.02)
10 | m.bias.data.fill_(0)
11 |
12 |
13 | class ConvEncoder(nn.Module):
14 | """9-layer CNN encoder (feature extractor)"""
15 |
16 | def __init__(self, image_channels=1):
17 | super(self.__class__, self).__init__()
18 | self.conv1 = nn.Conv2d(image_channels, 64, 3, 1, 1, bias=False)
19 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
20 | self.conv3 = nn.Conv2d(64, 128, 3, 2, 1, bias=False)
21 |
22 | self.conv4 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
23 | self.conv5 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
24 | self.conv6 = nn.Conv2d(128, 128, 3, 2, 1, bias=False)
25 |
26 | self.conv7 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
27 | self.conv8 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
28 | self.conv9 = nn.Conv2d(128, 128, 3, 2, 1, bias=False)
29 |
30 | self.bn1 = nn.BatchNorm2d(64)
31 | self.bn2 = nn.BatchNorm2d(64)
32 | self.bn3 = nn.BatchNorm2d(128)
33 |
34 | self.bn4 = nn.BatchNorm2d(128)
35 | self.bn5 = nn.BatchNorm2d(128)
36 | self.bn6 = nn.BatchNorm2d(128)
37 |
38 | self.bn7 = nn.BatchNorm2d(128)
39 | self.bn8 = nn.BatchNorm2d(128)
40 | self.bn9 = nn.BatchNorm2d(128)
41 |
42 | self.dr1 = nn.Dropout2d(0.2)
43 | self.dr2 = nn.Dropout2d(0.2)
44 | self.dr3 = nn.Dropout2d(0.2)
45 |
46 | self.apply(weights_init)
47 | self.out_channels = 128
48 | self.out_feature_dim = 128 * 4 * 4
49 |
50 | def forward(self, x):
51 | x = self.dr1(x)
52 | x = self.conv1(x)
53 | x = self.bn1(x)
54 | x = nn.LeakyReLU(0.2)(x)
55 | x = self.conv2(x)
56 | x = self.bn2(x)
57 | x = nn.LeakyReLU(0.2)(x)
58 | x = self.conv3(x)
59 | x = self.bn3(x)
60 | x = nn.LeakyReLU(0.2)(x)
61 |
62 | x = self.dr2(x)
63 | x = self.conv4(x)
64 | x = self.bn4(x)
65 | x = nn.LeakyReLU(0.2)(x)
66 | x = self.conv5(x)
67 | x = self.bn5(x)
68 | x = nn.LeakyReLU(0.2)(x)
69 | x = self.conv6(x)
70 | x = self.bn6(x)
71 | x = nn.LeakyReLU(0.2)(x)
72 |
73 | x = self.dr3(x)
74 | x = self.conv7(x)
75 | x = self.bn7(x)
76 | x = nn.LeakyReLU(0.2)(x)
77 | x = self.conv8(x)
78 | x = self.bn8(x)
79 | x = nn.LeakyReLU(0.2)(x)
80 | x = self.conv9(x)
81 | x = self.bn9(x)
82 | x = nn.LeakyReLU(0.2)(x)
83 |
84 | return x
85 |
--------------------------------------------------------------------------------
/CSI+CMG/models/conv/cnn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | def weights_init(m):
5 | classname = m.__class__.__name__
6 | if classname.find('Conv') != -1:
7 | m.weight.data.normal_(0.0, 0.05)
8 | elif classname.find('BatchNorm') != -1:
9 | m.weight.data.normal_(1.0, 0.02)
10 | m.bias.data.fill_(0)
11 |
12 |
13 | class ConvEncoder(nn.Module):
14 | """9-layer CNN encoder (feature extractor)"""
15 |
16 | def __init__(self, image_channels=1):
17 | super(self.__class__, self).__init__()
18 | self.conv1 = nn.Conv2d(image_channels, 64, 3, 1, 1, bias=False)
19 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
20 | self.conv3 = nn.Conv2d(64, 128, 3, 2, 1, bias=False)
21 |
22 | self.conv4 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
23 | self.conv5 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
24 | self.conv6 = nn.Conv2d(128, 128, 3, 2, 1, bias=False)
25 |
26 | self.conv7 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
27 | self.conv8 = nn.Conv2d(128, 128, 3, 1, 1, bias=False)
28 | self.conv9 = nn.Conv2d(128, 128, 3, 2, 1, bias=False)
29 |
30 | self.bn1 = nn.BatchNorm2d(64)
31 | self.bn2 = nn.BatchNorm2d(64)
32 | self.bn3 = nn.BatchNorm2d(128)
33 |
34 | self.bn4 = nn.BatchNorm2d(128)
35 | self.bn5 = nn.BatchNorm2d(128)
36 | self.bn6 = nn.BatchNorm2d(128)
37 |
38 | self.bn7 = nn.BatchNorm2d(128)
39 | self.bn8 = nn.BatchNorm2d(128)
40 | self.bn9 = nn.BatchNorm2d(128)
41 |
42 | self.dr1 = nn.Dropout2d(0.2)
43 | self.dr2 = nn.Dropout2d(0.2)
44 | self.dr3 = nn.Dropout2d(0.2)
45 |
46 | self.apply(weights_init)
47 | self.out_channels = 128
48 | self.out_feature_dim = 128 * 4 * 4
49 |
50 | def forward(self, x):
51 | x = self.dr1(x)
52 | x = self.conv1(x)
53 | x = self.bn1(x)
54 | x = nn.LeakyReLU(0.2)(x)
55 | x = self.conv2(x)
56 | x = self.bn2(x)
57 | x = nn.LeakyReLU(0.2)(x)
58 | x = self.conv3(x)
59 | x = self.bn3(x)
60 | x = nn.LeakyReLU(0.2)(x)
61 |
62 | x = self.dr2(x)
63 | x = self.conv4(x)
64 | x = self.bn4(x)
65 | x = nn.LeakyReLU(0.2)(x)
66 | x = self.conv5(x)
67 | x = self.bn5(x)
68 | x = nn.LeakyReLU(0.2)(x)
69 | x = self.conv6(x)
70 | x = self.bn6(x)
71 | x = nn.LeakyReLU(0.2)(x)
72 |
73 | x = self.dr3(x)
74 | x = self.conv7(x)
75 | x = self.bn7(x)
76 | x = nn.LeakyReLU(0.2)(x)
77 | x = self.conv8(x)
78 | x = self.bn8(x)
79 | x = nn.LeakyReLU(0.2)(x)
80 | x = self.conv9(x)
81 | x = self.bn9(x)
82 | x = nn.LeakyReLU(0.2)(x)
83 |
84 | return x
85 |
--------------------------------------------------------------------------------
/models/main_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from models.conv.cnn import ConvEncoder
5 | from models.conv.nets import ResNet18
6 | from models.fc.nets import MLP
7 | from models.utils import loss_functions as lf
8 | from models.utils import modules
9 |
10 |
11 | class MainModel(nn.Module):
12 | """
13 | feature_extractor(CNN) -> classifier (MLP)
14 | """
15 |
16 | def __init__(self, image_size, image_channels, classes, dataset='mnist'):
17 | super(MainModel, self).__init__()
18 | self.image_size = image_size
19 | self.image_channels = image_channels
20 | self.classes = classes
21 |
22 | # for encoder
23 | if dataset == 'mnist':
24 | self.convE = ConvEncoder(image_channels=image_channels)
25 | elif dataset == 'cifar10':
26 | self.convE = ResNet18()
27 | else:
28 | raise NotImplementedError
29 | self.flatten = modules.Flatten()
30 |
31 | # classifier
32 | self.classifier = MLP(self.convE.out_feature_dim, classes)
33 |
34 | self.optimizer = None # needs to be set before training starts
35 |
36 | self.device = None # needs to be set before using the model
37 |
38 | # --------- FROWARD FUNCTIONS ---------#
39 | def encode(self, x):
40 | """
41 | pass input through feed-forward connections to get [image_features]
42 | """
43 | # Forward-pass through conv-layers
44 | hidden_x = self.convE(x)
45 |
46 | return hidden_x
47 |
48 | def classify(self, x):
49 | """
50 | For input [x] (image or extracted "internal“ image features),
51 | return predicted scores (<2D tensor> [batch_size] * [classes])
52 | """
53 | result = self.classifier(x)
54 | return result
55 |
56 | def forward(self, x):
57 | """
58 | Forward function to propagate [x] through the encoder and the classifier.
59 | """
60 | hidden_x = self.encode(x)
61 | prediction = self.classifier(hidden_x)
62 | return prediction
63 |
64 | # ------------------TRAINING FUNCTIONS----------------------#
65 | def train_a_batch(self, x, y):
66 | """
67 | Train model for one batch ([x], [y])
68 | """
69 | # Set model to training-mode
70 | self.train()
71 |
72 | # Reset optimizer
73 | self.optimizer.zero_grad()
74 |
75 | # Run the model
76 | hidden_x = self.encode(x)
77 | prediction = self.classifier(hidden_x)
78 | predL = F.cross_entropy(prediction, y, reduction='none')
79 | loss = lf.weighted_average(predL, weights=None, dim=0)
80 |
81 | loss.backward()
82 |
83 | self.optimizer.step()
84 |
85 | return loss.item()
86 |
--------------------------------------------------------------------------------
/datasets/datasetsHelper.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from torch.utils.data.dataset import Subset
4 | from torchvision import datasets, transforms
5 |
6 | root_path = os.path.dirname(__file__)
7 |
8 |
9 | def get_subclass_dataset(dataset, classes):
10 | if not isinstance(classes, list):
11 | classes = [classes]
12 | indices = []
13 | for idx, data in enumerate(dataset):
14 | if data[1] in classes:
15 | indices.append(idx)
16 |
17 | dataset = Subset(dataset, indices)
18 | return dataset
19 |
20 |
21 | def get_dataset(dataset, train_transform, test_transform, download=False, seen=None):
22 | """Get datasets for setting 1 (OOD Detection on the Same Dataset)."""
23 |
24 | if dataset == 'cifar10':
25 | DATA_PATH = os.path.join(root_path, 'cifar10')
26 | class_idx = [int(num) for num in seen]
27 | for i in range(10):
28 | if i in class_idx:
29 | continue
30 | class_idx.append(i)
31 | train_set = datasets.CIFAR10(DATA_PATH, train=True, download=download, transform=train_transform,
32 | target_transform=lambda x: class_idx.index(x))
33 | test_set = datasets.CIFAR10(DATA_PATH, train=False, download=download, transform=test_transform,
34 | target_transform=lambda x: class_idx.index(x))
35 | seen_class_idx = [0, 1, 2, 3, 4, 5]
36 | unseen_class_idx = [6, 7, 8, 9]
37 |
38 | train_set = get_subclass_dataset(train_set, seen_class_idx)
39 | test_set_seen = get_subclass_dataset(test_set, seen_class_idx)
40 | test_set_unseen = get_subclass_dataset(test_set, unseen_class_idx)
41 |
42 | elif dataset == 'mnist':
43 | DATA_PATH = os.path.join(root_path, 'mnist')
44 | class_idx = [int(num) for num in seen]
45 | for i in range(10):
46 | if i in class_idx:
47 | continue
48 | class_idx.append(i)
49 | train_set = datasets.MNIST(DATA_PATH, train=True, download=download, transform=train_transform,
50 | target_transform=lambda x: class_idx.index(x))
51 | test_set = datasets.MNIST(DATA_PATH, train=False, download=download, transform=test_transform,
52 | target_transform=lambda x: class_idx.index(x))
53 | seen_class_idx = [0, 1, 2, 3, 4, 5]
54 | unseen_class_idx = [6, 7, 8, 9]
55 |
56 | train_set = get_subclass_dataset(train_set, seen_class_idx)
57 | test_set_seen = get_subclass_dataset(test_set, seen_class_idx)
58 | test_set_unseen = get_subclass_dataset(test_set, unseen_class_idx)
59 |
60 | else:
61 | raise NotImplementedError
62 |
63 | return train_set, test_set_seen, test_set_unseen
64 |
65 |
66 | def get_ood_dataset(dataset):
67 | """Get datasets for setting 2 (OOD Detection on Different Datasets)."""
68 |
69 | if dataset == 'SVHN':
70 | dir = os.path.join(root_path, 'svhn')
71 | data = datasets.SVHN(root=dir, split='test', download=True, transform=transforms.ToTensor())
72 | elif dataset == 'LSUN':
73 | dir = os.path.join(root_path, 'LSUN_resize')
74 | data = datasets.ImageFolder(dir, transform=transforms.ToTensor())
75 | elif dataset == 'tinyImageNet':
76 | dir = os.path.join(root_path, 'Imagenet_resize')
77 | data = datasets.ImageFolder(dir, transform=transforms.ToTensor())
78 | elif dataset == 'LSUN-FIX':
79 | dir = os.path.join(root_path, 'LSUN_fix')
80 | data = datasets.ImageFolder(dir, transform=transforms.ToTensor())
81 | elif dataset == 'ImageNet-FIX':
82 | dir = os.path.join(root_path, 'Imagenet_fix')
83 | data = datasets.ImageFolder(dir, transform=transforms.ToTensor())
84 | elif dataset == 'CIFAR100':
85 | dir = os.path.join(root_path, 'cifar100')
86 | data = datasets.CIFAR100(
87 | root=dir, train=False, transform=transforms.ToTensor(), download=True)
88 | else:
89 | raise NotImplementedError
90 |
91 | return data
92 |
--------------------------------------------------------------------------------
/models/conv/nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class ConvLayers(nn.Module):
7 | """Convolutional feature extractor model for (natural) images."""
8 |
9 | def __init__(self, image_channels):
10 | super(ConvLayers, self).__init__()
11 | self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=(4, 4), stride=2, padding=(15, 15))
12 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(4, 4), stride=2, padding=(15, 15))
13 | self.out_channels = 128
14 | self.out_feature_dim = 128 * 28 * 28
15 |
16 | def forward(self, x):
17 | x = F.relu(self.conv1(x))
18 | feature = F.relu(self.conv2(x))
19 |
20 | return feature
21 |
22 |
23 | class DeconvLayers(nn.Module):
24 | """'Deconvolutional' feature decoder model for (natural) images."""
25 |
26 | def __init__(self, image_channels):
27 | super(DeconvLayers, self).__init__()
28 | self.image_channels = image_channels
29 | self.in_channel = 128
30 | self.in_size = 7
31 | self.in_feature_dim = 7 * 7 * 128
32 | self.deconv1 = nn.ConvTranspose2d(
33 | in_channels=128, out_channels=64, kernel_size=4, padding=1, stride=2)
34 | self.deconv2 = nn.ConvTranspose2d(
35 | in_channels=64, out_channels=self.image_channels, kernel_size=4, padding=1, stride=2)
36 |
37 | def forward(self, x):
38 | x = F.relu(self.deconv1(x))
39 | x = torch.sigmoid(self.deconv2(x))
40 |
41 | return x
42 |
43 |
44 | # ---------------------------------------------------------------------------------------------------
45 |
46 |
47 | class ResBlock(nn.Module):
48 | """
49 | Input: [batch_size] x [dim] x [image_size] x [image_size] tensor
50 | Output: [batch_size] x [dim] x [image_size] x [image_size] tensor
51 | """
52 |
53 | def __init__(self, dim):
54 | super().__init__()
55 | self.block = nn.Sequential(
56 | nn.ReLU(True),
57 | nn.Conv2d(dim, dim, 3, 1, 1),
58 | nn.BatchNorm2d(dim),
59 | nn.ReLU(True),
60 | nn.Conv2d(dim, dim, 1),
61 | nn.BatchNorm2d(dim)
62 | )
63 |
64 | def forward(self, x):
65 | return x + self.block(x)
66 |
67 |
68 | class BasicBlock(nn.Module):
69 | expansion = 1
70 |
71 | def __init__(self, in_planes, planes, stride=1):
72 | super(BasicBlock, self).__init__()
73 | self.conv1 = nn.Conv2d(
74 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
75 | self.bn1 = nn.BatchNorm2d(planes)
76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
77 | stride=1, padding=1, bias=False)
78 | self.bn2 = nn.BatchNorm2d(planes)
79 |
80 | self.shortcut = nn.Sequential()
81 | if stride != 1 or in_planes != self.expansion * planes:
82 | self.shortcut = nn.Sequential(
83 | nn.Conv2d(in_planes, self.expansion * planes,
84 | kernel_size=1, stride=stride, bias=False),
85 | nn.BatchNorm2d(self.expansion * planes)
86 | )
87 |
88 | def forward(self, x):
89 | out = F.relu(self.bn1(self.conv1(x)))
90 | out = self.bn2(self.conv2(out))
91 | out += self.shortcut(x)
92 | out = F.relu(out)
93 | return out
94 |
95 |
96 | class ResNet(nn.Module):
97 | def __init__(self, block, num_blocks, channel_num=3):
98 | super(ResNet, self).__init__()
99 | self.in_planes = 64
100 |
101 | self.conv1 = nn.Conv2d(channel_num, 64, kernel_size=3,
102 | stride=1, padding=1, bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
105 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
106 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
107 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
108 | self.out_channels = 512
109 | self.out_feature_dim = 512 * block.expansion
110 |
111 | def _make_layer(self, block, planes, num_blocks, stride):
112 | strides = [stride] + [1] * (num_blocks - 1)
113 | layers = []
114 | for stride in strides:
115 | layers.append(block(self.in_planes, planes, stride))
116 | self.in_planes = planes * block.expansion
117 | return nn.Sequential(*layers)
118 |
119 | def forward(self, x):
120 | out = F.relu(self.bn1(self.conv1(x)))
121 | out = self.layer1(out)
122 | out = self.layer2(out)
123 | out = self.layer3(out)
124 | out = self.layer4(out)
125 | out = F.avg_pool2d(out, 4)
126 | return out
127 |
128 |
129 | def ResNet18(channel_num=3):
130 | return ResNet(BasicBlock, [2, 2, 2, 2], channel_num=channel_num)
131 |
132 |
133 | class DeconvResnet(nn.Module):
134 | """'Deconvolutional' feature decoder model for (natural) images using ResBlock as the backbone"""
135 |
136 | def __init__(self, channel_num, dim=512):
137 | super(DeconvResnet, self).__init__()
138 | self.image_channels = channel_num
139 | self.in_channel = dim
140 | self.in_size = 4
141 | self.in_feature_dim = dim * 4 * 4
142 | self.decoder = nn.Sequential(
143 | ResBlock(dim),
144 | ResBlock(dim),
145 |
146 | nn.ReLU(True),
147 | nn.ConvTranspose2d(dim, dim, 4, 2, 1),
148 |
149 | nn.BatchNorm2d(dim),
150 | nn.ReLU(True),
151 | nn.ConvTranspose2d(dim, dim, 4, 2, 1),
152 |
153 | nn.BatchNorm2d(dim),
154 | nn.ReLU(True),
155 | nn.ConvTranspose2d(dim, channel_num, 4, 2, 1)
156 | )
157 |
158 | def forward(self, x):
159 | x = self.decoder(x)
160 | x = torch.sigmoid(x)
161 |
162 | return x
163 |
--------------------------------------------------------------------------------
/CSI+CMG/models/conv/nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class ConvLayers(nn.Module):
7 | """Convolutional feature extractor model for (natural) images."""
8 |
9 | def __init__(self, image_channels):
10 | super(ConvLayers, self).__init__()
11 | self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=(4, 4), stride=2, padding=(15, 15))
12 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(4, 4), stride=2, padding=(15, 15))
13 | self.out_channels = 128
14 | self.out_feature_dim = 128 * 28 * 28
15 |
16 | def forward(self, x):
17 | x = F.relu(self.conv1(x))
18 | feature = F.relu(self.conv2(x))
19 |
20 | return feature
21 |
22 |
23 | class DeconvLayers(nn.Module):
24 | """'Deconvolutional' feature decoder model for (natural) images."""
25 |
26 | def __init__(self, image_channels):
27 | super(DeconvLayers, self).__init__()
28 | self.image_channels = image_channels
29 | self.in_channel = 128
30 | self.in_size = 7
31 | self.in_feature_dim = 7 * 7 * 128
32 | self.deconv1 = nn.ConvTranspose2d(
33 | in_channels=128, out_channels=64, kernel_size=4, padding=1, stride=2)
34 | self.deconv2 = nn.ConvTranspose2d(
35 | in_channels=64, out_channels=self.image_channels, kernel_size=4, padding=1, stride=2)
36 |
37 | def forward(self, x):
38 | x = F.relu(self.deconv1(x))
39 | x = torch.sigmoid(self.deconv2(x))
40 |
41 | return x
42 |
43 |
44 | # ---------------------------------------------------------------------------------------------------
45 |
46 |
47 | class ResBlock(nn.Module):
48 | """
49 | Input: [batch_size] x [dim] x [image_size] x [image_size] tensor
50 | Output: [batch_size] x [dim] x [image_size] x [image_size] tensor
51 | """
52 |
53 | def __init__(self, dim):
54 | super().__init__()
55 | self.block = nn.Sequential(
56 | nn.ReLU(True),
57 | nn.Conv2d(dim, dim, 3, 1, 1),
58 | nn.BatchNorm2d(dim),
59 | nn.ReLU(True),
60 | nn.Conv2d(dim, dim, 1),
61 | nn.BatchNorm2d(dim)
62 | )
63 |
64 | def forward(self, x):
65 | return x + self.block(x)
66 |
67 |
68 | class BasicBlock(nn.Module):
69 | expansion = 1
70 |
71 | def __init__(self, in_planes, planes, stride=1):
72 | super(BasicBlock, self).__init__()
73 | self.conv1 = nn.Conv2d(
74 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
75 | self.bn1 = nn.BatchNorm2d(planes)
76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
77 | stride=1, padding=1, bias=False)
78 | self.bn2 = nn.BatchNorm2d(planes)
79 |
80 | self.shortcut = nn.Sequential()
81 | if stride != 1 or in_planes != self.expansion * planes:
82 | self.shortcut = nn.Sequential(
83 | nn.Conv2d(in_planes, self.expansion * planes,
84 | kernel_size=1, stride=stride, bias=False),
85 | nn.BatchNorm2d(self.expansion * planes)
86 | )
87 |
88 | def forward(self, x):
89 | out = F.relu(self.bn1(self.conv1(x)))
90 | out = self.bn2(self.conv2(out))
91 | out += self.shortcut(x)
92 | out = F.relu(out)
93 | return out
94 |
95 |
96 | class ResNet(nn.Module):
97 | def __init__(self, block, num_blocks, channel_num=3):
98 | super(ResNet, self).__init__()
99 | self.in_planes = 64
100 |
101 | self.conv1 = nn.Conv2d(channel_num, 64, kernel_size=3,
102 | stride=1, padding=1, bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
105 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
106 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
107 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
108 | self.out_channels = 512
109 | self.out_feature_dim = 512 * block.expansion
110 |
111 | def _make_layer(self, block, planes, num_blocks, stride):
112 | strides = [stride] + [1] * (num_blocks - 1)
113 | layers = []
114 | for stride in strides:
115 | layers.append(block(self.in_planes, planes, stride))
116 | self.in_planes = planes * block.expansion
117 | return nn.Sequential(*layers)
118 |
119 | def forward(self, x):
120 | out = F.relu(self.bn1(self.conv1(x)))
121 | out = self.layer1(out)
122 | out = self.layer2(out)
123 | out = self.layer3(out)
124 | out = self.layer4(out)
125 | out = F.avg_pool2d(out, 4)
126 | return out
127 |
128 |
129 | def ResNet18(channel_num=3):
130 | return ResNet(BasicBlock, [2, 2, 2, 2], channel_num=channel_num)
131 |
132 |
133 | class DeconvResnet(nn.Module):
134 | """'Deconvolutional' feature decoder model for (natural) images using ResBlock as the backbone"""
135 |
136 | def __init__(self, channel_num, dim=512):
137 | super(DeconvResnet, self).__init__()
138 | self.image_channels = channel_num
139 | self.in_channel = dim
140 | self.in_size = 4
141 | self.in_feature_dim = dim * 4 * 4
142 | self.decoder = nn.Sequential(
143 | ResBlock(dim),
144 | ResBlock(dim),
145 |
146 | nn.ReLU(True),
147 | nn.ConvTranspose2d(dim, dim, 4, 2, 1),
148 |
149 | nn.BatchNorm2d(dim),
150 | nn.ReLU(True),
151 | nn.ConvTranspose2d(dim, dim, 4, 2, 1),
152 |
153 | nn.BatchNorm2d(dim),
154 | nn.ReLU(True),
155 | nn.ConvTranspose2d(dim, channel_num, 4, 2, 1)
156 | )
157 |
158 | def forward(self, x):
159 | x = self.decoder(x)
160 | x = torch.sigmoid(x)
161 |
162 | return x
163 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """CMG Stage 1: IND classifier building & CVAE training."""
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from torchvision import datasets, transforms
9 |
10 | from datasets.datasetsHelper import get_dataset
11 | from models.main_model import MainModel
12 | from models.vae import ConditionalVAE
13 | from training.train_classifier import train_classifier
14 | from training.train_cvae import train_cvae
15 | from utils import get_args
16 |
17 | args = get_args()
18 |
19 |
20 | def setup_seed(seed):
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | np.random.seed(seed)
24 | random.seed(seed)
25 | torch.backends.cudnn.deterministic = True
26 |
27 |
28 | setup_seed(args.seed)
29 |
30 | batch_size = 512
31 |
32 | # gpu
33 | device = torch.device(args.device if torch.cuda.is_available() else "cpu")
34 | print(device)
35 |
36 |
37 | def main():
38 | # prepare data
39 | if args.task == 'same_dataset_mnist':
40 | if args.partition == 'partition1':
41 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='012345')
42 | elif args.partition == 'partition2':
43 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='123456')
44 | elif args.partition == 'partition3':
45 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='234567')
46 | elif args.partition == 'partition4':
47 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='345678')
48 | elif args.partition == 'partition5':
49 | train_data, _, _ = get_dataset('mnist', transforms.ToTensor(), transforms.ToTensor(), seen='456789')
50 | else:
51 | raise NotImplementedError
52 | elif args.task == 'same_dataset_cifar10':
53 | if args.command == 'train_classifier':
54 | train_transform = transforms.Compose([
55 | transforms.RandomCrop(32, padding=4),
56 | transforms.RandomHorizontalFlip(),
57 | transforms.ToTensor(),
58 | ])
59 | elif args.command == 'train_cvae':
60 | train_transform = transforms.ToTensor()
61 |
62 | if args.partition == 'partition1':
63 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform,
64 | test_transform=transforms.ToTensor(), seen='012345')
65 | elif args.partition == 'partition2':
66 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform,
67 | test_transform=transforms.ToTensor(), seen='123456')
68 | elif args.partition == 'partition3':
69 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform,
70 | test_transform=transforms.ToTensor(), seen='234567')
71 | elif args.partition == 'partition4':
72 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform,
73 | test_transform=transforms.ToTensor(), seen='345678')
74 | elif args.partition == 'partition5':
75 | train_data, _, _ = get_dataset('cifar10', train_transform=train_transform,
76 | test_transform=transforms.ToTensor(), seen='456789')
77 | else:
78 | raise NotImplementedError
79 | elif args.task == 'different_dataset':
80 | if args.command == 'train_classifier':
81 | train_transform = transforms.Compose([
82 | transforms.RandomCrop(32, padding=4),
83 | transforms.RandomHorizontalFlip(),
84 | transforms.ToTensor(),
85 | ])
86 | elif args.command == 'train_cvae':
87 | train_transform = transforms.ToTensor()
88 |
89 | root_path = os.path.dirname(__file__)
90 | data_path = os.path.join(root_path, 'datasets/cifar10')
91 | train_data = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform)
92 | else:
93 | raise NotImplementedError
94 |
95 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
96 |
97 | # CMG stage 1
98 | if args.command == 'train_classifier':
99 | if args.task == 'same_dataset_mnist':
100 | model = MainModel(28, 1, 11, dataset='mnist')
101 | model.to(device)
102 | train_classifier(model, train_loader, device, args.params_dict_name, dataset='mnist')
103 | elif args.task == 'same_dataset_cifar10':
104 | model = MainModel(32, 3, 11, dataset='cifar10')
105 | model.to(device)
106 | train_classifier(model, train_loader, device, args.params_dict_name, dataset='cifar10')
107 | elif args.task == 'different_dataset':
108 | model = MainModel(32, 3, 110, dataset='cifar10')
109 | model.to(device)
110 | train_classifier(model, train_loader, device, args.params_dict_name, dataset='cifar10')
111 |
112 | elif args.command == 'train_cvae':
113 | if args.task == 'same_dataset_mnist':
114 | model = ConditionalVAE(image_channels=1, image_size=28, dataset='mnist')
115 | model.device = device
116 | model.to(device)
117 | train_cvae(model, train_loader, device, args.params_dict_name, dataset='mnist')
118 | elif args.task == 'same_dataset_cifar10':
119 | model = ConditionalVAE(image_channels=3, image_size=32, dataset='cifar10')
120 | model.device = device
121 | model.to(device)
122 | train_cvae(model, train_loader, device, args.params_dict_name, dataset='cifar10')
123 | elif args.task == 'different_dataset':
124 | model = ConditionalVAE(image_channels=3, image_size=32, dataset='cifar10')
125 | model.device = device
126 | model.to(device)
127 | train_cvae(model, train_loader, device, args.params_dict_name, dataset='cifar10')
128 |
129 | else:
130 | raise NotImplementedError
131 |
132 |
133 | if __name__ == '__main__':
134 | main()
135 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # CMG: A Class-Mixed Generation Approach to Out-of-Distribution Detection
2 |
3 | This repository contains the code for our ECML'22 paper [CMG: A Class-Mixed Generation Approach to Out-of-Distribution
4 | Detection](https://2022.ecmlpkdd.org/wp-content/uploads/2022/09/sub_531.pdf) by Mengyu Wang*, Yijia Shao*, Haowei Lin, Wenpeng Hu, and Bing Liu.
5 |
6 | ## Overview
7 |
8 | We propose CMG (*Class-Mixed Generation*) for efficient out-of-distribution (OOD) detection. CMG uses CVAE to generate
9 | pseudo-OOD training samples based on class-mixed information in the latent space. The pseudo-OOD data are used to
10 | fine-tune a classifier in a 2-stage manner. By using different loss functions, we propose two versions of the CMG
11 | system, CMG-softmax (CMG-s) and CMG-energy (CMG-e). The figure below illustrates our CMG system.
12 |
13 |
14 |
15 | ## Requirements
16 |
17 | ### Environments
18 |
19 | The required packages are as follows:
20 |
21 | - python 3.5
22 | - torch 1.2
23 | - torchvision 0.4
24 | - CUDA 10.0
25 | - scikit-learn 0.22
26 |
27 | ### Datasets
28 |
29 | For Setting 1 - OOD Detection on the Same Dataset, this repository supports MNIST and CIFAR10 which can be directly
30 | downloaded through `torchvision`.
31 |
32 | For Setting 2 - OOD Detection on Different Datasets, this repository supports CIFAR10 as IND data and SVHN / LSUN /
33 | tinyImageNet / LSUN-FIX / ImageNet-FIX / CIFAR-100 as OOD data. SVHN and CIFAR 100 can be directly downloaded
34 | through `torchvision`. The remaining data have been processed by the CSI paper and you can download
35 | them [here](https://github.com/alinlab/CSI).
36 |
37 | Please download datasets to [./datasets](./datasets) and rename the file. See [./datasets/datasetsHelper.py](./datasets/datasetsHelper.py) and our paper for more details.
38 |
39 | ## Training CMG
40 |
41 | ### CMG Stage 1
42 |
43 | CMG Stage 1 involves IND classifier building and CVAE training.
44 |
45 | The standard code is running with a single GPU and you can assign a specific GPU in the command line.
46 |
47 | For more details, please view [utils.py](./utils.py).
48 |
49 | #### Train IND classifier
50 |
51 | To train IND classifier on MNIST for Setting 1, run this command:
52 |
53 | ```
54 | python -m train \
55 | --task 'same_dataset_mnist' \
56 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \
57 | --command 'train_classifier' \
58 | --device {the available GPU in your cluser, e.g., cuda:0} \
59 | --params-dict-name {checkpoint name, e.g., './ckpt/main_model_partition1.pkl'}
60 | ```
61 |
62 | To train IND classifier on CIFAR10 for Setting 1, run this command:
63 |
64 | ```
65 | python -m train \
66 | --task 'same_dataset_cifar10' \
67 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \
68 | --command 'train_classifier' \
69 | --device {the available GPU in your cluser, e.g., cuda:0} \
70 | --params-dict-name {checkpoint name, e.g., './ckpt/main_model_partition1.pkl'}
71 | ```
72 |
73 | To train IND classifier on CIFAR10 for Setting 2, run this command:
74 |
75 | ```
76 | python -m train \
77 | --task 'different_dataset' \
78 | --command 'train_classifier' \
79 | --device {the available GPU in your cluser, e.g., cuda:0} \
80 | --params-dict-name {checkpoint name, e.g., './ckpt/main_model_different_dataset.pkl'}
81 | ```
82 |
83 | #### Train CVAE
84 |
85 | To train CVAE on MNIST for Setting 1, run this command:
86 |
87 | ```
88 | python -m train \
89 | --task 'same_dataset_mnist' \
90 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \
91 | --command 'train_cvae' \
92 | --device {the available GPU in your cluser, e.g., cuda:0} \
93 | --params-dict-name {checkpoint name, e.g., './ckpt/cvae_partition1.pkl'}
94 | ```
95 |
96 | To train CVAE on CIFAR10 for Setting 1, run this command:
97 |
98 | ```
99 | python -m train \
100 | --task 'same_dataset_cifar10' \
101 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \
102 | --command 'train_cvae' \
103 | --device {the available GPU in your cluser, e.g., cuda:0} \
104 | --params-dict-name {checkpoint name, e.g., './ckpt/cvae_partition1.pkl'}
105 | ```
106 |
107 | To train CVAE on CIFAR10 for Setting 2, run this command:
108 |
109 | ```
110 | python -m train \
111 | --task 'different_dataset' \
112 | --command 'train_cvae' \
113 | --device {the available GPU in your cluser, e.g., cuda:0} \
114 | --params-dict-name {checkpoint name, e.g., './ckpt/cvae_different_dataset.pkl'}
115 | ```
116 |
117 |
118 | ### CMG Stage 2 and Evaluation
119 |
120 | To perform CMG Stage 2 and get the final result on MNIST for Setting 1, run this command:
121 |
122 | ```
123 | python -m eval \
124 | --task 'same_dataset_mnist' \
125 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \
126 | --device {the available GPU in your cluser, e.g., cuda:0} \
127 | --params-dict-name {main model checkpoint name} \
128 | --params-dict-name2 {cvae checkpoint name} \
129 | --mode {'CMG-energy'/'CMG-softmax'}
130 | ```
131 |
132 | To perform CMG Stage 2 and get the final result on CIFAR10 for Setting 1, run this command:
133 |
134 | ```
135 | python -m eval \
136 | --task 'same_dataset_cifar10' \
137 | --partition {'partition1'/'partition2'/'partition3'/'partition4'/'partition5'} \
138 | --device {the available GPU in your cluser, e.g., cuda:0} \
139 | --params-dict-name {main model checkpoint name} \
140 | --params-dict-name2 {cvae checkpoint name} \
141 | --mode {'CMG-energy'/'CMG-softmax'}
142 | ```
143 |
144 | To perform CMG Stage 2 and get the final result on CIFAR10 for Setting 2, run this command:
145 |
146 | ```
147 | python -m eval \
148 | --task 'different_dataset' \
149 | --ood-dataset {'SVHN'/'LSUN'/'LSUN-FIX'/'tinyImageNet'/'ImageNet-FIX'/'CIFAR100'}
150 | --device {the available GPU in your cluser, e.g., cuda:0} \
151 | --params-dict-name {main model checkpoint name} \
152 | --params-dict-name2 {cvae checkpoint name} \
153 | --mode {'CMG-energy'/'CMG-softmax'}
154 | ```
155 |
156 | ## Apply CMG to CSI
157 |
158 | CMG is a training paradigm orthogonal to existing OOD detection models and can enhance existing systems to further
159 | improve their performance (please see our paper for details).
160 |
161 | Here we also provide codes to reproduce the new SOTA by applying CMG to [CSI](https://github.com/alinlab/CSI).
162 | See [./CSI+CMG](./CSI+CMG).
163 |
164 | ## Acknowledgements
165 |
166 | We thank [CSI](https://github.com/alinlab/CSI) for providing downloaded links for their processed data and our "CSI+CMG"
167 | code is also based on their implementation.
168 |
169 | ## Citation
170 |
171 | Please cite our paper if you use this code of parts of it:
172 | ```
173 | @article{wangcmg,
174 | title={CMG: A Class-Mixed Generation Approach to Out-of-Distribution Detection},
175 | author={Wang, Mengyu and Shao, Yijia and Lin, Haowei and Hu, Wenpeng and Liu, Bing}
176 | ```
177 |
--------------------------------------------------------------------------------
/CSI+CMG/CSI_model/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | BasicBlock and Bottleneck module is from the original ResNet paper:
3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
5 | PreActBlock and PreActBottleneck module is from the later paper:
6 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
8 | '''
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from models_copy.base_model import BaseModel
13 | from models_copy.transform_layers import NormalizeLayer
14 |
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | expansion = 1
22 |
23 | def __init__(self, in_planes, planes, stride=1):
24 | super(BasicBlock, self).__init__()
25 | self.conv1 = conv3x3(in_planes, planes, stride)
26 | self.conv2 = conv3x3(planes, planes)
27 | self.bn1 = nn.BatchNorm2d(planes)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 |
30 | self.shortcut = nn.Sequential()
31 | if stride != 1 or in_planes != self.expansion * planes:
32 | self.shortcut = nn.Sequential(
33 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
34 | nn.BatchNorm2d(self.expansion * planes)
35 | )
36 |
37 | def forward(self, x):
38 | out = F.relu(self.bn1(self.conv1(x)))
39 | out = self.bn2(self.conv2(out))
40 | out += self.shortcut(x)
41 | out = F.relu(out)
42 | return out
43 |
44 |
45 | class PreActBlock(nn.Module):
46 | '''Pre-activation version of the BasicBlock.'''
47 | expansion = 1
48 |
49 | def __init__(self, in_planes, planes, stride=1):
50 | super(PreActBlock, self).__init__()
51 | self.conv1 = conv3x3(in_planes, planes, stride)
52 | self.conv2 = conv3x3(planes, planes)
53 | self.bn1 = nn.BatchNorm2d(in_planes)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 |
56 | self.shortcut = nn.Sequential()
57 | if stride != 1 or in_planes != self.expansion * planes:
58 | self.shortcut = nn.Sequential(
59 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
60 | )
61 |
62 | def forward(self, x):
63 | out = F.relu(self.bn1(x))
64 | shortcut = self.shortcut(out)
65 | out = self.conv1(out)
66 | out = self.conv2(F.relu(self.bn2(out)))
67 | out += shortcut
68 | return out
69 |
70 |
71 | class Bottleneck(nn.Module):
72 | expansion = 4
73 |
74 | def __init__(self, in_planes, planes, stride=1):
75 | super(Bottleneck, self).__init__()
76 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
77 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
78 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
79 | self.bn1 = nn.BatchNorm2d(planes)
80 | self.bn2 = nn.BatchNorm2d(planes)
81 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
82 |
83 | self.shortcut = nn.Sequential()
84 | if stride != 1 or in_planes != self.expansion * planes:
85 | self.shortcut = nn.Sequential(
86 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
87 | nn.BatchNorm2d(self.expansion * planes)
88 | )
89 |
90 | def forward(self, x):
91 | out = F.relu(self.bn1(self.conv1(x)))
92 | out = F.relu(self.bn2(self.conv2(out)))
93 | out = self.bn3(self.conv3(out))
94 | out += self.shortcut(x)
95 | out = F.relu(out)
96 | return out
97 |
98 |
99 | class PreActBottleneck(nn.Module):
100 | '''Pre-activation version of the original Bottleneck module.'''
101 | expansion = 4
102 |
103 | def __init__(self, in_planes, planes, stride=1):
104 | super(PreActBottleneck, self).__init__()
105 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
107 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(in_planes)
109 | self.bn2 = nn.BatchNorm2d(planes)
110 | self.bn3 = nn.BatchNorm2d(planes)
111 |
112 | self.shortcut = nn.Sequential()
113 | if stride != 1 or in_planes != self.expansion * planes:
114 | self.shortcut = nn.Sequential(
115 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
116 | )
117 |
118 | def forward(self, x):
119 | out = F.relu(self.bn1(x))
120 | shortcut = self.shortcut(out)
121 | out = self.conv1(out)
122 | out = self.conv2(F.relu(self.bn2(out)))
123 | out = self.conv3(F.relu(self.bn3(out)))
124 | out += shortcut
125 | return out
126 |
127 |
128 | class ResNet(BaseModel):
129 | def __init__(self, block, num_blocks, num_classes=10):
130 | last_dim = 512 * block.expansion
131 | super(ResNet, self).__init__(last_dim, num_classes)
132 |
133 | self.in_planes = 64
134 | self.last_dim = last_dim
135 |
136 | self.normalize = NormalizeLayer()
137 |
138 | self.conv1 = conv3x3(3, 64)
139 | self.bn1 = nn.BatchNorm2d(64)
140 |
141 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
142 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
143 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
144 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
145 |
146 | def _make_layer(self, block, planes, num_blocks, stride):
147 | strides = [stride] + [1] * (num_blocks - 1)
148 | layers = []
149 | for stride in strides:
150 | layers.append(block(self.in_planes, planes, stride))
151 | self.in_planes = planes * block.expansion
152 | return nn.Sequential(*layers)
153 |
154 | def penultimate(self, x, all_features=False):
155 | out_list = []
156 |
157 | out = self.normalize(x)
158 | out = self.conv1(out)
159 | out = self.bn1(out)
160 | out = F.relu(out)
161 | out_list.append(out)
162 |
163 | out = self.layer1(out)
164 | out_list.append(out)
165 | out = self.layer2(out)
166 | out_list.append(out)
167 | out = self.layer3(out)
168 | out_list.append(out)
169 | out = self.layer4(out)
170 | out_list.append(out)
171 |
172 | out = F.avg_pool2d(out, 4)
173 | out = out.view(out.size(0), -1)
174 |
175 | if all_features:
176 | return out, out_list
177 | else:
178 | return out
179 |
180 |
181 | def ResNet18(num_classes):
182 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
183 |
184 |
185 | def ResNet34(num_classes):
186 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
187 |
188 |
189 | def ResNet50(num_classes):
190 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
191 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """CMG Stage 2: Fine-tuning the classification head using IND data and pseudo-OOD data generated by the CVAE."""
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from torchvision import datasets, transforms
9 |
10 | from datasets.datasetsHelper import get_dataset, get_ood_dataset
11 | from models.main_model import MainModel
12 | from models.vae import ConditionalVAE
13 | from training.fine_tune import fine_tune_same_dataset, fine_tune_different_dataset
14 | from utils import get_args
15 |
16 | args = get_args()
17 |
18 |
19 | def setup_seed(seed):
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 | np.random.seed(seed)
23 | random.seed(seed)
24 | torch.backends.cudnn.deterministic = True
25 |
26 |
27 | setup_seed(args.seed)
28 |
29 | batch_size = 128
30 |
31 | # gpu
32 | device = torch.device(args.device if torch.cuda.is_available() else "cpu")
33 | print(device)
34 |
35 |
36 | def main():
37 | if args.task == 'same_dataset_mnist':
38 | # prepare data
39 | if args.partition == 'partition1':
40 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(),
41 | transforms.ToTensor(), seen='012345')
42 | elif args.partition == 'partition2':
43 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(),
44 | transforms.ToTensor(), seen='123456')
45 | elif args.partition == 'partition3':
46 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(),
47 | transforms.ToTensor(), seen='234567')
48 | elif args.partition == 'partition4':
49 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(),
50 | transforms.ToTensor(), seen='345678')
51 | elif args.partition == 'partition5':
52 | train_data, test_data_seen, test_data_unseen = get_dataset('mnist', transforms.ToTensor(),
53 | transforms.ToTensor(), seen='456789')
54 | else:
55 | raise NotImplementedError
56 |
57 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
58 | test_loader_seen = DataLoader(test_data_seen, batch_size=512, num_workers=8)
59 | test_loader_unseen = DataLoader(test_data_unseen, batch_size=512, num_workers=8)
60 |
61 | # prepare model
62 | classifier = MainModel(28, 1, 11, dataset='mnist')
63 | classifier.load_state_dict(torch.load(args.params_dict_name, map_location='cpu'))
64 | classifier.to(device)
65 | vae = ConditionalVAE(image_channels=1, image_size=28, dataset='mnist')
66 | vae.load_state_dict(torch.load(args.params_dict_name2, map_location='cpu'))
67 | vae.to(device)
68 | vae.device = device
69 |
70 | # CMG Stage 2
71 | result = fine_tune_same_dataset(
72 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, dataset='mnist',
73 | mode=args.mode)
74 |
75 | print('{}, same dataset mnist {}: max roc auc = {}'.format(args.mode, args.partition, result))
76 |
77 | elif args.task == 'same_dataset_cifar10':
78 | # prepare data
79 | if args.partition == 'partition1':
80 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(),
81 | transforms.ToTensor(), seen='012345')
82 | elif args.partition == 'partition2':
83 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(),
84 | transforms.ToTensor(), seen='123456')
85 | elif args.partition == 'partition3':
86 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(),
87 | transforms.ToTensor(), seen='234567')
88 | elif args.partition == 'partition4':
89 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(),
90 | transforms.ToTensor(), seen='345678')
91 | elif args.partition == 'partition5':
92 | train_data, test_data_seen, test_data_unseen = get_dataset('cifar10', transforms.ToTensor(),
93 | transforms.ToTensor(), seen='456789')
94 | else:
95 | raise NotImplementedError
96 |
97 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
98 | test_loader_seen = DataLoader(test_data_seen, batch_size=512, num_workers=8)
99 | test_loader_unseen = DataLoader(test_data_unseen, batch_size=512, num_workers=8)
100 |
101 | # prepare model
102 | classifier = MainModel(32, 3, 11, dataset='cifar10')
103 | classifier.load_state_dict(torch.load(args.params_dict_name, map_location='cpu'))
104 | classifier.to(device)
105 | vae = ConditionalVAE(image_channels=3, image_size=32, dataset='cifar10')
106 | vae.load_state_dict(torch.load(args.params_dict_name2, map_location='cpu'))
107 | vae.to(device)
108 | vae.device = device
109 |
110 | # CMG Stage 2
111 | result = fine_tune_same_dataset(
112 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, dataset='cifar10',
113 | mode=args.mode)
114 |
115 | print('{}, same dataset cifar10 {}: max roc auc = {}'.format(args.mode, args.partition, result))
116 |
117 | elif args.task == 'different_dataset':
118 | # prepare data
119 | root_path = os.path.dirname(__file__)
120 | data_path = os.path.join(root_path, 'datasets/cifar10')
121 | train_data = datasets.CIFAR10(root=data_path, train=True, transform=transforms.ToTensor())
122 | test_data_seen = datasets.CIFAR10(root=data_path, train=False, transform=transforms.ToTensor())
123 | test_data_unseen = get_ood_dataset(args.ood_dataset)
124 |
125 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
126 | test_loader_seen = DataLoader(test_data_seen, batch_size=512, num_workers=8)
127 | test_loader_unseen = DataLoader(test_data_unseen, batch_size=512, num_workers=8)
128 |
129 | # prepare model
130 | classifier = MainModel(32, 3, 110, dataset='cifar10')
131 | classifier.load_state_dict(torch.load(args.params_dict_name, map_location='cpu'))
132 | classifier.to(device)
133 | vae = ConditionalVAE(image_channels=3, image_size=32, dataset='cifar10')
134 | vae.load_state_dict(torch.load(args.params_dict_name2, map_location='cpu'))
135 | vae.to(device)
136 | vae.device = device
137 |
138 | # CMG Stage 2
139 | result = fine_tune_different_dataset(
140 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, mode=args.mode)
141 |
142 | print('{}, different dataset, ood dataset {}: max roc auc = {}'.format(args.mode, args.ood_dataset, result))
143 |
144 | else:
145 | raise NotImplementedError
146 |
147 |
148 | if __name__ == '__main__':
149 | main()
150 |
--------------------------------------------------------------------------------
/CSI+CMG/CSI_model/resnet_imagenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models_copy.base_model import BaseModel
5 | from models_copy.transform_layers import NormalizeLayer
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9 | """3x3 convolution with padding"""
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=dilation, groups=groups, bias=False, dilation=dilation)
12 |
13 |
14 | def conv1x1(in_planes, out_planes, stride=1):
15 | """1x1 convolution"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
23 | base_width=64, dilation=1, norm_layer=None):
24 | super(BasicBlock, self).__init__()
25 | if norm_layer is None:
26 | norm_layer = nn.BatchNorm2d
27 | if groups != 1 or base_width != 64:
28 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
29 | if dilation > 1:
30 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
31 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = norm_layer(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = norm_layer(planes)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | identity = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
61 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
62 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
63 | # This variant is also known as ResNet V1.5 and improves accuracy according to
64 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
65 |
66 | expansion = 4
67 |
68 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
69 | base_width=64, dilation=1, norm_layer=None):
70 | super(Bottleneck, self).__init__()
71 | if norm_layer is None:
72 | norm_layer = nn.BatchNorm2d
73 | width = int(planes * (base_width / 64.)) * groups
74 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
75 | self.conv1 = conv1x1(inplanes, width)
76 | self.bn1 = norm_layer(width)
77 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
78 | self.bn2 = norm_layer(width)
79 | self.conv3 = conv1x1(width, planes * self.expansion)
80 | self.bn3 = norm_layer(planes * self.expansion)
81 | self.relu = nn.ReLU(inplace=True)
82 | self.downsample = downsample
83 | self.stride = stride
84 |
85 | def forward(self, x):
86 | identity = x
87 |
88 | out = self.conv1(x)
89 | out = self.bn1(out)
90 | out = self.relu(out)
91 |
92 | out = self.conv2(out)
93 | out = self.bn2(out)
94 | out = self.relu(out)
95 |
96 | out = self.conv3(out)
97 | out = self.bn3(out)
98 |
99 | if self.downsample is not None:
100 | identity = self.downsample(x)
101 |
102 | out += identity
103 | out = self.relu(out)
104 |
105 | return out
106 |
107 |
108 | class ResNet(BaseModel):
109 | def __init__(self, block, layers, num_classes=10,
110 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None,
111 | norm_layer=None):
112 | last_dim = 512 * block.expansion
113 | super(ResNet, self).__init__(last_dim, num_classes)
114 | if norm_layer is None:
115 | norm_layer = nn.BatchNorm2d
116 | self._norm_layer = norm_layer
117 |
118 | self.inplanes = 64
119 | self.dilation = 1
120 | if replace_stride_with_dilation is None:
121 | # each element in the tuple indicates if we should replace
122 | # the 2x2 stride with a dilated convolution instead
123 | replace_stride_with_dilation = [False, False, False]
124 | if len(replace_stride_with_dilation) != 3:
125 | raise ValueError("replace_stride_with_dilation should be None "
126 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
127 | self.groups = groups
128 | self.base_width = width_per_group
129 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
130 | bias=False)
131 | self.bn1 = norm_layer(self.inplanes)
132 | self.relu = nn.ReLU(inplace=True)
133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
134 | self.layer1 = self._make_layer(block, 64, layers[0])
135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
136 | dilate=replace_stride_with_dilation[0])
137 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
138 | dilate=replace_stride_with_dilation[1])
139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
140 | dilate=replace_stride_with_dilation[2])
141 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
142 | self.normalize = NormalizeLayer()
143 | self.last_dim = 512 * block.expansion
144 |
145 | for m in self.modules():
146 | if isinstance(m, nn.Conv2d):
147 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
148 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
149 | nn.init.constant_(m.weight, 1)
150 | nn.init.constant_(m.bias, 0)
151 |
152 | # Zero-initialize the last BN in each residual branch,
153 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
154 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
155 | if zero_init_residual:
156 | for m in self.modules():
157 | if isinstance(m, Bottleneck):
158 | nn.init.constant_(m.bn3.weight, 0)
159 | elif isinstance(m, BasicBlock):
160 | nn.init.constant_(m.bn2.weight, 0)
161 |
162 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
163 | norm_layer = self._norm_layer
164 | downsample = None
165 | previous_dilation = self.dilation
166 | if dilate:
167 | self.dilation *= stride
168 | stride = 1
169 | if stride != 1 or self.inplanes != planes * block.expansion:
170 | downsample = nn.Sequential(
171 | conv1x1(self.inplanes, planes * block.expansion, stride),
172 | norm_layer(planes * block.expansion),
173 | )
174 |
175 | layers = []
176 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
177 | self.base_width, previous_dilation, norm_layer))
178 | self.inplanes = planes * block.expansion
179 | for _ in range(1, blocks):
180 | layers.append(block(self.inplanes, planes, groups=self.groups,
181 | base_width=self.base_width, dilation=self.dilation,
182 | norm_layer=norm_layer))
183 |
184 | return nn.Sequential(*layers)
185 |
186 | def penultimate(self, x, all_features=False):
187 | # See note [TorchScript super()]
188 | out_list = []
189 |
190 | x = self.normalize(x)
191 | x = self.conv1(x)
192 | x = self.bn1(x)
193 | x = self.relu(x)
194 | x = self.maxpool(x)
195 | out_list.append(x)
196 |
197 | x = self.layer1(x)
198 | out_list.append(x)
199 | x = self.layer2(x)
200 | out_list.append(x)
201 | x = self.layer3(x)
202 | out_list.append(x)
203 | x = self.layer4(x)
204 | out_list.append(x)
205 |
206 | x = self.avgpool(x)
207 | x = torch.flatten(x, 1)
208 |
209 | if all_features:
210 | return x, out_list
211 | else:
212 | return x
213 |
214 |
215 | def _resnet(arch, block, layers, **kwargs):
216 | model = ResNet(block, layers, **kwargs)
217 | return model
218 |
219 |
220 | def resnet18(**kwargs):
221 | r"""ResNet-18 model from
222 | `"Deep Residual Learning for Image Recognition" `_
223 | """
224 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs)
225 |
226 |
227 | def resnet50(**kwargs):
228 | r"""ResNet-50 model from
229 | `"Deep Residual Learning for Image Recognition" `_
230 | """
231 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)
232 |
--------------------------------------------------------------------------------
/training/fine_tune.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from models.utils import loss_functions as lf
5 | from models.utils import score
6 |
7 |
8 | def generate_pseudo_data(vae, class_num, scalar, neg_item_per_batch, device):
9 | """Generates pseudo data for CMG stage 2."""
10 |
11 | # prepare for class embedding
12 | y1 = torch.Tensor(neg_item_per_batch, vae.class_num)
13 | y1.zero_()
14 | y2 = torch.Tensor(neg_item_per_batch, vae.class_num)
15 | y2.zero_()
16 | ind = torch.randint(0, class_num, (neg_item_per_batch, 1))
17 | ind2 = torch.randint(0, class_num, (neg_item_per_batch, 1))
18 | y1.scatter_(1, ind.view(-1, 1), 1)
19 | y2.scatter_(1, ind2.view(-1, 1), 1)
20 | y1 = y1.to(device)
21 | y2 = y2.to(device)
22 | class_embed1 = vae.class_embed(y1)
23 | class_embed2 = vae.class_embed(y2)
24 | rv = torch.randint(0, 2, class_embed1.shape).to(device)
25 | class_embed = torch.where(rv == 0, class_embed1, class_embed2)
26 |
27 | # sample in N(0, sigma^2)
28 | random_z = torch.randn(neg_item_per_batch, vae.z_dim).to(device) * scalar
29 |
30 | x_generate = vae.decode(random_z, class_embed)
31 |
32 | return x_generate
33 |
34 |
35 | def get_m(train_loader, classifier, vae, class_num, scalar, device):
36 | """Gets hyper-parameters for CMG-energy using training data."""
37 |
38 | print("=====get_m======")
39 | with torch.no_grad():
40 | Ec_out, Ec_in = None, None
41 | for data, target in train_loader:
42 | classifier.eval()
43 | classifier.classifier.train()
44 |
45 | data = data.to(device)
46 | prediction = classifier(data)
47 |
48 | x_generate = generate_pseudo_data(vae, class_num, scalar, 128, device)
49 |
50 | prediction_generate = classifier(x_generate)
51 |
52 | # calculate energy of the training data and generated negative data
53 | T = 1
54 | if Ec_in is None and Ec_out is None:
55 | Ec_in = -T * torch.logsumexp(prediction / T, dim=1)
56 | Ec_out = -T * torch.logsumexp(prediction_generate / T, dim=1)
57 | else:
58 | Ec_in = torch.cat((Ec_in, (-T * torch.logsumexp(prediction / T, dim=1))), dim=0)
59 | Ec_out = torch.cat((Ec_out, (-T * torch.logsumexp(prediction_generate / T, dim=1))), dim=0)
60 |
61 | Ec_in = Ec_in.sort()[0]
62 | Ec_out = Ec_out.sort()[0]
63 | in_size = Ec_in.size(0)
64 | out_size = Ec_out.size(0)
65 | m_in, m_out = Ec_in[int(in_size * 0.2)], Ec_out[int(out_size * 0.8)]
66 | print("m_in = ", m_in, ",m_out=", m_out)
67 |
68 | return m_in, m_out
69 |
70 |
71 | def fine_tune_same_dataset(
72 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, dataset='mnist',
73 | mode='CMG-energy'):
74 | """
75 | Performs CMG stage2 by fine-tuning the OOD detector using generated pseudo data on Setting 1.
76 | """
77 | for p in classifier.convE.parameters():
78 | p.requires_grad = False
79 |
80 | vae.eval()
81 | for p in vae.parameters():
82 | p.requires_grad = False
83 |
84 | n_epochs = 15
85 | LR = 0.0001
86 | if dataset == 'mnist':
87 | scalar = 3.0
88 | elif dataset == 'cifar10':
89 | scalar = 5.0
90 | else:
91 | raise NotImplementedError
92 | neg_item_per_batch = 128
93 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr=LR)
94 |
95 | if mode == 'CMG-energy':
96 | # define hyper-parameters for CMG-energy
97 | mu = 0.1
98 | m_in, m_out = get_m(train_loader, classifier, vae, 6, scalar, device)
99 |
100 | index = -1
101 | max_rocauc = 0
102 | for epoch in range(n_epochs):
103 | for data, target in train_loader:
104 |
105 | classifier.eval()
106 | classifier.classifier.train()
107 | optimizer.zero_grad()
108 |
109 | index += 1
110 | data = data.to(device)
111 | target = target.long().to(device)
112 | prediction = classifier(data)
113 | x_generate = generate_pseudo_data(vae, 6, scalar, neg_item_per_batch, device)
114 | prediction_generate = classifier(x_generate)
115 |
116 | if mode == 'CMG-softmax':
117 | loss_input = F.cross_entropy(prediction, target, reduction='none')
118 | loss_input = lf.weighted_average(loss_input, weights=None, dim=0)
119 |
120 | y_generate = torch.ones(neg_item_per_batch) * 6
121 | y_generate = y_generate.long().to(device)
122 |
123 | loss_generate = F.cross_entropy(prediction_generate[:, 0:7], y_generate)
124 |
125 | loss = loss_generate + loss_input
126 |
127 | loss.backward()
128 | optimizer.step()
129 |
130 | elif mode == 'CMG-energy':
131 | loss_ce = F.cross_entropy(prediction, target)
132 |
133 | Ec_in = - torch.logsumexp(prediction, dim=1)
134 | Ec_out = - torch.logsumexp(prediction_generate, dim=1)
135 | loss_energy = torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean()
136 |
137 | loss = loss_ce + mu * loss_energy
138 |
139 | loss.backward()
140 | optimizer.step()
141 |
142 | else:
143 | raise NotImplementedError
144 |
145 | # evaluate
146 | if index % 20 == 0:
147 | classifier.eval()
148 | with torch.no_grad():
149 | fx = []
150 | labels = []
151 |
152 | for x, y in test_loader_seen:
153 | x = x.to(device)
154 | y = y.long().to(device)
155 | output = classifier(x)
156 |
157 | fx.append(output[:, 0:7])
158 | labels.append(y)
159 |
160 | for x, _ in test_loader_unseen:
161 | x = x.to(device)
162 | output = classifier(x)
163 |
164 | fx.append(output[:, 0:7])
165 | labels.append(-torch.ones(x.size(0)).long().to(device))
166 |
167 | labels = torch.cat(labels, 0)
168 | fx = torch.cat(fx, 0)
169 |
170 | if mode == 'CMG-softmax':
171 | roc_auc = score.softmax_result(fx, labels)
172 | elif mode == 'CMG-energy':
173 | roc_auc = score.energy_result(fx, labels)
174 | else:
175 | raise NotImplementedError
176 |
177 | if roc_auc > max_rocauc:
178 | max_rocauc = roc_auc
179 |
180 | if mode == 'CMG-softmax':
181 | print('Epoch:', epoch, 'Index:', index, 'loss input', loss_input.item(),
182 | 'loss neg', loss_generate.item())
183 | elif mode == 'CMG-energy':
184 | print('Epoch:', epoch, 'Index:', index, 'loss ce', loss_ce.item(),
185 | 'loss energy', loss_energy.item())
186 | else:
187 | raise NotImplementedError
188 |
189 | print('max roc auc:', max_rocauc)
190 |
191 | return max_rocauc
192 |
193 |
194 | def fine_tune_different_dataset(
195 | classifier, vae, train_loader, test_loader_seen, test_loader_unseen, device, mode='CMG-energy'):
196 | """
197 | Performs CMG stage2 by fine-tuning the OOD detector using generated pseudo data on Setting 2.
198 | """
199 | for p in classifier.convE.parameters():
200 | p.requires_grad = False
201 |
202 | vae.eval()
203 | for p in vae.parameters():
204 | p.requires_grad = False
205 |
206 | n_epochs = 15
207 | LR = 0.0001
208 | scalar = 5.0
209 | neg_item_per_batch = 128
210 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr=LR)
211 |
212 | if mode == 'CMG-energy':
213 | # define hyper-parameters for CMG-energy
214 | mu = 0.1
215 | m_in, m_out = get_m(train_loader, classifier, vae, 10, scalar, device)
216 |
217 | index = -1
218 | max_rocauc = 0
219 | for epoch in range(n_epochs):
220 | for data, target in train_loader:
221 |
222 | classifier.eval()
223 | classifier.classifier.train()
224 | optimizer.zero_grad()
225 |
226 | index += 1
227 | data = data.to(device)
228 | target = target.long().to(device)
229 | prediction = classifier(data)
230 | x_generate = generate_pseudo_data(vae, 10, scalar, neg_item_per_batch, device)
231 | prediction_generate = classifier(x_generate)
232 |
233 | if mode == 'CMG-softmax':
234 | loss_input = F.cross_entropy(prediction, target, reduction='none')
235 | loss_input = lf.weighted_average(loss_input, weights=None, dim=0)
236 |
237 | y_generate = torch.ones(neg_item_per_batch) * 10
238 | y_generate = y_generate.long().to(device)
239 |
240 | loss_generate = F.cross_entropy(prediction_generate[:, 0:11], y_generate)
241 |
242 | loss = loss_generate + loss_input
243 |
244 | loss.backward()
245 | optimizer.step()
246 |
247 | elif mode == 'CMG-energy':
248 | loss_ce = F.cross_entropy(prediction, target)
249 |
250 | Ec_in = - torch.logsumexp(prediction, dim=1)
251 | Ec_out = - torch.logsumexp(prediction_generate, dim=1)
252 | loss_energy = torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean()
253 |
254 | loss = loss_ce + mu * loss_energy
255 |
256 | loss.backward()
257 | optimizer.step()
258 |
259 | else:
260 | raise NotImplementedError
261 |
262 | # evaluate
263 | if index % 100 == 0:
264 | classifier.eval()
265 | with torch.no_grad():
266 | fx = []
267 | labels = []
268 |
269 | for x, y in test_loader_seen:
270 | x = x.to(device)
271 | y = y.long().to(device)
272 | output = classifier(x)
273 |
274 | fx.append(output[:, 0:11])
275 | labels.append(y)
276 |
277 | for x, _ in test_loader_unseen:
278 | x = x.to(device)
279 | output = classifier(x)
280 |
281 | fx.append(output[:, 0:11])
282 | labels.append(-torch.ones(x.size(0)).long().to(device))
283 |
284 | labels = torch.cat(labels, 0)
285 | fx = torch.cat(fx, 0)
286 |
287 | if mode == 'CMG-softmax':
288 | roc_auc = score.softmax_result(fx, labels)
289 | elif mode == 'CMG-energy':
290 | roc_auc = score.energy_result(fx, labels)
291 | else:
292 | raise NotImplementedError
293 |
294 | if roc_auc > max_rocauc:
295 | max_rocauc = roc_auc
296 |
297 | if mode == 'CMG-softmax':
298 | print('Epoch:', epoch, 'Index:', index, 'loss input', loss_input.item(),
299 | 'loss neg', loss_generate.item())
300 | elif mode == 'CMG-energy':
301 | print('Epoch:', epoch, 'Index:', index, 'loss ce', loss_ce.item(),
302 | 'loss energy', loss_energy.item())
303 | else:
304 | raise NotImplementedError
305 |
306 | print('max roc auc:', max_rocauc)
307 |
308 | return max_rocauc
309 |
--------------------------------------------------------------------------------
/CSI+CMG/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from sklearn.metrics import roc_auc_score
8 | from torch.utils.data import DataLoader
9 | from torchvision import datasets, transforms
10 |
11 | import CSI_model.classifier as C
12 | from models.vae import ConditionalVAE2
13 | from utils import get_args
14 |
15 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
16 |
17 |
18 | def energy_result(fx, y):
19 | """Calculates roc_auc using energy score.
20 |
21 | Args:
22 | fx: Last layer output of the model.
23 | y: Class Label, assumes the label of unseen data to be -1.
24 | Returns:
25 | roc_auc: Unseen data as positive, seen data as negative.
26 | """
27 | energy_score = - torch.logsumexp(fx, dim=1)
28 | rocauc = roc_auc_score((y == -1).cpu().detach().numpy(), energy_score.cpu().detach().numpy())
29 |
30 | return rocauc
31 |
32 |
33 | # Set up seed -----------------------------------------------------------------
34 | def setup_seed(seed):
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed_all(seed)
37 | np.random.seed(seed)
38 | random.seed(seed)
39 | torch.backends.cudnn.deterministic = True
40 |
41 |
42 | setup_seed(222)
43 |
44 | # Define hyper parameters and model -------------------------------------------
45 | args = get_args()
46 | batch_size = 128
47 | n_epochs = 15
48 | LR = 0.0001
49 | device = torch.device(args.device if torch.cuda.is_available() else "cpu")
50 | print(device)
51 |
52 | classifier = C.get_classifier('resnet18', n_classes=10).to(device)
53 | checkpoint = torch.load(args.params_dict_name)
54 | classifier.load_state_dict(checkpoint, strict=False)
55 |
56 | # freeze the encoder part
57 | for p in classifier.layer1.parameters():
58 | p.requires_grad = False
59 | for p in classifier.layer2.parameters():
60 | p.requires_grad = False
61 | for p in classifier.layer3.parameters():
62 | p.requires_grad = False
63 | for p in classifier.layer4.parameters():
64 | p.requires_grad = False
65 | for p in classifier.conv1.parameters():
66 | p.requires_grad = False
67 |
68 | # load the CVAE model
69 | vae = ConditionalVAE2()
70 | vae.load_state_dict(torch.load(args.params_dict_name2, map_location='cpu'))
71 | vae.to(device)
72 | vae.eval()
73 | for p in vae.parameters():
74 | p.requires_grad = False
75 |
76 |
77 | def generate_pseudo_data(vae):
78 | scalar = 5.0
79 | neg_item_per_batch = 128
80 |
81 | # prepare for class embedding
82 | y1 = torch.Tensor(neg_item_per_batch, vae.class_num)
83 | y1.zero_()
84 | y2 = torch.Tensor(neg_item_per_batch, vae.class_num)
85 | y2.zero_()
86 | ind = torch.randint(0, 10, (neg_item_per_batch, 1))
87 | ind2 = torch.randint(0, 10, (neg_item_per_batch, 1))
88 | y1.scatter_(1, ind.view(-1, 1), 1)
89 | y2.scatter_(1, ind2.view(-1, 1), 1)
90 | y1 = y1.to(device)
91 | y2 = y2.to(device)
92 | class_embed1 = vae.class_embed(y1)
93 | class_embed2 = vae.class_embed(y2)
94 | rv = torch.randint(0, 2, class_embed1.shape).to(device)
95 | class_embed = torch.where(rv == 0, class_embed1, class_embed2)
96 |
97 | # sample in N(0, sigma^2)
98 | random_z = torch.randn(neg_item_per_batch, vae.z_dim).to(device) * scalar
99 |
100 | x_generate = vae.decode(random_z, class_embed)
101 |
102 | return x_generate
103 |
104 |
105 | def get_m(train_loader, classifier, vae):
106 | print("=====get_m======")
107 | with torch.no_grad():
108 | Ec_out, Ec_in = None, None
109 | for data, target in train_loader:
110 | classifier.eval()
111 | classifier.linear.train()
112 |
113 | data = data.to(device)
114 | prediction = classifier(data)
115 |
116 | x_generate = generate_pseudo_data(vae)
117 |
118 | prediction_generate = classifier(x_generate)
119 |
120 | # calculate energy of the training data and generated negative data
121 | T = 1
122 | if Ec_in is None and Ec_out is None:
123 | Ec_in = -T * torch.logsumexp(prediction / T, dim=1)
124 | Ec_out = -T * torch.logsumexp(prediction_generate / T, dim=1)
125 | else:
126 | Ec_in = torch.cat((Ec_in, (-T * torch.logsumexp(prediction / T, dim=1))), dim=0)
127 | Ec_out = torch.cat((Ec_out, (-T * torch.logsumexp(prediction_generate / T, dim=1))), dim=0)
128 | Ec_in = Ec_in.sort()[0]
129 | Ec_out = Ec_out.sort()[0]
130 | in_size = Ec_in.size(0)
131 | out_size = Ec_out.size(0)
132 | m_in, m_out = Ec_in[int(in_size * 0.2)], Ec_out[int(out_size * 0.8)]
133 | print("m_in = ", m_in, ",m_out=", m_out)
134 | return m_in, m_out
135 |
136 |
137 | def tune_main_model():
138 | """
139 | seen: cifar10
140 | unseen: SVHN / LSUN / ImagenNet / LSUN(FIX) / ImageNet(FIX) / CIFAR100.
141 | """
142 |
143 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr=LR)
144 |
145 | # prepare data ------------------------------------------------------------
146 | transform_train = transforms.ToTensor()
147 |
148 | train_data = datasets.CIFAR10(
149 | root='./data/cifar10', train=True, download=True,
150 | transform=transform_train)
151 | test_data = datasets.CIFAR10(
152 | root='./data/cifar10', train=False, download=True,
153 | transform=transforms.ToTensor())
154 |
155 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
156 | test_loader_seen = DataLoader(test_data, batch_size=batch_size, num_workers=4)
157 | test_dir_svhn = os.path.join('./data', 'svhn')
158 | svhn_data = datasets.SVHN(
159 | root=test_dir_svhn, split='test', download=True,
160 | transform=transforms.ToTensor())
161 | test_loader_svhn = DataLoader(svhn_data, batch_size=512, num_workers=4)
162 | test_dir_lsun = os.path.join('./data', 'LSUN_resize')
163 | lsun_data = datasets.ImageFolder(test_dir_lsun, transform=transforms.ToTensor())
164 | test_loader_lsun = DataLoader(lsun_data, batch_size=512, num_workers=4)
165 | test_dir_imagenet = os.path.join('./data', 'Imagenet_resize')
166 | imagenet_data = datasets.ImageFolder(test_dir_imagenet, transform=transforms.ToTensor())
167 | test_loader_imagenet = DataLoader(imagenet_data, batch_size=512, num_workers=4)
168 | test_dir_lsun_fix = os.path.join('./data', 'LSUN_fix')
169 | lsun_fix_data = datasets.ImageFolder(test_dir_lsun_fix, transform=transforms.ToTensor())
170 | test_loader_lsun_fix = DataLoader(lsun_fix_data, batch_size=512, num_workers=4)
171 | test_dir_imagenet_fix = os.path.join('./data', 'Imagenet_fix')
172 | imagenet_data = datasets.ImageFolder(test_dir_imagenet_fix, transform=transforms.ToTensor())
173 | test_loader_imagenet_fix = DataLoader(imagenet_data, batch_size=512, num_workers=4)
174 | test_dir_cifar100 = os.path.join('./data', 'cifar100')
175 | cifar100_data = datasets.CIFAR100(
176 | root=test_dir_cifar100, train=False, transform=transforms.ToTensor(), download=True)
177 | test_loader_cifar100 = DataLoader(cifar100_data, batch_size=512, num_workers=4)
178 |
179 | # hyper params for energy loss
180 | mu = 0.1
181 | m_in, m_out = get_m(train_loader, classifier, vae)
182 |
183 | # fine-tuning the classification head ----------------------------------------------------
184 | index = -1
185 | max_roc_auc = {'svhn': 0, 'lsun': 0, 'imagenet': 0, 'lsun_fix': 0, 'imagenet_fix': 0, 'cifar100': 0}
186 |
187 | for epoch in range(n_epochs):
188 |
189 | for data, target in train_loader:
190 |
191 | classifier.eval()
192 | classifier.linear.train()
193 | optimizer.zero_grad()
194 | index += 1
195 |
196 | data = data.to(device)
197 | target = target.long().to(device)
198 | prediction = classifier(data)
199 | loss_ce = F.cross_entropy(prediction, target)
200 | Ec_in = -torch.logsumexp(prediction, dim=1)
201 |
202 | x_generate = generate_pseudo_data(vae)
203 |
204 | prediction_generate = classifier(x_generate)[:, 0:10]
205 |
206 | Ec_out = -torch.logsumexp(prediction_generate, dim=1)
207 |
208 | # energy loss
209 | loss_energy = torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean()
210 |
211 | loss = loss_ce + mu * loss_energy
212 |
213 | loss.backward()
214 | optimizer.step()
215 |
216 | # evaluate (every 100 batches) ------------------------------------
217 | if index % 100 == 0:
218 | classifier.eval()
219 | with torch.no_grad():
220 |
221 | output_ind = []
222 | output_svhn, output_lsun, output_imagenet, output_lsun_fix, output_imagenet_fix, output_cifar100, \
223 | = [], [], [], [], [], []
224 |
225 | for x, _ in test_loader_seen:
226 | x = x.to(device)
227 | output = classifier(x)
228 | output_ind.append(output)
229 |
230 | for x, _ in test_loader_svhn:
231 | x = x.to(device)
232 | output = classifier(x)
233 | output_svhn.append(output)
234 |
235 | for x, _ in test_loader_lsun:
236 | x = x.to(device)
237 | output = classifier(x)
238 | output_lsun.append(output)
239 |
240 | for x, _ in test_loader_imagenet:
241 | x = x.to(device)
242 | output = classifier(x)
243 | output_imagenet.append(output)
244 |
245 | for x, _ in test_loader_lsun_fix:
246 | x = x.to(device)
247 | output = classifier(x)
248 | output_lsun_fix.append(output)
249 |
250 | for x, _ in test_loader_imagenet_fix:
251 | x = x.to(device)
252 | output = classifier(x)
253 | output_imagenet_fix.append(output)
254 |
255 | for x, _ in test_loader_cifar100:
256 | x = x.to(device)
257 | output = classifier(x)
258 | output_cifar100.append(output)
259 |
260 | output_ind = torch.cat(output_ind, 0)
261 | output_svhn = torch.cat(output_svhn, 0)
262 | output_lsun = torch.cat(output_lsun, 0)
263 | output_imagenet = torch.cat(output_imagenet, 0)
264 | output_lsun_fix = torch.cat(output_lsun_fix, 0)
265 | output_imagenet_fix = torch.cat(output_imagenet_fix, 0)
266 | output_cifar100 = torch.cat(output_cifar100, 0)
267 |
268 | roc_auc_svhn = energy_result(torch.cat([output_ind, output_svhn]), torch.cat(
269 | [torch.ones(output_ind.size(0)), -torch.ones(output_svhn.size(0))]).long().to(device))
270 | roc_auc_lsun = energy_result(torch.cat([output_ind, output_lsun]), torch.cat(
271 | [torch.ones(output_ind.size(0)), -torch.ones(output_lsun.size(0))]).long().to(device))
272 | roc_auc_imagenet = energy_result(torch.cat([output_ind, output_imagenet]), torch.cat(
273 | [torch.ones(output_ind.size(0)), -torch.ones(output_imagenet.size(0))]).long().to(device))
274 | roc_auc_lsun_fix = energy_result(torch.cat([output_ind, output_lsun_fix]), torch.cat(
275 | [torch.ones(output_ind.size(0)), -torch.ones(output_lsun_fix.size(0))]).long().to(device))
276 | roc_auc_imagenet_fix = energy_result(torch.cat([output_ind, output_imagenet_fix]), torch.cat(
277 | [torch.ones(output_ind.size(0)), -torch.ones(output_imagenet_fix.size(0))]).long().to(device))
278 | roc_auc_cifar100 = energy_result(torch.cat([output_ind, output_cifar100]), torch.cat(
279 | [torch.ones(output_ind.size(0)), -torch.ones(output_cifar100.size(0))]).long().to(device))
280 |
281 | max_roc_auc['svhn'] = max(max_roc_auc['svhn'], roc_auc_svhn)
282 | max_roc_auc['lsun'] = max(max_roc_auc['lsun'], roc_auc_lsun)
283 | max_roc_auc['imagenet'] = max(max_roc_auc['imagenet'], roc_auc_imagenet)
284 | max_roc_auc['lsun_fix'] = max(max_roc_auc['lsun_fix'], roc_auc_lsun_fix)
285 | max_roc_auc['imagenet_fix'] = max(max_roc_auc['imagenet_fix'], roc_auc_imagenet_fix)
286 | max_roc_auc['cifar100'] = max(max_roc_auc['cifar100'], roc_auc_cifar100)
287 |
288 | print('Epoch: {} Index: {}'.format(epoch, index))
289 | print('Max rocauc result')
290 | print(max_roc_auc)
291 |
292 |
293 | if __name__ == '__main__':
294 | tune_main_model()
295 |
--------------------------------------------------------------------------------
/CSI+CMG/CSI_model/transform_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import Function
9 |
10 | if torch.__version__ >= '1.4.0':
11 | kwargs = {'align_corners': False}
12 | else:
13 | kwargs = {}
14 |
15 |
16 | def rgb2hsv(rgb):
17 | """Convert a 4-d RGB tensor to the HSV counterpart.
18 |
19 | Here, we compute hue using atan2() based on the definition in [1],
20 | instead of using the common lookup table approach as in [2, 3].
21 | Those values agree when the angle is a multiple of 30°,
22 | otherwise they may differ at most ~1.2°.
23 |
24 | References
25 | [1] https://en.wikipedia.org/wiki/Hue
26 | [2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html
27 | [3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L212
28 | """
29 |
30 | r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
31 |
32 | Cmax = rgb.max(1)[0]
33 | Cmin = rgb.min(1)[0]
34 | delta = Cmax - Cmin
35 |
36 | hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b)
37 | hue = (hue % (2 * math.pi)) / (2 * math.pi)
38 | saturate = delta / Cmax
39 | value = Cmax
40 | hsv = torch.stack([hue, saturate, value], dim=1)
41 | hsv[~torch.isfinite(hsv)] = 0.
42 | return hsv
43 |
44 |
45 | def hsv2rgb(hsv):
46 | """Convert a 4-d HSV tensor to the RGB counterpart.
47 |
48 | >>> %timeit hsv2rgb(hsv)
49 | 2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
50 | >>> %timeit rgb2hsv_fast(rgb)
51 | 298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
52 | >>> torch.allclose(hsv2rgb(hsv), hsv2rgb_fast(hsv), atol=1e-6)
53 | True
54 |
55 | References
56 | [1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative
57 | """
58 | h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]]
59 | c = v * s
60 |
61 | n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1)
62 | k = (n + h * 6) % 6
63 | t = torch.min(k, 4 - k)
64 | t = torch.clamp(t, 0, 1)
65 |
66 | return v - c * t
67 |
68 |
69 | class RandomResizedCropLayer(nn.Module):
70 | def __init__(self, size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)):
71 | '''
72 | Inception Crop
73 | size (tuple): size of fowarding image (C, W, H)
74 | scale (tuple): range of size of the origin size cropped
75 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
76 | '''
77 | super(RandomResizedCropLayer, self).__init__()
78 |
79 | _eye = torch.eye(2, 3)
80 | self.size = size
81 | self.register_buffer('_eye', _eye)
82 | self.scale = scale
83 | self.ratio = ratio
84 |
85 | def forward(self, inputs, whbias=None):
86 | _device = inputs.device
87 | N = inputs.size(0)
88 | _theta = self._eye.repeat(N, 1, 1)
89 |
90 | if whbias is None:
91 | whbias = self._sample_latent(inputs)
92 |
93 | _theta[:, 0, 0] = whbias[:, 0]
94 | _theta[:, 1, 1] = whbias[:, 1]
95 | _theta[:, 0, 2] = whbias[:, 2]
96 | _theta[:, 1, 2] = whbias[:, 3]
97 |
98 | grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device)
99 | output = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs)
100 |
101 | if self.size is not None:
102 | output = F.adaptive_avg_pool2d(output, self.size)
103 |
104 | return output
105 |
106 | def _clamp(self, whbias):
107 |
108 | w = whbias[:, 0]
109 | h = whbias[:, 1]
110 | w_bias = whbias[:, 2]
111 | h_bias = whbias[:, 3]
112 |
113 | # Clamp with scale
114 | w = torch.clamp(w, *self.scale)
115 | h = torch.clamp(h, *self.scale)
116 |
117 | # Clamp with ratio
118 | w = self.ratio[0] * h + torch.relu(w - self.ratio[0] * h)
119 | w = self.ratio[1] * h - torch.relu(self.ratio[1] * h - w)
120 |
121 | # Clamp with bias range: w_bias \in (w - 1, 1 - w), h_bias \in (h - 1, 1 - h)
122 | w_bias = w - 1 + torch.relu(w_bias - w + 1)
123 | w_bias = 1 - w - torch.relu(1 - w - w_bias)
124 |
125 | h_bias = h - 1 + torch.relu(h_bias - h + 1)
126 | h_bias = 1 - h - torch.relu(1 - h - h_bias)
127 |
128 | whbias = torch.stack([w, h, w_bias, h_bias], dim=0).t()
129 |
130 | return whbias
131 |
132 | def _sample_latent(self, inputs):
133 |
134 | _device = inputs.device
135 | N, _, width, height = inputs.shape
136 |
137 | # N * 10 trial
138 | area = width * height
139 | target_area = np.random.uniform(*self.scale, N * 10) * area
140 | log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
141 | aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10))
142 |
143 | # If doesn't satisfy ratio condition, then do central crop
144 | w = np.round(np.sqrt(target_area * aspect_ratio))
145 | h = np.round(np.sqrt(target_area / aspect_ratio))
146 | cond = (0 < w) * (w <= width) * (0 < h) * (h <= height)
147 | w = w[cond]
148 | h = h[cond]
149 | cond_len = w.shape[0]
150 | if cond_len >= N:
151 | w = w[:N]
152 | h = h[:N]
153 | else:
154 | w = np.concatenate([w, np.ones(N - cond_len) * width])
155 | h = np.concatenate([h, np.ones(N - cond_len) * height])
156 |
157 | w_bias = np.random.randint(w - width, width - w + 1) / width
158 | h_bias = np.random.randint(h - height, height - h + 1) / height
159 | w = w / width
160 | h = h / height
161 |
162 | whbias = np.column_stack([w, h, w_bias, h_bias])
163 | whbias = torch.tensor(whbias, device=_device)
164 |
165 | return whbias
166 |
167 |
168 | class HorizontalFlipRandomCrop(nn.Module):
169 | def __init__(self, max_range):
170 | super(HorizontalFlipRandomCrop, self).__init__()
171 | self.max_range = max_range
172 | _eye = torch.eye(2, 3)
173 | self.register_buffer('_eye', _eye)
174 |
175 | def forward(self, input, sign=None, bias=None, rotation=None):
176 | _device = input.device
177 | N = input.size(0)
178 | _theta = self._eye.repeat(N, 1, 1)
179 |
180 | if sign is None:
181 | sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
182 | if bias is None:
183 | bias = torch.empty((N, 2), device=_device).uniform_(-self.max_range, self.max_range)
184 | _theta[:, 0, 0] = sign
185 | _theta[:, :, 2] = bias
186 |
187 | if rotation is not None:
188 | _theta[:, 0:2, 0:2] = rotation
189 |
190 | grid = F.affine_grid(_theta, input.size(), **kwargs).to(_device)
191 | output = F.grid_sample(input, grid, padding_mode='reflection', **kwargs)
192 |
193 | return output
194 |
195 | def _sample_latent(self, N, device=None):
196 | sign = torch.bernoulli(torch.ones(N, device=device) * 0.5) * 2 - 1
197 | bias = torch.empty((N, 2), device=device).uniform_(-self.max_range, self.max_range)
198 | return sign, bias
199 |
200 |
201 | class Rotation(nn.Module):
202 | def __init__(self, max_range=4):
203 | super(Rotation, self).__init__()
204 | self.max_range = max_range
205 | self.prob = 0.5
206 |
207 | def forward(self, input, aug_index=None):
208 | _device = input.device
209 |
210 | _, _, H, W = input.size()
211 |
212 | if aug_index is None:
213 | aug_index = np.random.randint(4)
214 |
215 | output = torch.rot90(input, aug_index, (2, 3))
216 |
217 | _prob = input.new_full((input.size(0),), self.prob)
218 | _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
219 | output = _mask * input + (1 - _mask) * output
220 |
221 | else:
222 | aug_index = aug_index % self.max_range
223 | output = torch.rot90(input, aug_index, (2, 3))
224 |
225 | return output
226 |
227 |
228 | class CutPerm(nn.Module):
229 | def __init__(self, max_range=4):
230 | super(CutPerm, self).__init__()
231 | self.max_range = max_range
232 | self.prob = 0.5
233 |
234 | def forward(self, input, aug_index=None):
235 | _device = input.device
236 |
237 | _, _, H, W = input.size()
238 |
239 | if aug_index is None:
240 | aug_index = np.random.randint(4)
241 |
242 | output = self._cutperm(input, aug_index)
243 |
244 | _prob = input.new_full((input.size(0),), self.prob)
245 | _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
246 | output = _mask * input + (1 - _mask) * output
247 |
248 | else:
249 | aug_index = aug_index % self.max_range
250 | output = self._cutperm(input, aug_index)
251 |
252 | return output
253 |
254 | def _cutperm(self, inputs, aug_index):
255 |
256 | _, _, H, W = inputs.size()
257 | h_mid = int(H / 2)
258 | w_mid = int(W / 2)
259 |
260 | jigsaw_h = aug_index // 2
261 | jigsaw_v = aug_index % 2
262 |
263 | if jigsaw_h == 1:
264 | inputs = torch.cat((inputs[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2)
265 | if jigsaw_v == 1:
266 | inputs = torch.cat((inputs[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3)
267 |
268 | return inputs
269 |
270 |
271 | class HorizontalFlipLayer(nn.Module):
272 | def __init__(self):
273 | """
274 | img_size : (int, int, int)
275 | Height and width must be powers of 2. E.g. (32, 32, 1) or
276 | (64, 128, 3). Last number indicates number of channels, e.g. 1 for
277 | grayscale or 3 for RGB
278 | """
279 | super(HorizontalFlipLayer, self).__init__()
280 |
281 | _eye = torch.eye(2, 3)
282 | self.register_buffer('_eye', _eye)
283 |
284 | def forward(self, inputs):
285 | _device = inputs.device
286 |
287 | N = inputs.size(0)
288 | _theta = self._eye.repeat(N, 1, 1)
289 | r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
290 | _theta[:, 0, 0] = r_sign
291 | grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device)
292 | inputs = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs)
293 |
294 | return inputs
295 |
296 |
297 | class RandomColorGrayLayer(nn.Module):
298 | def __init__(self, p):
299 | super(RandomColorGrayLayer, self).__init__()
300 | self.prob = p
301 |
302 | _weight = torch.tensor([[0.299, 0.587, 0.114]])
303 | self.register_buffer('_weight', _weight.view(1, 3, 1, 1))
304 |
305 | def forward(self, inputs, aug_index=None):
306 |
307 | if aug_index == 0:
308 | return inputs
309 |
310 | l = F.conv2d(inputs, self._weight)
311 | gray = torch.cat([l, l, l], dim=1)
312 |
313 | if aug_index is None:
314 | _prob = inputs.new_full((inputs.size(0),), self.prob)
315 | _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
316 |
317 | gray = inputs * (1 - _mask) + gray * _mask
318 |
319 | return gray
320 |
321 |
322 | class ColorJitterLayer(nn.Module):
323 | def __init__(self, p, brightness, contrast, saturation, hue):
324 | super(ColorJitterLayer, self).__init__()
325 | self.prob = p
326 | self.brightness = self._check_input(brightness, 'brightness')
327 | self.contrast = self._check_input(contrast, 'contrast')
328 | self.saturation = self._check_input(saturation, 'saturation')
329 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
330 | clip_first_on_zero=False)
331 |
332 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
333 | if isinstance(value, numbers.Number):
334 | if value < 0:
335 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
336 | value = [center - value, center + value]
337 | if clip_first_on_zero:
338 | value[0] = max(value[0], 0)
339 | elif isinstance(value, (tuple, list)) and len(value) == 2:
340 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
341 | raise ValueError("{} values should be between {}".format(name, bound))
342 | else:
343 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
344 |
345 | # if value is 0 or (1., 1.) for brightness/contrast/saturation
346 | # or (0., 0.) for hue, do nothing
347 | if value[0] == value[1] == center:
348 | value = None
349 | return value
350 |
351 | def adjust_contrast(self, x):
352 | if self.contrast:
353 | factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast)
354 | means = torch.mean(x, dim=[2, 3], keepdim=True)
355 | x = (x - means) * factor + means
356 | return torch.clamp(x, 0, 1)
357 |
358 | def adjust_hsv(self, x):
359 | f_h = x.new_zeros(x.size(0), 1, 1)
360 | f_s = x.new_ones(x.size(0), 1, 1)
361 | f_v = x.new_ones(x.size(0), 1, 1)
362 |
363 | if self.hue:
364 | f_h.uniform_(*self.hue)
365 | if self.saturation:
366 | f_s = f_s.uniform_(*self.saturation)
367 | if self.brightness:
368 | f_v = f_v.uniform_(*self.brightness)
369 |
370 | return RandomHSVFunction.apply(x, f_h, f_s, f_v)
371 |
372 | def transform(self, inputs):
373 | # Shuffle transform
374 | if np.random.rand() > 0.5:
375 | transforms = [self.adjust_contrast, self.adjust_hsv]
376 | else:
377 | transforms = [self.adjust_hsv, self.adjust_contrast]
378 |
379 | for t in transforms:
380 | inputs = t(inputs)
381 |
382 | return inputs
383 |
384 | def forward(self, inputs):
385 | _prob = inputs.new_full((inputs.size(0),), self.prob)
386 | _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
387 | return inputs * (1 - _mask) + self.transform(inputs) * _mask
388 |
389 |
390 | class RandomHSVFunction(Function):
391 | @staticmethod
392 | def forward(ctx, x, f_h, f_s, f_v):
393 | # ctx is a context object that can be used to stash information
394 | # for backward computation
395 | x = rgb2hsv(x)
396 | h = x[:, 0, :, :]
397 | h += (f_h * 255. / 360.)
398 | h = (h % 1)
399 | x[:, 0, :, :] = h
400 | x[:, 1, :, :] = x[:, 1, :, :] * f_s
401 | x[:, 2, :, :] = x[:, 2, :, :] * f_v
402 | x = torch.clamp(x, 0, 1)
403 | x = hsv2rgb(x)
404 | return x
405 |
406 | @staticmethod
407 | def backward(ctx, grad_output):
408 | # We return as many input gradients as there were arguments.
409 | # Gradients of non-Tensor arguments to forward must be None.
410 | grad_input = None
411 | if ctx.needs_input_grad[0]:
412 | grad_input = grad_output.clone()
413 | return grad_input, None, None, None
414 |
415 |
416 | class NormalizeLayer(nn.Module):
417 | """
418 | In order to certify radii in original coordinates rather than standardized coordinates, we
419 | add the Gaussian noise _before_ standardizing, which is why we have standardization be the first
420 | layer of the classifier rather than as a part of preprocessing as is typical.
421 | """
422 |
423 | def __init__(self):
424 | super(NormalizeLayer, self).__init__()
425 |
426 | def forward(self, inputs):
427 | return (inputs - 0.5) / 0.5
428 |
--------------------------------------------------------------------------------