├── .gitignore
├── pdf
├── PEFM.bmp
└── PFM.bmp
├── run_PFM.sh
├── run_PEFM.sh
├── dataset
├── dagm.py
├── MVTec3D_IMG.py
├── mvtec.py
└── stc.py
├── README.md
├── models.py
├── utils.py
├── MB-PFM-VGG.py
├── MB-PFM-ResNet.py
├── PEFM_MVTec3D_PE.py
├── PEFM_AD_PE_Cat.py
└── PEFM_AD.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.pyc
3 | *.pth
4 | *.png
5 | result/
6 | result_3D/
--------------------------------------------------------------------------------
/pdf/PEFM.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smiler96/PFM-and-PEFM-for-Image-Anomaly-Detection-and-Segmentation/HEAD/pdf/PEFM.bmp
--------------------------------------------------------------------------------
/pdf/PFM.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smiler96/PFM-and-PEFM-for-Image-Anomaly-Detection-and-Segmentation/HEAD/pdf/PFM.bmp
--------------------------------------------------------------------------------
/run_PFM.sh:
--------------------------------------------------------------------------------
1 | # Training
2 | # python MB-PFM-ResNet.py --train --gpu_id 0 --batch_size 8 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_anomaly_detection/
3 |
4 | # Testing
5 | python MB-PFM-ResNet.py --gpu_id 0 --batch_size 1 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root D:/Dataset/mvtec_anomaly_detection/
--------------------------------------------------------------------------------
/run_PEFM.sh:
--------------------------------------------------------------------------------
1 | # Training for MVTec AD
2 | # python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 128 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/
3 | # python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/
4 | # python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 512 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/
5 |
6 | # Testing for MVTec AD
7 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 128 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/
8 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/
9 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 512 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/
10 |
11 |
12 | # Training for MVTec 3D AD
13 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_3d_anomaly_detection/
14 |
15 | # Testing for MVTec 3D AD
16 | python PEFM_AD.py --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_3d_anomaly_detection/
--------------------------------------------------------------------------------
/dataset/dagm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import cv2
4 | import numpy as np
5 | from loguru import logger
6 |
7 | from PIL import Image
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms as T
11 |
12 |
13 | DATA_ROOT = "D:/Dataset/DAGM2007/"
14 | DAGMM_CLASS = ["Class1", "Class2", "Class3", "Class4", "Class5", "Class6", "Class7", "Class8", "Class9", "Class10"]
15 | # DAGMM_CLASS = [ "Class1", "Class2", "Class3", "Class4", "Class5", "Class6"]
16 | class DAGMDataset(Dataset):
17 | def __init__(self, root_path=DATA_ROOT, class_name='Class1', is_train=True,
18 | resize=256):
19 |
20 | assert class_name in DAGMM_CLASS
21 |
22 | self.resize = resize
23 | self.root_path = root_path
24 | self.class_name = class_name
25 | self.is_train = is_train
26 |
27 | # set transforms
28 | self.transform_x = T.Compose([T.Resize(resize, Image.ANTIALIAS),
29 | T.ToTensor(),
30 | T.Normalize(mean=[0.485, 0.456, 0.406],
31 | std=[0.229, 0.224, 0.225])])
32 | self.transform_mask = T.Compose([T.Resize(resize, Image.NEAREST),
33 | T.ToTensor()])
34 |
35 | self.x, self.mask = self.load_image()
36 | self.len = len(self.x)
37 |
38 | def load_image(self):
39 | phase = "Train" if self.is_train else "Test"
40 | img_path = os.path.join(self.root_path, self.class_name, phase)
41 | label_path = os.path.join(self.root_path, self.class_name, phase, "Label")
42 | label_file = os.path.join(label_path, "Labels.txt")
43 | with open(label_file, 'r') as f:
44 | info = f.readlines()
45 | info = info[1:]
46 | for i in range(len(info)):
47 | info[i] = info[i].strip('\n').split('\t')
48 | # print(info)
49 |
50 | mask_list = []
51 | img_list = []
52 | if self.is_train:
53 | for s in info:
54 | if s[1] == '0':
55 | img_list.append(os.path.join(img_path, s[2]))
56 | else:
57 | for s in info:
58 | img_list.append(os.path.join(img_path, s[2]))
59 | if s[1] == "1":
60 | mask_list.append(os.path.join(label_path, s[4]))
61 | else:
62 | mask_list.append("None")
63 | # img_list.sort()
64 | # mask_list.sort()
65 | # self.len = len(img_list) // 3
66 | # new_img_list = img_list[len(img_list)-self.len:len(img_list)]
67 | # new_mask_list = mask_list[len(img_list)-self.len:len(img_list)]
68 | #
69 | # new_mask_list[self.len//2:self.len] = mask_list[len(img_list)-self.len//2:len(img_list)]
70 | # new_img_list[self.len//2:self.len] = img_list[len(img_list)-self.len//2:len(img_list)]
71 | #
72 | # new_mask_list[:self.len//2] = mask_list[:self.len//2]
73 | # new_img_list[:self.len//2] = img_list[:self.len//2]
74 | #
75 | # img_list = new_mask_list
76 | # mask_list = new_mask_list
77 | assert len(img_list) == len(mask_list)
78 | return img_list, mask_list
79 |
80 | def __len__(self):
81 | return self.len
82 |
83 | def __getitem__(self, idx):
84 | x = self.x[idx]
85 | name = os.path.basename(x)
86 | x = Image.open(x).convert('RGB')
87 | x = self.transform_x(x)
88 |
89 | if self.is_train:
90 | mask = torch.zeros([1, self.resize, self.resize])
91 | y = 0
92 | else:
93 | mask = self.mask[idx]
94 | if mask != "None":
95 | mask = Image.open(mask)
96 | mask = self.transform_mask(mask)
97 | y = 1
98 | else:
99 | mask = torch.zeros([1, x.shape[1], x.shape[2]])
100 | y = 0
101 |
102 | return x, y, mask, name
103 |
104 | if __name__ == "__main__":
105 | stc = DAGMDataset(is_train=False)
106 | data = stc[0]
107 | print(data)
--------------------------------------------------------------------------------
/dataset/MVTec3D_IMG.py:
--------------------------------------------------------------------------------
1 | import cv2.cv2
2 | import matplotlib
3 |
4 | matplotlib.use("Agg")
5 | import torch
6 | from torch.utils.data import Dataset
7 | import os
8 | import glob
9 | from PIL import Image
10 | import numpy as np
11 | from skimage.segmentation import slic, mark_boundaries
12 | from torchvision import transforms
13 | from imageio import imsave
14 | # imagenet
15 | mean_train = [0.485, 0.456, 0.406]
16 | std_train = [0.229, 0.224, 0.225]
17 |
18 | MVTEC_CLASSES=['bagel', 'cable_gland','carrot', 'cookie','dowel',
19 | 'peach', 'potato', 'rope', 'tire' , 'foam']
20 |
21 | def denormalization(x):
22 | x = (((x.transpose(1, 2, 0) * std_train) + mean_train) * 255.).astype(np.uint8)
23 | return x
24 |
25 |
26 | class MVTec3DDataset_IMG(Dataset):
27 | def __init__(self, root, transform, gt_transform, phase):
28 | if phase=='train':
29 | self.img_path = os.path.join(root, 'train')
30 | else:
31 | self.img_path = os.path.join(root, 'test')
32 | self.gt_path = os.path.join(root, 'test')
33 |
34 | self.transform = transform
35 | self.gt_transform = gt_transform
36 | # load dataset
37 | self.img_paths, self.gt_paths, self.labels, self.types = self.load_dataset() # self.labels => good : 0, anomaly : 1
38 |
39 | def load_dataset(self):
40 |
41 | img_tot_paths = []
42 | gt_tot_paths = []
43 | tot_labels = []
44 | tot_types = []
45 |
46 | defect_types = os.listdir(self.img_path)
47 |
48 | for defect_type in defect_types:
49 | if defect_type == 'good':
50 | img_paths = glob.glob(os.path.join(self.img_path, defect_type,'rgb') + "/*.png")
51 | img_tot_paths.extend(img_paths)
52 | gt_tot_paths.extend([0]*len(img_paths))
53 | tot_labels.extend([0]*len(img_paths))
54 | tot_types.extend(['good']*len(img_paths))
55 | else:
56 | img_paths = glob.glob(os.path.join(self.img_path, defect_type,'rgb') + "/*.png")
57 | gt_paths = glob.glob(os.path.join(self.gt_path, defect_type, 'gt') + "/*.png")
58 | img_paths.sort()
59 | gt_paths.sort()
60 | img_tot_paths.extend(img_paths)
61 | gt_tot_paths.extend(gt_paths)
62 | tot_labels.extend([1]*len(img_paths))
63 | tot_types.extend([defect_type]*len(img_paths))
64 |
65 | assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
66 |
67 | return img_tot_paths, gt_tot_paths, tot_labels, tot_types
68 |
69 | def __len__(self):
70 | return len(self.img_paths)
71 |
72 | def __getitem__(self, idx):
73 | img_path, gt, label, img_type = self.img_paths[idx], self.gt_paths[idx], self.labels[idx], self.types[idx]
74 | img = Image.open(img_path).convert('RGB')
75 | img = self.transform(img)
76 | names = img_path.split("\\")
77 | name = names[-3] + "_" + names[-1]
78 | if gt == 0:
79 | gt = torch.zeros([1, img.size()[-2], img.size()[-2]])
80 | else:
81 | gt = Image.open(gt)
82 | gt = self.gt_transform(gt)
83 | gt[gt > 0.5] = 1
84 | gt[gt <= 0.5] = 0
85 |
86 | assert img.size()[1:] == gt.size()[1:], "image.size != gt.size !!!"
87 |
88 | return img, label, gt, name
89 |
90 |
91 |
92 |
93 |
94 | if __name__ == '__main__':
95 | categories = MVTEC_CLASSES
96 | dataset_path = r'../datasets/mvtec_anomaly_detection'
97 | load_size = 256
98 | input_size = 256
99 | data_transforms = transforms.Compose([
100 | transforms.Resize((load_size, load_size), Image.ANTIALIAS),
101 | transforms.ToTensor(),
102 | transforms.CenterCrop(input_size),
103 | transforms.Normalize(mean=mean_train,
104 | std=std_train)])
105 | gt_transforms = transforms.Compose([
106 | transforms.Resize((load_size, load_size), Image.NEAREST),
107 | transforms.ToTensor(),
108 | transforms.CenterCrop(input_size)])
109 |
110 | for category in categories:
111 | phase = 'train'
112 | dataset = MVTecDatasetSpxl(root=os.path.join(dataset_path, category),
113 | transform=data_transforms, gt_transform=gt_transforms, phase=phase)
114 |
115 | for img, gt, label, name, img_type, spxl_label in dataset:
116 | save_folder = os.path.join(dataset_path, 'spxls', category, phase)
117 | save_name = os.path.join(save_folder, f'{name}.bmp')
118 | os.makedirs(save_folder, exist_ok=True)
119 |
120 |
121 | pass
122 |
123 |
124 | pass
125 |
--------------------------------------------------------------------------------
/dataset/mvtec.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms as T
7 |
8 | MVTec_CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid',
9 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
10 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
11 | # MVTec_CLASS_NAMES = [ 'pill', 'screw',
12 | # 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
13 |
14 | class MVTecDataset(Dataset):
15 | def __init__(self, root_path='D:/Dataset/mvtec_anomaly_detection/', class_name='bottle', is_train=True, resize=256, trans=None, LOAD_CPU=False):
16 | assert class_name in MVTec_CLASS_NAMES, 'class_name: {}, should be in {}'.format(class_name, MVTec_CLASS_NAMES)
17 |
18 | self.root_path = root_path
19 | self.class_name = class_name
20 | self.is_train = is_train
21 | self.resize = resize
22 |
23 | # load dataset
24 | self.x, self.y, self.mask = self.load_dataset_folder()
25 |
26 | # set transforms
27 | if trans is None:
28 | self.transform_x = T.Compose([T.Resize(resize, Image.ANTIALIAS),
29 | T.ToTensor(),
30 | T.Normalize(mean=[0.485, 0.456, 0.406],
31 | std=[0.229, 0.224, 0.225])])
32 | else:
33 | self.transform_x = trans
34 | # self.transform_x = T.Compose([T.Resize(resize, Image.ANTIALIAS),
35 | # T.ToTensor()])
36 |
37 |
38 | self.transform_mask = T.Compose([T.Resize(resize, Image.NEAREST),
39 | T.ToTensor()])
40 |
41 | self.load_cpu = LOAD_CPU
42 | self.len = len(self.x)
43 | self.x_cpu = []
44 | self.y_cpu = []
45 | self.name = []
46 | self.mask_cpu = []
47 | if self.load_cpu:
48 | for i in range(self.len):
49 | names = self.x[i].split("/")
50 | name = names[-2] + "_" + names[-1]
51 | self.name.append(name)
52 | x = Image.open(self.x[i]).convert('RGB')
53 | x = self.transform_x(x)
54 | self.x_cpu.append(x)
55 |
56 | if self.y[i] == 0:
57 | mask = torch.zeros([1, self.resize, self.resize])
58 | else:
59 | mask = Image.open(self.mask[i])
60 | mask = self.transform_mask(mask)
61 | self.mask_cpu.append(mask)
62 | self.y_cpu.append(self.y[i])
63 |
64 | def __getitem__(self, idx):
65 | if self.load_cpu:
66 | x, y, mask, name = self.x_cpu[idx], self.y_cpu[idx], self.mask_cpu[idx], self.name[idx]
67 | else:
68 | x, y, mask = self.x[idx], self.y[idx], self.mask[idx]
69 | # names = x.split("/")
70 | names = x.split("\\")
71 | name = names[-2] + "_" + names[-1]
72 | x = Image.open(x).convert('RGB')
73 | x = self.transform_x(x)
74 |
75 | if y == 0:
76 | mask = torch.zeros([1, self.resize, self.resize])
77 | else:
78 | mask = Image.open(mask)
79 | mask = self.transform_mask(mask)
80 |
81 | return x, y, mask, name
82 |
83 | def __len__(self):
84 | return len(self.x)
85 |
86 | def load_dataset_folder(self):
87 | phase = 'train' if self.is_train else 'test'
88 | x, y, mask = [], [], []
89 |
90 | img_dir = os.path.join(self.root_path, self.class_name, phase)
91 | gt_dir = os.path.join(self.root_path, self.class_name, 'ground_truth')
92 |
93 | img_types = sorted(os.listdir(img_dir))
94 | for img_type in img_types:
95 |
96 | # load images
97 | img_type_dir = os.path.join(img_dir, img_type)
98 | if not os.path.isdir(img_type_dir):
99 | continue
100 | img_fpath_list = sorted([os.path.join(img_type_dir, f)
101 | for f in os.listdir(img_type_dir)
102 | if f.endswith('.png')])
103 | x.extend(img_fpath_list)
104 |
105 | # load gt labels
106 | if img_type == 'good':
107 | y.extend([0] * len(img_fpath_list))
108 | mask.extend([None] * len(img_fpath_list))
109 | else:
110 | y.extend([1] * len(img_fpath_list))
111 | gt_type_dir = os.path.join(gt_dir, img_type)
112 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list]
113 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '_mask.png')
114 | for img_fname in img_fname_list]
115 | mask.extend(gt_fpath_list)
116 |
117 | assert len(x) == len(y), 'number of x and y should be same'
118 |
119 | return list(x), list(y), list(mask)
120 |
121 |
122 | if __name__ == "__main__":
123 | dataset_ = MVTecDataset()
124 | a = dataset_[0]
125 | print(a)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PFM and PEFM for Image Anomaly Detection and Segmentation
2 |
3 | ## Abstract
4 |
5 | ### Unsupervised Image Anomaly Detection and Segmentation Based on Pre-trained Feature Mapping [(PFM-TII)](https://ieeexplore.ieee.org/document/9795121)
6 |
7 | Image anomaly detection and segmentation are important for the development of automatic product quality inspection in intelligent manufacturing. Because the normal data can be collected easily and abnormal ones are rarely existent, unsupervised methods based on reconstruction and embedding have been mainly studied for anomaly detection. But the detection performance and computing time requires to be further improved. This paper proposes a novel framework, named as Pre-trained Feature Mapping (PFM), for unsupervised image anomaly detection and segmentation. The proposed PFM maps the image from a pre-trained feature space to another one to detect the anomalies effectively. The bidirectional and multi-hierarchical bidirectional pre-trained feature mapping are further proposed and studied for improving the performance. The proposed framework achieves the better results on well-known MVTec AD dataset compared with state-of-the-art methods, with the area under the receiver operating characteristic curve of 97.5% for anomaly detection and of 97.3% for anomaly segmentation over all 15 categories. The proposed framework is also superior in term of the computing time. The extensive experiments on ablation studies are also conducted to show the effectiveness and efficient of the proposed framework.
8 |
9 |
10 |

11 |
12 |
13 | ### Position Encoding Enhanced Feature Mapping for Image Anomaly Detection [(PEFM CASE)](https://www.researchgate.net/publication/361254312_Position_Encoding_Enhanced_Feature_Mapping_for_Image_Anomaly_Detection)
14 |
15 | Image anomaly detection is an important stage for automatic visual inspection in intelligent manufacturing systems. The wide-ranging anomalies in images, such as various sizes, shapes, and colors, make automatic visual inspection challenging. Previous work on image anomaly detection has achieved significant advancements. However, there are still limitations in terms of detection performance and efficiency. In this paper, a novel Position Encoding enhanced Feature Mapping (PEFM) method is proposed to address the problem of image anomaly detection, detecting the anomalies by mapping a pair of pre-trained features embedded with position encodes. Experiment results show that the proposed PEFM achieves better performance and efficiency than the state-of-the-art methods on the MVTec AD dataset, an AUCROC of 98.30% and an AUCPRO of 95.52%, and achieves the AUCPRO of 94.0% on the MVTec 3D AD dataset.
16 |
17 |
18 |

19 |
20 |
21 | ## Using [](https://pytorch.org/)
22 |
23 | ### PFM (TII)
24 |
25 | ```python
26 | # Training
27 | python MB-PFM-ResNet.py --train --gpu_id 0 --batch_size 8 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_anomaly_detection/
28 | # Testing
29 | python MB-PFM-ResNet.py --gpu_id 0 --batch_size 1 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root D:/Dataset/mvtec_anomaly_detection/
30 | ```
31 |
32 | ### PEFM (CASE)
33 |
34 | ```python
35 | # Training for MVTec AD
36 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 128 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/
37 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/
38 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 512 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root /home/dlwanqian/data/mvtec_anomaly_detection/
39 |
40 | # Testing for MVTec AD
41 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 128 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/
42 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/
43 | python PEFM_AD.py --gpu_id 0 --batch_size 1 --resize 512 --data_trans imagenet --loss_type l2norm+l2 --pe_required --data_root D:/Dataset/mvtec_anomaly_detection/
44 |
45 |
46 | # Training for MVTec 3D AD
47 | python PEFM_AD.py --train --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_3d_anomaly_detection/
48 |
49 | # Testing for MVTec 3D AD
50 | python PEFM_AD.py --gpu_id 0 --batch_size 16 --epochs 200 --resize 256 --data_trans imagenet --loss_type l2norm+l2 --data_root /home/dlwanqian/data/mvtec_3d_anomaly_detection/
51 | ```
52 |
53 | ## Citation
54 |
55 | If there is any help for your work, please consider citing these papers:
56 |
57 | ```BibTeX
58 | @ARTICLE{PFM,
59 | author={Wan, Qian and Gao, Liang and Li, Xinyu and Wen, Long},
60 | journal={IEEE Transactions on Industrial Informatics},
61 | title={Unsupervised Image Anomaly Detection and Segmentation Based on Pre-trained Feature Mapping},
62 | year={2022},
63 | volume={},
64 | number={},
65 | pages={},
66 | doi={10.1109/TII.2022.3182385}
67 | }
68 | @INPROCEEDINGS{PEFM,
69 | author={Wan, Qian and Cao YunKang and Gao, Liang and Shen Weiming and Li, Xinyu},
70 | booktitle={2022 IEEE 18th International Conference on Automation Science and Engineering (CASE)},
71 | title={Position Encoding Enhanced Feature Mapping for Image Anomaly Detection},
72 | year={2022},
73 | volume={},
74 | number={},
75 | pages={},
76 | doi={}
77 | }
78 | ```
79 |
80 | ## Acknowledgment
81 |
82 | Thanks for the excellent work for:
83 | - [SPADE](https://github.com/byungjae89/SPADE-pytorch)
84 | - [PositionalEncoding2D](https://github.com/wzlxjtu/PositionalEncoding2D)
85 |
--------------------------------------------------------------------------------
/dataset/stc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import cv2
4 | import numpy as np
5 | from loguru import logger
6 |
7 | from PIL import Image
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms as T
11 |
12 |
13 | DATA_ROOT = "D:/Dataset/ShanghaiTech/"
14 | BIN_ROOT = os.path.join(DATA_ROOT, 'bin')
15 | os.makedirs(BIN_ROOT, exist_ok=True)
16 |
17 | INTERVAL = 5
18 |
19 | def capture_training_frames():
20 |
21 | video_path = os.path.join(DATA_ROOT, 'training', 'videos')
22 | video_list = glob.glob(video_path + '/*.avi')
23 | video_list.sort()
24 |
25 | frame_root = os.path.join(BIN_ROOT, "train", "image")
26 |
27 | for v in video_list:
28 | video_name = os.path.basename(v)
29 |
30 | scene_class = video_name.split('_')[0]
31 | frame_path = os.path.join(frame_root, scene_class)
32 | os.makedirs(frame_path, exist_ok=True)
33 |
34 | cap = cv2.VideoCapture(v)
35 | cnt = 0
36 | while (cap.isOpened()):
37 | ret, frame = cap.read()
38 | if ret == False:
39 | break
40 | if cnt % INTERVAL == 0:
41 | # cv2.imshow(video_name, frame)
42 | # cv2.waitKey(0)
43 | frame = cv2.resize(frame, (256, 256))
44 | frame_name = video_name.split('.')[0] + f"_{cnt}.png"
45 | cv2.imwrite(os.path.join(frame_path, frame_name), frame)
46 | logger.info(os.path.join(frame_path, frame_name))
47 | cnt += 1
48 | cap.release()
49 |
50 | def split_testing_frames():
51 | testing_frames_root = os.path.join(DATA_ROOT, "testing", "frames")
52 | testing_frame_mask_root = os.path.join(DATA_ROOT, "testing", "test_frame_mask")
53 | testing_pixel_mask_root = os.path.join(DATA_ROOT, "testing", "test_pixel_mask")
54 |
55 | img_root = os.path.join(BIN_ROOT, "test", 'image')
56 | gt_root = os.path.join(BIN_ROOT, "test", 'groundtruth')
57 |
58 | scene_folders = os.listdir(testing_frames_root)
59 | scene_folders.sort()
60 | for sf in scene_folders:
61 | scenne_class = sf.split("_")[0]
62 | img_path = os.path.join(img_root, scenne_class)
63 | os.makedirs(img_path, exist_ok=True)
64 | gt_path = os.path.join(gt_root, scenne_class)
65 | os.makedirs(gt_path, exist_ok=True)
66 |
67 | frames_path = os.path.join(testing_frames_root, sf)
68 | frames_list = glob.glob(frames_path + "/*.*")
69 | frames_list.sort()
70 |
71 | frames_pixel_masks = np.load(os.path.join(testing_pixel_mask_root, sf + ".npy"))
72 | # print(np.max(frames_pixel_masks))
73 | for cnt, f in enumerate(frames_list):
74 | if cnt % 1 == 0:
75 | # frame
76 | frame = cv2.imread(f)
77 | frame = cv2.resize(frame, (256, 256))
78 | frame_name = os.path.basename(f).split('.')[0] + '.png'
79 | cv2.imwrite(os.path.join(img_path, frame_name), frame)
80 | logger.info(os.path.join(img_path, frame_name))
81 |
82 | # gt
83 | gt = frames_pixel_masks[cnt] * 255
84 | gt = cv2.resize(gt, (256, 256), cv2.INTER_NEAREST)
85 | cv2.imwrite(os.path.join(gt_path, frame_name), gt)
86 |
87 |
88 | STC_CLASS = ["02", "01", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]
89 | # STC_CLASS = ["02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "13"]
90 | class STCDataset(Dataset):
91 | def __init__(self, root_path=BIN_ROOT, scene_class='01', is_train=True,
92 | resize=256, trans=None):
93 |
94 | assert scene_class in STC_CLASS
95 |
96 | self.root_path = root_path
97 | self.scene_class = scene_class
98 | self.is_train = is_train
99 |
100 | # set transforms
101 | if trans is None:
102 | self.transform_x = T.Compose([T.Resize(resize, Image.ANTIALIAS),
103 | T.ToTensor(),
104 | T.Normalize(mean=[0.485, 0.456, 0.406],
105 | std=[0.229, 0.224, 0.225])])
106 | else:
107 | self.transform_x = trans
108 | self.transform_mask = T.Compose([T.Resize(resize, Image.NEAREST),
109 | T.ToTensor()])
110 |
111 | self.x, self.mask = self.load_image()
112 | # self.len = len(self.x) // 3
113 | self.len = len(self.x)
114 |
115 | def load_image(self):
116 | if self.is_train:
117 | img_path = os.path.join(self.root_path, "train", "image", self.scene_class)
118 | img_list = glob.glob(img_path + "/*.png")
119 | mask_list = None
120 | else:
121 | img_path = os.path.join(self.root_path, "test", "image", self.scene_class)
122 | mask_path = os.path.join(self.root_path, "test", "groundtruth", self.scene_class)
123 | img_list = glob.glob(img_path + "/*.png")
124 | mask_list = glob.glob(mask_path + "/*.png")
125 | img_list.sort()
126 | mask_list.sort()
127 | assert len(img_list) == len(mask_list)
128 | return img_list, mask_list
129 |
130 | def __len__(self):
131 | return self.len
132 |
133 | def __getitem__(self, idx):
134 | x = self.x[idx]
135 | name = os.path.basename(x)
136 | x = Image.open(x).convert('RGB')
137 | x = self.transform_x(x)
138 |
139 | if self.is_train:
140 | _, H, W = x.shape
141 | mask = torch.zeros((1, H, W))
142 | else:
143 | mask = self.mask[idx]
144 | mask = Image.open(mask)
145 | mask = self.transform_mask(mask)
146 | y = torch.max(mask)
147 | return x, y, mask, name
148 |
149 | if __name__ == "__main__":
150 | # capture_training_frames()
151 | # split_testing_frames()
152 | stc = STCDataset(is_train=False)
153 | data = stc[120]
154 | print(data)
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.resnet import resnet18, resnet34
4 | from torchvision.models.vgg import vgg16_bn, vgg16
5 | # from resnet_no_relu import resnet18_nr
6 | import torch.nn.functional as F
7 |
8 | class VGG16(nn.Module):
9 | def __init__(self, pretrained=False, bn=False):
10 | super(VGG16, self).__init__()
11 | if bn:
12 | model = vgg16_bn(pretrained=pretrained)
13 | self.modules = list(model.features)
14 | self.block0 = nn.Sequential(*self.modules[0:17])
15 | self.block1 = nn.Sequential(*self.modules[17:23])
16 | self.block2 = nn.Sequential(*self.modules[23:33])
17 | self.block3 = nn.Sequential(*self.modules[33:43])
18 | else:
19 | model = vgg16(pretrained=pretrained)
20 | self.modules = list(model.features)
21 | self.block0 = nn.Sequential(*self.modules[0:12])
22 | self.block1 = nn.Sequential(*self.modules[12:16])
23 | self.block2 = nn.Sequential(*self.modules[16:23])
24 | self.block3 = nn.Sequential(*self.modules[23:30])
25 |
26 | def forward(self, x):
27 | # 64x64x256
28 | out0 = self.block0(x)
29 | # 64x64x256
30 | out1 = self.block1(out0)
31 | # 32x32x512
32 | out2 = self.block2(out1)
33 | # 16x16x512
34 | out3 = self.block3(out2)
35 | return {"out2": out0,
36 | "out3": out1,
37 | "out4": out2,
38 | "out5": out3
39 | }
40 |
41 |
42 | class ResNet18(nn.Module):
43 | def __init__(self, pretrained=False):
44 | super(ResNet18, self).__init__()
45 | model = resnet18(pretrained=pretrained)
46 |
47 | modules = list(model.children())
48 | self.block1 = nn.Sequential(*modules[0:4])
49 | self.block2 = modules[4]
50 | self.block3 = modules[5]
51 | self.block4 = modules[6]
52 | self.block5 = modules[7]
53 |
54 | def forward(self, x):
55 | x = self.block1(x)
56 | # 64x64x64
57 | out2 = self.block2(x)
58 | # 32x32x128
59 | out3 = self.block3(out2)
60 | # 16x16x256
61 | out4 = self.block4(out3)
62 | # 8x8x512
63 | out5 = self.block5(out4)
64 | return {"out2": out2,
65 | "out3": out3,
66 | "out4": out4,
67 | "out5": out5
68 | }
69 |
70 |
71 | # import copy
72 | # class ResNet18NR(nn.Module):
73 | # def __init__(self, pretrained=False):
74 | # super(ResNet18NR, self).__init__()
75 | # model_nr = resnet18_nr(pretrained=False)
76 | # modules_nr = list(model_nr.children())
77 |
78 |
79 | # self.block2_logvar = copy.copy(nn.Sequential(*modules_nr[0:5]))
80 | # self.block3_logvar = copy.copy(nn.Sequential(modules_nr[5]))
81 | # self.block4_logvar = copy.copy(nn.Sequential(modules_nr[6]))
82 |
83 | # def forward(self, x):
84 | # out2_logvar = self.block2_logvar(x)
85 | # out3_logvar = self.block3_logvar(out2_logvar)
86 | # out4_logvar = self.block4_logvar(out3_logvar)
87 | # # 8x8x512
88 | # return {"out2": out2_logvar,
89 | # "out3": out3_logvar,
90 | # "out4": out4_logvar
91 | # }
92 |
93 |
94 |
95 | class ResNet34(nn.Module):
96 | def __init__(self, pretrained=False):
97 | super(ResNet34, self).__init__()
98 | model = resnet34(pretrained=pretrained)
99 |
100 | modules = list(model.children())
101 | self.block1 = nn.Sequential(*modules[0:4])
102 | self.block2 = modules[4]
103 | self.block3 = modules[5]
104 | self.block4 = modules[6]
105 | self.block5 = modules[7]
106 |
107 | def forward(self, x):
108 | x = self.block1(x)
109 | # 64x64x64
110 | out2 = self.block2(x)
111 | # 32x32x128
112 | out3 = self.block3(out2)
113 | # 16x16x256
114 | out4 = self.block4(out3)
115 | # 8x8x512
116 | out5 = self.block5(out4)
117 | return {"out2": out2,
118 | "out3": out3,
119 | "out4": out4,
120 | "out5": out5
121 | }
122 |
123 |
124 | class KDLoss(nn.Module):
125 | def __init__(self, loss_type=None):
126 | super(KDLoss, self).__init__()
127 | '''
128 | loss type: l2, l1, consine, l2+consine, l2norm+l2
129 | '''
130 | self.type = loss_type
131 |
132 | def get_loss_map(self, feat_T, feat_S):
133 | '''
134 | :param feat_T: NxCxHxW
135 | :param feat_S: NxCxHxW
136 | :return:
137 | '''
138 | if self.type == "l2norm+l2":
139 | feat_T = F.normalize(feat_T, p=2, dim=1)
140 | feat_S = F.normalize(feat_S, p=2, dim=1)
141 |
142 | loss_map = 0.5 * ((feat_T - feat_S) ** 2)
143 | loss_map = torch.sum(loss_map, dim=1)
144 |
145 | elif self.type == "l1norm+l2":
146 | feat_T = F.normalize(feat_T, p=1, dim=1)
147 | feat_S = F.normalize(feat_S, p=1, dim=1)
148 |
149 | loss_map = 0.5 * ((feat_T - feat_S) ** 2)
150 | loss_map = torch.sum(loss_map, dim=1)
151 |
152 | elif self.type == "consine":
153 | feat_T = F.normalize(feat_T, p=2, dim=1)
154 | feat_S = F.normalize(feat_S, p=2, dim=1)
155 | loss_map = 1 - torch.sum(torch.mul(feat_T, feat_S), dim=1)
156 |
157 | elif self.type == "l2":
158 | loss_map = (feat_T - feat_S) ** 2
159 | loss_map = torch.sum(loss_map, dim=1)
160 |
161 | elif self.type == "l1":
162 | loss_map = torch.abs(feat_T - feat_S)
163 | loss_map = torch.sum(loss_map, dim=1)
164 |
165 | else:
166 | raise NotImplementedError
167 |
168 | return loss_map
169 |
170 | def forward(self, feat_T, feat_S):
171 | loss_map = self.get_loss_map(feat_T, feat_S)
172 | # if self.type == "consine":
173 | # return torch.mean(loss_map)
174 | # else:
175 | # return torch.sum(loss_map)
176 | return torch.sum(torch.mean(loss_map, dim=(1, 2)))
177 | # if use mean, must increase the learning rate
178 | # return torch.mean(loss_map)
179 |
180 |
181 | class Conv_BN_PRelu(nn.Module):
182 | def __init__(self, in_dim, out_dim, k=1, s=1, p=0, bn=True, prelu=True):
183 | super(Conv_BN_PRelu, self).__init__()
184 | self.conv = [
185 | nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p),
186 | ]
187 | if bn:
188 | self.conv.append(nn.BatchNorm2d(out_dim))
189 | if prelu:
190 | self.conv.append(nn.PReLU())
191 |
192 | self.conv = nn.Sequential(*self.conv)
193 |
194 | def forward(self, x):
195 | return self.conv(x)
196 |
197 |
198 | class NonLocalAttention(nn.Module):
199 | def __init__(self, channel=256, reduction=2, rescale=1.0):
200 | super(NonLocalAttention, self).__init__()
201 | # self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
202 | # self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
203 | # self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU())
204 |
205 | self.conv_match1 = Conv_BN_PRelu(channel, channel//reduction, 1, bn=False, prelu=True)
206 | self.conv_match2 = Conv_BN_PRelu(channel, channel//reduction, 1, bn=False, prelu=True)
207 | self.conv_assembly = Conv_BN_PRelu(channel, channel, 1,bn=False, prelu=True)
208 | self.rescale = rescale
209 |
210 | def forward(self, input):
211 | x_embed_1 = self.conv_match1(input)
212 | x_embed_2 = self.conv_match2(input)
213 | x_assembly = self.conv_assembly(input)
214 |
215 | N,C,H,W = x_embed_1.shape
216 | x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C))
217 | x_embed_2 = x_embed_2.view(N,C,H*W)
218 | score = torch.matmul(x_embed_1, x_embed_2)
219 | score = F.softmax(score, dim=2)
220 | x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1)
221 | x_final = torch.matmul(score, x_assembly)
222 | x_final = x_final.permute(0,2,1).view(N,-1,H,W)
223 | return x_final + input*self.rescale
224 | # return x_final
225 |
226 |
227 | if __name__ == "__main__":
228 | x = torch.rand([2, 3, 256, 256])
229 | T = ResNet18NR(True)
230 | T(x)
231 | S = ResNet18NR(False)
232 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from skimage import measure
4 | from sklearn.metrics import auc
5 | from loguru import logger
6 |
7 | import cv2
8 | import numpy as np
9 |
10 | def save_score_map(score_maps, names, class_name, save_path, loss_name=None):
11 | assert loss_name is not None
12 | num = len(score_maps)
13 | os.makedirs(save_path, exist_ok=True)
14 | score_maps = np.array(score_maps)
15 | max_mum = np.max(score_maps)
16 | min_mum = np.min(score_maps)
17 | score_maps = (score_maps - min_mum) / (max_mum - min_mum) * 255.0
18 | for _idx in range(num):
19 | score_map = np.uint8(np.squeeze(score_maps[_idx]))
20 | score_map = cv2.applyColorMap(score_map, cv2.COLORMAP_JET)
21 |
22 | _name = names[_idx].split('.')[0]
23 | path0 = os.path.join(save_path, f"{class_name}_{_name}_{loss_name}.png")
24 | cv2.imwrite(path0, score_map)
25 |
26 | def visualize(test_imgs, test_masks, score_maps, names, class_name, save_path, num=100, trans='imagenet'):
27 | num = min(num, len(test_imgs))
28 | os.makedirs(save_path, exist_ok=True)
29 | score_maps = np.array(score_maps)
30 | max_mum = np.max(score_maps)
31 | min_mum = np.min(score_maps)
32 | score_maps = (score_maps - min_mum) / (max_mum - min_mum) * 255.0
33 | for _idx in range(num):
34 | test_img = test_imgs[_idx]
35 | test_img = denormalize(test_img, trans=trans)
36 | test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
37 |
38 | test_mask = test_masks[_idx].transpose(1, 2, 0).squeeze()
39 |
40 | score_map = np.uint8(np.squeeze(score_maps[_idx]))
41 | # score_map = cv2.normalize(score_map, score_map, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
42 | score_map = cv2.applyColorMap(score_map, cv2.COLORMAP_JET)
43 | # cv2.imshow("score_map", score_map)
44 | # cv2.waitKey(0)
45 | # res_img = cv2.addWeighted(test_img, 0.4, score_map, 0.6, 0)
46 | # test_img = draw_detect(test_img, test_mask)
47 |
48 | name = names[_idx].split('.')[0]
49 | path0 = os.path.join(save_path, f"{class_name}_{name}.png")
50 | cv2.imwrite(path0, score_map)
51 |
52 | test_img_mask = draw_detect(test_img, test_mask)
53 | path1 = os.path.join(save_path, f"{class_name}_{name}_gt.png")
54 | cv2.imwrite(path1, test_img_mask)
55 |
56 |
57 | def draw_detect(img, label):
58 | assert len(label.shape) == 2
59 | label = np.uint8(label*255)
60 | _, label = cv2.threshold(label, 5, 255, 0)
61 | _, label_cnts, _ = cv2.findContours(label, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
62 |
63 | # img = cv2.drawContours(img, label_cnts, -1, (0, 0, 255), 1)
64 | mask = np.zeros(img.shape)
65 | mask = cv2.fillPoly(mask, label_cnts, color=(0, 0, 255))
66 | label=np.expand_dims(255-label, axis=2)/255
67 | img = img*(np.concatenate([label, label, label], axis=2)) + mask
68 | img = np.clip(img, 0, 255)
69 | img = np.uint8(img)
70 |
71 | # cv2.imshow("img", img)
72 | # cv2.waitKey(0)
73 | return img
74 |
75 |
76 | def denormalize(img, trans='imagenet'):
77 | if trans == 'imagenet':
78 | std = np.array([0.229, 0.224, 0.225])
79 | mean = np.array([0.485, 0.456, 0.406])
80 | x = (((img.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8)
81 | elif trans == 'coco':
82 | std = np.array([0.5, 0.5, 0.5])
83 | mean = np.array([0.5, 0.5, 0.5])
84 | x = (((img.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8)
85 | elif trans == 'navie':
86 | x = (img.transpose(1, 2, 0) * 255.).astype(np.uint8)
87 | elif trans == 'no':
88 | x = (img.transpose(1, 2, 0)).astype(np.uint8)
89 | else:
90 | raise NotImplementedError
91 | return x
92 |
93 | import os
94 | from PIL import Image
95 |
96 | from torchvision import transforms as T
97 | def load_image(path):
98 | transform_x = T.Compose([T.Resize(256, Image.ANTIALIAS),
99 | T.CenterCrop(224),
100 | T.ToTensor(),
101 | T.Normalize(mean=[0.485, 0.456, 0.406],
102 | std=[0.229, 0.224, 0.225])])
103 | x = Image.open(path).convert('RGB')
104 | x = transform_x(x)
105 | x = x.unsqueeze(0)
106 | return x
107 |
108 |
109 | def cal_pro_metric_new(labeled_imgs, score_imgs, fpr_thresh=0.3, max_steps=2000, class_name=None):
110 | labeled_imgs = np.array(labeled_imgs).squeeze(1)
111 | labeled_imgs[labeled_imgs <= 0.45] = 0
112 | labeled_imgs[labeled_imgs > 0.45] = 1
113 | labeled_imgs = labeled_imgs.astype(np.bool)
114 | score_imgs = np.array(score_imgs).squeeze(1)
115 |
116 | max_th = score_imgs.max()
117 | min_th = score_imgs.min()
118 | delta = (max_th - min_th) / max_steps
119 |
120 | ious_mean = []
121 | ious_std = []
122 | pros_mean = []
123 | pros_std = []
124 | threds = []
125 | fprs = []
126 | binary_score_maps = np.zeros_like(score_imgs, dtype=np.bool)
127 | for step in range(max_steps):
128 | thred = max_th - step * delta
129 | # segmentation
130 | binary_score_maps[score_imgs <= thred] = 0
131 | binary_score_maps[score_imgs > thred] = 1
132 |
133 | pro = [] # per region overlap
134 | iou = [] # per image iou
135 | # pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region
136 | # iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map
137 | for i in range(len(binary_score_maps)): # for i th image
138 | # pro (per region level)
139 | label_map = measure.label(labeled_imgs[i], connectivity=2)
140 | props = measure.regionprops(label_map)
141 | for prop in props:
142 | x_min, y_min, x_max, y_max = prop.bbox
143 | cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max]
144 | # cropped_mask = masks[i][x_min:x_max, y_min:y_max]
145 | cropped_mask = prop.filled_image # corrected!
146 | intersection = np.logical_and(cropped_pred_label, cropped_mask).astype(np.float32).sum()
147 | pro.append(intersection / prop.area)
148 | # iou (per image level)
149 | intersection = np.logical_and(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum()
150 | union = np.logical_or(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum()
151 | if labeled_imgs[i].any() > 0: # when the gt have no anomaly pixels, skip it
152 | iou.append(intersection / union)
153 | # against steps and average metrics on the testing data
154 | ious_mean.append(np.array(iou).mean())
155 | # print("per image mean iou:", np.array(iou).mean())
156 | ious_std.append(np.array(iou).std())
157 | pros_mean.append(np.array(pro).mean())
158 | pros_std.append(np.array(pro).std())
159 | # fpr for pro-auc
160 | masks_neg = ~labeled_imgs
161 | fpr = np.logical_and(masks_neg, binary_score_maps).sum() / masks_neg.sum()
162 | fprs.append(fpr)
163 | threds.append(thred)
164 |
165 | # as array
166 | threds = np.array(threds)
167 | pros_mean = np.array(pros_mean)
168 | pros_std = np.array(pros_std)
169 | fprs = np.array(fprs)
170 |
171 |
172 | # default 30% fpr vs pro, pro_auc
173 | idx = fprs <= fpr_thresh # find the indexs of fprs that is less than expect_fpr (default 0.3)
174 | fprs_selected = fprs[idx]
175 | fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1]
176 | pros_mean_selected = pros_mean[idx]
177 | pro_auc_score = auc(fprs_selected, pros_mean_selected)
178 | # print("pro auc ({}% FPR):".format(int(expect_fpr * 100)), pro_auc_score)
179 | return pro_auc_score
180 |
181 |
182 |
183 | def rescale(x):
184 | return (x - x.min()) / (x.max() - x.min())
185 |
186 | import torch
187 | import random
188 | def set_seed(seed):
189 | torch.manual_seed(seed)
190 | torch.cuda.manual_seed_all(seed)
191 | np.random.seed(seed)
192 | random.seed(seed)
193 | torch.backends.cudnn.deterministic = True
194 | torch.backends.cudnn.benchmark = False
195 |
196 |
197 |
198 |
199 |
200 | from thop import profile
201 | from thop import clever_format
202 | def calculate_flops(device, xs, xt, model):
203 | model = model.eval().to(device)
204 | xs = xs.to(device)
205 | xt = xt.to(device)
206 |
207 | flops, params = profile(model, inputs=(xs, xt,))
208 | flops, params = clever_format([flops, params], "%.3f")
209 | print(f"[INFO] flops: {flops}")
210 | print(f"[INFO] params: {params}")
211 | return flops, params
212 |
213 |
214 | if __name__ == "__main__":
215 | from MDFP_Dual_Norm_AD import DualProjectionNet
216 |
217 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
218 | ins = [64, 128, 256]
219 | outs = [256, 512, 1024]
220 | latents = [200, 400, 900]
221 | xs = [torch.rand([1, 64, 64, 64]), torch.rand([1, 128, 32, 32]), torch.rand([1, 256, 16, 16])]
222 | xt = [torch.rand([1, 256, 64, 64]), torch.rand([1, 512, 32, 32]), torch.rand([1, 1024, 16, 16])]
223 | for _in, _out, _latent, _xs, _xt in zip(ins, outs, latents, xs, xt):
224 | model1 = DualProjectionNet(in_dim=_in, out_dim=_out, latent_dim=_latent)
225 | flops, params = calculate_flops(device, model=model1, xs=_xs, xt=_xt)
226 |
227 |
--------------------------------------------------------------------------------
/MB-PFM-VGG.py:
--------------------------------------------------------------------------------
1 | '''
2 | Unsupervised Image Anomaly Detection and Segmentation Based on Pre-trained Feature Mapping
3 | '''
4 | import shutil
5 |
6 | from torch.utils.tensorboard import SummaryWriter
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torchvision.models.resnet import resnet18, resnet34
11 | from torchvision.models.vgg import vgg16_bn, vgg19_bn
12 | import os
13 | import numpy as np
14 | from sklearn.metrics import roc_auc_score
15 | from utils import visualize, cal_pro_metric_new, set_seed
16 | from loguru import logger
17 | import argparse
18 |
19 |
20 | class PretrainedModel(nn.Module):
21 | def __init__(self, model_name):
22 | super(PretrainedModel, self).__init__()
23 | if "resnet" in model_name:
24 | model = eval(model_name)(pretrained=True)
25 | modules = list(model.children())
26 | self.block1 = nn.Sequential(*modules[0:4])
27 | self.block2 = modules[4]
28 | self.block3 = modules[5]
29 | self.block4 = modules[6]
30 | self.block5 = modules[7]
31 | elif "vgg" in model_name:
32 | model = eval(model_name)(pretrained=True)
33 | self.modules = list(model.features)
34 | if model_name == "vgg16_bn":
35 | self.block1 = nn.Sequential(*self.modules[0:14])
36 | self.block2 = nn.Sequential(*self.modules[14:23])
37 | self.block3 = nn.Sequential(*self.modules[23:33])
38 | self.block4 = nn.Sequential(*self.modules[33:43])
39 | else:
40 | self.block1 = nn.Sequential(*self.modules[0:14])
41 | self.block2 = nn.Sequential(*self.modules[14:26])
42 | self.block3 = nn.Sequential(*self.modules[26:39])
43 | self.block4 = nn.Sequential(*self.modules[39:52])
44 |
45 | else:
46 | raise NotImplementedError
47 |
48 | def forward(self, x):
49 | # B x 64 x 64 x 64
50 | out1 = self.block1(x)
51 | # B x 128 x 32 x 32
52 | out2 = self.block2(out1)
53 | # B x 256 x 16 x 16
54 | # 32x32x128
55 | out3 = self.block3(out2)
56 | # 16x16x256
57 | out4 = self.block4(out3)
58 | return {"out2": out2,
59 | "out3": out3,
60 | "out4": out4
61 | }
62 |
63 | # x = torch.rand([1, 3, 256, 256])
64 | # # T = ResNetS(model_name='resnet18', pretrained=True)
65 | # T = PretrainedModel(model_name='vgg16_bn')
66 | # y = T(x)
67 | #
68 | # for key in y.keys():
69 | # print(f"{key}: {y[key].shape}")
70 |
71 | class Conv_BN_Relu(nn.Module):
72 | def __init__(self, in_dim, out_dim, k=1, s=1, p=0, bn=True, relu=True):
73 | super(Conv_BN_Relu, self).__init__()
74 | self.conv = [
75 | nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p),
76 | ]
77 | if bn:
78 | self.conv.append(nn.BatchNorm2d(out_dim))
79 | if relu:
80 | self.conv.append(nn.ReLU(inplace=True))
81 |
82 | self.conv = nn.Sequential(*self.conv)
83 |
84 | def forward(self, x):
85 | return self.conv(x)
86 |
87 |
88 |
89 | class DualProjectionNet(nn.Module):
90 | def __init__(self, in_dim=512, out_dim=512, latent_dim=256):
91 | super(DualProjectionNet, self).__init__()
92 | self.encoder1 = nn.Sequential(*[
93 | Conv_BN_Relu(in_dim, in_dim//2+latent_dim),
94 | Conv_BN_Relu(in_dim//2+latent_dim, 2*latent_dim),
95 | # Conv_BN_Relu(2*latent_dim, latent_dim),
96 | ])
97 |
98 | self.shared_coder = Conv_BN_Relu(2*latent_dim, latent_dim, bn=False, relu=False)
99 |
100 | self.decoder1 = nn.Sequential(*[
101 | Conv_BN_Relu(latent_dim, 2*latent_dim),
102 | Conv_BN_Relu(2*latent_dim, out_dim//2+latent_dim),
103 | Conv_BN_Relu(out_dim//2+latent_dim, out_dim, bn=False, relu=False),
104 | ])
105 |
106 |
107 | self.encoder2 = nn.Sequential(*[
108 | Conv_BN_Relu(out_dim, out_dim // 2 + latent_dim),
109 | Conv_BN_Relu(out_dim // 2 + latent_dim, 2 * latent_dim),
110 | # Conv_BN_Relu(2 * latent_dim, latent_dim),
111 | ])
112 |
113 | self.decoder2 = nn.Sequential(*[
114 | Conv_BN_Relu(latent_dim, 2 * latent_dim),
115 | Conv_BN_Relu(2 * latent_dim, in_dim // 2 + latent_dim),
116 | Conv_BN_Relu(in_dim // 2 + latent_dim, in_dim, bn=False, relu=False),
117 | ])
118 |
119 |
120 | def forward(self, xs, xt):
121 | xt_hat = self.encoder1(xs)
122 | xt_hat = self.shared_coder(xt_hat)
123 | xt_hat = self.decoder1(xt_hat)
124 |
125 | xs_hat = self.encoder2(xt)
126 | xs_hat = self.shared_coder(xs_hat)
127 | xs_hat = self.decoder2(xs_hat)
128 |
129 | return xs_hat, xt_hat
130 |
131 |
132 | class DFP_AD(object):
133 | def __init__(self, type='vgg'):
134 | if type == "resnet":
135 | self.Agent1 = PretrainedModel(model_name="resnet18")
136 | self.Agent2 = PretrainedModel(model_name="resnet34")
137 |
138 | elif type == "vgg":
139 | self.Agent1 = PretrainedModel(model_name="vgg16_bn")
140 | self.Agent2 = PretrainedModel(model_name="vgg19_bn")
141 |
142 | def register(self, **kwargs):
143 | self.class_name = kwargs['class_name']
144 | self.device = kwargs['device']
145 | self.trainloader = kwargs['trainloader']
146 | self.testloader = kwargs['testloader']
147 |
148 |
149 | self.projector2 = DualProjectionNet(in_dim=256, out_dim=256, latent_dim=200)
150 | self.optimizer2 = torch.optim.Adam(self.projector2.parameters(), lr=kwargs["lr2"], weight_decay=kwargs["weight_decay"])
151 | self.projector3 = DualProjectionNet(in_dim=512, out_dim=512, latent_dim=400)
152 | self.optimizer3 = torch.optim.Adam(self.projector3.parameters(), lr=kwargs["lr3"], weight_decay=kwargs["weight_decay"])
153 | self.projector4 = DualProjectionNet(in_dim=512, out_dim=512, latent_dim=400)
154 | self.optimizer4 = torch.optim.Adam(self.projector4.parameters(), lr=kwargs["lr4"], weight_decay=kwargs["weight_decay"])
155 |
156 | self.Agent1.to(self.device).eval()
157 | self.Agent2.to(self.device).eval()
158 | self.projector2.to(self.device)
159 | self.projector3.to(self.device)
160 | self.projector4.to(self.device)
161 |
162 | self.save_root = "./result/MB-PFM-VGG_{}/".format(kwargs["seed"])
163 | os.makedirs(os.path.join(self.save_root, "ckpt"), exist_ok=True)
164 | self.ckpt2 = os.path.join(self.save_root, "ckpt/{}_2.pth".format(kwargs["class_name"]))
165 | self.ckpt3 = os.path.join(self.save_root, "ckpt/{}_3.pth".format(kwargs["class_name"]))
166 | self.ckpt4 = os.path.join(self.save_root, "ckpt/{}_4.pth".format(kwargs["class_name"]))
167 | os.makedirs(os.path.join(self.save_root, "tblogs"), exist_ok=True)
168 | self.tblog = os.path.join(self.save_root, "tblogs/{}".format(kwargs["class_name"]))
169 |
170 |
171 | def get_agent_out(self, x):
172 | out_a1 = self.Agent1(x)
173 | out_a2 = self.Agent2(x)
174 | for key in out_a2.keys():
175 | out_a1[key] = F.normalize(out_a1[key], p=2)
176 | out_a2[key] = F.normalize(out_a2[key], p=2)
177 | return out_a1, out_a2
178 |
179 | def train(self, epochs=100):
180 | if not os.path.exists(self.ckpt2):
181 | if os.path.exists(self.tblog):
182 | shutil.rmtree(self.tblog)
183 | os.makedirs(self.tblog, exist_ok=True)
184 | self.writer = SummaryWriter(log_dir=self.tblog)
185 | for ep in range(0, epochs):
186 | self.projector2.train()
187 | self.projector3.train()
188 | self.projector4.train()
189 | for i, (x, _, _, _) in enumerate(self.trainloader):
190 | x = x.to(self.device)
191 | out_a1, out_a2 = self.get_agent_out(x)
192 |
193 | # project_out2 = self.projector2(out_a1["out2"].detach())
194 | # loss2 = torch.mean((out_a2["out2"].detach() - project_out2) ** 2)
195 | project_out21, project_out22 = self.projector2(out_a1["out2"].detach(), out_a2["out2"].detach())
196 | loss21 = torch.mean((out_a1["out2"] - project_out21) ** 2)
197 | loss22 = torch.mean((out_a2["out2"] - project_out22) ** 2)
198 | loss2 = loss21 + loss22
199 | self.optimizer2.zero_grad()
200 | loss2.backward()
201 | self.optimizer2.step()
202 |
203 | project_out31, project_out32 = self.projector3(out_a1["out3"].detach(), out_a2["out3"].detach())
204 | loss31 = torch.mean((out_a1["out3"].detach() - project_out31) ** 2)
205 | loss32 = torch.mean((out_a2["out3"].detach() - project_out32) ** 2)
206 | loss3 = loss31 + loss32
207 | self.optimizer3.zero_grad()
208 | loss3.backward()
209 | self.optimizer3.step()
210 |
211 | project_out41, project_out42 = self.projector4(out_a1["out4"].detach(), out_a2["out4"].detach())
212 | loss41 = torch.mean((out_a1["out4"].detach() - project_out41) ** 2)
213 | loss42 = torch.mean((out_a2["out4"].detach() - project_out42) ** 2)
214 | loss4 = loss41 + loss42
215 | self.optimizer4.zero_grad()
216 | loss4.backward()
217 | self.optimizer4.step()
218 |
219 | print(f"Epoch-{ep}-Step-{i}, {self.class_name} | loss2: {loss2.item():.6f} | loss3: {loss3.item():.6f} | loss4: {loss4.item():.6f}")
220 | self.writer.add_scalar('Train/loss2', loss2.item(), ep*len(self.trainloader)+i)
221 | self.writer.add_scalar('Train/loss3', loss3.item(), ep*len(self.trainloader)+i)
222 | self.writer.add_scalar('Train/loss4', loss4.item(), ep*len(self.trainloader)+i)
223 | # print(f"Epoch-{ep}-Step-{i}, {self.class_name} | loss: {loss.item():.5f} | loss1: {loss1.item():.5f} | loss2: {loss2.item():.5f}")
224 | # self.writer.add_scalar('Train/loss_l2', loss1.item(), ep*len(self.trainloader)+i)
225 | # self.writer.add_scalar('Train/loss_norm_l2', loss2.item(), ep*len(self.trainloader)+i)
226 |
227 | torch.save(self.projector2.state_dict(), self.ckpt2)
228 | torch.save(self.projector3.state_dict(), self.ckpt3)
229 | torch.save(self.projector4.state_dict(), self.ckpt4)
230 | if ep % 10 == 0:
231 | metrix = self.test(cal_pro=False)
232 | logger.info(f"Epoch-{ep}, {self.class_name} | all: {metrix['all'][0]:.5f}, {metrix['all'][1]:.5f} | 2: {metrix['2'][0]:.5f}, {metrix['2'][1]:.5f}"
233 | f"| 3: {metrix['3'][0]:.5f}, {metrix['3'][1]:.5f} | 4: {metrix['4'][0]:.5f}, {metrix['4'][1]:.5f}")
234 | self.writer.add_scalar('Val/imge_auc2', metrix['2'][0], ep)
235 | self.writer.add_scalar('Val/pixel_auc2', metrix['2'][1], ep)
236 | self.writer.add_scalar('Val/imge_auc3', metrix['3'][0], ep)
237 | self.writer.add_scalar('Val/pixel_auc3', metrix['3'][1], ep)
238 | self.writer.add_scalar('Val/imge_auc4', metrix['4'][0], ep)
239 | self.writer.add_scalar('Val/pixel_auc4', metrix['4'][1], ep)
240 | self.writer.add_scalar('Val/imge_auc', metrix['all'][0], ep)
241 | self.writer.add_scalar('Val/pixel_auc', metrix['all'][1], ep)
242 | self.writer.close()
243 | else:
244 | pass
245 |
246 | def test(self, cal_pro=False):
247 | self.load_project_model()
248 | self.projector2.eval()
249 | self.projector3.eval()
250 | self.projector4.eval()
251 | with torch.no_grad():
252 |
253 | test_y_list = []
254 | test_mask_list = []
255 | test_img_list = []
256 | test_img_name_list = []
257 | # pixel-level
258 | score_map_list = []
259 | score_list = []
260 |
261 | score2_map_list = []
262 | score2_list = []
263 | score3_map_list = []
264 | score3_list = []
265 | score4_map_list = []
266 | score4_list = []
267 |
268 | for x, y, mask, name in self.testloader:
269 | test_y_list.extend(y.detach().cpu().numpy())
270 | test_mask_list.extend(mask.detach().cpu().numpy())
271 | test_img_list.extend(x.detach().cpu().numpy())
272 | test_img_name_list.extend(name)
273 |
274 | x = x.to(self.device)
275 | _, _, H, W = x.shape
276 | out_a1, out_a2 = self.get_agent_out(x)
277 |
278 | project_out21, project_out22 = self.projector2(out_a1["out2"], out_a2["out2"])
279 | loss21_map = torch.sum((out_a1["out2"] - project_out21) ** 2, dim=1, keepdim=True)
280 | loss22_map = torch.sum((out_a2["out2"] - project_out22) ** 2, dim=1, keepdim=True)
281 | loss2_map = (loss21_map + loss22_map) / 2.0
282 | score2_map = F.interpolate(loss2_map, size=(H, W), mode='bilinear', align_corners=False)
283 | score2_map = score2_map.cpu().detach().numpy()
284 | score2_map_list.extend(score2_map)
285 | score2_list.extend(np.squeeze(np.max(np.max(score2_map, axis=2), axis=2), 1))
286 |
287 | project_out31, project_out32 = self.projector3(out_a1["out3"], out_a2["out3"])
288 | loss31_map = torch.sum((out_a1["out3"] - project_out31) ** 2, dim=1, keepdim=True)
289 | loss32_map = torch.sum((out_a2["out3"] - project_out32) ** 2, dim=1, keepdim=True)
290 | loss3_map = (loss31_map + loss32_map) / 2.0
291 | score3_map = F.interpolate(loss3_map, size=(H, W), mode='bilinear', align_corners=False)
292 | score3_map = score3_map.cpu().detach().numpy()
293 | score3_map_list.extend(score3_map)
294 | score3_list.extend(np.squeeze(np.max(np.max(score3_map, axis=2), axis=2), 1))
295 |
296 |
297 | project_out41, project_out42 = self.projector4(out_a1["out4"], out_a2["out4"])
298 | loss41_map = torch.sum((out_a1["out4"] - project_out41) ** 2, dim=1, keepdim=True)
299 | loss42_map = torch.sum((out_a2["out4"] - project_out42) ** 2, dim=1, keepdim=True)
300 | loss4_map = (loss41_map + loss42_map) / 2.0
301 | score4_map = F.interpolate(loss4_map, size=(H, W), mode='bilinear', align_corners=False)
302 | score4_map = score4_map.cpu().detach().numpy()
303 | score4_map_list.extend(score4_map)
304 | score4_list.extend(np.squeeze(np.max(np.max(score4_map, axis=2), axis=2), 1))
305 |
306 | score_map = (score4_map + score3_map + score2_map) / 3
307 | score_map_list.extend(score_map)
308 | score_list.extend(np.squeeze(np.max(np.max(score_map, axis=2), axis=2), 1))
309 |
310 | visualize(test_img_list, test_mask_list, score_map_list, test_img_name_list, self.class_name,
311 | f"{self.save_root}image/", 10000)
312 | # ROCAUC
313 | # imge_auc2, pixel_auc2, pixel_pro2 = self.cal_auc(score2_list, score2_map_list, test_y_list, test_mask_list)
314 | # imge_auc3, pixel_auc3, pixel_pro3 = self.cal_auc(score3_list, score3_map_list, test_y_list, test_mask_list)
315 | # imge_auc4, pixel_auc4, pixel_pro4 = self.cal_auc(score4_list, score4_map_list, test_y_list, test_mask_list)
316 | imge_auc, pixel_auc, pixel_pro = self.cal_auc(score_list, score_map_list, test_y_list, test_mask_list)
317 | # print(f"pixel AUC: {pixel_level_ROCAUC:.5f}")
318 | # metrix = {"2": [imge_auc2, pixel_auc2, pixel_pro2], "3": [imge_auc3, pixel_auc3, pixel_pro3],"4":[imge_auc4,
319 | # pixel_auc4, pixel_pro4], "all": [imge_auc, pixel_auc, pixel_pro]}
320 | metrix = { "all": [imge_auc, pixel_auc, pixel_pro]}
321 | return metrix
322 |
323 | def cal_auc(self, score_list, score_map_list, test_y_list, test_mask_list):
324 | flatten_y_list = np.array(test_y_list).ravel()
325 | flatten_score_list = np.array(score_list).ravel()
326 | image_level_ROCAUC = roc_auc_score(flatten_y_list, flatten_score_list)
327 |
328 | flatten_mask_list = np.concatenate(test_mask_list).ravel()
329 | flatten_score_map_list = np.concatenate(score_map_list).ravel()
330 | pixel_level_ROCAUC = roc_auc_score(flatten_mask_list, flatten_score_map_list)
331 | pro_auc_score = cal_pro_metric_new(test_mask_list, score_map_list, fpr_thresh=0.3)
332 | return image_level_ROCAUC, pixel_level_ROCAUC, pro_auc_score
333 |
334 | def load_project_model(self):
335 | self.projector2.load_state_dict(torch.load(self.ckpt2))
336 | self.projector3.load_state_dict(torch.load(self.ckpt3))
337 | self.projector4.load_state_dict(torch.load(self.ckpt4))
338 |
339 |
340 |
341 | def parse_args():
342 | parser = argparse.ArgumentParser('STFPM_Center')
343 | parser.add_argument("--seed", type=int, default=888)
344 | parser.add_argument("--gpu_id", type=str, default="0")
345 | parser.add_argument("--train", type=bool, default=True)
346 |
347 | parser.add_argument("--data_trans", type=str, default='imagenet', choices=['navie', 'imagenet'])
348 |
349 | parser.add_argument("--loss_type", type=str, default='l2norm+l2',
350 | choices=['l2norm+l2', 'l2', 'l1', 'consine', 'l2+consine'])
351 |
352 | parser.add_argument("--model_name", type=str, default='vgg16', choices=['vgg16', 'resnet18'])
353 | parser.add_argument("--epochs", type=int, default=200)
354 | parser.add_argument("--batch_size", type=int, default=8) # 6 or 20 for train
355 | parser.add_argument("--lr2", type=float, default=3e-3)
356 | parser.add_argument("--lr3", type=float, default=3e-4)
357 | parser.add_argument("--lr4", type=float, default=3e-4)
358 | parser.add_argument("--weight_decay", type=float, default=1e-5)
359 |
360 | parser.add_argument("--latent_dim", type=int, default=200)
361 |
362 | parser.add_argument("--data_root", type=str, default="D:/Dataset/mvtec_anomaly_detection/")
363 | parser.add_argument("--resize", type=int, default=256)
364 |
365 | parser.add_argument("--post_smooth", type=int, default=0)
366 | args = parser.parse_args()
367 | return args
368 |
369 |
370 | if __name__ == "__main__":
371 |
372 | args = parse_args()
373 | set_seed(args.seed)
374 | logger.add(
375 | f'./result/MB-PFM-VGG_{args.seed}/logger-{args.data_trans}-{args.loss_type}-{args.resize}-{args.model_name}.txt',
376 | rotation="200 MB",
377 | backtrace=True,
378 | diagnose=True)
379 | logger.info(str(args))
380 |
381 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
382 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
383 |
384 | from dataset.mvtec import MVTecDataset, MVTec_CLASS_NAMES
385 | from torch.utils.data import DataLoader
386 | from torchvision import transforms as T
387 | from PIL import Image
388 |
389 | if args.data_trans == 'navie':
390 | trans_x = T.Compose([T.Resize(args.resize, Image.ANTIALIAS),
391 | T.ToTensor()])
392 | else:
393 | trans_x = T.Compose([T.Resize(args.resize, Image.ANTIALIAS),
394 | T.ToTensor(),
395 | T.Normalize(mean=[0.485, 0.456, 0.406],
396 | std=[0.229, 0.224, 0.225])])
397 |
398 | image_aucs = []
399 | pixel_aucs = []
400 | pro_30s = []
401 | for class_name in MVTec_CLASS_NAMES:
402 | torch.cuda.empty_cache()
403 | trainset = MVTecDataset(root_path=args.data_root, is_train=True, class_name=class_name, resize=args.resize,
404 | trans=trans_x)
405 | trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True,
406 | num_workers=4)
407 |
408 | testset = MVTecDataset(root_path=args.data_root, is_train=False, class_name=class_name, resize=args.resize,
409 | trans=trans_x)
410 | testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
411 |
412 | model = DFP_AD(type="vgg")
413 | model.register(class_name=class_name, trainloader=trainloader, testloader=testloader, loss_type=args.loss_type,
414 | data_trans=args.data_trans, size=args.resize, device=device,
415 | latent_dim=args.latent_dim,
416 | lr2=args.lr2, lr3=args.lr3, lr4=args.lr4, weight_decay=args.weight_decay,
417 | seed=args.seed)
418 | if args.train:
419 | model.train(epochs=args.epochs)
420 | # else:
421 | # model.load_student_weight()
422 | metrix = model.test(cal_pro=True)
423 | image_aucs.append(metrix["all"][0])
424 | pixel_aucs.append(metrix["all"][1])
425 | pro_30s.append(metrix["all"][2])
426 | logger.info(f"{class_name}, image auc: {metrix['all'][0]:.5f}, pixel auc: {metrix['all'][1]:.5f}, pixel pro0.3: {metrix['all'][2]:.5f}")
427 |
428 | i_auc = np.mean(np.array(image_aucs))
429 | p_auc = np.mean(np.array(pixel_aucs))
430 | pro_auc = np.mean(np.array(pro_30s))
431 | logger.info(f"total, image AUC: {i_auc:.5f} | pixel AUC: {p_auc:.5f} | pixel PROo.3: {pro_auc:.5f}")
432 |
--------------------------------------------------------------------------------
/MB-PFM-ResNet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Unsupervised Image Anomaly Detection and Segmentation Based on Pre-trained Feature Mapping
3 | '''
4 | import shutil
5 |
6 | from torch.utils.tensorboard import SummaryWriter
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101
11 | from torchvision.models.vgg import vgg16_bn, vgg19_bn
12 | import os
13 | import numpy as np
14 | from sklearn.metrics import roc_auc_score
15 | from utils import visualize, set_seed, cal_pro_metric_new
16 | from loguru import logger
17 | import argparse
18 | import time
19 | from scipy.ndimage import gaussian_filter
20 | import matplotlib.pyplot as plt
21 |
22 |
23 |
24 | class PretrainedModel(nn.Module):
25 | def __init__(self, model_name):
26 | super(PretrainedModel, self).__init__()
27 | if "resnet" in model_name:
28 | model = eval(model_name)(pretrained=True)
29 | modules = list(model.children())
30 | self.block1 = nn.Sequential(*modules[0:4])
31 | self.block2 = modules[4]
32 | self.block3 = modules[5]
33 | self.block4 = modules[6]
34 | self.block5 = modules[7]
35 | elif "vgg" in model_name:
36 | if model_name == "vgg16_bn":
37 | self.block1 = nn.Sequential(*self.modules[0:14])
38 | self.block2 = nn.Sequential(*self.modules[14:23])
39 | self.block3 = nn.Sequential(*self.modules[23:33])
40 | self.block4 = nn.Sequential(*self.modules[33:43])
41 | else:
42 | self.block1 = nn.Sequential(*self.modules[0:14])
43 | self.block2 = nn.Sequential(*self.modules[14:26])
44 | self.block3 = nn.Sequential(*self.modules[26:39])
45 | self.block4 = nn.Sequential(*self.modules[39:52])
46 | else:
47 | raise NotImplementedError
48 |
49 | def forward(self, x):
50 | # B x 64 x 64 x 64
51 | out1 = self.block1(x)
52 | # B x 128 x 32 x 32
53 | out2 = self.block2(out1)
54 | # B x 256 x 16 x 16
55 | # 32x32x128
56 | out3 = self.block3(out2)
57 | # 16x16x256
58 | out4 = self.block4(out3)
59 | return {"out2": out2,
60 | "out3": out3,
61 | "out4": out4
62 | }
63 |
64 | class Conv_BN_Relu(nn.Module):
65 | def __init__(self, in_dim, out_dim, k=1, s=1, p=0, bn=True, relu=True):
66 | super(Conv_BN_Relu, self).__init__()
67 | self.conv = [
68 | nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p),
69 | ]
70 | if bn:
71 | self.conv.append(nn.BatchNorm2d(out_dim))
72 | if relu:
73 | self.conv.append(nn.ReLU(inplace=True))
74 |
75 | self.conv = nn.Sequential(*self.conv)
76 |
77 | def forward(self, x):
78 | return self.conv(x)
79 |
80 |
81 | class DualProjectionNet(nn.Module):
82 | def __init__(self, in_dim=512, out_dim=512, latent_dim=256):
83 | super(DualProjectionNet, self).__init__()
84 | self.encoder1 = nn.Sequential(*[
85 | Conv_BN_Relu(in_dim, in_dim//2+latent_dim),
86 | Conv_BN_Relu(in_dim//2+latent_dim, 2*latent_dim),
87 | # Conv_BN_Relu(2*latent_dim, latent_dim),
88 | ])
89 |
90 | self.shared_coder = Conv_BN_Relu(2*latent_dim, latent_dim, bn=False, relu=False)
91 |
92 | self.decoder1 = nn.Sequential(*[
93 | Conv_BN_Relu(latent_dim, 2*latent_dim),
94 | Conv_BN_Relu(2*latent_dim, out_dim//2+latent_dim),
95 | Conv_BN_Relu(out_dim//2+latent_dim, out_dim, bn=False, relu=False),
96 | ])
97 |
98 |
99 | self.encoder2 = nn.Sequential(*[
100 | Conv_BN_Relu(out_dim, out_dim // 2 + latent_dim),
101 | Conv_BN_Relu(out_dim // 2 + latent_dim, 2 * latent_dim),
102 | # Conv_BN_Relu(2 * latent_dim, latent_dim),
103 | ])
104 |
105 | self.decoder2 = nn.Sequential(*[
106 | Conv_BN_Relu(latent_dim, 2 * latent_dim),
107 | Conv_BN_Relu(2 * latent_dim, in_dim // 2 + latent_dim),
108 | Conv_BN_Relu(in_dim // 2 + latent_dim, in_dim, bn=False, relu=False),
109 | ])
110 |
111 |
112 | def forward(self, xs, xt):
113 | xt_hat = self.encoder1(xs)
114 | xt_hat = self.shared_coder(xt_hat)
115 | xt_hat = self.decoder1(xt_hat)
116 |
117 | xs_hat = self.encoder2(xt)
118 | xs_hat = self.shared_coder(xs_hat)
119 | xs_hat = self.decoder2(xs_hat)
120 |
121 | return xs_hat, xt_hat
122 |
123 |
124 | class DFP_AD(object):
125 | def __init__(self, agent_S='resnet50', agent_T="resnet101"):
126 | self.s_name = agent_S
127 | self.t_name = agent_T
128 | if agent_S == "resnet18" or agent_S == "resnet34":
129 | self.Agent1 = PretrainedModel(model_name=agent_S)
130 | self.indim = [64, 128, 256]
131 | # self.outdim = [50, 100, 200]
132 | elif agent_S == "resnet50":
133 | self.Agent1 = PretrainedModel(model_name=agent_S)
134 | self.indim = [256, 512, 1024]
135 | # self.Agent2 = PretrainedModel(model_name="resnet34")
136 | if agent_T == "resnet50" or agent_T == "resnet101":
137 | # self.Agent1 = PretrainedModel(model_name="vgg16")
138 | self.Agent2 = PretrainedModel(model_name=agent_T)
139 | self.outdim = [256, 512, 1024]
140 | self.latent_dim = [200, 400, 900]
141 |
142 | def register(self, **kwargs):
143 | self.class_name = kwargs['class_name']
144 | self.device = kwargs['device']
145 | self.trainloader = kwargs['trainloader']
146 | self.testloader = kwargs['testloader']
147 |
148 | self.projector2 = DualProjectionNet(in_dim=self.indim[0], out_dim=self.outdim[0], latent_dim=self.latent_dim[0])
149 | self.optimizer2 = torch.optim.Adam(self.projector2.parameters(), lr=kwargs["lr2"], weight_decay=kwargs["weight_decay"])
150 | self.projector3 = DualProjectionNet(in_dim=self.indim[1], out_dim=self.outdim[1], latent_dim=self.latent_dim[1])
151 | self.optimizer3 = torch.optim.Adam(self.projector3.parameters(), lr=kwargs["lr3"], weight_decay=kwargs["weight_decay"])
152 | self.projector4 = DualProjectionNet(in_dim=self.indim[2], out_dim=self.outdim[2], latent_dim=self.latent_dim[2])
153 | self.optimizer4 = torch.optim.Adam(self.projector4.parameters(), lr=kwargs["lr4"], weight_decay=kwargs["weight_decay"])
154 |
155 | self.Agent1.to(self.device).eval()
156 | self.Agent2.to(self.device).eval()
157 |
158 | self.projector2.to(self.device)
159 | self.projector3.to(self.device)
160 | self.projector4.to(self.device)
161 |
162 | self.save_root = "./result/MB-PFM_{}-{}_{}/".format(self.s_name, self.t_name, kwargs["seed"])
163 | os.makedirs(os.path.join(self.save_root, "ckpt"), exist_ok=True)
164 | self.ckpt2 = os.path.join(self.save_root, "ckpt/{}_2.pth".format(kwargs["class_name"]))
165 | self.ckpt3 = os.path.join(self.save_root, "ckpt/{}_3.pth".format(kwargs["class_name"]))
166 | self.ckpt4 = os.path.join(self.save_root, "ckpt/{}_4.pth".format(kwargs["class_name"]))
167 | os.makedirs(os.path.join(self.save_root, "tblogs"), exist_ok=True)
168 | self.tblog = os.path.join(self.save_root, "tblogs/{}".format(kwargs["class_name"]))
169 |
170 |
171 | def get_agent_out(self, x):
172 | out_a1 = self.Agent1(x)
173 | out_a2 = self.Agent2(x)
174 | for key in out_a2.keys():
175 | out_a1[key] = F.normalize(out_a1[key], p=2)
176 | out_a2[key] = F.normalize(out_a2[key], p=2)
177 | return out_a1, out_a2
178 |
179 |
180 | def train(self, epochs=100):
181 | if not os.path.exists(self.ckpt2):
182 | if os.path.exists(self.tblog):
183 | shutil.rmtree(self.tblog)
184 | os.makedirs(self.tblog, exist_ok=True)
185 | self.writer = SummaryWriter(log_dir=self.tblog)
186 | for ep in range(0, epochs):
187 | self.projector2.train()
188 | self.projector3.train()
189 | self.projector4.train()
190 | for i, (x, _, _, _) in enumerate(self.trainloader):
191 | x = x.to(self.device)
192 | out_a1, out_a2 = self.get_agent_out(x)
193 |
194 | # project_out2 = self.projector2(out_a1["out2"].detach())
195 | # loss2 = torch.mean((out_a2["out2"].detach() - project_out2) ** 2)
196 | project_out21, project_out22 = self.projector2(out_a1["out2"].detach(), out_a2["out2"].detach())
197 | loss21 = torch.mean((out_a1["out2"] - project_out21) ** 2)
198 | loss22 = torch.mean((out_a2["out2"] - project_out22) ** 2)
199 | loss2 = loss21 + loss22
200 | self.optimizer2.zero_grad()
201 | loss2.backward()
202 | self.optimizer2.step()
203 |
204 | project_out31, project_out32 = self.projector3(out_a1["out3"].detach(), out_a2["out3"].detach())
205 | loss31 = torch.mean((out_a1["out3"].detach() - project_out31) ** 2)
206 | loss32 = torch.mean((out_a2["out3"].detach() - project_out32) ** 2)
207 | loss3 = loss31 + loss32
208 | self.optimizer3.zero_grad()
209 | loss3.backward()
210 | self.optimizer3.step()
211 |
212 | project_out41, project_out42 = self.projector4(out_a1["out4"].detach(), out_a2["out4"].detach())
213 | loss41 = torch.mean((out_a1["out4"].detach() - project_out41) ** 2)
214 | loss42 = torch.mean((out_a2["out4"].detach() - project_out42) ** 2)
215 | loss4 = loss41 + loss42
216 | self.optimizer4.zero_grad()
217 | loss4.backward()
218 | self.optimizer4.step()
219 |
220 | print(f"Epoch-{ep}-Step-{i}, {self.class_name} | loss2: {loss2.item():.6f} | loss3: {loss3.item():.6f} | loss4: {loss4.item():.6f}")
221 | self.writer.add_scalar('Train/loss2', loss2.item(), ep*len(self.trainloader)+i)
222 | self.writer.add_scalar('Train/loss3', loss3.item(), ep*len(self.trainloader)+i)
223 | self.writer.add_scalar('Train/loss4', loss4.item(), ep*len(self.trainloader)+i)
224 | # print(f"Epoch-{ep}-Step-{i}, {self.class_name} | loss: {loss.item():.5f} | loss1: {loss1.item():.5f} | loss2: {loss2.item():.5f}")
225 | # self.writer.add_scalar('Train/loss_l2', loss1.item(), ep*len(self.trainloader)+i)
226 | # self.writer.add_scalar('Train/loss_norm_l2', loss2.item(), ep*len(self.trainloader)+i)
227 |
228 | torch.save(self.projector2.state_dict(), self.ckpt2)
229 | torch.save(self.projector3.state_dict(), self.ckpt3)
230 | torch.save(self.projector4.state_dict(), self.ckpt4)
231 | if ep % 10 == 0:
232 | metrix = self.test(cal_pro=False)
233 | logger.info(f"Epoch-{ep}, {self.class_name} | all: {metrix['all'][0]:.5f}, {metrix['all'][1]:.5f} | 2: {metrix['2'][0]:.5f}, {metrix['2'][1]:.5f}"
234 | f"| 3: {metrix['3'][0]:.5f}, {metrix['3'][1]:.5f} | 4: {metrix['4'][0]:.5f}, {metrix['4'][1]:.5f}")
235 | self.writer.add_scalar('Val/imge_auc2', metrix['2'][0], ep)
236 | self.writer.add_scalar('Val/pixel_auc2', metrix['2'][1], ep)
237 | self.writer.add_scalar('Val/imge_auc3', metrix['3'][0], ep)
238 | self.writer.add_scalar('Val/pixel_auc3', metrix['3'][1], ep)
239 | self.writer.add_scalar('Val/imge_auc4', metrix['4'][0], ep)
240 | self.writer.add_scalar('Val/pixel_auc4', metrix['4'][1], ep)
241 | self.writer.add_scalar('Val/imge_auc', metrix['all'][0], ep)
242 | self.writer.add_scalar('Val/pixel_auc', metrix['all'][1], ep)
243 | self.writer.close()
244 | else:
245 | pass
246 |
247 |
248 | def statistic_var(self, c=False):
249 | if c:
250 | self.var21 = 0
251 | self.var22 = 0
252 | self.var31 = 0
253 | self.var32 = 0
254 | self.var41 = 0
255 | self.var42 = 0
256 | with torch.no_grad():
257 | for i, (x, _, _, _) in enumerate(self.trainloader):
258 | torch.cuda.empty_cache()
259 | x = x.to(self.device)
260 | out_a1, out_a2 = self.get_agent_out(x)
261 | project_out21, project_out22 = self.projector2(out_a1["out2"], out_a2["out2"])
262 | var21 = (out_a1["out2"] - project_out21) ** 2
263 | var22 = (out_a2["out2"] - project_out22) ** 2
264 | self.var21 += torch.mean(var21, dim=0, keepdim=True)
265 | self.var22 += torch.mean(var22, dim=0, keepdim=True)
266 |
267 | project_out31, project_out32 = self.projector3(out_a1["out3"], out_a2["out3"])
268 | var31 = (out_a1["out3"] - project_out31) ** 2
269 | var32 = (out_a2["out3"] - project_out32) ** 2
270 | self.var31 += torch.mean(var31, dim=0, keepdim=True)
271 | self.var32 += torch.mean(var32, dim=0, keepdim=True)
272 |
273 | project_out41, project_out42 = self.projector4(out_a1["out4"], out_a2["out4"])
274 | var41 = (out_a1["out4"] - project_out41) ** 2
275 | var42 = (out_a2["out4"] - project_out42) ** 2
276 | self.var41 += torch.mean(var41, dim=0, keepdim=True)
277 | self.var42 += torch.mean(var42, dim=0, keepdim=True)
278 |
279 | self.var21 /= len(self.trainloader)
280 | self.var22 /= len(self.trainloader)
281 | self.var31 /= len(self.trainloader)
282 | self.var32 /= len(self.trainloader)
283 | self.var41 /= len(self.trainloader)
284 | else:
285 | self.var21 = 1
286 | self.var22 = 1
287 | self.var31 = 1
288 | self.var32 = 1
289 | self.var41 = 1
290 | self.var42 = 1
291 |
292 |
293 | def test(self, cal_pro=False):
294 | self.load_project_model()
295 | self.projector2.eval()
296 | self.projector3.eval()
297 | self.projector4.eval()
298 |
299 | self.statistic_var()
300 |
301 | with torch.no_grad():
302 |
303 | test_y_list = []
304 | test_mask_list = []
305 | test_img_list = []
306 | test_img_name_list = []
307 | # pixel-level
308 | score_map_list = []
309 | score_list = []
310 |
311 | score2_map_list = []
312 | score2_list = []
313 | score3_map_list = []
314 | score3_list = []
315 | score4_map_list = []
316 | score4_list = []
317 |
318 | start_t = time.time()
319 | for x, y, mask, name in self.testloader:
320 | test_y_list.extend(y.detach().cpu().numpy())
321 | test_mask_list.extend(mask.detach().cpu().numpy())
322 | test_img_list.extend(x.detach().cpu().numpy())
323 | test_img_name_list.extend(name)
324 |
325 | x = x.to(self.device)
326 | _, _, H, W = x.shape
327 | out_a1, out_a2 = self.get_agent_out(x)
328 |
329 | project_out21, project_out22 = self.projector2(out_a1["out2"], out_a2["out2"])
330 | loss21_map = torch.sum((out_a1["out2"] - project_out21) ** 2 / self.var21, dim=1, keepdim=True)
331 | loss22_map = torch.sum((out_a2["out2"] - project_out22) ** 2 / self.var22, dim=1, keepdim=True)
332 |
333 | loss2_map = (loss21_map + loss22_map) / 2.0
334 | score2_map = F.interpolate(loss2_map, size=(H, W), mode='bilinear', align_corners=False)
335 | score2_map = score2_map.cpu().detach().numpy()
336 | score2_map_list.extend(score2_map)
337 | score2_list.extend(np.squeeze(np.max(np.max(score2_map, axis=2), axis=2), 1))
338 |
339 | project_out31, project_out32 = self.projector3(out_a1["out3"], out_a2["out3"])
340 | loss31_map = torch.sum((out_a1["out3"] - project_out31) ** 2 / self.var31, dim=1, keepdim=True)
341 | loss32_map = torch.sum((out_a2["out3"] - project_out32) ** 2 / self.var32, dim=1, keepdim=True)
342 |
343 | loss3_map = (loss31_map + loss32_map) / 2.0
344 | score3_map = F.interpolate(loss3_map, size=(H, W), mode='bilinear', align_corners=False)
345 | score3_map = score3_map.cpu().detach().numpy()
346 | score3_map_list.extend(score3_map)
347 | score3_list.extend(np.squeeze(np.max(np.max(score3_map, axis=2), axis=2), 1))
348 |
349 |
350 | project_out41, project_out42 = self.projector4(out_a1["out4"], out_a2["out4"])
351 | loss41_map = torch.sum((out_a1["out4"] - project_out41) ** 2 / self.var41, dim=1, keepdim=True)
352 | loss42_map = torch.sum((out_a2["out4"] - project_out42) ** 2 / self.var42, dim=1, keepdim=True)
353 |
354 | loss4_map = (loss41_map + loss42_map) / 2.0
355 | score4_map = F.interpolate(loss4_map, size=(H, W), mode='bilinear', align_corners=False)
356 | score4_map = score4_map.cpu().detach().numpy()
357 | score4_map_list.extend(score4_map)
358 | score4_list.extend(np.squeeze(np.max(np.max(score4_map, axis=2), axis=2), 1))
359 |
360 | score_map = (score4_map + score3_map + score2_map) / 3
361 | # score_map = gaussian_filter(score_map.squeeze(), sigma=4)
362 |
363 | score_map_list.extend(score_map)
364 | # score_list.extend(np.squeeze(np.max(np.max(score_map, axis=2), axis=2), 1))
365 | score_map = np.reshape(score_map, (score_map.shape[0], -1))
366 | score_list.extend(np.max(score_map, 1))
367 |
368 | end_t = time.time()
369 | t_per_imge = end_t - start_t
370 | t_per_imge = t_per_imge / len(score_list)
371 |
372 | visualize(test_img_list, test_mask_list, score_map_list, test_img_name_list, self.class_name,
373 | f"{self.save_root}image/", 10000)
374 |
375 | # ROCAUC
376 | # imge_auc2, pixel_auc2, pixel_pro2 = self.cal_auc(score2_list, score2_map_list, test_y_list, test_mask_list)
377 | # imge_auc3, pixel_auc3, pixel_pro3 = self.cal_auc(score3_list, score3_map_list, test_y_list, test_mask_list)
378 | # imge_auc4, pixel_auc4, pixel_pro4 = self.cal_auc(score4_list, score4_map_list, test_y_list, test_mask_list)
379 | imge_auc, pixel_auc, pixel_pro = self.cal_auc(score_list, score_map_list, test_y_list, test_mask_list)
380 | # print(f"pixel AUC: {pixel_level_ROCAUC:.5f}")
381 | # metrix = {"2": [imge_auc2, pixel_auc2, pixel_pro2], "3": [imge_auc3, pixel_auc3, pixel_pro3],"4":[imge_auc4,
382 | # pixel_auc4, pixel_pro4], "all": [imge_auc, pixel_auc, pixel_pro], "time": t_per_imge}
383 | metrix = {"all": [imge_auc, pixel_auc, pixel_pro], "time": t_per_imge}
384 |
385 |
386 | # test_y_list = np.array(test_y_list)
387 | # score_list = np.array(score_list)
388 | # p_index = test_y_list == 1
389 | # n_index = test_y_list == 0
390 |
391 | # p_score = score_list[p_index]
392 | # n_score = score_list[n_index]
393 | # data = [n_score, p_score]
394 | # np.save(os.path.join(self.save_root, f"{self.class_name}_image.npy"), (p_score, n_score))
395 | # # import seaborn as sns
396 | # # sns.boxplot(data=data)
397 | # # sns.histplot(p_score, kde=False, color="r")
398 | # # sns.histplot(n_score, kde=False, color="b")
399 |
400 | # # plt.show()
401 | return metrix
402 |
403 | def cal_auc(self, score_list, score_map_list, test_y_list, test_mask_list):
404 | flatten_y_list = np.array(test_y_list).ravel()
405 | flatten_score_list = np.array(score_list).ravel()
406 | image_level_ROCAUC = roc_auc_score(flatten_y_list, flatten_score_list)
407 |
408 | flatten_mask_list = np.concatenate(test_mask_list).ravel()
409 | flatten_score_map_list = np.concatenate(score_map_list).ravel()
410 | pixel_level_ROCAUC = roc_auc_score(flatten_mask_list, flatten_score_map_list)
411 | # pro_auc_score = 0
412 | pro_auc_score = cal_pro_metric_new(test_mask_list, score_map_list, fpr_thresh=0.3)
413 | return image_level_ROCAUC, pixel_level_ROCAUC, pro_auc_score
414 |
415 | def load_project_model(self):
416 | self.projector2.load_state_dict(torch.load(self.ckpt2))
417 | self.projector3.load_state_dict(torch.load(self.ckpt3))
418 | self.projector4.load_state_dict(torch.load(self.ckpt4))
419 |
420 | def center_crop(img, dim):
421 |
422 | """Returns center cropped image
423 | Args:
424 | img: image to be center cropped
425 | dim: dimensions (width, height) to be cropped
426 | """
427 | width, height = img.shape[1], img.shape[0]
428 |
429 | # process crop width and height for max available dimension
430 | crop_width = dim[0] if dim[0]