├── src ├── __init__.py ├── datasets │ ├── btad.py │ ├── __init__.py │ └── mvtec.py ├── backbones.py ├── metrics.py ├── multi_variate_gaussian.py ├── utils.py ├── sampler.py ├── common.py └── softpatch.py ├── requirements_dev.txt ├── images └── intuition.png ├── requirements.txt ├── run_btad.sh ├── run_mvtec.sh ├── README.md └── main.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | black>=21.11b0 3 | flake8>=4.0.1 4 | isort>=5.10.1 5 | pytest>=6.2.5 -------------------------------------------------------------------------------- /images/intuition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentYoutuResearch/AnomalyDetection-SoftPatch/HEAD/images/intuition.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click>= 8.0.3 2 | cudatoolkit>= 10.2 3 | # faiss-cpu 4 | faiss-gpu 5 | matplotlib>= 3.5.0 6 | pillow>= 8.4.0 7 | pretrainedmodels>= 0.7.4 8 | torch>= 1.10.0 9 | scikit-image>= 0.18.3 10 | scikit-learn>= 1.0.1 11 | scipy>= 1.7.1 12 | torchvision>= 0.11.1 13 | tqdm>= 4.62.3 14 | -------------------------------------------------------------------------------- /run_btad.sh: -------------------------------------------------------------------------------- 1 | ##################### BTAD 2 | datapath=../BTAD 3 | datasets=('01' '02' '03') 4 | dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '$dataset; done)) 5 | 6 | python main.py --dataset btad --data_path ../../BTAD --noise 0 "${dataset_flags[@]}" --seed 0 \ 7 | --gpu 1 --resize 512 --imagesize 512 --sampling_ratio 0.01 -------------------------------------------------------------------------------- /src/datasets/btad.py: -------------------------------------------------------------------------------- 1 | from .mvtec import MVTecDataset, DatasetSplit # NOCA:unused-import(used) 2 | 3 | 4 | class BTADDataset(MVTecDataset): 5 | def __init__(self, 6 | **kwargs): 7 | super(BTADDataset, self).__init__(**kwargs) 8 | 9 | def set_normal_class(self): 10 | self.normal_class = "ok" 11 | -------------------------------------------------------------------------------- /run_mvtec.sh: -------------------------------------------------------------------------------- 1 | datapath=../../MVTec 2 | datasets=('bottle' 'cable' 'capsule' 'carpet' 'grid' 'hazelnut' 3 | 'leather' 'metal_nut' 'pill' 'screw' 'tile' 'toothbrush' 'transistor' 'wood' 'zipper') 4 | dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '$dataset; done)) 5 | 6 | python main.py --dataset mvtec --data_path ../../MVTec --noise 0.1 "${dataset_flags[@]}" --gpu 0 7 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import PIL 3 | import torch 4 | from torchvision import transforms 5 | import numpy as np 6 | 7 | 8 | class AddSaltPepperNoise(object): 9 | 10 | def __init__(self, density=0.0, prob=0.5): 11 | self.density = density 12 | self.prob = prob 13 | 14 | def __call__(self, img): 15 | if random.uniform(0, 1) < self.prob: 16 | img = np.array(img) 17 | height, width, channel = img.shape 18 | density = self.density 19 | s_d = 1 - density 20 | mask = np.random.choice((0, 1, 2), size=(height, width, 1), p=[density / 2.0, density / 2.0, s_d]) 21 | mask = np.repeat(mask, channel, axis=2) 22 | img[mask == 0] = 0 23 | img[mask == 1] = 255 24 | img = PIL.Image.fromarray(img.astype('uint8')).convert('RGB') 25 | return img 26 | else: 27 | return img 28 | 29 | 30 | class AddGaussianNoise(object): 31 | def __init__(self, mean=0., std=1.): 32 | self.std = std 33 | self.mean = mean 34 | 35 | def __call__(self, tensor): 36 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 37 | 38 | def __repr__(self): 39 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 40 | 41 | 42 | class NoiseDataset(torch.utils.data.Dataset): 43 | def __init__( 44 | self, 45 | source, 46 | ): 47 | self.source = source 48 | # transform 49 | self.transform_noise = transforms.Compose([ 50 | # transforms.RandomChoice(transforms), 51 | # AddSaltPepperNoise(0.05, 1), 52 | # AddGaussianNoise(std=0.05), 53 | # transforms.GaussianBlur(3), 54 | # transforms.RandomHorizontalFlip(p=1), 55 | # transforms.RandomRotation(10), 56 | transforms.RandomAffine(10, (0.1, 0.1), (0.9, 1.1), 10) 57 | ]) 58 | 59 | 60 | def __len__(self): 61 | return len(self.source) 62 | 63 | def __getitem__(self, idx): 64 | item = self.source[idx] 65 | 66 | item["image"] = self.transform_noise(item["image"]) 67 | return item 68 | -------------------------------------------------------------------------------- /src/backbones.py: -------------------------------------------------------------------------------- 1 | import timm # noqa # NOCA:unused-import(used) 2 | import torchvision.models as models # noqa # NOCA:unused-import(used) 3 | 4 | _BACKBONES = { 5 | "alexnet": "models.alexnet(pretrained=True)", 6 | "bninception": 'pretrainedmodels.__dict__["bninception"]' 7 | '(pretrained="imagenet", num_classes=1000)', 8 | "resnet50": "models.resnet50(pretrained=True)", 9 | "resnet101": "models.resnet101(pretrained=True)", 10 | "resnext101": "models.resnext101_32x8d(pretrained=True)", 11 | "resnet200": 'timm.create_model("resnet200", pretrained=True)', 12 | "resnest50": 'timm.create_model("resnest50d_4s2x40d", pretrained=True)', 13 | "resnetv2_50_bit": 'timm.create_model("resnetv2_50x3_bitm", pretrained=True)', 14 | "resnetv2_50_21k": 'timm.create_model("resnetv2_50x3_bitm_in21k", pretrained=True)', 15 | "resnetv2_101_bit": 'timm.create_model("resnetv2_101x3_bitm", pretrained=True)', 16 | "resnetv2_101_21k": 'timm.create_model("resnetv2_101x3_bitm_in21k", pretrained=True)', 17 | "resnetv2_152_bit": 'timm.create_model("resnetv2_152x4_bitm", pretrained=True)', 18 | "resnetv2_152_21k": 'timm.create_model("resnetv2_152x4_bitm_in21k", pretrained=True)', 19 | "resnetv2_152_384": 'timm.create_model("resnetv2_152x2_bit_teacher_384", pretrained=True)', 20 | "resnetv2_101": 'timm.create_model("resnetv2_101", pretrained=True)', 21 | "vgg11": "models.vgg11(pretrained=True)", 22 | "vgg19": "models.vgg19(pretrained=True)", 23 | "vgg19_bn": "models.vgg19_bn(pretrained=True)", 24 | "wideresnet50": "models.wide_resnet50_2(pretrained=True)", 25 | "wideresnet101": "models.wide_resnet101_2(pretrained=True)", 26 | "mnasnet_100": 'timm.create_model("mnasnet_100", pretrained=True)', 27 | "mnasnet_a1": 'timm.create_model("mnasnet_a1", pretrained=True)', 28 | "mnasnet_b1": 'timm.create_model("mnasnet_b1", pretrained=True)', 29 | "densenet121": 'timm.create_model("densenet121", pretrained=True)', 30 | "densenet201": 'timm.create_model("densenet201", pretrained=True)', 31 | "inception_v4": 'timm.create_model("inception_v4", pretrained=True)', 32 | "vit_small": 'timm.create_model("vit_small_patch16_224", pretrained=True)', 33 | "vit_base": 'timm.create_model("vit_base_patch16_224", pretrained=True)', 34 | "vit_large": 'timm.create_model("vit_large_patch16_224", pretrained=True)', 35 | "vit_r50": 'timm.create_model("vit_large_r50_s32_224", pretrained=True)', 36 | "vit_deit_base": 'timm.create_model("deit_base_patch16_224", pretrained=True)', 37 | "vit_deit_distilled": 'timm.create_model("deit_base_distilled_patch16_224", pretrained=True)', 38 | "vit_swin_base": 'timm.create_model("swin_base_patch4_window7_224", pretrained=True)', 39 | "vit_swin_large": 'timm.create_model("swin_large_patch4_window7_224", pretrained=True)', 40 | "efficientnet_b7": 'timm.create_model("tf_efficientnet_b7", pretrained=True)', 41 | "efficientnet_b5": 'timm.create_model("tf_efficientnet_b5", pretrained=True)', 42 | "efficientnet_b3": 'timm.create_model("tf_efficientnet_b3", pretrained=True)', 43 | "efficientnet_b1": 'timm.create_model("tf_efficientnet_b1", pretrained=True)', 44 | "efficientnetv2_m": 'timm.create_model("tf_efficientnetv2_m", pretrained=True)', 45 | "efficientnetv2_l": 'timm.create_model("tf_efficientnetv2_l", pretrained=True)', 46 | "efficientnet_b3a": 'timm.create_model("efficientnet_b3a", pretrained=True)', 47 | } 48 | 49 | 50 | def load(name): 51 | return eval(_BACKBONES[name]) 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AnomalyDetection-SoftPatch/SoftPatch+ 2 | This repository contains codes for the official implementation in PyTorch of NeurIPS 2022 paper "[SoftPatch: Unsupervised Anomaly Detection with Noisy Data](https://proceedings.neurips.cc//paper_files/paper/2022/hash/637a456d89289769ac1ab29617ef7213-Abstract-Conference.html)" and its improved version SoftPatch+. 3 | 4 | ![softpatch_intuition](images/intuition.png) 5 | 6 | ## Quick Start 7 | 8 | ### Requirement 9 | Our results were computed using Python 3.8 with packages and respective version noted in requirements.txt. 10 | 11 | ### MVTec-AD 12 | 13 | - **Default Training**. To train SoftPatch on [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad) with 0.1 additional noise samples, run 14 | 15 | ``` 16 | datapath=/path_to_mvtec_folder/mvtec 17 | datasets=('bottle' 'cable' 'capsule' 'carpet' 'grid' 'hazelnut' 18 | 'leather' 'metal_nut' 'pill' 'screw' 'tile' 'toothbrush' 'transistor' 'wood' 'zipper') 19 | dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '$dataset; done)) 20 | 21 | python main.py --dataset mvtec --data_path ../../MVTec --noise 0.1 "${dataset_flags[@]}" --gpu 0 22 | ``` 23 | The default setting in ```run_mvtec.sh``` runs with 224x224 image size using a WideResNet50-backbone pretrained on ImageNet. 24 | 25 | - **Expected Performance**. Training on 1 GPU (NVIDIA Tesla V100 32GB) results in following performance. 26 | 27 | | Row Names | image_auroc | pixel_auroc | 28 | |------------------|-------------|-------------| 29 | | mvtec_bottle | 1.0000 | 0.9878 | 30 | | mvtec_cable | 0.9904 | 0.9862 | 31 | | mvtec_capsule | 0.9654 | 0.9883 | 32 | | mvtec_carpet | 0.9965 | 0.9920 | 33 | | mvtec_grid | 1.0000 | 0.9939 | 34 | | mvtec_hazelnut | 1.0000 | 0.9906 | 35 | | mvtec_leather | 1.0000 | 0.9931 | 36 | | mvtec_metal_nut | 0.9987 | 0.9845 | 37 | | mvtec_pill | 0.9562 | 0.9798 | 38 | | mvtec_screw | 0.9526 | 0.9944 | 39 | | mvtec_tile | 0.9866 | 0.9645 | 40 | | mvtec_toothbrush | 0.9931 | 0.9860 | 41 | | mvtec_transistor | 0.9974 | 0.9064 | 42 | | mvtec_wood | 0.9854 | 0.9714 | 43 | | mvtec_zipper | 0.9753 | 0.9892 | 44 | | Mean | 0.9865 | 0.9805 | 45 | 46 | - **Parameter Setting**. 47 | 48 | To choose other noise discriminator, use the ```--weight_method``` argument with ```'lof', 'nearest', 'gaussian' or 'lof_gpu'```. 'lof_gpu' supports computing LOF using the GPU which usually faster. 49 | 50 | To 51 | 52 | 53 | ### BTAD 54 | To train SoftPatch on [BTAD](https://www.kaggle.com/datasets/thtuan/btad-beantech-anomaly-detection), run: 55 | ``` 56 | datapath=/path_to_btad_folder/BTAD 57 | datasets=('01' '02' '03') 58 | dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '$dataset; done)) 59 | 60 | python main.py --dataset btad --data_path ../../BTAD --noise 0 "${dataset_flags[@]}" --seed 0 \ 61 | --gpu 1 --resize 512 --imagesize 512 --sampling_ratio 0.01 62 | ``` 63 | The default setting in ```run_btad.sh``` runs with 512x512 image size using a WideResNet50-backbone pretrained on ImageNet. 64 | 65 | | Row Names | image_auroc | pixel_auroc | 66 | |-----------|-------------|-------------| 67 | | btad_01 | 0.9981 | 0.9761 | 68 | | btad_02 | 0.9343 | 0.9662 | 69 | | btad_03 | 0.9969 | 0.9935 | 70 | | Mean | 0.9764 | 0.9786 | 71 | 72 | # Comments 73 | - Our codebase for the coreset construction builds heavily on [PatchCore](https://github.com/amazon-science/patchcore-inspection) codebase. Thanks for open-sourcing! 74 | 75 | # Citation 76 | Please cite the following paper if this dataset helps your project: 77 | ``` 78 | @misc{xisoftpatch, 79 | title={SoftPatch: Unsupervised Anomaly Detection with Noisy Data}, 80 | author={Xi, Jiang and Liu, Jianlin and Wang, Jinbao and Nie, Qiang and Kai, WU and Liu, Yong and Wang, Chengjie and Zheng, Feng}, 81 | booktitle={Advances in Neural Information Processing Systems} 82 | } 83 | ``` 84 | 85 | # License 86 | This project is licensed under the Apache-2.0 License. 87 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | """Anomaly metrics.""" 2 | import numpy as np 3 | from sklearn import metrics 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def compute_imagewise_retrieval_metrics( 8 | anomaly_prediction_weights, anomaly_ground_truth_labels 9 | ): 10 | """ 11 | Computes retrieval statistics (AUROC, FPR, TPR). 12 | 13 | Args: 14 | anomaly_prediction_weights: [np.array or list] [N] Assignment weights 15 | per image. Higher indicates higher 16 | probability of being an anomaly. 17 | anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1 18 | if image is an anomaly, 0 if not. 19 | """ 20 | fpr, tpr, thresholds = metrics.roc_curve( 21 | anomaly_ground_truth_labels, anomaly_prediction_weights 22 | ) 23 | auroc = metrics.roc_auc_score( 24 | anomaly_ground_truth_labels, anomaly_prediction_weights 25 | ) 26 | # TODO: draw_curve 27 | # draw_curve(fpr, tpr, auroc) 28 | return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds} 29 | 30 | 31 | def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks): 32 | """ 33 | Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations 34 | and ground truth segmentation masks. 35 | 36 | Args: 37 | anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains 38 | generated segmentation masks. 39 | ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains 40 | predefined ground truth segmentation masks 41 | """ 42 | if isinstance(anomaly_segmentations, list): 43 | anomaly_segmentations = np.stack(anomaly_segmentations) 44 | if isinstance(ground_truth_masks, list): 45 | ground_truth_masks = np.stack(ground_truth_masks) 46 | 47 | flat_anomaly_segmentations = anomaly_segmentations.ravel() 48 | flat_ground_truth_masks = ground_truth_masks.ravel() 49 | 50 | fpr, tpr, thresholds = metrics.roc_curve( 51 | flat_ground_truth_masks.astype(int), flat_anomaly_segmentations 52 | ) 53 | auroc = metrics.roc_auc_score( 54 | flat_ground_truth_masks.astype(int), flat_anomaly_segmentations 55 | ) 56 | 57 | precision, recall, thresholds = metrics.precision_recall_curve( 58 | flat_ground_truth_masks.astype(int), flat_anomaly_segmentations 59 | ) 60 | f1_scores = np.divide( 61 | 2 * precision * recall, 62 | precision + recall, 63 | out=np.zeros_like(precision), 64 | where=(precision + recall) != 0, 65 | ) 66 | 67 | optimal_threshold = thresholds[np.argmax(f1_scores)] 68 | predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(int) 69 | fpr_optim = np.mean(predictions > flat_ground_truth_masks) 70 | fnr_optim = np.mean(predictions < flat_ground_truth_masks) 71 | 72 | return { 73 | "auroc": auroc, 74 | "fpr": fpr, 75 | "tpr": tpr, 76 | "optimal_threshold": optimal_threshold, 77 | "optimal_fpr": fpr_optim, 78 | "optimal_fnr": fnr_optim, 79 | } 80 | 81 | 82 | def draw_curve(fpr, tpr, auroc): 83 | plt.plot(fpr, tpr, 'k--', label='ROC (area = {0:.4f})'.format(auroc), lw=2) 84 | 85 | plt.xlim([-0.05, 1.05]) 86 | plt.ylim([-0.05, 1.05]) 87 | plt.xlabel('False Positive Rate') 88 | plt.ylabel('True Positive Rate') 89 | plt.title('ROC Curve') 90 | plt.legend(loc="lower right") 91 | 92 | error = 0.015 93 | miss = 0.1 94 | plt.plot([error, error], [-0.05, 1.05], 'k:', lw=1) 95 | plt.plot([-0.05, 1.05], [1-miss, 1-miss], 'k:', lw=1) 96 | error_y, miss_x = 0, 1 97 | for i in range(len(fpr)): 98 | if fpr[i] <= error <= fpr[i + 1]: 99 | error_y = tpr[i] 100 | if tpr[i] <= 1-miss <= tpr[i + 1]: 101 | miss_x = fpr[i] 102 | # plt.scatter(error, error_y, c='k') 103 | # plt.scatter(miss_x, 1-miss, c='k') 104 | plt.text(error, error_y, "({0}, {1:.4f})".format(error, error_y), color='k') 105 | plt.text(miss_x, 1-miss, "({0:.4f}, {1})".format(miss_x, 1-miss), color='k') 106 | plt.show() 107 | -------------------------------------------------------------------------------- /src/multi_variate_gaussian.py: -------------------------------------------------------------------------------- 1 | """Multi Variate Gaussian Distribution.""" 2 | 3 | from typing import Any, List, Optional 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | 8 | 9 | class MultiVariateGaussian(nn.Module): 10 | """Multi Variate Gaussian Distribution.""" 11 | 12 | def __init__(self, n_features, n_patches): 13 | super().__init__() 14 | 15 | self.register_buffer("mean", torch.zeros(n_features, n_patches)) 16 | self.register_buffer("inv_covariance", torch.eye(n_features).unsqueeze(0).repeat(n_patches, 1, 1)) 17 | 18 | self.mean: Tensor 19 | self.inv_covariance: Tensor 20 | 21 | @staticmethod 22 | def _cov( 23 | observations: Tensor, 24 | rowvar: bool = False, 25 | bias: bool = False, 26 | ddof: Optional[int] = None, 27 | aweights: Tensor = None, 28 | ) -> Tensor: 29 | 30 | # ensure at least 2D 31 | if observations.dim() == 1: 32 | observations = observations.view(-1, 1) 33 | 34 | # treat each column as a data point, each row as a variable 35 | if rowvar and observations.shape[0] != 1: 36 | observations = observations.t() 37 | 38 | if ddof is None: 39 | if bias == 0: 40 | ddof = 1 41 | else: 42 | ddof = 0 43 | 44 | weights = aweights 45 | weights_sum: Any 46 | 47 | if weights is not None: 48 | if not torch.is_tensor(weights): 49 | weights = torch.tensor(weights, dtype=torch.float) # pylint: disable=not-callable 50 | weights_sum = torch.sum(weights) 51 | avg = torch.sum(observations * (weights / weights_sum)[:, None], 0) 52 | else: 53 | avg = torch.mean(observations, 0) 54 | 55 | # Determine the normalization 56 | if weights is None: 57 | fact = observations.shape[0] - ddof 58 | elif ddof == 0: 59 | fact = weights_sum 60 | elif aweights is None: 61 | fact = weights_sum - ddof 62 | else: 63 | fact = weights_sum - ddof * torch.sum(weights * weights) / weights_sum 64 | 65 | observations_m = observations.sub(avg.expand_as(observations)) 66 | 67 | if weights is None: 68 | x_transposed = observations_m.t() 69 | else: 70 | x_transposed = torch.mm(torch.diag(weights), observations_m).t() 71 | 72 | covariance = torch.mm(x_transposed, observations_m) 73 | covariance = covariance / fact 74 | 75 | return covariance.squeeze() 76 | 77 | def forward(self, embedding: Tensor) -> List[Tensor]: 78 | """Calculate multivariate Gaussian distribution. 79 | 80 | Args: 81 | embedding (Tensor): CNN features whose dimensionality is reduced via either random sampling or PCA. 82 | 83 | Returns: 84 | mean and inverse covariance of the multi-variate gaussian distribution that fits the features. 85 | """ 86 | device = embedding.device 87 | patch, _, channel = embedding.shape 88 | embedding_vectors = embedding.permute(1, 2, 0) 89 | 90 | # batch, channel, height, width = embedding.size() 91 | # embedding_vectors = embedding.view(batch, channel, height * width) 92 | self.mean = torch.mean(embedding_vectors, dim=0) 93 | covariance = torch.zeros(size=(channel, channel, patch), device=device) 94 | identity = torch.eye(channel).to(device) 95 | for i in range(patch): 96 | covariance[:, :, i] = self._cov(embedding_vectors[:, :, i], rowvar=False) + 0.01 * identity 97 | # (evals, evecs) = torch.eig(covariance[:, :, i]) # 98 | # compaction[i] = evals[:, 0].max() #torch.max(evals[:, 0]) 99 | # calculate inverse covariance as we need only the inverse 100 | self.inv_covariance = torch.linalg.inv(covariance.permute(2, 0, 1)) 101 | # compaction = covariance.norm(p=2, dim=(0, 1)) 102 | 103 | return [self.mean, self.inv_covariance] # 104 | # return [self.mean, self.inv_covariance, compaction] # 105 | 106 | def fit(self, embedding: Tensor) -> List[Tensor]: 107 | """Fit multi-variate gaussian distribution to the input embedding. 108 | 109 | Args: 110 | embedding (Tensor): Embedding vector extracted from CNN. 111 | 112 | Returns: 113 | Mean and the covariance of the embedding. 114 | """ 115 | return self.forward(embedding) 116 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | import os 4 | import random 5 | import tqdm 6 | 7 | import PIL 8 | import torch 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | 16 | def plot_segmentation_images( 17 | savefolder, 18 | image_paths, 19 | segmentations, 20 | anomaly_scores=None, 21 | mask_paths=None, 22 | image_transform=lambda x: x, 23 | mask_transform=lambda x: x, 24 | save_depth=2, 25 | # dataset = None 26 | ): 27 | """Generate anomaly segmentation images. 28 | 29 | Args: 30 | image_paths: List[str] List of paths to images. 31 | segmentations: [List[np.ndarray]] Generated anomaly segmentations. 32 | anomaly_scores: [List[float]] Anomaly scores for each image. 33 | mask_paths: [List[str]] List of paths to ground truth masks. 34 | image_transform: [function or lambda] Optional transformation of images. 35 | mask_transform: [function or lambda] Optional transformation of masks. 36 | save_depth: [int] Number of path-strings to use for image savenames. 37 | """ 38 | if mask_paths is None: 39 | mask_paths = ["-1" for _ in range(len(image_paths))] 40 | masks_provided = mask_paths[0] != "-1" 41 | if anomaly_scores is None: 42 | anomaly_scores = ["-1" for _ in range(len(image_paths))] 43 | 44 | os.makedirs(savefolder, exist_ok=True) 45 | 46 | # ''' 47 | for image_path, mask_path, anomaly_score, segmentation in tqdm.tqdm( 48 | zip(image_paths, mask_paths, anomaly_scores, segmentations), 49 | total=len(image_paths), 50 | desc="Generating Segmentation Images...", 51 | leave=False, 52 | ): 53 | image = PIL.Image.open(image_path).convert("RGB") 54 | image = image_transform(image) 55 | if not isinstance(image, np.ndarray): 56 | image = image.numpy() 57 | 58 | # if masks_provided: 59 | if mask_path is not None: 60 | mask = PIL.Image.open(mask_path).convert("RGB") 61 | mask = mask_transform(mask) 62 | if not isinstance(mask, np.ndarray): 63 | mask = mask.numpy() 64 | else: 65 | mask = np.zeros_like(image) 66 | # ''' 67 | savename = image_path.split("/") 68 | savename = "_".join(savename[-save_depth:]) 69 | savename = os.path.join(savefolder, savename) 70 | figure, axes = plt.subplots(1, 2 + int(masks_provided)) 71 | axes[0].imshow(image.transpose(1, 2, 0)) 72 | # axes[2].imshow(segmentation) 73 | axes[1].imshow(image.transpose(1, 2, 0), alpha=1) 74 | axes[1].imshow(segmentation, alpha=0.5) 75 | axes[2].imshow(mask.transpose(1, 2, 0)) 76 | figure.suptitle("Anomaly Score: {:.3f}".format(anomaly_score)) 77 | figure.set_size_inches(3 * (2 + int(masks_provided)), 3) 78 | figure.tight_layout() 79 | figure.savefig(savename, dpi=300) 80 | plt.close() 81 | 82 | 83 | def create_storage_folder( 84 | main_folder_path, project_folder, group_folder, mode="iterate" 85 | ): 86 | os.makedirs(main_folder_path, exist_ok=True) 87 | project_path = os.path.join(main_folder_path, project_folder) 88 | os.makedirs(project_path, exist_ok=True) 89 | save_path = os.path.join(project_path, group_folder) 90 | if mode == "iterate": 91 | counter = 0 92 | while os.path.exists(save_path): 93 | save_path = os.path.join(project_path, group_folder + "_" + str(counter)) 94 | counter += 1 95 | os.makedirs(save_path) 96 | elif mode == "overwrite": 97 | os.makedirs(save_path, exist_ok=True) 98 | 99 | return save_path 100 | 101 | 102 | def set_torch_device(gpu_ids): 103 | """Returns correct torch.device. 104 | 105 | Args: 106 | gpu_ids: [list] list of gpu ids. If empty, cpu is used. 107 | """ 108 | if len(gpu_ids) > 0: 109 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 110 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_ids[0]) 111 | return torch.device("cuda:{}".format(gpu_ids[0])) 112 | return torch.device("cpu") 113 | 114 | 115 | def fix_seeds(seed, with_torch=True, with_cuda=True): 116 | """Fixed available seeds for reproducibility. 117 | 118 | Args: 119 | seed: [int] Seed value. 120 | with_torch: Flag. If true, torch-related seeds are fixed. 121 | with_cuda: Flag. If true, torch+cuda-related seeds are fixed 122 | """ 123 | random.seed(seed) 124 | np.random.seed(seed) 125 | if with_torch: 126 | torch.manual_seed(seed) 127 | if with_cuda: 128 | torch.cuda.manual_seed(seed) 129 | torch.cuda.manual_seed_all(seed) 130 | torch.backends.cudnn.deterministic = True 131 | 132 | 133 | def compute_and_store_final_results( 134 | results_path, 135 | results, 136 | row_names=None, 137 | column_names=None, 138 | ): 139 | """Store computed results as CSV file. 140 | 141 | Args: 142 | results_path: [str] Where to store result csv. 143 | results: [List[List]] List of lists containing results per dataset, 144 | with results[i][0] == 'dataset_name' and results[i][1:6] = 145 | [instance_auroc, full_pixelwisew_auroc, full_pro, 146 | anomaly-only_pw_auroc, anomaly-only_pro] 147 | """ 148 | if column_names is None: 149 | column_names = [ 150 | "Instance AUROC", 151 | "Full Pixel AUROC", 152 | "Full PRO", 153 | "Anomaly Pixel AUROC", 154 | "Anomaly PRO", 155 | ] 156 | if row_names is not None: 157 | assert len(row_names) == len(results), "#Rownames != #Result-rows." 158 | 159 | mean_metrics = {} 160 | for i, result_key in enumerate(column_names): 161 | mean_metrics[result_key] = np.mean([x[i] for x in results]) 162 | LOGGER.info("{0}: {1:3.3f}".format(result_key, mean_metrics[result_key])) 163 | 164 | savename = os.path.join(results_path, "results.csv") 165 | with open(savename, "w") as csv_file: 166 | csv_writer = csv.writer(csv_file, delimiter=",") 167 | header = column_names 168 | if row_names is not None: 169 | header = ["Row Names"] + header 170 | 171 | csv_writer.writerow(header) 172 | for i, result_list in enumerate(results): 173 | csv_row = result_list 174 | if row_names is not None: 175 | csv_row = [row_names[i]] + result_list 176 | csv_writer.writerow(csv_row) 177 | mean_scores = list(mean_metrics.values()) 178 | if row_names is not None: 179 | mean_scores = ["Mean"] + mean_scores 180 | csv_writer.writerow(mean_scores) 181 | 182 | mean_metrics = {"mean_{0}".format(key): item for key, item in mean_metrics.items()} 183 | return mean_metrics 184 | -------------------------------------------------------------------------------- /src/datasets/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | from enum import Enum 3 | import PIL 4 | import torch 5 | from torchvision import transforms 6 | import numpy as np 7 | 8 | _CLASSNAMES = [ 9 | "bottle", 10 | "cable", 11 | "capsule", 12 | "carpet", 13 | "grid", 14 | "hazelnut", 15 | "leather", 16 | "metal_nut", 17 | "pill", 18 | "screw", 19 | "tile", 20 | "toothbrush", 21 | "transistor", 22 | "wood", 23 | "zipper", 24 | ] 25 | 26 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 27 | IMAGENET_STD = [0.229, 0.224, 0.225] 28 | 29 | 30 | class DatasetSplit(Enum): 31 | TRAIN = "train" 32 | VAL = "val" 33 | TEST = "test" 34 | 35 | 36 | class MVTecDataset(torch.utils.data.Dataset): 37 | """ 38 | PyTorch Dataset for MVTec. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | source, 44 | classname, 45 | resize=256, 46 | imagesize=224, 47 | split=DatasetSplit.TRAIN, 48 | train_val_split=1.0, 49 | **kwargs, 50 | ): 51 | """ 52 | Args: 53 | source: [str]. Path to the MVTec data folder. 54 | classname: [str or None]. Name of MVTec class that should be 55 | provided in this dataset. If None, the datasets 56 | iterates over all available images. 57 | resize: [int]. (Square) Size the loaded image initially gets 58 | resized to. 59 | imagesize: [int]. (Square) Size the resized loaded image gets 60 | (center-)cropped to. 61 | split: [enum-option]. Indicates if training or test split of the 62 | data should be used. Has to be an option taken from 63 | DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that 64 | mvtec.DatasetSplit.TEST will also load mask data. 65 | """ 66 | super().__init__() 67 | self.normal_class = "good" 68 | self.set_normal_class() 69 | self.source = source 70 | self.split = split 71 | self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES 72 | self.train_val_split = train_val_split 73 | 74 | self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() 75 | 76 | self.transform_img = [ 77 | transforms.Resize(resize), 78 | transforms.CenterCrop(imagesize), 79 | transforms.ToTensor(), 80 | transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), 81 | ] 82 | self.transform_img = transforms.Compose(self.transform_img) 83 | 84 | self.transform_mask = [ 85 | transforms.Resize(resize), 86 | transforms.CenterCrop(imagesize), 87 | transforms.ToTensor(), 88 | ] 89 | self.transform_mask = transforms.Compose(self.transform_mask) 90 | 91 | self.imagesize = (3, imagesize, imagesize) 92 | self.transform_std = IMAGENET_STD 93 | self.transform_mean = IMAGENET_MEAN 94 | 95 | def set_normal_class(self): 96 | self.normal_class = "good" 97 | 98 | def __getitem__(self, idx): 99 | classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] 100 | image = PIL.Image.open(image_path).convert("RGB") 101 | image = self.transform_img(image) 102 | 103 | if self.split == DatasetSplit.TEST and mask_path is not None: 104 | mask = PIL.Image.open(mask_path) 105 | mask = self.transform_mask(mask) 106 | else: 107 | mask = torch.zeros([1, *image.size()[1:]]) 108 | 109 | return { 110 | "image": image, 111 | "mask": mask, 112 | "classname": classname, 113 | "anomaly": anomaly, 114 | "is_anomaly": int(anomaly != self.normal_class), 115 | "image_name": "/".join(image_path.split("/")[-4:]), 116 | "image_path": image_path, 117 | # "mask_path": mask_path, 118 | } 119 | 120 | def __len__(self): 121 | return len(self.data_to_iterate) 122 | 123 | def get_image_data(self): 124 | imgpaths_per_class = {} 125 | maskpaths_per_class = {} 126 | 127 | for classname in self.classnames_to_use: 128 | classpath = os.path.join(self.source, classname, self.split.value) 129 | maskpath = os.path.join(self.source, classname, "ground_truth") 130 | anomaly_types = os.listdir(classpath) 131 | 132 | imgpaths_per_class[classname] = {} 133 | maskpaths_per_class[classname] = {} 134 | 135 | for anomaly in anomaly_types: 136 | anomaly_path = os.path.join(classpath, anomaly) 137 | anomaly_files = sorted(os.listdir(anomaly_path)) 138 | imgpaths_per_class[classname][anomaly] = [ 139 | os.path.join(anomaly_path, x) for x in anomaly_files 140 | ] 141 | 142 | if self.train_val_split < 1.0: 143 | n_images = len(imgpaths_per_class[classname][anomaly]) 144 | train_val_split_idx = int(n_images * self.train_val_split) 145 | if self.split == DatasetSplit.TRAIN: 146 | imgpaths_per_class[classname][anomaly] = imgpaths_per_class[ 147 | classname 148 | ][anomaly][:train_val_split_idx] 149 | elif self.split == DatasetSplit.VAL: 150 | imgpaths_per_class[classname][anomaly] = imgpaths_per_class[ 151 | classname 152 | ][anomaly][train_val_split_idx:] 153 | 154 | if self.split == DatasetSplit.TEST and anomaly != self.normal_class: 155 | anomaly_mask_path = os.path.join(maskpath, anomaly) 156 | anomaly_mask_files = sorted(os.listdir(anomaly_mask_path)) 157 | maskpaths_per_class[classname][anomaly] = [ 158 | os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files 159 | ] 160 | else: 161 | maskpaths_per_class[classname][self.normal_class] = None 162 | 163 | # Unrolls the data dictionary to an easy-to-iterate list. 164 | data_to_iterate = [] 165 | for classname in sorted(imgpaths_per_class.keys()): 166 | for anomaly in sorted(imgpaths_per_class[classname].keys()): 167 | for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): 168 | data_tuple = [classname, anomaly, image_path] 169 | if self.split == DatasetSplit.TEST and anomaly != self.normal_class: 170 | data_tuple.append(maskpaths_per_class[classname][anomaly][i]) 171 | else: 172 | data_tuple.append(None) 173 | data_to_iterate.append(data_tuple) 174 | data_to_iterate = np.array(data_to_iterate) 175 | return imgpaths_per_class, data_to_iterate 176 | -------------------------------------------------------------------------------- /src/sampler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Union 3 | import tqdm 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class IdentitySampler: 10 | def run( 11 | self, features: Union[torch.Tensor, np.ndarray] 12 | ) -> Union[torch.Tensor, np.ndarray]: 13 | return features 14 | 15 | 16 | class BaseSampler(abc.ABC): 17 | def __init__(self, percentage: float): 18 | if not 0 < percentage < 1: 19 | raise ValueError("Percentage value not in (0, 1).") 20 | self.percentage = percentage 21 | 22 | @abc.abstractmethod 23 | def run( 24 | self, features: Union[torch.Tensor, np.ndarray] 25 | ) -> Union[torch.Tensor, np.ndarray]: 26 | pass 27 | 28 | def _store_type(self, features: Union[torch.Tensor, np.ndarray]) -> None: 29 | self.features_is_numpy = isinstance(features, np.ndarray) 30 | if not self.features_is_numpy: 31 | self.features_device = features.device 32 | 33 | def _restore_type(self, features: torch.Tensor) -> Union[torch.Tensor, np.ndarray]: 34 | if self.features_is_numpy: 35 | return features.cpu().numpy() 36 | return features.to(self.features_device) 37 | 38 | 39 | class GreedyCoresetSampler(BaseSampler): 40 | def __init__( 41 | self, 42 | percentage: float, 43 | device: torch.device, 44 | dimension_to_project_features_to=128, 45 | ): 46 | """Greedy Coreset sampling base class.""" 47 | super().__init__(percentage) 48 | 49 | self.device = device 50 | self.dimension_to_project_features_to = dimension_to_project_features_to 51 | 52 | def _reduce_features(self, features): 53 | if features.shape[1] == self.dimension_to_project_features_to: 54 | return features 55 | mapper = torch.nn.Linear( 56 | features.shape[1], self.dimension_to_project_features_to, bias=False 57 | ) 58 | _ = mapper.to(self.device) 59 | features = features.to(self.device) 60 | return mapper(features) 61 | 62 | def run( 63 | self, features: Union[torch.Tensor, np.ndarray] 64 | ) -> Union[torch.Tensor, np.ndarray]: 65 | """Subsamples features using Greedy Coreset. 66 | 67 | Args: 68 | features: [N x D] 69 | """ 70 | if self.percentage == 1: 71 | return features 72 | self._store_type(features) 73 | if isinstance(features, np.ndarray): 74 | features = torch.from_numpy(features) 75 | reduced_features = self._reduce_features(features) 76 | sample_indices = self._compute_greedy_coreset_indices(reduced_features) 77 | features = features[sample_indices] 78 | return self._restore_type(features), sample_indices 79 | 80 | @staticmethod 81 | def _compute_batchwise_differences( 82 | matrix_a: torch.Tensor, matrix_b: torch.Tensor 83 | ) -> torch.Tensor: 84 | """Computes batchwise Euclidean distances using PyTorch.""" 85 | a_times_a = matrix_a.unsqueeze(1).bmm(matrix_a.unsqueeze(2)).reshape(-1, 1) 86 | b_times_b = matrix_b.unsqueeze(1).bmm(matrix_b.unsqueeze(2)).reshape(1, -1) 87 | a_times_b = matrix_a.mm(matrix_b.T) 88 | 89 | return (-2 * a_times_b + a_times_a + b_times_b).clamp(0, None).sqrt() 90 | 91 | def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray: 92 | """Runs iterative greedy coreset selection. 93 | 94 | Args: 95 | features: [NxD] input feature bank to sample. 96 | """ 97 | distance_matrix = self._compute_batchwise_differences(features, features) 98 | coreset_anchor_distances = torch.norm(distance_matrix, dim=1) 99 | 100 | coreset_indices = [] 101 | num_coreset_samples = int(len(features) * self.percentage) 102 | 103 | for _ in range(num_coreset_samples): 104 | select_idx = torch.argmax(coreset_anchor_distances).item() 105 | coreset_indices.append(select_idx) 106 | 107 | coreset_select_distance = distance_matrix[ 108 | :, select_idx : select_idx + 1 # noqa E203 109 | ] 110 | coreset_anchor_distances = torch.cat( 111 | [coreset_anchor_distances.unsqueeze(-1), coreset_select_distance], dim=1 112 | ) 113 | coreset_anchor_distances = torch.min(coreset_anchor_distances, dim=1).values 114 | 115 | return np.array(coreset_indices) 116 | 117 | 118 | class ApproximateGreedyCoresetSampler(GreedyCoresetSampler): 119 | def __init__( 120 | self, 121 | percentage: float, 122 | device: torch.device, 123 | number_of_starting_points: int = 10, 124 | dimension_to_project_features_to: int = 128, 125 | ): 126 | """Approximate Greedy Coreset sampling base class.""" 127 | self.number_of_starting_points = number_of_starting_points 128 | self.sampling_weight = None 129 | super().__init__(percentage, device, dimension_to_project_features_to) 130 | 131 | def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray: 132 | """Runs approximate iterative greedy coreset selection. 133 | 134 | This greedy coreset implementation does not require computation of the 135 | full N x N distance matrix and thus requires a lot less memory, however 136 | at the cost of increased sampling times. 137 | 138 | Args: 139 | features: [NxD] input feature bank to sample. 140 | """ 141 | number_of_starting_points = np.clip( 142 | self.number_of_starting_points, None, len(features) 143 | ) 144 | start_points = np.random.choice( 145 | len(features), number_of_starting_points, replace=False 146 | ).tolist() 147 | 148 | approximate_distance_matrix = self._compute_batchwise_differences( 149 | features, features[start_points] 150 | ) 151 | approximate_coreset_anchor_distances = torch.mean( 152 | approximate_distance_matrix, axis=-1 153 | ).reshape(-1, 1) 154 | coreset_indices = [] 155 | num_coreset_samples = int(len(features) * self.percentage) 156 | 157 | with torch.no_grad(): 158 | for _ in tqdm.tqdm(range(num_coreset_samples), desc="Subsampling..."): 159 | select_idx = torch.argmax(approximate_coreset_anchor_distances).item() 160 | coreset_indices.append(select_idx) 161 | coreset_select_distance = self._compute_batchwise_differences( 162 | features, features[select_idx : select_idx + 1] 163 | ) 164 | if self.sampling_weight is not None: 165 | coreset_select_distance = coreset_select_distance*self.sampling_weight.unsqueeze(-1) 166 | approximate_coreset_anchor_distances = torch.cat( 167 | [approximate_coreset_anchor_distances, coreset_select_distance], 168 | dim=-1, 169 | ) 170 | approximate_coreset_anchor_distances = torch.min( 171 | approximate_coreset_anchor_distances, dim=1 172 | ).values.reshape(-1, 1) 173 | 174 | return np.array(coreset_indices) 175 | 176 | 177 | class RandomSampler(BaseSampler): 178 | def __init__(self, percentage: float): 179 | super().__init__(percentage) 180 | 181 | def run( 182 | self, features: Union[torch.Tensor, np.ndarray] 183 | ) -> Union[torch.Tensor, np.ndarray]: 184 | """Randomly samples input feature collection. 185 | 186 | Args: 187 | features: [N x D] 188 | """ 189 | num_random_samples = int(len(features) * self.percentage) 190 | subset_indices = np.random.choice( 191 | len(features), num_random_samples, replace=False 192 | ) 193 | subset_indices = np.array(subset_indices) 194 | return features[subset_indices] 195 | 196 | 197 | class WeightedGreedyCoresetSampler(ApproximateGreedyCoresetSampler): 198 | def __init__( 199 | self, 200 | percentage: float, 201 | device: torch.device, 202 | number_of_starting_points: int = 10, 203 | dimension_to_project_features_to: int = 128, 204 | ): 205 | """Approximate Greedy Coreset sampling base class.""" 206 | self.number_of_starting_points = number_of_starting_points 207 | super().__init__(percentage, device, dimension_to_project_features_to) 208 | self.sampling_weight = None 209 | 210 | def set_sampling_weight(self, sampling_weight): 211 | self.sampling_weight = sampling_weight 212 | -------------------------------------------------------------------------------- /src/common.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import pickle 4 | from typing import List 5 | from typing import Union 6 | 7 | import faiss 8 | import numpy as np 9 | import scipy.ndimage as ndimage 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | 14 | class FaissNN(object): 15 | def __init__(self, on_gpu: bool = False, num_workers: int = 4, device=0) -> None: 16 | """FAISS Nearest neighbourhood search. 17 | 18 | Args: 19 | on_gpu: If set true, nearest neighbour searches are done on GPU. 20 | num_workers: Number of workers to use with FAISS for similarity search. 21 | """ 22 | faiss.omp_set_num_threads(num_workers) 23 | self.on_gpu = on_gpu 24 | self.search_index = None 25 | 26 | self.device = device 27 | 28 | def _gpu_cloner_options(self): 29 | return faiss.GpuClonerOptions() 30 | 31 | def _index_to_gpu(self, index): 32 | if self.on_gpu: 33 | # For the non-gpu faiss python package, there is no GpuClonerOptions 34 | # so we can not make a default in the function header. 35 | return faiss.index_cpu_to_gpu( 36 | faiss.StandardGpuResources(), self.device, index, self._gpu_cloner_options() 37 | ) 38 | return index 39 | 40 | def _index_to_cpu(self, index): 41 | if self.on_gpu: 42 | return faiss.index_gpu_to_cpu(index) 43 | return index 44 | 45 | def _create_index(self, dimension): 46 | if self.on_gpu: 47 | cfg = faiss.GpuIndexFlatConfig() 48 | cfg.device = self.device 49 | return faiss.GpuIndexFlatL2( 50 | faiss.StandardGpuResources(), dimension, cfg 51 | ) 52 | return faiss.IndexFlatL2(dimension) 53 | 54 | def fit(self, features: np.ndarray) -> None: 55 | """ 56 | Adds features to the FAISS search index. 57 | 58 | Args: 59 | features: Array of size NxD. 60 | """ 61 | if self.search_index: 62 | self.reset_index() 63 | self.search_index = self._create_index(features.shape[-1]) 64 | self._train(self.search_index, features) 65 | self.search_index.add(features) 66 | 67 | def _train(self, _index, _features): 68 | pass 69 | 70 | def run( 71 | self, 72 | n_nearest_neighbours, 73 | query_features: np.ndarray, 74 | index_features: np.ndarray = None, 75 | ) -> Union[np.ndarray, np.ndarray, np.ndarray]: 76 | """ 77 | Returns distances and indices of nearest neighbour search. 78 | 79 | Args: 80 | query_features: Features to retrieve. 81 | index_features: [optional] Index features to search in. 82 | """ 83 | if index_features is None: 84 | return self.search_index.search(query_features, n_nearest_neighbours) 85 | 86 | # Build a search index just for this search. 87 | search_index = self._create_index(index_features.shape[-1]) 88 | self._train(search_index, index_features) 89 | search_index.add(index_features) 90 | return search_index.search(query_features, n_nearest_neighbours) 91 | 92 | def save(self, filename: str) -> None: 93 | faiss.write_index(self._index_to_cpu(self.search_index), filename) 94 | 95 | def load(self, filename: str) -> None: 96 | self.search_index = self._index_to_gpu(faiss.read_index(filename)) 97 | 98 | def reset_index(self): 99 | if self.search_index: 100 | self.search_index.reset() 101 | self.search_index = None 102 | 103 | 104 | class ApproximateFaissNN(FaissNN): 105 | def _train(self, index, features): 106 | index.train(features) 107 | 108 | def _gpu_cloner_options(self): 109 | cloner = faiss.GpuClonerOptions() 110 | cloner.useFloat16 = True 111 | return cloner 112 | 113 | def _create_index(self, dimension): 114 | index = faiss.IndexIVFPQ( 115 | faiss.IndexFlatL2(dimension), 116 | dimension, 117 | 512, # n_centroids 118 | 64, # sub-quantizers 119 | 8, 120 | ) # nbits per code 121 | return self._index_to_gpu(index) 122 | 123 | 124 | class _BaseMerger: 125 | def __init__(self): 126 | """Merges feature embedding by name.""" 127 | 128 | def merge(self, features: list): 129 | features = [self._reduce(feature) for feature in features] 130 | return np.concatenate(features, axis=1) 131 | 132 | 133 | class AverageMerger(_BaseMerger): 134 | @staticmethod 135 | def _reduce(features): 136 | # NxCxWxH -> NxC 137 | return features.reshape([features.shape[0], features.shape[1], -1]).mean( 138 | axis=-1 139 | ) 140 | 141 | 142 | class ConcatMerger(_BaseMerger): 143 | @staticmethod 144 | def _reduce(features): 145 | # NxCxWxH -> NxCWH 146 | return features.reshape(len(features), -1) 147 | 148 | 149 | class Preprocessing(torch.nn.Module): 150 | def __init__(self, input_dims, output_dim): 151 | super(Preprocessing, self).__init__() 152 | self.input_dims = input_dims 153 | self.output_dim = output_dim 154 | 155 | self.preprocessing_modules = torch.nn.ModuleList() 156 | for _ in input_dims: 157 | module = MeanMapper(output_dim) 158 | self.preprocessing_modules.append(module) 159 | 160 | def forward(self, features): 161 | _features = [] 162 | for module, feature in zip(self.preprocessing_modules, features): 163 | _features.append(module(feature)) 164 | return torch.stack(_features, dim=1) 165 | 166 | 167 | class MeanMapper(torch.nn.Module): 168 | def __init__(self, preprocessing_dim): 169 | super(MeanMapper, self).__init__() 170 | self.preprocessing_dim = preprocessing_dim 171 | 172 | def forward(self, features): 173 | features = features.reshape(len(features), 1, -1) 174 | return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) 175 | 176 | 177 | class Aggregator(torch.nn.Module): 178 | def __init__(self, target_dim): 179 | super(Aggregator, self).__init__() 180 | self.target_dim = target_dim 181 | 182 | def forward(self, features): 183 | """Returns reshaped and average pooled features.""" 184 | # batchsize x number_of_layers x input_dim -> batchsize x target_dim 185 | features = features.reshape(len(features), 1, -1) 186 | features = F.adaptive_avg_pool1d(features, self.target_dim) 187 | return features.reshape(len(features), -1) 188 | 189 | 190 | class RescaleSegmentor: 191 | def __init__(self, device, target_size=224): 192 | self.device = device 193 | self.target_size = target_size 194 | self.smoothing = 4 195 | 196 | def convert_to_segmentation(self, patch_scores): 197 | 198 | with torch.no_grad(): 199 | if isinstance(patch_scores, np.ndarray): 200 | patch_scores = torch.from_numpy(patch_scores) 201 | _scores = patch_scores.to(self.device) 202 | _scores = _scores.unsqueeze(1) 203 | _scores = F.interpolate( 204 | _scores, size=self.target_size, mode="bilinear", align_corners=False 205 | ) 206 | _scores = _scores.squeeze(1) 207 | patch_scores = _scores.cpu().numpy() 208 | 209 | return [ 210 | ndimage.gaussian_filter(patch_score, sigma=self.smoothing) 211 | for patch_score in patch_scores 212 | ] 213 | 214 | 215 | class NetworkFeatureAggregator(torch.nn.Module): 216 | """Efficient extraction of network features.""" 217 | 218 | def __init__(self, backbone, layers_to_extract_from, device): 219 | super(NetworkFeatureAggregator, self).__init__() 220 | """Extraction of network features. 221 | 222 | Runs a network only to the last layer of the list of layers where 223 | network features should be extracted from. 224 | 225 | Args: 226 | backbone: torchvision.model 227 | layers_to_extract_from: [list of str] 228 | """ 229 | self.layers_to_extract_from = layers_to_extract_from 230 | self.backbone = backbone 231 | self.device = device 232 | if not hasattr(backbone, "hook_handles"): 233 | self.backbone.hook_handles = [] 234 | for handle in self.backbone.hook_handles: 235 | handle.remove() 236 | self.outputs = {} 237 | 238 | for extract_layer in layers_to_extract_from: 239 | forward_hook = ForwardHook( 240 | self.outputs, extract_layer, layers_to_extract_from[-1] 241 | ) 242 | if "." in extract_layer: 243 | extract_block, extract_idx = extract_layer.split(".") 244 | network_layer = backbone.__dict__["_modules"][extract_block] 245 | if extract_idx.isnumeric(): 246 | extract_idx = int(extract_idx) 247 | network_layer = network_layer[extract_idx] 248 | else: 249 | network_layer = network_layer.__dict__["_modules"][extract_idx] 250 | else: 251 | network_layer = backbone.__dict__["_modules"][extract_layer] 252 | 253 | if isinstance(network_layer, torch.nn.Sequential): 254 | self.backbone.hook_handles.append( 255 | network_layer[-1].register_forward_hook(forward_hook) 256 | ) 257 | else: 258 | self.backbone.hook_handles.append( 259 | network_layer.register_forward_hook(forward_hook) 260 | ) 261 | self.to(self.device) 262 | 263 | def forward(self, images): 264 | self.outputs.clear() 265 | 266 | with torch.no_grad(): 267 | # The backbone will throw an Exception once it reached the last 268 | # layer to compute features from. Computation will stop there. 269 | try: 270 | _ = self.backbone(images) 271 | except LastLayerToExtractReachedException: 272 | pass 273 | return self.outputs 274 | 275 | def feature_dimensions(self, input_shape): 276 | """Computes the feature dimensions for all layers given input_shape.""" 277 | _input = torch.ones([1] + list(input_shape)).to(self.device) 278 | _output = self(_input) 279 | return [_output[layer].shape[1] for layer in self.layers_to_extract_from] 280 | 281 | 282 | class ForwardHook: 283 | def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str): 284 | self.hook_dict = hook_dict 285 | self.layer_name = layer_name 286 | self.raise_exception_to_break = copy.deepcopy( 287 | layer_name == last_layer_to_extract 288 | ) 289 | 290 | def __call__(self, module, input, output): 291 | # pdb.set_trace() 292 | self.hook_dict[self.layer_name] = output.requires_grad_(True) 293 | if self.raise_exception_to_break: 294 | raise LastLayerToExtractReachedException() 295 | return None 296 | 297 | 298 | class LastLayerToExtractReachedException(Exception): 299 | pass 300 | 301 | 302 | class NearestNeighbourScorer(object): 303 | def __init__(self, n_nearest_neighbours: int, nn_method=FaissNN(False, 4)) -> None: 304 | """ 305 | Neearest-Neighbourhood Anomaly Scorer class. 306 | 307 | Args: 308 | n_nearest_neighbours: [int] Number of nearest neighbours used to 309 | determine anomalous pixels. 310 | nn_method: Nearest neighbour search method. 311 | """ 312 | self.feature_merger = ConcatMerger() 313 | 314 | self.n_nearest_neighbours = n_nearest_neighbours 315 | self.nn_method = nn_method 316 | 317 | self.imagelevel_nn = lambda query: self.nn_method.run( 318 | n_nearest_neighbours, query 319 | ) 320 | self.pixelwise_nn = lambda query, index: self.nn_method.run(1, query, index) 321 | 322 | def fit(self, detection_features: List[np.ndarray]) -> None: 323 | """Calls the fit function of the nearest neighbour method. 324 | 325 | Args: 326 | detection_features: [list of np.arrays] 327 | [[bs x d_i] for i in n] Contains a list of 328 | np.arrays for all training images corresponding to respective 329 | features VECTORS (or maps, but will be resized) produced by 330 | some backbone network which should be used for image-level 331 | anomaly detection. 332 | """ 333 | self.detection_features = self.feature_merger.merge( 334 | detection_features, 335 | ) 336 | self.nn_method.fit(self.detection_features) 337 | 338 | def predict( 339 | self, query_features: List[np.ndarray] 340 | ) -> Union[np.ndarray, np.ndarray, np.ndarray]: 341 | """Predicts anomaly score. 342 | 343 | Searches for nearest neighbours of test images in all 344 | support training images. 345 | 346 | Args: 347 | detection_query_features: [dict of np.arrays] List of np.arrays 348 | corresponding to the test features generated by 349 | some backbone network. 350 | """ 351 | query_features = self.feature_merger.merge( 352 | query_features, 353 | ) 354 | query_distances, query_nns = self.imagelevel_nn(query_features) 355 | 356 | anomaly_scores = np.mean(query_distances, axis=-1) 357 | return anomaly_scores, query_distances, query_nns 358 | 359 | @staticmethod 360 | def _detection_file(folder, prepend=""): 361 | return os.path.join(folder, prepend + "nnscorer_features.pkl") 362 | 363 | @staticmethod 364 | def _index_file(folder, prepend=""): 365 | return os.path.join(folder, prepend + "nnscorer_search_index.faiss") 366 | 367 | @staticmethod 368 | def _save(filename, features): 369 | if features is None: 370 | return 371 | with open(filename, "wb") as save_file: 372 | pickle.dump(features, save_file, pickle.HIGHEST_PROTOCOL) 373 | 374 | @staticmethod 375 | def _load(filename: str): 376 | with open(filename, "rb") as load_file: 377 | return pickle.load(load_file) 378 | 379 | def save( 380 | self, 381 | save_folder: str, 382 | save_features_separately: bool = False, 383 | prepend: str = "", 384 | ) -> None: 385 | self.nn_method.save(self._index_file(save_folder, prepend)) 386 | if save_features_separately: 387 | self._save( 388 | self._detection_file(save_folder, prepend), self.detection_features 389 | ) 390 | 391 | def save_and_reset(self, save_folder: str) -> None: 392 | self.save(save_folder) 393 | self.nn_method.reset_index() 394 | 395 | def load(self, load_folder: str, prepend: str = "") -> None: 396 | self.nn_method.load(self._index_file(load_folder, prepend)) 397 | if os.path.exists(self._detection_file(load_folder, prepend)): 398 | self.detection_features = self._load( 399 | self._detection_file(load_folder, prepend) 400 | ) 401 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import os 4 | import sys 5 | 6 | import time 7 | import random 8 | from pathlib import Path 9 | import argparse 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Subset, ConcatDataset 14 | 15 | # sys.path.append('src') 16 | import src.backbones as backbones 17 | import src.common as common 18 | import src.metrics as metrics 19 | import src.sampler as sampler 20 | import src.utils as utils 21 | import src.softpatch as softpatch 22 | import src.datasets as datasets 23 | LOGGER = logging.getLogger(__name__) 24 | 25 | _DATASETS = {"mvtec": ["datasets.mvtec", "MVTecDataset"], 26 | "btad": ["datasets.btad", "BTADDataset"]} 27 | 28 | 29 | def parse_args(): 30 | # project 31 | parser = argparse.ArgumentParser(description='SoftPatch') 32 | parser.add_argument('--gpu', type=int, default=[], action='append') 33 | parser.add_argument("--seed", type=int, default=0) 34 | parser.add_argument("--results_path", type=str, default="result") 35 | parser.add_argument("--log_project", type=str, default="project") 36 | parser.add_argument("--log_group", type=str, default="group") 37 | parser.add_argument("--save_segmentation_images", action='store_true') 38 | # backbone 39 | parser.add_argument("--backbone_names", "-b", type=str, action='append', default=['wideresnet50']) 40 | parser.add_argument("--layers_to_extract_from", "-le", type=str, action='append', default=['layer2', 'layer3']) 41 | # coreset sampler 42 | parser.add_argument("--sampler_name", type=str, default="approx_greedy_coreset") 43 | parser.add_argument("--sampling_ratio", type=float, default=0.1) 44 | parser.add_argument("--faiss_on_gpu", action='store_true') 45 | parser.add_argument("--faiss_num_workers", type=int, default=4) 46 | # SoftPatch hyper-parameter 47 | parser.add_argument("--weight_method", type=str, default="lof") 48 | parser.add_argument("--threshold", type=float, default=0.15) 49 | parser.add_argument("--lof_k", type=int, default=6) 50 | parser.add_argument("--without_soft_weight", action='store_true') 51 | # dataset 52 | parser.add_argument("--dataset", type=str, required=True) 53 | parser.add_argument("--data_path", type=str, required=True) 54 | parser.add_argument("--subdatasets", "-d", action='append', type=str, required=True) 55 | parser.add_argument("--batch_size", default=8, type=int) 56 | parser.add_argument("--resize", default=256, type=int) 57 | parser.add_argument("--imagesize", default=224, type=int) 58 | parser.add_argument("--noise", type=float, default=0) 59 | parser.add_argument("--overlap", action='store_true') 60 | parser.add_argument("--noise_augmentation", action='store_true') 61 | parser.add_argument("--fold", type=int, default=0) 62 | 63 | args = parser.parse_args() 64 | return args 65 | 66 | 67 | def get_dataloaders(args): 68 | data_path = args.data_path 69 | batch_size = args.batch_size 70 | resize = args.resize 71 | imagesize = args.imagesize 72 | noise = args.noise 73 | overlap = args.overlap 74 | noise_augmentation = args.noise_augmentation 75 | fold = args.fold 76 | 77 | 78 | dataset_info = _DATASETS[args.dataset] 79 | dataset_library = __import__(dataset_info[0], fromlist=[dataset_info[1]]) 80 | dataloaders = [] 81 | for subdataset in args.subdatasets: 82 | train_dataset = dataset_library.__dict__[dataset_info[1]]( 83 | source=data_path, 84 | classname=subdataset, 85 | resize=resize, 86 | imagesize=imagesize, 87 | split=dataset_library.DatasetSplit.TRAIN, 88 | ) 89 | 90 | test_dataset = dataset_library.__dict__[dataset_info[1]]( 91 | source=data_path, 92 | classname=subdataset, 93 | resize=resize, 94 | imagesize=imagesize, 95 | split=dataset_library.DatasetSplit.TEST, 96 | ) 97 | 98 | if noise >= 0: 99 | anomaly_index = [index for index in range(len(test_dataset)) if test_dataset[index]["is_anomaly"]] 100 | train_length = len(train_dataset) 101 | noise_number = int(noise * train_length) 102 | LOGGER.info("{} anomaly samples are being added into train dataset as noise.".format(noise_number)) 103 | 104 | noise_index_path = Path("noise_index" + "/" + str(args.dataset) + "_noise" 105 | + str(noise) + "_fold" + str(fold)) 106 | 107 | noise_index_path.mkdir(parents=True, exist_ok=True) 108 | path = noise_index_path / (subdataset + "-noise" + str(noise) + ".pth") 109 | if path.exists(): 110 | noise_index = torch.load(path) 111 | assert len(noise_index) == noise_number 112 | else: 113 | noise_index = random.sample(anomaly_index, noise_number) 114 | torch.save(noise_index, path) 115 | 116 | noise_dataset = Subset(test_dataset, noise_index) 117 | if noise_augmentation: 118 | noise_dataset = datasets.NoiseDataset(noise_dataset) 119 | train_dataset = ConcatDataset([train_dataset, noise_dataset]) 120 | 121 | train_dataset.imagesize = train_dataset.datasets[0].imagesize 122 | 123 | if not overlap: 124 | new_test_data_index = list(set(range(len(test_dataset))) - set(noise_index)) 125 | test_dataset = Subset(test_dataset, new_test_data_index) 126 | else: 127 | test_dataset = Subset(test_dataset, range(len(test_dataset))) 128 | 129 | 130 | train_dataloader = torch.utils.data.DataLoader( 131 | train_dataset, 132 | batch_size=batch_size, 133 | shuffle=False, 134 | pin_memory=True, 135 | ) 136 | 137 | test_dataloader = torch.utils.data.DataLoader( 138 | test_dataset, 139 | batch_size=batch_size, 140 | shuffle=False, 141 | pin_memory=True, 142 | ) 143 | 144 | train_dataloader.name = args.dataset 145 | if subdataset is not None: 146 | train_dataloader.name += "_" + subdataset 147 | 148 | dataloader_dict = { 149 | "training": train_dataloader, 150 | "testing": test_dataloader, 151 | } 152 | 153 | dataloaders.append(dataloader_dict) 154 | return dataloaders 155 | 156 | 157 | def get_sampler(sampler_name, sampling_ratio, device): 158 | if sampler_name == "identity": 159 | return sampler.IdentitySampler() 160 | elif sampler_name == "greedy_coreset": 161 | return sampler.GreedyCoresetSampler(sampling_ratio, device) 162 | elif sampler_name == "approx_greedy_coreset": 163 | return sampler.ApproximateGreedyCoresetSampler(sampling_ratio, device) 164 | 165 | 166 | def get_coreset(args, imagesize, sampler, device): 167 | input_shape = (3, imagesize, imagesize) 168 | backbone_names = list(args.backbone_names) 169 | if len(backbone_names) > 1: 170 | layers_to_extract_from_coll = [[] for _ in range(len(backbone_names))] 171 | for layer in args.layers_to_extract_from: 172 | idx = int(layer.split(".")[0]) 173 | layer = ".".join(layer.split(".")[1:]) 174 | layers_to_extract_from_coll[idx].append(layer) 175 | else: 176 | layers_to_extract_from_coll = [args.layers_to_extract_from] 177 | 178 | loaded_coresets = [] 179 | for backbone_name, layers_to_extract_from in zip( 180 | backbone_names, layers_to_extract_from_coll 181 | ): 182 | backbone_seed = None 183 | if ".seed-" in backbone_name: 184 | backbone_name, backbone_seed = backbone_name.split(".seed-")[0], int( 185 | backbone_name.split("-")[-1] 186 | ) 187 | backbone = backbones.load(backbone_name) 188 | backbone.name, backbone.seed = backbone_name, backbone_seed 189 | 190 | nn_method = common.FaissNN(args.faiss_on_gpu, args.faiss_num_workers, device=device.index) 191 | 192 | coreset_instance = softpatch.SoftPatch(device) 193 | coreset_instance.load( 194 | backbone=backbone, 195 | layers_to_extract_from=layers_to_extract_from, 196 | device=device, 197 | input_shape=input_shape, 198 | featuresampler=sampler, 199 | nn_method=nn_method, 200 | LOF_k=args.lof_k, 201 | threshold=args.threshold, 202 | weight_method=args.weight_method, 203 | soft_weight_flag=not args.without_soft_weight, 204 | ) 205 | loaded_coresets.append(coreset_instance) 206 | return loaded_coresets 207 | 208 | 209 | def run(args): 210 | 211 | seed = args.seed 212 | run_save_path = utils.create_storage_folder( 213 | args.results_path, args.log_project, args.log_group, mode="iterate" 214 | ) 215 | 216 | list_of_dataloaders = get_dataloaders(args) 217 | 218 | device = utils.set_torch_device(args.gpu) 219 | LOGGER.info(device) 220 | # Device context here is specifically set and used later 221 | # because there was GPU memory-bleeding which I could only fix with 222 | # context managers. 223 | device_context = ( 224 | torch.cuda.device("cuda:{}".format(device.index)) 225 | if "cuda" in device.type.lower() 226 | else contextlib.suppress() 227 | ) 228 | 229 | result_collect = [] 230 | 231 | for dataloader_count, dataloaders in enumerate(list_of_dataloaders): 232 | dataset_name = dataloaders["training"].name 233 | LOGGER.info( 234 | "Evaluating dataset [{}] ({}/{})...".format( 235 | dataloaders["training"].name, 236 | dataloader_count + 1, 237 | len(list_of_dataloaders), 238 | ) 239 | ) 240 | start_time = time.time() 241 | utils.fix_seeds(seed, device) 242 | 243 | 244 | 245 | with device_context: 246 | torch.cuda.empty_cache() 247 | sampler = get_sampler(args.sampler_name, args.sampling_ratio, device) 248 | coreset_list = get_coreset(args, args.imagesize, sampler, device) 249 | if len(coreset_list) > 1: 250 | LOGGER.info( 251 | "Utilizing Coreset Ensemble (N={}).".format(len(coreset_list)) 252 | ) 253 | for i, coreset in enumerate(coreset_list): 254 | torch.cuda.empty_cache() 255 | if coreset.backbone.seed is not None: 256 | utils.fix_seeds(coreset.backbone.seed, device) 257 | LOGGER.info( 258 | "Training models ({}/{})".format(i + 1, len(coreset_list)) 259 | ) 260 | # for epoch in range(20): 261 | # coreset._train(dataloaders["training"]) 262 | coreset.fit(dataloaders["training"]) 263 | train_end = time.time() 264 | torch.cuda.empty_cache() 265 | aggregator = {"scores": [], "segmentations": []} 266 | for i, coreset in enumerate(coreset_list): 267 | torch.cuda.empty_cache() 268 | LOGGER.info( 269 | "Embedding test data with models ({}/{})".format( 270 | i + 1, len(coreset_list) 271 | ) 272 | ) 273 | scores, segmentations, labels_gt, masks_gt = coreset.predict( 274 | dataloaders["testing"] 275 | ) 276 | aggregator["scores"].append(scores) 277 | aggregator["segmentations"].append(segmentations) 278 | 279 | scores = np.array(aggregator["scores"]) 280 | min_scores = scores.min(axis=-1).reshape(-1, 1) 281 | max_scores = scores.max(axis=-1).reshape(-1, 1) 282 | scores = (scores - min_scores) / (max_scores - min_scores + 1e-5) 283 | scores = np.mean(scores, axis=0) 284 | 285 | segmentations = np.array(aggregator["segmentations"]) 286 | min_scores = ( 287 | segmentations.reshape(len(segmentations), -1) 288 | .min(axis=-1) 289 | .reshape(-1, 1, 1, 1) 290 | ) 291 | max_scores = ( 292 | segmentations.reshape(len(segmentations), -1) 293 | .max(axis=-1) 294 | .reshape(-1, 1, 1, 1) 295 | ) 296 | segmentations = (segmentations - min_scores) / (max_scores - min_scores) 297 | segmentations = np.mean(segmentations, axis=0) 298 | 299 | # anomaly_labels = [ 300 | # x[1] != "good" for x in dataloaders["testing"].dataset.data_to_iterate 301 | # ] 302 | 303 | test_end = time.time() 304 | LOGGER.info("Training time:{}, Testing time:{}".format(train_end - start_time, test_end - train_end)) 305 | 306 | # (Optional) Plot example images. 307 | if args.save_segmentation_images: 308 | # dataset = dataloaders["testing"].dataset 309 | image_paths = [ 310 | x[2] for x in 311 | dataloaders["testing"].dataset.dataset.data_to_iterate[dataloaders["testing"].dataset.indices] 312 | ] 313 | mask_paths = [ 314 | x[3] for x in 315 | dataloaders["testing"].dataset.dataset.data_to_iterate[dataloaders["testing"].dataset.indices] 316 | ] 317 | 318 | def image_transform(image): 319 | in_std = np.array( 320 | dataloaders["testing"].dataset.dataset.transform_std 321 | ).reshape(-1, 1, 1) 322 | in_mean = np.array( 323 | dataloaders["testing"].dataset.dataset.transform_mean 324 | ).reshape(-1, 1, 1) 325 | image = dataloaders["testing"].dataset.dataset.transform_img(image) 326 | return np.clip( 327 | (image.numpy() * in_std + in_mean) * 255, 0, 255 328 | ).astype(np.uint8) 329 | 330 | def mask_transform(mask): 331 | return dataloaders["testing"].dataset.dataset.transform_mask(mask).numpy() 332 | 333 | image_save_path = os.path.join( 334 | run_save_path, "segmentation_images", dataset_name 335 | ) 336 | os.makedirs(image_save_path, exist_ok=True) 337 | utils.plot_segmentation_images( 338 | image_save_path, 339 | image_paths, 340 | segmentations, 341 | scores, 342 | mask_paths, 343 | image_transform=image_transform, 344 | mask_transform=mask_transform, 345 | # dataset=dataset 346 | ) 347 | 348 | LOGGER.info("Computing evaluation metrics.") 349 | auroc = metrics.compute_imagewise_retrieval_metrics( 350 | scores, labels_gt 351 | )["auroc"] 352 | 353 | # Compute PRO score & PW Auroc for all images 354 | pixel_scores = metrics.compute_pixelwise_retrieval_metrics( 355 | segmentations, masks_gt 356 | ) 357 | full_pixel_auroc = pixel_scores["auroc"] 358 | 359 | # Compute PRO score & PW Auroc only images with anomalies 360 | # sel_idxs = [] 361 | # for i in range(len(masks_gt)): 362 | # if np.sum(masks_gt[i]) > 0: 363 | # sel_idxs.append(i) 364 | # pixel_scores = coreset.metrics.compute_pixelwise_retrieval_metrics( 365 | # [segmentations[i] for i in sel_idxs], 366 | # [masks_gt[i] for i in sel_idxs], 367 | # ) 368 | # anomaly_pixel_auroc = pixel_scores["auroc"] 369 | 370 | result_collect.append( 371 | { 372 | "dataset_name": dataset_name, 373 | "image_auroc": auroc, 374 | "pixel_auroc": full_pixel_auroc, 375 | # "anomaly_pixel_auroc": anomaly_pixel_auroc, 376 | } 377 | ) 378 | 379 | for key, item in result_collect[-1].items(): 380 | if key != "dataset_name": 381 | LOGGER.info("{0}: {1:4.4f}".format(key, item)) 382 | 383 | LOGGER.info("\n\n-----\n") 384 | 385 | # Store all results and mean scores to a csv-file. 386 | result_metric_names = list(result_collect[-1].keys())[1:] 387 | result_dataset_names = [results["dataset_name"] for results in result_collect] 388 | result_scores = [list(results.values())[1:] for results in result_collect] 389 | utils.compute_and_store_final_results( 390 | run_save_path, 391 | result_scores, 392 | column_names=result_metric_names, 393 | row_names=result_dataset_names, 394 | ) 395 | 396 | if __name__ == "__main__": 397 | logging.basicConfig(level=logging.INFO) 398 | LOGGER.info("Command line arguments: {}".format(" ".join(sys.argv))) 399 | args = parse_args() 400 | run(args) 401 | -------------------------------------------------------------------------------- /src/softpatch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import tqdm 5 | 6 | import torch 7 | import common 8 | import sampler 9 | import multi_variate_gaussian 10 | 11 | from sklearn.neighbors import LocalOutlierFactor 12 | import backbones 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | # from torch_cluster import graclus_cluster 17 | 18 | LOGGER = logging.getLogger(__name__) 19 | 20 | 21 | class SoftPatch(torch.nn.Module): 22 | def __init__(self, device): 23 | super(SoftPatch, self).__init__() 24 | self.device = device 25 | 26 | def load( 27 | self, 28 | backbone, 29 | device, 30 | input_shape, 31 | layers_to_extract_from=("layer2", "layer2"), 32 | pretrain_embed_dimension=1024, 33 | target_embed_dimension=1024, 34 | patchsize=3, 35 | patchstride=1, 36 | anomaly_score_num_nn=1, 37 | featuresampler=sampler.ApproximateGreedyCoresetSampler(percentage=0.1, device=torch.device("cuda")), 38 | nn_method=common.FaissNN(False, 4), 39 | lof_k=5, 40 | threshold=0.15, 41 | weight_method="lof", 42 | soft_weight_flag=True, 43 | **kwargs, 44 | ): 45 | self.backbone = backbone.to(device) 46 | self.layers_to_extract_from = layers_to_extract_from 47 | self.input_shape = input_shape 48 | 49 | self.device = device 50 | self.patch_maker = PatchMaker(patchsize, stride=patchstride) 51 | 52 | self.forward_modules = torch.nn.ModuleDict({}) 53 | 54 | feature_aggregator = common.NetworkFeatureAggregator( 55 | self.backbone, self.layers_to_extract_from, self.device 56 | ) 57 | feature_dimensions = feature_aggregator.feature_dimensions(input_shape) 58 | self.forward_modules["feature_aggregator"] = feature_aggregator 59 | 60 | preprocessing = common.Preprocessing( 61 | feature_dimensions, pretrain_embed_dimension 62 | ) 63 | self.forward_modules["preprocessing"] = preprocessing 64 | 65 | self.target_embed_dimension = target_embed_dimension 66 | preadapt_aggregator = common.Aggregator( 67 | target_dim=target_embed_dimension 68 | ) 69 | 70 | _ = preadapt_aggregator.to(self.device) 71 | 72 | self.forward_modules["preadapt_aggregator"] = preadapt_aggregator 73 | 74 | self.anomaly_scorer = common.NearestNeighbourScorer( 75 | n_nearest_neighbours=anomaly_score_num_nn, nn_method=nn_method 76 | ) 77 | 78 | self.anomaly_segmentor = common.RescaleSegmentor( 79 | device=self.device, target_size=input_shape[-2:] 80 | ) 81 | 82 | self.featuresampler = featuresampler 83 | 84 | ############SoftPatch ########## 85 | self.featuresampler = sampler.WeightedGreedyCoresetSampler(featuresampler.percentage, 86 | featuresampler.device) 87 | self.patch_weight = None 88 | self.feature_shape = [] 89 | self.lof_k = lof_k 90 | self.threshold = threshold 91 | self.coreset_weight = None 92 | self.weight_method = weight_method 93 | self.soft_weight_flag = soft_weight_flag 94 | 95 | def embed(self, data): 96 | if isinstance(data, torch.utils.data.DataLoader): 97 | features = [] 98 | for image in data: 99 | if isinstance(image, dict): 100 | image = image["image"] 101 | with torch.no_grad(): 102 | input_image = image.to(torch.float).to(self.device) 103 | features.append(self._embed(input_image)) 104 | return features 105 | return self._embed(data) 106 | 107 | def _embed(self, images, detach=True, provide_patch_shapes=False): 108 | """Returns feature embeddings for images.""" 109 | 110 | def _detach(features): 111 | if detach: 112 | return [x.detach().cpu().numpy() for x in features] 113 | return features 114 | 115 | _ = self.forward_modules["feature_aggregator"].eval() 116 | with torch.no_grad(): 117 | features = self.forward_modules["feature_aggregator"](images) 118 | 119 | features = [features[layer] for layer in self.layers_to_extract_from] 120 | 121 | features = [ 122 | self.patch_maker.patchify(x, return_spatial_info=True) for x in features 123 | ] 124 | patch_shapes = [x[1] for x in features] 125 | features = [x[0] for x in features] 126 | ref_num_patches = patch_shapes[0] 127 | 128 | for i in range(1, len(features)): 129 | _features = features[i] 130 | patch_dims = patch_shapes[i] 131 | 132 | _features = _features.reshape( 133 | _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:] 134 | ) 135 | _features = _features.permute(0, -3, -2, -1, 1, 2) 136 | perm_base_shape = _features.shape 137 | _features = _features.reshape(-1, *_features.shape[-2:]) 138 | _features = F.interpolate( 139 | _features.unsqueeze(1), 140 | size=(ref_num_patches[0], ref_num_patches[1]), 141 | mode="bilinear", 142 | align_corners=False, 143 | ) 144 | _features = _features.squeeze(1) 145 | _features = _features.reshape( 146 | *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1] 147 | ) 148 | _features = _features.permute(0, -2, -1, 1, 2, 3) 149 | _features = _features.reshape(len(_features), -1, *_features.shape[-3:]) 150 | features[i] = _features 151 | features = [x.reshape(-1, *x.shape[-3:]) for x in features] 152 | 153 | # As different feature backbones & patching provide differently 154 | # sized features, these are brought into the correct form here. 155 | features = self.forward_modules["preprocessing"](features) 156 | features = self.forward_modules["preadapt_aggregator"](features) 157 | 158 | if provide_patch_shapes: 159 | return _detach(features), patch_shapes 160 | return _detach(features) 161 | 162 | def fit(self, training_data): 163 | """ 164 | This function computes the embeddings of the training data and fills the 165 | memory bank of SPADE. 166 | """ 167 | self._fill_memory_bank(training_data) 168 | 169 | def _fill_memory_bank(self, input_data): 170 | """Computes and sets the support features for SPADE.""" 171 | _ = self.forward_modules.eval() 172 | 173 | def _image_to_features(input_image): 174 | with torch.no_grad(): 175 | input_image = input_image.to(torch.float).to(self.device) 176 | return self._embed(input_image) 177 | 178 | features = [] 179 | with tqdm.tqdm( 180 | input_data, desc="Computing support features...", leave=True 181 | ) as data_iterator: 182 | for image in data_iterator: 183 | if isinstance(image, dict): 184 | image = image["image"] 185 | features.append(_image_to_features(image)) 186 | 187 | features = np.concatenate(features, axis=0) 188 | 189 | with torch.no_grad(): 190 | # pdb.set_trace() 191 | self.feature_shape = self._embed(image.to(torch.float).to(self.device), provide_patch_shapes=True)[1][0] 192 | patch_weight = self._compute_patch_weight(features) 193 | 194 | # normalization 195 | # patch_weight = (patch_weight - patch_weight.quantile(0.5, dim=1, keepdim=True)).reshape(-1) + 1 196 | 197 | patch_weight = patch_weight.reshape(-1) 198 | threshold = torch.quantile(patch_weight, 1 - self.threshold) 199 | sampling_weight = torch.where(patch_weight > threshold, 0, 1) 200 | self.featuresampler.set_sampling_weight(sampling_weight) 201 | self.patch_weight = patch_weight.clamp(min=0) 202 | 203 | sample_features, sample_indices = self.featuresampler.run(features) 204 | features = sample_features 205 | self.coreset_weight = self.patch_weight[sample_indices].cpu().numpy() 206 | 207 | self.anomaly_scorer.fit(detection_features=[features]) 208 | 209 | def _compute_patch_weight(self, features: np.ndarray): 210 | if isinstance(features, np.ndarray): 211 | features = torch.from_numpy(features) 212 | 213 | reduced_features = self.featuresampler._reduce_features(features) 214 | patch_features = \ 215 | reduced_features.reshape(-1, self.feature_shape[0]*self.feature_shape[1], reduced_features.shape[-1]) 216 | 217 | # if aligned: 218 | # codebook = patch_features[0] 219 | # assign = [] 220 | # for i in range(1, patch_features.shape[0]): 221 | # dist = torch.cdist(codebook, patch_features[i]).cpu().numpy() 222 | # row_ind, col_ind = linear_assignment(dist) 223 | # assign.append(col_ind) 224 | # patch_features[i]=torch.index_select(patch_features[i], 0, torch.from_numpy(col_ind).to(self.device)) 225 | 226 | patch_features = patch_features.permute(1, 0, 2) 227 | 228 | if self.weight_method == "lof": 229 | patch_weight = self._compute_lof(self.lof_k, patch_features).transpose(-1, -2) 230 | elif self.weight_method == "lof_gpu": 231 | patch_weight = self._compute_lof_gpu(self.lof_k, patch_features).transpose(-1, -2) 232 | elif self.weight_method == "nearest": 233 | patch_weight = self._compute_nearest_distance(patch_features).transpose(-1, -2) 234 | patch_weight = patch_weight + 1 235 | elif self.weight_method == "gaussian": 236 | gaussian = multi_variate_gaussian.MultiVariateGaussian(patch_features.shape[2], patch_features.shape[0]) 237 | stats = gaussian.fit(patch_features) 238 | patch_weight = self._compute_distance_with_gaussian(patch_features, stats).transpose(-1, -2) 239 | patch_weight = patch_weight + 1 240 | else: 241 | raise ValueError("Unexpected weight method") 242 | 243 | # if aligned: 244 | # patch_weight = patch_weight.cpu().numpy() 245 | # for i in range(0, patch_weight.shape[0]): 246 | # patch_weight[i][assign[i]] = patch_weight[i] 247 | # patch_weight = torch.from_numpy(patch_weight).to(self.device) 248 | 249 | return patch_weight 250 | 251 | def _compute_distance_with_gaussian(self, embedding: torch.Tensor, stats: [torch.Tensor]) -> torch.Tensor: 252 | """ 253 | Args: 254 | embedding (Tensor): Embedding Vector 255 | stats (List[Tensor]): Mean and Covariance Matrix of the multivariate Gaussian distribution 256 | 257 | Returns: 258 | Anomaly score of a test image via mahalanobis distance. 259 | """ 260 | # patch, batch, channel = embedding.shape 261 | embedding = embedding.permute(1, 2, 0) 262 | 263 | mean, inv_covariance = stats 264 | delta = (embedding - mean).permute(2, 0, 1) 265 | 266 | distances = (torch.matmul(delta, inv_covariance) * delta).sum(2) 267 | distances = torch.sqrt(distances) 268 | 269 | return distances 270 | 271 | def _compute_nearest_distance(self, embedding: torch.Tensor) -> torch.Tensor: 272 | patch, batch, _ = embedding.shape 273 | 274 | x_x = (embedding ** 2).sum(dim=-1, keepdim=True).expand(patch, batch, batch) 275 | dist_mat = (x_x + x_x.transpose(-1, -2) - 2 * embedding.matmul(embedding.transpose(-1, -2))).abs() ** 0.5 276 | nearest_distance = torch.topk(dist_mat, dim=-1, largest=False, k=2)[0].sum(dim=-1) # 277 | # nearest_distance = nearest_distance.transpose(0, 1).reshape(batch * patch) 278 | return nearest_distance 279 | 280 | def _compute_lof(self, k, embedding: torch.Tensor) -> torch.Tensor: 281 | patch, batch, _ = embedding.shape # 784x219x128 282 | clf = LocalOutlierFactor(n_neighbors=int(k), metric='l2') 283 | scores = torch.zeros(size=(patch, batch), device=embedding.device) 284 | for i in range(patch): 285 | clf.fit(embedding[i].cpu()) 286 | scores[i] = torch.Tensor(- clf.negative_outlier_factor_) 287 | # scores[i] = scores[i] / scores[i].mean() # normalization 288 | # embedding = embedding.reshape(patch*batch, channel) 289 | # clf.fit(embedding.cpu()) 290 | # scores = torch.Tensor(- clf.negative_outlier_factor_) 291 | # scores = scores.reshape(patch, batch) 292 | return scores 293 | 294 | def _compute_lof_gpu(self, k, embedding: torch.Tensor) -> torch.Tensor: 295 | """ 296 | GPU support 297 | """ 298 | 299 | patch, batch, _ = embedding.shape 300 | 301 | # calculate distance 302 | x_x = (embedding ** 2).sum(dim=-1, keepdim=True).expand(patch, batch, batch) 303 | dist_mat = (x_x + x_x.transpose(-1, -2) - 2 * embedding.matmul(embedding.transpose(-1, -2))).abs() ** 0.5 + 1e-6 304 | 305 | # find neighborhoods 306 | top_k_distance_mat, top_k_index = torch.topk(dist_mat, dim=-1, largest=False, k=k + 1) 307 | top_k_distance_mat, top_k_index = top_k_distance_mat[:, :, 1:], top_k_index[:, :, 1:] 308 | k_distance_value_mat = top_k_distance_mat[:, :, -1] 309 | 310 | # calculate reachability distance 311 | reach_dist_mat = torch.max(dist_mat, k_distance_value_mat.unsqueeze(2).expand(patch, batch, batch) 312 | .transpose(-1, -2)) # Transposing is important 313 | top_k_index_hot = torch.zeros(size=dist_mat.shape, device=top_k_index.device).scatter_(-1, top_k_index, 1) 314 | 315 | # Local reachability density 316 | lrd_mat = k / (top_k_index_hot * reach_dist_mat).sum(dim=-1) 317 | 318 | # calculate local outlier factor 319 | lof_mat = ((lrd_mat.unsqueeze(2).expand(patch, batch, batch).transpose(-1, -2) * top_k_index_hot).sum( 320 | dim=-1) / k) / lrd_mat 321 | return lof_mat 322 | 323 | 324 | def _chunk_lof(self, k, embedding: torch.Tensor) -> torch.Tensor: 325 | width, height, batch, channel = embedding.shape 326 | chunk_size = 2 327 | 328 | new_width, new_height = int(width / chunk_size), int(height / chunk_size) 329 | new_patch = new_width * new_height 330 | new_batch = batch * chunk_size * chunk_size 331 | 332 | split_width = torch.stack(embedding.split(chunk_size, dim=0), dim=0) 333 | split_height = torch.stack(split_width.split(chunk_size, dim=1 + 1), dim=1) 334 | 335 | new_embedding = split_height.view(new_patch, new_batch, channel) 336 | lof_mat = self._compute_lof(k, new_embedding) 337 | chunk_lof_mat = lof_mat.reshape(new_width, new_height, chunk_size, chunk_size, batch) 338 | chunk_lof_mat = chunk_lof_mat.transpose(1, 2).reshape(width, height, batch) 339 | return chunk_lof_mat 340 | 341 | 342 | def predict(self, data): 343 | if isinstance(data, torch.utils.data.DataLoader): 344 | return self._predict_dataloader(data) 345 | return self._predict(data) 346 | 347 | def _predict_dataloader(self, dataloader): 348 | """This function provides anomaly scores/maps for full dataloaders.""" 349 | _ = self.forward_modules.eval() 350 | 351 | scores = [] 352 | masks = [] 353 | labels_gt = [] 354 | masks_gt = [] 355 | with tqdm.tqdm(dataloader, desc="Inferring...", leave=True) as data_iterator: 356 | for image in data_iterator: 357 | if isinstance(image, dict): 358 | labels_gt.extend(image["is_anomaly"].numpy().tolist()) 359 | masks_gt.extend(image["mask"].numpy().tolist()) 360 | image = image["image"] 361 | _scores, _masks = self._predict(image) 362 | for score, mask in zip(_scores, _masks): 363 | scores.append(score) 364 | masks.append(mask) 365 | return scores, masks, labels_gt, masks_gt 366 | 367 | def _predict(self, images): 368 | """Infer score and mask for a batch of images.""" 369 | images = images.to(torch.float).to(self.device) 370 | _ = self.forward_modules.eval() 371 | 372 | batchsize = images.shape[0] 373 | with torch.no_grad(): 374 | features, patch_shapes = self._embed(images, provide_patch_shapes=True) 375 | features = np.asarray(features) 376 | 377 | image_scores, _, indices = self.anomaly_scorer.predict([features]) 378 | if self.soft_weight_flag: 379 | indices = indices.squeeze() 380 | # indices = torch.tensor(indices).to(self.device) 381 | weight = np.take(self.coreset_weight, axis=0, indices=indices) 382 | 383 | image_scores = image_scores * weight 384 | # image_scores = weight 385 | 386 | patch_scores = image_scores 387 | 388 | image_scores = self.patch_maker.unpatch_scores( 389 | image_scores, batchsize=batchsize 390 | ) 391 | image_scores = image_scores.reshape(*image_scores.shape[:2], -1) 392 | image_scores = self.patch_maker.score(image_scores) 393 | 394 | patch_scores = self.patch_maker.unpatch_scores( 395 | patch_scores, batchsize=batchsize 396 | ) 397 | scales = patch_shapes[0] 398 | patch_scores = patch_scores.reshape(batchsize, scales[0], scales[1]) 399 | 400 | masks = self.anomaly_segmentor.convert_to_segmentation(patch_scores) 401 | 402 | return [score for score in image_scores], [mask for mask in masks] 403 | 404 | @staticmethod 405 | def _params_file(filepath, prepend=""): 406 | return os.path.join(filepath, prepend + "params.pkl") 407 | 408 | def save_to_path(self, save_path: str, prepend: str = "") -> None: 409 | LOGGER.info("Saving data.") 410 | self.anomaly_scorer.save( 411 | save_path, save_features_separately=False, prepend=prepend 412 | ) 413 | params = { 414 | "backbone.name": self.backbone.name, 415 | "layers_to_extract_from": self.layers_to_extract_from, 416 | "input_shape": self.input_shape, 417 | "pretrain_embed_dimension": self.forward_modules[ 418 | "preprocessing" 419 | ].output_dim, 420 | "target_embed_dimension": self.forward_modules[ 421 | "preadapt_aggregator" 422 | ].target_dim, 423 | "patchsize": self.patch_maker.patchsize, 424 | "patchstride": self.patch_maker.stride, 425 | "anomaly_scorer_num_nn": self.anomaly_scorer.n_nearest_neighbours, 426 | } 427 | with open(self._params_file(save_path, prepend), "wb") as save_file: 428 | pickle.dump(params, save_file, pickle.HIGHEST_PROTOCOL) 429 | 430 | def load_from_path( 431 | self, 432 | load_path: str, 433 | device: torch.device, 434 | nn_method: common.FaissNN(False, 4), 435 | prepend: str = "", 436 | ) -> None: 437 | LOGGER.info("Loading and initializing.") 438 | with open(self._params_file(load_path, prepend), "rb") as load_file: 439 | params = pickle.load(load_file) 440 | params["backbone"] = backbones.load( 441 | params["backbone.name"] 442 | ) 443 | params["backbone"].name = params["backbone.name"] 444 | del params["backbone.name"] 445 | self.load(**params, device=device, nn_method=nn_method) 446 | 447 | self.anomaly_scorer.load(load_path, prepend) 448 | 449 | 450 | # Image handling classes. 451 | class PatchMaker: 452 | def __init__(self, patchsize, stride=None): 453 | self.patchsize = patchsize 454 | self.stride = stride 455 | 456 | def patchify(self, features, return_spatial_info=False): 457 | """Convert a tensor into a tensor of respective patches. 458 | Args: 459 | x: [torch.Tensor, bs x c x w x h] 460 | Returns: 461 | x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize, 462 | patchsize] 463 | """ 464 | padding = int((self.patchsize - 1) / 2) 465 | unfolder = torch.nn.Unfold( 466 | kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1 467 | ) 468 | unfolded_features = unfolder(features) 469 | number_of_total_patches = [] 470 | for side in features.shape[-2:]: 471 | n_patches = ( 472 | side + 2 * padding - 1 * (self.patchsize - 1) - 1 473 | ) / self.stride + 1 474 | number_of_total_patches.append(int(n_patches)) 475 | unfolded_features = unfolded_features.reshape( 476 | *features.shape[:2], self.patchsize, self.patchsize, -1 477 | ) 478 | unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) 479 | 480 | if return_spatial_info: 481 | return unfolded_features, number_of_total_patches 482 | return unfolded_features 483 | 484 | def unpatch_scores(self, patch_scores, batchsize): 485 | return patch_scores.reshape(batchsize, -1, *patch_scores.shape[1:]) 486 | 487 | def score(self, image_scores): 488 | was_numpy = False 489 | if isinstance(image_scores, np.ndarray): 490 | was_numpy = True 491 | image_scores = torch.from_numpy(image_scores) 492 | while image_scores.ndim > 1: 493 | image_scores = torch.max(image_scores, dim=-1).values 494 | if was_numpy: 495 | return image_scores.numpy() 496 | return image_scores 497 | --------------------------------------------------------------------------------