├── patchcore ├── utils │ ├── __pycache__ │ │ ├── PU.cpython-38.pyc │ │ ├── eval.cpython-38.pyc │ │ ├── mlp.cpython-38.pyc │ │ ├── utils.cpython-38.pyc │ │ ├── dataset.cpython-38.pyc │ │ └── features.cpython-38.pyc │ ├── set_mtd.py │ ├── dataset.py │ └── utils.py ├── components │ ├── filters │ │ ├── __init__.py │ │ └── blur.py │ ├── freia │ │ ├── modules │ │ │ ├── __pycache__ │ │ │ │ ├── base.cpython-37.pyc │ │ │ │ ├── base.cpython-38.pyc │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── all_in_one_block.cpython-37.pyc │ │ │ │ └── all_in_one_block.cpython-38.pyc │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── all_in_one_block.py │ │ ├── framework │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── sequence_inn.cpython-37.pyc │ │ │ │ └── sequence_inn.cpython-38.pyc │ │ │ ├── __init__.py │ │ │ └── sequence_inn.py │ │ ├── __init__.py │ │ └── README.md │ ├── sampling │ │ ├── __init__.py │ │ └── k_center_greedy.py │ ├── feature_extractors │ │ ├── __init__.py │ │ └── feature_extractor.py │ ├── stats │ │ ├── __init__.py │ │ ├── kde.py │ │ └── multi_variate_gaussian.py │ ├── base │ │ ├── __init__.py │ │ ├── dynamic_module.py │ │ └── anomaly_module.py │ ├── dimensionality_reduction │ │ ├── __init__.py │ │ ├── pca.py │ │ └── random_projection.py │ └── __init__.py ├── pre_processing │ ├── transforms │ │ ├── __init__.py │ │ └── custom.py │ ├── __init__.py │ ├── pre_process.py │ └── tiler.py ├── README.md ├── metrics │ ├── min_max.py │ ├── collection.py │ ├── optimal_f1.py │ ├── anomaly_score_distribution.py │ ├── adaptive_threshold.py │ ├── auroc.py │ ├── plotting_utils.py │ ├── __init__.py │ ├── aupr.py │ └── aupro.py ├── anomaly_map.py ├── lightning_model.py ├── run.py └── torch_model.py ├── run_unet_v2.sh ├── run_ad.sh ├── preprocess ├── polygon2mask.py └── preprocess.py ├── dataset_2d_sup.py ├── README.md ├── train_normal_unet.py └── finetune_cnn_coord.py /patchcore/utils/__pycache__/PU.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/utils/__pycache__/PU.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/components/filters/__init__.py: -------------------------------------------------------------------------------- 1 | """Implements filters used by models.""" 2 | 3 | from .blur import GaussianBlur2d 4 | 5 | __all__ = ["GaussianBlur2d"] 6 | -------------------------------------------------------------------------------- /patchcore/utils/__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/utils/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/utils/__pycache__/mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/utils/__pycache__/mlp.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/utils/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/utils/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/utils/__pycache__/features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/utils/__pycache__/features.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/modules/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/modules/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/modules/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/modules/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/framework/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/framework/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/framework/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/framework/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/framework/__pycache__/sequence_inn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/framework/__pycache__/sequence_inn.cpython-37.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/framework/__pycache__/sequence_inn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/framework/__pycache__/sequence_inn.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/modules/__pycache__/all_in_one_block.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/modules/__pycache__/all_in_one_block.cpython-37.pyc -------------------------------------------------------------------------------- /patchcore/components/freia/modules/__pycache__/all_in_one_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oopil/PSAD_logical_anomaly_detection/HEAD/patchcore/components/freia/modules/__pycache__/all_in_one_block.cpython-38.pyc -------------------------------------------------------------------------------- /patchcore/components/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | """Sampling methods.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from .k_center_greedy import KCenterGreedy 7 | 8 | __all__ = ["KCenterGreedy"] 9 | -------------------------------------------------------------------------------- /patchcore/components/feature_extractors/__init__.py: -------------------------------------------------------------------------------- 1 | """Feature extractors.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from .feature_extractor import FeatureExtractor 7 | 8 | __all__ = ["FeatureExtractor"] 9 | -------------------------------------------------------------------------------- /patchcore/pre_processing/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | """Anomalib Data Transforms.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from .custom import Denormalize, ToNumpy 7 | 8 | __all__ = ["Denormalize", "ToNumpy"] 9 | -------------------------------------------------------------------------------- /patchcore/components/freia/framework/__init__.py: -------------------------------------------------------------------------------- 1 | """Framework.""" 2 | 3 | # Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | 7 | from .sequence_inn import SequenceINN 8 | 9 | __all__ = ["SequenceINN"] 10 | -------------------------------------------------------------------------------- /patchcore/components/stats/__init__.py: -------------------------------------------------------------------------------- 1 | """Statistical functions.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from .kde import GaussianKDE 7 | from .multi_variate_gaussian import MultiVariateGaussian 8 | 9 | __all__ = ["GaussianKDE", "MultiVariateGaussian"] 10 | -------------------------------------------------------------------------------- /patchcore/pre_processing/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities for pre-processing the input before passing to the model.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from .pre_process import PreProcessor 7 | from .tiler import Tiler 8 | 9 | __all__ = ["PreProcessor", "Tiler"] 10 | -------------------------------------------------------------------------------- /patchcore/components/base/__init__.py: -------------------------------------------------------------------------------- 1 | """Base classes for all anomaly components.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from .anomaly_module import AnomalyModule 7 | from .dynamic_module import DynamicBufferModule 8 | 9 | __all__ = ["AnomalyModule", "DynamicBufferModule"] 10 | -------------------------------------------------------------------------------- /patchcore/components/freia/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Modules.""" 2 | 3 | # Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | 7 | from .all_in_one_block import AllInOneBlock 8 | from .base import InvertibleModule 9 | 10 | __all__ = ["AllInOneBlock", "InvertibleModule"] 11 | -------------------------------------------------------------------------------- /patchcore/components/dimensionality_reduction/__init__.py: -------------------------------------------------------------------------------- 1 | """Algorithms for decomposition and dimensionality reduction.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from .pca import PCA 7 | from .random_projection import SparseRandomProjection 8 | 9 | __all__ = ["PCA", "SparseRandomProjection"] 10 | -------------------------------------------------------------------------------- /patchcore/components/freia/__init__.py: -------------------------------------------------------------------------------- 1 | """Framework for Easily Invertible Architectures. 2 | 3 | Module to construct invertible networks with pytorch, based on a graph 4 | structure of operations. 5 | 6 | Link to the original repo: https://github.com/VLL-HD/FrEIA 7 | """ 8 | 9 | # Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. 10 | # SPDX-License-Identifier: MIT 11 | # 12 | 13 | from .framework import SequenceINN 14 | from .modules import AllInOneBlock 15 | 16 | __all__ = ["SequenceINN", "AllInOneBlock"] 17 | -------------------------------------------------------------------------------- /patchcore/components/freia/README.md: -------------------------------------------------------------------------------- 1 | # FrEIA 2 | 3 | This sub-package contains freia packages to use within flow-based algorithms such as Cflow. 4 | 5 | ## Description 6 | 7 | [FrEIA](https://github.com/VLL-HD/FrEIA) package is currently not available in pypi to install via pip. The only way to install it is `pip install git+https://github.com/VLL-HD/FrEIA.git`. PyPI, however, does not support installing packages from git links. Due to this limitation, anomalib cannot be updated on PyPI. To avoid this, `anomalib` contains some of the [FrEIA](https://github.com/VLL-HD/FrEIA) modules to facilitate CFlow training/inference. 8 | -------------------------------------------------------------------------------- /run_unet_v2.sh: -------------------------------------------------------------------------------- 1 | GPU_ID=$1 2 | DIR=$2 3 | gpu_list=(0 4 5 6 7) 4 | idx=0 5 | 6 | SEG_DIR="fss_comparison/${DIR}" 7 | SAVE_DIR="orig_512_seg/${DIR}" 8 | SNAP_DIR="output/${DIR}" 9 | 10 | for obj in "breakfast_box" "screw_bag" "juice_bottle" "splicing_connectors" "pushpins" 11 | do 12 | CUDA_VISIBLE_DEVICES=${gpu_list[$idx]} python train_normal_unet.py \ 13 | --obj_name ${obj} \ 14 | --num_epochs 300 \ 15 | --snapshot_dir ${SAVE_DIR} \ 16 | --save_dir ${SAVE_DIR} \ 17 | --seg_dir ${SEG_DIR} \ 18 | --learning_rate 1e-3 \ 19 | --pretrained False & 20 | 21 | idx=$(($idx+1)) 22 | sleep 1 23 | done 24 | 25 | exit 0 -------------------------------------------------------------------------------- /patchcore/components/__init__.py: -------------------------------------------------------------------------------- 1 | """Components used within the models.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from .base import AnomalyModule, DynamicBufferModule 7 | from .dimensionality_reduction import PCA, SparseRandomProjection 8 | from .feature_extractors import FeatureExtractor 9 | from .filters import GaussianBlur2d 10 | from .sampling import KCenterGreedy 11 | from .stats import GaussianKDE, MultiVariateGaussian 12 | 13 | __all__ = [ 14 | "AnomalyModule", 15 | "DynamicBufferModule", 16 | "PCA", 17 | "SparseRandomProjection", 18 | "FeatureExtractor", 19 | "KCenterGreedy", 20 | "GaussianKDE", 21 | "GaussianBlur2d", 22 | "MultiVariateGaussian", 23 | ] 24 | -------------------------------------------------------------------------------- /run_ad.sh: -------------------------------------------------------------------------------- 1 | GPU_ID=$1 2 | DIR=$2 3 | mtype=$3 4 | stype=$4 5 | gpu_list=(0 4 5 6 7) 6 | idx=0 7 | 8 | SEG_DIR="orig_512_seg/${DIR}" 9 | CUDA_VISIBLE_DEVICES=${GPU_ID} python psad.py \ 10 | --type logical \ 11 | --standardize 1 \ 12 | --avgpool_size 5 \ 13 | --less_data 1 \ 14 | --memory_type ${mtype} \ 15 | --scale_type ${stype} \ 16 | --seg_dir ${SEG_DIR} \ 17 | --save_img 0 \ 18 | --save_csv 0 19 | 20 | CUDA_VISIBLE_DEVICES=${GPU_ID} python psad.py \ 21 | --type structural \ 22 | --standardize 1 \ 23 | --avgpool_size 5 \ 24 | --less_data 1 \ 25 | --memory_type ${mtype} \ 26 | --scale_type ${stype} \ 27 | --seg_dir ${SEG_DIR} \ 28 | --save_img 0 \ 29 | --save_csv 0 30 | 31 | exit 0 32 | 33 | 34 | -------------------------------------------------------------------------------- /patchcore/README.md: -------------------------------------------------------------------------------- 1 | # Few shot Anomaly Detection using Positive Unlabeled Learning with Cycle Consistency and Co-occurrence Features 2 | ## Getting Started 3 | * To set the python environment, use below command. 4 | ``` 5 | $ conda env create -f environment.yml 6 | ``` 7 | * Download datasets from: 8 | MVtec AD: 9 | MPDD: 10 | MTD: 11 | 12 | Notice: If you use MTD, follow below process: 13 | ``` 14 | $ cd [MTD dataset path] 15 | $ mkdir ./MT 16 | $ cd MT 17 | $ cp utils/set_mtd.py ./ 18 | $ python set_mtd.py 19 | ``` 20 | 21 | ## Train & Evaluation 22 | * To train and evaluate model, run run.py as follows: 23 | * Run only evaluation (after training), run run.py with --mode test 24 | ``` 25 | $ run.py --mode train --gpu [gpu id] --datapath [datapath] --dataset [dataset] --category [category] --few [the number of labeled samples] 26 | ``` 27 | 28 | This code is based on anomaly detection library (anomalib): https://github.com/openvinotoolkit/anomalib 29 | -------------------------------------------------------------------------------- /patchcore/metrics/min_max.py: -------------------------------------------------------------------------------- 1 | """Module that tracks the min and max values of the observations in each batch.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torchmetrics import Metric 11 | 12 | 13 | class MinMax(Metric): 14 | """Track the min and max values of the observations in each batch.""" 15 | 16 | def __init__(self, **kwargs): 17 | super().__init__(**kwargs) 18 | self.add_state("min", torch.tensor(float("inf")), persistent=True) # pylint: disable=not-callable 19 | self.add_state("max", torch.tensor(float("-inf")), persistent=True) # pylint: disable=not-callable 20 | 21 | self.min = torch.tensor(float("inf")) # pylint: disable=not-callable 22 | self.max = torch.tensor(float("-inf")) # pylint: disable=not-callable 23 | 24 | # pylint: disable=arguments-differ 25 | def update(self, predictions: Tensor) -> None: # type: ignore 26 | """Update the min and max values.""" 27 | self.max = torch.max(self.max, torch.max(predictions)) 28 | self.min = torch.min(self.min, torch.min(predictions)) 29 | 30 | def compute(self) -> Tuple[Tensor, Tensor]: 31 | """Return min and max values.""" 32 | return self.min, self.max 33 | -------------------------------------------------------------------------------- /patchcore/metrics/collection.py: -------------------------------------------------------------------------------- 1 | """Anomalib Metric Collection.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from torchmetrics import MetricCollection 7 | 8 | 9 | class AnomalibMetricCollection(MetricCollection): 10 | """Extends the MetricCollection class for use in the Anomalib pipeline.""" 11 | 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self._update_called = False 15 | self._threshold = 0.5 16 | 17 | def set_threshold(self, threshold_value): 18 | """Update the threshold value for all metrics that have the threshold attribute.""" 19 | self._threshold = threshold_value 20 | for metric in self.values(): 21 | if hasattr(metric, "threshold"): 22 | metric.threshold = threshold_value 23 | 24 | def update(self, *args, **kwargs) -> None: 25 | """Add data to the metrics.""" 26 | super().update(*args, **kwargs) 27 | self._update_called = True 28 | 29 | @property 30 | def update_called(self) -> bool: 31 | """Returns a boolean indicating if the update method has been called at least once.""" 32 | return self._update_called 33 | 34 | @property 35 | def threshold(self) -> float: 36 | """Return the value of the anomaly threshold.""" 37 | return self._threshold 38 | -------------------------------------------------------------------------------- /patchcore/utils/set_mtd.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from shutil import copyfile 4 | import random 5 | 6 | category = os.listdir('../') 7 | category.remove('MT') 8 | 9 | for c in category: 10 | imgs = glob.glob(os.path.join('..', c, '*', '*.jpg')) 11 | random.shuffle(imgs) 12 | gts = glob.glob(os.path.join('..', c, '*', '*.png')) 13 | 14 | if 'Free' in c: 15 | trainset = int(len(imgs)*0.8) 16 | train_imgs = imgs[:trainset] 17 | test_imgs = imgs[trainset:] 18 | if not os.path.exists(os.path.join('.', 'train', 'good')): 19 | os.makedirs(os.path.join('.', 'train', 'good')) 20 | 21 | for i in train_imgs: 22 | copyfile(i, os.path.join('.', 'train', 'good', i.split('/')[-1])) 23 | 24 | if not os.path.exists(os.path.join('.', 'test', 'good')): 25 | os.makedirs(os.path.join('.', 'test', 'good')) 26 | 27 | for i in test_imgs: 28 | copyfile(i, os.path.join('.', 'test', 'good', i.split('/')[-1])) 29 | 30 | else: 31 | if not os.path.exists(os.path.join('.', 'test', c)): 32 | os.makedirs(os.path.join('.', 'test', c)) 33 | 34 | test_imgs = imgs 35 | for i in test_imgs: 36 | copyfile(i, os.path.join('.', 'test', c, i.split('/')[-1])) 37 | 38 | if not os.path.exists(os.path.join('.', 'ground_truth', c)): 39 | os.makedirs(os.path.join('.', 'ground_truth', c)) 40 | 41 | gt_imgs = gts 42 | for i in gt_imgs: 43 | copyfile(i, os.path.join('.', 'ground_truth', c, i.split('/')[-1])) 44 | -------------------------------------------------------------------------------- /patchcore/metrics/optimal_f1.py: -------------------------------------------------------------------------------- 1 | """Implementation of Optimal F1 score based on TorchMetrics.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | import torch 7 | from torchmetrics import Metric, PrecisionRecallCurve 8 | 9 | 10 | class OptimalF1(Metric): 11 | """Optimal F1 Metric. 12 | 13 | Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the 14 | predicted anomaly scores. 15 | """ 16 | 17 | def __init__(self, num_classes: int, **kwargs): 18 | super().__init__(**kwargs) 19 | 20 | self.precision_recall_curve = PrecisionRecallCurve(num_classes=num_classes) 21 | 22 | self.threshold: torch.Tensor 23 | 24 | # pylint: disable=arguments-differ 25 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore 26 | """Update the precision-recall curve metric.""" 27 | self.precision_recall_curve.update(preds, target) 28 | 29 | def compute(self) -> torch.Tensor: 30 | """Compute the value of the optimal F1 score. 31 | 32 | Compute the F1 scores while varying the threshold. Store the optimal 33 | threshold as attribute and return the maximum value of the F1 score. 34 | 35 | Returns: 36 | Value of the F1 score at the optimal threshold. 37 | """ 38 | precision: torch.Tensor 39 | recall: torch.Tensor 40 | thresholds: torch.Tensor 41 | 42 | precision, recall, thresholds = self.precision_recall_curve.compute() 43 | f1_score = (2 * precision * recall) / (precision + recall + 1e-10) 44 | self.threshold = thresholds[torch.argmax(f1_score)] 45 | optimal_f1_score = torch.max(f1_score) 46 | return optimal_f1_score 47 | 48 | def reset(self) -> None: 49 | """Reset the metric.""" 50 | self.precision_recall_curve.reset() 51 | -------------------------------------------------------------------------------- /patchcore/components/base/dynamic_module.py: -------------------------------------------------------------------------------- 1 | """Dynamic Buffer Module.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from abc import ABC 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class DynamicBufferModule(ABC, nn.Module): 12 | """Torch module that allows loading variables from the state dict even in the case of shape mismatch.""" 13 | 14 | def get_tensor_attribute(self, attribute_name: str) -> Tensor: 15 | """Get attribute of the tensor given the name. 16 | 17 | Args: 18 | attribute_name (str): Name of the tensor 19 | 20 | Raises: 21 | ValueError: `attribute_name` is not a torch Tensor 22 | 23 | Returns: 24 | Tensor: Tensor attribute 25 | """ 26 | attribute = getattr(self, attribute_name) 27 | if isinstance(attribute, Tensor): 28 | return attribute 29 | 30 | raise ValueError(f"Attribute with name '{attribute_name}' is not a torch Tensor") 31 | 32 | def _load_from_state_dict(self, state_dict: dict, prefix: str, *args): 33 | """Resizes the local buffers to match those stored in the state dict. 34 | 35 | Overrides method from parent class. 36 | 37 | Args: 38 | state_dict (dict): State dictionary containing weights 39 | prefix (str): Prefix of the weight file. 40 | *args: 41 | """ 42 | persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} 43 | local_buffers = {k: v for k, v in persistent_buffers.items() if v is not None} 44 | 45 | for param in local_buffers.keys(): 46 | for key in state_dict.keys(): 47 | if key.startswith(prefix) and key[len(prefix) :].split(".")[0] == param: 48 | if not local_buffers[param].shape == state_dict[key].shape: 49 | attribute = self.get_tensor_attribute(param) 50 | attribute.resize_(state_dict[key].shape) 51 | 52 | super()._load_from_state_dict(state_dict, prefix, *args) 53 | -------------------------------------------------------------------------------- /patchcore/metrics/anomaly_score_distribution.py: -------------------------------------------------------------------------------- 1 | """Module that computes the parameters of the normal data distribution of the training set.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torchmetrics import Metric 11 | 12 | 13 | class AnomalyScoreDistribution(Metric): 14 | """Mean and standard deviation of the anomaly scores of normal training data.""" 15 | 16 | def __init__(self, **kwargs): 17 | super().__init__(**kwargs) 18 | self.anomaly_maps = [] 19 | self.anomaly_scores = [] 20 | 21 | self.add_state("image_mean", torch.empty(0), persistent=True) 22 | self.add_state("image_std", torch.empty(0), persistent=True) 23 | self.add_state("pixel_mean", torch.empty(0), persistent=True) 24 | self.add_state("pixel_std", torch.empty(0), persistent=True) 25 | 26 | self.image_mean = torch.empty(0) 27 | self.image_std = torch.empty(0) 28 | self.pixel_mean = torch.empty(0) 29 | self.pixel_std = torch.empty(0) 30 | 31 | # pylint: disable=arguments-differ 32 | def update( # type: ignore 33 | self, anomaly_scores: Optional[Tensor] = None, anomaly_maps: Optional[Tensor] = None 34 | ) -> None: 35 | """Update the precision-recall curve metric.""" 36 | if anomaly_maps is not None: 37 | self.anomaly_maps.append(anomaly_maps) 38 | if anomaly_scores is not None: 39 | self.anomaly_scores.append(anomaly_scores) 40 | 41 | def compute(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 42 | """Compute stats.""" 43 | anomaly_scores = torch.hstack(self.anomaly_scores) 44 | anomaly_scores = torch.log(anomaly_scores) 45 | 46 | self.image_mean = anomaly_scores.mean() 47 | self.image_std = anomaly_scores.std() 48 | 49 | if self.anomaly_maps: 50 | anomaly_maps = torch.vstack(self.anomaly_maps) 51 | anomaly_maps = torch.log(anomaly_maps).cpu() 52 | 53 | self.pixel_mean = anomaly_maps.mean(dim=0).squeeze() 54 | self.pixel_std = anomaly_maps.std(dim=0).squeeze() 55 | 56 | return self.image_mean, self.image_std, self.pixel_mean, self.pixel_std 57 | -------------------------------------------------------------------------------- /patchcore/metrics/adaptive_threshold.py: -------------------------------------------------------------------------------- 1 | """Implementation of Optimal F1 score based on TorchMetrics.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | import torch 7 | from torchmetrics import Metric, PrecisionRecallCurve 8 | import pdb 9 | 10 | 11 | class AdaptiveThreshold(Metric): 12 | """Optimal F1 Metric. 13 | 14 | Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the 15 | predicted anomaly scores. 16 | """ 17 | 18 | def __init__(self, default_value: float = 0.5, **kwargs): 19 | super().__init__(**kwargs) 20 | 21 | self.precision_recall_curve = PrecisionRecallCurve(num_classes=1) 22 | self.add_state("value", default=torch.tensor(default_value), persistent=True) # pylint: disable=not-callable 23 | self.value = torch.tensor(default_value) # pylint: disable=not-callable 24 | 25 | # pylint: disable=arguments-differ 26 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore 27 | """Update the precision-recall curve metric.""" 28 | self.precision_recall_curve.update(preds, target) 29 | 30 | def compute(self) -> torch.Tensor: 31 | """Compute the threshold that yields the optimal F1 score. 32 | 33 | Compute the F1 scores while varying the threshold. Store the optimal 34 | threshold as attribute and return the maximum value of the F1 score. 35 | 36 | Returns: 37 | Value of the F1 score at the optimal threshold. 38 | """ 39 | precision: torch.Tensor 40 | recall: torch.Tensor 41 | thresholds: torch.Tensor 42 | # pdb.set_trace() 43 | precision, recall, thresholds = self.precision_recall_curve.compute() 44 | f1_score = (2 * precision * recall) / (precision + recall + 1e-10) 45 | if thresholds.dim() == 0: 46 | # special case where recall is 1.0 even for the highest threshold. 47 | # In this case 'thresholds' will be scalar. 48 | self.value = thresholds 49 | else: 50 | self.value = thresholds[torch.argmax(f1_score)] 51 | # print('BBBBBBBBBBBBBBBBBB', self.value) 52 | 53 | return self.value 54 | 55 | def reset(self) -> None: 56 | """Reset the metric.""" 57 | self.precision_recall_curve.reset() 58 | -------------------------------------------------------------------------------- /preprocess/polygon2mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pdb 4 | import json 5 | import shutil 6 | import numpy as np 7 | from glob import glob 8 | from PIL import Image,ImageDraw 9 | 10 | # root_dir = '/media/NAS/nas_187/soopil/project/battery/sample_crack/data/imgs_train' 11 | # src_dir = f"{root_dir}/imgs" 12 | shot = "fewshot" 13 | trg_dir = f"./{shot}_mask" 14 | # trg_dir = f"./{shot}_mask_vis" 15 | 16 | if not os.path.exists(trg_dir): 17 | os.makedirs(trg_dir) 18 | 19 | label_fpath = f"./{shot}/ano_vgg_v6.json" 20 | # label_fpath = f"{root_dir}/{label_fname}" 21 | print(label_fpath) 22 | with open(label_fpath, "r") as json_file: 23 | data = json.load(json_file) 24 | 25 | 26 | palette = [0, 0, 0, 204, 241, 227, 112, 142, 18, 254, 8, 23, 207, 149, 84, 202, 24, 214, 27 | 230, 192, 37, 241, 80, 68, 74, 127, 0, 2, 81, 216, 24, 240, 129, 20, 215, 125, 161, 31, 204, 28 | 254, 52, 116, 117, 198, 203, 4, 41, 68, 127, 252, 61, 21, 3, 142, 40, 10, 159, 241, 61, 36, 29 | 14, 175, 77, 144, 61, 115, 131, 79, 97, 109, 177, 163, 58, 198, 140, 17, 235, 168, 47, 128, 91, 30 | 238, 103, 45, 124, 35, 228, 101, 48, 232, 74, 124, 114, 78, 49, 30, 35, 167, 27, 137, 231, 47, 31 | 235, 32, 39, 56, 112, 32, 62, 173, 79, 86, 44, 201, 77, 47, 217, 246, 223, 57, ] 32 | # Pad with zeroes to 768 values, i.e. 256 RGB colours 33 | palette = palette + [0]*(768-len(palette)) 34 | 35 | # pdb.set_trace() 36 | imgs = list(data.keys()) 37 | for img in imgs: 38 | print(img) 39 | img_path = f"./{shot}/{img}" 40 | im = Image.open(img_path).convert("RGB") 41 | (w,h) = im.size 42 | 43 | d = data[img] 44 | num_objs = len(d['regions']) 45 | mask = Image.new('L', (w,h), 0) 46 | for idx_obj in reversed(range(num_objs)): 47 | x_coords = d['regions'][str(idx_obj)]['shape_attributes']['all_points_x'] 48 | y_coords = d['regions'][str(idx_obj)]['shape_attributes']['all_points_y'] 49 | coords = [(x,y) for x,y in zip(x_coords, y_coords)] 50 | label = int(d['regions'][str(idx_obj)]['region_attributes']['label']) 51 | # print(len(coords), label) 52 | ImageDraw.Draw(mask).polygon(coords, fill=(label)*1) 53 | print(idx_obj,label) 54 | # ImageDraw.Draw(mask).polygon(coords, outline=idx_obj+1, fill=idx_obj+1) 55 | mask = np.array(mask) 56 | print(mask.shape, np.unique(mask)) 57 | 58 | opath = f"{trg_dir}/{img}" 59 | # new_im = Image.fromarray(mask) 60 | # new_im.save(opath) 61 | pi = Image.fromarray(mask,'P') 62 | pi.putpalette(palette) 63 | # pi.show() 64 | pi.save(opath) 65 | 66 | -------------------------------------------------------------------------------- /patchcore/metrics/auroc.py: -------------------------------------------------------------------------------- 1 | """Implementation of AUROC metric based on TorchMetrics.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | from matplotlib.figure import Figure 10 | from torch import Tensor 11 | from torchmetrics import ROC 12 | from torchmetrics.functional import auc 13 | 14 | from .plotting_utils import plot_figure 15 | 16 | 17 | class AUROC(ROC): 18 | """Area under the ROC curve.""" 19 | 20 | def compute(self) -> Tensor: 21 | """First compute ROC curve, then compute area under the curve. 22 | 23 | Returns: 24 | Tensor: Value of the AUROC metric 25 | """ 26 | tpr: Tensor 27 | fpr: Tensor 28 | 29 | fpr, tpr = self._compute() 30 | # TODO: use stable sort after upgrading to pytorch 1.9.x (https://github.com/openvinotoolkit/anomalib/issues/92) 31 | if not (torch.all(fpr.diff() <= 0) or torch.all(fpr.diff() >= 0)): 32 | return auc(fpr, tpr, reorder=True) # only reorder if fpr is not increasing or decreasing 33 | return auc(fpr, tpr) 34 | 35 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 36 | """Update state with new values. 37 | 38 | Need to flatten new values as ROC expects them in this format for binary classification. 39 | 40 | Args: 41 | preds (Tensor): predictions of the model 42 | target (Tensor): ground truth targets 43 | """ 44 | super().update(preds.flatten(), target.flatten()) 45 | 46 | def _compute(self) -> Tuple[Tensor, Tensor]: 47 | """Compute fpr/tpr value pairs. 48 | 49 | Returns: 50 | Tuple containing Tensors for fpr and tpr 51 | """ 52 | tpr: Tensor 53 | fpr: Tensor 54 | fpr, tpr, _thresholds = super().compute() 55 | return (fpr, tpr) 56 | 57 | def generate_figure(self) -> Tuple[Figure, str]: 58 | """Generate a figure containing the ROC curve, the baseline and the AUROC. 59 | 60 | Returns: 61 | Tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging 62 | """ 63 | fpr, tpr = self._compute() 64 | auroc = self.compute() 65 | 66 | xlim = (0.0, 1.0) 67 | ylim = (0.0, 1.0) 68 | xlabel = "False Positive Rate" 69 | ylabel = "True Positive Rate" 70 | loc = "lower right" 71 | title = "ROC" 72 | 73 | fig, axis = plot_figure(fpr, tpr, auroc, xlim, ylim, xlabel, ylabel, loc, title) 74 | 75 | axis.plot( 76 | [0, 1], 77 | [0, 1], 78 | color="navy", 79 | lw=2, 80 | linestyle="--", 81 | figure=fig, 82 | ) 83 | 84 | return fig, title 85 | -------------------------------------------------------------------------------- /patchcore/metrics/plotting_utils.py: -------------------------------------------------------------------------------- 1 | """Helper functions to generate ROC-style plots of various metrics.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | from matplotlib import pyplot as plt 10 | from matplotlib.axis import Axis 11 | from matplotlib.figure import Figure 12 | from torch import Tensor 13 | 14 | 15 | def plot_figure( 16 | x_vals: Tensor, 17 | y_vals: Tensor, 18 | auc: Tensor, 19 | xlim: Tuple[float, float], 20 | ylim: Tuple[float, float], 21 | xlabel: str, 22 | ylabel: str, 23 | loc: str, 24 | title: str, 25 | sample_points: int = 1000, 26 | ) -> Tuple[Figure, Axis]: 27 | """Generate a simple, ROC-style plot, where x_vals is plotted against y_vals. 28 | 29 | Note that a subsampling is applied if > sample_points are present in x/y, as matplotlib plotting draws 30 | every single plot which takes very long, especially for high-resolution segmentations. 31 | 32 | Args: 33 | x_vals (Tensor): x values to plot 34 | y_vals (Tensor): y values to plot 35 | auc (Tensor): normalized area under the curve spanned by x_vals, y_vals 36 | xlim (Tuple[float, float]): displayed range for x-axis 37 | ylim (Tuple[float, float]): displayed range for y-axis 38 | xlabel (str): label of x axis 39 | ylabel (str): label of y axis 40 | loc (str): string-based legend location, for details see 41 | https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.legend.html 42 | title (str): title of the plot 43 | sample_points (int): number of sampling points to subsample x_vals/y_vals with 44 | 45 | Returns: 46 | Tuple[Figure, Axis]: Figure and the contained Axis 47 | """ 48 | fig, axis = plt.subplots() 49 | 50 | x_vals = x_vals.detach().cpu() 51 | y_vals = y_vals.detach().cpu() 52 | 53 | if sample_points < x_vals.size(0): 54 | possible_idx = range(x_vals.size(0)) 55 | interval = len(possible_idx) // sample_points 56 | 57 | idx = [0] # make sure to start at first point 58 | idx.extend(possible_idx[::interval]) 59 | idx.append(possible_idx[-1]) # also include last point 60 | 61 | idx = torch.tensor( 62 | idx, 63 | device=x_vals.device, 64 | ) 65 | x_vals = torch.index_select(x_vals, 0, idx) 66 | y_vals = torch.index_select(y_vals, 0, idx) 67 | 68 | axis.plot( 69 | x_vals, 70 | y_vals, 71 | color="darkorange", 72 | figure=fig, 73 | lw=2, 74 | label=f"AUC: {auc.detach().cpu():0.2f}", 75 | ) 76 | 77 | axis.set_xlim(xlim) 78 | axis.set_ylim(ylim) 79 | axis.set_xlabel(xlabel) 80 | axis.set_ylabel(ylabel) 81 | axis.legend(loc=loc) 82 | axis.set_title(title) 83 | return fig, axis 84 | -------------------------------------------------------------------------------- /patchcore/pre_processing/transforms/custom.py: -------------------------------------------------------------------------------- 1 | """Dataset Utils.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import List, Optional, Tuple 7 | 8 | import numpy as np 9 | from torch import Tensor 10 | 11 | 12 | class Denormalize: 13 | """Denormalize Torch Tensor into np image format.""" 14 | 15 | def __init__(self, mean: Optional[List[float]] = None, std: Optional[List[float]] = None): 16 | """Denormalize Torch Tensor into np image format. 17 | 18 | Args: 19 | mean: Mean 20 | std: Standard deviation. 21 | """ 22 | # If no mean and std provided, assign ImageNet values. 23 | if mean is None: 24 | mean = [0.485, 0.456, 0.406] 25 | 26 | if std is None: 27 | std = [0.229, 0.224, 0.225] 28 | 29 | self.mean = Tensor(mean) 30 | self.std = Tensor(std) 31 | 32 | def __call__(self, tensor: Tensor) -> np.ndarray: 33 | """Denormalize the input. 34 | 35 | Args: 36 | tensor (Tensor): Input tensor image (C, H, W) 37 | 38 | Returns: 39 | Denormalized numpy array (H, W, C). 40 | """ 41 | if tensor.dim() == 4: 42 | if tensor.size(0): 43 | tensor = tensor.squeeze(0) 44 | else: 45 | raise ValueError(f"Tensor has batch size of {tensor.size(0)}. Only single batch is supported.") 46 | 47 | for tnsr, mean, std in zip(tensor, self.mean, self.std): 48 | tnsr.mul_(std).add_(mean) 49 | 50 | array = (tensor * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) 51 | return array 52 | 53 | def __repr__(self): 54 | """Representational string.""" 55 | return self.__class__.__name__ + "()" 56 | 57 | 58 | class ToNumpy: 59 | """Convert Tensor into Numpy Array.""" 60 | 61 | def __call__(self, tensor: Tensor, dims: Optional[Tuple[int, ...]] = None) -> np.ndarray: 62 | """Convert Tensor into Numpy Array. 63 | 64 | Args: 65 | tensor (Tensor): Tensor to convert. Input tensor in range 0-1. 66 | dims (Optional[Tuple[int, ...]], optional): Convert dimensions from torch to numpy format. 67 | Tuple corresponding to axis permutation from torch tensor to numpy array. Defaults to None. 68 | 69 | Returns: 70 | Converted numpy ndarray. 71 | """ 72 | # Default support is (C, H, W) or (N, C, H, W) 73 | if dims is None: 74 | dims = (0, 2, 3, 1) if len(tensor.shape) == 4 else (1, 2, 0) 75 | 76 | array = (tensor * 255).permute(dims).cpu().numpy().astype(np.uint8) 77 | 78 | if array.shape[0] == 1: 79 | array = array.squeeze(0) 80 | if array.shape[-1] == 1: 81 | array = array.squeeze(-1) 82 | 83 | return array 84 | 85 | def __repr__(self) -> str: 86 | """Representational string.""" 87 | return self.__class__.__name__ + "()" 88 | -------------------------------------------------------------------------------- /patchcore/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """Custom anomaly evaluation metrics.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | import importlib 7 | import warnings 8 | from typing import List, Optional, Tuple, Union 9 | 10 | import torchmetrics 11 | from omegaconf import DictConfig, ListConfig 12 | 13 | from .adaptive_threshold import AdaptiveThreshold 14 | from .anomaly_score_distribution import AnomalyScoreDistribution 15 | from .aupr import AUPR 16 | from .aupro import AUPRO 17 | from .auroc import AUROC 18 | from .collection import AnomalibMetricCollection 19 | from .min_max import MinMax 20 | from .optimal_f1 import OptimalF1 21 | 22 | __all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"] 23 | 24 | 25 | def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]: 26 | """Create metric collections based on the config. 27 | 28 | Args: 29 | config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf 30 | 31 | Returns: 32 | AnomalibMetricCollection: Image-level metric collection 33 | AnomalibMetricCollection: Pixel-level metric collection 34 | """ 35 | image_metric_names = config.metrics.image if "image" in config.metrics.keys() else [] 36 | pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else [] 37 | image_metrics = metric_collection_from_names(image_metric_names, "image_") 38 | pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_") 39 | return image_metrics, pixel_metrics 40 | 41 | 42 | def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection: 43 | """Create a metric collection from a list of metric names. 44 | 45 | The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module, 46 | then in TorchMetrics package. 47 | 48 | Args: 49 | metric_names (List[str]): List of metric names to be included in the collection. 50 | prefix (Optional[str]): prefix to assign to the metrics in the collection. 51 | 52 | Returns: 53 | AnomalibMetricCollection: Collection of metrics. 54 | """ 55 | metrics_module = importlib.import_module("anomalib.utils.metrics") 56 | metrics = AnomalibMetricCollection([], prefix=prefix) 57 | for metric_name in metric_names: 58 | if hasattr(metrics_module, metric_name): 59 | metric_cls = getattr(metrics_module, metric_name) 60 | metrics.add_metrics(metric_cls()) 61 | elif hasattr(torchmetrics, metric_name): 62 | try: 63 | metric_cls = getattr(torchmetrics, metric_name) 64 | metrics.add_metrics(metric_cls()) 65 | except TypeError: 66 | warnings.warn(f"Incorrect constructor arguments for {metric_name} metric from TorchMetrics package.") 67 | else: 68 | warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.") 69 | return metrics 70 | -------------------------------------------------------------------------------- /patchcore/metrics/aupr.py: -------------------------------------------------------------------------------- 1 | """Implementation of AUROC metric based on TorchMetrics.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | from matplotlib.figure import Figure 10 | from torch import Tensor 11 | from torchmetrics import PrecisionRecallCurve 12 | from torchmetrics.functional import auc 13 | from torchmetrics.utilities.data import dim_zero_cat 14 | 15 | from .plotting_utils import plot_figure 16 | 17 | 18 | class AUPR(PrecisionRecallCurve): 19 | """Area under the PR curve.""" 20 | 21 | def compute(self) -> Tensor: 22 | """First compute PR curve, then compute area under the curve. 23 | 24 | Returns: 25 | Value of the AUPR metric 26 | """ 27 | prec: Tensor 28 | rec: Tensor 29 | 30 | prec, rec = self._compute() 31 | # TODO: use stable sort after upgrading to pytorch 1.9.x (https://github.com/openvinotoolkit/anomalib/issues/92) 32 | if not (torch.all(prec.diff() <= 0) or torch.all(prec.diff() >= 0)): 33 | return auc(rec, prec, reorder=True) # only reorder if rec is not increasing or decreasing 34 | return auc(rec, prec) 35 | 36 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 37 | """Update state with new values. 38 | 39 | Need to flatten new values as PrecicionRecallCurve expects them in this format for binary classification. 40 | 41 | Args: 42 | preds (Tensor): predictions of the model 43 | target (Tensor): ground truth targets 44 | """ 45 | super().update(preds.flatten(), target.flatten()) 46 | 47 | def _compute(self) -> Tuple[Tensor, Tensor]: 48 | """Compute prec/rec value pairs. 49 | 50 | Returns: 51 | Tuple containing Tensors for rec and prec 52 | """ 53 | prec: Tensor 54 | rec: Tensor 55 | prec, rec, _ = super().compute() 56 | return (prec, rec) 57 | 58 | def generate_figure(self) -> Tuple[Figure, str]: 59 | """Generate a figure containing the PR curve as well as the random baseline and the AUC. 60 | 61 | Returns: 62 | Tuple[Figure, str]: Tuple containing both the PR curve and the figure title to be used for logging 63 | """ 64 | prec, rec = self._compute() 65 | aupr = self.compute() 66 | 67 | xlim = (0.0, 1.0) 68 | ylim = (0.0, 1.0) 69 | xlabel = "Precision" 70 | ylabel = "Recall" 71 | loc = "best" 72 | title = "AUPR" 73 | 74 | fig, axis = plot_figure(rec, prec, aupr, xlim, ylim, xlabel, ylabel, loc, title) 75 | 76 | # Baseline in PR-curve is the prevalence of the positive class 77 | rate = (dim_zero_cat(self.target) == 1).sum() / (dim_zero_cat(self.target).size(0)) 78 | axis.plot( 79 | (0, 1), 80 | (rate.detach().cpu(), rate.detach().cpu()), 81 | color="navy", 82 | lw=2, 83 | linestyle="--", 84 | figure=fig, 85 | ) 86 | 87 | return fig, title 88 | -------------------------------------------------------------------------------- /patchcore/components/feature_extractors/feature_extractor.py: -------------------------------------------------------------------------------- 1 | """Feature Extractor. 2 | 3 | This script extracts features from a CNN network 4 | """ 5 | 6 | # Copyright (C) 2022 Intel Corporation 7 | # SPDX-License-Identifier: Apache-2.0 8 | 9 | import warnings 10 | from typing import Dict, List 11 | 12 | import timm 13 | import torch 14 | from torch import Tensor, nn 15 | 16 | 17 | class FeatureExtractor(nn.Module): 18 | """Extract features from a CNN. 19 | 20 | Args: 21 | backbone (nn.Module): The backbone to which the feature extraction hooks are attached. 22 | layers (Iterable[str]): List of layer names of the backbone to which the hooks are attached. 23 | 24 | Example: 25 | >>> import torch 26 | >>> from anomalib.core.model.feature_extractor import FeatureExtractor 27 | 28 | >>> model = FeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3']) 29 | >>> input = torch.rand((32, 3, 256, 256)) 30 | >>> features = model(input) 31 | 32 | >>> [layer for layer in features.keys()] 33 | ['layer1', 'layer2', 'layer3'] 34 | >>> [feature.shape for feature in features.values()] 35 | [torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])] 36 | """ 37 | 38 | def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True): 39 | super().__init__() 40 | self.backbone = backbone 41 | self.layers = layers 42 | self.idx = self._map_layer_to_idx() 43 | self.feature_extractor = timm.create_model( 44 | backbone, 45 | pretrained=pre_trained, 46 | features_only=True, 47 | exportable=True, 48 | out_indices=self.idx, 49 | ) 50 | self.out_dims = self.feature_extractor.feature_info.channels() 51 | self._features = {layer: torch.empty(0) for layer in self.layers} 52 | 53 | def _map_layer_to_idx(self, offset: int = 3) -> List[int]: 54 | """Maps set of layer names to indices of model. 55 | 56 | Args: 57 | offset (int) `timm` ignores the first few layers when indexing please update offset based on need 58 | 59 | Returns: 60 | Feature map extracted from the CNN 61 | """ 62 | idx = [] 63 | features = timm.create_model( 64 | self.backbone, 65 | pretrained=False, 66 | features_only=False, 67 | exportable=True, 68 | ) 69 | for i in self.layers: 70 | try: 71 | idx.append(list(dict(features.named_children()).keys()).index(i) - offset) 72 | except ValueError: 73 | warnings.warn(f"Layer {i} not found in model {self.backbone}") 74 | # Remove unfound key from layer dict 75 | self.layers.remove(i) 76 | 77 | return idx 78 | 79 | def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]: 80 | """Forward-pass input tensor into the CNN. 81 | 82 | Args: 83 | input_tensor (Tensor): Input tensor 84 | 85 | Returns: 86 | Feature map extracted from the CNN 87 | """ 88 | features = dict(zip(self.layers, self.feature_extractor(input_tensor))) 89 | return features 90 | -------------------------------------------------------------------------------- /patchcore/components/stats/kde.py: -------------------------------------------------------------------------------- 1 | """Gaussian Kernel Density Estimation.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | import math 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import Tensor 11 | import os 12 | import sys 13 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 14 | from base import DynamicBufferModule 15 | 16 | 17 | class GaussianKDE(DynamicBufferModule): 18 | """Gaussian Kernel Density Estimation. 19 | 20 | Args: 21 | dataset (Optional[Tensor], optional): Dataset on which to fit the KDE model. Defaults to None. 22 | """ 23 | 24 | def __init__(self, dataset: Optional[Tensor] = None): 25 | super().__init__() 26 | 27 | if dataset is not None: 28 | self.fit(dataset) 29 | 30 | self.register_buffer("bw_transform", Tensor()) 31 | self.register_buffer("dataset", Tensor()) 32 | self.register_buffer("norm", Tensor()) 33 | 34 | self.bw_transform = Tensor() 35 | self.dataset = Tensor() 36 | self.norm = Tensor() 37 | 38 | def forward(self, features: Tensor) -> Tensor: 39 | """Get the KDE estimates from the feature map. 40 | 41 | Args: 42 | features (Tensor): Feature map extracted from the CNN 43 | 44 | Returns: KDE Estimates 45 | """ 46 | features = torch.matmul(features, self.bw_transform) 47 | 48 | estimate = torch.zeros(features.shape[0]).to(features.device) 49 | for i in range(features.shape[0]): 50 | embedding = ((self.dataset - features[i]) ** 2).sum(dim=1) 51 | embedding = torch.exp(-embedding / 2) * self.norm 52 | estimate[i] = torch.mean(embedding) 53 | 54 | return estimate 55 | 56 | def fit(self, dataset: Tensor) -> None: 57 | """Fit a KDE model to the input dataset. 58 | 59 | Args: 60 | dataset (Tensor): Input dataset. 61 | 62 | Returns: 63 | None 64 | """ 65 | num_samples, dimension = dataset.shape 66 | 67 | # compute scott's bandwidth factor 68 | factor = num_samples ** (-1 / (dimension + 4)) 69 | 70 | cov_mat = self.cov(dataset.T) 71 | inv_cov_mat = torch.linalg.inv(cov_mat) 72 | inv_cov = inv_cov_mat / factor**2 73 | 74 | # transform data to account for bandwidth 75 | bw_transform = torch.linalg.cholesky(inv_cov) 76 | dataset = torch.matmul(dataset, bw_transform) 77 | 78 | # 79 | norm = torch.prod(torch.diag(bw_transform)) 80 | norm *= math.pow((2 * math.pi), (-dimension / 2)) 81 | 82 | self.bw_transform = bw_transform 83 | self.dataset = dataset 84 | self.norm = norm 85 | 86 | @staticmethod 87 | def cov(tensor: Tensor) -> Tensor: 88 | """Calculate the unbiased covariance matrix. 89 | 90 | Args: 91 | tensor (Tensor): Input tensor from which covariance matrix is computed. 92 | 93 | Returns: 94 | Output covariance matrix. 95 | """ 96 | mean = torch.mean(tensor, dim=1) 97 | tensor -= mean[:, None] 98 | cov = torch.matmul(tensor, tensor.T) / (tensor.size(1) - 1) 99 | return cov 100 | -------------------------------------------------------------------------------- /patchcore/components/filters/blur.py: -------------------------------------------------------------------------------- 1 | """Gaussian blurring via pytorch.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from kornia.filters import get_gaussian_kernel2d 6 | from kornia.filters.filter import _compute_padding 7 | from kornia.filters.kernels import normalize_kernel2d 8 | from torch import Tensor, nn 9 | from torch.nn import functional as F 10 | 11 | 12 | class GaussianBlur2d(nn.Module): 13 | """Compute GaussianBlur in 2d. 14 | 15 | Makes use of kornia functions, but most notably the kernel is not computed 16 | during the forward pass, and does not depend on the input size. As a caveat, 17 | the number of channels that are expected have to be provided during initialization. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | kernel_size: Union[Tuple[int, int], int], 23 | sigma: Union[Tuple[float, float], float], 24 | channels: int, 25 | normalize: bool = True, 26 | border_type: str = "reflect", 27 | padding: str = "same", 28 | ) -> None: 29 | """Initialize model, setup kernel etc.. 30 | 31 | Args: 32 | kernel_size (Union[Tuple[int, int], int]): size of the Gaussian kernel to use. 33 | sigma (Union[Tuple[float, float], float]): standard deviation to use for constructing the Gaussian kernel. 34 | channels (int): channels of the input 35 | normalize (bool, optional): Whether to normalize the kernel or not (i.e. all elements sum to 1). 36 | Defaults to True. 37 | border_type (str, optional): Border type to use for padding of the input. Defaults to "reflect". 38 | padding (str, optional): Type of padding to apply. Defaults to "same". 39 | """ 40 | super().__init__() 41 | kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) 42 | sigma = sigma if isinstance(sigma, tuple) else (sigma, sigma) 43 | self.kernel: Tensor 44 | self.register_buffer("kernel", get_gaussian_kernel2d(kernel_size=kernel_size, sigma=sigma)) 45 | if normalize: 46 | self.kernel = normalize_kernel2d(self.kernel) 47 | self.channels = channels 48 | self.kernel.unsqueeze_(0).unsqueeze_(0) 49 | self.kernel = self.kernel.expand(self.channels, -1, -1, -1) 50 | self.border_type = border_type 51 | self.padding = padding 52 | self.height, self.width = self.kernel.shape[-2:] 53 | self.padding_shape = _compute_padding([self.height, self.width]) 54 | 55 | def forward(self, input_tensor: Tensor) -> Tensor: 56 | """Blur the input with the computed Gaussian. 57 | 58 | Args: 59 | input_tensor (Tensor): Input tensor to be blurred. 60 | 61 | Returns: 62 | Tensor: Blurred output tensor. 63 | """ 64 | batch, channel, height, width = input_tensor.size() 65 | 66 | if self.padding == "same": 67 | input_tensor = F.pad(input_tensor, self.padding_shape, mode=self.border_type) 68 | 69 | # convolve the tensor with the kernel. 70 | output = F.conv2d(input_tensor, self.kernel, groups=self.channels, padding=0, stride=1) 71 | 72 | if self.padding == "same": 73 | out = output.view(batch, channel, height, width) 74 | else: 75 | out = output.view(batch, channel, height - self.height + 1, width - self.width + 1) 76 | 77 | return out 78 | -------------------------------------------------------------------------------- /preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import tqdm 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import nibabel as nib 8 | import numpy as np 9 | from glob import glob 10 | from pdb import set_trace 11 | from PIL import Image, ImageDraw 12 | import json 13 | 14 | def crop(obj_name, img, is_mask=False): 15 | w,h = img.size 16 | arr = np.array(img) 17 | mg = 0 18 | if obj_name == "screw_bag": 19 | mg = 150 20 | 21 | if is_mask: 22 | arr = arr[mg:h-mg, mg:w-mg] 23 | else: 24 | arr = arr[mg:h-mg, mg:w-mg] 25 | return Image.fromarray(arr) 26 | 27 | 28 | def zero_pad(img, size, is_mask=False): 29 | if is_mask: 30 | new_arr = np.zeros((size, size), dtype=np.uint8) 31 | arr = np.array(img) 32 | print(arr.shape) 33 | w, h = arr.shape 34 | new_arr[:w, :h] = arr 35 | else: 36 | new_arr = np.zeros((size, size, 3), dtype=np.uint8) 37 | arr = np.array(img) 38 | w, h, _ = arr.shape 39 | new_arr[:w, :h] = arr 40 | return Image.fromarray(new_arr) 41 | 42 | 43 | def resize_im(trg_size, im, is_mask=False): 44 | w, h = im.size 45 | d_max = np.amax([w, h]) 46 | scale = trg_size/d_max 47 | new_size = [int(w*scale), int(h*scale)] 48 | if is_mask: 49 | im = im.resize(new_size, Image.NEAREST) 50 | else: 51 | im = im.resize(new_size) 52 | return im 53 | 54 | 55 | obj_names = ["breakfast_box", "juice_bottle", 56 | "pushpins", "screw_bag", "splicing_connectors"] 57 | 58 | def main(obj_name): 59 | trg_size = 448 60 | shot = "fewshot" 61 | print(obj_name) 62 | dpath = f"./orig/{obj_name}/train/good" 63 | os.makedirs(f"./Annotations_{trg_size}/{obj_name}", exist_ok=True) 64 | os.makedirs(f"./Images_{trg_size}/{obj_name}", exist_ok=True) 65 | fpaths = glob(f"{dpath}/*.png") 66 | # print(fpaths) 67 | fpaths.sort() 68 | for i, fpath in enumerate(fpaths): 69 | # read nii file 70 | im = Image.open(fpath).convert("RGB") 71 | im = crop(obj_name, im, is_mask=False) 72 | im = resize_im(trg_size, im, is_mask=False) 73 | im = zero_pad(im, trg_size, is_mask=False) 74 | im.save(f"./Images_{trg_size}/{obj_name}/{str(i).zfill(3)}.png") 75 | print(f"{i}/{len(fpaths)}", end="\r") 76 | 77 | return 78 | dpath = f"./{shot}" 79 | fpaths = glob(f"{dpath}/{obj_name}*.png") 80 | fpaths.sort() 81 | os.makedirs(f"./Annotations_{shot}_{trg_size}/{obj_name}", exist_ok=True) 82 | os.makedirs(f"./Images_{shot}_{trg_size}/{obj_name}", exist_ok=True) 83 | # print(fpaths) 84 | for i, fpath in enumerate(fpaths): 85 | # read nii file 86 | im = Image.open(fpath).convert("RGB") 87 | im = crop(obj_name, im, is_mask=False) 88 | im = resize_im(trg_size, im, is_mask=False) 89 | # print(im.size) 90 | # break 91 | im = zero_pad(im, trg_size, is_mask=False) 92 | im.save(f"./Images_{shot}_{trg_size}/{obj_name}/{str(i).zfill(3)}.png") 93 | 94 | lpath = fpath.replace(f"{shot}",f"{shot}_mask") 95 | im = Image.open(lpath) 96 | # print(lpath) 97 | im = crop(obj_name, im, is_mask=True) 98 | im = resize_im(trg_size, im, is_mask=True) 99 | im = zero_pad(im, trg_size, is_mask=True) 100 | im.save(f"./Annotations_{shot}_{trg_size}/{obj_name}/{str(i).zfill(3)}.png") 101 | 102 | print(f"{i}/{len(fpaths)}", end="\r") 103 | 104 | 105 | if __name__ == '__main__': 106 | for obj_name in obj_names: 107 | main(obj_name) 108 | -------------------------------------------------------------------------------- /patchcore/components/dimensionality_reduction/pca.py: -------------------------------------------------------------------------------- 1 | """Principle Component Analysis (PCA) with PyTorch.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Union 7 | import torch 8 | from torch import Tensor 9 | import os 10 | import sys 11 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 12 | 13 | from base import DynamicBufferModule 14 | 15 | 16 | class PCA(DynamicBufferModule): 17 | """Principle Component Analysis (PCA). 18 | 19 | Args: 20 | n_components (float): Number of components. Can be either integer number of components 21 | or a ratio between 0-1. 22 | """ 23 | 24 | def __init__(self, n_components: Union[float, int]): 25 | super().__init__() 26 | self.n_components = n_components 27 | 28 | self.register_buffer("singular_vectors", Tensor()) 29 | self.register_buffer("mean", Tensor()) 30 | self.register_buffer("num_components", Tensor()) 31 | 32 | self.singular_vectors: Tensor 33 | self.singular_values: Tensor 34 | self.mean: Tensor 35 | self.num_components: Tensor 36 | 37 | def fit(self, dataset: Tensor) -> None: 38 | """Fits the PCA model to the dataset. 39 | 40 | Args: 41 | dataset (Tensor): Input dataset to fit the model. 42 | """ 43 | mean = dataset.mean(dim=0) 44 | dataset -= mean 45 | 46 | _, sig, v_h = torch.linalg.svd(dataset.double()) 47 | num_components: int 48 | if self.n_components <= 1: 49 | variance_ratios = torch.cumsum(sig * sig, dim=0) / torch.sum(sig * sig) 50 | num_components = torch.nonzero(variance_ratios >= self.n_components)[0] 51 | else: 52 | num_components = int(self.n_components) 53 | 54 | self.num_components = Tensor([num_components]) 55 | 56 | self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components].float() 57 | self.singular_values = sig[:num_components].float() 58 | self.mean = mean 59 | 60 | def fit_transform(self, dataset: Tensor) -> Tensor: 61 | """Fit and transform PCA to dataset. 62 | 63 | Args: 64 | dataset (Tensor): Dataset to which the PCA if fit and transformed 65 | 66 | Returns: 67 | Transformed dataset 68 | """ 69 | mean = dataset.mean(dim=0) 70 | dataset -= mean 71 | num_components = int(self.n_components) 72 | self.num_components = Tensor([num_components]) 73 | 74 | v_h = torch.linalg.svd(dataset)[-1] 75 | self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components] 76 | self.mean = mean 77 | 78 | return torch.matmul(dataset, self.singular_vectors) 79 | 80 | def transform(self, features: Tensor) -> Tensor: 81 | """Transforms the features based on singular vectors calculated earlier. 82 | 83 | Args: 84 | features (Tensor): Input features 85 | 86 | Returns: 87 | Transformed features 88 | """ 89 | 90 | features -= self.mean 91 | return torch.matmul(features, self.singular_vectors) 92 | 93 | def inverse_transform(self, features: Tensor) -> Tensor: 94 | """Inverses the transformed features. 95 | 96 | Args: 97 | features (Tensor): Transformed features 98 | 99 | Returns: Inverse features 100 | """ 101 | inv_features = torch.matmul(features, self.singular_vectors.transpose(-2, -1)) 102 | return inv_features 103 | 104 | def forward(self, features: Tensor) -> Tensor: 105 | """Transforms the features. 106 | 107 | Args: 108 | features (Tensor): Input features 109 | 110 | Returns: 111 | Transformed features 112 | """ 113 | return self.transform(features) 114 | -------------------------------------------------------------------------------- /patchcore/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import glob 5 | from PIL import Image 6 | 7 | import torch 8 | from torchvision import transforms 9 | from torchvision.transforms.functional import to_pil_image 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | class MVTecLOCODataset(Dataset): 13 | def __init__(self, root, transform, phase, args, anomal_type = None): 14 | self.transform = transform 15 | self.phase = phase 16 | self.args = args 17 | self.resize = transforms.Resize((int(args.size/8), int(args.size/8)), Image.NEAREST) 18 | 19 | self.img_paths = glob.glob(os.path.join(root, 'orig_512', args.category, phase, '*', '*.png')) 20 | self.seg_paths = glob.glob(os.path.join(root, 'orig_512_seg_scratch_3level_unet', args.category, phase, '*', '*.png')) 21 | 22 | if anomal_type != None: 23 | self.img_paths = [i for i in self.img_paths if anomal_type in i or 'good' in i] 24 | self.seg_paths = [i for i in self.seg_paths if anomal_type in i or 'good' in i] 25 | 26 | self.labels = [] 27 | for i in self.img_paths: 28 | if 'good' in i: 29 | self.labels.append(0) 30 | else: 31 | self.labels.append(1) 32 | 33 | self.totensor = transforms.ToTensor() 34 | 35 | def __len__(self): 36 | return len(self.img_paths) 37 | 38 | def __getitem__(self, idx): 39 | img = Image.open(self.img_paths[idx]).convert('RGB') 40 | img = self.transform(img) 41 | seg = Image.open(self.seg_paths[idx]).convert('RGB') 42 | seg = self.resize(seg) 43 | seg = self.totensor(seg) 44 | label = self.labels[idx] 45 | name = self.img_paths[idx].split('/')[-3:] 46 | name = name[0] + '/' + name[1] + '/' + name[2].replace('.png', '') 47 | 48 | return img, seg, label, name 49 | 50 | class MVTecADDataset(Dataset): 51 | def __init__(self, root, transform, phase, args, anomal_type = None): 52 | self.transform = transform 53 | self.phase = phase 54 | self.args = args 55 | self.resize = transforms.Resize((int(args.size/8), int(args.size/8)), Image.NEAREST) 56 | 57 | self.img_paths = glob.glob(os.path.join(root, args.category, phase, '*', '*.png')) 58 | 59 | self.labels = [] 60 | for i in self.img_paths: 61 | if 'good' in i: 62 | self.labels.append(0) 63 | else: 64 | self.labels.append(1) 65 | 66 | self.totensor = transforms.ToTensor() 67 | 68 | def __len__(self): 69 | return len(self.img_paths) 70 | 71 | def __getitem__(self, idx): 72 | img = Image.open(self.img_paths[idx]).convert('RGB') 73 | img = self.transform(img) 74 | label = self.labels[idx] 75 | name = self.img_paths[idx].split('/')[-3:] 76 | name = name[0] + '/' + name[1] + '/' + name[2].replace('.png', '') 77 | 78 | return img, img, label, name 79 | 80 | class VisADataset(Dataset): 81 | def __init__(self, root, transform, phase, args, anomal_type = None): 82 | self.transform = transform 83 | self.phase = phase 84 | self.args = args 85 | self.resize = transforms.Resize((int(args.size/8), int(args.size/8)), Image.NEAREST) 86 | 87 | self.img_paths = glob.glob(os.path.join(root, args.category, phase, '*', '*.JPG')) 88 | 89 | self.labels = [] 90 | for i in self.img_paths: 91 | if 'good' in i: 92 | self.labels.append(0) 93 | else: 94 | self.labels.append(1) 95 | 96 | self.totensor = transforms.ToTensor() 97 | 98 | def __len__(self): 99 | return len(self.img_paths) 100 | 101 | def __getitem__(self, idx): 102 | img = Image.open(self.img_paths[idx]).convert('RGB') 103 | img = self.transform(img) 104 | label = self.labels[idx] 105 | name = self.img_paths[idx].split('/')[-3:] 106 | name = name[0] + '/' + name[1] + '/' + name[2].replace('.JPG', '') 107 | 108 | return img, img, label, name -------------------------------------------------------------------------------- /patchcore/anomaly_map.py: -------------------------------------------------------------------------------- 1 | """Anomaly Map Generator for the PatchCore model implementation.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Tuple, Union 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from omegaconf import ListConfig 11 | from torch import nn 12 | import numpy as np 13 | import pdb 14 | 15 | from components import GaussianBlur2d 16 | 17 | 18 | class AnomalyMapGenerator(nn.Module): 19 | """Generate Anomaly Heatmap.""" 20 | 21 | def __init__( 22 | self, 23 | input_size: Union[ListConfig, Tuple], 24 | sigma: int = 4, 25 | ) -> None: 26 | super().__init__() 27 | self.input_size = input_size 28 | kernel_size = 2 * int(4.0 * sigma + 0.5) + 1 29 | self.blur = GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), channels=1) 30 | 31 | def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: torch.Size) -> torch.Tensor: 32 | """Pixel Level Anomaly Heatmap. 33 | 34 | Args: 35 | patch_scores (torch.Tensor): Patch-level anomaly scores 36 | feature_map_shape (torch.Size): 2-D feature map shape (width, height) 37 | 38 | Returns: 39 | torch.Tensor: Map of the pixel-level anomaly scores 40 | """ 41 | width, height = feature_map_shape 42 | batch_size = len(patch_scores) // (width * height) 43 | 44 | anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, width, height)) 45 | # anomaly_map_ = self.blur(anomaly_map) 46 | anomaly_map_ = anomaly_map.clone() 47 | anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1])) 48 | 49 | anomaly_map = self.blur(anomaly_map) 50 | 51 | return anomaly_map/10, anomaly_map_/10 52 | 53 | @staticmethod 54 | def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor: 55 | """Compute Image-Level Anomaly Score. 56 | 57 | Args: 58 | patch_scores (torch.Tensor): Patch-level anomaly scores 59 | Returns: 60 | torch.Tensor: Image-level anomaly scores 61 | """ 62 | patch_scores /= 10 63 | 64 | if patch_scores.shape[1] != 1: 65 | max_scores = torch.argmax(patch_scores[:, 0]) 66 | confidence = torch.index_select(patch_scores, 0, max_scores) 67 | weights = 1 - (torch.max(torch.exp(confidence)) / torch.sum(torch.exp(confidence))) 68 | score = weights * torch.max(patch_scores[:, 0]) 69 | else: 70 | score = torch.max(patch_scores[:, 0]) 71 | 72 | # if np.isnan(score.detach().cpu().item()): 73 | # pdb.set_trace() 74 | # print('') 75 | 76 | return score 77 | 78 | def forward(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """Returns anomaly_map and anomaly_score. 80 | 81 | Expects `patch_scores` keyword to be passed explicitly 82 | Expects `feature_map_shape` keyword to be passed explicitly 83 | 84 | Example 85 | >>> anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) 86 | >>> map, score = anomaly_map_generator(patch_scores=numpy_array, feature_map_shape=feature_map_shape) 87 | 88 | Raises: 89 | ValueError: If `patch_scores` key is not found 90 | 91 | Returns: 92 | Tuple[torch.Tensor, torch.Tensor]: anomaly_map, anomaly_score 93 | """ 94 | 95 | if "patch_scores" not in kwargs: 96 | raise ValueError(f"Expected key `patch_scores`. Found {kwargs.keys()}") 97 | 98 | if "feature_map_shape" not in kwargs: 99 | raise ValueError(f"Expected key `feature_map_shape`. Found {kwargs.keys()}") 100 | 101 | patch_scores = kwargs["patch_scores"] 102 | feature_map_shape = kwargs["feature_map_shape"] 103 | 104 | anomaly_map, anomaly_map_ = self.compute_anomaly_map(patch_scores, feature_map_shape) 105 | anomaly_score = self.compute_anomaly_score(patch_scores) 106 | 107 | return anomaly_map, anomaly_score, anomaly_map_ 108 | -------------------------------------------------------------------------------- /patchcore/lightning_model.py: -------------------------------------------------------------------------------- 1 | """Towards Total Recall in Industrial Anomaly Detection. 2 | 3 | Paper https://arxiv.org/abs/2106.08265. 4 | """ 5 | 6 | # Copyright (C) 2022 Intel Corporation 7 | # SPDX-License-Identifier: Apache-2.0 8 | 9 | import logging 10 | from typing import List, Tuple, Union 11 | 12 | import torch 13 | from omegaconf import DictConfig, ListConfig 14 | from torch import Tensor 15 | 16 | from components import AnomalyModule 17 | from torch_model import PatchcoreModel 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class Patchcore(AnomalyModule): 22 | """PatchcoreLightning Module to train PatchCore algorithm. 23 | 24 | Args: 25 | input_size (Tuple[int, int]): Size of the model input. 26 | backbone (str): Backbone CNN network 27 | layers (List[str]): Layers to extract features from the backbone CNN 28 | pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. 29 | coreset_sampling_ratio (float, optional): Coreset sampling ratio to subsample embedding. 30 | Defaults to 0.1. 31 | num_neighbors (int, optional): Number of nearest neighbors. Defaults to 9. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | input_size: Tuple[int, int], 37 | backbone: str, 38 | layers: List[str], 39 | pre_trained: bool = True, 40 | coreset_sampling_ratio: float = 1, 41 | num_neighbors: int = 9, 42 | ) -> None: 43 | super().__init__() 44 | 45 | self.model: PatchcoreModel = PatchcoreModel( 46 | input_size=input_size, 47 | backbone=backbone, 48 | pre_trained=pre_trained, 49 | layers=layers, 50 | num_neighbors=num_neighbors, 51 | ) 52 | self.coreset_sampling_ratio = coreset_sampling_ratio 53 | self.embeddings: List[Tensor] = [] 54 | 55 | def configure_optimizers(self) -> None: 56 | """Configure optimizers. 57 | 58 | Returns: 59 | None: Do not set optimizers by returning None. 60 | """ 61 | return None 62 | 63 | def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ 64 | """Generate feature embedding of the batch. 65 | 66 | Args: 67 | batch (Dict[str, Any]): Batch containing image filename, image, label and mask 68 | _batch_idx (int): Batch Index 69 | 70 | Returns: 71 | Dict[str, np.ndarray]: Embedding Vector 72 | """ 73 | self.model.feature_extractor.eval() 74 | embedding = self.model(batch["image"]) 75 | 76 | # NOTE: `self.embedding` appends each batch embedding to 77 | # store the training set embedding. We manually append these 78 | # values mainly due to the new order of hooks introduced after PL v1.4.0 79 | # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 80 | self.embeddings.append(embedding) 81 | 82 | def on_validation_start(self) -> None: 83 | """Apply subsampling to the embedding collected from the training set.""" 84 | # NOTE: Previous anomalib versions fit subsampling at the end of the epoch. 85 | # This is not possible anymore with PyTorch Lightning v1.4.0 since validation 86 | # is run within train epoch. 87 | logger.info("Aggregating the embedding extracted from the training set.") 88 | embeddings = torch.vstack(self.embeddings) 89 | 90 | logger.info("Applying core-set subsampling to get the embedding.") 91 | # self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio) 92 | 93 | def validation_step(self, batch, _): # pylint: disable=arguments-differ 94 | """Get batch of anomaly maps from input image batch. 95 | 96 | Args: 97 | batch (Dict[str, Any]): Batch containing image filename, 98 | image, label and mask 99 | _ (int): Batch Index 100 | 101 | Returns: 102 | Dict[str, Any]: Image filenames, test images, GT and predicted label/masks 103 | """ 104 | 105 | anomaly_maps, anomaly_score = self.model(batch["image"]) 106 | batch["anomaly_maps"] = anomaly_maps 107 | batch["pred_scores"] = anomaly_score.unsqueeze(0) 108 | 109 | return batch -------------------------------------------------------------------------------- /patchcore/components/freia/framework/sequence_inn.py: -------------------------------------------------------------------------------- 1 | """Sequence INN.""" 2 | 3 | # Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | 7 | # pylint: disable=invalid-name 8 | # flake8: noqa 9 | # pylint: skip-file 10 | # type: ignore 11 | # pydocstyle: noqa 12 | 13 | from typing import Iterable, List, Tuple 14 | 15 | import torch 16 | from torch import Tensor, nn 17 | 18 | from anomalib.models.components.freia.modules.base import InvertibleModule 19 | 20 | 21 | class SequenceINN(InvertibleModule): 22 | """Simpler than FrEIA.framework.GraphINN. 23 | 24 | Only supports a sequential series of modules (no splitting, merging, 25 | branching off). 26 | Has an append() method, to add new blocks in a more simple way than the 27 | computation-graph based approach of GraphINN. For example: 28 | .. code-block:: python 29 | inn = SequenceINN(channels, dims_H, dims_W) 30 | for i in range(n_blocks): 31 | inn.append(FrEIA.modules.AllInOneBlock, clamp=2.0, permute_soft=True) 32 | inn.append(FrEIA.modules.HaarDownsampling) 33 | # and so on 34 | """ 35 | 36 | def __init__(self, *dims: int, force_tuple_output=False): 37 | super().__init__([dims]) 38 | 39 | self.shapes = [tuple(dims)] 40 | self.conditions = [] 41 | self.module_list = nn.ModuleList() 42 | 43 | self.force_tuple_output = force_tuple_output 44 | 45 | def append(self, module_class, cond=None, cond_shape=None, **kwargs): 46 | """Append a reversible block from FrEIA.modules to the network. 47 | 48 | Args: 49 | module_class: Class from FrEIA.modules. 50 | cond (int): index of which condition to use (conditions will be passed as list to forward()). 51 | Conditioning nodes are not needed for SequenceINN. 52 | cond_shape (tuple[int]): the shape of the condition tensor. 53 | **kwargs: Further keyword arguments that are passed to the constructor of module_class (see example). 54 | """ 55 | 56 | dims_in = [self.shapes[-1]] 57 | self.conditions.append(cond) 58 | 59 | if cond is not None: 60 | kwargs["dims_c"] = [cond_shape] 61 | 62 | module = module_class(dims_in, **kwargs) 63 | self.module_list.append(module) 64 | ouput_dims = module.output_dims(dims_in) 65 | assert len(ouput_dims) == 1, "Module has more than one output" 66 | self.shapes.append(ouput_dims[0]) 67 | 68 | def __getitem__(self, item): 69 | """Get item.""" 70 | return self.module_list.__getitem__(item) 71 | 72 | def __len__(self): 73 | """Get length.""" 74 | return self.module_list.__len__() 75 | 76 | def __iter__(self): 77 | """Iter.""" 78 | return self.module_list.__iter__() 79 | 80 | def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: 81 | """Output Dims.""" 82 | if not self.force_tuple_output: 83 | raise ValueError( 84 | "You can only call output_dims on a SequentialINN " "when setting force_tuple_output=True." 85 | ) 86 | return input_dims 87 | 88 | def forward( 89 | self, x_or_z: Tensor, c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True 90 | ) -> Tuple[Tensor, Tensor]: 91 | """Execute the sequential INN in forward or inverse (rev=True) direction. 92 | 93 | Args: 94 | x_or_z: input tensor (in contrast to GraphINN, a list of 95 | tensors is not supported, as SequenceINN only has 96 | one input). 97 | c: list of conditions. 98 | rev: whether to compute the network forward or reversed. 99 | jac: whether to compute the log jacobian 100 | Returns: 101 | z_or_x (Tensor): network output. 102 | jac (Tensor): log-jacobian-determinant. 103 | """ 104 | 105 | iterator = range(len(self.module_list)) 106 | log_det_jac = 0 107 | 108 | if rev: 109 | iterator = reversed(iterator) 110 | 111 | if torch.is_tensor(x_or_z): 112 | x_or_z = (x_or_z,) 113 | for i in iterator: 114 | if self.conditions[i] is None: 115 | x_or_z, j = self.module_list[i](x_or_z, jac=jac, rev=rev) 116 | else: 117 | x_or_z, j = self.module_list[i](x_or_z, c=[c[self.conditions[i]]], jac=jac, rev=rev) 118 | log_det_jac = j + log_det_jac 119 | 120 | return x_or_z if self.force_tuple_output else x_or_z[0], log_det_jac 121 | -------------------------------------------------------------------------------- /patchcore/components/sampling/k_center_greedy.py: -------------------------------------------------------------------------------- 1 | """This module comprises PatchCore Sampling Methods for the embedding. 2 | 3 | - k Center Greedy Method 4 | Returns points that minimizes the maximum distance of any point to a center. 5 | . https://arxiv.org/abs/1708.00489 6 | """ 7 | 8 | from typing import List, Optional 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import Tensor 13 | import os 14 | import sys 15 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 16 | from dimensionality_reduction import SparseRandomProjection 17 | 18 | 19 | class KCenterGreedy: 20 | """Implements k-center-greedy method. 21 | 22 | Args: 23 | embedding (Tensor): Embedding vector extracted from a CNN 24 | sampling_ratio (float): Ratio to choose coreset size from the embedding size. 25 | 26 | Example: 27 | >>> embedding.shape 28 | torch.Size([219520, 1536]) 29 | >>> sampler = KCenterGreedy(embedding=embedding) 30 | >>> sampled_idxs = sampler.select_coreset_idxs() 31 | >>> coreset = embedding[sampled_idxs] 32 | >>> coreset.shape 33 | torch.Size([219, 1536]) 34 | """ 35 | 36 | def __init__(self, embedding: Tensor, sampling_ratio: float) -> None: 37 | self.embedding = embedding 38 | self.coreset_size = int(embedding.shape[0] * sampling_ratio) 39 | self.model = SparseRandomProjection(eps=0.9) 40 | 41 | self.features: Tensor 42 | self.min_distances: Tensor = None 43 | self.n_observations = self.embedding.shape[0] 44 | 45 | def reset_distances(self) -> None: 46 | """Reset minimum distances.""" 47 | self.min_distances = None 48 | 49 | def update_distances(self, cluster_centers: List[int]) -> None: 50 | """Update min distances given cluster centers. 51 | 52 | Args: 53 | cluster_centers (List[int]): indices of cluster centers 54 | """ 55 | 56 | if cluster_centers: 57 | centers = self.features[cluster_centers] 58 | 59 | distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1) 60 | 61 | if self.min_distances is None: 62 | self.min_distances = distance 63 | else: 64 | self.min_distances = torch.minimum(self.min_distances, distance) 65 | 66 | def get_new_idx(self) -> int: 67 | """Get index value of a sample. 68 | 69 | Based on minimum distance of the cluster 70 | 71 | Returns: 72 | int: Sample index 73 | """ 74 | 75 | if isinstance(self.min_distances, Tensor): 76 | idx = int(torch.argmax(self.min_distances).item()) 77 | else: 78 | raise ValueError(f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}") 79 | 80 | return idx 81 | 82 | def select_coreset_idxs(self, selected_idxs: Optional[List[int]] = None) -> List[int]: 83 | """Greedily form a coreset to minimize the maximum distance of a cluster. 84 | 85 | Args: 86 | selected_idxs: index of samples already selected. Defaults to an empty set. 87 | 88 | Returns: 89 | indices of samples selected to minimize distance to cluster centers 90 | """ 91 | 92 | if selected_idxs is None: 93 | selected_idxs = [] 94 | 95 | if self.embedding.ndim == 2: 96 | self.model.fit(self.embedding) 97 | self.features = self.model.transform(self.embedding) 98 | self.reset_distances() 99 | else: 100 | self.features = self.embedding.reshape(self.embedding.shape[0], -1) 101 | self.update_distances(cluster_centers=selected_idxs) 102 | 103 | selected_coreset_idxs: List[int] = [] 104 | idx = int(torch.randint(high=self.n_observations, size=(1,)).item()) 105 | for _ in range(self.coreset_size): 106 | self.update_distances(cluster_centers=[idx]) 107 | idx = self.get_new_idx() 108 | if idx in selected_idxs: 109 | raise ValueError("New indices should not be in selected indices.") 110 | self.min_distances[idx] = 0 111 | selected_coreset_idxs.append(idx) 112 | 113 | return selected_coreset_idxs 114 | 115 | def sample_coreset(self, selected_idxs: Optional[List[int]] = None) -> Tensor: 116 | """Select coreset from the embedding. 117 | 118 | Args: 119 | selected_idxs: index of samples already selected. Defaults to an empty set. 120 | 121 | Returns: 122 | Tensor: Output coreset 123 | 124 | Example: 125 | >>> embedding.shape 126 | torch.Size([219520, 1536]) 127 | >>> sampler = KCenterGreedy(...) 128 | >>> coreset = sampler.sample_coreset() 129 | >>> coreset.shape 130 | torch.Size([219, 1536]) 131 | """ 132 | 133 | idxs = self.select_coreset_idxs(selected_idxs) 134 | coreset = self.embedding[idxs] 135 | 136 | return coreset 137 | -------------------------------------------------------------------------------- /patchcore/components/freia/modules/base.py: -------------------------------------------------------------------------------- 1 | """Base Module.""" 2 | 3 | # Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | 7 | # flake8: noqa 8 | # pylint: skip-file 9 | # type: ignore 10 | # pydocstyle: noqa 11 | 12 | from typing import Iterable, List, Tuple 13 | 14 | import torch.nn as nn 15 | from torch import Tensor 16 | 17 | 18 | class InvertibleModule(nn.Module): 19 | r"""Base class for all invertible modules in FrEIA. 20 | 21 | Given ``module``, an instance of some InvertibleModule. 22 | This ``module`` shall be invertible in its input dimensions, 23 | so that the input can be recovered by applying the module 24 | in backwards mode (``rev=True``), not to be confused with 25 | ``pytorch.backward()`` which computes the gradient of an operation:: 26 | x = torch.randn(BATCH_SIZE, DIM_COUNT) 27 | c = torch.randn(BATCH_SIZE, CONDITION_DIM) 28 | # Forward mode 29 | z, jac = module([x], [c], jac=True) 30 | # Backward mode 31 | x_rev, jac_rev = module(z, [c], rev=True) 32 | The ``module`` returns :math:`\\log \\det J = \\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` 33 | of the operation in forward mode, and 34 | :math:`-\\log | \\det J | = \\log \\left| \\det \\frac{\\partial f^{-1}}{\\partial z} \\right| = -\\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` 35 | in backward mode (``rev=True``). 36 | Then, ``torch.allclose(x, x_rev) == True`` and ``torch.allclose(jac, -jac_rev) == True``. 37 | """ 38 | 39 | def __init__(self, dims_in: Iterable[Tuple[int]], dims_c: Iterable[Tuple[int]] = None): 40 | """Initialize. 41 | 42 | Args: 43 | dims_in: list of tuples specifying the shape of the inputs to this 44 | operator: ``dims_in = [shape_x_0, shape_x_1, ...]`` 45 | dims_c: list of tuples specifying the shape of the conditions to 46 | this operator. 47 | """ 48 | super().__init__() 49 | if dims_c is None: 50 | dims_c = [] 51 | self.dims_in = list(dims_in) 52 | self.dims_c = list(dims_c) 53 | 54 | def forward( 55 | self, x_or_z: Iterable[Tensor], c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True 56 | ) -> Tuple[Tuple[Tensor], Tensor]: 57 | r"""Forward/Backward Pass. 58 | 59 | Perform a forward (default, ``rev=False``) or backward pass (``rev=True``) through this module/operator. 60 | 61 | **Note to implementers:** 62 | - Subclasses MUST return a Jacobian when ``jac=True``, but CAN return a 63 | valid Jacobian when ``jac=False`` (not punished). The latter is only recommended 64 | if the computation of the Jacobian is trivial. 65 | - Subclasses MUST follow the convention that the returned Jacobian be 66 | consistent with the evaluation direction. Let's make this more precise: 67 | Let :math:`f` be the function that the subclass represents. Then: 68 | .. math:: 69 | J &= \\log \\det \\frac{\\partial f}{\\partial x} \\\\ 70 | -J &= \\log \\det \\frac{\\partial f^{-1}}{\\partial z}. 71 | Any subclass MUST return :math:`J` for forward evaluation (``rev=False``), 72 | and :math:`-J` for backward evaluation (``rev=True``). 73 | 74 | Args: 75 | x_or_z: input data (array-like of one or more tensors) 76 | c: conditioning data (array-like of none or more tensors) 77 | rev: perform backward pass 78 | jac: return Jacobian associated to the direction 79 | """ 80 | raise NotImplementedError(f"{self.__class__.__name__} does not provide forward(...) method") 81 | 82 | def log_jacobian(self, *args, **kwargs): 83 | """This method is deprecated, and does nothing except raise a warning.""" 84 | raise DeprecationWarning( 85 | "module.log_jacobian(...) is deprecated. " 86 | "module.forward(..., jac=True) returns a " 87 | "tuple (out, jacobian) now." 88 | ) 89 | 90 | def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: 91 | """Use for shape inference during construction of the graph. 92 | 93 | MUST be implemented for each subclass of ``InvertibleModule``. 94 | 95 | Args: 96 | input_dims: A list with one entry for each input to the module. 97 | Even if the module only has one input, must be a list with one 98 | entry. Each entry is a tuple giving the shape of that input, 99 | excluding the batch dimension. For example for a module with one 100 | input, which receives a 32x32 pixel RGB image, ``input_dims`` would 101 | be ``[(3, 32, 32)]`` 102 | 103 | Returns: 104 | A list structured in the same way as ``input_dims``. Each entry 105 | represents one output of the module, and the entry is a tuple giving 106 | the shape of that output. For example if the module splits the image 107 | into a right and a left half, the return value should be 108 | ``[(3, 16, 32), (3, 16, 32)]``. It is up to the implementor of the 109 | subclass to ensure that the total number of elements in all inputs 110 | and all outputs is consistent. 111 | """ 112 | raise NotImplementedError(f"{self.__class__.__name__} does not provide output_dims(...)") 113 | -------------------------------------------------------------------------------- /patchcore/components/dimensionality_reduction/random_projection.py: -------------------------------------------------------------------------------- 1 | """This module comprises PatchCore Sampling Methods for the embedding. 2 | 3 | - Random Sparse Projector 4 | Sparse Random Projection using PyTorch Operations 5 | """ 6 | 7 | # Copyright (C) 2022 Intel Corporation 8 | # SPDX-License-Identifier: Apache-2.0 9 | 10 | from typing import Optional 11 | 12 | import numpy as np 13 | import torch 14 | from sklearn.utils.random import sample_without_replacement 15 | from torch import Tensor 16 | 17 | 18 | class NotFittedError(ValueError, AttributeError): 19 | """Raise Exception if estimator is used before fitting.""" 20 | 21 | 22 | class SparseRandomProjection: 23 | """Sparse Random Projection using PyTorch operations. 24 | 25 | Args: 26 | eps (float, optional): Minimum distortion rate parameter for calculating 27 | Johnson-Lindenstrauss minimum dimensions. Defaults to 0.1. 28 | random_state (Optional[int], optional): Uses the seed to set the random 29 | state for sample_without_replacement function. Defaults to None. 30 | """ 31 | 32 | def __init__(self, eps: float = 0.1, random_state: Optional[int] = None) -> None: 33 | self.n_components: int 34 | self.sparse_random_matrix: Tensor 35 | self.eps = eps 36 | self.random_state = random_state 37 | 38 | def _sparse_random_matrix(self, n_features: int): 39 | """Random sparse matrix. Based on https://web.stanford.edu/~hastie/Papers/Ping/KDD06_rp.pdf. 40 | 41 | Args: 42 | n_features (int): Dimentionality of the original source space 43 | 44 | Returns: 45 | Tensor: Sparse matrix of shape (n_components, n_features). 46 | The generated Gaussian random matrix is in CSR (compressed sparse row) 47 | format. 48 | """ 49 | 50 | # Density 'auto'. Factorize density 51 | density = 1 / np.sqrt(n_features) 52 | 53 | if density == 1: 54 | # skip index generation if totally dense 55 | binomial = torch.distributions.Binomial(total_count=1, probs=0.5) 56 | components = binomial.sample((self.n_components, n_features)) * 2 - 1 57 | components = 1 / np.sqrt(self.n_components) * components 58 | 59 | else: 60 | # Sparse matrix is not being generated here as it is stored as dense anyways 61 | components = torch.zeros((self.n_components, n_features), dtype=torch.float64) 62 | for i in range(self.n_components): 63 | # find the indices of the non-zero components for row i 64 | nnz_idx = torch.distributions.Binomial(total_count=n_features, probs=density).sample() 65 | # get nnz_idx column indices 66 | # pylint: disable=not-callable 67 | c_idx = torch.tensor( 68 | sample_without_replacement( 69 | n_population=n_features, n_samples=nnz_idx, random_state=self.random_state 70 | ), 71 | dtype=torch.int64, 72 | ) 73 | data = torch.distributions.Binomial(total_count=1, probs=0.5).sample(sample_shape=c_idx.size()) * 2 - 1 74 | # assign data to only those columns 75 | components[i, c_idx] = data.double() 76 | 77 | components *= np.sqrt(1 / density) / np.sqrt(self.n_components) 78 | 79 | return components 80 | 81 | def johnson_lindenstrauss_min_dim(self, n_samples: int, eps: float = 0.1): 82 | """Find a 'safe' number of components to randomly project to. 83 | 84 | Ref eqn 2.1 https://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf 85 | 86 | Args: 87 | n_samples (int): Number of samples used to compute safe components 88 | eps (float, optional): Minimum distortion rate. Defaults to 0.1. 89 | """ 90 | 91 | denominator = (eps**2 / 2) - (eps**3 / 3) 92 | return (4 * np.log(n_samples) / denominator).astype(np.int64) 93 | 94 | def fit(self, embedding: Tensor) -> "SparseRandomProjection": 95 | """Generates sparse matrix from the embedding tensor. 96 | 97 | Args: 98 | embedding (Tensor): embedding tensor for generating embedding 99 | 100 | Returns: 101 | (SparseRandomProjection): Return self to be used as 102 | >>> generator = SparseRandomProjection() 103 | >>> generator = generator.fit() 104 | """ 105 | n_samples, n_features = embedding.shape 106 | device = embedding.device 107 | 108 | self.n_components = self.johnson_lindenstrauss_min_dim(n_samples=n_samples, eps=self.eps) 109 | 110 | # Generate projection matrix 111 | # torch can't multiply directly on sparse matrix and moving sparse matrix to cuda throws error 112 | # (Could not run 'aten::empty_strided' with arguments from the 'SparseCsrCUDA' backend) 113 | # hence sparse matrix is stored as a dense matrix on the device 114 | self.sparse_random_matrix = self._sparse_random_matrix(n_features=n_features).to(device) 115 | 116 | return self 117 | 118 | def transform(self, embedding: Tensor) -> Tensor: 119 | """Project the data by using matrix product with the random matrix. 120 | 121 | Args: 122 | embedding (Tensor): Embedding of shape (n_samples, n_features) 123 | The input data to project into a smaller dimensional space 124 | 125 | Returns: 126 | projected_embedding (Tensor): Sparse matrix of shape 127 | (n_samples, n_components) Projected array. 128 | """ 129 | if self.sparse_random_matrix is None: 130 | raise NotFittedError("`fit()` has not been called on SparseRandomProjection yet.") 131 | 132 | projected_embedding = embedding @ self.sparse_random_matrix.T.float() 133 | return projected_embedding 134 | -------------------------------------------------------------------------------- /patchcore/pre_processing/pre_process.py: -------------------------------------------------------------------------------- 1 | """Pre Process. 2 | 3 | This module contains `PreProcessor` class that applies preprocessing 4 | to an input image before the forward-pass stage. 5 | """ 6 | 7 | # Copyright (C) 2022 Intel Corporation 8 | # SPDX-License-Identifier: Apache-2.0 9 | 10 | from typing import Optional, Tuple, Union 11 | 12 | import albumentations as A 13 | from albumentations.pytorch import ToTensorV2 14 | 15 | 16 | class PreProcessor: 17 | """Applies pre-processing and data augmentations to the input and returns the transformed output. 18 | 19 | Output could be either numpy ndarray or torch tensor. 20 | When `PreProcessor` class is used for training, the output would be `torch.Tensor`. 21 | For the inference it returns a numpy array. 22 | 23 | Args: 24 | config (Optional[Union[str, A.Compose]], optional): Transformation configurations. 25 | When it is ``None``, ``PreProcessor`` only applies resizing. When it is ``str`` 26 | it loads the config via ``albumentations`` deserialisation methos . Defaults to None. 27 | image_size (Optional[Union[int, Tuple[int, int]]], optional): When there is no config, 28 | ``image_size`` resizes the image. Defaults to None. 29 | to_tensor (bool, optional): Boolean to check whether the augmented image is transformed 30 | into a tensor or not. Defaults to True. 31 | 32 | Examples: 33 | >>> import skimage 34 | >>> image = skimage.data.astronaut() 35 | 36 | >>> pre_processor = PreProcessor(image_size=256, to_tensor=False) 37 | >>> output = pre_processor(image=image) 38 | >>> output["image"].shape 39 | (256, 256, 3) 40 | 41 | >>> pre_processor = PreProcessor(image_size=256, to_tensor=True) 42 | >>> output = pre_processor(image=image) 43 | >>> output["image"].shape 44 | torch.Size([3, 256, 256]) 45 | 46 | 47 | Transforms could be read from albumentations Compose object. 48 | >>> import albumentations as A 49 | >>> from albumentations.pytorch import ToTensorV2 50 | >>> config = A.Compose([A.Resize(512, 512), ToTensorV2()]) 51 | >>> pre_processor = PreProcessor(config=config, to_tensor=False) 52 | >>> output = pre_processor(image=image) 53 | >>> output["image"].shape 54 | (512, 512, 3) 55 | >>> type(output["image"]) 56 | numpy.ndarray 57 | 58 | Transforms could be deserialized from a yaml file. 59 | >>> transforms = A.Compose([A.Resize(1024, 1024), ToTensorV2()]) 60 | >>> A.save(transforms, "/tmp/transforms.yaml", data_format="yaml") 61 | >>> pre_processor = PreProcessor(config="/tmp/transforms.yaml") 62 | >>> output = pre_processor(image=image) 63 | >>> output["image"].shape 64 | torch.Size([3, 1024, 1024]) 65 | """ 66 | 67 | def __init__( 68 | self, 69 | config: Optional[Union[str, A.Compose]] = None, 70 | image_size: Optional[Union[int, Tuple]] = None, 71 | to_tensor: bool = True, 72 | ) -> None: 73 | self.config = config 74 | self.image_size = image_size 75 | self.to_tensor = to_tensor 76 | 77 | self.transforms = self.get_transforms() 78 | 79 | def get_transforms(self) -> A.Compose: 80 | """Get transforms from config or image size. 81 | 82 | Returns: 83 | A.Compose: List of albumentation transformations to apply to the 84 | input image. 85 | """ 86 | if self.config is None and self.image_size is None: 87 | raise ValueError( 88 | "Both config and image_size cannot be `None`. " 89 | "Provide either config file to de-serialize transforms " 90 | "or image_size to get the default transformations" 91 | ) 92 | 93 | transforms: A.Compose 94 | 95 | if self.config is None and self.image_size is not None: 96 | height, width = self._get_height_and_width() 97 | transforms = A.Compose( 98 | [ 99 | A.Resize(height=height, width=width, always_apply=True), 100 | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 101 | ToTensorV2(), 102 | ] 103 | ) 104 | 105 | if self.config is not None: 106 | if isinstance(self.config, str): 107 | transforms = A.load(filepath=self.config, data_format="yaml") 108 | elif isinstance(self.config, A.Compose): 109 | transforms = self.config 110 | else: 111 | raise ValueError("config could be either ``str`` or ``A.Compose``") 112 | 113 | if not self.to_tensor: 114 | if isinstance(transforms[-1], ToTensorV2): 115 | transforms = A.Compose(transforms[:-1]) 116 | 117 | # always resize to specified image size 118 | if not any(isinstance(transform, A.Resize) for transform in transforms) and self.image_size is not None: 119 | height, width = self._get_height_and_width() 120 | transforms = A.Compose([A.Resize(height=height, width=width, always_apply=True), transforms]) 121 | 122 | return transforms 123 | 124 | def __call__(self, *args, **kwargs): 125 | """Return transformed arguments.""" 126 | return self.transforms(*args, **kwargs) 127 | 128 | def _get_height_and_width(self) -> Tuple[Optional[int], Optional[int]]: 129 | """Extract height and width from image size attribute.""" 130 | if isinstance(self.image_size, int): 131 | return self.image_size, self.image_size 132 | if isinstance(self.image_size, tuple): 133 | return int(self.image_size[0]), int(self.image_size[1]) 134 | if self.image_size is None: 135 | return None, None 136 | raise ValueError("``image_size`` could be either int or Tuple[int, int]") 137 | -------------------------------------------------------------------------------- /patchcore/components/stats/multi_variate_gaussian.py: -------------------------------------------------------------------------------- 1 | """Multi Variate Gaussian Distribution.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Any, List, Optional 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | 12 | class MultiVariateGaussian(nn.Module): 13 | """Multi Variate Gaussian Distribution.""" 14 | 15 | def __init__(self, n_features, n_patches): 16 | super().__init__() 17 | 18 | self.register_buffer("mean", torch.zeros(n_features, n_patches)) 19 | self.register_buffer("inv_covariance", torch.eye(n_features).unsqueeze(0).repeat(n_patches, 1, 1)) 20 | 21 | self.mean: Tensor 22 | self.inv_covariance: Tensor 23 | 24 | @staticmethod 25 | def _cov( 26 | observations: Tensor, 27 | rowvar: bool = False, 28 | bias: bool = False, 29 | ddof: Optional[int] = None, 30 | aweights: Tensor = None, 31 | ) -> Tensor: 32 | """Estimates covariance matrix like numpy.cov. 33 | 34 | Args: 35 | observations (Tensor): A 1-D or 2-D array containing multiple variables and observations. 36 | Each row of `m` represents a variable, and each column a single 37 | observation of all those variables. Also see `rowvar` below. 38 | rowvar (bool): If `rowvar` is True (default), then each row represents a 39 | variable, with observations in the columns. Otherwise, the relationship 40 | is transposed: each column represents a variable, while the rows 41 | contain observations. Defaults to False. 42 | bias (bool): Default normalization (False) is by ``(N - 1)``, where ``N`` is the 43 | number of observations given (unbiased estimate). If `bias` is True, 44 | then normalization is by ``N``. These values can be overridden by using 45 | the keyword ``ddof`` in numpy versions >= 1.5. Defaults to False 46 | ddof (Optional, int): If not ``None`` the default value implied by `bias` is overridden. 47 | Note that ``ddof=1`` will return the unbiased estimate, even if both 48 | `fweights` and `aweights` are specified, and ``ddof=0`` will return 49 | the simple average. See the notes for the details. The default value 50 | is ``None``. 51 | aweights (Tensor): 1-D array of observation vector weights. These relative weights are 52 | typically large for observations considered "important" and smaller for 53 | observations considered less "important". If ``ddof=0`` the array of 54 | weights can be used to assign probabilities to observation vectors. (Default value = None) 55 | 56 | 57 | Returns: 58 | The covariance matrix of the variables. 59 | """ 60 | # ensure at least 2D 61 | if observations.dim() == 1: 62 | observations = observations.view(-1, 1) 63 | 64 | # treat each column as a data point, each row as a variable 65 | if rowvar and observations.shape[0] != 1: 66 | observations = observations.t() 67 | 68 | if ddof is None: 69 | if bias == 0: 70 | ddof = 1 71 | else: 72 | ddof = 0 73 | 74 | weights = aweights 75 | weights_sum: Any 76 | 77 | if weights is not None: 78 | if not torch.is_tensor(weights): 79 | weights = torch.tensor(weights, dtype=torch.float) # pylint: disable=not-callable 80 | weights_sum = torch.sum(weights) 81 | avg = torch.sum(observations * (weights / weights_sum)[:, None], 0) 82 | else: 83 | avg = torch.mean(observations, 0) 84 | 85 | # Determine the normalization 86 | if weights is None: 87 | fact = observations.shape[0] - ddof 88 | elif ddof == 0: 89 | fact = weights_sum 90 | elif aweights is None: 91 | fact = weights_sum - ddof 92 | else: 93 | fact = weights_sum - ddof * torch.sum(weights * weights) / weights_sum 94 | 95 | observations_m = observations.sub(avg.expand_as(observations)) 96 | 97 | if weights is None: 98 | x_transposed = observations_m.t() 99 | else: 100 | x_transposed = torch.mm(torch.diag(weights), observations_m).t() 101 | 102 | covariance = torch.mm(x_transposed, observations_m) 103 | covariance = covariance / fact 104 | 105 | return covariance.squeeze() 106 | 107 | def forward(self, embedding: Tensor) -> List[Tensor]: 108 | """Calculate multivariate Gaussian distribution. 109 | 110 | Args: 111 | embedding (Tensor): CNN features whose dimensionality is reduced via either random sampling or PCA. 112 | 113 | Returns: 114 | mean and inverse covariance of the multi-variate gaussian distribution that fits the features. 115 | """ 116 | device = embedding.device 117 | 118 | batch, channel, height, width = embedding.size() 119 | embedding_vectors = embedding.view(batch, channel, height * width) 120 | self.mean = torch.mean(embedding_vectors, dim=0) 121 | covariance = torch.zeros(size=(channel, channel, height * width), device=device) 122 | identity = torch.eye(channel).to(device) 123 | for i in range(height * width): 124 | covariance[:, :, i] = self._cov(embedding_vectors[:, :, i], rowvar=False) + 0.01 * identity 125 | 126 | # calculate inverse covariance as we need only the inverse 127 | self.inv_covariance = torch.linalg.inv(covariance.permute(2, 0, 1)) 128 | 129 | return [self.mean, self.inv_covariance] 130 | 131 | def fit(self, embedding: Tensor) -> List[Tensor]: 132 | """Fit multi-variate gaussian distribution to the input embedding. 133 | 134 | Args: 135 | embedding (Tensor): Embedding vector extracted from CNN. 136 | 137 | Returns: 138 | Mean and the covariance of the embedding. 139 | """ 140 | return self.forward(embedding) 141 | -------------------------------------------------------------------------------- /dataset_2d_sup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import collections 6 | import torch 7 | import torchvision 8 | import cv2 9 | from torch.utils import data 10 | import matplotlib.pyplot as plt 11 | import nibabel as nib 12 | from skimage.transform import resize 13 | import SimpleITK as sitk 14 | import math 15 | from PIL import Image, ImageOps, ImageFilter, ImageChops 16 | # from torchvision.transforms import Compose 17 | from glob import glob 18 | from pdb import set_trace 19 | import albumentations as A 20 | 21 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 22 | imagenet_std = np.array([0.229, 0.224, 0.225]) 23 | 24 | class DataSet2D(data.Dataset): 25 | def __init__(self, \ 26 | root, 27 | seg_dir, 28 | obj_name, 29 | label=255, 30 | size=(256,256), 31 | rot_angle=(0,360), 32 | transform=None): 33 | print("root:", root) 34 | 35 | self.root = root 36 | self.seg_dir = seg_dir 37 | self.size = size 38 | self.label = label 39 | self.rot_angle = rot_angle 40 | self.obj_name = obj_name 41 | self.transform = transform 42 | self.is_mirror = True 43 | self.files_pre = glob(f"{root}/orig_512/{obj_name}/train/good/*.png") 44 | self.files_pre.sort() 45 | self.files = [] 46 | for e in self.files_pre: 47 | fname = os.path.basename(e) 48 | lpath = f"{self.root}/{self.seg_dir}/{self.obj_name}/pred_{fname}" 49 | if os.path.exists(lpath): 50 | self.files.append(e) 51 | print('{} images are loaded!'.format(len(self.files))) 52 | 53 | self.aug = A.Compose([ 54 | A.ToGray(p=0.2), 55 | A.Posterize(p=0.2), 56 | A.Equalize(p=0.2), 57 | A.Sharpen(p=0.2), 58 | A.RandomBrightnessContrast(p=0.2), 59 | A.Solarize(p=0.2), 60 | A.ColorJitter(p=0.2) 61 | ]) 62 | 63 | def __len__(self): 64 | return len(self.files) 65 | 66 | def __getitem__(self, index): 67 | fpath = self.files[index] 68 | fname = os.path.basename(fpath) 69 | lpath = f"{self.root}/{self.seg_dir}/{self.obj_name}/pred_{fname}" 70 | random_angle = random.choice(list(range(self.rot_angle[0],self.rot_angle[1]))) 71 | # random_angle = random.choice(list(range(360))) 72 | try: 73 | im = Image.open(fpath) 74 | except: 75 | im = Image.open(fpath.replace(".png",".jpg")) 76 | # im = im.resize(self.size) 77 | im = im.convert("RGB")#.filter(ImageFilter.BLUR) 78 | # im = ImageChops.offset(im, -10, 0) 79 | im = im.rotate(random_angle) 80 | image = np.array(im) 81 | image = self.aug(image=image)['image'] 82 | image = np.array(image)/255 83 | 84 | ## coordinate vector 85 | w,h,_ = image.shape 86 | x = np.linspace(0,1,w) 87 | y = np.linspace(0,1,h) 88 | xx,yy = np.meshgrid(x,y) 89 | coord = np.stack([xx,yy,np.zeros_like(yy)],axis=2)*255 90 | coord_im = Image.fromarray(coord.astype(np.uint8)) 91 | coord = np.array(coord_im)[:,:,:2]/255 92 | coord = np.swapaxes(coord, 1, 2) 93 | coord = np.swapaxes(coord, 0, 1) # [2,w,h] 94 | coord_orig = torch.tensor(coord).float()*30 95 | coord_im = coord_im.rotate(random_angle) 96 | coord = np.array(coord_im)[:,:,:2]/255 97 | coord = np.swapaxes(coord, 1, 2) 98 | coord = np.swapaxes(coord, 0, 1) # [2,w,h] 99 | coord_rot = torch.tensor(coord).float()*30 100 | 101 | # normalize 102 | image = image - imagenet_mean 103 | image = image / imagenet_std 104 | 105 | image = np.swapaxes(image, 1, 2) 106 | image = np.swapaxes(image, 0, 1) # [3,w,h] 107 | 108 | im_label = Image.open(lpath) 109 | # im_label = im_label.resize(self.size, Image.NEAREST) 110 | im_label = im_label.rotate(random_angle, Image.NEAREST) 111 | 112 | label = np.array(im_label) 113 | 114 | if np.random.rand(1) <= 0.5: # flip W 115 | image = np.flip(image, axis=1).copy() 116 | label = np.flip(label, axis=0).copy() 117 | if np.random.rand(1) <= 0.5: 118 | image = np.flip(image, axis=2).copy() 119 | label = np.flip(label, axis=1).copy() 120 | 121 | image = torch.tensor(image).float() 122 | label = torch.tensor(label).long() 123 | 124 | sample = { 125 | "image": image, 126 | "label": label, 127 | "coord_rot": coord_rot, 128 | "coord_orig": coord_orig, 129 | "original_size": image.shape[1:], 130 | 131 | } 132 | return sample 133 | 134 | 135 | class ValDataSet2D(data.Dataset): 136 | def __init__(self, \ 137 | root, 138 | obj_name, 139 | label=255, 140 | ref_name="001", 141 | n_shot=1, 142 | n_zero=3, 143 | size=(256,256), 144 | transform=None, 145 | save_dir="orig_512_seg"): 146 | 147 | self.size = size 148 | self.root = root 149 | self.label = label 150 | self.obj_name = obj_name 151 | self.save_dir = save_dir 152 | self.transform = transform 153 | self.files = glob(f"{root}/orig_512/{obj_name}/*/*/*.png") 154 | self.files.sort() 155 | print('{} images are loaded!'.format(len(self.files))) 156 | 157 | def __len__(self): 158 | return len(self.files) 159 | 160 | def __getitem__(self, index): 161 | fpath = self.files[index] 162 | lpath = self.files[index].replace("Images","Annotations") 163 | 164 | # read nii file 165 | try: 166 | im = Image.open(fpath).convert("RGB")#.filter(ImageFilter.BLUR) 167 | except: 168 | im = Image.open(fpath.replace(".png",".jpg")).convert("RGB") 169 | 170 | # im = im.resize(self.size) 171 | image = np.array(im)/255 172 | ## coordinate vector 173 | w,h,_ = image.shape 174 | x = np.linspace(0,1,w) 175 | y = np.linspace(0,1,h) 176 | xx,yy = np.meshgrid(x,y) 177 | coord = np.stack([xx,yy,np.zeros_like(yy)],axis=2)*255 178 | coord_im = Image.fromarray(coord.astype(np.uint8)) 179 | coord = np.array(coord_im)[:,:,:2]/255 180 | coord = np.swapaxes(coord, 1, 2) 181 | coord = np.swapaxes(coord, 0, 1) # [2,w,h] 182 | coord_orig = torch.tensor(coord).float()*30 183 | 184 | 185 | image = image - imagenet_mean 186 | image = image / imagenet_std 187 | image = np.swapaxes(image, 1, 2) 188 | image = np.swapaxes(image, 0, 1) # [3,w,h] 189 | opath = fpath.replace("orig_512",self.save_dir) 190 | 191 | image = torch.tensor(image).float() 192 | sample = { 193 | "image": image, 194 | "opath": opath, 195 | "coord_orig": coord_orig, 196 | "original_size": image.shape[1:], 197 | 198 | } 199 | return sample 200 | # return image, np.zeros_like(image)#, coord_orig 201 | 202 | def my_collate(batch): 203 | return batch -------------------------------------------------------------------------------- /patchcore/components/base/anomaly_module.py: -------------------------------------------------------------------------------- 1 | """Base Anomaly Module for Training Task.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | import pdb 6 | import logging 7 | from abc import ABC 8 | from typing import Any, List, Optional 9 | 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.callbacks.base import Callback 12 | from torch import Tensor, nn 13 | 14 | from metrics import ( 15 | AdaptiveThreshold, 16 | AnomalibMetricCollection, 17 | AnomalyScoreDistribution, 18 | MinMax, 19 | ) 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | class AnomalyModule(pl.LightningModule, ABC): 24 | """AnomalyModule to train, validate, predict and test images. 25 | 26 | Acts as a base class for all the Anomaly Modules in the library. 27 | """ 28 | 29 | def __init__(self): 30 | super().__init__() 31 | logger.info("Initializing %s model.", self.__class__.__name__) 32 | 33 | self.save_hyperparameters() 34 | self.model: nn.Module 35 | self.loss: Tensor 36 | self.callbacks: List[Callback] 37 | 38 | self.adaptive_threshold: bool 39 | 40 | self.image_threshold = AdaptiveThreshold() #.cpu() 41 | self.pixel_threshold = AdaptiveThreshold() #.cpu() 42 | 43 | self.training_distribution = AnomalyScoreDistribution() #.cpu() 44 | self.min_max = MinMax() #.cpu() 45 | 46 | # Create placeholders for image and pixel metrics. 47 | # If set from the config file, MetricsConfigurationCallback will 48 | # create the metric collections upon setup. 49 | 50 | self.image_metrics: AnomalibMetricCollection 51 | self.pixel_metrics: AnomalibMetricCollection 52 | 53 | def forward(self, batch): # pylint: disable=arguments-differ 54 | """Forward-pass input tensor to the module. 55 | 56 | Args: 57 | batch (Tensor): Input Tensor 58 | 59 | Returns: 60 | Tensor: Output tensor from the model. 61 | """ 62 | return self.model(batch) 63 | 64 | def validation_step(self, batch, batch_idx) -> dict: # type: ignore # pylint: disable=arguments-differ 65 | """To be implemented in the subclasses.""" 66 | raise NotImplementedError 67 | 68 | def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int] = None) -> Any: 69 | """Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. 70 | 71 | By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. 72 | Override to add any processing logic. 73 | 74 | Args: 75 | batch (Tensor): Current batch 76 | batch_idx (int): Index of current batch 77 | _dataloader_idx (int): Index of the current dataloader 78 | 79 | Return: 80 | Predicted output 81 | """ 82 | outputs = self.validation_step(batch, batch_idx) 83 | self._post_process(outputs) 84 | outputs["pred_labels"] = outputs["pred_scores"] >= self.image_threshold.value 85 | if "anomaly_maps" in outputs.keys(): 86 | outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value 87 | return outputs 88 | 89 | def test_step(self, batch, _): # pylint: disable=arguments-differ 90 | """Calls validation_step for anomaly map/score calculation. 91 | 92 | Args: 93 | batch (Tensor): Input batch 94 | _: Index of the batch. 95 | 96 | Returns: 97 | Dictionary containing images, features, true labels and masks. 98 | These are required in `validation_epoch_end` for feature concatenation. 99 | """ 100 | return self.predict_step(batch, _) 101 | 102 | def validation_step_end(self, val_step_outputs): # pylint: disable=arguments-differ 103 | """Called at the end of each validation step.""" 104 | # self._outputs_to_cpu(val_step_outputs) 105 | self._post_process(val_step_outputs) 106 | return val_step_outputs 107 | 108 | def test_step_end(self, test_step_outputs): # pylint: disable=arguments-differ 109 | """Called at the end of each test step.""" 110 | # self._outputs_to_cpu(test_step_outputs) 111 | self._post_process(test_step_outputs) 112 | return test_step_outputs 113 | 114 | def validation_epoch_end(self, outputs): 115 | """Compute threshold and performance metrics. 116 | 117 | Args: 118 | outputs: Batch of outputs from the validation step 119 | """ 120 | if self.adaptive_threshold: 121 | self._compute_adaptive_threshold(outputs) 122 | self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs) 123 | self._log_metrics() 124 | 125 | def test_epoch_end(self, outputs): 126 | """Compute and save anomaly scores of the test set. 127 | 128 | Args: 129 | outputs: Batch of outputs from the validation step 130 | """ 131 | self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs) 132 | self.results = {} 133 | img_result = self.image_metrics.compute() 134 | for k, v in img_result.items(): 135 | self.results[k] = v 136 | pix_result = self.pixel_metrics.compute() 137 | for k, v in pix_result.items(): 138 | self.results[k] = v 139 | self._log_metrics() 140 | 141 | def _compute_adaptive_threshold(self, outputs): 142 | self._collect_outputs(self.image_threshold, self.pixel_threshold, outputs) 143 | img_thresh = self.image_threshold.compute() 144 | self.image_threshold.value = img_thresh 145 | if "mask" in outputs[0].keys() and "anomaly_maps" in outputs[0].keys(): 146 | pix_thresh = self.pixel_threshold.compute() 147 | self.pixel_threshold.value = pix_thresh 148 | else: 149 | # self.pixel_threshold.value = self.image_threshold.value 150 | pix_thresh = img_thresh 151 | self.pixel_threshold.value = pix_thresh 152 | 153 | def _collect_outputs(self, image_metric, pixel_metric, outputs): 154 | for output in outputs: 155 | # image_metric.cpu() 156 | image_metric.update(output["pred_scores"], output["label"].int()) 157 | if "mask" in output.keys() and "anomaly_maps" in output.keys(): 158 | # pixel_metric.cpu() 159 | pixel_metric.update(output["anomaly_maps"], output["mask"].int()) 160 | 161 | def _post_process(self, outputs): 162 | """Compute labels based on model predictions.""" 163 | if "pred_scores" not in outputs and "anomaly_maps" in outputs: 164 | outputs["pred_scores"] = ( 165 | outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values 166 | ) 167 | 168 | def _outputs_to_cpu(self, output): 169 | # for output in outputs: 170 | for key, value in output.items(): 171 | if isinstance(value, Tensor): 172 | output[key] = value #.cpu() 173 | 174 | def _log_metrics(self): 175 | """Log computed performance metrics.""" 176 | if self.pixel_metrics.update_called: 177 | self.log_dict(self.pixel_metrics, prog_bar=True) 178 | self.log_dict(self.image_metrics, prog_bar=False) 179 | else: 180 | self.log_dict(self.image_metrics, prog_bar=True) 181 | -------------------------------------------------------------------------------- /patchcore/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from sklearn.metrics import roc_auc_score 4 | import matplotlib.pyplot as plt 5 | import cv2 6 | from PIL import Image 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from torchvision.transforms.functional import to_pil_image 13 | from torch.utils.data import DataLoader 14 | import csv 15 | 16 | def feature_extraction(model, train_datasets, args): 17 | train_datasets.unlabeled = False 18 | train_loader = DataLoader(train_datasets, batch_size=1, shuffle=False, num_workers=4) 19 | extracted_features = [] 20 | segmentation = [] 21 | labels = [] 22 | for idx, batch in enumerate(train_loader): 23 | print(f'Extract features {idx+1}/{len(train_loader)}', end='\r') 24 | with torch.no_grad(): 25 | img = batch[0].cuda() 26 | embed = model.model(img) 27 | embed = embed.reshape(img.shape[0], int(args.size/8), int(args.size/8), -1) 28 | extracted_features.append(embed.detach().cpu()) 29 | segmentation.append(batch[1]) 30 | labels.append(batch[2]) 31 | print('') 32 | 33 | extracted_features = torch.cat(extracted_features, dim=0) 34 | segmentation = torch.cat(segmentation, dim=0) 35 | labels = torch.cat(labels, dim=0) 36 | 37 | return extracted_features, segmentation, labels 38 | 39 | def compute_anomaly_score(test_datasets, model, args): 40 | test_datasets.unlabeled = False 41 | test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False, num_workers=4) 42 | outputs = [] 43 | for idx, batch in enumerate(test_loader): 44 | print(f'Extract features from testset - {idx + 1}/{len(test_loader)}', end='\r') 45 | img = batch[0].cuda() 46 | anomaly_map, anomaly_score, _ = model.model(img) 47 | output = {} 48 | output['anomaly_scores'] = anomaly_score.detach().cpu() 49 | output['label'] = batch[2] 50 | outputs.append(output) 51 | 52 | '''Visualize''' 53 | # fig = plt.figure(figsize=(20, 5)) 54 | # invnorm_img = (img[0].detach().cpu() * torch.Tensor(args.std_train).reshape(3, 1, 1)) + torch.Tensor(args.mean_train).reshape(3, 1, 1) 55 | # plt.subplot(1, 4, 1) 56 | # plt.title(str(anomaly_score.item())) 57 | # plt.imshow(to_pil_image(invnorm_img)) 58 | # plt.axis('off') 59 | # 60 | # plt.subplot(1, 4, 2) 61 | # anomaly_map = anomaly_map.clone().squeeze().detach().cpu().numpy() 62 | # anomaly_map_ = anomaly_map - np.min(anomaly_map) 63 | # anomaly_map = anomaly_map_ / np.max(anomaly_map) 64 | # anomaly_map = np.uint8(255 * anomaly_map) 65 | # heatmap = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET) 66 | # heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 67 | # plt.imshow(heatmap) 68 | # plt.axis('off') 69 | # 70 | # plt.subplot(1, 4, 3) 71 | # plt.imshow(np.uint8(heatmap * 0.3 + np.transpose(invnorm_img.numpy(), (1, 2, 0)) * 255 * 0.5)) 72 | # plt.axis('off') 73 | # 74 | # plt.subplot(1, 4, 4) 75 | # mask = batch[1].clone().squeeze().numpy() 76 | # mask = np.uint8(255 * mask) 77 | # mask = Image.fromarray(mask) 78 | # plt.imshow(mask) 79 | # plt.axis('off') 80 | # 81 | # plt.savefig(os.path.join(args.result_path, 'prediction', 'result-' + batch[4][0] + '-' + batch[3][0] + '.png')) 82 | # plt.clf() 83 | # plt.close() 84 | 85 | print('') 86 | 87 | anomaly_scores = [] 88 | labels = [] 89 | for output in outputs: 90 | anomaly_scores.append(output['anomaly_scores'].item()) 91 | labels.append(output['label'].item()) 92 | 93 | anomaly_scores = np.array(anomaly_scores) 94 | labels = np.array(labels) 95 | img_auc = roc_auc_score(labels, anomaly_scores) 96 | print("Image AUROC: ", img_auc) 97 | 98 | f = open(os.path.join(args.result_path, 'result.txt'), 'w') 99 | f.write(f'Image AUROC: {img_auc}\n') 100 | f.close() 101 | 102 | def compute_anomaly_score_standardization(test_datasets, model, mean, std, is_training, args): 103 | test_datasets.unlabeled = False 104 | test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False, num_workers=4) 105 | outputs = [] 106 | original_memory_bank = model.model.memory_bank 107 | for idx, batch in enumerate(test_loader): 108 | if is_training: 109 | model.model.memory_bank = torch.cat([original_memory_bank[:idx * int(args.size/8) * int(args.size/8)], 110 | original_memory_bank[(idx+1) * int(args.size/8) * int(args.size/8):]], 111 | dim = 0) 112 | print(f'Extract features from testset - {idx + 1}/{len(test_loader)}', end='\r') 113 | img = batch[0].cuda() 114 | anomaly_map, anomaly_score, _ = model.model(img, mean, std) 115 | output = {} 116 | output['distance'] = model.model.patch_scores 117 | output['anomaly_maps_interpolate'] = anomaly_map.detach().cpu() 118 | output['anomaly_maps'] = _.detach().cpu() 119 | output['anomaly_scores'] = anomaly_score.detach().cpu() 120 | output['label'] = batch[2] 121 | output['name'] = batch[3][0] 122 | 123 | os.makedirs(os.path.join(args.result_path, batch[3][0].split('/')[0], batch[3][0].split('/')[1]), exist_ok=True) 124 | torch.save(output, os.path.join(args.result_path, batch[3][0] + '.pt')) 125 | del output['anomaly_maps'] 126 | del output['anomaly_maps_interpolate'] 127 | outputs.append(output) 128 | print('') 129 | 130 | if not is_training: 131 | if args.dataset == 'mvtec_loco': 132 | anomal_type = ['both', 'logical', 'structural'] 133 | else: 134 | anomal_type = ['both'] 135 | for i in anomal_type: 136 | if i == 'both': 137 | anomaly_scores = [] 138 | labels = [] 139 | f = open(os.path.join(args.result_path, 'ADscore'+'_all.txt'), 'w') 140 | wr = csv.writer(f) 141 | wr.writerow(['Name', 'Score', 'Label']) 142 | for output in outputs: 143 | anomaly_scores.append(output['anomaly_scores'].item()) 144 | labels.append(output['label'].item()) 145 | wr.writerow([output['name'], output['anomaly_scores'].item(), output['label'].item()]) 146 | f.close() 147 | 148 | anomaly_scores = np.array(anomaly_scores) 149 | labels = np.array(labels) 150 | img_auc = roc_auc_score(labels, anomaly_scores) 151 | print("Image AUROC: ", img_auc) 152 | 153 | f = open(os.path.join(args.result_path, 'result'+'_all.txt'), 'w') 154 | f.write(f'Image AUROC: {img_auc}\n') 155 | f.close() 156 | else: 157 | anomaly_scores = [] 158 | labels = [] 159 | f = open(os.path.join(args.result_path, 'ADscore_'+i+ '.txt'), 'w') 160 | wr = csv.writer(f) 161 | wr.writerow(['Name', 'Score', 'Label']) 162 | for output in outputs: 163 | if 'good' in output['name'] or i in output['name']: 164 | anomaly_scores.append(output['anomaly_scores'].item()) 165 | labels.append(output['label'].item()) 166 | wr.writerow([output['name'], output['anomaly_scores'].item(), output['label'].item()]) 167 | f.close() 168 | 169 | anomaly_scores = np.array(anomaly_scores) 170 | labels = np.array(labels) 171 | img_auc = roc_auc_score(labels, anomaly_scores) 172 | print("Image AUROC: ", img_auc) 173 | 174 | f = open(os.path.join(args.result_path, 'result_'+i+'.txt'), 'w') 175 | f.write(f'Image AUROC: {img_auc}\n') 176 | f.close() -------------------------------------------------------------------------------- /patchcore/run.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | import os 3 | 4 | parser = ArgumentParser() 5 | parser.add_argument("--gpu", type=str, default='0') 6 | parser.add_argument("--dataset", type=str, default = 'mvtec_loco', choices=['mvtec_ad', 'mpdd', 'mtd', 'mvtec_loco', 'visa']) 7 | parser.add_argument("--mode", type=str, default='train', choices=['train', 'test', 'standardization']) 8 | parser.add_argument("--category", type=str) 9 | parser.add_argument("--size", type=int, default=256) #512 10 | parser.add_argument("--coreset_sampling_ratio", type=float, default=1.0) 11 | 12 | parser.add_argument("--datapath", type=str, default = '/media/NAS/nas_187/datasets/MVTec_AD') #'/media/NAS/nas_187/soopil/data/stanford/LOCO_AD_pre' 13 | parser.add_argument("--result_path", type=str, default='./result') 14 | parser.add_argument("--backbone", type=str, default = 'wide_resnet101_2', choices = ['resnet18', 'resnet50', 'wide_resnet50_2', 'wide_resnet101_2']) 15 | 16 | args = parser.parse_args() 17 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 18 | 19 | import torch 20 | import torch.multiprocessing 21 | from torch.utils.data import DataLoader 22 | from lightning_model import Patchcore 23 | from torchvision import transforms 24 | 25 | from PIL import Image 26 | import warnings 27 | import random 28 | import numpy as np 29 | 30 | from utils.dataset import MVTecLOCODataset, MVTecADDataset, VisADataset 31 | from utils.utils import compute_anomaly_score, compute_anomaly_score_standardization, feature_extraction 32 | 33 | seed = 2 34 | random.seed(seed) 35 | np.random.seed(seed) 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | warnings.filterwarnings("ignore") 42 | torch.multiprocessing.set_sharing_strategy('file_system') 43 | 44 | result_path = os.path.join(args.result_path, args.dataset, args.backbone, args.category) 45 | if args.size != 256: 46 | result_path = result_path.replace(args.category, args.category+'_'+str(args.size)) 47 | args.result_path = result_path 48 | print(result_path) 49 | 50 | if not os.path.exists(result_path): 51 | os.makedirs(result_path) 52 | 53 | '''Load Dataset''' 54 | args.mean_train = [0.485, 0.456, 0.406] 55 | args.std_train = [0.229, 0.224, 0.225] 56 | inv_normalize = transforms.Normalize(mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255], 57 | std=[1 / 0.229, 1 / 0.224, 1 / 0.255]) 58 | data_transforms = transforms.Compose([ 59 | transforms.Resize((args.size, args.size), Image.ANTIALIAS), 60 | transforms.ToTensor(), 61 | transforms.CenterCrop(args.size), 62 | transforms.Normalize(mean=args.mean_train, std=args.std_train)]) 63 | gt_transforms = transforms.Compose([ 64 | transforms.Resize((args.size, args.size)), 65 | transforms.ToTensor(), 66 | transforms.CenterCrop(args.size)]) 67 | 68 | if args.dataset == 'mvtec_loco': 69 | args.datapath = '/media/NAS/nas_187/soopil/data/stanford/LOCO_AD_pre' 70 | train_datasets = MVTecLOCODataset(root=args.datapath, 71 | transform=data_transforms, 72 | phase='train', 73 | args=args) 74 | test_datasets = MVTecLOCODataset(root=args.datapath, 75 | transform=data_transforms, 76 | phase='test', 77 | args=args, 78 | anomal_type = None) 79 | elif args.dataset == 'mvtec_ad': 80 | args.datapath = '/media/NAS/nas_187/datasets/MVTec_AD' 81 | train_datasets = MVTecADDataset(root=args.datapath, 82 | transform=data_transforms, 83 | phase='train', 84 | args=args) 85 | test_datasets = MVTecADDataset(root=args.datapath, 86 | transform=data_transforms, 87 | phase='test', 88 | args=args, 89 | anomal_type = None) 90 | 91 | elif args.dataset == 'visa': 92 | args.datapath = '/media/NAS/nas_187/datasets/VisA/split/1cls' 93 | train_datasets = VisADataset(root=args.datapath, 94 | transform=data_transforms, 95 | phase='train', 96 | args=args) 97 | test_datasets = VisADataset(root=args.datapath, 98 | transform=data_transforms, 99 | phase='test', 100 | args=args, 101 | anomal_type = None) 102 | 103 | train_loader = DataLoader(train_datasets, batch_size=1, shuffle=False, num_workers=4) 104 | test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False, num_workers=4) 105 | 106 | '''Load pre-trained model''' 107 | if args.backbone == 'resnet18': 108 | args.indim = 384 109 | args.fm = 1 110 | elif args.backbone == 'resnet50': 111 | args.indim = 1536 112 | args.fm = 4 113 | elif args.backbone == 'wide_resnet50_2': 114 | args.indim = 1536 115 | args.fm = 4 116 | elif args.backbone == 'wide_resnet101_2': 117 | args.indim = 1536 118 | args.fm = 4 119 | 120 | model = Patchcore(input_size=(args.size, args.size), backbone=args.backbone, layers=['layer2', 'layer3']).cuda() 121 | model.model.feature_extractor.eval() 122 | 123 | if args.mode == 'train': 124 | if args.coreset_sampling_ratio != 1 and os.path.exists(os.path.join(args.result_path, args.backbone+'_'+str(args.coreset_sampling_ratio)+'.pt')): 125 | extracted_features = torch.load(os.path.join(args.result_path, args.backbone+'_'+str(args.coreset_sampling_ratio)+'.pt')) 126 | model.model.memory_bank = extracted_features 127 | else: 128 | '''Training''' 129 | extracted_features, segmentation, labels = feature_extraction(model, train_datasets, args) 130 | 131 | '''Build memory bank''' 132 | model.model.subsample_embedding(extracted_features.reshape(-1, extracted_features.shape[-1]), args.coreset_sampling_ratio) 133 | torch.save(model.model.memory_bank, os.path.join(args.result_path, args.backbone+'_'+str(args.coreset_sampling_ratio)+'.pt')) 134 | 135 | print(f'Memory bank : {model.model.memory_bank.shape}') 136 | 137 | # if not os.path.exists(os.path.join(result_path, 'prediction')): 138 | # os.makedirs(os.path.join(result_path, 'prediction')) 139 | 140 | '''Compute anomaly score''' 141 | model.model.feature_extractor.eval() 142 | with torch.no_grad(): 143 | model.model.training = False 144 | compute_anomaly_score(test_datasets, model, args) 145 | 146 | elif args.mode == 'standardization': 147 | '''Training''' 148 | extracted_features, segmentation, labels = feature_extraction(model, train_datasets, args) 149 | 150 | '''Build memory bank''' 151 | model.model.subsample_embedding(extracted_features.reshape(-1, extracted_features.shape[-1]), 152 | args.coreset_sampling_ratio) 153 | torch.save(model.model.memory_bank, 154 | os.path.join(args.result_path, 'memory_bank_' + str(args.coreset_sampling_ratio) + '.pt')) 155 | 156 | '''Standardization''' 157 | memory_bank = model.model.memory_bank # N, 1536 158 | mean = torch.mean(memory_bank, dim = 0) 159 | std = torch.std(memory_bank, dim = 0) 160 | 161 | '''For original patchcore''' 162 | mean = torch.zeros_like(mean) 163 | std = torch.ones_like(std) 164 | 165 | memory_bank = (memory_bank - mean) / std 166 | # torch.save(model.model.memory_bank, os.path.join(args.result_path, 'memory_bank_std' + str(args.coreset_sampling_ratio) + '.pt')) 167 | 168 | model.model.memory_bank = memory_bank 169 | print(f'Memory bank : {model.model.memory_bank.shape}') 170 | 171 | '''Compute distance for train set''' 172 | model.model.feature_extractor.eval() 173 | with torch.no_grad(): 174 | model.model.training = False 175 | compute_anomaly_score_standardization(train_datasets, model, mean, std, True, args) 176 | 177 | '''Compute anomaly score''' 178 | model.model.feature_extractor.eval() 179 | with torch.no_grad(): 180 | model.model.training = False 181 | compute_anomaly_score_standardization(test_datasets, model, mean, std, False, args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PSAD: Few Shot Part Segmentation Reveals Compositional Logic for Industrial Anomaly Detection 2 | 3 | This repository includes a pytorch implementation of the paper "Few Shot Part Segmentation Reveals Compositional Logic for Industrial Anomaly Detection" accepted in [AAAI2024](https://ojs.aaai.org/index.php/AAAI/article/view/28703). 4 | 5 | ## Abstract 6 | 7 | Logical anomalies (LA) refer to data violating underlying logical constraints e.g., the quantity, arrangement, or composition of components within an image. Detecting accurately such anomalies requires models to reason about various component types through segmentation. However, curation of pixel-level annotations for semantic segmentation is both time-consuming and expensive. Although there are some prior few-shot or unsupervised co-part segmentation algorithms, they often fail on images with industrial object. These images have components with similar textures and shapes, and a precise differentiation proves challenging. In this study, we introduce a novel component segmentation model for LA detection that leverages a few labeled samples and unlabeled images sharing logical constraints. To ensure consistent segmentation across unlabeled images, we employ a histogram matching loss in conjunction with an entropy loss. As segmentation predictions play a crucial role, we propose to enhance both local and global sample validity detection by capturing key aspects from visual semantics via three memory banks: class histograms, component composition embeddings and patch-level representations. For effective LA detection, we propose an adaptive scaling strategy to standardize anomaly scores from different memory banks in inference. Extensive experiments on the public benchmark MVTec LOCO AD reveal our method achieves 98.1% AUROC in LA detection vs. 89.6% from competing methods. 8 | 9 | ## Few Shot Part Segmentation 10 | ![fig_fss](https://github.com/oopil/PSAD_logical_anomaly_detection/assets/44998223/68056d95-f62e-438c-804b-e3c9b001e018) 11 | 12 | ## PSAD 13 | ![fig_psad](https://github.com/oopil/PSAD_logical_anomaly_detection/assets/44998223/8e0193f3-d713-4c11-b43a-c14163ffb99f) 14 | 15 | ## Dataset 16 | - We use MVTec LOCO AD Dataset for evaluation. For details of this dataset, please see [here](https://www.mvtec.com/company/research/datasets/mvtec-loco). 17 | 18 | ## Training 19 | To implement our proposed Part Segmentation-based Anomaly Detection (PSAD), follow procedure shown as below. If you want to skip the preprocessing and segmentation process 1~4, download [preprocessed images](https://drive.google.com/file/d/1uE7pXn6XwGHxiHhNv7iCb-gj62PC-HkM/view?usp=sharing), [segmentation maps](https://drive.google.com/file/d/1Esa06exQ2cH3c3GIozpdU2-qmwKiVBdr/view?usp=sharing) and use them as inputs. 20 | 21 | 1. Annotate few labeled images per category. 22 | 23 | We used an [annotation tool](https://www.makesense.ai/) to make annotations for few images. In most cases, we used 5 labeled images. But when there are multiples product types (e.g., juicie bottle and splicing connectors), we used 1 labeled images per type, i.e., total 3 images. 24 | 25 | 2. Preprocess images. 26 | 27 | Convert label format, crop and resize the image. Please check the code in `preprocess` directory. 28 | 29 | 3. Train a few shot part segmentation model using few labeld and numerous unlabeled images. Using different levels of features of pretrained encoder can lead to different segmentation results. 30 | ``` 31 | CUDA_VISIBLE_DEVICES=[GPU_ID] python finetune_cnn_coord.py 32 | --n_shot 5 33 | --num_epochs 100 34 | --obj_name [OBJ_NAME] 35 | --snapshot_dir [DIR] 36 | ``` 37 | 38 | 4. Train a segmentation model again using predicted pseudo label. 39 | ``` 40 | run_unet.sh [GPU_ID] [DIR] 41 | ``` 42 | 43 | 5. Save predictions of [PatchCore](https://github.com/amazon-science/patchcore-inspection). Please check `patchcore` directory. Or you can use the [scores](https://drive.google.com/file/d/1Q8RVR8rDV6oOMhRa_8fEYBM9OVQIn2eM/view?usp=drive_link) already obtained using PatchCore. It includes anomaly scores of all train and test data predicted using PatchCore without the coreset sampling, and the scores of train data by treating each as a test data for 'adaptive scaling'. Each .pth file (Ex. dict = train/good/000.pt) has a dictionary with 5 keys ['distance', 'anomaly_map_interpolate', 'anomaly_maps', 'anomaly_scores', 'label', 'name']. 'anomaly_scores' is the score for the adaptive scaling. 44 | 45 | 6. Train and test PSAD based on segmentation results. You can choose memory type like "hcp". h,c, and p denote histogram, composition, and patch memory banks. 46 | ``` 47 | ./run_ad.sh [GPU_ID] [DIR] "hcp" "max" 48 | ``` 49 | 50 | ## Testing 51 | Our proposed PSAD and other comparison methods are evaluated using image AUROC scores. As there is no ground truth for segmentation task, we indirectly evaluated the segmentation performance using anomaly detection performance. In general, accurate segmentation correlates with anomaly detection performance. 52 | 53 | ## Results 54 | - Comparison with state-of-the-art methods 55 | 56 | | Category | | PatchCore | RD4AD | DRAEM | ST | AST | GCAD | SINBAD | ComAD | SLSG | PSAD | 57 | |:--------:|:---------------------:|:---------:|:-----:|:-----:|:-----:|:-----:|:------:|:------:|:-----:|:-----:|:------:| 58 | | LA | Breakfast Box | 74.8 | 66.7 | 75.1 | 68.9 | 80.0 | 87.0 | 96.5 | 91.1 | - | 100.0 | 59 | | | Juice Bottle | 93.9 | 93.6 | 97.8 | 82.9 | 91.6 | 100.0 | 96.6 | 95.0 | - | 99.1 | 60 | | | Pushpins | 63.6 | 63.6 | 55.7 | 59.5 | 65.1 | 97.5 | 83.4 | 95.7 | - | 100.0 | 61 | | | Screw Bag | 57.8 | 54.1 | 56.2 | 55.5 | 80.1 | 56.0 | 78.6 | 71.9 | - | 99.3 | 62 | | | Splicing Connectors | 79.2 | 75.3 | 75.2 | 65.4 | 81.8 | 89.7 | 89.3 | 93.3 | - | 91.9 | 63 | | | Average (LA) | 74.0 | 70.7 | 72.0 | 66.4 | 79.7 | 86.0 | 88.9 | 89.4 | 89.6 | 98.1 | 64 | | SA | Breakfast Box | 80.1 | 60.3 | 85.4 | 68.4 | 79.9 | 80.9 | 87.5 | 81.6 | - | 84.9 | 65 | | | Juice Bottle | 98.5 | 95.2 | 90.8 | 99.3 | 95.5 | 98.9 | 93.1 | 98.2 | - | 98.2 | 66 | | | Pushpins | 87.9 | 84.8 | 81.5 | 90.3 | 77.8 | 74.9 | 74.2 | 91.1 | - | 89.8 | 67 | | | Screw Bag | 92.0 | 89.2 | 85.0 | 87.0 | 95.9 | 70.5 | 92.2 | 88.5 | - | 95.7 | 68 | | | Splicing Connectors | 88.0 | 95.9 | 95.5 | 96.8 | 89.4 | 78.3 | 76.7 | 94.9 | - | 89.3 | 69 | | | Average (SA) | 89.3 | 85.1 | 87.6 | 88.4 | 87.7 | 80.7 | 84.7 | 90.9 | 91.4 | 91.6 | 70 | | Average | | 81.7 | 77.9 | 79.8 | 77.4 | 83.7 | 83.4 | 86.8 | 90.1 | 90.3 | 94.0 | 71 | 72 | 73 | - PSAD performance depending on segmentation model 74 | 75 | | Models | LA | SA | 76 | |-----------------------------------|------|------| 77 | | SCOPS (Hung et al. 2019) | 82.5 | 90.2 | 78 | | Part-Assembly (Gao et al. 2021) | 80.3 | 85.6 | 79 | | SegGPT (Wang et al. 2023) | 88.7 | 87.2 | 80 | | VAT (Hong et al. 2022) | 79.2 | 87.8 | 81 | | RePRI (Boudiaf et al. 2021) | 83.6 | 88.4 | 82 | | Ours (L_sup) | 95.9 | 89.6 | 83 | | Ours (L_sup + L_H)) | 96.3 | 90 | 84 | | Ours (L_sup + L_H + L_hist) | 98.1 | 91.6 | 85 | 86 | - Qualitative evaluation on FSS models 87 | ![fss_seg](https://github.com/oopil/PSAD_logical_anomaly_detection/assets/44998223/6cb07231-d4d3-4dff-a576-13743008ab38) 88 | 89 | - Ablation study on multiple memory banks and adaptive scaling (AS) 90 | 91 | | M_hist | M_comp | M_patch | AS | LA | SA | 92 | |:------:|:------:|:-------:|:--:|:----:|:----:| 93 | | ✓ | | | | 94.2 | 71.1 | 94 | | | ✓ | | | 90.9 | 85.4 | 95 | | | | ✓ | | 73.9 | 89.3 | 96 | | ✓ | ✓ | ✓ | | 96.8 | 87.6 | 97 | | ✓ | ✓ | ✓ | ✓ | 98.1 | 91.6 | 98 | 99 | - PSAD performance using less normal images (same segmentation model is used.) 100 | 101 | | N_M | 100% | 50% | 25% | 12.5% | 102 | |:---------:|:----:|:----:|:----:|:-----:| 103 | | Avg AUROC | 97.4 | 97.1 | 96.6 | 96.2 | 104 | 105 | - Visualizing hitrograms from different memory banks 106 | ![qual_hist](https://github.com/oopil/PSAD_logical_anomaly_detection/assets/44998223/d299e1ac-6683-42f6-b446-9835adbe01d2) 107 | 108 | 112 | ## Acknowledgments 113 | Our work was inspired by many previous works related to industrial anomaly detection and few shot segmentation including [PatchCore](https://github.com/amazon-science/patchcore-inspection), [RePRI](https://github.com/mboudiaf/RePRI-for-Few-Shot-Segmentation/tree/master). Thanks to their inspiring works. 114 | -------------------------------------------------------------------------------- /patchcore/metrics/aupro.py: -------------------------------------------------------------------------------- 1 | """Implementation of AUPRO score based on TorchMetrics.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Any, Callable, List, Optional, Tuple 7 | 8 | import torch 9 | from kornia.contrib import connected_components 10 | from matplotlib.figure import Figure 11 | from torch import Tensor 12 | from torchmetrics import Metric 13 | from torchmetrics.functional import auc, roc 14 | from torchmetrics.utilities.data import dim_zero_cat 15 | 16 | from .plotting_utils import plot_figure 17 | 18 | 19 | class AUPRO(Metric): 20 | """Area under per region overlap (AUPRO) Metric.""" 21 | 22 | is_differentiable: bool = False 23 | higher_is_better: Optional[bool] = None 24 | full_state_update: bool = False 25 | preds: List[Tensor] 26 | target: List[Tensor] 27 | 28 | def __init__( 29 | self, 30 | compute_on_step: bool = True, 31 | dist_sync_on_step: bool = False, 32 | process_group: Optional[Any] = None, 33 | dist_sync_fn: Callable = None, 34 | fpr_limit: float = 0.3, 35 | ) -> None: 36 | super().__init__( 37 | compute_on_step=compute_on_step, 38 | dist_sync_on_step=dist_sync_on_step, 39 | process_group=process_group, 40 | dist_sync_fn=dist_sync_fn, 41 | ) 42 | 43 | self.add_state("preds", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable 44 | self.add_state("target", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable 45 | self.register_buffer("fpr_limit", torch.tensor(fpr_limit)) 46 | 47 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 48 | """Update state with new values. 49 | 50 | Args: 51 | preds (Tensor): predictions of the model 52 | target (Tensor): ground truth targets 53 | """ 54 | self.target.append(target) 55 | self.preds.append(preds) 56 | 57 | def _compute(self) -> Tuple[Tensor, Tensor]: 58 | """Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit. 59 | 60 | It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall 61 | PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction. 62 | 63 | Raises: 64 | ValueError: ValueError is raised if self.target doesn't conform with requirements imposed by kornia for 65 | connected component analysis. 66 | 67 | Returns: 68 | Tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. 69 | """ 70 | target = dim_zero_cat(self.target) 71 | preds = dim_zero_cat(self.preds) 72 | 73 | # check and prepare target for labeling via kornia 74 | if target.min() < 0 or target.max() > 1: 75 | raise ValueError( 76 | ( 77 | f"kornia.contrib.connected_components expects input to lie in the interval [0, 1], but found " 78 | f"interval was [{target.min()}, {target.max()}]." 79 | ) 80 | ) 81 | target = target.unsqueeze(1) # kornia expects N1HW format 82 | target = target.type(torch.float) # kornia expects FloatTensor 83 | cca = connected_components( 84 | target, num_iterations=1000 85 | ) # Need higher thresholds this to avoid oversegmentation. 86 | 87 | preds = preds.flatten() 88 | cca = cca.flatten() 89 | target = target.flatten() 90 | 91 | # compute the global fpr-size 92 | fpr: Tensor = roc(preds, target)[0] # only need fpr 93 | output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) 94 | 95 | # compute the PRO curve by aggregating per-region tpr/fpr curves/values. 96 | tpr = torch.zeros(output_size, device=preds.device, dtype=torch.float) 97 | fpr = torch.zeros(output_size, device=preds.device, dtype=torch.float) 98 | new_idx = torch.arange(0, output_size, device=preds.device, dtype=torch.float) 99 | 100 | # Loop over the labels, computing per-region tpr/fpr curves, and aggregating them. 101 | # Note that, since the groundtruth is different for every all to `roc`, we also get 102 | # different/unique tpr/fpr curves (i.e. len(_fpr_idx) is different for every call). 103 | # We therefore need to resample per-region curves to a fixed sampling ratio (defined above). 104 | labels = cca.unique()[1:] # 0 is background 105 | background = cca == 0 106 | _fpr: Tensor 107 | _tpr: Tensor 108 | for label in labels: 109 | interp: bool = False 110 | new_idx[-1] = output_size - 1 111 | mask = cca == label 112 | # Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other 113 | # label in labels as FPs. We also don't need to return the thresholds 114 | _fpr, _tpr = roc(preds[background | mask], mask[background | mask])[:-1] 115 | 116 | # catch edge-case where ROC only has fpr vals > self.fpr_limit 117 | if _fpr[_fpr <= self.fpr_limit].max() == 0: 118 | _fpr_limit = _fpr[_fpr > self.fpr_limit].min() 119 | else: 120 | _fpr_limit = self.fpr_limit 121 | 122 | _fpr_idx = torch.where(_fpr <= _fpr_limit)[0] 123 | # if computed roc curve is not specified sufficiently close to self.fpr_limit, 124 | # we include the closest higher tpr/fpr pair and linearly interpolate the tpr/fpr point at self.fpr_limit 125 | if not torch.allclose(_fpr[_fpr_idx].max(), self.fpr_limit): 126 | _tmp_idx = torch.searchsorted(_fpr, self.fpr_limit) 127 | _fpr_idx = torch.cat([_fpr_idx, _tmp_idx.unsqueeze_(0)]) 128 | _slope = 1 - ((_fpr[_tmp_idx] - self.fpr_limit) / (_fpr[_tmp_idx] - _fpr[_tmp_idx - 1])) 129 | interp = True 130 | 131 | _fpr = _fpr[_fpr_idx] 132 | _tpr = _tpr[_fpr_idx] 133 | 134 | _fpr_idx = _fpr_idx.float() 135 | _fpr_idx /= _fpr_idx.max() 136 | _fpr_idx *= new_idx.max() 137 | 138 | if interp: 139 | # last point will be sampled at self.fpr_limit 140 | new_idx[-1] = _fpr_idx[-2] + ((_fpr_idx[-1] - _fpr_idx[-2]) * _slope) 141 | 142 | _tpr = self.interp1d(_fpr_idx, _tpr, new_idx) 143 | _fpr = self.interp1d(_fpr_idx, _fpr, new_idx) 144 | tpr += _tpr 145 | fpr += _fpr 146 | 147 | # Actually perform the averaging 148 | tpr /= labels.size(0) 149 | fpr /= labels.size(0) 150 | return fpr, tpr 151 | 152 | def compute(self) -> Tensor: 153 | """Fist compute PRO curve, then compute and scale area under the curve. 154 | 155 | Returns: 156 | Tensor: Value of the AUPRO metric 157 | """ 158 | fpr, tpr = self._compute() 159 | 160 | aupro = auc(fpr, tpr) 161 | aupro = aupro / fpr[-1] # normalize the area 162 | 163 | return aupro 164 | 165 | def generate_figure(self) -> Tuple[Figure, str]: 166 | """Generate a figure containing the PRO curve and the AUPRO. 167 | 168 | Returns: 169 | Tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging 170 | """ 171 | fpr, tpr = self._compute() 172 | aupro = self.compute() 173 | 174 | xlim = (0.0, self.fpr_limit.detach_().cpu().numpy()) 175 | ylim = (0.0, 1.0) 176 | xlabel = "Global FPR" 177 | ylabel = "Averaged Per-Region TPR" 178 | loc = "lower right" 179 | title = "PRO" 180 | 181 | fig, _axis = plot_figure(fpr, tpr, aupro, xlim, ylim, xlabel, ylabel, loc, title) 182 | 183 | return fig, "PRO" 184 | 185 | @staticmethod 186 | def interp1d(old_x: Tensor, old_y: Tensor, new_x: Tensor) -> Tensor: 187 | """Function to interpolate a 1D signal linearly to new sampling points. 188 | 189 | Args: 190 | old_x (Tensor): original 1-D x values (same size as y) 191 | old_y (Tensor): original 1-D y values (same size as x) 192 | new_x (Tensor): x-values where y should be interpolated at 193 | 194 | Returns: 195 | Tensor: y-values at corresponding new_x values. 196 | """ 197 | 198 | # Compute slope 199 | eps = torch.finfo(old_y.dtype).eps 200 | slope = (old_y[1:] - old_y[:-1]) / (eps + (old_x[1:] - old_x[:-1])) 201 | 202 | # Prepare idx for linear interpolation 203 | idx = torch.searchsorted(old_x, new_x) 204 | 205 | # searchsorted looks for the index where the values must be inserted 206 | # to preserve order, but we actually want the preceeding index. 207 | idx -= 1 208 | # we clamp the index, because the number of intervals = old_x.size(0) -1, 209 | # and the left neighbour should hence be at most number of intervals -1, i.e. old_x.size(0) - 2 210 | idx = torch.clamp(idx, 0, old_x.size(0) - 2) 211 | 212 | # perform actual linear interpolation 213 | y_new = old_y[idx] + slope[idx] * (new_x - old_x[idx]) 214 | 215 | return y_new 216 | -------------------------------------------------------------------------------- /patchcore/torch_model.py: -------------------------------------------------------------------------------- 1 | """PyTorch model for the PatchCore model implementation.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from typing import Dict, List, Optional, Tuple, Union 7 | import pdb 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import Tensor, nn 12 | 13 | from components import ( 14 | DynamicBufferModule, 15 | FeatureExtractor, 16 | KCenterGreedy, 17 | ) 18 | from anomaly_map import AnomalyMapGenerator 19 | from pre_processing import Tiler 20 | 21 | class PatchcoreModel(DynamicBufferModule, nn.Module): 22 | """Patchcore Module.""" 23 | 24 | def __init__( 25 | self, 26 | input_size: Tuple[int, int], 27 | layers: List[str], 28 | backbone: str = "wide_resnet50_2", 29 | pre_trained: bool = True, 30 | num_neighbors: int = 9, 31 | ) -> None: 32 | super().__init__() 33 | self.tiler: Optional[Tiler] = None 34 | 35 | self.backbone = backbone 36 | self.layers = layers 37 | self.input_size = input_size 38 | self.num_neighbors = num_neighbors 39 | 40 | self.feature_extractor = FeatureExtractor(backbone=self.backbone, pre_trained=pre_trained, layers=self.layers) 41 | self.w = None 42 | self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) 43 | self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) 44 | 45 | self.register_buffer("memory_bank", torch.Tensor()) 46 | self.memory_bank: torch.Tensor 47 | 48 | def feature_pooling_multiscale(self, input): 49 | out = [] 50 | for idx, p in enumerate(self.multiple_pooler): 51 | if (idx+1) != len(self.multiple_pooler): 52 | out.append(p(input)) 53 | else: 54 | global_feature = p(input) 55 | global_feature = torch.repeat_interleave(global_feature, dim=2, repeats=out[0].shape[2]) 56 | global_feature = torch.repeat_interleave(global_feature, dim=3, repeats=out[0].shape[3]) 57 | out.append(global_feature) 58 | 59 | return torch.cat(out, dim=1) 60 | 61 | def forward(self, input_tensor: Tensor, mean: Tensor = None, std: Tensor = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 62 | """Return Embedding during training, or a tuple of anomaly map and anomaly score during testing. 63 | 64 | Steps performed: 65 | 1. Get features from a CNN. 66 | 2. Generate embedding based on the features. 67 | 3. Compute anomaly map in test mode. 68 | 69 | Args: 70 | input_tensor (Tensor): Input tensor 71 | 72 | Returns: 73 | Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Embedding for training, 74 | anomaly map and anomaly score for testing. 75 | """ 76 | if self.tiler: 77 | input_tensor = self.tiler.tile(input_tensor) 78 | 79 | # with torch.no_grad(): 80 | features = self.feature_extractor(input_tensor) 81 | 82 | """pre-trained correspondence layer""" 83 | if self.w: 84 | features = {layer: self.feature_pooler(self.w[layer](feature)) for layer, feature in features.items()} 85 | else: 86 | features = {layer: self.feature_pooler(feature) for layer, feature in features.items()} 87 | 88 | """original patchcore""" 89 | # features = {layer: self.feature_pooler(feature) for layer, feature in features.items()} 90 | 91 | embedding = self.generate_embedding(features) 92 | 93 | if self.tiler: 94 | embedding = self.tiler.untile(embedding) 95 | 96 | feature_map_shape = embedding.shape[-2:] 97 | embedding = self.reshape_embedding(embedding) 98 | 99 | self.embed = embedding 100 | if self.training: 101 | output = embedding 102 | else: 103 | embedding = embedding.detach().cpu() 104 | if mean != None and std != None: 105 | embedding = (embedding - mean) / std 106 | self.anomaly_map_generator.cpu() 107 | patch_scores = self.nearest_neighbors(embedding=embedding, n_neighbors=self.num_neighbors) 108 | self.patch_scores = patch_scores 109 | anomaly_map, anomaly_score, anomaly_map_ = self.anomaly_map_generator( 110 | patch_scores=patch_scores, feature_map_shape=feature_map_shape 111 | ) 112 | output = (anomaly_map, anomaly_score, anomaly_map_) 113 | 114 | return output 115 | 116 | def generate_embedding(self, features: Dict[str, Tensor]) -> torch.Tensor: 117 | """Generate embedding from hierarchical feature map. 118 | 119 | Args: 120 | features: Hierarchical feature map from a CNN (ResNet18 or WideResnet) 121 | features: Dict[str:Tensor]: 122 | 123 | Returns: 124 | Embedding vector 125 | """ 126 | 127 | embeddings = features[self.layers[0]] 128 | for layer in self.layers[1:]: 129 | layer_embedding = features[layer] 130 | layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest") 131 | embeddings = torch.cat((embeddings, layer_embedding), 1) 132 | 133 | return embeddings 134 | 135 | @staticmethod 136 | def reshape_embedding(embedding: Tensor) -> Tensor: 137 | """Reshape Embedding. 138 | 139 | Reshapes Embedding to the following format: 140 | [Batch, Embedding, Patch, Patch] to [Batch*Patch*Patch, Embedding] 141 | 142 | Args: 143 | embedding (Tensor): Embedding tensor extracted from CNN features. 144 | 145 | Returns: 146 | Tensor: Reshaped embedding tensor. 147 | """ 148 | embedding_size = embedding.size(1) 149 | embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) 150 | return embedding 151 | 152 | def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None: 153 | """Subsample embedding based on coreset sampling and store to memory. 154 | 155 | Args: 156 | embedding (np.ndarray): Embedding tensor from the CNN 157 | sampling_ratio (float): Coreset sampling ratio 158 | """ 159 | 160 | ## Coreset Subsampling 161 | if sampling_ratio < 1: 162 | sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) 163 | coreset = sampler.sample_coreset() 164 | else: 165 | coreset = embedding 166 | # sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) 167 | # coreset = sampler.sample_coreset() 168 | self.memory_bank = coreset 169 | 170 | def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor: 171 | """Nearest Neighbours using brute force method and euclidean norm. 172 | 173 | Args: 174 | embedding (Tensor): Features to compare the distance with the memory bank. 175 | n_neighbors (int): Number of neighbors to look at 176 | 177 | Returns: 178 | Tensor: Patch scores. 179 | """ 180 | # ### Cal variance 181 | # distances = torch.cdist(self.memory_bank, self.memory_bank, p=2.0) # [N, N] 182 | # _, k_nearest_idx = distances.topk(k=n_neighbors + 1, largest=False, dim=1) 183 | # k_nearest = self.memory_bank[k_nearest_idx[:, 1:]] # [N, k, C] 184 | # diff = (self.memory_bank.unsqueeze(dim=1) - k_nearest) ** 2 # [N, k, C] 185 | # var = diff.var(dim=1) # [N, C] 186 | # # var = diff.reshape(-1, diff.shape[-1]).var(dim=0).unsqueeze(dim=0).unsqueeze(dim=0) # [C] 187 | # # var = 1 188 | # 189 | # ### Cal distance with variance 190 | # distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm 191 | # _, k_nearest_idx = distances.topk(k=n_neighbors+1, largest=False, dim=1) 192 | # k_nearest = self.memory_bank[k_nearest_idx[:, 1:]] # [N, k, C] 193 | # var = torch.index_select(var, dim=0, index=k_nearest_idx[:, 1:].reshape(-1)).reshape(embedding.shape[0], n_neighbors, -1) + 1 194 | # patch_scores = ((((embedding.unsqueeze(dim=1) - k_nearest) ** 2)/var).sum(dim=-1))**0.5 195 | 196 | ## Original 197 | distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm 198 | patch_scores, _ = distances.topk(k=n_neighbors, largest=False, dim=1) 199 | 200 | 201 | return patch_scores 202 | 203 | def forward_neg(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 204 | """Return Embedding during training, or a tuple of anomaly map and anomaly score during testing. 205 | 206 | Steps performed: 207 | 1. Get features from a CNN. 208 | 2. Generate embedding based on the features. 209 | 3. Compute anomaly map in test mode. 210 | 211 | Args: 212 | input_tensor (Tensor): Input tensor 213 | 214 | Returns: 215 | Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Embedding for training, 216 | anomaly map and anomaly score for testing. 217 | """ 218 | if self.tiler: 219 | input_tensor = self.tiler.tile(input_tensor) 220 | 221 | # with torch.no_grad(): 222 | features = self.feature_extractor(input_tensor) 223 | 224 | """pre-trained correspondence layer""" 225 | if self.w: 226 | features = {layer: self.feature_pooler(self.w[layer](feature)) for layer, feature in features.items()} 227 | else: 228 | features = {layer: self.feature_pooler(feature) for layer, feature in features.items()} 229 | 230 | """original patchcore""" 231 | # features = {layer: self.feature_pooler(feature) for layer, feature in features.items()} 232 | 233 | embedding = self.generate_embedding(features) 234 | 235 | if self.tiler: 236 | embedding = self.tiler.untile(embedding) 237 | 238 | feature_map_shape = embedding.shape[-2:] 239 | embedding = self.reshape_embedding(embedding) 240 | self.embed_neg = embedding 241 | if self.training: 242 | output = embedding 243 | else: 244 | patch_scores = self.nearest_neighbors_neg(embedding=embedding, n_neighbors=self.num_neighbors) 245 | anomaly_map, anomaly_score, anomaly_map_ = self.anomaly_map_generator( 246 | patch_scores=patch_scores, feature_map_shape=feature_map_shape 247 | ) 248 | output = (anomaly_map, anomaly_score, anomaly_map_) 249 | 250 | return output 251 | 252 | def subsample_embedding_neg(self, embedding: torch.Tensor, sampling_ratio: float) -> None: 253 | """Subsample embedding based on coreset sampling and store to memory. 254 | 255 | Args: 256 | embedding (np.ndarray): Embedding tensor from the CNN 257 | sampling_ratio (float): Coreset sampling ratio 258 | """ 259 | 260 | ## Coreset Subsampling 261 | if sampling_ratio < 1: 262 | sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) 263 | coreset = sampler.sample_coreset() 264 | else: 265 | coreset = embedding 266 | # sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) 267 | # coreset = sampler.sample_coreset() 268 | self.memory_bank_neg = coreset 269 | 270 | def nearest_neighbors_neg(self, embedding: Tensor, n_neighbors: int = 1) -> Tensor: 271 | """Nearest Neighbours using brute force method and euclidean norm. 272 | 273 | Args: 274 | embedding (Tensor): Features to compare the distance with the memory bank. 275 | n_neighbors (int): Number of neighbors to look at 276 | 277 | Returns: 278 | Tensor: Patch scores. 279 | """ 280 | # ### Cal variance 281 | # distances = torch.cdist(self.memory_bank, self.memory_bank, p=2.0) # [N, N] 282 | # _, k_nearest_idx = distances.topk(k=n_neighbors + 1, largest=False, dim=1) 283 | # k_nearest = self.memory_bank[k_nearest_idx[:, 1:]] # [N, k, C] 284 | # diff = (self.memory_bank.unsqueeze(dim=1) - k_nearest) ** 2 # [N, k, C] 285 | # var = diff.var(dim=1) # [N, C] 286 | # # var = diff.reshape(-1, diff.shape[-1]).var(dim=0).unsqueeze(dim=0).unsqueeze(dim=0) # [C] 287 | # # var = 1 288 | # 289 | # ### Cal distance with variance 290 | # distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm 291 | # _, k_nearest_idx = distances.topk(k=n_neighbors+1, largest=False, dim=1) 292 | # k_nearest = self.memory_bank[k_nearest_idx[:, 1:]] # [N, k, C] 293 | # var = torch.index_select(var, dim=0, index=k_nearest_idx[:, 1:].reshape(-1)).reshape(embedding.shape[0], n_neighbors, -1) + 1 294 | # patch_scores = ((((embedding.unsqueeze(dim=1) - k_nearest) ** 2)/var).sum(dim=-1))**0.5 295 | 296 | ## Original 297 | distances = torch.cdist(embedding, self.memory_bank_neg, p=2.0) # euclidean norm 298 | patch_scores, _ = distances.topk(k=n_neighbors, largest=False, dim=1) 299 | 300 | return patch_scores 301 | -------------------------------------------------------------------------------- /patchcore/components/freia/modules/all_in_one_block.py: -------------------------------------------------------------------------------- 1 | """All in One Block Module.""" 2 | 3 | # Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | 7 | # flake8: noqa 8 | # pylint: skip-file 9 | # type: ignore 10 | # pydocstyle: noqa 11 | 12 | import warnings 13 | from typing import Callable 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from scipy.stats import special_ortho_group 20 | 21 | from anomalib.models.components.freia.modules.base import InvertibleModule 22 | 23 | 24 | class AllInOneBlock(InvertibleModule): 25 | r"""Module combining the most common operations in a normalizing flow or similar model. 26 | 27 | It combines affine coupling, permutation, and global affine transformation 28 | ('ActNorm'). It can also be used as GIN coupling block, perform learned 29 | householder permutations, and use an inverted pre-permutation. The affine 30 | transformation includes a soft clamping mechanism, first used in Real-NVP. 31 | The block as a whole performs the following computation: 32 | .. math:: 33 | y = V\\,R \\; \\Psi(s_\\mathrm{global}) \\odot \\mathrm{Coupling}\\Big(R^{-1} V^{-1} x\\Big)+ t_\\mathrm{global} 34 | - The inverse pre-permutation of x (i.e. :math:`R^{-1} V^{-1}`) is optional (see 35 | ``reverse_permutation`` below). 36 | - The learned householder reflection matrix 37 | :math:`V` is also optional all together (see ``learned_householder_permutation`` 38 | below). 39 | - For the coupling, the input is split into :math:`x_1, x_2` along 40 | the channel dimension. Then the output of the coupling operation is the 41 | two halves :math:`u = \\mathrm{concat}(u_1, u_2)`. 42 | .. math:: 43 | u_1 &= x_1 \\odot \\exp \\Big( \\alpha \\; \\mathrm{tanh}\\big( s(x_2) \\big)\\Big) + t(x_2) \\\\ 44 | u_2 &= x_2 45 | Because :math:`\\mathrm{tanh}(s) \\in [-1, 1]`, this clamping mechanism prevents 46 | exploding values in the exponential. The hyperparameter :math:`\\alpha` can be adjusted. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | dims_in, 52 | dims_c=[], 53 | subnet_constructor: Callable = None, 54 | affine_clamping: float = 2.0, 55 | gin_block: bool = False, 56 | global_affine_init: float = 1.0, 57 | global_affine_type: str = "SOFTPLUS", 58 | permute_soft: bool = False, 59 | learned_householder_permutation: int = 0, 60 | reverse_permutation: bool = False, 61 | ): 62 | r"""Initialize. 63 | 64 | Args: 65 | dims_in (_type_): dims_in 66 | dims_c (list, optional): dims_c. Defaults to []. 67 | subnet_constructor (Callable, optional): class or callable ``f``, called as ``f(channels_in, channels_out)`` and 68 | should return a torch.nn.Module. Predicts coupling coefficients :math:`s, t`. Defaults to None. 69 | affine_clamping (float, optional): clamp the output of the multiplicative coefficients before 70 | exponentiation to +/- ``affine_clamping`` (see :math:`\\alpha` above). Defaults to 2.0. 71 | gin_block (bool, optional): Turn the block into a GIN block from Sorrenson et al, 2019. 72 | Makes it so that the coupling operations as a whole is volume preserving. Defaults to False. 73 | global_affine_init (float, optional): Initial value for the global affine scaling :math:`s_\mathrm{global}`.. Defaults to 1.0. 74 | global_affine_type (str, optional): ``'SIGMOID'``, ``'SOFTPLUS'``, or ``'EXP'``. Defines the activation to be used 75 | on the beta for the global affine scaling (:math:`\\Psi` above).. Defaults to "SOFTPLUS". 76 | permute_soft (bool, optional): bool, whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, 77 | or to use hard permutations instead. Note, ``permute_soft=True`` is very slow 78 | when working with >512 dimensions. Defaults to False. 79 | learned_householder_permutation (int, optional): Int, if >0, turn on the matrix :math:`V` above, that represents 80 | multiple learned householder reflections. Slow if large number. 81 | Dubious whether it actually helps network performance. Defaults to 0. 82 | reverse_permutation (bool, optional): Reverse the permutation before the block, as introduced by Putzky 83 | et al, 2019. Turns on the :math:`R^{-1} V^{-1}` pre-multiplication above. Defaults to False. 84 | 85 | Raises: 86 | ValueError: _description_ 87 | ValueError: _description_ 88 | ValueError: _description_ 89 | """ 90 | 91 | super().__init__(dims_in, dims_c) 92 | 93 | channels = dims_in[0][0] 94 | # rank of the tensors means 1d, 2d, 3d tensor etc. 95 | self.input_rank = len(dims_in[0]) - 1 96 | # tuple containing all dims except for batch-dim (used at various points) 97 | self.sum_dims = tuple(range(1, 2 + self.input_rank)) 98 | 99 | if len(dims_c) == 0: 100 | self.conditional = False 101 | self.condition_channels = 0 102 | else: 103 | assert tuple(dims_c[0][1:]) == tuple( 104 | dims_in[0][1:] 105 | ), f"Dimensions of input and condition don't agree: {dims_c} vs {dims_in}." 106 | self.conditional = True 107 | self.condition_channels = sum(dc[0] for dc in dims_c) 108 | 109 | split_len1 = channels - channels // 2 110 | split_len2 = channels // 2 111 | self.splits = [split_len1, split_len2] 112 | 113 | try: 114 | self.permute_function = {0: F.linear, 1: F.conv1d, 2: F.conv2d, 3: F.conv3d}[self.input_rank] 115 | except KeyError: 116 | raise ValueError(f"Data is {1 + self.input_rank}D. Must be 1D-4D.") 117 | 118 | self.in_channels = channels 119 | self.clamp = affine_clamping 120 | self.GIN = gin_block 121 | self.reverse_pre_permute = reverse_permutation 122 | self.householder = learned_householder_permutation 123 | 124 | if permute_soft and channels > 512: 125 | warnings.warn( 126 | ( 127 | "Soft permutation will take a very long time to initialize " 128 | f"with {channels} feature channels. Consider using hard permutation instead." 129 | ) 130 | ) 131 | 132 | # global_scale is used as the initial value for the global affine scale 133 | # (pre-activation). It is computed such that 134 | # global_scale_activation(global_scale) = global_affine_init 135 | # the 'magic numbers' (specifically for sigmoid) scale the activation to 136 | # a sensible range. 137 | if global_affine_type == "SIGMOID": 138 | global_scale = 2.0 - np.log(10.0 / global_affine_init - 1.0) 139 | self.global_scale_activation = lambda a: 10 * torch.sigmoid(a - 2.0) 140 | elif global_affine_type == "SOFTPLUS": 141 | global_scale = 2.0 * np.log(np.exp(0.5 * 10.0 * global_affine_init) - 1) 142 | self.softplus = nn.Softplus(beta=0.5) 143 | self.global_scale_activation = lambda a: 0.1 * self.softplus(a) 144 | elif global_affine_type == "EXP": 145 | global_scale = np.log(global_affine_init) 146 | self.global_scale_activation = lambda a: torch.exp(a) 147 | else: 148 | raise ValueError('Global affine activation must be "SIGMOID", "SOFTPLUS" or "EXP"') 149 | 150 | self.global_scale = nn.Parameter( 151 | torch.ones(1, self.in_channels, *([1] * self.input_rank)) * float(global_scale) 152 | ) 153 | self.global_offset = nn.Parameter(torch.zeros(1, self.in_channels, *([1] * self.input_rank))) 154 | 155 | if permute_soft: 156 | w = special_ortho_group.rvs(channels) 157 | else: 158 | w = np.zeros((channels, channels)) 159 | for i, j in enumerate(np.random.permutation(channels)): 160 | w[i, j] = 1.0 161 | 162 | if self.householder: 163 | # instead of just the permutation matrix w, the learned housholder 164 | # permutation keeps track of reflection vectors vk, in addition to a 165 | # random initial permutation w_0. 166 | self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True) 167 | self.w_perm = None 168 | self.w_perm_inv = None 169 | self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False) 170 | else: 171 | self.w_perm = nn.Parameter( 172 | torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)), requires_grad=False 173 | ) 174 | self.w_perm_inv = nn.Parameter( 175 | torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)), requires_grad=False 176 | ) 177 | 178 | if subnet_constructor is None: 179 | raise ValueError("Please supply a callable subnet_constructor" "function or object (see docstring)") 180 | self.subnet = subnet_constructor(self.splits[0] + self.condition_channels, 2 * self.splits[1]) 181 | self.last_jac = None 182 | 183 | def _construct_householder_permutation(self): 184 | """Compute a permutation matrix. 185 | 186 | Compute a permutation matrix from the reflection vectors that are 187 | learned internally as nn.Parameters. 188 | """ 189 | w = self.w_0 190 | for vk in self.vk_householder: 191 | w = torch.mm(w, torch.eye(self.in_channels).to(w.device) - 2 * torch.ger(vk, vk) / torch.dot(vk, vk)) 192 | 193 | for i in range(self.input_rank): 194 | w = w.unsqueeze(-1) 195 | return w 196 | 197 | def _permute(self, x, rev=False): 198 | """Perform permutation. 199 | 200 | Performs the permutation and scaling after the coupling operation. 201 | Returns transformed outputs and the LogJacDet of the scaling operation. 202 | """ 203 | if self.GIN: 204 | scale = 1.0 205 | perm_log_jac = 0.0 206 | else: 207 | scale = self.global_scale_activation(self.global_scale) 208 | perm_log_jac = torch.sum(torch.log(scale)) 209 | 210 | if rev: 211 | return ((self.permute_function(x, self.w_perm_inv) - self.global_offset) / scale, perm_log_jac) 212 | else: 213 | return (self.permute_function(x * scale + self.global_offset, self.w_perm), perm_log_jac) 214 | 215 | def _pre_permute(self, x, rev=False): 216 | """Permute before the coupling block, only used if reverse_permutation is set.""" 217 | if rev: 218 | return self.permute_function(x, self.w_perm) 219 | else: 220 | return self.permute_function(x, self.w_perm_inv) 221 | 222 | def _affine(self, x, a, rev=False): 223 | """Perform affine coupling operation. 224 | 225 | Given the passive half, and the pre-activation outputs of the 226 | coupling subnetwork, perform the affine coupling operation. 227 | Returns both the transformed inputs and the LogJacDet. 228 | """ 229 | 230 | # the entire coupling coefficient tensor is scaled down by a 231 | # factor of ten for stability and easier initialization. 232 | a *= 0.1 233 | ch = x.shape[1] 234 | 235 | sub_jac = self.clamp * torch.tanh(a[:, :ch]) 236 | if self.GIN: 237 | sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) 238 | 239 | if not rev: 240 | return (x * torch.exp(sub_jac) + a[:, ch:], torch.sum(sub_jac, dim=self.sum_dims)) 241 | else: 242 | return ((x - a[:, ch:]) * torch.exp(-sub_jac), -torch.sum(sub_jac, dim=self.sum_dims)) 243 | 244 | def forward(self, x, c=[], rev=False, jac=True): 245 | """See base class docstring.""" 246 | if self.householder: 247 | self.w_perm = self._construct_householder_permutation() 248 | if rev or self.reverse_pre_permute: 249 | self.w_perm_inv = self.w_perm.transpose(0, 1).contiguous() 250 | 251 | if rev: 252 | x, global_scaling_jac = self._permute(x[0], rev=True) 253 | x = (x,) 254 | elif self.reverse_pre_permute: 255 | x = (self._pre_permute(x[0], rev=False),) 256 | 257 | x1, x2 = torch.split(x[0], self.splits, dim=1) 258 | 259 | if self.conditional: 260 | x1c = torch.cat([x1, *c], 1) 261 | else: 262 | x1c = x1 263 | 264 | if not rev: 265 | a1 = self.subnet(x1c) 266 | x2, j2 = self._affine(x2, a1) 267 | else: 268 | a1 = self.subnet(x1c) 269 | x2, j2 = self._affine(x2, a1, rev=True) 270 | 271 | log_jac_det = j2 272 | x_out = torch.cat((x1, x2), 1) 273 | 274 | if not rev: 275 | x_out, global_scaling_jac = self._permute(x_out, rev=False) 276 | elif self.reverse_pre_permute: 277 | x_out = self._pre_permute(x_out, rev=True) 278 | 279 | # add the global scaling Jacobian to the total. 280 | # trick to get the total number of non-channel dimensions: 281 | # number of elements of the first channel of the first batch member 282 | n_pixels = x_out[0, :1].numel() 283 | log_jac_det += (-1) ** rev * n_pixels * global_scaling_jac 284 | 285 | return (x_out,), log_jac_det 286 | 287 | def output_dims(self, input_dims): 288 | """Output Dims.""" 289 | return input_dims 290 | -------------------------------------------------------------------------------- /train_normal_unet.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace 2 | from PIL import Image 3 | from statistics import mean 4 | import os 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils import data 9 | import numpy as np 10 | import pickle 11 | import cv2 12 | import torch.optim as optim 13 | import scipy.misc 14 | import torchvision.models as models 15 | import torch.backends.cudnn as cudnn 16 | from torch.nn.functional import threshold, normalize 17 | import torch.nn.functional as F 18 | import matplotlib.pyplot as plt 19 | import os.path as osp 20 | # from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor 21 | from dataset_2d_sup import DataSet2D, my_collate, ValDataSet2D 22 | from torch.utils.data import DataLoader 23 | import random 24 | import timeit 25 | from tensorboardX import SummaryWriter 26 | from sklearn import metrics 27 | from math import ceil 28 | from apex import amp 29 | from apex.parallel import convert_syncbn_model 30 | import sys 31 | sys.path.append("..") 32 | 33 | 34 | start = timeit.default_timer() 35 | 36 | 37 | def get_arguments(): 38 | 39 | parser = argparse.ArgumentParser(description="SAM for Medical Image") 40 | 41 | parser.add_argument("--data_dir", type=str, 42 | default="/media/NAS/nas_187/soopil/data/stanford/LOCO_AD_pre") 43 | parser.add_argument("--seg_dir", type=str, 44 | default="fss_comparison/4_onetype_5shot_3level") 45 | parser.add_argument("--obj_name", type=str, default='screw_bag') 46 | parser.add_argument("--save_dir", type=str, default='orig_512_seg/4_onetype_5shot_3level') 47 | parser.add_argument("--label", type=int, default=255) 48 | parser.add_argument("--ref_name", type=str, default='001') 49 | parser.add_argument("--snapshot_dir", type=str, default='./output/results') 50 | parser.add_argument("--level", type=int, default=3) 51 | parser.add_argument("--pretrained", type=str2bool, default=True) 52 | parser.add_argument("--n_shot", type=int, default=1) 53 | parser.add_argument("--n_zero", type=int, default=3) 54 | parser.add_argument("--input_size", type=str, default='256,256') 55 | parser.add_argument("--batch_size", type=int, default=2) 56 | parser.add_argument("--num_gpus", type=int, default=1) 57 | parser.add_argument('--local_rank', type=int, default=0) 58 | parser.add_argument("--FP16", type=str2bool, default=False) 59 | parser.add_argument("--num_epochs", type=int, default=50) 60 | parser.add_argument("--itrs_each_epoch", type=int, default=250) 61 | parser.add_argument("--learning_rate", type=float, default=1e-3) 62 | parser.add_argument("--num_classes", type=int, default=2) 63 | parser.add_argument("--num_workers", type=int, default=1) 64 | parser.add_argument("--weight_std", type=str2bool, default=True) 65 | parser.add_argument("--momentum", type=float, default=0.9) 66 | parser.add_argument("--power", type=float, default=0.9) 67 | parser.add_argument("--weight_decay", type=float, default=0.0005) 68 | parser.add_argument("--ignore_label", type=int, default=255) 69 | parser.add_argument("--is_training", action="store_true") 70 | parser.add_argument("--random_mirror", type=str2bool, default=True) 71 | parser.add_argument("--random_scale", type=str2bool, default=True) 72 | parser.add_argument("--random_seed", type=int, default=1234) 73 | parser.add_argument("--gpu", type=str, default='None') 74 | return parser 75 | 76 | 77 | class CNNSegmenter(nn.Module): 78 | def __init__(self, num_cls=15+1, use_coord=False, level=3, pretrained=True): 79 | # def __init__(self, num_cls=6+1): 80 | super().__init__() 81 | 82 | # pretrained CNN feature extractor 83 | self.level = level 84 | self.use_coord = use_coord 85 | 86 | def hook_t(module, input, output): 87 | self.features.append(output) 88 | 89 | self.enc = models.wide_resnet101_2(pretrained=pretrained) 90 | # self.enc = models.resnet18(pretrained=True) 91 | # in_ch = 64+128+256 92 | # self.enc = models.resnet50(pretrained=True) 93 | # in_ch = 64+128+256 94 | 95 | in_ch = 256+512+1024 96 | if use_coord: 97 | in_ch += 2 98 | 99 | self.conv1 = nn.Conv2d(1024+2, 512, 3, 1, 1) 100 | self.conv2 = nn.Conv2d(512, 256, 3, 1, 1) 101 | self.conv3 = nn.Conv2d(256, num_cls, 1) 102 | self.relu = nn.ReLU() 103 | # for param in self.enc.parameters(): 104 | # param.requires_grad = False 105 | 106 | # self.enc.layer1[-1].register_forward_hook(hook_t) 107 | # self.enc.layer2[-1].register_forward_hook(hook_t) 108 | # self.enc.layer3[-1].register_forward_hook(hook_t) 109 | # self.enc.layer4[-1].register_forward_hook(hook_t) 110 | # print(self.enc) 111 | 112 | def forward(self, x, coord): 113 | out = self.enc.conv1(x) 114 | out = self.enc.bn1(out) 115 | out = self.enc.relu(out) 116 | out = self.enc.maxpool(out) 117 | f0 = self.enc.layer1(out) 118 | f1 = self.enc.layer2(f0) 119 | out = self.enc.layer3(f1) 120 | 121 | # if self.use_coord: 122 | coord = F.interpolate(coord, align_corners=True, 123 | mode="bilinear", size=out.shape[-2:]) 124 | out = torch.cat([out, coord], dim=1) 125 | 126 | out = F.interpolate( 127 | out, align_corners=True, mode="bilinear", size=f1.shape[-2:]) 128 | # print(out.shape, f1.shape) 129 | out = self.relu(self.conv1(out))+f1 130 | 131 | out = F.interpolate( 132 | out, align_corners=True, mode="bilinear", size=f0.shape[-2:]) 133 | out = self.relu(self.conv2(out))+f0 134 | 135 | out = F.interpolate( 136 | out, align_corners=True, mode="bilinear", size=x.shape[-2:]) 137 | out = self.conv3(out) 138 | 139 | return out 140 | 141 | 142 | def str2bool(v): 143 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 144 | return True 145 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 146 | return False 147 | else: 148 | raise argparse.ArgumentTypeError('Boolean value expected.') 149 | 150 | 151 | def lr_poly(base_lr, iter, max_iter, power): 152 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 153 | 154 | 155 | def adjust_learning_rate(optimizer, i_iter, lr, num_stemps, power): 156 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 157 | lr = lr_poly(lr, i_iter, num_stemps, power) 158 | optimizer.param_groups[0]['lr'] = lr 159 | return lr 160 | 161 | 162 | def get_dice(pred, trg): # multi chprednnel 163 | # set_trace() 164 | 165 | eps = 0.0001 166 | trg = trg.unsqueeze(1) 167 | new_trg = torch.zeros_like(trg).repeat(1, pred.shape[1], 1, 1).long() 168 | new_trg = new_trg.scatter(1, trg, 1) 169 | 170 | numer = (pred * new_trg).sum((2, 3)) 171 | denom = pred.sum((2, 3)) + new_trg.sum((2, 3)) + eps 172 | dsc = (numer*2)/denom 173 | dsc_score = dsc.mean((0, 1)) 174 | return dsc_score 175 | 176 | 177 | def proportion_loss(pred, trg): 178 | # set_trace() 179 | eps = 0.0001 180 | trg = trg.unsqueeze(1) 181 | new_trg = torch.zeros_like(trg).repeat(1, pred.shape[1], 1, 1).long() 182 | new_trg = new_trg.scatter(1, trg, 1).float() 183 | diff = torch.abs(new_trg.mean((2, 3)) - pred.mean((2, 3))) 184 | loss = diff[:, 1:].sum() # exclude BG 185 | return loss 186 | 187 | 188 | def save_sample(pred, epoch): 189 | pred = torch.sigmoid(pred).round()*255 190 | pred = np.array(pred.detach().cpu(), dtype=np.uint8) 191 | # set_trace() 192 | for k in range(pred.shape[0]): 193 | slice = pred[k, 0] 194 | slice = np.expand_dims(slice, axis=2) 195 | slice = np.repeat(slice, 3, axis=2) 196 | im = Image.fromarray(slice) 197 | im.save(f"./sample/{epoch}_{k}.png") 198 | 199 | 200 | def main(): 201 | """Create the model and start the training.""" 202 | parser = get_arguments() 203 | print(parser) 204 | 205 | args = parser.parse_args() 206 | if args.num_gpus > 1: 207 | torch.cuda.set_device(args.local_rank) 208 | 209 | h, w = map(int, args.input_size.split(',')) 210 | input_size = (h, w) 211 | 212 | cudnn.benchmark = True 213 | seed = args.random_seed 214 | torch.manual_seed(seed) 215 | if torch.cuda.is_available(): 216 | torch.cuda.manual_seed(seed) 217 | 218 | print(args.obj_name) 219 | if args.obj_name == "screw_bag": 220 | rot_angle = (0, 360) 221 | num_cls = 7 222 | use_coord = True 223 | elif args.obj_name == "breakfast_box": 224 | rot_angle = (0, 360) 225 | num_cls = 7 226 | use_coord = True 227 | elif args.obj_name == "juice_bottle": 228 | rot_angle = (-20, 20) 229 | num_cls = 9 230 | use_coord = True 231 | elif args.obj_name == "pushpins": 232 | rot_angle = (-20, 20) 233 | num_cls = 26 # 16 # 20 234 | use_coord = True 235 | elif args.obj_name == "splicing_connectors": 236 | rot_angle = (-20, 20) 237 | num_cls = 10 # 7 238 | use_coord = True 239 | else: 240 | assert False 241 | 242 | # num_cls=10 # for unsupervised co-part segmentation 243 | model = CNNSegmenter( 244 | num_cls=num_cls, 245 | use_coord=use_coord, 246 | level=args.level, 247 | pretrained=args.pretrained) 248 | model = model.cuda() 249 | model.train() 250 | 251 | optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate) 252 | 253 | save_dir = f"{args.snapshot_dir}/{args.obj_name}" 254 | if not os.path.exists(save_dir): 255 | os.makedirs(save_dir) 256 | 257 | train_dataset = DataSet2D(root=args.data_dir, 258 | seg_dir=args.seg_dir, 259 | obj_name=args.obj_name, 260 | label=args.label, 261 | rot_angle=rot_angle, 262 | size=input_size, 263 | transform=None,) 264 | train_loader = DataLoader( 265 | train_dataset, 266 | batch_size=8, 267 | shuffle=True, 268 | num_workers=8, ) # , collate_fn=my_collate) 269 | 270 | loss_fn = torch.nn.CrossEntropyLoss() 271 | 272 | all_tr_loss = [] 273 | val_best_loss = 999999 274 | losses = [] 275 | 276 | for epoch in range(args.num_epochs): 277 | epoch_losses = [] 278 | epoch_dice = [] 279 | epoch_ce = [] 280 | epoch_H = [] 281 | epoch_prop = [] 282 | # adjust_learning_rate(optimizer, epoch, args.learning_rate, args.num_epochs, args.power) 283 | 284 | for iter, batch in enumerate(train_loader): 285 | supp = batch['image'].cuda() 286 | # un = batch['un_image'].cuda() 287 | coord_orig = batch['coord_orig'].cuda() 288 | coord_rot = batch['coord_rot'].cuda() 289 | labels = batch['label'].cuda() 290 | 291 | pred_supp = model(supp, coord_rot) 292 | ce_loss = loss_fn(pred_supp, labels) * 10 293 | prob = F.softmax(pred_supp, dim=1) 294 | dice_loss = 1 - get_dice(prob, labels) 295 | entropy_loss = (-1*prob*((prob+1e-5).log())).mean() * 10 296 | prop_loss = proportion_loss(prob, labels) * 1 297 | loss = ce_loss + dice_loss 298 | 299 | optimizer.zero_grad() 300 | loss.backward() 301 | optimizer.step() 302 | epoch_losses.append(loss.item()) 303 | epoch_dice.append(dice_loss.item()) 304 | epoch_ce.append(ce_loss.item()) 305 | epoch_H.append(entropy_loss.item()) 306 | epoch_prop.append(prop_loss.item()) 307 | # print(f"Iter:{iter}, Loss:{loss:.4f}, Dice:{1-dice_loss:.2f}", end="\r") 308 | 309 | losses.append(epoch_losses) 310 | print( 311 | f'EPOCH: {epoch}, CE: {mean(epoch_ce):.3f}, Dice: {1-mean(epoch_dice):.3f}, H: {mean(epoch_H):.3f}, Prop: {mean(epoch_prop):.3f}, {args.obj_name}') 312 | 313 | print('save model ...') 314 | torch.save(model.state_dict(), osp.join( 315 | save_dir, f'{args.obj_name}_{args.num_epochs}.pth')) 316 | 317 | end = timeit.default_timer() 318 | print(end - start, 'seconds') 319 | 320 | val_dataset = ValDataSet2D(root=args.data_dir, 321 | obj_name=args.obj_name, 322 | label=args.label, 323 | ref_name=args.ref_name, 324 | n_shot=args.n_shot, 325 | n_zero=args.n_zero, 326 | size=input_size, 327 | transform=None, 328 | save_dir=args.save_dir) 329 | 330 | valloader = DataLoader( 331 | val_dataset, batch_size=1, shuffle=False, num_workers=4) # , collate_fn=my_collate) 332 | 333 | palette = [0, 0, 0, 204, 241, 227, 112, 142, 18, 254, 8, 23, 207, 149, 84, 202, 24, 214, 334 | 230, 192, 37, 241, 80, 68, 74, 127, 0, 2, 81, 216, 24, 240, 129, 20, 215, 125, 161, 31, 204, 335 | 254, 52, 116, 117, 198, 203, 4, 41, 68, 127, 252, 61, 21, 3, 142, 40, 10, 159, 241, 61, 36, 336 | 14, 175, 77, 144, 61, 115, 131, 79, 97, 109, 177, 163, 58, 198, 140, 17, 235, 168, 47, 128, 91, 337 | 238, 103, 45, 124, 35, 228, 101, 48, 232, 74, 124, 114, 78, 49, 30, 35, 167, 27, 137, 231, 47, 338 | 235, 32, 39, 56, 112, 32, 62, 173, 79, 86, 44, 201, 77, 47, 217, 246, 223, 57, ] 339 | # Pad with zeroes to 768 values, i.e. 256 RGB colours 340 | palette = palette + [0]*(768-len(palette)) 341 | model.eval() 342 | # os.makedirs(f"{args.data_dir}/orig_512_seg/{args.obj_name}", exist_ok=True) 343 | for index, batch in enumerate(valloader): 344 | image = batch['image'].cuda() 345 | opath = batch['opath'][0] 346 | os.makedirs(os.path.dirname(opath), exist_ok=True) 347 | coord = batch['coord_orig'].cuda() 348 | index_str = str(index).zfill(3) 349 | 350 | image = image.cuda() 351 | coord = coord.cuda() 352 | with torch.no_grad(): 353 | pred = model(image, coord) 354 | pred = torch.argmax(pred, dim=1)[0] 355 | pred = np.array(pred.cpu(), np.uint8) 356 | pi = Image.fromarray(pred, 'P') 357 | pi.putpalette(palette) 358 | pi.show() 359 | pi.save(opath) 360 | print(f"{index}/{len(valloader)}", end='\r') 361 | 362 | 363 | if __name__ == '__main__': 364 | main() 365 | -------------------------------------------------------------------------------- /finetune_cnn_coord.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace 2 | from PIL import Image 3 | from statistics import mean 4 | import os 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils import data 9 | import numpy as np 10 | import pickle 11 | import cv2 12 | import torch.optim as optim 13 | import scipy.misc 14 | import torchvision.models as models 15 | import torch.backends.cudnn as cudnn 16 | from torch.nn.functional import threshold, normalize 17 | import torch.nn.functional as F 18 | import matplotlib.pyplot as plt 19 | import os.path as osp 20 | # from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor 21 | from dataset_2d_semi import DataSet2D, my_collate, ValDataSet2D 22 | from torch.utils.data import DataLoader 23 | import random 24 | import timeit 25 | from tensorboardX import SummaryWriter 26 | from sklearn import metrics 27 | from math import ceil 28 | from apex import amp 29 | from apex.parallel import convert_syncbn_model 30 | import sys 31 | sys.path.append("..") 32 | 33 | 34 | start = timeit.default_timer() 35 | 36 | 37 | class Encoder(nn.Module): 38 | def __init__(self, num_cls=15+1, level=3): 39 | # def __init__(self, num_cls=6+1): 40 | super().__init__() 41 | 42 | # pretrained CNN feature extractor 43 | self.init_features() 44 | 45 | def hook_t(module, input, output): 46 | self.features.append(output) 47 | 48 | self.model = models.wide_resnet101_2(pretrained=True) 49 | self.level = level 50 | # in_ch = 256+512+2 51 | if level == 3: 52 | in_ch = 256+512+1024+2 53 | elif level == 2: 54 | in_ch = 256+512+2 55 | 56 | self.conv = nn.Sequential( 57 | nn.Conv2d(in_ch, in_ch//2, 1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(in_ch//2, num_cls, 1), 60 | 61 | ) 62 | 63 | self.model.layer1[-1].register_forward_hook(hook_t) 64 | self.model.layer2[-1].register_forward_hook(hook_t) 65 | self.model.layer3[-1].register_forward_hook(hook_t) 66 | # self.model.layer4[-1].register_forward_hook(hook_t) 67 | 68 | def init_features(self): 69 | self.features = [] 70 | 71 | def extract_ft(self, x): 72 | self.init_features() 73 | _ = self.model(x) 74 | f0 = self.features[0] 75 | f1 = F.interpolate( 76 | self.features[1], align_corners=True, mode="bilinear", size=f0.shape[-2:]) 77 | f2 = F.interpolate( 78 | self.features[2], align_corners=True, mode="bilinear", size=f0.shape[-2:]) 79 | # f3 = F.interpolate(self.features[3], align_corners=True, mode="bilinear", size=f0.shape[-2:]) 80 | # print(f0.shape, f1.shape, f2.shape) 81 | # return torch.cat([f0, f1], dim=1) 82 | return torch.cat([f0, f1, f2], dim=1) 83 | # return self.features 84 | 85 | def forward(self, x, coord): 86 | fts = self.extract_ft(x) 87 | coord = F.interpolate(coord, align_corners=True, 88 | mode="bilinear", size=fts.shape[-2:]) 89 | fts = torch.cat([fts, coord], dim=1) 90 | out = self.conv(fts) 91 | out = F.interpolate(out, align_corners=True, 92 | mode="bilinear", size=x.shape[-2:]) 93 | return out 94 | 95 | 96 | def str2bool(v): 97 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 98 | return True 99 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 100 | return False 101 | else: 102 | raise argparse.ArgumentTypeError('Boolean value expected.') 103 | 104 | 105 | def get_arguments(): 106 | 107 | parser = argparse.ArgumentParser(description="SAM for Medical Image") 108 | 109 | parser.add_argument("--data_dir", type=str, 110 | default="/media/NAS/nas_187/soopil/data/stanford/LOCO_AD_pre") 111 | parser.add_argument("--obj_name", type=str, default='screw_bag') 112 | parser.add_argument("--label", type=int, default=255) 113 | parser.add_argument("--ref_name", type=str, default='001') 114 | parser.add_argument("--snapshot_dir", type=str, default='./output/results') 115 | parser.add_argument("--n_shot", type=int, default=1) 116 | parser.add_argument("--n_zero", type=int, default=3) 117 | parser.add_argument("--input_size", type=str, default='256,256') 118 | parser.add_argument("--batch_size", type=int, default=2) 119 | parser.add_argument("--num_gpus", type=int, default=1) 120 | parser.add_argument('--local_rank', type=int, default=0) 121 | parser.add_argument("--FP16", type=str2bool, default=False) 122 | parser.add_argument("--num_epochs", type=int, default=50) 123 | parser.add_argument("--itrs_each_epoch", type=int, default=250) 124 | parser.add_argument("--learning_rate", type=float, default=1e-3) 125 | parser.add_argument("--num_classes", type=int, default=2) 126 | parser.add_argument("--num_workers", type=int, default=1) 127 | parser.add_argument("--weight_std", type=str2bool, default=True) 128 | parser.add_argument("--momentum", type=float, default=0.9) 129 | parser.add_argument("--power", type=float, default=0.9) 130 | parser.add_argument("--weight_decay", type=float, default=0.0005) 131 | parser.add_argument("--ignore_label", type=int, default=255) 132 | parser.add_argument("--is_training", action="store_true") 133 | parser.add_argument("--random_mirror", type=str2bool, default=True) 134 | parser.add_argument("--random_scale", type=str2bool, default=True) 135 | parser.add_argument("--random_seed", type=int, default=1234) 136 | parser.add_argument("--gpu", type=str, default='None') 137 | return parser 138 | 139 | 140 | def lr_poly(base_lr, iter, max_iter, power): 141 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 142 | 143 | 144 | def adjust_learning_rate(optimizer, i_iter, lr, num_stemps, power): 145 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 146 | lr = lr_poly(lr, i_iter, num_stemps, power) 147 | optimizer.param_groups[0]['lr'] = lr 148 | return lr 149 | 150 | 151 | def get_dice(pred, trg): # multi chprednnel 152 | # set_trace() 153 | 154 | eps = 0.0001 155 | trg = trg.unsqueeze(1) 156 | new_trg = torch.zeros_like(trg).repeat(1, pred.shape[1], 1, 1).long() 157 | new_trg = new_trg.scatter(1, trg, 1) 158 | 159 | numer = (pred * new_trg).sum((2, 3)) 160 | denom = pred.sum((2, 3)) + new_trg.sum((2, 3)) + eps 161 | dsc = (numer*2)/denom 162 | dsc_score = dsc.mean((0, 1)) 163 | return dsc_score 164 | 165 | 166 | def proportion_loss(pred, trg, obj_name): 167 | # set_trace() 168 | eps = 0.0001 169 | trg = trg.unsqueeze(1) 170 | new_trg = torch.zeros_like(trg).repeat(1, pred.shape[1], 1, 1).long() 171 | new_trg = new_trg.scatter(1, trg, 1).float() 172 | diff = torch.abs(new_trg.mean((2, 3)) - pred.mean((2, 3))) 173 | if obj_name == "breakfast_box": 174 | loss = diff[:, 1:4].sum() + diff[:, 6].sum() # exclude BG 175 | else: 176 | loss = diff[:, 1:].sum() # exclude BG 177 | return loss 178 | 179 | 180 | def save_sample(pred, epoch): 181 | pred = torch.sigmoid(pred).round()*255 182 | pred = np.array(pred.detach().cpu(), dtype=np.uint8) 183 | # set_trace() 184 | for k in range(pred.shape[0]): 185 | slice = pred[k, 0] 186 | slice = np.expand_dims(slice, axis=2) 187 | slice = np.repeat(slice, 3, axis=2) 188 | im = Image.fromarray(slice) 189 | im.save(f"./sample/{epoch}_{k}.png") 190 | 191 | 192 | def main(): 193 | """Create the model and start the training.""" 194 | parser = get_arguments() 195 | print(parser) 196 | 197 | args = parser.parse_args() 198 | if args.num_gpus > 1: 199 | torch.cuda.set_device(args.local_rank) 200 | 201 | h, w = map(int, args.input_size.split(',')) 202 | input_size = (h, w) 203 | 204 | cudnn.benchmark = True 205 | seed = args.random_seed 206 | torch.manual_seed(seed) 207 | if torch.cuda.is_available(): 208 | torch.cuda.manual_seed(seed) 209 | 210 | if args.obj_name == "screw_bag": 211 | num_cls = 7 212 | elif args.obj_name == "breakfast_box": 213 | num_cls = 7 214 | elif args.obj_name == "juice_bottle": 215 | num_cls = 9 216 | elif args.obj_name == "pushpins": 217 | num_cls = 26 # 16 # 20 218 | elif args.obj_name == "splicing_connectors": 219 | num_cls = 10 # 7 220 | else: 221 | assert False 222 | 223 | model = Encoder(num_cls=num_cls) 224 | model = model.cuda() 225 | model.train() 226 | 227 | optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate) 228 | 229 | save_dir = f"{args.snapshot_dir}/{args.obj_name}" 230 | if not os.path.exists(save_dir): 231 | os.makedirs(save_dir) 232 | 233 | train_dataset = DataSet2D(root=args.data_dir, 234 | obj_name=args.obj_name, 235 | label=args.label, 236 | ref_name=args.ref_name, 237 | n_shot=args.n_shot, 238 | n_zero=args.n_zero, 239 | size=input_size, 240 | transform=None) 241 | train_loader = DataLoader( 242 | train_dataset, 243 | batch_size=5, 244 | shuffle=True, 245 | num_workers=8, ) # , collate_fn=my_collate) 246 | 247 | loss_fn = torch.nn.CrossEntropyLoss() 248 | 249 | all_tr_loss = [] 250 | val_best_loss = 999999 251 | losses = [] 252 | 253 | for epoch in range(args.num_epochs): 254 | epoch_losses = [] 255 | epoch_dice = [] 256 | epoch_ce = [] 257 | epoch_H = [] 258 | epoch_prop = [] 259 | # adjust_learning_rate(optimizer, epoch, args.learning_rate, args.num_epochs, args.power) 260 | 261 | for iter, batch in enumerate(train_loader): 262 | supp = batch['image'].cuda() 263 | un = batch['un_image'].cuda() 264 | coord_orig = batch['coord_orig'].cuda() 265 | coord_rot = batch['coord_rot'].cuda() 266 | labels = batch['label'].cuda() 267 | 268 | pred_supp = model(supp, coord_rot) 269 | pred_un = model(un, coord_orig) 270 | # set_trace() 271 | ce_loss = loss_fn(pred_supp, labels) * 10 272 | prob = F.softmax(pred_supp, dim=1) 273 | prob_un = F.softmax(pred_un, dim=1) 274 | dice_loss = 1 - get_dice(prob, labels) 275 | entropy_loss = (-1*prob_un*((prob_un+1e-5).log())).mean() * 10 276 | prop_loss = proportion_loss(prob_un, labels, args.obj_name) * 1 277 | # prop_loss = prob_un.sum((2,3)) 278 | # print(entropy_loss.shape) 279 | # assert False 280 | loss = ce_loss + dice_loss 281 | if epoch >= 50: 282 | loss += prop_loss 283 | loss += entropy_loss 284 | 285 | optimizer.zero_grad() 286 | loss.backward() 287 | optimizer.step() 288 | epoch_losses.append(loss.item()) 289 | epoch_dice.append(dice_loss.item()) 290 | epoch_ce.append(ce_loss.item()) 291 | epoch_H.append(entropy_loss.item()) 292 | epoch_prop.append(prop_loss.item()) 293 | # print(f"Iter:{iter}, Loss:{loss:.4f}, Dice:{1-dice_loss:.2f}", end="\r") 294 | 295 | # save_sample(upscaled_masks, epoch) 296 | losses.append(epoch_losses) 297 | # print() 298 | print( 299 | f'EPOCH: {epoch}, CE: {mean(epoch_ce):.3f}, Dice: {1-mean(epoch_dice):.3f}, H: {mean(epoch_H):.3f}, Prop: {mean(epoch_prop):.3f}, {args.obj_name}') 300 | 301 | print('save model ...') 302 | torch.save(model.state_dict(), osp.join( 303 | save_dir, f'{args.obj_name}_{args.num_epochs}.pth')) 304 | 305 | end = timeit.default_timer() 306 | print(end - start, 'seconds') 307 | 308 | val_dataset = ValDataSet2D(root=args.data_dir, 309 | obj_name=args.obj_name, 310 | label=args.label, 311 | ref_name=args.ref_name, 312 | n_shot=args.n_shot, 313 | n_zero=args.n_zero, 314 | size=input_size, 315 | transform=None) 316 | 317 | valloader = DataLoader( 318 | val_dataset, batch_size=1, shuffle=False, num_workers=4) # , collate_fn=my_collate) 319 | 320 | palette = [0, 0, 0, 204, 241, 227, 112, 142, 18, 254, 8, 23, 207, 149, 84, 202, 24, 214, 321 | 230, 192, 37, 241, 80, 68, 74, 127, 0, 2, 81, 216, 24, 240, 129, 20, 215, 125, 161, 31, 204, 322 | 254, 52, 116, 117, 198, 203, 4, 41, 68, 127, 252, 61, 21, 3, 142, 40, 10, 159, 241, 61, 36, 323 | 14, 175, 77, 144, 61, 115, 131, 79, 97, 109, 177, 163, 58, 198, 140, 17, 235, 168, 47, 128, 91, 324 | 238, 103, 45, 124, 35, 228, 101, 48, 232, 74, 124, 114, 78, 49, 30, 35, 167, 27, 137, 231, 47, 325 | 235, 32, 39, 56, 112, 32, 62, 173, 79, 86, 44, 201, 77, 47, 217, 246, 223, 57, ] 326 | # Pad with zeroes to 768 values, i.e. 256 RGB colours 327 | palette = palette + [0]*(768-len(palette)) 328 | 329 | for index, batch in enumerate(valloader): 330 | image, label, coord = batch 331 | index_str = str(index).zfill(3) 332 | 333 | image = image.cuda() 334 | coord = coord.cuda() 335 | with torch.no_grad(): 336 | pred = model(image, coord) 337 | pred = torch.argmax(pred, dim=1)[0] 338 | pred = np.array(pred.cpu(), np.uint8) 339 | pi = Image.fromarray(pred, 'P') 340 | pi.putpalette(palette) 341 | pi.show() 342 | pi.save(f"{save_dir}/pred_{index_str}.png") 343 | print(f"{index}/{len(valloader)}", end='\r') 344 | 345 | 346 | if __name__ == '__main__': 347 | main() 348 | 349 | """ 350 | CUDA_VISIBLE_DEVICES=3 python finetune_cnn_coord.py --n_shot 5 --num_epochs 100 --obj_name juice_bottle --snapshot_dir ./output/onetype_5shot_2level_selective & 351 | CUDA_VISIBLE_DEVICES=4 python finetune_cnn_coord.py --n_shot 5 --num_epochs 100 --obj_name splicing_connectors --snapshot_dir ./output/onetype_5shot_2level_selective & 352 | CUDA_VISIBLE_DEVICES=5 python finetune_cnn_coord.py --n_shot 5 --num_epochs 100 --obj_name pushpins --snapshot_dir ./output/onetype_5shot_2level_selective & 353 | CUDA_VISIBLE_DEVICES=6 python finetune_cnn_coord.py --n_shot 5 --num_epochs 100 --obj_name screw_bag --snapshot_dir ./output/onetype_5shot_3level & 354 | CUDA_VISIBLE_DEVICES=7 python finetune_cnn_coord.py --n_shot 5 --num_epochs 100 --obj_name breakfast_box --snapshot_dir ./output/onetype_5shot_2level_selective & 355 | """ 356 | 357 | 358 | # def get_dice(a, b): 359 | # dsc_list = [] 360 | # for i in range(a.shape[0]): 361 | # numer = (a[i]*b[i]).sum() 362 | # denom = a[i].sum()+b[i].sum() 363 | # dsc = (numer*2)/denom 364 | # dsc_list.append(dsc) 365 | # dsc_all = torch.tensor(dsc_list) 366 | # return dsc_all.mean() 367 | -------------------------------------------------------------------------------- /patchcore/pre_processing/tiler.py: -------------------------------------------------------------------------------- 1 | """Image Tiler.""" 2 | 3 | # Copyright (C) 2022 Intel Corporation 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | from itertools import product 7 | from math import ceil 8 | from typing import Optional, Sequence, Tuple, Union 9 | 10 | import torch 11 | import torchvision.transforms as T 12 | from torch import Tensor 13 | from torch.nn import functional as F 14 | 15 | 16 | class StrideSizeError(Exception): 17 | """StrideSizeError to raise exception when stride size is greater than the tile size.""" 18 | 19 | 20 | def compute_new_image_size(image_size: Tuple, tile_size: Tuple, stride: Tuple) -> Tuple: 21 | """This function checks if image size is divisible by tile size and stride. 22 | 23 | If not divisible, it resizes the image size to make it divisible. 24 | 25 | Args: 26 | image_size (Tuple): Original image size 27 | tile_size (Tuple): Tile size 28 | stride (Tuple): Stride 29 | 30 | Examples: 31 | >>> compute_new_image_size(image_size=(512, 512), tile_size=(256, 256), stride=(128, 128)) 32 | (512, 512) 33 | 34 | >>> compute_new_image_size(image_size=(512, 512), tile_size=(222, 222), stride=(111, 111)) 35 | (555, 555) 36 | 37 | Returns: 38 | Tuple: Updated image size that is divisible by tile size and stride. 39 | """ 40 | 41 | def __compute_new_edge_size(edge_size: int, tile_size: int, stride: int) -> int: 42 | """This function makes the resizing within the edge level.""" 43 | if (edge_size - tile_size) % stride != 0: 44 | edge_size = (ceil((edge_size - tile_size) / stride) * stride) + tile_size 45 | 46 | return edge_size 47 | 48 | resized_h = __compute_new_edge_size(image_size[0], tile_size[0], stride[0]) 49 | resized_w = __compute_new_edge_size(image_size[1], tile_size[1], stride[1]) 50 | 51 | return resized_h, resized_w 52 | 53 | 54 | def upscale_image(image: Tensor, size: Tuple, mode: str = "padding") -> Tensor: 55 | """Upscale image to the desired size via either padding or interpolation. 56 | 57 | Args: 58 | image (Tensor): Image 59 | size (Tuple): Tuple to which image is upscaled. 60 | mode (str, optional): Upscaling mode. Defaults to "padding". 61 | 62 | Examples: 63 | >>> image = torch.rand(1, 3, 512, 512) 64 | >>> image = upscale_image(image, size=(555, 555), mode="padding") 65 | >>> image.shape 66 | torch.Size([1, 3, 555, 555]) 67 | 68 | >>> image = torch.rand(1, 3, 512, 512) 69 | >>> image = upscale_image(image, size=(555, 555), mode="interpolation") 70 | >>> image.shape 71 | torch.Size([1, 3, 555, 555]) 72 | 73 | Returns: 74 | Tensor: Upscaled image. 75 | """ 76 | 77 | image_h, image_w = image.shape[2:] 78 | resize_h, resize_w = size 79 | 80 | if mode == "padding": 81 | pad_h = resize_h - image_h 82 | pad_w = resize_w - image_w 83 | 84 | image = F.pad(image, [0, pad_w, 0, pad_h]) 85 | elif mode == "interpolation": 86 | image = F.interpolate(input=image, size=(resize_h, resize_w)) 87 | else: 88 | raise ValueError(f"Unknown mode {mode}. Only padding and interpolation is available.") 89 | 90 | return image 91 | 92 | 93 | def downscale_image(image: Tensor, size: Tuple, mode: str = "padding") -> Tensor: 94 | """Opposite of upscaling. This image downscales image to a desired size. 95 | 96 | Args: 97 | image (Tensor): Input image 98 | size (Tuple): Size to which image is down scaled. 99 | mode (str, optional): Downscaling mode. Defaults to "padding". 100 | 101 | Examples: 102 | >>> x = torch.rand(1, 3, 512, 512) 103 | >>> y = upscale_image(image, upscale_size=(555, 555), mode="padding") 104 | >>> y = downscale_image(y, size=(512, 512), mode='padding') 105 | >>> torch.allclose(x, y) 106 | True 107 | 108 | Returns: 109 | Tensor: Downscaled image 110 | """ 111 | input_h, input_w = size 112 | if mode == "padding": 113 | image = image[:, :, :input_h, :input_w] 114 | else: 115 | image = F.interpolate(input=image, size=(input_h, input_w)) 116 | 117 | return image 118 | 119 | 120 | class Tiler: 121 | """Tile Image into (non)overlapping Patches. Images are tiled in order to efficiently process large images. 122 | 123 | Args: 124 | tile_size: Tile dimension for each patch 125 | stride: Stride length between patches 126 | remove_border_count: Number of border pixels to be removed from tile before untiling 127 | mode: Upscaling mode for image resize.Supported formats: padding, interpolation 128 | 129 | Examples: 130 | >>> import torch 131 | >>> from torchvision import transforms 132 | >>> from skimage.data import camera 133 | >>> tiler = Tiler(tile_size=256,stride=128) 134 | >>> image = transforms.ToTensor()(camera()) 135 | >>> tiles = tiler.tile(image) 136 | >>> image.shape, tiles.shape 137 | (torch.Size([3, 512, 512]), torch.Size([9, 3, 256, 256])) 138 | 139 | >>> # Perform your operations on the tiles. 140 | 141 | >>> # Untile the patches to reconstruct the image 142 | >>> reconstructed_image = tiler.untile(tiles) 143 | >>> reconstructed_image.shape 144 | torch.Size([1, 3, 512, 512]) 145 | """ 146 | 147 | def __init__( 148 | self, 149 | tile_size: Union[int, Sequence], 150 | stride: Optional[Union[int, Sequence]] = None, 151 | remove_border_count: int = 0, 152 | mode: str = "padding", 153 | tile_count: int = 4, 154 | ) -> None: 155 | 156 | self.tile_size_h, self.tile_size_w = self.__validate_size_type(tile_size) 157 | self.tile_count = tile_count 158 | 159 | if stride is not None: 160 | self.stride_h, self.stride_w = self.__validate_size_type(stride) 161 | 162 | self.remove_border_count = int(remove_border_count) 163 | self.overlapping = not (self.stride_h == self.tile_size_h and self.stride_w == self.tile_size_w) 164 | self.mode = mode 165 | 166 | if self.stride_h > self.tile_size_h or self.stride_w > self.tile_size_w: 167 | raise StrideSizeError( 168 | "Larger stride size than kernel size produces unreliable tiling results. " 169 | "Please ensure stride size is less than or equal than tiling size." 170 | ) 171 | 172 | if self.mode not in ["padding", "interpolation"]: 173 | raise ValueError(f"Unknown tiling mode {self.mode}. Available modes are padding and interpolation") 174 | 175 | self.batch_size: int 176 | self.num_channels: int 177 | 178 | self.input_h: int 179 | self.input_w: int 180 | 181 | self.pad_h: int 182 | self.pad_w: int 183 | 184 | self.resized_h: int 185 | self.resized_w: int 186 | 187 | self.num_patches_h: int 188 | self.num_patches_w: int 189 | 190 | @staticmethod 191 | def __validate_size_type(parameter: Union[int, Sequence]) -> Tuple[int, ...]: 192 | if isinstance(parameter, int): 193 | output = (parameter, parameter) 194 | elif isinstance(parameter, Sequence): 195 | output = (parameter[0], parameter[1]) 196 | else: 197 | raise ValueError(f"Unknown type {type(parameter)} for tile or stride size. Could be int or Sequence type.") 198 | 199 | if len(output) != 2: 200 | raise ValueError(f"Length of the size type must be 2 for height and width. Got {len(output)} instead.") 201 | 202 | return output 203 | 204 | def __random_tile(self, image: Tensor) -> Tensor: 205 | """Randomly crop tiles from the given image. 206 | 207 | Args: 208 | image: input image to be cropped 209 | 210 | Returns: Randomly cropped tiles from the image 211 | """ 212 | return torch.vstack([T.RandomCrop(self.tile_size_h)(image) for i in range(self.tile_count)]) 213 | 214 | def __unfold(self, tensor: Tensor) -> Tensor: 215 | """Unfolds tensor into tiles. 216 | 217 | This is the core function to perform tiling operation. 218 | 219 | Args: 220 | tensor: Input tensor from which tiles are generated. 221 | 222 | Returns: Generated tiles 223 | """ 224 | 225 | # identify device type based on input tensor 226 | device = tensor.device 227 | 228 | # extract and calculate parameters 229 | batch, channels, image_h, image_w = tensor.shape 230 | 231 | self.num_patches_h = int((image_h - self.tile_size_h) / self.stride_h) + 1 232 | self.num_patches_w = int((image_w - self.tile_size_w) / self.stride_w) + 1 233 | 234 | # create an empty torch tensor for output 235 | tiles = torch.zeros( 236 | (self.num_patches_h, self.num_patches_w, batch, channels, self.tile_size_h, self.tile_size_w), device=device 237 | ) 238 | 239 | # fill-in output tensor with spatial patches extracted from the image 240 | for (tile_i, tile_j), (loc_i, loc_j) in zip( 241 | product(range(self.num_patches_h), range(self.num_patches_w)), 242 | product( 243 | range(0, image_h - self.tile_size_h + 1, self.stride_h), 244 | range(0, image_w - self.tile_size_w + 1, self.stride_w), 245 | ), 246 | ): 247 | tiles[tile_i, tile_j, :] = tensor[ 248 | :, :, loc_i : (loc_i + self.tile_size_h), loc_j : (loc_j + self.tile_size_w) 249 | ] 250 | 251 | # rearrange the tiles in order [tile_count * batch, channels, tile_height, tile_width] 252 | tiles = tiles.permute(2, 0, 1, 3, 4, 5) 253 | tiles = tiles.contiguous().view(-1, channels, self.tile_size_h, self.tile_size_w) 254 | 255 | return tiles 256 | 257 | def __fold(self, tiles: Tensor) -> Tensor: 258 | """Fold the tiles back into the original tensor. 259 | 260 | This is the core method to reconstruct the original image from its tiled version. 261 | 262 | Args: 263 | tiles: Tiles from the input image, generated via __unfold method. 264 | 265 | Returns: 266 | Output that is the reconstructed version of the input tensor. 267 | """ 268 | # number of channels differs between image and anomaly map, so infer from input tiles. 269 | _, num_channels, tile_size_h, tile_size_w = tiles.shape 270 | scale_h, scale_w = (tile_size_h / self.tile_size_h), (tile_size_w / self.tile_size_w) 271 | # identify device type based on input tensor 272 | device = tiles.device 273 | # calculate tile size after borders removed 274 | reduced_tile_h = tile_size_h - (2 * self.remove_border_count) 275 | reduced_tile_w = tile_size_w - (2 * self.remove_border_count) 276 | # reconstructed image dimension 277 | image_size = (self.batch_size, num_channels, int(self.resized_h * scale_h), int(self.resized_w * scale_w)) 278 | 279 | # rearrange input tiles in format [tile_count, batch, channel, tile_h, tile_w] 280 | tiles = tiles.contiguous().view( 281 | self.batch_size, 282 | self.num_patches_h, 283 | self.num_patches_w, 284 | num_channels, 285 | tile_size_h, 286 | tile_size_w, 287 | ) 288 | tiles = tiles.permute(0, 3, 1, 2, 4, 5) 289 | tiles = tiles.contiguous().view(self.batch_size, num_channels, -1, tile_size_h, tile_size_w) 290 | tiles = tiles.permute(2, 0, 1, 3, 4) 291 | 292 | # remove tile borders by defined count 293 | tiles = tiles[ 294 | :, 295 | :, 296 | :, 297 | self.remove_border_count : reduced_tile_h + self.remove_border_count, 298 | self.remove_border_count : reduced_tile_w + self.remove_border_count, 299 | ] 300 | 301 | # create tensors to store intermediate results and outputs 302 | img = torch.zeros(image_size, device=device) 303 | lookup = torch.zeros(image_size, device=device) 304 | ones = torch.ones(reduced_tile_h, reduced_tile_w, device=device) 305 | 306 | # reconstruct image by adding patches to their respective location and 307 | # create a lookup for patch count in every location 308 | for patch, (loc_i, loc_j) in zip( 309 | tiles, 310 | product( 311 | range( 312 | self.remove_border_count, 313 | int(self.resized_h * scale_h) - reduced_tile_h + 1, 314 | int(self.stride_h * scale_h), 315 | ), 316 | range( 317 | self.remove_border_count, 318 | int(self.resized_w * scale_w) - reduced_tile_w + 1, 319 | int(self.stride_w * scale_w), 320 | ), 321 | ), 322 | ): 323 | img[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += patch 324 | lookup[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += ones 325 | 326 | # divide the reconstucted image by the lookup to average out the values 327 | img = torch.divide(img, lookup) 328 | # alternative way of removing nan values (isnan not supported by openvino) 329 | img[img != img] = 0 # pylint: disable=comparison-with-itself 330 | 331 | return img 332 | 333 | def tile(self, image: Tensor, use_random_tiling: Optional[bool] = False) -> Tensor: 334 | """Tiles an input image to either overlapping, non-overlapping or random patches. 335 | 336 | Args: 337 | image: Input image to tile. 338 | 339 | Examples: 340 | >>> from anomalib.data.tiler import Tiler 341 | >>> tiler = Tiler(tile_size=512,stride=256) 342 | >>> image = torch.rand(size=(2, 3, 1024, 1024)) 343 | >>> image.shape 344 | torch.Size([2, 3, 1024, 1024]) 345 | >>> tiles = tiler.tile(image) 346 | >>> tiles.shape 347 | torch.Size([18, 3, 512, 512]) 348 | 349 | Returns: 350 | Tiles generated from the image. 351 | """ 352 | if image.dim() == 3: 353 | image = image.unsqueeze(0) 354 | 355 | self.batch_size, self.num_channels, self.input_h, self.input_w = image.shape 356 | 357 | if self.input_h < self.tile_size_h or self.input_w < self.tile_size_w: 358 | raise ValueError( 359 | f"One of the edges of the tile size {self.tile_size_h, self.tile_size_w} " 360 | "is larger than that of the image {self.input_h, self.input_w}." 361 | ) 362 | 363 | self.resized_h, self.resized_w = compute_new_image_size( 364 | image_size=(self.input_h, self.input_w), 365 | tile_size=(self.tile_size_h, self.tile_size_w), 366 | stride=(self.stride_h, self.stride_w), 367 | ) 368 | 369 | image = upscale_image(image, size=(self.resized_h, self.resized_w), mode=self.mode) 370 | 371 | if use_random_tiling: 372 | image_tiles = self.__random_tile(image) 373 | else: 374 | image_tiles = self.__unfold(image) 375 | return image_tiles 376 | 377 | def untile(self, tiles: Tensor) -> Tensor: 378 | """Untiles patches to reconstruct the original input image. 379 | 380 | If patches, are overlapping patches, the function averages the overlapping pixels, 381 | and return the reconstructed image. 382 | 383 | Args: 384 | tiles: Tiles from the input image, generated via tile().. 385 | 386 | Examples: 387 | >>> from anomalib.datasets.tiler import Tiler 388 | >>> tiler = Tiler(tile_size=512,stride=256) 389 | >>> image = torch.rand(size=(2, 3, 1024, 1024)) 390 | >>> image.shape 391 | torch.Size([2, 3, 1024, 1024]) 392 | >>> tiles = tiler.tile(image) 393 | >>> tiles.shape 394 | torch.Size([18, 3, 512, 512]) 395 | >>> reconstructed_image = tiler.untile(tiles) 396 | >>> reconstructed_image.shape 397 | torch.Size([2, 3, 1024, 1024]) 398 | >>> torch.equal(image, reconstructed_image) 399 | True 400 | 401 | Returns: 402 | Output that is the reconstructed version of the input tensor. 403 | """ 404 | image = self.__fold(tiles) 405 | image = downscale_image(image=image, size=(self.input_h, self.input_w), mode=self.mode) 406 | 407 | return image 408 | --------------------------------------------------------------------------------