├── .gitignore ├── Images └── vgg_network_last.png ├── README.md ├── configs └── config.yaml ├── dataloader.py ├── loss_functions.py ├── models └── network.py ├── requirements.txt ├── test.py ├── test_functions.py ├── train.py └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ---> Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | .idea/ 13 | venv/ 14 | train/ 15 | test/ 16 | temp_data/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /Images/vgg_network_last.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Niousha12/Knowledge_Distillation_AD/5e08e2d72b26bda55700150cacfafc76983d9c75/Images/vgg_network_last.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiresolution Knowledge Distillation for Anomaly Detection 2 | 3 | This repository contains code for training and evaluating the proposed method in our paper [Multiresolution Knowledge Distillation for Anomaly Detection](https://arxiv.org/pdf/2011.11108.pdf). 4 | 5 | hi 6 | 7 | ## Citation 8 | If you find this useful for your research, please cite the following paper: 9 | ``` bash 10 | @article{salehi2020distillation, 11 | title={Multiresolution Knowledge Distillation for Anomaly Detection}, 12 | author={Salehi, Mohammadreza and Sadjadi, Niousha and Baselizadeh, Soroosh and Rohban, Mohammad Hossein and Rabiee, Hamid R}, 13 | year={2020}, 14 | eprint={2011.11108}, 15 | archivePrefix={arXiv}, 16 | primaryClass={cs.CV} 17 | } 18 | ``` 19 | 20 | ### 1- Clone this repo: 21 | ``` bash 22 | git clone https://github.com/Niousha12/Knowledge_Distillation_AD.git 23 | cd Knowledge_Distillation_AD 24 | ``` 25 | ### 2- Datsets: 26 | This repository performs Novelty/Anomaly Detection in the following datasets: MNIST, Fashion-MNIST, CIFAR-10, MVTecAD, and 2 medical datasets (Head CT hemorrhage and Brain MRI Images for Brain Tumor Detection). 27 | 28 | Furthermore, Anomaly Localization have been performed on MVTecAD dataset. 29 | 30 | MNIST, Fashion-MNIST and CIFAR-10 datasets will be downloaded by Torchvision. You have to download [MVTecAD](https://www.mvtec.com/company/research/datasets/mvtec-ad/), [Retina](https://www.kaggle.com/paultimothymooney/kermany2018), [Head CT Hemorrhage](http://www.kaggle.com/felipekitamura/head-ct-hemorrhage), and [Brain MRI Images for Brain Tumor Detection](http://www.kaggle.com/navoneel/brain-mri-images-for-brain-tumor-detection), and unpack them into the `Dataset` folder. 31 | 32 | ##### For Localization test you should remove the `good` folder in `{mvtec_class_name}/test/` folder. 33 | 34 | ### 3- Train the Model: 35 | Start the training using the following command. The checkpoints will be saved in the folder `outputs/{experiment_name}/{dataset_name}/checkpoints`. 36 | 37 | Train parameters such as experiment_name, dataset_name, normal_class, batch_size and etc. can be specified in `configs/config.yaml`. 38 | ``` bash 39 | python train.py --config configs/config.yaml 40 | ``` 41 | 42 | ### 4- Test the Trained Model: 43 | Test parameters can also be specified in `configs/config.yaml`. 44 | ``` bash 45 | python test.py --config configs/config.yaml 46 | ``` 47 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # Data parameters 2 | experiment_name: 'local_equal_net' 3 | dataset_name: cifar10 # [mnist, fashionmnist, cifar10, mvtec, retina] 4 | last_checkpoint: 201 5 | 6 | 7 | # Training parameters 8 | num_epochs: 201 # mnist/fashionmnist:51, cifar10:201, mvtec:601 9 | batch_size: 64 10 | learning_rate: 1e-3 11 | mvtec_img_size: 128 12 | 13 | normal_class: 3 # mvtec:'capsule', mnist:3 14 | 15 | lamda: 0.01 # mvtec:0.5, Others:0.01 16 | 17 | 18 | pretrain: True # True:use pre-trained vgg as source network --- False:use random initialize 19 | use_bias: False # True:using bias term in neural network layer 20 | equal_network_size: False # True:using equal network size for cloner and source network --- False:smaller network for cloner 21 | direction_loss_only: False 22 | continue_train: False 23 | 24 | 25 | # Test parameters 26 | localization_test: False # True:For Localization Test --- False:For Detection 27 | localization_method: 'gbp' # gradients , smooth_grad , gbp 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader, TensorDataset, ConcatDataset 5 | import torchvision.transforms as transforms 6 | from torchvision.datasets import MNIST, CIFAR10, FashionMNIST 7 | from torchvision.datasets import ImageFolder 8 | from PIL import Image 9 | 10 | 11 | def load_data(config): 12 | normal_class = config['normal_class'] 13 | batch_size = config['batch_size'] 14 | 15 | if config['dataset_name'] in ['cifar10']: 16 | img_transform = transforms.Compose([ 17 | transforms.Resize((256, 256), Image.ANTIALIAS), 18 | transforms.CenterCrop(224), 19 | transforms.ToTensor(), 20 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 21 | ]) 22 | 23 | os.makedirs("./Dataset/CIFAR10/train", exist_ok=True) 24 | dataset = CIFAR10('./Dataset/CIFAR10/train', train=True, download=True, transform=img_transform) 25 | print("Cifar10 DataLoader Called...") 26 | print("All Train Data: ", dataset.data.shape) 27 | dataset.data = dataset.data[np.array(dataset.targets) == normal_class] 28 | dataset.targets = [normal_class] * dataset.data.shape[0] 29 | print("Normal Train Data: ", dataset.data.shape) 30 | 31 | os.makedirs("./Dataset/CIFAR10/test", exist_ok=True) 32 | test_set = CIFAR10("./Dataset/CIFAR10/test", train=False, download=True, transform=img_transform) 33 | print("Test Train Data:", test_set.data.shape) 34 | 35 | elif config['dataset_name'] in ['mnist']: 36 | img_transform = transforms.Compose([ 37 | transforms.Resize((32, 32)), 38 | transforms.ToTensor() 39 | ]) 40 | 41 | os.makedirs("./Dataset/MNIST/train", exist_ok=True) 42 | dataset = MNIST('./Dataset/MNIST/train', train=True, download=True, transform=img_transform) 43 | print("MNIST DataLoader Called...") 44 | print("All Train Data: ", dataset.data.shape) 45 | dataset.data = dataset.data[np.array(dataset.targets) == normal_class] 46 | dataset.targets = [normal_class] * dataset.data.shape[0] 47 | print("Normal Train Data: ", dataset.data.shape) 48 | 49 | os.makedirs("./Dataset/MNIST/test", exist_ok=True) 50 | test_set = MNIST("./Dataset/MNIST/test", train=False, download=True, transform=img_transform) 51 | print("Test Train Data:", test_set.data.shape) 52 | 53 | elif config['dataset_name'] in ['fashionmnist']: 54 | img_transform = transforms.Compose([ 55 | transforms.Resize((32, 32)), 56 | transforms.ToTensor() 57 | ]) 58 | 59 | os.makedirs("./Dataset/FashionMNIST/train", exist_ok=True) 60 | dataset = FashionMNIST('./Dataset/FashionMNIST/train', train=True, download=True, transform=img_transform) 61 | print("FashionMNIST DataLoader Called...") 62 | print("All Train Data: ", dataset.data.shape) 63 | dataset.data = dataset.data[np.array(dataset.targets) == normal_class] 64 | dataset.targets = [normal_class] * dataset.data.shape[0] 65 | print("Normal Train Data: ", dataset.data.shape) 66 | 67 | os.makedirs("./Dataset/FashionMNIST/test", exist_ok=True) 68 | test_set = FashionMNIST("./Dataset/FashionMNIST/test", train=False, download=True, transform=img_transform) 69 | print("Test Train Data:", test_set.data.shape) 70 | 71 | elif config['dataset_name'] in ['mvtec']: 72 | data_path = 'Dataset/MVTec/' + normal_class + '/train' 73 | mvtec_img_size = config['mvtec_img_size'] 74 | 75 | orig_transform = transforms.Compose([ 76 | transforms.Resize([mvtec_img_size, mvtec_img_size]), 77 | transforms.ToTensor() 78 | ]) 79 | 80 | dataset = ImageFolder(root=data_path, transform=orig_transform) 81 | 82 | test_data_path = 'Dataset/MVTec/' + normal_class + '/test' 83 | test_set = ImageFolder(root=test_data_path, transform=orig_transform) 84 | 85 | elif config['dataset_name'] in ['retina']: 86 | data_path = 'Dataset/OCT2017/train' 87 | 88 | orig_transform = transforms.Compose([ 89 | transforms.Resize([128, 128]), 90 | transforms.ToTensor() 91 | ]) 92 | 93 | dataset = ImageFolder(root=data_path, transform=orig_transform) 94 | 95 | test_data_path = 'Dataset/OCT2017/test' 96 | test_set = ImageFolder(root=test_data_path, transform=orig_transform) 97 | 98 | else: 99 | raise Exception( 100 | "You enter {} as dataset, which is not a valid dataset for this repository!".format(config['dataset_name'])) 101 | 102 | train_dataloader = torch.utils.data.DataLoader( 103 | dataset, 104 | batch_size=batch_size, 105 | shuffle=True, 106 | ) 107 | test_dataloader = torch.utils.data.DataLoader( 108 | test_set, 109 | batch_size=batch_size, 110 | shuffle=False, 111 | ) 112 | 113 | return train_dataloader, test_dataloader 114 | 115 | 116 | def load_localization_data(config): 117 | normal_class = config['normal_class'] 118 | mvtec_img_size = config['mvtec_img_size'] 119 | 120 | orig_transform = transforms.Compose([ 121 | transforms.Resize([mvtec_img_size, mvtec_img_size]), 122 | transforms.ToTensor() 123 | ]) 124 | 125 | test_data_path = 'Dataset/MVTec/' + normal_class + '/test' 126 | test_set = ImageFolder(root=test_data_path, transform=orig_transform) 127 | test_dataloader = torch.utils.data.DataLoader( 128 | test_set, 129 | batch_size=512, 130 | shuffle=False, 131 | ) 132 | 133 | ground_data_path = 'Dataset/MVTec/' + normal_class + '/ground_truth' 134 | ground_dataset = ImageFolder(root=ground_data_path, transform=orig_transform) 135 | ground_dataloader = torch.utils.data.DataLoader( 136 | ground_dataset, 137 | batch_size=512, 138 | num_workers=0, 139 | shuffle=False 140 | ) 141 | 142 | x_ground = next(iter(ground_dataloader))[0].numpy() 143 | ground_temp = x_ground 144 | 145 | std_groud_temp = np.transpose(ground_temp, (0, 2, 3, 1)) 146 | x_ground = std_groud_temp 147 | 148 | return test_dataloader, x_ground 149 | -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MseDirectionLoss(nn.Module): 6 | def __init__(self, lamda): 7 | super(MseDirectionLoss, self).__init__() 8 | self.lamda = lamda 9 | self.criterion = nn.MSELoss() 10 | self.similarity_loss = torch.nn.CosineSimilarity() 11 | 12 | def forward(self, output_pred, output_real): 13 | y_pred_0, y_pred_1, y_pred_2, y_pred_3 = output_pred[3], output_pred[6], output_pred[9], output_pred[12] 14 | y_0, y_1, y_2, y_3 = output_real[3], output_real[6], output_real[9], output_real[12] 15 | 16 | # different terms of loss 17 | abs_loss_0 = self.criterion(y_pred_0, y_0) 18 | loss_0 = torch.mean(1 - self.similarity_loss(y_pred_0.view(y_pred_0.shape[0], -1), y_0.view(y_0.shape[0], -1))) 19 | abs_loss_1 = self.criterion(y_pred_1, y_1) 20 | loss_1 = torch.mean(1 - self.similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1))) 21 | abs_loss_2 = self.criterion(y_pred_2, y_2) 22 | loss_2 = torch.mean(1 - self.similarity_loss(y_pred_2.view(y_pred_2.shape[0], -1), y_2.view(y_2.shape[0], -1))) 23 | abs_loss_3 = self.criterion(y_pred_3, y_3) 24 | loss_3 = torch.mean(1 - self.similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1))) 25 | 26 | total_loss = loss_0 + loss_1 + loss_2 + loss_3 + self.lamda * ( 27 | abs_loss_0 + abs_loss_1 + abs_loss_2 + abs_loss_3) 28 | 29 | return total_loss 30 | 31 | 32 | class DirectionOnlyLoss(nn.Module): 33 | def __init__(self): 34 | super(DirectionOnlyLoss, self).__init__() 35 | self.similarity_loss = torch.nn.CosineSimilarity() 36 | 37 | def forward(self, output_pred, output_real): 38 | y_pred_0, y_pred_1, y_pred_2, y_pred_3 = output_pred[3], output_pred[6], output_pred[9], output_pred[12] 39 | y_0, y_1, y_2, y_3 = output_real[3], output_real[6], output_real[9], output_real[12] 40 | 41 | loss_0 = torch.mean(1 - self.similarity_loss(y_pred_0.view(y_pred_0.shape[0], -1), y_0.view(y_0.shape[0], -1))) 42 | loss_1 = torch.mean(1 - self.similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1))) 43 | loss_2 = torch.mean(1 - self.similarity_loss(y_pred_2.view(y_pred_2.shape[0], -1), y_2.view(y_2.shape[0], -1))) 44 | loss_3 = torch.mean(1 - self.similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1))) 45 | 46 | total_loss = loss_0 + loss_1 + loss_2 + loss_3 47 | 48 | return total_loss 49 | -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models import vgg16 4 | from pathlib import Path 5 | 6 | 7 | class VGG(nn.Module): 8 | ''' 9 | VGG model 10 | ''' 11 | 12 | def __init__(self, features): 13 | super(VGG, self).__init__() 14 | self.features = features 15 | 16 | # placeholder for the gradients 17 | self.gradients = None 18 | self.activation = None 19 | 20 | # hook for the gradients of the activations 21 | def activations_hook(self, grad): 22 | self.gradients = grad 23 | 24 | def forward(self, x, target_layer=11): 25 | result = [] 26 | for i in range(len(nn.ModuleList(self.features))): 27 | x = self.features[i](x) 28 | if i == target_layer: 29 | self.activation = x 30 | h = x.register_hook(self.activations_hook) 31 | if i == 2 or i == 5 or i == 8 or i == 11 or i == 14 or i == 17 or i == 20 or i == 23 or i == 26 or i == 29 or i == 32 or i == 35 or i == 38: 32 | result.append(x) 33 | 34 | return result 35 | 36 | def get_activations_gradient(self): 37 | return self.gradients 38 | 39 | def get_activations(self, x): 40 | return self.activation 41 | 42 | 43 | def make_layers(cfg, use_bias, batch_norm=False): 44 | layers = [] 45 | in_channels = 3 46 | outputs = [] 47 | for i in range(len(cfg)): 48 | if cfg[i] == 'O': 49 | outputs.append(nn.Sequential(*layers)) 50 | elif cfg[i] == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=1, bias=use_bias) 54 | torch.nn.init.xavier_uniform_(conv2d.weight) 55 | if batch_norm and cfg[i + 1] != 'M': 56 | layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)] 57 | else: 58 | layers += [conv2d, nn.ReLU(inplace=True)] 59 | in_channels = cfg[i] 60 | return nn.Sequential(*layers) 61 | 62 | 63 | def make_arch(idx, cfg, use_bias, batch_norm=False): 64 | return VGG(make_layers(cfg[idx], use_bias, batch_norm=batch_norm)) 65 | 66 | 67 | class Vgg16(torch.nn.Module): 68 | def __init__(self, pretrain): 69 | super(Vgg16, self).__init__() 70 | features = list(vgg16('vgg16-397923af.pth').features) 71 | 72 | if not pretrain: 73 | for ind, f in enumerate(features): 74 | # nn.init.xavier_normal_(f) 75 | if type(f) is torch.nn.modules.conv.Conv2d: 76 | torch.nn.init.xavier_uniform(f.weight) 77 | print("Initialized", ind, f) 78 | else: 79 | print("Bypassed", ind, f) 80 | # print("Pre-trained Network loaded") 81 | self.features = nn.ModuleList(features).eval() 82 | self.output = [] 83 | 84 | def forward(self, x): 85 | output = [] 86 | for i in range(31): 87 | x = self.features[i](x) 88 | if i == 1 or i == 4 or i == 6 or i == 9 or i == 11 or i == 13 or i == 16 or i == 18 or i == 20 or i == 23 or i == 25 or i == 27 or i == 30: 89 | output.append(x) 90 | return output 91 | 92 | 93 | def get_networks(config, load_checkpoint=False): 94 | equal_network_size = config['equal_network_size'] 95 | pretrain = config['pretrain'] 96 | experiment_name = config['experiment_name'] 97 | dataset_name = config['dataset_name'] 98 | normal_class = config['normal_class'] 99 | use_bias = config['use_bias'] 100 | cfg = { 101 | 'A': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 102 | 'B': [16, 16, 'M', 16, 128, 'M', 16, 16, 256, 'M', 16, 16, 512, 'M', 16, 16, 512, 'M'], 103 | } 104 | 105 | if equal_network_size: 106 | config_type = 'A' 107 | else: 108 | config_type = 'B' 109 | 110 | vgg = Vgg16(pretrain).cuda() 111 | model = make_arch(config_type, cfg, use_bias, True).cuda() 112 | 113 | for j, item in enumerate(nn.ModuleList(model.features)): 114 | print('layer : {} {}'.format(j, item)) 115 | 116 | if load_checkpoint: 117 | last_checkpoint = config['last_checkpoint'] 118 | checkpoint_path = "./outputs/{}/{}/checkpoints/".format(experiment_name, dataset_name) 119 | model.load_state_dict( 120 | torch.load('{}Cloner_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, last_checkpoint))) 121 | if not pretrain: 122 | vgg.load_state_dict( 123 | torch.load('{}Source_{}_random_vgg.pth'.format(checkpoint_path, normal_class))) 124 | elif not pretrain: 125 | checkpoint_path = "./outputs/{}/{}/checkpoints/".format(experiment_name, dataset_name) 126 | Path(checkpoint_path).mkdir(parents=True, exist_ok=True) 127 | 128 | torch.save(vgg.state_dict(), '{}Source_{}_random_vgg.pth'.format(checkpoint_path, normal_class)) 129 | print("Source Checkpoint saved!") 130 | 131 | return vgg, model 132 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | sklearn==0.0 4 | matplotlib==3.3.0 5 | PyYAML==5.3.1 6 | opencv-python==4.4.0.46 7 | scipy==1.5 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from utils.utils import get_config 3 | from dataloader import load_data, load_localization_data 4 | from test_functions import detection_test, localization_test 5 | from models.network import get_networks 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument('--config', type=str, default='configs/config.yaml', help="training configuration") 9 | 10 | 11 | def main(): 12 | args = parser.parse_args() 13 | config = get_config(args.config) 14 | vgg, model = get_networks(config, load_checkpoint=True) 15 | 16 | # Localization test 17 | if config['localization_test']: 18 | test_dataloader, ground_truth = load_localization_data(config) 19 | roc_auc = localization_test(model=model, vgg=vgg, test_dataloader=test_dataloader, ground_truth=ground_truth, 20 | config=config) 21 | 22 | # Detection test 23 | else: 24 | _, test_dataloader = load_data(config) 25 | roc_auc = detection_test(model=model, vgg=vgg, test_dataloader=test_dataloader, config=config) 26 | last_checkpoint = config['last_checkpoint'] 27 | print("RocAUC after {} epoch:".format(last_checkpoint), roc_auc) 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /test_functions.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from sklearn.metrics import roc_curve, auc 3 | from utils.utils import morphological_process, convert_to_grayscale, max_regarding_to_abs 4 | from scipy.ndimage.filters import gaussian_filter 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | from copy import deepcopy 9 | from torch.nn import ReLU 10 | 11 | 12 | def detection_test(model, vgg, test_dataloader, config): 13 | normal_class = config["normal_class"] 14 | lamda = config['lamda'] 15 | dataset_name = config['dataset_name'] 16 | direction_only = config['direction_loss_only'] 17 | 18 | if dataset_name != "mvtec": 19 | target_class = normal_class 20 | else: 21 | mvtec_good_dict = {'bottle': 3, 'cable': 5, 'capsule': 2, 'carpet': 2, 22 | 'grid': 3, 'hazelnut': 2, 'leather': 4, 'metal_nut': 3, 'pill': 5, 23 | 'screw': 0, 'tile': 2, 'toothbrush': 1, 'transistor': 3, 'wood': 2, 24 | 'zipper': 4 25 | } 26 | target_class = mvtec_good_dict[normal_class] 27 | 28 | similarity_loss = torch.nn.CosineSimilarity() 29 | label_score = [] 30 | model.eval() 31 | for data in test_dataloader: 32 | X, Y = data 33 | if X.shape[1] == 1: 34 | X = X.repeat(1, 3, 1, 1) 35 | X = Variable(X).cuda() 36 | output_pred = model.forward(X) 37 | output_real = vgg(X) 38 | y_pred_1, y_pred_2, y_pred_3 = output_pred[6], output_pred[9], output_pred[12] 39 | y_1, y_2, y_3 = output_real[6], output_real[9], output_real[12] 40 | 41 | if direction_only: 42 | loss_1 = 1 - similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1)) 43 | loss_2 = 1 - similarity_loss(y_pred_2.view(y_pred_2.shape[0], -1), y_2.view(y_2.shape[0], -1)) 44 | loss_3 = 1 - similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1)) 45 | total_loss = loss_1 + loss_2 + loss_3 46 | else: 47 | abs_loss_1 = torch.mean((y_pred_1 - y_1) ** 2, dim=(1, 2, 3)) 48 | loss_1 = 1 - similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1)) 49 | abs_loss_2 = torch.mean((y_pred_2 - y_2) ** 2, dim=(1, 2, 3)) 50 | loss_2 = 1 - similarity_loss(y_pred_2.view(y_pred_2.shape[0], -1), y_2.view(y_2.shape[0], -1)) 51 | abs_loss_3 = torch.mean((y_pred_3 - y_3) ** 2, dim=(1, 2, 3)) 52 | loss_3 = 1 - similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1)) 53 | total_loss = loss_1 + loss_2 + loss_3 + lamda * (abs_loss_1 + abs_loss_2 + abs_loss_3) 54 | 55 | label_score += list(zip(Y.cpu().data.numpy().tolist(), total_loss.cpu().data.numpy().tolist())) 56 | 57 | labels, scores = zip(*label_score) 58 | labels = np.array(labels) 59 | indx1 = labels == target_class 60 | indx2 = labels != target_class 61 | labels[indx1] = 1 62 | labels[indx2] = 0 63 | scores = np.array(scores) 64 | fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=0) 65 | roc_auc = auc(fpr, tpr) 66 | roc_auc = round(roc_auc, 4) 67 | return roc_auc 68 | 69 | 70 | def localization_test(model, vgg, test_dataloader, ground_truth, config): 71 | localization_method = config['localization_method'] 72 | if localization_method == 'gradients': 73 | grad = gradients_localization(model, vgg, test_dataloader, config) 74 | if localization_method == 'smooth_grad': 75 | grad = smooth_grad_localization(model, vgg, test_dataloader, config) 76 | if localization_method == 'gbp': 77 | grad = gbp_localization(model, vgg, test_dataloader, config) 78 | 79 | return compute_localization_auc(grad, ground_truth) 80 | 81 | 82 | def grad_calc(inputs, model, vgg, config): 83 | inputs = inputs.cuda() 84 | inputs.requires_grad = True 85 | temp = torch.zeros(inputs.shape) 86 | lamda = config['lamda'] 87 | criterion = nn.MSELoss() 88 | similarity_loss = torch.nn.CosineSimilarity() 89 | 90 | for i in range(inputs.shape[0]): 91 | output_pred = model.forward(inputs[i].unsqueeze(0), target_layer=14) 92 | output_real = vgg(inputs[i].unsqueeze(0)) 93 | y_pred_1, y_pred_2, y_pred_3 = output_pred[6], output_pred[9], output_pred[12] 94 | y_1, y_2, y_3 = output_real[6], output_real[9], output_real[12] 95 | abs_loss_1 = criterion(y_pred_1, y_1) 96 | loss_1 = torch.mean(1 - similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1))) 97 | abs_loss_2 = criterion(y_pred_2, y_2) 98 | loss_2 = torch.mean(1 - similarity_loss(y_pred_2.view(y_pred_2.shape[0], -1), y_2.view(y_2.shape[0], -1))) 99 | abs_loss_3 = criterion(y_pred_3, y_3) 100 | loss_3 = torch.mean(1 - similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1))) 101 | total_loss = loss_1 + loss_2 + loss_3 + lamda * (abs_loss_1 + abs_loss_2 + abs_loss_3) 102 | model.zero_grad() 103 | total_loss.backward() 104 | 105 | temp[i] = inputs.grad[i] 106 | 107 | return temp 108 | 109 | 110 | def gradients_localization(model, vgg, test_dataloader, config): 111 | model.eval() 112 | print("Vanilla Backpropagation:") 113 | temp = None 114 | for data in test_dataloader: 115 | X, Y = data 116 | grad = grad_calc(X, model, vgg, config) 117 | temp = np.zeros((grad.shape[0], grad.shape[2], grad.shape[3])) 118 | for i in range(grad.shape[0]): 119 | grad_temp = convert_to_grayscale(grad[i].cpu().numpy()) 120 | grad_temp = grad_temp.squeeze(0) 121 | grad_temp = gaussian_filter(grad_temp, sigma=4) 122 | temp[i] = grad_temp 123 | return temp 124 | 125 | 126 | class VanillaSaliency(): 127 | def __init__(self, model, vgg, device, config): 128 | self.model = model 129 | self.vgg = vgg 130 | self.device = device 131 | self.config = config 132 | self.model.eval() 133 | 134 | def generate_saliency(self, data, make_single_channel=True): 135 | data_var_sal = Variable(data).to(self.device) 136 | self.model.zero_grad() 137 | if data_var_sal.grad is not None: 138 | data_var_sal.grad.data.zero_() 139 | data_var_sal.requires_grad_(True) 140 | 141 | lamda = self.config['lamda'] 142 | criterion = nn.MSELoss() 143 | similarity_loss = torch.nn.CosineSimilarity() 144 | 145 | output_pred = self.model.forward(data_var_sal) 146 | output_real = self.vgg(data_var_sal) 147 | y_pred_1, y_pred_2, y_pred_3 = output_pred[6], output_pred[9], output_pred[12] 148 | y_1, y_2, y_3 = output_real[6], output_real[9], output_real[12] 149 | 150 | abs_loss_1 = criterion(y_pred_1, y_1) 151 | loss_1 = torch.mean(1 - similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1))) 152 | abs_loss_2 = criterion(y_pred_2, y_2) 153 | loss_2 = torch.mean(1 - similarity_loss(y_pred_2.view(y_pred_2.shape[0], -1), y_2.view(y_2.shape[0], -1))) 154 | abs_loss_3 = criterion(y_pred_3, y_3) 155 | loss_3 = torch.mean(1 - similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1))) 156 | total_loss = loss_1 + loss_2 + loss_3 + lamda * (abs_loss_1 + abs_loss_2 + abs_loss_3) 157 | self.model.zero_grad() 158 | total_loss.backward() 159 | grad = data_var_sal.grad.data.detach().cpu() 160 | 161 | if make_single_channel: 162 | grad = np.asarray(grad.detach().cpu().squeeze(0)) 163 | # grad = max_regarding_to_abs(np.max(grad, axis=0), np.min(grad, axis=0)) 164 | # grad = np.expand_dims(grad, axis=0) 165 | grad = convert_to_grayscale(grad) 166 | # print(grad.shape) 167 | else: 168 | grad = np.asarray(grad) 169 | return grad 170 | 171 | 172 | def generate_smooth_grad(data, param_n, param_sigma_multiplier, vbp, single_channel=True): 173 | smooth_grad = None 174 | 175 | mean = 0 176 | sigma = param_sigma_multiplier / (torch.max(data) - torch.min(data)).item() 177 | VBP = vbp 178 | for x in range(param_n): 179 | noise = Variable(data.data.new(data.size()).normal_(mean, sigma ** 2)) 180 | noisy_img = data + noise 181 | vanilla_grads = VBP.generate_saliency(noisy_img, single_channel) 182 | if not isinstance(vanilla_grads, np.ndarray): 183 | vanilla_grads = vanilla_grads.detach().cpu().numpy() 184 | if smooth_grad is None: 185 | smooth_grad = vanilla_grads 186 | else: 187 | smooth_grad = smooth_grad + vanilla_grads 188 | 189 | smooth_grad = smooth_grad / param_n 190 | return smooth_grad 191 | 192 | 193 | class IntegratedGradients(): 194 | def __init__(self, model, vgg, device): 195 | self.model = model 196 | self.vgg = vgg 197 | self.gradients = None 198 | self.device = device 199 | # Put model in evaluation mode 200 | self.model.eval() 201 | 202 | def generate_images_on_linear_path(self, input_image, steps): 203 | step_list = np.arange(steps + 1) / steps 204 | xbar_list = [input_image * step for step in step_list] 205 | return xbar_list 206 | 207 | def generate_gradients(self, input_image, make_single_channel=True): 208 | vanillaSaliency = VanillaSaliency(self.model, self.vgg, self.device) 209 | saliency = vanillaSaliency.generate_saliency(input_image, make_single_channel) 210 | if not isinstance(saliency, np.ndarray): 211 | saliency = saliency.detach().cpu().numpy() 212 | return saliency 213 | 214 | def generate_integrated_gradients(self, input_image, steps, make_single_channel=True): 215 | xbar_list = self.generate_images_on_linear_path(input_image, steps) 216 | integrated_grads = None 217 | for xbar_image in xbar_list: 218 | single_integrated_grad = self.generate_gradients(xbar_image, False) 219 | if integrated_grads is None: 220 | integrated_grads = deepcopy(single_integrated_grad) 221 | else: 222 | integrated_grads = (integrated_grads + single_integrated_grad) 223 | integrated_grads /= steps 224 | saliency = integrated_grads[0] 225 | img = input_image.detach().cpu().numpy().squeeze(0) 226 | saliency = np.asarray(saliency) * img 227 | if make_single_channel: 228 | saliency = max_regarding_to_abs(np.max(saliency, axis=0), np.min(saliency, axis=0)) 229 | return saliency 230 | 231 | 232 | def generate_integrad_saliency_maps(model, vgg, preprocessed_image, device, steps=100, make_single_channel=True): 233 | IG = IntegratedGradients(model, vgg, device) 234 | integrated_grads = IG.generate_integrated_gradients(preprocessed_image, steps, make_single_channel) 235 | if make_single_channel: 236 | integrated_grads = convert_to_grayscale(integrated_grads) 237 | return integrated_grads 238 | 239 | 240 | class GuidedBackprop(): 241 | def __init__(self, model, vgg, device): 242 | self.model = model 243 | self.vgg = vgg 244 | self.gradients = None 245 | self.forward_relu_outputs = [] 246 | self.device = device 247 | self.hooks = [] 248 | self.model.eval() 249 | self.update_relus() 250 | 251 | def update_relus(self): 252 | 253 | def relu_backward_hook_function(module, grad_in, grad_out): 254 | corresponding_forward_output = self.forward_relu_outputs[-1] 255 | corresponding_forward_output[corresponding_forward_output > 0] = 1 256 | modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0) 257 | del self.forward_relu_outputs[-1] # Remove last forward output 258 | return (modified_grad_out,) 259 | 260 | def relu_forward_hook_function(module, ten_in, ten_out): 261 | self.forward_relu_outputs.append(ten_out) 262 | 263 | # Loop through layers, hook up ReLUs 264 | for module in self.model.modules(): 265 | if isinstance(module, ReLU): 266 | self.hooks.append(module.register_backward_hook(relu_backward_hook_function)) 267 | self.hooks.append(module.register_forward_hook(relu_forward_hook_function)) 268 | 269 | def generate_gradients(self, input_image, config, make_single_channel=True): 270 | vanillaSaliency = VanillaSaliency(self.model, self.vgg, self.device, config=config) 271 | sal = vanillaSaliency.generate_saliency(input_image, make_single_channel) 272 | if not isinstance(sal, np.ndarray): 273 | sal = sal.detach().cpu().numpy() 274 | for hook in self.hooks: 275 | hook.remove() 276 | return sal 277 | 278 | 279 | def gbp_localization(model, vgg, test_dataloader, config): 280 | model.eval() 281 | print("GBP Method:") 282 | 283 | grad1 = None 284 | i = 0 285 | 286 | for data in test_dataloader: 287 | X, Y = data 288 | grad1 = np.zeros((X.shape[0], 1, 128, 128), dtype=np.float32) 289 | for x in X: 290 | data = x.view(1, 3, 128, 128) 291 | 292 | GBP = GuidedBackprop(model, vgg, 'cuda:0') 293 | gbp_saliency = abs(GBP.generate_gradients(data, config)) 294 | gbp_saliency = (gbp_saliency - min(gbp_saliency.flatten())) / ( 295 | max(gbp_saliency.flatten()) - min(gbp_saliency.flatten())) 296 | saliency = gbp_saliency 297 | 298 | saliency = gaussian_filter(saliency, sigma=4) 299 | grad1[i] = saliency 300 | i += 1 301 | 302 | grad1 = grad1.reshape(-1, 128, 128) 303 | return grad1 304 | 305 | 306 | def smooth_grad_localization(model, vgg, test_dataloader, config): 307 | model.eval() 308 | print("Smooth Grad Method:") 309 | 310 | grad1 = None 311 | i = 0 312 | 313 | for data in test_dataloader: 314 | X, Y = data 315 | grad1 = np.zeros((X.shape[0], 1, 128, 128), dtype=np.float32) 316 | for x in X: 317 | data = x.view(1, 3, 128, 128) 318 | 319 | vbp = VanillaSaliency(model, vgg, 'cuda:0', config) 320 | 321 | smooth_grad_saliency = abs(generate_smooth_grad(data, 50, 0.05, vbp)) 322 | smooth_grad_saliency = (smooth_grad_saliency - min(smooth_grad_saliency.flatten())) / ( 323 | max(smooth_grad_saliency.flatten()) - min(smooth_grad_saliency.flatten())) 324 | saliency = smooth_grad_saliency 325 | 326 | saliency = gaussian_filter(saliency, sigma=4) 327 | grad1[i] = saliency 328 | i += 1 329 | 330 | grad1 = grad1.reshape(-1, 128, 128) 331 | return grad1 332 | 333 | 334 | def compute_localization_auc(grad, x_ground): 335 | tpr = [] 336 | fpr = [] 337 | x_ground_comp = np.mean(x_ground, axis=3) 338 | 339 | thresholds = [0.001 * i for i in range(1000)] 340 | 341 | for threshold in thresholds: 342 | grad_t = 1.0 * (grad >= threshold) 343 | grad_t = morphological_process(grad_t) 344 | tp_map = np.multiply(grad_t, x_ground_comp) 345 | tpr.append(np.sum(tp_map) / np.sum(x_ground_comp)) 346 | 347 | inv_x_ground = 1 - x_ground_comp 348 | fp_map = np.multiply(grad_t, inv_x_ground) 349 | tn_map = np.multiply(1 - grad_t, 1 - x_ground_comp) 350 | fpr.append(np.sum(fp_map) / (np.sum(fp_map) + np.sum(tn_map))) 351 | 352 | return auc(fpr, tpr) 353 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from test import * 2 | from utils.utils import * 3 | from dataloader import * 4 | from pathlib import Path 5 | from torch.autograd import Variable 6 | import pickle 7 | from test_functions import detection_test 8 | from loss_functions import * 9 | 10 | parser = ArgumentParser() 11 | parser.add_argument('--config', type=str, default='configs/config.yaml', help="training configuration") 12 | 13 | 14 | def train(config): 15 | direction_loss_only = config["direction_loss_only"] 16 | normal_class = config["normal_class"] 17 | learning_rate = float(config['learning_rate']) 18 | num_epochs = config["num_epochs"] 19 | lamda = config['lamda'] 20 | continue_train = config['continue_train'] 21 | last_checkpoint = config['last_checkpoint'] 22 | 23 | checkpoint_path = "./outputs/{}/{}/checkpoints/".format(config['experiment_name'], config['dataset_name']) 24 | 25 | # create directory 26 | Path(checkpoint_path).mkdir(parents=True, exist_ok=True) 27 | 28 | train_dataloader, test_dataloader = load_data(config) 29 | if continue_train: 30 | vgg, model = get_networks(config, load_checkpoint=True) 31 | else: 32 | vgg, model = get_networks(config) 33 | 34 | # Criteria And Optimizers 35 | if direction_loss_only: 36 | criterion = DirectionOnlyLoss() 37 | else: 38 | criterion = MseDirectionLoss(lamda) 39 | 40 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 41 | if continue_train: 42 | optimizer.load_state_dict( 43 | torch.load('{}Opt_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, last_checkpoint))) 44 | 45 | losses = [] 46 | roc_aucs = [] 47 | if continue_train: 48 | with open('{}Auc_{}_epoch_{}.pickle'.format(checkpoint_path, normal_class, last_checkpoint), 'rb') as f: 49 | roc_aucs = pickle.load(f) 50 | 51 | for epoch in range(num_epochs + 1): 52 | model.train() 53 | epoch_loss = 0 54 | for data in train_dataloader: 55 | X = data[0] 56 | if X.shape[1] == 1: 57 | X = X.repeat(1, 3, 1, 1) 58 | X = Variable(X).cuda() 59 | 60 | output_pred = model.forward(X) 61 | output_real = vgg(X) 62 | 63 | total_loss = criterion(output_pred, output_real) 64 | 65 | # Add loss to the list 66 | epoch_loss += total_loss.item() 67 | losses.append(total_loss.item()) 68 | 69 | # Clear the previous gradients 70 | optimizer.zero_grad() 71 | # Compute gradients 72 | total_loss.backward() 73 | # Adjust weights 74 | optimizer.step() 75 | 76 | print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, epoch_loss)) 77 | if epoch % 10 == 0: 78 | roc_auc = detection_test(model, vgg, test_dataloader, config) 79 | roc_aucs.append(roc_auc) 80 | print("RocAUC at epoch {}:".format(epoch), roc_auc) 81 | 82 | if epoch % 50 == 0: 83 | torch.save(model.state_dict(), 84 | '{}Cloner_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, epoch)) 85 | torch.save(optimizer.state_dict(), 86 | '{}Opt_{}_epoch_{}.pth'.format(checkpoint_path, normal_class, epoch)) 87 | with open('{}Auc_{}_epoch_{}.pickle'.format(checkpoint_path, normal_class, epoch), 88 | 'wb') as f: 89 | pickle.dump(roc_aucs, f) 90 | 91 | 92 | def main(): 93 | args = parser.parse_args() 94 | config = get_config(args.config) 95 | train(config) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def get_config(config): 7 | with open(config, 'r') as stream: 8 | return yaml.load(stream) 9 | 10 | 11 | def convert_to_grayscale(im_as_arr): 12 | grayscale_im = np.sum(np.abs(im_as_arr), axis=0) 13 | im_max = np.percentile(grayscale_im, 99) 14 | im_min = np.min(grayscale_im) 15 | grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1)) 16 | grayscale_im = np.expand_dims(grayscale_im, axis=0) 17 | return grayscale_im 18 | 19 | 20 | # opening morphological process for localization 21 | def morphological_process(x): 22 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) 23 | kernel = kernel.astype(np.uint8) 24 | binary_map = x.astype(np.uint8) 25 | opening = cv2.morphologyEx(binary_map[0], cv2.MORPH_OPEN, kernel) 26 | opening = opening.reshape(1, opening.shape[0], opening.shape[1]) 27 | for index in range(1, binary_map.shape[0]): 28 | temp = cv2.morphologyEx(binary_map[index], cv2.MORPH_OPEN, kernel) 29 | temp = temp.reshape(1, temp.shape[0], temp.shape[1]) 30 | opening = np.concatenate((opening, temp), axis=0) 31 | return opening 32 | 33 | 34 | def max_regarding_to_abs(a, b): 35 | c = np.zeros(a.shape) 36 | for i in range(len(a)): 37 | for j in range(len(a[0])): 38 | if abs(a[i][j]) >= abs(b[i][j]): 39 | c[i][j] = a[i][j] 40 | else: 41 | c[i][j] = b[i][j] 42 | return c 43 | --------------------------------------------------------------------------------