├── data ├── .gitkeep ├── imagenet1k │ └── .gitkeep └── fall11_whole_extracted │ └── .gitkeep ├── log └── .gitkeep ├── src ├── __init__.py ├── optim │ ├── __init__.py │ └── classifier_trainer.py ├── base │ ├── __init__.py │ ├── base_net.py │ ├── torchvision_dataset.py │ ├── base_dataset.py │ └── base_trainer.py ├── utils │ ├── __init__.py │ ├── config.py │ ├── diag.py │ └── visualization │ │ └── plot_fun.py ├── networks │ ├── __init__.py │ ├── modules │ │ └── focal_loss.py │ ├── toy_Net.py │ ├── classifier_Net.py │ ├── main.py │ ├── mnist_LeNet.py │ ├── cifar10_LeNet.py │ ├── cbam.py │ └── imagenet_WideResNet.py ├── datasets │ ├── __init__.py │ ├── array_dataset.py │ ├── mnist.py │ ├── main.py │ ├── cifar10.py │ ├── emnist.py │ ├── imagenet1k.py │ ├── cifar100.py │ ├── tinyimages.py │ └── imagenet22k.py ├── experiments │ ├── cifar10_OE_tinyimages.sh │ ├── imagenet1k_OE_imagenet22k.sh │ ├── mnist_OE_emnist_classes.sh │ ├── mnist_OE_emnist_blur.sh │ ├── cifar10_OE_cifar100_classes.sh │ ├── cifar10_OE_tinyimages_sizes.sh │ ├── cifar10_OE_tinyimages_blur.sh │ ├── imagenet1k_OE_imagenet22k_sizes.sh │ ├── imagenet1k_OE_imagenet22k_blur.sh │ ├── focal_sensitivity.sh │ └── cifar10_OE_tinyimages_hsc_sensitivity.sh ├── classifier.py └── main.py ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /log/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/imagenet1k/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/fall11_whole_extracted/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier_trainer import ClassifierTrainer 2 | -------------------------------------------------------------------------------- /src/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import * 2 | from .torchvision_dataset import * 3 | from .base_net import * 4 | from .base_trainer import * 5 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | from .diag import plot_extreme_samples 3 | from .visualization.plot_fun import plot_images_grid, plot_dist, plot_line 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Click==7.0 2 | matplotlib==3.1.0 3 | numpy==1.16.4 4 | pandas==0.24.2 5 | Pillow==6.1.0 6 | scikit-learn==0.21.2 7 | scipy==1.3.0 8 | seaborn==0.9.0 9 | torch==1.5.0 10 | torchvision==0.6.0 11 | -------------------------------------------------------------------------------- /src/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import build_network 2 | from .toy_Net import Toy_Net 3 | from .mnist_LeNet import MNIST_LeNet 4 | from .cifar10_LeNet import CIFAR10_LeNet 5 | from .cbam import CBAM 6 | from .imagenet_WideResNet import ImageNet_WideResNet 7 | from .classifier_Net import ClassifierNet 8 | from .modules.focal_loss import FocalLoss 9 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import load_dataset 2 | from .mnist import MNIST_Dataset 3 | from .emnist import EMNIST_Dataset 4 | from .cifar10 import CIFAR10_Dataset 5 | from .cifar100 import CIFAR100_Dataset 6 | from .tinyimages import TinyImages_Dataset 7 | from .imagenet1k import ImageNet1K_Dataset 8 | from .imagenet22k import ImageNet22K_Dataset 9 | from .array_dataset import Array_Dataset 10 | -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class Config(object): 5 | """Base class for experimental setting/configuration.""" 6 | 7 | def __init__(self, settings): 8 | self.settings = settings 9 | 10 | def load_config(self, import_json): 11 | """Load settings dict from import_json (path/filename.json) JSON-file.""" 12 | 13 | with open(import_json, 'r') as fp: 14 | settings = json.load(fp) 15 | 16 | for key, value in settings.items(): 17 | self.settings[key] = value 18 | 19 | def save_config(self, export_json): 20 | """Save settings dict to export_json (path/filename.json) JSON-file.""" 21 | 22 | with open(export_json, 'w') as fp: 23 | json.dump(self.settings, fp) 24 | -------------------------------------------------------------------------------- /src/networks/modules/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.modules.loss import _Loss 4 | 5 | 6 | # Implementation of the focal loss from https://arxiv.org/abs/1708.02002 7 | class FocalLoss(_Loss): 8 | def __init__(self, gamma=2.0, eps=1e-7): 9 | super(FocalLoss, self).__init__(size_average=None, reduce=None, reduction='mean') 10 | self.gamma = gamma 11 | self.eps = eps 12 | 13 | def forward(self, input, target): 14 | BCE_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none') 15 | pt = torch.exp(-BCE_loss) # prevents nans when probability 0 16 | pt = pt.clamp(self.eps, 1. - self.eps) 17 | F_loss = (1 - pt).pow(self.gamma) * BCE_loss 18 | return F_loss.mean() 19 | -------------------------------------------------------------------------------- /src/networks/toy_Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from base.base_net import BaseNet 6 | 7 | 8 | class Toy_Net(BaseNet): 9 | 10 | def __init__(self, rep_dim=128, bias_terms=True): 11 | super().__init__() 12 | 13 | self.rep_dim = rep_dim 14 | 15 | self.fc1 = nn.Linear(2, 256, bias=bias_terms) 16 | nn.init.xavier_normal_(self.fc1.weight, gain=nn.init.calculate_gain('leaky_relu')) 17 | self.bn1 = nn.BatchNorm1d(256, affine=bias_terms) 18 | 19 | self.fc2 = nn.Linear(256, self.rep_dim, bias=bias_terms) 20 | nn.init.xavier_normal_(self.fc2.weight, gain=nn.init.calculate_gain('sigmoid')) 21 | 22 | def forward(self, x): 23 | x = F.elu(self.bn1(self.fc1(x))) 24 | x = torch.sigmoid(self.fc2(x)) 25 | return x 26 | -------------------------------------------------------------------------------- /src/base/base_net.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BaseNet(nn.Module): 7 | """Base class for all neural networks.""" 8 | 9 | def __init__(self): 10 | super().__init__() 11 | self.logger = logging.getLogger(self.__class__.__name__) 12 | self.rep_dim = None # representation dimensionality, i.e. dim of the last or code layer 13 | 14 | def forward(self, *input): 15 | """ 16 | Forward pass logic 17 | :return: Network output 18 | """ 19 | raise NotImplementedError 20 | 21 | def summary(self): 22 | """Network summary.""" 23 | net_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in net_parameters]) 25 | self.logger.info('Trainable parameters: {}'.format(params)) 26 | self.logger.info(self) 27 | -------------------------------------------------------------------------------- /src/experiments/cifar10_OE_tinyimages.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/cifar10_oe_tinyimages; 5 | 6 | methods=( hsc deepSAD bce focal); 7 | 8 | for seed in $(seq 1 10); 9 | do 10 | for exp in $(seq 0 9); 11 | do 12 | for method in "${methods[@]}"; 13 | do 14 | mkdir ../log/cifar10_oe_tinyimages/${method}; 15 | mkdir ../log/cifar10_oe_tinyimages/${method}/${exp}_vs_rest; 16 | mkdir ../log/cifar10_oe_tinyimages/${method}/${exp}_vs_rest/seed_${seed}; 17 | python main.py cifar10 cifar10_LeNet ../log/cifar10_oe_tinyimages/${method}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective ${method} --outlier_exposure True --oe_dataset_name tinyimages --device cuda --seed ${seed} --lr 0.001 --n_epochs 200 --lr_milestone 100 --lr_milestone 150 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 18 | done 19 | done 20 | done 21 | -------------------------------------------------------------------------------- /src/experiments/imagenet1k_OE_imagenet22k.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/imagenet1k_oe_imagenet22k; 5 | 6 | methods=( hsc deepSAD bce focal); 7 | 8 | for seed in $(seq 1 10); 9 | do 10 | for exp in $(seq 0 29); 11 | do 12 | for method in "${methods[@]}"; 13 | do 14 | mkdir ../log/imagenet1k_oe_imagenet22k/${method}; 15 | mkdir ../log/imagenet1k_oe_imagenet22k/${method}/${exp}_vs_rest; 16 | mkdir ../log/imagenet1k_oe_imagenet22k/${method}/${exp}_vs_rest/seed_${seed}; 17 | python main.py imagenet1k imagenet_WideResNet ../log/imagenet1k_oe_imagenet22k/${method}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective ${method} --outlier_exposure True --oe_dataset_name imagenet22k --device cuda --seed ${seed} --lr 0.001 --n_epochs 150 --lr_milestone 100 --lr_milestone 125 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 18 | done 19 | done 20 | done 21 | -------------------------------------------------------------------------------- /src/base/torchvision_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseADDataset 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class TorchvisionDataset(BaseADDataset): 6 | """TorchvisionDataset class for datasets already implemented in torchvision.datasets.""" 7 | 8 | def __init__(self, root: str): 9 | super().__init__(root) 10 | 11 | self.image_size = None # tuple with the size of an image from the dataset (e.g. (1, 28, 28) for MNIST) 12 | 13 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0, 14 | pin_memory: bool = False) -> (DataLoader, DataLoader): 15 | train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train, 16 | num_workers=num_workers, pin_memory=pin_memory, drop_last=True) 17 | test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test, 18 | num_workers=num_workers, pin_memory=pin_memory, drop_last=False) 19 | return train_loader, test_loader 20 | -------------------------------------------------------------------------------- /src/networks/classifier_Net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from base.base_net import BaseNet 4 | from .mnist_LeNet import MNIST_LeNet 5 | from .cifar10_LeNet import CIFAR10_LeNet 6 | from .imagenet_WideResNet import ImageNet_WideResNet 7 | from .toy_Net import Toy_Net 8 | 9 | 10 | class ClassifierNet(BaseNet): 11 | 12 | def __init__(self, net_name, rep_dim=64, bias_terms=False): 13 | super().__init__() 14 | 15 | if net_name == 'mnist_LeNet_classifier': 16 | self.network = MNIST_LeNet(rep_dim=rep_dim, bias_terms=bias_terms) 17 | if net_name == 'cifar10_LeNet_classifier': 18 | self.network = CIFAR10_LeNet(rep_dim=rep_dim, bias_terms=bias_terms) 19 | if net_name == 'imagenet_WideResNet_classifier': 20 | self.network = ImageNet_WideResNet(rep_dim=rep_dim) 21 | if net_name == 'toy_Net_classifier': 22 | self.network = Toy_Net(rep_dim=rep_dim) 23 | 24 | self.linear = nn.Linear(self.network.rep_dim, 1) 25 | 26 | def forward(self, x): 27 | x = self.network(x) 28 | x = self.linear(x) 29 | return x 30 | -------------------------------------------------------------------------------- /src/base/base_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class BaseADDataset(ABC): 6 | """Anomaly detection dataset base class.""" 7 | 8 | def __init__(self, root: str): 9 | super().__init__() 10 | self.root = root # root path to data 11 | 12 | self.n_classes = 2 # 0: normal, 1: outlier 13 | self.normal_classes = None # tuple with original class labels that define the normal class 14 | self.outlier_classes = None # tuple with original class labels that define the outlier class 15 | 16 | self.train_set = None # must be of type torch.utils.data.Dataset 17 | self.test_set = None # must be of type torch.utils.data.Dataset 18 | 19 | @abstractmethod 20 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> ( 21 | DataLoader, DataLoader): 22 | """Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set.""" 23 | pass 24 | 25 | def __repr__(self): 26 | return self.__class__.__name__ 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 lukasruff 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/experiments/mnist_OE_emnist_classes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/mnist_oe_emnist; 5 | 6 | methods=( hsc bce ); 7 | n_classes=( 1 2 3 5 10 15 20 26 ); 8 | 9 | for seed in $(seq 1 10); 10 | do 11 | for k in "${n_classes[@]}"; 12 | do 13 | for exp in $(seq 0 9); 14 | do 15 | for method in "${methods[@]}"; 16 | do 17 | mkdir ../log/mnist_oe_emnist/${method}; 18 | mkdir ../log/mnist_oe_emnist/${method}/${k}_oe_classes; 19 | mkdir ../log/mnist_oe_emnist/${method}/${k}_oe_classes/${exp}_vs_rest; 20 | mkdir ../log/mnist_oe_emnist/${method}/${k}_oe_classes/${exp}_vs_rest/seed_${seed}; 21 | python main.py mnist mnist_LeNet ../log/mnist_oe_emnist/${method}/${k}_oe_classes/${exp}_vs_rest/seed_${seed} ../data --rep_dim 32 --objective ${method} --outlier_exposure True --oe_dataset_name emnist --oe_n_classes ${k} --device cuda --seed ${seed} --lr 0.001 --n_epochs 150 --lr_milestone 50 --lr_milestone 100 --batch_size 128 --normal_class ${exp}; 22 | done 23 | done 24 | done 25 | done 26 | -------------------------------------------------------------------------------- /src/experiments/mnist_OE_emnist_blur.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/mnist_oe_emnist; 5 | mkdir ../log/mnist_oe_emnist/blur; 6 | 7 | methods=( hsc bce ); 8 | stds=( 1 2 4 8 16 32 ); 9 | 10 | for seed in $(seq 1 5); 11 | do 12 | for exp in $(seq 0 9); 13 | do 14 | for std in "${stds[@]}"; 15 | do 16 | for method in "${methods[@]}"; 17 | do 18 | mkdir ../log/mnist_oe_emnist/blur/${method}; 19 | mkdir ../log/mnist_oe_emnist/blur/${method}/blur_std=${std}; 20 | mkdir ../log/mnist_oe_emnist/blur/${method}/blur_std=${std}/${exp}_vs_rest; 21 | mkdir ../log/mnist_oe_emnist/blur/${method}/blur_std=${std}/${exp}_vs_rest/seed_${seed}; 22 | python main.py mnist mnist_LeNet ../log/mnist_oe_emnist/blur/${method}/blur_std=${std}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 32 --objective ${method} --outlier_exposure True --oe_dataset_name emnist --oe_n_classes 26 --blur_oe True --blur_std ${std} --device cuda --seed ${seed} --lr 0.001 --n_epochs 150 --lr_milestone 50 --lr_milestone 100 --batch_size 128 --normal_class ${exp}; 23 | done 24 | done 25 | done 26 | done 27 | -------------------------------------------------------------------------------- /src/experiments/cifar10_OE_cifar100_classes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/cifar10_oe_cifar100; 5 | 6 | methods=( hsc bce ); 7 | n_classes=( 1 2 4 8 16 32 64 100 ); 8 | 9 | for seed in $(seq 1 10); 10 | do 11 | for k in "${n_classes[@]}"; 12 | do 13 | for exp in $(seq 0 9); 14 | do 15 | for method in "${methods[@]}"; 16 | do 17 | mkdir ../log/cifar10_oe_cifar100/${method}; 18 | mkdir ../log/cifar10_oe_cifar100/${method}/${k}_oe_classes; 19 | mkdir ../log/cifar10_oe_cifar100/${method}/${k}_oe_classes/${exp}_vs_rest; 20 | mkdir ../log/cifar10_oe_cifar100/${method}/${k}_oe_classes/${exp}_vs_rest/seed_${seed}; 21 | python main.py cifar10 cifar10_LeNet ../log/cifar10_oe_cifar100/${method}/${k}_oe_classes/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective ${method} --outlier_exposure True --oe_dataset_name cifar100 --oe_n_classes ${k} --device cuda --seed ${seed} --lr 0.001 --n_epochs 200 --lr_milestone 100 --lr_milestone 150 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 22 | done 23 | done 24 | done 25 | done 26 | -------------------------------------------------------------------------------- /src/base/base_trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from .base_dataset import BaseADDataset 3 | from .base_net import BaseNet 4 | 5 | 6 | class BaseTrainer(ABC): 7 | """Trainer base class.""" 8 | 9 | def __init__(self, optimizer_name: str, lr: float, n_epochs: int, lr_milestones: tuple, batch_size: int, 10 | weight_decay: float, device: str, n_jobs_dataloader: int): 11 | super().__init__() 12 | self.optimizer_name = optimizer_name 13 | self.lr = lr 14 | self.n_epochs = n_epochs 15 | self.lr_milestones = lr_milestones 16 | self.batch_size = batch_size 17 | self.weight_decay = weight_decay 18 | self.device = device 19 | self.n_jobs_dataloader = n_jobs_dataloader 20 | 21 | @abstractmethod 22 | def train(self, dataset: BaseADDataset, net: BaseNet) -> BaseNet: 23 | """ 24 | Implement train method that trains the given network using the train_set of dataset. 25 | :return: Trained net 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def test(self, dataset: BaseADDataset, net: BaseNet): 31 | """ 32 | Implement test method that evaluates the test_set of dataset on the given network. 33 | """ 34 | pass 35 | -------------------------------------------------------------------------------- /src/experiments/cifar10_OE_tinyimages_sizes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/cifar10_oe_tinyimages; 5 | 6 | methods=( hsc bce ); 7 | sizes=( 1 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768 65536 131072 262144 524288 1048576 2097152 ); 8 | 9 | for seed in $(seq 1 10); 10 | do 11 | for size in "${sizes[@]}"; 12 | do 13 | for exp in $(seq 0 9); 14 | do 15 | for method in "${methods[@]}"; 16 | do 17 | mkdir ../log/cifar10_oe_tinyimages/${method}; 18 | mkdir ../log/cifar10_oe_tinyimages/${method}/oe_size_${size}; 19 | mkdir ../log/cifar10_oe_tinyimages/${method}/oe_size_${size}/${exp}_vs_rest; 20 | mkdir ../log/cifar10_oe_tinyimages/${method}/oe_size_${size}/${exp}_vs_rest/seed_${seed}; 21 | python main.py cifar10 cifar10_LeNet ../log/cifar10_oe_tinyimages/${method}/oe_size_${size}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective ${method} --outlier_exposure True --oe_dataset_name tinyimages --oe_size ${size} --device cuda --seed ${seed} --lr 0.001 --n_epochs 200 --lr_milestone 100 --lr_milestone 150 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 22 | done 23 | done 24 | done 25 | done 26 | -------------------------------------------------------------------------------- /src/experiments/cifar10_OE_tinyimages_blur.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/cifar10_oe_tinyimages; 5 | mkdir ../log/cifar10_oe_tinyimages/blur; 6 | 7 | methods=( hsc bce ); 8 | stds=( 1 2 4 8 16 32 ); 9 | 10 | for seed in $(seq 1 10); 11 | do 12 | for std in "${stds[@]}"; 13 | do 14 | for exp in $(seq 0 9); 15 | do 16 | for method in "${methods[@]}"; 17 | do 18 | mkdir ../log/cifar10_oe_tinyimages/blur/${method}; 19 | mkdir ../log/cifar10_oe_tinyimages/blur/${method}/blur_std=${std}; 20 | mkdir ../log/cifar10_oe_tinyimages/blur/${method}/blur_std=${std}/${exp}_vs_rest; 21 | mkdir ../log/cifar10_oe_tinyimages/blur/${method}/blur_std=${std}/${exp}_vs_rest/seed_${seed}; 22 | python main.py cifar10 cifar10_LeNet ../log/cifar10_oe_tinyimages/blur/${method}/blur_std=${std}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective ${method} --outlier_exposure True --oe_dataset_name tinyimages --oe_size 128 --blur_oe True --blur_std ${std} --device cuda --seed ${seed} --lr 0.001 --n_epochs 200 --lr_milestone 100 --lr_milestone 150 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 23 | done 24 | done 25 | done 26 | done 27 | -------------------------------------------------------------------------------- /src/experiments/imagenet1k_OE_imagenet22k_sizes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/imagenet1k_oe_imagenet22k; 5 | 6 | methods=( hsc bce ); 7 | sizes=( 1 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768 65536 131072 262144 ); 8 | 9 | for seed in $(seq 1 5); 10 | do 11 | for size in "${sizes[@]}"; 12 | do 13 | for exp in $(seq 0 29); 14 | do 15 | for method in "${methods[@]}"; 16 | do 17 | mkdir ../log/imagenet1k_oe_imagenet22k/${method}; 18 | mkdir ../log/imagenet1k_oe_imagenet22k/${method}/oe_size_${size}; 19 | mkdir ../log/imagenet1k_oe_imagenet22k/${method}/oe_size_${size}/${exp}_vs_rest; 20 | mkdir ../log/imagenet1k_oe_imagenet22k/${method}/oe_size_${size}/${exp}_vs_rest/seed_${seed}; 21 | python main.py imagenet1k imagenet_WideResNet ../log/imagenet1k_oe_imagenet22k/${method}/oe_size_${size}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective ${method} --outlier_exposure True --oe_dataset_name imagenet22k --oe_size ${size} --device cuda --seed ${seed} --lr 0.001 --n_epochs 150 --lr_milestone 100 --lr_milestone 125 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 22 | done 23 | done 24 | done 25 | done 26 | -------------------------------------------------------------------------------- /src/experiments/imagenet1k_OE_imagenet22k_blur.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/imagenet1k_oe_imagenet22k; 5 | mkdir ../log/imagenet1k_oe_imagenet22k/blur; 6 | 7 | methods=( hsc bce ); 8 | stds=( 1 2 4 8 16 32 ); 9 | 10 | for seed in $(seq 1 5); 11 | do 12 | for std in "${stds[@]}"; 13 | do 14 | for exp in $(seq 0 29); 15 | do 16 | for method in "${methods[@]}"; 17 | do 18 | mkdir ../log/imagenet1k_oe_imagenet22k/blur/${method}; 19 | mkdir ../log/imagenet1k_oe_imagenet22k/blur/${method}/blur_std=${std}; 20 | mkdir ../log/imagenet1k_oe_imagenet22k/blur/${method}/blur_std=${std}/${exp}_vs_rest; 21 | mkdir ../log/imagenet1k_oe_imagenet22k/blur/${method}/blur_std=${std}/${exp}_vs_rest/seed_${seed}; 22 | python main.py imagenet1k imagenet_WideResNet ../log/imagenet1k_oe_imagenet22k/blur/${method}/blur_std=${std}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective ${method} --outlier_exposure True --oe_dataset_name imagenet22k --oe_size 64 --blur_oe True --blur_std ${std} --device cuda --seed ${seed} --lr 0.001 --n_epochs 150 --lr_milestone 100 --lr_milestone 125 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 23 | done 24 | done 25 | done 26 | done 27 | -------------------------------------------------------------------------------- /src/networks/main.py: -------------------------------------------------------------------------------- 1 | from .mnist_LeNet import MNIST_LeNet 2 | from .cifar10_LeNet import CIFAR10_LeNet 3 | from .imagenet_WideResNet import ImageNet_WideResNet 4 | from .toy_Net import Toy_Net 5 | from .classifier_Net import ClassifierNet 6 | 7 | 8 | def build_network(net_name, rep_dim=64, bias_terms=False): 9 | """Builds the neural network.""" 10 | 11 | implemented_networks = ('mnist_LeNet', 'cifar10_LeNet', 'imagenet_WideResNet', 'toy_Net', 12 | 'mnist_LeNet_classifier', 'cifar10_LeNet_classifier', 'imagenet_WideResNet_classifier', 13 | 'toy_Net_classifier') 14 | assert net_name in implemented_networks 15 | 16 | net = None 17 | 18 | if net_name == 'mnist_LeNet': 19 | net = MNIST_LeNet(rep_dim=rep_dim, bias_terms=bias_terms) 20 | 21 | if net_name == 'cifar10_LeNet': 22 | net = CIFAR10_LeNet(rep_dim=rep_dim, bias_terms=bias_terms) 23 | 24 | if net_name == 'imagenet_WideResNet': 25 | net = ImageNet_WideResNet(rep_dim=rep_dim) 26 | 27 | if net_name == 'toy_Net': 28 | net = Toy_Net(rep_dim=rep_dim) 29 | 30 | if net_name in ['mnist_LeNet_classifier', 'cifar10_LeNet_classifier', 'imagenet_WideResNet_classifier', 31 | 'toy_Net_classifier']: 32 | net = ClassifierNet(net_name, rep_dim=rep_dim, bias_terms=bias_terms) 33 | 34 | return net 35 | -------------------------------------------------------------------------------- /src/networks/mnist_LeNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from base.base_net import BaseNet 5 | 6 | 7 | class MNIST_LeNet(BaseNet): 8 | 9 | def __init__(self, rep_dim=32, bias_terms=False): 10 | super().__init__() 11 | 12 | self.rep_dim = rep_dim 13 | self.pool = nn.MaxPool2d(2, 2) 14 | 15 | self.conv1 = nn.Conv2d(1, 16, 5, bias=bias_terms, padding=2) 16 | nn.init.xavier_normal_(self.conv1.weight, gain=nn.init.calculate_gain('leaky_relu')) 17 | self.bn2d1 = nn.BatchNorm2d(16, eps=1e-04, affine=bias_terms) 18 | self.conv2 = nn.Conv2d(16, 32, 5, bias=bias_terms, padding=2) 19 | nn.init.xavier_normal_(self.conv2.weight, gain=nn.init.calculate_gain('leaky_relu')) 20 | self.bn2d2 = nn.BatchNorm2d(32, eps=1e-04, affine=bias_terms) 21 | self.fc1 = nn.Linear(32 * 7 * 7, 64, bias=bias_terms) 22 | nn.init.xavier_normal_(self.fc1.weight, gain=nn.init.calculate_gain('leaky_relu')) 23 | self.bn1d1 = nn.BatchNorm1d(64, eps=1e-04, affine=bias_terms) 24 | self.fc2 = nn.Linear(64, self.rep_dim, bias=bias_terms) 25 | nn.init.xavier_normal_(self.fc2.weight) 26 | 27 | def forward(self, x): 28 | x = x.view(-1, 1, 28, 28) 29 | x = self.conv1(x) 30 | x = self.pool(F.leaky_relu(self.bn2d1(x))) 31 | x = self.conv2(x) 32 | x = self.pool(F.leaky_relu(self.bn2d2(x))) 33 | x = x.view(int(x.size(0)), -1) 34 | x = self.fc1(x) 35 | x = F.leaky_relu(self.bn1d1(x)) 36 | x = self.fc2(x) 37 | return x 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.p 6 | *.pkl 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | venv/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | 58 | # Sphinx documentation 59 | docs/_build/ 60 | 61 | # PyBuilder 62 | target/ 63 | 64 | #Ipython Notebook 65 | .ipynb_checkpoints 66 | 67 | # Eclipse config files 68 | .project 69 | .pydevproject 70 | 71 | # Sublime config file 72 | cnn.sublime-workspace 73 | cnn.sublime-project 74 | 75 | # PyCharm config files 76 | .idea 77 | 78 | # Temporary files 79 | *.m~ 80 | *.py~ 81 | 82 | # macOS General 83 | .DS_Store 84 | .AppleDouble 85 | .LSOverride 86 | 87 | # Icon must end with two \r 88 | Icon 89 | 90 | # Thumbnails 91 | ._* 92 | 93 | # Files that might appear in the root of a volume 94 | .DocumentRevisions-V100 95 | .fseventsd 96 | .Spotlight-V100 97 | .TemporaryItems 98 | .Trashes 99 | .VolumeIcon.icns 100 | .com.apple.timemachine.donotpresent 101 | 102 | # Directories potentially created on remote AFP share 103 | .AppleDB 104 | .AppleDesktop 105 | Network Trash Folder 106 | Temporary Items 107 | .apdisk -------------------------------------------------------------------------------- /src/networks/cifar10_LeNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from base.base_net import BaseNet 5 | 6 | 7 | class CIFAR10_LeNet(BaseNet): 8 | 9 | def __init__(self, rep_dim=256, bias_terms=False): 10 | super().__init__() 11 | 12 | self.rep_dim = rep_dim 13 | self.pool = nn.MaxPool2d(2, 2) 14 | 15 | # Encoder network 16 | self.conv1 = nn.Conv2d(3, 32, 5, bias=bias_terms, padding=2) 17 | nn.init.xavier_normal_(self.conv1.weight, gain=nn.init.calculate_gain('leaky_relu')) 18 | self.bn2d1 = nn.BatchNorm2d(32, eps=1e-04, affine=bias_terms) 19 | self.conv2 = nn.Conv2d(32, 64, 5, bias=bias_terms, padding=2) 20 | nn.init.xavier_normal_(self.conv2.weight, gain=nn.init.calculate_gain('leaky_relu')) 21 | self.bn2d2 = nn.BatchNorm2d(64, eps=1e-04, affine=bias_terms) 22 | self.conv3 = nn.Conv2d(64, 128, 5, bias=bias_terms, padding=2) 23 | nn.init.xavier_normal_(self.conv3.weight, gain=nn.init.calculate_gain('leaky_relu')) 24 | self.bn2d3 = nn.BatchNorm2d(128, eps=1e-04, affine=bias_terms) 25 | self.fc1 = nn.Linear(128 * 4 * 4, 512, bias=bias_terms) 26 | nn.init.xavier_normal_(self.fc1.weight, gain=nn.init.calculate_gain('leaky_relu')) 27 | self.bn1d1 = nn.BatchNorm1d(512, eps=1e-04, affine=bias_terms) 28 | self.fc2 = nn.Linear(512, self.rep_dim, bias=bias_terms) 29 | nn.init.xavier_normal_(self.fc2.weight) 30 | 31 | def forward(self, x): 32 | x = x.view(-1, 3, 32, 32) 33 | x = self.conv1(x) 34 | x = self.pool(F.leaky_relu(self.bn2d1(x))) 35 | x = self.conv2(x) 36 | x = self.pool(F.leaky_relu(self.bn2d2(x))) 37 | x = self.conv3(x) 38 | x = self.pool(F.leaky_relu(self.bn2d3(x))) 39 | x = x.view(int(x.size(0)), -1) 40 | x = self.fc1(x) 41 | x = F.leaky_relu(self.bn1d1(x)) 42 | x = self.fc2(x) 43 | return x 44 | -------------------------------------------------------------------------------- /src/experiments/focal_sensitivity.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/cifar10_oe_tinyimages; 5 | mkdir ../log/cifar10_oe_tinyimages/focal_sensitivity; 6 | 7 | mkdir ../log/imagenet1k_oe_imagenet22k; 8 | mkdir ../log/imagenet1k_oe_imagenet22k/focal_sensitivity; 9 | 10 | gammas=( 0 0.5 2 4 ); 11 | 12 | for seed in $(seq 1 10); 13 | do 14 | for exp in $(seq 0 9); 15 | do 16 | for gamma in "${gammas[@]}"; 17 | do 18 | mkdir ../log/cifar10_oe_tinyimages/focal_sensitivity/gamma=${gamma}; 19 | mkdir ../log/cifar10_oe_tinyimages/focal_sensitivity/gamma=${gamma}/${exp}_vs_rest; 20 | mkdir ../log/cifar10_oe_tinyimages/focal_sensitivity/gamma=${gamma}/${exp}_vs_rest/seed_${seed}; 21 | python main.py cifar10 cifar10_LeNet ../log/cifar10_oe_tinyimages/focal_sensitivity/gamma=${gamma}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective focal --focal_gamma ${gamma} --outlier_exposure True --oe_dataset_name tinyimages --device cuda --seed ${seed} --lr 0.001 --n_epochs 200 --lr_milestone 100 --lr_milestone 150 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 22 | 23 | mkdir ../log/imagenet1k_oe_imagenet22k/focal_sensitivity/gamma=${gamma}; 24 | mkdir ../log/imagenet1k_oe_imagenet22k/focal_sensitivity/gamma=${gamma}/${exp}_vs_rest; 25 | mkdir ../log/imagenet1k_oe_imagenet22k/focal_sensitivity/gamma=${gamma}/${exp}_vs_rest/seed_${seed}; 26 | python main.py imagenet1k imagenet_WideResNet ../log/imagenet1k_oe_imagenet22k/focal_sensitivity/gamma=${gamma}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective focal --focal_gamma ${gamma} --outlier_exposure True --oe_dataset_name imagenet22k --device cuda --seed ${seed} --lr 0.001 --n_epochs 150 --lr_milestone 100 --lr_milestone 125 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 27 | done 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /src/experiments/cifar10_OE_tinyimages_hsc_sensitivity.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | mkdir ../log/cifar10_oe_tinyimages; 5 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity; 6 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/data_augment; 7 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/no_data_augment; 8 | 9 | norms=( l1 l2 l2_squared l2_squared_linear ); 10 | 11 | for seed in $(seq 1 10); 12 | do 13 | for exp in $(seq 0 9); 14 | do 15 | for norm in "${norms[@]}"; 16 | do 17 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/data_augment; 18 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/data_augment/${norm}; 19 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/data_augment/${norm}/${exp}_vs_rest; 20 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/data_augment/${norm}/${exp}_vs_rest/seed_${seed}; 21 | python main.py cifar10 cifar10_LeNet ../log/cifar10_oe_tinyimages/hsc_sensitivity/data_augment/${norm}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective hsc --hsc_norm ${norm} --outlier_exposure True --oe_dataset_name tinyimages --device cuda --seed ${seed} --lr 0.001 --n_epochs 200 --lr_milestone 100 --lr_milestone 150 --batch_size 128 --data_augmentation True --data_normalization True --normal_class ${exp}; 22 | 23 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/no_data_augment/${norm}; 24 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/no_data_augment/${norm}/${exp}_vs_rest; 25 | mkdir ../log/cifar10_oe_tinyimages/hsc_sensitivity/no_data_augment/${norm}/${exp}_vs_rest/seed_${seed}; 26 | python main.py cifar10 cifar10_LeNet ../log/cifar10_oe_tinyimages/hsc_sensitivity/no_data_augment/${norm}/${exp}_vs_rest/seed_${seed} ../data --rep_dim 256 --objective hsc --hsc_norm ${norm} --outlier_exposure True --oe_dataset_name tinyimages --device cuda --seed ${seed} --lr 0.001 --n_epochs 200 --lr_milestone 100 --lr_milestone 150 --batch_size 128 --data_augmentation False --data_normalization True --normal_class ${exp}; 27 | done 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /src/datasets/array_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from base.base_dataset import BaseADDataset 3 | 4 | 5 | class Array_Dataset(BaseADDataset): 6 | 7 | def __init__(self, X_train, y_train, y_semi_train, X_test=None, y_test=None, y_semi_test=None): 8 | super().__init__(root='') 9 | 10 | self.shuffle = True 11 | 12 | # Get train set 13 | self.train_set = MyTensorDataset(X_train, y_train, y_semi_train) 14 | 15 | # Get test set 16 | if X_test is None: 17 | self.test_set = None 18 | else: 19 | self.test_set = MyTensorDataset(X_test, y_test, y_semi_test) 20 | 21 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0, 22 | pin_memory: bool = False) -> (DataLoader, DataLoader): 23 | train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train, 24 | num_workers=num_workers, pin_memory=pin_memory, drop_last=True) 25 | test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test, 26 | num_workers=num_workers, pin_memory=pin_memory, drop_last=False) 27 | return train_loader, test_loader 28 | 29 | 30 | class MyTensorDataset(Dataset): 31 | """Dataset wrapping tensors. 32 | 33 | Each sample will be retrieved by indexing tensors along the first dimension. 34 | 35 | Arguments: 36 | *tensors: Triple of (X, target, semi_target) tensors that have the same size of the first dimension. 37 | """ 38 | 39 | def __init__(self, *tensors): 40 | 41 | assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) 42 | assert len(tensors) == 3 43 | 44 | self.tensors = tensors 45 | self.shuffle_idxs = False 46 | 47 | def __getitem__(self, index): 48 | """ 49 | Args: 50 | index (int): Index 51 | 52 | Returns: 53 | tuple: (x, target, semi_target, index) 54 | """ 55 | return self.tensors[0][index], self.tensors[1][index], self.tensors[2][index], index 56 | 57 | def __len__(self): 58 | return self.tensors[0].size(0) 59 | -------------------------------------------------------------------------------- /src/utils/diag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from utils.visualization.plot_fun import plot_images_grid 5 | 6 | 7 | def plot_extreme_samples(data, scores, labels, idx, xp_path, n=32, prefix='', suffix=''): 8 | """ 9 | Plots the n extreme (most anomalous and most normal) train and test set samples. 10 | 11 | :param data: The data of type torch.Tensor or np.ndarray. 12 | :param scores: iterable with scores. 13 | :param labels: iterable with labels. 14 | :param idx: iterable with indices. 15 | :param xp_path: export path for image files. 16 | :param n: number of extreme samples to plot. 17 | :param prefix: option to add some prefix to filenames. 18 | :param suffix: option to add some suffix to filenames. 19 | """ 20 | 21 | if not (torch.is_tensor(data)): 22 | data = torch.tensor(data) # convert to torch.Tensor if np.ndarray 23 | if data.dim() == 3: # single-channel images with B x H x W 24 | data = data.unsqueeze(1) # add channel C dimension 25 | if data.dim() == 4 and not data.size(1) in (1, 3): 26 | data = data.permute(0, 3, 1, 2) # Convert from (B x H x W x C) to (B x C x H x W) 27 | 28 | # Overall data 29 | idx_sorted = idx[np.argsort(scores)] # by score, from lowest to highest 30 | X_low = data[idx_sorted[:n], ...] 31 | X_high = data[idx_sorted[-n:], ...] 32 | plot_images_grid(X_low, xp_path=xp_path, filename= prefix + 'all_low' + suffix, padding=2) 33 | plot_images_grid(X_high, xp_path=xp_path, filename=prefix + 'all_high' + suffix, padding=2) 34 | 35 | # Normal samples (within-class scoring) 36 | idx_sorted = idx[labels == 0][np.argsort(scores[labels == 0])] # by score, from lowest to highest 37 | X_low = data[idx_sorted[:n], ...] 38 | X_high = data[idx_sorted[-n:], ...] 39 | plot_images_grid(X_low, xp_path=xp_path, filename=prefix + 'normal_low' + suffix, padding=2) 40 | plot_images_grid(X_high, xp_path=xp_path, filename=prefix + 'normal_high' + suffix, padding=2) 41 | 42 | # Outlier samples (out-of-class scoring) 43 | if np.sum(labels) > 0: 44 | idx_sorted = idx[labels == 1][np.argsort(scores[labels == 1])] # by score, from lowest to highest 45 | X_low = data[idx_sorted[:n], ...] 46 | X_high = data[idx_sorted[-n:], ...] 47 | plot_images_grid(X_low, xp_path=xp_path, filename=prefix + 'outlier_low' + suffix, padding=2) 48 | plot_images_grid(X_high, xp_path=xp_path, filename=prefix + 'outlier_high' + suffix, padding=2) 49 | -------------------------------------------------------------------------------- /src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from PIL import Image 3 | from torchvision.datasets import MNIST 4 | from base.torchvision_dataset import TorchvisionDataset 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class MNIST_Dataset(TorchvisionDataset): 12 | 13 | def __init__(self, root: str, normal_class: int = 0): 14 | super().__init__(root) 15 | 16 | self.image_size = (1, 28, 28) 17 | 18 | # Define normal and outlier classes 19 | self.n_classes = 2 # 0: normal, 1: outlier 20 | self.normal_classes = tuple([normal_class]) 21 | self.outlier_classes = list(range(0, 10)) 22 | self.outlier_classes.remove(normal_class) 23 | self.outlier_classes = tuple(self.outlier_classes) 24 | 25 | # MNIST preprocessing: feature scaling to [0, 1] 26 | transform = transforms.ToTensor() 27 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) 28 | 29 | # Get train set 30 | train_set = MyMNIST(root=self.root, train=True, transform=transform, target_transform=target_transform, 31 | download=True) 32 | 33 | # Subset train_set to normal_classes 34 | idx = np.argwhere(np.isin(train_set.targets.cpu().data.numpy(), self.normal_classes)) 35 | idx = idx.flatten().tolist() 36 | train_set.semi_targets[idx] = torch.zeros(len(idx)).long() 37 | self.train_set = Subset(train_set, idx) 38 | 39 | # Get test set 40 | self.test_set = MyMNIST(root=self.root, train=False, transform=transform, target_transform=target_transform, 41 | download=True) 42 | 43 | 44 | class MyMNIST(MNIST): 45 | """ 46 | Torchvision MNIST class with additional targets for the outlier exposure setting and patch of __getitem__ method 47 | to also return the outlier exposure target as well as the index of a data sample. 48 | """ 49 | 50 | def __init__(self, *args, **kwargs): 51 | super(MyMNIST, self).__init__(*args, **kwargs) 52 | 53 | self.semi_targets = torch.zeros_like(self.targets) 54 | 55 | def __getitem__(self, index): 56 | """Override the original method of the MNIST class. 57 | Args: 58 | index (int): Index 59 | 60 | Returns: 61 | tuple: (image, target, semi_target, index) 62 | """ 63 | img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index]) 64 | 65 | # doing this so that it is consistent with all other datasets 66 | # to return a PIL Image 67 | img = Image.fromarray(img.numpy(), mode='L') 68 | 69 | if self.transform is not None: 70 | img = self.transform(img) 71 | 72 | if self.target_transform is not None: 73 | target = self.target_transform(target) 74 | 75 | return img, target, semi_target, index 76 | -------------------------------------------------------------------------------- /src/datasets/main.py: -------------------------------------------------------------------------------- 1 | from .mnist import MNIST_Dataset 2 | from .emnist import EMNIST_Dataset 3 | from .cifar10 import CIFAR10_Dataset 4 | from .cifar100 import CIFAR100_Dataset 5 | from .tinyimages import TinyImages_Dataset 6 | from .imagenet1k import ImageNet1K_Dataset 7 | from .imagenet22k import ImageNet22K_Dataset 8 | 9 | 10 | def load_dataset(dataset_name, data_path, normal_class, data_augmentation: bool = False, normalize: bool = False, 11 | seed=None, outlier_exposure: bool = False, oe_size: int = 79302016, oe_n_classes: int = -1, 12 | blur_oe: bool = False, blur_std: float = 1.0): 13 | """Loads the dataset.""" 14 | 15 | implemented_datasets = ('mnist', 'emnist', 'cifar10', 'cifar100', 'tinyimages', 'imagenet1k', 'imagenet22k') 16 | assert dataset_name in implemented_datasets 17 | 18 | # Set default number of OE classes if oe_n_classes == -1 19 | if oe_n_classes == -1: 20 | if dataset_name == 'emnist': 21 | oe_n_classes = 26 22 | if dataset_name == 'cifar100': 23 | oe_n_classes = 100 24 | 25 | dataset = None 26 | 27 | if dataset_name == 'mnist': 28 | dataset = MNIST_Dataset(root=data_path, normal_class=normal_class) 29 | 30 | if dataset_name == 'emnist': 31 | dataset = EMNIST_Dataset(root=data_path, 32 | normal_class=normal_class, 33 | outlier_exposure=outlier_exposure, 34 | oe_n_classes=oe_n_classes, 35 | blur_oe=blur_oe, 36 | blur_std=blur_std, 37 | seed=seed) 38 | 39 | if dataset_name == 'cifar10': 40 | dataset = CIFAR10_Dataset(root=data_path, 41 | normal_class=normal_class, 42 | data_augmentation=data_augmentation, 43 | normalize=normalize) 44 | 45 | if dataset_name == 'cifar100': 46 | dataset = CIFAR100_Dataset(root=data_path, 47 | normal_class=normal_class, 48 | data_augmentation=data_augmentation, 49 | normalize=normalize, 50 | outlier_exposure=outlier_exposure, 51 | oe_n_classes=oe_n_classes, 52 | seed=seed) 53 | 54 | if dataset_name == 'tinyimages': 55 | dataset = TinyImages_Dataset(root=data_path, 56 | data_augmentation=data_augmentation, 57 | normalize=normalize, 58 | size=oe_size, 59 | blur_oe=blur_oe, 60 | blur_std=blur_std, 61 | seed=seed) 62 | 63 | if dataset_name == 'imagenet1k': 64 | dataset = ImageNet1K_Dataset(root=data_path, 65 | normal_class=normal_class, 66 | data_augmentation=data_augmentation, 67 | normalize=normalize) 68 | 69 | if dataset_name == 'imagenet22k': 70 | dataset = ImageNet22K_Dataset(root=data_path, 71 | data_augmentation=data_augmentation, 72 | normalize=normalize, 73 | size=oe_size, 74 | blur_oe=blur_oe, 75 | blur_std=blur_std, 76 | seed=seed) 77 | 78 | return dataset 79 | -------------------------------------------------------------------------------- /src/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from PIL import Image 3 | from torchvision.datasets import CIFAR10 4 | from base.torchvision_dataset import TorchvisionDataset 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class CIFAR10_Dataset(TorchvisionDataset): 12 | 13 | def __init__(self, root: str, normal_class: int = 5, data_augmentation: bool = False, normalize: bool = False): 14 | super().__init__(root) 15 | 16 | self.image_size = (3, 32, 32) 17 | 18 | # Define normal and outlier classes 19 | self.n_classes = 2 # 0: normal, 1: outlier 20 | self.normal_classes = tuple([normal_class]) 21 | self.outlier_classes = list(range(0, 10)) 22 | self.outlier_classes.remove(normal_class) 23 | self.outlier_classes = tuple(self.outlier_classes) 24 | 25 | # CIFAR-10 preprocessing: feature scaling to [0, 1], data normalization, and data augmentation 26 | train_transform = [] 27 | test_transform = [] 28 | if data_augmentation: 29 | # only augment training data 30 | train_transform += [transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), 31 | transforms.RandomHorizontalFlip(p=0.5), 32 | transforms.RandomCrop(32, padding=4)] 33 | train_transform += [transforms.ToTensor()] 34 | test_transform += [transforms.ToTensor()] 35 | if data_augmentation: 36 | train_transform += [transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x))] 37 | if normalize: 38 | train_transform += [transforms.Normalize((0.491373, 0.482353, 0.446667), (0.247059, 0.243529, 0.261569))] 39 | test_transform += [transforms.Normalize((0.491373, 0.482353, 0.446667), (0.247059, 0.243529, 0.261569))] 40 | train_transform = transforms.Compose(train_transform) 41 | test_transform = transforms.Compose(test_transform) 42 | 43 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) 44 | 45 | # Get train set 46 | train_set = MyCIFAR10(root=self.root, train=True, transform=train_transform, target_transform=target_transform, 47 | download=True) 48 | 49 | # Subset train_set to normal_classes 50 | idx = np.argwhere(np.isin(np.array(train_set.targets), self.normal_classes)) 51 | idx = idx.flatten().tolist() 52 | train_set.semi_targets[idx] = torch.zeros(len(idx)).long() 53 | self.train_set = Subset(train_set, idx) 54 | 55 | # Get test set 56 | self.test_set = MyCIFAR10(root=self.root, train=False, transform=test_transform, 57 | target_transform=target_transform, download=True) 58 | 59 | 60 | class MyCIFAR10(CIFAR10): 61 | """ 62 | Torchvision CIFAR10 class with additional targets for the outlier exposure setting and patch of __getitem__ method 63 | to also return the outlier exposure target as well as the index of a data sample. 64 | """ 65 | 66 | def __init__(self, *args, **kwargs): 67 | super(MyCIFAR10, self).__init__(*args, **kwargs) 68 | 69 | self.semi_targets = torch.zeros(len(self.targets), dtype=torch.int64) 70 | 71 | def __getitem__(self, index): 72 | """Override the original method of the CIFAR10 class. 73 | Args: 74 | index (int): Index 75 | 76 | Returns: 77 | tuple: (image, target, semi_target, index) 78 | """ 79 | img, target, semi_target = self.data[index], self.targets[index], int(self.semi_targets[index]) 80 | 81 | # doing this so that it is consistent with all other datasets 82 | # to return a PIL Image 83 | img = Image.fromarray(img) 84 | 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | 88 | if self.target_transform is not None: 89 | target = self.target_transform(target) 90 | 91 | return img, target, semi_target, index 92 | -------------------------------------------------------------------------------- /src/classifier.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | 4 | from base.base_dataset import BaseADDataset 5 | from networks.main import build_network 6 | from optim import ClassifierTrainer 7 | 8 | 9 | class Classifier(object): 10 | """A class for an anomaly detection classifier model. 11 | 12 | Attributes: 13 | objective: Hypersphere ('hsc'), binary cross-entropy ('bce'), or focal loss ('focal') classifier. 14 | hsc_norm: Set specific norm to use with HSC ('l1', 'l2', 'l2_squared', 'l2_squared_linear'). 15 | focal_gamma: Gamma parameter of the focal loss. 16 | net_name: A string indicating the name of the neural network to use. 17 | net: The neural network. 18 | trainer: ClassifierTrainer to train the classifier model. 19 | optimizer_name: A string indicating the optimizer to use for training. 20 | results: A dictionary to save the results. 21 | """ 22 | 23 | def __init__(self, objective: str, hsc_norm: str = 'l2_squared_linear', focal_gamma: float = 2.0): 24 | """Inits Classifier.""" 25 | 26 | self.objective = objective 27 | 28 | self.hsc_norm = hsc_norm 29 | self.focal_gamma = focal_gamma 30 | 31 | self.net_name = None 32 | self.net = None 33 | 34 | self.trainer = None 35 | self.optimizer_name = None 36 | 37 | self.results = { 38 | 'train_time': None, 39 | 'train_scores': None, 40 | 'test_time': None, 41 | 'test_scores': None, 42 | 'test_auc': None 43 | } 44 | 45 | def set_network(self, net_name, rep_dim=64, bias_terms=False): 46 | """Builds the neural network.""" 47 | self.net_name = net_name 48 | self.net = build_network(net_name, rep_dim=rep_dim, bias_terms=bias_terms) 49 | 50 | def train(self, dataset: BaseADDataset, oe_dataset: BaseADDataset = None, optimizer_name: str = 'adam', 51 | lr: float = 0.001, n_epochs: int = 50, lr_milestones: tuple = (), batch_size: int = 128, 52 | weight_decay: float = 1e-6, device: str = 'cuda', n_jobs_dataloader: int = 0): 53 | """Trains the classifier on the training data.""" 54 | 55 | self.optimizer_name = optimizer_name 56 | self.trainer = ClassifierTrainer(self.objective, self.hsc_norm, self.focal_gamma, optimizer_name=optimizer_name, 57 | lr=lr, n_epochs=n_epochs, lr_milestones=lr_milestones, batch_size=batch_size, 58 | weight_decay=weight_decay, device=device, n_jobs_dataloader=n_jobs_dataloader) 59 | 60 | # Get results 61 | self.net = self.trainer.train(dataset=dataset, oe_dataset=oe_dataset, net=self.net) 62 | self.results['train_time'] = self.trainer.train_time 63 | self.results['train_scores'] = self.trainer.train_scores 64 | 65 | def test(self, dataset: BaseADDataset, device: str = 'cuda', n_jobs_dataloader: int = 0): 66 | """Tests the Classifier on the test data.""" 67 | 68 | if self.trainer is None: 69 | self.trainer = ClassifierTrainer(self.objective, self.hsc_norm, self.focal_gamma, device=device, 70 | n_jobs_dataloader=n_jobs_dataloader) 71 | 72 | self.trainer.test(dataset, self.net) 73 | 74 | # Get results 75 | self.results['test_time'] = self.trainer.test_time 76 | self.results['test_scores'] = self.trainer.test_scores 77 | self.results['test_auc'] = self.trainer.test_auc 78 | 79 | def save_model(self, export_model): 80 | """Save the classifier model to export_model.""" 81 | net_dict = self.net.state_dict() 82 | torch.save({'net_dict': net_dict}, export_model) 83 | 84 | def load_model(self, model_path, map_location='cpu'): 85 | """Load the classifier model from model_path.""" 86 | model_dict = torch.load(model_path, map_location=map_location) 87 | self.net.load_state_dict(model_dict['net_dict']) 88 | 89 | def save_results(self, export_json): 90 | """Save results dict to a JSON-file.""" 91 | with open(export_json, 'w') as fp: 92 | json.dump(self.results, fp) 93 | -------------------------------------------------------------------------------- /src/networks/cbam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Credits to: https://github.com/hendrycks/ss-ood 7 | class BasicConv(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 9 | bn=True, bias=False): 10 | super(BasicConv, self).__init__() 11 | self.out_channels = out_planes 12 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 13 | dilation=dilation, groups=groups, bias=bias) 14 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 15 | self.relu = nn.ReLU() if relu else None 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | if self.bn is not None: 20 | x = self.bn(x) 21 | if self.relu is not None: 22 | x = self.relu(x) 23 | return x 24 | 25 | 26 | class Flatten(nn.Module): 27 | def forward(self, x): 28 | return x.view(x.size(0), -1) 29 | 30 | 31 | class ChannelGate(nn.Module): 32 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 33 | super(ChannelGate, self).__init__() 34 | self.gate_channels = gate_channels 35 | self.mlp = nn.Sequential( 36 | Flatten(), 37 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 38 | nn.ReLU(), 39 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 40 | ) 41 | self.pool_types = pool_types 42 | 43 | def forward(self, x): 44 | channel_att_sum = None 45 | for pool_type in self.pool_types: 46 | if pool_type == 'avg': 47 | avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 48 | channel_att_raw = self.mlp(avg_pool) 49 | elif pool_type == 'max': 50 | max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 51 | channel_att_raw = self.mlp(max_pool) 52 | elif pool_type == 'lp': 53 | lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 54 | channel_att_raw = self.mlp(lp_pool) 55 | elif pool_type == 'lse': 56 | # LSE pool only 57 | lse_pool = logsumexp_2d(x) 58 | channel_att_raw = self.mlp(lse_pool) 59 | 60 | if channel_att_sum is None: 61 | channel_att_sum = channel_att_raw 62 | else: 63 | channel_att_sum = channel_att_sum + channel_att_raw 64 | 65 | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) 66 | return x * scale 67 | 68 | 69 | def logsumexp_2d(tensor): 70 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 71 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 72 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 73 | return outputs 74 | 75 | 76 | class ChannelPool(nn.Module): 77 | def forward(self, x): 78 | return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 79 | 80 | 81 | class SpatialGate(nn.Module): 82 | def __init__(self): 83 | super(SpatialGate, self).__init__() 84 | kernel_size = 7 85 | self.compress = ChannelPool() 86 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False) 87 | 88 | def forward(self, x): 89 | x_compress = self.compress(x) 90 | x_out = self.spatial(x_compress) 91 | scale = torch.sigmoid(x_out) # broadcasting 92 | return x * scale 93 | 94 | 95 | class CBAM(nn.Module): 96 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 97 | super(CBAM, self).__init__() 98 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 99 | self.no_spatial = no_spatial 100 | if not no_spatial: 101 | self.SpatialGate = SpatialGate() 102 | 103 | def forward(self, x): 104 | x_out = self.ChannelGate(x) 105 | if not self.no_spatial: 106 | x_out = self.SpatialGate(x_out) 107 | return x_out 108 | -------------------------------------------------------------------------------- /src/datasets/emnist.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from PIL import Image 3 | from torchvision.datasets import EMNIST 4 | from base.torchvision_dataset import TorchvisionDataset 5 | from PIL.ImageFilter import GaussianBlur 6 | 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms as transforms 10 | import random 11 | 12 | 13 | class EMNIST_Dataset(TorchvisionDataset): 14 | 15 | def __init__(self, root: str, split: str = 'letters', normal_class: int = 1, outlier_exposure: bool = False, 16 | oe_n_classes: int = 26, blur_oe: bool = False, blur_std: float = 1.0, seed: int = 0): 17 | super().__init__(root) 18 | 19 | self.image_size = (1, 28, 28) 20 | 21 | self.n_classes = 2 # 0: normal, 1: outlier 22 | self.shuffle = True 23 | self.split = split 24 | random.seed(seed) # set seed 25 | 26 | if outlier_exposure: 27 | self.normal_classes = None 28 | self.outlier_classes = list(range(1, 27)) 29 | self.known_outlier_classes = tuple(random.sample(self.outlier_classes, oe_n_classes)) 30 | else: 31 | # Define normal and outlier classes 32 | self.normal_classes = tuple([normal_class]) 33 | self.outlier_classes = list(range(1, 27)) 34 | self.outlier_classes.remove(normal_class) 35 | self.outlier_classes = tuple(self.outlier_classes) 36 | 37 | # EMNIST preprocessing: feature scaling to [0, 1] 38 | transform = [] 39 | if blur_oe: 40 | transform += [transforms.Lambda(lambda x: x.filter(GaussianBlur(radius=blur_std)))] 41 | transform += [transforms.ToTensor()] 42 | transform = transforms.Compose(transform) 43 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) 44 | 45 | # Get train set 46 | train_set = MyEMNIST(root=self.root, split=self.split, train=True, transform=transform, 47 | target_transform=target_transform, download=True) 48 | 49 | if outlier_exposure: 50 | idx = np.argwhere(np.isin(train_set.targets.cpu().data.numpy(), self.known_outlier_classes)) 51 | idx = idx.flatten().tolist() 52 | train_set.semi_targets[idx] = -1 * torch.ones(len(idx)).long() # set outlier exposure labels 53 | 54 | # Subset train_set to selected classes 55 | self.train_set = Subset(train_set, idx) 56 | self.train_set.shuffle_idxs = False 57 | self.test_set = None 58 | else: 59 | # Subset train_set to normal_classes 60 | idx = np.argwhere(np.isin(train_set.targets.cpu().data.numpy(), self.normal_classes)) 61 | idx = idx.flatten().tolist() 62 | train_set.semi_targets[idx] = torch.zeros(len(idx)).long() 63 | self.train_set = Subset(train_set, idx) 64 | 65 | # Get test set 66 | self.test_set = MyEMNIST(root=self.root, split=self.split, train=False, transform=transform, 67 | target_transform=target_transform, download=True) 68 | 69 | 70 | class MyEMNIST(EMNIST): 71 | """ 72 | Torchvision EMNIST class with additional targets for the outlier exposure setting and patch of __getitem__ method 73 | to also return the outlier exposure target as well as the index of a data sample. 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | super(MyEMNIST, self).__init__(*args, **kwargs) 78 | 79 | self.semi_targets = torch.zeros_like(self.targets) 80 | self.shuffle_idxs = False 81 | 82 | def __getitem__(self, index): 83 | """Override the original method of the EMNIST class. 84 | Args: 85 | index (int): Index 86 | 87 | Returns: 88 | tuple: (image, target, semi_target, index) 89 | """ 90 | img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index]) 91 | 92 | # doing this so that it is consistent with all other datasets 93 | # to return a PIL Image 94 | img = Image.fromarray(img.numpy(), mode='L') 95 | 96 | if self.transform is not None: 97 | img = self.transform(img) 98 | 99 | if self.target_transform is not None: 100 | target = self.target_transform(target) 101 | 102 | return img, target, semi_target, index 103 | -------------------------------------------------------------------------------- /src/datasets/imagenet1k.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from torchvision.datasets import ImageFolder 3 | from base.torchvision_dataset import TorchvisionDataset 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class ImageNet1K_Dataset(TorchvisionDataset): 11 | 12 | def __init__(self, root: str, normal_class: int = 0, data_augmentation: bool = False, normalize: bool = False): 13 | super().__init__(root) 14 | 15 | classes = ['acorn', 'airliner', 'ambulance', 'american_alligator', 'banjo', 'barn', 'bikini', 'digital_clock', 16 | 'dragonfly', 'dumbbell', 'forklift', 'goblet', 'grand_piano', 'hotdog', 'hourglass', 'manhole_cover', 17 | 'mosque', 'nail', 'parking_meter', 'pillow', 'revolver', 'rotary_dial_telephone', 'schooner', 18 | 'snowmobile', 'soccer_ball', 'stingray', 'strawberry', 'tank', 'toaster', 'volcano'] 19 | 20 | self.image_size = (3, 224, 224) 21 | 22 | # Define normal and outlier classes 23 | self.n_classes = 2 # 0: normal, 1: outlier 24 | self.normal_classes = tuple([normal_class]) 25 | self.outlier_classes = list(range(0, 30)) 26 | self.outlier_classes.remove(normal_class) 27 | self.outlier_classes = tuple(self.outlier_classes) 28 | 29 | # ImageNet preprocessing: feature scaling to [0, 1], data normalization, and data augmentation 30 | train_transform = [transforms.Resize(256)] 31 | test_transform = [transforms.Resize(256), 32 | transforms.CenterCrop(224)] 33 | if data_augmentation: 34 | # only augment training data 35 | train_transform += [transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), 36 | transforms.RandomHorizontalFlip(p=0.5), 37 | transforms.RandomCrop(224)] 38 | else: 39 | train_transform += [transforms.CenterCrop(224)] 40 | train_transform += [transforms.ToTensor()] 41 | test_transform += [transforms.ToTensor()] 42 | if data_augmentation: 43 | train_transform += [transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x))] 44 | if normalize: 45 | train_transform += [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 46 | test_transform += [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 47 | train_transform = transforms.Compose(train_transform) 48 | test_transform = transforms.Compose(test_transform) 49 | 50 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) 51 | 52 | # Get train set 53 | train_set = MyImageNet1K(root=self.root + '/imagenet1k/one_class_train', transform=train_transform, 54 | target_transform=target_transform) 55 | 56 | # Subset train_set to normal_classes 57 | idx = np.argwhere(np.isin(np.array(train_set.targets), self.normal_classes)) 58 | idx = idx.flatten().tolist() 59 | train_set.semi_targets[idx] = torch.zeros(len(idx)).long() 60 | self.train_set = Subset(train_set, idx) 61 | 62 | # Get test set 63 | self.test_set = MyImageNet1K(root=self.root + '/imagenet1k/one_class_test', transform=test_transform, 64 | target_transform=target_transform) 65 | 66 | 67 | class MyImageNet1K(ImageFolder): 68 | """ 69 | Torchvision ImageFolder class with additional targets for the outlier exposure setting and patch of __getitem__ 70 | method to also return the outlier exposure target as well as the index of a data sample. 71 | """ 72 | 73 | def __init__(self, *args, **kwargs): 74 | super(MyImageNet1K, self).__init__(*args, **kwargs) 75 | 76 | self.semi_targets = torch.zeros(len(self.targets), dtype=torch.int64) 77 | 78 | def __getitem__(self, index): 79 | """Override the original method of the ImageFolder class. 80 | Args: 81 | index (int): Index 82 | 83 | Returns: 84 | tuple: (sample, target, semi_target, index) 85 | """ 86 | path, target = self.samples[index] 87 | sample = self.loader(path) 88 | semi_target = int(self.semi_targets[index]) 89 | 90 | if self.transform is not None: 91 | sample = self.transform(sample) 92 | 93 | if self.target_transform is not None: 94 | target = self.target_transform(target) 95 | 96 | return sample, target, semi_target, index 97 | -------------------------------------------------------------------------------- /src/networks/imagenet_WideResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from base.base_net import BaseNet 4 | from networks.cbam import CBAM 5 | from torch.nn import init 6 | 7 | 8 | # Credits to: https://github.com/hendrycks/ss-ood 9 | class ImageNet_WideResNet(BaseNet): 10 | 11 | def __init__(self, rep_dim=256): 12 | self.inplanes = 64 13 | super().__init__() 14 | 15 | self.rep_dim = rep_dim 16 | att_type = 'CBAM' 17 | layers = [2, 2, 2, 2] 18 | 19 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 20 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 21 | self.avgpool = nn.AvgPool2d(7) 22 | self.bn1 = nn.BatchNorm2d(64) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | self.bam1, self.bam2, self.bam3 = None, None, None 26 | 27 | self.layer1 = self._make_layer(BasicBlock, 64, layers[0], att_type=att_type) 28 | self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2, att_type=att_type) 29 | self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2, att_type=att_type) 30 | self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2, att_type=att_type) 31 | 32 | self.fc = nn.Linear(512 * BasicBlock.expansion, self.rep_dim) 33 | 34 | init.kaiming_normal_(self.fc.weight) 35 | for key in self.state_dict(): 36 | if key.split('.')[-1] == "weight": 37 | if "conv" in key: 38 | init.kaiming_normal_(self.state_dict()[key], mode='fan_out') 39 | if "bn" in key: 40 | if "SpatialGate" in key: 41 | self.state_dict()[key][...] = 0 42 | else: 43 | self.state_dict()[key][...] = 1 44 | elif key.split(".")[-1] == 'bias': 45 | self.state_dict()[key][...] = 0 46 | 47 | def _make_layer(self, block, planes, blocks, stride=1, att_type=None): 48 | downsample = None 49 | if stride != 1 or self.inplanes != planes * block.expansion: 50 | downsample = nn.Sequential( 51 | nn.Conv2d(self.inplanes, planes * block.expansion, 52 | kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(planes * block.expansion), 54 | ) 55 | 56 | layers = [] 57 | layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type == 'CBAM')) 58 | self.inplanes = planes * block.expansion 59 | for i in range(1, blocks): 60 | layers.append(block(self.inplanes, planes, use_cbam=att_type == 'CBAM')) 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | x = x.view(-1, 3, 224, 224) 66 | x = self.conv1(x) 67 | x = self.bn1(x) 68 | x = self.relu(x) 69 | x = self.maxpool(x) 70 | 71 | x = self.layer1(x) 72 | if not self.bam1 is None: 73 | x = self.bam1(x) 74 | 75 | x = self.layer2(x) 76 | if not self.bam2 is None: 77 | x = self.bam2(x) 78 | 79 | x = self.layer3(x) 80 | if not self.bam3 is None: 81 | x = self.bam3(x) 82 | 83 | x = self.layer4(x) 84 | x = self.avgpool(x) 85 | x = x.view(x.size(0), -1) 86 | return self.fc(x) 87 | 88 | 89 | class BasicBlock(nn.Module): 90 | expansion = 1 91 | 92 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): 93 | super(BasicBlock, self).__init__() 94 | self.conv1 = conv3x3(inplanes, planes, stride) 95 | self.bn1 = nn.BatchNorm2d(planes) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.conv2 = conv3x3(planes, planes) 98 | self.bn2 = nn.BatchNorm2d(planes) 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | if use_cbam: 103 | self.cbam = CBAM(planes, 16) 104 | else: 105 | self.cbam = None 106 | 107 | def forward(self, x): 108 | residual = x 109 | 110 | out = self.conv1(x) 111 | out = self.bn1(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv2(out) 115 | out = self.bn2(out) 116 | 117 | if self.downsample is not None: 118 | residual = self.downsample(x) 119 | 120 | if not self.cbam is None: 121 | out = self.cbam(out) 122 | 123 | out += residual 124 | out = self.relu(out) 125 | 126 | return out 127 | 128 | 129 | def conv3x3(in_planes, out_planes, stride=1): 130 | "3x3 convolution with padding" 131 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 132 | padding=1, bias=False) 133 | -------------------------------------------------------------------------------- /src/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from PIL import Image 3 | from torchvision.datasets import CIFAR100 4 | from base.torchvision_dataset import TorchvisionDataset 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | import random 10 | 11 | 12 | class CIFAR100_Dataset(TorchvisionDataset): 13 | 14 | def __init__(self, root: str, normal_class: int = 0, data_augmentation: bool = False, normalize: bool = False, 15 | outlier_exposure: bool = False, oe_n_classes: int = 100, seed: int = 0): 16 | super().__init__(root) 17 | 18 | self.image_size = (3, 32, 32) 19 | 20 | self.n_classes = 2 # 0: normal, 1: outlier 21 | self.shuffle = True 22 | random.seed(seed) # set seed 23 | 24 | if outlier_exposure: 25 | self.normal_classes = None 26 | self.outlier_classes = list(range(0, 100)) 27 | self.known_outlier_classes = tuple(random.sample(self.outlier_classes, oe_n_classes)) 28 | else: 29 | # Define normal and outlier classes 30 | self.normal_classes = tuple([normal_class]) 31 | self.outlier_classes = list(range(0, 100)) 32 | self.outlier_classes.remove(normal_class) 33 | self.outlier_classes = tuple(self.outlier_classes) 34 | 35 | # CIFAR-100 preprocessing: feature scaling to [0, 1], data normalization, and data augmentation 36 | train_transform = [] 37 | test_transform = [] 38 | if data_augmentation: 39 | # only augment training data 40 | train_transform += [transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), 41 | transforms.RandomHorizontalFlip(p=0.5), 42 | transforms.RandomCrop(32, padding=4)] 43 | train_transform += [transforms.ToTensor()] 44 | test_transform += [transforms.ToTensor()] 45 | if data_augmentation: 46 | train_transform += [transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x))] 47 | if normalize: 48 | train_transform += [transforms.Normalize((0.491373, 0.482353, 0.446667), (0.247059, 0.243529, 0.261569))] 49 | test_transform += [transforms.Normalize((0.491373, 0.482353, 0.446667), (0.247059, 0.243529, 0.261569))] 50 | train_transform = transforms.Compose(train_transform) 51 | test_transform = transforms.Compose(test_transform) 52 | 53 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) 54 | 55 | # Get train set 56 | train_set = MyCIFAR100(root=self.root, train=True, transform=train_transform, target_transform=target_transform, 57 | download=True) 58 | 59 | if outlier_exposure: 60 | idx = np.argwhere(np.isin(np.array(train_set.targets), self.known_outlier_classes)) 61 | idx = idx.flatten().tolist() 62 | train_set.semi_targets[idx] = -1 * torch.ones(len(idx)).long() # set outlier exposure labels 63 | 64 | # Subset train_set to selected classes 65 | self.train_set = Subset(train_set, idx) 66 | self.train_set.shuffle_idxs = False 67 | self.test_set = None 68 | else: 69 | # Subset train_set to normal_classes 70 | idx = np.argwhere(np.isin(np.array(train_set.targets), self.normal_classes)) 71 | idx = idx.flatten().tolist() 72 | train_set.semi_targets[idx] = torch.zeros(len(idx)).long() 73 | self.train_set = Subset(train_set, idx) 74 | 75 | # Get test set 76 | self.test_set = MyCIFAR100(root=self.root, train=False, transform=test_transform, 77 | target_transform=target_transform, download=True) 78 | 79 | 80 | class MyCIFAR100(CIFAR100): 81 | """ 82 | Torchvision CIFAR100 class with additional targets for the outlier exposure setting and patch of __getitem__ method 83 | to also return the outlier exposure target as well as the index of a data sample. 84 | """ 85 | 86 | def __init__(self, *args, **kwargs): 87 | super(MyCIFAR100, self).__init__(*args, **kwargs) 88 | 89 | self.semi_targets = torch.zeros(len(self.targets), dtype=torch.int64) 90 | self.shuffle_idxs = False 91 | 92 | def __getitem__(self, index): 93 | """Override the original method of the CIFAR100 class. 94 | Args: 95 | index (int): Index 96 | 97 | Returns: 98 | tuple: (image, target, semi_target, index) 99 | """ 100 | img, target, semi_target = self.data[index], self.targets[index], int(self.semi_targets[index]) 101 | 102 | # doing this so that it is consistent with all other datasets 103 | # to return a PIL Image 104 | img = Image.fromarray(img) 105 | 106 | if self.transform is not None: 107 | img = self.transform(img) 108 | 109 | if self.target_transform is not None: 110 | target = self.target_transform(target) 111 | 112 | return img, target, semi_target, index 113 | -------------------------------------------------------------------------------- /src/utils/visualization/plot_fun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib 5 | 6 | matplotlib.use('Agg') # or 'PS', 'PDF', 'SVG' 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | from torchvision.utils import make_grid 11 | 12 | 13 | def plot_images_grid(x: torch.tensor, xp_path, filename, title='', nrow=8, padding=2, normalize=False, pad_value=0): 14 | """Plot 4D Tensor of images of shape (B x C x H x W) as a grid.""" 15 | 16 | grid = make_grid(x, nrow=nrow, padding=padding, normalize=normalize, pad_value=pad_value) 17 | npgrid = grid.cpu().numpy() 18 | 19 | plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') 20 | 21 | ax = plt.gca() 22 | ax.xaxis.set_visible(False) 23 | ax.yaxis.set_visible(False) 24 | 25 | if not (title == ''): 26 | plt.title(title) 27 | 28 | plt.savefig(xp_path + '/' + filename, bbox_inches='tight', pad_inches=0.1) 29 | plt.clf() 30 | 31 | 32 | def plot_dist(x, xp_path, filename, target=None, title='', axlabel=None, legendlabel=None): 33 | """ 34 | Plot the univariate distribution (histogram and kde) of values in x. 35 | 36 | :param x: Data as a series, 1d-array, or list. 37 | :param xp_path: Export path for the plot as string. 38 | :param filename: Filename as string. 39 | :param target: Target as series, 1d-array, or list that categorize subplots. Optional. 40 | :param title: Title for the plot as string. Optional. 41 | :param axlabel: Name for the support axis label as string. Optional. 42 | :param legendlabel: Legend label for the relevant component of the plot as string. Optional. 43 | """ 44 | 45 | # Set plot parameters 46 | sns.set() 47 | sns.set_style('white') 48 | sns.set_palette('colorblind') 49 | 50 | # Convert data to pandas DataFrame and set legend labels 51 | if target is not None: 52 | data = {'x': list(x), 'target': list(target)} 53 | columns = ['x', 'target'] 54 | unique_targets = list(set(target)) 55 | if legendlabel is not None: 56 | assert len(legendlabel) == len(unique_targets) 57 | label = legendlabel 58 | else: 59 | label = None 60 | else: 61 | data = {'x': list(x)} 62 | columns = ['x'] 63 | if legendlabel is not None: 64 | assert len(legendlabel) == 1 65 | label = legendlabel 66 | else: 67 | label = None 68 | 69 | dataset = pd.DataFrame(data=data, columns=columns) 70 | 71 | if target is not None: 72 | # sort dataframe by target 73 | df_list = [dataset.loc[dataset['target'] == val] for val in unique_targets] 74 | for i, df in enumerate(df_list): 75 | sns.distplot(df[['x']], norm_hist=True, axlabel=axlabel, label=label[i]) 76 | else: 77 | sns.distplot(dataset, norm_hist=True, axlabel=axlabel, label=label) 78 | 79 | if not (title == ''): 80 | plt.title(title) 81 | plt.legend() 82 | 83 | plt.savefig(xp_path + '/' + filename, bbox_inches='tight') 84 | plt.clf() 85 | 86 | 87 | def plot_line(x, xp_path, filename, title='', xlabel='Epochs', ylabel='Values', legendlabel=None, log_scale=False): 88 | """ 89 | Draw a line plot with grouping options. 90 | 91 | :param x: Data as a series, 1d-array, or list. 92 | :param xp_path: Export path for the plot as string. 93 | :param filename: Filename as string. 94 | :param title: Title for the plot as string. Optional. 95 | :param xlabel: Label for x-axis as string. Optional. 96 | :param ylabel: Label for y-axis as string. Optional. 97 | :param legendlabel: String or list of strings with data series legend labels. Optional. 98 | :param log_scale: Boolean to set y-axis to log-scale. 99 | """ 100 | 101 | # Set plot parameters 102 | sns.set() 103 | sns.set_style('whitegrid') 104 | 105 | # Convert data to pandas DataFrame and set legend labels 106 | data = { 107 | 'x': [], 108 | 'y': [], 109 | 'label': [] 110 | } 111 | 112 | if isinstance(x, list): 113 | n_series = len(x) 114 | 115 | if legendlabel is None: 116 | legendlabel = ['series ' + str(i + 1) for i in range(n_series)] 117 | else: 118 | assert len(legendlabel) == n_series 119 | 120 | for i, series in enumerate(x): 121 | data['x'].extend(list(range(1, len(x[i]) + 1))) 122 | data['y'].extend(list(x[i])) 123 | data['label'].extend([legendlabel[i]] * len(x[i])) 124 | else: 125 | if legendlabel is None: 126 | legendlabel = ['series 1'] 127 | else: 128 | assert len(legendlabel) == 1 129 | 130 | data['x'].extend(list(range(1, len(x) + 1))) 131 | data['y'].extend(list(x)) 132 | data['label'].extend(legendlabel * len(x)) 133 | 134 | df = pd.DataFrame(data, columns=['x', 'y', 'label']) 135 | 136 | sns.lineplot(x='x', y='y', hue='label', data=df, palette='colorblind') 137 | 138 | if log_scale: 139 | plt.yscale('symlog') 140 | plt.grid(True, axis='both') 141 | else: 142 | plt.grid(False, axis='x') 143 | plt.grid(True, axis='y') 144 | 145 | # Add title, axis labels, and legend 146 | if not (title == ''): 147 | plt.title(title) 148 | plt.legend(legendlabel, title=False) 149 | plt.xlabel(xlabel) 150 | plt.ylabel(ylabel) 151 | 152 | plt.savefig(xp_path + '/' + filename, bbox_inches='tight') 153 | plt.clf() 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rethinking Assumptions in Deep Anomaly Detection 2 | 3 | This repository provides the code for the methods and experiments presented in our ICML UDL 2021 workshop paper 'Rethinking Assumptions in Deep Anomaly Detection.' 4 | 5 | ## Citation and Contact 6 | 7 | You find a PDF of the paper on arXiv: [https://arxiv.org/abs/2006.00339](https://arxiv.org/abs/2006.00339). 8 | If you find our work useful, please cite: 9 | ``` 10 | @inproceedings{ruff2020rethinking, 11 | title = {Rethinking Assumptions in Deep Anomaly Detection}, 12 | author = {Ruff, Lukas and Vandermeulen, Robert A and Franks, Billy Joe and M{\"u}ller, Klaus-Robert and Kloft, Marius}, 13 | booktitle = {ICML 2021 Workshop on Uncertainty \& Robustness in Deep Learning}, 14 | year = {2021} 15 | } 16 | ``` 17 | 18 | ## Abstract 19 | 20 | > > Though anomaly detection (AD) can be viewed as a classification problem (nominal vs. anomalous) it is usually treated in an unsupervised manner since one typically does not have access to, or it is infeasible to utilize, a dataset that sufficiently characterizes what it means to be "anomalous." In this paper we present results demonstrating that this intuition surprisingly seems not to extend to deep AD on images. For a recent AD benchmark on ImageNet, classifiers trained to discern between normal samples and just a few (64) random natural images are able to outperform the current state of the art in deep AD. Experimentally we discover that the multiscale structure of image data makes example anomalies exceptionally informative. 21 | 22 | 23 | ## Installation 24 | This code is written in `Python 3.7` and requires the packages listed in `requirements.txt`. For running the code, we recommend to set up a virtual environment, e.g. via `virtualenv` or `conda`, and install the packages therein in the specified versions: 25 | 26 | ### `virtualenv` 27 | 28 | ``` 29 | # pip install virtualenv 30 | cd 31 | virtualenv myenv 32 | source myenv/bin/activate 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ### `conda` 37 | 38 | ``` 39 | cd 40 | conda create --name myenv 41 | conda activate myenv 42 | while read requirement; do conda install -n myenv --yes $requirement; done < requirements.txt 43 | ``` 44 | 45 | ## Data 46 | 47 | We present experiments using the [MNIST](http://yann.lecun.com/exdb/mnist/), [EMNIST](https://www.nist.gov/itl/products-and-services/emnist-dataset), [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html), [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html), [80 Million Tiny Images](https://groups.csail.mit.edu/vision/TinyImages/), [ImageNet-1K](http://www.image-net.org/), and [ImageNet-22K](http://www.image-net.org/) datasets in our paper. These datasets get automatically downloaded to the `./data` directory when experiments are run for the first time on the respective datasets, except ImageNet-1K and ImageNet-22K. The ImageNet-1K one-vs-rest anomaly detection benchmark data can be downloaded from [https://github.com/hendrycks/ss-ood](https://github.com/hendrycks/ss-ood), which is the repository of the paper that introduced the benchmark, and should be placed in the `./data/imagenet1k` directory. 48 | The ImageNet-22K dataset can be downloaded from [http://www.image-net.org](http://www.image-net.org/), which requires a registration. Note that our implementation assumes the ImageNet-22K `*.tar` archives to be extracted into the `./data/fall11_whole_extracted` directory. 49 | 50 | 51 | ## Running experiments 52 | 53 | All the experiments presented in our paper can be run by using the `main.py` script. The specific method (`hsc`, `deepSAD`, `bce`, or `focal`) can be set via the `--objective` option, e.g. `--objective hsc`. 54 | 55 | The `main.py` script features various options and experimental parameters. Have a look into `main.py` for all the possible options and arguments. 56 | 57 | Below, we present two examples for the CIFAR-10 as well as the ImageNet one-vs-rest anomaly detection benchmarks. The complete bash scripts to reproduce all experimental results reported in our paper are given in `./src/experiments`. 58 | 59 | ### CIFAR-10 One-vs-Rest Benchmark using 80 Million Tiny Images as OE 60 | 61 | The following runs a Hypersphere Classifier (`--objective hsc`) experiment on CIFAR-10 with class `0` (airplane) considered to be the normal class with using 80 Million Tiny Images as OE (`--oe_dataset_name tinyimages`): 62 | 63 | ``` 64 | cd 65 | 66 | # activate virtual environment 67 | source myenv/bin/activate # or 'conda activate myenv' for conda 68 | 69 | # create folder for experimental outputs 70 | mkdir log/cifar10_test 71 | 72 | # change to source directory 73 | cd src 74 | 75 | # run experiment 76 | python main.py cifar10 cifar10_LeNet ../log/cifar10_test ../data --rep_dim 256 --objective hsc --outlier_exposure True --oe_dataset_name tinyimages --device cuda --seed 42 --lr 0.001 --n_epochs 200 --lr_milestone 100 --lr_milestone 150 --batch_size 128 --data_augmentation True --data_normalization True --normal_class 0; 77 | ``` 78 | 79 | ### ImageNet-1K One-vs-Rest Benchmark using ImageNet-22K as OE 80 | 81 | The following runs a Binary Cross-Entropy Classifier (`--objective bce`) experiment on ImageNet-1K with class `4` (banjo) considered to be the normal class with using ImageNet-22K as OE (`--oe_dataset_name imagenet22k`): 82 | 83 | ``` 84 | cd 85 | 86 | # activate virtual environment 87 | source myenv/bin/activate # or 'conda activate myenv' for conda 88 | 89 | # create folders for experimental outputs 90 | mkdir log/imagenet_test 91 | 92 | # change to source directory 93 | cd src 94 | 95 | # run classifier experiment 96 | python main.py imagenet1k imagenet_WideResNet ../log/imagenet_test ../data --rep_dim 256 --objective bce --outlier_exposure True --oe_dataset_name imagenet22k --device cuda --seed 42 --lr 0.001 --n_epochs 150 --lr_milestone 100 --lr_milestone 125 --batch_size 128 --data_augmentation True --data_normalization True --normal_class 4; 97 | ``` 98 | 99 | ## License 100 | MIT 101 | -------------------------------------------------------------------------------- /src/datasets/tinyimages.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.vision import VisionDataset 2 | from torchvision.datasets.utils import download_url, check_integrity 3 | from base.torchvision_dataset import TorchvisionDataset 4 | from PIL.ImageFilter import GaussianBlur 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | import random 10 | import os 11 | 12 | 13 | class TinyImages_Dataset(TorchvisionDataset): 14 | 15 | def __init__(self, root: str, data_augmentation: bool = True, normalize: bool = False, size: int = 79302016, 16 | blur_oe: bool = False, blur_std: float = 1.0, seed: int = 0): 17 | super().__init__(root) 18 | 19 | self.image_size = (3, 32, 32) 20 | 21 | self.n_classes = 1 # only class 1: outlier since 80 Million Tiny Images is used for outlier exposure 22 | self.shuffle = False 23 | self.size = size 24 | 25 | # TinyImages preprocessing: feature scaling to [0, 1] and data augmentation if specified 26 | transform = [transforms.ToTensor(), 27 | transforms.ToPILImage()] 28 | if data_augmentation: 29 | transform += [transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), 30 | transforms.RandomHorizontalFlip(p=0.5), 31 | transforms.RandomCrop(32, padding=4)] 32 | else: 33 | transform += [transforms.CenterCrop(32)] 34 | if blur_oe: 35 | transform += [transforms.Lambda(lambda x: x.filter(GaussianBlur(radius=blur_std)))] 36 | transform += [transforms.ToTensor()] 37 | if data_augmentation: 38 | transform += [transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x))] 39 | if normalize: 40 | # CIFAR-10 mean and std 41 | transform += [transforms.Normalize((0.491373, 0.482353, 0.446667), (0.247059, 0.243529, 0.261569))] 42 | transform = transforms.Compose(transform) 43 | 44 | # Get dataset 45 | self.train_set = TinyImages(root=self.root, size=self.size, transform=transform, download=True, seed=seed) 46 | self.test_set = None 47 | 48 | 49 | class TinyImages(VisionDataset): 50 | """`80 Million Tiny Images `_ Dataset. 51 | 52 | VisionDataset class with additional targets for the outlier exposure setting and patch of __getitem__ method 53 | to also return the outlier exposure target as well as the index of a data sample. 54 | 55 | Args: 56 | root (string): Root directory of dataset where ``tiny_images.bin`` file exists or will be saved to 57 | if download is set to True. 58 | size (int, optional): Set the dataset sample size. Default = 79302016 (full dataset). 59 | exclude_cifar (bool, optional): If true, exclude the CIFAR samples from the 80 million tiny images dataset. 60 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 61 | version. E.g, ``transforms.RandomCrop`` 62 | download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If 63 | dataset is already downloaded, it is not downloaded again. 64 | seed (int, optional): Seed for drawing dataset sample if size is not full dataset 65 | """ 66 | url = 'http://horatio.cs.nyu.edu/mit/tiny/data/tiny_images.bin' 67 | filename = 'tiny_images.bin' 68 | 69 | def __init__(self, root, size: int = 79302016, exclude_cifar=True, transform=None, download=False, seed: int = 0): 70 | 71 | super(TinyImages, self).__init__(root) 72 | self.size = size 73 | self.exclude_cifar = exclude_cifar 74 | self.transform = transform 75 | 76 | # Draw random permutation of indices of self.size if not full dataset 77 | self.shuffle_idxs = True 78 | if self.size < 79302016: 79 | random.seed(seed) 80 | self.idxs = random.sample(range(79302016), self.size) # set seed to have a fair comparison across models 81 | else: 82 | self.idxs = list(range(79302016)) 83 | 84 | if download: 85 | self.download() 86 | 87 | data_file = open(os.path.join(root, self.filename), 'rb') 88 | 89 | def load_image(idx): 90 | data_file.seek(idx * 3072) 91 | data = data_file.read(3072) 92 | return np.fromstring(data, dtype='uint8').reshape(32, 32, 3, order='F') 93 | 94 | self.load_image = load_image 95 | self.offset = 0 # offset index 96 | 97 | if exclude_cifar: 98 | self.cifar_idxs = [] 99 | with open(os.path.join(root, '80mn_cifar_idxs.txt'), 'r') as idxs: 100 | for idx in idxs: 101 | # indices in file take the 80mn database to start at 1, hence "- 1" 102 | self.cifar_idxs.append(int(idx) - 1) 103 | 104 | # hash table option 105 | self.cifar_idxs = set(self.cifar_idxs) 106 | self.in_cifar = lambda x: x in self.cifar_idxs 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index 112 | 113 | Returns: 114 | tuple: (image, target, semi_target, index) 115 | """ 116 | index = (index + self.offset) % self.size 117 | index = self.idxs[index] 118 | 119 | if self.exclude_cifar: 120 | while self.in_cifar(index): 121 | index = np.random.randint(79302016) 122 | 123 | img = self.load_image(index) 124 | if self.transform is not None: 125 | img = self.transform(img) 126 | 127 | return img, 1, -1, index 128 | 129 | def __len__(self): 130 | return 79302016 131 | 132 | def _check_integrity(self): 133 | root = self.root 134 | filename = self.filename 135 | fpath = os.path.join(root, filename) 136 | return check_integrity(fpath) 137 | 138 | def download(self): 139 | 140 | if self._check_integrity(): 141 | print('Files already downloaded and verified') 142 | return 143 | 144 | download_url(self.url, self.root, self.filename) 145 | -------------------------------------------------------------------------------- /src/datasets/imagenet22k.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | from base.torchvision_dataset import TorchvisionDataset 3 | from PIL.ImageFilter import GaussianBlur 4 | 5 | import torch 6 | import torchvision.transforms as transforms 7 | import random 8 | 9 | 10 | class ImageNet22K_Dataset(TorchvisionDataset): 11 | 12 | def __init__(self, root: str, data_augmentation: bool = True, normalize: bool = False, size: int = 14155519, 13 | blur_oe: bool = False, blur_std: float = 1.0, seed: int = 0): 14 | super().__init__(root) 15 | 16 | self.image_size = (3, 224, 224) 17 | 18 | self.n_classes = 1 # only class 1: outlier since ImageNet22K is used for outlier exposure 19 | self.shuffle = False 20 | self.size = size 21 | 22 | # ImageNet preprocessing: feature scaling to [0, 1], data normalization, and data augmentation 23 | transform = [transforms.Resize(256)] 24 | if data_augmentation: 25 | transform += [transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), 26 | transforms.RandomHorizontalFlip(p=0.5), 27 | transforms.RandomCrop(224)] 28 | else: 29 | transform += [transforms.CenterCrop(224)] 30 | if blur_oe: 31 | transform += [transforms.Lambda(lambda x: x.filter(GaussianBlur(radius=blur_std)))] 32 | transform += [transforms.ToTensor()] 33 | if data_augmentation: 34 | transform += [transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x))] 35 | if normalize: 36 | transform += [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 37 | transform = transforms.Compose(transform) 38 | 39 | # Get dataset 40 | self.train_set = MyImageNet22K(root=self.root + '/fall11_whole_extracted', size=self.size, transform=transform, 41 | seed=seed) 42 | self.test_set = None 43 | 44 | 45 | class MyImageNet22K(ImageFolder): 46 | """ 47 | Torchvision ImageFolder class with additional targets for the outlier exposure setting and patch of __getitem__ 48 | method to also return the outlier exposure target as well as the index of a data sample. 49 | 50 | Args: 51 | root (string): Root directory ``fall11_whole_extracted`` of the ImageNet22K dataset. 52 | size (int, optional): Set the dataset sample size. Default = 14155519 (full dataset; excluding ImageNet1K). 53 | exclude_imagenet1k (bool, optional): If true, exclude the ImageNet1K samples from the ImageNet22K dataset. 54 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 55 | version. E.g, ``transforms.RandomCrop`` 56 | seed (int, optional): Seed for drawing dataset sample if size is not full dataset 57 | """ 58 | imagenet1k_pairs = [('acorn', 'n12267677'), 59 | ('airliner', 'n02690373'), 60 | ('ambulance', 'n02701002'), 61 | ('american_alligator', 'n01698640'), 62 | ('banjo', 'n02787622'), 63 | ('barn', 'n02793495'), 64 | ('bikini', 'n02837789'), 65 | ('digital_clock', 'n03196217'), 66 | ('dragonfly', 'n02268443'), 67 | ('dumbbell', 'n03255030'), 68 | ('forklift', 'n03384352'), 69 | ('goblet', 'n03443371'), 70 | ('grand_piano', 'n03452741'), 71 | ('hotdog', 'n07697537'), 72 | ('hourglass', 'n03544143'), 73 | ('manhole_cover', 'n03717622'), 74 | ('mosque', 'n03788195'), 75 | ('nail', 'n03804744'), 76 | ('parking_meter', 'n03891332'), 77 | ('pillow', 'n03938244'), 78 | ('revolver', 'n04086273'), 79 | ('rotary_dial_telephone', 'n03187595'), 80 | ('schooner', 'n04147183'), 81 | ('snowmobile', 'n04252077'), 82 | ('soccer_ball', 'n04254680'), 83 | ('stingray', 'n01498041'), 84 | ('strawberry', 'n07745940'), 85 | ('tank', 'n04389033'), 86 | ('toaster', 'n04442312'), 87 | ('volcano', 'n09472597')] 88 | imagenet1k_labels = [label for name, label in imagenet1k_pairs] 89 | 90 | def __init__(self, size: int = 14155519, exclude_imagenet1k=True, seed: int = 0, *args, **kwargs): 91 | 92 | super(MyImageNet22K, self).__init__(*args, **kwargs) 93 | self.size = size 94 | self.exclude_imagenet1k = exclude_imagenet1k 95 | 96 | if exclude_imagenet1k: 97 | imagenet1k_idxs = tuple([self.class_to_idx.get(label) for label in self.imagenet1k_labels]) 98 | self.samples = [s for s in self.samples if s[1] not in imagenet1k_idxs] # s = ('', idx) pair 99 | self.targets = [s[1] for s in self.samples] 100 | self.imgs = self.samples 101 | 102 | for label in self.imagenet1k_labels: 103 | try: 104 | self.classes.remove(label) 105 | del self.class_to_idx[label] 106 | except: 107 | pass 108 | 109 | # Draw random permutation of indices of self.size if not full dataset 110 | self.shuffle_idxs = True 111 | if self.size < 14155519: 112 | random.seed(seed) # set seed to have a fair comparison across models 113 | self.idxs = random.sample(range(len(self.samples)), self.size) 114 | else: 115 | self.idxs = list(range(len(self.samples))) 116 | 117 | self.offset = 0 # offset index 118 | 119 | def __getitem__(self, index): 120 | """Override the original method of the ImageFolder class. 121 | Args: 122 | index (int): Index 123 | 124 | Returns: 125 | tuple: (sample, target, semi_target, index) 126 | """ 127 | index = (index + self.offset) % self.size 128 | index = self.idxs[index] 129 | 130 | path, target = self.samples[index] 131 | sample = self.loader(path) 132 | 133 | if self.transform is not None: 134 | sample = self.transform(sample) 135 | 136 | return sample, 1, -1, index 137 | -------------------------------------------------------------------------------- /src/optim/classifier_trainer.py: -------------------------------------------------------------------------------- 1 | from base.base_trainer import BaseTrainer 2 | from base.base_dataset import BaseADDataset 3 | from base.base_net import BaseNet 4 | from networks.modules.focal_loss import FocalLoss 5 | from torch.utils.data import DataLoader, RandomSampler 6 | from sklearn.metrics import roc_auc_score 7 | 8 | import logging 9 | import time 10 | import random 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import numpy as np 15 | 16 | 17 | class ClassifierTrainer(BaseTrainer): 18 | 19 | def __init__(self, objective: str, hsc_norm: str = 'l2_squared_linear', focal_gamma: float = 2.0, 20 | optimizer_name: str = 'adam', lr: float = 0.001, n_epochs: int = 150, lr_milestones: tuple = (), 21 | batch_size: int = 128, weight_decay: float = 1e-6, device: str = 'cuda', n_jobs_dataloader: int = 0): 22 | super().__init__(optimizer_name, lr, n_epochs, lr_milestones, batch_size, weight_decay, device, 23 | n_jobs_dataloader) 24 | 25 | # Classifier parameters 26 | self.objective = objective 27 | self.hsc_norm = hsc_norm 28 | self.focal_gamma = focal_gamma 29 | self.eps = 1e-9 30 | 31 | # Results 32 | self.train_time = None 33 | self.train_scores = None 34 | self.test_time = None 35 | self.test_scores = None 36 | self.test_auc = None 37 | 38 | def train(self, dataset: BaseADDataset, oe_dataset: BaseADDataset, net: BaseNet): 39 | logger = logging.getLogger() 40 | 41 | # Get train data loader 42 | if oe_dataset is not None: 43 | num_workers = int(self.n_jobs_dataloader / 2) 44 | else: 45 | num_workers = self.n_jobs_dataloader 46 | 47 | train_loader, _ = dataset.loaders(batch_size=self.batch_size, num_workers=num_workers) 48 | if oe_dataset is not None: 49 | if oe_dataset.shuffle: 50 | if len(dataset.train_set) > len(oe_dataset.train_set): 51 | oe_sampler = RandomSampler(oe_dataset.train_set, 52 | replacement=True, num_samples=len(dataset.train_set)) 53 | oe_loader = DataLoader(dataset=oe_dataset.train_set, batch_size=self.batch_size, shuffle=False, 54 | sampler=oe_sampler, num_workers=num_workers, drop_last=True) 55 | else: 56 | oe_loader = DataLoader(dataset=oe_dataset.train_set, batch_size=self.batch_size, shuffle=True, 57 | num_workers=num_workers, drop_last=True) 58 | 59 | else: 60 | oe_loader = DataLoader(dataset=oe_dataset.train_set, batch_size=self.batch_size, shuffle=False, 61 | num_workers=num_workers, drop_last=True) 62 | dataset_loader = zip(train_loader, oe_loader) 63 | else: 64 | dataset_loader = train_loader 65 | 66 | # Set loss 67 | if self.objective in ['bce', 'focal']: 68 | if self.objective == 'bce': 69 | criterion = nn.BCEWithLogitsLoss() 70 | if self.objective == 'focal': 71 | criterion = FocalLoss(gamma=self.focal_gamma) 72 | criterion = criterion.to(self.device) 73 | 74 | # Set device 75 | net = net.to(self.device) 76 | 77 | # Set optimizer 78 | optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) 79 | 80 | # Set learning rate scheduler 81 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1) 82 | 83 | # Training 84 | logger.info('Starting training...') 85 | net.train() 86 | start_time = time.time() 87 | 88 | for epoch in range(self.n_epochs + 1): 89 | epoch_loss = 0.0 90 | n_batches = 0 91 | idx_label_score = [] 92 | epoch_start_time = time.time() 93 | 94 | # start at random point for the outlier exposure dataset in each epoch 95 | if (oe_dataset is not None) and (epoch < self.n_epochs): 96 | oe_loader.dataset.offset = np.random.randint(len(oe_loader.dataset)) 97 | if oe_loader.dataset.shuffle_idxs: 98 | random.shuffle(oe_loader.dataset.idxs) 99 | dataset_loader = zip(train_loader, oe_loader) 100 | 101 | # only load samples from the original training set in a last epoch for saving train scores 102 | if epoch == self.n_epochs: 103 | dataset_loader = train_loader 104 | net.eval() 105 | 106 | for data in dataset_loader: 107 | if (oe_dataset is not None) and (epoch < self.n_epochs): 108 | inputs = torch.cat((data[0][0], data[1][0]), 0) 109 | labels = torch.cat((data[0][1], data[1][1]), 0) 110 | semi_targets = torch.cat((data[0][2], data[1][2]), 0) 111 | idx = torch.cat((data[0][3], data[1][3]), 0) 112 | else: 113 | inputs, labels, semi_targets, idx = data 114 | 115 | inputs = inputs.to(self.device) 116 | labels = labels.to(self.device) 117 | semi_targets = semi_targets.to(self.device) 118 | idx = idx.to(self.device) 119 | 120 | # Zero the network parameter gradients 121 | if epoch < self.n_epochs: 122 | optimizer.zero_grad() 123 | 124 | # Update network parameters via backpropagation: forward + backward + optimize 125 | outputs = net(inputs) 126 | 127 | if self.objective == 'hsc': 128 | if self.hsc_norm == 'l1': 129 | dists = torch.norm(outputs, p=1, dim=1) 130 | if self.hsc_norm == 'l2': 131 | dists = torch.norm(outputs, p=2, dim=1) 132 | if self.hsc_norm == 'l2_squared': 133 | dists = torch.norm(outputs, p=2, dim=1) ** 2 134 | if self.hsc_norm == 'l2_squared_linear': 135 | dists = torch.sqrt(torch.norm(outputs, p=2, dim=1) ** 2 + 1) - 1 136 | 137 | scores = 1 - torch.exp(-dists) 138 | losses = torch.where(semi_targets == 0, dists, -torch.log(scores + self.eps)) 139 | loss = torch.mean(losses) 140 | 141 | if self.objective == 'deepSAD': 142 | dists = torch.norm(outputs, p=2, dim=1) ** 2 143 | scores = dists 144 | losses = torch.where(semi_targets == 0, dists, ((dists + self.eps) ** semi_targets.float())) 145 | loss = torch.mean(losses) 146 | 147 | if self.objective in ['bce', 'focal']: 148 | targets = torch.zeros(inputs.size(0)) 149 | targets[semi_targets == -1] = 1 150 | targets = targets.view(-1, 1).to(self.device) 151 | 152 | scores = torch.sigmoid(outputs) 153 | loss = criterion(outputs, targets) 154 | 155 | if epoch < self.n_epochs: 156 | loss.backward() 157 | optimizer.step() 158 | 159 | # save train scores in last epoch 160 | if epoch == self.n_epochs: 161 | idx_label_score += list(zip(idx.cpu().data.numpy().tolist(), 162 | labels.cpu().data.numpy().tolist(), 163 | scores.flatten().cpu().data.numpy().tolist())) 164 | 165 | epoch_loss += loss.item() 166 | n_batches += 1 167 | 168 | # Take learning rate scheduler step 169 | scheduler.step() 170 | if epoch in self.lr_milestones: 171 | logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_last_lr()[0])) 172 | 173 | # log epoch statistics 174 | epoch_train_time = time.time() - epoch_start_time 175 | logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s ' 176 | f'| Train Loss: {epoch_loss / n_batches:.6f} |') 177 | 178 | self.train_time = time.time() - start_time 179 | self.train_scores = idx_label_score 180 | 181 | # Log results 182 | logger.info('Train Time: {:.3f}s'.format(self.train_time)) 183 | logger.info('Train Loss: {:.6f}'.format(epoch_loss / n_batches)) 184 | logger.info('Finished training.') 185 | 186 | return net 187 | 188 | def test(self, dataset: BaseADDataset, net: BaseNet): 189 | logger = logging.getLogger() 190 | 191 | # Get test data loader 192 | _, test_loader = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader) 193 | 194 | # Set loss 195 | if self.objective in ['bce', 'focal']: 196 | if self.objective == 'bce': 197 | criterion = nn.BCEWithLogitsLoss() 198 | if self.objective == 'focal': 199 | criterion = FocalLoss(gamma=self.focal_gamma) 200 | criterion = criterion.to(self.device) 201 | 202 | # Set device for network 203 | net = net.to(self.device) 204 | 205 | # Testing 206 | logger.info('Starting testing...') 207 | net.eval() 208 | epoch_loss = 0.0 209 | n_batches = 0 210 | idx_label_score = [] 211 | start_time = time.time() 212 | 213 | with torch.no_grad(): 214 | for data in test_loader: 215 | inputs, labels, semi_targets, idx = data 216 | 217 | inputs = inputs.to(self.device) 218 | labels = labels.to(self.device) 219 | semi_targets = semi_targets.to(self.device) 220 | idx = idx.to(self.device) 221 | 222 | outputs = net(inputs) 223 | 224 | if self.objective == 'hsc': 225 | if self.hsc_norm == 'l1': 226 | dists = torch.norm(outputs, p=1, dim=1) 227 | if self.hsc_norm == 'l2': 228 | dists = torch.norm(outputs, p=2, dim=1) 229 | if self.hsc_norm == 'l2_squared': 230 | dists = torch.norm(outputs, p=2, dim=1) ** 2 231 | if self.hsc_norm == 'l2_squared_linear': 232 | dists = torch.sqrt(torch.norm(outputs, p=2, dim=1) ** 2 + 1) - 1 233 | 234 | scores = 1 - torch.exp(-dists) 235 | losses = torch.where(semi_targets == 0, dists, -torch.log(scores + self.eps)) 236 | loss = torch.mean(losses) 237 | 238 | if self.objective == 'deepSAD': 239 | dists = torch.norm(outputs, p=2, dim=1) ** 2 240 | scores = dists 241 | losses = torch.where(semi_targets == 0, dists, ((dists + self.eps) ** semi_targets.float())) 242 | loss = torch.mean(losses) 243 | 244 | if self.objective in ['bce', 'focal']: 245 | targets = torch.zeros(inputs.size(0)) 246 | targets[semi_targets == -1] = 1 247 | targets = targets.view(-1, 1).to(self.device) 248 | 249 | scores = torch.sigmoid(outputs) 250 | loss = criterion(outputs, targets) 251 | 252 | # Save triple of (idx, label, score) in a list 253 | idx_label_score += list(zip(idx.cpu().data.numpy().tolist(), 254 | labels.cpu().data.numpy().tolist(), 255 | scores.flatten().cpu().data.numpy().tolist())) 256 | 257 | epoch_loss += loss.item() 258 | n_batches += 1 259 | 260 | self.test_time = time.time() - start_time 261 | self.test_scores = idx_label_score 262 | 263 | # Compute AUC 264 | _, labels, scores = zip(*idx_label_score) 265 | labels = np.array(labels) 266 | scores = np.array(scores) 267 | self.test_auc = roc_auc_score(labels, scores) 268 | 269 | # Log results 270 | logger.info('Test Time: {:.3f}s'.format(self.test_time)) 271 | logger.info('Test Loss: {:.6f}'.format(epoch_loss / n_batches)) 272 | logger.info('Test AUC: {:.2f}'.format(100. * self.test_auc)) 273 | logger.info('Finished testing.') 274 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | import logging 4 | import random 5 | import numpy as np 6 | 7 | from classifier import Classifier 8 | from datasets.main import load_dataset 9 | from utils.config import Config 10 | from utils.visualization.plot_fun import plot_dist 11 | 12 | 13 | ################################################################################ 14 | # Settings 15 | ################################################################################ 16 | @click.command() 17 | @click.argument('dataset_name', type=click.Choice(['mnist', 'emnist', 'cifar10', 'cifar100', 'imagenet1k'])) 18 | @click.argument('net_name', type=click.Choice(['mnist_LeNet', 'cifar10_LeNet', 'imagenet_WideResNet', 'toy_Net'])) 19 | @click.argument('xp_path', type=click.Path(exists=True)) 20 | @click.argument('data_path', type=click.Path(exists=True)) 21 | @click.option('--load_config', type=click.Path(exists=True), default=None, 22 | help='Config JSON-file path (default: None).') 23 | @click.option('--load_model', type=click.Path(exists=True), default=None, help='Model file path (default: None).') 24 | @click.option('--rep_dim', type=int, default=64, help='Final layer dimensionality.') 25 | @click.option('--bias_terms', type=bool, default=True, help='Option to include bias terms in the network.') 26 | @click.option('--objective', type=click.Choice(['hsc', 'deepSAD', 'bce', 'focal']), 27 | default='hsc', help='Set specific type of classification objective to use.') 28 | @click.option('--hsc_norm', type=click.Choice(['l1', 'l2', 'l2_squared', 'l2_squared_linear']), 29 | default='l2_squared_linear', help='Set specific norm to use with HSC.') 30 | @click.option('--focal_gamma', type=float, default=2.0, help='Focal loss hyperparameter gamma. Default=2.0') 31 | @click.option('--outlier_exposure', type=bool, default=False, 32 | help='Apply outlier exposure using oe_dataset_name. Doubles the specified batch_size.') 33 | @click.option('--oe_dataset_name', type=click.Choice(['emnist', 'tinyimages', 'cifar100', 'imagenet22k', 'noise']), 34 | default='tinyimages', help='Choose the dataset to use as outlier exposure.') 35 | @click.option('--oe_size', type=int, default=79302016, 36 | help='Size of the outlier exposure dataset (option to train on subsets).') 37 | @click.option('--oe_n_classes', type=int, default=-1, 38 | help='Number of classes in the outlier exposure dataset.' 39 | 'If -1, all classes.' 40 | 'If > 1, the specified number of classes will be sampled at random.') 41 | @click.option('--blur_oe', type=bool, default=False, help='Option to blur (Gaussian filter) OE samples.') 42 | @click.option('--blur_std', type=float, default=1.0, help='Gaussian blurring filter standard deviation. default=1.0') 43 | @click.option('--device', type=str, default='cuda', help='Computation device to use ("cpu", "cuda", "cuda:2", etc.).') 44 | @click.option('--seed', type=int, default=-1, help='Set seed. If -1, use randomization.') 45 | @click.option('--optimizer_name', type=click.Choice(['adam']), default='adam', 46 | help='Name of the optimizer to use for training.') 47 | @click.option('--lr', type=float, default=0.001, help='Initial learning rate for training. Default=0.001') 48 | @click.option('--n_epochs', type=int, default=50, help='Number of epochs to train.') 49 | @click.option('--lr_milestone', type=int, default=0, multiple=True, 50 | help='Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.') 51 | @click.option('--batch_size', type=int, default=64, help='Batch size for mini-batch training.') 52 | @click.option('--weight_decay', type=float, default=0.5e-6, help='Weight decay (L2 penalty) hyperparameter.') 53 | @click.option('--data_augmentation', type=bool, default=False, 54 | help='Apply data augmentation (random flipping, rotation, and translation) for training.') 55 | @click.option('--data_normalization', type=bool, default=False, help='Normalize data wrt dataset sample mean and std.') 56 | @click.option('--num_threads', type=int, default=0, 57 | help='Number of threads used for parallelizing CPU operations. 0 means that all resources are used.') 58 | @click.option('--n_jobs_dataloader', type=int, default=0, 59 | help='Number of workers for data loading. 0 means that the data will be loaded in the main process.') 60 | @click.option('--normal_class', type=int, default=0, 61 | help='Specify the normal class of the dataset (all other classes are considered anomalous).') 62 | def main(dataset_name, net_name, xp_path, data_path, load_config, load_model, rep_dim, bias_terms, objective, hsc_norm, 63 | focal_gamma, outlier_exposure, oe_dataset_name, oe_size, oe_n_classes, blur_oe, blur_std, device, seed, 64 | optimizer_name, lr, n_epochs, lr_milestone, batch_size, weight_decay, data_augmentation, data_normalization, 65 | num_threads, n_jobs_dataloader, normal_class): 66 | """ 67 | A binary classification model. 68 | 69 | :arg DATASET_NAME: Name of the dataset to load. 70 | :arg NET_NAME: Name of the neural network to use. 71 | :arg XP_PATH: Export path for logging the experiment. 72 | :arg DATA_PATH: Root path of data. 73 | """ 74 | 75 | # Get configuration 76 | cfg = Config(locals().copy()) 77 | 78 | # Set up logging 79 | logging.basicConfig(level=logging.INFO) 80 | logger = logging.getLogger() 81 | logger.setLevel(logging.INFO) 82 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 83 | log_file = xp_path + '/log.txt' 84 | file_handler = logging.FileHandler(log_file) 85 | file_handler.setLevel(logging.INFO) 86 | file_handler.setFormatter(formatter) 87 | logger.addHandler(file_handler) 88 | 89 | # Print paths 90 | logger.info('Log file is %s' % log_file) 91 | logger.info('Data path is %s' % data_path) 92 | logger.info('Export path is %s' % xp_path) 93 | 94 | # If specified, load experiment config from JSON-file 95 | if load_config: 96 | cfg.load_config(import_json=load_config) 97 | logger.info('Loaded configuration from %s.' % load_config) 98 | 99 | # Print experimental setup 100 | logger.info('Dataset: %s' % dataset_name) 101 | logger.info('Normal class: %d' % cfg.settings['normal_class']) 102 | logger.info('Apply outlier exposure: %s' % cfg.settings['outlier_exposure']) 103 | if outlier_exposure: 104 | logger.info('Outlier exposure dataset: %s' % cfg.settings['oe_dataset_name']) 105 | logger.info('Size of OE dataset: %d' % cfg.settings['oe_size']) 106 | if cfg.settings['oe_dataset_name'] in ['emnist', 'cifar100']: 107 | logger.info('Number of classes in OE dataset: %d' % cfg.settings['oe_n_classes']) 108 | logger.info('Blur OE samples with a Gaussian filter: %s' % cfg.settings['blur_oe']) 109 | if cfg.settings['blur_oe']: 110 | logger.info('Gaussian blur filter StdDev: %g' % cfg.settings['blur_std']) 111 | logger.info('Network: %s' % net_name) 112 | logger.info('Representation dimensionality: %d' % cfg.settings['rep_dim']) 113 | logger.info('Include bias terms into the network: %s' % cfg.settings['bias_terms']) 114 | logger.info('Use data augmentation: %s' % cfg.settings['data_augmentation']) 115 | logger.info('Normalize data: %s' % cfg.settings['data_normalization']) 116 | 117 | # Print model configuration 118 | logger.info('Objective: %s' % cfg.settings['objective']) 119 | if cfg.settings['objective'] == 'hsc': 120 | logger.info('HSC norm: %s' % cfg.settings['hsc_norm']) 121 | if cfg.settings['objective'] == 'focal': 122 | logger.info('Focal loss gamma: %g' % cfg.settings['focal_gamma']) 123 | 124 | # Set seed 125 | if cfg.settings['seed'] != -1: 126 | random.seed(cfg.settings['seed']) 127 | np.random.seed(cfg.settings['seed']) 128 | np_random_state = np.random.RandomState(cfg.settings['seed']) 129 | torch.manual_seed(cfg.settings['seed']) 130 | torch.cuda.manual_seed(cfg.settings['seed']) 131 | torch.backends.cudnn.deterministic = True 132 | logger.info('Set seed to %d.' % cfg.settings['seed']) 133 | else: 134 | cfg.settings['seed'] = None 135 | 136 | # Default device to 'cpu' if cuda is not available 137 | if not torch.cuda.is_available(): 138 | device = 'cpu' 139 | else: 140 | torch.cuda.set_device(device) 141 | # Set the number of threads used for parallelizing CPU operations 142 | if num_threads > 0: 143 | torch.set_num_threads(num_threads) 144 | logger.info('Computation device: %s' % device) 145 | logger.info('Number of threads: %d' % num_threads) 146 | logger.info('Number of dataloader workers: %d' % n_jobs_dataloader) 147 | 148 | # Load data 149 | dataset = load_dataset(dataset_name=dataset_name, data_path=data_path, normal_class=cfg.settings['normal_class'], 150 | data_augmentation=cfg.settings['data_augmentation'], 151 | normalize=cfg.settings['data_normalization'], seed=cfg.settings['seed']) 152 | # Load outlier exposure dataset if specified 153 | if outlier_exposure: 154 | oe_dataset = load_dataset(dataset_name=cfg.settings['oe_dataset_name'], data_path=data_path, 155 | normal_class=cfg.settings['normal_class'], 156 | data_augmentation=cfg.settings['data_augmentation'], 157 | normalize=cfg.settings['data_normalization'], seed=cfg.settings['seed'], 158 | outlier_exposure=cfg.settings['outlier_exposure'], oe_size=cfg.settings['oe_size'], 159 | oe_n_classes=cfg.settings['oe_n_classes'], blur_oe=cfg.settings['blur_oe'], 160 | blur_std=cfg.settings['blur_std']) 161 | else: 162 | oe_dataset = None 163 | 164 | # Initialize Classifier model and set neural network 165 | classifier = Classifier(cfg.settings['objective'], cfg.settings['hsc_norm'], cfg.settings['focal_gamma']) 166 | if cfg.settings['objective'] in ['bce', 'focal']: 167 | net_name = net_name + '_classifier' 168 | classifier.set_network(net_name, rep_dim=cfg.settings['rep_dim'], bias_terms=cfg.settings['bias_terms']) 169 | 170 | # If specified, load model 171 | if load_model: 172 | classifier.load_model(model_path=load_model, map_location=device) 173 | logger.info('Loading model from %s.' % load_model) 174 | 175 | # Log training details 176 | logger.info('Training optimizer: %s' % cfg.settings['optimizer_name']) 177 | logger.info('Training learning rate: %g' % cfg.settings['lr']) 178 | logger.info('Training epochs: %d' % cfg.settings['n_epochs']) 179 | logger.info('Training learning rate scheduler milestones: %s' % (cfg.settings['lr_milestone'],)) 180 | logger.info('Training batch size: %d' % cfg.settings['batch_size']) 181 | logger.info('Training weight decay: %g' % cfg.settings['weight_decay']) 182 | 183 | # Train model on dataset 184 | classifier.train(dataset, oe_dataset, 185 | optimizer_name=cfg.settings['optimizer_name'], 186 | lr=cfg.settings['lr'], 187 | n_epochs=cfg.settings['n_epochs'], 188 | lr_milestones=cfg.settings['lr_milestone'], 189 | batch_size=cfg.settings['batch_size'], 190 | weight_decay=cfg.settings['weight_decay'], 191 | device=device, 192 | n_jobs_dataloader=n_jobs_dataloader) 193 | 194 | # Test model 195 | classifier.test(dataset, device=device, n_jobs_dataloader=n_jobs_dataloader) 196 | 197 | # Get scores 198 | train_idx, train_labels, train_scores = zip(*classifier.results['train_scores']) 199 | train_idx, train_labels, train_scores = np.array(train_idx), np.array(train_labels), np.array(train_scores) 200 | test_idx, test_labels, test_scores = zip(*classifier.results['test_scores']) 201 | test_idx, test_labels, test_scores = np.array(test_idx), np.array(test_labels), np.array(test_scores) 202 | 203 | # Plot score distributions 204 | plot_dist(x=train_scores, xp_path=xp_path, filename='train_scores', 205 | title='Distribution of anomaly scores (train set)', axlabel='Anomaly Score', legendlabel=['normal']) 206 | plot_dist(x=test_scores, xp_path=xp_path, filename='test_scores', target=test_labels, 207 | title='Distribution of anomaly scores (test set)', axlabel='Anomaly Score', 208 | legendlabel=['normal', 'outlier']) 209 | 210 | # Save results, model, and configuration 211 | classifier.save_results(export_json=xp_path + '/results.json') 212 | classifier.save_model(export_model=xp_path + '/model.tar') 213 | cfg.save_config(export_json=xp_path + '/config.json') 214 | 215 | 216 | if __name__ == '__main__': 217 | main() 218 | --------------------------------------------------------------------------------