├── metric ├── __init__.py ├── mvtec3d │ └── __init__.py └── mvtec_loco │ └── src │ └── __init__.py ├── models ├── __init__.py ├── igd │ ├── __init__.py │ └── net_igd.py ├── graphcore │ ├── __init__.py │ ├── gcn_lib │ │ ├── __init__.py │ │ ├── pos_embed.py │ │ └── torch_nn.py │ └── net_graphcore.py ├── patchcore │ └── __init__.py ├── pointcore │ ├── __init__.py │ ├── descriptor │ │ ├── __init__.py │ │ ├── point_mlp │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── activation.py │ │ │ │ └── block.py │ │ │ ├── __init__.py │ │ │ └── point_mlp.py │ │ ├── dgcnn │ │ │ └── dgcnn.py │ │ └── neural_point │ │ │ └── neural_point.py │ └── pointcore.py ├── resnet │ ├── __init__.py │ └── resnet.py ├── _patchcore │ ├── __init__.py │ ├── sampling_base.py │ └── kcenter_greedy.py ├── cfa │ ├── __init__.py │ ├── net_cfa.py │ ├── coordconv.py │ └── cfa.py ├── net_csflow │ ├── __init__.py │ └── net_csflow.py ├── _example │ ├── __init__.py │ └── net_example.py ├── cutpaste │ ├── __init__.py │ ├── model.py │ └── density.py ├── reverse │ ├── __init__.py │ ├── net_reverse.py │ └── blocks.py ├── simplenet │ └── __init__.py ├── fastflow │ └── func.py ├── devnet │ └── devnet_resnet18.py ├── dra │ └── dra_resnet18.py ├── softpatch │ ├── multi_variate_gaussian.py │ └── sampler.py └── favae │ ├── func.py │ └── net_favae.py ├── tools ├── __init__.py ├── visualize.py ├── utils.py └── record_helper.py ├── data_io ├── __init__.py ├── fewshot.py ├── semi.py └── noisy.py ├── augmentation ├── __init__.py ├── type.py └── perlin.py ├── configuration ├── __init__.py ├── 2_train_base │ ├── federated_learning.yaml │ └── centralized_learning.yaml ├── 1_model_base │ ├── transferad.yaml │ ├── padim.yaml │ ├── _example.yaml │ ├── cfa.yaml │ ├── spade.yaml │ ├── _patchcore.yaml │ ├── fastflow.yaml │ ├── igd.yaml │ ├── favae.yaml │ ├── cutpaste.yaml │ ├── reverse.yaml │ ├── stpm.yaml │ ├── devnet.yaml │ ├── patchcore.yaml │ ├── draem.yaml │ ├── simplenet.yaml │ ├── dra.yaml │ ├── softpatch.yaml │ ├── graphcore.yaml │ ├── dne.yaml │ └── csflow.yaml ├── 3_dataset_base │ ├── mtd.yaml │ ├── coad.yaml │ ├── dagm.yaml │ ├── mpdd.yaml │ ├── _example.yaml │ ├── miadloco.yaml │ ├── mvtec2d.yaml │ ├── mvtec3d.yaml │ ├── visa.yaml │ ├── btad.yaml │ ├── mvtec2df3d.yaml │ └── mvtecloco.yaml ├── device.py ├── registration.py └── config.py ├── loss_function ├── __init__.py ├── reverse.py ├── deviation.py ├── binaryfocal.py ├── ssim.py └── focal.py ├── checkpoints ├── vit │ └── download.txt └── graphcore │ └── pretrain │ └── download.txt ├── .gitignore ├── paradigms ├── federated │ └── f2d.py └── centralized │ └── c3d.py ├── requirements.txt ├── arch ├── _example.py ├── fastflow.py ├── base.py ├── csflow.py ├── simplenet.py ├── cfa.py ├── patchcore.py ├── cutpaste.py ├── reverse.py ├── draem.py ├── softpatch.py ├── devnet.py ├── favae.py └── stpm.py ├── main.py └── dataset ├── mtd.py ├── _example.py ├── btad.py ├── mpdd.py ├── dagm.py ├── mvtec2df3d.py └── visa.py /metric/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_io/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/igd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configuration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/graphcore/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/patchcore/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/pointcore/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /loss_function/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/_patchcore/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/cfa/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/net_csflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metric/mvtec3d/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/_example/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /configuration/2_train_base/federated_learning.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/pointcore/descriptor/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/pointcore/descriptor/point_mlp/modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/pointcore/descriptor/dgcnn/dgcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn -------------------------------------------------------------------------------- /models/cutpaste/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | 5 | -------------------------------------------------------------------------------- /models/reverse/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | 5 | -------------------------------------------------------------------------------- /models/simplenet/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | 5 | -------------------------------------------------------------------------------- /models/pointcore/descriptor/point_mlp/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | 5 | -------------------------------------------------------------------------------- /checkpoints/vit/download.txt: -------------------------------------------------------------------------------- 1 | https://console.cloud.google.com/storage/browser/_details/vit_models/sam/ViT-B_16.npz -------------------------------------------------------------------------------- /configuration/1_model_base/transferad.yaml: -------------------------------------------------------------------------------- 1 | model: tad 2 | net: resnet18 3 | num_epochs: 100 4 | train_batch_size: 8 -------------------------------------------------------------------------------- /models/graphcore/gcn_lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_nn import * 2 | from .torch_edge import * 3 | from .torch_vertex import * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/ 2 | */*/__pycache__/ 3 | */*/*/__pycache__/ 4 | work_dir/ 5 | checkpoints/ 6 | .vscode/ 7 | .idea/ 8 | .history/ -------------------------------------------------------------------------------- /metric/mvtec_loco/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .image import * 3 | from .metrics import * 4 | from .util import * 5 | -------------------------------------------------------------------------------- /checkpoints/graphcore/pretrain/download.txt: -------------------------------------------------------------------------------- 1 | https://github.com/huawei-noah/Efficient-AI-Backbones/releases/tag/vig 2 | https://github.com/huawei-noah/Efficient-AI-Backbones/releases/tag/pyramid-vig -------------------------------------------------------------------------------- /models/_example/net_example.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class NetExample(nn.Module): 4 | def __init__(self, args): 5 | super(NetExample, self).__init__() 6 | self.args = args -------------------------------------------------------------------------------- /configuration/3_dataset_base/mtd.yaml: -------------------------------------------------------------------------------- 1 | dataset: mtd 2 | root_path: '/disk4/xgy' 3 | data_path: '/mtd' 4 | num_task: 1 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 -------------------------------------------------------------------------------- /configuration/3_dataset_base/coad.yaml: -------------------------------------------------------------------------------- 1 | dataset: coad 2 | root_path: '/disk4/xgy' 3 | data_path: '/coad' 4 | num_task: 21 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 -------------------------------------------------------------------------------- /configuration/3_dataset_base/dagm.yaml: -------------------------------------------------------------------------------- 1 | dataset: dagm 2 | root_path: '/disk4/xgy' 3 | data_path: '/dagm' 4 | num_task: 10 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 -------------------------------------------------------------------------------- /configuration/3_dataset_base/mpdd.yaml: -------------------------------------------------------------------------------- 1 | dataset: mpdd 2 | root_path: '/disk4/xgy' 3 | data_path: '/mpdd' 4 | num_task: 6 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 -------------------------------------------------------------------------------- /configuration/3_dataset_base/_example.yaml: -------------------------------------------------------------------------------- 1 | dataset: _example 2 | root_path: '/disk4/xgy' 3 | data_path: '/_example' 4 | num_task: 2 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 -------------------------------------------------------------------------------- /configuration/3_dataset_base/miadloco.yaml: -------------------------------------------------------------------------------- 1 | dataset: miadloco 2 | root_path: '/disk4/xgy' 3 | data_path: '/miadloco' 4 | num_task: 3 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 -------------------------------------------------------------------------------- /configuration/3_dataset_base/mvtec2d.yaml: -------------------------------------------------------------------------------- 1 | dataset: mvtec2d 2 | root_path: '/disk4/xgy' 3 | data_path: '/mvtec2d' 4 | num_task: 15 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 -------------------------------------------------------------------------------- /configuration/3_dataset_base/mvtec3d.yaml: -------------------------------------------------------------------------------- 1 | dataset: mvtec3d 2 | root_path: '/disk4/xgy' 3 | data_path: '/mvtec3d' 4 | num_task: 10 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 -------------------------------------------------------------------------------- /configuration/3_dataset_base/visa.yaml: -------------------------------------------------------------------------------- 1 | dataset: mvtec2d 2 | root_path: '/disk4/xgy' 3 | data_path: '/visa' 4 | num_task: 12 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 9 | seed: 66 -------------------------------------------------------------------------------- /configuration/3_dataset_base/btad.yaml: -------------------------------------------------------------------------------- 1 | dataset: btad 2 | root_path: '/disk4/xgy' 3 | data_path: '/btad' 4 | num_task: 3 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 9 | 10 | -------------------------------------------------------------------------------- /configuration/3_dataset_base/mvtec2df3d.yaml: -------------------------------------------------------------------------------- 1 | dataset: mvtec2df3d 2 | root_path: '/disk4/xgy' 3 | data_path: '/mvtec2df3d' 4 | num_task: 10 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 9 | -------------------------------------------------------------------------------- /configuration/1_model_base/padim.yaml: -------------------------------------------------------------------------------- 1 | model: padim 2 | net: resnet18 3 | num_epochs: 1 4 | train_batch_size: 32 5 | valid_batch_size: 1 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | data_size: 256 9 | data_crop_size: 256 10 | mask_size: 256 11 | mask_crop_size: 256 -------------------------------------------------------------------------------- /models/cfa/net_cfa.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from models.cfa.resnet import resnet18 3 | 4 | class NetCFA(nn.Module): 5 | def __init__(self, args): 6 | super(NetCFA, self).__init__() 7 | self.args = args 8 | 9 | self.resnet18 = resnet18(pretrained=True, progress=True) -------------------------------------------------------------------------------- /models/pointcore/pointcore.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['NetPointCore'] 5 | 6 | class NetPointCore(nn.Module): 7 | def __init__(self): 8 | super(NetPointCore).__init__() 9 | pass 10 | 11 | def forward(self, x): 12 | pass -------------------------------------------------------------------------------- /configuration/1_model_base/_example.yaml: -------------------------------------------------------------------------------- 1 | model: _example 2 | net: net_example 3 | data_size: 256 4 | data_crop_size: 256 5 | mask_size: 256 6 | mask_crop_size: 256 7 | num_epochs: 50 8 | train_batch_size: 4 9 | valid_batch_size: 1 10 | train_aug_type: normal 11 | valid_aug_type: normal 12 | -------------------------------------------------------------------------------- /configuration/3_dataset_base/mvtecloco.yaml: -------------------------------------------------------------------------------- 1 | dataset: mvtecloco 2 | root_path: '/disk4/xgy' 3 | data_path: '/mvtecloco' 4 | num_task: 5 5 | data_size: 256 6 | data_crop_size: 256 7 | mask_size: 256 8 | mask_crop_size: 256 9 | niceness: 19 10 | curve_max_distance: 0.01 11 | num_parallel_workers: 8 -------------------------------------------------------------------------------- /configuration/1_model_base/cfa.yaml: -------------------------------------------------------------------------------- 1 | model: cfa 2 | net: net_cfa 3 | data_size: 256 4 | data_crop_size: 256 5 | mask_size: 256 6 | mask_crop_size: 256 7 | num_epochs: 50 # 50 8 | train_batch_size: 4 9 | valid_batch_size: 1 10 | train_aug_type: normal 11 | valid_aug_type: normal 12 | 13 | gamma_c: 1 14 | gamma_d: 1 -------------------------------------------------------------------------------- /configuration/1_model_base/spade.yaml: -------------------------------------------------------------------------------- 1 | model: spade 2 | net: wide_resnet50 3 | num_epochs: 1 # 4 | train_batch_size: 8 # 32 5 | valid_batch_size: 1 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | data_size: 256 9 | data_crop_size: 256 10 | mask_size: 256 11 | mask_crop_size: 256 12 | 13 | _name: spade 14 | _top_k: 5 15 | -------------------------------------------------------------------------------- /models/pointcore/descriptor/neural_point/neural_point.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | import torch.nn as nn 4 | 5 | __all__ = ['NeuralPoint'] 6 | 7 | class NeuralPoint(nn.Module): 8 | def __init__(self): 9 | super(NeuralPoint, self).__init__() 10 | pass 11 | 12 | def forward(self, x): 13 | pass 14 | -------------------------------------------------------------------------------- /configuration/1_model_base/_patchcore.yaml: -------------------------------------------------------------------------------- 1 | model: _patchcore 2 | net: wide_resnet50 # resnet18/wide_resnet50 3 | sampler_percentage: 0.001 4 | n_neighbours: 9 5 | lr: 0.0001 6 | num_epochs: 1 7 | train_batch_size: 32 8 | valid_batch_size: 1 9 | # optimizer 10 | beta1: 0.5 11 | beta2: 0.999 12 | train_aug_type: normal 13 | valid_aug_type: normal 14 | 15 | -------------------------------------------------------------------------------- /models/pointcore/descriptor/point_mlp/point_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['PointMLP'] 5 | 6 | class PointMLP(nn.Module): 7 | def __init__(self, num_affinity_points): 8 | super(PointMLP).__init__() 9 | self.num_affinity_points = num_affinity_points 10 | 11 | def forward(self, x): 12 | pass -------------------------------------------------------------------------------- /paradigms/federated/f2d.py: -------------------------------------------------------------------------------- 1 | from configuration.device import assign_service 2 | from rich import print 3 | from tools.utils import * 4 | import yaml 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | class FederatedAD2D(): 10 | def __init__(self, args): 11 | self.args = args 12 | 13 | 14 | def run_work_flow(self): 15 | pass -------------------------------------------------------------------------------- /paradigms/centralized/c3d.py: -------------------------------------------------------------------------------- 1 | from configuration.device import assign_service 2 | from rich import print 3 | from tools.utils import * 4 | import yaml 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | class CentralizedAD3D(): 10 | def __init__(self, args): 11 | self.args = args 12 | 13 | 14 | def run_work_flow(self): 15 | pass -------------------------------------------------------------------------------- /configuration/1_model_base/fastflow.yaml: -------------------------------------------------------------------------------- 1 | model: fastflow 2 | num_epochs: 500 3 | net: net_fastflow 4 | data_size: 256 5 | data_crop_size: 256 6 | backbone_name: resnet18 7 | flow_steps: 8 8 | hidden_ratio: 1.0 9 | _optimizer_name: adam 10 | _base_lr: 0.001 11 | _weight_decay: 0.00001 12 | train_aug_type: normal 13 | valid_aug_type: normal 14 | train_batch_size: 32 15 | valid_batch_size: 1 -------------------------------------------------------------------------------- /models/igd/net_igd.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from models.igd.mvtec_module import twoin1Generator256, VisualDiscriminator256 3 | 4 | class NetIGD(nn.Module): 5 | def __init__(self, args): 6 | super(NetIGD, self).__init__() 7 | self.args = args 8 | 9 | self.g = twoin1Generator256(64, latent_dimension=self.args._latent_dimension) 10 | self.d = VisualDiscriminator256(64) -------------------------------------------------------------------------------- /configuration/1_model_base/igd.yaml: -------------------------------------------------------------------------------- 1 | model: igd 2 | net: net_igd 3 | num_epochs: 256 # 256 4 | train_batch_size: 16 5 | valid_batch_size: 1 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | 9 | _name: igd 10 | _latent_dimension: 128 11 | _data_range: 4.7579 12 | _max_epoch: 256 13 | _optimizer_name: adam 14 | _base_lr: 0.0001 15 | _gamma: 0.2 16 | _beta1: 0 17 | _beta2: 0.9 18 | _weight_decay: 0.000001 -------------------------------------------------------------------------------- /configuration/1_model_base/favae.yaml: -------------------------------------------------------------------------------- 1 | model: favae 2 | net: net_favae 3 | num_epochs: 100 # 100 4 | train_batch_size: 64 # 64 5 | valid_batch_size: 1 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | data_size: 128 9 | data_crop_size: 128 10 | mask_size: 128 11 | mask_crop_size: 128 12 | 13 | _name: favae 14 | _kld_weight: 1.0 15 | _optimizer_name: adam 16 | _weight_decay: 0.00001 17 | _base_lr: 0.005 18 | -------------------------------------------------------------------------------- /configuration/1_model_base/cutpaste.yaml: -------------------------------------------------------------------------------- 1 | model: cutpaste 2 | train_batch_size: 32 3 | valid_batch_size: 1 4 | net: vit_b_16 5 | num_epochs: 256 #256 6 | train_aug_type: cutpaste 7 | valid_aug_type: normal 8 | data_size: 224 9 | data_crop_size: 224 10 | mask_size: 224 11 | mask_crop_size: 224 12 | 13 | _name: cutpaste 14 | _base_lr: 0.0001 15 | _optimizer_name: adam 16 | _weight_decay: 0.00001 17 | _num_classes: 2 18 | _pretrained: True -------------------------------------------------------------------------------- /configuration/1_model_base/reverse.yaml: -------------------------------------------------------------------------------- 1 | model: reverse 2 | net: net_reverse 3 | num_epochs: 200 # 200 4 | train_batch_size: 8 # 32 5 | valid_batch_size: 1 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | data_size: 256 9 | data_crop_size: 256 10 | mask_size: 256 11 | mask_crop_size: 256 12 | 13 | _name: reverse 14 | _optimizer_name: adam 15 | _base_lr: 0.005 16 | _gamma: 0.2 17 | _beta1: 0.5 18 | _beta2: 0.999 19 | _weight_decay: 0.000001 -------------------------------------------------------------------------------- /configuration/1_model_base/stpm.yaml: -------------------------------------------------------------------------------- 1 | model: stpm 2 | net: resnet18 3 | num_epochs: 100 # 4 | train_batch_size: 8 # 32 5 | valid_batch_size: 1 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | data_size: 256 9 | data_crop_size: 256 10 | mask_size: 256 11 | mask_crop_size: 256 12 | 13 | _name: stpm 14 | _optimizer_name: sgd 15 | _weight_decay: 0.00001 # 0.00003; csflow: 0.00001 16 | _momentum: 0.9 17 | _warmup_epochs: 10 18 | _warmup_lr: 0 19 | _base_lr: 0.4 # 0.0001; csflow: 0.0002; revdis:0.005 20 | _final_lr: 0 -------------------------------------------------------------------------------- /configuration/device.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import fcntl 3 | import struct 4 | from configuration.registration import server_data 5 | 6 | 7 | def get_ip_address(ifname): 8 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 9 | info = fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', bytes(ifname[:15], 'utf-8'))) 10 | 11 | return socket.inet_ntoa(info[20:24]) 12 | 13 | def assign_service(moda='eno1'): 14 | # moda: eno1, lo 15 | ip = get_ip_address(moda) 16 | root_path = server_data[ip] 17 | 18 | return ip, root_path 19 | 20 | -------------------------------------------------------------------------------- /models/fastflow/func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['AverageMeter'] 5 | 6 | class AverageMeter: 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rich==12.4.4 2 | # numpy==1.19.2 3 | opencv_python==4.5.4.58 4 | scipy==1.5.4 5 | tqdm==4.62.3 6 | PyYAML==6.0 7 | # dgl==0.8.0post2 8 | open3d==0.15.2 9 | scikit-image==0.19.2 10 | scikit-learn==1.0.2 11 | munch==2.5.0 12 | # torch==1.10.1 13 | ninja==1.10.2.3 14 | matplotlib==3.5.1 15 | tifffile==2021.11.2 16 | imgaug==0.4.0 17 | kornia==0.6.5 18 | torchmetrics==0.9.1 19 | faiss-gpu==1.7.2 20 | timm==0.6.11 21 | efficientnet_pytorch==0.7.1 22 | FrEIA==0.2 23 | ignite==0.4.10 24 | einops==0.4.1 25 | torchprofile==0.0.4 26 | pytorch-msssim==0.2.1 -------------------------------------------------------------------------------- /loss_function/reverse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['reverse_loss'] 5 | 6 | def reverse_loss(a, b): 7 | #mse_loss = torch.nn.MSELoss() 8 | cos_loss = torch.nn.CosineSimilarity() 9 | loss = 0 10 | for item in range(len(a)): 11 | #print(a[item].shape) 12 | #print(b[item].shape) 13 | #loss += 0.1*mse_loss(a[item], b[item]) 14 | loss += torch.mean(1-cos_loss(a[item].view(a[item].shape[0],-1), 15 | b[item].view(b[item].shape[0],-1))) 16 | return loss -------------------------------------------------------------------------------- /models/devnet/devnet_resnet18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | class DevNetResNet18(nn.Module): 6 | def __init__(self): 7 | super(DevNetResNet18, self).__init__() 8 | self.net = models.resnet18(pretrained=True) 9 | 10 | def forward(self, x): 11 | x = self.net.conv1(x) 12 | x = self.net.bn1(x) 13 | x = self.net.relu(x) 14 | x = self.net.maxpool(x) 15 | x = self.net.layer1(x) 16 | x = self.net.layer2(x) 17 | x = self.net.layer3(x) 18 | x = self.net.layer4(x) 19 | return x -------------------------------------------------------------------------------- /loss_function/deviation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['DeviationLoss'] 6 | 7 | class DeviationLoss(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, y_pred, y_true): 12 | confidence_margin = 5. 13 | ref = torch.normal(mean=0., std=torch.full([5000], 1.)).to(y_pred.device) 14 | dev = (y_pred - torch.mean(ref)) / torch.std(ref) 15 | inlier_loss = torch.abs(dev) 16 | outlier_loss = torch.abs((confidence_margin - dev).clamp_(min=0.)) 17 | dev_loss = (1 - y_true) * inlier_loss + y_true * outlier_loss 18 | 19 | return torch.mean(dev_loss) 20 | -------------------------------------------------------------------------------- /configuration/1_model_base/devnet.yaml: -------------------------------------------------------------------------------- 1 | model: devnet 2 | net: resnet18 3 | semi: false 4 | semi_anomaly_num: 10 # _n_anomaly 5 | semi_overlap: false 6 | num_epochs: 30 # 50 7 | train_batch_size: 48 8 | valid_batch_size: 1 9 | train_aug_type: normal 10 | valid_aug_type: normal 11 | data_size: 448 12 | data_crop_size: 448 13 | mask_size: 448 14 | mask_crop_size: 448 15 | 16 | _name: devnet 17 | _batch_size: 48 18 | _steps_per_epoch: 20 19 | _ramdn_seed: 42 20 | _no-cuda: True 21 | _classname: 'capsule' 22 | _img_size: 448 23 | _n_scales: 2 24 | _criterion: 'deviation' 25 | _topk: 0.1 26 | _n_anomaly: 10 27 | _optimizer_name: adam 28 | _base_lr: 0.0002 29 | _weight_decay: 0.00001 30 | _step_size: 100 31 | _gamma: 0.1 -------------------------------------------------------------------------------- /models/_patchcore/sampling_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | 4 | __all__ = ['SamplingMethod'] 5 | class SamplingMethod(object): 6 | __metaclass__ = abc.ABCMeta 7 | 8 | @abc.abstractmethod 9 | def __init__(self, X, y, **kwargs): 10 | self.X = X 11 | self.y = y 12 | 13 | def flatten_X(self): 14 | shape = self.X.shape 15 | flat_X = self.X 16 | if len(shape) > 2: 17 | flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:]))) 18 | return flat_X 19 | 20 | 21 | @abc.abstractmethod 22 | def select_batch_(self): 23 | return 24 | 25 | def select_batch(self, **kwargs): 26 | return self.select_batch_(**kwargs) 27 | 28 | def to_dict(self): 29 | return None -------------------------------------------------------------------------------- /configuration/1_model_base/patchcore.yaml: -------------------------------------------------------------------------------- 1 | model: patchcore 2 | net: wide_resnet50 # resnet18/wide_resnet50 3 | data_size: 256 4 | data_crop_size: 256 5 | mask_size: 256 6 | mask_crop_size: 256 7 | num_epochs: 50 8 | train_batch_size: 4 9 | valid_batch_size: 1 10 | train_aug_type: normal 11 | valid_aug_type: normal 12 | 13 | _input_shape: 14 | - 3 15 | - 256 16 | - 256 17 | _faiss_on_gpu: True 18 | _faiss_num_workers: 8 19 | _layers_to_extract_from: 20 | - layer2 21 | - layer3 22 | _pretrain_embed_dimension: 1024 23 | _target_embed_dimension: 1024 24 | _anomaly_scorer_num_nn: 1 25 | _patch_size: 3 26 | 27 | # identity, greedy_coreset, approx_greedy_coreset 28 | _sampler_name: approx_greedy_coreset 29 | sampler_percentage: 0.1 -------------------------------------------------------------------------------- /models/pointcore/descriptor/point_mlp/modules/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['get_activation'] 6 | 7 | def get_activation(activation): 8 | if activation.lower() == 'gelu': 9 | return nn.GELU() 10 | elif activation.lower() == 'rrelu': 11 | return nn.RReLU(inplace=True) 12 | elif activation.lower() == 'selu': 13 | return nn.SELU(inplace=True) 14 | elif activation.lower() == 'silu': 15 | return nn.SiLU(inplace=True) 16 | elif activation.lower() == 'hardswish': 17 | return nn.Hardswish(inplace=True) 18 | elif activation.lower() == 'leakyrelu': 19 | return nn.LeakyReLU(inplace=True) 20 | else: 21 | return nn.ReLU(inplace=True) -------------------------------------------------------------------------------- /arch/_example.py: -------------------------------------------------------------------------------- 1 | from arch.base import ModelBase 2 | from models._example.net_example import NetExample 3 | 4 | 5 | __all__ = ['Example'] 6 | 7 | class ModelExample(ModelBase): 8 | def __init__(self, config): 9 | super(ModelExample, self).__init__(config) 10 | self.config = config 11 | self.net = NetExample(self.config) 12 | 13 | def train_model(self, train_loader, task_id, inf=''): 14 | pass 15 | 16 | def prediction(self, valid_loader, task_id=None): 17 | # implement these for test 18 | self.img_pred_list = [] # list 19 | self.img_gt_list = [] # list 20 | self.pixel_pred_list = [] # list 21 | self.pixel_gt_list = [] # list 22 | self.img_path_list = [] # list -------------------------------------------------------------------------------- /configuration/1_model_base/draem.yaml: -------------------------------------------------------------------------------- 1 | model: draem 2 | train_batch_size: 16 # 8 3 | valid_batch_size: 1 4 | net: net_draem 5 | num_epochs: 50 # 700 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | data_size: 256 9 | data_crop_size: 256 10 | mask_size: 256 11 | mask_crop_size: 256 12 | 13 | _name: draem 14 | _pretrained: True 15 | _use_dis: False 16 | _fix_head: True 17 | _save_anormal : True 18 | _n_feat: 304 19 | _fc_internal: 1024 20 | _n_coupling_blocks: 4 21 | _clamp: 3 22 | _n_scales: 3 23 | 24 | _optimizer_name: adam 25 | _weight_decay: 0.00003 26 | _momentum: 0.9 27 | _warmup_epochs: 10 28 | _warmup_lr: 0 29 | _base_lr: 0.0001 30 | _gamma: 0.2 31 | _final_lr: 0 32 | _test_epochs: 10 33 | _alpha: 0.4 34 | _beta: 0.5 35 | _num_classes: 2 36 | _eval_classifier: density # head, density 37 | _visualization: False 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /configuration/1_model_base/simplenet.yaml: -------------------------------------------------------------------------------- 1 | model: simplenet 2 | net: wide_resnet50 3 | data_size: 256 # 288 4 | data_crop_size: 256 5 | mask_size: 256 6 | mask_crop_size: 256 7 | num_epochs: 1 # 40 8 | train_batch_size: 4 9 | valid_batch_size: 1 10 | train_aug_type: normal 11 | valid_aug_type: normal 12 | 13 | _layers_to_extract_from: 14 | - layer2 15 | - layer3 16 | _pretrain_embed_dimension: 1536 17 | _target_embed_dimension: 1536 18 | _patchsize: 3 19 | _patchstride: 1 20 | _embedding_size: None # 256 21 | # _meta_epochs: 40 # num_epochs 22 | _aed_meta_epochs: 1 23 | _gan_epochs: 4 # 4 24 | _noise_std: 0.015 25 | _mix_noise: 1 26 | _noise_type: 'GAU' 27 | _dsc_layers: 2 28 | _dsc_hidden: 1024 # 1024 29 | _dsc_margin: .8 # .5 30 | _dsc_lr: 0.0002 31 | _train_backbone: False 32 | _auto_noise: 0 33 | _cos_lr: False 34 | _lr: 1e-3 35 | _pre_proj: 1 # 1 36 | _proj_layer_type: 0 -------------------------------------------------------------------------------- /configuration/1_model_base/dra.yaml: -------------------------------------------------------------------------------- 1 | model: dra 2 | net: net_dra 3 | semi: True 4 | semi_anomaly_num: 2 # _nAnomaly: 10 5 | semi_overlap: false 6 | num_epochs: 1 # 30 7 | train_batch_size: 48 8 | valid_batch_size: 1 9 | train_aug_type: normal 10 | valid_aug_type: normal 11 | data_size: 224 12 | data_crop_size: 224 13 | mask_size: 224 14 | mask_crop_size: 224 15 | ref_num: 5 # _nRef 16 | 17 | _name: dra 18 | _batch_size: 48 19 | _steps_per_epoch: 20 20 | _cont_rate: 0 21 | _test_threshold: 0 22 | _test_rate: 0 23 | _ramdn_seed: 42 24 | _no-cuda: True 25 | _classname: 'capsule' 26 | _img_size: 224 27 | # _nAnomaly: 10 28 | _n_scales: 2 29 | _criterion: 'deviation' 30 | _topk: 0.1 31 | _know_class: None 32 | _pretrain_dir: None 33 | _total_heads: 4 34 | # _nRef: 5 35 | _outlier_root: None 36 | _optimizer_name: adam 37 | _base_lr: 0.0002 38 | _weight_decay: 0.00001 39 | _step_size: 10 40 | _gamma: 0.1 -------------------------------------------------------------------------------- /configuration/1_model_base/softpatch.yaml: -------------------------------------------------------------------------------- 1 | model: softpatch 2 | net: wide_resnet50 # resnet18/wide_resnet50 3 | data_size: 256 4 | data_crop_size: 256 5 | mask_size: 256 6 | mask_crop_size: 256 7 | num_epochs: 50 8 | train_batch_size: 4 9 | valid_batch_size: 1 10 | train_aug_type: normal 11 | valid_aug_type: normal 12 | 13 | _input_shape: 14 | - 3 15 | - 256 16 | - 256 17 | _faiss_on_gpu: True 18 | _faiss_num_workers: 8 19 | _layers_to_extract_from: 20 | - layer2 21 | - layer3 22 | _pretrain_embed_dimension: 1024 23 | _target_embed_dimension: 1024 24 | _anomaly_scorer_num_nn: 1 25 | _patch_size: 3 26 | 27 | # identity, greedy_coreset, approx_greedy_coreset, weighted_greedy_coreset 28 | _sampler_name: weighted_greedy_coreset 29 | sampler_percentage: 0.1 30 | 31 | _lof_k: 5 32 | _threshold: 0.15 33 | # 'lof', 'nearest', 'gaussian' or 'lof_gpu' 34 | _weight_method: gaussian 35 | _soft_weight_flag: True -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from configuration.config import parse_arguments_main 2 | from paradigms.centralized.c2d import CentralizedAD2D 3 | from paradigms.centralized.c3d import CentralizedAD3D 4 | from paradigms.federated.f2d import FederatedAD2D 5 | 6 | 7 | def main(args): 8 | # centralized learning for 2d anomaly detection 9 | if args.paradigm == 'c2d': 10 | work = CentralizedAD2D(args=args) 11 | work.run_work_flow() 12 | 13 | # centralized learning for 3d anomaly detection 14 | if args.paradigm == 'c3d': 15 | work = CentralizedAD3D(args=args) 16 | work.run_work_flow() 17 | 18 | # federated learning for 2d anomaly detection 19 | if args.paradigm == 'f2d': 20 | work = FederatedAD2D(args=args) 21 | work.run_work_flow() 22 | 23 | 24 | if __name__ == '__main__': 25 | args = parse_arguments_main() 26 | main(args) -------------------------------------------------------------------------------- /loss_function/binaryfocal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __all__ = ['BinaryFocalLoss'] 7 | 8 | class BinaryFocalLoss(nn.Module): 9 | def __init__(self, alpha=1, gamma=2, logits=True, reduce=True): 10 | super(BinaryFocalLoss, self).__init__() 11 | self.alpha = alpha 12 | self.gamma = gamma 13 | self.logits = logits 14 | self.reduce = reduce 15 | 16 | def forward(self, inputs, targets): 17 | if self.logits: 18 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') 19 | else: 20 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none') 21 | pt = torch.exp(-BCE_loss) 22 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 23 | 24 | if self.reduce: 25 | return torch.mean(F_loss) 26 | else: 27 | return F_loss -------------------------------------------------------------------------------- /data_io/fewshot.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | from torch.utils.data import Dataset 4 | 5 | 6 | __all__ = ['FewShot', 'extract_fewshot_data'] 7 | 8 | class FewShot(Dataset): 9 | def __init__(self, data) -> None: 10 | self.data = data 11 | 12 | def __getitem__(self, idx): 13 | return self.data[idx] 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def extract_fewshot_data(train_dataset, fewshot_exm=1): 19 | fewshot_exm_src = fewshot_exm 20 | # construct train_fewshot_dataset 21 | train_fewshot_dataset = copy.deepcopy(train_dataset) 22 | for i, num in enumerate(train_dataset.sample_num_in_task): 23 | if fewshot_exm > num: 24 | fewshot_exm = num 25 | chosen_samples = random.sample(train_fewshot_dataset.sample_indices_in_task[i], fewshot_exm) 26 | train_fewshot_dataset.sample_indices_in_task[i] = chosen_samples 27 | train_fewshot_dataset.sample_num_in_task[i] = fewshot_exm 28 | fewshot_exm = fewshot_exm_src 29 | 30 | return train_fewshot_dataset 31 | 32 | -------------------------------------------------------------------------------- /configuration/2_train_base/centralized_learning.yaml: -------------------------------------------------------------------------------- 1 | # centralzied learning 2 | learning_mode: 'centralized' 3 | # vanilla learning 4 | vanilla: false 5 | # semi-supervised learning 6 | semi: false 7 | semi_anomaly_num: 5 8 | semi_overlap: false 9 | # reference 10 | ref_num: 5 11 | # continual learning 12 | continual: false 13 | # fewshot learning 14 | fewshot: false 15 | fewshot_exm: 5 16 | fewshot_num_dg: 4 17 | fewshot_data_aug: false 18 | fewshot_feat_aug: false 19 | fewshot_aug_type: ['normal'] 20 | # noisy label 21 | noisy: false 22 | noisy_ratio: 0.1 23 | noisy_overlap: false 24 | # transfer 25 | transfer: false 26 | transfer_target_sample_num: 8 27 | gpu_id: 0 28 | gpu_ids: 29 | - '01234567' 30 | num_workers: 8 31 | # work file 32 | work_dir: ./work_dir 33 | save_log: true 34 | debug: false 35 | batch_limit: 2 36 | seed: 66 37 | # chosen task ids 38 | train_task_id: 39 | - 0 40 | valid_task_id: 41 | - 0 42 | # running task id 43 | train_task_id_tmp: 0 44 | valid_task_id_tmp: 0 45 | # vis img 46 | vis: false 47 | vis_num: 10 48 | # vis embedding 49 | vis_em: false 50 | server_moda: eno1 -------------------------------------------------------------------------------- /configuration/1_model_base/graphcore.yaml: -------------------------------------------------------------------------------- 1 | model: graphcore 2 | net: vig_ti_224_gelu 3 | data_size: 224 4 | data_crop_size: 224 5 | mask_size: 224 6 | mask_crop_size: 224 7 | pretrained: True 8 | train_batch_size: 32 9 | valid_batch_size: 1 10 | train_aug_type: normal 11 | valid_aug_type: normal 12 | num_epochs: 1 13 | n_neighbours: 9 14 | sampler_percentage: 0.001 15 | #####vig parameters##### 16 | drop_rate: 0 17 | drop_path_rate: 0 18 | drop_connect_rate: 0 # legacy issue, deprecated, use drop_path 19 | drop_block_rate: 0 20 | gp: None # global pooling 21 | bn_tf: False # bn ad, use tensorflow batchnorm defaults for models that support it 22 | bn_momentum: None 23 | bn_eps: None 24 | checkpoint_path: './checkpoints/graphcore/pretrain/' 25 | local_smoothing: False 26 | #####optimizer###### 27 | opt: sgd 28 | opt_eps: None 29 | #opt_betas: None 30 | momentum: 0.9 31 | weight_decay: 0.0001 32 | clip_grad: None 33 | lr: 0.01 34 | lr_noise: None 35 | lr_noise_pct: 0.67 36 | lr_noise_std: 1.0 37 | lr_cycle_mul: 1.0 38 | lr_cycle_limit: 1 39 | warmup_lr: 0.0001 40 | min_lr: 1e-5 41 | #######scheduler###### 42 | sched: step 43 | start_epoch: 0 44 | decay_epochs: 30 45 | warmup_epochs: 3 46 | cooldown_epochs: 10 47 | patience_epochs: 10 48 | decay_rate: 0.1 49 | epochs: 200 50 | layer_num_1: 3 51 | layer_num_2: 4 52 | layer_num_3: 5 53 | 54 | 55 | -------------------------------------------------------------------------------- /tools/visualize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | from tools.utils import create_folders 5 | 6 | __all__ = ['cv2heatmap', 'heatmap_on_image', 'min_max_norm', 'save_anomaly_map'] 7 | 8 | 9 | def cv2heatmap(gray): 10 | heatmap = cv2.applyColorMap(np.uint8(gray), cv2.COLORMAP_JET) 11 | return heatmap 12 | 13 | 14 | def heatmap_on_image(heatmap, image): 15 | if heatmap.shape != image.shape: 16 | heatmap = cv2.resize(heatmap, (image.shape[0], image.shape[1])) 17 | out = np.float32(heatmap)/255 + np.float32(image)/255 18 | out = out / np.max(out) 19 | return np.uint8(255 * out) 20 | 21 | 22 | def min_max_norm(image): 23 | a_min, a_max = image.min(), image.max() 24 | return (image-a_min)/(a_max - a_min) 25 | 26 | 27 | def save_anomaly_map(anomaly_map, input_img, mask, file_path): 28 | if anomaly_map.shape != input_img.shape: 29 | anomaly_map = cv2.resize(anomaly_map, (input_img.shape[0], input_img.shape[1])) 30 | 31 | anomaly_map_norm = min_max_norm(anomaly_map) 32 | heatmap = cv2heatmap(anomaly_map_norm * 255) 33 | 34 | heatmap_on_img = heatmap_on_image(heatmap, input_img) 35 | create_folders(file_path) 36 | 37 | cv2.imwrite(os.path.join(file_path, 'input.jpg'), input_img) 38 | cv2.imwrite(os.path.join(file_path, 'heatmap.jpg'), heatmap) 39 | cv2.imwrite(os.path.join(file_path, 'heatmap_on_img.jpg'), heatmap_on_img) 40 | cv2.imwrite(os.path.join(file_path, 'mask.jpg'), mask * 255) 41 | -------------------------------------------------------------------------------- /configuration/1_model_base/dne.yaml: -------------------------------------------------------------------------------- 1 | model: dne 2 | train_batch_size: 32 3 | valid_batch_size: 1 4 | net: vit_b16 # resnet, vit, net_csflow, net_draem, net_revdis 5 | num_epochs: 50 # 50 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | data_size: 224 9 | data_crop_size: 224 10 | mask_size: 224 11 | mask_crop_size: 224 12 | 13 | # _name: seq-mtd-mvtec # seq-mvtec or seq-mtd-mvtec 14 | _image_size: 224 # 224; draem, revdis: 256; csflow: 768; 15 | # _num_workers: 4 16 | _data_incre_setting: mul # one: 10+1+1+1+1+1 mul: 3+3+3+3+3 17 | _n_classes_per_task: 3 # one_class_incre:1 mul_class_incre: 3 18 | _n_tasks: 6 # seq-mtd-mvtec, one_class_incre:6, mul_class_incre: 5 19 | _dataset_order: 1 # 1, 2, 3 20 | _strong_augmentation: True # strong augmentation: cutpaste, maskimg, etc.; weak augmentation: ColorJitter, RandomRotation, etc. 21 | _random_aug: False 22 | 23 | _name: dne # panda, dis, cutpaste, csflow, draem, revdis, upper 24 | # _net: vit # resnet, vit, net_csflow, net_draem, net_revdis 25 | _pretrained: True 26 | # plug in 27 | _use_dis: False 28 | # dis, discat 29 | _fix_head: False 30 | _save_anormal : True 31 | # cflow 32 | _n_feat: 304 33 | _fc_internal: 1024 34 | _n_coupling_blocks: 4 35 | _clamp: 3 36 | _n_scales: 3 37 | 38 | _optimizer_name: adam 39 | _weight_decay: 0.00003 # 0.00003; csflow: 0.00001 40 | _momentum: 0.9 41 | _warmup_epochs: 10 42 | _warmup_lr: 0 43 | _base_lr: 0.0001 # 0.0001 44 | _final_lr: 0 45 | # _num_epochs: 50 46 | _batch_size: 32 # 32 47 | _test_epochs: 10 48 | _alpha: 0.4 49 | _beta: 0.5 50 | _num_classes: 2 51 | 52 | _eval_classifier: density # head, density 53 | _batch_size: 32 # 32 54 | _visualization: False 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /configuration/1_model_base/csflow.yaml: -------------------------------------------------------------------------------- 1 | model: csflow 2 | train_batch_size: 16 3 | valid_batch_size: 1 4 | net: net_csflow # resnet, vit, net_csflow, net_draem, net_revdis 5 | num_epochs: 240 # 4*60 6 | train_aug_type: normal 7 | valid_aug_type: normal 8 | data_size: 768 9 | data_crop_size: 768 10 | mask_size: 768 11 | mask_crop_size: 768 12 | 13 | # _name: seq-mvtec # seq-mvtec or seq-mtd-mvtec 14 | _image_size: 768 # 224; draem, revdis: 256; csflow: 768; 15 | _data_incre_setting: mul # one: 10+1+1+1+1+1 mul: 3+3+3+3+3 16 | _n_classes_per_task: 3 # one_class_incre:1 mul_class_incre: 3 17 | _n_tasks: 5 # seq-mtd-mvtec, one_class_incre:6, mul_class_incre: 5 18 | _dataset_order: 1 # 1, 2, 3 19 | _strong_augmentation: True # strong augmentation: cutpaste, maskimg, etc.; weak augmentation: ColorJitter, RandomRotation, etc. 20 | _random_aug: False 21 | 22 | _name: csflow # panda, dis, cutpaste, csflow, draem, revdis, upper 23 | # _net: net_csflow # resnet, vit, net_csflow, net_draem, net_revdis 24 | _pretrained: True 25 | # panda 26 | _use_dis: False 27 | # dis, discat 28 | _fix_head: True 29 | _save_anormal : True 30 | # cflow 31 | _n_feat: 304 32 | _fc_internal: 1024 33 | _n_coupling_blocks: 4 34 | _clamp: 3 35 | _n_scales: 3 36 | 37 | _optimizer_name: adam 38 | _weight_decay: 0.00001 # 0.00003; csflow: 0.00001 39 | _momentum: 0.9 40 | _warmup_epochs: 10 41 | _warmup_lr: 0 42 | _base_lr: 0.0002 # 0.0001; csflow: 0.0002; revdis:0.005 43 | _final_lr: 0 44 | # _num_epochs: 50 45 | _test_epochs: 10 46 | _alpha: 0.4 47 | _beta: 0.5 48 | _num_classes: 2 49 | 50 | _eval_classifier: density # head, density 51 | # _batch_size: 16 # 32, revdis,draem:1, csflow:16 52 | _visualization: False 53 | 54 | 55 | -------------------------------------------------------------------------------- /models/cutpaste/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import resnet18 3 | 4 | __all__ = ['ProjectionNet'] 5 | 6 | class ProjectionNet(nn.Module): 7 | def __init__(self, pretrained=True, head_layers=[512,512,512,512,512,512,512,512,128], num_classes=2): 8 | super(ProjectionNet, self).__init__() 9 | #self.resnet18 = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=pretrained) 10 | self.resnet18 = resnet18(pretrained=pretrained) 11 | 12 | # create MPL head as seen in the code in: https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 13 | # TODO: check if this is really the right architecture 14 | last_layer = 512 15 | sequential_layers = [] 16 | for num_neurons in head_layers: 17 | sequential_layers.append(nn.Linear(last_layer, num_neurons)) 18 | sequential_layers.append(nn.BatchNorm1d(num_neurons)) 19 | sequential_layers.append(nn.ReLU(inplace=True)) 20 | last_layer = num_neurons 21 | 22 | #the last layer without activation 23 | 24 | head = nn.Sequential( 25 | *sequential_layers 26 | ) 27 | self.resnet18.fc = nn.Identity() 28 | self.head = head 29 | self.out = nn.Linear(last_layer, num_classes) 30 | 31 | def forward(self, x): 32 | embeds = self.resnet18(x) 33 | tmp = self.head(embeds) 34 | logits = self.out(tmp) 35 | return embeds, logits 36 | 37 | def freeze_resnet(self): 38 | # freez full resnet18 39 | for param in self.resnet18.parameters(): 40 | param.requires_grad = False 41 | 42 | #unfreeze head: 43 | for param in self.resnet18.fc.parameters(): 44 | param.requires_grad = True 45 | 46 | def unfreeze(self): 47 | #unfreeze all: 48 | for param in self.parameters(): 49 | param.requires_grad = True -------------------------------------------------------------------------------- /arch/fastflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from arch.base import ModelBase 4 | from models.fastflow.net import NetFastFlow 5 | from models.fastflow.func import AverageMeter 6 | from optimizer.optimizer import get_optimizer 7 | 8 | __all__ = ['FastFlow'] 9 | 10 | class FastFlow(ModelBase): 11 | def __init__(self, config): 12 | super(FastFlow, self).__init__(config) 13 | self.config = config 14 | 15 | self.net = NetFastFlow(self.config).to(self.device) 16 | self.optimizer = get_optimizer(self.config, self.net.parameters()) 17 | self.scheduler = None 18 | 19 | def train_model(self, train_loader, inf=''): 20 | self.net.train() 21 | self.clear_all_list() 22 | loss_meter = AverageMeter() 23 | for epoch in range(self.config['num_epochs']): 24 | for batch_id, batch in enumerate(train_loader): 25 | img = batch['img'].to(self.device) 26 | ret = self.net(img) 27 | loss = ret['loss'] 28 | # backward 29 | self.optimizer.zero_grad() 30 | loss.backward() 31 | self.optimizer.step() 32 | loss_meter.update(loss.item()) 33 | 34 | def prediction(self, valid_loader, task_id=None): 35 | self.net.eval() 36 | self.clear_all_list() 37 | 38 | with torch.no_grad(): 39 | for batch_id, batch in enumerate(valid_loader): 40 | img = batch['img'].to(self.device) 41 | mask = batch['mask'] 42 | label = batch['label'] 43 | mask[mask>=0.5] = 1 44 | mask[mask<0.5] = 0 45 | mask_np = mask.numpy()[0,0].astype(int) 46 | 47 | ret = self.net(img) 48 | 49 | outputs = ret["anomaly_map"].cpu().detach().numpy() 50 | self.pixel_gt_list.append(mask_np) 51 | self.pixel_pred_list.append(outputs[0,0,:,:]) 52 | self.img_gt_list.append(label.numpy()[0]) 53 | self.img_pred_list.append(np.max(outputs)) 54 | self.img_path_list.append(batch['img_src']) 55 | -------------------------------------------------------------------------------- /augmentation/type.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as T 2 | from augmentation.cutpaste_aug import * 3 | 4 | __all__ = ['aug_type'] 5 | 6 | def aug_type(augment_type, args): 7 | if augment_type == 'normal': 8 | img_transform = T.Compose([T.Resize((args['data_size'], args['data_size'])), 9 | T.CenterCrop(args['data_crop_size']), 10 | T.ToTensor(), 11 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 12 | ]) 13 | 14 | mask_transform = T.Compose([T.Resize(args['mask_size']), 15 | T.CenterCrop(args['mask_crop_size']), 16 | T.ToTensor(), 17 | ]) 18 | 19 | elif augment_type == 'cutpaste': 20 | after_cutpaste_transform = T.Compose([T.RandomRotation(90), 21 | T.ToTensor(), 22 | T.Normalize(mean=[0.485, 0.456, 0.406], 23 | std=[0.229, 0.224, 0.225]) 24 | ]) 25 | 26 | img_transform = T.Compose([T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 27 | T.Resize((args['data_crop_size'], args['data_crop_size'])), 28 | CutPasteNormal(transform=after_cutpaste_transform) 29 | #T.RandomChoice([CutPasteNormal(transform=after_cutpaste_transform), 30 | # CutPasteScar(transform=after_cutpaste_transform)]) 31 | ]) 32 | 33 | mask_transform = T.Compose([T.Resize(args['mask_size']), 34 | T.CenterCrop(args['mask_crop_size']), 35 | T.ToTensor(), 36 | ]) 37 | else: 38 | raise NotImplementedError('The Augmentation Type Has Not Been Implemented Yet') 39 | 40 | 41 | 42 | return img_transform, mask_transform 43 | -------------------------------------------------------------------------------- /models/dra/dra_resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | import torch.nn.functional as F 5 | 6 | 7 | class DraResNet18(nn.Module): 8 | def __init__(self): 9 | super(DraResNet18, self).__init__() 10 | self.net = models.resnet18(pretrained=True) 11 | 12 | def forward(self, x): 13 | x = self.net.conv1(x) 14 | x = self.net.bn1(x) 15 | x = self.net.relu(x) 16 | x = self.net.maxpool(x) 17 | x = self.net.layer1(x) 18 | x = self.net.layer2(x) 19 | x = self.net.layer3(x) 20 | x = self.net.layer4(x) 21 | return x 22 | 23 | class HolisticHead(nn.Module): 24 | def __init__(self, in_dim, dropout=0): 25 | super(HolisticHead, self).__init__() 26 | self.fc1 = nn.Linear(in_dim, 256) 27 | self.fc2 = nn.Linear(256, 1) 28 | self.drop = nn.Dropout(dropout) 29 | 30 | def forward(self, x): 31 | x = F.adaptive_avg_pool2d(x, (1, 1)) 32 | x = x.view(x.size(0), -1) 33 | x = self.drop(F.relu(self.fc1(x))) 34 | x = self.fc2(x) 35 | return torch.abs(x) 36 | 37 | class PlainHead(nn.Module): 38 | def __init__(self, in_dim, topk_rate=0.1): 39 | super(PlainHead, self).__init__() 40 | self.scoring = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1, padding=0) 41 | self.topk_rate = topk_rate 42 | 43 | def forward(self, x): 44 | x = self.scoring(x) 45 | x = x.view(int(x.size(0)), -1) 46 | topk = max(int(x.size(1) * self.topk_rate), 1) 47 | x = torch.topk(torch.abs(x), topk, dim=1)[0] 48 | x = torch.mean(x, dim=1).view(-1, 1) 49 | return x 50 | 51 | class CompositeHead(PlainHead): 52 | def __init__(self, in_dim, topk=0.1): 53 | super(CompositeHead, self).__init__(in_dim, topk) 54 | self.conv = nn.Sequential(nn.Conv2d(in_dim, in_dim, 3, padding=1), 55 | nn.BatchNorm2d(in_dim), 56 | nn.ReLU()) 57 | 58 | def forward(self, x, ref): 59 | ref = torch.mean(ref, dim=0).repeat([x.size(0), 1, 1, 1]) 60 | x = ref - x 61 | x = self.conv(x) 62 | x = super().forward(x) 63 | return x 64 | 65 | 66 | -------------------------------------------------------------------------------- /arch/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tools.utils import parse_device_list, seed_everything 3 | from metric.cal_metric import CalMetric 4 | from tools.record_helper import RecordHelper 5 | 6 | class ModelBase(): 7 | def __init__(self, config): 8 | self.config = config 9 | self.device, device_ids = parse_device_list(self.config['gpu_ids'], int(self.config['gpu_id'])) 10 | seed_everything(self.config['seed']) 11 | 12 | # for training 13 | self.net = None 14 | self.optimizer = None 15 | self.scheduler = None 16 | 17 | # for test 18 | self.img_pred_list = [] # list 19 | self.img_gt_list = [] # list 20 | self.pixel_pred_list = [] # list 21 | self.pixel_gt_list = [] # list 22 | self.img_path_list = [] # list 23 | 24 | # for computing result 25 | self.metric = CalMetric(self.config) 26 | # for recording result 27 | self.recorder = RecordHelper(self.config) 28 | 29 | def train_model(self, train_loader, task_id, inf=''): 30 | pass 31 | 32 | def train_epoch(self, train_loader, task_id, inf=''): 33 | pass 34 | 35 | def prediction(self, valid_loader, task_id=None): 36 | pass 37 | 38 | def visualization(self, vis_loader, task_id=None): 39 | self.clear_all_list() 40 | 41 | self.prediction(vis_loader, task_id) 42 | if len(self.pixel_gt_list)!=0 : 43 | self.recorder.record_images(self.img_pred_list, self.img_gt_list, 44 | self.pixel_pred_list, self.pixel_gt_list, 45 | self.img_path_list) 46 | 47 | def clear_all_list(self): 48 | self.img_pred_list = [] 49 | self.img_gt_list = [] 50 | self.pixel_pred_list = [] 51 | self.pixel_gt_list = [] 52 | self.img_path_list = [] 53 | 54 | def cal_metric_all(self, task_id): 55 | # Logica AD Evaluation Needs Task ID and File Path 56 | return self.metric.cal_metric(self.img_pred_list, self.img_gt_list, 57 | self.pixel_pred_list, self.pixel_gt_list, 58 | self.img_path_list, task_id, self.config['file_path']) 59 | -------------------------------------------------------------------------------- /configuration/registration.py: -------------------------------------------------------------------------------- 1 | # setting 2 | setting_name = ['vanilla', 'fewshot', 'semi', 'noisy', 'continual', 'transfer'] 3 | 4 | # add new dataset 5 | dataset_name = {'_example': ('dataset._example', '_example', 'Example'), 6 | 'mvtec2d': ('dataset.mvtec2d', 'mvtec2d', 'MVTec2D'), 7 | 'mvtec2df3d': ('dataset.mvtec2df3d', 'mvtec2df3d', 'MVTec2DF3D'), 8 | 'mvtecloco': ('dataset.mvtecloco', 'mvtecloco', 'MVTecLoco'), 9 | 'mpdd': ('dataset.mpdd', 'mpdd', 'MPDD'), 10 | 'btad': ('dataset.btad', 'btad', 'BTAD'), 11 | 'mtd': ('dataset.mtd', 'mtd', 'MTD'), 12 | 'mvtec3d': ('dataset.mvtec3d', 'mvtec3d', 'MVTec3D'), 13 | 'visa': ('dataset.visa', 'visa', 'VisA'), 14 | 'dagm': ('dataset.dagm', 'dagm', 'DAGM'), 15 | 'coad': ('dataset.coad', 'coad', 'COAD'), 16 | } 17 | 18 | # add new model 19 | model_name = {'_example': ('arch._example', '_example', 'Example'), 20 | '_patchcore': ('arch._patchcore', '_patchcore', 'PatchCore'), 21 | 'patchcore': ('arch.patchcore', 'patchcore', 'PatchCore'), 22 | 'padim': ('arch.padim', 'padim', 'PaDim'), 23 | 'csflow': ('arch.csflow', 'csflow', 'CSFlow'), 24 | 'dne': ('arch.dne', 'dne', 'DNE'), 25 | 'draem': ('arch.draem', 'draem', 'DRAEM'), 26 | 'igd': ('arch.igd', 'igd', 'IGD'), 27 | 'dra': ('arch.dra', 'dra', 'DRA'), 28 | 'devnet': ('arch.devnet', 'devnet', 'DevNet'), 29 | 'favae': ('arch.favae', 'favae', 'FAVAE'), 30 | 'fastflow': ('arch.fastflow', 'fastflow', 'FastFlow'), 31 | 'cfa': ('arch.cfa', 'cfa', 'CFA'), 32 | 'reverse': ('arch.reverse', 'reverse', 'REVERSE'), 33 | 'spade': ('arch.spade', 'spade', 'SPADE'), 34 | 'stpm': ('arch.stpm', 'stpm', 'STPM'), 35 | 'cutpaste': ('arch.cutpaste', 'cutpaste', 'CutPaste'), 36 | 'graphcore': ('arch.graphcore', 'graphcore', 'GraphCore'), 37 | 'simplenet': ('arch.simplenet', 'simplenet', 'SimpleNet'), 38 | 'softpatch': ('arch.softpatch', 'softpatch', 'SoftPatch'), 39 | } 40 | 41 | # server config, ip: dataset root path 42 | server_data = {'127.0.0.1': '/home/robot/data', 43 | '172.18.36.108': '/ssd2/m3lab/data/open-ad', 44 | '172.18.36.107': '/ssd-sata1/wjb/data/open-ad', 45 | } -------------------------------------------------------------------------------- /models/graphcore/net_graphcore.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.graphcore.pyramid_vig import * 4 | from models.graphcore.vig import * 5 | from timm.models import create_model 6 | from tools.utils import * 7 | from torchprofile import profile_macs 8 | 9 | __all__ = ['NetGraphCore'] 10 | 11 | def graphcore_ck_name(model_name, ck_path): 12 | model_name_splits = model_name.split('_') 13 | print(model_name_splits) 14 | if model_name_splits[0] == 'pvig': 15 | ck_name = ck_path+model_name_splits[0]+ '_' + model_name_splits[1] + '.pth.tar' 16 | elif model_name_splits[0] == 'vig': 17 | ck_name = ck_path+model_name_splits[0]+ '_' + model_name_splits[1] + '.pth' 18 | else: 19 | raise FileNotFoundError 20 | 21 | return ck_name 22 | 23 | class NetGraphCore(nn.Module): 24 | def __init__(self, config): 25 | super().__init__() 26 | self.config = config 27 | create_folders(self.config.checkpoint_path) 28 | 29 | self.model = create_model(self.config.net, 30 | pretrained=self.config.pretrained, 31 | num_classes=1000, 32 | drop_rate=self.config.drop_rate, 33 | drop_path_rate=self.config.drop_path_rate, 34 | drop_block_rate=self.config.drop_block_rate, 35 | global_pool=self.config.gp, 36 | bn_tf=self.config.bn_tf, 37 | bn_momentum=self.config.bn_momentum, 38 | bn_eps=self.config.bn_eps) 39 | 40 | # Loading pretrained model 41 | if self.config.checkpoint_path is not None: 42 | ck_name = graphcore_ck_name(self.config.net, self.config.checkpoint_path) 43 | print('Loading:', ck_name) 44 | state_dict = torch.load(ck_name) 45 | self.model.load_state_dict(state_dict, strict=False) 46 | print('Pretrain weights loaded') 47 | 48 | # Flops Calculation 49 | #print(self.model) 50 | if hasattr(self.model, 'default_cfg'): 51 | default_cfg = self.model.default_cfg 52 | input_size = [1] + list(default_cfg['input_size']) 53 | else: 54 | input_size = [1, 3, 224, 224] 55 | 56 | input = torch.randn(input_size) 57 | 58 | self.model.eval() 59 | macs = profile_macs(self.model, input) 60 | self.model.train() 61 | print('model flops:', macs, 'input_size:', input_size) 62 | 63 | -------------------------------------------------------------------------------- /models/pointcore/descriptor/point_mlp/modules/block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from point_mlp.modules.activation import get_activation 4 | 5 | __all__ = ['MLP1D'] 6 | 7 | class MLP1D(nn.Module): 8 | def __init__(self, inc, ouc, ks=1, bias=True, activation='relu'): 9 | super(MLP1D).__init__() 10 | self.inc = inc 11 | self.ouc = ouc 12 | self.ks = ks 13 | self.bias = bias 14 | self.activation = get_activation(activation) 15 | self.net = nn.Sequential( 16 | nn.Conv1d(input_channles=self.inc, out_channels=self.ouc, kernel_size=self.ks, bias=self.bias), 17 | nn.BatchNorm1d(num_features=self.ouc), 18 | self.activation) 19 | 20 | 21 | def forward(self, x): 22 | output = self.net(x) 23 | return output 24 | 25 | class MLPRes1D(nn.Module): 26 | 27 | def __init__(self, inc, ks=1, bias=True, groups=1, res_expansion=1.0, activation='relu'): 28 | super(MLPRes1D).__init__() 29 | self.inc = inc 30 | self.ks = ks 31 | self.bias = bias 32 | self.activation = get_activation(activation) 33 | self.res_expansion = res_expansion 34 | self.groups = groups 35 | 36 | # main branch first part -- the part before group convolution 37 | self.net1 = nn.Sequential( 38 | nn.Conv1d(input_channles=self.inc, out_channels=int(self.inc * self.res_expansion), 39 | kernel_size=self.ks, groups=self.groups, bias=self.bias), 40 | nn.BatchNorm1d(num_features=int(self.inc * self.res_expansion)), 41 | self.activation) 42 | 43 | # main branch second part -- the group convolution part 44 | if self.groups > 1: 45 | self.net2 = nn.Sequential( 46 | nn.Conv1d(in_channels=int(self.inc*self.res_expansion), out_channels=self.inc, 47 | kernel_size=self.ks, groups=self.groups, bias=self.bias), 48 | nn.BatchNorm1d(self.inc), 49 | self.act, 50 | nn.Conv1d(in_channels=self.inc, out_channels=self.inc, kernel_size=self.ks, 51 | bias=self.bias), 52 | nn.BatchNorm1d(self.inc) 53 | ) 54 | else: 55 | self.net2 = nn.Sequential( 56 | nn.Conv1d(in_channels=int(self.inc*self.res_expansion), out_channels=self.inc, 57 | kernel_size=self.ks, bias=self.bias), 58 | nn.BatchNorm1d(self.inc) 59 | ) 60 | 61 | def forward(self, x): 62 | output_main_branch = self.net2(self.net1(x)) 63 | output = self.act(output_main_branch + x) 64 | return output -------------------------------------------------------------------------------- /models/cfa/coordconv.py: -------------------------------------------------------------------------------- 1 | import torch.nn.modules.conv as conv 2 | import torch.nn as nn 3 | import torch 4 | 5 | __all__ = ['CoordConv2d', 'AddCoords'] 6 | 7 | class CoordConv2d(conv.Conv2d): 8 | def __init__(self, in_channels, out_channels, kernel_size, device, stride=1, 9 | padding=0, dilation=1, groups=1, bias=True, with_r=False): 10 | super(CoordConv2d, self).__init__(in_channels, out_channels, kernel_size, 11 | stride, padding, dilation, groups, bias) 12 | self.rank = 2 13 | self.addcoords = AddCoords(self.rank, device, with_r) 14 | self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels, 15 | kernel_size, stride, padding, dilation, groups, bias) 16 | 17 | def forward(self, input_tensor): 18 | out = self.addcoords(input_tensor) 19 | #self.conv.to(self.device) 20 | out = self.conv(out) 21 | 22 | return out 23 | 24 | class AddCoords(nn.Module): 25 | def __init__(self, rank, device, with_r=False): 26 | super(AddCoords, self).__init__() 27 | self.rank = rank 28 | self.with_r = with_r 29 | self.device = device 30 | 31 | def forward(self, input_tensor): 32 | batch_size_shape, _, dim_y, dim_x = input_tensor.shape 33 | xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) 34 | yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) 35 | 36 | xx_range = torch.arange(dim_y, dtype=torch.int32) 37 | yy_range = torch.arange(dim_x, dtype=torch.int32) 38 | xx_range = xx_range[None, None, :, None] 39 | yy_range = yy_range[None, None, :, None] 40 | 41 | xx_channel = torch.matmul(xx_range, xx_ones) 42 | yy_channel = torch.matmul(yy_range, yy_ones) 43 | 44 | yy_channel = yy_channel.permute(0, 1, 3, 2) 45 | 46 | xx_channel = xx_channel.float() / (dim_y - 1) 47 | yy_channel = yy_channel.float() / (dim_x - 1) 48 | 49 | xx_channel = xx_channel * 2 - 1 50 | yy_channel = yy_channel * 2 - 1 51 | 52 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) 53 | yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) 54 | 55 | if torch.cuda.is_available: 56 | input_tensor = input_tensor.to(self.device) 57 | xx_channel = xx_channel.to(self.device) 58 | yy_channel = yy_channel.to(self.device) 59 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 60 | 61 | if self.with_r: 62 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) 63 | out = torch.cat([out, rr], dim=1) 64 | 65 | return out -------------------------------------------------------------------------------- /arch/csflow.py: -------------------------------------------------------------------------------- 1 | from __future__ import nested_scopes 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | import argparse 6 | 7 | from arch.base import ModelBase 8 | from models.net_csflow.net_csflow import NetCSFlow 9 | from optimizer.optimizer import get_optimizer 10 | 11 | __all__ = ['CSFlow'] 12 | 13 | class _CSFlow(nn.Module): 14 | def __init__(self, args, net, optimizer, scheduler): 15 | super(_CSFlow, self).__init__() 16 | self.args = args 17 | self.optimizer = optimizer 18 | self.scheduler = scheduler 19 | self.net = net 20 | self.net.feature_extractor.eval() 21 | 22 | def forward(self, epoch, inputs): 23 | self.optimizer.zero_grad() 24 | embeds, z, log_jac_det = self.net(inputs) 25 | # yy, rev_y, zz = self.net.revward(inputs) 26 | loss = torch.mean(0.5 * torch.sum(z ** 2, dim=(1,)) - log_jac_det) / z.shape[1] 27 | 28 | loss.backward() 29 | self.optimizer.step() 30 | if self.scheduler is not None: 31 | self.scheduler.step(epoch) 32 | 33 | class CSFlow(ModelBase): 34 | def __init__(self, config): 35 | super(CSFlow, self).__init__(config) 36 | self.config = config 37 | args = argparse.Namespace(**self.config) 38 | self.net = NetCSFlow(args) 39 | self.optimizer = get_optimizer(self.config, self.net.density_estimator.parameters()) 40 | self.model = _CSFlow(args, self.net, self.optimizer, self.scheduler).to(self.device) 41 | 42 | def train_model(self, train_loader, task_id, inf=''): 43 | self.net.density_estimator.train() 44 | 45 | for epoch in range(self.config['num_epochs']): 46 | for batch_id, batch in enumerate(train_loader): 47 | inputs = batch['img'].to(self.device) 48 | self.model(epoch, inputs) 49 | 50 | def prediction(self, valid_loader, task_id): 51 | self.net.eval() 52 | self.clear_all_list() 53 | 54 | test_z, test_labels = [], [] 55 | with torch.no_grad(): 56 | for batch_id, batch in enumerate(valid_loader): 57 | inputs = batch['img'].to(self.device) 58 | labels = batch['label'].to(self.device) 59 | 60 | _, z, jac = self.net(inputs) 61 | z = z[..., None].cpu().data.numpy() 62 | score = np.mean(z ** 2, axis=(1, 2)) 63 | test_z.append(score) 64 | test_labels.append(labels.cpu().data.numpy()) 65 | self.img_path_list.append(batch['img_src']) 66 | 67 | test_labels = np.concatenate(test_labels) 68 | is_anomaly = np.array([0 if l == 0 else 1 for l in test_labels]) 69 | anomaly_score = np.concatenate(test_z, axis=0) 70 | self.img_gt_list = is_anomaly 71 | self.img_pred_list = anomaly_score -------------------------------------------------------------------------------- /models/cutpaste/density.py: -------------------------------------------------------------------------------- 1 | from sklearn.covariance import LedoitWolf 2 | from sklearn.neighbors import KernelDensity 3 | import torch 4 | 5 | __all__ = ['Density', 'GaussianDensityTorch', 'GaussianDensitySklearn'] 6 | 7 | class Density(object): 8 | def fit(self, embeddings): 9 | raise NotImplementedError 10 | 11 | def predict(self, embeddings): 12 | raise NotImplementedError 13 | 14 | 15 | class GaussianDensityTorch(object): 16 | """Gaussian Density estimation similar to the implementation used by Ripple et al. 17 | The code of Ripple et al. can be found here: https://github.com/ORippler/gaussian-ad-mvtec. 18 | """ 19 | def fit(self, embeddings): 20 | self.mean = torch.mean(embeddings, axis=0) 21 | self.inv_cov = torch.Tensor(LedoitWolf().fit(embeddings.cpu()).precision_,device="cpu") 22 | return self.mean, self.inv_cov 23 | 24 | def predict(self, embeddings): 25 | distances = self.mahalanobis_distance(embeddings, self.mean, self.inv_cov) 26 | return distances 27 | 28 | @staticmethod 29 | def mahalanobis_distance( 30 | values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor 31 | ): 32 | """Compute the batched mahalanobis distance. 33 | values is a batch of feature vectors. 34 | mean is either the mean of the distribution to compare, or a second 35 | batch of feature vectors. 36 | inv_covariance is the inverse covariance of the target distribution. 37 | 38 | from https://github.com/ORippler/gaussian-ad-mvtec/blob/4e85fb5224eee13e8643b684c8ef15ab7d5d016e/src/gaussian/model.py#L308 39 | """ 40 | assert values.dim() == 2 41 | assert 1 <= mean.dim() <= 2 42 | assert len(inv_covariance.shape) == 2 43 | assert values.shape[1] == mean.shape[-1] 44 | assert mean.shape[-1] == inv_covariance.shape[0] 45 | assert inv_covariance.shape[0] == inv_covariance.shape[1] 46 | 47 | if mean.dim() == 1: # Distribution mean. 48 | mean = mean.unsqueeze(0) 49 | x_mu = values - mean # batch x features 50 | # Same as dist = x_mu.t() * inv_covariance * x_mu batch wise 51 | dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu) 52 | return dist.sqrt() 53 | 54 | class GaussianDensitySklearn(): 55 | """Li et al. use sklearn for density estimation. 56 | This implementation uses sklearn KernelDensity module for fitting and predicting. 57 | """ 58 | def fit(self, embeddings): 59 | # estimate KDE parameters 60 | # use grid search cross-validation to optimize the bandwidth 61 | self.kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(embeddings) 62 | 63 | def predict(self, embeddings): 64 | scores = self.kde.score_samples(embeddings) 65 | 66 | # invert scores, so they fit to the class labels for the auc calculation 67 | scores = -scores 68 | 69 | return scores 70 | -------------------------------------------------------------------------------- /models/graphcore/gcn_lib/pos_embed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | __all__ = ['get_2d_relative_pos_embed', 'get_2d_sincos_pos_embed', 4 | 'get_2d_sincos_pos_embed_from_grid', 'get_1d_sincos_pos_embed_from_grid'] 5 | # -------------------------------------------------------- 6 | # relative position embedding 7 | # References: https://arxiv.org/abs/2009.13658 8 | # -------------------------------------------------------- 9 | def get_2d_relative_pos_embed(embed_dim, grid_size): 10 | """ 11 | grid_size: int of the grid height and width 12 | return: 13 | pos_embed: [grid_size*grid_size, grid_size*grid_size] 14 | """ 15 | pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size) 16 | relative_pos = 2 * np.matmul(pos_embed, pos_embed.transpose()) / pos_embed.shape[1] 17 | return relative_pos 18 | 19 | 20 | # -------------------------------------------------------- 21 | # 2D sine-cosine position embedding 22 | # References: 23 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 24 | # MoCo v3: https://github.com/facebookresearch/moco-v3 25 | # -------------------------------------------------------- 26 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 27 | """ 28 | grid_size: int of the grid height and width 29 | return: 30 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 31 | """ 32 | grid_h = np.arange(grid_size, dtype=np.float32) 33 | grid_w = np.arange(grid_size, dtype=np.float32) 34 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 35 | grid = np.stack(grid, axis=0) 36 | 37 | grid = grid.reshape([2, 1, grid_size, grid_size]) 38 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 39 | if cls_token: 40 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 41 | return pos_embed 42 | 43 | 44 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 45 | assert embed_dim % 2 == 0 46 | 47 | # use half of dimensions to encode grid_h 48 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 49 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 50 | 51 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 52 | return emb 53 | 54 | 55 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 56 | """ 57 | embed_dim: output dimension for each position 58 | pos: a list of positions to be encoded: size (M,) 59 | out: (M, D) 60 | """ 61 | assert embed_dim % 2 == 0 62 | omega = np.arange(embed_dim // 2, dtype=np.float) 63 | omega /= embed_dim / 2. 64 | omega = 1. / 10000**omega # (D/2,) 65 | 66 | pos = pos.reshape(-1) # (M,) 67 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 68 | 69 | emb_sin = np.sin(out) # (M, D/2) 70 | emb_cos = np.cos(out) # (M, D/2) 71 | 72 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 73 | return emb -------------------------------------------------------------------------------- /models/resnet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet18, resnet34 5 | 6 | 7 | __all__ = ['ResNetModel'] 8 | 9 | class ResNetModel(nn.Module): 10 | def __init__(self, pretrained=True, head_layers=[512, 512, 512, 512, 512, 512, 512, 512, 128], num_classes=2): 11 | super(ResNetModel, self).__init__() 12 | # self.resnet18 = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=pretrained) 13 | self.backbone = resnet18(pretrained=pretrained) 14 | 15 | # create MPL head as seen in the code in: https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 16 | # TODO: check if this is really the right architecture 17 | last_layer = 512 18 | sequential_layers = [] 19 | for num_neurons in head_layers: 20 | sequential_layers.append(nn.Linear(last_layer, num_neurons)) 21 | sequential_layers.append(nn.BatchNorm1d(num_neurons)) 22 | sequential_layers.append(nn.ReLU(inplace=True)) 23 | last_layer = num_neurons 24 | 25 | head = nn.Sequential( 26 | *sequential_layers 27 | ) 28 | self.backbone.fc = nn.Identity() 29 | self.head = nn.Sequential( 30 | head, 31 | nn.Linear(last_layer, num_classes) 32 | ) 33 | 34 | self.feature_extractor = torch.nn.Sequential(*(list(self.backbone.children())[:7])) 35 | self.dim_redu = torch.nn.Sequential(*(list(self.backbone.children())[7:9])) 36 | 37 | def forward_features(self, x): 38 | embeds = self.backbone(x) 39 | return embeds 40 | 41 | def forward(self, x): 42 | embeds = self.forward_features(x) 43 | logits = self.head(embeds) 44 | 45 | dim4_embeds = self.feature_extractor(x) # (64, 256, 14, 14) 46 | tmp_embeds = self.dim_redu(dim4_embeds) # (64, 512, 1, 1) 47 | dim2_embeds = torch.flatten(tmp_embeds, 1) #与embeds值相同,(64, 512) 48 | 49 | return embeds, logits 50 | 51 | def freeze_resnet(self): 52 | # freez full resnet 53 | for param in self.backbone.parameters(): 54 | param.requires_grad = False 55 | # unfreeze head: 56 | for param in self.backbone.fc.parameters(): 57 | param.requires_grad = True 58 | 59 | def unfreeze(self): 60 | # unfreeze all: 61 | for param in self.parameters(): 62 | param.requires_grad = True 63 | 64 | def freeze_parameters(self, train_fc=False): 65 | for p in self.backbone.conv1.parameters(): 66 | p.requires_grad = False 67 | for p in self.backbone.bn1.parameters(): 68 | p.requires_grad = False 69 | for p in self.backbone.layer1.parameters(): 70 | p.requires_grad = False 71 | for p in self.backbone.layer2.parameters(): 72 | p.requires_grad = False 73 | if not train_fc: 74 | for p in self.backbone.fc.parameters(): 75 | p.requires_grad = False 76 | -------------------------------------------------------------------------------- /data_io/semi.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | 5 | __all__ = ['extract_semi_data'] 6 | 7 | def extract_semi_data(train_dataset, valid_dataset, anomaly_num=10, anomaly_overlap=False, upper_ratio=0.75): 8 | valid_sample_nums = [0] + valid_dataset.sample_num_in_task 9 | valid_sample_indice = valid_dataset.sample_indices_in_task 10 | 11 | # obtain semi label in validation set 12 | semi_indices = [] 13 | for i in range(valid_dataset.num_task): 14 | anomaly_index = [] 15 | for k, j in enumerate(range(valid_sample_nums[i], valid_sample_nums[i] + valid_sample_nums[i+1])): 16 | z = sum(valid_sample_nums[:i+1]) + k 17 | label = valid_dataset.labels_list[z] 18 | if label == 1: 19 | anomaly_index.append(valid_sample_indice[i][k]) 20 | # set semi data to be less than 50 percent of these in test set 21 | anomaly_num_max = int(len(anomaly_index) * upper_ratio) 22 | if anomaly_num >= anomaly_num_max: 23 | anomaly_num = anomaly_num_max 24 | semi_index = random.sample(anomaly_index, anomaly_num) 25 | 26 | semi_indices.append(semi_index) 27 | 28 | # construct valid_semi_dataset 29 | valid_semi_dataset = copy.deepcopy(valid_dataset) 30 | if not anomaly_overlap: 31 | for task_id in range(valid_dataset.num_task): 32 | valid_semi_dataset.sample_indices_in_task[task_id] = list(set(valid_sample_indice[task_id]) - set(semi_indices[task_id])) 33 | valid_semi_dataset.sample_num_in_task[task_id] = len(valid_semi_dataset.sample_indices_in_task[task_id]) 34 | 35 | # construct semi_dataset 36 | semi_dataset = copy.deepcopy(valid_dataset) 37 | for task_id in range(valid_dataset.num_task): 38 | semi_dataset.sample_indices_in_task[task_id] = semi_indices[task_id] 39 | semi_dataset.sample_num_in_task[task_id] = len(semi_dataset.sample_indices_in_task[task_id]) 40 | 41 | # construct train_semi_dataset 42 | train_semi_dataset = copy.deepcopy(train_dataset) 43 | for task_id in range(valid_dataset.num_task): 44 | for img_id in semi_indices[task_id]: 45 | train_semi_dataset.imgs_list.append(semi_dataset.imgs_list[img_id]) 46 | train_semi_dataset.labels_list.append(semi_dataset.labels_list[img_id]) 47 | train_semi_dataset.masks_list.append(semi_dataset.masks_list[img_id]) 48 | train_semi_dataset.task_ids_list.append(semi_dataset.task_ids_list[img_id]) 49 | 50 | for task_id in range(train_dataset.num_task): 51 | semi_indices = [] 52 | for i in range(semi_dataset.sample_num_in_task[task_id]): 53 | local_idx = i + len(train_dataset.imgs_list) + sum(semi_dataset.sample_num_in_task[:task_id]) 54 | semi_indices.append(int(local_idx)) 55 | 56 | train_semi_dataset.sample_indices_in_task[task_id].extend(semi_indices) 57 | train_semi_dataset.sample_num_in_task[task_id] += semi_dataset.sample_num_in_task[task_id] 58 | 59 | return train_semi_dataset, valid_semi_dataset, semi_dataset -------------------------------------------------------------------------------- /models/reverse/net_reverse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Type, Any, Callable, Union, List, Optional 4 | from models.reverse.blocks import * 5 | from models.reverse.encoder import * 6 | from models.reverse.decoder import * 7 | try: 8 | from torch.hub import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | __all__ = ['enc_wide_resnet_50_2', 'dec_wide_resnet_50_2'] 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 24 | } 25 | 26 | def _resnet( 27 | arch: str, 28 | block: Type[Union[EncBasicBlock, EncBottleneck, DecBasicBlock, DecBottleneck]], 29 | layers: List[int], 30 | pretrained: bool, 31 | progress: bool, 32 | phase: str, 33 | **kwargs: Any 34 | ): 35 | if phase == 'encode': 36 | model = EncResNet(block, layers, **kwargs) 37 | elif phase == 'decode': 38 | model = DecResNet(block, layers, **kwargs) 39 | if pretrained: 40 | state_dict = load_state_dict_from_url(model_urls[arch], 41 | progress=progress) 42 | #for k,v in list(state_dict.items()): 43 | # if 'layer4' in k or 'fc' in k: 44 | # state_dict.pop(k) 45 | model.load_state_dict(state_dict) 46 | return model 47 | 48 | def enc_wide_resnet_50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any): 49 | kwargs['width_per_group'] = 64 * 2 50 | return _resnet('wide_resnet50_2', EncBottleneck, [3, 4, 6, 3], 51 | pretrained, progress, phase='encode', **kwargs) 52 | 53 | def dec_wide_resnet_50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any): 54 | kwargs['width_per_group'] = 64 * 2 55 | return _resnet('wide_resnet50_2', DecBottleneck, [3, 4, 6, 3], 56 | pretrained, progress, phase='decode', **kwargs) 57 | 58 | def bn_layer(**kwargs): 59 | return BNLayer(AttnBottleneck, 3, **kwargs) 60 | 61 | class NetReverse(nn.Module): 62 | def __init__(self, args): 63 | super(NetReverse, self).__init__() 64 | self.args = args 65 | 66 | self.encoder = enc_wide_resnet_50_2(pretrained=True) 67 | self.decoder = dec_wide_resnet_50_2(pretrained=False) 68 | self.bn = bn_layer() 69 | 70 | 71 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | import torch 6 | import random 7 | import yaml 8 | import time 9 | import shutil 10 | import torchvision 11 | 12 | __all__ = ['seed_everything', 'parse_device_list', 'merge_config', 'override_config', 'extract_config', 'create_folders', 13 | 'record_path', 'save_arg', 'save_log', 'save_script', 'save_image'] 14 | 15 | 16 | def seed_everything(seed): 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | # some cudnn methods can be random even after fixing the seed 23 | # unless you tell it to be deterministic 24 | torch.backends.cudnn.deterministic = True 25 | 26 | def parse_device_list(device_ids_string, id_choice=None): 27 | device_ids = [int(i) for i in device_ids_string[0]] 28 | id_choice = 0 if id_choice is None else id_choice 29 | device = device_ids[id_choice] 30 | device = torch.device("cuda", device) 31 | return device, device_ids 32 | 33 | def override_config(previous, new): 34 | config = previous 35 | for new_key in new.keys(): 36 | config[new_key] = new[new_key] 37 | 38 | return config 39 | 40 | def merge_config(config, args): 41 | """ 42 | args overlaps config, the args is given a high priority 43 | """ 44 | for key_arg in dir(args): 45 | value = getattr(args, key_arg) 46 | is_int = (type(value)==int) 47 | if (getattr(args, key_arg) or is_int) and (key_arg in config.keys()): 48 | config[key_arg] = getattr(args, key_arg) 49 | 50 | return config 51 | 52 | def extract_config(args): 53 | config = dict() 54 | for key_arg in vars(args): 55 | value = getattr(args, key_arg) 56 | is_int = (type(value)==int) 57 | if vars(args)[key_arg] or is_int: 58 | config[key_arg] = vars(args)[key_arg] 59 | 60 | return config 61 | 62 | def record_path(para_dict): 63 | # mkdir ./work_dir/fed/brats/time-dir 64 | localtime = time.asctime(time.localtime(time.time())) 65 | file_path = '{}/{}/{}/{}'.format( 66 | para_dict['work_dir'], para_dict['learning_mode'], para_dict['dataset'], localtime) 67 | 68 | os.makedirs(file_path) 69 | 70 | return file_path 71 | 72 | def save_arg(para_dict, file_path): 73 | with open('{}/config.yaml'.format(file_path), 'w') as f: 74 | yaml.dump(para_dict, f) 75 | 76 | def save_log(infor, file_path, description=None): 77 | localtime = time.asctime(time.localtime(time.time())) 78 | infor = '[{}] {}'.format(localtime, infor) 79 | 80 | with open('{}/log{}.txt'.format(file_path, description), 'a') as f: 81 | print(infor, file=f) 82 | 83 | def save_script(src_file, file_path): 84 | shutil.copy2(src_file, file_path) 85 | 86 | def save_image(image, name, image_path): 87 | if not os.path.exists(image_path): 88 | os.makedirs(image_path) 89 | 90 | torchvision.utils.save_image(image, '{}/{}'.format(image_path, name), normalize=False) 91 | 92 | def create_folders(tag_path): 93 | if not os.path.exists(tag_path): 94 | os.makedirs(tag_path) 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /data_io/noisy.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | 5 | __all__ = ['extract_noisy_data'] 6 | 7 | def extract_noisy_data(train_dataset, valid_dataset, noisy_ratio=0.1, noisy_overlap=False, upper_ratio=0.75): 8 | valid_sample_nums = [0] + valid_dataset.sample_num_in_task 9 | valid_sample_indice = valid_dataset.sample_indices_in_task 10 | 11 | # obtain noisy label in validation set 12 | noisy_indices = [] 13 | for task_id in range(valid_dataset.num_task): 14 | anomaly_index = [] 15 | for k, j in enumerate(range(valid_sample_nums[task_id], valid_sample_nums[task_id] + valid_sample_nums[task_id+1])): 16 | z = sum(valid_sample_nums[:task_id+1]) + k 17 | label = valid_dataset.labels_list[z] 18 | if label == 1: 19 | anomaly_index.append(valid_sample_indice[task_id][k]) 20 | # set noisy data to be less than 75 percent of these in test set 21 | noise_num = int(noisy_ratio * train_dataset.sample_num_in_task[task_id]) 22 | anomaly_num_max = int(len(anomaly_index) * upper_ratio) 23 | if noise_num >= anomaly_num_max: 24 | noise_num = anomaly_num_max 25 | noise_index = random.sample(anomaly_index, noise_num) 26 | 27 | noisy_indices.append(noise_index) 28 | 29 | # construct valid_noisy_dataset 30 | valid_noisy_dataset = copy.deepcopy(valid_dataset) 31 | if not noisy_overlap: 32 | for task_id in range(valid_dataset.num_task): 33 | valid_noisy_dataset.sample_indices_in_task[task_id] = list(set(valid_sample_indice[task_id]) - set(noisy_indices[task_id])) 34 | valid_noisy_dataset.sample_num_in_task[task_id] = len(valid_noisy_dataset.sample_indices_in_task[task_id]) 35 | 36 | # construct noisy_dataset 37 | noisy_dataset = copy.deepcopy(valid_dataset) 38 | for task_id in range(valid_dataset.num_task): 39 | noisy_dataset.sample_indices_in_task[task_id] = noisy_indices[task_id] 40 | noisy_dataset.sample_num_in_task[task_id] = len(noisy_dataset.sample_indices_in_task[task_id]) 41 | 42 | # construct train_noisy_dataset 43 | train_noisy_dataset = copy.deepcopy(train_dataset) 44 | for task_id in range(valid_dataset.num_task): 45 | for img_id in noisy_indices[task_id]: 46 | train_noisy_dataset.imgs_list.append(noisy_dataset.imgs_list[img_id]) 47 | train_noisy_dataset.labels_list.append(noisy_dataset.labels_list[img_id]) 48 | train_noisy_dataset.masks_list.append(noisy_dataset.masks_list[img_id]) 49 | train_noisy_dataset.task_ids_list.append(noisy_dataset.task_ids_list[img_id]) 50 | 51 | for task_id in range(train_dataset.num_task): 52 | noisy_indices = [] 53 | for i in range(noisy_dataset.sample_num_in_task[task_id]): 54 | local_idx = i + len(train_dataset.imgs_list) + sum(noisy_dataset.sample_num_in_task[:task_id]) 55 | noisy_indices.append(int(local_idx)) 56 | 57 | train_noisy_dataset.sample_indices_in_task[task_id].extend(noisy_indices) 58 | train_noisy_dataset.sample_num_in_task[task_id] += noisy_dataset.sample_num_in_task[task_id] 59 | 60 | return train_noisy_dataset, valid_noisy_dataset, noisy_dataset 61 | -------------------------------------------------------------------------------- /loss_function/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import exp 5 | 6 | 7 | __all__ = ['SSIMLoss'] 8 | 9 | class SSIMLoss(nn.Module): 10 | def __init__(self, window_size=11, size_average=True, val_range=None): 11 | super(SSIMLoss, self).__init__() 12 | self.window_size = window_size 13 | self.size_average = size_average 14 | self.val_range = val_range 15 | 16 | # Assume 1 channel for SSIM 17 | self.channel = 1 18 | self.window = create_window(window_size).cuda() 19 | 20 | def forward(self, img1, img2): 21 | (_, channel, _, _) = img1.size() 22 | 23 | if channel == self.channel and self.window.dtype == img1.dtype: 24 | window = self.window 25 | else: 26 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 27 | self.window = window 28 | self.channel = channel 29 | 30 | s_score, ssim_map = ssim(img1, img2, window=window, window_size=self.window_size, 31 | size_average=self.size_average) 32 | return 1.0 - s_score 33 | 34 | def gaussian(window_size, sigma): 35 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 36 | return gauss/gauss.sum() 37 | 38 | def create_window(window_size, channel=1): 39 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 40 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 41 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 42 | return window 43 | 44 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 45 | if val_range is None: 46 | if torch.max(img1) > 128: 47 | max_val = 255 48 | else: 49 | max_val = 1 50 | 51 | if torch.min(img1) < -0.5: 52 | min_val = -1 53 | else: 54 | min_val = 0 55 | l = max_val - min_val 56 | else: 57 | l = val_range 58 | 59 | padd = window_size//2 60 | (_, channel, height, width) = img1.size() 61 | if window is None: 62 | real_size = min(window_size, height, width) 63 | window = create_window(real_size, channel=channel).to(img1.device) 64 | 65 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 66 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 67 | 68 | mu1_sq = mu1.pow(2) 69 | mu2_sq = mu2.pow(2) 70 | mu1_mu2 = mu1 * mu2 71 | 72 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 73 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 74 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 75 | 76 | c1 = (0.01 * l) ** 2 77 | c2 = (0.03 * l) ** 2 78 | 79 | v1 = 2.0 * sigma12 + c2 80 | v2 = sigma1_sq + sigma2_sq + c2 81 | cs = torch.mean(v1 / v2) # contrast sensitivity 82 | 83 | ssim_map = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2) 84 | 85 | if size_average: 86 | ret = ssim_map.mean() 87 | else: 88 | ret = ssim_map.mean(1).mean(1).mean(1) 89 | 90 | if full: 91 | return ret, cs 92 | return ret, ssim_map 93 | 94 | -------------------------------------------------------------------------------- /arch/simplenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from arch.base import ModelBase 3 | from models.simplenet.simplenet import SimpleNet as _SimpleNet 4 | from torchvision import models 5 | 6 | __all__ = ['SimpleNet'] 7 | 8 | class SimpleNet(ModelBase): 9 | def __init__(self, config): 10 | super(SimpleNet, self).__init__(config) 11 | self.config = config 12 | 13 | if self.config['net'] == 'wide_resnet50': 14 | self.net = models.wide_resnet50_2(pretrained=True, progress=True).to(self.device) 15 | 16 | self.simplenet = _SimpleNet(self.device) 17 | self.simplenet.load( 18 | backbone=self.net, 19 | layers_to_extract_from=self.config['_layers_to_extract_from'], 20 | device=self.device, 21 | input_shape=(3, self.config['data_crop_size'], self.config['data_crop_size']), 22 | pretrain_embed_dimension=self.config['_pretrain_embed_dimension'], 23 | target_embed_dimension=self.config['_target_embed_dimension'], 24 | patchsize=self.config['_patchsize'], 25 | embedding_size=self.config['_embedding_size'], 26 | meta_epochs=self.config['num_epochs'], 27 | aed_meta_epochs=self.config['_aed_meta_epochs'], 28 | gan_epochs=self.config['_gan_epochs'], 29 | noise_std=self.config['_noise_std'], 30 | dsc_layers=self.config['_dsc_layers'], 31 | dsc_hidden=self.config['_dsc_hidden'], 32 | dsc_margin=self.config['_dsc_margin'], 33 | dsc_lr=self.config['_dsc_lr'], 34 | auto_noise=self.config['_auto_noise'], 35 | train_backbone=self.config['_train_backbone'], 36 | cos_lr=self.config['_cos_lr'], 37 | pre_proj=self.config['_pre_proj'], 38 | proj_layer_type=self.config['_proj_layer_type'], 39 | mix_noise=self.config['_mix_noise'], 40 | ) 41 | 42 | def train_model(self, train_loader, task_id, inf=''): 43 | self.simplenet.train_discriminator(train_loader) 44 | 45 | def prediction(self, valid_loader, task_id): 46 | self.clear_all_list() 47 | 48 | scores, segmentations, labels_gt, masks_gt, img_srcs = self.simplenet.predict(valid_loader) 49 | 50 | scores = np.array(scores) 51 | min_scores = scores.min(axis=-1).reshape(-1, 1) 52 | max_scores = scores.max(axis=-1).reshape(-1, 1) 53 | scores = (scores - min_scores) / (max_scores - min_scores) 54 | scores = np.mean(scores, axis=0) 55 | 56 | segmentations = np.array(segmentations) 57 | min_scores = segmentations.reshape(len(segmentations), -1).min(axis=-1).reshape(-1, 1, 1, 1) 58 | max_scores = segmentations.reshape(len(segmentations), -1).max(axis=-1).reshape(-1, 1, 1, 1) 59 | segmentations = (segmentations - min_scores) / (max_scores - min_scores) 60 | segmentations = np.mean(segmentations, axis=0) 61 | segmentations[segmentations >= 0.5] = 1 62 | segmentations[segmentations < 0.5] = 0 63 | segmentations = np.array(segmentations, dtype='uint8') 64 | masks_gt = np.array(masks_gt).squeeze().astype(int) 65 | 66 | self.pixel_gt_list = [mask for mask in masks_gt] 67 | self.pixel_pred_list = [seg for seg in segmentations] 68 | self.img_gt_list = labels_gt 69 | self.img_pred_list = scores 70 | self.img_path_list = img_srcs -------------------------------------------------------------------------------- /arch/cfa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from arch.base import ModelBase 3 | from models.cfa.net_cfa import NetCFA 4 | from models.cfa.cfa import DSVDD 5 | import torch.nn.functional as F 6 | from scipy.ndimage import gaussian_filter 7 | 8 | __all__ = ['CFA'] 9 | 10 | class CFA(ModelBase): 11 | def __init__(self, config): 12 | super(CFA, self).__init__(config) 13 | self.config = config 14 | self.net = NetCFA(self.config).resnet18.to(self.device) 15 | 16 | @staticmethod 17 | def upsample(x, size, mode): 18 | return F.interpolate(x.unsqueeze(1), size=size, mode=mode, align_corners=False).squeeze().numpy() 19 | 20 | @staticmethod 21 | def gaussian_smooth(x, sigma=4): 22 | bs = x.shape[0] 23 | for i in range(0, bs): 24 | x[i] = gaussian_filter(x[i], sigma=sigma) 25 | 26 | return x 27 | 28 | @staticmethod 29 | def rescale(x): 30 | return (x - x.min()) / (x.max() - x.min()) 31 | 32 | def train_model(self, train_loader, task_id, inf=''): 33 | self.net.eval() 34 | 35 | self.loss_fn = DSVDD(model=self.net, data_loader=train_loader, 36 | cnn='resnet18', gamma_c=self.config['gamma_c'], 37 | gamma_d=self.config['gamma_d'], device=self.device) 38 | self.loss_fn = self.loss_fn.to(self.device) 39 | self.loss_fn.train() 40 | 41 | params = [{'params' : self.loss_fn.parameters()},] 42 | optimizer = torch.optim.AdamW(params=params, lr=1e-3, weight_decay=5e-4, 43 | amsgrad=True) 44 | 45 | for epoch in range(self.config['num_epochs']): 46 | for batch_id, batch in enumerate(train_loader): 47 | optimizer.zero_grad() 48 | img = batch['img'].to(self.device) 49 | p = self.net(img) 50 | 51 | loss, _ = self.loss_fn(p) 52 | loss.backward() 53 | optimizer.step() 54 | 55 | def prediction(self, valid_loader, task_id=None): 56 | self.loss_fn.eval() 57 | self.clear_all_list() 58 | heatmaps = None 59 | 60 | with torch.no_grad(): 61 | for batch_id, batch in enumerate(valid_loader): 62 | img = batch['img'].to(self.device) 63 | label = batch['label'] 64 | mask = batch['mask'].to(self.device) 65 | mask[mask >= 0.5] = 1 66 | mask[mask < 0.5] = 0 67 | 68 | self.img_gt_list.append(label.cpu().detach().numpy()) 69 | self.pixel_gt_list.append(mask.cpu().detach().numpy()[0,0,:,:]) 70 | self.img_path_list.append(batch['img_src']) 71 | 72 | p = self.net(img) 73 | 74 | _, score = self.loss_fn(p) 75 | heatmap = score.cpu().detach() 76 | heatmap = torch.mean(heatmap, dim=1) 77 | heatmaps = torch.cat((heatmaps, heatmap), dim=0) if heatmaps != None else heatmap 78 | 79 | heatmaps = self.upsample(heatmaps, size=img.size(2), mode='bilinear') 80 | heatmaps = self.gaussian_smooth(heatmaps, sigma=4) 81 | 82 | scores = self.rescale(heatmaps) 83 | for i in range(scores.shape[0]): 84 | self.pixel_pred_list.append(scores[i]) 85 | img_scores = scores.reshape(scores.shape[0], -1).max(axis=1) 86 | self.img_pred_list = img_scores -------------------------------------------------------------------------------- /loss_function/focal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['FocalLoss'] 6 | 7 | class FocalLoss(nn.Module): 8 | """ 9 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 10 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 11 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 12 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 13 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 14 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 15 | focus on hard misclassified example 16 | :param smooth: (float,double) smooth value when cross entropy 17 | :param balance_index: (int) balance class index, should be specific when alpha is float 18 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 19 | """ 20 | 21 | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): 22 | super(FocalLoss, self).__init__() 23 | self.apply_nonlin = apply_nonlin 24 | self.alpha = alpha 25 | self.gamma = gamma 26 | self.balance_index = balance_index 27 | self.smooth = smooth 28 | self.size_average = size_average 29 | 30 | if self.smooth is not None: 31 | if self.smooth < 0 or self.smooth > 1.0: 32 | raise ValueError('smooth value should be in [0,1]') 33 | 34 | def forward(self, logit, target): 35 | if self.apply_nonlin is not None: 36 | logit = self.apply_nonlin(logit) 37 | num_class = logit.shape[1] 38 | 39 | if logit.dim() > 2: 40 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 41 | logit = logit.view(logit.size(0), logit.size(1), -1) 42 | logit = logit.permute(0, 2, 1).contiguous() 43 | logit = logit.view(-1, logit.size(-1)) 44 | target = torch.squeeze(target, 1) 45 | target = target.view(-1, 1) 46 | alpha = self.alpha 47 | 48 | if alpha is None: 49 | alpha = torch.ones(num_class, 1) 50 | elif isinstance(alpha, (list, np.ndarray)): 51 | assert len(alpha) == num_class 52 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 53 | alpha = alpha / alpha.sum() 54 | elif isinstance(alpha, float): 55 | alpha = torch.ones(num_class, 1) 56 | alpha = alpha * (1 - self.alpha) 57 | alpha[self.balance_index] = self.alpha 58 | 59 | else: 60 | raise TypeError('Not support alpha type') 61 | 62 | if alpha.device != logit.device: 63 | alpha = alpha.to(logit.device) 64 | 65 | idx = target.cpu().long() 66 | 67 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 68 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 69 | if one_hot_key.device != logit.device: 70 | one_hot_key = one_hot_key.to(logit.device) 71 | 72 | if self.smooth: 73 | one_hot_key = torch.clamp( 74 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) 75 | pt = (one_hot_key * logit).sum(1) + self.smooth 76 | logpt = pt.log() 77 | 78 | gamma = self.gamma 79 | 80 | alpha = alpha[idx] 81 | alpha = torch.squeeze(alpha) 82 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 83 | 84 | if self.size_average: 85 | loss = loss.mean() 86 | return loss -------------------------------------------------------------------------------- /arch/patchcore.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from arch.base import ModelBase 3 | from models.patchcore.patchcore import PatchCore as patchcore_official 4 | from models.patchcore import common 5 | from models.patchcore import sampler 6 | from torchvision import models 7 | 8 | __all__ = ['PatchCore'] 9 | 10 | class PatchCore(ModelBase): 11 | def __init__(self, config): 12 | super(PatchCore, self).__init__(config) 13 | self.config = config 14 | 15 | if self.config['net'] == 'resnet18': 16 | self.net = models.resnet18(pretrained=True, progress=True).to(self.device) 17 | if self.config['net'] == 'wide_resnet50': 18 | self.net = models.wide_resnet50_2(pretrained=True, progress=True).to(self.device) 19 | 20 | self.sampler = self.get_sampler(self.config['_sampler_name']) 21 | self.nn_method = common.FaissNN(self.config['_faiss_on_gpu'], self.config['_faiss_num_workers']) 22 | 23 | self.patchcore_instance = patchcore_official(self.device) 24 | self.patchcore_instance.load( 25 | backbone=self.net, 26 | layers_to_extract_from=self.config['_layers_to_extract_from'], 27 | device=self.device, 28 | input_shape=self.config['_input_shape'], 29 | pretrain_embed_dimension=self.config['_pretrain_embed_dimension'], 30 | target_embed_dimension=self.config['_target_embed_dimension'], 31 | patchsize=self.config['_patch_size'], 32 | featuresampler=self.sampler, 33 | anomaly_scorer_num_nn=self.config['_anomaly_scorer_num_nn'], 34 | nn_method=self.nn_method, 35 | ) 36 | 37 | def get_sampler(self, name): 38 | if name == 'identity': 39 | return sampler.IdentitySampler() 40 | elif name == 'greedy_coreset': 41 | return sampler.GreedyCoresetSampler(self.config['sampler_percentage'], self.device) 42 | elif name == 'approx_greedy_coreset': 43 | return sampler.ApproximateGreedyCoresetSampler(self.config['sampler_percentage'], self.device) 44 | else: 45 | raise ValueError('No This Sampler: {}'.format(name)) 46 | 47 | def train_model(self, train_loader, task_id, inf=''): 48 | self.patchcore_instance.eval() 49 | self.patchcore_instance.fit(train_loader) 50 | 51 | def prediction(self, valid_loader, task_id=None): 52 | self.patchcore_instance.eval() 53 | self.clear_all_list() 54 | 55 | scores, segmentations, labels_gt, masks_gt, img_srcs = self.patchcore_instance.predict(valid_loader) 56 | 57 | scores = np.array(scores) 58 | min_scores = scores.min(axis=-1).reshape(-1, 1) 59 | max_scores = scores.max(axis=-1).reshape(-1, 1) 60 | scores = (scores - min_scores) / (max_scores - min_scores) 61 | scores = np.mean(scores, axis=0) 62 | 63 | segmentations = np.array(segmentations) 64 | min_scores = segmentations.reshape(len(segmentations), -1).min(axis=-1).reshape(-1, 1, 1, 1) 65 | max_scores = segmentations.reshape(len(segmentations), -1).max(axis=-1).reshape(-1, 1, 1, 1) 66 | segmentations = (segmentations - min_scores) / (max_scores - min_scores) 67 | segmentations = np.mean(segmentations, axis=0) 68 | segmentations[segmentations >= 0.5] = 1 69 | segmentations[segmentations < 0.5] = 0 70 | segmentations = np.array(segmentations, dtype='uint8') 71 | masks_gt = np.array(masks_gt).squeeze().astype(int) 72 | 73 | self.pixel_gt_list = [mask for mask in masks_gt] 74 | self.pixel_pred_list = [seg for seg in segmentations] 75 | self.img_gt_list = labels_gt 76 | self.img_pred_list = scores 77 | self.img_path_list = img_srcs 78 | -------------------------------------------------------------------------------- /tools/record_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from tools.visualize import save_anomaly_map 4 | from configuration.registration import setting_name 5 | from rich import print 6 | 7 | __all__ = ['RecordHelper'] 8 | 9 | class RecordHelper(): 10 | def __init__(self, config): 11 | self.config = config 12 | 13 | def update(self, config): 14 | self.config = config 15 | 16 | def printer(self, info): 17 | print(info) 18 | 19 | def paradigm_name(self): 20 | for s in setting_name: 21 | if self.config[s]: 22 | return s 23 | 24 | print('Add new setting in record_helper.py!') 25 | return 'unknown' 26 | 27 | def record_result(self, result): 28 | paradim = self.paradigm_name() 29 | save_dir = '{}/benchmark/{}/{}/{}/task_{}'.format(self.config['work_dir'], paradim, self.config['dataset'], 30 | self.config['model'], self.config['train_task_id_tmp']) 31 | if not os.path.exists(save_dir): 32 | os.makedirs(save_dir) 33 | 34 | save_path = save_dir + '/result.txt' 35 | if paradim == 'vanilla': 36 | save_path = save_path 37 | if paradim == 'semi': 38 | save_path = '{}/result_{}_num.txt'.format(save_dir, self.config['semi_anomaly_num']) 39 | if paradim == 'fewshot': 40 | save_path = '{}/result_{}_{}_shot.txt'.format(save_dir, ''.join(self.config['fewshot_aug_type']), self.config['fewshot_exm']) 41 | if paradim == 'continual': 42 | save_path = '{}/result_{}_task.txt'.format(save_dir, self.config['valid_task_id_tmp']) 43 | if paradim == 'noisy': 44 | save_path = '{}/result_{}_ratio.txt'.format(save_dir, self.config['noisy_ratio']) 45 | if paradim == 'transfer': 46 | save_path = '{}/result_from_{}_to_{}.txt'.format(save_dir, self.config['train_task_id'][0], self.config['valid_task_id'][0]) 47 | 48 | with open(save_path, 'a') as f: 49 | print(result, file=f) 50 | 51 | def record_images(self, img_pred_list, img_gt_list, pixel_pred_list, pixel_gt_list, img_path_list): 52 | paradim = self.paradigm_name() 53 | save_dir = '{}/benchmark/{}/{}/{}/task_{}'.format(self.config['work_dir'], paradim, self.config['dataset'], 54 | self.config['model'], self.config['train_task_id_tmp']) 55 | 56 | if paradim == 'vanilla': 57 | save_dir = save_dir + '/vis' 58 | if paradim == 'semi': 59 | save_dir = '{}/vis_{}_num'.format(save_dir, self.config['semi_anomaly_num']) 60 | if paradim == 'fewshot': 61 | save_dir = '{}/vis_{}_{}_shot'.format(save_dir, ''.join(self.config['fewshot_aug_type']), self.config['fewshot_exm']) 62 | if paradim == 'continual': 63 | save_dir = '{}/vis_{}_task'.format(save_dir, self.config['valid_task_id_tmp']) 64 | if paradim == 'noisy': 65 | save_dir = '{}/vis_{}_ratio'.format(save_dir, self.config['noisy_ratio']) 66 | if paradim == 'transfer': 67 | save_dir = '{}/vis_from_{}_to_{}'.format(save_dir, self.config['train_task_id'][0], self.config['valid_task_id'][0]) 68 | 69 | if not os.path.exists(save_dir): 70 | os.makedirs(save_dir) 71 | 72 | for i in range(len(img_path_list)): 73 | img_src = cv2.imread(img_path_list[i][0]) 74 | img_src = cv2.resize(img_src, pixel_pred_list[0].shape) 75 | path_dir = img_path_list[i][0].split('/') 76 | save_path = '{}/{}_{}'.format(save_dir, path_dir[-2], path_dir[-1][:-4]) 77 | 78 | save_anomaly_map(anomaly_map=pixel_pred_list[i], input_img=img_src, mask=pixel_gt_list[i], file_path=save_path) 79 | -------------------------------------------------------------------------------- /arch/cutpaste.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import argparse 5 | 6 | from arch.base import ModelBase 7 | from models.vit.vit import ViT 8 | from models.cutpaste.density import GaussianDensityTorch 9 | from optimizer.optimizer import get_optimizer 10 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 11 | 12 | __all__ = ['CutPaste'] 13 | 14 | class _CutPaste(nn.Module): 15 | def __init__(self, args, net, optimizer, scheduler): 16 | super(_CutPaste, self).__init__() 17 | self.args = args 18 | self.optimizer = optimizer 19 | self.scheduler = scheduler 20 | self.net = net 21 | self.softmax = nn.Softmax(dim=1) 22 | self.cross_entropy = nn.CrossEntropyLoss() 23 | 24 | def forward(self, epoch, inputs, labels, one_epoch_embeds, *args): 25 | num = int(len(inputs) / 2) 26 | self.optimizer.zero_grad() 27 | embeds, outs = self.net(inputs) 28 | one_epoch_embeds.append(embeds[:num].detach().cpu()) 29 | loss = self.cross_entropy(self.softmax(outs), labels.long()) 30 | loss.backward() 31 | self.optimizer.step() 32 | if self.scheduler is not None: 33 | self.scheduler.step(epoch) 34 | 35 | def training_epoch(self, density, one_epoch_embeds, *args): 36 | one_epoch_embeds = torch.cat(one_epoch_embeds) 37 | one_epoch_embeds = F.normalize(one_epoch_embeds, p=2, dim=1) 38 | _, _ = density.fit(one_epoch_embeds) 39 | return density 40 | 41 | class CutPaste(ModelBase): 42 | def __init__(self, config): 43 | super(CutPaste, self).__init__(config) 44 | self.config = config 45 | args = argparse.Namespace(**self.config) 46 | self.net = ViT(num_classes=self.config['_num_classes'], pretrained=self.config['_pretrained'], checkpoint_path='./checkpoints/vit/vit_b_16.npz') 47 | self.optimizer = get_optimizer(self.config, self.net.parameters()) 48 | self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, self.config['num_epochs']) 49 | 50 | self.model = _CutPaste(args, self.net, self.optimizer, self.scheduler).to(self.device) 51 | self.density = GaussianDensityTorch() 52 | self.one_epoch_embeds = [] 53 | 54 | def train_model(self, train_loader, task_id, inf=''): 55 | for epoch in range(self.config['num_epochs']): 56 | for batch_id, batch in enumerate(train_loader): 57 | self.net.train() 58 | imgs = batch['img'] 59 | inputs = [img.to(self.device) for img in imgs] 60 | labels = torch.arange(len(inputs), device=self.device) 61 | labels = labels.repeat_interleave(inputs[0].size(0)) 62 | inputs = torch.cat(inputs, dim=0) 63 | self.model(epoch, inputs, labels, self.one_epoch_embeds) 64 | 65 | def prediction(self, valid_loader, task_id=None): 66 | self.net.eval() 67 | density = self.model.training_epoch(self.density, self.one_epoch_embeds) 68 | labels = [] 69 | embeds = [] 70 | 71 | with torch.no_grad(): 72 | for batch_id, batch in enumerate(valid_loader): 73 | input = batch['img'].to(self.device) 74 | label = batch['label'].to(self.device) 75 | self.img_path_list.append(batch['img_src']) 76 | 77 | embed = self.net.forward_features(input) 78 | embeds.append(embed.cpu()) 79 | labels.append(label.cpu()) 80 | 81 | labels = torch.cat(labels) 82 | embeds = torch.cat(embeds) 83 | embeds = F.normalize(embeds, p=2, dim=1) 84 | 85 | distances = density.predict(embeds) 86 | self.img_gt_list = labels 87 | self.img_pred_list = distances 88 | -------------------------------------------------------------------------------- /arch/reverse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | from torch.nn import functional as F 5 | from arch.base import ModelBase 6 | from scipy.ndimage import gaussian_filter 7 | from loss_function.reverse import reverse_loss 8 | from models.reverse.net_reverse import NetReverse 9 | from optimizer.optimizer import get_optimizer 10 | 11 | __all__ = ['REVERSE'] 12 | 13 | class REVERSE(ModelBase): 14 | def __init__(self, config): 15 | super(REVERSE, self).__init__(config) 16 | self.config = config 17 | 18 | args = argparse.Namespace(**self.config) 19 | self.net = NetReverse(args) 20 | self.optimizer = get_optimizer(self.config, list(self.net.decoder.parameters()) + list(self.net.bn.parameters())) 21 | self.encoder = self.net.encoder.to(self.device) 22 | self.decoder = self.net.decoder.to(self.device) 23 | self.bn = self.net.bn.to(self.device) 24 | 25 | def train_model(self, train_loader, task_id, inf=''): 26 | self.encoder.eval() 27 | self.bn.train() 28 | self.decoder.train() 29 | loss_list = [] 30 | 31 | for epoch in range(self.config['num_epochs']): 32 | for batch_id, batch in enumerate(train_loader): 33 | img = batch['img'].to(self.device) 34 | 35 | inputs = self.encoder(img) 36 | outputs = self.decoder(self.bn(inputs)) 37 | loss = reverse_loss(inputs, outputs) 38 | loss_list.append(loss.item()) 39 | 40 | self.optimizer.zero_grad() 41 | loss.backward() 42 | self.optimizer.step() 43 | 44 | def prediction(self, valid_loader, task_id): 45 | self.encoder.eval() 46 | self.bn.eval() 47 | self.decoder.eval() 48 | self.clear_all_list() 49 | 50 | self.pixel_gt_list = [] 51 | self.pixel_pred_list = [] 52 | self.img_gt_list = [] 53 | self.img_pred_list = [] 54 | 55 | with torch.no_grad(): 56 | for batch_id, batch in enumerate(valid_loader): 57 | img = batch['img'].to(self.device) 58 | mask = batch['mask'] 59 | label = batch['label'] 60 | inputs = self.encoder(img) 61 | outputs = self.decoder(self.bn(inputs)) 62 | 63 | anomaly_map, _ = self.cal_anomaly_map(inputs, outputs, img.shape[-1], amap_mode='a') 64 | anomaly_map = gaussian_filter(anomaly_map, sigma=4) 65 | 66 | mask[mask >= 0.5] = 1 67 | mask[mask < 0.5] = 0 68 | 69 | self.pixel_gt_list.append(mask.cpu().numpy()[0,0].astype(int)) 70 | self.pixel_pred_list.append(anomaly_map) 71 | self.img_gt_list.append(label.numpy()[0]) 72 | self.img_pred_list.append(np.max(anomaly_map)) 73 | self.img_path_list.append(batch['img_src']) 74 | 75 | def cal_anomaly_map(self, fs_list, ft_list, out_size=256, amap_mode='full'): 76 | if amap_mode == 'mul': 77 | anomaly_map = np.ones([out_size, out_size]) 78 | else: 79 | anomaly_map = np.zeros([out_size, out_size]) 80 | 81 | a_map_list = [] 82 | for i in range(len(ft_list)): 83 | fs = fs_list[i] 84 | ft = ft_list[i] 85 | #fs_norm = F.normalize(fs, p=2) 86 | #ft_norm = F.normalize(ft, p=2) 87 | a_map = 1 - F.cosine_similarity(fs, ft) 88 | a_map = torch.unsqueeze(a_map, dim=1) 89 | a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True) 90 | a_map = a_map[0, 0, :, :].to('cpu').detach().numpy() 91 | a_map_list.append(a_map) 92 | if amap_mode == 'mul': 93 | anomaly_map *= a_map 94 | else: 95 | anomaly_map += a_map 96 | return anomaly_map, a_map_list 97 | -------------------------------------------------------------------------------- /models/graphcore/gcn_lib/torch_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Sequential as Seq, Linear as Lin, Conv2d 4 | 5 | __all__ = ['act_layer', 'norm_layer', 'MLP', 'BasicConv'] 6 | ############################## 7 | # Basic layers 8 | ############################## 9 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): 10 | # activation layer 11 | 12 | act = act.lower() 13 | if act == 'relu': 14 | layer = nn.ReLU(inplace) 15 | elif act == 'leakyrelu': 16 | layer = nn.LeakyReLU(neg_slope, inplace) 17 | elif act == 'prelu': 18 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 19 | elif act == 'gelu': 20 | layer = nn.GELU() 21 | elif act == 'hswish': 22 | layer = nn.Hardswish(inplace) 23 | else: 24 | raise NotImplementedError('activation layer [%s] is not found' % act) 25 | return layer 26 | 27 | 28 | def norm_layer(norm, nc): 29 | # normalization layer 2d 30 | norm = norm.lower() 31 | if norm == 'batch': 32 | layer = nn.BatchNorm2d(nc, affine=True) 33 | elif norm == 'instance': 34 | layer = nn.InstanceNorm2d(nc, affine=False) 35 | else: 36 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 37 | return layer 38 | 39 | 40 | class MLP(Seq): 41 | def __init__(self, channels, act='relu', norm=None, bias=True): 42 | m = [] 43 | for i in range(1, len(channels)): 44 | m.append(Lin(channels[i - 1], channels[i], bias)) 45 | if act is not None and act.lower() != 'none': 46 | m.append(act_layer(act)) 47 | if norm is not None and norm.lower() != 'none': 48 | m.append(norm_layer(norm, channels[-1])) 49 | super(MLP, self).__init__(*m) 50 | 51 | 52 | class BasicConv(Seq): 53 | def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): 54 | m = [] 55 | for i in range(1, len(channels)): 56 | m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias, groups=4)) 57 | if norm is not None and norm.lower() != 'none': 58 | m.append(norm_layer(norm, channels[-1])) 59 | if act is not None and act.lower() != 'none': 60 | m.append(act_layer(act)) 61 | if drop > 0: 62 | m.append(nn.Dropout2d(drop)) 63 | 64 | super(BasicConv, self).__init__(*m) 65 | 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | nn.init.kaiming_normal_(m.weight) 72 | if m.bias is not None: 73 | nn.init.zeros_(m.bias) 74 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 75 | m.weight.data.fill_(1) 76 | m.bias.data.zero_() 77 | 78 | 79 | def batched_index_select(x, idx): 80 | """fetches neighbors features from a given neighbor idx 81 | Args: 82 | x (Tensor): input feature Tensor 83 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. 84 | idx (Tensor): edge_idx 85 | :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. 86 | Returns: 87 | Tensor: output neighbors features 88 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. 89 | """ 90 | batch_size, num_dims, num_vertices_reduced = x.shape[:3] 91 | _, num_vertices, k = idx.shape 92 | idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced 93 | idx = idx + idx_base 94 | idx = idx.contiguous().view(-1) 95 | 96 | x = x.transpose(2, 1) 97 | feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :] 98 | feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() 99 | return feature -------------------------------------------------------------------------------- /arch/draem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import argparse 5 | from arch.base import ModelBase 6 | from models.dream.draem import NetDRAEM 7 | from loss_function.focal import FocalLoss 8 | from loss_function.ssim import SSIMLoss 9 | from augmentation.draem_aug import DraemAugData 10 | from optimizer.optimizer import get_optimizer 11 | 12 | __all__ = ['DRAEM', 'weights_init'] 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv') != -1: 17 | m.weight.data.normal_(0.0, 0.02) 18 | elif classname.find('BatchNorm') != -1: 19 | m.weight.data.normal_(1.0, 0.02) 20 | m.bias.data.fill_(0) 21 | 22 | class DRAEM(ModelBase): 23 | def __init__(self, config): 24 | super(DRAEM, self).__init__(config) 25 | self.config = config 26 | 27 | args = argparse.Namespace(**self.config) 28 | self.net = NetDRAEM(args).to(self.device) 29 | self.optimizer = get_optimizer(self.config, list(self.net.reconstructive_subnetwork.parameters()) + list(self.net.discriminative_subnetwork.parameters())) 30 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, [args.num_epochs * 0.8, args.num_epochs * 0.9], gamma=args._gamma, last_epoch=-1) 31 | self.dream_aug = DraemAugData(self.config['root_path'] + '/dtd/images', [args.data_size, args.data_size]) 32 | 33 | self.net.reconstructive_subnetwork.apply(weights_init) 34 | self.net.discriminative_subnetwork.apply(weights_init) 35 | 36 | self.loss_l2 = nn.modules.loss.MSELoss() 37 | self.loss_ssim = SSIMLoss() 38 | self.loss_focal = FocalLoss() 39 | 40 | def train_model(self, train_loader, task_id, inf=''): 41 | self.net.train() 42 | 43 | for epoch in range(self.config['num_epochs']): 44 | for batch_id, batch in enumerate(train_loader): 45 | inputs, masks, labels = self.dream_aug.transform_batch(batch['img'], batch['label'], batch['mask']) 46 | inputs = inputs.to(self.device) 47 | masks = masks.to(self.device) 48 | 49 | rec_imgs, out_masks = self.net(inputs) 50 | 51 | out_masks_sm = torch.softmax(out_masks, dim=1) 52 | l2_loss = self.loss_l2(rec_imgs, inputs) 53 | ssim_loss = self.loss_ssim(rec_imgs, inputs) 54 | segment_loss = self.loss_focal(out_masks_sm, masks) 55 | loss = l2_loss + ssim_loss + segment_loss 56 | 57 | self.optimizer.zero_grad() 58 | loss.backward() 59 | self.optimizer.step() 60 | 61 | self.scheduler.step() 62 | 63 | def prediction(self, valid_loader, task_id): 64 | self.net.eval() 65 | self.clear_all_list() 66 | 67 | with torch.no_grad(): 68 | for batch_id, batch in enumerate(valid_loader): 69 | inputs = batch['img'].to(self.device) 70 | labels = batch['label'].numpy() 71 | mask = batch['mask'].numpy() 72 | 73 | _, out_masks = self.net(inputs) 74 | out_masks_sm = torch.softmax(out_masks, dim=1) 75 | out_mask_cv = out_masks_sm[0, 1, :, :].detach().cpu().numpy() 76 | outs_mask_averaged = torch.nn.functional.avg_pool2d(out_masks_sm[:, 1:, :, :], 77 | 21, stride=1, padding=21 // 2).cpu().detach().numpy() 78 | image_score = np.max(outs_mask_averaged) 79 | self.pixel_pred_list.append(out_mask_cv) 80 | self.img_pred_list.append(image_score) 81 | 82 | mask[mask >= 0.5] = 1 83 | mask[mask < 0.5] = 0 84 | mask_np = mask[0, 0].astype(int) 85 | self.pixel_gt_list.append(mask_np) 86 | self.img_gt_list.append(labels[0]) 87 | self.img_path_list.append(batch['img_src']) -------------------------------------------------------------------------------- /models/_patchcore/kcenter_greedy.py: -------------------------------------------------------------------------------- 1 | """Returns points that minimizes the maximum distance of any point to a center. 2 | Implements the k-Center-Greedy method in 3 | Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for 4 | Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017 5 | Distance metric defaults to l2 distance. Features used to calculate distance 6 | are either raw features or if a model has transform method then uses the output 7 | of model.transform(X). 8 | Can be extended to a robust k centers algorithm that ignores a certain number of 9 | outlier datapoints. Resulting centers are solution to multiple integer program. 10 | """ 11 | 12 | import numpy as np 13 | from sklearn.metrics import pairwise_distances 14 | from models._patchcore.sampling_base import SamplingMethod 15 | 16 | __all__ = ['KCenterGreedy'] 17 | 18 | class KCenterGreedy(SamplingMethod): 19 | 20 | def __init__(self, X, y, metric='euclidean'): 21 | self.X = X 22 | self.y = y 23 | self.flat_X = self.flatten_X() 24 | self.name = 'kcenter' 25 | self.features = self.flat_X 26 | self.metric = metric 27 | self.min_distances = None 28 | self.n_obs = self.X.shape[0] 29 | self.already_selected = [] 30 | 31 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False): 32 | """Update min distances given cluster centers. 33 | Args: 34 | cluster_centers: indices of cluster centers 35 | only_new: only calculate distance for newly selected points and update 36 | min_distances. 37 | rest_dist: whether to reset min_distances. 38 | """ 39 | 40 | if reset_dist: 41 | self.min_distances = None 42 | if only_new: 43 | cluster_centers = [d for d in cluster_centers 44 | if d not in self.already_selected] 45 | if cluster_centers: 46 | # Update min_distances for all examples given new cluster center. 47 | x = self.features[cluster_centers] 48 | dist = pairwise_distances(self.features, x, metric=self.metric) 49 | 50 | if self.min_distances is None: 51 | self.min_distances = np.min(dist, axis=1).reshape(-1,1) 52 | else: 53 | self.min_distances = np.minimum(self.min_distances, dist) 54 | 55 | def select_batch_(self, model, already_selected, N, **kwargs): 56 | """ 57 | Diversity promoting active learning method that greedily forms a batch 58 | to minimize the maximum distance to a cluster center among all unlabeled 59 | datapoints. 60 | Args: 61 | model: model with scikit-like API with decision_function implemented 62 | already_selected: index of datapoints already selected 63 | N: batch size 64 | Returns: 65 | indices of points selected to minimize distance to cluster centers 66 | """ 67 | 68 | try: 69 | # Assumes that the transform function takes in original data and not 70 | # flattened data. 71 | # print('Getting transformed features...') 72 | self.features = model.transform(self.X) 73 | # print('Calculating distances...') 74 | self.update_distances(already_selected, only_new=False, reset_dist=True) 75 | except: 76 | # print('Using flat_X as features.') 77 | self.update_distances(already_selected, only_new=True, reset_dist=False) 78 | 79 | new_batch = [] 80 | 81 | for _ in range(N): 82 | if self.already_selected is None: 83 | # Initialize centers with a randomly selected datapoint 84 | ind = np.random.choice(np.arange(self.n_obs)) 85 | else: 86 | ind = np.argmax(self.min_distances) 87 | # New examples should not be in already selected since those points 88 | # should have min_distance of zero to a cluster center. 89 | assert ind not in already_selected 90 | 91 | self.update_distances([ind], only_new=True, reset_dist=False) 92 | new_batch.append(ind) 93 | #print('Maximum distance from cluster centers is %0.2f' % max(self.min_distances)) 94 | self.already_selected = already_selected 95 | 96 | return new_batch -------------------------------------------------------------------------------- /arch/softpatch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from arch.base import ModelBase 3 | from models.softpatch.softpatch import SoftPatch as softpatch 4 | from models.patchcore import common 5 | from models.softpatch import sampler 6 | from torchvision import models 7 | 8 | __all__ = ['SoftPatch'] 9 | 10 | class SoftPatch(ModelBase): 11 | def __init__(self, config): 12 | super(SoftPatch, self).__init__(config) 13 | self.config = config 14 | 15 | if self.config['net'] == 'resnet18': 16 | self.net = models.resnet18(pretrained=True, progress=True).to(self.device) 17 | if self.config['net'] == 'wide_resnet50': 18 | self.net = models.wide_resnet50_2(pretrained=True, progress=True).to(self.device) 19 | 20 | self.sampler = self.get_sampler(self.config['_sampler_name']) 21 | self.nn_method = common.FaissNN(self.config['_faiss_on_gpu'], self.config['_faiss_num_workers']) 22 | 23 | self.patchcore_instance = softpatch(self.device) 24 | self.patchcore_instance.load( 25 | backbone=self.net, 26 | layers_to_extract_from=self.config['_layers_to_extract_from'], 27 | device=self.device, 28 | input_shape=self.config['_input_shape'], 29 | pretrain_embed_dimension=self.config['_pretrain_embed_dimension'], 30 | target_embed_dimension=self.config['_target_embed_dimension'], 31 | patchsize=self.config['_patch_size'], 32 | featuresampler=self.sampler, 33 | anomaly_scorer_num_nn=self.config['_anomaly_scorer_num_nn'], 34 | nn_method=self.nn_method, 35 | lof_k=self.config['_lof_k'], 36 | threshold=self.config['_threshold'], 37 | weight_method=self.config['_weight_method'], 38 | soft_weight_flag=self.config['_soft_weight_flag'], 39 | ) 40 | 41 | def get_sampler(self, name): 42 | if name == 'identity': 43 | return sampler.IdentitySampler() 44 | elif name == 'greedy_coreset': 45 | return sampler.GreedyCoresetSampler(self.config['sampler_percentage'], self.device) 46 | elif name == 'approx_greedy_coreset': 47 | return sampler.ApproximateGreedyCoresetSampler(self.config['sampler_percentage'], self.device) 48 | elif name == 'weighted_greedy_coreset': 49 | return sampler.WeightedGreedyCoresetSampler(self.config['sampler_percentage'], self.device) 50 | else: 51 | raise ValueError('No This Sampler: {}'.format(name)) 52 | 53 | def train_model(self, train_loader, task_id, inf=''): 54 | self.patchcore_instance.eval() 55 | self.patchcore_instance.fit(train_loader) 56 | 57 | def prediction(self, valid_loader, task_id=None): 58 | self.patchcore_instance.eval() 59 | self.clear_all_list() 60 | 61 | scores, segmentations, labels_gt, masks_gt, img_srcs = self.patchcore_instance.predict(valid_loader) 62 | 63 | scores = np.array(scores) 64 | min_scores = scores.min(axis=-1).reshape(-1, 1) 65 | max_scores = scores.max(axis=-1).reshape(-1, 1) 66 | scores = (scores - min_scores) / (max_scores - min_scores) 67 | scores = np.mean(scores, axis=0) 68 | 69 | segmentations = np.array(segmentations) 70 | min_scores = segmentations.reshape(len(segmentations), -1).min(axis=-1).reshape(-1, 1, 1, 1) 71 | max_scores = segmentations.reshape(len(segmentations), -1).max(axis=-1).reshape(-1, 1, 1, 1) 72 | segmentations = (segmentations - min_scores) / (max_scores - min_scores) 73 | segmentations = np.mean(segmentations, axis=0) 74 | segmentations[segmentations >= 0.5] = 1 75 | segmentations[segmentations < 0.5] = 0 76 | segmentations = np.array(segmentations, dtype='uint8') 77 | masks_gt = np.array(masks_gt).squeeze().astype(int) 78 | 79 | self.pixel_gt_list = [mask for mask in masks_gt] 80 | self.pixel_pred_list = [seg for seg in segmentations] 81 | self.img_gt_list = labels_gt 82 | self.img_pred_list = scores 83 | self.img_path_list = img_srcs 84 | -------------------------------------------------------------------------------- /models/net_csflow/net_csflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from efficientnet_pytorch import EfficientNet 5 | from models.net_csflow.freia_funcs import * 6 | 7 | class NetCSFlow(nn.Module): 8 | def __init__(self, args): 9 | super(NetCSFlow, self).__init__() 10 | self.args = args 11 | self.feature_extractor = EfficientNet.from_pretrained('efficientnet-b5') 12 | for param in self.feature_extractor.parameters(): 13 | param.requires_grad = False 14 | self.map_size = (self.args._image_size // 12, self.args._image_size // 12) 15 | self.kernel_sizes = [3] * (self.args._n_coupling_blocks - 1) + [5] 16 | self.density_estimator = self.get_cs_flow_model(input_dim=self.args._n_feat) 17 | 18 | def get_cs_flow_model(self, input_dim): 19 | nodes = list() 20 | nodes.append(InputNode(input_dim, self.map_size[0], self.map_size[1], name='input')) 21 | nodes.append(InputNode(input_dim, self.map_size[0] // 2, self.map_size[1] // 2, name='input2')) 22 | nodes.append(InputNode(input_dim, self.map_size[0] // 4, self.map_size[1] // 4, name='input3')) 23 | 24 | for k in range(self.args._n_coupling_blocks): 25 | if k == 0: 26 | node_to_permute = [nodes[-3].out0, nodes[-2].out0, nodes[-1].out0] 27 | else: 28 | node_to_permute = [nodes[-1].out0, nodes[-1].out1, nodes[-1].out2] 29 | 30 | nodes.append(Node(node_to_permute, ParallelPermute, {'seed': k}, name=F'permute_{k}')) 31 | nodes.append(Node([nodes[-1].out0, nodes[-1].out1, nodes[-1].out2], parallel_glow_coupling_layer, 32 | {'clamp': self.args._clamp, 'F_class': CrossConvolutions, 33 | 'F_args': {'channels_hidden': self.args._fc_internal, 34 | 'kernel_size': self.kernel_sizes[k], 'block_no': k}}, 35 | name=F'fc1_{k}')) 36 | 37 | nodes.append(OutputNode([nodes[-1].out0], name='output_end0')) 38 | nodes.append(OutputNode([nodes[-2].out1], name='output_end1')) 39 | nodes.append(OutputNode([nodes[-3].out2], name='output_end2')) 40 | nf = ReversibleGraphNet(nodes, n_jac=3) 41 | return nf 42 | 43 | def eff_ext(self, x, use_layer=36): 44 | x = self.feature_extractor._swish(self.feature_extractor._bn0(self.feature_extractor._conv_stem(x))) 45 | # Blocks 46 | for idx, block in enumerate(self.feature_extractor._blocks): 47 | drop_connect_rate = self.feature_extractor._global_params.drop_connect_rate 48 | if drop_connect_rate: 49 | drop_connect_rate *= float(idx) / len(self.feature_extractor._blocks) # scale drop connect_rate 50 | x = block(x, drop_connect_rate=drop_connect_rate) 51 | if idx == use_layer: 52 | return x 53 | 54 | def forward_features(self, x): 55 | y = list() 56 | for s in range(self.args._n_scales): 57 | feat_s = F.interpolate(x, size=( 58 | self.args._image_size // (2 ** s), self.args._image_size // (2 ** s))) if s > 0 else x 59 | feat_s = self.eff_ext(feat_s) 60 | y.append(feat_s) 61 | return y 62 | 63 | def forward_logits(self, y): 64 | z, log_jac_det = self.density_estimator(y), self.density_estimator.jacobian(run_forward=False) 65 | z = torch.cat([z[i].reshape(z[i].shape[0], -1) for i in range(len(z))], dim=1) 66 | log_jac_det = sum(log_jac_det) 67 | return z, log_jac_det 68 | 69 | def forward(self, x): 70 | # embeds = torch.cat([y[i].reshape(y[i].shape[0], -1) for i in range(len(y))], dim=1) 71 | # z, log_jac_det = self.density_estimator(y), self.density_estimator.jacobian(run_forward=False) 72 | # z = torch.cat([z[i].reshape(z[i].shape[0], -1) for i in range(len(z))], dim=1) 73 | # log_jac_det = sum(log_jac_det) 74 | y = self.forward_features(x) # y(16, 512, 24/12/6, 24/12/6) 75 | z, log_jac_det = self.forward_logits(y) # z(16, 229824) 76 | 77 | zz = self.density_estimator(y) 78 | yy = self.density_estimator(zz, rev=True) 79 | return y, z, log_jac_det 80 | 81 | 82 | -------------------------------------------------------------------------------- /arch/devnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import argparse 4 | import torch.nn.functional as F 5 | from arch.base import ModelBase 6 | from loss_function.deviation import DeviationLoss 7 | from loss_function.binaryfocal import BinaryFocalLoss 8 | from models.devnet.devnet_resnet18 import DevNetResNet18 9 | from optimizer.optimizer import get_optimizer 10 | 11 | __all__ = ['DevNet'] 12 | 13 | def build_criterion(criterion): 14 | if criterion == 'deviation': 15 | return DeviationLoss() 16 | elif criterion == 'BCE': 17 | return torch.nn.BCEWithLogitsLoss() 18 | elif criterion == 'focal': 19 | return BinaryFocalLoss() 20 | elif criterion == 'CE': 21 | return torch.nn.CrossEntropyLoss() 22 | else: 23 | raise NotImplementedError 24 | 25 | class _DevNet(nn.Module): 26 | def __init__(self, args, net): 27 | super(_DevNet, self).__init__() 28 | self.args = args 29 | self.net = net 30 | 31 | self.conv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, padding=0) 32 | 33 | def forward(self, image): 34 | if self.args._n_scales == 0: 35 | raise ValueError 36 | 37 | image_pyramid = list() 38 | for s in range(self.args._n_scales): 39 | image_scaled = F.interpolate(image, size=self.args._img_size // (2 ** s)) if s > 0 else image 40 | feature = self.net(image_scaled) 41 | 42 | scores = self.conv(feature) 43 | if self.args._topk > 0: 44 | scores = scores.view(int(scores.size(0)), -1) 45 | topk = max(int(scores.size(1) * self.args._topk), 1) 46 | scores = torch.topk(torch.abs(scores), topk, dim=1)[0] 47 | scores = torch.mean(scores, dim=1).view(-1, 1) 48 | else: 49 | scores = scores.view(int(scores.size(0)), -1) 50 | scores = torch.mean(scores, dim=1).view(-1, 1) 51 | 52 | image_pyramid.append(scores) 53 | scores = torch.cat(image_pyramid, dim=1) 54 | score = torch.mean(scores, dim=1) 55 | return score.view(-1, 1) 56 | 57 | class DevNet(ModelBase): 58 | def __init__(self, config): 59 | super(DevNet, self).__init__(config) 60 | self.config = config 61 | 62 | self.net = DevNetResNet18() 63 | self.optimizer = get_optimizer(self.config, self.net.parameters()) 64 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.config['_step_size'], gamma=self.config['_gamma']) 65 | args = argparse.Namespace(**self.config) 66 | self.model = _DevNet(args, self.net).to(self.device) 67 | self.criterion = build_criterion(self.config['_criterion']) 68 | 69 | def train_model(self, train_loader, task_id, inf=''): 70 | self.model.train() 71 | self.scheduler.step() 72 | 73 | train_loss = 0. 74 | 75 | for epoch in range(self.config['num_epochs']): 76 | for batch_id, batch in enumerate(train_loader): 77 | image = batch['img'].to(self.device) 78 | target = batch['label'].to(self.device) 79 | 80 | output = self.model(image) 81 | loss = self.criterion(output, target.unsqueeze(1).float()) 82 | self.optimizer.zero_grad() 83 | loss.backward() 84 | 85 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 86 | self.optimizer.step() 87 | train_loss += loss.item() 88 | 89 | def prediction(self, valid_loader, task_id): 90 | self.model.eval() 91 | self.clear_all_list() 92 | test_loss = 0.0 93 | 94 | for batch_id, batch in enumerate(valid_loader): 95 | image = batch['img'].to(self.device) 96 | target = batch['label'].to(self.device) 97 | 98 | with torch.no_grad(): 99 | output = self.model(image.float()) 100 | loss = self.criterion(output, target.unsqueeze(1).float()) 101 | test_loss += loss.item() 102 | 103 | self.img_gt_list.append(target.cpu().numpy()[0]) 104 | self.img_pred_list.append(output.data.cpu().numpy()[0]) 105 | self.img_path_list.append(batch['img_src']) 106 | -------------------------------------------------------------------------------- /models/softpatch/multi_variate_gaussian.py: -------------------------------------------------------------------------------- 1 | """Multi Variate Gaussian Distribution.""" 2 | 3 | from typing import Any, List, Optional 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | 8 | 9 | class MultiVariateGaussian(nn.Module): 10 | """Multi Variate Gaussian Distribution.""" 11 | 12 | def __init__(self, n_features, n_patches): 13 | super().__init__() 14 | 15 | self.register_buffer("mean", torch.zeros(n_features, n_patches)) 16 | self.register_buffer("inv_covariance", torch.eye(n_features).unsqueeze(0).repeat(n_patches, 1, 1)) 17 | 18 | self.mean: Tensor 19 | self.inv_covariance: Tensor 20 | 21 | @staticmethod 22 | def _cov( 23 | observations: Tensor, 24 | rowvar: bool = False, 25 | bias: bool = False, 26 | ddof: Optional[int] = None, 27 | aweights: Tensor = None, 28 | ) -> Tensor: 29 | 30 | # ensure at least 2D 31 | if observations.dim() == 1: 32 | observations = observations.view(-1, 1) 33 | 34 | # treat each column as a data point, each row as a variable 35 | if rowvar and observations.shape[0] != 1: 36 | observations = observations.t() 37 | 38 | if ddof is None: 39 | if bias == 0: 40 | ddof = 1 41 | else: 42 | ddof = 0 43 | 44 | weights = aweights 45 | weights_sum: Any 46 | 47 | if weights is not None: 48 | if not torch.is_tensor(weights): 49 | weights = torch.tensor(weights, dtype=torch.float) # pylint: disable=not-callable 50 | weights_sum = torch.sum(weights) 51 | avg = torch.sum(observations * (weights / weights_sum)[:, None], 0) 52 | else: 53 | avg = torch.mean(observations, 0) 54 | 55 | # Determine the normalization 56 | if weights is None: 57 | fact = observations.shape[0] - ddof 58 | elif ddof == 0: 59 | fact = weights_sum 60 | elif aweights is None: 61 | fact = weights_sum - ddof 62 | else: 63 | fact = weights_sum - ddof * torch.sum(weights * weights) / weights_sum 64 | 65 | observations_m = observations.sub(avg.expand_as(observations)) 66 | 67 | if weights is None: 68 | x_transposed = observations_m.t() 69 | else: 70 | x_transposed = torch.mm(torch.diag(weights), observations_m).t() 71 | 72 | covariance = torch.mm(x_transposed, observations_m) 73 | covariance = covariance / fact 74 | 75 | return covariance.squeeze() 76 | 77 | def forward(self, embedding: Tensor) -> List[Tensor]: 78 | """Calculate multivariate Gaussian distribution. 79 | 80 | Args: 81 | embedding (Tensor): CNN features whose dimensionality is reduced via either random sampling or PCA. 82 | 83 | Returns: 84 | mean and inverse covariance of the multi-variate gaussian distribution that fits the features. 85 | """ 86 | device = embedding.device 87 | patch, _, channel = embedding.shape 88 | embedding_vectors = embedding.permute(1, 2, 0) 89 | 90 | # batch, channel, height, width = embedding.size() 91 | # embedding_vectors = embedding.view(batch, channel, height * width) 92 | self.mean = torch.mean(embedding_vectors, dim=0) 93 | covariance = torch.zeros(size=(channel, channel, patch), device=device) 94 | identity = torch.eye(channel).to(device) 95 | for i in range(patch): 96 | covariance[:, :, i] = self._cov(embedding_vectors[:, :, i], rowvar=False) + 0.01 * identity 97 | # (evals, evecs) = torch.eig(covariance[:, :, i]) # 98 | # compaction[i] = evals[:, 0].max() #torch.max(evals[:, 0]) 99 | # calculate inverse covariance as we need only the inverse 100 | self.inv_covariance = torch.linalg.inv(covariance.permute(2, 0, 1)) 101 | # compaction = covariance.norm(p=2, dim=(0, 1)) 102 | 103 | return [self.mean, self.inv_covariance] # 104 | # return [self.mean, self.inv_covariance, compaction] # 105 | 106 | def fit(self, embedding: Tensor) -> List[Tensor]: 107 | """Fit multi-variate gaussian distribution to the input embedding. 108 | 109 | Args: 110 | embedding (Tensor): Embedding vector extracted from CNN. 111 | 112 | Returns: 113 | Mean and the covariance of the embedding. 114 | """ 115 | return self.forward(embedding) 116 | -------------------------------------------------------------------------------- /models/favae/func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | import random 5 | 6 | __all__ = ['AverageMeter', 'time_string', 'convert_secs2time', 'time_file_str', 7 | 'print_log', 'feature_extractor', 'EarlyStop'] 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | def time_string(): 28 | ISOTIMEFORMAT = '%Y-%m-%d %X' 29 | string = '[{}]'.format(time.strftime(ISOTIMEFORMAT, time.localtime())) 30 | return string 31 | 32 | 33 | def convert_secs2time(epoch_time): 34 | need_hour = int(epoch_time / 3600) 35 | need_mins = int((epoch_time - 3600 * need_hour) / 60) 36 | need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) 37 | return need_hour, need_mins, need_secs 38 | 39 | 40 | def time_file_str(): 41 | ISOTIMEFORMAT = '%Y-%m-%d' 42 | string = '{}'.format(time.strftime(ISOTIMEFORMAT, time.localtime())) 43 | return string + '-{}'.format(random.randint(1, 10000)) 44 | 45 | 46 | def print_log(print_string, log): 47 | print("{:}".format(print_string)) 48 | log.write('{:}\n'.format(print_string)) 49 | log.flush() 50 | 51 | def feature_extractor(x, model, target_layers): 52 | target_activations = list() 53 | for name, module in model._modules.items(): 54 | x = module(x) 55 | if name in target_layers: 56 | target_activations += [x] 57 | return target_activations, x 58 | 59 | def denormalization(x): 60 | # mean = np.array([0.485, 0.456, 0.406]) 61 | # std = np.array([0.229, 0.224, 0.225]) 62 | # x = (((x.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8) 63 | x = (x.transpose(1, 2, 0) * 255.).astype(np.uint8) 64 | return x 65 | 66 | 67 | def rescale(x): 68 | return (x - x.min()) / (x.max() - x.min()) 69 | 70 | 71 | class EarlyStop(): 72 | """Used to early stop the training if validation loss doesn't improve after a given patience.""" 73 | def __init__(self, patience=20, verbose=True, delta=0, save_name="checkpoint.pt"): 74 | """ 75 | Args: 76 | patience (int): How long to wait after last time validation loss improved. 77 | Default: 20 78 | verbose (bool): If True, prints a message for each validation loss improvement. 79 | Default: False 80 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 81 | Default: 0 82 | save_name (string): The filename with which the model and the optimizer is saved when improved. 83 | Default: "checkpoint.pt" 84 | """ 85 | self.patience = patience 86 | self.verbose = verbose 87 | self.save_name = save_name 88 | self.counter = 0 89 | self.best_score = None 90 | self.early_stop = False 91 | self.val_loss_min = np.Inf 92 | self.delta = delta 93 | 94 | def __call__(self, val_loss, model, optimizer, log): 95 | 96 | score = -val_loss 97 | 98 | if self.best_score is None: 99 | self.best_score = score 100 | self.save_checkpoint(val_loss, model, optimizer, log) 101 | elif score < self.best_score - self.delta: 102 | self.counter += 1 103 | print_log((f'EarlyStopping counter: {self.counter} out of {self.patience}'), log) 104 | if self.counter >= self.patience: 105 | self.early_stop = True 106 | else: 107 | self.best_score = score 108 | self.save_checkpoint(val_loss, model, optimizer, log) 109 | self.counter = 0 110 | 111 | return self.early_stop 112 | 113 | def save_checkpoint(self, val_loss, model, optimizer, log): 114 | '''Saves model when validation loss decrease.''' 115 | if self.verbose: 116 | print_log((f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...'), 117 | log) 118 | state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()} 119 | torch.save(state, self.save_name) 120 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /models/favae/net_favae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | #from torchsummary import summary 4 | 5 | __all__ = ['NetFAVAE'] 6 | 7 | class NetFAVAE(nn.Module): 8 | def __init__(self, input_channel=3, z_dim=100): 9 | super(NetFAVAE, self).__init__() 10 | 11 | # encode 12 | self.encode = nn.Sequential( 13 | nn.Conv2d(input_channel, 128, kernel_size=4, stride=2, padding=1), # 128 => 64 14 | nn.BatchNorm2d(128), 15 | nn.LeakyReLU(negative_slope=0.2), 16 | nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1), # 64 => 32 17 | nn.BatchNorm2d(128), 18 | nn.LeakyReLU(negative_slope=0.2), 19 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), 20 | nn.BatchNorm2d(256), 21 | nn.LeakyReLU(negative_slope=0.2), 22 | nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1), # 32 => 16 23 | nn.BatchNorm2d(256), 24 | nn.LeakyReLU(negative_slope=0.2), 25 | nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), 26 | nn.BatchNorm2d(512), 27 | nn.LeakyReLU(negative_slope=0.2), 28 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), # 16 => 8 29 | nn.BatchNorm2d(512), 30 | nn.LeakyReLU(negative_slope=0.2), 31 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), 32 | nn.BatchNorm2d(512), 33 | nn.LeakyReLU(negative_slope=0.2), 34 | nn.Conv2d(512, 32, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(32), 36 | nn.LeakyReLU(negative_slope=0.2), 37 | nn.Conv2d(32, 200, kernel_size=8, stride=1), # 8 => 1 38 | nn.Flatten(), 39 | Split()) 40 | 41 | # decode 42 | self.decode = nn.Sequential( 43 | DeFlatten(), 44 | nn.ConvTranspose2d(100, 32, kernel_size=8, stride=1), # 1 => 8 45 | nn.BatchNorm2d(32), 46 | nn.LeakyReLU(negative_slope=0.2), 47 | nn.ConvTranspose2d(32, 512, kernel_size=3, stride=1, padding=1), 48 | nn.BatchNorm2d(512), 49 | nn.LeakyReLU(negative_slope=0.2), 50 | nn.ConvTranspose2d(512, 512, kernel_size=3, stride=1, padding=1), 51 | nn.BatchNorm2d(512), 52 | nn.LeakyReLU(negative_slope=0.2), 53 | nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1), # 8 => 16 54 | nn.BatchNorm2d(512), 55 | nn.LeakyReLU(negative_slope=0.2), 56 | nn.ConvTranspose2d(512, 256, kernel_size=3, stride=1, padding=1), 57 | nn.BatchNorm2d(256), 58 | nn.LeakyReLU(negative_slope=0.2), 59 | nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1), # 16 => 32 60 | nn.BatchNorm2d(256), 61 | nn.LeakyReLU(negative_slope=0.2), 62 | nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1), 63 | nn.BatchNorm2d(128), 64 | nn.LeakyReLU(negative_slope=0.2), 65 | nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), # 32 => 64 66 | nn.BatchNorm2d(128), 67 | nn.LeakyReLU(negative_slope=0.2), 68 | nn.ConvTranspose2d(128, input_channel, kernel_size=4, stride=2, padding=1), # 64 => 128 69 | nn.Identity(), 70 | nn.Sigmoid() 71 | # nn.Tanh() 72 | ) 73 | 74 | self.adapter = nn.ModuleList([Adapter_model(128), Adapter_model(256), Adapter_model(512)]) 75 | 76 | def reparameterize(self, mu, log_var): 77 | if self.training: 78 | std = log_var.mul(0.5).exp_() 79 | eps = std.new(std.size()).normal_() 80 | return eps.mul(std).add_(mu) 81 | else: 82 | return mu 83 | 84 | def forward(self, x): 85 | mu, logvar = self.encode(x) 86 | z = self.reparameterize(mu, logvar) 87 | 88 | return z, self.decode(z), mu, logvar 89 | 90 | 91 | class DeFlatten(nn.Module): 92 | def forward(self, x): 93 | return x.view(x.shape[0], 100, 1, 1) 94 | 95 | 96 | class Split(nn.Module): 97 | def forward(self, x): 98 | mu, logvar = x.chunk(2, dim=1) 99 | return mu, logvar 100 | 101 | 102 | class Adapter_model(nn.Module): 103 | def __init__(self, channel=128): 104 | super(Adapter_model, self).__init__() 105 | 106 | self.conv = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=1, stride=1), nn.ReLU(), 107 | nn.Conv2d(channel, channel, kernel_size=1, stride=1)) 108 | 109 | def forward(self, x): 110 | return self.conv(x) -------------------------------------------------------------------------------- /configuration/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from asyncio import FastChildWatcher 3 | from logging import root 4 | 5 | __all__ = ['parse_arguments_main'] 6 | 7 | def parse_arguments_main(): 8 | parser = argparse.ArgumentParser() 9 | ## learning paradigm 10 | parser.add_argument('--paradigm', '-p', type=str, default='c2d', choices=['c2d', 'c3d', 'f2d']) 11 | 12 | # ----------------------------- centralized learning ----------------------------- # 13 | parser.add_argument('--dataset', '-d', type=str, default='mvtec2d', choices=['_example', 'mvtec2d', 'mvtec3d', 'mpdd', 'mvtecloco', 'mtd', 14 | 'btad', 'mvtec2df3d', 'visa', 'dagm', 'coad']) 15 | parser.add_argument('--model', '-m', type=str, default='softpatch', choices=['_example', '_patchcore', 'patchcore', 'csflow', 'dne', 16 | 'draem', 'igd', 'cutpaste', 'devnet', 'dra', 'favae', 'padim', 'reverse', 'spade', 'fastflow', 'softpatch', 'cfa', 'stpm', 17 | 'simplenet', 'softpatch']) 18 | parser.add_argument('--net', '-n', type=str, default='wide_resnet50', choices=['net_example', 'wide_resnet50', 'resnet18', 'net_csflow', 19 | 'vit_b_16', 'net_draem', 'net_dra', 'net_igd', 'net_reverse', 'net_favae', 'net_fastflow', 'net_cfa', 'net_devnet', 20 | 'vig_ti_224_gelu']) 21 | 22 | parser.add_argument('--root-path', '-rp', type=str, default=None) 23 | parser.add_argument('--data-path', '-dp', type=str, default=None) 24 | 25 | parser.add_argument('--train-task-id', '-tid', type=int, default=[0], nargs='+') 26 | parser.add_argument('--valid-task-id', '-vid', type=int, default=[0], nargs='+') 27 | parser.add_argument('--sampler-percentage', '-sp', type=float, default=None) 28 | 29 | # vanilla 30 | parser.add_argument('--vanilla', '-v', action='store_true', default=False) 31 | 32 | # semi-supervised 33 | parser.add_argument('--semi', '-s', action='store_true', default=False) 34 | parser.add_argument('--semi-anomaly-num', '-san', type=int, default=None) 35 | parser.add_argument('--semi-overlap', '-so', action='store_true', default=False) 36 | 37 | # continual 38 | parser.add_argument('--continual', '-c', action='store_true', default=False) 39 | 40 | # fewshot 41 | parser.add_argument('--fewshot', '-f', action='store_true', default=False) 42 | parser.add_argument('--fewshot-exm', '-fe', type=int, default=None) 43 | parser.add_argument('--fewshot-data-aug', '-fda', action='store_true', default=False) 44 | parser.add_argument('--fewshot-feat-aug', '-ffa', action='store_true', default=False) 45 | parser.add_argument('--fewshot-num-dg', '-fnd', type=int, default=None) 46 | parser.add_argument('--fewshot-aug-type', '-fat', default=None, nargs='+', 47 | choices=['normal', 'rotation', 'scale', 'translate', 'flip', 'color_jitter', 'perspective']) 48 | 49 | # noisy label 50 | parser.add_argument('--noisy', '-z', action='store_true', default=False) 51 | parser.add_argument('--noisy-overlap', '-no', action='store_true', default=False) 52 | parser.add_argument('--noisy-ratio', '-nr', type=float, default=None) 53 | 54 | # transfer 55 | parser.add_argument('--transfer', '-t', action='store_true', default=False) 56 | parser.add_argument('--transfer-target-sample-num', '-ttn', type=int, default=None) 57 | 58 | # data augmentation type 59 | parser.add_argument('--train-aug-type', '-tag', default=None, choices=['normal', 'cutpaste'], help='data augmentation type') 60 | parser.add_argument('--valid-aug-type', '-vag', default=None, choices=['normal', 'cutpaste'], help='data augmentation type') 61 | 62 | # universal 63 | parser.add_argument('--gpu-id', '-g', type=str, default=None) 64 | parser.add_argument('--server-moda', '-sm', type=str, default=None, choices=['eno1', 'lo']) 65 | parser.add_argument('--num-epochs', '-ne', type=int, default=None) 66 | parser.add_argument('--seed', type=int, default=None) 67 | parser.add_argument('--debug', action='store_true', default=False) 68 | parser.add_argument('--vis', '-vis', action='store_true', default=True) 69 | parser.add_argument('--vis-em', action='store_true', default=False) 70 | 71 | parser.add_argument('--save-model', action='store_true', default=False) 72 | parser.add_argument('--load-model', action='store_true', default=False) 73 | parser.add_argument('--load-model-dir', type=str, default=None) 74 | 75 | # ----------------------------- federated learning ----------------------------- # 76 | parser.add_argument('--fed-aggregate-method', '-fam', type=str, default=None) 77 | parser.add_argument('--num-round', type=int, default=None) 78 | 79 | 80 | args = parser.parse_args() 81 | return args 82 | 83 | -------------------------------------------------------------------------------- /models/softpatch/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union 4 | from models.patchcore.sampler import GreedyCoresetSampler as GreedyCoresetSamplerBase 5 | import tqdm 6 | 7 | class IdentitySampler: 8 | def run( 9 | self, features: Union[torch.Tensor, np.ndarray] 10 | ) -> Union[torch.Tensor, np.ndarray]: 11 | return features 12 | 13 | class GreedyCoresetSampler(GreedyCoresetSamplerBase): 14 | def __init__( 15 | self, 16 | percentage: float, 17 | device: torch.device, 18 | dimension_to_project_features_to=128, 19 | ): 20 | """Greedy Coreset sampling base class.""" 21 | super().__init__(percentage, device, dimension_to_project_features_to) 22 | 23 | def run( 24 | self, features: Union[torch.Tensor, np.ndarray] 25 | ) -> Union[torch.Tensor, np.ndarray]: 26 | """Subsamples features using Greedy Coreset. 27 | 28 | Args: 29 | features: [N x D] 30 | """ 31 | if self.percentage == 1: 32 | return features 33 | self._store_type(features) 34 | if isinstance(features, np.ndarray): 35 | features = torch.from_numpy(features) 36 | reduced_features = self._reduce_features(features) 37 | sample_indices = self._compute_greedy_coreset_indices(reduced_features) 38 | features = features[sample_indices] 39 | 40 | return self._restore_type(features), sample_indices 41 | 42 | class ApproximateGreedyCoresetSampler(GreedyCoresetSampler): 43 | def __init__( 44 | self, 45 | percentage: float, 46 | device: torch.device, 47 | number_of_starting_points: int = 10, 48 | dimension_to_project_features_to: int = 128, 49 | ): 50 | """Approximate Greedy Coreset sampling base class.""" 51 | self.number_of_starting_points = number_of_starting_points 52 | super().__init__(percentage, device, dimension_to_project_features_to) 53 | 54 | def _compute_greedy_coreset_indices(self, features: torch.Tensor) -> np.ndarray: 55 | """Runs approximate iterative greedy coreset selection. 56 | 57 | This greedy coreset implementation does not require computation of the 58 | full N x N distance matrix and thus requires a lot less memory, however 59 | at the cost of increased sampling times. 60 | 61 | Args: 62 | features: [NxD] input feature bank to sample. 63 | """ 64 | number_of_starting_points = np.clip( 65 | self.number_of_starting_points, None, len(features) 66 | ) 67 | start_points = np.random.choice( 68 | len(features), number_of_starting_points, replace=False 69 | ).tolist() 70 | 71 | approximate_distance_matrix = self._compute_batchwise_differences( 72 | features, features[start_points] 73 | ) 74 | approximate_coreset_anchor_distances = torch.mean( 75 | approximate_distance_matrix, axis=-1 76 | ).reshape(-1, 1) 77 | coreset_indices = [] 78 | num_coreset_samples = int(len(features) * self.percentage) 79 | 80 | with torch.no_grad(): 81 | for _ in tqdm.tqdm(range(num_coreset_samples), desc="Subsampling..."): 82 | select_idx = torch.argmax(approximate_coreset_anchor_distances).item() 83 | coreset_indices.append(select_idx) 84 | coreset_select_distance = self._compute_batchwise_differences( 85 | features, features[select_idx : select_idx + 1] # noqa: E203 86 | ) 87 | approximate_coreset_anchor_distances = torch.cat( 88 | [approximate_coreset_anchor_distances, coreset_select_distance], 89 | dim=-1, 90 | ) 91 | approximate_coreset_anchor_distances = torch.min( 92 | approximate_coreset_anchor_distances, dim=1 93 | ).values.reshape(-1, 1) 94 | 95 | return np.array(coreset_indices) 96 | 97 | class WeightedGreedyCoresetSampler(ApproximateGreedyCoresetSampler): 98 | def __init__( 99 | self, 100 | percentage: float, 101 | device: torch.device, 102 | number_of_starting_points: int = 10, 103 | dimension_to_project_features_to: int = 128, 104 | ): 105 | """Approximate Greedy Coreset sampling base class.""" 106 | self.number_of_starting_points = number_of_starting_points 107 | super().__init__(percentage, device, dimension_to_project_features_to) 108 | self.sampling_weight = None 109 | 110 | def set_sampling_weight(self, sampling_weight): 111 | self.sampling_weight = sampling_weight -------------------------------------------------------------------------------- /models/cfa/cfa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from sklearn.cluster import KMeans 5 | from models.cfa.coordconv import CoordConv2d 6 | import torch.nn.functional as F 7 | 8 | __all__ = ['DSVDD', 'Descriptor'] 9 | 10 | class DSVDD(nn.Module): 11 | def __init__(self, model, data_loader, cnn, gamma_c, gamma_d, device): 12 | super(DSVDD, self).__init__() 13 | self.device = device 14 | 15 | self.C = 0 16 | self.nu = 1e-3 17 | self.scale = None 18 | 19 | self.gamma_c = gamma_c 20 | self.gamma_d = gamma_d 21 | self.alpha = 1e-1 22 | self.K = 3 23 | self.J = 3 24 | 25 | self.r = nn.Parameter(1e-5*torch.ones(1), requires_grad=True) 26 | self.Descriptor = Descriptor(self.gamma_d, cnn, self.device).to(self.device) 27 | model.to(self.device) 28 | self._init_centroid(model, data_loader) 29 | self.C = rearrange(self.C, 'b c h w -> (b h w) c').detach() 30 | 31 | if self.gamma_c > 1: 32 | self.C = self.C.cpu().detach().numpy() 33 | self.C = KMeans(n_clusters=(self.scale**2)//self.gamma_c, max_iter=3000).fit(self.C).cluster_centers_ 34 | self.C = torch.Tensor(self.C).to(device) 35 | 36 | self.C = self.C.transpose(-1, -2).detach() 37 | self.C = nn.Parameter(self.C, requires_grad=False) 38 | 39 | def forward(self, p): 40 | phi_p = self.Descriptor(p) 41 | phi_p = rearrange(phi_p, 'b c h w -> b (h w) c') 42 | 43 | features = torch.sum(torch.pow(phi_p, 2), 2, keepdim=True) 44 | centers = torch.sum(torch.pow(self.C, 2), 0, keepdim=True) 45 | f_c = 2 * torch.matmul(phi_p, (self.C)) 46 | dist = features + centers - f_c 47 | dist = torch.sqrt(dist) 48 | 49 | n_neighbors = self.K 50 | dist = dist.topk(n_neighbors, largest=False).values 51 | 52 | dist = (F.softmin(dist, dim=-1)[:, :, 0]) * dist[:, :, 0] 53 | dist = dist.unsqueeze(-1) 54 | 55 | score = rearrange(dist, 'b (h w) c -> b c h w', h=self.scale) 56 | 57 | loss = 0 58 | if self.training: 59 | loss = self._soft_boundary(phi_p) 60 | 61 | return loss, score 62 | 63 | def _soft_boundary(self, phi_p): 64 | features = torch.sum(torch.pow(phi_p, 2), 2, keepdim=True) 65 | centers = torch.sum(torch.pow(self.C, 2), 0, keepdim=True) 66 | f_c = 2 * torch.matmul(phi_p, (self.C)) 67 | dist = features + centers - f_c 68 | n_neighbors = self.K + self.J 69 | dist = dist.topk(n_neighbors, largest=False).values 70 | 71 | score = (dist[:, : , :self.K] - self.r**2) 72 | L_att = (1/self.nu) * torch.mean(torch.max(torch.zeros_like(score), score)) 73 | 74 | score = (self.r**2 - dist[:, : , self.J:]) 75 | L_rep = (1/self.nu) * torch.mean(torch.max(torch.zeros_like(score), score - self.alpha)) 76 | 77 | loss = L_att + L_rep 78 | 79 | return loss 80 | 81 | def _init_centroid(self, model, data_loader): 82 | for i, batch in enumerate(data_loader): 83 | img = batch['img'].to(self.device) 84 | p = model(img) 85 | self.scale = p[0].size(2) 86 | phi_p = self.Descriptor(p) 87 | self.C = ((self.C * i) + torch.mean(phi_p, dim=0, keepdim=True).detach()) / (i+1) 88 | 89 | 90 | class Descriptor(nn.Module): 91 | def __init__(self, gamma_d, cnn, device): 92 | super(Descriptor, self).__init__() 93 | self.cnn = cnn 94 | self.device = device 95 | 96 | if cnn == 'wide_resnet50': 97 | dim = 1792 98 | self.layer = CoordConv2d(dim, dim//gamma_d, 1, self.device).to(self.device) 99 | elif cnn == 'resnet18': 100 | dim = 448 101 | self.layer = CoordConv2d(dim, dim//gamma_d, 1, self.device).to(self.device) 102 | elif cnn == 'efficientnet': 103 | dim = 568 104 | self.layer = CoordConv2d(dim, 2*dim//gamma_d, 1, self.device).to(self.device) 105 | elif cnn == 'vgg': 106 | dim = 1280 107 | self.layer = CoordConv2d(dim, dim//gamma_d, 1, self.device).to(self.device) 108 | 109 | 110 | def forward(self, p): 111 | sample = None 112 | for o in p: 113 | o = F.avg_pool2d(o, 3, 1, 1) / o.size(1) if self.cnn == 'effnet-b5' else F.avg_pool2d(o, 3, 1, 1) 114 | sample = o if sample is None else torch.cat((sample, F.interpolate(o, sample.size(2), mode='bilinear')), dim=1) 115 | 116 | sample 117 | phi_p = self.layer(sample) 118 | return phi_p -------------------------------------------------------------------------------- /dataset/mtd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import random 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | __all__ = ['MTD', 'mtd_classes'] 10 | 11 | 12 | def mtd_classes(): 13 | return ['mtd'] 14 | 15 | 16 | class MTD(Dataset): 17 | def __init__(self, data_path, learning_mode='centralized', phase='train', 18 | data_transform=None, num_task=1): 19 | 20 | self.data_path = data_path 21 | self.learning_mode = learning_mode 22 | self.phase = phase 23 | self.img_transform = data_transform[0] 24 | self.mask_transform = data_transform[1] 25 | self.class_name = mtd_classes() 26 | assert set(self.class_name) <= set(mtd_classes()) 27 | 28 | self.num_task = num_task 29 | self.class_in_task = [] 30 | 31 | self.imgs_list = [] 32 | self.labels_list = [] 33 | self.masks_list = [] 34 | self.task_ids_list = [] 35 | 36 | # continual 37 | self.sample_num_in_task = [] 38 | self.sample_indices_in_task = [] 39 | 40 | # load dataset 41 | self.load_dataset() 42 | self.allocate_task_data() 43 | 44 | def __getitem__(self, idx): 45 | img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx] 46 | 47 | img = Image.open(img_src).convert('RGB') 48 | img = self.img_transform(img) 49 | 50 | if label == 0: 51 | if isinstance(img, tuple): 52 | mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]]) 53 | else: 54 | mask = torch.zeros([1, img.shape[1], img.shape[2]]) 55 | else: 56 | mask = Image.open(mask) 57 | mask = self.mask_transform(mask) 58 | 59 | return { 60 | 'img': img, 'label':label, 'mask':mask, 'task_id':task_id, 'img_src': img_src, 61 | } 62 | 63 | def __len__(self): 64 | return len(self.imgs_list) 65 | 66 | def load_dataset(self): 67 | # input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth 68 | # train directory: only good cases 69 | # test directory: bad and good cases 70 | # ground truth directory: only bad case 71 | 72 | # get classes in each task group 73 | # only one task 74 | self.class_in_task = self.split_chunks(self.class_name, self.num_task) 75 | # get data 76 | for id, class_in_task in enumerate(self.class_in_task): 77 | x, y, mask = [], [], [] 78 | for class_name in class_in_task: 79 | img_dir = os.path.join(self.data_path, self.phase) 80 | gt_dir = os.path.join(self.data_path, 'ground_truth') 81 | 82 | img_types = sorted(os.listdir(img_dir)) 83 | for img_type in img_types: 84 | 85 | # load images 86 | img_type_dir = os.path.join(img_dir, img_type) 87 | if not os.path.isdir(img_type_dir): 88 | continue 89 | img_path_list = sorted([os.path.join(img_type_dir, f) 90 | for f in os.listdir(img_type_dir) 91 | if f.endswith('.jpg')]) 92 | x.extend(img_path_list) 93 | 94 | if img_type == 'good': 95 | y.extend([0] * len(img_path_list)) 96 | mask.extend([None] * len(img_path_list)) 97 | else: 98 | y.extend([1] * len(img_path_list)) 99 | gt_type_dir = os.path.join(gt_dir, img_type) 100 | img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list] 101 | gt_path_list = [os.path.join(gt_type_dir, img_fname + '.png') 102 | for img_fname in img_name_list] 103 | mask.extend(gt_path_list) 104 | # continual 105 | task_id = [id for i in range(len(x))] 106 | self.sample_num_in_task.append(len(x)) 107 | 108 | self.imgs_list.extend(x) 109 | self.labels_list.extend(y) 110 | self.masks_list.extend(mask) 111 | self.task_ids_list.extend(task_id) 112 | 113 | def allocate_task_data(self): 114 | start = 0 115 | for num in self.sample_num_in_task: 116 | end = start + num 117 | indice = [i for i in range(start, end)] 118 | random.shuffle(indice) 119 | self.sample_indices_in_task.append(indice) 120 | start = end 121 | 122 | # split the arr into n chunks 123 | @staticmethod 124 | def split_chunks(arr, m): 125 | n = int(math.ceil(len(arr) / float(m))) 126 | return [arr[i:i + n] for i in range(0, len(arr), n)] -------------------------------------------------------------------------------- /arch/favae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from arch.base import ModelBase 6 | from models.favae.net_favae import NetFAVAE 7 | from torchvision import models 8 | from models.favae.func import EarlyStop, AverageMeter 9 | from scipy.ndimage import gaussian_filter 10 | from optimizer.optimizer import get_optimizer 11 | 12 | 13 | __all__ = ['FAVAE'] 14 | 15 | class FAVAE(ModelBase): 16 | def __init__(self, config): 17 | super(FAVAE, self).__init__(config) 18 | self.config = config 19 | 20 | self.vaenet = NetFAVAE().to(self.device) 21 | self.optimizer = get_optimizer(self.config, self.vaenet.parameters()) 22 | self.scheduler = None 23 | self.teacher = models.vgg16(pretrained=True).to(self.device) 24 | for param in self.teacher.parameters(): 25 | param.requires_grad = False 26 | 27 | self.early_stop = EarlyStop(patience=20, save_name='favae.pt') 28 | self.criterion_1 = nn.MSELoss(reduction='sum') 29 | self.criterion_2 = nn.MSELoss(reduction='none') 30 | 31 | def feature_extractor(self, x, model, target_layers): 32 | target_activations = list() 33 | for name, module in model._modules.items(): 34 | x = module(x) 35 | if name in target_layers: 36 | target_activations += [x] 37 | return target_activations, x 38 | 39 | def train_model(self, train_loader, task_id, inf=''): 40 | self.vaenet.train() 41 | self.teacher.eval() 42 | 43 | losses = AverageMeter() 44 | for epoch in range(self.config['num_epochs']): 45 | for batch_id, batch in enumerate(train_loader): 46 | img = batch['img'].to(self.device) 47 | z, output, mu, log_var = self.vaenet(img) 48 | s_activations, _ = self.feature_extractor(z, self.vaenet.decode, target_layers=['10', '16', '22']) 49 | t_activations, _ = self.feature_extractor(img, self.teacher.features, target_layers=['7', '14', '21']) 50 | 51 | self.optimizer.zero_grad() 52 | mse_loss = self.criterion_1(output, img) 53 | kld_loss = 0.5 * torch.sum(-1 - log_var + torch.exp(log_var) + mu**2) 54 | for i in range(len(s_activations)): 55 | s_act = self.vaenet.adapter[i](s_activations[-(i + 1)]) 56 | mse_loss += self.criterion_1(s_act, t_activations[i]) 57 | loss = mse_loss + self.config['_kld_weight'] * kld_loss 58 | losses.update(loss.sum().item(), img.size(0)) 59 | 60 | loss.backward() 61 | self.optimizer.step() 62 | 63 | def prediction(self, valid_loader, task_id=None): 64 | self.vaenet.eval() 65 | self.teacher.eval() 66 | self.clear_all_list() 67 | 68 | pixel_pred_list = [] 69 | gt_mask_list = [] 70 | recon_imgs = [] 71 | 72 | with torch.no_grad(): 73 | for batch_id, batch in enumerate(valid_loader): 74 | img = batch['img'].to(self.device) 75 | mask = batch['mask'].numpy() 76 | label = batch['label'] 77 | z, output, mu, log_var = self.vaenet(img) 78 | s_activations, _ = self.feature_extractor(z, self.vaenet.decode, target_layers=['10', '16', '22']) 79 | t_activations, _ = self.feature_extractor(img, self.teacher.features, target_layers=['7', '14', '21']) 80 | 81 | score = self.criterion_2(output, img).sum(1, keepdim=True) 82 | 83 | for i in range(len(s_activations)): 84 | s_act = self.vaenet.adapter[i](s_activations[-(i + 1)]) 85 | mse_loss = self.criterion_2(s_act, t_activations[i]).sum(1, keepdim=True) 86 | score += F.interpolate(mse_loss, size=img.size(2), mode='bilinear', align_corners=False) 87 | 88 | score = score.squeeze().cpu().numpy() 89 | 90 | for i in range(score.shape[0]): 91 | score[i] = gaussian_filter(score[i], sigma=4) 92 | pixel_pred_list.append(score.reshape(img.size(2),img.size(2))) 93 | recon_imgs.extend(output.cpu().numpy()) 94 | mask[mask >= 0.5] = 1 95 | mask[mask < 0.5] = 0 96 | gt_mask_list.append(mask[0, 0].astype(int)) 97 | self.img_gt_list.append(label.numpy()[0]) 98 | self.img_pred_list.append(np.max(score)) 99 | self.img_path_list.append(batch['img_src']) 100 | 101 | max_anomaly_score = np.array(pixel_pred_list).max() 102 | min_anomaly_score = np.array(pixel_pred_list).min() 103 | pixel_pred_list = (pixel_pred_list - min_anomaly_score) / (max_anomaly_score - min_anomaly_score) 104 | self.pixel_gt_list = gt_mask_list 105 | self.pixel_pred_list = pixel_pred_list -------------------------------------------------------------------------------- /augmentation/perlin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | __all__ = ['lerp_np', 'generate_fractal_noise_2d', 'generate_perlin_noise_2d', 'rand_perlin_2d_np', 6 | 'rand_perlin_2d_torch', 'rand_perlin_2d_octaves'] 7 | 8 | def lerp_np(x,y,w): 9 | fin_out = (y-x)*w + x 10 | return fin_out 11 | 12 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5): 13 | noise = np.zeros(shape) 14 | frequency = 1 15 | amplitude = 1 16 | for _ in range(octaves): 17 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1])) 18 | frequency *= 2 19 | amplitude *= persistence 20 | return noise 21 | 22 | 23 | def generate_perlin_noise_2d(shape, res): 24 | def f(t): 25 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3 26 | 27 | delta = (res[0] / shape[0], res[1] / shape[1]) 28 | d = (shape[0] // res[0], shape[1] // res[1]) 29 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 30 | # Gradients 31 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) 32 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 33 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 34 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 35 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) 36 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) 37 | # Ramps 38 | n00 = np.sum(grid * g00, 2) 39 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) 40 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) 41 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) 42 | # Interpolation 43 | t = f(grid) 44 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 45 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 46 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) 47 | 48 | 49 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 50 | delta = (res[0] / shape[0], res[1] / shape[1]) 51 | d = (shape[0] // res[0], shape[1] // res[1]) 52 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 53 | 54 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) 55 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) 56 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1) 57 | 58 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1) 59 | dot = lambda grad, shift: ( 60 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 61 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) 62 | 63 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 64 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 65 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 66 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 67 | t = fade(grid[:shape[0], :shape[1]]) 68 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) 69 | 70 | 71 | def rand_perlin_2d_torch(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 72 | delta = (res[0] / shape[0], res[1] / shape[1]) 73 | d = (shape[0] // res[0], shape[1] // res[1]) 74 | 75 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 76 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) 77 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) 78 | 79 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 80 | 0).repeat_interleave( 81 | d[1], 1) 82 | dot = lambda grad, shift: ( 83 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 84 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) 85 | 86 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 87 | 88 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 89 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 90 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 91 | t = fade(grid[:shape[0], :shape[1]]) 92 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) 93 | 94 | 95 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): 96 | noise = torch.zeros(shape) 97 | frequency = 1 98 | amplitude = 1 99 | for _ in range(octaves): 100 | noise += amplitude * rand_perlin_2d_torch(shape, (frequency * res[0], frequency * res[1])) 101 | frequency *= 2 102 | amplitude *= persistence 103 | return noise -------------------------------------------------------------------------------- /models/reverse/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Type, Any, Callable, Union, List, Optional 4 | 5 | __all__ = ['conv3x3', 'conv1x1', 'deconv2x2', 'AttnBasicBlock', 'AttnBottleneck'] 6 | 7 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=dilation, groups=groups, bias=False, dilation=dilation) 11 | 12 | 13 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | def deconv2x2(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1): 18 | """1x1 convolution""" 19 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=stride, 20 | groups=groups, bias=False, dilation=dilation) 21 | 22 | class AttnBasicBlock(nn.Module): 23 | expansion: int = 1 24 | 25 | def __init__( 26 | self, 27 | inplanes: int, 28 | planes: int, 29 | stride: int = 1, 30 | downsample: Optional[nn.Module] = None, 31 | groups: int = 1, 32 | base_width: int = 64, 33 | dilation: int = 1, 34 | norm_layer: Optional[Callable[..., nn.Module]] = None, 35 | attention: bool = True, 36 | ): 37 | super(AttnBasicBlock, self).__init__() 38 | self.attention = attention 39 | #print("Attention:", self.attention) 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | #self.cbam = GLEAM(planes, 16) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x: torch.Tensor): 57 | #if self.attention: 58 | # x = self.cbam(x) 59 | identity = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | class AttnBottleneck(nn.Module): 78 | 79 | expansion: int = 4 80 | 81 | def __init__( 82 | self, 83 | inplanes: int, 84 | planes: int, 85 | stride: int = 1, 86 | downsample: Optional[nn.Module] = None, 87 | groups: int = 1, 88 | base_width: int = 64, 89 | dilation: int = 1, 90 | norm_layer: Optional[Callable[..., nn.Module]] = None, 91 | attention: bool = True, 92 | ): 93 | super(AttnBottleneck, self).__init__() 94 | self.attention = attention 95 | #print("Attention:",self.attention) 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = conv1x1(inplanes, width) 101 | self.bn1 = norm_layer(width) 102 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 103 | self.bn2 = norm_layer(width) 104 | self.conv3 = conv1x1(width, planes * self.expansion) 105 | self.bn3 = norm_layer(planes * self.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | #self.cbam = GLEAM([int(planes * self.expansion/4), 108 | # int(planes * self.expansion//2), 109 | # planes * self.expansion], 16) 110 | self.downsample = downsample 111 | self.stride = stride 112 | 113 | def forward(self, x: torch.Tensor): 114 | #if self.attention: 115 | # x = self.cbam(x) 116 | identity = x 117 | 118 | out = self.conv1(x) 119 | out = self.bn1(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv2(out) 123 | out = self.bn2(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv3(out) 127 | out = self.bn3(out) 128 | 129 | if self.downsample is not None: 130 | identity = self.downsample(x) 131 | 132 | 133 | out += identity 134 | out = self.relu(out) 135 | 136 | return out 137 | -------------------------------------------------------------------------------- /dataset/_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import random 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | __all__ = ['Example', 'example_classes'] 10 | 11 | def example_classes(): 12 | return ['class_1', 'class_2'] 13 | 14 | 15 | class Example(Dataset): 16 | def __init__(self, data_path, learning_mode='centralized', phase='train', 17 | data_transform=None, num_task=2): 18 | 19 | self.data_path = data_path 20 | self.learning_mode = learning_mode 21 | self.phase = phase 22 | self.class_name = example_classes() 23 | self.img_transform = data_transform[0] 24 | self.mask_transform = data_transform[1] 25 | assert set(self.class_name) <= set(example_classes()) 26 | 27 | self.num_task = num_task 28 | self.class_in_task = [] 29 | 30 | self.imgs_list = [] 31 | self.labels_list = [] 32 | self.masks_list = [] 33 | self.task_ids_list = [] 34 | 35 | # mark each sample task id 36 | self.sample_num_in_task = [] 37 | self.sample_indices_in_task = [] 38 | 39 | # load dataset 40 | self.load_dataset() 41 | self.allocate_task_data() 42 | 43 | def __getitem__(self, idx): 44 | img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx] 45 | 46 | img = Image.open(img_src).convert('RGB') 47 | img = self.img_transform(img) 48 | 49 | if label == 0: 50 | if isinstance(img, tuple): 51 | mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]]) 52 | else: 53 | mask = torch.zeros([1, img.shape[1], img.shape[2]]) 54 | else: 55 | mask = Image.open(mask) 56 | mask = self.mask_transform(mask) 57 | 58 | return { 59 | 'img': img, 'label': label, 'mask': mask, 'task_id': task_id, 'img_src': img_src, 60 | } 61 | 62 | def __len__(self): 63 | return len(self.imgs_list) 64 | 65 | def load_dataset(self): 66 | # input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth 67 | # train directory: only good cases 68 | # test directory: bad and good cases 69 | # ground truth directory: only bad case 70 | 71 | # get classes in each task group 72 | # If num_task is 2, each task constain each class 73 | self.class_in_task = self.split_chunks(self.class_name, self.num_task) 74 | # get data 75 | for id, class_in_task in enumerate(self.class_in_task): 76 | x, y, mask = [], [], [] 77 | for class_name in class_in_task: 78 | img_dir = os.path.join(self.data_path, class_name, self.phase) 79 | gt_dir = os.path.join(self.data_path, class_name, 'ground_truth') 80 | 81 | img_types = sorted(os.listdir(img_dir)) 82 | for img_type in img_types: 83 | 84 | # load images 85 | img_type_dir = os.path.join(img_dir, img_type) 86 | if not os.path.isdir(img_type_dir): 87 | continue 88 | img_path_list = sorted([os.path.join(img_type_dir, f) 89 | for f in os.listdir(img_type_dir) 90 | if f.endswith('.png')]) 91 | x.extend(img_path_list) 92 | 93 | if img_type == 'good': 94 | y.extend([0] * len(img_path_list)) 95 | mask.extend([None] * len(img_path_list)) 96 | else: 97 | y.extend([1] * len(img_path_list)) 98 | gt_type_dir = os.path.join(gt_dir, img_type) 99 | img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list] 100 | gt_path_list = [os.path.join(gt_type_dir, img_fname + '_mask.png') 101 | for img_fname in img_name_list] 102 | mask.extend(gt_path_list) 103 | 104 | task_id = [id for i in range(len(x))] 105 | self.sample_num_in_task.append(len(x)) 106 | 107 | self.imgs_list.extend(x) 108 | self.labels_list.extend(y) 109 | self.masks_list.extend(mask) 110 | self.task_ids_list.extend(task_id) 111 | 112 | def allocate_task_data(self): 113 | start = 0 114 | for num in self.sample_num_in_task: 115 | end = start + num 116 | indice = [i for i in range(start, end)] 117 | random.shuffle(indice) 118 | self.sample_indices_in_task.append(indice) 119 | start = end 120 | 121 | # split the arr into n chunks 122 | @staticmethod 123 | def split_chunks(arr, m): 124 | n = int(math.ceil(len(arr) / float(m))) 125 | return [arr[i:i + n] for i in range(0, len(arr), n)] -------------------------------------------------------------------------------- /dataset/btad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import random 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | __all__ = ['BTAD', 'btad_classes'] 10 | 11 | def btad_classes(): 12 | return ['01', '02', '03'] 13 | 14 | class BTAD(Dataset): 15 | def __init__(self, data_path, learning_mode='centralized', phase='train', 16 | data_transform=None, num_task=3): 17 | 18 | self.data_path = data_path 19 | self.learning_mode = learning_mode 20 | self.phase = phase 21 | self.img_transform = data_transform[0] 22 | self.mask_transform = data_transform[1] 23 | self.class_name = btad_classes() 24 | assert set(self.class_name) <= set(btad_classes()) 25 | 26 | self.num_task = num_task 27 | self.class_in_task = [] 28 | 29 | self.imgs_list = [] 30 | self.labels_list = [] 31 | self.masks_list = [] 32 | self.task_ids_list = [] 33 | 34 | self.sample_num_in_task = [] 35 | self.sample_indices_in_task = [] 36 | 37 | # load dataset 38 | self.load_dataset() 39 | self.allocate_task_data() 40 | 41 | def __getitem__(self, idx): 42 | img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx] 43 | 44 | img = Image.open(img_src).convert('RGB') 45 | img = self.img_transform(img) 46 | 47 | if label == 0: 48 | if isinstance(img, tuple): 49 | mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]]) 50 | else: 51 | mask = torch.zeros([1, img.shape[1], img.shape[2]]) 52 | else: 53 | mask = Image.open(mask) 54 | mask = self.mask_transform(mask) 55 | 56 | return { 57 | 'img': img, 'label':label, 'mask':mask, 'task_id':task_id, 'img_src': img_src, 58 | } 59 | 60 | def __len__(self): 61 | return len(self.imgs_list) 62 | 63 | def load_dataset(self): 64 | # input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth 65 | # train directory: only good cases 66 | # test directory: bad and good cases 67 | # ground truth directory: only bad case 68 | 69 | # get classes in each task group 70 | self.class_in_task = self.split_chunks(self.class_name, self.num_task) 71 | # get data 72 | for id, class_in_task in enumerate(self.class_in_task): 73 | x, y, mask = [], [], [] 74 | for class_name in class_in_task: 75 | 76 | img_dir = os.path.join(self.data_path, class_name, self.phase) 77 | gt_dir = os.path.join(self.data_path, class_name, 'ground_truth') 78 | 79 | img_types = sorted(os.listdir(img_dir)) 80 | for img_type in img_types: 81 | 82 | # load images 83 | img_type_dir = os.path.join(img_dir, img_type) 84 | if not os.path.isdir(img_type_dir): 85 | continue 86 | img_path_list = sorted([os.path.join(img_type_dir, f) 87 | for f in os.listdir(img_type_dir) 88 | if f.endswith(('.png', '.bmp'))]) 89 | x.extend(img_path_list) 90 | 91 | if img_type == 'good': 92 | y.extend([0] * len(img_path_list)) 93 | mask.extend([None] * len(img_path_list)) 94 | else: 95 | y.extend([1] * len(img_path_list)) 96 | gt_type_dir = os.path.join(gt_dir, img_type) 97 | img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list] 98 | suffix = '.png' 99 | if class_name == '03': 100 | suffix = '.bmp' 101 | gt_path_list = [os.path.join(gt_type_dir, img_fname + suffix) 102 | for img_fname in img_name_list] 103 | mask.extend(gt_path_list) 104 | # continual 105 | task_id = [id for i in range(len(x))] 106 | self.sample_num_in_task.append(len(x)) 107 | 108 | self.imgs_list.extend(x) 109 | self.labels_list.extend(y) 110 | self.masks_list.extend(mask) 111 | self.task_ids_list.extend(task_id) 112 | 113 | def allocate_task_data(self): 114 | start = 0 115 | for num in self.sample_num_in_task: 116 | end = start + num 117 | indice = [i for i in range(start, end)] 118 | random.shuffle(indice) 119 | self.sample_indices_in_task.append(indice) 120 | start = end 121 | 122 | # split the arr into n chunks 123 | @staticmethod 124 | def split_chunks(arr, m): 125 | n = int(math.ceil(len(arr) / float(m))) 126 | return [arr[i:i + n] for i in range(0, len(arr), n)] -------------------------------------------------------------------------------- /dataset/mpdd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import random 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | __all__ = ['MPDD', 'mpdd_classes'] 10 | 11 | def mpdd_classes(): 12 | return [ "bracket_black", "bracket_brown", "bracket_white", 13 | "connector", "metal_plate", "tubes"] 14 | 15 | class MPDD(Dataset): 16 | def __init__(self, data_path, learning_mode='centralized', phase='train', 17 | data_transform=None, num_task=6): 18 | 19 | self.data_path = data_path 20 | self.learning_mode = learning_mode 21 | self.phase = phase 22 | self.img_transform = data_transform[0] 23 | self.mask_transform = data_transform[1] 24 | self.class_name = mpdd_classes() 25 | assert set(self.class_name) <= set(mpdd_classes()) 26 | 27 | self.num_task = num_task 28 | self.class_in_task = [] 29 | 30 | self.imgs_list = [] 31 | self.labels_list = [] 32 | self.masks_list = [] 33 | self.task_ids_list = [] 34 | 35 | # continual 36 | self.sample_num_in_task = [] 37 | self.sample_indices_in_task = [] 38 | 39 | # load dataset 40 | self.load_dataset() 41 | self.allocate_task_data() 42 | 43 | def __getitem__(self, idx): 44 | img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx] 45 | 46 | img = Image.open(img_src).convert('RGB') 47 | img = self.img_transform(img) 48 | 49 | if label == 0: 50 | if isinstance(img, tuple): 51 | mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]]) 52 | else: 53 | mask = torch.zeros([1, img.shape[1], img.shape[2]]) 54 | else: 55 | mask = Image.open(mask) 56 | mask = self.mask_transform(mask) 57 | 58 | return { 59 | 'img': img, 'label':label, 'mask':mask, 'task_id':task_id, 'img_src': img_src, 60 | } 61 | 62 | def __len__(self): 63 | return len(self.imgs_list) 64 | 65 | def load_dataset(self): 66 | # input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth 67 | # train directory: only good cases 68 | # test directory: bad and good cases 69 | # ground truth directory: only bad case 70 | 71 | # get classes in each task group 72 | # If num_task is 5, each task constain each class 73 | self.class_in_task = self.split_chunks(self.class_name, self.num_task) 74 | # get data 75 | for id, class_in_task in enumerate(self.class_in_task): 76 | x, y, mask = [], [], [] 77 | for class_name in class_in_task: 78 | img_dir = os.path.join(self.data_path, class_name, self.phase) 79 | gt_dir = os.path.join(self.data_path, class_name, 'ground_truth') 80 | 81 | img_types = sorted(os.listdir(img_dir)) 82 | for img_type in img_types: 83 | 84 | # load images 85 | img_type_dir = os.path.join(img_dir, img_type) 86 | if not os.path.isdir(img_type_dir): 87 | continue 88 | img_path_list = sorted([os.path.join(img_type_dir, f) 89 | for f in os.listdir(img_type_dir) 90 | if f.endswith('.png')]) 91 | x.extend(img_path_list) 92 | 93 | if img_type == 'good': 94 | y.extend([0] * len(img_path_list)) 95 | mask.extend([None] * len(img_path_list)) 96 | else: 97 | y.extend([1] * len(img_path_list)) 98 | gt_type_dir = os.path.join(gt_dir, img_type) 99 | img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list] 100 | gt_path_list = [os.path.join(gt_type_dir, img_fname + '_mask.png') 101 | for img_fname in img_name_list] 102 | mask.extend(gt_path_list) 103 | # continual 104 | task_id = [id for i in range(len(x))] 105 | self.sample_num_in_task.append(len(x)) 106 | 107 | self.imgs_list.extend(x) 108 | self.labels_list.extend(y) 109 | self.masks_list.extend(mask) 110 | self.task_ids_list.extend(task_id) 111 | 112 | def allocate_task_data(self): 113 | start = 0 114 | for num in self.sample_num_in_task: 115 | end = start + num 116 | indice = [i for i in range(start, end)] 117 | random.shuffle(indice) 118 | self.sample_indices_in_task.append(indice) 119 | start = end 120 | 121 | # split the arr into n chunks 122 | @staticmethod 123 | def split_chunks(arr, m): 124 | n = int(math.ceil(len(arr) / float(m))) 125 | return [arr[i:i + n] for i in range(0, len(arr), n)] -------------------------------------------------------------------------------- /dataset/dagm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import random 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | __all__ = ['DAGM', 'dagm_classes'] 10 | 11 | def dagm_classes(): 12 | return ["Class1","Class2","Class3","Class4","Class5", 13 | "Class6","Class7","Class8","Class9","Class10"] 14 | 15 | 16 | class DAGM(Dataset): 17 | def __init__(self, data_path, learning_mode='centralized', phase='train', 18 | data_transform=None, num_task=15): 19 | 20 | self.data_path = data_path 21 | self.learning_mode = learning_mode 22 | self.phase = phase 23 | self.class_name = dagm_classes() 24 | self.img_transform = data_transform[0] 25 | self.mask_transform = data_transform[1] 26 | assert set(self.class_name) <= set(dagm_classes()) 27 | 28 | self.num_task = num_task 29 | self.class_in_task = [] 30 | 31 | self.imgs_list = [] 32 | self.labels_list = [] 33 | self.masks_list = [] 34 | self.task_ids_list = [] 35 | 36 | # mark each sample task id 37 | self.sample_num_in_task = [] 38 | self.sample_indices_in_task = [] 39 | 40 | # load dataset 41 | self.load_dataset() 42 | self.allocate_task_data() 43 | 44 | def __getitem__(self, idx): 45 | img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx] 46 | 47 | img = Image.open(img_src).convert('RGB') 48 | img = self.img_transform(img) 49 | 50 | if label == 0: 51 | if isinstance(img, tuple): 52 | mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]]) 53 | else: 54 | mask = torch.zeros([1, img.shape[1], img.shape[2]]) 55 | else: 56 | mask = Image.open(mask) 57 | mask = self.mask_transform(mask) 58 | 59 | return { 60 | 'img': img, 'label': label, 'mask': mask, 'task_id': task_id, 'img_src': img_src, 61 | } 62 | 63 | def __len__(self): 64 | return len(self.imgs_list) 65 | 66 | def load_dataset(self): 67 | # input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth 68 | # train directory: only good cases 69 | # test directory: bad and good cases 70 | # ground truth directory: only bad case 71 | 72 | # get classes in each task group 73 | # If num_task is 15, each task constain each class 74 | self.class_in_task = self.split_chunks(self.class_name, self.num_task) 75 | # get data 76 | for id, class_in_task in enumerate(self.class_in_task): 77 | x, y, mask = [], [], [] 78 | for class_name in class_in_task: 79 | img_dir = os.path.join(self.data_path, class_name, self.phase) 80 | gt_dir = os.path.join(self.data_path, class_name, 'ground_truth') 81 | 82 | img_types = sorted(os.listdir(img_dir)) 83 | for img_type in img_types: 84 | 85 | # load images 86 | img_type_dir = os.path.join(img_dir, img_type) 87 | if not os.path.isdir(img_type_dir): 88 | continue 89 | img_path_list = sorted([os.path.join(img_type_dir, f) 90 | for f in os.listdir(img_type_dir) 91 | if f.endswith('.PNG')]) 92 | x.extend(img_path_list) 93 | 94 | if img_type == 'good': 95 | y.extend([0] * len(img_path_list)) 96 | mask.extend([None] * len(img_path_list)) 97 | else: 98 | y.extend([1] * len(img_path_list)) 99 | gt_type_dir = os.path.join(gt_dir, img_type) 100 | img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list] 101 | gt_path_list = [os.path.join(gt_type_dir, img_fname + '_label.PNG') 102 | for img_fname in img_name_list] 103 | mask.extend(gt_path_list) 104 | 105 | task_id = [id for i in range(len(x))] 106 | self.sample_num_in_task.append(len(x)) 107 | 108 | self.imgs_list.extend(x) 109 | self.labels_list.extend(y) 110 | self.masks_list.extend(mask) 111 | self.task_ids_list.extend(task_id) 112 | 113 | def allocate_task_data(self): 114 | start = 0 115 | for num in self.sample_num_in_task: 116 | end = start + num 117 | indice = [i for i in range(start, end)] 118 | random.shuffle(indice) 119 | self.sample_indices_in_task.append(indice) 120 | start = end 121 | 122 | # split the arr into n chunks 123 | @staticmethod 124 | def split_chunks(arr, m): 125 | n = int(math.ceil(len(arr) / float(m))) 126 | return [arr[i:i + n] for i in range(0, len(arr), n)] -------------------------------------------------------------------------------- /arch/stpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import copy 5 | from arch.base import ModelBase 6 | from torchvision import models 7 | from optimizer.optimizer import get_optimizer 8 | 9 | __all__ = ['STPM'] 10 | 11 | class STPM(ModelBase): 12 | def __init__(self, config): 13 | super(STPM, self).__init__(config) 14 | self.config = config 15 | 16 | if self.config['net'] == 'resnet18': 17 | self.net = models.resnet18(pretrained=True, progress=True).to(self.device) 18 | self.net_student = self.net 19 | self.net_teacher = copy.deepcopy(self.net).to(self.device) 20 | self.optimizer = get_optimizer(self.config, self.net_student.parameters()) 21 | 22 | self.features_teacher = [] 23 | self.features_student = [] 24 | self.get_layer_features() 25 | 26 | self.criterion = torch.nn.MSELoss(reduction='sum') 27 | 28 | def cal_anomaly_map(self, feat_teachers, feat_students, out_size=224): 29 | anomaly_map = np.ones([out_size, out_size]) 30 | a_map_list = [] 31 | for i in range(len(feat_teachers)): 32 | fs = feat_students[i] 33 | ft = feat_teachers[i] 34 | fs_norm = F.normalize(fs, p=2) 35 | ft_norm = F.normalize(ft, p=2) 36 | a_map = 1 - F.cosine_similarity(fs_norm, ft_norm) 37 | a_map = torch.unsqueeze(a_map, dim=1) 38 | a_map = F.interpolate(a_map, size=out_size, mode='bilinear') 39 | a_map = a_map[0,0,:,:].to('cpu').detach().numpy() 40 | a_map_list.append(a_map) 41 | anomaly_map *= a_map 42 | 43 | return anomaly_map, a_map_list 44 | 45 | def get_layer_features(self): 46 | 47 | def hook_t(module, input, output): 48 | self.features_teacher.append(output) 49 | 50 | def hook_s(module, input, output): 51 | self.features_student.append(output) 52 | 53 | self.net_teacher.layer1[-1].register_forward_hook(hook_t) 54 | self.net_teacher.layer2[-1].register_forward_hook(hook_t) 55 | self.net_teacher.layer3[-1].register_forward_hook(hook_t) 56 | 57 | self.net_student.layer1[-1].register_forward_hook(hook_s) 58 | self.net_student.layer2[-1].register_forward_hook(hook_s) 59 | self.net_student.layer3[-1].register_forward_hook(hook_s) 60 | 61 | def cal_loss(self, feat_teachers, feat_students, criterion): 62 | total_loss = 0 63 | for i in range(len(feat_teachers)): 64 | fs = feat_students[i] 65 | ft = feat_teachers[i] 66 | _, _, h, w = fs.shape 67 | fs_norm = F.normalize(fs, p=2) 68 | ft_norm = F.normalize(ft, p=2) 69 | f_loss = (0.5/(w*h))*criterion(fs_norm, ft_norm) 70 | total_loss += f_loss 71 | 72 | return total_loss 73 | 74 | def train_model(self, train_loader, task_id, inf=''): 75 | self.net_teacher.eval() 76 | self.net_student.train() 77 | 78 | for epoch in range(self.config['num_epochs']): 79 | for batch_id, batch in enumerate(train_loader): 80 | img = batch['img'].to(self.device) 81 | self.optimizer.zero_grad() 82 | 83 | with torch.set_grad_enabled(True): 84 | self.features_teacher.clear() 85 | self.features_student.clear() 86 | 87 | _ = self.net_teacher(img) 88 | _ = self.net_student(img) 89 | 90 | loss = self.cal_loss(feat_teachers=self.features_teacher, feat_students=self.features_student, 91 | criterion=self.criterion) 92 | loss.backward() 93 | self.optimizer.step() 94 | 95 | def prediction(self, valid_loader, task_id): 96 | self.net_teacher.eval() 97 | self.net_student.eval() 98 | self.clear_all_list() 99 | 100 | for batch_id, batch in enumerate(valid_loader): 101 | img = batch['img'].to(self.device) 102 | label = batch['label'] 103 | mask = batch['mask'] 104 | mask[mask>=0.5] = 1 105 | mask[mask<0.5] = 0 106 | self.img_path_list.append(batch['img_src']) 107 | 108 | with torch.set_grad_enabled(False): 109 | self.features_teacher.clear() 110 | self.features_student.clear() 111 | 112 | _ = self.net_teacher(img) 113 | _ = self.net_student(img) 114 | 115 | anomaly_map, _ = self.cal_anomaly_map(feat_teachers=self.features_teacher, feat_students=self.features_student, 116 | out_size=self.config['data_crop_size']) 117 | 118 | self.pixel_pred_list.append(anomaly_map) 119 | self.pixel_gt_list.append(mask.cpu().numpy()[0,0].astype(int)) 120 | self.img_pred_list.append(np.max(anomaly_map)) 121 | self.img_gt_list.append(label.numpy()[0]) 122 | -------------------------------------------------------------------------------- /dataset/mvtec2df3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import random 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | __all__ = ['MVTec2DF3D', 'mvtec2df3d_classes'] 10 | 11 | def mvtec2df3d_classes(): 12 | return [ "bagel", "cable_gland", "carrot", "cookie", "dowel", 13 | "foam", "peach", "potato", "rope", "tire"] 14 | 15 | 16 | class MVTec2DF3D(Dataset): 17 | def __init__(self, data_path, learning_mode='centralized', phase='train', 18 | data_transform=None, num_task=10): 19 | 20 | self.data_path = data_path 21 | self.learning_mode = learning_mode 22 | self.phase = phase 23 | self.img_transform = data_transform[0] 24 | self.mask_transform = data_transform[1] 25 | self.class_name = mvtec2df3d_classes() 26 | assert set(self.class_name) <= set(mvtec2df3d_classes()) 27 | 28 | self.num_task = num_task 29 | self.class_in_task = [] 30 | 31 | self.imgs_list = [] 32 | self.labels_list = [] 33 | self.masks_list = [] 34 | self.task_ids_list = [] 35 | 36 | # mark each sample task id 37 | self.sample_num_in_task = [] 38 | self.sample_indices_in_task = [] 39 | 40 | # load dataset 41 | self.load_dataset() 42 | self.allocate_task_data() 43 | 44 | def __getitem__(self, idx): 45 | img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx] 46 | 47 | img = Image.open(img_src).convert('RGB') 48 | img = self.img_transform(img) 49 | 50 | if label == 0: 51 | if isinstance(img, tuple): 52 | mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]]) 53 | else: 54 | mask = torch.zeros([1, img.shape[1], img.shape[2]]) 55 | else: 56 | mask = Image.open(mask) 57 | mask = self.mask_transform(mask) 58 | 59 | return { 60 | 'img': img, 'label': label, 'mask': mask, 'task_id': task_id, 'img_src': img_src, 61 | } 62 | 63 | def __len__(self): 64 | return len(self.imgs_list) 65 | 66 | def load_dataset(self): 67 | # input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth 68 | # train directory: only good cases 69 | # test directory: bad and good cases 70 | # ground truth directory: only bad case 71 | 72 | # get classes in each task group 73 | 74 | # If num_task is 10, each task constain each class 75 | self.class_in_task = self.split_chunks(self.class_name, self.num_task) 76 | # get data 77 | for id, class_in_task in enumerate(self.class_in_task): 78 | x, y, mask = [], [], [] 79 | for class_name in class_in_task: 80 | img_dir = os.path.join(self.data_path, class_name, self.phase) 81 | gt_dir = os.path.join(self.data_path, class_name, 'ground_truth') 82 | 83 | img_types = sorted(os.listdir(img_dir)) 84 | for img_type in img_types: 85 | 86 | # load images 87 | img_type_dir = os.path.join(img_dir, img_type) 88 | if not os.path.isdir(img_type_dir): 89 | continue 90 | img_path_list = sorted([os.path.join(img_type_dir, f) 91 | for f in os.listdir(img_type_dir) 92 | if f.endswith('.png')]) 93 | x.extend(img_path_list) 94 | 95 | if img_type == 'good': 96 | y.extend([0] * len(img_path_list)) 97 | mask.extend([None] * len(img_path_list)) 98 | else: 99 | y.extend([1] * len(img_path_list)) 100 | gt_type_dir = os.path.join(gt_dir, img_type) 101 | img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list] 102 | gt_path_list = [os.path.join(gt_type_dir, img_fname + '_mask.png') 103 | for img_fname in img_name_list] 104 | mask.extend(gt_path_list) 105 | 106 | task_id = [id for i in range(len(x))] 107 | self.sample_num_in_task.append(len(x)) 108 | 109 | self.imgs_list.extend(x) 110 | self.labels_list.extend(y) 111 | self.masks_list.extend(mask) 112 | self.task_ids_list.extend(task_id) 113 | 114 | def allocate_task_data(self): 115 | start = 0 116 | for num in self.sample_num_in_task: 117 | end = start + num 118 | indice = [i for i in range(start, end)] 119 | random.shuffle(indice) 120 | self.sample_indices_in_task.append(indice) 121 | start = end 122 | 123 | # split the arr into n chunks 124 | @staticmethod 125 | def split_chunks(arr, m): 126 | n = int(math.ceil(len(arr) / float(m))) 127 | return [arr[i:i + n] for i in range(0, len(arr), n)] -------------------------------------------------------------------------------- /dataset/visa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import random 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms as T 8 | 9 | 10 | __all__ = ['VisA', 'visa_classes'] 11 | 12 | def visa_classes(): 13 | return ["candle", "capsules", "cashew", "chewinggum", "fryum", "macaroni1", 14 | "macaroni2", "pcb1", "pcb2", "pcb3", "pcb4", "pipe_fryum"] 15 | 16 | 17 | class VisA(Dataset): 18 | def __init__(self, data_path, learning_mode='centralized', phase='train', 19 | data_transform=None, num_task=12): 20 | 21 | self.data_path = data_path 22 | self.learning_mode = learning_mode 23 | self.phase = phase 24 | self.class_name = visa_classes() 25 | self.img_transform = data_transform[0] 26 | self.mask_transform = data_transform[1] 27 | assert set(self.class_name) <= set(visa_classes()) 28 | 29 | self.num_task = num_task 30 | self.class_in_task = [] 31 | 32 | self.imgs_list = [] 33 | self.labels_list = [] 34 | self.masks_list = [] 35 | self.task_ids_list = [] 36 | 37 | # mark each sample task id 38 | self.sample_num_in_task = [] 39 | self.sample_indices_in_task = [] 40 | 41 | # load dataset 42 | self.load_dataset() 43 | self.allocate_task_data() 44 | 45 | 46 | def __getitem__(self, idx): 47 | img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx] 48 | 49 | img = Image.open(img_src).convert('RGB') 50 | img = self.img_transform(img) 51 | 52 | if label == 0: 53 | if isinstance(img, tuple): 54 | mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]]) 55 | else: 56 | mask = torch.zeros([1, img.shape[1], img.shape[2]]) 57 | else: 58 | mask = Image.open(mask) 59 | mask = self.mask_transform(mask) 60 | 61 | return { 62 | 'img': img, 'label': label, 'mask': mask, 'task_id': task_id, 'img_src': img_src, 63 | } 64 | 65 | def __len__(self): 66 | return len(self.imgs_list) 67 | 68 | 69 | def load_dataset(self): 70 | # input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth 71 | # train directory: only good cases 72 | # test directory: bad and good cases 73 | # ground truth directory: only bad case 74 | 75 | # get classes in each task group 76 | # If num_task is 15, each task constain each class 77 | self.class_in_task = self.split_chunks(self.class_name, self.num_task) 78 | # get data 79 | for id, class_in_task in enumerate(self.class_in_task): 80 | x, y, mask = [], [], [] 81 | for class_name in class_in_task: 82 | img_dir = os.path.join(self.data_path, class_name, self.phase) 83 | gt_dir = os.path.join(self.data_path, class_name, 'ground_truth') 84 | 85 | img_types = sorted(os.listdir(img_dir)) 86 | for img_type in img_types: 87 | 88 | # load images 89 | img_type_dir = os.path.join(img_dir, img_type) 90 | if not os.path.isdir(img_type_dir): 91 | continue 92 | img_path_list = sorted([os.path.join(img_type_dir, f) 93 | for f in os.listdir(img_type_dir) 94 | if f.endswith('.JPG')]) 95 | x.extend(img_path_list) 96 | 97 | if img_type == 'good': 98 | y.extend([0] * len(img_path_list)) 99 | mask.extend([None] * len(img_path_list)) 100 | else: 101 | y.extend([1] * len(img_path_list)) 102 | gt_type_dir = os.path.join(gt_dir, img_type) 103 | img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list] 104 | gt_path_list = [os.path.join(gt_type_dir, img_fname + '.png') 105 | for img_fname in img_name_list] 106 | mask.extend(gt_path_list) 107 | 108 | task_id = [id for i in range(len(x))] 109 | self.sample_num_in_task.append(len(x)) 110 | 111 | self.imgs_list.extend(x) 112 | self.labels_list.extend(y) 113 | self.masks_list.extend(mask) 114 | self.task_ids_list.extend(task_id) 115 | 116 | def allocate_task_data(self): 117 | start = 0 118 | for num in self.sample_num_in_task: 119 | end = start + num 120 | indice = [i for i in range(start, end)] 121 | random.shuffle(indice) 122 | self.sample_indices_in_task.append(indice) 123 | start = end 124 | 125 | # split the arr into n chunks 126 | @staticmethod 127 | def split_chunks(arr, m): 128 | n = int(math.ceil(len(arr) / float(m))) 129 | return [arr[i:i + n] for i in range(0, len(arr), n)] --------------------------------------------------------------------------------