├── utils ├── __init__.py ├── visualizer.py └── utils.py ├── models ├── __init__.py ├── embedding.py ├── attention.py └── model.py ├── FOD-framework.jpg ├── requirements.txt ├── datasets ├── __init__.py ├── btad.py ├── mvtec_3d.py └── mvtec.py ├── losses.py ├── LICENSE ├── create_distance_maps.py ├── README.md ├── main.py ├── create_ref_features.py └── trainer.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FOD 2 | 3 | 4 | __all__ = ['FOD'] -------------------------------------------------------------------------------- /FOD-framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcyao00/FOD/HEAD/FOD-framework.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | matplotlib 4 | scipy == 1.9.3 5 | timm == 0.6.12 6 | scikit-learn == 1.1.3 7 | torch == 1.13.1 8 | torchvision == 0.14.1 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mvtec import MVTecDataset, MVTEC_CLASS_NAMES 2 | from .btad import BTADDataset, BTAD_CLASS_NAMES 3 | from .mvtec_3d import MVTec3DDataset, MVTEC3D_CLASS_NAMES -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def kl_loss(p, q): 5 | # p: (N, n_heads, L, L) q: (N, n_heads, L, L) 6 | logits = p * (torch.log(p + 0.0001) - torch.log(q + 0.0001)) # (N, n_heads, L, L) 7 | kl = torch.sum(logits, dim=-1) # (N, n_heads, L) 8 | 9 | return torch.mean(kl, dim=1) # (N, L) 10 | 11 | 12 | def entropy_loss(p): 13 | # p: (N, n_heads, L, L) 14 | logits = -p * torch.log(p + 0.0001) 15 | entropy = torch.sum(logits, dim=-1) # (N, n_heads, L) 16 | 17 | return torch.mean(entropy, dim=1) # (N, L) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 @ Shanghai Jiao Tong University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /create_distance_maps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | 5 | 6 | def idx_1d_to_pos_2d(idx, width=16): 7 | h = idx // width 8 | w = idx % width 9 | 10 | return h, w 11 | 12 | 13 | if __name__ == '__main__': 14 | os.makedirs("distances", exist_ok=True) 15 | seg_lens = [1024, 256] 16 | for seq_len in seg_lens: 17 | print("Creating {}x{} distance map...".format(seq_len, seq_len)) 18 | width = int(math.sqrt(seq_len)) 19 | distances_x = np.zeros((seq_len, seq_len), dtype=np.float32) # position distances 20 | distances_y = np.zeros((seq_len, seq_len), dtype=np.float32) 21 | for idx1 in range(seq_len): 22 | for idx2 in range(seq_len): 23 | h1, w1 = idx_1d_to_pos_2d(idx1, width=width) # convert 1d index to 2d position 24 | h2, w2 = idx_1d_to_pos_2d(idx2, width=width) 25 | # position distances are represented by 1d index relation 26 | distances_x[idx1][idx2] = abs(w1 - w2) ** 2 27 | distances_y[idx1][idx2] = abs(h1 - h2) ** 2 28 | np.save(os.path.join("distances", "distances_x_{}.npy".format(seq_len)), distances_x) 29 | np.save(os.path.join("distances", "distances_y_{}.npy".format(seq_len)), distances_y) -------------------------------------------------------------------------------- /models/embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | def __init__(self, d_model, h=16, w=16, device=torch.device('cuda')): 8 | super(PositionalEmbedding, self).__init__() 9 | 10 | if d_model % 4 != 0: 11 | raise ValueError("Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(d_model)) 12 | self.d_model = d_model 13 | self.h = h 14 | self.w = w 15 | 16 | pos_embed = torch.zeros(self.d_model, self.h, self.w) 17 | # Each dimension use half of D 18 | half_d_model = self.d_model // 2 19 | div_term = torch.exp(torch.arange(0.0, half_d_model, 2) * -(math.log(1e4) / half_d_model)) 20 | pos_w = torch.arange(0.0, self.w).unsqueeze(1) 21 | pos_h = torch.arange(0.0, self.h).unsqueeze(1) 22 | pos_embed[0:half_d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, self.h, 1) 23 | pos_embed[1:half_d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, self.h, 1) 24 | pos_embed[half_d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, self.w) 25 | pos_embed[half_d_model+1::2,:, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, self.w) 26 | 27 | self.pos_embed = pos_embed.to(device) 28 | 29 | def forward(self): 30 | return self.pos_embed 31 | 32 | 33 | class ProjectEmbedding(nn.Module): 34 | def __init__(self, in_channels, d_model): 35 | super(ProjectEmbedding, self).__init__() 36 | 37 | self.project = nn.Conv1d(in_channels=in_channels, out_channels=d_model, kernel_size=1) 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv1d): 40 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') 41 | 42 | def forward(self, x): 43 | x = self.project(x.permute(0, 2, 1)).transpose(1, 2) 44 | 45 | return x 46 | 47 | 48 | class Embedding2D(nn.Module): 49 | def __init__(self, in_channels, d_model, dropout=0.0, h=16, w=16, with_pos_embed=True, device=torch.device('cuda')): 50 | super(Embedding2D, self).__init__() 51 | 52 | self.project_embedding = ProjectEmbedding(in_channels, d_model) 53 | if with_pos_embed: 54 | self.position_embedding = PositionalEmbedding(d_model, h=h, w=w, device=device) 55 | self.with_pos_embed = with_pos_embed 56 | 57 | self.dropout = nn.Dropout(p=dropout) 58 | 59 | def forward(self, x): 60 | x = self.project_embedding(x) 61 | 62 | if self.with_pos_embed: 63 | pos_embed = self.position_embedding() 64 | pos_embed = pos_embed.unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 65 | pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(x.shape[0], -1, x.shape[-1]) 66 | 67 | x = x + pos_embed 68 | 69 | return self.dropout(x) 70 | 71 | 72 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | class Attention2D(nn.Module): 8 | def __init__(self, 9 | seq_len=256, 10 | d_model=512, 11 | num_heads=8, 12 | scale_factor=None, 13 | dropout=0.0, 14 | device=torch.device('cuda')): 15 | super(Attention2D, self).__init__() 16 | 17 | self.seq_len = seq_len 18 | self.width = int(math.sqrt(seq_len)) 19 | self.scale_factor = scale_factor or 1. / math.sqrt(d_model // num_heads) 20 | self.dropout = nn.Dropout(dropout) 21 | self.num_heads = num_heads 22 | self.device = device 23 | 24 | self.query_projection = nn.Linear(d_model, d_model) 25 | self.key_projection = nn.Linear(d_model, d_model) 26 | self.value_projection = nn.Linear(d_model, d_model) 27 | self.sigma_projection = nn.Linear(d_model, 2 * num_heads) 28 | self.out_projection = nn.Linear(d_model, d_model) 29 | 30 | distances_x = np.load('distances/distances_x_{}.npy'.format(seq_len)) 31 | distances_y = np.load('distances/distances_y_{}.npy'.format(seq_len)) 32 | distances_x = torch.from_numpy(distances_x) 33 | distances_y = torch.from_numpy(distances_y) 34 | self.distances_x = distances_x.to(device) 35 | self.distances_y = distances_y.to(device) 36 | 37 | def forward(self, query, key, value, return_attention=True): 38 | B, L, _ = query.shape 39 | _, S, _ = key.shape 40 | 41 | if return_attention: # the sigma will be learned in the intra and inter correlation branches 42 | sigma = self.sigma_projection(query).view(B, L, self.num_heads, -1) 43 | query = self.query_projection(query).view(B, L, self.num_heads, -1) 44 | key = self.key_projection(key).view(B, S, self.num_heads, -1) 45 | value = self.value_projection(value).view(B, S, self.num_heads, -1) 46 | 47 | scores = torch.einsum("blhe,bshe->bhls", query, key) 48 | 49 | # attn: (N, n_heads, L, L) 50 | attn = self.scale_factor * scores 51 | 52 | if return_attention: 53 | sigma = sigma.transpose(1, 2) # (B, L, n_heads, 2) -> (B, n_heads, L, 2) 54 | sigma = torch.sigmoid(sigma * 5) + 1e-5 55 | sigma = torch.pow(3, sigma) - 1 # can change these hyperparameter 56 | 57 | sigma1 = sigma[:, :, :, 0] # (B, n_heads, L) 58 | sigma2 = sigma[:, :, :, 1] # (B, n_heads, L) 59 | sigma1 = sigma1.unsqueeze(-1).repeat(1, 1, 1, self.seq_len) # (B, n_heads, L, L) 60 | sigma2 = sigma2.unsqueeze(-1).repeat(1, 1, 1, self.seq_len) # (B, n_heads, L, L) 61 | 62 | # (B, n_heads, L, L) 63 | distances_x = self.distances_x.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).to(self.device) 64 | distances_y = self.distances_y.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).to(self.device) 65 | # gaussian distance prior 66 | target = 1.0 / (2 * math.pi * sigma1 * sigma2) * torch.exp(-distances_y / (2 * sigma1 ** 2) -distances_x / (2 * sigma2 ** 2)) 67 | 68 | softmax_scores = self.dropout(torch.softmax(attn, dim=-1)) 69 | out = torch.einsum("bhls,bshd->blhd", softmax_scores, value) 70 | 71 | out = out.contiguous().view(B, L, -1) 72 | self.out_projection(out) 73 | 74 | if return_attention: 75 | return out, softmax_scores, target 76 | else: 77 | return out, None, None 78 | 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [Focus the Discrepancy: Intra- and Inter-Correlation Learning for Image Anomaly Detection](https://arxiv.org/abs/2308.02983) 2 | 3 | PyTorch implementation and for ICCV2023 paper, Focus the Discrepancy: Intra- and Inter-Correlation Learning for Image Anomaly Detection. 4 | 5 | 6 | 7 | --- 8 | 9 | ## Installation 10 | Install all packages (the same version with ours) by the following command: 11 | ``` 12 | $ pip3 install -r requirements.txt 13 | ``` 14 | 15 | ## Download Datasets 16 | Please download MVTecAD dataset from [MVTecAD dataset](https://www.mvtec.com/de/unternehmen/forschung/datasets/mvtec-ad/), BTAD dataset from [BTAD dataset](http://avires.dimi.uniud.it/papers/btad/btad.zip), and MVTec3D dataset from [MVTec3D dataset](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad). 17 | 18 | ## Creating Distance Maps 19 | Please run the following code for creating distance maps used in target correlations. 20 | 21 | ``` 22 | python create_distance_maps.py 23 | ``` 24 | 25 | ## Creating Reference Features 26 | Please run the following code for generating external reference features (based on ``wide_resnet50``). 27 | 28 | ```bash 29 | # For MVTecAD 30 | python create_ref_features.py --dataset mvtec --data_path /path/to/your/dataset --backbone_arch wide_resnet50_2 --save_path rfeatures_w50 31 | # For BTAD 32 | python create_ref_features.py --dataset btad --data_path /path/to/your/dataset --backbone_arch wide_resnet50_2 --save_path rfeatures_w50 33 | # For MVTec3D-RGB 34 | python create_ref_features.py --dataset mvtec3d --data_path /path/to/your/dataset --backbone_arch wide_resnet50_2 --save_path rfeatures_w50 35 | ``` 36 | 37 | 38 | ## Training and Evaluating 39 | In this repository, we use ``wide_resnet50`` as the feature extractor by default. 40 | As we find this can get slightly better results than ``efficientnet-b6`` reported in the paper. 41 | 42 | - Run code for training and evaluating MVTecAD 43 | ```bash 44 | python main.py --dataset mvtec --data_path /path/to/your/dataset --backbone_arch wide_resnet50_2 --rfeatures_path rfeatures_w50 --with_intra --with_inter --save_prefix mvtec 45 | ``` 46 | - Run code for training and evaluating BTAD 47 | ```bash 48 | python main.py --dataset btad --data_path /path/to/your/dataset --backbone_arch wide_resnet50_2 --rfeatures_path rfeatures_w50 --with_intra --with_inter --save_prefix btad 49 | ``` 50 | - Run code for training and evaluating MVTec3D-RGB 51 | ```bash 52 | python main.py --dataset mvtec3d --data_path /path/to/your/dataset --backbone_arch wide_resnet50_2 --rfeatures_path rfeatures_w50 --with_intra --with_inter --save_prefix mvtec3d 53 | ``` 54 | 55 | ## Citation 56 | 57 | If you find this repository useful, please consider citing our work: 58 | ``` 59 | @article{FOD, 60 | title={Focus the Discrepancy: Intra- and Inter-Correlation Learning for Image Anomaly Detection}, 61 | author={Xincheng Yao and Ruoqi Li and Zefeng Qian and Yan Luo and Chongyang Zhang}, 62 | year={2023}, 63 | booktitle={International Conference on Computer Vision 2023}, 64 | url={https://arxiv.org/abs/2308.02983}, 65 | primaryClass={cs.CV} 66 | } 67 | ``` 68 | 69 | If you are interested in our work, you can also follow our other works: [BGAD (CVPR2023)](https://github.com/xcyao00/BGAD), [PMAD (AAAI2023)](https://github.com/xcyao00/PMAD), [ResAD (NeurIPS2024)](https://github.com/xcyao00/ResAD), [HGAD (ECCV2024)](https://github.com/xcyao00/HGAD). Or, you can follow our github page [xcyao00](https://github.com/xcyao00). 70 | 71 | -------------------------------------------------------------------------------- /datasets/btad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms as T 6 | 7 | 8 | BTAD_CLASS_NAMES = ['01', '02', '03'] 9 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 10 | IMAGENET_STD = [0.229, 0.224, 0.225] 11 | 12 | 13 | class BTADDataset(Dataset): 14 | def __init__(self, 15 | data_path, 16 | classname, 17 | resize=256, 18 | cropsize=256, 19 | is_train=True): 20 | assert classname in BTAD_CLASS_NAMES, 'class_name: {}, should be in {}'.format(classname, BTAD_CLASS_NAMES) 21 | self.dataset_path = data_path 22 | self.class_name = classname 23 | self.is_train = is_train 24 | self.cropsize = cropsize 25 | # load dataset 26 | self.x, self.y, self.mask, self.img_types = self.load_dataset_folder() 27 | # set transforms 28 | if is_train: 29 | self.transform_x = T.Compose([ 30 | T.Resize(resize, Image.ANTIALIAS), 31 | #T.RandomRotation(5), 32 | T.CenterCrop(cropsize), 33 | T.ToTensor()]) 34 | # test: 35 | else: 36 | self.transform_x = T.Compose([ 37 | T.Resize(resize, Image.ANTIALIAS), 38 | T.CenterCrop(cropsize), 39 | T.ToTensor()]) 40 | # mask 41 | self.transform_mask = T.Compose([ 42 | T.Resize(resize, Image.NEAREST), 43 | T.CenterCrop(cropsize), 44 | T.ToTensor()]) 45 | 46 | self.normalize = T.Compose([T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) 47 | 48 | def __getitem__(self, idx): 49 | img_path, y, mask, img_type = self.x[idx], self.y[idx], self.mask[idx], self.img_types[idx] 50 | 51 | x = Image.open(img_path).convert('RGB') 52 | 53 | x = self.normalize(self.transform_x(x)) 54 | 55 | if y == 0: 56 | mask = torch.zeros([1, self.cropsize, self.cropsize]) 57 | else: 58 | mask = Image.open(mask) 59 | mask = self.transform_mask(mask) 60 | 61 | return x, y, mask, os.path.basename(img_path[:-4]), img_type 62 | 63 | def __len__(self): 64 | return len(self.x) 65 | 66 | def load_dataset_folder(self): 67 | phase = 'train' if self.is_train else 'test' 68 | x, y, mask, types = [], [], [], [] 69 | 70 | img_dir = os.path.join(self.dataset_path, self.class_name, phase) 71 | gt_dir = os.path.join(self.dataset_path, self.class_name, 'ground_truth') 72 | 73 | img_types = sorted(os.listdir(img_dir)) 74 | for img_type in img_types: 75 | 76 | # load images 77 | img_type_dir = os.path.join(img_dir, img_type) 78 | if not os.path.isdir(img_type_dir): 79 | continue 80 | img_fpath_list = sorted([os.path.join(img_type_dir, f) 81 | for f in os.listdir(img_type_dir) 82 | if f.endswith('.bmp') or f.endswith('.png')]) 83 | x.extend(img_fpath_list) 84 | 85 | # load gt labels 86 | if img_type == 'ok': 87 | y.extend([0] * len(img_fpath_list)) 88 | mask.extend([None] * len(img_fpath_list)) 89 | types.extend(['ok'] * len(img_fpath_list)) 90 | else: 91 | y.extend([1] * len(img_fpath_list)) 92 | gt_type_dir = os.path.join(gt_dir, img_type) 93 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list] 94 | if self.class_name == '03': 95 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '.bmp') 96 | for img_fname in img_fname_list] 97 | else: 98 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '.png') 99 | for img_fname in img_fname_list] 100 | mask.extend(gt_fpath_list) 101 | types.extend([img_type] * len(img_fpath_list)) 102 | 103 | assert len(x) == len(y), 'number of x and y should be same' 104 | 105 | return list(x), list(y), list(mask), list(types) -------------------------------------------------------------------------------- /datasets/mvtec_3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms as T 7 | 8 | 9 | # URL = 'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz' 10 | MVTEC3D_CLASS_NAMES = ['bagel', 'cable_gland', 'carrot', 'cookie', 'dowel', 11 | 'foam', 'peach', 'potato', 'rope', 'tire'] 12 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 13 | IMAGENET_STD = [0.229, 0.224, 0.225] 14 | 15 | 16 | class MVTec3DDataset(Dataset): 17 | def __init__(self, 18 | data_path, 19 | classname, 20 | resize=256, 21 | cropsize=256, 22 | is_train=True): 23 | assert classname in MVTEC3D_CLASS_NAMES, 'class_name: {}, should be in {}'.format(classname, MVTEC3D_CLASS_NAMES) 24 | self.dataset_path = data_path 25 | self.class_name = classname 26 | self.is_train = is_train 27 | self.cropsize = cropsize 28 | # load dataset 29 | self.x, self.y, self.mask, self.img_types = self.load_dataset_folder() 30 | # set transforms 31 | if is_train: 32 | self.transform_x = T.Compose([ 33 | T.Resize(resize, Image.ANTIALIAS), 34 | #T.RandomRotation(5), 35 | T.CenterCrop(cropsize), 36 | T.ToTensor()]) 37 | # test: 38 | else: 39 | self.transform_x = T.Compose([ 40 | T.Resize(resize, Image.ANTIALIAS), 41 | T.CenterCrop(cropsize), 42 | T.ToTensor()]) 43 | # mask 44 | self.transform_mask = T.Compose([ 45 | T.Resize(resize, Image.NEAREST), 46 | T.CenterCrop(cropsize), 47 | T.ToTensor()]) 48 | 49 | self.normalize = T.Compose([T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) 50 | 51 | def __getitem__(self, idx): 52 | img_path, y, mask, img_type = self.x[idx], self.y[idx], self.mask[idx], self.img_types[idx] 53 | 54 | x = Image.open(img_path).convert('RGB') 55 | 56 | x = self.normalize(self.transform_x(x)) 57 | 58 | if y == 0: 59 | mask = torch.zeros([1, self.cropsize, self.cropsize]) 60 | else: 61 | mask = Image.open(mask) 62 | mask = np.array(mask) 63 | mask[mask != 0] = 255 64 | mask = Image.fromarray(mask) 65 | mask = self.transform_mask(mask) 66 | 67 | return x, y, mask, os.path.basename(img_path[:-4]), img_type 68 | 69 | def __len__(self): 70 | return len(self.x) 71 | 72 | def load_dataset_folder(self): 73 | phase = 'train' if self.is_train else 'test' 74 | x, y, mask, types = [], [], [], [] 75 | 76 | img_dir = os.path.join(self.dataset_path, self.class_name, phase) 77 | gt_dir = os.path.join(self.dataset_path, self.class_name, 'test') 78 | 79 | img_types = sorted(os.listdir(img_dir)) 80 | for img_type in img_types: 81 | 82 | # load images 83 | img_type_dir = os.path.join(img_dir, img_type, 'rgb') 84 | if not os.path.isdir(img_type_dir): 85 | continue 86 | img_fpath_list = sorted([os.path.join(img_type_dir, f) 87 | for f in os.listdir(img_type_dir) 88 | if f.endswith('.png')]) 89 | x.extend(img_fpath_list) 90 | 91 | # load gt labels 92 | if img_type == 'good': 93 | y.extend([0] * len(img_fpath_list)) 94 | mask.extend([None] * len(img_fpath_list)) 95 | types.extend(['good'] * len(img_fpath_list)) 96 | else: 97 | y.extend([1] * len(img_fpath_list)) 98 | gt_type_dir = os.path.join(gt_dir, img_type, 'gt') 99 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list] 100 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '.png') 101 | for img_fname in img_fname_list] 102 | mask.extend(gt_fpath_list) 103 | types.extend([img_type] * len(img_fpath_list)) 104 | 105 | assert len(x) == len(y), 'number of x and y should be same' 106 | 107 | return list(x), list(y), list(mask), list(types) -------------------------------------------------------------------------------- /datasets/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms as T 7 | 8 | 9 | # URL = 'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz' 10 | MVTEC_CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 11 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 12 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper'] 13 | 14 | 15 | class MVTecDataset(Dataset): 16 | def __init__(self, c, is_train=True, excluded_images=None): 17 | assert c.class_name in MVTEC_CLASS_NAMES, 'class_name: {}, should be in {}'.format(c.class_name, MVTEC_CLASS_NAMES) 18 | self.dataset_path = c.data_path 19 | self.class_name = c.class_name 20 | self.is_train = is_train 21 | self.cropsize = c.crop_size 22 | # load dataset 23 | self.x, self.y, self.mask, self.img_types = self.load_dataset_folder() 24 | # set transforms 25 | if is_train: 26 | self.transform_x = T.Compose([ 27 | T.Resize(c.img_size, Image.ANTIALIAS), 28 | #T.RandomRotation(5), 29 | T.CenterCrop(c.crop_size), 30 | T.ToTensor()]) 31 | # test: 32 | else: 33 | self.transform_x = T.Compose([ 34 | T.Resize(c.img_size, Image.ANTIALIAS), 35 | T.CenterCrop(c.crop_size), 36 | T.ToTensor()]) 37 | # mask 38 | self.transform_mask = T.Compose([ 39 | T.Resize(c.img_size, Image.NEAREST), 40 | T.CenterCrop(c.crop_size), 41 | T.ToTensor()]) 42 | 43 | self.normalize = T.Compose([T.Normalize(c.norm_mean, c.norm_std)]) 44 | 45 | def __getitem__(self, idx): 46 | img_path, y, mask, img_type = self.x[idx], self.y[idx], self.mask[idx], self.img_types[idx] 47 | 48 | x = Image.open(img_path) 49 | if self.class_name in ['zipper', 'screw', 'grid']: # handle greyscale classes 50 | x = np.expand_dims(np.array(x), axis=2) 51 | x = np.concatenate([x, x, x], axis=2) 52 | 53 | x = Image.fromarray(x.astype('uint8')).convert('RGB') 54 | 55 | x = self.normalize(self.transform_x(x)) 56 | 57 | if y == 0: 58 | mask = torch.zeros([1, self.cropsize[0], self.cropsize[1]]) 59 | else: 60 | mask = Image.open(mask) 61 | mask = self.transform_mask(mask) 62 | 63 | return x, y, mask, os.path.basename(img_path[:-4]), img_type 64 | 65 | def __len__(self): 66 | return len(self.x) 67 | 68 | def load_dataset_folder(self): 69 | phase = 'train' if self.is_train else 'test' 70 | x, y, mask, types = [], [], [], [] 71 | 72 | img_dir = os.path.join(self.dataset_path, self.class_name, phase) 73 | gt_dir = os.path.join(self.dataset_path, self.class_name, 'ground_truth') 74 | 75 | img_types = sorted(os.listdir(img_dir)) 76 | for img_type in img_types: 77 | 78 | # load images 79 | img_type_dir = os.path.join(img_dir, img_type) 80 | if not os.path.isdir(img_type_dir): 81 | continue 82 | img_fpath_list = sorted([os.path.join(img_type_dir, f) 83 | for f in os.listdir(img_type_dir) 84 | if f.endswith('.png')]) 85 | x.extend(img_fpath_list) 86 | 87 | # load gt labels 88 | if img_type == 'good': 89 | y.extend([0] * len(img_fpath_list)) 90 | mask.extend([None] * len(img_fpath_list)) 91 | types.extend(['good'] * len(img_fpath_list)) 92 | else: 93 | y.extend([1] * len(img_fpath_list)) 94 | gt_type_dir = os.path.join(gt_dir, img_type) 95 | img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list] 96 | gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '_mask.png') 97 | for img_fname in img_fname_list] 98 | mask.extend(gt_fpath_list) 99 | types.extend([img_type] * len(img_fpath_list)) 100 | 101 | assert len(x) == len(y), 'number of x and y should be same' 102 | 103 | return list(x), list(y), list(mask), list(types) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from torch.backends import cudnn 4 | from utils.utils import * 5 | 6 | from trainer import Trainer 7 | from datasets.mvtec import MVTEC_CLASS_NAMES 8 | from datasets.btad import BTAD_CLASS_NAMES 9 | from datasets.mvtec_3d import MVTEC3D_CLASS_NAMES 10 | 11 | 12 | def main(args): 13 | cudnn.benchmark = True 14 | init_seeds(3407) 15 | 16 | trainer = Trainer(args) 17 | 18 | if args.mode == 'train': 19 | img_auc, pix_auc = trainer.train() 20 | elif args.mode == 'test': 21 | img_auc, pix_auc = trainer.test(vis=args.vis, checkpoint_path=args.checkpoint) 22 | print("Class Name: {}".format(args.class_name)) 23 | print('Image AUC: {}'.format(img_auc)) 24 | print('Pixel AUC: {}'.format(pix_auc)) 25 | 26 | return img_auc, pix_auc 27 | 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | # basic config 32 | parser.add_argument('--lr', type=float, default=1e-4) 33 | parser.add_argument('--num_epochs', type=int, default=100) 34 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 35 | # dataset config 36 | parser.add_argument('--dataset', default='mvtec', type=str, metavar='D', 37 | help='dataset name: mvtec/btad/mvtec3d (default: mvtec)') 38 | parser.add_argument('--data_path', default='/data/to/your/path', type=str) 39 | parser.add_argument('--class_name', default='none', type=str, metavar='C', 40 | help='class name for MVTecAD (default: none)') 41 | parser.add_argument('--inp_size', default=256, type=int, metavar='C', 42 | help='image resize dimensions (default: 256)') 43 | parser.add_argument('--batch_size', default=1, type=int, metavar='B', 44 | help='train batch size (default: 32)') 45 | parser.add_argument('--num_workers', default=4, type=int, metavar='G', 46 | help='number of data loading workers (default: 4)') 47 | # model config 48 | parser.add_argument('--backbone_arch', default='wide_resnet50_2', type=str, metavar='A', 49 | help='feature extractor: (default: wide_resnet50_2)') 50 | parser.add_argument('--feature_levels', default=2, type=int, metavar='L', 51 | help='number of feature layers (default: 2)') 52 | parser.add_argument('--rfeatures_path', default='rfeatures_w50', type=str, metavar='A', 53 | help='path to reference features (default: rfeatures_w50)') 54 | parser.add_argument('--with_intra', action='store_true', default=True, 55 | help='Learning intra correlations (default: True)') 56 | parser.add_argument('--with_inter', action='store_true', default=True, 57 | help='Learning inter correlations (default: True)') 58 | parser.add_argument('--lambda1', type=int, default=0.5) 59 | parser.add_argument('--lambda2', type=int, default=0.5) 60 | # misc 61 | parser.add_argument('--save_path', type=str, default='checkpoints') 62 | parser.add_argument('--save_prefix', type=str, default='mvtec') 63 | parser.add_argument('--checkpoint', default='', type=str, metavar='D', 64 | help='used in test phase, set same with the save_path/save_prefix') 65 | parser.add_argument('--vis', action='store_true', default=False, 66 | help='Visualize localization map (default: False)') 67 | 68 | args = parser.parse_args() 69 | 70 | args.device = torch.device("cuda") 71 | args.img_size = (args.inp_size, args.inp_size) 72 | args.crop_size = (args.inp_size, args.inp_size) 73 | args.norm_mean, args.norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 74 | 75 | args_dict = vars(args) 76 | print('------------ Options -------------') 77 | for k, v in sorted(args_dict.items()): 78 | print('%s: %s' % (str(k), str(v))) 79 | print('-------------- End ----------------') 80 | 81 | if args.dataset == 'mvtec': 82 | CLASS_NAMES = MVTEC_CLASS_NAMES 83 | elif args.dataset == 'btad': 84 | CLASS_NAMES = BTAD_CLASS_NAMES 85 | elif args.dataset == 'mvtec3d': 86 | CLASS_NAMES = MVTEC3D_CLASS_NAMES 87 | else: 88 | raise ValueError("Not recognized dataset: {}!".format(args.dataset)) 89 | 90 | img_aucs, pix_aucs = [], [] 91 | for class_name in CLASS_NAMES: 92 | args.class_name = class_name 93 | img_auc, pix_auc = main(args) 94 | img_aucs.append(img_auc) 95 | pix_aucs.append(pix_auc) 96 | for i, class_name in enumerate(CLASS_NAMES): 97 | print(f'{class_name}: Image-AUC: {img_aucs[i]}, Pixel-AUC: {pix_aucs[i]}') 98 | print('Mean Image-AUC: {}'.format(np.mean(img_aucs))) 99 | print('Mean Pixel-AUC: {}'.format(np.mean(pix_aucs))) -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import numpy as np 4 | from scipy.ndimage import gaussian_filter 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def denormalization(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 10 | mean = np.array(mean) 11 | std = np.array(std) 12 | x = (((x.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8) 13 | return x 14 | 15 | 16 | class Visualizer(object): 17 | def __init__(self, root, prefix=''): 18 | self.root = root 19 | self.prefix = prefix 20 | os.makedirs(self.root, exist_ok=True) 21 | os.makedirs(os.path.join(self.root, 'normal_ok'), exist_ok=True) 22 | os.makedirs(os.path.join(self.root, 'normal_nok'), exist_ok=True) 23 | os.makedirs(os.path.join(self.root, 'anomaly_ok'), exist_ok=True) 24 | os.makedirs(os.path.join(self.root, 'anomaly_nok'), exist_ok=True) 25 | 26 | def set_prefix(self, prefix): 27 | self.prefix = prefix 28 | 29 | def plot(self, test_imgs, scores, img_scores, gt_masks, file_names, img_types, img_threshold): 30 | """ 31 | Args: 32 | test_imgs (ndarray): shape (N, 3, h, w) 33 | scores (ndarray): shape (N, h, w) 34 | img_scores (ndarray): shape (N, ) 35 | gt_masks (ndarray): shape (N, 1, h, w) 36 | """ 37 | vmax = scores.max() * 255. 38 | vmin = scores.min() * 255. + 10 39 | vmax = vmax - 220 40 | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) 41 | for i in range(len(scores)): 42 | img = test_imgs[i] 43 | img = denormalization(img) 44 | gt_mask = gt_masks[i].squeeze() 45 | score = scores[i] 46 | #score = gaussian_filter(score, sigma=4) 47 | 48 | heat_map = score * 255 49 | fig_img, ax_img = plt.subplots(1, 3, figsize=(9, 3)) 50 | 51 | fig_img.subplots_adjust(wspace=0.05, hspace=0) 52 | for ax_i in ax_img: 53 | ax_i.axes.xaxis.set_visible(False) 54 | ax_i.axes.yaxis.set_visible(False) 55 | 56 | ax_img[0].imshow(img) 57 | ax_img[0].title.set_text('Input image') 58 | ax_img[1].imshow(gt_mask, cmap='gray') 59 | ax_img[1].title.set_text('GroundTruth') 60 | ax_img[2].imshow(heat_map, cmap='jet', norm=norm, interpolation='none') 61 | ax_img[2].imshow(img, cmap='gray', alpha=0.7, interpolation='none') 62 | ax_img[2].title.set_text('Segmentation') 63 | 64 | if img_types[i] == 'good': 65 | if img_scores[i] <= img_threshold: 66 | fig_img.savefig(os.path.join(self.root, 'normal_ok', img_types[i] + '_' + file_names[i]), dpi=300) 67 | else: 68 | fig_img.savefig(os.path.join(self.root, 'normal_nok', img_types[i] + '_' + file_names[i]), dpi=300) 69 | else: 70 | if img_scores[i] > img_threshold: 71 | fig_img.savefig(os.path.join(self.root, 'anomaly_ok', img_types[i] + '_' + file_names[i]), dpi=300) 72 | else: 73 | fig_img.savefig(os.path.join(self.root, 'anomaly_nok', img_types[i] + '_' + file_names[i]), dpi=300) 74 | 75 | #fig_img.savefig(os.path.join(self.root, str(i) + '.png'), dpi=1000) 76 | 77 | plt.close() 78 | 79 | 80 | def visualize_correlations(attn, img, file_name, img_type, args): 81 | L = attn.shape[0] 82 | H = W = int(math.sqrt(L)) 83 | attn = attn.reshape(H, W, H, W).cpu().numpy() 84 | # downsampling factor for the feature level 85 | factor = 16 86 | 87 | # let's select 4 reference points for visualization 88 | idxs = [(60, 60), (110, 110), (130, 130), (140, 160)] 89 | 90 | # here we create the canvas 91 | fig = plt.figure(constrained_layout=True, figsize=(25 * 0.7, 8.5 * 0.7)) 92 | # and we add one plot per reference point 93 | gs = fig.add_gridspec(2, 4) 94 | axs = [fig.add_subplot(gs[0, 0]), 95 | fig.add_subplot(gs[1, 0]), 96 | fig.add_subplot(gs[0, -1]), 97 | fig.add_subplot(gs[1, -1])] 98 | 99 | # for each one of the reference points, let's plot the self-attention for that point 100 | for idx_o, ax in zip(idxs, axs): 101 | idx = (idx_o[0] // factor, idx_o[1] // factor) 102 | ax.imshow(attn[idx[0], idx[1], ...], cmap='cividis', interpolation='nearest') 103 | ax.axis('off') 104 | ax.set_title(f'global-correlation{idx_o}') 105 | 106 | # and now let's add the central image, with the reference points as red circles 107 | img = denormalization(img.squeeze(0)) 108 | fcenter_ax = fig.add_subplot(gs[:, 1:-1]) 109 | fcenter_ax.imshow(img) 110 | for (y, x) in idxs: 111 | scale = img.shape[0] / img.shape[0] 112 | x = ((x // factor) + 0.5) * factor 113 | y = ((y // factor) + 0.5) * factor 114 | fcenter_ax.add_patch(plt.Circle((x * scale, y * scale), factor // 4 , color='r')) 115 | fcenter_ax.axis('off') 116 | 117 | os.makedirs(os.path.join('vis_results', 'attn', args.class_name), exist_ok=True) 118 | fig.savefig(os.path.join('vis_results', 'attn', args.class_name, img_type + '_' + file_name + '.png'), dpi=1000) -------------------------------------------------------------------------------- /create_ref_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timm 3 | import pickle 4 | import argparse 5 | from collections import OrderedDict 6 | 7 | from tqdm import tqdm 8 | 9 | import torch 10 | from datasets.mvtec import MVTecDataset, MVTEC_CLASS_NAMES 11 | from datasets.btad import BTADDataset, BTAD_CLASS_NAMES 12 | from datasets.mvtec_3d import MVTec3DDataset, MVTEC3D_CLASS_NAMES 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser('FOD') 17 | 18 | parser.add_argument('--save_path', type=str, default='./rfeatures_w50') 19 | 20 | parser.add_argument('--dataset', default='mvtec', type=str, metavar='D', 21 | help='dataset name: mvtec/btad/mvtec3d (default: mvtec)') 22 | parser.add_argument('--data_path', default='/data2/yxc/datasets/mvtec_anomaly_detection', type=str) 23 | parser.add_argument('--class_name', default='none', type=str, metavar='C', 24 | help='class name for MVTecAD (default: none)') 25 | parser.add_argument('--inp_size', default=256, type=int, metavar='C', 26 | help='image resize dimensions (default: 256)') 27 | parser.add_argument('--batch_size', default=32, type=int, metavar='B', 28 | help='train batch size (default: 32)') 29 | parser.add_argument('--num_workers', default=4, type=int, metavar='G', 30 | help='number of data loading workers (default: 4)') 31 | 32 | parser.add_argument('--backbone_arch', default='wide_resnet50_2', type=str, metavar='A', 33 | help='feature extractor: (default: efficientnet_b6)') 34 | parser.add_argument('--feature_levels', default=3, type=int, metavar='L', 35 | help='number of feature layers (default: 3)') 36 | 37 | return parser.parse_args() 38 | 39 | 40 | def main(): 41 | # device setup 42 | use_cuda = torch.cuda.is_available() 43 | device = torch.device('cuda' if use_cuda else 'cpu') 44 | 45 | args = parse_args() 46 | args.img_size = (args.inp_size, args.inp_size) 47 | args.crop_size = (args.inp_size, args.inp_size) 48 | args.norm_mean, args.norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 49 | 50 | os.makedirs(args.save_path, exist_ok=True) 51 | 52 | # load model 53 | encoder = timm.create_model(args.backbone_arch, features_only=True, 54 | out_indices=[i + 1 for i in range(args.feature_levels)], pretrained=True) 55 | encoder = encoder.to(device).eval() 56 | print("Feature Dimensions", encoder.feature_info.channels()) 57 | 58 | if args.dataset == 'mvtec': 59 | CLASS_NAMES = MVTEC_CLASS_NAMES 60 | elif args.dataset == 'btad': 61 | CLASS_NAMES = BTAD_CLASS_NAMES 62 | elif args.dataset == 'mvtec3d': 63 | CLASS_NAMES = MVTEC3D_CLASS_NAMES 64 | else: 65 | raise ValueError("Not recognized dataset: {}!".format(args.dataset)) 66 | 67 | for class_name in CLASS_NAMES: 68 | args.class_name = class_name 69 | if args.class_name in MVTEC_CLASS_NAMES: 70 | dataset = MVTecDataset(args, is_train=True) 71 | elif args.class_name in BTAD_CLASS_NAMES: 72 | dataset = BTADDataset(args.data_path, classname=args.class_name, resize=256, cropsize=256, is_train=True) 73 | elif args.class_name in MVTEC3D_CLASS_NAMES: 74 | dataset = MVTec3DDataset(args.data_path, classname=args.class_name, resize=256, cropsize=256, is_train=True) 75 | else: 76 | raise ValueError('Invalid Class Name: {}'.format(args.class_name)) 77 | 78 | kwargs = {'num_workers': args.num_workers, 'pin_memory': True} 79 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, **kwargs) 80 | 81 | train_outputs = OrderedDict([('layer0', []), ('layer1', []), ('layer2', [])]) 82 | 83 | # extract train set features 84 | train_feature_filepath = os.path.join(args.save_path, '%s.pkl' % class_name) 85 | if not os.path.exists(train_feature_filepath): 86 | for (images, _, _, _, _) in tqdm(loader, '| feature extraction | train | %s |' % class_name): 87 | # model prediction 88 | with torch.no_grad(): 89 | outputs = encoder(images.to(device)) 90 | # get intermediate layer outputs 91 | for k, v in zip(train_outputs.keys(), outputs): 92 | train_outputs[k].append(v.cpu().detach()) 93 | # every single feature level, calculate mean and cov statistics. 94 | for k, v in train_outputs.items(): 95 | embedding_vectors = torch.cat(v, 0) 96 | m = torch.nn.AvgPool2d(3, 1, 1) 97 | embedding_vectors = m(embedding_vectors) 98 | 99 | B, C, H, W = embedding_vectors.size() # (32, 256, 56, 56) 100 | embedding_vectors = embedding_vectors.view(B, C, H * W) 101 | 102 | mean = torch.mean(embedding_vectors, dim=0).numpy() # (C, H*W) 103 | 104 | train_outputs[k] = mean 105 | with open(train_feature_filepath, 'wb') as f: 106 | pickle.dump(train_outputs, f) 107 | else: 108 | print('load train set feature from: %s' % train_feature_filepath) 109 | with open(train_feature_filepath, 'rb') as f: 110 | train_outputs = pickle.load(f) 111 | 112 | 113 | if __name__ == '__main__': 114 | main() -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .attention import Attention2D 8 | from .embedding import Embedding2D 9 | 10 | 11 | class Mlp(nn.Module): 12 | def __init__(self, 13 | in_features, 14 | hidden_features=None, 15 | out_features=None, 16 | act_layer=nn.GELU, 17 | drop=0.): 18 | super().__init__() 19 | 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = nn.Linear(in_features, hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Linear(hidden_features, out_features) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.fc2(x) 31 | x = self.drop(x) 32 | 33 | return x 34 | 35 | 36 | class EncoderLayer(nn.Module): 37 | def __init__(self, 38 | self_attention, 39 | cross_attention, 40 | d_model, 41 | d_feed_foward=None, 42 | dropout=0.1): 43 | super(EncoderLayer, self).__init__() 44 | 45 | d_feed_foward = d_feed_foward or 4 * d_model 46 | self.self_attention = self_attention 47 | self.cross_attention = cross_attention 48 | self.ffn = Mlp(d_model, d_feed_foward, drop=dropout) 49 | self.norm1 = nn.LayerNorm(d_model) 50 | self.norm2 = nn.LayerNorm(d_model) 51 | self.dropout = nn.Dropout(dropout) 52 | #self.ffn_proj = Mlp(2 * d_model, d_feed_foward, out_features=d_model, drop=dropout) 53 | 54 | def forward(self, x, ref_x=None, with_intra=True, with_inter=True): 55 | if with_intra and with_inter: # intra correlation + inter correlation 56 | new_x, intra_corr, intra_target = self.self_attention( 57 | x, x, x, 58 | return_attention=True 59 | ) 60 | new_x = x + self.dropout(new_x) 61 | 62 | new_ref_x, inter_corr, inter_target = self.cross_attention( 63 | x, ref_x, ref_x, 64 | return_attention=True 65 | ) 66 | ref_x = x + self.dropout(new_ref_x) 67 | new_x = new_x - ref_x # I2Correlation: residual input 68 | # or concatenation input 69 | # new_x = torch.cat([new_x, ref_x], dim=-1) 70 | # new_x = self.ffn_proj(new_x) 71 | elif with_inter: # only inter correlation 72 | new_x, inter_corr, inter_target = self.cross_attention( 73 | x, ref_x, ref_x, 74 | return_attention=True 75 | ) 76 | new_x = x + self.dropout(new_x) 77 | elif with_intra: # only intra correlation 78 | new_x, intra_corr, intra_target = self.self_attention( 79 | x, x, x, 80 | return_attention=True 81 | ) 82 | new_x = x + self.dropout(new_x) 83 | else: # only patch-wise reconstruction 84 | new_x, _, _ = self.self_attention( 85 | x, x, x, 86 | return_attention=False 87 | ) 88 | new_x = x + self.dropout(new_x) 89 | 90 | y = x = self.norm1(new_x) 91 | y = self.ffn(y) 92 | out = self.norm2(x + y) 93 | 94 | if with_intra and with_inter: 95 | return out, intra_corr, intra_target, inter_corr, inter_target 96 | elif with_intra: 97 | return out, intra_corr, intra_target, None, None 98 | elif with_inter: 99 | return out, None, None, inter_corr, inter_target 100 | else: 101 | return out, None, None, None, None 102 | 103 | 104 | class Encoder(nn.Module): 105 | def __init__(self, encode_layers, norm_layer=None): 106 | super(Encoder, self).__init__() 107 | 108 | self.encode_layers = nn.ModuleList(encode_layers) 109 | self.norm = norm_layer 110 | 111 | def forward(self, x, ref_x=None, with_intra=True, with_inter=True): 112 | intra_corrs_list = [] 113 | intra_targets_list = [] 114 | inter_corrs_list = [] 115 | inter_targets_list = [] 116 | for layer in self.encode_layers: 117 | x, intra_corrs, intra_targets, inter_corrs, inter_targets = layer(x, ref_x, with_intra, with_inter) 118 | intra_corrs_list.append(intra_corrs) 119 | intra_targets_list.append(intra_targets) 120 | inter_corrs_list.append(inter_corrs) 121 | inter_targets_list.append(inter_targets) 122 | 123 | if self.norm is not None: 124 | x = self.norm(x) 125 | 126 | return x, intra_corrs_list, intra_targets_list, inter_corrs_list, inter_targets_list 127 | 128 | 129 | class FOD(nn.Module): 130 | def __init__(self, 131 | seq_len, 132 | in_channels, 133 | out_channels, 134 | d_model=512, 135 | n_heads=4, 136 | n_layers=3, 137 | d_feed_foward_scale=4, 138 | dropout=0.0, 139 | args=None): 140 | super(FOD, self).__init__() 141 | 142 | d_feed_foward = d_model * d_feed_foward_scale 143 | 144 | # embedding 145 | h = w = int(math.sqrt(seq_len)) 146 | self.embedding = Embedding2D(in_channels, d_model, dropout, h=h, w=w) 147 | 148 | self.with_intra = args.with_intra 149 | self.with_inter = args.with_inter 150 | if self.with_inter: 151 | # changing here for non 256x256 input 152 | mappings = {4096: 'layer0', 1024: 'layer1', 256: 'layer2', 64: 'layer3'} 153 | layer_name = mappings[seq_len] 154 | 155 | self.ref_embedding = Embedding2D(in_channels, d_model, dropout, h=h, w=w, with_pos_embed=True, device=args.device) 156 | ref_feature_filepath = os.path.join(args.rfeatures_path, '%s.pkl' % args.class_name) 157 | with open(ref_feature_filepath, 'rb') as f: 158 | ref_feats = pickle.load(f) 159 | ref_feats = ref_feats[layer_name] 160 | self.ref_feats = torch.from_numpy(ref_feats[0] if isinstance(ref_feats, list) else ref_feats).to(args.device) 161 | self.ref_feats = self.ref_feats.unsqueeze(0).repeat([args.batch_size, 1, 1]).permute(0, 2, 1) # (N, L, dim) 162 | 163 | # Encoder 164 | self.encoder = Encoder( 165 | [ 166 | EncoderLayer( 167 | Attention2D( 168 | seq_len, d_model, n_heads, dropout=dropout, device=args.device), 169 | Attention2D( 170 | seq_len, d_model, n_heads, dropout=dropout, device=args.device), 171 | d_model, 172 | d_feed_foward, 173 | dropout=dropout 174 | ) for l in range(n_layers) 175 | ], 176 | norm_layer=torch.nn.LayerNorm(d_model) 177 | ) 178 | 179 | self.projection = nn.Linear(d_model, out_channels, bias=True) 180 | 181 | def forward(self, x, train=True): 182 | # (N, L, dim) -> (N, L, d_model) 183 | emb = self.embedding(x) 184 | if self.with_inter: 185 | ref_emb = self.ref_embedding(self.ref_feats if train else self.ref_feats[0:1, :, :]) 186 | emb, intra_corrs, intra_targets, inter_corrs, inter_targets = self.encoder(emb, ref_emb, self.with_intra, self.with_inter) 187 | else: 188 | emb, intra_corrs, intra_targets, inter_corrs, inter_targets = self.encoder(emb, None, self.with_intra, self.with_inter) 189 | # (N, L, d_model) -> (N, L, dim) 190 | x_out = self.projection(emb) 191 | 192 | return x_out, intra_corrs, intra_targets, inter_corrs, inter_targets 193 | 194 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from sklearn.metrics import auc 10 | from skimage.measure import label, regionprops 11 | 12 | 13 | def to_var(x, volatile=False): 14 | if torch.cuda.is_available(): 15 | x = x.cuda() 16 | return Variable(x, volatile=volatile) 17 | 18 | 19 | def mkdir(directory): 20 | if not os.path.exists(directory): 21 | os.makedirs(directory) 22 | 23 | 24 | def init_seeds(seed=0): 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | 31 | 32 | def embedding_concat(x, y): 33 | # from https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master 34 | B, C1, H1, W1 = x.size() 35 | _, C2, H2, W2 = y.size() 36 | s = int(H1 / H2) 37 | x = F.unfold(x, kernel_size=s, dilation=1, stride=s) 38 | x = x.view(B, C1, -1, H2, W2) 39 | z = x.new_zeros(B, C1 + C2, x.size(2), H2, W2) 40 | for i in range(x.size(2)): 41 | z[:, :, i, :, :] = torch.cat((x[:, :, i, :, :], y), 1) 42 | z = z.view(B, -1, H2 * W2) 43 | z = F.fold(z, kernel_size=s, output_size=(H1, W1), stride=s) 44 | 45 | return z 46 | 47 | 48 | def compute_pro_retrieval_metrics(scores, gt_mask): 49 | """ 50 | calculate segmentation AUPRO, from https://github.com/YoungGod/DFR 51 | """ 52 | max_step = 1000 53 | expect_fpr = 0.3 # default 30% 54 | max_th = scores.max() 55 | min_th = scores.min() 56 | delta = (max_th - min_th) / max_step 57 | ious_mean = [] 58 | ious_std = [] 59 | pros_mean = [] 60 | pros_std = [] 61 | threds = [] 62 | fprs = [] 63 | binary_score_maps = np.zeros_like(scores, dtype=np.bool) 64 | for step in range(max_step): 65 | thred = max_th - step * delta 66 | # segmentation 67 | binary_score_maps[scores <= thred] = 0 68 | binary_score_maps[scores > thred] = 1 69 | pro = [] # per region overlap 70 | iou = [] # per image iou 71 | # pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region 72 | # iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map 73 | for i in range(len(binary_score_maps)): # for i th image 74 | # pro (per region level) 75 | label_map = label(gt_mask[i], connectivity=2) 76 | props = regionprops(label_map) 77 | for prop in props: 78 | x_min, y_min, x_max, y_max = prop.bbox # find the bounding box of an anomaly region 79 | cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max] 80 | # cropped_mask = gt_mask[i][x_min:x_max, y_min:y_max] # bug! 81 | cropped_mask = prop.filled_image # corrected! 82 | intersection = np.logical_and(cropped_pred_label, cropped_mask).astype(np.float32).sum() 83 | pro.append(intersection / prop.area) 84 | # iou (per image level) 85 | intersection = np.logical_and(binary_score_maps[i], gt_mask[i]).astype(np.float32).sum() 86 | union = np.logical_or(binary_score_maps[i], gt_mask[i]).astype(np.float32).sum() 87 | if gt_mask[i].any() > 0: # when the gt have no anomaly pixels, skip it 88 | iou.append(intersection / union) 89 | # against steps and average metrics on the testing data 90 | ious_mean.append(np.array(iou).mean()) 91 | #print("per image mean iou:", np.array(iou).mean()) 92 | ious_std.append(np.array(iou).std()) 93 | pros_mean.append(np.array(pro).mean()) 94 | pros_std.append(np.array(pro).std()) 95 | # fpr for pro-auc 96 | gt_masks_neg = ~gt_mask 97 | fpr = np.logical_and(gt_masks_neg, binary_score_maps).sum() / gt_masks_neg.sum() 98 | fprs.append(fpr) 99 | threds.append(thred) 100 | # as array 101 | threds = np.array(threds) 102 | pros_mean = np.array(pros_mean) 103 | pros_std = np.array(pros_std) 104 | fprs = np.array(fprs) 105 | ious_mean = np.array(ious_mean) 106 | ious_std = np.array(ious_std) 107 | # best per image iou 108 | best_miou = ious_mean.max() 109 | #print(f"Best IOU: {best_miou:.4f}") 110 | # default 30% fpr vs pro, pro_auc 111 | idx = fprs <= expect_fpr # find the indexs of fprs that is less than expect_fpr (default 0.3) 112 | fprs_selected = fprs[idx] 113 | fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1] 114 | pros_mean_selected = pros_mean[idx] 115 | pix_pro_auc = auc(fprs_selected, pros_mean_selected) 116 | 117 | return pix_pro_auc 118 | 119 | 120 | def compute_aupr_retrieval_metrics( 121 | predicted_masks, 122 | ground_truth_masks, 123 | include_optimal_threshold_rates=False, 124 | ): 125 | """ 126 | Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations 127 | and ground truth segmentation masks. 128 | Args: 129 | predicted_masks: [list of np.arrays or np.array] [NxHxW] Contains 130 | generated segmentation masks. 131 | ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains 132 | predefined ground truth segmentation masks 133 | """ 134 | pred_mask = copy.deepcopy(predicted_masks) 135 | gt_mask = copy.deepcopy(ground_truth_masks) 136 | num = 200 137 | out = {} 138 | 139 | if pred_mask is None or gt_mask is None: 140 | for key in out: 141 | out[key].append(float('nan')) 142 | else: 143 | fprs, tprs = [], [] 144 | precisions, f1s = [], [] 145 | gt_mask = np.array(gt_mask, np.uint8) 146 | 147 | t = (gt_mask == 1) 148 | f = ~t 149 | n_true = t.sum() 150 | n_false = f.sum() 151 | th_min = pred_mask.min() - 1e-8 152 | th_max = pred_mask.max() + 1e-8 153 | pred_gt = pred_mask[t] 154 | th_gt_min = pred_gt.min() 155 | th_gt_max = pred_gt.max() 156 | 157 | ''' 158 | Using scikit learn to compute pixel au_roc results in a memory error since it tries to store the NxHxW float score values. 159 | To avoid this, we compute the tp, fp, tn, fn at equally spaced thresholds in the range between min of predicted 160 | scores and maximum of predicted scores 161 | ''' 162 | percents = np.linspace(100, 0, num=num // 2) 163 | th_gt_per = np.percentile(pred_gt, percents) 164 | th_unif = np.linspace(th_gt_max, th_gt_min, num=num // 2) 165 | thresholds = np.concatenate([th_gt_per, th_unif, [th_min, th_max]]) 166 | thresholds = np.flip(np.sort(thresholds)) 167 | 168 | if n_true == 0 or n_false == 0: 169 | raise ValueError("gt_submasks must contains at least one normal and anomaly samples") 170 | 171 | for th in thresholds: 172 | p = (pred_mask > th).astype(np.uint8) 173 | p = (p == 1) 174 | fp = (p & f).sum() 175 | tp = (p & t).sum() 176 | 177 | fpr = fp / n_false 178 | tpr = tp / n_true 179 | if tp + fp > 0: 180 | prec = tp / (tp + fp) 181 | else: 182 | prec = 1.0 183 | if prec > 0. and tpr > 0.: 184 | f1 = (2 * prec * tpr) / (prec + tpr) 185 | else: 186 | f1 = 0.0 187 | fprs.append(fpr) 188 | tprs.append(tpr) 189 | precisions.append(prec) 190 | f1s.append(f1) 191 | 192 | roc_auc = auc(fprs, tprs) 193 | roc_auc = round(roc_auc, 4) 194 | pr_auc = auc(tprs, precisions) 195 | pr_auc = round(pr_auc, 4) 196 | out['roc_auc'] = (roc_auc) 197 | out['pr_auc'] = (pr_auc) 198 | out['fpr'] = (fprs) 199 | out['tpr'] = (tprs) 200 | out['precision'] = (precisions) 201 | out['f1'] = (f1s) 202 | out['thresholds'] = (thresholds) 203 | 204 | return pr_auc 205 | 206 | def rescale(x): 207 | return (x - x.min()) / (x.max() - x.min()) 208 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import timm 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from scipy.ndimage import gaussian_filter 9 | from sklearn.metrics import roc_auc_score 10 | from sklearn.metrics import precision_recall_curve 11 | 12 | from utils.utils import * 13 | from utils.visualizer import Visualizer 14 | from datasets.mvtec import MVTecDataset, MVTEC_CLASS_NAMES 15 | from datasets.btad import BTADDataset, BTAD_CLASS_NAMES 16 | from datasets.mvtec_3d import MVTec3DDataset, MVTEC3D_CLASS_NAMES 17 | from models.model import FOD 18 | from losses import kl_loss, entropy_loss 19 | 20 | 21 | class Trainer(object): 22 | def __init__(self, args): 23 | self.args = args 24 | 25 | if args.class_name in MVTEC_CLASS_NAMES: 26 | train_dataset = MVTecDataset(args, is_train=True) 27 | test_dataset = MVTecDataset(args, is_train=False) 28 | elif args.class_name in BTAD_CLASS_NAMES: 29 | train_dataset = BTADDataset(args.data_path, classname=args.class_name, resize=self.args.inp_size, cropsize=self.args.inp_size, is_train=True) 30 | test_dataset = BTADDataset(args.data_path, classname=args.class_name, resize=self.args.inp_size, cropsize=self.args.inp_size, is_train=False) 31 | elif args.class_name in MVTEC3D_CLASS_NAMES: 32 | train_dataset = MVTec3DDataset(args.data_path, classname=args.class_name, resize=self.args.inp_size, cropsize=self.args.inp_size, is_train=True) 33 | test_dataset = MVTec3DDataset(args.data_path, classname=args.class_name, resize=self.args.inp_size, cropsize=self.args.inp_size, is_train=False) 34 | else: 35 | raise ValueError('Invalid Class Name: {}'.format(args.class_name)) 36 | 37 | kwargs = {'num_workers': args.num_workers, 'pin_memory': True} 38 | self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, **kwargs) 39 | self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=False, **kwargs) 40 | 41 | self.build_model() 42 | self.l2_criterion = nn.MSELoss() 43 | self.cos_criterion = nn.CosineSimilarity(dim=-1) 44 | 45 | def build_model(self): 46 | encoder = timm.create_model(self.args.backbone_arch, features_only=True, 47 | out_indices=[2, 3], pretrained=True) 48 | self.encoder = encoder.to(self.args.device).eval() 49 | 50 | feat_dims = encoder.feature_info.channels() 51 | print("Feature Dimensions:", feat_dims) 52 | 53 | models = [] 54 | self.seq_lens = [1024, 256] 55 | self.ws = [32, 16] # feature map height/width 56 | for seq_len, in_channels, d_model in zip(self.seq_lens, feat_dims, [256, 512]): 57 | model = FOD(seq_len=seq_len, 58 | in_channels=in_channels, 59 | out_channels=in_channels, 60 | d_model=d_model, 61 | n_heads=8, 62 | n_layers=3, 63 | args=self.args) 64 | print('One Model...Done') 65 | models.append(model.to(self.args.device)) 66 | self.models = models 67 | print('Creating Models...Done') 68 | params = list(models[0].parameters()) 69 | for l in range(1, self.args.feature_levels): 70 | params += list(models[l].parameters()) 71 | self.optimizer = torch.optim.Adam(params, lr=self.args.lr) 72 | self.avg_pool = torch.nn.AvgPool2d(3, 1, 1) 73 | 74 | def train(self): 75 | path = os.path.join(self.args.save_path, self.args.save_prefix) 76 | if not os.path.exists(path): 77 | os.makedirs(path) 78 | 79 | start_time = time.time() 80 | train_steps = len(self.train_loader) 81 | best_img_auc, best_pix_auc = 0.0, 0.0 82 | for epoch in range(self.args.num_epochs): 83 | print("======================TRAIN MODE======================") 84 | iter_count = 0 85 | loss_rec_list, loss_intra_entropy_list, loss_inter_entropy_list = [], [], [] 86 | loss_corr_list, loss_target_list = [], [] 87 | 88 | epoch_time = time.time() 89 | for model in self.models: 90 | model.train() 91 | for i, (images, _, _, _, _) in enumerate(self.train_loader): 92 | iter_count += 1 93 | images = images.float().to(self.args.device) # (N, 3, H, W) 94 | 95 | with torch.no_grad(): 96 | features = self.encoder(images) 97 | 98 | for fl in range(self.args.feature_levels): 99 | m = torch.nn.AvgPool2d(3, 1, 1) 100 | input = m(features[fl]) 101 | N, D, _, _ = input.shape 102 | input = input.permute(0, 2, 3, 1).reshape(N, -1, D) 103 | 104 | # output: reconstructed features, (N, L, dim) 105 | # intra_corrs: intra correlations, list[(N, num_heads, L, L)] 106 | # intra_targets: intra target correlations, list[(N, num_heads, L, L)] 107 | # inter_corrs: inter correlations, list[(N, num_heads, L, L)] 108 | # inter_targets: inter target correlations, list[(N, num_heads, L, L)] 109 | # len of list is attention layers of transformer 110 | model = self.models[fl] 111 | output, intra_corrs, intra_targets, inter_corrs, inter_targets = model(input) 112 | 113 | if self.args.with_intra: 114 | loss_intra1, loss_intra2, loss_intra_entropy = 0.0, 0.0, 0.0 115 | for l in range(len(intra_targets)): 116 | L = intra_targets[l].shape[-1] 117 | norm_targets = (intra_targets[l] / torch.unsqueeze(torch.sum(intra_targets[l], dim=-1), dim=-1).repeat(1, 1, 1, L)).detach() 118 | # optimizing intra correlations 119 | loss_intra1 += torch.mean(kl_loss(norm_targets, intra_corrs[l])) + torch.mean(kl_loss(intra_corrs[l], norm_targets)) 120 | 121 | norm_targets = intra_targets[l] / torch.unsqueeze(torch.sum(intra_targets[l], dim=-1), dim=-1).repeat(1, 1, 1, L) 122 | loss_intra2 += torch.mean(kl_loss(norm_targets, intra_corrs[l].detach())) + torch.mean(kl_loss(intra_corrs[l].detach(), norm_targets)) 123 | 124 | loss_intra_entropy += torch.mean(entropy_loss(intra_corrs[l])) 125 | 126 | loss_intra1 = loss_intra1 / len(intra_targets) 127 | loss_intra2 = loss_intra2 / len(intra_targets) 128 | loss_intra_entropy = loss_intra_entropy / len(intra_targets) 129 | 130 | if self.args.with_inter: 131 | loss_inter1, loss_inter2, loss_inter_entropy = 0.0, 0.0, 0.0 132 | for l in range(len(inter_targets)): 133 | L = inter_targets[l].shape[-1] 134 | norm_targets = (inter_targets[l] / torch.unsqueeze(torch.sum(inter_targets[l], dim=-1), dim=-1).repeat(1, 1, 1, L)).detach() 135 | # optimizing inter correlations 136 | loss_inter1 += torch.mean(kl_loss(norm_targets, inter_corrs[l])) + torch.mean(kl_loss(inter_corrs[l], norm_targets)) 137 | 138 | norm_targets = inter_targets[l] / torch.unsqueeze(torch.sum(inter_targets[l], dim=-1), dim=-1).repeat(1, 1, 1, L) 139 | loss_inter2 += torch.mean(kl_loss(norm_targets, inter_corrs[l].detach())) + torch.mean(kl_loss(inter_corrs[l].detach(), norm_targets)) 140 | 141 | loss_inter_entropy += torch.mean(entropy_loss(inter_corrs[l])) 142 | 143 | loss_inter1 = loss_inter1 / len(inter_targets) 144 | loss_inter2 = loss_inter2 / len(inter_targets) 145 | loss_inter_entropy = loss_inter_entropy / len(inter_targets) 146 | 147 | loss_rec = self.l2_criterion(output, input) + torch.mean(1 - self.cos_criterion(output, input)) # mse + cosine 148 | 149 | if self.args.with_intra and self.args.with_inter: # patch-wise reconstruction + intra correlation + inter correlation 150 | loss1 = loss_rec + self.args.lambda1 * loss_intra2 - self.args.lambda1 * loss_inter2 151 | loss2 = loss_rec - self.args.lambda1 * loss_intra1 - self.args.lambda2 * loss_intra_entropy + self.args.lambda1 * loss_inter1 + self.args.lambda2 * loss_inter_entropy 152 | elif self.args.with_intra: # patch-wise reconstruction + intra correlation 153 | loss1 = loss_rec + self.args.lambda1 * loss_intra2 154 | loss2 = loss_rec - self.args.lambda1 * loss_intra1 - self.args.lambda2 * loss_intra_entropy 155 | elif self.args.with_inter: # patch-wise reconstruction + inter correlation 156 | loss1 = loss_rec - self.args.lambda1 * loss_inter2 157 | loss2 = loss_rec + self.args.lambda1 * loss_inter1 + self.args.lambda2 * loss_inter_entropy 158 | else: # only patch-wise reconstruction 159 | loss = loss_rec 160 | 161 | loss_rec_list.append(loss_rec.item()) 162 | if self.args.with_intra and self.args.with_inter: 163 | loss_target_list.append((loss_intra2 - loss_inter2).item()) 164 | loss_corr_list.append((-loss_intra1 + loss_inter1).item()) 165 | loss_intra_entropy_list.append(loss_intra_entropy.item()) 166 | loss_inter_entropy_list.append(loss_inter_entropy.item()) 167 | elif self.args.with_intra: 168 | loss_target_list.append((loss_intra2).item()) 169 | loss_corr_list.append((-loss_intra1).item()) 170 | loss_intra_entropy_list.append(loss_intra_entropy.item()) 171 | elif self.args.with_inter: 172 | loss_target_list.append((-loss_inter2).item()) 173 | loss_corr_list.append((loss_inter1).item()) 174 | loss_inter_entropy_list.append(loss_inter_entropy.item()) 175 | 176 | self.optimizer.zero_grad() 177 | if not self.args.with_intra and not self.args.with_inter: # only patch-wise reconstruction 178 | loss.backward() 179 | else: 180 | # Two-stage optimization 181 | loss1.backward(retain_graph=True) 182 | loss2.backward() 183 | self.optimizer.step() 184 | 185 | speed = (time.time() - start_time) / iter_count 186 | left_time = speed * ((self.args.num_epochs - epoch) * train_steps - i) 187 | print("Epoch: {} cost time: {}s | speed: {:.4f}s/iter | left time: {:.4f}s".format(epoch + 1, time.time() - epoch_time, speed, left_time)) 188 | iter_count = 0 189 | start_time = time.time() 190 | 191 | if self.args.with_intra and self.args.with_inter: 192 | print( 193 | "Epoch: {0}, Steps: {1} | Rec Loss: {2:.7f} | Target Loss: {3:.7f} | Corr Loss: {4:.7f} | Intra Entropy: {5:.7f} | Inter Entropy: {6:.7f}".format( 194 | epoch + 1, train_steps, np.average(loss_rec_list), np.average(loss_target_list), np.average(loss_corr_list), np.average(loss_intra_entropy_list), np.average(loss_inter_entropy_list))) 195 | elif self.args.with_intra: 196 | print( 197 | "Epoch: {0}, Steps: {1} | Rec Loss: {2:.7f} | Target Loss: {3:.7f} | Corr Loss: {4:.7f} | Intra Entropy: {5:.7f}".format( 198 | epoch + 1, train_steps, np.average(loss_rec_list), np.average(loss_target_list), np.average(loss_corr_list), np.average(loss_intra_entropy_list))) 199 | elif self.args.with_inter: 200 | print( 201 | "Epoch: {0}, Steps: {1} | Rec Loss: {2:.7f} | Target Loss: {3:.7f} | Corr Loss: {4:.7f} | Inter Entropy: {5:.7f}".format( 202 | epoch + 1, train_steps, np.average(loss_rec_list), np.average(loss_target_list), np.average(loss_corr_list), np.average(loss_inter_entropy_list))) 203 | else: 204 | print( 205 | "Epoch: {0}, Steps: {1} | Rec Loss: {2:.7f}".format(epoch + 1, train_steps, np.average(loss_rec_list))) 206 | 207 | img_auc, pix_auc = self.test(vis=False) 208 | 209 | print("Epoch: {0}, Class Name: {1}, Image AUC: {2:.7f} | Pixel AUC: {3:.7f}".format(epoch + 1, self.args.class_name, img_auc, pix_auc)) 210 | 211 | if img_auc > best_img_auc: 212 | best_img_auc = img_auc 213 | state = {'state_dict': [model.state_dict() for model in self.models]} 214 | torch.save(state, os.path.join(path, self.args.class_name + '-img.pth')) 215 | if pix_auc > best_pix_auc: 216 | best_pix_auc = pix_auc 217 | state = {'state_dict': [model.state_dict() for model in self.models]} 218 | torch.save(state, os.path.join(path, self.args.class_name + '-pix.pth')) 219 | 220 | return best_img_auc, best_pix_auc 221 | 222 | def test(self, vis=False, checkpoint_path=None): 223 | if checkpoint_path is not None: 224 | checkpoint = torch.load(os.path.join(checkpoint_path, self.args.class_name + '-pix.pth')) 225 | state_dict = checkpoint['state_dict'] 226 | for i, model in enumerate(self.models): 227 | model.load_state_dict(state_dict[i]) 228 | for model in self.models: 229 | model.eval() 230 | temperature = 1 231 | 232 | print("======================TEST MODE======================") 233 | 234 | l2_criterion = nn.MSELoss(reduction='none') 235 | cos_criterion = nn.CosineSimilarity(dim=-1) 236 | 237 | scores_list = [list() for _ in range(self.args.feature_levels)] 238 | test_imgs, gt_label_list, gt_mask_list, file_names, img_types = [], [], [], [], [] 239 | for i, (image, label, mask, file_name, img_type) in enumerate(self.test_loader): 240 | test_imgs.append(image.cpu().numpy()) 241 | gt_label_list.extend(label) 242 | gt_mask_list.extend(mask.numpy()) 243 | file_names.extend(file_name) 244 | img_types.extend(img_type) 245 | 246 | image = image.float().to(self.args.device) 247 | 248 | with torch.no_grad(): 249 | features = self.encoder(image) 250 | 251 | for fl in range(self.args.feature_levels): 252 | m = torch.nn.AvgPool2d(3, 1, 1) 253 | input = m(features[fl]) 254 | N, D, _, _ = input.shape 255 | input = input.permute(0, 2, 3, 1).reshape(N, -1, D) 256 | 257 | model = self.models[fl] 258 | output, intra_corrs, intra_targets, inter_corrs, inter_targets = model(input, train=False) 259 | 260 | rec_score = torch.mean(l2_criterion(input, output), dim=-1) + 1 - cos_criterion(input, output) 261 | 262 | if self.args.with_intra: 263 | correlations1, correlations2, entropys = 0.0, 0.0, 0.0 264 | for l in range(len(intra_targets)): 265 | L = intra_targets[l].shape[-1] 266 | norm_targets = intra_targets[l] / torch.unsqueeze(torch.sum(intra_targets[l], dim=-1), dim=-1).repeat(1, 1, 1, L) 267 | correlations1 += kl_loss(intra_corrs[l], norm_targets) * temperature 268 | correlations2 += kl_loss(norm_targets, intra_corrs[l]) * temperature 269 | 270 | entropys += entropy_loss(intra_corrs[l]) 271 | 272 | corrs = (correlations1 + correlations2) / len(intra_targets) 273 | intra_score = torch.softmax((-corrs), dim=-1) 274 | # entropys = entropys / len(intra_targets) 275 | # ent_score = torch.softmax(-entropys, dim=-1) 276 | if self.args.with_inter: 277 | correlations1, correlations2, entropys = 0.0, 0.0, 0.0 278 | for l in range(len(inter_targets)): 279 | L = inter_targets[l].shape[-1] 280 | norm_targets = inter_targets[l] / torch.unsqueeze(torch.sum(inter_targets[l], dim=-1), dim=-1).repeat(1, 1, 1, L) 281 | correlations1 += kl_loss(inter_corrs[l], norm_targets) * temperature 282 | correlations2 += kl_loss(norm_targets, inter_corrs[l]) * temperature 283 | 284 | entropys += entropy_loss(inter_corrs[l]) 285 | 286 | corrs = (correlations1 + correlations2) / len(inter_targets) 287 | inter_score = torch.softmax((-corrs), dim=-1) 288 | inter_score = torch.max(inter_score) - inter_score 289 | # entropys = entropys / len(inter_targets) 290 | # ent_score = torch.softmax(-entropys, dim=-1) 291 | # ent_score = torch.max(ent_score) - ent_score 292 | 293 | if self.args.with_intra and self.args.with_inter: 294 | # we find that only use inter_score can get slightly better results, 295 | # but in training the intra-correlations learning is still necessary for achieving the best results 296 | score = rec_score * inter_score 297 | elif self.args.with_intra: 298 | score = rec_score * intra_score 299 | elif self.args.with_inter: 300 | score = rec_score * inter_score 301 | else: 302 | score = rec_score 303 | score = score.detach() # (N, L) 304 | score = score.reshape(score.shape[0], self.ws[fl], self.ws[fl]) 305 | score = F.interpolate(score.unsqueeze(1), 306 | size=self.args.inp_size, mode='bilinear', align_corners=True).squeeze().cpu().numpy() 307 | scores_list[fl].append(score) 308 | 309 | lvl_scores = [] 310 | for l in range(self.args.feature_levels): 311 | lvl_score = np.stack(scores_list[l], axis=0) # (N, 256, 256) 312 | lvl_scores.append(lvl_score) 313 | 314 | scores = np.zeros_like(lvl_scores[0]) 315 | for l in range(self.args.feature_levels): 316 | scores += lvl_scores[l] 317 | scores = scores / self.args.feature_levels 318 | 319 | # scores = np.ones_like(lvl_scores[0]) 320 | # for l in range(self.args.feature_levels): 321 | # scores *= lvl_scores[l] 322 | 323 | gt_mask = np.squeeze(np.asarray(gt_mask_list, dtype=np.bool), axis=1) 324 | pix_auc = roc_auc_score(gt_mask.flatten(), scores.flatten()) 325 | 326 | for i in range(scores.shape[0]): 327 | scores[i] = gaussian_filter(scores[i], sigma=4) 328 | 329 | # image and pixel level auroc 330 | img_scores = np.max(scores, axis=(1, 2)) 331 | gt_label = np.asarray(gt_label_list, dtype=np.bool) 332 | img_auc = roc_auc_score(gt_label, img_scores) 333 | 334 | if vis: 335 | precision, recall, thresholds = precision_recall_curve(gt_label, img_scores) 336 | a = 2 * precision * recall 337 | b = precision + recall 338 | f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0) 339 | img_threshold = thresholds[np.argmax(f1)] 340 | 341 | visulizer = Visualizer(f'vis_results/{self.args.save_prefix}/{self.args.class_name}') 342 | max_score = np.max(scores) 343 | min_score = np.min(scores) 344 | scores = (scores - min_score) / (max_score - min_score) 345 | test_imgs = np.concatenate(test_imgs, axis=0) 346 | visulizer.plot(test_imgs, scores, img_scores, gt_mask, file_names, img_types, img_threshold) 347 | 348 | return img_auc, pix_auc 349 | --------------------------------------------------------------------------------