├── .gitignore ├── pdf ├── PEFM.bmp └── PFM.bmp ├── run_PFM.sh ├── run_PEFM.sh ├── dataset ├── dagm.py ├── MVTec3D_IMG.py ├── mvtec.py └── stc.py ├── README.md ├── models.py ├── utils.py ├── MB-PFM-VGG.py ├── MB-PFM-ResNet.py ├── PEFM_MVTec3D_PE.py ├── PEFM_AD_PE_Cat.py └── PEFM_AD.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.pth 4 | *.png 5 | result/ 6 | result_3D/ -------------------------------------------------------------------------------- /pdf/PEFM.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smiler96/PFM-and-PEFM-for-Image-Anomaly-Detection-and-Segmentation/HEAD/pdf/PEFM.bmp -------------------------------------------------------------------------------- /pdf/PFM.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smiler96/PFM-and-PEFM-for-Image-Anomaly-Detection-and-Segmentation/HEAD/pdf/PFM.bmp -------------------------------------------------------------------------------- /run_PFM.sh: -------------------------------------------------------------------------------- 1 | # Training 2 | # python MB-PFM-ResNet.py --train --gpu_id 0 --batch_size 8 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_anomaly_detection/ 3 | 4 | # Testing 5 | python MB-PFM-ResNet.py --gpu_id 0 --batch_size 1 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root D:/Dataset/mvtec_anomaly_detection/ -------------------------------------------------------------------------------- /run_PEFM.sh: -------------------------------------------------------------------------------- 1 | # Training for MVTec AD 2 | # python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 128 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/ 3 | # python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/ 4 | # python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 512 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/ 5 | 6 | # Testing for MVTec AD 7 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 128 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/ 8 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/ 9 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 512 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/ 10 | 11 | 12 | # Training for MVTec 3D AD 13 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_3d_anomaly_detection/ 14 | 15 | # Testing for MVTec 3D AD 16 | python PEFM_AD.py --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_3d_anomaly_detection/ -------------------------------------------------------------------------------- /dataset/dagm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import numpy as np 5 | from loguru import logger 6 | 7 | from PIL import Image 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms as T 11 | 12 | 13 | DATA_ROOT = "D:/Dataset/DAGM2007/" 14 | DAGMM_CLASS = ["Class1", "Class2", "Class3", "Class4", "Class5", "Class6", "Class7", "Class8", "Class9", "Class10"] 15 | # DAGMM_CLASS = [ "Class1", "Class2", "Class3", "Class4", "Class5", "Class6"] 16 | class DAGMDataset(Dataset): 17 | def __init__(self, root_path=DATA_ROOT, class_name='Class1', is_train=True, 18 | resize=256): 19 | 20 | assert class_name in DAGMM_CLASS 21 | 22 | self.resize = resize 23 | self.root_path = root_path 24 | self.class_name = class_name 25 | self.is_train = is_train 26 | 27 | # set transforms 28 | self.transform_x = T.Compose([T.Resize(resize, Image.ANTIALIAS), 29 | T.ToTensor(), 30 | T.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225])]) 32 | self.transform_mask = T.Compose([T.Resize(resize, Image.NEAREST), 33 | T.ToTensor()]) 34 | 35 | self.x, self.mask = self.load_image() 36 | self.len = len(self.x) 37 | 38 | def load_image(self): 39 | phase = "Train" if self.is_train else "Test" 40 | img_path = os.path.join(self.root_path, self.class_name, phase) 41 | label_path = os.path.join(self.root_path, self.class_name, phase, "Label") 42 | label_file = os.path.join(label_path, "Labels.txt") 43 | with open(label_file, 'r') as f: 44 | info = f.readlines() 45 | info = info[1:] 46 | for i in range(len(info)): 47 | info[i] = info[i].strip('\n').split('\t') 48 | # print(info) 49 | 50 | mask_list = [] 51 | img_list = [] 52 | if self.is_train: 53 | for s in info: 54 | if s[1] == '0': 55 | img_list.append(os.path.join(img_path, s[2])) 56 | else: 57 | for s in info: 58 | img_list.append(os.path.join(img_path, s[2])) 59 | if s[1] == "1": 60 | mask_list.append(os.path.join(label_path, s[4])) 61 | else: 62 | mask_list.append("None") 63 | # img_list.sort() 64 | # mask_list.sort() 65 | # self.len = len(img_list) // 3 66 | # new_img_list = img_list[len(img_list)-self.len:len(img_list)] 67 | # new_mask_list = mask_list[len(img_list)-self.len:len(img_list)] 68 | # 69 | # new_mask_list[self.len//2:self.len] = mask_list[len(img_list)-self.len//2:len(img_list)] 70 | # new_img_list[self.len//2:self.len] = img_list[len(img_list)-self.len//2:len(img_list)] 71 | # 72 | # new_mask_list[:self.len//2] = mask_list[:self.len//2] 73 | # new_img_list[:self.len//2] = img_list[:self.len//2] 74 | # 75 | # img_list = new_mask_list 76 | # mask_list = new_mask_list 77 | assert len(img_list) == len(mask_list) 78 | return img_list, mask_list 79 | 80 | def __len__(self): 81 | return self.len 82 | 83 | def __getitem__(self, idx): 84 | x = self.x[idx] 85 | name = os.path.basename(x) 86 | x = Image.open(x).convert('RGB') 87 | x = self.transform_x(x) 88 | 89 | if self.is_train: 90 | mask = torch.zeros([1, self.resize, self.resize]) 91 | y = 0 92 | else: 93 | mask = self.mask[idx] 94 | if mask != "None": 95 | mask = Image.open(mask) 96 | mask = self.transform_mask(mask) 97 | y = 1 98 | else: 99 | mask = torch.zeros([1, x.shape[1], x.shape[2]]) 100 | y = 0 101 | 102 | return x, y, mask, name 103 | 104 | if __name__ == "__main__": 105 | stc = DAGMDataset(is_train=False) 106 | data = stc[0] 107 | print(data) -------------------------------------------------------------------------------- /dataset/MVTec3D_IMG.py: -------------------------------------------------------------------------------- 1 | import cv2.cv2 2 | import matplotlib 3 | 4 | matplotlib.use("Agg") 5 | import torch 6 | from torch.utils.data import Dataset 7 | import os 8 | import glob 9 | from PIL import Image 10 | import numpy as np 11 | from skimage.segmentation import slic, mark_boundaries 12 | from torchvision import transforms 13 | from imageio import imsave 14 | # imagenet 15 | mean_train = [0.485, 0.456, 0.406] 16 | std_train = [0.229, 0.224, 0.225] 17 | 18 | MVTEC_CLASSES=['bagel', 'cable_gland','carrot', 'cookie','dowel', 19 | 'peach', 'potato', 'rope', 'tire' , 'foam'] 20 | 21 | def denormalization(x): 22 | x = (((x.transpose(1, 2, 0) * std_train) + mean_train) * 255.).astype(np.uint8) 23 | return x 24 | 25 | 26 | class MVTec3DDataset_IMG(Dataset): 27 | def __init__(self, root, transform, gt_transform, phase): 28 | if phase=='train': 29 | self.img_path = os.path.join(root, 'train') 30 | else: 31 | self.img_path = os.path.join(root, 'test') 32 | self.gt_path = os.path.join(root, 'test') 33 | 34 | self.transform = transform 35 | self.gt_transform = gt_transform 36 | # load dataset 37 | self.img_paths, self.gt_paths, self.labels, self.types = self.load_dataset() # self.labels => good : 0, anomaly : 1 38 | 39 | def load_dataset(self): 40 | 41 | img_tot_paths = [] 42 | gt_tot_paths = [] 43 | tot_labels = [] 44 | tot_types = [] 45 | 46 | defect_types = os.listdir(self.img_path) 47 | 48 | for defect_type in defect_types: 49 | if defect_type == 'good': 50 | img_paths = glob.glob(os.path.join(self.img_path, defect_type,'rgb') + "/*.png") 51 | img_tot_paths.extend(img_paths) 52 | gt_tot_paths.extend([0]*len(img_paths)) 53 | tot_labels.extend([0]*len(img_paths)) 54 | tot_types.extend(['good']*len(img_paths)) 55 | else: 56 | img_paths = glob.glob(os.path.join(self.img_path, defect_type,'rgb') + "/*.png") 57 | gt_paths = glob.glob(os.path.join(self.gt_path, defect_type, 'gt') + "/*.png") 58 | img_paths.sort() 59 | gt_paths.sort() 60 | img_tot_paths.extend(img_paths) 61 | gt_tot_paths.extend(gt_paths) 62 | tot_labels.extend([1]*len(img_paths)) 63 | tot_types.extend([defect_type]*len(img_paths)) 64 | 65 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" 66 | 67 | return img_tot_paths, gt_tot_paths, tot_labels, tot_types 68 | 69 | def __len__(self): 70 | return len(self.img_paths) 71 | 72 | def __getitem__(self, idx): 73 | img_path, gt, label, img_type = self.img_paths[idx], self.gt_paths[idx], self.labels[idx], self.types[idx] 74 | img = Image.open(img_path).convert('RGB') 75 | img = self.transform(img) 76 | names = img_path.split("\\") 77 | name = names[-3] + "_" + names[-1] 78 | if gt == 0: 79 | gt = torch.zeros([1, img.size()[-2], img.size()[-2]]) 80 | else: 81 | gt = Image.open(gt) 82 | gt = self.gt_transform(gt) 83 | gt[gt > 0.5] = 1 84 | gt[gt <= 0.5] = 0 85 | 86 | assert img.size()[1:] == gt.size()[1:], "image.size != gt.size !!!" 87 | 88 | return img, label, gt, name 89 | 90 | 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | categories = MVTEC_CLASSES 96 | dataset_path = r'../datasets/mvtec_anomaly_detection' 97 | load_size = 256 98 | input_size = 256 99 | data_transforms = transforms.Compose([ 100 | transforms.Resize((load_size, load_size), Image.ANTIALIAS), 101 | transforms.ToTensor(), 102 | transforms.CenterCrop(input_size), 103 | transforms.Normalize(mean=mean_train, 104 | std=std_train)]) 105 | gt_transforms = transforms.Compose([ 106 | transforms.Resize((load_size, load_size), Image.NEAREST), 107 | transforms.ToTensor(), 108 | transforms.CenterCrop(input_size)]) 109 | 110 | for category in categories: 111 | phase = 'train' 112 | dataset = MVTecDatasetSpxl(root=os.path.join(dataset_path, category), 113 | transform=data_transforms, gt_transform=gt_transforms, phase=phase) 114 | 115 | for img, gt, label, name, img_type, spxl_label in dataset: 116 | save_folder = os.path.join(dataset_path, 'spxls', category, phase) 117 | save_name = os.path.join(save_folder, f'{name}.bmp') 118 | os.makedirs(save_folder, exist_ok=True) 119 | 120 | 121 | pass 122 | 123 | 124 | pass 125 | -------------------------------------------------------------------------------- /dataset/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms as T 7 | 8 | MVTec_CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 9 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 10 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper'] 11 | # MVTec_CLASS_NAMES = [ 'pill', 'screw', 12 | # 'tile', 'toothbrush', 'transistor', 'wood', 'zipper'] 13 | 14 | class MVTecDataset(Dataset): 15 | def __init__(self, root_path='D:/Dataset/mvtec_anomaly_detection/', class_name='bottle', is_train=True, resize=256, trans=None, LOAD_CPU=False): 16 | assert class_name in MVTec_CLASS_NAMES, 'class_name: {}, should be in {}'.format(class_name, MVTec_CLASS_NAMES) 17 | 18 | self.root_path = root_path 19 | self.class_name = class_name 20 | self.is_train = is_train 21 | self.resize = resize 22 | 23 | # load dataset 24 | self.x, self.y, self.mask = self.load_dataset_folder() 25 | 26 | # set transforms 27 | if trans is None: 28 | self.transform_x = T.Compose([T.Resize(resize, Image.ANTIALIAS), 29 | T.ToTensor(), 30 | T.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225])]) 32 | else: 33 | self.transform_x = trans 34 | # self.transform_x = T.Compose([T.Resize(resize, Image.ANTIALIAS), 35 | # T.ToTensor()]) 36 | 37 | 38 | self.transform_mask = T.Compose([T.Resize(resize, Image.NEAREST), 39 | T.ToTensor()]) 40 | 41 | self.load_cpu = LOAD_CPU 42 | self.len = len(self.x) 43 | self.x_cpu = [] 44 | self.y_cpu = [] 45 | self.name = [] 46 | self.mask_cpu = [] 47 | if self.load_cpu: 48 | for i in range(self.len): 49 | names = self.x[i].split("/") 50 | name = names[-2] + "_" + names[-1] 51 | self.name.append(name) 52 | x = Image.open(self.x[i]).convert('RGB') 53 | x = self.transform_x(x) 54 | self.x_cpu.append(x) 55 | 56 | if self.y[i] == 0: 57 | mask = torch.zeros([1, self.resize, self.resize]) 58 | else: 59 | mask = Image.open(self.mask[i]) 60 | mask = self.transform_mask(mask) 61 | self.mask_cpu.append(mask) 62 | self.y_cpu.append(self.y[i]) 63 | 64 | def __getitem__(self, idx): 65 | if self.load_cpu: 66 | x, y, mask, name = self.x_cpu[idx], self.y_cpu[idx], self.mask_cpu[idx], self.name[idx] 67 | else: 68 | x, y, mask = self.x[idx], self.y[idx], self.mask[idx] 69 | # names = x.split("/") 70 | names = x.split("\\") 71 | name = names[-2] + "_" + names[-1] 72 | x = Image.open(x).convert('RGB') 73 | x = self.transform_x(x) 74 | 75 | if y == 0: 76 | mask = torch.zeros([1, self.resize, self.resize]) 77 | else: 78 | mask = Image.open(mask) 79 | mask = self.transform_mask(mask) 80 | 81 | return x, y, mask, name 82 | 83 | def __len__(self): 84 | return len(self.x) 85 | 86 | def load_dataset_folder(self): 87 | phase = 'train' if self.is_train else 'test' 88 | x, y, mask = [], [], [] 89 | 90 | img_dir = os.path.join(self.root_path, self.class_name, phase) 91 | gt_dir = os.path.join(self.root_path, self.class_name, 'ground_truth') 92 | 93 | img_types = sorted(os.listdir(img_dir)) 94 | for img_type in img_types: 95 | 96 | # load images 97 | img_type_dir = os.path.join(img_dir, img_type) 98 | if not os.path.isdir(img_type_dir): 99 | continue 100 | img_fpath_list = sorted([os.path.join(img_type_dir, f) 101 | for f in os.listdir(img_type_dir) 102 | if f.endswith('.png')]) 103 | x.extend(img_fpath_list) 104 | 105 | # load gt labels 106 | if img_type == 'good': 107 | y.extend([0] * len(img_fpath_list)) 108 | mask.extend([None] * len(img_fpath_list)) 109 | else: 110 | y.extend([1] * len(img_fpath_list)) 111 | gt_type_dir = os.path.join(gt_dir, img_type) 112 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list] 113 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '_mask.png') 114 | for img_fname in img_fname_list] 115 | mask.extend(gt_fpath_list) 116 | 117 | assert len(x) == len(y), 'number of x and y should be same' 118 | 119 | return list(x), list(y), list(mask) 120 | 121 | 122 | if __name__ == "__main__": 123 | dataset_ = MVTecDataset() 124 | a = dataset_[0] 125 | print(a) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PFM and PEFM for Image Anomaly Detection and Segmentation 2 | 3 | ## Abstract 4 | 5 | ### Unsupervised Image Anomaly Detection and Segmentation Based on Pre-trained Feature Mapping [(PFM-TII)](https://ieeexplore.ieee.org/document/9795121) 6 | 7 | Image anomaly detection and segmentation are important for the development of automatic product quality inspection in intelligent manufacturing. Because the normal data can be collected easily and abnormal ones are rarely existent, unsupervised methods based on reconstruction and embedding have been mainly studied for anomaly detection. But the detection performance and computing time requires to be further improved. This paper proposes a novel framework, named as Pre-trained Feature Mapping (PFM), for unsupervised image anomaly detection and segmentation. The proposed PFM maps the image from a pre-trained feature space to another one to detect the anomalies effectively. The bidirectional and multi-hierarchical bidirectional pre-trained feature mapping are further proposed and studied for improving the performance. The proposed framework achieves the better results on well-known MVTec AD dataset compared with state-of-the-art methods, with the area under the receiver operating characteristic curve of 97.5% for anomaly detection and of 97.3% for anomaly segmentation over all 15 categories. The proposed framework is also superior in term of the computing time. The extensive experiments on ablation studies are also conducted to show the effectiveness and efficient of the proposed framework. 8 | 9 |
10 | PFM 11 |
12 | 13 | ### Position Encoding Enhanced Feature Mapping for Image Anomaly Detection [(PEFM CASE)](https://www.researchgate.net/publication/361254312_Position_Encoding_Enhanced_Feature_Mapping_for_Image_Anomaly_Detection) 14 | 15 | Image anomaly detection is an important stage for automatic visual inspection in intelligent manufacturing systems. The wide-ranging anomalies in images, such as various sizes, shapes, and colors, make automatic visual inspection challenging. Previous work on image anomaly detection has achieved significant advancements. However, there are still limitations in terms of detection performance and efficiency. In this paper, a novel Position Encoding enhanced Feature Mapping (PEFM) method is proposed to address the problem of image anomaly detection, detecting the anomalies by mapping a pair of pre-trained features embedded with position encodes. Experiment results show that the proposed PEFM achieves better performance and efficiency than the state-of-the-art methods on the MVTec AD dataset, an AUCROC of 98.30% and an AUCPRO of 95.52%, and achieves the AUCPRO of 94.0% on the MVTec 3D AD dataset. 16 | 17 |
18 | PEFM 19 |
20 | 21 | ## Using [![Framework: PyTorch](https://img.shields.io/badge/Framework-PyTorch-orange.svg)](https://pytorch.org/) 22 | 23 | ### PFM (TII) 24 | 25 | ```python 26 | # Training 27 | python MB-PFM-ResNet.py --train --gpu_id 0 --batch_size 8 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_anomaly_detection/ 28 | # Testing 29 | python MB-PFM-ResNet.py --gpu_id 0 --batch_size 1 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root D:/Dataset/mvtec_anomaly_detection/ 30 | ``` 31 | 32 | ### PEFM (CASE) 33 | 34 | ```python 35 | # Training for MVTec AD 36 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 128 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/ 37 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/ 38 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 512 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/ 39 | 40 | # Testing for MVTec AD 41 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 128 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/ 42 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/ 43 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 512 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/ 44 | 45 | 46 | # Training for MVTec 3D AD 47 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_3d_anomaly_detection/ 48 | 49 | # Testing for MVTec 3D AD 50 | python PEFM_AD.py --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_3d_anomaly_detection/ 51 | ``` 52 | 53 | ## Citation 54 | 55 | If there is any help for your work, please consider citing these papers: 56 | 57 | ```BibTeX 58 | @ARTICLE{PFM, 59 | author={Wan, Qian and Gao, Liang and Li, Xinyu and Wen, Long}, 60 | journal={IEEE Transactions on Industrial Informatics}, 61 | title={Unsupervised Image Anomaly Detection and Segmentation Based on Pre-trained Feature Mapping}, 62 | year={2022}, 63 | volume={}, 64 | number={}, 65 | pages={}, 66 | doi={10.1109/TII.2022.3182385} 67 | } 68 | @INPROCEEDINGS{PEFM, 69 | author={Wan, Qian and Cao YunKang and Gao, Liang and Shen Weiming and Li, Xinyu}, 70 | booktitle={2022 IEEE 18th International Conference on Automation Science and Engineering (CASE)}, 71 | title={Position Encoding Enhanced Feature Mapping for Image Anomaly Detection}, 72 | year={2022}, 73 | volume={}, 74 | number={}, 75 | pages={}, 76 | doi={} 77 | } 78 | ``` 79 | 80 | ## Acknowledgment 81 | 82 | Thanks for the excellent work for: 83 | - [SPADE](https://github.com/byungjae89/SPADE-pytorch) 84 | - [PositionalEncoding2D](https://github.com/wzlxjtu/PositionalEncoding2D) 85 | -------------------------------------------------------------------------------- /dataset/stc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import numpy as np 5 | from loguru import logger 6 | 7 | from PIL import Image 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms as T 11 | 12 | 13 | DATA_ROOT = "D:/Dataset/ShanghaiTech/" 14 | BIN_ROOT = os.path.join(DATA_ROOT, 'bin') 15 | os.makedirs(BIN_ROOT, exist_ok=True) 16 | 17 | INTERVAL = 5 18 | 19 | def capture_training_frames(): 20 | 21 | video_path = os.path.join(DATA_ROOT, 'training', 'videos') 22 | video_list = glob.glob(video_path + '/*.avi') 23 | video_list.sort() 24 | 25 | frame_root = os.path.join(BIN_ROOT, "train", "image") 26 | 27 | for v in video_list: 28 | video_name = os.path.basename(v) 29 | 30 | scene_class = video_name.split('_')[0] 31 | frame_path = os.path.join(frame_root, scene_class) 32 | os.makedirs(frame_path, exist_ok=True) 33 | 34 | cap = cv2.VideoCapture(v) 35 | cnt = 0 36 | while (cap.isOpened()): 37 | ret, frame = cap.read() 38 | if ret == False: 39 | break 40 | if cnt % INTERVAL == 0: 41 | # cv2.imshow(video_name, frame) 42 | # cv2.waitKey(0) 43 | frame = cv2.resize(frame, (256, 256)) 44 | frame_name = video_name.split('.')[0] + f"_{cnt}.png" 45 | cv2.imwrite(os.path.join(frame_path, frame_name), frame) 46 | logger.info(os.path.join(frame_path, frame_name)) 47 | cnt += 1 48 | cap.release() 49 | 50 | def split_testing_frames(): 51 | testing_frames_root = os.path.join(DATA_ROOT, "testing", "frames") 52 | testing_frame_mask_root = os.path.join(DATA_ROOT, "testing", "test_frame_mask") 53 | testing_pixel_mask_root = os.path.join(DATA_ROOT, "testing", "test_pixel_mask") 54 | 55 | img_root = os.path.join(BIN_ROOT, "test", 'image') 56 | gt_root = os.path.join(BIN_ROOT, "test", 'groundtruth') 57 | 58 | scene_folders = os.listdir(testing_frames_root) 59 | scene_folders.sort() 60 | for sf in scene_folders: 61 | scenne_class = sf.split("_")[0] 62 | img_path = os.path.join(img_root, scenne_class) 63 | os.makedirs(img_path, exist_ok=True) 64 | gt_path = os.path.join(gt_root, scenne_class) 65 | os.makedirs(gt_path, exist_ok=True) 66 | 67 | frames_path = os.path.join(testing_frames_root, sf) 68 | frames_list = glob.glob(frames_path + "/*.*") 69 | frames_list.sort() 70 | 71 | frames_pixel_masks = np.load(os.path.join(testing_pixel_mask_root, sf + ".npy")) 72 | # print(np.max(frames_pixel_masks)) 73 | for cnt, f in enumerate(frames_list): 74 | if cnt % 1 == 0: 75 | # frame 76 | frame = cv2.imread(f) 77 | frame = cv2.resize(frame, (256, 256)) 78 | frame_name = os.path.basename(f).split('.')[0] + '.png' 79 | cv2.imwrite(os.path.join(img_path, frame_name), frame) 80 | logger.info(os.path.join(img_path, frame_name)) 81 | 82 | # gt 83 | gt = frames_pixel_masks[cnt] * 255 84 | gt = cv2.resize(gt, (256, 256), cv2.INTER_NEAREST) 85 | cv2.imwrite(os.path.join(gt_path, frame_name), gt) 86 | 87 | 88 | STC_CLASS = ["02", "01", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"] 89 | # STC_CLASS = ["02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "13"] 90 | class STCDataset(Dataset): 91 | def __init__(self, root_path=BIN_ROOT, scene_class='01', is_train=True, 92 | resize=256, trans=None): 93 | 94 | assert scene_class in STC_CLASS 95 | 96 | self.root_path = root_path 97 | self.scene_class = scene_class 98 | self.is_train = is_train 99 | 100 | # set transforms 101 | if trans is None: 102 | self.transform_x = T.Compose([T.Resize(resize, Image.ANTIALIAS), 103 | T.ToTensor(), 104 | T.Normalize(mean=[0.485, 0.456, 0.406], 105 | std=[0.229, 0.224, 0.225])]) 106 | else: 107 | self.transform_x = trans 108 | self.transform_mask = T.Compose([T.Resize(resize, Image.NEAREST), 109 | T.ToTensor()]) 110 | 111 | self.x, self.mask = self.load_image() 112 | # self.len = len(self.x) // 3 113 | self.len = len(self.x) 114 | 115 | def load_image(self): 116 | if self.is_train: 117 | img_path = os.path.join(self.root_path, "train", "image", self.scene_class) 118 | img_list = glob.glob(img_path + "/*.png") 119 | mask_list = None 120 | else: 121 | img_path = os.path.join(self.root_path, "test", "image", self.scene_class) 122 | mask_path = os.path.join(self.root_path, "test", "groundtruth", self.scene_class) 123 | img_list = glob.glob(img_path + "/*.png") 124 | mask_list = glob.glob(mask_path + "/*.png") 125 | img_list.sort() 126 | mask_list.sort() 127 | assert len(img_list) == len(mask_list) 128 | return img_list, mask_list 129 | 130 | def __len__(self): 131 | return self.len 132 | 133 | def __getitem__(self, idx): 134 | x = self.x[idx] 135 | name = os.path.basename(x) 136 | x = Image.open(x).convert('RGB') 137 | x = self.transform_x(x) 138 | 139 | if self.is_train: 140 | _, H, W = x.shape 141 | mask = torch.zeros((1, H, W)) 142 | else: 143 | mask = self.mask[idx] 144 | mask = Image.open(mask) 145 | mask = self.transform_mask(mask) 146 | y = torch.max(mask) 147 | return x, y, mask, name 148 | 149 | if __name__ == "__main__": 150 | # capture_training_frames() 151 | # split_testing_frames() 152 | stc = STCDataset(is_train=False) 153 | data = stc[120] 154 | print(data) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.resnet import resnet18, resnet34 4 | from torchvision.models.vgg import vgg16_bn, vgg16 5 | # from resnet_no_relu import resnet18_nr 6 | import torch.nn.functional as F 7 | 8 | class VGG16(nn.Module): 9 | def __init__(self, pretrained=False, bn=False): 10 | super(VGG16, self).__init__() 11 | if bn: 12 | model = vgg16_bn(pretrained=pretrained) 13 | self.modules = list(model.features) 14 | self.block0 = nn.Sequential(*self.modules[0:17]) 15 | self.block1 = nn.Sequential(*self.modules[17:23]) 16 | self.block2 = nn.Sequential(*self.modules[23:33]) 17 | self.block3 = nn.Sequential(*self.modules[33:43]) 18 | else: 19 | model = vgg16(pretrained=pretrained) 20 | self.modules = list(model.features) 21 | self.block0 = nn.Sequential(*self.modules[0:12]) 22 | self.block1 = nn.Sequential(*self.modules[12:16]) 23 | self.block2 = nn.Sequential(*self.modules[16:23]) 24 | self.block3 = nn.Sequential(*self.modules[23:30]) 25 | 26 | def forward(self, x): 27 | # 64x64x256 28 | out0 = self.block0(x) 29 | # 64x64x256 30 | out1 = self.block1(out0) 31 | # 32x32x512 32 | out2 = self.block2(out1) 33 | # 16x16x512 34 | out3 = self.block3(out2) 35 | return {"out2": out0, 36 | "out3": out1, 37 | "out4": out2, 38 | "out5": out3 39 | } 40 | 41 | 42 | class ResNet18(nn.Module): 43 | def __init__(self, pretrained=False): 44 | super(ResNet18, self).__init__() 45 | model = resnet18(pretrained=pretrained) 46 | 47 | modules = list(model.children()) 48 | self.block1 = nn.Sequential(*modules[0:4]) 49 | self.block2 = modules[4] 50 | self.block3 = modules[5] 51 | self.block4 = modules[6] 52 | self.block5 = modules[7] 53 | 54 | def forward(self, x): 55 | x = self.block1(x) 56 | # 64x64x64 57 | out2 = self.block2(x) 58 | # 32x32x128 59 | out3 = self.block3(out2) 60 | # 16x16x256 61 | out4 = self.block4(out3) 62 | # 8x8x512 63 | out5 = self.block5(out4) 64 | return {"out2": out2, 65 | "out3": out3, 66 | "out4": out4, 67 | "out5": out5 68 | } 69 | 70 | 71 | # import copy 72 | # class ResNet18NR(nn.Module): 73 | # def __init__(self, pretrained=False): 74 | # super(ResNet18NR, self).__init__() 75 | # model_nr = resnet18_nr(pretrained=False) 76 | # modules_nr = list(model_nr.children()) 77 | 78 | 79 | # self.block2_logvar = copy.copy(nn.Sequential(*modules_nr[0:5])) 80 | # self.block3_logvar = copy.copy(nn.Sequential(modules_nr[5])) 81 | # self.block4_logvar = copy.copy(nn.Sequential(modules_nr[6])) 82 | 83 | # def forward(self, x): 84 | # out2_logvar = self.block2_logvar(x) 85 | # out3_logvar = self.block3_logvar(out2_logvar) 86 | # out4_logvar = self.block4_logvar(out3_logvar) 87 | # # 8x8x512 88 | # return {"out2": out2_logvar, 89 | # "out3": out3_logvar, 90 | # "out4": out4_logvar 91 | # } 92 | 93 | 94 | 95 | class ResNet34(nn.Module): 96 | def __init__(self, pretrained=False): 97 | super(ResNet34, self).__init__() 98 | model = resnet34(pretrained=pretrained) 99 | 100 | modules = list(model.children()) 101 | self.block1 = nn.Sequential(*modules[0:4]) 102 | self.block2 = modules[4] 103 | self.block3 = modules[5] 104 | self.block4 = modules[6] 105 | self.block5 = modules[7] 106 | 107 | def forward(self, x): 108 | x = self.block1(x) 109 | # 64x64x64 110 | out2 = self.block2(x) 111 | # 32x32x128 112 | out3 = self.block3(out2) 113 | # 16x16x256 114 | out4 = self.block4(out3) 115 | # 8x8x512 116 | out5 = self.block5(out4) 117 | return {"out2": out2, 118 | "out3": out3, 119 | "out4": out4, 120 | "out5": out5 121 | } 122 | 123 | 124 | class KDLoss(nn.Module): 125 | def __init__(self, loss_type=None): 126 | super(KDLoss, self).__init__() 127 | ''' 128 | loss type: l2, l1, consine, l2+consine, l2norm+l2 129 | ''' 130 | self.type = loss_type 131 | 132 | def get_loss_map(self, feat_T, feat_S): 133 | ''' 134 | :param feat_T: NxCxHxW 135 | :param feat_S: NxCxHxW 136 | :return: 137 | ''' 138 | if self.type == "l2norm+l2": 139 | feat_T = F.normalize(feat_T, p=2, dim=1) 140 | feat_S = F.normalize(feat_S, p=2, dim=1) 141 | 142 | loss_map = 0.5 * ((feat_T - feat_S) ** 2) 143 | loss_map = torch.sum(loss_map, dim=1) 144 | 145 | elif self.type == "l1norm+l2": 146 | feat_T = F.normalize(feat_T, p=1, dim=1) 147 | feat_S = F.normalize(feat_S, p=1, dim=1) 148 | 149 | loss_map = 0.5 * ((feat_T - feat_S) ** 2) 150 | loss_map = torch.sum(loss_map, dim=1) 151 | 152 | elif self.type == "consine": 153 | feat_T = F.normalize(feat_T, p=2, dim=1) 154 | feat_S = F.normalize(feat_S, p=2, dim=1) 155 | loss_map = 1 - torch.sum(torch.mul(feat_T, feat_S), dim=1) 156 | 157 | elif self.type == "l2": 158 | loss_map = (feat_T - feat_S) ** 2 159 | loss_map = torch.sum(loss_map, dim=1) 160 | 161 | elif self.type == "l1": 162 | loss_map = torch.abs(feat_T - feat_S) 163 | loss_map = torch.sum(loss_map, dim=1) 164 | 165 | else: 166 | raise NotImplementedError 167 | 168 | return loss_map 169 | 170 | def forward(self, feat_T, feat_S): 171 | loss_map = self.get_loss_map(feat_T, feat_S) 172 | # if self.type == "consine": 173 | # return torch.mean(loss_map) 174 | # else: 175 | # return torch.sum(loss_map) 176 | return torch.sum(torch.mean(loss_map, dim=(1, 2))) 177 | # if use mean, must increase the learning rate 178 | # return torch.mean(loss_map) 179 | 180 | 181 | class Conv_BN_PRelu(nn.Module): 182 | def __init__(self, in_dim, out_dim, k=1, s=1, p=0, bn=True, prelu=True): 183 | super(Conv_BN_PRelu, self).__init__() 184 | self.conv = [ 185 | nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p), 186 | ] 187 | if bn: 188 | self.conv.append(nn.BatchNorm2d(out_dim)) 189 | if prelu: 190 | self.conv.append(nn.PReLU()) 191 | 192 | self.conv = nn.Sequential(*self.conv) 193 | 194 | def forward(self, x): 195 | return self.conv(x) 196 | 197 | 198 | class NonLocalAttention(nn.Module): 199 | def __init__(self, channel=256, reduction=2, rescale=1.0): 200 | super(NonLocalAttention, self).__init__() 201 | # self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU()) 202 | # self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU()) 203 | # self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU()) 204 | 205 | self.conv_match1 = Conv_BN_PRelu(channel, channel//reduction, 1, bn=False, prelu=True) 206 | self.conv_match2 = Conv_BN_PRelu(channel, channel//reduction, 1, bn=False, prelu=True) 207 | self.conv_assembly = Conv_BN_PRelu(channel, channel, 1,bn=False, prelu=True) 208 | self.rescale = rescale 209 | 210 | def forward(self, input): 211 | x_embed_1 = self.conv_match1(input) 212 | x_embed_2 = self.conv_match2(input) 213 | x_assembly = self.conv_assembly(input) 214 | 215 | N,C,H,W = x_embed_1.shape 216 | x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C)) 217 | x_embed_2 = x_embed_2.view(N,C,H*W) 218 | score = torch.matmul(x_embed_1, x_embed_2) 219 | score = F.softmax(score, dim=2) 220 | x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1) 221 | x_final = torch.matmul(score, x_assembly) 222 | x_final = x_final.permute(0,2,1).view(N,-1,H,W) 223 | return x_final + input*self.rescale 224 | # return x_final 225 | 226 | 227 | if __name__ == "__main__": 228 | x = torch.rand([2, 3, 256, 256]) 229 | T = ResNet18NR(True) 230 | T(x) 231 | S = ResNet18NR(False) 232 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from skimage import measure 4 | from sklearn.metrics import auc 5 | from loguru import logger 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | def save_score_map(score_maps, names, class_name, save_path, loss_name=None): 11 | assert loss_name is not None 12 | num = len(score_maps) 13 | os.makedirs(save_path, exist_ok=True) 14 | score_maps = np.array(score_maps) 15 | max_mum = np.max(score_maps) 16 | min_mum = np.min(score_maps) 17 | score_maps = (score_maps - min_mum) / (max_mum - min_mum) * 255.0 18 | for _idx in range(num): 19 | score_map = np.uint8(np.squeeze(score_maps[_idx])) 20 | score_map = cv2.applyColorMap(score_map, cv2.COLORMAP_JET) 21 | 22 | _name = names[_idx].split('.')[0] 23 | path0 = os.path.join(save_path, f"{class_name}_{_name}_{loss_name}.png") 24 | cv2.imwrite(path0, score_map) 25 | 26 | def visualize(test_imgs, test_masks, score_maps, names, class_name, save_path, num=100, trans='imagenet'): 27 | num = min(num, len(test_imgs)) 28 | os.makedirs(save_path, exist_ok=True) 29 | score_maps = np.array(score_maps) 30 | max_mum = np.max(score_maps) 31 | min_mum = np.min(score_maps) 32 | score_maps = (score_maps - min_mum) / (max_mum - min_mum) * 255.0 33 | for _idx in range(num): 34 | test_img = test_imgs[_idx] 35 | test_img = denormalize(test_img, trans=trans) 36 | test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR) 37 | 38 | test_mask = test_masks[_idx].transpose(1, 2, 0).squeeze() 39 | 40 | score_map = np.uint8(np.squeeze(score_maps[_idx])) 41 | # score_map = cv2.normalize(score_map, score_map, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U) 42 | score_map = cv2.applyColorMap(score_map, cv2.COLORMAP_JET) 43 | # cv2.imshow("score_map", score_map) 44 | # cv2.waitKey(0) 45 | # res_img = cv2.addWeighted(test_img, 0.4, score_map, 0.6, 0) 46 | # test_img = draw_detect(test_img, test_mask) 47 | 48 | name = names[_idx].split('.')[0] 49 | path0 = os.path.join(save_path, f"{class_name}_{name}.png") 50 | cv2.imwrite(path0, score_map) 51 | 52 | test_img_mask = draw_detect(test_img, test_mask) 53 | path1 = os.path.join(save_path, f"{class_name}_{name}_gt.png") 54 | cv2.imwrite(path1, test_img_mask) 55 | 56 | 57 | def draw_detect(img, label): 58 | assert len(label.shape) == 2 59 | label = np.uint8(label*255) 60 | _, label = cv2.threshold(label, 5, 255, 0) 61 | _, label_cnts, _ = cv2.findContours(label, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 62 | 63 | # img = cv2.drawContours(img, label_cnts, -1, (0, 0, 255), 1) 64 | mask = np.zeros(img.shape) 65 | mask = cv2.fillPoly(mask, label_cnts, color=(0, 0, 255)) 66 | label=np.expand_dims(255-label, axis=2)/255 67 | img = img*(np.concatenate([label, label, label], axis=2)) + mask 68 | img = np.clip(img, 0, 255) 69 | img = np.uint8(img) 70 | 71 | # cv2.imshow("img", img) 72 | # cv2.waitKey(0) 73 | return img 74 | 75 | 76 | def denormalize(img, trans='imagenet'): 77 | if trans == 'imagenet': 78 | std = np.array([0.229, 0.224, 0.225]) 79 | mean = np.array([0.485, 0.456, 0.406]) 80 | x = (((img.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8) 81 | elif trans == 'coco': 82 | std = np.array([0.5, 0.5, 0.5]) 83 | mean = np.array([0.5, 0.5, 0.5]) 84 | x = (((img.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8) 85 | elif trans == 'navie': 86 | x = (img.transpose(1, 2, 0) * 255.).astype(np.uint8) 87 | elif trans == 'no': 88 | x = (img.transpose(1, 2, 0)).astype(np.uint8) 89 | else: 90 | raise NotImplementedError 91 | return x 92 | 93 | import os 94 | from PIL import Image 95 | 96 | from torchvision import transforms as T 97 | def load_image(path): 98 | transform_x = T.Compose([T.Resize(256, Image.ANTIALIAS), 99 | T.CenterCrop(224), 100 | T.ToTensor(), 101 | T.Normalize(mean=[0.485, 0.456, 0.406], 102 | std=[0.229, 0.224, 0.225])]) 103 | x = Image.open(path).convert('RGB') 104 | x = transform_x(x) 105 | x = x.unsqueeze(0) 106 | return x 107 | 108 | 109 | def cal_pro_metric_new(labeled_imgs, score_imgs, fpr_thresh=0.3, max_steps=2000, class_name=None): 110 | labeled_imgs = np.array(labeled_imgs).squeeze(1) 111 | labeled_imgs[labeled_imgs <= 0.45] = 0 112 | labeled_imgs[labeled_imgs > 0.45] = 1 113 | labeled_imgs = labeled_imgs.astype(np.bool) 114 | score_imgs = np.array(score_imgs).squeeze(1) 115 | 116 | max_th = score_imgs.max() 117 | min_th = score_imgs.min() 118 | delta = (max_th - min_th) / max_steps 119 | 120 | ious_mean = [] 121 | ious_std = [] 122 | pros_mean = [] 123 | pros_std = [] 124 | threds = [] 125 | fprs = [] 126 | binary_score_maps = np.zeros_like(score_imgs, dtype=np.bool) 127 | for step in range(max_steps): 128 | thred = max_th - step * delta 129 | # segmentation 130 | binary_score_maps[score_imgs <= thred] = 0 131 | binary_score_maps[score_imgs > thred] = 1 132 | 133 | pro = [] # per region overlap 134 | iou = [] # per image iou 135 | # pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region 136 | # iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map 137 | for i in range(len(binary_score_maps)): # for i th image 138 | # pro (per region level) 139 | label_map = measure.label(labeled_imgs[i], connectivity=2) 140 | props = measure.regionprops(label_map) 141 | for prop in props: 142 | x_min, y_min, x_max, y_max = prop.bbox 143 | cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max] 144 | # cropped_mask = masks[i][x_min:x_max, y_min:y_max] 145 | cropped_mask = prop.filled_image # corrected! 146 | intersection = np.logical_and(cropped_pred_label, cropped_mask).astype(np.float32).sum() 147 | pro.append(intersection / prop.area) 148 | # iou (per image level) 149 | intersection = np.logical_and(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum() 150 | union = np.logical_or(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum() 151 | if labeled_imgs[i].any() > 0: # when the gt have no anomaly pixels, skip it 152 | iou.append(intersection / union) 153 | # against steps and average metrics on the testing data 154 | ious_mean.append(np.array(iou).mean()) 155 | # print("per image mean iou:", np.array(iou).mean()) 156 | ious_std.append(np.array(iou).std()) 157 | pros_mean.append(np.array(pro).mean()) 158 | pros_std.append(np.array(pro).std()) 159 | # fpr for pro-auc 160 | masks_neg = ~labeled_imgs 161 | fpr = np.logical_and(masks_neg, binary_score_maps).sum() / masks_neg.sum() 162 | fprs.append(fpr) 163 | threds.append(thred) 164 | 165 | # as array 166 | threds = np.array(threds) 167 | pros_mean = np.array(pros_mean) 168 | pros_std = np.array(pros_std) 169 | fprs = np.array(fprs) 170 | 171 | 172 | # default 30% fpr vs pro, pro_auc 173 | idx = fprs <= fpr_thresh # find the indexs of fprs that is less than expect_fpr (default 0.3) 174 | fprs_selected = fprs[idx] 175 | fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1] 176 | pros_mean_selected = pros_mean[idx] 177 | pro_auc_score = auc(fprs_selected, pros_mean_selected) 178 | # print("pro auc ({}% FPR):".format(int(expect_fpr * 100)), pro_auc_score) 179 | return pro_auc_score 180 | 181 | 182 | 183 | def rescale(x): 184 | return (x - x.min()) / (x.max() - x.min()) 185 | 186 | import torch 187 | import random 188 | def set_seed(seed): 189 | torch.manual_seed(seed) 190 | torch.cuda.manual_seed_all(seed) 191 | np.random.seed(seed) 192 | random.seed(seed) 193 | torch.backends.cudnn.deterministic = True 194 | torch.backends.cudnn.benchmark = False 195 | 196 | 197 | 198 | 199 | 200 | from thop import profile 201 | from thop import clever_format 202 | def calculate_flops(device, xs, xt, model): 203 | model = model.eval().to(device) 204 | xs = xs.to(device) 205 | xt = xt.to(device) 206 | 207 | flops, params = profile(model, inputs=(xs, xt,)) 208 | flops, params = clever_format([flops, params], "%.3f") 209 | print(f"[INFO] flops: {flops}") 210 | print(f"[INFO] params: {params}") 211 | return flops, params 212 | 213 | 214 | if __name__ == "__main__": 215 | from MDFP_Dual_Norm_AD import DualProjectionNet 216 | 217 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 218 | ins = [64, 128, 256] 219 | outs = [256, 512, 1024] 220 | latents = [200, 400, 900] 221 | xs = [torch.rand([1, 64, 64, 64]), torch.rand([1, 128, 32, 32]), torch.rand([1, 256, 16, 16])] 222 | xt = [torch.rand([1, 256, 64, 64]), torch.rand([1, 512, 32, 32]), torch.rand([1, 1024, 16, 16])] 223 | for _in, _out, _latent, _xs, _xt in zip(ins, outs, latents, xs, xt): 224 | model1 = DualProjectionNet(in_dim=_in, out_dim=_out, latent_dim=_latent) 225 | flops, params = calculate_flops(device, model=model1, xs=_xs, xt=_xt) 226 | 227 | -------------------------------------------------------------------------------- /MB-PFM-VGG.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Unsupervised Image Anomaly Detection and Segmentation Based on Pre-trained Feature Mapping 3 | ''' 4 | import shutil 5 | 6 | from torch.utils.tensorboard import SummaryWriter 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.models.resnet import resnet18, resnet34 11 | from torchvision.models.vgg import vgg16_bn, vgg19_bn 12 | import os 13 | import numpy as np 14 | from sklearn.metrics import roc_auc_score 15 | from utils import visualize, cal_pro_metric_new, set_seed 16 | from loguru import logger 17 | import argparse 18 | 19 | 20 | class PretrainedModel(nn.Module): 21 | def __init__(self, model_name): 22 | super(PretrainedModel, self).__init__() 23 | if "resnet" in model_name: 24 | model = eval(model_name)(pretrained=True) 25 | modules = list(model.children()) 26 | self.block1 = nn.Sequential(*modules[0:4]) 27 | self.block2 = modules[4] 28 | self.block3 = modules[5] 29 | self.block4 = modules[6] 30 | self.block5 = modules[7] 31 | elif "vgg" in model_name: 32 | model = eval(model_name)(pretrained=True) 33 | self.modules = list(model.features) 34 | if model_name == "vgg16_bn": 35 | self.block1 = nn.Sequential(*self.modules[0:14]) 36 | self.block2 = nn.Sequential(*self.modules[14:23]) 37 | self.block3 = nn.Sequential(*self.modules[23:33]) 38 | self.block4 = nn.Sequential(*self.modules[33:43]) 39 | else: 40 | self.block1 = nn.Sequential(*self.modules[0:14]) 41 | self.block2 = nn.Sequential(*self.modules[14:26]) 42 | self.block3 = nn.Sequential(*self.modules[26:39]) 43 | self.block4 = nn.Sequential(*self.modules[39:52]) 44 | 45 | else: 46 | raise NotImplementedError 47 | 48 | def forward(self, x): 49 | # B x 64 x 64 x 64 50 | out1 = self.block1(x) 51 | # B x 128 x 32 x 32 52 | out2 = self.block2(out1) 53 | # B x 256 x 16 x 16 54 | # 32x32x128 55 | out3 = self.block3(out2) 56 | # 16x16x256 57 | out4 = self.block4(out3) 58 | return {"out2": out2, 59 | "out3": out3, 60 | "out4": out4 61 | } 62 | 63 | # x = torch.rand([1, 3, 256, 256]) 64 | # # T = ResNetS(model_name='resnet18', pretrained=True) 65 | # T = PretrainedModel(model_name='vgg16_bn') 66 | # y = T(x) 67 | # 68 | # for key in y.keys(): 69 | # print(f"{key}: {y[key].shape}") 70 | 71 | class Conv_BN_Relu(nn.Module): 72 | def __init__(self, in_dim, out_dim, k=1, s=1, p=0, bn=True, relu=True): 73 | super(Conv_BN_Relu, self).__init__() 74 | self.conv = [ 75 | nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p), 76 | ] 77 | if bn: 78 | self.conv.append(nn.BatchNorm2d(out_dim)) 79 | if relu: 80 | self.conv.append(nn.ReLU(inplace=True)) 81 | 82 | self.conv = nn.Sequential(*self.conv) 83 | 84 | def forward(self, x): 85 | return self.conv(x) 86 | 87 | 88 | 89 | class DualProjectionNet(nn.Module): 90 | def __init__(self, in_dim=512, out_dim=512, latent_dim=256): 91 | super(DualProjectionNet, self).__init__() 92 | self.encoder1 = nn.Sequential(*[ 93 | Conv_BN_Relu(in_dim, in_dim//2+latent_dim), 94 | Conv_BN_Relu(in_dim//2+latent_dim, 2*latent_dim), 95 | # Conv_BN_Relu(2*latent_dim, latent_dim), 96 | ]) 97 | 98 | self.shared_coder = Conv_BN_Relu(2*latent_dim, latent_dim, bn=False, relu=False) 99 | 100 | self.decoder1 = nn.Sequential(*[ 101 | Conv_BN_Relu(latent_dim, 2*latent_dim), 102 | Conv_BN_Relu(2*latent_dim, out_dim//2+latent_dim), 103 | Conv_BN_Relu(out_dim//2+latent_dim, out_dim, bn=False, relu=False), 104 | ]) 105 | 106 | 107 | self.encoder2 = nn.Sequential(*[ 108 | Conv_BN_Relu(out_dim, out_dim // 2 + latent_dim), 109 | Conv_BN_Relu(out_dim // 2 + latent_dim, 2 * latent_dim), 110 | # Conv_BN_Relu(2 * latent_dim, latent_dim), 111 | ]) 112 | 113 | self.decoder2 = nn.Sequential(*[ 114 | Conv_BN_Relu(latent_dim, 2 * latent_dim), 115 | Conv_BN_Relu(2 * latent_dim, in_dim // 2 + latent_dim), 116 | Conv_BN_Relu(in_dim // 2 + latent_dim, in_dim, bn=False, relu=False), 117 | ]) 118 | 119 | 120 | def forward(self, xs, xt): 121 | xt_hat = self.encoder1(xs) 122 | xt_hat = self.shared_coder(xt_hat) 123 | xt_hat = self.decoder1(xt_hat) 124 | 125 | xs_hat = self.encoder2(xt) 126 | xs_hat = self.shared_coder(xs_hat) 127 | xs_hat = self.decoder2(xs_hat) 128 | 129 | return xs_hat, xt_hat 130 | 131 | 132 | class DFP_AD(object): 133 | def __init__(self, type='vgg'): 134 | if type == "resnet": 135 | self.Agent1 = PretrainedModel(model_name="resnet18") 136 | self.Agent2 = PretrainedModel(model_name="resnet34") 137 | 138 | elif type == "vgg": 139 | self.Agent1 = PretrainedModel(model_name="vgg16_bn") 140 | self.Agent2 = PretrainedModel(model_name="vgg19_bn") 141 | 142 | def register(self, **kwargs): 143 | self.class_name = kwargs['class_name'] 144 | self.device = kwargs['device'] 145 | self.trainloader = kwargs['trainloader'] 146 | self.testloader = kwargs['testloader'] 147 | 148 | 149 | self.projector2 = DualProjectionNet(in_dim=256, out_dim=256, latent_dim=200) 150 | self.optimizer2 = torch.optim.Adam(self.projector2.parameters(), lr=kwargs["lr2"], weight_decay=kwargs["weight_decay"]) 151 | self.projector3 = DualProjectionNet(in_dim=512, out_dim=512, latent_dim=400) 152 | self.optimizer3 = torch.optim.Adam(self.projector3.parameters(), lr=kwargs["lr3"], weight_decay=kwargs["weight_decay"]) 153 | self.projector4 = DualProjectionNet(in_dim=512, out_dim=512, latent_dim=400) 154 | self.optimizer4 = torch.optim.Adam(self.projector4.parameters(), lr=kwargs["lr4"], weight_decay=kwargs["weight_decay"]) 155 | 156 | self.Agent1.to(self.device).eval() 157 | self.Agent2.to(self.device).eval() 158 | self.projector2.to(self.device) 159 | self.projector3.to(self.device) 160 | self.projector4.to(self.device) 161 | 162 | self.save_root = "./result/MB-PFM-VGG_{}/".format(kwargs["seed"]) 163 | os.makedirs(os.path.join(self.save_root, "ckpt"), exist_ok=True) 164 | self.ckpt2 = os.path.join(self.save_root, "ckpt/{}_2.pth".format(kwargs["class_name"])) 165 | self.ckpt3 = os.path.join(self.save_root, "ckpt/{}_3.pth".format(kwargs["class_name"])) 166 | self.ckpt4 = os.path.join(self.save_root, "ckpt/{}_4.pth".format(kwargs["class_name"])) 167 | os.makedirs(os.path.join(self.save_root, "tblogs"), exist_ok=True) 168 | self.tblog = os.path.join(self.save_root, "tblogs/{}".format(kwargs["class_name"])) 169 | 170 | 171 | def get_agent_out(self, x): 172 | out_a1 = self.Agent1(x) 173 | out_a2 = self.Agent2(x) 174 | for key in out_a2.keys(): 175 | out_a1[key] = F.normalize(out_a1[key], p=2) 176 | out_a2[key] = F.normalize(out_a2[key], p=2) 177 | return out_a1, out_a2 178 | 179 | def train(self, epochs=100): 180 | if not os.path.exists(self.ckpt2): 181 | if os.path.exists(self.tblog): 182 | shutil.rmtree(self.tblog) 183 | os.makedirs(self.tblog, exist_ok=True) 184 | self.writer = SummaryWriter(log_dir=self.tblog) 185 | for ep in range(0, epochs): 186 | self.projector2.train() 187 | self.projector3.train() 188 | self.projector4.train() 189 | for i, (x, _, _, _) in enumerate(self.trainloader): 190 | x = x.to(self.device) 191 | out_a1, out_a2 = self.get_agent_out(x) 192 | 193 | # project_out2 = self.projector2(out_a1["out2"].detach()) 194 | # loss2 = torch.mean((out_a2["out2"].detach() - project_out2) ** 2) 195 | project_out21, project_out22 = self.projector2(out_a1["out2"].detach(), out_a2["out2"].detach()) 196 | loss21 = torch.mean((out_a1["out2"] - project_out21) ** 2) 197 | loss22 = torch.mean((out_a2["out2"] - project_out22) ** 2) 198 | loss2 = loss21 + loss22 199 | self.optimizer2.zero_grad() 200 | loss2.backward() 201 | self.optimizer2.step() 202 | 203 | project_out31, project_out32 = self.projector3(out_a1["out3"].detach(), out_a2["out3"].detach()) 204 | loss31 = torch.mean((out_a1["out3"].detach() - project_out31) ** 2) 205 | loss32 = torch.mean((out_a2["out3"].detach() - project_out32) ** 2) 206 | loss3 = loss31 + loss32 207 | self.optimizer3.zero_grad() 208 | loss3.backward() 209 | self.optimizer3.step() 210 | 211 | project_out41, project_out42 = self.projector4(out_a1["out4"].detach(), out_a2["out4"].detach()) 212 | loss41 = torch.mean((out_a1["out4"].detach() - project_out41) ** 2) 213 | loss42 = torch.mean((out_a2["out4"].detach() - project_out42) ** 2) 214 | loss4 = loss41 + loss42 215 | self.optimizer4.zero_grad() 216 | loss4.backward() 217 | self.optimizer4.step() 218 | 219 | print(f"Epoch-{ep}-Step-{i}, {self.class_name} | loss2: {loss2.item():.6f} | loss3: {loss3.item():.6f} | loss4: {loss4.item():.6f}") 220 | self.writer.add_scalar('Train/loss2', loss2.item(), ep*len(self.trainloader)+i) 221 | self.writer.add_scalar('Train/loss3', loss3.item(), ep*len(self.trainloader)+i) 222 | self.writer.add_scalar('Train/loss4', loss4.item(), ep*len(self.trainloader)+i) 223 | # print(f"Epoch-{ep}-Step-{i}, {self.class_name} | loss: {loss.item():.5f} | loss1: {loss1.item():.5f} | loss2: {loss2.item():.5f}") 224 | # self.writer.add_scalar('Train/loss_l2', loss1.item(), ep*len(self.trainloader)+i) 225 | # self.writer.add_scalar('Train/loss_norm_l2', loss2.item(), ep*len(self.trainloader)+i) 226 | 227 | torch.save(self.projector2.state_dict(), self.ckpt2) 228 | torch.save(self.projector3.state_dict(), self.ckpt3) 229 | torch.save(self.projector4.state_dict(), self.ckpt4) 230 | if ep % 10 == 0: 231 | metrix = self.test(cal_pro=False) 232 | logger.info(f"Epoch-{ep}, {self.class_name} | all: {metrix['all'][0]:.5f}, {metrix['all'][1]:.5f} | 2: {metrix['2'][0]:.5f}, {metrix['2'][1]:.5f}" 233 | f"| 3: {metrix['3'][0]:.5f}, {metrix['3'][1]:.5f} | 4: {metrix['4'][0]:.5f}, {metrix['4'][1]:.5f}") 234 | self.writer.add_scalar('Val/imge_auc2', metrix['2'][0], ep) 235 | self.writer.add_scalar('Val/pixel_auc2', metrix['2'][1], ep) 236 | self.writer.add_scalar('Val/imge_auc3', metrix['3'][0], ep) 237 | self.writer.add_scalar('Val/pixel_auc3', metrix['3'][1], ep) 238 | self.writer.add_scalar('Val/imge_auc4', metrix['4'][0], ep) 239 | self.writer.add_scalar('Val/pixel_auc4', metrix['4'][1], ep) 240 | self.writer.add_scalar('Val/imge_auc', metrix['all'][0], ep) 241 | self.writer.add_scalar('Val/pixel_auc', metrix['all'][1], ep) 242 | self.writer.close() 243 | else: 244 | pass 245 | 246 | def test(self, cal_pro=False): 247 | self.load_project_model() 248 | self.projector2.eval() 249 | self.projector3.eval() 250 | self.projector4.eval() 251 | with torch.no_grad(): 252 | 253 | test_y_list = [] 254 | test_mask_list = [] 255 | test_img_list = [] 256 | test_img_name_list = [] 257 | # pixel-level 258 | score_map_list = [] 259 | score_list = [] 260 | 261 | score2_map_list = [] 262 | score2_list = [] 263 | score3_map_list = [] 264 | score3_list = [] 265 | score4_map_list = [] 266 | score4_list = [] 267 | 268 | for x, y, mask, name in self.testloader: 269 | test_y_list.extend(y.detach().cpu().numpy()) 270 | test_mask_list.extend(mask.detach().cpu().numpy()) 271 | test_img_list.extend(x.detach().cpu().numpy()) 272 | test_img_name_list.extend(name) 273 | 274 | x = x.to(self.device) 275 | _, _, H, W = x.shape 276 | out_a1, out_a2 = self.get_agent_out(x) 277 | 278 | project_out21, project_out22 = self.projector2(out_a1["out2"], out_a2["out2"]) 279 | loss21_map = torch.sum((out_a1["out2"] - project_out21) ** 2, dim=1, keepdim=True) 280 | loss22_map = torch.sum((out_a2["out2"] - project_out22) ** 2, dim=1, keepdim=True) 281 | loss2_map = (loss21_map + loss22_map) / 2.0 282 | score2_map = F.interpolate(loss2_map, size=(H, W), mode='bilinear', align_corners=False) 283 | score2_map = score2_map.cpu().detach().numpy() 284 | score2_map_list.extend(score2_map) 285 | score2_list.extend(np.squeeze(np.max(np.max(score2_map, axis=2), axis=2), 1)) 286 | 287 | project_out31, project_out32 = self.projector3(out_a1["out3"], out_a2["out3"]) 288 | loss31_map = torch.sum((out_a1["out3"] - project_out31) ** 2, dim=1, keepdim=True) 289 | loss32_map = torch.sum((out_a2["out3"] - project_out32) ** 2, dim=1, keepdim=True) 290 | loss3_map = (loss31_map + loss32_map) / 2.0 291 | score3_map = F.interpolate(loss3_map, size=(H, W), mode='bilinear', align_corners=False) 292 | score3_map = score3_map.cpu().detach().numpy() 293 | score3_map_list.extend(score3_map) 294 | score3_list.extend(np.squeeze(np.max(np.max(score3_map, axis=2), axis=2), 1)) 295 | 296 | 297 | project_out41, project_out42 = self.projector4(out_a1["out4"], out_a2["out4"]) 298 | loss41_map = torch.sum((out_a1["out4"] - project_out41) ** 2, dim=1, keepdim=True) 299 | loss42_map = torch.sum((out_a2["out4"] - project_out42) ** 2, dim=1, keepdim=True) 300 | loss4_map = (loss41_map + loss42_map) / 2.0 301 | score4_map = F.interpolate(loss4_map, size=(H, W), mode='bilinear', align_corners=False) 302 | score4_map = score4_map.cpu().detach().numpy() 303 | score4_map_list.extend(score4_map) 304 | score4_list.extend(np.squeeze(np.max(np.max(score4_map, axis=2), axis=2), 1)) 305 | 306 | score_map = (score4_map + score3_map + score2_map) / 3 307 | score_map_list.extend(score_map) 308 | score_list.extend(np.squeeze(np.max(np.max(score_map, axis=2), axis=2), 1)) 309 | 310 | visualize(test_img_list, test_mask_list, score_map_list, test_img_name_list, self.class_name, 311 | f"{self.save_root}image/", 10000) 312 | # ROCAUC 313 | # imge_auc2, pixel_auc2, pixel_pro2 = self.cal_auc(score2_list, score2_map_list, test_y_list, test_mask_list) 314 | # imge_auc3, pixel_auc3, pixel_pro3 = self.cal_auc(score3_list, score3_map_list, test_y_list, test_mask_list) 315 | # imge_auc4, pixel_auc4, pixel_pro4 = self.cal_auc(score4_list, score4_map_list, test_y_list, test_mask_list) 316 | imge_auc, pixel_auc, pixel_pro = self.cal_auc(score_list, score_map_list, test_y_list, test_mask_list) 317 | # print(f"pixel AUC: {pixel_level_ROCAUC:.5f}") 318 | # metrix = {"2": [imge_auc2, pixel_auc2, pixel_pro2], "3": [imge_auc3, pixel_auc3, pixel_pro3],"4":[imge_auc4, 319 | # pixel_auc4, pixel_pro4], "all": [imge_auc, pixel_auc, pixel_pro]} 320 | metrix = { "all": [imge_auc, pixel_auc, pixel_pro]} 321 | return metrix 322 | 323 | def cal_auc(self, score_list, score_map_list, test_y_list, test_mask_list): 324 | flatten_y_list = np.array(test_y_list).ravel() 325 | flatten_score_list = np.array(score_list).ravel() 326 | image_level_ROCAUC = roc_auc_score(flatten_y_list, flatten_score_list) 327 | 328 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 329 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 330 | pixel_level_ROCAUC = roc_auc_score(flatten_mask_list, flatten_score_map_list) 331 | pro_auc_score = cal_pro_metric_new(test_mask_list, score_map_list, fpr_thresh=0.3) 332 | return image_level_ROCAUC, pixel_level_ROCAUC, pro_auc_score 333 | 334 | def load_project_model(self): 335 | self.projector2.load_state_dict(torch.load(self.ckpt2)) 336 | self.projector3.load_state_dict(torch.load(self.ckpt3)) 337 | self.projector4.load_state_dict(torch.load(self.ckpt4)) 338 | 339 | 340 | 341 | def parse_args(): 342 | parser = argparse.ArgumentParser('STFPM_Center') 343 | parser.add_argument("--seed", type=int, default=888) 344 | parser.add_argument("--gpu_id", type=str, default="0") 345 | parser.add_argument("--train", type=bool, default=True) 346 | 347 | parser.add_argument("--data_trans", type=str, default='imagenet', choices=['navie', 'imagenet']) 348 | 349 | parser.add_argument("--loss_type", type=str, default='l2norm+l2', 350 | choices=['l2norm+l2', 'l2', 'l1', 'consine', 'l2+consine']) 351 | 352 | parser.add_argument("--model_name", type=str, default='vgg16', choices=['vgg16', 'resnet18']) 353 | parser.add_argument("--epochs", type=int, default=200) 354 | parser.add_argument("--batch_size", type=int, default=8) # 6 or 20 for train 355 | parser.add_argument("--lr2", type=float, default=3e-3) 356 | parser.add_argument("--lr3", type=float, default=3e-4) 357 | parser.add_argument("--lr4", type=float, default=3e-4) 358 | parser.add_argument("--weight_decay", type=float, default=1e-5) 359 | 360 | parser.add_argument("--latent_dim", type=int, default=200) 361 | 362 | parser.add_argument("--data_root", type=str, default="D:/Dataset/mvtec_anomaly_detection/") 363 | parser.add_argument("--resize", type=int, default=256) 364 | 365 | parser.add_argument("--post_smooth", type=int, default=0) 366 | args = parser.parse_args() 367 | return args 368 | 369 | 370 | if __name__ == "__main__": 371 | 372 | args = parse_args() 373 | set_seed(args.seed) 374 | logger.add( 375 | f'./result/MB-PFM-VGG_{args.seed}/logger-{args.data_trans}-{args.loss_type}-{args.resize}-{args.model_name}.txt', 376 | rotation="200 MB", 377 | backtrace=True, 378 | diagnose=True) 379 | logger.info(str(args)) 380 | 381 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 382 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 383 | 384 | from dataset.mvtec import MVTecDataset, MVTec_CLASS_NAMES 385 | from torch.utils.data import DataLoader 386 | from torchvision import transforms as T 387 | from PIL import Image 388 | 389 | if args.data_trans == 'navie': 390 | trans_x = T.Compose([T.Resize(args.resize, Image.ANTIALIAS), 391 | T.ToTensor()]) 392 | else: 393 | trans_x = T.Compose([T.Resize(args.resize, Image.ANTIALIAS), 394 | T.ToTensor(), 395 | T.Normalize(mean=[0.485, 0.456, 0.406], 396 | std=[0.229, 0.224, 0.225])]) 397 | 398 | image_aucs = [] 399 | pixel_aucs = [] 400 | pro_30s = [] 401 | for class_name in MVTec_CLASS_NAMES: 402 | torch.cuda.empty_cache() 403 | trainset = MVTecDataset(root_path=args.data_root, is_train=True, class_name=class_name, resize=args.resize, 404 | trans=trans_x) 405 | trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True, 406 | num_workers=4) 407 | 408 | testset = MVTecDataset(root_path=args.data_root, is_train=False, class_name=class_name, resize=args.resize, 409 | trans=trans_x) 410 | testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4) 411 | 412 | model = DFP_AD(type="vgg") 413 | model.register(class_name=class_name, trainloader=trainloader, testloader=testloader, loss_type=args.loss_type, 414 | data_trans=args.data_trans, size=args.resize, device=device, 415 | latent_dim=args.latent_dim, 416 | lr2=args.lr2, lr3=args.lr3, lr4=args.lr4, weight_decay=args.weight_decay, 417 | seed=args.seed) 418 | if args.train: 419 | model.train(epochs=args.epochs) 420 | # else: 421 | # model.load_student_weight() 422 | metrix = model.test(cal_pro=True) 423 | image_aucs.append(metrix["all"][0]) 424 | pixel_aucs.append(metrix["all"][1]) 425 | pro_30s.append(metrix["all"][2]) 426 | logger.info(f"{class_name}, image auc: {metrix['all'][0]:.5f}, pixel auc: {metrix['all'][1]:.5f}, pixel pro0.3: {metrix['all'][2]:.5f}") 427 | 428 | i_auc = np.mean(np.array(image_aucs)) 429 | p_auc = np.mean(np.array(pixel_aucs)) 430 | pro_auc = np.mean(np.array(pro_30s)) 431 | logger.info(f"total, image AUC: {i_auc:.5f} | pixel AUC: {p_auc:.5f} | pixel PROo.3: {pro_auc:.5f}") 432 | -------------------------------------------------------------------------------- /MB-PFM-ResNet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Unsupervised Image Anomaly Detection and Segmentation Based on Pre-trained Feature Mapping 3 | ''' 4 | import shutil 5 | 6 | from torch.utils.tensorboard import SummaryWriter 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101 11 | from torchvision.models.vgg import vgg16_bn, vgg19_bn 12 | import os 13 | import numpy as np 14 | from sklearn.metrics import roc_auc_score 15 | from utils import visualize, set_seed, cal_pro_metric_new 16 | from loguru import logger 17 | import argparse 18 | import time 19 | from scipy.ndimage import gaussian_filter 20 | import matplotlib.pyplot as plt 21 | 22 | 23 | 24 | class PretrainedModel(nn.Module): 25 | def __init__(self, model_name): 26 | super(PretrainedModel, self).__init__() 27 | if "resnet" in model_name: 28 | model = eval(model_name)(pretrained=True) 29 | modules = list(model.children()) 30 | self.block1 = nn.Sequential(*modules[0:4]) 31 | self.block2 = modules[4] 32 | self.block3 = modules[5] 33 | self.block4 = modules[6] 34 | self.block5 = modules[7] 35 | elif "vgg" in model_name: 36 | if model_name == "vgg16_bn": 37 | self.block1 = nn.Sequential(*self.modules[0:14]) 38 | self.block2 = nn.Sequential(*self.modules[14:23]) 39 | self.block3 = nn.Sequential(*self.modules[23:33]) 40 | self.block4 = nn.Sequential(*self.modules[33:43]) 41 | else: 42 | self.block1 = nn.Sequential(*self.modules[0:14]) 43 | self.block2 = nn.Sequential(*self.modules[14:26]) 44 | self.block3 = nn.Sequential(*self.modules[26:39]) 45 | self.block4 = nn.Sequential(*self.modules[39:52]) 46 | else: 47 | raise NotImplementedError 48 | 49 | def forward(self, x): 50 | # B x 64 x 64 x 64 51 | out1 = self.block1(x) 52 | # B x 128 x 32 x 32 53 | out2 = self.block2(out1) 54 | # B x 256 x 16 x 16 55 | # 32x32x128 56 | out3 = self.block3(out2) 57 | # 16x16x256 58 | out4 = self.block4(out3) 59 | return {"out2": out2, 60 | "out3": out3, 61 | "out4": out4 62 | } 63 | 64 | class Conv_BN_Relu(nn.Module): 65 | def __init__(self, in_dim, out_dim, k=1, s=1, p=0, bn=True, relu=True): 66 | super(Conv_BN_Relu, self).__init__() 67 | self.conv = [ 68 | nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p), 69 | ] 70 | if bn: 71 | self.conv.append(nn.BatchNorm2d(out_dim)) 72 | if relu: 73 | self.conv.append(nn.ReLU(inplace=True)) 74 | 75 | self.conv = nn.Sequential(*self.conv) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | 80 | 81 | class DualProjectionNet(nn.Module): 82 | def __init__(self, in_dim=512, out_dim=512, latent_dim=256): 83 | super(DualProjectionNet, self).__init__() 84 | self.encoder1 = nn.Sequential(*[ 85 | Conv_BN_Relu(in_dim, in_dim//2+latent_dim), 86 | Conv_BN_Relu(in_dim//2+latent_dim, 2*latent_dim), 87 | # Conv_BN_Relu(2*latent_dim, latent_dim), 88 | ]) 89 | 90 | self.shared_coder = Conv_BN_Relu(2*latent_dim, latent_dim, bn=False, relu=False) 91 | 92 | self.decoder1 = nn.Sequential(*[ 93 | Conv_BN_Relu(latent_dim, 2*latent_dim), 94 | Conv_BN_Relu(2*latent_dim, out_dim//2+latent_dim), 95 | Conv_BN_Relu(out_dim//2+latent_dim, out_dim, bn=False, relu=False), 96 | ]) 97 | 98 | 99 | self.encoder2 = nn.Sequential(*[ 100 | Conv_BN_Relu(out_dim, out_dim // 2 + latent_dim), 101 | Conv_BN_Relu(out_dim // 2 + latent_dim, 2 * latent_dim), 102 | # Conv_BN_Relu(2 * latent_dim, latent_dim), 103 | ]) 104 | 105 | self.decoder2 = nn.Sequential(*[ 106 | Conv_BN_Relu(latent_dim, 2 * latent_dim), 107 | Conv_BN_Relu(2 * latent_dim, in_dim // 2 + latent_dim), 108 | Conv_BN_Relu(in_dim // 2 + latent_dim, in_dim, bn=False, relu=False), 109 | ]) 110 | 111 | 112 | def forward(self, xs, xt): 113 | xt_hat = self.encoder1(xs) 114 | xt_hat = self.shared_coder(xt_hat) 115 | xt_hat = self.decoder1(xt_hat) 116 | 117 | xs_hat = self.encoder2(xt) 118 | xs_hat = self.shared_coder(xs_hat) 119 | xs_hat = self.decoder2(xs_hat) 120 | 121 | return xs_hat, xt_hat 122 | 123 | 124 | class DFP_AD(object): 125 | def __init__(self, agent_S='resnet50', agent_T="resnet101"): 126 | self.s_name = agent_S 127 | self.t_name = agent_T 128 | if agent_S == "resnet18" or agent_S == "resnet34": 129 | self.Agent1 = PretrainedModel(model_name=agent_S) 130 | self.indim = [64, 128, 256] 131 | # self.outdim = [50, 100, 200] 132 | elif agent_S == "resnet50": 133 | self.Agent1 = PretrainedModel(model_name=agent_S) 134 | self.indim = [256, 512, 1024] 135 | # self.Agent2 = PretrainedModel(model_name="resnet34") 136 | if agent_T == "resnet50" or agent_T == "resnet101": 137 | # self.Agent1 = PretrainedModel(model_name="vgg16") 138 | self.Agent2 = PretrainedModel(model_name=agent_T) 139 | self.outdim = [256, 512, 1024] 140 | self.latent_dim = [200, 400, 900] 141 | 142 | def register(self, **kwargs): 143 | self.class_name = kwargs['class_name'] 144 | self.device = kwargs['device'] 145 | self.trainloader = kwargs['trainloader'] 146 | self.testloader = kwargs['testloader'] 147 | 148 | self.projector2 = DualProjectionNet(in_dim=self.indim[0], out_dim=self.outdim[0], latent_dim=self.latent_dim[0]) 149 | self.optimizer2 = torch.optim.Adam(self.projector2.parameters(), lr=kwargs["lr2"], weight_decay=kwargs["weight_decay"]) 150 | self.projector3 = DualProjectionNet(in_dim=self.indim[1], out_dim=self.outdim[1], latent_dim=self.latent_dim[1]) 151 | self.optimizer3 = torch.optim.Adam(self.projector3.parameters(), lr=kwargs["lr3"], weight_decay=kwargs["weight_decay"]) 152 | self.projector4 = DualProjectionNet(in_dim=self.indim[2], out_dim=self.outdim[2], latent_dim=self.latent_dim[2]) 153 | self.optimizer4 = torch.optim.Adam(self.projector4.parameters(), lr=kwargs["lr4"], weight_decay=kwargs["weight_decay"]) 154 | 155 | self.Agent1.to(self.device).eval() 156 | self.Agent2.to(self.device).eval() 157 | 158 | self.projector2.to(self.device) 159 | self.projector3.to(self.device) 160 | self.projector4.to(self.device) 161 | 162 | self.save_root = "./result/MB-PFM_{}-{}_{}/".format(self.s_name, self.t_name, kwargs["seed"]) 163 | os.makedirs(os.path.join(self.save_root, "ckpt"), exist_ok=True) 164 | self.ckpt2 = os.path.join(self.save_root, "ckpt/{}_2.pth".format(kwargs["class_name"])) 165 | self.ckpt3 = os.path.join(self.save_root, "ckpt/{}_3.pth".format(kwargs["class_name"])) 166 | self.ckpt4 = os.path.join(self.save_root, "ckpt/{}_4.pth".format(kwargs["class_name"])) 167 | os.makedirs(os.path.join(self.save_root, "tblogs"), exist_ok=True) 168 | self.tblog = os.path.join(self.save_root, "tblogs/{}".format(kwargs["class_name"])) 169 | 170 | 171 | def get_agent_out(self, x): 172 | out_a1 = self.Agent1(x) 173 | out_a2 = self.Agent2(x) 174 | for key in out_a2.keys(): 175 | out_a1[key] = F.normalize(out_a1[key], p=2) 176 | out_a2[key] = F.normalize(out_a2[key], p=2) 177 | return out_a1, out_a2 178 | 179 | 180 | def train(self, epochs=100): 181 | if not os.path.exists(self.ckpt2): 182 | if os.path.exists(self.tblog): 183 | shutil.rmtree(self.tblog) 184 | os.makedirs(self.tblog, exist_ok=True) 185 | self.writer = SummaryWriter(log_dir=self.tblog) 186 | for ep in range(0, epochs): 187 | self.projector2.train() 188 | self.projector3.train() 189 | self.projector4.train() 190 | for i, (x, _, _, _) in enumerate(self.trainloader): 191 | x = x.to(self.device) 192 | out_a1, out_a2 = self.get_agent_out(x) 193 | 194 | # project_out2 = self.projector2(out_a1["out2"].detach()) 195 | # loss2 = torch.mean((out_a2["out2"].detach() - project_out2) ** 2) 196 | project_out21, project_out22 = self.projector2(out_a1["out2"].detach(), out_a2["out2"].detach()) 197 | loss21 = torch.mean((out_a1["out2"] - project_out21) ** 2) 198 | loss22 = torch.mean((out_a2["out2"] - project_out22) ** 2) 199 | loss2 = loss21 + loss22 200 | self.optimizer2.zero_grad() 201 | loss2.backward() 202 | self.optimizer2.step() 203 | 204 | project_out31, project_out32 = self.projector3(out_a1["out3"].detach(), out_a2["out3"].detach()) 205 | loss31 = torch.mean((out_a1["out3"].detach() - project_out31) ** 2) 206 | loss32 = torch.mean((out_a2["out3"].detach() - project_out32) ** 2) 207 | loss3 = loss31 + loss32 208 | self.optimizer3.zero_grad() 209 | loss3.backward() 210 | self.optimizer3.step() 211 | 212 | project_out41, project_out42 = self.projector4(out_a1["out4"].detach(), out_a2["out4"].detach()) 213 | loss41 = torch.mean((out_a1["out4"].detach() - project_out41) ** 2) 214 | loss42 = torch.mean((out_a2["out4"].detach() - project_out42) ** 2) 215 | loss4 = loss41 + loss42 216 | self.optimizer4.zero_grad() 217 | loss4.backward() 218 | self.optimizer4.step() 219 | 220 | print(f"Epoch-{ep}-Step-{i}, {self.class_name} | loss2: {loss2.item():.6f} | loss3: {loss3.item():.6f} | loss4: {loss4.item():.6f}") 221 | self.writer.add_scalar('Train/loss2', loss2.item(), ep*len(self.trainloader)+i) 222 | self.writer.add_scalar('Train/loss3', loss3.item(), ep*len(self.trainloader)+i) 223 | self.writer.add_scalar('Train/loss4', loss4.item(), ep*len(self.trainloader)+i) 224 | # print(f"Epoch-{ep}-Step-{i}, {self.class_name} | loss: {loss.item():.5f} | loss1: {loss1.item():.5f} | loss2: {loss2.item():.5f}") 225 | # self.writer.add_scalar('Train/loss_l2', loss1.item(), ep*len(self.trainloader)+i) 226 | # self.writer.add_scalar('Train/loss_norm_l2', loss2.item(), ep*len(self.trainloader)+i) 227 | 228 | torch.save(self.projector2.state_dict(), self.ckpt2) 229 | torch.save(self.projector3.state_dict(), self.ckpt3) 230 | torch.save(self.projector4.state_dict(), self.ckpt4) 231 | if ep % 10 == 0: 232 | metrix = self.test(cal_pro=False) 233 | logger.info(f"Epoch-{ep}, {self.class_name} | all: {metrix['all'][0]:.5f}, {metrix['all'][1]:.5f} | 2: {metrix['2'][0]:.5f}, {metrix['2'][1]:.5f}" 234 | f"| 3: {metrix['3'][0]:.5f}, {metrix['3'][1]:.5f} | 4: {metrix['4'][0]:.5f}, {metrix['4'][1]:.5f}") 235 | self.writer.add_scalar('Val/imge_auc2', metrix['2'][0], ep) 236 | self.writer.add_scalar('Val/pixel_auc2', metrix['2'][1], ep) 237 | self.writer.add_scalar('Val/imge_auc3', metrix['3'][0], ep) 238 | self.writer.add_scalar('Val/pixel_auc3', metrix['3'][1], ep) 239 | self.writer.add_scalar('Val/imge_auc4', metrix['4'][0], ep) 240 | self.writer.add_scalar('Val/pixel_auc4', metrix['4'][1], ep) 241 | self.writer.add_scalar('Val/imge_auc', metrix['all'][0], ep) 242 | self.writer.add_scalar('Val/pixel_auc', metrix['all'][1], ep) 243 | self.writer.close() 244 | else: 245 | pass 246 | 247 | 248 | def statistic_var(self, c=False): 249 | if c: 250 | self.var21 = 0 251 | self.var22 = 0 252 | self.var31 = 0 253 | self.var32 = 0 254 | self.var41 = 0 255 | self.var42 = 0 256 | with torch.no_grad(): 257 | for i, (x, _, _, _) in enumerate(self.trainloader): 258 | torch.cuda.empty_cache() 259 | x = x.to(self.device) 260 | out_a1, out_a2 = self.get_agent_out(x) 261 | project_out21, project_out22 = self.projector2(out_a1["out2"], out_a2["out2"]) 262 | var21 = (out_a1["out2"] - project_out21) ** 2 263 | var22 = (out_a2["out2"] - project_out22) ** 2 264 | self.var21 += torch.mean(var21, dim=0, keepdim=True) 265 | self.var22 += torch.mean(var22, dim=0, keepdim=True) 266 | 267 | project_out31, project_out32 = self.projector3(out_a1["out3"], out_a2["out3"]) 268 | var31 = (out_a1["out3"] - project_out31) ** 2 269 | var32 = (out_a2["out3"] - project_out32) ** 2 270 | self.var31 += torch.mean(var31, dim=0, keepdim=True) 271 | self.var32 += torch.mean(var32, dim=0, keepdim=True) 272 | 273 | project_out41, project_out42 = self.projector4(out_a1["out4"], out_a2["out4"]) 274 | var41 = (out_a1["out4"] - project_out41) ** 2 275 | var42 = (out_a2["out4"] - project_out42) ** 2 276 | self.var41 += torch.mean(var41, dim=0, keepdim=True) 277 | self.var42 += torch.mean(var42, dim=0, keepdim=True) 278 | 279 | self.var21 /= len(self.trainloader) 280 | self.var22 /= len(self.trainloader) 281 | self.var31 /= len(self.trainloader) 282 | self.var32 /= len(self.trainloader) 283 | self.var41 /= len(self.trainloader) 284 | else: 285 | self.var21 = 1 286 | self.var22 = 1 287 | self.var31 = 1 288 | self.var32 = 1 289 | self.var41 = 1 290 | self.var42 = 1 291 | 292 | 293 | def test(self, cal_pro=False): 294 | self.load_project_model() 295 | self.projector2.eval() 296 | self.projector3.eval() 297 | self.projector4.eval() 298 | 299 | self.statistic_var() 300 | 301 | with torch.no_grad(): 302 | 303 | test_y_list = [] 304 | test_mask_list = [] 305 | test_img_list = [] 306 | test_img_name_list = [] 307 | # pixel-level 308 | score_map_list = [] 309 | score_list = [] 310 | 311 | score2_map_list = [] 312 | score2_list = [] 313 | score3_map_list = [] 314 | score3_list = [] 315 | score4_map_list = [] 316 | score4_list = [] 317 | 318 | start_t = time.time() 319 | for x, y, mask, name in self.testloader: 320 | test_y_list.extend(y.detach().cpu().numpy()) 321 | test_mask_list.extend(mask.detach().cpu().numpy()) 322 | test_img_list.extend(x.detach().cpu().numpy()) 323 | test_img_name_list.extend(name) 324 | 325 | x = x.to(self.device) 326 | _, _, H, W = x.shape 327 | out_a1, out_a2 = self.get_agent_out(x) 328 | 329 | project_out21, project_out22 = self.projector2(out_a1["out2"], out_a2["out2"]) 330 | loss21_map = torch.sum((out_a1["out2"] - project_out21) ** 2 / self.var21, dim=1, keepdim=True) 331 | loss22_map = torch.sum((out_a2["out2"] - project_out22) ** 2 / self.var22, dim=1, keepdim=True) 332 | 333 | loss2_map = (loss21_map + loss22_map) / 2.0 334 | score2_map = F.interpolate(loss2_map, size=(H, W), mode='bilinear', align_corners=False) 335 | score2_map = score2_map.cpu().detach().numpy() 336 | score2_map_list.extend(score2_map) 337 | score2_list.extend(np.squeeze(np.max(np.max(score2_map, axis=2), axis=2), 1)) 338 | 339 | project_out31, project_out32 = self.projector3(out_a1["out3"], out_a2["out3"]) 340 | loss31_map = torch.sum((out_a1["out3"] - project_out31) ** 2 / self.var31, dim=1, keepdim=True) 341 | loss32_map = torch.sum((out_a2["out3"] - project_out32) ** 2 / self.var32, dim=1, keepdim=True) 342 | 343 | loss3_map = (loss31_map + loss32_map) / 2.0 344 | score3_map = F.interpolate(loss3_map, size=(H, W), mode='bilinear', align_corners=False) 345 | score3_map = score3_map.cpu().detach().numpy() 346 | score3_map_list.extend(score3_map) 347 | score3_list.extend(np.squeeze(np.max(np.max(score3_map, axis=2), axis=2), 1)) 348 | 349 | 350 | project_out41, project_out42 = self.projector4(out_a1["out4"], out_a2["out4"]) 351 | loss41_map = torch.sum((out_a1["out4"] - project_out41) ** 2 / self.var41, dim=1, keepdim=True) 352 | loss42_map = torch.sum((out_a2["out4"] - project_out42) ** 2 / self.var42, dim=1, keepdim=True) 353 | 354 | loss4_map = (loss41_map + loss42_map) / 2.0 355 | score4_map = F.interpolate(loss4_map, size=(H, W), mode='bilinear', align_corners=False) 356 | score4_map = score4_map.cpu().detach().numpy() 357 | score4_map_list.extend(score4_map) 358 | score4_list.extend(np.squeeze(np.max(np.max(score4_map, axis=2), axis=2), 1)) 359 | 360 | score_map = (score4_map + score3_map + score2_map) / 3 361 | # score_map = gaussian_filter(score_map.squeeze(), sigma=4) 362 | 363 | score_map_list.extend(score_map) 364 | # score_list.extend(np.squeeze(np.max(np.max(score_map, axis=2), axis=2), 1)) 365 | score_map = np.reshape(score_map, (score_map.shape[0], -1)) 366 | score_list.extend(np.max(score_map, 1)) 367 | 368 | end_t = time.time() 369 | t_per_imge = end_t - start_t 370 | t_per_imge = t_per_imge / len(score_list) 371 | 372 | visualize(test_img_list, test_mask_list, score_map_list, test_img_name_list, self.class_name, 373 | f"{self.save_root}image/", 10000) 374 | 375 | # ROCAUC 376 | # imge_auc2, pixel_auc2, pixel_pro2 = self.cal_auc(score2_list, score2_map_list, test_y_list, test_mask_list) 377 | # imge_auc3, pixel_auc3, pixel_pro3 = self.cal_auc(score3_list, score3_map_list, test_y_list, test_mask_list) 378 | # imge_auc4, pixel_auc4, pixel_pro4 = self.cal_auc(score4_list, score4_map_list, test_y_list, test_mask_list) 379 | imge_auc, pixel_auc, pixel_pro = self.cal_auc(score_list, score_map_list, test_y_list, test_mask_list) 380 | # print(f"pixel AUC: {pixel_level_ROCAUC:.5f}") 381 | # metrix = {"2": [imge_auc2, pixel_auc2, pixel_pro2], "3": [imge_auc3, pixel_auc3, pixel_pro3],"4":[imge_auc4, 382 | # pixel_auc4, pixel_pro4], "all": [imge_auc, pixel_auc, pixel_pro], "time": t_per_imge} 383 | metrix = {"all": [imge_auc, pixel_auc, pixel_pro], "time": t_per_imge} 384 | 385 | 386 | # test_y_list = np.array(test_y_list) 387 | # score_list = np.array(score_list) 388 | # p_index = test_y_list == 1 389 | # n_index = test_y_list == 0 390 | 391 | # p_score = score_list[p_index] 392 | # n_score = score_list[n_index] 393 | # data = [n_score, p_score] 394 | # np.save(os.path.join(self.save_root, f"{self.class_name}_image.npy"), (p_score, n_score)) 395 | # # import seaborn as sns 396 | # # sns.boxplot(data=data) 397 | # # sns.histplot(p_score, kde=False, color="r") 398 | # # sns.histplot(n_score, kde=False, color="b") 399 | 400 | # # plt.show() 401 | return metrix 402 | 403 | def cal_auc(self, score_list, score_map_list, test_y_list, test_mask_list): 404 | flatten_y_list = np.array(test_y_list).ravel() 405 | flatten_score_list = np.array(score_list).ravel() 406 | image_level_ROCAUC = roc_auc_score(flatten_y_list, flatten_score_list) 407 | 408 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 409 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 410 | pixel_level_ROCAUC = roc_auc_score(flatten_mask_list, flatten_score_map_list) 411 | # pro_auc_score = 0 412 | pro_auc_score = cal_pro_metric_new(test_mask_list, score_map_list, fpr_thresh=0.3) 413 | return image_level_ROCAUC, pixel_level_ROCAUC, pro_auc_score 414 | 415 | def load_project_model(self): 416 | self.projector2.load_state_dict(torch.load(self.ckpt2)) 417 | self.projector3.load_state_dict(torch.load(self.ckpt3)) 418 | self.projector4.load_state_dict(torch.load(self.ckpt4)) 419 | 420 | def center_crop(img, dim): 421 | 422 | """Returns center cropped image 423 | Args: 424 | img: image to be center cropped 425 | dim: dimensions (width, height) to be cropped 426 | """ 427 | width, height = img.shape[1], img.shape[0] 428 | 429 | # process crop width and height for max available dimension 430 | crop_width = dim[0] if dim[0]