├── README.md ├── codes ├── LRN.py ├── __init__.py ├── datasets.py ├── inspection.py ├── mvtecad.py ├── nearest_neighbor.py ├── networks.py └── utils.py ├── data.npy ├── doc └── svdd_result.jpeg ├── heat_map.py ├── requirements.txt ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # SCL-VI: Self-supervised Context Learning for Visual Inspection of Industrial Defects 2 | 3 | 4 | 5 | We address the challenge of detecting object defects through the self-supervised learning approach of solving the jigsaw puzzle problem. 6 | 7 | ## Results 8 | ![segmentation](./doc/svdd_result.jpeg) 9 | 10 | ## Dependencies 11 | Since I did this project a long time ago, there may be some potential issues with environmental dependencies. 12 | - Tested with Python 3.8 13 | - [Pytorch](http://pytorch.org/) v1.6.0 14 | 15 | ## Dateset 16 | - Dataset : [MvTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad/) 17 | 18 | ## Run Training 19 | - python train.py --obj=cable --lambda_value=1 --D=64 --epoches=400 --lr=1e-4 --gpu=0 20 | 21 | ## Run Affinity Testing 22 | - python test.py --obj=cable --gpu=0 23 | - enc.load(obj, N) N is the serial number of the obtained training weight file 24 | 25 | ## Anomaly maps 26 | - python heat_map.py --obj=cable 27 | - enc.load(obj, N) N is the serial number of the obtained training weight file 28 | 29 | ## Details: 30 | - The input of the network should be 256x256 31 | - data.npy contains the relative positions and their reference numbers. -------------------------------------------------------------------------------- /codes/LRN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class LRN(nn.Module): 5 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): 6 | super(LRN, self).__init__() 7 | self.ACROSS_CHANNELS = ACROSS_CHANNELS 8 | if ACROSS_CHANNELS: 9 | self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), 10 | stride=1,padding=(int((local_size-1.0)/2), 0, 0)) 11 | else: 12 | self.average=nn.AvgPool2d(kernel_size=local_size, 13 | stride=1,padding=int((local_size-1.0)/2)) 14 | self.alpha = alpha 15 | self.beta = beta 16 | 17 | 18 | def forward(self, x): 19 | if self.ACROSS_CHANNELS: 20 | div = x.pow(2).unsqueeze(1) 21 | div = self.average(div).squeeze(1) 22 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 23 | else: 24 | div = x.pow(2) 25 | div = self.average(div) 26 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 27 | x = x.div(div) 28 | return x -------------------------------------------------------------------------------- /codes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangpeng000/VisualInspection/a5933402284662bbab1d218c188cb16788de6a4e/codes/__init__.py -------------------------------------------------------------------------------- /codes/datasets.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | 8 | __all__ = ['SVDD_Dataset', 'PositionDataset', 'MyJigsawPositionDataset'] 9 | 10 | jigsaw_num = 4 11 | 12 | def generate_coords(H, W, K): 13 | h = np.random.randint(0, H - K + 1) 14 | w = np.random.randint(0, W - K + 1) 15 | return h, w 16 | 17 | 18 | def generate_coords_mine(H, W, K): 19 | h = np.random.randint(K, H - K + 1) 20 | w = np.random.randint(K, W - K + 1) 21 | return h, w 22 | 23 | def generate_coords_position(H, W, K): 24 | with task('P1'): 25 | p1 = generate_coords(H, W, K) 26 | h1, w1 = p1 27 | 28 | pos = np.random.randint(8) 29 | 30 | with task('P2'): 31 | J = K // 4 32 | 33 | K3_4 = 3 * K // 4 34 | h_dir, w_dir = pos_to_diff[pos] 35 | h_del, w_del = np.random.randint(J, size=2) 36 | 37 | h_diff = h_dir * (h_del + K3_4) 38 | w_diff = w_dir * (w_del + K3_4) 39 | 40 | h2 = h1 + h_diff 41 | w2 = w1 + w_diff 42 | 43 | h2 = np.clip(h2, 0, H - K) 44 | w2 = np.clip(w2, 0, W - K) 45 | 46 | p2 = (h2, w2) 47 | 48 | return p1, p2, pos 49 | 50 | 51 | def generate_coords_svdd(H, W, K): 52 | 53 | with task('P1'): 54 | p1 = generate_coords(H, W, K) 55 | h1, w1 = p1 56 | 57 | with task('P2'): 58 | J = K // 32 59 | 60 | h_jit, w_jit = 0, 0 61 | 62 | while h_jit == 0 and w_jit == 0: 63 | h_jit = np.random.randint(-J, J + 1) 64 | w_jit = np.random.randint(-J, J + 1) 65 | 66 | h2 = h1 + h_jit 67 | w2 = w1 + w_jit 68 | 69 | h2 = np.clip(h2, 0, H - K) 70 | w2 = np.clip(w2, 0, W - K) 71 | 72 | p2 = (h2, w2) 73 | 74 | return p1, p2 75 | 76 | 77 | pos_to_diff = { 78 | 0: (-1, -1), 79 | 1: (-1, 0), 80 | 2: (-1, 1), 81 | 3: (0, -1), 82 | 4: (0, 1), 83 | 5: (1, -1), 84 | 6: (1, 0), 85 | 7: (1, 1) 86 | } 87 | 88 | 89 | def generate_coords_position_mine(H, W, K): 90 | with task('P_STD'): 91 | p_std = generate_coords(H, W, K) 92 | h_std, w_std = p_std 93 | 94 | with task('P_STD2'): 95 | J = K // 32 96 | h_jit, w_jit = 0, 0 97 | 98 | while h_jit == 0 and w_jit == 0: 99 | h_jit = np.random.randint(-J, J + 1) 100 | w_jit = np.random.randint(-J, J + 1) 101 | 102 | h_std2 = h_std + h_jit 103 | w_std2 = h_std + h_jit 104 | 105 | h_std2 = np.clip(h_std2, 0, H - K) 106 | w_std2 = np.clip(w_std2, 0, H - K) 107 | 108 | p_std2 = (h_std2, w_std2) 109 | 110 | with task('P0'): 111 | pos0 = 0 112 | 113 | J = K // 4 114 | K3_4 = 3 * K //4 115 | h_dir, w_dir = pos_to_diff[pos0] 116 | h_del, w_del = np.random.randint(J, size=2) 117 | 118 | h_diff = h_dir * (h_del + K3_4) 119 | w_diff = w_dir * (w_del + K3_4) 120 | 121 | h0 = h_std + h_diff 122 | w0 = w_std + w_diff 123 | 124 | h0 = np.clip(h0, 0, H - K) 125 | w0 = np.clip(w0, 0, W - K) 126 | 127 | p0 = (h0, w0) 128 | 129 | with task('P1'): 130 | pos1 = 1 131 | 132 | J = K // 4 133 | K3_4 = 3 * K // 4 134 | h_dir, w_dir = pos_to_diff[pos1] 135 | h_del, w_del = np.random.randint(J, size=2) 136 | 137 | h_diff = h_dir * (h_del + K3_4) 138 | w_diff = w_dir * (w_del + K3_4) 139 | 140 | h1 = h_std + h_diff 141 | w1 = w_std + w_diff 142 | 143 | h1 = np.clip(h1, 0, H - K) 144 | w1 = np.clip(w1, 0, W - K) 145 | 146 | p1 = (h1, w1) 147 | 148 | with task('P2'): 149 | pos2 = 2 150 | 151 | J = K // 4 152 | K3_4 = 3 * K // 4 153 | h_dir, w_dir = pos_to_diff[pos2] 154 | h_del, w_del = np.random.randint(J, size=2) 155 | 156 | h_diff = h_dir * (h_del + K3_4) 157 | w_diff = w_dir * (w_del + K3_4) 158 | 159 | h2 = h_std + h_diff 160 | w2 = w_std + w_diff 161 | 162 | h2 = np.clip(h2, 0, H - K) 163 | w2 = np.clip(w2, 0, W - K) 164 | 165 | p2 = (h2, w2) 166 | 167 | with task('P3'): 168 | pos3 = 3 169 | 170 | J = K // 4 171 | K3_4 = 3 * K // 4 172 | h_dir, w_dir = pos_to_diff[pos3] 173 | h_del, w_del = np.random.randint(J, size=2) 174 | 175 | h_diff = h_dir * (h_del + K3_4) 176 | w_diff = w_dir * (w_del + K3_4) 177 | 178 | h3 = h_std + h_diff 179 | w3 = w_std + w_diff 180 | 181 | h3 = np.clip(h3, 0, H - K) 182 | w3 = np.clip(w3, 0, W - K) 183 | 184 | p3 = (h3, w3) 185 | 186 | with task('P4'): 187 | pos4 = 4 188 | 189 | J = K // 4 190 | K3_4 = 3 * K // 4 191 | h_dir, w_dir = pos_to_diff[pos4] 192 | h_del, w_del = np.random.randint(J, size=2) 193 | 194 | h_diff = h_dir * (h_del + K3_4) 195 | w_diff = w_dir * (w_del + K3_4) 196 | 197 | h4 = h_std + h_diff 198 | w4 = w_std + w_diff 199 | 200 | h4 = np.clip(h4, 0, H - K) 201 | w4 = np.clip(w4, 0, W - K) 202 | 203 | p4 = (h4, w4) 204 | 205 | with task('P5'): 206 | pos5 = 5 207 | 208 | J = K // 4 209 | K3_4 = 3 * K // 4 210 | h_dir, w_dir = pos_to_diff[pos5] 211 | h_del, w_del = np.random.randint(J, size=2) 212 | 213 | h_diff = h_dir * (h_del + K3_4) 214 | w_diff = w_dir * (w_del + K3_4) 215 | 216 | h5 = h_std + h_diff 217 | w5 = w_std + w_diff 218 | 219 | h5 = np.clip(h5, 0, H - K) 220 | w5 = np.clip(w5, 0, W - K) 221 | 222 | p5 = (h5, w5) 223 | 224 | with task('P6'): 225 | pos6 = 6 226 | 227 | J = K // 4 228 | K3_4 = 3 * K // 4 229 | h_dir, w_dir = pos_to_diff[pos6] 230 | h_del, w_del = np.random.randint(J, size=2) 231 | 232 | h_diff = h_dir * (h_del + K3_4) 233 | w_diff = w_dir * (w_del + K3_4) 234 | 235 | h6 = h_std + h_diff 236 | w6 = w_std + w_diff 237 | 238 | h6 = np.clip(h6, 0, H - K) 239 | w6 = np.clip(w6, 0, W - K) 240 | 241 | p6 = (h6, w6) 242 | 243 | with task('P7'): 244 | pos7 = 7 245 | 246 | J = K // 4 247 | K3_4 = 3 * K // 4 248 | h_dir, w_dir = pos_to_diff[pos7] 249 | h_del, w_del = np.random.randint(J, size=2) 250 | 251 | h_diff = h_dir * (h_del + K3_4) 252 | w_diff = w_dir * (w_del + K3_4) 253 | 254 | h7 = h_std + h_diff 255 | w7 = w_std + w_diff 256 | 257 | h7 = np.clip(h7, 0, H - K) 258 | w7 = np.clip(w7, 0, W - K) 259 | 260 | p7 = (h7, w7) 261 | 262 | return p0, p1, p2, p3, p_std2,p4, p5, p6, p7, p_std 263 | 264 | 265 | 266 | 267 | class SVDD_Dataset(Dataset): 268 | 269 | def __init__(self, memmap, K=64, repeat=1): 270 | super().__init__() 271 | self.arr = np.asarray(memmap) 272 | self.K = K 273 | self.repeat = repeat 274 | 275 | def __len__(self): 276 | N = self.arr.shape[0] 277 | return N * self.repeat 278 | 279 | def __getitem__(self, idx): 280 | N = self.arr.shape[0] 281 | K = self.K 282 | n = idx % N 283 | 284 | p1, p2 = generate_coords_svdd(256, 256, K) 285 | 286 | image = self.arr[n] 287 | 288 | patch1 = crop_image_CHW(image, p1, K) 289 | patch2 = crop_image_CHW(image, p2, K) 290 | 291 | return patch1, patch2 292 | 293 | @staticmethod 294 | def infer(enc, batch): 295 | 296 | x1s, x2s, = batch 297 | h1s = enc(x1s) 298 | h2s = enc(x2s) 299 | diff = h1s - h2s 300 | l2 = diff.norm(dim=1) 301 | loss = l2.mean() 302 | 303 | return loss 304 | 305 | 306 | class PositionDataset(Dataset): 307 | 308 | def __init__(self, x, K=64, repeat=1): 309 | super(PositionDataset, self).__init__() 310 | self.x = np.asarray(x) 311 | self.K = K 312 | self.repeat = repeat 313 | 314 | def __len__(self): 315 | N = self.x.shape[0] 316 | return N * self.repeat 317 | 318 | def __getitem__(self, idx): 319 | N = self.x.shape[0] 320 | K = self.K 321 | n = idx % N 322 | 323 | image = self.x[n] 324 | p1, p2, pos = generate_coords_position(256, 256, K) 325 | 326 | patch1 = crop_image_CHW(image, p1, K).copy() 327 | patch2 = crop_image_CHW(image, p2, K).copy() 328 | 329 | rgbshift1 = np.random.normal(scale=0.02, size=(3, 1, 1)) 330 | rgbshift2 = np.random.normal(scale=0.02, size=(3, 1, 1)) 331 | 332 | patch1 += rgbshift1 333 | patch2 += rgbshift2 334 | 335 | 336 | noise1 = np.random.normal(scale=0.02, size=(3, K, K)) 337 | noise2 = np.random.normal(scale=0.02, size=(3, K, K)) 338 | 339 | patch1 += noise1 340 | patch2 += noise2 341 | 342 | return patch1, patch2, pos 343 | 344 | class MyJigsawPositionDataset(Dataset): 345 | def __init__(self, x, K=64, repeat=1): 346 | super(MyJigsawPositionDataset, self).__init__() 347 | self.x = np.asarray(x) 348 | self.K = K 349 | self.repeat = repeat 350 | 351 | def __len__(self): 352 | N = self.x.shape[0] 353 | return N * self.repeat 354 | 355 | def __getitem__(self, idx): 356 | N = self.x.shape[0] 357 | K = self.K 358 | n = idx % N 359 | 360 | image = self.x[n] 361 | 362 | position = generate_coords_position_mine(256, 256, K) 363 | npy = np.load('data.npy') 364 | order = np.random.randint(len(npy)) 365 | 366 | patch1 = crop_image_CHW(image, position[npy[order][0]], K).copy() 367 | patch2 = crop_image_CHW(image, position[npy[order][1]], K).copy() 368 | pos = npy[order][2] 369 | 370 | rgbshift1 = np.random.normal(scale=0.02, size=(3, 1, 1)) 371 | rgbshift2 = np.random.normal(scale=0.02, size=(3, 1, 1)) 372 | 373 | patch1 += rgbshift1 374 | patch2 += rgbshift2 375 | 376 | 377 | noise1 = np.random.normal(scale=0.02, size=(3, K, K)) 378 | noise2 = np.random.normal(scale=0.02, size=(3, K, K)) 379 | 380 | patch1 += noise1 381 | patch2 += noise2 382 | 383 | return patch1, patch2, pos -------------------------------------------------------------------------------- /codes/inspection.py: -------------------------------------------------------------------------------- 1 | from codes import mvtecad 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from .utils import PatchDataset_NCHW, NHWC2NCHW, distribute_scores 6 | 7 | __all__ = ['eval_encoder_NN_multiK', 'eval_embeddings_NN_multiK'] 8 | 9 | 10 | def weights_e_alpha(l2_maps): 11 | N, I, J, NN = l2_maps.shape 12 | weights_alpha = np.ones((N, I, J, NN), dtype=np.float32) 13 | weights_1_alpha = np.ones((N, I, J, NN), dtype=np.float32) 14 | weights_e_1_alpha = np.ones((N, I, J, NN), dtype=np.float32) 15 | 16 | weights = np.ones((N, I, J, NN), dtype=np.float32) 17 | 18 | l2_maps_e_1_alpha_sum = np.ones((N, I, J, NN), dtype=np.float32) 19 | result_NN = np.ones((N, I, J, NN), dtype=np.float32) 20 | 21 | l2_maps_sum = np.sum(l2_maps, axis=-1) 22 | l2_maps_sum[l2_maps_sum == 0] = 1 23 | 24 | 25 | for n in range(N): 26 | for i in range(I): 27 | for j in range(J): 28 | weights_alpha[n, i, j, :] = l2_maps[n, i, j, :]/l2_maps_sum[n, i, j] 29 | weights_alpha[weights_alpha == 0] = 1 30 | weights_1_alpha[n, i, j, :] = 1 / weights_alpha[n, i, j, :] 31 | weights_1_alpha[weights_1_alpha > 20] = 15 32 | weights_e_1_alpha[n, i, j, :] = np.exp(weights_1_alpha[n, i, j, :]) 33 | l2_maps_e_1_alpha_sum[n, i, j] = np.sum(weights_e_1_alpha[n, i, j, :], axis=-1) 34 | weights[n, i, j, :] = weights_e_1_alpha[n, i, j, :] / l2_maps_e_1_alpha_sum[n, i, j] 35 | result_NN[n, i, j, :] = l2_maps[n, i, j, :] * weights[n, i, j, :] 36 | 37 | result = np.sum(result_NN, axis=-1) 38 | return result 39 | 40 | def infer(x, enc, K, S): 41 | x = NHWC2NCHW(x) 42 | 43 | dataset = PatchDataset_NCHW(x, K=K, S=S) 44 | loader = DataLoader(dataset, batch_size=64, shuffle=False, pin_memory=True) 45 | 46 | embs = np.empty((dataset.N, dataset.row_num, dataset.col_num, enc.D), dtype=np.float32) # [-1, I, J, D] 47 | 48 | enc = enc.eval() 49 | with torch.no_grad(): 50 | for xs, ns, iis, js in loader: 51 | xs = xs.cuda() 52 | 53 | embedding = enc(xs) 54 | embedding = embedding.detach().cpu().numpy() 55 | 56 | for embed, n, i, j in zip(embedding, ns, iis, js): 57 | embs[n, i, j] = np.squeeze(embed) 58 | return embs 59 | 60 | 61 | def assess_anomaly_maps(obj, anomaly_maps): 62 | auroc_seg = mvtecad.segmentation_auroc(obj, anomaly_maps) 63 | anomaly_scores = anomaly_maps.max(axis=-1).max(axis=-1) 64 | auroc_det = mvtecad.detection_auroc(obj, anomaly_scores) 65 | return auroc_det, auroc_seg 66 | 67 | 68 | def measure_emb_NN(emb_te, emb_tr, method='kdt', NN=1): 69 | from .nearest_neighbor import search_NN 70 | 71 | D = emb_tr.shape[-1] 72 | 73 | train_emb_all = emb_tr.reshape(-1, D) 74 | 75 | l2_maps, _ = search_NN(emb_te, train_emb_all, method=method, NN=NN) 76 | 77 | anomaly_maps = weights_e_alpha(l2_maps) 78 | 79 | return anomaly_maps 80 | 81 | 82 | 83 | def eval_encoder_NN_multiK(enc, obj, maps_num): 84 | 85 | x_tr = mvtecad.get_x_standardized(obj, mode='train') 86 | x_te = mvtecad.get_x_standardized(obj, mode='test') 87 | 88 | embs64_tr = infer(x_tr, enc, K=64, S=16) 89 | embs64_te = infer(x_te, enc, K=64, S=16) 90 | 91 | embs32_tr = infer(x_tr, enc.enc, K=32, S=4) 92 | embs32_te = infer(x_te, enc.enc, K=32, S=4) 93 | 94 | 95 | embs64 = embs64_tr, embs64_te 96 | embs32 = embs32_tr, embs32_te 97 | 98 | 99 | return eval_embeddings_NN_multiK(obj, embs64, embs32, NN=maps_num) 100 | 101 | 102 | def eval_embeddings_NN_multiK(obj, embs64, embs32, NN=1): 103 | emb_tr, emb_te = embs64 104 | 105 | maps_64 = measure_emb_NN(emb_te, emb_tr, method='kdt', NN=NN) 106 | maps_64 = distribute_scores(maps_64, (256, 256), K=64, S=16) 107 | det_64, seg_64 = assess_anomaly_maps(obj, maps_64) 108 | 109 | emb_tr, emb_te = embs32 110 | maps_32 = measure_emb_NN(emb_te, emb_tr, method='ngt', NN=NN) 111 | maps_32 = distribute_scores(maps_32, (256, 256), K=32, S=4) 112 | det_32, seg_32 = assess_anomaly_maps(obj, maps_32) 113 | 114 | maps_sum = maps_64 + maps_32 115 | det_sum, seg_sum = assess_anomaly_maps(obj, maps_sum) 116 | 117 | maps_mult = maps_64 * maps_32 118 | det_mult, seg_mult = assess_anomaly_maps(obj, maps_mult) 119 | 120 | return { 121 | 'det_64': det_64, 122 | 'seg_64': seg_64, 123 | 124 | 'det_32': det_32, 125 | 'seg_32': seg_32, 126 | 127 | 'det_sum': det_sum, 128 | 'seg_sum': seg_sum, 129 | 130 | 'det_mult': det_mult, 131 | 'seg_mult': seg_mult, 132 | 133 | 'maps_64': maps_64, 134 | 'maps_32': maps_32, 135 | 'maps_sum': maps_sum, 136 | 'maps_mult': maps_mult, 137 | } 138 | -------------------------------------------------------------------------------- /codes/mvtecad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | from imageio import imread 5 | from glob import glob 6 | from sklearn.metrics import roc_auc_score 7 | import os 8 | 9 | 10 | DATASET_PATH = './data/MVTec' 11 | 12 | 13 | __all__ = ['objs', 'set_root_path', 14 | 'get_x', 'get_x_standardized', 15 | 'detection_auroc', 'segmentation_auroc'] 16 | 17 | 18 | objs = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 19 | 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 20 | 'transistor', 'wood', 'zipper'] 21 | 22 | 23 | def resize(image, shape=(256, 256)): 24 | return np.array(Image.fromarray(image).resize(shape[::-1])) 25 | 26 | 27 | 28 | def bilinears(images, shape) -> np.ndarray: 29 | import cv2 30 | N = images.shape[0] 31 | 32 | new_shape = (N,) + shape 33 | ret = np.zeros(new_shape, dtype=images.dtype) 34 | for i in range(N): 35 | 36 | ret[i] = cv2.resize(images[i], dsize=shape[::-1], interpolation=cv2.INTER_LINEAR) 37 | return ret 38 | 39 | 40 | 41 | def gray2rgb(images): 42 | tile_shape = tuple(np.ones(len(images.shape), dtype=int)) 43 | tile_shape += (3,) 44 | 45 | images = np.tile(np.expand_dims(images, axis=-1), tile_shape) 46 | 47 | return images 48 | 49 | 50 | 51 | def set_root_path(new_path): 52 | global DATASET_PATH 53 | DATASET_PATH = new_path 54 | 55 | 56 | def get_x(obj, mode='train'): 57 | 58 | fpattern = os.path.join(DATASET_PATH, f'{obj}/{mode}/*/*.png') 59 | 60 | fpaths = sorted(glob(fpattern)) 61 | 62 | if mode == 'test': 63 | 64 | fpaths1 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) != 'good', fpaths)) 65 | fpaths2 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) == 'good', fpaths)) 66 | 67 | images1 = np.asarray(list(map(imread, fpaths1))) 68 | images2 = np.asarray(list(map(imread, fpaths2))) 69 | images = np.concatenate([images1, images2]) 70 | 71 | else: 72 | images = np.asarray(list(map(imread, fpaths))) 73 | 74 | if images.shape[-1] != 3: 75 | images = gray2rgb(images) 76 | images = list(map(resize, images)) 77 | images = np.asarray(images) 78 | return images 79 | 80 | 81 | def get_x_standardized(obj, mode='train'): 82 | x = get_x(obj, mode=mode) 83 | mean = get_mean(obj) 84 | return (x.astype(np.float32) - mean) / 255 85 | 86 | 87 | def get_label(obj): 88 | fpattern = os.path.join(DATASET_PATH, f'{obj}/test/*/*.png') 89 | fpaths = sorted(glob(fpattern)) 90 | fpaths1 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) != 'good', fpaths)) 91 | fpaths2 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) == 'good', fpaths)) 92 | 93 | Nanomaly = len(fpaths1) 94 | Nnormal = len(fpaths2) 95 | labels = np.zeros(Nanomaly + Nnormal, dtype=np.int32) 96 | labels[:Nanomaly] = 1 97 | return labels 98 | 99 | 100 | def get_mask(obj): 101 | fpattern = os.path.join(DATASET_PATH, f'{obj}/ground_truth/*/*.png') 102 | fpaths = sorted(glob(fpattern)) 103 | masks = np.asarray(list(map(lambda fpath: resize(imread(fpath), (256, 256)), fpaths))) 104 | Nanomaly = masks.shape[0] 105 | Nnormal = len(glob(os.path.join(DATASET_PATH, f'{obj}/test/good/*.png'))) 106 | 107 | masks[masks <= 128] = 0 108 | masks[masks > 128] = 255 109 | results = np.zeros((Nanomaly + Nnormal,) + masks.shape[1:], dtype=masks.dtype) 110 | results[:Nanomaly] = masks 111 | 112 | return results 113 | 114 | 115 | def get_mean(obj): 116 | images = get_x(obj, mode='train') 117 | mean = images.astype(np.float32).mean(axis=0) 118 | return mean 119 | 120 | 121 | 122 | def detection_auroc(obj, anomaly_scores): 123 | label = get_label(obj) # 1: anomaly 0: normal 124 | auroc = roc_auc_score(label, anomaly_scores) 125 | return auroc 126 | 127 | 128 | def segmentation_auroc(obj, anomaly_maps): 129 | gt = get_mask(obj) 130 | gt = gt.astype(np.int32) 131 | gt[gt == 255] = 1 132 | 133 | anomaly_maps = bilinears(anomaly_maps, (256, 256)) 134 | auroc = roc_auc_score(gt.flatten(), anomaly_maps.flatten()) 135 | return auroc 136 | 137 | -------------------------------------------------------------------------------- /codes/nearest_neighbor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import shutil 3 | import os 4 | 5 | 6 | __all__ = ['search_NN'] 7 | 8 | 9 | 10 | def search_NN(test_emb, train_emb_flat, NN=1, method='kdt'): 11 | if method == 'ngt': 12 | return search_NN_ngt(test_emb, train_emb_flat, NN=NN) 13 | 14 | from sklearn.neighbors import KDTree 15 | kdt = KDTree(train_emb_flat) 16 | 17 | Ntest, I, J, D = test_emb.shape 18 | closest_inds = np.empty((Ntest, I, J, NN), dtype=np.int32) 19 | l2_maps = np.empty((Ntest, I, J, NN), dtype=np.float32) 20 | 21 | for n in range(Ntest): 22 | for i in range(I): 23 | dists, inds = kdt.query(test_emb[n, i, :, :], return_distance=True, k=NN) 24 | closest_inds[n, i, :, :] = inds[:, :] 25 | l2_maps[n, i, :, :] = dists[:, :] 26 | 27 | return l2_maps, closest_inds 28 | 29 | 30 | def search_NN_ngt(test_emb, train_emb_flat, NN=1): 31 | import ngtpy 32 | 33 | Ntest, I, J, D = test_emb.shape 34 | closest_inds = np.empty((Ntest, I, J, NN), dtype=np.int32) 35 | l2_maps = np.empty((Ntest, I, J, NN), dtype=np.float32) 36 | 37 | dpath = f'/tmp/{os.getpid()}' 38 | ngtpy.create(dpath, D) 39 | index = ngtpy.Index(dpath) 40 | index.batch_insert(train_emb_flat) 41 | 42 | for n in range(Ntest): 43 | for i in range(I): 44 | for j in range(J): 45 | query = test_emb[n, i, j, :] 46 | results = index.search(query, NN) 47 | inds = [result[0] for result in results] 48 | 49 | closest_inds[n, i, j, :] = inds 50 | vecs = np.asarray([index.get_object(inds[nn]) for nn in range(NN)]) 51 | dists = np.linalg.norm(query - vecs, axis=-1) 52 | l2_maps[n, i, j, :] = dists 53 | shutil.rmtree(dpath) 54 | 55 | return l2_maps, closest_inds 56 | -------------------------------------------------------------------------------- /codes/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | from .utils import makedirpath 6 | from .LRN import LRN 7 | from .datasets import jigsaw_num 8 | 9 | __all__ = ['MyJigsawPositionHierEncoder', 'MyJigsawPositionDeepEncoder', 'MyJigsawPositionEncoder', 'MyJigsawPositionClassifier'] 10 | 11 | 12 | 13 | def forward_hier(x, emb_small, K): 14 | K_2 = K // 2 15 | n = x.size(0) 16 | x1 = x[..., :K_2, :K_2] 17 | x2 = x[..., :K_2, K_2:] 18 | x3 = x[..., K_2:, :K_2] 19 | x4 = x[..., K_2:, K_2:] 20 | xx = torch.cat([x1, x2, x3, x4], dim=0) 21 | 22 | hh = emb_small(xx) 23 | 24 | 25 | h1 = hh[:n] 26 | h2 = hh[n: 2 * n] 27 | h3 = hh[2 * n: 3 * n] 28 | h4 = hh[3 * n:] 29 | 30 | 31 | h12 = torch.cat([h1, h2], dim=3) 32 | h34 = torch.cat([h3, h4], dim=3) 33 | 34 | h = torch.cat([h12, h34], dim=2) 35 | return h 36 | 37 | 38 | 39 | xent = nn.CrossEntropyLoss() 40 | 41 | class NormalizedLinear(nn.Module): 42 | __constants__ = ['bias', 'in_features', 'out_features'] 43 | 44 | def __init__(self, in_features, out_features, bias=True): 45 | super(NormalizedLinear, self).__init__() 46 | self.in_features = in_features 47 | self.out_features = out_features 48 | # weight的形状 49 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 50 | if bias: 51 | # bias的形状 52 | self.bias = nn.Parameter(torch.Tensor(out_features)) 53 | else: 54 | self.register_parameter('bias', None) 55 | self.reset_parameters() 56 | 57 | def reset_parameters(self): 58 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 59 | if self.bias is not None: 60 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 61 | bound = 1 / math.sqrt(fan_in) 62 | nn.init.uniform_(self.bias, -bound, bound) 63 | 64 | def forward(self, x): 65 | with torch.no_grad(): 66 | w = self.weight / self.weight.data.norm(keepdim=True, dim=0) 67 | return F.linear(x, w, self.bias) 68 | 69 | def extra_repr(self): 70 | return 'in_features={}, out_features={}, bias={}'.format( 71 | self.in_features, self.out_features, self.bias is not None 72 | ) 73 | 74 | 75 | class MyJigsawPositionEncoder(nn.Module): 76 | def __init__(self, K, D=64, bias=True): 77 | super().__init__() 78 | 79 | 80 | self.conv1 = nn.Conv2d(3, 64, 5, 2, 0, bias=bias) 81 | self.conv2 = nn.Conv2d(64, 64, 5, 2, 0, bias=bias) 82 | self.conv3 = nn.Conv2d(64, 128, 5, 2, 0, bias=bias) 83 | self.conv4 = nn.Conv2d(128, D, 5, 1, 0, bias=bias) 84 | 85 | self.K = K 86 | self.D = D 87 | self.bias = bias 88 | 89 | def forward(self, x): 90 | h = self.conv1(x) 91 | h = F.leaky_relu(h, 0.1) 92 | 93 | h = self.conv2(h) 94 | h = F.leaky_relu(h, 0.1) 95 | 96 | h = self.conv3(h) 97 | 98 | 99 | if self.K == 64: 100 | h = F.leaky_relu(h, 0.1) 101 | h = self.conv4(h) 102 | 103 | h = torch.tanh(h) 104 | 105 | return h 106 | 107 | class MyJigsawPositionDeepEncoder(nn.Module): 108 | def __init__(self, K, D=64, bias=True): 109 | super().__init__() 110 | self.conv = nn.Sequential( 111 | nn.Conv2d(3, 96, kernel_size=5, stride=1, padding=0), 112 | nn.GroupNorm(12, 96), 113 | nn.ReLU(inplace=True), 114 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 115 | LRN(local_size=5, alpha=0.0001, beta=0.75), 116 | 117 | nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2), 118 | nn.GroupNorm(32, 256), 119 | nn.ReLU(inplace=True), 120 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0), 121 | LRN(local_size=5, alpha=0.0001, beta=0.75), 122 | 123 | nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1), 124 | nn.GroupNorm(48, 384), 125 | nn.ReLU(inplace=True), 126 | 127 | nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1), 128 | nn.GroupNorm(48, 384), 129 | nn.ReLU(inplace=True), 130 | 131 | nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1), 132 | nn.GroupNorm(32, 256), 133 | nn.ReLU(inplace=True), 134 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 135 | 136 | nn.Conv2d(256, 128, kernel_size=2, stride=1, padding=0), 137 | nn.GroupNorm(16, 128), 138 | 139 | nn.Conv2d(128, D, kernel_size=2, stride=1, padding=0), 140 | nn.GroupNorm(int(D/8), D) 141 | ) 142 | 143 | self.K = K 144 | self.D = D 145 | 146 | def forward(self, x): 147 | x = self.conv(x) 148 | 149 | x = torch.tanh(x) 150 | 151 | return x 152 | 153 | def save(self, name): 154 | fpath = self.fpath_from_name(name) 155 | makedirpath(fpath) 156 | torch.save(self.state_dict(), fpath) 157 | 158 | def load(self, name): 159 | fpath = self.fpath_from_name(name) 160 | self.load_state_dict(torch.load(fpath)) 161 | 162 | @staticmethod 163 | def fpath_from_name(name): 164 | return f'ckpts/{name}/myjigsawposition_encdeep.pkl' 165 | 166 | class MyJigsawPositionHierEncoder(nn.Module): 167 | def __init__(self, K, D=64, bias=True): 168 | super().__init__() 169 | 170 | if K > 64: 171 | self.enc = MyJigsawPositionHierEncoder(K // 2, D, bias=bias) 172 | 173 | 174 | elif K == 64: 175 | self.enc = MyJigsawPositionDeepEncoder(K // 2, D, bias=bias) 176 | 177 | else: 178 | raise ValueError() 179 | 180 | self.conv1 = nn.Conv2d(D, 128, 2, 1, 0, bias=bias) 181 | self.conv2 = nn.Conv2d(128, D, 1, 1, 0, bias=bias) 182 | 183 | self.K = K 184 | self.D = D 185 | 186 | def forward(self, x): 187 | 188 | h = forward_hier(x, self.enc, K=self.K) 189 | 190 | 191 | h = self.conv1(h) 192 | h = F.leaky_relu(h, 0.1) 193 | 194 | h = self.conv2(h) 195 | h = torch.tanh(h) 196 | 197 | return h 198 | 199 | def save(self, name, i): 200 | fpath = self.fpath_from_name(name, i) 201 | makedirpath(fpath) 202 | torch.save(self.state_dict(), fpath) 203 | 204 | def load(self, name, i): 205 | fpath = self.fpath_from_name(name, i) 206 | self.load_state_dict(torch.load(fpath)) 207 | print('Encoder has been loaded!') 208 | 209 | @staticmethod 210 | def fpath_from_name(name, i): 211 | return f'ckpts/{name}/myjigsawposition_enchier_{i}_step.pkl' 212 | 213 | class MyJigsawPositionClassifier(nn.Module): 214 | def __init__(self, K, D, class_num=12): 215 | super().__init__() 216 | self.D = D 217 | 218 | self.fc1 = nn.Linear(D, 128) 219 | self.act1 = nn.LeakyReLU(0.1) 220 | 221 | self.fc2 = nn.Linear(128, 128) 222 | self.act2 = nn.LeakyReLU(0.1) 223 | 224 | self.fc3 = NormalizedLinear(128, class_num) 225 | 226 | self.K = K 227 | 228 | def save(self, name): 229 | fpath = self.fpath_from_name(name) 230 | makedirpath(fpath) 231 | torch.save(self.state_dict(), fpath) 232 | 233 | def load(self, name): 234 | fpath = self.fpath_from_name(name) 235 | self.load_state_dict(torch.load(fpath)) 236 | 237 | def fpath_from_name(self, name): 238 | return f'ckpts/{name}/position_classifier_K{self.K}.pkl' 239 | 240 | @staticmethod 241 | def infer(c, enc, batch): 242 | 243 | x1s, x2s, ys = batch 244 | ys = ys.long().cuda() 245 | 246 | 247 | h1 = enc(x1s) 248 | h2 = enc(x2s) 249 | 250 | logits = c(h1, h2) 251 | 252 | loss = xent(logits, ys) 253 | return loss 254 | 255 | def forward(self, h1, h2): 256 | h1 = h1.view(-1, self.D) 257 | h2 = h2.view(-1, self.D) 258 | 259 | 260 | h = h1 - h2 261 | 262 | h = self.fc1(h) 263 | h = self.act1(h) 264 | 265 | h = self.fc2(h) 266 | h = self.act2(h) 267 | 268 | h = self.fc3(h) 269 | 270 | return h -------------------------------------------------------------------------------- /codes/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import _pickle as p 4 | from torch.utils.data import Dataset 5 | import torch 6 | from contextlib import contextmanager 7 | from PIL import Image 8 | 9 | 10 | __all__ = ['crop_image_CHW', 'PatchDataset_NCHW', 'NHWC2NCHW_normalize', 'NHWC2NCHW', 11 | 'save_binary', 'load_binary', 'makedirpath', 'task', 'DictionaryConcatDataset', 12 | 'to_device', 'distribute_scores', 'resize'] 13 | 14 | 15 | def to_device(obj, device, non_blocking=False): 16 | 17 | 18 | if isinstance(obj, torch.Tensor): 19 | return obj.to(device, non_blocking=non_blocking) 20 | 21 | 22 | if isinstance(obj, dict): 23 | 24 | return {k: to_device(v, device, non_blocking=non_blocking) 25 | for k, v in obj.items()} 26 | 27 | 28 | if isinstance(obj, list): 29 | return [to_device(v, device, non_blocking=non_blocking) 30 | for v in obj] 31 | 32 | 33 | if isinstance(obj, tuple): 34 | return tuple([to_device(v, device, non_blocking=non_blocking) 35 | for v in obj]) 36 | 37 | 38 | @contextmanager 39 | def task(_): 40 | yield 41 | 42 | 43 | class DictionaryConcatDataset(Dataset): 44 | 45 | def __init__(self, d_of_datasets): 46 | self.d_of_datasets = d_of_datasets 47 | lengths = [len(d) for d in d_of_datasets.values()] 48 | self._length = min(lengths) 49 | self.keys = self.d_of_datasets.keys() 50 | assert min(lengths) == max(lengths), 'Length of the datasets should be the same' 51 | 52 | def __getitem__(self, idx): 53 | return { 54 | key: self.d_of_datasets[key][idx] 55 | for key in self.keys 56 | } 57 | 58 | def __len__(self): 59 | return self._length 60 | 61 | def crop_CHW(image, i, j, K, S=1): 62 | if S == 1: 63 | h, w = i, j 64 | else: 65 | h = S * i 66 | w = S * j 67 | return image[:, h: h + K, w: w + K] 68 | 69 | 70 | def cnn_output_size(H, K, S=1, P=0) -> int: 71 | """ 72 | 73 | :param int H: input_size 74 | :param int K: filter_size 75 | :param int S: stride 76 | :param int P: padding 77 | :return: 78 | """ 79 | return 1 + (H - K + 2 * P) // S 80 | 81 | def crop_image_CHW(image, coord, K): 82 | h, w = coord 83 | return image[:, h: h + K, w: w + K] 84 | 85 | 86 | class PatchDataset_NCHW(Dataset): 87 | def __init__(self, memmap, tfs=None, K=32, S=1): 88 | super().__init__() 89 | self.arr = memmap 90 | self.tfs = tfs 91 | self.S = S 92 | self.K = K 93 | self.N = self.arr.shape[0] 94 | 95 | def __len__(self): 96 | return self.N * self.row_num * self.col_num 97 | 98 | @property 99 | def row_num(self): 100 | N, C, H, W = self.arr.shape 101 | K = self.K 102 | S = self.S 103 | I = cnn_output_size(H, K=K, S=S) 104 | return I 105 | 106 | @property 107 | def col_num(self): 108 | N, C, H, W = self.arr.shape 109 | K = self.K 110 | S = self.S 111 | J = cnn_output_size(W, K=K, S=S) 112 | return J 113 | 114 | 115 | def __getitem__(self, idx): 116 | N = self.N 117 | n, i, j = np.unravel_index(idx, (N, self.row_num, self.col_num)) 118 | K = self.K 119 | S = self.S 120 | 121 | image = self.arr[n] 122 | 123 | patch = crop_CHW(image, i, j, K, S) 124 | 125 | if self.tfs: 126 | patch = self.tfs(patch) 127 | 128 | return patch, n, i, j 129 | 130 | 131 | 132 | def NHWC2NCHW_normalize(x): 133 | x = (x / 255.).astype(np.float32) 134 | return np.transpose(x, [0, 3, 1, 2]) 135 | 136 | 137 | 138 | def NHWC2NCHW(x): 139 | return np.transpose(x, [0, 3, 1, 2]) 140 | 141 | 142 | def load_binary(fpath, encoding='ASCII'): 143 | with open(fpath, 'rb') as f: 144 | return p.load(f, encoding=encoding) 145 | 146 | 147 | def save_binary(d, fpath): 148 | with open(fpath, 'wb') as f: 149 | p.dump(d, f) 150 | 151 | 152 | def makedirpath(fpath: str): 153 | 154 | dpath = os.path.dirname(fpath) 155 | if dpath: 156 | os.makedirs(dpath, exist_ok=True) 157 | 158 | 159 | def distribute_scores(score_masks, output_shape, K: int, S: int) -> np.ndarray: 160 | N = score_masks.shape[0] 161 | results = [distribute_score(score_masks[n], output_shape, K, S) for n in range(N)] 162 | return np.asarray(results) 163 | 164 | 165 | def distribute_score(score_mask, output_shape, K: int, S: int) -> np.ndarray: 166 | H, W = output_shape 167 | mask = np.zeros([H, W], dtype=np.float32) 168 | cnt = np.zeros([H, W], dtype=np.int32) 169 | 170 | I, J = score_mask.shape[:2] 171 | for i in range(I): 172 | for j in range(J): 173 | h, w = i * S, j * S 174 | 175 | 176 | mask[h: h + K, w: w + K] += score_mask[i, j] 177 | 178 | cnt[h: h + K, w: w + K] += 1 179 | 180 | 181 | cnt[cnt == 0] = 1 182 | 183 | return mask / cnt 184 | 185 | 186 | def resize(image, shape=(256, 256)): 187 | 188 | return np.array(Image.fromarray(image).resize(shape[::-1])) 189 | -------------------------------------------------------------------------------- /data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangpeng000/VisualInspection/a5933402284662bbab1d218c188cb16788de6a4e/data.npy -------------------------------------------------------------------------------- /doc/svdd_result.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangpeng000/VisualInspection/a5933402284662bbab1d218c188cb16788de6a4e/doc/svdd_result.jpeg -------------------------------------------------------------------------------- /heat_map.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | from codes import mvtecad 4 | from tqdm import tqdm 5 | from codes.utils import resize, makedirpath 6 | 7 | from skimage import morphology 8 | from skimage.segmentation import mark_boundaries 9 | import os 10 | import numpy as np 11 | import matplotlib 12 | from scipy.ndimage import gaussian_filter 13 | from sklearn.metrics import precision_recall_curve 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--obj', default='transistor') 17 | args = parser.parse_args() 18 | 19 | 20 | def plot_fig(test_img, scores, gts, threshold, obj): 21 | 22 | num = len(scores) 23 | vmax = scores.max() * 255. 24 | vmin = scores.min() * 255. 25 | for i in range(num): 26 | img = test_img[i] 27 | gt = gts[i] 28 | heat_map = scores[i] * 255 29 | mask = scores[i] 30 | mask[mask > threshold] = 1 31 | mask[mask <= threshold] = 0 32 | kernel = morphology.disk(4) 33 | mask = morphology.opening(mask, kernel) 34 | mask *= 255 35 | vis_img = mark_boundaries(img, mask, color=(1, 0, 0), mode='thick') 36 | fig_img, ax_img = plt.subplots(1, 5, figsize=(20, 5)) 37 | fig_img.subplots_adjust(right=0.9) 38 | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) 39 | for ax_i in ax_img: 40 | ax_i.axes.xaxis.set_visible(False) 41 | ax_i.axes.yaxis.set_visible(False) 42 | ax_img[0].imshow(img) 43 | ax_img[1].imshow(gt, cmap='gray') 44 | ax = ax_img[2].imshow(heat_map, cmap='jet', norm=norm) 45 | ax_img[2].imshow(img, cmap='gray', interpolation='none') 46 | ax_img[2].imshow(heat_map, cmap='jet', alpha=0.5, interpolation='none') 47 | ax_img[3].imshow(mask, cmap='gray') 48 | ax_img[4].imshow(vis_img) 49 | left = 0.92 50 | bottom = 0.15 51 | width = 0.015 52 | height = 1 - 2 * bottom 53 | rect = [left, bottom, width, height] 54 | cbar_ax = fig_img.add_axes(rect) 55 | cb = plt.colorbar(ax, shrink=0.6, cax=cbar_ax, fraction=0.046) 56 | cb.ax.tick_params(labelsize=8) 57 | font = { 58 | 'family': 'serif', 59 | 'color': 'black', 60 | 'weight': 'normal', 61 | 'size': 8, 62 | } 63 | 64 | fpath = f'anomaly_maps/{obj}/{i:03d}.png' 65 | makedirpath(fpath) 66 | fig_img.savefig(fpath) 67 | plt.close() 68 | 69 | def denormalization(x): 70 | mean = np.array([0.485, 0.456, 0.406]) 71 | std = np.array([0.229, 0.224, 0.225]) 72 | x = (((x * std) + mean) * 255.).astype(np.uint8) 73 | 74 | return x 75 | 76 | 77 | def main(): 78 | from codes.inspection import eval_encoder_NN_multiK 79 | from codes.networks import MyJigsawPositionHierEncoder 80 | 81 | obj = args.obj 82 | 83 | enc = MyJigsawPositionHierEncoder(K=64, D=64).cuda() 84 | enc.load(obj, 0) 85 | enc.eval() 86 | results = eval_encoder_NN_multiK(enc, obj, 1) 87 | score_map = results['maps_mult'] 88 | 89 | images = mvtecad.get_x(obj, mode='test') 90 | 91 | masks = mvtecad.get_mask(obj) 92 | masks[masks==255] = 1 93 | 94 | 95 | for i in range(score_map.shape[0]): 96 | score_map[i] = gaussian_filter(score_map[i], sigma=2) 97 | 98 | max_score = score_map.max() 99 | min_score = score_map.min() 100 | scores = (score_map - min_score) / (max_score - min_score) 101 | 102 | gt_mask = np.asarray(masks) 103 | precision, recall, thresholds = precision_recall_curve(gt_mask.flatten(), scores.flatten()) 104 | a = 2 * precision * recall 105 | b = precision + recall 106 | f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0) 107 | threshold = thresholds[np.argmax(f1)] 108 | 109 | plot_fig(images, scores, masks, threshold, obj) 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | scikit-image 4 | scikit-learn 5 | torch 6 | tqdm 7 | Pillow 8 | imageio 9 | opencv-python 10 | ngt -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--obj', default='transistor') 6 | parser.add_argument('--maps_num', default=5, type=int) 7 | parser.add_argument('--gpu', default='3', type=str) 8 | 9 | args = parser.parse_args() 10 | 11 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 12 | 13 | def do_evaluate_encoder_multiK(obj): 14 | from codes.inspection import eval_encoder_NN_multiK 15 | from codes.networks import MyJigsawPositionHierEncoder 16 | 17 | enc = MyJigsawPositionHierEncoder(K=64, D=64).cuda() 18 | enc.load(obj, 0) 19 | enc.eval() 20 | for i in range(1, 21): 21 | results = eval_encoder_NN_multiK(enc, obj, i) 22 | 23 | det_64 = results['det_64'] 24 | seg_64 = results['seg_64'] 25 | 26 | det_32 = results['det_32'] 27 | seg_32 = results['seg_32'] 28 | 29 | det_sum = results['det_sum'] 30 | seg_sum = results['seg_sum'] 31 | 32 | det_mult = results['det_mult'] 33 | seg_mult = results['seg_mult'] 34 | 35 | print('Maps NUM is {}'.format(i)) 36 | print( 37 | f'| K64 | Det: {det_64:.3f} Seg:{seg_64:.3f} | K32 | Det: {det_32:.3f} Seg:{seg_32:.3f} | sum | Det: {det_sum:.3f} Seg:{seg_sum:.3f} | mult | Det: {det_mult:.3f} Seg:{seg_mult:.3f} ({obj})') 38 | 39 | 40 | ######################### 41 | 42 | 43 | def main(): 44 | do_evaluate_encoder_multiK(args.obj) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from codes import mvtecad 4 | from functools import reduce 5 | from torch.utils.data import DataLoader 6 | from codes.datasets import * 7 | from codes.networks import * 8 | from codes.inspection import eval_encoder_NN_multiK 9 | from codes.utils import * 10 | import os 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--obj', default='transistor_plus', type=str) 15 | parser.add_argument('--lambda_value', default=0.001, type=float) 16 | parser.add_argument('--D', default=64, type=int) 17 | 18 | parser.add_argument('--epochs', default=400, type=int) 19 | parser.add_argument('--lr', default=1e-4, type=float) 20 | 21 | parser.add_argument('--gpu', default='0', type=str) 22 | parser.add_argument('--maps_num', default=1, type=int) 23 | 24 | args = parser.parse_args() 25 | 26 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 27 | 28 | def train(): 29 | obj = args.obj 30 | D = args.D 31 | lr = args.lr 32 | 33 | with task('Networks'): 34 | 35 | enc = MyJigsawPositionHierEncoder(64, D).cuda() 36 | 37 | cls_64 = MyJigsawPositionClassifier(64, D).cuda() 38 | cls_32 = MyJigsawPositionClassifier(32, D).cuda() 39 | 40 | modules = [enc, cls_64, cls_32] 41 | params = [list(module.parameters()) for module in modules] 42 | 43 | params = reduce(lambda x, y: x + y, params) 44 | 45 | opt = torch.optim.Adam(params=params, lr=lr) 46 | 47 | with task('Datasets'): 48 | 49 | train_x = mvtecad.get_x_standardized(obj, mode='train') 50 | train_x = NHWC2NCHW(train_x) 51 | 52 | rep = 100 53 | datasets = dict() 54 | 55 | datasets[f'pos_64'] = MyJigsawPositionDataset(train_x, K=64, repeat=rep) 56 | datasets[f'pos_32'] = MyJigsawPositionDataset(train_x, K=32, repeat=rep) 57 | 58 | datasets[f'svdd_64'] = SVDD_Dataset(train_x, K=64, repeat=rep) 59 | datasets[f'svdd_32'] = SVDD_Dataset(train_x, K=32, repeat=rep) 60 | 61 | dataset = DictionaryConcatDataset(datasets) 62 | loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True) 63 | 64 | print('Start training') 65 | for i_epoch in range(args.epochs): 66 | if i_epoch != 0: 67 | for module in modules: 68 | module.train() 69 | 70 | for d in loader: 71 | d = to_device(d, 'cuda', non_blocking=True) 72 | opt.zero_grad() 73 | 74 | loss_pos_64 = MyJigsawPositionClassifier.infer(cls_64, enc, d['pos_64']) 75 | loss_pos_32 = MyJigsawPositionClassifier.infer(cls_32, enc.enc, d['pos_32']) 76 | 77 | loss_svdd_64 = SVDD_Dataset.infer(enc, d['svdd_64']) 78 | loss_svdd_32 = SVDD_Dataset.infer(enc.enc, d['svdd_32']) 79 | 80 | loss = loss_pos_64 + loss_pos_32 + args.lambda_value * (loss_svdd_64 + loss_svdd_32) 81 | 82 | loss.backward() 83 | opt.step() 84 | 85 | aurocs = eval_encoder_NN_multiK(enc, obj, args.maps_num) 86 | 87 | log_result(obj, aurocs) 88 | enc.save(obj, i_epoch) 89 | 90 | 91 | def log_result(obj, aurocs): 92 | det_64 = aurocs['det_64'] * 100 93 | seg_64 = aurocs['seg_64'] * 100 94 | 95 | det_32 = aurocs['det_32'] * 100 96 | seg_32 = aurocs['seg_32'] * 100 97 | 98 | det_sum = aurocs['det_sum'] * 100 99 | seg_sum = aurocs['seg_sum'] * 100 100 | 101 | det_mult = aurocs['det_mult'] * 100 102 | seg_mult = aurocs['seg_mult'] * 100 103 | 104 | print( 105 | f'|K64| Det: {det_64:4.1f} Seg: {seg_64:4.1f} |K32| Det: {det_32:4.1f} Seg: {seg_32:4.1f} |mult| Det: {det_sum:4.1f} Seg: {seg_sum:4.1f} |mult| Det: {det_mult:4.1f} Seg: {seg_mult:4.1f} ({obj})') 106 | 107 | 108 | if __name__ == '__main__': 109 | train() 110 | --------------------------------------------------------------------------------