├── README.md ├── dataset.py ├── loss_fun.py ├── model.py ├── requirements.txt ├── run_training.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # MemSeg 2 | 3 | Official implementation for MemSeg: [A semi-supervised method for image surface defect detection using differences and commonalities](https://arxiv.org/abs/2205.00908) 4 | 5 | This project contains only the inference code for MemSeg. 6 | 7 | The weight files can be obtained from the following link:https://drive.google.com/drive/folders/1elQDo0vaW7NTMYWcNcw_gW65lvZgVbno?usp=share_link 8 | 9 | The complete codes of MemSeg will be open sourced as soon as they are collated. 10 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from pathlib import Path 4 | from PIL import Image 5 | from joblib import Parallel, delayed 6 | from torchvision import transforms 7 | 8 | class Repeat(Dataset): 9 | def __init__(self, org_dataset, new_length): 10 | self.org_dataset = org_dataset 11 | self.org_length = len(self.org_dataset) 12 | self.new_length = new_length 13 | 14 | def __len__(self): 15 | return self.new_length 16 | 17 | def __getitem__(self, idx): 18 | return self.org_dataset[idx % self.org_length] 19 | 20 | class MVTecAT(Dataset): 21 | """Face Landmarks dataset.""" 22 | 23 | def __init__(self, root_dir, defect_name, size, transform=None, mode="train",memory_number = 15): 24 | """ 25 | Args: 26 | root_dir (string): Directory with the MVTec AD dataset. 27 | defect_name (string): defect to load.待检测工件的名称 28 | transform: Transform to apply to data 29 | mode: "train" loads training samples "test" test samples default "train" 30 | """ 31 | self.root_dir = Path(root_dir) 32 | self.defect_name = defect_name 33 | self.transform = transform 34 | self.mode = mode 35 | self.size = size 36 | 37 | self.test_transform = transforms.Compose([]) 38 | self.test_transform.transforms.append(transforms.ToTensor()) 39 | 40 | # find test images 41 | if self.mode == "train": 42 | self.image_names = list((self.root_dir / defect_name / "train" / "good").glob("*.png")) 43 | self.image_names.sort() 44 | self.image_names = self.image_names[memory_number:] 45 | #print(self.image_names) 46 | print("loading images") 47 | # during training we cache the smaller images for performance reasons (not a good coding style) 48 | #self.imgs = [Image.open(file).resize((size,size)).convert("RGB") for file in self.image_names] 49 | train_transform = transforms.Compose([]) 50 | train_transform.transforms.append(transforms.Resize(size, Image.ANTIALIAS)) 51 | self.imgs = Parallel(n_jobs=10)(delayed(lambda file: train_transform(Image.open(file).convert("RGB")))(file) for file in self.image_names) 52 | print(f"loaded {len(self.imgs)} images") 53 | else: 54 | #test mode 55 | self.image_names = list((self.root_dir / defect_name / "test").glob(str(Path("*") / "*.png"))) 56 | self.image_names.sort() 57 | #self.imagemask_names = list((self.root_dir / defect_name / "ground_truth").glob(str(Path("*") / "*.png"))) 58 | 59 | def __len__(self): 60 | return len(self.image_names) 61 | 62 | def __getitem__(self, idx): 63 | if self.mode == "train": 64 | # img = Image.open(self.image_names[idx]) 65 | # img = img.convert("RGB") 66 | img = self.imgs[idx].copy() 67 | if self.transform is not None: 68 | img = self.transform(img) 69 | return img 70 | 71 | 72 | 73 | else: 74 | filename = self.image_names[idx] 75 | 76 | try: 77 | temp = self.image_names[idx] 78 | temp = str(temp) 79 | list_path = temp.split('/') 80 | list_path[-3] = 'ground_truth' 81 | Path_mask = '' 82 | for ii in list_path: 83 | Path_mask += ii 84 | Path_mask += "/" 85 | Path_mask = Path_mask[0:-5] 86 | Path_mask += '_mask.png' 87 | img_mask = Image.open(Path_mask).convert('L') 88 | img_mask = img_mask.resize((int(self.size), int(self.size))) 89 | img_mask = self.test_transform(img_mask) 90 | except: 91 | img_mask = torch.zeros((1,256,256)) 92 | 93 | 94 | # print(img_mask.shape) 95 | label = filename.parts[-2] 96 | img = Image.open(filename) 97 | img = img.resize((self.size,self.size)).convert("RGB") 98 | if self.transform is not None: 99 | img = self.transform(img) 100 | return img, label != "good",img_mask 101 | -------------------------------------------------------------------------------- /loss_fun.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from math import exp 6 | 7 | class FocalLoss(nn.Module): 8 | """ 9 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 10 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 11 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 12 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 13 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 14 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 15 | focus on hard misclassified example 16 | :param smooth: (float,double) smooth value when cross entropy 17 | :param balance_index: (int) balance class index, should be specific when alpha is float 18 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 19 | """ 20 | 21 | def __init__(self, apply_nonlin=None, alpha=None, gamma=4, balance_index=0, smooth=1e-5, size_average=True): 22 | super(FocalLoss, self).__init__() 23 | self.apply_nonlin = apply_nonlin 24 | self.alpha = alpha 25 | self.gamma = gamma 26 | self.balance_index = balance_index 27 | self.smooth = smooth 28 | self.size_average = size_average 29 | 30 | if self.smooth is not None: 31 | if self.smooth < 0 or self.smooth > 1.0: 32 | raise ValueError('smooth value should be in [0,1]') 33 | 34 | def forward(self, logit, target): 35 | if self.apply_nonlin is not None: 36 | logit = self.apply_nonlin(logit) 37 | num_class = logit.shape[1] 38 | 39 | if logit.dim() > 2: 40 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 41 | logit = logit.view(logit.size(0), logit.size(1), -1) 42 | logit = logit.permute(0, 2, 1).contiguous() 43 | logit = logit.view(-1, logit.size(-1)) 44 | target = torch.squeeze(target, 1) 45 | target = target.view(-1, 1) 46 | alpha = self.alpha 47 | 48 | if alpha is None: 49 | alpha = torch.ones(num_class, 1) 50 | alpha[1] = alpha[1]*1.2 51 | elif isinstance(alpha, (list, np.ndarray)): 52 | assert len(alpha) == num_class 53 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 54 | alpha = alpha / alpha.sum() 55 | elif isinstance(alpha, float): 56 | alpha = torch.ones(num_class, 1) 57 | alpha = alpha * (1 - self.alpha) 58 | alpha[self.balance_index] = self.alpha 59 | 60 | else: 61 | raise TypeError('Not support alpha type') 62 | 63 | if alpha.device != logit.device: 64 | alpha = alpha.to(logit.device) 65 | 66 | idx = target.cpu().long() 67 | 68 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 69 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 70 | if one_hot_key.device != logit.device: 71 | one_hot_key = one_hot_key.to(logit.device) 72 | 73 | if self.smooth: 74 | one_hot_key = torch.clamp( 75 | one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) 76 | pt = (one_hot_key * logit).sum(1) + self.smooth 77 | logpt = pt.log() 78 | 79 | gamma = self.gamma 80 | 81 | alpha = alpha[idx] 82 | alpha = torch.squeeze(alpha) 83 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 84 | 85 | if self.size_average: 86 | loss = loss.mean() 87 | return loss 88 | 89 | def gaussian(window_size, sigma): 90 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 91 | return gauss/gauss.sum() 92 | 93 | def create_window(window_size, channel=1): 94 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 95 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 96 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 97 | return window 98 | 99 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 100 | if val_range is None: 101 | if torch.max(img1) > 128: 102 | max_val = 255 103 | else: 104 | max_val = 1 105 | 106 | if torch.min(img1) < -0.5: 107 | min_val = -1 108 | else: 109 | min_val = 0 110 | l = max_val - min_val 111 | else: 112 | l = val_range 113 | 114 | padd = window_size//2 115 | (_, channel, height, width) = img1.size() 116 | if window is None: 117 | real_size = min(window_size, height, width) 118 | window = create_window(real_size, channel=channel).to(img1.device) 119 | 120 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 121 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 122 | 123 | mu1_sq = mu1.pow(2) 124 | mu2_sq = mu2.pow(2) 125 | mu1_mu2 = mu1 * mu2 126 | 127 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 128 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 129 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 130 | 131 | c1 = (0.01 * l) ** 2 132 | c2 = (0.03 * l) ** 2 133 | 134 | v1 = 2.0 * sigma12 + c2 135 | v2 = sigma1_sq + sigma2_sq + c2 136 | cs = torch.mean(v1 / v2) # contrast sensitivity 137 | 138 | ssim_map = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2) 139 | 140 | if size_average: 141 | ret = ssim_map.mean() 142 | else: 143 | ret = ssim_map.mean(1).mean(1).mean(1) 144 | 145 | if full: 146 | return ret, cs 147 | return ret, ssim_map 148 | 149 | 150 | class SSIM(torch.nn.Module): 151 | def __init__(self, window_size=11, size_average=True, val_range=None): 152 | super(SSIM, self).__init__() 153 | self.window_size = window_size 154 | self.size_average = size_average 155 | self.val_range = val_range 156 | 157 | # Assume 1 channel for SSIM 158 | self.channel = 1 159 | self.window = create_window(window_size).cuda() 160 | 161 | def forward(self, img1, img2): 162 | (_, channel, _, _) = img1.size() 163 | 164 | if channel == self.channel and self.window.dtype == img1.dtype: 165 | window = self.window 166 | else: 167 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 168 | self.window = window 169 | self.channel = channel 170 | 171 | s_score, ssim_map = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 172 | return 1.0 - s_score 173 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet18,wide_resnet50_2 5 | from torchsummary import summary 6 | from collections import OrderedDict 7 | import pickle 8 | 9 | import torch 10 | import torch.nn as nn 11 | import math 12 | import torch.nn.functional as F 13 | 14 | 15 | class h_sigmoid(nn.Module): 16 | def __init__(self, inplace=True): 17 | super(h_sigmoid, self).__init__() 18 | self.relu = nn.ReLU6(inplace=inplace) 19 | 20 | def forward(self, x): 21 | return self.relu(x + 3) / 6 22 | 23 | 24 | class h_swish(nn.Module): 25 | def __init__(self, inplace=True): 26 | super(h_swish, self).__init__() 27 | self.sigmoid = h_sigmoid(inplace=inplace) 28 | 29 | def forward(self, x): 30 | return x * self.sigmoid(x) 31 | 32 | 33 | class CoordAtt(nn.Module): 34 | def __init__(self, inp, oup, reduction=32): 35 | super(CoordAtt, self).__init__() 36 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) 37 | self.pool_w = nn.AdaptiveAvgPool2d((1, None)) 38 | 39 | mip = max(8, inp // reduction) 40 | 41 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) 42 | self.bn1 = nn.BatchNorm2d(mip) 43 | self.act = h_swish() 44 | 45 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 46 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 47 | 48 | def forward(self, x): 49 | #identity = x 50 | 51 | n, c, h, w = x.size() 52 | x_h = self.pool_h(x) 53 | x_w = self.pool_w(x).permute(0, 1, 3, 2) 54 | 55 | y = torch.cat([x_h, x_w], dim=2) 56 | y = self.conv1(y) 57 | y = self.bn1(y) 58 | y = self.act(y) 59 | 60 | x_h, x_w = torch.split(y, [h, w], dim=2) 61 | x_w = x_w.permute(0, 1, 3, 2) 62 | 63 | a_h = self.conv_h(x_h).sigmoid() 64 | a_w = self.conv_w(x_w).sigmoid() 65 | 66 | out = a_w * a_h 67 | 68 | return out 69 | 70 | 71 | 72 | 73 | 74 | class Decoder (nn.Module): 75 | def __init__(self, base_width, out_channels=1): 76 | super(Decoder , self).__init__() 77 | 78 | self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 79 | nn.Conv2d(512, 256, kernel_size=3, padding=1), 80 | nn.BatchNorm2d(256), 81 | nn.ReLU(inplace=True)) 82 | self.db1 = nn.Sequential( 83 | nn.Conv2d(256+256, 256, kernel_size=3, padding=1), 84 | nn.BatchNorm2d(256), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 87 | nn.BatchNorm2d(256), 88 | nn.ReLU(inplace=True) 89 | ) 90 | self.db1_shor_cut = nn.Sequential( nn.Conv2d(256+256, 256, kernel_size=3, padding=1), 91 | nn.BatchNorm2d(256), 92 | nn.ReLU(inplace=True),) 93 | 94 | 95 | 96 | self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 97 | nn.Conv2d(256, 128, kernel_size=3, padding=1), 98 | nn.BatchNorm2d(128), 99 | nn.ReLU(inplace=True)) 100 | self.db2 = nn.Sequential( 101 | nn.Conv2d(128+128, 128, kernel_size=3, padding=1), 102 | nn.BatchNorm2d(128), 103 | nn.ReLU(inplace=True), 104 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 105 | nn.BatchNorm2d(128), 106 | nn.ReLU(inplace=True) 107 | ) 108 | self.db2_shor_cut = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), 109 | nn.BatchNorm2d(128), 110 | nn.ReLU(inplace=True), ) 111 | #self.inception = InceptionB(192) 112 | 113 | self.up3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 114 | nn.Conv2d(128, 64, kernel_size=3, padding=1), 115 | nn.BatchNorm2d(64), 116 | nn.ReLU(inplace=True)) 117 | self.db3 = nn.Sequential( 118 | nn.Conv2d(64+64,64, kernel_size=3, padding=1), 119 | nn.BatchNorm2d(64), 120 | nn.ReLU(inplace=True), 121 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 122 | nn.BatchNorm2d(64), 123 | nn.ReLU(inplace=True) 124 | ) 125 | self.db3_shor_cut = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), 126 | nn.BatchNorm2d(64), 127 | nn.ReLU(inplace=True), ) 128 | 129 | self.up4 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 130 | nn.Conv2d(64, 32, kernel_size=3, padding=1), 131 | nn.BatchNorm2d( 32), 132 | nn.ReLU(inplace=True)) 133 | 134 | self.db4 = nn.Sequential( 135 | nn.Conv2d(32+64, 48, kernel_size=3, padding=1), 136 | nn.BatchNorm2d(48), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(48, 48, kernel_size=3, padding=1), 139 | nn.BatchNorm2d(48), 140 | nn.ReLU(inplace=True) 141 | ) 142 | self.db4_shor_cut = nn.Sequential(nn.Conv2d(96, 48, kernel_size=3, padding=1), 143 | nn.BatchNorm2d(48), 144 | nn.ReLU(inplace=True), ) 145 | 146 | self.up5 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 147 | nn.Conv2d(48, 48, kernel_size=3, padding=1), 148 | nn.BatchNorm2d(48), 149 | nn.ReLU(inplace=True)) 150 | 151 | self.se = SE(in_chnls=32,ratio=4) 152 | 153 | 154 | self.db5 = nn.Sequential( 155 | nn.Conv2d(48, 24, kernel_size=3, padding=1), 156 | nn.BatchNorm2d(24), 157 | nn.ReLU(inplace=True), 158 | nn.Conv2d(24, 24, kernel_size=3, padding=1), 159 | nn.BatchNorm2d(24), 160 | nn.ReLU(inplace=True), 161 | ) 162 | self.db5_shor_cut = nn.Sequential(nn.Conv2d(48, 24, kernel_size=3, padding=1), 163 | nn.BatchNorm2d(24), 164 | nn.ReLU(inplace=True), ) 165 | 166 | self.res_bn_relu = nn.Sequential(nn.BatchNorm2d(24), 167 | nn.ReLU(inplace=True), ) 168 | self.final_out = nn.Sequential( 169 | nn.Conv2d(24, 24, kernel_size=3, padding=1), 170 | nn.BatchNorm2d(24), 171 | nn.ReLU(inplace=True), 172 | nn.Conv2d(24, 2, kernel_size=3, padding=1), 173 | #nn.Sigmoid(), 174 | 175 | ) 176 | 177 | 178 | self.Init() 179 | 180 | 181 | def Init(self): 182 | mo_list = [self.up1,self.up2,self.up3,self.up4,self.up5, self.db1,self.db2,self.db3,self.db4,self.db5,self.final_out,self.se,self.res_bn_relu, 183 | self.db1_shor_cut,self.db2_shor_cut,self.db3_shor_cut,self.db4_shor_cut,self.db5_shor_cut] 184 | for m in mo_list: 185 | #for m in self.block_down1.modules(): 186 | if isinstance(m, nn.Conv2d): 187 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 188 | elif isinstance(m, nn.BatchNorm2d): 189 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 190 | torch.nn.init.constant_(m.bias.data, 0.0) 191 | elif isinstance(m, nn.Linear): 192 | torch.nn.init.xavier_uniform_(m.weight) 193 | torch.nn.init.constant_(m.bias, 0) 194 | 195 | def forward(self, forward_out,aggregate1,aggregate2,aggregate3,bn_out_128x128,x): 196 | up1 = self.up1(forward_out) 197 | cat = torch.cat((up1,aggregate3),dim=1) 198 | db1 = self.db1(cat) 199 | #db1 = db1 + self.db1_shor_cut(cat) 200 | 201 | up2 = self.up2(db1) 202 | cat = torch.cat((up2,aggregate2),dim=1) 203 | db2 = self.db2(cat) 204 | #db2 = db2 + self.db2_shor_cut(cat) 205 | 206 | 207 | up3 = self.up3(db2) 208 | cat = torch.cat((up3,aggregate1),dim=1) 209 | db3 = self.db3(cat) 210 | #db3 = db3 + self.db3_shor_cut(cat) 211 | 212 | up4 = self.up4(db3) 213 | cat = torch.cat((up4, bn_out_128x128), dim=1) 214 | db4 = self.db4(cat) 215 | #db4 = db4 + self.db4_shor_cut(cat) 216 | 217 | up5 = self.up5(db4) 218 | db5 = self.db5(up5) 219 | 220 | 221 | 222 | 223 | out = self.final_out(db5) 224 | 225 | 226 | 227 | return out 228 | 229 | 230 | 231 | class Encoder(nn.Module): 232 | def __init__(self, pretrained=True, head_layers=[512,512,512,512,512,512,512,512,128], num_classes=2,data_type=None): 233 | super(Encoder, self).__init__() 234 | #self.resnet18 = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=pretrained) 235 | self.resnet18 = resnet18(pretrained=pretrained) 236 | self.resnet18.avgpool=nn.Identity() 237 | self.resnet18.fc = nn.Identity() 238 | self.Init() 239 | 240 | 241 | 242 | 243 | def Init(self,): 244 | for param in self.resnet18.parameters(): 245 | param.requires_grad = False 246 | for param in self.resnet18.layer4.parameters(): 247 | param.requires_grad = True 248 | 249 | 250 | 251 | 252 | def forward(self, x): 253 | 254 | #layer1_out_64x64, layer2_out_32x32, layer3_out_16x16, layer_final_256x256 = self.get_feature(x) 255 | 256 | x = self.resnet18.conv1(x) 257 | x = self.resnet18.bn1(x) 258 | x = self.resnet18.relu(x) 259 | x = self.resnet18.maxpool(x) 260 | x = self.resnet18.layer1(x) 261 | x = self.resnet18.layer2(x) 262 | x = self.resnet18.layer3(x) 263 | forward_out = self.resnet18.layer4(x) 264 | 265 | #bn_out_128x128 = 5555 266 | return forward_out,#bn_out_128x128,layer1_out_64x64,layer2_out_32x32,layer3_out_16x16,layer_final_256x256 267 | 268 | 269 | 270 | 271 | class BasicConv2d(nn.Module): 272 | def __init__(self, in_channels, out_channels, **kwargs): 273 | super(BasicConv2d, self).__init__() 274 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 275 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 276 | def forward(self, x): 277 | x = self.conv(x) 278 | x = self.bn(x) 279 | return F.relu(x, inplace=True) 280 | 281 | 282 | 283 | class InceptionB(nn.Module): 284 | def __init__(self, in_channels): 285 | super(InceptionB, self).__init__() 286 | self.branch3x3 = BasicConv2d(in_channels, int(in_channels/2), kernel_size=3, padding=1) 287 | 288 | self.branch3x3dbl_1 = BasicConv2d(in_channels, int(in_channels/4), kernel_size=1) 289 | self.branch3x3dbl_2 = BasicConv2d(int(in_channels/4), int(in_channels/4), kernel_size=3, padding=1) 290 | self.branch3x3dbl_3 = BasicConv2d(int(in_channels/4), int(in_channels/4), kernel_size=3, padding=1) 291 | 292 | self.branch1x1 = BasicConv2d(in_channels, int(in_channels/4), kernel_size=1) 293 | def forward(self, x): 294 | branch3x3 = self.branch3x3(x) 295 | 296 | branch3x3dbl = self.branch3x3dbl_1(x) 297 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 298 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 299 | 300 | branch1x1 = self.branch1x1(x) 301 | 302 | outputs = [branch3x3, branch3x3dbl, branch1x1] 303 | return torch.cat(outputs, 1) 304 | 305 | 306 | class SE(nn.Module): 307 | def __init__(self, in_chnls, ratio=7,out_chnls=66): 308 | super(SE, self).__init__() 309 | self.f = BasicConv2d(in_chnls,in_chnls,kernel_size=3, padding=1) 310 | self.squeeze = nn.AdaptiveAvgPool2d((1, 1)) 311 | self.compress = nn.Conv2d(in_chnls, in_chnls // ratio, 1, 1, 0) 312 | self.excitation = nn.Conv2d(in_chnls // ratio, in_chnls, 1, 1, 0) 313 | 314 | def forward(self, x): 315 | # b, c, h, w = x.size() 316 | #out = self.f(x) 317 | out = self.squeeze(x) 318 | # out = out.view(b,c) 319 | out = self.compress(out) 320 | out = F.relu(out) 321 | out_t = self.excitation(out) # .view(b,c,1,1) 322 | out = torch.sigmoid(out_t) 323 | return out 324 | 325 | class Aggregate (nn.Module): 326 | def __init__(self,use_se=None,duochidu=True): 327 | super(Aggregate , self).__init__() 328 | self.use_se = use_se 329 | self.duochidu = duochidu 330 | 331 | 332 | self.inception1 = InceptionB(in_channels = 64) 333 | self.inception2 = InceptionB(in_channels = 128) 334 | self.inception3 = InceptionB(in_channels = 256) 335 | 336 | self.layer1_64x64_1 = nn.Sequential( 337 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 338 | nn.BatchNorm2d(128), 339 | nn.ReLU(inplace=True), 340 | ) 341 | self.se1 = SE(in_chnls=128, ratio=4) 342 | self.layer1_64x64_2 = nn.Sequential( 343 | nn.Conv2d(128, 64, kernel_size=3, padding=1), 344 | nn.BatchNorm2d(64), 345 | nn.ReLU(inplace=True), 346 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 347 | nn.BatchNorm2d(64), 348 | nn.ReLU(inplace=True), 349 | ) 350 | 351 | self.layer2_32x32_1 = nn.Sequential( 352 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 353 | nn.BatchNorm2d(256), 354 | nn.ReLU(inplace=True), 355 | ) 356 | self.se2 = SE(in_chnls=256, ratio=4) 357 | self.layer2_32x32_2 = nn.Sequential( 358 | nn.Conv2d(256, 128, kernel_size=3, padding=1), 359 | nn.BatchNorm2d(128), 360 | nn.ReLU(inplace=True), 361 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 362 | nn.BatchNorm2d(128), 363 | nn.ReLU(inplace=True), 364 | ) 365 | 366 | 367 | self.layer3_16x16_1 = nn.Sequential( 368 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 369 | nn.BatchNorm2d(512), 370 | nn.ReLU(inplace=True), 371 | ) 372 | self.se3 = SE(in_chnls=512, ratio=4) 373 | self.layer3_16x16_2 = nn.Sequential( 374 | nn.Conv2d(512, 256, kernel_size=3, padding=1), 375 | nn.BatchNorm2d(256), 376 | nn.ReLU(inplace=True), 377 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 378 | nn.BatchNorm2d(256), 379 | nn.ReLU(inplace=True), 380 | ) 381 | 382 | 383 | 384 | 385 | self.layer3_up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 386 | nn.Conv2d(256, 128, kernel_size=3, padding=1), 387 | nn.BatchNorm2d(128), 388 | nn.ReLU(inplace=True)) 389 | self.layer2_up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 390 | nn.Conv2d(128, 64, kernel_size=3, padding=1), 391 | nn.BatchNorm2d(64), 392 | nn.ReLU(inplace=True)) 393 | 394 | self.coordatt128 = CoordAtt(inp=128,oup=128) 395 | self.coordatt256 = CoordAtt(inp=256, oup=256) 396 | self.coordatt512 = CoordAtt(inp=512, oup=512) 397 | 398 | self.Init() 399 | 400 | 401 | def Init(self): 402 | mo_list = [self.se1,self.se2,self.se3,self.inception1,self.inception2,self.inception3, 403 | self.layer1_64x64_1,self.layer1_64x64_2,self.layer2_32x32_1,self.layer2_32x32_2,self.layer3_16x16_1,self.layer3_16x16_2, 404 | self.layer3_up,self.layer2_up, 405 | self.coordatt128,self.coordatt256,self.coordatt512] 406 | for m in mo_list: 407 | #for m in self.block_down1.modules(): 408 | if isinstance(m, nn.Conv2d): 409 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 410 | elif isinstance(m, nn.BatchNorm2d): 411 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 412 | torch.nn.init.constant_(m.bias.data, 0.0) 413 | elif isinstance(m, nn.Linear): 414 | torch.nn.init.xavier_uniform_(m.weight) 415 | torch.nn.init.constant_(m.bias, 0) 416 | 417 | def forward(self, layer1_out_64x64_o, layer2_out_32x32_o, layer3_out_16x16_o, layer1_out_64x64_m, layer2_out_32x32_m,layer3_out_16x16_m): 418 | 419 | 420 | score_map1_temp = torch.mean(layer1_out_64x64_m, dim=1).unsqueeze(1) 421 | score_map2_temp = torch.mean(layer2_out_32x32_m, dim=1).unsqueeze(1) 422 | score_map3_temp = torch.mean(layer3_out_16x16_m, dim=1).unsqueeze(1) 423 | 424 | 425 | score_map2 = F.interpolate(score_map2_temp, size=64, 426 | mode='bilinear', align_corners=False) # 对任何尺度 上采样到224,224 427 | score_map3 = F.interpolate(score_map3_temp, size=64, 428 | mode='bilinear', align_corners=False) # 对任何尺度 上采样到224,224 429 | score_map1 = score_map1_temp * score_map2 * score_map3 430 | 431 | 432 | score_map3 = F.interpolate(score_map3_temp, size=32, 433 | mode='bilinear', align_corners=False) # 对任何尺度 上采样到224,224 434 | score_map2 = score_map2_temp*score_map3 435 | score_map3 = score_map3_temp 436 | 437 | 438 | 439 | layer1_out_64x64 = torch.cat((layer1_out_64x64_o, layer1_out_64x64_m),dim=1) 440 | layer2_out_32x32 = torch.cat((layer2_out_32x32_o, layer2_out_32x32_m),dim=1) 441 | layer3_out_16x16 = torch.cat((layer3_out_16x16_o, layer3_out_16x16_m),dim=1) 442 | 443 | out1 = self.layer1_64x64_1(layer1_out_64x64) 444 | if self.use_se: 445 | weight1 = self.se1(layer1_out_64x64) 446 | else: 447 | #print('ssssssssss') 448 | weight1 = self.coordatt128(layer1_out_64x64) 449 | out1_temp = out1*weight1 450 | out1 = self.layer1_64x64_2(out1_temp) 451 | 452 | 453 | out2 = self.layer2_32x32_1(layer2_out_32x32) 454 | if self.use_se: 455 | weight2 = self.se2(layer2_out_32x32) 456 | else: 457 | weight2 = self.coordatt256(layer2_out_32x32) 458 | out2_temp = out2*weight2 459 | out2 = self.layer2_32x32_2(out2_temp) 460 | 461 | out3 = self.layer3_16x16_1(layer3_out_16x16) 462 | if self.use_se: 463 | weight3 = self.se3(layer3_out_16x16) 464 | else: 465 | weight3 = self.coordatt512(layer3_out_16x16) 466 | out3_temp = out3*weight3 467 | out3 = self.layer3_16x16_2(out3_temp) 468 | 469 | 470 | aggregate1 = out1 #B,64,64,64 471 | aggregate2 = out2 #B,128,32,32 472 | aggregate3 = out3 #B,256,16,16 473 | 474 | if self.duochidu: 475 | #print("ddd") 476 | temp3 = self.layer3_up(aggregate3) 477 | aggregate2 = aggregate2 + temp3 478 | 479 | temp2 = self.layer2_up(aggregate2) 480 | aggregate1 = aggregate1+temp2 481 | else: 482 | #print('ppp') 483 | pass 484 | 485 | 486 | return aggregate1*score_map1, aggregate2*score_map2 ,aggregate3*score_map3 487 | 488 | #return aggregate1 , aggregate2 , aggregate3 489 | 490 | class ProjectionNet(nn.Module): 491 | def __init__(self,out_features=False,num_classes = 2,data_type=None,use_se=None,use_duibi=False,duochidu=True): 492 | super(ProjectionNet, self).__init__() 493 | self.encoder_segment = Encoder(data_type=data_type) 494 | self.decoder_segment = Decoder (base_width=64) 495 | self.aggregate = Aggregate(use_se=use_se,duochidu = duochidu) 496 | #self.segment_act = torch.nn.Sigmoid() 497 | self.out_features = out_features 498 | self.use_duibi = use_duibi 499 | 500 | self.avgpool = nn.AdaptiveMaxPool2d(output_size=(1, 1)) 501 | self.fc1 = nn.Linear(256, 64) 502 | nn.init.xavier_uniform_(self.fc1.weight) 503 | self.bn = nn.BatchNorm1d(128) 504 | self.fc2 = nn.Linear(128, num_classes) 505 | nn.init.xavier_uniform_(self.fc2.weight) 506 | 507 | self.bott_conv = nn.Sequential( 508 | nn.Conv2d(512, 256, kernel_size=3, stride=2, padding=1, bias=False), 509 | nn.BatchNorm2d(256), 510 | nn.ReLU(inplace=True), 511 | ) 512 | 513 | 514 | for m in self.bott_conv.modules(): 515 | if isinstance(m, nn.Conv2d): 516 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 517 | elif isinstance(m, nn.BatchNorm2d): 518 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 519 | torch.nn.init.constant_(m.bias.data, 0.0) 520 | def forward(self, x, layer1_out_64x64_o, layer2_out_32x32_o, layer3_out_16x16_o, layer1_out_64x64_m, layer2_out_32x32_m,layer3_out_16x16_m,bn_out_128x128): 521 | forward_out = self.encoder_segment(x) 522 | forward_out = forward_out[0] 523 | aggregate1,aggregate2,aggregate3 = self.aggregate(layer1_out_64x64_o, layer2_out_32x32_o, layer3_out_16x16_o, layer1_out_64x64_m, layer2_out_32x32_m,layer3_out_16x16_m) 524 | output_segment = self.decoder_segment(forward_out,aggregate1,aggregate2,aggregate3,bn_out_128x128,x) 525 | layer_final_256x256 = 1 526 | 527 | if self.use_duibi: 528 | cls_result = forward_out.flatten(start_dim=1) 529 | else: 530 | cls_result = 1 531 | 532 | 533 | 534 | if self.out_features: 535 | return output_segment,cls_result 536 | else: 537 | return output_segment,cls_result,layer_final_256x256 538 | 539 | 540 | 541 | 542 | 543 | 544 | if __name__ == '__main__': 545 | model = ProjectionNet(data_type='wood') 546 | for name, value in model.named_parameters(): 547 | print(name) 548 | print(value.requires_grad) 549 | 550 | #print(model) 551 | model.to(torch.device("cuda")) 552 | summary(model, (3, 256, 256)) 553 | 554 | 555 | 556 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | sklearn 4 | pandas 5 | seaborn 6 | tqdm 7 | tensorboard 8 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | # head dims:512,512,512,512,512,512,512,512,128 2 | # code is basicly:https://github.com/google-research/deep_representation_one_class 3 | import warnings 4 | warnings.filterwarnings("ignore") 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | import datetime 8 | import argparse 9 | import torch.nn.functional as F 10 | import matplotlib.pyplot as plt 11 | import os 12 | from collections import OrderedDict 13 | from torchvision.models import resnet18 14 | from matplotlib import cm 15 | import torch 16 | from torch import optim 17 | from torch.utils.data import DataLoader 18 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 19 | from torch.utils.tensorboard import SummaryWriter 20 | from torchvision import transforms 21 | from loss_fun import FocalLoss, SSIM 22 | import pickle 23 | from dataset import MVTecAT, Repeat 24 | #from cutpaste import CutPasteNormal,CutPasteScar, cut_paste_collate_fn ,CutPastePlus2 25 | from model import ProjectionNet 26 | from utils import str2bool 27 | from sklearn.metrics import roc_auc_score 28 | import numpy as np 29 | from PIL import Image 30 | 31 | Get_feature = resnet18(pretrained=True) 32 | Get_feature.to('cuda:0') 33 | Get_feature.eval() 34 | Hook_outputs = [] 35 | 36 | 37 | def hook(module, input, output): 38 | Hook_outputs.append(output) 39 | Get_feature.bn1.register_forward_hook(hook) 40 | Get_feature.layer1[-1].register_forward_hook(hook) 41 | Get_feature.layer2[-1].register_forward_hook(hook) 42 | Get_feature.layer3[-1].register_forward_hook(hook) 43 | 44 | 45 | 46 | 47 | def run_training(data_type="screw", 48 | model_dir="models", 49 | epochs=256, 50 | pretrained=True, 51 | test_epochs=10, 52 | freeze_resnet=20, 53 | learninig_rate=0.03, 54 | optim_name="SGD", 55 | batch_size=64, 56 | head_layer=8, 57 | #cutpate_type=CutPasteNormal, 58 | device = "cuda", 59 | workers=8, 60 | size = 256, 61 | args = None, 62 | duochidu=True, 63 | use_jiegou_only = False, 64 | use_wenli_only = False, 65 | without_qianjing=False,use_duibi = False, 66 | se=False,gg=False,without_loss = [] ,test_memory_samples = False,memory_samples=15,test_toy_dataset = False, 67 | MVTECAD_DATA_PATH=None, 68 | 69 | ): 70 | torch.multiprocessing.freeze_support() 71 | weight_decay = 0.00003 72 | momentum = 0.9 73 | model_name = f"model-{data_type}" + '-{date:%Y-%m-%d_%H_%M_%S}'.format(date=datetime.datetime.now() ) 74 | print(without_loss) 75 | min_scale = 1 76 | if without_loss[1]: 77 | learninig_rate=1 78 | 79 | # create Training Dataset and Dataloader 80 | after_cutpaste_transform = transforms.Compose([]) 81 | after_cutpaste_transform.transforms.append(transforms.ToTensor()) 82 | after_cutpaste_transform.transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], 83 | std=[0.229, 0.224, 0.225])) 84 | 85 | train_transform = transforms.Compose([]) 86 | # train_transform.transforms.append(transforms.ColorJitter(brightness=0.1, saturation=0.1)) 87 | # train_transform.transforms.append(cutpate_type(transform = after_cutpaste_transform,args = args,data_type = data_type, 88 | # use_wenli_only = use_wenli_only,use_jiegou_only=use_jiegou_only,without_qianjing=without_qianjing)) 89 | 90 | if test_memory_samples: 91 | all_num = memory_samples 92 | else: 93 | all_num = 30 94 | if data_type == 'screw_new' or data_type == 'screw': 95 | all_num = 120 96 | if data_type == 'toothbrush': 97 | all_num = 10 98 | 99 | train_data = MVTecAT(MVTECAD_DATA_PATH, data_type, transform = train_transform, size=int(size * (1/min_scale)),memory_number = all_num) 100 | dataloader = DataLoader(Repeat(train_data, 3000), batch_size=batch_size, drop_last=True, 101 | shuffle=True, num_workers=workers, collate_fn=None, 102 | pin_memory=True,) 103 | 104 | 105 | num_classes = 2 106 | model = ProjectionNet(num_classes=num_classes,data_type=data_type,use_se=se,use_duibi = use_duibi, duochidu=duochidu) 107 | 108 | weights = torch.load(f"Test_pth/model-{data_type}.tch") 109 | 110 | model.load_state_dict(weights) 111 | model.to(device) 112 | if test_memory_samples: 113 | with open(f"memory_features/train_{data_type}.pkl", 'rb',) as f: 114 | train_memory = pickle.load(f) 115 | else: 116 | with open(f"memory_features/train_{data_type}.pkl", 'rb',) as f: 117 | train_memory = pickle.load(f) 118 | 119 | 120 | loss_focal = FocalLoss() 121 | loss_l1 = torch.nn.L1Loss() 122 | 123 | if optim_name == "sgd": 124 | optimizer = optim.SGD(model.parameters(), lr=learninig_rate, momentum=momentum, weight_decay=weight_decay) 125 | scheduler = CosineAnnealingWarmRestarts(optimizer, epochs) 126 | #scheduler = None 127 | elif optim_name == "adam": 128 | optimizer = optim.Adam(model.parameters(), lr=learninig_rate, weight_decay=weight_decay) 129 | scheduler = None 130 | else: 131 | print(f"ERROR unkown optimizer: {optim_name}") 132 | 133 | step = 0 134 | num_batches = len(dataloader) 135 | def get_data_inf(): 136 | while True: 137 | for out in enumerate(dataloader): 138 | yield out 139 | dataloader_inf = get_data_inf() 140 | 141 | 142 | if not os.path.isdir(f"./test_out/{model_name}"): 143 | os.mkdir(f"./test_out/{model_name}") 144 | 145 | 146 | model.eval() 147 | Test_AUC_MMM_, per_pixel_rocauc,= eval_model(model_name, 148 | data_type, 149 | device=device, 150 | save_plots=False, 151 | size=size, 152 | show_training_data=False, 153 | model=model, 154 | step=step, 155 | Get_feature=Get_feature, 156 | train_memory=train_memory, 157 | test_memory_samples=test_memory_samples, 158 | memory_samples=memory_samples, 159 | test_toy_dataset = test_toy_dataset 160 | ) 161 | return Test_AUC_MMM_,per_pixel_rocauc 162 | 163 | 164 | 165 | 166 | 167 | def get_feature(Get_feature, x, train_memory,type_i,test_memory_samples,memory_samples): 168 | global Hook_outputs 169 | with torch.no_grad(): 170 | _ = Get_feature(x) 171 | # get intermediate layer outputs 172 | m = torch.nn.AvgPool2d(3, 1, 1) 173 | test_outputs = OrderedDict([('layer1', []), ('layer2', []), ('layer3', []), ]) 174 | for k, v in zip(test_outputs.keys(), Hook_outputs[1:]): 175 | test_outputs[k].append(m(v)) 176 | # initialize hook outputs 177 | 178 | for k, v in test_outputs.items(): 179 | test_outputs[k] = torch.cat(v, 0) 180 | 181 | layer1_feature_diff = [] 182 | layer2_feature_diff = [] 183 | layer3_feature_diff = [] 184 | sim_di = [] 185 | for t_idx in range(test_outputs['layer3'].shape[0]): # 对每一个样本遍历 186 | for layer_name in ['layer1', 'layer2', 'layer3']: # for each layer 187 | # construct a gallery of features at all pixel locations of the K nearest neighbors 188 | topk_feat_map = train_memory[layer_name] 189 | test_feat_map = test_outputs[layer_name][t_idx:t_idx + 1] # 1,256,56,56 190 | 191 | # calculate distance matrix 192 | dist_matrix_list = [] 193 | dist_matrix_all = [] 194 | if test_memory_samples: 195 | all_num = memory_samples 196 | else: 197 | all_num = 30 198 | if type_i == 'toothbrush': 199 | all_num = 10 200 | if type_i == 'screw': 201 | all_num = 120 202 | #print(all_num) 203 | #all_num = topk_feat_map.shape[0] 204 | for iii in range(all_num): 205 | dist_matrix_list.append(torch.pow(topk_feat_map[iii] - test_feat_map[0], 2) ** 0.5) 206 | dist_matrix_all.append(dist_matrix_list[iii].sum()) 207 | idx_ = torch.argmin(torch.Tensor(dist_matrix_all)) 208 | dist_matrix = dist_matrix_list[idx_] 209 | sim_di.append(idx_) 210 | 211 | 212 | # k nearest features from the gallery (k=1) 213 | # score_map = torch.min(dist_matrix, dim=0)[0]#56,56 666666666!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 214 | 215 | #score_map = torch.mean(dist_matrix_list[idx_], dim=0).cpu() 216 | if layer_name == "layer1": 217 | layer1_feature_diff.append(dist_matrix) 218 | if layer_name == "layer2": 219 | layer2_feature_diff.append(dist_matrix) 220 | if layer_name == "layer3": 221 | layer3_feature_diff.append(dist_matrix) 222 | 223 | layer1_feature = torch.stack(layer1_feature_diff) 224 | layer2_feature = torch.stack(layer2_feature_diff) 225 | layer3_feature = torch.stack(layer3_feature_diff) 226 | 227 | m_scale = 100 # 三级管的时候 这里只放大1000倍 228 | o_scale = 100 229 | if type_i=='transistor': 230 | m_scale=1000 231 | if type_i=='cable' or type_i=='transistor' :#or type_i=='screw_new' or type_i=='screw': 232 | o_scale = 1 # 80 233 | 234 | 235 | layer3_out_16x16_m = layer3_feature.cuda() * m_scale 236 | layer2_out_32x32_m = layer2_feature.cuda() * m_scale 237 | layer1_out_64x64_m = layer1_feature.cuda() * m_scale 238 | 239 | layer1_out_64x64_o = Hook_outputs[1] * o_scale 240 | layer2_out_32x32_o = Hook_outputs[2] * o_scale 241 | layer3_out_16x16_o = Hook_outputs[3] * o_scale 242 | 243 | # m = torch.nn.AvgPool2d(3, 1, 1) 244 | bn_out_128x128 = (Hook_outputs[0]) 245 | Hook_outputs = [] 246 | return layer1_out_64x64_o, layer2_out_32x32_o, layer3_out_16x16_o, layer1_out_64x64_m, layer2_out_32x32_m, layer3_out_16x16_m, bn_out_128x128 247 | 248 | 249 | 250 | 251 | 252 | test_data_eval = None 253 | test_transform = None 254 | cached_type = None 255 | def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256, show_training_data=True, model=None, 256 | train_embed=None, head_layer=8, step=0 , Get_feature = "Get_feature",test_memory_samples = False,memory_samples=15, 257 | train_memory = "train_memory",test_toy_dataset = False): 258 | 259 | global test_data_eval, test_transform, cached_type,Hook_outputs 260 | 261 | if test_data_eval is None or cached_type != defect_type: 262 | cached_type = defect_type 263 | test_transform = transforms.Compose([]) 264 | test_transform.transforms.append(transforms.Resize(size, Image.ANTIALIAS)) 265 | test_transform.transforms.append(transforms.ToTensor()) 266 | test_transform.transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], 267 | std=[0.229, 0.224, 0.225])) 268 | if test_toy_dataset : 269 | test_data_eval = MVTecAT("c:/xunlei_cloud/mvtec_anomaly_detection.tar(1)/toy_dataset/", defect_type, 270 | size, transform=test_transform, mode="test") 271 | else: 272 | test_data_eval = MVTecAT("/home/wangyizhuo/Documents/DATA_SETS/dataset_anomaly_detection/", defect_type, 273 | size, transform=test_transform, mode="test") 274 | 275 | dataloader_test = DataLoader(test_data_eval, batch_size=16, 276 | shuffle=False, num_workers=0) 277 | 278 | 279 | # get embeddings for test data 280 | labels = [] 281 | output_segments = [] 282 | logits = [] 283 | true_masks = [] 284 | with torch.no_grad(): 285 | index_ = 0 286 | for x, label, img_mask in dataloader_test: # x维度为B,3,256,256 287 | x = x.to(device) 288 | layer1_out_64x64_o, layer2_out_32x32_o, layer3_out_16x16_o, layer1_out_64x64_m, layer2_out_32x32_m, layer3_out_16x16_m, bn_out_128x128 = get_feature(Get_feature, x, 289 | train_memory,type_i = defect_type,test_memory_samples = test_memory_samples,memory_samples=memory_samples) 290 | output_segment, logit, layer_final_256x256 = model(x,layer1_out_64x64_o, layer2_out_32x32_o, layer3_out_16x16_o, layer1_out_64x64_m, layer2_out_32x32_m, layer3_out_16x16_m, bn_out_128x128) 291 | true_masks.append(img_mask.cpu()) 292 | # save 293 | output_segments.append(output_segment.cpu()) 294 | labels.append(label.cpu()) 295 | 296 | labels = torch.cat(labels) # 83 297 | output_segments = torch.cat(output_segments) # 83,512 298 | output_segments = torch.softmax(output_segments, dim=1) 299 | 300 | #logits = torch.cat(logits) # 83,512 301 | true_masks = torch.cat(true_masks) # 83,512 302 | 303 | true_masks = true_masks.numpy() 304 | output_segments = output_segments.numpy() 305 | output_segments = output_segments[:, 1, :, :] 306 | 307 | #Get AUC from seg: 308 | MAX_anormaly = [] 309 | for im_index in range(output_segments.shape[0]): 310 | MAX_anormaly.append(output_segments[im_index].max()) 311 | all_auc = [] 312 | auc_score_max = roc_auc_score(labels, np.array(MAX_anormaly)) 313 | all_auc.append(auc_score_max) 314 | 315 | 316 | MAX_anormaly_100 = [] 317 | 318 | 319 | if not os.path.isdir(f"./test_out/{modelname}"): 320 | os.mkdir(f"./test_out/{modelname}") 321 | if not os.path.isdir(f"./test_out/{modelname}/test"): 322 | os.mkdir(f"./test_out/{modelname}/test") 323 | # if not os.path.isdir(f"./test_out/{modelname}/test/{step}"): 324 | # os.mkdir(f"./test_out/{modelname}/test/{step}") 325 | 326 | for im_index in range(output_segments.shape[0]): 327 | tempp = output_segments[im_index].flatten() 328 | tempp.sort() 329 | #tempp = tempp*tempp[62500] 330 | MAX_anormaly_100.append(tempp[65436:65536].mean()) 331 | 332 | 333 | auc_score_max_100_mean = roc_auc_score(labels, np.array(MAX_anormaly_100)) 334 | all_auc.append(auc_score_max_100_mean) 335 | for iiii in range(output_segments.shape[0]): 336 | plt.imsave(f"./test_out/{modelname}/test/pred_{iiii}.jpg", output_segments[iiii, :, :], 337 | cmap=cm.gray) 338 | 339 | true_masks = true_masks.flatten().astype(np.uint32) 340 | output_segments = output_segments.flatten() 341 | # fpr, tpr, _ = roc_curve(true_masks.flatten(), output_segments.flatten()) 342 | per_pixel_rocauc = roc_auc_score(true_masks, output_segments) 343 | Test_AUC_MMM_ = auc_score_max_100_mean 344 | print(f"-----------------------------------{defect_type}----Test_AUC_MMM",Test_AUC_MMM_,'Test_pixel_AUC:',per_pixel_rocauc) 345 | return Test_AUC_MMM_,per_pixel_rocauc 346 | 347 | 348 | 349 | 350 | if __name__ == '__main__': 351 | parser = argparse.ArgumentParser(description='Training defect detection as described in the CutPaste Paper.') 352 | parser.add_argument('--type', default="all", 353 | help='MVTec defection dataset type to train seperated by , (default: "all": train all defect types)') 354 | 355 | parser.add_argument('--epochs', default=3701, type=int, 356 | help='number of epochs to train the model , (default: 256)') 357 | 358 | parser.add_argument('--model_dir', default="models", 359 | help='output folder of the models , (default: models)') 360 | 361 | parser.add_argument('--no-pretrained', dest='pretrained', default=True, action='store_false', 362 | help='use pretrained values to initalize ResNet18 , (default: True)') 363 | 364 | parser.add_argument('--test_epochs', default=50, type=int, 365 | help='interval to calculate the auc during trainig, if -1 do not calculate test scores, (default: 10)') 366 | 367 | parser.add_argument('--freeze_resnet', default=60000, type=int, 368 | help='number of epochs to freeze resnet (default: 20)') 369 | 370 | parser.add_argument('--lr', default=0.5, type=float,#screw_new0.08 else 0.04 371 | help='learning rate (default: 0.03)') 372 | 373 | parser.add_argument('--optim', default="sgd", 374 | help='optimizing algorithm values:[sgd, adam] (dafault: "sgd")') 375 | 376 | parser.add_argument('--batch_size', default=4, type=int,# 377 | help='batch size, real batchsize is depending on cut paste config normal cutaout has effective batchsize of 2x batchsize (dafault: "64")') 378 | 379 | parser.add_argument('--head_layer', default=1, type=int, 380 | help='number of layers in the projection head (default: 1)') 381 | 382 | parser.add_argument('--variant', default="plus2", choices=['normal', 'scar', '3way', 'union' ,'plus','plus2'], help='cutpaste variant to use (dafault: "3way")') 383 | 384 | parser.add_argument('--cuda', default=True, type=str2bool, 385 | help='use cuda for training (default: False)') 386 | 387 | parser.add_argument('--workers', default=0, type=int, help="number of workers to use for data loading (default:8)") 388 | 389 | parser.add_argument('--MVTECAD_DATA_PATH', default="mvtec_datasets/",help='') 390 | #parser.add_argument('--MVTECAD_DATA_PATH', default="all", help='') 391 | 392 | 393 | args = parser.parse_args() 394 | print(args)#zipper leather metal_nut capsule pill 395 | all_types = [ 396 | 'zipper', 397 | 'tile', 398 | 'cable', 399 | 'hazelnut', 400 | 'metal_nut', 401 | 'toothbrush', 402 | 'leather', 403 | 'carpet', 404 | 405 | 'bottle', 406 | 'transistor', 407 | 'screw', 408 | 409 | 'grid', 410 | 'capsule', 411 | 'pill', 412 | 'wood', 413 | ] 414 | 415 | if args.type == "all": 416 | types = all_types 417 | else: 418 | types = args.type.split(",") 419 | 420 | # variant_map = {'normal':CutPasteNormal, 'scar':CutPasteScar, "plus2":CutPastePlus2} 421 | # variant = variant_map[args.variant] 422 | variant=None 423 | 424 | device = "cuda" if args.cuda else "cpu" 425 | print(f"using device: {device}") 426 | 427 | 428 | Path(args.model_dir).mkdir(exist_ok=True, parents=True) 429 | # save config. 430 | with open(Path(args.model_dir) / "run_config.txt", "w") as f: 431 | f.write(str(args)) 432 | for memory_samples in [30]: 433 | for _ in range(1): 434 | all_img_auc = [] 435 | all_pixel_auc = [] 436 | for data_type in types: 437 | args.epochs = 2701 438 | args.lr = 0.04 439 | print(f"======================================================={data_type}_{memory_samples}=======================================================") 440 | torch.cuda.empty_cache() 441 | Test_AUC_MMM_,per_pixel_rocauc = run_training(data_type, 442 | model_dir=Path(args.model_dir), 443 | epochs=args.epochs, 444 | pretrained=args.pretrained, 445 | test_epochs=args.test_epochs, 446 | freeze_resnet=args.freeze_resnet, 447 | learninig_rate=args.lr, 448 | optim_name=args.optim, 449 | batch_size=args.batch_size, 450 | head_layer=args.head_layer, 451 | device=device, 452 | #cutpate_type=variant, 453 | workers=args.workers, 454 | #variant = args.variant, 455 | args = args, 456 | se=False, 457 | use_jiegou_only=False, 458 | use_wenli_only=False, 459 | without_qianjing=False, 460 | gg = False, 461 | without_loss = [False,False,False], 462 | test_memory_samples = False, 463 | memory_samples = memory_samples, 464 | use_duibi=False, 465 | duochidu=True, 466 | test_toy_dataset = False, 467 | MVTECAD_DATA_PATH=args.MVTECAD_DATA_PATH 468 | )#c, focal,L1 469 | all_img_auc.append(Test_AUC_MMM_) 470 | all_pixel_auc.append(per_pixel_rocauc) 471 | print(f'\n\n ALL Image roc-auc={np.mean(np.array(all_img_auc))}') 472 | print(f' ALL Pixel roc-auc={np.mean(np.array(all_pixel_auc))}') 473 | 474 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def str2bool(v): 2 | """argparse handels type=bool in a weird way. 3 | See this stack overflow: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 4 | we can use this function as type converter for boolean values 5 | """ 6 | if isinstance(v, bool): 7 | return v 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise argparse.ArgumentTypeError('Boolean value expected.') --------------------------------------------------------------------------------