├── LICENSE ├── README.md ├── computer_flops_params.py ├── computer_time.py ├── config.py ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── dagm.cpython-36.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-38.pyc │ ├── logodataset.cpython-38.pyc │ ├── mvtec.cpython-36.pyc │ ├── mvtec.cpython-38.pyc │ ├── perlin.cpython-36.pyc │ └── perlin.cpython-38.pyc ├── dagm.py ├── dataset.py ├── logodataset.py ├── mvtec.py └── perlin.py ├── models ├── TFA_Net_model.py ├── __init__.py ├── __pycache__ │ ├── TFA_Net_model.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── misc.cpython-36.pyc │ ├── misc.cpython-38.pyc │ ├── model_MAE.cpython-36.pyc │ ├── model_MAE.cpython-38.pyc │ ├── networks.cpython-36.pyc │ ├── networks.cpython-38.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-38.pyc ├── efficientnet │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── model.cpython-36.pyc │ │ ├── model.cpython-38.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-38.pyc │ ├── model.py │ └── utils.py ├── misc.py ├── model_MAE.py ├── networks.py └── utils.py ├── network.png ├── ref_find.py ├── test.py ├── test_loco_mvtec.py ├── test_mvlogo.py ├── train.py ├── train_mvlogo.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── metric.cpython-38.pyc ├── gen_mask.py ├── image.py ├── metric.py ├── myutil.py ├── sPro.py ├── test.py └── tools.py └── vis_attention_map.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Wei Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Template-based Feature Aggregation Network for industrial anomaly detection (EAAI 2024)](https://www.sciencedirect.com/science/article/abs/pii/S0952197623019942) 2 | PyTorch implementation and for EAAI2024 paper, Template-based Feature Aggregation Network for industrial anomaly detection. 3 | ![这是图片](network.png) 4 | # Download Datasets 5 | Please download MVTecAD dataset from [MVTecAD dataset](https://www.mvtec.com/de/unternehmen/forschung/datasets/mvtec-ad/) and MVTecLOCOAD dataset from [MVTecLOCOAD dataset](https://www.mvtec.com/company/research/datasets/mvtec-loco). 6 | # Citation 7 | If you find this repository useful, please consider citing our work: 8 | ``` 9 | @article{luo2024template, 10 | title={Template-based Feature Aggregation Network for industrial anomaly detection}, 11 | author={Luo, Wei and Yao, Haiming and Yu, Wenyong}, 12 | journal={Engineering Applications of Artificial Intelligence}, 13 | volume={131}, 14 | pages={107810}, 15 | year={2024}, 16 | publisher={Elsevier} 17 | } 18 | ``` 19 | 20 | -------------------------------------------------------------------------------- /computer_flops_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from thop import profile 3 | from config import DefaultConfig 4 | from models.TFA_Net_model import * 5 | opt = DefaultConfig() 6 | model = eval(opt.model_name)(opt) 7 | model.eval() 8 | input = torch.rand(1, 3, 256, 256) 9 | flops, params = profile(model, (input, input, 'test')) 10 | print('flops: ', str(flops/1e9)+'G', 'params: ', str(params/1e6)+'M') -------------------------------------------------------------------------------- /computer_time.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from config import DefaultConfig 3 | from ref_find import get_pos_sample 4 | from models.TFA_Net_model import * 5 | import numpy as np 6 | opt = DefaultConfig() 7 | model = eval(opt.model_name)(opt) 8 | 9 | device = torch.device('cuda:0') 10 | 11 | model.to(device) 12 | dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float).to(device) 13 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 14 | repetitions = 300 15 | timings=np.zeros((repetitions,1)) 16 | model.eval() 17 | #GPU-WARM-UP 18 | for _ in range(10): 19 | _ = model(dummy_input, dummy_input, 'test') 20 | # MEASURE PERFORMANCE 21 | with torch.no_grad(): 22 | for rep in range(repetitions): 23 | starter.record() 24 | _ = model(dummy_input, dummy_input, 'test') 25 | ender.record() 26 | # WAIT FOR GPU SYNC 27 | torch.cuda.synchronize() 28 | curr_time = starter.elapsed_time(ender) 29 | timings[rep] = curr_time 30 | mean_syn = np.sum(timings) / repetitions 31 | std_syn = np.std(timings) 32 | mean_fps = 1000. / mean_syn 33 | print(' * Mean@1 {mean_syn:.3f}ms Std@5 {std_syn:.3f}ms FPS@1 {mean_fps:.2f}'.format(mean_syn=mean_syn, std_syn=std_syn, mean_fps=mean_fps)) 34 | print(mean_syn) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | model_name_list = ['RB_VIT_dir'] 4 | 5 | class DefaultConfig(object): 6 | class_name = 'bottle' 7 | data_root = r'data/mvtec_anomaly_detection' 8 | device = torch.device('cuda:0') 9 | model_name = model_name_list[0] 10 | batch_size = 4 11 | iter = 0 12 | niter = 400 13 | lr = 0.0001 14 | lr_decay = 0.90 15 | weight_decay = 1e-5 16 | momentum = 0.9 17 | nc = 3 18 | isTrain = True 19 | backbone_name = 'WideResnet50' 20 | referenc_img_file = f'' 21 | resume = '' 22 | k = 4 23 | 24 | 25 | if __name__ == '__main__': 26 | opt = DefaultConfig() 27 | opt.trai = 1 28 | print(opt.trai) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dagm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/dagm.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/logodataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/logodataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/mvtec.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/mvtec.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/mvtec.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/mvtec.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/perlin.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/perlin.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/perlin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/datasets/__pycache__/perlin.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/dagm.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import tarfile 3 | from PIL import Image 4 | # import urllib.request 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | import imgaug.augmenters as iaa 10 | import glob 11 | from datasets.perlin import rand_perlin_2d_np 12 | import numpy as np 13 | import cv2 14 | import random 15 | 16 | # URL = 'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz' 17 | CLASS_NAMES = [ 18 | 'wallpaper', 'cement', 'MAGtile', 'fabric','wood','carpet'] 19 | 20 | class MVTecDataset(Dataset): 21 | def __init__(self, 22 | dataset_path='../data/DAGM', 23 | class_name='wallpaper', 24 | is_train=False, 25 | resize=256, 26 | anomaly_sourec_path='../data/nature/' 27 | ): 28 | assert class_name in CLASS_NAMES, 'class_name: {}, should be in {}'.format(class_name, CLASS_NAMES) 29 | self.dataset_path = dataset_path 30 | self.class_name = class_name 31 | self.is_train = is_train 32 | self.resize = resize 33 | self.anomaly_source_paths = sorted(glob.glob(anomaly_sourec_path+"/*.JPEG")) 34 | self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True), 35 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)), 36 | iaa.pillike.EnhanceSharpness(), 37 | iaa.AddToHueAndSaturation((-50,50),per_channel=True), 38 | iaa.Solarize(0.5, threshold=(32,128)), 39 | iaa.Posterize(), 40 | iaa.Invert(), 41 | iaa.pillike.Autocontrast(), 42 | iaa.pillike.Equalize(), 43 | iaa.Affine(rotate=(-45, 45)) 44 | ] 45 | 46 | 47 | self.x, self.y, self.mask = self.load_dataset_folder() 48 | 49 | # set transforms 50 | self.transform_x = transforms.Compose([ 51 | transforms.Resize(resize, Image.ANTIALIAS), 52 | transforms.ToTensor()]) 53 | self.transform_mask = transforms.Compose( 54 | [transforms.Resize(resize, Image.NEAREST), 55 | transforms.ToTensor()]) 56 | self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) 57 | 58 | def randAugmenter(self): 59 | aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False) 60 | aug = iaa.Sequential([self.augmenters[aug_ind[0]], 61 | self.augmenters[aug_ind[1]], 62 | self.augmenters[aug_ind[2]]] 63 | ) 64 | return aug 65 | 66 | def augment_image(self, image, anomaly_source_path): 67 | 68 | random_nature_img_name = random.sample(anomaly_source_path, 1)[0] 69 | aug = self.randAugmenter() 70 | perlin_scale = 6 71 | min_perlin_scale = 0 72 | anomaly_source_img = cv2.imread(random_nature_img_name) 73 | # cv2.imwrite('luowei4.jpg', anomaly_source_img) 74 | anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.resize, self.resize)) 75 | cv2.imwrite('luowei3.jpg', anomaly_source_img) 76 | 77 | anomaly_img_augmented = aug(image=anomaly_source_img) 78 | cv2.imwrite('luowei4.jpg', anomaly_img_augmented) 79 | perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 80 | perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 81 | 82 | perlin_noise = rand_perlin_2d_np((self.resize, self.resize), (perlin_scalex, perlin_scaley)) 83 | perlin_noise = self.rot(image=perlin_noise) 84 | threshold = 0.5 85 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) 86 | perlin_thr = np.expand_dims(perlin_thr, axis=2) 87 | 88 | img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0 89 | 90 | beta = torch.rand(1).numpy()[0] 91 | 92 | augmented_image = image * (1 - perlin_thr) + img_thr*perlin_thr 93 | 94 | # no_anomaly = torch.rand(1).numpy()[0] 95 | # if no_anomaly > 0.8: 96 | # image = image.astype(np.float32) 97 | # return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0], dtype=np.float32) 98 | 99 | # augmented_image = augmented_image.astype(np.float32) 100 | msk = (perlin_thr).astype(np.float32) 101 | # augmented_image = msk * augmented_image + (1 - msk) * image 102 | has_anomaly = 1.0 103 | if np.sum(msk) == 0: 104 | has_anomaly = 0.0 105 | return augmented_image, msk, np.array([has_anomaly], dtype=np.float32) 106 | 107 | 108 | def __getitem__(self, idx): 109 | # print(self.x) 110 | x, y, mask = self.x[idx], self.y[idx], self.mask[idx] 111 | aug_x, aug_mask, aug_label = self.random_anomaly(x) 112 | 113 | aug_x = Image.fromarray(np.uint8(aug_x)) 114 | aug_x = self.transform_x(aug_x) 115 | aug_mask = aug_mask.reshape(aug_mask.shape[0], aug_mask.shape[1]) 116 | aug_mask = Image.fromarray(np.uint8(aug_mask*255)) 117 | aug_mask = self.transform_mask(aug_mask) 118 | x = Image.open(x).convert('RGB') 119 | x = self.transform_x(x) 120 | 121 | if y == 0: 122 | mask = torch.zeros([1, self.resize, self.resize]) 123 | else: 124 | mask = cv2.imread(mask, 0) 125 | cv2.imwrite('2.jpg', mask) 126 | mask = Image.fromarray(np.uint8(mask)) 127 | mask = self.transform_mask(mask) 128 | 129 | 130 | # if self.train_stage == 1: 131 | # return x, y, mask 132 | # elif self.train_stage == 2: 133 | return x, y, mask, aug_x, aug_mask, aug_label 134 | 135 | 136 | def __len__(self): 137 | return len(self.x) 138 | 139 | def random_anomaly(self, image_path): 140 | image = cv2.imread(image_path) 141 | image = cv2.resize(image, dsize=(self.resize, self.resize)) 142 | # image = image/255.0 143 | 144 | image = np.array(image).astype(np.float32) / 255.0 145 | # print(image.shape) 146 | aug_img, aug_mask, aug_label = self.augment_image(image, self.anomaly_source_paths) 147 | # print(aug_img.shape) 148 | # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 149 | aug_img = cv2.cvtColor(np.uint8(aug_img*255), cv2.COLOR_BGR2RGB) 150 | # aug_img = np.transpose(aug_img, (2, 0, 1)) 151 | # aug_mask = np.transpose(aug_mask, (2, 0, 1)) 152 | return aug_img, aug_mask, aug_label 153 | 154 | def load_dataset_folder(self): 155 | phase = 'train' if self.is_train else 'test' 156 | x, y, mask = [], [], [] 157 | 158 | img_dir = os.path.join(self.dataset_path, self.class_name, phase) 159 | gt_dir = os.path.join(self.dataset_path, self.class_name, 'ground_truth') 160 | 161 | img_types = sorted(os.listdir(img_dir)) 162 | for img_type in img_types: 163 | 164 | # load images 165 | img_type_dir = os.path.join(img_dir, img_type) 166 | if not os.path.isdir(img_type_dir): 167 | continue 168 | img_fpath_list = sorted( 169 | [os.path.join(img_type_dir, f) for f in os.listdir(img_type_dir) if f.endswith('.png')]) 170 | x.extend(img_fpath_list) 171 | 172 | # load gt labels 173 | if img_type == 'good': 174 | # print('ok') 175 | y.extend([0] * len(img_fpath_list)) 176 | mask.extend([None] * len(img_fpath_list)) 177 | else: 178 | y.extend([1] * len(img_fpath_list)) 179 | gt_type_dir = os.path.join(gt_dir, img_type) 180 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list] 181 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '_mask.png') for img_fname in img_fname_list] 182 | # print(gt_fpath_list) 183 | mask.extend(gt_fpath_list) 184 | 185 | assert len(x) == len(y), 'number of x and y should be same' 186 | 187 | return list(x), list(y), list(mask) 188 | 189 | def tensor_to_np(tensor_img): 190 | np_img = np.array(tensor_img) 191 | np_img = np.transpose(np_img, (1, 2, 0)) 192 | if np_img.shape[2] == 3: 193 | np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) 194 | return np_img 195 | if __name__ == '__main__': 196 | mvtec = MVTecDataset() 197 | x, y, mask, aug_x, aug_mask,_ = mvtec[1] 198 | # print(x) 199 | # print(y.shape) 200 | # print(mask.shape) 201 | # print(aug_x) 202 | # print(aug_mask) 203 | x = tensor_to_np(x) 204 | cv2.imwrite('luowei1.jpg', x*255) 205 | np_img = tensor_to_np(aug_x) 206 | cv2.imwrite('luowei.jpg', np_img*255) 207 | aug_mask = tensor_to_np(mask) 208 | cv2.imwrite('luowei2.jpg', aug_mask*255) 209 | 210 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import tarfile 3 | from PIL import Image 4 | # import urllib.request 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | import imgaug.augmenters as iaa 9 | import glob 10 | from datasets.perlin import rand_perlin_2d_np 11 | import numpy as np 12 | import cv2 13 | import random 14 | 15 | # URL = 'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz' 16 | CLASS_NAMES = [ 17 | 'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 18 | 'toothbrush', 'transistor', 'wood', 'zipper', 'ubolt' 19 | ] 20 | 21 | class MVTecDataset(Dataset): 22 | def __init__(self, 23 | dataset_path='../data/mvtec_anomaly_detection', 24 | class_name='leather', 25 | is_train=True, 26 | resize=256, 27 | ): 28 | assert class_name in CLASS_NAMES, 'class_name: {}, should be in {}'.format(class_name, CLASS_NAMES) 29 | self.dataset_path = dataset_path 30 | self.class_name = class_name 31 | self.is_train = is_train 32 | self.resize = resize 33 | 34 | self.x, self.y, self.mask = self.load_dataset_folder() 35 | self.len = len(self.x) 36 | self.name = [] 37 | # set transforms 38 | self.transform_x = transforms.Compose([ 39 | transforms.Resize(resize, Image.ANTIALIAS), 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 42 | std= [0.229, 0.224, 0.225])]) 43 | 44 | self.transform_mask = transforms.Compose( 45 | [transforms.Resize(resize, Image.NEAREST), 46 | transforms.ToTensor()]) 47 | 48 | 49 | for i in range(self.len): 50 | names = self.x[i].split("\\") 51 | name = names[-2]+"!"+names[-1] 52 | self.name.append(name) 53 | 54 | def __getitem__(self, idx): 55 | 56 | x, y, mask, name = self.x[idx], self.y[idx], self.mask[idx], self.name[idx] 57 | x = Image.open(x).convert('RGB') 58 | x = self.transform_x(x) 59 | 60 | if y == 0: 61 | mask = torch.zeros([1, self.resize, self.resize]) 62 | else: 63 | mask = Image.open(mask).convert('L') 64 | mask = self.transform_mask(mask) 65 | # if self.train_stage == 1: 66 | # return x, y, mask 67 | # elif self.train_stage == 2: 68 | return x, y, mask, name 69 | 70 | 71 | def __len__(self): 72 | return len(self.x) 73 | 74 | def load_dataset_folder(self): 75 | phase = 'train' if self.is_train else 'test' 76 | x, y, mask = [], [], [] 77 | 78 | img_dir = os.path.join(self.dataset_path, self.class_name, phase) 79 | gt_dir = os.path.join(self.dataset_path, self.class_name, 'ground_truth') 80 | 81 | img_types = sorted(os.listdir(img_dir)) 82 | for img_type in img_types: 83 | # load images 84 | img_type_dir = os.path.join(img_dir, img_type) 85 | if not os.path.isdir(img_type_dir): 86 | continue 87 | img_fpath_list = sorted( 88 | [os.path.join(img_type_dir, f) for f in os.listdir(img_type_dir) if f.endswith('.png') or f.endswith('.bmp')]) 89 | x.extend(img_fpath_list) 90 | 91 | # load gt labels 92 | if img_type == 'good': 93 | y.extend([0] * len(img_fpath_list)) 94 | mask.extend([None] * len(img_fpath_list)) 95 | else: 96 | y.extend([1] * len(img_fpath_list)) 97 | gt_type_dir = os.path.join(gt_dir, img_type) 98 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list] 99 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '_mask.png') for img_fname in img_fname_list] 100 | mask.extend(gt_fpath_list) 101 | 102 | assert len(x) == len(y), 'number of x and y should be same' 103 | 104 | return list(x), list(y), list(mask) 105 | 106 | def tensor_to_np(tensor_img): 107 | np_img = np.array(tensor_img) 108 | np_img = np.transpose(np_img, (1, 2, 0)) 109 | if np_img.shape[2] == 3: 110 | np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) 111 | return np_img 112 | def denormalize(img): 113 | std = np.array([0.229, 0.224, 0.225]) 114 | mean = np.array([0.485, 0.456, 0.406]) 115 | x = (((img.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8) 116 | return x 117 | if __name__ == '__main__': 118 | mvtec = MVTecDataset() 119 | x, y, mask, aug_x, aug_mask,_ = mvtec[0] 120 | # print(x) 121 | # print(y.shape) 122 | # print(mask.shape) 123 | # print(aug_x) 124 | # print(aug_mask) 125 | x = tensor_to_np(x) 126 | cv2.imwrite('luowei1.jpg', x*255) 127 | np_img = tensor_to_np(aug_x) 128 | cv2.imwrite('luowei.jpg', np_img*255) 129 | aug_mask = tensor_to_np(aug_mask) 130 | cv2.imwrite('luowei2.jpg', aug_mask*255) 131 | -------------------------------------------------------------------------------- /datasets/logodataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import tarfile 3 | from PIL import Image 4 | # import urllib.request 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | import imgaug.augmenters as iaa 9 | import glob 10 | from datasets.perlin import rand_perlin_2d_np 11 | import numpy as np 12 | import cv2 13 | import random 14 | 15 | # URL = 'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz' 16 | CLASS_NAMES = [ 17 | 'breakfast_box', 'juice_bottle', 'pushpins', 'screw_bag', 'splicing_connectors'] 18 | 19 | class MVTecDataset(Dataset): 20 | def __init__(self, 21 | dataset_path='../data/mvtec_anomaly_detection', 22 | class_name='breakfast_box', 23 | is_train=True, 24 | resize=(256, 256), 25 | ): 26 | assert class_name in CLASS_NAMES, 'class_name: {}, should be in {}'.format(class_name, CLASS_NAMES) 27 | self.dataset_path = dataset_path 28 | self.class_name = class_name 29 | self.is_train = is_train 30 | self.resize = resize 31 | 32 | self.x, self.y, self.mask = self.load_dataset_folder() 33 | self.len = len(self.x) 34 | self.name = [] 35 | # set transforms 36 | self.transform_x = transforms.Compose([ 37 | transforms.Resize(resize, Image.ANTIALIAS), 38 | transforms.ToTensor(), 39 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 40 | std= [0.229, 0.224, 0.225])]) 41 | 42 | self.transform_mask = transforms.Compose( 43 | [transforms.Resize(resize, Image.NEAREST), 44 | transforms.ToTensor()]) 45 | 46 | 47 | for i in range(self.len): 48 | names = self.x[i].split("\\") 49 | name = names[-2]+"!"+names[-1] 50 | self.name.append(name) 51 | 52 | def __getitem__(self, idx): 53 | 54 | x, y, mask, name = self.x[idx], self.y[idx], self.mask[idx], self.name[idx] 55 | x = Image.open(x).convert('RGB') 56 | x = self.transform_x(x) 57 | 58 | if y == 0: 59 | mask = torch.zeros([1, *self.resize]) 60 | else: 61 | mask = Image.open(mask).convert('L') 62 | mask = self.transform_mask(mask) 63 | # if self.train_stage == 1: 64 | # return x, y, mask 65 | # elif self.train_stage == 2: 66 | return x, y, mask, name 67 | 68 | 69 | def __len__(self): 70 | return len(self.x) 71 | 72 | def load_dataset_folder(self): 73 | phase = 'train' if self.is_train else 'test' 74 | x, y, mask = [], [], [] 75 | 76 | img_dir = os.path.join(self.dataset_path, self.class_name, phase) 77 | gt_dir = os.path.join(self.dataset_path, self.class_name, 'ground_truth') 78 | 79 | img_types = sorted(os.listdir(img_dir)) 80 | for img_type in img_types: 81 | # load images 82 | img_type_dir = os.path.join(img_dir, img_type) 83 | if not os.path.isdir(img_type_dir): 84 | continue 85 | img_fpath_list = sorted( 86 | [os.path.join(img_type_dir, f) for f in os.listdir(img_type_dir) if f.endswith('.png')]) 87 | x.extend(img_fpath_list) 88 | 89 | # load gt labels 90 | if img_type == 'good': 91 | y.extend([0] * len(img_fpath_list)) 92 | mask.extend([None] * len(img_fpath_list)) 93 | else: 94 | y.extend([1] * len(img_fpath_list)) 95 | gt_type_dir = os.path.join(gt_dir, img_type) 96 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list] 97 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname, '000.png') for img_fname in img_fname_list] 98 | mask.extend(gt_fpath_list) 99 | 100 | assert len(x) == len(y), 'number of x and y should be same' 101 | 102 | return list(x), list(y), list(mask) 103 | 104 | def tensor_to_np(tensor_img): 105 | np_img = np.array(tensor_img) 106 | np_img = np.transpose(np_img, (1, 2, 0)) 107 | if np_img.shape[2] == 3: 108 | np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) 109 | return np_img 110 | def denormalize(img): 111 | std = np.array([0.229, 0.224, 0.225]) 112 | mean = np.array([0.485, 0.456, 0.406]) 113 | x = (((img.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8) 114 | return x 115 | if __name__ == '__main__': 116 | mvtec = MVTecDataset() 117 | x, y, mask, aug_x, aug_mask,_ = mvtec[0] 118 | # print(x) 119 | # print(y.shape) 120 | # print(mask.shape) 121 | # print(aug_x) 122 | # print(aug_mask) 123 | x = tensor_to_np(x) 124 | cv2.imwrite('luowei1.jpg', x*255) 125 | np_img = tensor_to_np(aug_x) 126 | cv2.imwrite('luowei.jpg', np_img*255) 127 | aug_mask = tensor_to_np(aug_mask) 128 | cv2.imwrite('luowei2.jpg', aug_mask*255) 129 | -------------------------------------------------------------------------------- /datasets/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import tarfile 3 | from PIL import Image 4 | # import urllib.request 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | import imgaug.augmenters as iaa 10 | import glob 11 | from datasets.perlin import rand_perlin_2d_np 12 | import numpy as np 13 | import cv2 14 | import random 15 | 16 | # URL = 'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz' 17 | CLASS_NAMES = [ 18 | 'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 19 | 'toothbrush', 'transistor', 'wood', 'zipper' 20 | ] 21 | 22 | class MVTecDataset(Dataset): 23 | def __init__(self, 24 | dataset_path='../data/mvtec_anomaly_detection', 25 | class_name='leather', 26 | is_train=True, 27 | resize=256, 28 | anomaly_sourec_path='../data/nature/' 29 | ): 30 | assert class_name in CLASS_NAMES, 'class_name: {}, should be in {}'.format(class_name, CLASS_NAMES) 31 | self.dataset_path = dataset_path 32 | self.class_name = class_name 33 | self.is_train = is_train 34 | self.resize = resize 35 | self.anomaly_source_paths = sorted(glob.glob(anomaly_sourec_path+"/*.JPEG")) 36 | self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True), 37 | iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)), 38 | iaa.pillike.EnhanceSharpness(), 39 | iaa.AddToHueAndSaturation((-50,50),per_channel=True), 40 | iaa.Solarize(0.5, threshold=(32,128)), 41 | iaa.Posterize(), 42 | iaa.Invert(), 43 | iaa.pillike.Autocontrast(), 44 | iaa.pillike.Equalize(), 45 | iaa.Affine(rotate=(-45, 45)) 46 | ] 47 | 48 | 49 | self.x, self.y, self.mask = self.load_dataset_folder() 50 | 51 | # set transforms 52 | self.transform_x = transforms.Compose([ 53 | transforms.Resize(resize, Image.ANTIALIAS), 54 | transforms.ToTensor()]) 55 | self.transform_mask = transforms.Compose( 56 | [transforms.Resize(resize, Image.NEAREST), 57 | transforms.ToTensor()]) 58 | self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) 59 | 60 | def randAugmenter(self): 61 | aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False) 62 | aug = iaa.Sequential([self.augmenters[aug_ind[0]], 63 | self.augmenters[aug_ind[1]], 64 | self.augmenters[aug_ind[2]]] 65 | ) 66 | return aug 67 | 68 | def augment_image(self, image, anomaly_source_path): 69 | 70 | random_nature_img_name = random.sample(anomaly_source_path, 1)[0] 71 | aug = self.randAugmenter() 72 | perlin_scale = 6 73 | min_perlin_scale = 0 74 | anomaly_source_img = cv2.imread(random_nature_img_name) 75 | # cv2.imwrite('luowei4.jpg', anomaly_source_img) 76 | anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.resize, self.resize)) 77 | cv2.imwrite('luowei3.jpg', anomaly_source_img) 78 | 79 | anomaly_img_augmented = aug(image=anomaly_source_img) 80 | cv2.imwrite('luowei4.jpg', anomaly_img_augmented) 81 | perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 82 | perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0]) 83 | 84 | perlin_noise = rand_perlin_2d_np((self.resize, self.resize), (perlin_scalex, perlin_scaley)) 85 | perlin_noise = self.rot(image=perlin_noise) 86 | threshold = 0.5 87 | perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) 88 | perlin_thr = np.expand_dims(perlin_thr, axis=2) 89 | 90 | img_thr = anomaly_img_augmented.astype(np.float32) * perlin_thr / 255.0 91 | 92 | beta = torch.rand(1).numpy()[0] 93 | 94 | augmented_image = image * (1 - perlin_thr) + img_thr*perlin_thr 95 | 96 | # no_anomaly = torch.rand(1).numpy()[0] 97 | # if no_anomaly > 0.8: 98 | # image = image.astype(np.float32) 99 | # return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0], dtype=np.float32) 100 | 101 | # augmented_image = augmented_image.astype(np.float32) 102 | msk = (perlin_thr).astype(np.float32) 103 | # augmented_image = msk * augmented_image + (1 - msk) * image 104 | has_anomaly = 1.0 105 | if np.sum(msk) == 0: 106 | has_anomaly = 0.0 107 | return augmented_image, msk, np.array([has_anomaly], dtype=np.float32) 108 | 109 | 110 | def __getitem__(self, idx): 111 | # print(self.x) 112 | x, y, mask = self.x[idx], self.y[idx], self.mask[idx] 113 | aug_x, aug_mask, aug_label = self.random_anomaly(x) 114 | 115 | aug_x = Image.fromarray(np.uint8(aug_x)) 116 | aug_x = self.transform_x(aug_x) 117 | aug_mask = aug_mask.reshape(aug_mask.shape[0], aug_mask.shape[1]) 118 | aug_mask = Image.fromarray(np.uint8(aug_mask*255)) 119 | aug_mask = self.transform_mask(aug_mask) 120 | x = Image.open(x).convert('RGB') 121 | x = self.transform_x(x) 122 | 123 | if y == 0: 124 | mask = torch.zeros([1, self.resize, self.resize]) 125 | else: 126 | mask = Image.open(mask) 127 | mask = self.transform_mask(mask) 128 | # if self.train_stage == 1: 129 | # return x, y, mask 130 | # elif self.train_stage == 2: 131 | return x, y, mask, aug_x,aug_mask, aug_label 132 | 133 | 134 | def __len__(self): 135 | return len(self.x) 136 | 137 | def random_anomaly(self, image_path): 138 | image = cv2.imread(image_path) 139 | image = cv2.resize(image, dsize=(self.resize, self.resize)) 140 | # image = image/255.0 141 | 142 | image = np.array(image).astype(np.float32) / 255.0 143 | # print(image.shape) 144 | aug_img, aug_mask, aug_label = self.augment_image(image, self.anomaly_source_paths) 145 | # print(aug_img.shape) 146 | # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 147 | aug_img = cv2.cvtColor(np.uint8(aug_img*255), cv2.COLOR_BGR2RGB) 148 | # aug_img = np.transpose(aug_img, (2, 0, 1)) 149 | # aug_mask = np.transpose(aug_mask, (2, 0, 1)) 150 | return aug_img, aug_mask, aug_label 151 | 152 | def load_dataset_folder(self): 153 | phase = 'train' if self.is_train else 'test' 154 | x, y, mask = [], [], [] 155 | 156 | img_dir = os.path.join(self.dataset_path, self.class_name, phase) 157 | gt_dir = os.path.join(self.dataset_path, self.class_name, 'ground_truth') 158 | 159 | img_types = sorted(os.listdir(img_dir)) 160 | for img_type in img_types: 161 | 162 | # load images 163 | img_type_dir = os.path.join(img_dir, img_type) 164 | if not os.path.isdir(img_type_dir): 165 | continue 166 | img_fpath_list = sorted( 167 | [os.path.join(img_type_dir, f) for f in os.listdir(img_type_dir) if f.endswith('.png')]) 168 | x.extend(img_fpath_list) 169 | 170 | # load gt labels 171 | if img_type == 'good': 172 | y.extend([0] * len(img_fpath_list)) 173 | mask.extend([None] * len(img_fpath_list)) 174 | else: 175 | y.extend([1] * len(img_fpath_list)) 176 | gt_type_dir = os.path.join(gt_dir, img_type) 177 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list] 178 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '_mask.png') for img_fname in img_fname_list] 179 | mask.extend(gt_fpath_list) 180 | 181 | assert len(x) == len(y), 'number of x and y should be same' 182 | 183 | return list(x), list(y), list(mask) 184 | 185 | def tensor_to_np(tensor_img): 186 | np_img = np.array(tensor_img) 187 | np_img = np.transpose(np_img, (1, 2, 0)) 188 | if np_img.shape[2] == 3: 189 | np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) 190 | return np_img 191 | if __name__ == '__main__': 192 | mvtec = MVTecDataset() 193 | x, y, mask, aug_x, aug_mask,_ = mvtec[0] 194 | # print(x) 195 | # print(y.shape) 196 | # print(mask.shape) 197 | # print(aug_x) 198 | # print(aug_mask) 199 | x = tensor_to_np(x) 200 | cv2.imwrite('luowei1.jpg', x*255) 201 | np_img = tensor_to_np(aug_x) 202 | cv2.imwrite('luowei.jpg', np_img*255) 203 | aug_mask = tensor_to_np(aug_mask) 204 | cv2.imwrite('luowei2.jpg', aug_mask*255) 205 | -------------------------------------------------------------------------------- /datasets/perlin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | def lerp_np(x,y,w): 6 | fin_out = (y-x)*w + x 7 | return fin_out 8 | 9 | def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5): 10 | noise = np.zeros(shape) 11 | frequency = 1 12 | amplitude = 1 13 | for _ in range(octaves): 14 | noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1])) 15 | frequency *= 2 16 | amplitude *= persistence 17 | return noise 18 | 19 | 20 | def generate_perlin_noise_2d(shape, res): 21 | def f(t): 22 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3 23 | 24 | delta = (res[0] / shape[0], res[1] / shape[1]) 25 | d = (shape[0] // res[0], shape[1] // res[1]) 26 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 27 | # Gradients 28 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) 29 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 30 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 31 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 32 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) 33 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) 34 | # Ramps 35 | n00 = np.sum(grid * g00, 2) 36 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) 37 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) 38 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) 39 | # Interpolation 40 | t = f(grid) 41 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 42 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 43 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) 44 | 45 | 46 | def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 47 | delta = (res[0] / shape[0], res[1] / shape[1]) 48 | d = (shape[0] // res[0], shape[1] // res[1]) 49 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 50 | 51 | angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) 52 | gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) 53 | tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1) 54 | 55 | tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1) 56 | dot = lambda grad, shift: ( 57 | np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 58 | axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) 59 | 60 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 61 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 62 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 63 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 64 | t = fade(grid[:shape[0], :shape[1]]) 65 | return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) 66 | 67 | 68 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): 69 | delta = (res[0] / shape[0], res[1] / shape[1]) 70 | d = (shape[0] // res[0], shape[1] // res[1]) 71 | 72 | grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 73 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) 74 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) 75 | 76 | tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 77 | 0).repeat_interleave( 78 | d[1], 1) 79 | dot = lambda grad, shift: ( 80 | torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), 81 | dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) 82 | 83 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 84 | 85 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 86 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 87 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 88 | t = fade(grid[:shape[0], :shape[1]]) 89 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) 90 | 91 | 92 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): 93 | noise = torch.zeros(shape) 94 | frequency = 1 95 | amplitude = 1 96 | for _ in range(octaves): 97 | noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) 98 | frequency *= 2 99 | amplitude *= persistence 100 | return noise -------------------------------------------------------------------------------- /models/TFA_Net_model.py: -------------------------------------------------------------------------------- 1 | from models.model_MAE import * 2 | from models.networks import * 3 | 4 | from torch import nn 5 | 6 | 7 | class RB_VIT_dir(nn.Module): 8 | def __init__(self, opt): 9 | super(RB_VIT_dir, self).__init__() 10 | 11 | if opt.backbone_name == 'D_VGG': 12 | self.Feature_extractor = D_VGG().eval() 13 | self.Roncon_model = RB_MAE_dir(in_chans=768) 14 | 15 | if opt.backbone_name == 'VGG': 16 | self.Feature_extractor = VGG().eval() 17 | self.Roncon_model = RB_MAE_dir(in_chans=960, patch_size=4) 18 | 19 | if opt.backbone_name == 'Resnet34': 20 | self.Feature_extractor = Resnet34().eval() 21 | self.Roncon_model = RB_MAE_dir(in_chans=512) 22 | 23 | if opt.backbone_name == 'Resnet50': 24 | self.Feature_extractor = Resnet50().eval() 25 | self.Roncon_model = RB_MAE_dir(in_chans=1856) 26 | 27 | if opt.backbone_name == 'WideResnet50': 28 | self.Feature_extractor = WideResNet50().eval() 29 | self.Roncon_model = RB_MAE_dir(in_chans=1856, patch_size=opt.k) 30 | 31 | if opt.backbone_name == 'Resnet101': 32 | self.Feature_extractor = Resnet101().eval() 33 | self.Roncon_model = RB_MAE_dir(in_chans=1856) 34 | 35 | if opt.backbone_name == 'WideResnet101': 36 | self.Feature_extractor = WideResnet101().eval() 37 | self.Roncon_model = RB_MAE_dir(in_chans=1856) 38 | 39 | if opt.backbone_name == 'MobileNet': 40 | self.Feature_extractor = MobileNet().eval() 41 | self.Roncon_model = RB_MAE_dir(in_chans=104) 42 | 43 | 44 | def forward(self, imgs, ref_imgs, stages): 45 | deep_feature = self.Feature_extractor(imgs) 46 | ref_deep_feature = self.Feature_extractor(ref_imgs) 47 | loss, pre_feature, _ = self.Roncon_model(deep_feature, ref_deep_feature) 48 | pre_feature_recon = self.Roncon_model.unpatchify(pre_feature) 49 | # vis_feature = [self.Roncon_model.unpatchify(i) for i in vis_feature] 50 | return deep_feature, ref_deep_feature, pre_feature_recon, loss 51 | 52 | def a_map(self, deep_feature, recon_feature): 53 | # recon_feature = self.Roncon_model.unpatchify(pre_feature) 54 | batch_size = recon_feature.shape[0] 55 | dis_map = torch.mean((deep_feature - recon_feature) ** 2, dim=1, keepdim=True) 56 | dis_map = nn.functional.interpolate(dis_map, size=(256, 256), mode="bilinear", align_corners=True).squeeze(1) 57 | dis_map = dis_map.clone().squeeze(0).cpu().detach().numpy() 58 | 59 | dir_map = 1 - torch.nn.CosineSimilarity()(deep_feature, recon_feature) 60 | dir_map = dir_map.reshape(batch_size, 1, 64, 64) 61 | dir_map = nn.functional.interpolate(dir_map, size=(256, 256), mode="bilinear", align_corners=True).squeeze(1) 62 | dir_map = dir_map.clone().squeeze(0).cpu().detach().numpy() 63 | return dis_map, dir_map 64 | 65 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models import * -------------------------------------------------------------------------------- /models/__pycache__/TFA_Net_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/TFA_Net_model.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_MAE.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/model_MAE.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_MAE.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/model_MAE.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | """__init__.py - all efficientnet models. 2 | """ 3 | 4 | # Author: lukemelas (github username) 5 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 6 | # With adjustments and added comments by workingcoder (github username). 7 | 8 | 9 | __version__ = "0.7.1" 10 | from .model import EfficientNet 11 | 12 | __all__ = [ 13 | "EfficientNet", 14 | "efficientnet_b0", 15 | "efficientnet_b1", 16 | "efficientnet_b2", 17 | "efficientnet_b3", 18 | "efficientnet_b4", 19 | "efficientnet_b5", 20 | "efficientnet_b6", 21 | "efficientnet_b7", 22 | "efficientnet_b8", 23 | "efficientnet_l2", 24 | ] 25 | 26 | 27 | def efficientnet_b0(pretrained, outblocks, outstrides, pretrained_model=""): 28 | return build_efficient( 29 | "efficientnet-b0", pretrained, outblocks, outstrides, pretrained_model 30 | ) 31 | 32 | 33 | def efficientnet_b1(pretrained, outblocks, outstrides, pretrained_model=""): 34 | return build_efficient( 35 | "efficientnet-b1", pretrained, outblocks, outstrides, pretrained_model 36 | ) 37 | 38 | 39 | def efficientnet_b2(pretrained, outblocks, outstrides, pretrained_model=""): 40 | return build_efficient( 41 | "efficientnet-b2", pretrained, outblocks, outstrides, pretrained_model 42 | ) 43 | 44 | 45 | def efficientnet_b3(pretrained, outblocks, outstrides, pretrained_model=""): 46 | return build_efficient( 47 | "efficientnet-b3", pretrained, outblocks, outstrides, pretrained_model 48 | ) 49 | 50 | 51 | def efficientnet_b4(pretrained, outblocks, outstrides, pretrained_model=""): 52 | return build_efficient( 53 | "efficientnet-b4", pretrained, outblocks, outstrides, pretrained_model 54 | ) 55 | 56 | 57 | def efficientnet_b5(pretrained, outblocks, outstrides, pretrained_model=""): 58 | return build_efficient( 59 | "efficientnet-b5", pretrained, outblocks, outstrides, pretrained_model 60 | ) 61 | 62 | 63 | def efficientnet_b6(pretrained, outblocks, outstrides, pretrained_model=""): 64 | return build_efficient( 65 | "efficientnet-b6", pretrained, outblocks, outstrides, pretrained_model 66 | ) 67 | 68 | 69 | def efficientnet_b7(pretrained, outblocks, outstrides, pretrained_model=""): 70 | return build_efficient( 71 | "efficientnet-b7", pretrained, outblocks, outstrides, pretrained_model 72 | ) 73 | 74 | 75 | def efficientnet_b8(pretrained, outblocks, outstrides, pretrained_model=""): 76 | return build_efficient( 77 | "efficientnet-b8", pretrained, outblocks, outstrides, pretrained_model 78 | ) 79 | 80 | 81 | def efficientnet_l2(pretrained, outblocks, outstrides, pretrained_model=""): 82 | return build_efficient( 83 | "efficientnet-l2", pretrained, outblocks, outstrides, pretrained_model 84 | ) 85 | 86 | 87 | def build_efficient(model_name, pretrained, outblocks, outstrides, pretrained_model=""): 88 | if pretrained: 89 | model = EfficientNet.from_pretrained( 90 | model_name, 91 | outblocks=outblocks, 92 | outstrides=outstrides, 93 | pretrained_model=pretrained_model, 94 | ) 95 | else: 96 | model = EfficientNet.from_name( 97 | model_name, outblocks=outblocks, outstrides=outstrides 98 | ) 99 | return model 100 | -------------------------------------------------------------------------------- /models/efficientnet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/efficientnet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/efficientnet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/efficientnet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/efficientnet/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/efficientnet/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /models/efficientnet/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/efficientnet/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /models/efficientnet/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/efficientnet/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/efficientnet/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/models/efficientnet/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | print("Resume checkpoint %s" % args.resume) 324 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if 'scaler' in checkpoint: 328 | loss_scaler.load_state_dict(checkpoint['scaler']) 329 | print("With optim & sched!") 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x -------------------------------------------------------------------------------- /models/model_MAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.vision_transformer import PatchEmbed, Block 4 | from models.utils import get_2d_sincos_pos_embed 5 | 6 | class RB_MAE_dir(nn.Module): 7 | """ Masked Autoencoder with VisionTransformer backbone 8 | """ 9 | 10 | def __init__(self, img_size=64, patch_size=4, in_chans=960, 11 | embed_dim=768, depth=12, num_heads=12, 12 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 13 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 14 | super(RB_MAE_dir, self).__init__() 15 | self.len_keep = 0 # 初始化 16 | self.in_chans = in_chans 17 | 18 | # -------------------------------------------------------------------------- 19 | # MAE encoder specifics 20 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 21 | num_patches = self.patch_embed.num_patches 22 | 23 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 24 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), 25 | requires_grad=False) # fixed sin-cos embedding 26 | 27 | self.blocks = nn.ModuleList([ 28 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 29 | for i in range(depth)]) 30 | self.norm = norm_layer(embed_dim) 31 | # -------------------------------------------------------------------------- 32 | 33 | # -------------------------------------------------------------------------- 34 | # MAE decoder specifics 35 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 36 | 37 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), 38 | requires_grad=False) # fixed sin-cos embedding 39 | 40 | self.decoder_blocks = nn.ModuleList([ 41 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 42 | for i in range(decoder_depth)]) 43 | 44 | self.decoder_norm = norm_layer(decoder_embed_dim) 45 | 46 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) 47 | 48 | # -------------------------------------------------------------------------- 49 | 50 | self.norm_pix_loss = norm_pix_loss 51 | 52 | self.initialize_weights() 53 | 54 | def initialize_weights(self): 55 | # initialization 56 | # initialize (and freeze) pos_embed by sin-cos embedding 57 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), 58 | cls_token=True) 59 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 60 | 61 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], 62 | int(self.patch_embed.num_patches ** .5), cls_token=True) 63 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 64 | 65 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 66 | 67 | w = self.patch_embed.proj.weight.data 68 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 69 | 70 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 71 | torch.nn.init.normal_(self.cls_token, std=.02) 72 | 73 | # initialize nn.Linear and nn.LayerNorm 74 | self.apply(self._init_weights) 75 | 76 | def _init_weights(self, m): 77 | if isinstance(m, nn.Linear): 78 | # we use xavier_uniform following official JAX ViT: 79 | torch.nn.init.xavier_uniform_(m.weight) 80 | if isinstance(m, nn.Linear) and m.bias is not None: 81 | nn.init.constant_(m.bias, 0) 82 | elif isinstance(m, nn.LayerNorm): 83 | nn.init.constant_(m.bias, 0) 84 | nn.init.constant_(m.weight, 1.0) 85 | 86 | def patchify(self, imgs): 87 | """ 88 | imgs: (N, 3, H, W) 89 | x: (N, L, patch_size**2 *3) 90 | """ 91 | p = self.patch_embed.patch_size[0] 92 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 93 | 94 | h = w = imgs.shape[2] // p 95 | x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p)) 96 | x = torch.einsum('nchpwq->nhwpqc', x) 97 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * self.in_chans)) 98 | return x 99 | 100 | def unpatchify(self, x): 101 | """ 102 | x: (N, L, patch_size**2 *3) 103 | imgs: (N, 3, H, W) 104 | """ 105 | p = self.patch_embed.patch_size[0] 106 | h = w = int(x.shape[1] ** .5) 107 | assert h * w == x.shape[1] 108 | 109 | x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_chans)) 110 | x = torch.einsum('nhwpqc->nchpwq', x) 111 | imgs = x.reshape(shape=(x.shape[0], self.in_chans, h * p, h * p)) 112 | return imgs 113 | 114 | def random_masking(self, x, x_ref, mask_ratio): 115 | """ 116 | Perform per-sample random masking by per-sample shuffling. 117 | Per-sample shuffling is done by argsort random noise. 118 | x: [N, L, D], sequence 119 | """ 120 | 121 | N, L, D = x.shape # batch, length, dim 122 | 123 | len_keep = int(L * (1 - mask_ratio)) 124 | self.len_keep = len_keep 125 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 126 | noise1 = torch.rand(N, L, device=x.device) 127 | 128 | # sort noise for each sample 129 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 130 | ids_restore = torch.argsort(ids_shuffle, dim=1) 131 | 132 | ids_shuffle1 = torch.argsort(noise1, dim=1) # ascend: small is keep, large is remove 133 | ids_restore1 = torch.argsort(ids_shuffle1, dim=1) 134 | 135 | # keep the first subset 136 | ids_keep0 = ids_shuffle[:, :len_keep] 137 | ids_keep0_ref = ids_shuffle1[:, :len_keep] 138 | x_masked0 = torch.gather(x, dim=1, index=ids_keep0.unsqueeze(-1).repeat(1, 1, D)) 139 | x_ref_mask0 = torch.gather(x_ref, dim=1, index=ids_keep0_ref.unsqueeze(-1).repeat(1, 1, D)) 140 | 141 | # keep the second subset 142 | ids_keep1 = ids_shuffle[:, len_keep:] 143 | ids_keep1_ref = ids_shuffle1[:, len_keep:] 144 | x_masked1 = torch.gather(x, dim=1, index=ids_keep1.unsqueeze(-1).repeat(1, 1, D)) 145 | x_ref_mask1 = torch.gather(x_ref, dim=1, index=ids_keep1_ref.unsqueeze(-1).repeat(1, 1, D)) 146 | 147 | # generate the binary mask: 0 is keep, 1 is remove 148 | mask = torch.ones([N, L], device=x.device) 149 | mask[:, :len_keep] = 0 150 | # unshuffle to get the binary mask 151 | mask = torch.gather(mask, dim=1, index=ids_restore) 152 | 153 | x_fuse0 = torch.cat([x_masked0, x_ref_mask0], dim=1) 154 | x_fuse1 = torch.cat([x_masked1, x_ref_mask1], dim=1) 155 | return x_fuse0, x_fuse1, mask, ids_restore, ids_restore1 156 | 157 | def forward_encoder(self, x, x_ref, mask_ratio): 158 | # embed patches 159 | x = self.patch_embed(x) 160 | x_ref = self.patch_embed(x_ref) 161 | # add pos embed w/o cls token 162 | x = x + self.pos_embed[:, 1:, :] 163 | x_ref = x_ref + self.pos_embed[:, 1:, :] 164 | # masking: length -> length * mask_ratio 165 | # x1, x2, mask, ids_restore, ids_restore1 = self.random_masking(x, x_ref, mask_ratio) 166 | 167 | # append cls token 168 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 169 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 170 | x1 = torch.cat((cls_tokens, x, x_ref), dim=1) 171 | 172 | # cls_token = self.cls_token + self.pos_embed[:, :1, :] 173 | # cls_tokens = cls_token.expand(x.shape[0], -1, -1) 174 | # x2 = torch.cat((cls_tokens, x2), dim=1) 175 | 176 | # apply Transformer blocks 177 | for blk in self.blocks: 178 | x1 = blk(x1) 179 | x1 = self.norm(x1) 180 | 181 | # for blk in self.blocks: 182 | # x2 = blk(x2) 183 | # x2 = self.norm(x2) 184 | 185 | return x1 186 | 187 | def forward_decoder(self, x1): 188 | # embed tokens 189 | cls = x1[:, :1, :] 190 | x_ref = x1[:, self.patch_embed.num_patches+1:, :] 191 | x_ref = torch.cat([cls, x_ref], dim=1) 192 | x1 = self.decoder_embed(x_ref) 193 | # x2 = self.decoder_embed(x2) 194 | 195 | # x1_no_cls = x1[:, 1:, :] 196 | # x2_no_cls = x2[:, 1:, :] 197 | # x_ori_1 = x1_no_cls[:, :self.len_keep, :] 198 | # x_ori_2 = x2_no_cls[:, :self.len_keep, :] 199 | # 200 | # x_ref_1 = x1_no_cls[:, self.len_keep:, :] 201 | # x_ref_2 = x2_no_cls[:, self.len_keep:, :] 202 | # 203 | # x_ori = torch.cat([x_ori_1, x_ori_2], dim=1) 204 | # x_ref = torch.cat([x_ref_1, x_ref_2], dim=1) 205 | # 206 | # # x_ori_ = torch.gather(x_ori, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_ori.shape[2])) 207 | # # x_ref_ = torch.gather(x_ref, dim=1, index=ids_restore1.unsqueeze(-1).repeat(1, 1, x_ref.shape[2])) 208 | # # x_ = torch.cat([x2[:, 1:, :], x1[:, 1:, :]], dim=1) # no cls token 209 | # # x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2])) # unshuffle 210 | # 211 | # x = torch.cat([x1[:, :1, :], x_ref_], dim=1) # append cls token 212 | 213 | # add pos embed 214 | x = x1 + self.decoder_pos_embed 215 | 216 | # apply Transformer blocks 217 | for blk in self.decoder_blocks: 218 | x = blk(x) 219 | x = self.decoder_norm(x) 220 | 221 | # predictor projection 222 | x = self.decoder_pred(x) 223 | 224 | # remove cls token 225 | x = x[:, 1:, :] 226 | 227 | return x 228 | 229 | def forward_loss(self, imgs, pred): 230 | """ 231 | imgs: [N, 3, H, W] 232 | pred: [N, L, p*p*3] 233 | mask: [N, L], 0 is keep, 1 is remove, 234 | """ 235 | target = self.patchify(imgs) 236 | if self.norm_pix_loss: 237 | mean = target.mean(dim=-1, keepdim=True) 238 | var = target.var(dim=-1, keepdim=True) 239 | target = (target - mean) / (var + 1.e-6) ** .5 240 | 241 | dis_loss = (pred - target) ** 2 242 | dis_loss = dis_loss.mean(dim=-1) # [N, L], mean loss per patch 243 | dir_loss = 1 - torch.nn.CosineSimilarity(-1)(pred, target) 244 | 245 | loss = 5 * dir_loss.mean() + dis_loss.mean() # mean loss on removed patches 246 | return loss 247 | 248 | def forward(self, imgs, ref_imgs, mask_ratio=0.5): 249 | latent1 = self.forward_encoder(imgs, ref_imgs, mask_ratio) 250 | pred = self.forward_decoder(latent1) # [N, L, p*p*3] 251 | loss = self.forward_loss(imgs, pred) 252 | return loss, pred, None 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torchvision import models 5 | from torchvision.models.vgg import vgg16, vgg19,vgg19_bn 6 | from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, wide_resnet50_2, wide_resnet101_2 7 | import torch.nn.functional as F 8 | # from torchsummary import summary 9 | from torchvision.models import mobilenet_v2 10 | from models.efficientnet import model 11 | 12 | 13 | # class EfficientNet(nn.Module): 14 | # def __init__(self): 15 | # super(EfficientNet, self).__init__() 16 | # efficient_net = model.EfficientNet.from_pretrained('efficientnet-b0') 17 | # self.efficient_net = efficient_net.eval() 18 | # 19 | # def forward(self, input_): 20 | # with torch.no_grad(): 21 | # features = self.efficient_net.extract_features(input_) 22 | # for feature in features: 23 | # print(feature.shape) 24 | # # 25 | # # f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 26 | # # f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 27 | # # f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 28 | # # f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 29 | # # 30 | # # f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 31 | # return features 32 | # 33 | # model = EfficientNet() 34 | # input_tensor = torch.rand(1,3,256,256) 35 | # output_tensor =model(input_tensor) 36 | 37 | class MobileNet(nn.Module): 38 | def __init__(self): 39 | super(MobileNet, self).__init__() 40 | mobilenet = mobilenet_v2(True) 41 | layers = mobilenet.features 42 | # for i in range(15): 43 | # print(layers[i]) 44 | # print("=================================") 45 | self.layer1 = layers[:1] 46 | self.layer2 = layers[1:2] 47 | self.layer3 = layers[2:4] 48 | self.layer4 = layers[4:7] 49 | 50 | def forward(self, input_): 51 | out1 = self.layer1(input_) 52 | out2 = self.layer2(out1) 53 | out3 = self.layer3(out2) 54 | out4 = self.layer4(out3) 55 | 56 | f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 57 | f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 58 | f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 59 | f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 60 | 61 | f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 62 | return f_ 63 | 64 | 65 | 66 | class VGG(nn.Module): 67 | def __init__(self): 68 | super(VGG, self).__init__() 69 | vgg = vgg19(True) 70 | layers = vgg.features 71 | # print(layers) 72 | self.layer1 = layers[:5] 73 | self.layer2 = layers[5:10] 74 | self.layer3 = layers[10:19] 75 | self.layer4 = layers[19:28] 76 | 77 | def forward(self, input_): 78 | out1 = self.layer1(input_) 79 | out2 = self.layer2(out1) 80 | out3 = self.layer3(out2) 81 | out4 = self.layer4(out3) 82 | f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 83 | f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 84 | f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 85 | f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 86 | 87 | f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 88 | 89 | return f_ 90 | 91 | # class VGG(nn.Module): 92 | # def __init__(self): 93 | # super(VGG, self).__init__() 94 | # vgg = vgg19_bn(True) 95 | # layers = vgg.features 96 | # # print(layers) 97 | # self.layer1 = layers[:6] 98 | # self.layer2 = layers[6:13] 99 | # self.layer3 = layers[13:26] 100 | # self.layer4 = layers[26:39] 101 | # self.layer5 = layers[39:] 102 | # 103 | # def forward(self, input_): 104 | # out1 = self.layer1(input_) 105 | # out2 = self.layer2(out1) 106 | # out3 = self.layer3(out2) 107 | # out4 = self.layer4(out3) 108 | # f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 109 | # f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 110 | # f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 111 | # f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 112 | # 113 | # f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 114 | # return f_ 115 | 116 | class Resnet34(nn.Module): 117 | def __init__(self): 118 | super(Resnet34, self).__init__() 119 | resnet = resnet34(True) 120 | 121 | modules = list(resnet.children()) 122 | self.block1 = nn.Sequential(*modules[0:4]) 123 | self.block2 = modules[4] 124 | self.block3 = modules[5] 125 | self.block4 = modules[6] 126 | 127 | 128 | def forward(self, input_): 129 | out1 = self.block1(input_) 130 | out2 = self.block2(out1) 131 | out3 = self.block3(out2) 132 | out4 = self.block4(out3) 133 | f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 134 | f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 135 | f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 136 | f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 137 | 138 | f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 139 | return f_ 140 | 141 | 142 | class Resnet50(nn.Module): 143 | def __init__(self): 144 | super(Resnet50, self).__init__() 145 | resnet = resnet50(True) 146 | 147 | modules = list(resnet.children()) 148 | self.block1 = nn.Sequential(*modules[0:4]) 149 | self.block2 = modules[4] 150 | self.block3 = modules[5] 151 | self.block4 = modules[6] 152 | 153 | 154 | def forward(self, input_): 155 | out1 = self.block1(input_) 156 | out2 = self.block2(out1) 157 | out3 = self.block3(out2) 158 | out4 = self.block4(out3) 159 | 160 | f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 161 | f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 162 | f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 163 | f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 164 | 165 | 166 | f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 167 | return f_ 168 | 169 | class WideResNet50(nn.Module): 170 | def __init__(self): 171 | super().__init__() 172 | wideresnet50 = wide_resnet50_2(True) 173 | modules = list(wideresnet50.children()) 174 | self.block1 = nn.Sequential(*modules[0:4]) 175 | self.block2 = modules[4] 176 | self.block3 = modules[5] 177 | self.block4 = modules[6] 178 | 179 | 180 | def forward(self, input_): 181 | out1 = self.block1(input_) 182 | out2 = self.block2(out1) 183 | out3 = self.block3(out2) 184 | out4 = self.block4(out3) 185 | 186 | f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 187 | f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 188 | f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 189 | f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 190 | 191 | f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 192 | return f_ 193 | 194 | class Resnet101(nn.Module): 195 | def __init__(self): 196 | super(Resnet101, self).__init__() 197 | resnet = resnet101(True) 198 | 199 | modules = list(resnet.children()) 200 | self.block1 = nn.Sequential(*modules[0:4]) 201 | self.block2 = modules[4] 202 | self.block3 = modules[5] 203 | self.block4 = modules[6] 204 | 205 | 206 | def forward(self, input_): 207 | out1 = self.block1(input_) 208 | out2 = self.block2(out1) 209 | out3 = self.block3(out2) 210 | out4 = self.block4(out3) 211 | 212 | f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 213 | f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 214 | f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 215 | f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 216 | 217 | 218 | f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 219 | f_ = F.normalize(f_) 220 | return f_ 221 | 222 | 223 | class WideResnet101(nn.Module): 224 | def __init__(self): 225 | super(WideResnet101, self).__init__() 226 | wideresnet101 = wide_resnet101_2(True) 227 | 228 | modules = list(wideresnet101.children()) 229 | self.block1 = nn.Sequential(*modules[0:4]) 230 | self.block2 = modules[4] 231 | self.block3 = modules[5] 232 | self.block4 = modules[6] 233 | 234 | def forward(self, input_): 235 | out1 = self.block1(input_) 236 | out2 = self.block2(out1) 237 | out3 = self.block3(out2) 238 | out4 = self.block4(out3) 239 | 240 | f1_ = F.interpolate(out1, size=(64, 64), mode='bilinear', align_corners=True) 241 | f2_ = F.interpolate(out2, size=(64, 64), mode='bilinear', align_corners=True) 242 | f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 243 | f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 244 | 245 | f_ = torch.cat([f1_, f2_, f3_, f4_], dim=1) 246 | f_ = F.normalize(f_) 247 | return f_ 248 | 249 | class D_VGG(nn.Module): 250 | def __init__(self): 251 | super(D_VGG, self).__init__() 252 | vgg = vgg19(True) 253 | layers = vgg.features 254 | # print(layers) 255 | self.layer1 = layers[:5] 256 | self.layer2 = layers[5:10] 257 | self.layer3 = layers[10:19] 258 | self.layer4 = layers[19:28] 259 | 260 | def forward(self, input_): 261 | out1 = self.layer1(input_) 262 | out2 = self.layer2(out1) 263 | out3 = self.layer3(out2) 264 | out4 = self.layer4(out3) 265 | 266 | f3_ = F.interpolate(out3, size=(64, 64), mode='bilinear', align_corners=True) 267 | f4_ = F.interpolate(out4, size=(64, 64), mode='bilinear', align_corners=True) 268 | 269 | f_ = torch.cat([f3_, f4_], dim=1) 270 | 271 | return f_ 272 | 273 | class IMAGE(nn.Module): 274 | def __init__(self): 275 | super(IMAGE, self).__init__() 276 | 277 | def forward(self, input_): 278 | input_ = F.interpolate(input_, size=(256, 256), mode='bilinear', align_corners=True) 279 | 280 | return input_ 281 | 282 | if __name__ == '__main__': 283 | a = torch.rand((1, 3, 256, 256)) 284 | pre_fea = WideResNet50() 285 | print(pre_fea(a)) -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | 98 | import math 99 | 100 | def adjust_learning_rate(optimizer, epoch): 101 | """Decay the learning rate with half-cycle cosine after warmup""" 102 | decay_epochs =300 103 | lr =0.0001 104 | if epoch >= decay_epochs: 105 | lr = lr * 0.1 106 | for param_group in optimizer.param_groups: 107 | if "lr_scale" in param_group: 108 | param_group["lr"] = lr * param_group["lr_scale"] 109 | else: 110 | param_group["lr"] = lr 111 | return lr -------------------------------------------------------------------------------- /network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/network.png -------------------------------------------------------------------------------- /ref_find.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | import torch 4 | 5 | def get_pos_sample(train_img_root, device, batch_size): 6 | pos_img_name = train_img_root 7 | transforms_x = transforms.Compose([transforms.Resize(256, Image.ANTIALIAS), 8 | transforms.ToTensor()]) 9 | pos_img = transforms_x(Image.open(pos_img_name).convert('RGB')) 10 | pos_img = [pos_img.unsqueeze(0)]*batch_size 11 | pos_img = torch.cat(pos_img, dim=0) 12 | pos_img = pos_img.to(device) 13 | return pos_img 14 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # from models.RB_VIT import RB_VIT 2 | import torch 3 | from config import DefaultConfig 4 | import os 5 | from torch import optim 6 | 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | 10 | from matplotlib import pyplot as plt 11 | import numpy as np 12 | import argparse 13 | from sklearn.metrics import roc_auc_score, f1_score, average_precision_score 14 | from utils.metric import cal_pro_metric_new 15 | import cv2 16 | from models.misc import NativeScalerWithGradNormCount as NativeScaler 17 | from losses.vgg_loss import StyleLoss, PerceptualLoss 18 | from scipy.ndimage import gaussian_filter 19 | from datasets.dataset import denormalize 20 | from ref_find import get_pos_sample 21 | from models.TFA_Net_model import * 22 | class Model(object): 23 | def __init__(self, opt, test_no_mask, test_add): 24 | super(Model, self).__init__() 25 | # if test_no_mask == True: 26 | # model_name = 'RB_VIT_dir_res_ref' 27 | # else: 28 | # model_name = opt.model_name 29 | self.opt = opt 30 | self.model = eval(opt.model_name)(opt) 31 | self.device = opt.device 32 | self.test_add = test_add 33 | self.class_name = opt.class_name 34 | self.trainloader = opt.trainloader 35 | self.testloader = opt.testloader 36 | self.loss_scaler = NativeScaler() 37 | 38 | if self.opt.resume != "": 39 | print('\nload pre-trained networks') 40 | self.opt.iter = torch.load(os.path.join(self.opt.resume, f'{opt.model_name}_{opt.backbone_name}.pth'))['epoch'] 41 | print(self.opt.iter) 42 | self.model.load_state_dict(torch.load(os.path.join(self.opt.resume, f'{opt.model_name}_{opt.backbone_name}.pth'))['state_dict'], strict=False) 43 | print('\ndone.\n') 44 | 45 | if self.opt.isTrain: 46 | self.model.Roncon_model.train() 47 | self.optimizer_g = optim.AdamW(self.model.Roncon_model.parameters(), lr=opt.lr, betas=(0.9, 0.95)) 48 | if test_no_mask == True: 49 | self.save_root = f"./result_test/{opt.model_name}_{opt.backbone_name}_test_no_mask" 50 | else: 51 | self.save_root = f"./result_test/{opt.model_name}_{opt.backbone_name}" 52 | 53 | if test_add == True: 54 | self.save_root += 'add' 55 | # os.makedirs(os.path.join(self.save_root, "weight"), exist_ok=True) 56 | self.ckpt_root = os.path.join(self.save_root, "weight/{}".format(self.class_name)) 57 | self.vis_root = os.path.join(self.save_root, "img/{}".format(self.class_name)) 58 | 59 | 60 | 61 | def get_max(self, tensor): 62 | a_1, _ = torch.max(tensor, dim=1, keepdim=True) 63 | a_2, _ = torch.max(a_1, dim=2, keepdim=True) 64 | a_3, _ = torch.max(a_2, dim=3, keepdim=True) 65 | return a_3 66 | def train(self): 67 | 68 | loss_now = 100000 69 | auc_now = 0. 70 | for epoch in range(self.opt.iter, self.opt.niter): 71 | self.model.Feature_extractor.eval() 72 | self.model.Roncon_model.train(True) 73 | self.model.to(self.device) 74 | loss_total = 0. 75 | count = 0 76 | for index, (x, _, _, _) in enumerate(tqdm(self.trainloader, ncols=80)): 77 | bs = x.shape[0] 78 | x = x.to(self.device) 79 | ref_x = get_pos_sample(self.opt.referenc_img_file, self.device, bs) 80 | 81 | deep_feature,_, recon_feature, loss = self.model(x, ref_x) 82 | self.loss_scaler(loss, self.optimizer_g, parameters=self.model.Roncon_model.parameters(), update_grad=(index + 1) % 1 == 0) 83 | loss_total += loss.item() 84 | count += 1 85 | 86 | loss_total = loss_total / count 87 | print('the {} epoch is done loss:{}'.format(epoch + 1, loss_total)) 88 | if (epoch + 1) % 50 == 0: 89 | # self.test_2() 90 | x1, x2, x3, x4 = self.test() 91 | auc_roc = x1+x2 92 | if auc_roc > auc_now: 93 | auc_now = auc_roc 94 | class_rocauc[self.opt.class_name] = (x1, x2, x3, x4) 95 | print('save model') 96 | weight_dir = self.ckpt_root 97 | os.makedirs(weight_dir, exist_ok=True) 98 | torch.save({'epoch': epoch + 1, 'state_dict': self.model.state_dict()}, 99 | f'%s/{self.opt.model_name}_{self.opt.backbone_name}.pth' % (weight_dir)) 100 | 101 | 102 | 103 | def cal_auc(self, score_list, score_map_list, test_y_list, test_mask_list): 104 | flatten_y_list = np.array(test_y_list).ravel() 105 | flatten_score_list = np.array(score_list).ravel() 106 | image_level_ROCAUC = roc_auc_score(flatten_y_list, flatten_score_list) 107 | image_level_AP = average_precision_score(flatten_y_list, flatten_score_list) 108 | 109 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 110 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 111 | pixel_level_ROCAUC = roc_auc_score(flatten_mask_list, flatten_score_map_list) 112 | pixel_level_AP = average_precision_score(flatten_mask_list, flatten_score_map_list) 113 | # pro_auc_score = 0 114 | # pro_auc_score = cal_pro_metric_new(test_mask_list, score_map_list, fpr_thresh=0.3) 115 | return round(image_level_ROCAUC,3), round(pixel_level_ROCAUC,3), round(image_level_AP,3), round(pixel_level_AP,3) 116 | # return image_level_ROCAUC, pixel_level_ROCAUC 117 | 118 | def F1_score(self, score_map_list, test_mask_list): 119 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 120 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 121 | F1_score = f1_score(flatten_mask_list, flatten_score_map_list) 122 | return F1_score 123 | 124 | def filter(self, pred_mask): 125 | pred_mask_my = np.squeeze(np.squeeze(pred_mask, 0), 0) 126 | pred_mask_my = cv2.medianBlur(np.uint8(pred_mask_my * 255), 7) 127 | mean = np.mean(pred_mask_my) 128 | std = np.std(pred_mask_my) 129 | _ , binary_pred_mask = cv2.threshold(pred_mask_my, mean+2.75*std, 255, type=cv2.THRESH_BINARY) 130 | binary_pred_mask = np.uint8(binary_pred_mask/255) 131 | pred_mask_my = np.expand_dims(np.expand_dims(pred_mask_my, 0), 0) 132 | binary_pred_mask = np.expand_dims(np.expand_dims(binary_pred_mask, 0), 0) 133 | return pred_mask_my, binary_pred_mask 134 | 135 | 136 | # def thresholding(self, pred_mask_my): 137 | # np_img 138 | 139 | # return 140 | def feature_map_vis(self, feature_map_list): 141 | feature_map_list = [torch.mean(i.clone(), dim=1).squeeze(0).cpu().detach().numpy() for i in feature_map_list] 142 | # feature_map_list = [(i.squeeze(0))[25, :, :].cpu().detach().numpy() for i in feature_map_list] 143 | return feature_map_list 144 | # def feature_map_vis(self, feature_map_list): 145 | # feature_map_list = [torch.mean(i.clone(), dim=1).squeeze(0).cpu().detach().numpy() for i in feature_map_list] 146 | # return feature_map_list 147 | 148 | def test(self): 149 | test_y_list = [] 150 | test_mask_list = [] 151 | score_list = [] 152 | score_map_list = [] 153 | 154 | for idx, (x, y, mask, name) in enumerate(tqdm(self.testloader, ncols=80)): 155 | test_y_list.extend(y.detach().cpu().numpy()) 156 | test_mask_list.extend(mask.detach().cpu().numpy()) 157 | self.model.eval() 158 | self.model.to(self.device) 159 | x = x.to(self.device) 160 | mask = mask.to(self.device) 161 | mask_cpu = mask.cpu().detach().numpy()[0, :, :, :].transpose((1, 2, 0)) 162 | ref_x = get_pos_sample(self.opt.referenc_img_file, self.device, 1) 163 | deep_feature, ref_feature, recon_feature, _ = self.model(x, ref_x, None) 164 | feature_map_vis_list = self.feature_map_vis([deep_feature, ref_feature, recon_feature]) 165 | dis_amap, dir_amap = self.model.a_map(deep_feature, recon_feature) 166 | dis_amap = gaussian_filter(dis_amap, sigma=4) 167 | dir_amap = gaussian_filter(dir_amap, sigma=4) 168 | # print(type(name0])) 169 | name_list= name[0].split(r'!') 170 | # print(name_list) 171 | category, img_name = name_list[-2], name_list[-1] 172 | if self.test_add == False: 173 | # amap = dir_amap*dis_amap 174 | amap = dir_amap*5+dis_amap 175 | else: 176 | # print('ok') 177 | # print(np.max(dir_amap)) 178 | # print(np.max(dis_amap)) 179 | amap = 0.5*(dir_amap/np.max(dir_amap)) + 0.5*(dis_amap/np.max(dis_amap)) 180 | self.vis_img([x,*feature_map_vis_list, dis_amap, dir_amap, amap, mask_cpu], os.path.join(self.vis_root, category), img_name) 181 | 182 | 183 | score_list.extend(np.array(np.std(amap)).reshape(1)) 184 | score_map_list.extend(amap.reshape((1, 1, 256, 256))) 185 | 186 | 187 | image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP= self.cal_auc(score_list, score_map_list, test_y_list, test_mask_list) 188 | # F1_score = self.F1_score(F1_score_map_list, test_mask_list) 189 | print('image_auc_roc: {} '.format(image_level_ROCAUC), 190 | 'pixel_auc_roc: {} '.format(pixel_level_ROCAUC), 191 | 'image_AP: {}'.format(image_level_AP), 192 | 'pixel_AP: {}'.format(pixel_level_AP) 193 | ) 194 | class_rocauc[self.opt.class_name] = (image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP) 195 | return image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP 196 | 197 | def vis_img(self, img_list, save_root, idx_name): 198 | os.makedirs(save_root, exist_ok=True) 199 | input_frame = denormalize(img_list[0].clone().squeeze(0).cpu().detach().numpy()) 200 | cv2_input = np.array(input_frame, dtype=np.uint8) 201 | plt.figure() 202 | plt.subplot(241) 203 | plt.imshow(cv2_input) 204 | plt.axis('off') 205 | plt.subplot(242) 206 | plt.imshow(img_list[1]) 207 | plt.axis('off') 208 | plt.subplot(243) 209 | plt.imshow(img_list[2]) 210 | plt.axis('off') 211 | plt.subplot(244) 212 | plt.imshow(img_list[3]) 213 | plt.axis('off') 214 | plt.subplot(245) 215 | plt.imshow(img_list[4], cmap='jet') 216 | plt.axis('off') 217 | plt.subplot(246) 218 | plt.imshow(img_list[5],cmap='jet') 219 | plt.axis('off') 220 | plt.subplot(247) 221 | plt.imshow(img_list[6], cmap='jet') 222 | plt.axis('off') 223 | plt.subplot(248) 224 | plt.imshow(img_list[7]) 225 | plt.axis('off') 226 | plt.savefig(os.path.join(save_root, idx_name)) 227 | plt.close() 228 | 229 | def save_img(self, img_list, save_root, idx_name): 230 | os.makedirs(save_root, exist_ok=True) 231 | input_frame = denormalize(img_list[0].clone().squeeze(0).cpu().detach().numpy()) 232 | cv2_input = np.array(input_frame, dtype=np.uint8) 233 | # plt.figure() 234 | # plt.subplot(241) 235 | plt.imsave(os.path.join(save_root, f'{idx_name}_{0}.png'), cv2_input) 236 | plt.imsave(os.path.join(save_root, f'{idx_name}_{1}.png'), img_list[1]) 237 | plt.imsave(os.path.join(save_root, f'{idx_name}_{2}.png'), img_list[2]) 238 | plt.imsave(os.path.join(save_root, f'{idx_name}_{3}.png'), img_list[3]) 239 | plt.imsave(os.path.join(save_root, f'{idx_name}_{4}.png'), img_list[4], cmap='jet') 240 | plt.imsave(os.path.join(save_root, f'{idx_name}_{5}.png'), img_list[5], cmap='jet') 241 | plt.imsave(os.path.join(save_root, f'{idx_name}_{6}.png'), img_list[6], cmap='jet') 242 | 243 | plt.imsave(os.path.join(save_root, f'{idx_name}_{7}.png'),cv2.cvtColor(img_list[7], cv2.COLOR_GRAY2RGB), cmap='gray') 244 | # plt.axis('off') 245 | # plt.subplot(242) 246 | # plt.imwrite() 247 | # # plt.axis('off') 248 | # # plt.subplot(243) 249 | # plt.imshow(img_list[2]) 250 | # # plt.axis('off') 251 | # # plt.subplot(244) 252 | # plt.imshow(img_list[3]) 253 | # # plt.axis('off') 254 | # # plt.subplot(245) 255 | # plt.imshow(img_list[4], cmap='jet') 256 | # 257 | # # plt.axis('off') 258 | # # plt.subplot(246) 259 | # plt.imshow(img_list[5],cmap='jet') 260 | # # plt.axis('off') 261 | # # plt.subplot(247) 262 | # plt.imshow(img_list[6], cmap='jet') 263 | # # plt.axis('off') 264 | # # plt.subplot(248) 265 | # plt.imshow(img_list[7]) 266 | # plt.axis('off') 267 | # plt.savefig(os.path.join(save_root, idx_name)) 268 | # plt.close() 269 | 270 | def tensor_to_np_cpu(self, tensor): 271 | x_cpu = tensor.squeeze(0).data.cpu().numpy() 272 | x_cpu = np.transpose(x_cpu, (1, 2, 0)) 273 | return x_cpu 274 | 275 | def check(self, img): 276 | if len(img.shape) == 2: 277 | return img 278 | if img.shape[2] == 3: 279 | return img 280 | elif img.shape[2] == 1: 281 | return img.reshape(img.shape[0], img.shape[1]) 282 | 283 | MVTec_CLASS_NAMES = [ 'bottle', 'cable', 'capsule', 'carpet', 'grid', 284 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 285 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper'] 286 | 287 | class_rocauc = { 288 | 'bottle':(0, 0, 0, 0), 289 | 'cable':(0, 0, 0, 0), 290 | 'capsule':(0, 0, 0, 0), 291 | 'carpet':(0, 0, 0, 0), 292 | 'grid':(0, 0, 0, 0), 293 | 'hazelnut':(0, 0, 0, 0), 294 | 'leather':(0, 0, 0, 0), 295 | 'metal_nut':(0, 0, 0, 0), 296 | 'pill':(0, 0, 0, 0), 297 | 'screw':(0, 0, 0, 0), 298 | 'tile':(0, 0, 0, 0), 299 | 'toothbrush':(0, 0, 0, 0), 300 | 'transistor':(0, 0, 0, 0), 301 | 'wood':(0, 0, 0, 0), 302 | 'zipper':(0, 0, 0, 0)} 303 | 304 | model_name_list = ['RB_VIT', 'RB_VIT_dir', 'RB_VIT_dir_no_ref','VIT_dir', 'RB_VIT_average', 'RB_VIT_dir_mask25', 'RB_VIT_dir_mask50', 'RB_VIT_dir_mask75', 'ST_VIT', 'RB_VIT_Res', 'RB_VIT_dir_no_ref', 305 | 'RB_VIT_dir_res_ref', 'RB_VIT_dir_res_ref_mask25', 'RB_VIT_dir_res_ref_mask50', 'RB_VIT_dir_res_ref_mask75', 306 | 'RB_VIT_dir_res_ref_mask_random'] 307 | 308 | MVTec_CLASS_NAMES = [ 'transistor'] 309 | 310 | if __name__ == '__main__': 311 | opt = DefaultConfig() 312 | test_no_mask = True 313 | test_add = False 314 | from datasets.dataset import MVTecDataset 315 | from torch.utils.data import DataLoader 316 | opt.model_name = model_name_list[1] 317 | for classname in MVTec_CLASS_NAMES: 318 | opt.class_name = classname 319 | # opt.class_name = 'capsule' 320 | opt.referenc_img_file = f'data/mvtec_anomaly_detection/{opt.class_name}/train/good/000.png' 321 | save_name = opt.model_name+'_'+opt.backbone_name 322 | # opt.resume = fr'F:\LW\RB-VIT\result/{save_name}/weight/{opt.class_name}' 323 | # opt.resume = fr'result/RB_VIT_dir_WideResnet50_k=4/weight/{opt.class_name}' 324 | opt.resume = r'result/RB_VIT_dir_WideResnet50/weight/transistor' 325 | # if test_no_mask == True: 326 | # opt.model_name = 'RB_VIT_dir_res_ref' 327 | print(opt.class_name, opt.model_name) 328 | # print(opt.referenc_img_file) 329 | # opt.resume = r'result/RB_VIT_dir_res_ref_VGG/weight/capsule' 330 | opt.train_dataset = MVTecDataset(dataset_path=opt.data_root, class_name=opt.class_name, is_train=True) 331 | opt.test_dataset = MVTecDataset(dataset_path=opt.data_root, class_name=opt.class_name, is_train=False) 332 | opt.trainloader = DataLoader(opt.train_dataset, batch_size=opt.batch_size, shuffle=True) 333 | opt.testloader = DataLoader(opt.test_dataset, batch_size=1, shuffle=False) 334 | model = Model(opt, test_no_mask, test_add) 335 | model.test() 336 | print(class_rocauc) 337 | value = list(class_rocauc.values()) 338 | img_roc = [i[0] for i in value] 339 | pixel_roc = [i[1] for i in value] 340 | img_ap = [i[2] for i in value] 341 | pixel_ap = [i[3] for i in value] 342 | mean_img_roc = np.mean(np.array(img_roc)) 343 | mean_pixel_roc = np.mean(np.array(pixel_roc)) 344 | mean_img_ap = np.mean(np.array(img_ap)) 345 | mean_pixel_ap = np.mean(np.array(pixel_ap)) 346 | 347 | print(round(mean_img_roc,3), round(mean_pixel_roc,3), round(mean_img_ap, 3), round(mean_pixel_ap,3)) 348 | -------------------------------------------------------------------------------- /test_mvlogo.py: -------------------------------------------------------------------------------- 1 | # from models.RB_VIT import RB_VIT 2 | import torch 3 | from config import DefaultConfig 4 | import os 5 | from torch import optim 6 | 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | 10 | from matplotlib import pyplot as plt 11 | import numpy as np 12 | import argparse 13 | from sklearn.metrics import roc_auc_score, f1_score, average_precision_score 14 | from utils.metric import cal_pro_metric_new 15 | import cv2 16 | from models.misc import NativeScalerWithGradNormCount as NativeScaler 17 | from losses.vgg_loss import StyleLoss, PerceptualLoss 18 | from scipy.ndimage import gaussian_filter 19 | from datasets.dataset import denormalize 20 | from ref_find import get_pos_sample 21 | from models.TFA_Net_model import * 22 | class Model(object): 23 | def __init__(self, opt, test_no_mask, test_add): 24 | super(Model, self).__init__() 25 | # if test_no_mask == True: 26 | # model_name = 'RB_VIT_dir_res_ref' 27 | # else: 28 | # model_name = opt.model_name 29 | self.opt = opt 30 | self.model = eval(opt.model_name)(opt) 31 | self.device = opt.device 32 | self.test_add = test_add 33 | self.class_name = opt.class_name 34 | self.trainloader = opt.trainloader 35 | self.testloader = opt.testloader 36 | self.loss_scaler = NativeScaler() 37 | 38 | if self.opt.resume != "": 39 | print('\nload pre-trained networks') 40 | self.opt.iter = torch.load(os.path.join(self.opt.resume, f'{opt.model_name}_{opt.backbone_name}_k={opt.k}.pth'))['epoch'] 41 | print(self.opt.iter) 42 | self.model.load_state_dict(torch.load(os.path.join(self.opt.resume, f'{opt.model_name}_{opt.backbone_name}_k={opt.k}.pth'))['state_dict'], strict=False) 43 | print('\ndone.\n') 44 | 45 | if self.opt.isTrain: 46 | self.model.Roncon_model.train() 47 | self.optimizer_g = optim.AdamW(self.model.Roncon_model.parameters(), lr=opt.lr, betas=(0.9, 0.95)) 48 | if test_no_mask == True: 49 | self.save_root = f"./result_logo_test/{opt.model_name}_{opt.backbone_name}_test_no_mask" 50 | else: 51 | self.save_root = f"./result_test/{opt.model_name}_{opt.backbone_name}" 52 | 53 | if test_add == True: 54 | self.save_root += 'add' 55 | # os.makedirs(os.path.join(self.save_root, "weight"), exist_ok=True) 56 | self.ckpt_root = os.path.join(self.save_root, "weight/{}".format(self.class_name)) 57 | self.vis_root = os.path.join(self.save_root, "img/{}".format(self.class_name)) 58 | 59 | 60 | 61 | def get_max(self, tensor): 62 | a_1, _ = torch.max(tensor, dim=1, keepdim=True) 63 | a_2, _ = torch.max(a_1, dim=2, keepdim=True) 64 | a_3, _ = torch.max(a_2, dim=3, keepdim=True) 65 | return a_3 66 | def train(self): 67 | 68 | loss_now = 100000 69 | auc_now = 0. 70 | for epoch in range(self.opt.iter, self.opt.niter): 71 | self.model.Feature_extractor.eval() 72 | self.model.Roncon_model.train(True) 73 | self.model.to(self.device) 74 | loss_total = 0. 75 | count = 0 76 | for index, (x, _, _, _) in enumerate(tqdm(self.trainloader, ncols=80)): 77 | bs = x.shape[0] 78 | x = x.to(self.device) 79 | ref_x = get_pos_sample(self.opt.referenc_img_file, self.device, bs) 80 | 81 | deep_feature,_, recon_feature, loss = self.model(x, ref_x) 82 | self.loss_scaler(loss, self.optimizer_g, parameters=self.model.Roncon_model.parameters(), update_grad=(index + 1) % 1 == 0) 83 | loss_total += loss.item() 84 | count += 1 85 | 86 | loss_total = loss_total / count 87 | print('the {} epoch is done loss:{}'.format(epoch + 1, loss_total)) 88 | if (epoch + 1) % 50 == 0: 89 | # self.test_2() 90 | x1, x2, x3, x4 = self.test() 91 | auc_roc = x1+x2 92 | if auc_roc > auc_now: 93 | auc_now = auc_roc 94 | class_rocauc[self.opt.class_name] = (x1, x2, x3, x4) 95 | print('save model') 96 | weight_dir = self.ckpt_root 97 | os.makedirs(weight_dir, exist_ok=True) 98 | torch.save({'epoch': epoch + 1, 'state_dict': self.model.state_dict()}, 99 | f'%s/{self.opt.model_name}_{self.opt.backbone_name}.pth' % (weight_dir)) 100 | 101 | 102 | 103 | def cal_auc(self, score_list, logo_score_list, struct_score_list, score_map_list, test_y_list, logo_y_list, struct_y_list, test_mask_list): 104 | flatten_y_list = np.array(test_y_list).ravel() 105 | flatten_score_list = np.array(score_list).ravel() 106 | image_level_ROCAUC = roc_auc_score(flatten_y_list, flatten_score_list) 107 | flatten_logo_y_list = np.array(logo_y_list).ravel() 108 | flatten_logo_score_list = np.array(logo_score_list).ravel() 109 | logo_img_auroc = roc_auc_score(flatten_logo_y_list, flatten_logo_score_list) 110 | flatten_stru_y_list = np.array(struct_y_list).ravel() 111 | flatten_stru_score_list = np.array(struct_score_list).ravel() 112 | stru_img_auroc = roc_auc_score(flatten_stru_y_list, flatten_stru_score_list) 113 | # image_level_AP = average_precision_score(flatten_y_list, flatten_score_list) 114 | 115 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 116 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 117 | pixel_level_ROCAUC = roc_auc_score(flatten_mask_list, flatten_score_map_list) 118 | # pixel_level_AP = average_precision_score(flatten_mask_list, flatten_score_map_list) 119 | # pro_auc_score = 0 120 | # pro_auc_score = cal_pro_metric_new(test_mask_list, score_map_list, fpr_thresh=0.3) 121 | return round(image_level_ROCAUC,3), round(pixel_level_ROCAUC,3), round(logo_img_auroc,3), round(stru_img_auroc,3) 122 | # return image_level_ROCAUC, pixel_level_ROCAUC 123 | 124 | def F1_score(self, score_map_list, test_mask_list): 125 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 126 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 127 | F1_score = f1_score(flatten_mask_list, flatten_score_map_list) 128 | return F1_score 129 | 130 | def filter(self, pred_mask): 131 | pred_mask_my = np.squeeze(np.squeeze(pred_mask, 0), 0) 132 | pred_mask_my = cv2.medianBlur(np.uint8(pred_mask_my * 255), 7) 133 | mean = np.mean(pred_mask_my) 134 | std = np.std(pred_mask_my) 135 | _ , binary_pred_mask = cv2.threshold(pred_mask_my, mean+2.75*std, 255, type=cv2.THRESH_BINARY) 136 | binary_pred_mask = np.uint8(binary_pred_mask/255) 137 | pred_mask_my = np.expand_dims(np.expand_dims(pred_mask_my, 0), 0) 138 | binary_pred_mask = np.expand_dims(np.expand_dims(binary_pred_mask, 0), 0) 139 | return pred_mask_my, binary_pred_mask 140 | 141 | 142 | # def thresholding(self, pred_mask_my): 143 | # np_img 144 | 145 | # return 146 | # def feature_map_vis(self, feature_map_list): 147 | # # feature_map_list = [torch.mean(i.clone(), dim=1).squeeze(0).cpu().detach().numpy() for i in feature_map_list] 148 | # feature_map_list = [(i.squeeze(0))[25, :, :].cpu().detach().numpy() for i in feature_map_list] 149 | # return feature_map_list 150 | def feature_map_vis(self, feature_map_list): 151 | feature_map_list = [torch.mean(i.clone(), dim=1).squeeze(0).cpu().detach().numpy() for i in feature_map_list] 152 | return feature_map_list 153 | 154 | def test(self): 155 | test_y_list = [] 156 | logo_y_list=[] 157 | struct_y_list = [] 158 | test_mask_list = [] 159 | score_list = [] 160 | score_map_list = [] 161 | logo_score_list = [] 162 | struct_score_list = [] 163 | 164 | for idx, (x, y, mask, name) in enumerate(tqdm(self.testloader, ncols=80)): 165 | test_y_list.extend(y.detach().cpu().numpy()) 166 | test_mask_list.extend(mask.detach().cpu().numpy()) 167 | 168 | self.model.eval() 169 | self.model.to(self.device) 170 | x = x.to(self.device) 171 | mask = mask.to(self.device) 172 | mask_cpu = mask.cpu().detach().numpy()[0, :, :, :].transpose((1, 2, 0)) 173 | ref_x = get_pos_sample(self.opt.referenc_img_file, self.device, 1) 174 | deep_feature, ref_feature, recon_feature, _ = self.model(x, ref_x, None) 175 | feature_map_vis_list = self.feature_map_vis([deep_feature, ref_feature, recon_feature]) 176 | dis_amap, dir_amap = self.model.a_map(deep_feature, recon_feature) 177 | dis_amap = gaussian_filter(dis_amap, sigma=4) 178 | dir_amap = gaussian_filter(dir_amap, sigma=4) 179 | # print(type(name0])) 180 | name_list= name[0].split(r'!') 181 | # print(name_list) 182 | category, img_name = name_list[-2], name_list[-1] 183 | if category == 'logical_anomalies': 184 | logo_y_list.extend(y.detach().cpu().numpy()) 185 | elif category == 'structural_anomalies': 186 | struct_y_list.extend(y.detach().cpu().numpy()) 187 | if self.test_add == False: 188 | amap = dir_amap*dis_amap 189 | # amap = dir_amap*5+dis_amap 190 | else: 191 | # print('ok') 192 | # print(np.max(dir_amap)) 193 | # print(np.max(dis_amap)) 194 | amap = 0.5*(dir_amap/np.max(dir_amap)) + 0.5*(dis_amap/np.max(dis_amap)) 195 | self.save_img([x,*feature_map_vis_list, dis_amap, dir_amap, amap, mask_cpu], os.path.join(self.vis_root, category), img_name) 196 | 197 | 198 | score_list.extend(np.array(np.std(amap)).reshape(1)) 199 | score_map_list.extend(amap.reshape((1, 1, 256, 256))) 200 | if category == 'logical_anomalies': 201 | logo_score_list.extend(np.array(np.std(amap)).reshape(1)) 202 | elif category == 'structural_anomalies': 203 | struct_score_list.extend(np.array(np.std(amap)).reshape(1)) 204 | 205 | image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP= self.cal_auc(score_list, logo_score_list, struct_score_list, score_map_list, test_y_list, logo_y_list, struct_y_list, test_mask_list 206 | ) 207 | # F1_score = self.F1_score(F1_score_map_list, test_mask_list) 208 | print('image_auc_roc: {} '.format(image_level_ROCAUC), 209 | 'pixel_auc_roc: {} '.format(pixel_level_ROCAUC), 210 | 'logi_auroc: {}'.format(image_level_AP), 211 | 'stru_auroc: {}'.format(pixel_level_AP) 212 | ) 213 | class_rocauc[self.opt.class_name] = (image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP) 214 | return image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP 215 | 216 | def vis_img(self, img_list, save_root, idx_name): 217 | os.makedirs(save_root, exist_ok=True) 218 | input_frame = denormalize(img_list[0].clone().squeeze(0).cpu().detach().numpy()) 219 | cv2_input = np.array(input_frame, dtype=np.uint8) 220 | plt.figure() 221 | plt.subplot(241) 222 | plt.imshow(cv2_input) 223 | plt.axis('off') 224 | plt.subplot(242) 225 | plt.imshow(img_list[1]) 226 | plt.axis('off') 227 | plt.subplot(243) 228 | plt.imshow(img_list[2]) 229 | plt.axis('off') 230 | plt.subplot(244) 231 | plt.imshow(img_list[3]) 232 | plt.axis('off') 233 | plt.subplot(245) 234 | plt.imshow(img_list[4], cmap='jet') 235 | plt.axis('off') 236 | plt.subplot(246) 237 | plt.imshow(img_list[5],cmap='jet') 238 | plt.axis('off') 239 | plt.subplot(247) 240 | plt.imshow(img_list[6], cmap='jet') 241 | plt.axis('off') 242 | plt.subplot(248) 243 | plt.imshow(img_list[7]) 244 | plt.axis('off') 245 | plt.savefig(os.path.join(save_root, idx_name)) 246 | plt.close() 247 | 248 | def save_img(self, img_list, save_root, idx_name): 249 | os.makedirs(save_root, exist_ok=True) 250 | input_frame = denormalize(img_list[0].clone().squeeze(0).cpu().detach().numpy()) 251 | cv2_input = np.array(input_frame, dtype=np.uint8) 252 | # plt.figure() 253 | # plt.subplot(241) 254 | plt.imsave(os.path.join(save_root, f'{idx_name}_{0}.png'), cv2_input) 255 | plt.imsave(os.path.join(save_root, f'{idx_name}_{1}.png'), img_list[1]) 256 | plt.imsave(os.path.join(save_root, f'{idx_name}_{2}.png'), img_list[2]) 257 | plt.imsave(os.path.join(save_root, f'{idx_name}_{3}.png'), img_list[3]) 258 | plt.imsave(os.path.join(save_root, f'{idx_name}_{4}.png'), img_list[4], cmap='jet') 259 | plt.imsave(os.path.join(save_root, f'{idx_name}_{5}.png'), img_list[5], cmap='jet') 260 | plt.imsave(os.path.join(save_root, f'{idx_name}_{6}.png'), img_list[6], cmap='jet') 261 | 262 | plt.imsave(os.path.join(save_root, f'{idx_name}_{7}.png'),cv2.cvtColor(img_list[7], cv2.COLOR_GRAY2RGB), cmap='gray') 263 | # plt.axis('off') 264 | # plt.subplot(242) 265 | # plt.imwrite() 266 | # # plt.axis('off') 267 | # # plt.subplot(243) 268 | # plt.imshow(img_list[2]) 269 | # # plt.axis('off') 270 | # # plt.subplot(244) 271 | # plt.imshow(img_list[3]) 272 | # # plt.axis('off') 273 | # # plt.subplot(245) 274 | # plt.imshow(img_list[4], cmap='jet') 275 | # 276 | # # plt.axis('off') 277 | # # plt.subplot(246) 278 | # plt.imshow(img_list[5],cmap='jet') 279 | # # plt.axis('off') 280 | # # plt.subplot(247) 281 | # plt.imshow(img_list[6], cmap='jet') 282 | # # plt.axis('off') 283 | # # plt.subplot(248) 284 | # plt.imshow(img_list[7]) 285 | # plt.axis('off') 286 | # plt.savefig(os.path.join(save_root, idx_name)) 287 | # plt.close() 288 | 289 | def tensor_to_np_cpu(self, tensor): 290 | x_cpu = tensor.squeeze(0).data.cpu().numpy() 291 | x_cpu = np.transpose(x_cpu, (1, 2, 0)) 292 | return x_cpu 293 | 294 | def check(self, img): 295 | if len(img.shape) == 2: 296 | return img 297 | if img.shape[2] == 3: 298 | return img 299 | elif img.shape[2] == 1: 300 | return img.reshape(img.shape[0], img.shape[1]) 301 | 302 | MVTec_CLASS_NAMES = [ 'breakfast_box', 'juice_bottle', 'pushpins', 'screw_bag', 'splicing_connectors'] 303 | 304 | class_rocauc = { 305 | 'breakfast_box':(0, 0, 0, 0), 306 | 'juice_bottle':(0, 0, 0, 0), 307 | 'pushpins':(0, 0, 0, 0), 308 | 'screw_bag':(0, 0, 0, 0), 309 | 'splicing_connectors':(0, 0, 0, 0) 310 | } 311 | 312 | model_name_list = ['RB_VIT', 'RB_VIT_dir', 'VIT_dir', 'RB_VIT_average', 'RB_VIT_dir_mask25', 'RB_VIT_dir_mask50', 'RB_VIT_dir_mask75', 'ST_VIT', 'RB_VIT_Res', 'RB_VIT_dir_no_ref', 313 | 'RB_VIT_dir_res_ref', 'RB_VIT_dir_res_ref_mask25', 'RB_VIT_dir_res_ref_mask50', 'RB_VIT_dir_res_ref_mask75', 314 | 'RB_VIT_dir_res_ref_mask_random'] 315 | 316 | if __name__ == '__main__': 317 | opt = DefaultConfig() 318 | test_no_mask = True 319 | test_add = False 320 | from datasets.logodataset import MVTecDataset 321 | from torch.utils.data import DataLoader 322 | opt.model_name = model_name_list[1] 323 | for classname in MVTec_CLASS_NAMES: 324 | opt.class_name = classname 325 | # opt.class_name = 'capsule' 326 | opt.referenc_img_file = f'data/mvtec_anomaly_detection/{opt.class_name}/train/good/000.png' 327 | save_name = opt.model_name+'_'+opt.backbone_name 328 | opt.resume = fr'F:\LW\RB-VIT\result/{save_name}/weight/{opt.class_name}' 329 | # if test_no_mask == True: 330 | # opt.model_name = 'RB_VIT_dir_res_ref' 331 | print(opt.class_name, opt.model_name) 332 | # print(opt.referenc_img_file) 333 | # opt.resume = r'result/RB_VIT_dir_res_ref_VGG/weight/capsule' 334 | opt.train_dataset = MVTecDataset(dataset_path=opt.data_root, class_name=opt.class_name, is_train=True) 335 | opt.test_dataset = MVTecDataset(dataset_path=opt.data_root, class_name=opt.class_name, is_train=False) 336 | opt.trainloader = DataLoader(opt.train_dataset, batch_size=opt.batch_size, shuffle=True) 337 | opt.testloader = DataLoader(opt.test_dataset, batch_size=1, shuffle=False) 338 | model = Model(opt, test_no_mask, test_add) 339 | model.test() 340 | print(class_rocauc) 341 | value = list(class_rocauc.values()) 342 | img_roc = [i[0] for i in value] 343 | pixel_roc = [i[1] for i in value] 344 | img_ap = [i[2] for i in value] 345 | pixel_ap = [i[3] for i in value] 346 | mean_img_roc = np.mean(np.array(img_roc)) 347 | mean_pixel_roc = np.mean(np.array(pixel_roc)) 348 | mean_img_ap = np.mean(np.array(img_ap)) 349 | mean_pixel_ap = np.mean(np.array(pixel_ap)) 350 | 351 | print(round(mean_img_roc,3), round(mean_pixel_roc,3), round(mean_img_ap, 3), round(mean_pixel_ap,3)) 352 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from config import DefaultConfig 2 | import os 3 | from torch import optim 4 | from tqdm import tqdm 5 | from matplotlib import pyplot as plt 6 | import numpy as np 7 | from sklearn.metrics import roc_auc_score, f1_score, average_precision_score 8 | import cv2 9 | from models.misc import NativeScalerWithGradNormCount as NativeScaler 10 | from scipy.ndimage import gaussian_filter 11 | from datasets.dataset import denormalize 12 | from models.TFA_Net_model import * 13 | from ref_find import get_pos_sample 14 | class Model(object): 15 | def __init__(self, opt): 16 | super(Model, self).__init__() 17 | self.opt = opt 18 | self.model = eval(opt.model_name)(opt) 19 | self.device = opt.device 20 | 21 | self.class_name = opt.class_name 22 | self.trainloader = opt.trainloader 23 | self.testloader = opt.testloader 24 | self.loss_scaler = NativeScaler() 25 | # self.opt.model_name = self.opt.model_name 26 | if self.opt.resume != "": 27 | print('\nload pre-trained networks') 28 | self.opt.iter = torch.load(os.path.join(self.opt.resume, f'{opt.model_name}_{opt.backbone_name}_k={opt.k}.pth'))['epoch'] 29 | print(self.opt.iter) 30 | self.model.load_state_dict(torch.load(os.path.join(self.opt.resume, f'{opt.model_name}_{opt.backbone_name}_k={opt.k}.pth'))['state_dict'], strict=False) 31 | print('\ndone.\n') 32 | 33 | if self.opt.isTrain: 34 | self.model.Roncon_model.train() 35 | self.optimizer_g = optim.AdamW(self.model.Roncon_model.parameters(), lr=opt.lr, betas=(0.9, 0.95)) 36 | 37 | self.save_root = f"./result/{opt.model_name}_{opt.backbone_name}_k={opt.k}/" 38 | # os.makedirs(os.path.join(self.save_root, "weight"), exist_ok=True) 39 | self.ckpt_root = os.path.join(self.save_root, "weight/{}".format(self.class_name)) 40 | self.vis_root = os.path.join(self.save_root, "img/{}".format(self.class_name)) 41 | 42 | 43 | 44 | def get_max(self, tensor): 45 | a_1, _ = torch.max(tensor, dim=1, keepdim=True) 46 | a_2, _ = torch.max(a_1, dim=2, keepdim=True) 47 | a_3, _ = torch.max(a_2, dim=3, keepdim=True) 48 | return a_3 49 | def train(self): 50 | 51 | loss_now = 100000 52 | auc_now = 0. 53 | patience = 20 54 | no_update_num = 0 55 | for epoch in range(self.opt.iter, self.opt.niter): 56 | self.model.Feature_extractor.eval() 57 | self.model.Roncon_model.train(True) 58 | self.model.to(self.device) 59 | loss_total = 0. 60 | count = 0 61 | for index, (x, _, _, _) in enumerate(tqdm(self.trainloader, ncols=80)): 62 | bs = x.shape[0] 63 | x = x.to(self.device) 64 | ref_x = get_pos_sample(self.opt.referenc_img_file, self.device, bs) 65 | 66 | deep_feature, _, recon_feature, loss = self.model(x, ref_x, 'train') 67 | self.loss_scaler(loss, self.optimizer_g, parameters=self.model.Roncon_model.parameters(), update_grad=(index + 1) % 1 == 0) 68 | loss_total += loss.item() 69 | count += 1 70 | 71 | loss_total = loss_total / count 72 | print('the {} epoch is done loss:{}'.format(epoch + 1, loss_total)) 73 | if (epoch + 1) % 10 == 0: 74 | # self.test_2() 75 | x1, x2, x3, x4 = self.test() 76 | auc_roc = x1+x2 77 | if auc_roc > auc_now: 78 | no_update_num = 0 79 | auc_now = auc_roc 80 | class_rocauc[self.opt.class_name] = (x1, x2, x3, x4) 81 | print('save model') 82 | weight_dir = self.ckpt_root 83 | os.makedirs(weight_dir, exist_ok=True) 84 | torch.save({'epoch': epoch + 1, 'state_dict': self.model.state_dict()}, 85 | f'%s/{self.opt.model_name}_{self.opt.backbone_name}.pth' % (weight_dir)) 86 | else: 87 | no_update_num += 1 88 | print('no_update_num:{}'.format(no_update_num)) 89 | if no_update_num > patience: 90 | break 91 | 92 | 93 | 94 | def cal_auc(self, score_list, score_map_list, test_y_list, test_mask_list): 95 | flatten_y_list = np.array(test_y_list).ravel() 96 | flatten_score_list = np.array(score_list).ravel() 97 | image_level_ROCAUC = roc_auc_score(flatten_y_list, flatten_score_list) 98 | image_level_AP = average_precision_score(flatten_y_list, flatten_score_list) 99 | 100 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 101 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 102 | pixel_level_ROCAUC = roc_auc_score(flatten_mask_list, flatten_score_map_list) 103 | pixel_level_AP = average_precision_score(flatten_mask_list, flatten_score_map_list) 104 | # pro_auc_score = 0 105 | # pro_auc_score = cal_pro_metric_new(test_mask_list, score_map_list, fpr_thresh=0.3) 106 | return round(image_level_ROCAUC, 3), round(pixel_level_ROCAUC, 3), round(image_level_AP, 3), round(pixel_level_AP, 3) 107 | # return image_level_ROCAUC, pixel_level_ROCAUC 108 | 109 | def F1_score(self, score_map_list, test_mask_list): 110 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 111 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 112 | F1_score = f1_score(flatten_mask_list, flatten_score_map_list) 113 | return F1_score 114 | 115 | def filter(self, pred_mask): 116 | pred_mask_my = np.squeeze(np.squeeze(pred_mask, 0), 0) 117 | pred_mask_my = cv2.medianBlur(np.uint8(pred_mask_my * 255), 7) 118 | mean = np.mean(pred_mask_my) 119 | std = np.std(pred_mask_my) 120 | _ , binary_pred_mask = cv2.threshold(pred_mask_my, mean+2.75*std, 255, type=cv2.THRESH_BINARY) 121 | binary_pred_mask = np.uint8(binary_pred_mask/255) 122 | pred_mask_my = np.expand_dims(np.expand_dims(pred_mask_my, 0), 0) 123 | binary_pred_mask = np.expand_dims(np.expand_dims(binary_pred_mask, 0), 0) 124 | return pred_mask_my, binary_pred_mask 125 | 126 | 127 | # def thresholding(self, pred_mask_my): 128 | # np_img 129 | 130 | # return 131 | def feature_map_vis(self, feature_map_list): 132 | feature_map_list = [torch.mean(i.clone(), dim=1).squeeze(0).cpu().detach().numpy() for i in feature_map_list] 133 | return feature_map_list 134 | 135 | def test(self): 136 | test_y_list = [] 137 | test_mask_list = [] 138 | score_list = [] 139 | score_map_list = [] 140 | 141 | for idx, (x, y, mask, name) in enumerate(tqdm(self.testloader, ncols=80)): 142 | test_y_list.extend(y.detach().cpu().numpy()) 143 | test_mask_list.extend(mask.detach().cpu().numpy()) 144 | self.model.eval() 145 | self.model.to(self.device) 146 | x = x.to(self.device) 147 | mask = mask.to(self.device) 148 | mask_cpu = mask.cpu().detach().numpy()[0, :, :, :].transpose((1, 2, 0)) 149 | ref_x = get_pos_sample(self.opt.referenc_img_file, self.device, 1) 150 | deep_feature, ref_feature, recon_feature, _ = self.model(x, ref_x, 'test') 151 | feature_map_vis_list = self.feature_map_vis([deep_feature, ref_feature, recon_feature]) 152 | dis_amap, dir_amap = self.model.a_map(deep_feature, recon_feature) 153 | dis_amap = gaussian_filter(dis_amap, sigma=4) 154 | dir_amap = gaussian_filter(dir_amap, sigma=4) 155 | # print(type(name0])) 156 | name_list = name[0].split(r'!') 157 | # print(name_list) 158 | category, img_name = name_list[-2], name_list[-1] 159 | amap = dir_amap*dis_amap 160 | # amap = dir_amap + dis_amap 161 | self.vis_img([x,*feature_map_vis_list, dis_amap, dir_amap, amap, mask_cpu], os.path.join(self.vis_root, category), img_name) 162 | 163 | 164 | score_list.extend(np.array(np.std(amap)).reshape(1)) 165 | score_map_list.extend(amap.reshape((1, 1, 256, 256))) 166 | 167 | 168 | image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP= self.cal_auc(score_list, score_map_list, test_y_list, test_mask_list) 169 | # F1_score = self.F1_score(F1_score_map_list, test_mask_list) 170 | print('image_auc_roc: {} '.format(image_level_ROCAUC), 171 | 'pixel_auc_roc: {} '.format(pixel_level_ROCAUC), 172 | 'image_AP: {}'.format(image_level_AP), 173 | 'pixel_AP: {}'.format(pixel_level_AP) 174 | ) 175 | return image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP 176 | 177 | def vis_img(self, img_list, save_root, idx_name): 178 | os.makedirs(save_root, exist_ok=True) 179 | input_frame = denormalize(img_list[0].clone().squeeze(0).cpu().detach().numpy()) 180 | cv2_input = np.array(input_frame, dtype=np.uint8) 181 | plt.figure() 182 | plt.subplot(241) 183 | plt.imshow(cv2_input) 184 | plt.axis('off') 185 | plt.subplot(242) 186 | plt.imshow(img_list[1]) 187 | plt.axis('off') 188 | plt.subplot(243) 189 | plt.imshow(img_list[2]) 190 | plt.axis('off') 191 | plt.subplot(244) 192 | plt.imshow(img_list[3]) 193 | plt.axis('off') 194 | plt.subplot(245) 195 | plt.imshow(img_list[4], cmap='jet') 196 | plt.axis('off') 197 | plt.subplot(246) 198 | plt.imshow(img_list[5],cmap='jet') 199 | plt.axis('off') 200 | plt.subplot(247) 201 | plt.imshow(img_list[6], cmap='jet') 202 | plt.axis('off') 203 | plt.subplot(248) 204 | plt.imshow(img_list[7]) 205 | plt.axis('off') 206 | plt.savefig(os.path.join(save_root, idx_name)) 207 | plt.close() 208 | 209 | 210 | def tensor_to_np_cpu(self, tensor): 211 | x_cpu = tensor.squeeze(0).data.cpu().numpy() 212 | x_cpu = np.transpose(x_cpu, (1, 2, 0)) 213 | return x_cpu 214 | 215 | def check(self, img): 216 | if len(img.shape) == 2: 217 | return img 218 | if img.shape[2] == 3: 219 | return img 220 | elif img.shape[2] == 1: 221 | return img.reshape(img.shape[0], img.shape[1]) 222 | 223 | MVTec_CLASS_NAMES = ['transistor', 'bottle', 'cable', 'capsule', 'carpet', 'grid', 224 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 225 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper'] 226 | # MVTec_CLASS_NAMES = ['hazelnut'] 227 | 228 | class_rocauc = { 229 | 'bottle':(0, 0, 0, 0), 230 | 'cable':(0, 0, 0, 0), 231 | 'capsule':(0, 0, 0, 0), 232 | 'carpet':(0, 0, 0, 0), 233 | 'grid':(0, 0, 0, 0), 234 | 'hazelnut':(0, 0, 0, 0), 235 | 'leather':(0, 0, 0, 0), 236 | 'metal_nut':(0, 0, 0, 0), 237 | 'pill':(0, 0, 0, 0), 238 | 'screw':(0, 0, 0, 0), 239 | 'tile':(0, 0, 0, 0), 240 | 'toothbrush':(0, 0, 0, 0), 241 | 'transistor':(0, 0, 0, 0), 242 | 'wood':(0, 0, 0, 0), 243 | 'zipper':(0, 0, 0, 0)} 244 | 245 | if __name__ == '__main__': 246 | opt = DefaultConfig() 247 | from datasets.dataset import MVTecDataset 248 | from torch.utils.data import DataLoader 249 | for classname in MVTec_CLASS_NAMES: 250 | opt.class_name = classname 251 | # opt.class_name = 'capsule' 252 | opt.referenc_img_file = f'data/mvtec_anomaly_detection/{opt.class_name}/train/good/000.png' 253 | # opt.referenc_img_file = f'data/ref/{opt.class_name}/ref.png' 254 | # opt.referenc_img_file = f'natrual.JPEG' 255 | print(opt.class_name, opt.model_name) 256 | # print(opt.referenc_img_file) 257 | # opt.resume = r'result/RB_VIT_dir_res_ref_VGG/weight/capsule' 258 | opt.train_dataset = MVTecDataset(dataset_path=opt.data_root, class_name=opt.class_name, is_train=True) 259 | opt.test_dataset = MVTecDataset(dataset_path=opt.data_root, class_name=opt.class_name, is_train=False) 260 | opt.trainloader = DataLoader(opt.train_dataset, batch_size=opt.batch_size, shuffle=True) 261 | opt.testloader = DataLoader(opt.test_dataset, batch_size=1, shuffle=False) 262 | model = Model(opt) 263 | model.train() 264 | print(class_rocauc) 265 | value = list(class_rocauc.values()) 266 | img_roc = [i[0] for i in value] 267 | pixel_roc = [i[1] for i in value] 268 | img_ap = [i[2] for i in value] 269 | pixel_ap = [i[3] for i in value] 270 | mean_img_roc = np.mean(np.array(img_roc)) 271 | mean_pixel_roc = np.mean(np.array(pixel_roc)) 272 | mean_img_ap = np.mean(np.array(img_ap)) 273 | mean_pixel_ap = np.mean(np.array(pixel_ap)) 274 | 275 | print(round(mean_img_roc, 3), round(mean_pixel_roc, 3), round(mean_img_ap, 3), round(mean_pixel_ap, 3)) 276 | -------------------------------------------------------------------------------- /train_mvlogo.py: -------------------------------------------------------------------------------- 1 | from config import DefaultConfig 2 | import os 3 | from torch import optim 4 | from tqdm import tqdm 5 | from matplotlib import pyplot as plt 6 | import numpy as np 7 | from sklearn.metrics import roc_auc_score, f1_score, average_precision_score 8 | import cv2 9 | from models.misc import NativeScalerWithGradNormCount as NativeScaler 10 | from scipy.ndimage import gaussian_filter 11 | from datasets.dataset import denormalize 12 | from models.TFA_Net_model import * 13 | class Model(object): 14 | def __init__(self, opt): 15 | super(Model, self).__init__() 16 | self.opt = opt 17 | self.model = eval(opt.model_name)(opt) 18 | self.device = opt.device 19 | 20 | self.class_name = opt.class_name 21 | self.trainloader = opt.trainloader 22 | self.testloader = opt.testloader 23 | self.loss_scaler = NativeScaler() 24 | # self.opt.model_name = self.opt.model_name 25 | if self.opt.resume != "": 26 | print('\nload pre-trained networks') 27 | self.opt.iter = torch.load(os.path.join(self.opt.resume, f'{opt.model_name}_{opt.backbone_name}_k={opt.k}.pth'))['epoch'] 28 | print(self.opt.iter) 29 | self.model.load_state_dict(torch.load(os.path.join(self.opt.resume, f'{opt.model_name}_{opt.backbone_name}_k={opt.k}.pth'))['state_dict'], strict=False) 30 | print('\ndone.\n') 31 | 32 | if self.opt.isTrain: 33 | self.model.Roncon_model.train() 34 | self.optimizer_g = optim.AdamW(self.model.Roncon_model.parameters(), lr=opt.lr, betas=(0.9, 0.95)) 35 | 36 | self.save_root = f"./result_logo/{opt.model_name}_{opt.backbone_name}_k={opt.k}/" 37 | # os.makedirs(os.path.join(self.save_root, "weight"), exist_ok=True) 38 | self.ckpt_root = os.path.join(self.save_root, "weight/{}".format(self.class_name)) 39 | self.vis_root = os.path.join(self.save_root, "img/{}".format(self.class_name)) 40 | 41 | 42 | 43 | def get_max(self, tensor): 44 | a_1, _ = torch.max(tensor, dim=1, keepdim=True) 45 | a_2, _ = torch.max(a_1, dim=2, keepdim=True) 46 | a_3, _ = torch.max(a_2, dim=3, keepdim=True) 47 | return a_3 48 | def train(self): 49 | 50 | loss_now = 100000 51 | auc_now = 0. 52 | patience = 20 53 | no_update_num = 0 54 | for epoch in range(self.opt.iter, self.opt.niter): 55 | self.model.Feature_extractor.eval() 56 | self.model.Roncon_model.train(True) 57 | self.model.to(self.device) 58 | loss_total = 0. 59 | count = 0 60 | for index, (x, _, _, _) in enumerate(tqdm(self.trainloader, ncols=80)): 61 | bs = x.shape[0] 62 | x = x.to(self.device) 63 | ref_x = get_pos_sample(self.opt.referenc_img_file, self.device, bs) 64 | 65 | deep_feature, _, recon_feature, loss = self.model(x, ref_x, 'train') 66 | self.loss_scaler(loss, self.optimizer_g, parameters=self.model.Roncon_model.parameters(), update_grad=(index + 1) % 1 == 0) 67 | loss_total += loss.item() 68 | count += 1 69 | 70 | loss_total = loss_total / count 71 | print('the {} epoch is done loss:{}'.format(epoch + 1, loss_total)) 72 | if (epoch + 1) % 1 == 0: 73 | # self.test_2() 74 | x1, x2, x3, x4 = self.test() 75 | auc_roc = x1+x2 76 | if auc_roc > auc_now: 77 | no_update_num = 0 78 | auc_now = auc_roc 79 | class_rocauc[self.opt.class_name] = (x1, x2, x3, x4) 80 | print('save model') 81 | weight_dir = self.ckpt_root 82 | os.makedirs(weight_dir, exist_ok=True) 83 | torch.save({'epoch': epoch + 1, 'state_dict': self.model.state_dict()}, 84 | f'%s/{self.opt.model_name}_{self.opt.backbone_name}.pth' % (weight_dir)) 85 | else: 86 | no_update_num += 1 87 | print('no_update_num:{}'.format(no_update_num)) 88 | if no_update_num > patience: 89 | break 90 | 91 | 92 | 93 | def cal_auc(self, score_list, score_map_list, test_y_list, test_mask_list): 94 | flatten_y_list = np.array(test_y_list,dtype=np.uint8).ravel() 95 | flatten_score_list = np.array(score_list).ravel() 96 | image_level_ROCAUC = roc_auc_score(flatten_y_list, flatten_score_list) 97 | image_level_AP = average_precision_score(flatten_y_list, flatten_score_list) 98 | 99 | flatten_mask_list = np.concatenate(test_mask_list).ravel().astype(np.uint8) 100 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 101 | pixel_level_ROCAUC = roc_auc_score(flatten_mask_list, flatten_score_map_list) 102 | pixel_level_AP = average_precision_score(flatten_mask_list, flatten_score_map_list) 103 | # pro_auc_score = 0 104 | # pro_auc_score = cal_pro_metric_new(test_mask_list, score_map_list, fpr_thresh=0.3) 105 | return round(image_level_ROCAUC, 3), round(pixel_level_ROCAUC, 3), round(image_level_AP, 3), round(pixel_level_AP, 3) 106 | # return image_level_ROCAUC, pixel_level_ROCAUC 107 | 108 | def F1_score(self, score_map_list, test_mask_list): 109 | flatten_mask_list = np.concatenate(test_mask_list).ravel() 110 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 111 | F1_score = f1_score(flatten_mask_list, flatten_score_map_list) 112 | return F1_score 113 | 114 | def filter(self, pred_mask): 115 | pred_mask_my = np.squeeze(np.squeeze(pred_mask, 0), 0) 116 | pred_mask_my = cv2.medianBlur(np.uint8(pred_mask_my * 255), 7) 117 | mean = np.mean(pred_mask_my) 118 | std = np.std(pred_mask_my) 119 | _ , binary_pred_mask = cv2.threshold(pred_mask_my, mean+2.75*std, 255, type=cv2.THRESH_BINARY) 120 | binary_pred_mask = np.uint8(binary_pred_mask/255) 121 | pred_mask_my = np.expand_dims(np.expand_dims(pred_mask_my, 0), 0) 122 | binary_pred_mask = np.expand_dims(np.expand_dims(binary_pred_mask, 0), 0) 123 | return pred_mask_my, binary_pred_mask 124 | 125 | 126 | # def thresholding(self, pred_mask_my): 127 | # np_img 128 | 129 | # return 130 | def feature_map_vis(self, feature_map_list): 131 | feature_map_list = [torch.mean(i.clone(), dim=1).squeeze(0).cpu().detach().numpy() for i in feature_map_list] 132 | return feature_map_list 133 | 134 | def test(self): 135 | test_y_list = [] 136 | test_mask_list = [] 137 | score_list = [] 138 | score_map_list = [] 139 | 140 | for idx, (x, y, mask, name) in enumerate(tqdm(self.testloader, ncols=80)): 141 | test_y_list.extend(y.detach().cpu().numpy()) 142 | test_mask_list.extend(mask.detach().cpu().numpy()) 143 | self.model.eval() 144 | self.model.to(self.device) 145 | x = x.to(self.device) 146 | mask = mask.to(self.device) 147 | mask_cpu = mask.cpu().detach().numpy()[0, :, :, :].transpose((1, 2, 0)) 148 | ref_x = get_pos_sample(self.opt.referenc_img_file, self.device, 1) 149 | deep_feature, ref_feature, recon_feature, _ = self.model(x, ref_x, 'test') 150 | feature_map_vis_list = self.feature_map_vis([deep_feature, ref_feature, recon_feature]) 151 | dis_amap, dir_amap = self.model.a_map(deep_feature, recon_feature) 152 | dis_amap = gaussian_filter(dis_amap, sigma=4) 153 | dir_amap = gaussian_filter(dir_amap, sigma=4) 154 | # print(type(name0])) 155 | name_list = name[0].split(r'!') 156 | # print(name_list) 157 | category, img_name = name_list[-2], name_list[-1] 158 | amap = dir_amap*dis_amap 159 | # amap = dir_amap + dis_amap 160 | self.vis_img([x,*feature_map_vis_list, dis_amap, dir_amap, amap, mask_cpu], os.path.join(self.vis_root, category), img_name) 161 | 162 | 163 | score_list.extend(np.array(np.std(amap)).reshape(1)) 164 | score_map_list.extend(amap.reshape((1, 1, 256, 256))) 165 | 166 | 167 | image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP= self.cal_auc(score_list, score_map_list, test_y_list, test_mask_list) 168 | # F1_score = self.F1_score(F1_score_map_list, test_mask_list) 169 | print('image_auc_roc: {} '.format(image_level_ROCAUC), 170 | 'pixel_auc_roc: {} '.format(pixel_level_ROCAUC), 171 | 'image_AP: {}'.format(image_level_AP), 172 | 'pixel_AP: {}'.format(pixel_level_AP) 173 | ) 174 | return image_level_ROCAUC, pixel_level_ROCAUC, image_level_AP, pixel_level_AP 175 | 176 | def vis_img(self, img_list, save_root, idx_name): 177 | os.makedirs(save_root, exist_ok=True) 178 | input_frame = denormalize(img_list[0].clone().squeeze(0).cpu().detach().numpy()) 179 | cv2_input = np.array(input_frame, dtype=np.uint8) 180 | plt.figure() 181 | plt.subplot(241) 182 | plt.imshow(cv2_input) 183 | plt.axis('off') 184 | plt.subplot(242) 185 | plt.imshow(img_list[1]) 186 | plt.axis('off') 187 | plt.subplot(243) 188 | plt.imshow(img_list[2]) 189 | plt.axis('off') 190 | plt.subplot(244) 191 | plt.imshow(img_list[3]) 192 | plt.axis('off') 193 | plt.subplot(245) 194 | plt.imshow(img_list[4], cmap='jet') 195 | plt.axis('off') 196 | plt.subplot(246) 197 | plt.imshow(img_list[5],cmap='jet') 198 | plt.axis('off') 199 | plt.subplot(247) 200 | plt.imshow(img_list[6], cmap='jet') 201 | plt.axis('off') 202 | plt.subplot(248) 203 | plt.imshow(img_list[7]) 204 | plt.axis('off') 205 | plt.savefig(os.path.join(save_root, idx_name)) 206 | plt.close() 207 | 208 | 209 | def tensor_to_np_cpu(self, tensor): 210 | x_cpu = tensor.squeeze(0).data.cpu().numpy() 211 | x_cpu = np.transpose(x_cpu, (1, 2, 0)) 212 | return x_cpu 213 | 214 | def check(self, img): 215 | if len(img.shape) == 2: 216 | return img 217 | if img.shape[2] == 3: 218 | return img 219 | elif img.shape[2] == 1: 220 | return img.reshape(img.shape[0], img.shape[1]) 221 | 222 | MVTec_CLASS_NAMES = [ 'breakfast_box', 'juice_bottle', 'pushpins', 'screw_bag', 'splicing_connectors'] 223 | 224 | class_rocauc = { 225 | 'breakfast_box':(0, 0, 0, 0), 226 | 'juice_bottle':(0, 0, 0, 0), 227 | 'pushpins':(0, 0, 0, 0), 228 | 'screw_bag':(0, 0, 0, 0), 229 | 'splicing_connectors':(0, 0, 0, 0) 230 | } 231 | 232 | if __name__ == '__main__': 233 | opt = DefaultConfig() 234 | from datasets.logodataset import MVTecDataset 235 | from torch.utils.data import DataLoader 236 | for classname in MVTec_CLASS_NAMES: 237 | opt.class_name = classname 238 | # opt.class_name = 'capsule' 239 | opt.referenc_img_file = f'data/mvtec_loco_anomaly_detection/{opt.class_name}/train/good/000.png' 240 | # opt.referenc_img_file = f'data/ref/{opt.class_name}/ref.png' 241 | # opt.referenc_img_file = f'natrual.JPEG' 242 | print(opt.class_name, opt.model_name) 243 | # print(opt.referenc_img_file) 244 | # opt.resume = r'result/RB_VIT_dir_res_ref_VGG/weight/capsule' 245 | opt.train_dataset = MVTecDataset(dataset_path=opt.data_root, class_name=opt.class_name, is_train=True) 246 | opt.test_dataset = MVTecDataset(dataset_path=opt.data_root, class_name=opt.class_name, is_train=False) 247 | opt.trainloader = DataLoader(opt.train_dataset, batch_size=opt.batch_size, shuffle=True) 248 | opt.testloader = DataLoader(opt.test_dataset, batch_size=1, shuffle=False) 249 | model = Model(opt) 250 | model.train() 251 | print(class_rocauc) 252 | value = list(class_rocauc.values()) 253 | img_roc = [i[0] for i in value] 254 | pixel_roc = [i[1] for i in value] 255 | img_ap = [i[2] for i in value] 256 | pixel_ap = [i[3] for i in value] 257 | mean_img_roc = np.mean(np.array(img_roc)) 258 | mean_pixel_roc = np.mean(np.array(pixel_roc)) 259 | mean_img_ap = np.mean(np.array(img_ap)) 260 | mean_pixel_ap = np.mean(np.array(pixel_ap)) 261 | 262 | print(round(mean_img_roc, 3), round(mean_pixel_roc, 3), round(mean_img_ap, 3), round(mean_pixel_ap, 3)) 263 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/utils/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /utils/gen_mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def gen_mask(k_list, n, im_size): 5 | while True: 6 | Ms = [] 7 | for k in k_list: 8 | N = im_size // k 9 | rdn = np.random.permutation(N**2) 10 | additive = N**2 % n 11 | if additive > 0: 12 | rdn = np.concatenate((rdn, np.asarray([-1] * (n - additive)))) 13 | n_index = rdn.reshape(n, -1) 14 | for index in n_index: 15 | tmp = [0. if i in index else 1. for i in range(N**2)] 16 | tmp = np.asarray(tmp).reshape(N, N) 17 | tmp = tmp.repeat(k, 0).repeat(k, 1) 18 | Ms.append(tmp) 19 | yield Ms 20 | 21 | if __name__ == '__main__': 22 | g = gen_mask([2, 4, 8, 16], 3, 256) 23 | # b = next(next(g)) 24 | b = next(g) 25 | print(b[0].shape) 26 | # for b in g: 27 | # # b[0] = b[0].astype(np.bool) 28 | # print(b[0].shape) 29 | # cv2.imshow('1', b[0]) 30 | # # b[1] = b[1].astype(np.bool) 31 | # cv2.imshow('2', b[1]) 32 | # # b[2] = b[2].astype(np.bool) 33 | # cv2.imshow('3', b[2]) 34 | # print(b[0]+b[1]+b[2]) 35 | # 36 | # # print(np.sum(~b[0]+~b[1]+~b[2])) 37 | # # a = np.array([1, 1]) 38 | # # print(~a) 39 | # cv2.waitKey(0) 40 | # print(np.random.permutation(4)) -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | """Classes for handling ground truth and anomaly maps.""" 2 | 3 | import glob 4 | import os 5 | from functools import lru_cache 6 | from typing import Sequence, Optional, Mapping, Iterable, Tuple, Union, Any 7 | 8 | import numpy as np 9 | import tifffile 10 | from PIL import Image 11 | 12 | 13 | def get_file_path_repr(file_path: Optional[str]) -> str: 14 | if file_path is None: 15 | return 'no file path' 16 | else: 17 | parent_dir_path, file_name = os.path.split(file_path) 18 | _, parent_dir = os.path.split(parent_dir_path) 19 | return f'.../{parent_dir}/{file_name}' 20 | 21 | 22 | class DefectConfig: 23 | def __init__(self, 24 | defect_name: str, 25 | pixel_value: int, 26 | saturation_threshold: Union[int, float], 27 | relative_saturation: bool): 28 | # Input validation. 29 | assert 1 <= pixel_value <= 255 30 | if relative_saturation: 31 | assert isinstance(saturation_threshold, float) 32 | assert 0. < saturation_threshold <= 1. 33 | else: 34 | assert isinstance(saturation_threshold, int) 35 | 36 | self.defect_name = defect_name 37 | self.pixel_value = pixel_value 38 | self.saturation_threshold = saturation_threshold 39 | self.relative_saturation = relative_saturation 40 | 41 | def __repr__(self): 42 | return f'DefectConfig({self.__dict__})' 43 | 44 | 45 | class DefectsConfig: 46 | def __init__(self, entries: Sequence[DefectConfig]): 47 | # Create a pixel_value -> entry mapping for faster lookup. 48 | self.pixel_value_to_entry = {e.pixel_value: e for e in entries} 49 | 50 | @property 51 | def entries(self): 52 | return tuple(self.pixel_value_to_entry.values()) 53 | 54 | @classmethod 55 | def create_from_list(cls, defects_list: Sequence[Mapping[str, Any]]): 56 | entries = [] 57 | for defect_config in defects_list: 58 | entry = DefectConfig( 59 | defect_name=defect_config['defect_name'], 60 | pixel_value=defect_config['pixel_value'], 61 | saturation_threshold=defect_config['saturation_threshold'], 62 | relative_saturation=defect_config['relative_saturation']) 63 | entries.append(entry) 64 | return DefectsConfig(entries=entries) 65 | 66 | 67 | class GroundTruthChannel: 68 | """A channel of a ground truth map. 69 | 70 | Corresponds to exactly one defect in a ground truth map. Must not be used 71 | to represent a defect-free image. 72 | """ 73 | 74 | def __init__(self, 75 | bool_array: np.ndarray, 76 | defect_config: DefectConfig): 77 | """ 78 | Args: 79 | bool_array: A 2-D numpy array with dtype np.bool_. A True value 80 | indicates an anomalous pixel. 81 | defect_config: The DefectConfig for this channel's defect type. 82 | """ 83 | 84 | # Input validation. 85 | # numpy dtypes need to be checked with == instead of `is`, see 86 | # https://stackoverflow.com/a/26921882/2305095 87 | # We want np.bool_ for a fast computation of unions, intersections etc. 88 | assert len(bool_array.shape) == 2 and bool_array.dtype == np.bool_ 89 | 90 | self.bool_array = bool_array 91 | self.defect_config = defect_config 92 | 93 | def get_defect_area(self): 94 | return np.sum(self.bool_array) 95 | 96 | def get_saturation_area(self): 97 | defect_area = self.get_defect_area() 98 | if self.defect_config.relative_saturation: 99 | return int(self.defect_config.saturation_threshold * defect_area) 100 | else: 101 | return np.minimum(self.defect_config.saturation_threshold, 102 | defect_area) 103 | 104 | @classmethod 105 | def create_from_integer_array(cls, 106 | np_array: np.ndarray, 107 | defects_config: DefectsConfig): 108 | """Create a new GroundTruthChannel from an integer array. 109 | 110 | Args: 111 | np_array: A 2-D array with exactly one distinct positive value. All 112 | non-positive entries must be zero and correspond to defect-free 113 | pixels. 114 | defects_config: The defects configuration for the dataset object 115 | being evaluated. 116 | """ 117 | assert np.issubdtype(np_array.dtype, np.integer) 118 | 119 | # Ensure that each channel has exactly one unique positive integer. 120 | sorted_unique = sorted(np.unique(np_array)) 121 | if len(sorted_unique) == 1: 122 | defect_id = sorted_unique[0] 123 | else: 124 | zero, defect_id = sorted_unique 125 | assert zero == 0 126 | assert defect_id > 0 127 | # Cast np.uint8 etc. to int. 128 | defect_id = int(defect_id) 129 | 130 | # Convert to bool for faster logical operations with anomaly maps. 131 | bool_array = np_array.astype(np.bool_) 132 | 133 | # Look up the defect config for this defect id. 134 | defect_config = defects_config.pixel_value_to_entry[defect_id] 135 | return GroundTruthChannel(bool_array=bool_array, 136 | defect_config=defect_config) 137 | 138 | 139 | class GroundTruthMap: 140 | """A ground truth map for an anomalous image. 141 | 142 | Each channel corresponds to one defect in the image. 143 | 144 | Use GroundTruthMap.read_from_tiff(...) to read a GroundTruthMap from a 145 | .tiff file. 146 | 147 | If defect_id_to_name is None, it is constructed based on the defect ids in 148 | the channels, using defect_id -> str(defect_id). 149 | """ 150 | 151 | def __init__(self, 152 | channels: Sequence[GroundTruthChannel], 153 | file_path: Optional[str] = None): 154 | 155 | # Input validation. 156 | assert len(channels) > 0 157 | # Ensure that each channel has the same size. 158 | first_shape = channels[0].bool_array.shape 159 | assert set(c.bool_array.shape for c in channels) == {first_shape} 160 | # Check whether some channels have larger saturation thresholds than 161 | # defect areas. 162 | for i_channel, channel in enumerate(channels): 163 | if channel.defect_config.relative_saturation: 164 | continue 165 | threshold = channel.defect_config.saturation_threshold 166 | defect_area = channel.get_defect_area() 167 | if threshold > defect_area: 168 | print(f'WARNING: Channel {i_channel + 1} (1=first) of ground' 169 | f' truth image {get_file_path_repr(file_path)} has a' 170 | f' defect area of {defect_area}, but a saturation' 171 | f' threshold of {threshold}. Corresponding defect' 172 | f' config: {channel.defect_config}') 173 | 174 | self.channels = tuple(channels) 175 | self.file_path = file_path 176 | 177 | @property 178 | def size(self): 179 | return self.channels[0].bool_array.shape 180 | 181 | def get_or_over_channels(self) -> np.ndarray: 182 | """Combine the channels with a logical OR operation. 183 | 184 | Returns a numpy array of type np.bool_. 185 | """ 186 | channels_np = tuple(c.bool_array for c in self.channels) 187 | return np.sum(channels_np, axis=0).astype(bool) 188 | 189 | @classmethod 190 | def read_from_png_dir(cls, 191 | png_dir: str, 192 | defects_config: DefectsConfig): 193 | """Read a GroundTruthMap from a directory containing one .png per 194 | channel. 195 | """ 196 | gt_channels = [] 197 | for png_path in sorted(glob.glob(os.path.join(png_dir, '*.png'))): 198 | image = Image.open(png_path) 199 | np_array = np.array(image) 200 | gt_channel = GroundTruthChannel.create_from_integer_array( 201 | np_array=np_array, 202 | defects_config=defects_config) 203 | gt_channels.append(gt_channel) 204 | 205 | return cls(channels=gt_channels, file_path=png_dir) 206 | 207 | 208 | class AnomalyMap: 209 | """An anomaly map generated by a model. 210 | 211 | Use AnomalyMap.read_from_tiff(...) to read an AnomalyMap from a 212 | .tiff file. 213 | """ 214 | 215 | def __init__(self, 216 | np_array: np.ndarray, 217 | file_path: Optional[str] = None): 218 | """ 219 | Args: 220 | np_array: A 2-D numpy array containing the real-valued anomaly 221 | scores. 222 | file_path: (optional) file path of the image. Not used for I/O. 223 | """ 224 | 225 | assert len(np_array.shape) == 2 226 | 227 | self.np_array = np_array 228 | self.file_path = file_path 229 | 230 | def __repr__(self): 231 | return f'AnomalyMap({get_file_path_repr(self.file_path)})' 232 | 233 | @property 234 | def size(self): 235 | return self.np_array.shape 236 | 237 | def get_binary_image(self, anomaly_threshold: float): 238 | """Return the binary anomaly map based on a given threshold. 239 | 240 | The result is a 2-D numpy array with dtype np.bool_. 241 | """ 242 | return self.get_binary_images( 243 | anomaly_thresholds=[anomaly_threshold])[0] 244 | 245 | def get_binary_images(self, anomaly_thresholds: Iterable[float]): 246 | """Return binary anomaly maps based on given thresholds. 247 | 248 | The result is a 3-D numpy array with dtype np.bool_. The first 249 | dimension has the same length as the anomaly_thresholds. 250 | """ 251 | return self._get_binary_images( 252 | anomaly_thresholds=tuple(anomaly_thresholds)) 253 | 254 | @lru_cache(maxsize=3) 255 | def _get_binary_images(self, anomaly_thresholds: Tuple[float, ...]): 256 | thresholds = [[[t]] for t in anomaly_thresholds] 257 | return np.greater(self.np_array[np.newaxis, :, :], thresholds) 258 | 259 | @classmethod 260 | def read_from_tiff(cls, tiff_path: str): 261 | """Read an AnomalyMap from a TIFF-file.""" 262 | np_array = tifffile.imread(tiff_path) 263 | assert len(np_array.shape) == 2 264 | return cls(np_array=np_array, 265 | file_path=tiff_path) -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | from skimage import measure 2 | from sklearn.metrics import auc 3 | import numpy as np 4 | 5 | def rescale(x): 6 | return (x - x.min()) / (x.max() - x.min()) 7 | 8 | 9 | def cal_pro_metric_new(labeled_imgs, score_imgs, fpr_thresh=0.3, max_steps=2000, class_name=None): 10 | labeled_imgs = np.array(labeled_imgs).squeeze(1) 11 | labeled_imgs[labeled_imgs <= 0.45] = 0 12 | labeled_imgs[labeled_imgs > 0.45] = 1 13 | labeled_imgs = labeled_imgs.astype(np.bool) 14 | score_imgs = np.array(score_imgs).squeeze(1) 15 | 16 | max_th = score_imgs.max() 17 | min_th = score_imgs.min() 18 | delta = (max_th - min_th) / max_steps 19 | 20 | ious_mean = [] 21 | ious_std = [] 22 | pros_mean = [] 23 | pros_std = [] 24 | threds = [] 25 | fprs = [] 26 | binary_score_maps = np.zeros_like(score_imgs, dtype=np.bool) 27 | for step in range(max_steps): 28 | thred = max_th - step * delta 29 | # segmentation 30 | binary_score_maps[score_imgs <= thred] = 0 31 | binary_score_maps[score_imgs > thred] = 1 32 | 33 | pro = [] # per region overlap 34 | iou = [] # per image iou 35 | # pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region 36 | # iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map 37 | for i in range(len(binary_score_maps)): # for i th image 38 | # pro (per region level) 39 | label_map = measure.label(labeled_imgs[i], connectivity=2) 40 | props = measure.regionprops(label_map) 41 | for prop in props: 42 | x_min, y_min, x_max, y_max = prop.bbox 43 | cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max] 44 | # cropped_mask = masks[i][x_min:x_max, y_min:y_max] 45 | cropped_mask = prop.filled_image # corrected! 46 | intersection = np.logical_and(cropped_pred_label, cropped_mask).astype(np.float32).sum() 47 | pro.append(intersection / prop.area) 48 | # iou (per image level) 49 | intersection = np.logical_and(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum() 50 | union = np.logical_or(binary_score_maps[i], labeled_imgs[i]).astype(np.float32).sum() 51 | if labeled_imgs[i].any() > 0: # when the gt have no anomaly pixels, skip it 52 | iou.append(intersection / union) 53 | # against steps and average metrics on the testing data 54 | ious_mean.append(np.array(iou).mean()) 55 | # print("per image mean iou:", np.array(iou).mean()) 56 | ious_std.append(np.array(iou).std()) 57 | pros_mean.append(np.array(pro).mean()) 58 | pros_std.append(np.array(pro).std()) 59 | # fpr for pro-auc 60 | masks_neg = ~labeled_imgs 61 | fpr = np.logical_and(masks_neg, binary_score_maps).sum() / masks_neg.sum() 62 | fprs.append(fpr) 63 | threds.append(thred) 64 | 65 | # as array 66 | threds = np.array(threds) 67 | pros_mean = np.array(pros_mean) 68 | pros_std = np.array(pros_std) 69 | fprs = np.array(fprs) 70 | 71 | 72 | # default 30% fpr vs pro, pro_auc 73 | idx = fprs <= fpr_thresh # find the indexs of fprs that is less than expect_fpr (default 0.3) 74 | fprs_selected = fprs[idx] 75 | fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1] 76 | pros_mean_selected = pros_mean[idx] 77 | pro_auc_score = auc(fprs_selected, pros_mean_selected) 78 | # print("pro auc ({}% FPR):".format(int(expect_fpr * 100)), pro_auc_score) 79 | return pro_auc_score -------------------------------------------------------------------------------- /utils/myutil.py: -------------------------------------------------------------------------------- 1 | """Collection of utility functions.""" 2 | import os 3 | import platform 4 | from bisect import bisect 5 | from typing import Iterable, Sequence, List, Callable 6 | 7 | import numpy as np 8 | 9 | 10 | def is_dict_order_stable(): 11 | """Returns true, if and only if dicts always iterate in the same order.""" 12 | if platform.python_implementation() == 'CPython': 13 | required_minor = 6 14 | else: 15 | required_minor = 7 16 | major, minor, _ = platform.python_version_tuple() 17 | assert major == '3' and all(s.isdigit() for s in minor) 18 | return int(minor) >= required_minor 19 | 20 | 21 | def listdir(path, sort=True, include_hidden=False): 22 | file_names = os.listdir(path) 23 | if sort: 24 | file_names = sorted(file_names) 25 | if not include_hidden: 26 | file_names = [f for f in file_names if not f.startswith('.')] 27 | return file_names 28 | 29 | 30 | def set_niceness(niceness): 31 | # Same as os.nice, but takes an absolute niceness instead of an increment. 32 | current_niceness = os.nice(0) 33 | niceness_increment = niceness - current_niceness 34 | # Regular users are not allowed to decrease the niceness. Doing so would 35 | # raise an exception, even if the resulting niceness would be positive. 36 | niceness_increment = max(0, niceness_increment) 37 | return os.nice(niceness_increment) 38 | 39 | 40 | def take(seq: Sequence, indices: Iterable[int]) -> List: 41 | return [seq[i] for i in indices] 42 | 43 | 44 | def flatten_2d(seq: Sequence[Sequence]) -> List: 45 | return [elem for innerseq in seq for elem in innerseq] 46 | 47 | 48 | def get_sorted_nested_arrays(nested_arrays, sort_indices, nest_level=1): 49 | return map_nested(nested_objects=nested_arrays, 50 | fun=lambda a: a[sort_indices], 51 | nest_level=nest_level) 52 | 53 | 54 | def concat_nested_arrays(head_arrays: Sequence, 55 | tail_arrays: Sequence, 56 | nest_level=1): 57 | """Concatenate numpy arrays nested in a sequence (of sequences ... 58 | of sequences). 59 | 60 | Args: 61 | head_arrays: Sequence (of sequences ... of sequences) of numpy arrays. 62 | The lengths of the nested numpy arrays may differ. 63 | tail_arrays: Sequence (of sequences ... of sequences) of numpy arrays. 64 | Must have the same structure as head_arrays. 65 | The lengths of the nested numpy arrays may differ. 66 | nest_level: Number of sequence levels. 1 means there is a sequence of 67 | arrays. 2 means there is a sequence of sequences of arrays. 68 | Must be >= 1. 69 | 70 | Returns: 71 | A sequence (of sequences ... of sequences) of numpy arrays with the 72 | same structure as head_arrays and tail_arrays containing the 73 | concatenated arrays. 74 | """ 75 | 76 | # Zip the heads and tails at the deepest level. 77 | head_tail_arrays = zip_nested(head_arrays, tail_arrays, 78 | nest_level=nest_level) 79 | 80 | def concat(args): 81 | head, tail = args 82 | return np.concatenate([head, tail]) 83 | 84 | return map_nested(nested_objects=head_tail_arrays, 85 | fun=concat, 86 | nest_level=nest_level) 87 | 88 | 89 | def map_nested(nested_objects: Sequence, fun: Callable, nest_level=1): 90 | """Apply a function to objects nested in a sequence (of sequences ... 91 | of sequences). 92 | 93 | Args: 94 | nested_objects: Sequence (of sequences ... of sequences) of objects. 95 | fun: Function to call for each object. 96 | nest_level: Number of sequence levels. 1 means there is a sequence of 97 | objects. 2 means there is a sequence of sequences of objects. 98 | Must be >= 1. 99 | 100 | Returns: 101 | A list (of lists ... of lists) of mapped objects. This list has the 102 | same structure as nested_objects. Each item is the result of 103 | applying fun to the corresponding nested object. 104 | """ 105 | assert 1 <= nest_level 106 | if nest_level == 1: 107 | return [fun(o) for o in nested_objects] 108 | else: 109 | # Go one level deeper. 110 | mapped = [] 111 | for lower_nested_objects in nested_objects: 112 | # Map the nested sequence of objects. 113 | lower_mapped = map_nested( 114 | nested_objects=lower_nested_objects, 115 | fun=fun, 116 | nest_level=nest_level - 1 117 | ) 118 | mapped.append(lower_mapped) 119 | return mapped 120 | 121 | 122 | def zip_nested(*seqs: Sequence, nest_level=1): 123 | """Zip sequences (of sequences ... of sequences) of objects at the deepest 124 | level. 125 | 126 | Args: 127 | seqs: Sequences (of sequences ... of sequences) of objects. 128 | All sequences must have the same structure (length, length of 129 | descending sequences etc.). 130 | nest_level: Number of sequence levels. 1 means each sequence is a 131 | sequence of objects. 2 means each is a sequence of sequences of 132 | objects. Must be >= 1. 133 | 134 | Returns: 135 | A list (of lists ... of lists) of tuples containing the zipped objects. 136 | This list has the same structure as each sequence in seqs. 137 | """ 138 | assert 1 <= nest_level 139 | # All sequences must have the same length. 140 | seq_length = len(seqs[0]) 141 | assert set(len(seq) for seq in seqs) == {seq_length} 142 | 143 | if nest_level == 1: 144 | return list(zip(*seqs)) 145 | else: 146 | # Zip one level deeper. 147 | zipped = [] 148 | for i in range(seq_length): 149 | # Get the i-th sequence in every sequence of sequences. 150 | nested_seqs = [seq[i] for seq in seqs] 151 | # Zip the i-th sequences. 152 | zipped_nested = zip_nested(*nested_seqs, nest_level=nest_level - 1) 153 | zipped.append(zipped_nested) 154 | return zipped 155 | 156 | 157 | def get_auc_for_max_fpr(fprs, y_values, max_fpr, scale_to_one=True): 158 | """Compute AUCs for varying maximum FPRs.""" 159 | assert fprs[0] == 0 and fprs[-1] == 1 160 | auc = trapz(x=fprs, 161 | y=y_values, 162 | x_max=max_fpr) 163 | if scale_to_one: 164 | auc /= max_fpr 165 | return auc 166 | 167 | 168 | def trapz(x, y, x_max=None): 169 | """ 170 | This function calculates the definit integral of a curve given by 171 | x- and corresponding y-values. In contrast to, e.g., 'numpy.trapz()', 172 | this function allows to define an upper bound to the integration range by 173 | setting a value x_max. 174 | 175 | Points that do not have a finite x or y value will be ignored with a 176 | warning. 177 | 178 | Args: 179 | x: Samples from the domain of the function to integrate 180 | Need to be sorted in ascending order. May contain the same value 181 | multiple times. In that case, the order of the corresponding 182 | y values will affect the integration with the trapezoidal rule. 183 | y: Values of the function corresponding to x values. 184 | x_max: Upper limit of the integration. The y value at max_x will be 185 | determined by interpolating between its neighbors. Must not lie 186 | outside the range of x. 187 | 188 | Returns: 189 | Area under the curve. 190 | """ 191 | 192 | x = np.array(x) 193 | y = np.array(y) 194 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y)) 195 | if not finite_mask.all(): 196 | print("""WARNING: Not all x and y values passed to trapezoid(...) 197 | are finite. Will continue with only the finite values.""") 198 | x = x[finite_mask] 199 | y = y[finite_mask] 200 | 201 | # Introduce a correction term if max_x is not an element of x. 202 | correction = 0. 203 | if x_max is not None: 204 | if x_max not in x: 205 | # Get the insertion index that would keep x sorted after 206 | # np.insert(x, ins, x_max). 207 | ins = bisect(x, x_max) 208 | # x_max must be between the minimum and the maximum, so the 209 | # insertion_point cannot be zero or len(x). 210 | assert 0 < ins < len(x) 211 | 212 | # Calculate the correction term which is the integral between 213 | # the last x[ins-1] and x_max. Since we do not know the exact value 214 | # of y at x_max, we interpolate between y[ins] and y[ins-1]. 215 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * 216 | (x_max - x[ins - 1]) / 217 | (x[ins] - x[ins - 1])) 218 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1]) 219 | 220 | # Cut off at x_max. 221 | mask = x <= x_max 222 | x = x[mask] 223 | y = y[mask] 224 | 225 | # Return area under the curve using the trapezoidal rule. 226 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction 227 | 228 | 229 | def compute_classification_roc( 230 | anomaly_scores_ok, 231 | anomaly_scores_nok): 232 | """ 233 | Compute the ROC curve for anomaly classification on the image level. 234 | 235 | Args: 236 | anomaly_scores_ok: List of real-valued anomaly scores of anomaly-free 237 | samples. 238 | anomaly_scores_nok: List of real-valued anomaly scores of anomalous 239 | samples. 240 | 241 | Returns: 242 | fprs: List of false positive rates. 243 | tprs: List of correspoding true positive rates. 244 | """ 245 | # Merge anomaly scores into a single list, keeping track of the GT label. 246 | # 0 = anomaly-free. 1 = anomalous. 247 | anomaly_scores = [] 248 | 249 | anomaly_scores.extend([(x, 0) for x in anomaly_scores_ok]) 250 | anomaly_scores.extend([(x, 1) for x in anomaly_scores_nok]) 251 | 252 | # Sort anomaly scores. 253 | anomaly_scores = sorted(anomaly_scores, key=lambda x: x[0]) 254 | 255 | # Fetch the number of ok and nok samples. 256 | num_scores = len(anomaly_scores) 257 | num_nok = len(anomaly_scores_nok) 258 | num_ok = len(anomaly_scores_ok) 259 | 260 | # Initially, every NOK sample is correctly classified as anomalous 261 | # (tpr = 1.0), and every OK sample is incorrectly classified as anomalous 262 | # (fpr = 1.0). 263 | fprs = [1.0] 264 | tprs = [1.0] 265 | 266 | # Keep track of the current number of false and true positive predictions. 267 | num_fp = num_ok 268 | num_tp = num_nok 269 | 270 | # Compute new true and false positive rates when successively increasing 271 | # the threshold. Add points to the curve only when anomaly scores change. 272 | prev_score = None 273 | for i, (score, label) in enumerate(anomaly_scores): 274 | if label == 0: 275 | num_fp -= 1 276 | else: 277 | num_tp -= 1 278 | 279 | if (prev_score is None) or (score != prev_score) or ( 280 | i == num_scores - 1): 281 | fprs.append(num_fp / num_ok) 282 | tprs.append(num_tp / num_nok) 283 | prev_score = score 284 | 285 | # Return (FPR, TPR) pairs in increasing order. 286 | fprs = fprs[::-1] 287 | tprs = tprs[::-1] 288 | 289 | return fprs, tprs 290 | 291 | 292 | def compute_classification_auc_roc( 293 | anomaly_scores_ok, 294 | anomaly_scores_nok): 295 | """ 296 | Compute the area under the ROC curve for anomaly classification. 297 | 298 | Args: 299 | anomaly_scores_ok: List of real-valued anomaly scores of anomaly-free 300 | samples. 301 | anomaly_scores_nok: List of real-valued anomaly scores of anomalous 302 | samples. 303 | 304 | Returns: 305 | auc_roc: Area under the ROC curve. 306 | """ 307 | # Compute the ROC curve. 308 | fprs, tprs = \ 309 | compute_classification_roc(anomaly_scores_ok, anomaly_scores_nok) 310 | 311 | # Integrate its area. 312 | return trapz(fprs, tprs) -------------------------------------------------------------------------------- /utils/sPro.py: -------------------------------------------------------------------------------- 1 | """Metrics computed on a single image, but many anomaly thresholds. 2 | 3 | At the bottom, there are two functions for computing sPRO values and false 4 | positive rates efficiently for many images and many anomaly thresholds: 5 | - get_spros_of_defects_of_images(...) 6 | - get_fp_tn_areas_per_image(...) 7 | """ 8 | 9 | from concurrent.futures import ProcessPoolExecutor 10 | from typing import Sequence, Optional, MutableMapping 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | from utils.image import AnomalyMap, GroundTruthMap, GroundTruthChannel 16 | from utils.myutil import set_niceness 17 | 18 | 19 | def get_spro(gt_channel: GroundTruthChannel, 20 | anomaly_map: AnomalyMap, 21 | anomaly_threshold: float) -> float: 22 | """Compute the saturated PRO metric for a single ground truth channel 23 | (i.e. defect) and a single threshold. 24 | 25 | Only use this function for testing and understanding. Do not use it 26 | repeatedly for different anomaly thresholds. Use get_spros(...) for that. 27 | """ 28 | binary_anomaly_map = anomaly_map.get_binary_image(anomaly_threshold) 29 | tp = np.logical_and(binary_anomaly_map, gt_channel.bool_array) 30 | tp_area = np.sum(tp) 31 | saturation_area = gt_channel.get_saturation_area() 32 | return np.minimum(tp_area / saturation_area, 1.) 33 | 34 | 35 | def get_spros_for_thresholds(gt_channel: GroundTruthChannel, 36 | anomaly_map: AnomalyMap, 37 | anomaly_thresholds: Sequence[float]): 38 | """Compute the saturated PRO metric for a single ground truth channel 39 | (i.e. defect) and multiple thresholds. 40 | 41 | Returns: 42 | A 1-D numpy array with the same length as anomaly_thresholds 43 | containing the sPRO values. 44 | """ 45 | tp_areas = get_tp_areas_for_thresholds( 46 | gt_channel=gt_channel, 47 | anomaly_map=anomaly_map, 48 | anomaly_thresholds=anomaly_thresholds) 49 | saturation_area = gt_channel.get_saturation_area() 50 | return np.minimum(tp_areas / saturation_area, 1.) 51 | 52 | 53 | def get_spros_per_defect_for_thresholds(gt_map: Optional[GroundTruthMap], 54 | anomaly_map: AnomalyMap, 55 | anomaly_thresholds: Sequence[float]): 56 | """Compute the saturated PRO metric for a single ground truth map 57 | (containing multiple defects / channels) and multiple thresholds. 58 | 59 | Returns: 60 | A tuple of 1-D numpy arrays. The length of the tuple is given by 61 | the number of channels in the ground truth map. Each numpy array 62 | has the same length as anomaly_thresholds and contains the sPRO 63 | values for the respective channel. If gt_map is None, the 64 | returned tuple is empty. 65 | """ 66 | if gt_map is None: 67 | return [] 68 | 69 | assert anomaly_map.np_array.shape == gt_map.size 70 | spros_per_defect = [] 71 | for channel in gt_map.channels: 72 | spros = get_spros_for_thresholds(gt_channel=channel, 73 | anomaly_map=anomaly_map, 74 | anomaly_thresholds=anomaly_thresholds) 75 | spros_per_defect.append(spros) 76 | return tuple(spros_per_defect) 77 | 78 | 79 | def get_tp_areas_for_thresholds(gt_channel: GroundTruthChannel, 80 | anomaly_map: AnomalyMap, 81 | anomaly_thresholds: Sequence[float]): 82 | """Compute the true positive areas for a single ground truth channel 83 | (i.e. defect) and multiple thresholds. 84 | 85 | Returns: 86 | A 1-D numpy array with the same length as anomaly_thresholds 87 | containing the true positive areas. 88 | """ 89 | binary_anomaly_maps = anomaly_map.get_binary_images(anomaly_thresholds) 90 | tps = np.logical_and(binary_anomaly_maps, gt_channel.bool_array) 91 | tp_areas = np.sum(tps, axis=(1, 2)) 92 | return tp_areas 93 | 94 | 95 | def get_fp_areas_for_thresholds(gt_map: Optional[GroundTruthMap], 96 | anomaly_map: AnomalyMap, 97 | anomaly_thresholds: Sequence[float]): 98 | """Compute the false positive areas for a single ground truth map 99 | (containing multiple defects / channels) and multiple thresholds. 100 | 101 | Set gt_map to None for "good" images without ground truth annotations. 102 | 103 | Needs the whole GT maps to make sure that we do not mark a pixel a false 104 | positive that would be a true positive in another channel. 105 | 106 | A false positive pixel is a pixel that is defect-free in all channels of 107 | the ground truth map, but is a positive in the anomaly map. 108 | 109 | Returns: 110 | A 1-D numpy array with the same length as anomaly_thresholds 111 | containing the false positive areas. 112 | """ 113 | 114 | binary_anomaly_maps = anomaly_map.get_binary_images(anomaly_thresholds) 115 | 116 | fp_areas: np.ndarray 117 | if gt_map is None: 118 | # This is a good image. All positive pixels are false positives. 119 | fp_areas = np.sum(binary_anomaly_maps, axis=(1, 2)) 120 | else: 121 | # False positives do not depend on a single channel, like true 122 | # positives. Only pixels that are defect-free in all ground truth 123 | # channels can be a false positive. 124 | gt_combined = gt_map.get_or_over_channels() 125 | fps = np.logical_and(binary_anomaly_maps, 126 | np.logical_not(gt_combined)) 127 | fp_areas = np.sum(fps, axis=(1, 2)) 128 | return fp_areas 129 | 130 | 131 | def get_tn_areas_for_thresholds(gt_map: Optional[GroundTruthMap], 132 | anomaly_map: AnomalyMap, 133 | anomaly_thresholds: Sequence[float], 134 | fp_areas: Optional[np.ndarray] = None): 135 | """Compute the true negative areas for a single ground truth map 136 | (containing multiple defects / channels) and multiple thresholds. 137 | 138 | Set gt_map to None for "good" images without ground truth annotations. 139 | 140 | A true negative pixel is a pixel that is defect-free in all channels of 141 | the ground truth map and is a negative in the anomaly map. 142 | 143 | The true negative area plus the false positive area equals the number of 144 | pixels that are defect-free in *all* channels of the ground truth map, 145 | see get_fp_areas_for_thresholds(...). 146 | 147 | The computation can be sped up significantly by setting fp_areas to the 148 | result of get_fp_areas_for_thresholds(...)! 149 | 150 | Returns: 151 | A 1-D numpy array with the same length as anomaly_thresholds 152 | containing the true negative areas. 153 | """ 154 | 155 | binary_anomaly_maps = anomaly_map.get_binary_images(anomaly_thresholds) 156 | 157 | tn_areas: np.ndarray 158 | if gt_map is None: 159 | # This is a good image. All negative pixels are true negatives. 160 | tn_areas = np.sum(np.logical_not(binary_anomaly_maps), axis=(1, 2)) 161 | else: 162 | # True negatives are pixels that have a negative prediction and are 163 | # marked as being defect-free in all channels of a ground truth map. 164 | gt_combined = gt_map.get_or_over_channels() 165 | 166 | if fp_areas is not None: 167 | # Compute the true negatives based on the false positives and the 168 | # total defect-free area. 169 | no_defect_area = np.sum(np.logical_not(gt_combined)) 170 | tn_areas = no_defect_area - fp_areas 171 | else: 172 | # NOT prediction=True AND NOT defect=True can be replaced with 173 | # NOT (prediction=True OR defect=True) 174 | tns = np.logical_not(np.logical_or(binary_anomaly_maps, 175 | gt_combined)) 176 | tn_areas = np.sum(tns, axis=(1, 2)) 177 | return tn_areas 178 | 179 | 180 | def _get_spros_per_defect_for_thresholds_kwargs(kwargs: MutableMapping): 181 | if 'niceness' in kwargs: 182 | set_niceness(kwargs['niceness']) 183 | del kwargs['niceness'] 184 | return get_spros_per_defect_for_thresholds(**kwargs) 185 | 186 | 187 | def get_spros_of_defects_of_images( 188 | gt_maps: Sequence[Optional[GroundTruthMap]], 189 | anomaly_maps: Sequence[AnomalyMap], 190 | anomaly_thresholds: Sequence[float], 191 | parallel_workers: Optional[int] = None, 192 | parallel_niceness: int = 19): 193 | """Compute the saturated PRO values for several images and anomaly 194 | thresholds, possibly in parallel. 195 | 196 | Args: 197 | gt_maps: Sequence of GroundTruthMap or None entries with the same 198 | length and ordering as anomaly_maps. Use None for "good" images 199 | without ground truth annotations. 200 | anomaly_maps: Must have the same length and ordering as gt_maps. 201 | anomaly_thresholds: Thresholds for obtaining binary anomaly maps. 202 | parallel_workers: If None (default), nothing will be parallelized 203 | across CPUs. Otherwise, the value denotes the number of CPUs to use 204 | for parallelism. A value of 1 will result in suboptimal performance 205 | compared to None. 206 | parallel_niceness: Niceness of child processes. Only applied in the 207 | parallelized setting. 208 | 209 | Returns: 210 | A list of tuples of numpy arrays. The outer list will have the same 211 | length as gt_maps and anomaly_maps. The length of each inner 212 | tuple is given by the number of defects per image. The length of 213 | each numpy array is given by the number of anomaly thresholds. 214 | "good" images will have an empty inner tuple. 215 | """ 216 | assert len(gt_maps) == len(anomaly_maps) 217 | 218 | # Construct the kwargs for each call to get_spros_per_defect_for_thresholds 219 | # via _get_spros_per_defect_for_thresholds_kwargs. 220 | kwargs_list = [] 221 | for gt_map, anomaly_map in zip(gt_maps, anomaly_maps): 222 | kwargs = { 223 | 'gt_map': gt_map, 224 | 'anomaly_map': anomaly_map, 225 | 'anomaly_thresholds': anomaly_thresholds 226 | } 227 | if parallel_workers is not None: 228 | kwargs['niceness'] = parallel_niceness 229 | kwargs_list.append(kwargs) 230 | 231 | if parallel_workers is None: 232 | print(f'Computing mean sPROs for {len(anomaly_thresholds)} anomaly' 233 | f' thresholds...') 234 | spros_of_defects_of_images = [ 235 | _get_spros_per_defect_for_thresholds_kwargs(kwargs) 236 | for kwargs in tqdm(kwargs_list)] 237 | else: 238 | print(f'Computing mean sPROs for {len(anomaly_thresholds)} anomaly' 239 | f' thresholds in parallel on {parallel_workers} CPUs...') 240 | pool = ProcessPoolExecutor(max_workers=parallel_workers) 241 | spros_of_defects_of_images = pool.map( 242 | _get_spros_per_defect_for_thresholds_kwargs, 243 | kwargs_list) 244 | spros_of_defects_of_images = list(spros_of_defects_of_images) 245 | return spros_of_defects_of_images 246 | 247 | 248 | def _get_fp_areas_for_thresholds_kwargs(kwargs: MutableMapping): 249 | if 'niceness' in kwargs: 250 | set_niceness(kwargs['niceness']) 251 | del kwargs['niceness'] 252 | return get_fp_areas_for_thresholds(**kwargs) 253 | 254 | 255 | def get_fp_tn_areas_per_image( 256 | gt_maps: Sequence[Optional[GroundTruthMap]], 257 | anomaly_maps: Sequence[AnomalyMap], 258 | anomaly_thresholds: Sequence[float], 259 | parallel_workers: Optional[int] = None, 260 | parallel_niceness: int = 19): 261 | """Compute the false positive and the true negative areas for several 262 | images and anomaly thresholds, possibly in parallel. 263 | 264 | Args: 265 | gt_maps: Sequence of GroundTruthMap or None entries with the same 266 | length and ordering as anomaly_maps. Use for "good" images 267 | without ground truth annotations. 268 | anomaly_maps: Must have the same length and ordering as gt_maps. 269 | anomaly_thresholds: Thresholds for obtaining binary anomaly maps. 270 | parallel_workers: If None (default), nothing will be parallelized 271 | across CPUs. Otherwise, the value denotes the number of CPUs to use 272 | for parallelism. A value of 1 will result in suboptimal performance 273 | compared to None. 274 | parallel_niceness: Niceness of child processes. Only applied in the 275 | parallelized setting. 276 | 277 | Returns: 278 | A list of 1-D numpy arrays. The list has the same length as gt_maps 279 | and anomaly_maps. It contains the false positive areas for each 280 | image. Each numpy array has the same length as anomaly_thresholds. 281 | A list of 1-D numpy arrays. The list has the same length as gt_maps 282 | and anomaly_maps. It contains the true negative areas for each 283 | image. Each numpy array has the same length as anomaly_thresholds. 284 | """ 285 | assert len(gt_maps) == len(anomaly_maps) 286 | 287 | # Construct the kwargs for each call to get_fp_areas_for_thresholds via 288 | # _get_fp_areas_for_thresholds_kwargs. 289 | kwargs_list = [] 290 | for gt_map, anomaly_map in zip(gt_maps, anomaly_maps): 291 | kwargs = { 292 | 'gt_map': gt_map, 293 | 'anomaly_map': anomaly_map, 294 | 'anomaly_thresholds': anomaly_thresholds 295 | } 296 | if parallel_workers is not None: 297 | kwargs['niceness'] = parallel_niceness 298 | kwargs_list.append(kwargs) 299 | 300 | # For each anomaly threshold, compute the FP areas per image. 301 | if parallel_workers is None: 302 | print(f'Computing FPRs for {len(anomaly_thresholds)} anomaly' 303 | f' thresholds...') 304 | fp_areas_per_image = [ 305 | _get_fp_areas_for_thresholds_kwargs(kwargs) 306 | for kwargs in tqdm(kwargs_list)] 307 | else: 308 | print(f'Computing FPRs for {len(anomaly_thresholds)} anomaly' 309 | f' thresholds in parallel on {parallel_workers} CPUs...') 310 | pool = ProcessPoolExecutor(max_workers=parallel_workers) 311 | fp_areas_per_image = pool.map( 312 | _get_fp_areas_for_thresholds_kwargs, 313 | kwargs_list) 314 | fp_areas_per_image = list(fp_areas_per_image) 315 | 316 | # For each anomaly threshold, compute the TN areas per image. 317 | tn_areas_per_image = [] 318 | for gt_map, anomaly_map, fp_areas in zip( 319 | gt_maps, anomaly_maps, fp_areas_per_image): 320 | # For each image, there is only one FP area and one TN area per 321 | # anomaly threshold. 322 | tn_areas = get_tn_areas_for_thresholds( 323 | gt_map=gt_map, 324 | anomaly_map=anomaly_map, 325 | anomaly_thresholds=anomaly_thresholds, 326 | fp_areas=fp_areas) 327 | tn_areas_per_image.append(tn_areas) 328 | return fp_areas_per_image, tn_areas_per_image 329 | 330 | 331 | def get_fp_rates(fp_areas_per_image: Sequence[np.ndarray], 332 | tn_areas_per_image: Sequence[np.ndarray]): 333 | """Compute false positive rates based on the results of 334 | get_fp_tn_areas_per_image(...). 335 | 336 | Args: 337 | fp_areas_per_image: See get_fp_tn_areas_per_image(...). 338 | tn_areas_per_image: See get_fp_tn_areas_per_image(...). 339 | 340 | Returns: 341 | A 1-D numpy array with the same length as each array in 342 | fp_areas_per_image and tn_areas_per_image. For each 343 | anomaly threshold, it contains the FPR computed over all images. 344 | 345 | Raises: 346 | ZeroDivisionError if there is no defect-free pixel in any of the 347 | images. This would result in a zero division when computing the 348 | FPR for any anomaly threshold. 349 | """ 350 | total_fp_areas = np.zeros_like(fp_areas_per_image[0], dtype=np.int64) 351 | total_tn_areas = np.zeros_like(fp_areas_per_image[0], dtype=np.int64) 352 | for fp_areas, tn_areas in zip(fp_areas_per_image, tn_areas_per_image): 353 | assert len(fp_areas) == len(tn_areas) 354 | total_fp_areas += fp_areas 355 | total_tn_areas += tn_areas 356 | 357 | # If there is no defect-free pixel in any of the images, there cannot be 358 | # false positives or true negatives. Then, TN+FP will be zero, regardless 359 | # of the anomaly threshold. Otherwise, TN+FP will be positive, regardless 360 | # of the anomaly threshold. 361 | # Therefore, we prevent division by zero by checking if the sum of TN+FP 362 | # is zero for any of the thresholds. 363 | if total_tn_areas[0] + total_fp_areas[0] == 0: 364 | assert np.sum(total_fp_areas + total_tn_areas) == 0 365 | raise ZeroDivisionError 366 | fp_rates = total_fp_areas / (total_tn_areas + total_fp_areas) 367 | return fp_rates -------------------------------------------------------------------------------- /utils/test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luow23/TFA-Net/926f6f5eebe43d13c95148aa819eae6503072365/utils/test.py -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import yaml 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch.nn.functional as F 8 | 9 | 10 | def pil_loader(path): 11 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 12 | with open(path, 'rb') as f: 13 | img = Image.open(f) 14 | return img.convert('RGB') 15 | 16 | 17 | def default_loader(path): 18 | return pil_loader(path) 19 | 20 | 21 | def tensor_img_to_npimg(tensor_img): 22 | """ 23 | Turn a tensor image with shape CxHxW to a numpy array image with shape HxWxC 24 | :param tensor_img: 25 | :return: a numpy array image with shape HxWxC 26 | """ 27 | if not (torch.is_tensor(tensor_img) and tensor_img.ndimension() == 3): 28 | raise NotImplementedError("Not supported tensor image. Only tensors with dimension CxHxW are supported.") 29 | npimg = np.transpose(tensor_img.numpy(), (1, 2, 0)) 30 | npimg = npimg.squeeze() 31 | assert isinstance(npimg, np.ndarray) and (npimg.ndim in {2, 3}) 32 | return npimg 33 | 34 | 35 | # Change the values of tensor x from range [0, 1] to [-1, 1] 36 | def normalize(x): 37 | return x.mul_(2).add_(-1) 38 | 39 | def same_padding(images, ksizes, strides, rates): 40 | assert len(images.size()) == 4 41 | batch_size, channel, rows, cols = images.size() 42 | out_rows = (rows + strides[0] - 1) // strides[0] 43 | out_cols = (cols + strides[1] - 1) // strides[1] 44 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 45 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 46 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 47 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 48 | # Pad the input 49 | padding_top = int(padding_rows / 2.) 50 | padding_left = int(padding_cols / 2.) 51 | padding_bottom = padding_rows - padding_top 52 | padding_right = padding_cols - padding_left 53 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 54 | images = torch.nn.ZeroPad2d(paddings)(images) 55 | return images 56 | 57 | 58 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 59 | """ 60 | Extract patches from images and put them in the C output dimension. 61 | :param padding: 62 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 63 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 64 | each dimension of images 65 | :param strides: [stride_rows, stride_cols] 66 | :param rates: [dilation_rows, dilation_cols] 67 | :return: A Tensor 68 | """ 69 | assert len(images.size()) == 4 70 | assert padding in ['same', 'valid'] 71 | batch_size, channel, height, width = images.size() 72 | 73 | if padding == 'same': 74 | images = same_padding(images, ksizes, strides, rates) 75 | elif padding == 'valid': 76 | pass 77 | else: 78 | raise NotImplementedError('Unsupported padding type: {}.\ 79 | Only "same" or "valid" are supported.'.format(padding)) 80 | 81 | unfold = torch.nn.Unfold(kernel_size=ksizes, 82 | dilation=rates, 83 | padding=0, 84 | stride=strides) 85 | patches = unfold(images) 86 | return patches # [N, C*k*k, L], L is the total number of such blocks 87 | 88 | 89 | def random_bbox(config, batch_size): 90 | """Generate a random tlhw with configuration. 91 | 92 | Args: 93 | config: Config should have configuration including img 94 | 95 | Returns: 96 | tuple: (top, left, height, width) 97 | 98 | """ 99 | img_height, img_width, _ = config['image_shape'] 100 | h, w = config['mask_shape'] 101 | margin_height, margin_width = config['margin'] 102 | maxt = img_height - margin_height - h 103 | maxl = img_width - margin_width - w 104 | bbox_list = [] 105 | if config['mask_batch_same']: 106 | t = np.random.randint(margin_height, maxt) 107 | l = np.random.randint(margin_width, maxl) 108 | bbox_list.append((t, l, h, w)) 109 | bbox_list = bbox_list * batch_size 110 | else: 111 | for i in range(batch_size): 112 | t = np.random.randint(margin_height, maxt) 113 | l = np.random.randint(margin_width, maxl) 114 | bbox_list.append((t, l, h, w)) 115 | 116 | return torch.tensor(bbox_list, dtype=torch.int64) 117 | 118 | 119 | def test_random_bbox(): 120 | image_shape = [256, 256, 3] 121 | mask_shape = [128, 128] 122 | margin = [0, 0] 123 | bbox = random_bbox(image_shape) 124 | return bbox 125 | 126 | 127 | def bbox2mask(bboxes, height, width, max_delta_h, max_delta_w): 128 | batch_size = bboxes.size(0) 129 | mask = torch.zeros((batch_size, 1, height, width), dtype=torch.float32) 130 | for i in range(batch_size): 131 | bbox = bboxes[i] 132 | delta_h = np.random.randint(max_delta_h // 2 + 1) 133 | delta_w = np.random.randint(max_delta_w // 2 + 1) 134 | mask[i, :, bbox[0] + delta_h:bbox[0] + bbox[2] - delta_h, bbox[1] + delta_w:bbox[1] + bbox[3] - delta_w] = 1. 135 | return mask 136 | 137 | 138 | def test_bbox2mask(): 139 | image_shape = [256, 256, 3] 140 | mask_shape = [128, 128] 141 | margin = [0, 0] 142 | max_delta_shape = [32, 32] 143 | bbox = random_bbox(image_shape) 144 | mask = bbox2mask(bbox, image_shape[0], image_shape[1], max_delta_shape[0], max_delta_shape[1]) 145 | return mask 146 | 147 | 148 | def local_patch(x, bbox_list): 149 | assert len(x.size()) == 4 150 | patches = [] 151 | for i, bbox in enumerate(bbox_list): 152 | t, l, h, w = bbox 153 | patches.append(x[i, :, t:t + h, l:l + w]) 154 | return torch.stack(patches, dim=0) 155 | 156 | 157 | def mask_image(x, bboxes, config): 158 | height, width, _ = config['image_shape'] 159 | max_delta_h, max_delta_w = config['max_delta_shape'] 160 | mask = bbox2mask(bboxes, height, width, max_delta_h, max_delta_w) 161 | if x.is_cuda: 162 | mask = mask.cuda() 163 | 164 | if config['mask_type'] == 'hole': 165 | result = x * (1. - mask) 166 | elif config['mask_type'] == 'mosaic': 167 | # TODO: Matching the mosaic patch size and the mask size 168 | mosaic_unit_size = config['mosaic_unit_size'] 169 | downsampled_image = F.interpolate(x, scale_factor=1. / mosaic_unit_size, mode='nearest') 170 | upsampled_image = F.interpolate(downsampled_image, size=(height, width), mode='nearest') 171 | result = upsampled_image * mask + x * (1. - mask) 172 | else: 173 | raise NotImplementedError('Not implemented mask type.') 174 | 175 | return result, mask 176 | 177 | 178 | def spatial_discounting_mask(config): 179 | """Generate spatial discounting mask constant. 180 | 181 | Spatial discounting mask is first introduced in publication: 182 | Generative Image Inpainting with Contextual Attention, Yu et al. 183 | 184 | Args: 185 | config: Config should have configuration including HEIGHT, WIDTH, 186 | DISCOUNTED_MASK. 187 | 188 | Returns: 189 | tf.Tensor: spatial discounting mask 190 | 191 | """ 192 | gamma = config['spatial_discounting_gamma'] 193 | height, width = config['mask_shape'] 194 | shape = [1, 1, height, width] 195 | if config['discounted_mask']: 196 | mask_values = np.ones((height, width)) 197 | for i in range(height): 198 | for j in range(width): 199 | mask_values[i, j] = max( 200 | gamma ** min(i, height - i), 201 | gamma ** min(j, width - j)) 202 | mask_values = np.expand_dims(mask_values, 0) 203 | mask_values = np.expand_dims(mask_values, 0) 204 | else: 205 | mask_values = np.ones(shape) 206 | spatial_discounting_mask_tensor = torch.tensor(mask_values, dtype=torch.float32) 207 | if config['cuda']: 208 | spatial_discounting_mask_tensor = spatial_discounting_mask_tensor.cuda() 209 | return spatial_discounting_mask_tensor 210 | 211 | 212 | def reduce_mean(x, axis=None, keepdim=False): 213 | if not axis: 214 | axis = range(len(x.shape)) 215 | for i in sorted(axis, reverse=True): 216 | x = torch.mean(x, dim=i, keepdim=keepdim) 217 | return x 218 | 219 | 220 | def reduce_std(x, axis=None, keepdim=False): 221 | if not axis: 222 | axis = range(len(x.shape)) 223 | for i in sorted(axis, reverse=True): 224 | x = torch.std(x, dim=i, keepdim=keepdim) 225 | return x 226 | 227 | 228 | def reduce_sum(x, axis=None, keepdim=False): 229 | if not axis: 230 | axis = range(len(x.shape)) 231 | for i in sorted(axis, reverse=True): 232 | x = torch.sum(x, dim=i, keepdim=keepdim) 233 | return x 234 | 235 | 236 | def flow_to_image(flow): 237 | """Transfer flow map to image. 238 | Part of code forked from flownet. 239 | """ 240 | out = [] 241 | maxu = -999. 242 | maxv = -999. 243 | minu = 999. 244 | minv = 999. 245 | maxrad = -1 246 | for i in range(flow.shape[0]): 247 | u = flow[i, :, :, 0] 248 | v = flow[i, :, :, 1] 249 | idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7) 250 | u[idxunknow] = 0 251 | v[idxunknow] = 0 252 | maxu = max(maxu, np.max(u)) 253 | minu = min(minu, np.min(u)) 254 | maxv = max(maxv, np.max(v)) 255 | minv = min(minv, np.min(v)) 256 | rad = np.sqrt(u ** 2 + v ** 2) 257 | maxrad = max(maxrad, np.max(rad)) 258 | u = u / (maxrad + np.finfo(float).eps) 259 | v = v / (maxrad + np.finfo(float).eps) 260 | img = compute_color(u, v) 261 | out.append(img) 262 | return np.float32(np.uint8(out)) 263 | 264 | 265 | def pt_flow_to_image(flow): 266 | """Transfer flow map to image. 267 | Part of code forked from flownet. 268 | """ 269 | out = [] 270 | maxu = torch.tensor(-999) 271 | maxv = torch.tensor(-999) 272 | minu = torch.tensor(999) 273 | minv = torch.tensor(999) 274 | maxrad = torch.tensor(-1) 275 | if torch.cuda.is_available(): 276 | maxu = maxu.cuda() 277 | maxv = maxv.cuda() 278 | minu = minu.cuda() 279 | minv = minv.cuda() 280 | maxrad = maxrad.cuda() 281 | for i in range(flow.shape[0]): 282 | u = flow[i, 0, :, :] 283 | v = flow[i, 1, :, :] 284 | idxunknow = (torch.abs(u) > 1e7) + (torch.abs(v) > 1e7) 285 | u[idxunknow] = 0 286 | v[idxunknow] = 0 287 | maxu = torch.max(maxu, torch.max(u)) 288 | minu = torch.min(minu, torch.min(u)) 289 | maxv = torch.max(maxv, torch.max(v)) 290 | minv = torch.min(minv, torch.min(v)) 291 | rad = torch.sqrt((u ** 2 + v ** 2).float()).to(torch.int64) 292 | maxrad = torch.max(maxrad, torch.max(rad)) 293 | u = u / (maxrad + torch.finfo(torch.float32).eps) 294 | v = v / (maxrad + torch.finfo(torch.float32).eps) 295 | # TODO: change the following to pytorch 296 | img = pt_compute_color(u, v) 297 | out.append(img) 298 | 299 | return torch.stack(out, dim=0) 300 | 301 | 302 | def highlight_flow(flow): 303 | """Convert flow into middlebury color code image. 304 | """ 305 | out = [] 306 | s = flow.shape 307 | for i in range(flow.shape[0]): 308 | img = np.ones((s[1], s[2], 3)) * 144. 309 | u = flow[i, :, :, 0] 310 | v = flow[i, :, :, 1] 311 | for h in range(s[1]): 312 | for w in range(s[1]): 313 | ui = u[h, w] 314 | vi = v[h, w] 315 | img[ui, vi, :] = 255. 316 | out.append(img) 317 | return np.float32(np.uint8(out)) 318 | 319 | 320 | def pt_highlight_flow(flow): 321 | """Convert flow into middlebury color code image. 322 | """ 323 | out = [] 324 | s = flow.shape 325 | for i in range(flow.shape[0]): 326 | img = np.ones((s[1], s[2], 3)) * 144. 327 | u = flow[i, :, :, 0] 328 | v = flow[i, :, :, 1] 329 | for h in range(s[1]): 330 | for w in range(s[1]): 331 | ui = u[h, w] 332 | vi = v[h, w] 333 | img[ui, vi, :] = 255. 334 | out.append(img) 335 | return np.float32(np.uint8(out)) 336 | 337 | 338 | def compute_color(u, v): 339 | h, w = u.shape 340 | img = np.zeros([h, w, 3]) 341 | nanIdx = np.isnan(u) | np.isnan(v) 342 | u[nanIdx] = 0 343 | v[nanIdx] = 0 344 | # colorwheel = COLORWHEEL 345 | colorwheel = make_color_wheel() 346 | ncols = np.size(colorwheel, 0) 347 | rad = np.sqrt(u ** 2 + v ** 2) 348 | a = np.arctan2(-v, -u) / np.pi 349 | fk = (a + 1) / 2 * (ncols - 1) + 1 350 | k0 = np.floor(fk).astype(int) 351 | k1 = k0 + 1 352 | k1[k1 == ncols + 1] = 1 353 | f = fk - k0 354 | for i in range(np.size(colorwheel, 1)): 355 | tmp = colorwheel[:, i] 356 | col0 = tmp[k0 - 1] / 255 357 | col1 = tmp[k1 - 1] / 255 358 | col = (1 - f) * col0 + f * col1 359 | idx = rad <= 1 360 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 361 | notidx = np.logical_not(idx) 362 | col[notidx] *= 0.75 363 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) 364 | return img 365 | 366 | 367 | def pt_compute_color(u, v): 368 | h, w = u.shape 369 | img = torch.zeros([3, h, w]) 370 | if torch.cuda.is_available(): 371 | img = img.cuda() 372 | nanIdx = (torch.isnan(u) + torch.isnan(v)) != 0 373 | u[nanIdx] = 0. 374 | v[nanIdx] = 0. 375 | # colorwheel = COLORWHEEL 376 | colorwheel = pt_make_color_wheel() 377 | if torch.cuda.is_available(): 378 | colorwheel = colorwheel.cuda() 379 | ncols = colorwheel.size()[0] 380 | rad = torch.sqrt((u ** 2 + v ** 2).to(torch.float32)) 381 | a = torch.atan2(-v.to(torch.float32), -u.to(torch.float32)) / np.pi 382 | fk = (a + 1) / 2 * (ncols - 1) + 1 383 | k0 = torch.floor(fk).to(torch.int64) 384 | k1 = k0 + 1 385 | k1[k1 == ncols + 1] = 1 386 | f = fk - k0.to(torch.float32) 387 | for i in range(colorwheel.size()[1]): 388 | tmp = colorwheel[:, i] 389 | col0 = tmp[k0 - 1] 390 | col1 = tmp[k1 - 1] 391 | col = (1 - f) * col0 + f * col1 392 | idx = rad <= 1. / 255. 393 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 394 | notidx = (idx != 0) 395 | col[notidx] *= 0.75 396 | img[i, :, :] = col * (1 - nanIdx).to(torch.float32) 397 | return img 398 | 399 | 400 | def make_color_wheel(): 401 | RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6) 402 | ncols = RY + YG + GC + CB + BM + MR 403 | colorwheel = np.zeros([ncols, 3]) 404 | col = 0 405 | # RY 406 | colorwheel[0:RY, 0] = 255 407 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 408 | col += RY 409 | # YG 410 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 411 | colorwheel[col:col + YG, 1] = 255 412 | col += YG 413 | # GC 414 | colorwheel[col:col + GC, 1] = 255 415 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 416 | col += GC 417 | # CB 418 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 419 | colorwheel[col:col + CB, 2] = 255 420 | col += CB 421 | # BM 422 | colorwheel[col:col + BM, 2] = 255 423 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 424 | col += + BM 425 | # MR 426 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 427 | colorwheel[col:col + MR, 0] = 255 428 | return colorwheel 429 | 430 | 431 | def pt_make_color_wheel(): 432 | RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6) 433 | ncols = RY + YG + GC + CB + BM + MR 434 | colorwheel = torch.zeros([ncols, 3]) 435 | col = 0 436 | # RY 437 | colorwheel[0:RY, 0] = 1. 438 | colorwheel[0:RY, 1] = torch.arange(0, RY, dtype=torch.float32) / RY 439 | col += RY 440 | # YG 441 | colorwheel[col:col + YG, 0] = 1. - (torch.arange(0, YG, dtype=torch.float32) / YG) 442 | colorwheel[col:col + YG, 1] = 1. 443 | col += YG 444 | # GC 445 | colorwheel[col:col + GC, 1] = 1. 446 | colorwheel[col:col + GC, 2] = torch.arange(0, GC, dtype=torch.float32) / GC 447 | col += GC 448 | # CB 449 | colorwheel[col:col + CB, 1] = 1. - (torch.arange(0, CB, dtype=torch.float32) / CB) 450 | colorwheel[col:col + CB, 2] = 1. 451 | col += CB 452 | # BM 453 | colorwheel[col:col + BM, 2] = 1. 454 | colorwheel[col:col + BM, 0] = torch.arange(0, BM, dtype=torch.float32) / BM 455 | col += BM 456 | # MR 457 | colorwheel[col:col + MR, 2] = 1. - (torch.arange(0, MR, dtype=torch.float32) / MR) 458 | colorwheel[col:col + MR, 0] = 1. 459 | return colorwheel 460 | 461 | 462 | def is_image_file(filename): 463 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 464 | filename_lower = filename.lower() 465 | return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS) 466 | 467 | 468 | def deprocess(img): 469 | img = img.add_(1).div_(2) 470 | return img 471 | 472 | 473 | # get configs 474 | def get_config(config): 475 | with open(config, 'r') as stream: 476 | return yaml.load(stream) 477 | 478 | 479 | # Get model list for resume 480 | def get_model_list(dirname, key, iteration=0): 481 | if os.path.exists(dirname) is False: 482 | return None 483 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 484 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] 485 | if gen_models is None: 486 | return None 487 | gen_models.sort() 488 | if iteration == 0: 489 | last_model_name = gen_models[-1] 490 | else: 491 | for model_name in gen_models: 492 | if '{:0>8d}'.format(iteration) in model_name: 493 | return model_name 494 | raise ValueError('Not found models with this iteration') 495 | return last_model_name 496 | 497 | 498 | if __name__ == '__main__': 499 | test_random_bbox() 500 | mask = test_bbox2mask() 501 | print(mask.shape) 502 | import matplotlib.pyplot as plt 503 | 504 | plt.imshow(mask, cmap='gray') 505 | plt.show() 506 | --------------------------------------------------------------------------------