├── README.md
├── codes
├── LRN.py
├── __init__.py
├── datasets.py
├── inspection.py
├── mvtecad.py
├── nearest_neighbor.py
├── networks.py
└── utils.py
├── data.npy
├── doc
└── svdd_result.jpeg
├── heat_map.py
├── requirements.txt
├── test.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # SCL-VI: Self-supervised Context Learning for Visual Inspection of Industrial Defects
2 |
3 |
4 |
5 | We address the challenge of detecting object defects through the self-supervised learning approach of solving the jigsaw puzzle problem.
6 |
7 | ## Results
8 | 
9 |
10 | ## Dependencies
11 | Since I did this project a long time ago, there may be some potential issues with environmental dependencies.
12 | - Tested with Python 3.8
13 | - [Pytorch](http://pytorch.org/) v1.6.0
14 |
15 | ## Dateset
16 | - Dataset : [MvTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad/)
17 |
18 | ## Run Training
19 | - python train.py --obj=cable --lambda_value=1 --D=64 --epoches=400 --lr=1e-4 --gpu=0
20 |
21 | ## Run Affinity Testing
22 | - python test.py --obj=cable --gpu=0
23 | - enc.load(obj, N) N is the serial number of the obtained training weight file
24 |
25 | ## Anomaly maps
26 | - python heat_map.py --obj=cable
27 | - enc.load(obj, N) N is the serial number of the obtained training weight file
28 |
29 | ## Details:
30 | - The input of the network should be 256x256
31 | - data.npy contains the relative positions and their reference numbers.
--------------------------------------------------------------------------------
/codes/LRN.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class LRN(nn.Module):
5 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True):
6 | super(LRN, self).__init__()
7 | self.ACROSS_CHANNELS = ACROSS_CHANNELS
8 | if ACROSS_CHANNELS:
9 | self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1),
10 | stride=1,padding=(int((local_size-1.0)/2), 0, 0))
11 | else:
12 | self.average=nn.AvgPool2d(kernel_size=local_size,
13 | stride=1,padding=int((local_size-1.0)/2))
14 | self.alpha = alpha
15 | self.beta = beta
16 |
17 |
18 | def forward(self, x):
19 | if self.ACROSS_CHANNELS:
20 | div = x.pow(2).unsqueeze(1)
21 | div = self.average(div).squeeze(1)
22 | div = div.mul(self.alpha).add(1.0).pow(self.beta)
23 | else:
24 | div = x.pow(2)
25 | div = self.average(div)
26 | div = div.mul(self.alpha).add(1.0).pow(self.beta)
27 | x = x.div(div)
28 | return x
--------------------------------------------------------------------------------
/codes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangpeng000/VisualInspection/a5933402284662bbab1d218c188cb16788de6a4e/codes/__init__.py
--------------------------------------------------------------------------------
/codes/datasets.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 |
8 | __all__ = ['SVDD_Dataset', 'PositionDataset', 'MyJigsawPositionDataset']
9 |
10 | jigsaw_num = 4
11 |
12 | def generate_coords(H, W, K):
13 | h = np.random.randint(0, H - K + 1)
14 | w = np.random.randint(0, W - K + 1)
15 | return h, w
16 |
17 |
18 | def generate_coords_mine(H, W, K):
19 | h = np.random.randint(K, H - K + 1)
20 | w = np.random.randint(K, W - K + 1)
21 | return h, w
22 |
23 | def generate_coords_position(H, W, K):
24 | with task('P1'):
25 | p1 = generate_coords(H, W, K)
26 | h1, w1 = p1
27 |
28 | pos = np.random.randint(8)
29 |
30 | with task('P2'):
31 | J = K // 4
32 |
33 | K3_4 = 3 * K // 4
34 | h_dir, w_dir = pos_to_diff[pos]
35 | h_del, w_del = np.random.randint(J, size=2)
36 |
37 | h_diff = h_dir * (h_del + K3_4)
38 | w_diff = w_dir * (w_del + K3_4)
39 |
40 | h2 = h1 + h_diff
41 | w2 = w1 + w_diff
42 |
43 | h2 = np.clip(h2, 0, H - K)
44 | w2 = np.clip(w2, 0, W - K)
45 |
46 | p2 = (h2, w2)
47 |
48 | return p1, p2, pos
49 |
50 |
51 | def generate_coords_svdd(H, W, K):
52 |
53 | with task('P1'):
54 | p1 = generate_coords(H, W, K)
55 | h1, w1 = p1
56 |
57 | with task('P2'):
58 | J = K // 32
59 |
60 | h_jit, w_jit = 0, 0
61 |
62 | while h_jit == 0 and w_jit == 0:
63 | h_jit = np.random.randint(-J, J + 1)
64 | w_jit = np.random.randint(-J, J + 1)
65 |
66 | h2 = h1 + h_jit
67 | w2 = w1 + w_jit
68 |
69 | h2 = np.clip(h2, 0, H - K)
70 | w2 = np.clip(w2, 0, W - K)
71 |
72 | p2 = (h2, w2)
73 |
74 | return p1, p2
75 |
76 |
77 | pos_to_diff = {
78 | 0: (-1, -1),
79 | 1: (-1, 0),
80 | 2: (-1, 1),
81 | 3: (0, -1),
82 | 4: (0, 1),
83 | 5: (1, -1),
84 | 6: (1, 0),
85 | 7: (1, 1)
86 | }
87 |
88 |
89 | def generate_coords_position_mine(H, W, K):
90 | with task('P_STD'):
91 | p_std = generate_coords(H, W, K)
92 | h_std, w_std = p_std
93 |
94 | with task('P_STD2'):
95 | J = K // 32
96 | h_jit, w_jit = 0, 0
97 |
98 | while h_jit == 0 and w_jit == 0:
99 | h_jit = np.random.randint(-J, J + 1)
100 | w_jit = np.random.randint(-J, J + 1)
101 |
102 | h_std2 = h_std + h_jit
103 | w_std2 = h_std + h_jit
104 |
105 | h_std2 = np.clip(h_std2, 0, H - K)
106 | w_std2 = np.clip(w_std2, 0, H - K)
107 |
108 | p_std2 = (h_std2, w_std2)
109 |
110 | with task('P0'):
111 | pos0 = 0
112 |
113 | J = K // 4
114 | K3_4 = 3 * K //4
115 | h_dir, w_dir = pos_to_diff[pos0]
116 | h_del, w_del = np.random.randint(J, size=2)
117 |
118 | h_diff = h_dir * (h_del + K3_4)
119 | w_diff = w_dir * (w_del + K3_4)
120 |
121 | h0 = h_std + h_diff
122 | w0 = w_std + w_diff
123 |
124 | h0 = np.clip(h0, 0, H - K)
125 | w0 = np.clip(w0, 0, W - K)
126 |
127 | p0 = (h0, w0)
128 |
129 | with task('P1'):
130 | pos1 = 1
131 |
132 | J = K // 4
133 | K3_4 = 3 * K // 4
134 | h_dir, w_dir = pos_to_diff[pos1]
135 | h_del, w_del = np.random.randint(J, size=2)
136 |
137 | h_diff = h_dir * (h_del + K3_4)
138 | w_diff = w_dir * (w_del + K3_4)
139 |
140 | h1 = h_std + h_diff
141 | w1 = w_std + w_diff
142 |
143 | h1 = np.clip(h1, 0, H - K)
144 | w1 = np.clip(w1, 0, W - K)
145 |
146 | p1 = (h1, w1)
147 |
148 | with task('P2'):
149 | pos2 = 2
150 |
151 | J = K // 4
152 | K3_4 = 3 * K // 4
153 | h_dir, w_dir = pos_to_diff[pos2]
154 | h_del, w_del = np.random.randint(J, size=2)
155 |
156 | h_diff = h_dir * (h_del + K3_4)
157 | w_diff = w_dir * (w_del + K3_4)
158 |
159 | h2 = h_std + h_diff
160 | w2 = w_std + w_diff
161 |
162 | h2 = np.clip(h2, 0, H - K)
163 | w2 = np.clip(w2, 0, W - K)
164 |
165 | p2 = (h2, w2)
166 |
167 | with task('P3'):
168 | pos3 = 3
169 |
170 | J = K // 4
171 | K3_4 = 3 * K // 4
172 | h_dir, w_dir = pos_to_diff[pos3]
173 | h_del, w_del = np.random.randint(J, size=2)
174 |
175 | h_diff = h_dir * (h_del + K3_4)
176 | w_diff = w_dir * (w_del + K3_4)
177 |
178 | h3 = h_std + h_diff
179 | w3 = w_std + w_diff
180 |
181 | h3 = np.clip(h3, 0, H - K)
182 | w3 = np.clip(w3, 0, W - K)
183 |
184 | p3 = (h3, w3)
185 |
186 | with task('P4'):
187 | pos4 = 4
188 |
189 | J = K // 4
190 | K3_4 = 3 * K // 4
191 | h_dir, w_dir = pos_to_diff[pos4]
192 | h_del, w_del = np.random.randint(J, size=2)
193 |
194 | h_diff = h_dir * (h_del + K3_4)
195 | w_diff = w_dir * (w_del + K3_4)
196 |
197 | h4 = h_std + h_diff
198 | w4 = w_std + w_diff
199 |
200 | h4 = np.clip(h4, 0, H - K)
201 | w4 = np.clip(w4, 0, W - K)
202 |
203 | p4 = (h4, w4)
204 |
205 | with task('P5'):
206 | pos5 = 5
207 |
208 | J = K // 4
209 | K3_4 = 3 * K // 4
210 | h_dir, w_dir = pos_to_diff[pos5]
211 | h_del, w_del = np.random.randint(J, size=2)
212 |
213 | h_diff = h_dir * (h_del + K3_4)
214 | w_diff = w_dir * (w_del + K3_4)
215 |
216 | h5 = h_std + h_diff
217 | w5 = w_std + w_diff
218 |
219 | h5 = np.clip(h5, 0, H - K)
220 | w5 = np.clip(w5, 0, W - K)
221 |
222 | p5 = (h5, w5)
223 |
224 | with task('P6'):
225 | pos6 = 6
226 |
227 | J = K // 4
228 | K3_4 = 3 * K // 4
229 | h_dir, w_dir = pos_to_diff[pos6]
230 | h_del, w_del = np.random.randint(J, size=2)
231 |
232 | h_diff = h_dir * (h_del + K3_4)
233 | w_diff = w_dir * (w_del + K3_4)
234 |
235 | h6 = h_std + h_diff
236 | w6 = w_std + w_diff
237 |
238 | h6 = np.clip(h6, 0, H - K)
239 | w6 = np.clip(w6, 0, W - K)
240 |
241 | p6 = (h6, w6)
242 |
243 | with task('P7'):
244 | pos7 = 7
245 |
246 | J = K // 4
247 | K3_4 = 3 * K // 4
248 | h_dir, w_dir = pos_to_diff[pos7]
249 | h_del, w_del = np.random.randint(J, size=2)
250 |
251 | h_diff = h_dir * (h_del + K3_4)
252 | w_diff = w_dir * (w_del + K3_4)
253 |
254 | h7 = h_std + h_diff
255 | w7 = w_std + w_diff
256 |
257 | h7 = np.clip(h7, 0, H - K)
258 | w7 = np.clip(w7, 0, W - K)
259 |
260 | p7 = (h7, w7)
261 |
262 | return p0, p1, p2, p3, p_std2,p4, p5, p6, p7, p_std
263 |
264 |
265 |
266 |
267 | class SVDD_Dataset(Dataset):
268 |
269 | def __init__(self, memmap, K=64, repeat=1):
270 | super().__init__()
271 | self.arr = np.asarray(memmap)
272 | self.K = K
273 | self.repeat = repeat
274 |
275 | def __len__(self):
276 | N = self.arr.shape[0]
277 | return N * self.repeat
278 |
279 | def __getitem__(self, idx):
280 | N = self.arr.shape[0]
281 | K = self.K
282 | n = idx % N
283 |
284 | p1, p2 = generate_coords_svdd(256, 256, K)
285 |
286 | image = self.arr[n]
287 |
288 | patch1 = crop_image_CHW(image, p1, K)
289 | patch2 = crop_image_CHW(image, p2, K)
290 |
291 | return patch1, patch2
292 |
293 | @staticmethod
294 | def infer(enc, batch):
295 |
296 | x1s, x2s, = batch
297 | h1s = enc(x1s)
298 | h2s = enc(x2s)
299 | diff = h1s - h2s
300 | l2 = diff.norm(dim=1)
301 | loss = l2.mean()
302 |
303 | return loss
304 |
305 |
306 | class PositionDataset(Dataset):
307 |
308 | def __init__(self, x, K=64, repeat=1):
309 | super(PositionDataset, self).__init__()
310 | self.x = np.asarray(x)
311 | self.K = K
312 | self.repeat = repeat
313 |
314 | def __len__(self):
315 | N = self.x.shape[0]
316 | return N * self.repeat
317 |
318 | def __getitem__(self, idx):
319 | N = self.x.shape[0]
320 | K = self.K
321 | n = idx % N
322 |
323 | image = self.x[n]
324 | p1, p2, pos = generate_coords_position(256, 256, K)
325 |
326 | patch1 = crop_image_CHW(image, p1, K).copy()
327 | patch2 = crop_image_CHW(image, p2, K).copy()
328 |
329 | rgbshift1 = np.random.normal(scale=0.02, size=(3, 1, 1))
330 | rgbshift2 = np.random.normal(scale=0.02, size=(3, 1, 1))
331 |
332 | patch1 += rgbshift1
333 | patch2 += rgbshift2
334 |
335 |
336 | noise1 = np.random.normal(scale=0.02, size=(3, K, K))
337 | noise2 = np.random.normal(scale=0.02, size=(3, K, K))
338 |
339 | patch1 += noise1
340 | patch2 += noise2
341 |
342 | return patch1, patch2, pos
343 |
344 | class MyJigsawPositionDataset(Dataset):
345 | def __init__(self, x, K=64, repeat=1):
346 | super(MyJigsawPositionDataset, self).__init__()
347 | self.x = np.asarray(x)
348 | self.K = K
349 | self.repeat = repeat
350 |
351 | def __len__(self):
352 | N = self.x.shape[0]
353 | return N * self.repeat
354 |
355 | def __getitem__(self, idx):
356 | N = self.x.shape[0]
357 | K = self.K
358 | n = idx % N
359 |
360 | image = self.x[n]
361 |
362 | position = generate_coords_position_mine(256, 256, K)
363 | npy = np.load('data.npy')
364 | order = np.random.randint(len(npy))
365 |
366 | patch1 = crop_image_CHW(image, position[npy[order][0]], K).copy()
367 | patch2 = crop_image_CHW(image, position[npy[order][1]], K).copy()
368 | pos = npy[order][2]
369 |
370 | rgbshift1 = np.random.normal(scale=0.02, size=(3, 1, 1))
371 | rgbshift2 = np.random.normal(scale=0.02, size=(3, 1, 1))
372 |
373 | patch1 += rgbshift1
374 | patch2 += rgbshift2
375 |
376 |
377 | noise1 = np.random.normal(scale=0.02, size=(3, K, K))
378 | noise2 = np.random.normal(scale=0.02, size=(3, K, K))
379 |
380 | patch1 += noise1
381 | patch2 += noise2
382 |
383 | return patch1, patch2, pos
--------------------------------------------------------------------------------
/codes/inspection.py:
--------------------------------------------------------------------------------
1 | from codes import mvtecad
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from .utils import PatchDataset_NCHW, NHWC2NCHW, distribute_scores
6 |
7 | __all__ = ['eval_encoder_NN_multiK', 'eval_embeddings_NN_multiK']
8 |
9 |
10 | def weights_e_alpha(l2_maps):
11 | N, I, J, NN = l2_maps.shape
12 | weights_alpha = np.ones((N, I, J, NN), dtype=np.float32)
13 | weights_1_alpha = np.ones((N, I, J, NN), dtype=np.float32)
14 | weights_e_1_alpha = np.ones((N, I, J, NN), dtype=np.float32)
15 |
16 | weights = np.ones((N, I, J, NN), dtype=np.float32)
17 |
18 | l2_maps_e_1_alpha_sum = np.ones((N, I, J, NN), dtype=np.float32)
19 | result_NN = np.ones((N, I, J, NN), dtype=np.float32)
20 |
21 | l2_maps_sum = np.sum(l2_maps, axis=-1)
22 | l2_maps_sum[l2_maps_sum == 0] = 1
23 |
24 |
25 | for n in range(N):
26 | for i in range(I):
27 | for j in range(J):
28 | weights_alpha[n, i, j, :] = l2_maps[n, i, j, :]/l2_maps_sum[n, i, j]
29 | weights_alpha[weights_alpha == 0] = 1
30 | weights_1_alpha[n, i, j, :] = 1 / weights_alpha[n, i, j, :]
31 | weights_1_alpha[weights_1_alpha > 20] = 15
32 | weights_e_1_alpha[n, i, j, :] = np.exp(weights_1_alpha[n, i, j, :])
33 | l2_maps_e_1_alpha_sum[n, i, j] = np.sum(weights_e_1_alpha[n, i, j, :], axis=-1)
34 | weights[n, i, j, :] = weights_e_1_alpha[n, i, j, :] / l2_maps_e_1_alpha_sum[n, i, j]
35 | result_NN[n, i, j, :] = l2_maps[n, i, j, :] * weights[n, i, j, :]
36 |
37 | result = np.sum(result_NN, axis=-1)
38 | return result
39 |
40 | def infer(x, enc, K, S):
41 | x = NHWC2NCHW(x)
42 |
43 | dataset = PatchDataset_NCHW(x, K=K, S=S)
44 | loader = DataLoader(dataset, batch_size=64, shuffle=False, pin_memory=True)
45 |
46 | embs = np.empty((dataset.N, dataset.row_num, dataset.col_num, enc.D), dtype=np.float32) # [-1, I, J, D]
47 |
48 | enc = enc.eval()
49 | with torch.no_grad():
50 | for xs, ns, iis, js in loader:
51 | xs = xs.cuda()
52 |
53 | embedding = enc(xs)
54 | embedding = embedding.detach().cpu().numpy()
55 |
56 | for embed, n, i, j in zip(embedding, ns, iis, js):
57 | embs[n, i, j] = np.squeeze(embed)
58 | return embs
59 |
60 |
61 | def assess_anomaly_maps(obj, anomaly_maps):
62 | auroc_seg = mvtecad.segmentation_auroc(obj, anomaly_maps)
63 | anomaly_scores = anomaly_maps.max(axis=-1).max(axis=-1)
64 | auroc_det = mvtecad.detection_auroc(obj, anomaly_scores)
65 | return auroc_det, auroc_seg
66 |
67 |
68 | def measure_emb_NN(emb_te, emb_tr, method='kdt', NN=1):
69 | from .nearest_neighbor import search_NN
70 |
71 | D = emb_tr.shape[-1]
72 |
73 | train_emb_all = emb_tr.reshape(-1, D)
74 |
75 | l2_maps, _ = search_NN(emb_te, train_emb_all, method=method, NN=NN)
76 |
77 | anomaly_maps = weights_e_alpha(l2_maps)
78 |
79 | return anomaly_maps
80 |
81 |
82 |
83 | def eval_encoder_NN_multiK(enc, obj, maps_num):
84 |
85 | x_tr = mvtecad.get_x_standardized(obj, mode='train')
86 | x_te = mvtecad.get_x_standardized(obj, mode='test')
87 |
88 | embs64_tr = infer(x_tr, enc, K=64, S=16)
89 | embs64_te = infer(x_te, enc, K=64, S=16)
90 |
91 | embs32_tr = infer(x_tr, enc.enc, K=32, S=4)
92 | embs32_te = infer(x_te, enc.enc, K=32, S=4)
93 |
94 |
95 | embs64 = embs64_tr, embs64_te
96 | embs32 = embs32_tr, embs32_te
97 |
98 |
99 | return eval_embeddings_NN_multiK(obj, embs64, embs32, NN=maps_num)
100 |
101 |
102 | def eval_embeddings_NN_multiK(obj, embs64, embs32, NN=1):
103 | emb_tr, emb_te = embs64
104 |
105 | maps_64 = measure_emb_NN(emb_te, emb_tr, method='kdt', NN=NN)
106 | maps_64 = distribute_scores(maps_64, (256, 256), K=64, S=16)
107 | det_64, seg_64 = assess_anomaly_maps(obj, maps_64)
108 |
109 | emb_tr, emb_te = embs32
110 | maps_32 = measure_emb_NN(emb_te, emb_tr, method='ngt', NN=NN)
111 | maps_32 = distribute_scores(maps_32, (256, 256), K=32, S=4)
112 | det_32, seg_32 = assess_anomaly_maps(obj, maps_32)
113 |
114 | maps_sum = maps_64 + maps_32
115 | det_sum, seg_sum = assess_anomaly_maps(obj, maps_sum)
116 |
117 | maps_mult = maps_64 * maps_32
118 | det_mult, seg_mult = assess_anomaly_maps(obj, maps_mult)
119 |
120 | return {
121 | 'det_64': det_64,
122 | 'seg_64': seg_64,
123 |
124 | 'det_32': det_32,
125 | 'seg_32': seg_32,
126 |
127 | 'det_sum': det_sum,
128 | 'seg_sum': seg_sum,
129 |
130 | 'det_mult': det_mult,
131 | 'seg_mult': seg_mult,
132 |
133 | 'maps_64': maps_64,
134 | 'maps_32': maps_32,
135 | 'maps_sum': maps_sum,
136 | 'maps_mult': maps_mult,
137 | }
138 |
--------------------------------------------------------------------------------
/codes/mvtecad.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | from imageio import imread
5 | from glob import glob
6 | from sklearn.metrics import roc_auc_score
7 | import os
8 |
9 |
10 | DATASET_PATH = './data/MVTec'
11 |
12 |
13 | __all__ = ['objs', 'set_root_path',
14 | 'get_x', 'get_x_standardized',
15 | 'detection_auroc', 'segmentation_auroc']
16 |
17 |
18 | objs = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
19 | 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
20 | 'transistor', 'wood', 'zipper']
21 |
22 |
23 | def resize(image, shape=(256, 256)):
24 | return np.array(Image.fromarray(image).resize(shape[::-1]))
25 |
26 |
27 |
28 | def bilinears(images, shape) -> np.ndarray:
29 | import cv2
30 | N = images.shape[0]
31 |
32 | new_shape = (N,) + shape
33 | ret = np.zeros(new_shape, dtype=images.dtype)
34 | for i in range(N):
35 |
36 | ret[i] = cv2.resize(images[i], dsize=shape[::-1], interpolation=cv2.INTER_LINEAR)
37 | return ret
38 |
39 |
40 |
41 | def gray2rgb(images):
42 | tile_shape = tuple(np.ones(len(images.shape), dtype=int))
43 | tile_shape += (3,)
44 |
45 | images = np.tile(np.expand_dims(images, axis=-1), tile_shape)
46 |
47 | return images
48 |
49 |
50 |
51 | def set_root_path(new_path):
52 | global DATASET_PATH
53 | DATASET_PATH = new_path
54 |
55 |
56 | def get_x(obj, mode='train'):
57 |
58 | fpattern = os.path.join(DATASET_PATH, f'{obj}/{mode}/*/*.png')
59 |
60 | fpaths = sorted(glob(fpattern))
61 |
62 | if mode == 'test':
63 |
64 | fpaths1 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) != 'good', fpaths))
65 | fpaths2 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) == 'good', fpaths))
66 |
67 | images1 = np.asarray(list(map(imread, fpaths1)))
68 | images2 = np.asarray(list(map(imread, fpaths2)))
69 | images = np.concatenate([images1, images2])
70 |
71 | else:
72 | images = np.asarray(list(map(imread, fpaths)))
73 |
74 | if images.shape[-1] != 3:
75 | images = gray2rgb(images)
76 | images = list(map(resize, images))
77 | images = np.asarray(images)
78 | return images
79 |
80 |
81 | def get_x_standardized(obj, mode='train'):
82 | x = get_x(obj, mode=mode)
83 | mean = get_mean(obj)
84 | return (x.astype(np.float32) - mean) / 255
85 |
86 |
87 | def get_label(obj):
88 | fpattern = os.path.join(DATASET_PATH, f'{obj}/test/*/*.png')
89 | fpaths = sorted(glob(fpattern))
90 | fpaths1 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) != 'good', fpaths))
91 | fpaths2 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) == 'good', fpaths))
92 |
93 | Nanomaly = len(fpaths1)
94 | Nnormal = len(fpaths2)
95 | labels = np.zeros(Nanomaly + Nnormal, dtype=np.int32)
96 | labels[:Nanomaly] = 1
97 | return labels
98 |
99 |
100 | def get_mask(obj):
101 | fpattern = os.path.join(DATASET_PATH, f'{obj}/ground_truth/*/*.png')
102 | fpaths = sorted(glob(fpattern))
103 | masks = np.asarray(list(map(lambda fpath: resize(imread(fpath), (256, 256)), fpaths)))
104 | Nanomaly = masks.shape[0]
105 | Nnormal = len(glob(os.path.join(DATASET_PATH, f'{obj}/test/good/*.png')))
106 |
107 | masks[masks <= 128] = 0
108 | masks[masks > 128] = 255
109 | results = np.zeros((Nanomaly + Nnormal,) + masks.shape[1:], dtype=masks.dtype)
110 | results[:Nanomaly] = masks
111 |
112 | return results
113 |
114 |
115 | def get_mean(obj):
116 | images = get_x(obj, mode='train')
117 | mean = images.astype(np.float32).mean(axis=0)
118 | return mean
119 |
120 |
121 |
122 | def detection_auroc(obj, anomaly_scores):
123 | label = get_label(obj) # 1: anomaly 0: normal
124 | auroc = roc_auc_score(label, anomaly_scores)
125 | return auroc
126 |
127 |
128 | def segmentation_auroc(obj, anomaly_maps):
129 | gt = get_mask(obj)
130 | gt = gt.astype(np.int32)
131 | gt[gt == 255] = 1
132 |
133 | anomaly_maps = bilinears(anomaly_maps, (256, 256))
134 | auroc = roc_auc_score(gt.flatten(), anomaly_maps.flatten())
135 | return auroc
136 |
137 |
--------------------------------------------------------------------------------
/codes/nearest_neighbor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import shutil
3 | import os
4 |
5 |
6 | __all__ = ['search_NN']
7 |
8 |
9 |
10 | def search_NN(test_emb, train_emb_flat, NN=1, method='kdt'):
11 | if method == 'ngt':
12 | return search_NN_ngt(test_emb, train_emb_flat, NN=NN)
13 |
14 | from sklearn.neighbors import KDTree
15 | kdt = KDTree(train_emb_flat)
16 |
17 | Ntest, I, J, D = test_emb.shape
18 | closest_inds = np.empty((Ntest, I, J, NN), dtype=np.int32)
19 | l2_maps = np.empty((Ntest, I, J, NN), dtype=np.float32)
20 |
21 | for n in range(Ntest):
22 | for i in range(I):
23 | dists, inds = kdt.query(test_emb[n, i, :, :], return_distance=True, k=NN)
24 | closest_inds[n, i, :, :] = inds[:, :]
25 | l2_maps[n, i, :, :] = dists[:, :]
26 |
27 | return l2_maps, closest_inds
28 |
29 |
30 | def search_NN_ngt(test_emb, train_emb_flat, NN=1):
31 | import ngtpy
32 |
33 | Ntest, I, J, D = test_emb.shape
34 | closest_inds = np.empty((Ntest, I, J, NN), dtype=np.int32)
35 | l2_maps = np.empty((Ntest, I, J, NN), dtype=np.float32)
36 |
37 | dpath = f'/tmp/{os.getpid()}'
38 | ngtpy.create(dpath, D)
39 | index = ngtpy.Index(dpath)
40 | index.batch_insert(train_emb_flat)
41 |
42 | for n in range(Ntest):
43 | for i in range(I):
44 | for j in range(J):
45 | query = test_emb[n, i, j, :]
46 | results = index.search(query, NN)
47 | inds = [result[0] for result in results]
48 |
49 | closest_inds[n, i, j, :] = inds
50 | vecs = np.asarray([index.get_object(inds[nn]) for nn in range(NN)])
51 | dists = np.linalg.norm(query - vecs, axis=-1)
52 | l2_maps[n, i, j, :] = dists
53 | shutil.rmtree(dpath)
54 |
55 | return l2_maps, closest_inds
56 |
--------------------------------------------------------------------------------
/codes/networks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import math
5 | from .utils import makedirpath
6 | from .LRN import LRN
7 | from .datasets import jigsaw_num
8 |
9 | __all__ = ['MyJigsawPositionHierEncoder', 'MyJigsawPositionDeepEncoder', 'MyJigsawPositionEncoder', 'MyJigsawPositionClassifier']
10 |
11 |
12 |
13 | def forward_hier(x, emb_small, K):
14 | K_2 = K // 2
15 | n = x.size(0)
16 | x1 = x[..., :K_2, :K_2]
17 | x2 = x[..., :K_2, K_2:]
18 | x3 = x[..., K_2:, :K_2]
19 | x4 = x[..., K_2:, K_2:]
20 | xx = torch.cat([x1, x2, x3, x4], dim=0)
21 |
22 | hh = emb_small(xx)
23 |
24 |
25 | h1 = hh[:n]
26 | h2 = hh[n: 2 * n]
27 | h3 = hh[2 * n: 3 * n]
28 | h4 = hh[3 * n:]
29 |
30 |
31 | h12 = torch.cat([h1, h2], dim=3)
32 | h34 = torch.cat([h3, h4], dim=3)
33 |
34 | h = torch.cat([h12, h34], dim=2)
35 | return h
36 |
37 |
38 |
39 | xent = nn.CrossEntropyLoss()
40 |
41 | class NormalizedLinear(nn.Module):
42 | __constants__ = ['bias', 'in_features', 'out_features']
43 |
44 | def __init__(self, in_features, out_features, bias=True):
45 | super(NormalizedLinear, self).__init__()
46 | self.in_features = in_features
47 | self.out_features = out_features
48 | # weight的形状
49 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
50 | if bias:
51 | # bias的形状
52 | self.bias = nn.Parameter(torch.Tensor(out_features))
53 | else:
54 | self.register_parameter('bias', None)
55 | self.reset_parameters()
56 |
57 | def reset_parameters(self):
58 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
59 | if self.bias is not None:
60 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
61 | bound = 1 / math.sqrt(fan_in)
62 | nn.init.uniform_(self.bias, -bound, bound)
63 |
64 | def forward(self, x):
65 | with torch.no_grad():
66 | w = self.weight / self.weight.data.norm(keepdim=True, dim=0)
67 | return F.linear(x, w, self.bias)
68 |
69 | def extra_repr(self):
70 | return 'in_features={}, out_features={}, bias={}'.format(
71 | self.in_features, self.out_features, self.bias is not None
72 | )
73 |
74 |
75 | class MyJigsawPositionEncoder(nn.Module):
76 | def __init__(self, K, D=64, bias=True):
77 | super().__init__()
78 |
79 |
80 | self.conv1 = nn.Conv2d(3, 64, 5, 2, 0, bias=bias)
81 | self.conv2 = nn.Conv2d(64, 64, 5, 2, 0, bias=bias)
82 | self.conv3 = nn.Conv2d(64, 128, 5, 2, 0, bias=bias)
83 | self.conv4 = nn.Conv2d(128, D, 5, 1, 0, bias=bias)
84 |
85 | self.K = K
86 | self.D = D
87 | self.bias = bias
88 |
89 | def forward(self, x):
90 | h = self.conv1(x)
91 | h = F.leaky_relu(h, 0.1)
92 |
93 | h = self.conv2(h)
94 | h = F.leaky_relu(h, 0.1)
95 |
96 | h = self.conv3(h)
97 |
98 |
99 | if self.K == 64:
100 | h = F.leaky_relu(h, 0.1)
101 | h = self.conv4(h)
102 |
103 | h = torch.tanh(h)
104 |
105 | return h
106 |
107 | class MyJigsawPositionDeepEncoder(nn.Module):
108 | def __init__(self, K, D=64, bias=True):
109 | super().__init__()
110 | self.conv = nn.Sequential(
111 | nn.Conv2d(3, 96, kernel_size=5, stride=1, padding=0),
112 | nn.GroupNorm(12, 96),
113 | nn.ReLU(inplace=True),
114 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
115 | LRN(local_size=5, alpha=0.0001, beta=0.75),
116 |
117 | nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
118 | nn.GroupNorm(32, 256),
119 | nn.ReLU(inplace=True),
120 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
121 | LRN(local_size=5, alpha=0.0001, beta=0.75),
122 |
123 | nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
124 | nn.GroupNorm(48, 384),
125 | nn.ReLU(inplace=True),
126 |
127 | nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
128 | nn.GroupNorm(48, 384),
129 | nn.ReLU(inplace=True),
130 |
131 | nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
132 | nn.GroupNorm(32, 256),
133 | nn.ReLU(inplace=True),
134 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
135 |
136 | nn.Conv2d(256, 128, kernel_size=2, stride=1, padding=0),
137 | nn.GroupNorm(16, 128),
138 |
139 | nn.Conv2d(128, D, kernel_size=2, stride=1, padding=0),
140 | nn.GroupNorm(int(D/8), D)
141 | )
142 |
143 | self.K = K
144 | self.D = D
145 |
146 | def forward(self, x):
147 | x = self.conv(x)
148 |
149 | x = torch.tanh(x)
150 |
151 | return x
152 |
153 | def save(self, name):
154 | fpath = self.fpath_from_name(name)
155 | makedirpath(fpath)
156 | torch.save(self.state_dict(), fpath)
157 |
158 | def load(self, name):
159 | fpath = self.fpath_from_name(name)
160 | self.load_state_dict(torch.load(fpath))
161 |
162 | @staticmethod
163 | def fpath_from_name(name):
164 | return f'ckpts/{name}/myjigsawposition_encdeep.pkl'
165 |
166 | class MyJigsawPositionHierEncoder(nn.Module):
167 | def __init__(self, K, D=64, bias=True):
168 | super().__init__()
169 |
170 | if K > 64:
171 | self.enc = MyJigsawPositionHierEncoder(K // 2, D, bias=bias)
172 |
173 |
174 | elif K == 64:
175 | self.enc = MyJigsawPositionDeepEncoder(K // 2, D, bias=bias)
176 |
177 | else:
178 | raise ValueError()
179 |
180 | self.conv1 = nn.Conv2d(D, 128, 2, 1, 0, bias=bias)
181 | self.conv2 = nn.Conv2d(128, D, 1, 1, 0, bias=bias)
182 |
183 | self.K = K
184 | self.D = D
185 |
186 | def forward(self, x):
187 |
188 | h = forward_hier(x, self.enc, K=self.K)
189 |
190 |
191 | h = self.conv1(h)
192 | h = F.leaky_relu(h, 0.1)
193 |
194 | h = self.conv2(h)
195 | h = torch.tanh(h)
196 |
197 | return h
198 |
199 | def save(self, name, i):
200 | fpath = self.fpath_from_name(name, i)
201 | makedirpath(fpath)
202 | torch.save(self.state_dict(), fpath)
203 |
204 | def load(self, name, i):
205 | fpath = self.fpath_from_name(name, i)
206 | self.load_state_dict(torch.load(fpath))
207 | print('Encoder has been loaded!')
208 |
209 | @staticmethod
210 | def fpath_from_name(name, i):
211 | return f'ckpts/{name}/myjigsawposition_enchier_{i}_step.pkl'
212 |
213 | class MyJigsawPositionClassifier(nn.Module):
214 | def __init__(self, K, D, class_num=12):
215 | super().__init__()
216 | self.D = D
217 |
218 | self.fc1 = nn.Linear(D, 128)
219 | self.act1 = nn.LeakyReLU(0.1)
220 |
221 | self.fc2 = nn.Linear(128, 128)
222 | self.act2 = nn.LeakyReLU(0.1)
223 |
224 | self.fc3 = NormalizedLinear(128, class_num)
225 |
226 | self.K = K
227 |
228 | def save(self, name):
229 | fpath = self.fpath_from_name(name)
230 | makedirpath(fpath)
231 | torch.save(self.state_dict(), fpath)
232 |
233 | def load(self, name):
234 | fpath = self.fpath_from_name(name)
235 | self.load_state_dict(torch.load(fpath))
236 |
237 | def fpath_from_name(self, name):
238 | return f'ckpts/{name}/position_classifier_K{self.K}.pkl'
239 |
240 | @staticmethod
241 | def infer(c, enc, batch):
242 |
243 | x1s, x2s, ys = batch
244 | ys = ys.long().cuda()
245 |
246 |
247 | h1 = enc(x1s)
248 | h2 = enc(x2s)
249 |
250 | logits = c(h1, h2)
251 |
252 | loss = xent(logits, ys)
253 | return loss
254 |
255 | def forward(self, h1, h2):
256 | h1 = h1.view(-1, self.D)
257 | h2 = h2.view(-1, self.D)
258 |
259 |
260 | h = h1 - h2
261 |
262 | h = self.fc1(h)
263 | h = self.act1(h)
264 |
265 | h = self.fc2(h)
266 | h = self.act2(h)
267 |
268 | h = self.fc3(h)
269 |
270 | return h
--------------------------------------------------------------------------------
/codes/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import _pickle as p
4 | from torch.utils.data import Dataset
5 | import torch
6 | from contextlib import contextmanager
7 | from PIL import Image
8 |
9 |
10 | __all__ = ['crop_image_CHW', 'PatchDataset_NCHW', 'NHWC2NCHW_normalize', 'NHWC2NCHW',
11 | 'save_binary', 'load_binary', 'makedirpath', 'task', 'DictionaryConcatDataset',
12 | 'to_device', 'distribute_scores', 'resize']
13 |
14 |
15 | def to_device(obj, device, non_blocking=False):
16 |
17 |
18 | if isinstance(obj, torch.Tensor):
19 | return obj.to(device, non_blocking=non_blocking)
20 |
21 |
22 | if isinstance(obj, dict):
23 |
24 | return {k: to_device(v, device, non_blocking=non_blocking)
25 | for k, v in obj.items()}
26 |
27 |
28 | if isinstance(obj, list):
29 | return [to_device(v, device, non_blocking=non_blocking)
30 | for v in obj]
31 |
32 |
33 | if isinstance(obj, tuple):
34 | return tuple([to_device(v, device, non_blocking=non_blocking)
35 | for v in obj])
36 |
37 |
38 | @contextmanager
39 | def task(_):
40 | yield
41 |
42 |
43 | class DictionaryConcatDataset(Dataset):
44 |
45 | def __init__(self, d_of_datasets):
46 | self.d_of_datasets = d_of_datasets
47 | lengths = [len(d) for d in d_of_datasets.values()]
48 | self._length = min(lengths)
49 | self.keys = self.d_of_datasets.keys()
50 | assert min(lengths) == max(lengths), 'Length of the datasets should be the same'
51 |
52 | def __getitem__(self, idx):
53 | return {
54 | key: self.d_of_datasets[key][idx]
55 | for key in self.keys
56 | }
57 |
58 | def __len__(self):
59 | return self._length
60 |
61 | def crop_CHW(image, i, j, K, S=1):
62 | if S == 1:
63 | h, w = i, j
64 | else:
65 | h = S * i
66 | w = S * j
67 | return image[:, h: h + K, w: w + K]
68 |
69 |
70 | def cnn_output_size(H, K, S=1, P=0) -> int:
71 | """
72 |
73 | :param int H: input_size
74 | :param int K: filter_size
75 | :param int S: stride
76 | :param int P: padding
77 | :return:
78 | """
79 | return 1 + (H - K + 2 * P) // S
80 |
81 | def crop_image_CHW(image, coord, K):
82 | h, w = coord
83 | return image[:, h: h + K, w: w + K]
84 |
85 |
86 | class PatchDataset_NCHW(Dataset):
87 | def __init__(self, memmap, tfs=None, K=32, S=1):
88 | super().__init__()
89 | self.arr = memmap
90 | self.tfs = tfs
91 | self.S = S
92 | self.K = K
93 | self.N = self.arr.shape[0]
94 |
95 | def __len__(self):
96 | return self.N * self.row_num * self.col_num
97 |
98 | @property
99 | def row_num(self):
100 | N, C, H, W = self.arr.shape
101 | K = self.K
102 | S = self.S
103 | I = cnn_output_size(H, K=K, S=S)
104 | return I
105 |
106 | @property
107 | def col_num(self):
108 | N, C, H, W = self.arr.shape
109 | K = self.K
110 | S = self.S
111 | J = cnn_output_size(W, K=K, S=S)
112 | return J
113 |
114 |
115 | def __getitem__(self, idx):
116 | N = self.N
117 | n, i, j = np.unravel_index(idx, (N, self.row_num, self.col_num))
118 | K = self.K
119 | S = self.S
120 |
121 | image = self.arr[n]
122 |
123 | patch = crop_CHW(image, i, j, K, S)
124 |
125 | if self.tfs:
126 | patch = self.tfs(patch)
127 |
128 | return patch, n, i, j
129 |
130 |
131 |
132 | def NHWC2NCHW_normalize(x):
133 | x = (x / 255.).astype(np.float32)
134 | return np.transpose(x, [0, 3, 1, 2])
135 |
136 |
137 |
138 | def NHWC2NCHW(x):
139 | return np.transpose(x, [0, 3, 1, 2])
140 |
141 |
142 | def load_binary(fpath, encoding='ASCII'):
143 | with open(fpath, 'rb') as f:
144 | return p.load(f, encoding=encoding)
145 |
146 |
147 | def save_binary(d, fpath):
148 | with open(fpath, 'wb') as f:
149 | p.dump(d, f)
150 |
151 |
152 | def makedirpath(fpath: str):
153 |
154 | dpath = os.path.dirname(fpath)
155 | if dpath:
156 | os.makedirs(dpath, exist_ok=True)
157 |
158 |
159 | def distribute_scores(score_masks, output_shape, K: int, S: int) -> np.ndarray:
160 | N = score_masks.shape[0]
161 | results = [distribute_score(score_masks[n], output_shape, K, S) for n in range(N)]
162 | return np.asarray(results)
163 |
164 |
165 | def distribute_score(score_mask, output_shape, K: int, S: int) -> np.ndarray:
166 | H, W = output_shape
167 | mask = np.zeros([H, W], dtype=np.float32)
168 | cnt = np.zeros([H, W], dtype=np.int32)
169 |
170 | I, J = score_mask.shape[:2]
171 | for i in range(I):
172 | for j in range(J):
173 | h, w = i * S, j * S
174 |
175 |
176 | mask[h: h + K, w: w + K] += score_mask[i, j]
177 |
178 | cnt[h: h + K, w: w + K] += 1
179 |
180 |
181 | cnt[cnt == 0] = 1
182 |
183 | return mask / cnt
184 |
185 |
186 | def resize(image, shape=(256, 256)):
187 |
188 | return np.array(Image.fromarray(image).resize(shape[::-1]))
189 |
--------------------------------------------------------------------------------
/data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangpeng000/VisualInspection/a5933402284662bbab1d218c188cb16788de6a4e/data.npy
--------------------------------------------------------------------------------
/doc/svdd_result.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangpeng000/VisualInspection/a5933402284662bbab1d218c188cb16788de6a4e/doc/svdd_result.jpeg
--------------------------------------------------------------------------------
/heat_map.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import matplotlib.pyplot as plt
3 | from codes import mvtecad
4 | from tqdm import tqdm
5 | from codes.utils import resize, makedirpath
6 |
7 | from skimage import morphology
8 | from skimage.segmentation import mark_boundaries
9 | import os
10 | import numpy as np
11 | import matplotlib
12 | from scipy.ndimage import gaussian_filter
13 | from sklearn.metrics import precision_recall_curve
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--obj', default='transistor')
17 | args = parser.parse_args()
18 |
19 |
20 | def plot_fig(test_img, scores, gts, threshold, obj):
21 |
22 | num = len(scores)
23 | vmax = scores.max() * 255.
24 | vmin = scores.min() * 255.
25 | for i in range(num):
26 | img = test_img[i]
27 | gt = gts[i]
28 | heat_map = scores[i] * 255
29 | mask = scores[i]
30 | mask[mask > threshold] = 1
31 | mask[mask <= threshold] = 0
32 | kernel = morphology.disk(4)
33 | mask = morphology.opening(mask, kernel)
34 | mask *= 255
35 | vis_img = mark_boundaries(img, mask, color=(1, 0, 0), mode='thick')
36 | fig_img, ax_img = plt.subplots(1, 5, figsize=(20, 5))
37 | fig_img.subplots_adjust(right=0.9)
38 | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
39 | for ax_i in ax_img:
40 | ax_i.axes.xaxis.set_visible(False)
41 | ax_i.axes.yaxis.set_visible(False)
42 | ax_img[0].imshow(img)
43 | ax_img[1].imshow(gt, cmap='gray')
44 | ax = ax_img[2].imshow(heat_map, cmap='jet', norm=norm)
45 | ax_img[2].imshow(img, cmap='gray', interpolation='none')
46 | ax_img[2].imshow(heat_map, cmap='jet', alpha=0.5, interpolation='none')
47 | ax_img[3].imshow(mask, cmap='gray')
48 | ax_img[4].imshow(vis_img)
49 | left = 0.92
50 | bottom = 0.15
51 | width = 0.015
52 | height = 1 - 2 * bottom
53 | rect = [left, bottom, width, height]
54 | cbar_ax = fig_img.add_axes(rect)
55 | cb = plt.colorbar(ax, shrink=0.6, cax=cbar_ax, fraction=0.046)
56 | cb.ax.tick_params(labelsize=8)
57 | font = {
58 | 'family': 'serif',
59 | 'color': 'black',
60 | 'weight': 'normal',
61 | 'size': 8,
62 | }
63 |
64 | fpath = f'anomaly_maps/{obj}/{i:03d}.png'
65 | makedirpath(fpath)
66 | fig_img.savefig(fpath)
67 | plt.close()
68 |
69 | def denormalization(x):
70 | mean = np.array([0.485, 0.456, 0.406])
71 | std = np.array([0.229, 0.224, 0.225])
72 | x = (((x * std) + mean) * 255.).astype(np.uint8)
73 |
74 | return x
75 |
76 |
77 | def main():
78 | from codes.inspection import eval_encoder_NN_multiK
79 | from codes.networks import MyJigsawPositionHierEncoder
80 |
81 | obj = args.obj
82 |
83 | enc = MyJigsawPositionHierEncoder(K=64, D=64).cuda()
84 | enc.load(obj, 0)
85 | enc.eval()
86 | results = eval_encoder_NN_multiK(enc, obj, 1)
87 | score_map = results['maps_mult']
88 |
89 | images = mvtecad.get_x(obj, mode='test')
90 |
91 | masks = mvtecad.get_mask(obj)
92 | masks[masks==255] = 1
93 |
94 |
95 | for i in range(score_map.shape[0]):
96 | score_map[i] = gaussian_filter(score_map[i], sigma=2)
97 |
98 | max_score = score_map.max()
99 | min_score = score_map.min()
100 | scores = (score_map - min_score) / (max_score - min_score)
101 |
102 | gt_mask = np.asarray(masks)
103 | precision, recall, thresholds = precision_recall_curve(gt_mask.flatten(), scores.flatten())
104 | a = 2 * precision * recall
105 | b = precision + recall
106 | f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
107 | threshold = thresholds[np.argmax(f1)]
108 |
109 | plot_fig(images, scores, masks, threshold, obj)
110 |
111 |
112 | if __name__ == '__main__':
113 | main()
114 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | numpy
3 | scikit-image
4 | scikit-learn
5 | torch
6 | tqdm
7 | Pillow
8 | imageio
9 | opencv-python
10 | ngt
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--obj', default='transistor')
6 | parser.add_argument('--maps_num', default=5, type=int)
7 | parser.add_argument('--gpu', default='3', type=str)
8 |
9 | args = parser.parse_args()
10 |
11 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
12 |
13 | def do_evaluate_encoder_multiK(obj):
14 | from codes.inspection import eval_encoder_NN_multiK
15 | from codes.networks import MyJigsawPositionHierEncoder
16 |
17 | enc = MyJigsawPositionHierEncoder(K=64, D=64).cuda()
18 | enc.load(obj, 0)
19 | enc.eval()
20 | for i in range(1, 21):
21 | results = eval_encoder_NN_multiK(enc, obj, i)
22 |
23 | det_64 = results['det_64']
24 | seg_64 = results['seg_64']
25 |
26 | det_32 = results['det_32']
27 | seg_32 = results['seg_32']
28 |
29 | det_sum = results['det_sum']
30 | seg_sum = results['seg_sum']
31 |
32 | det_mult = results['det_mult']
33 | seg_mult = results['seg_mult']
34 |
35 | print('Maps NUM is {}'.format(i))
36 | print(
37 | f'| K64 | Det: {det_64:.3f} Seg:{seg_64:.3f} | K32 | Det: {det_32:.3f} Seg:{seg_32:.3f} | sum | Det: {det_sum:.3f} Seg:{seg_sum:.3f} | mult | Det: {det_mult:.3f} Seg:{seg_mult:.3f} ({obj})')
38 |
39 |
40 | #########################
41 |
42 |
43 | def main():
44 | do_evaluate_encoder_multiK(args.obj)
45 |
46 |
47 | if __name__ == '__main__':
48 | main()
49 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from codes import mvtecad
4 | from functools import reduce
5 | from torch.utils.data import DataLoader
6 | from codes.datasets import *
7 | from codes.networks import *
8 | from codes.inspection import eval_encoder_NN_multiK
9 | from codes.utils import *
10 | import os
11 |
12 | parser = argparse.ArgumentParser()
13 |
14 | parser.add_argument('--obj', default='transistor_plus', type=str)
15 | parser.add_argument('--lambda_value', default=0.001, type=float)
16 | parser.add_argument('--D', default=64, type=int)
17 |
18 | parser.add_argument('--epochs', default=400, type=int)
19 | parser.add_argument('--lr', default=1e-4, type=float)
20 |
21 | parser.add_argument('--gpu', default='0', type=str)
22 | parser.add_argument('--maps_num', default=1, type=int)
23 |
24 | args = parser.parse_args()
25 |
26 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
27 |
28 | def train():
29 | obj = args.obj
30 | D = args.D
31 | lr = args.lr
32 |
33 | with task('Networks'):
34 |
35 | enc = MyJigsawPositionHierEncoder(64, D).cuda()
36 |
37 | cls_64 = MyJigsawPositionClassifier(64, D).cuda()
38 | cls_32 = MyJigsawPositionClassifier(32, D).cuda()
39 |
40 | modules = [enc, cls_64, cls_32]
41 | params = [list(module.parameters()) for module in modules]
42 |
43 | params = reduce(lambda x, y: x + y, params)
44 |
45 | opt = torch.optim.Adam(params=params, lr=lr)
46 |
47 | with task('Datasets'):
48 |
49 | train_x = mvtecad.get_x_standardized(obj, mode='train')
50 | train_x = NHWC2NCHW(train_x)
51 |
52 | rep = 100
53 | datasets = dict()
54 |
55 | datasets[f'pos_64'] = MyJigsawPositionDataset(train_x, K=64, repeat=rep)
56 | datasets[f'pos_32'] = MyJigsawPositionDataset(train_x, K=32, repeat=rep)
57 |
58 | datasets[f'svdd_64'] = SVDD_Dataset(train_x, K=64, repeat=rep)
59 | datasets[f'svdd_32'] = SVDD_Dataset(train_x, K=32, repeat=rep)
60 |
61 | dataset = DictionaryConcatDataset(datasets)
62 | loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
63 |
64 | print('Start training')
65 | for i_epoch in range(args.epochs):
66 | if i_epoch != 0:
67 | for module in modules:
68 | module.train()
69 |
70 | for d in loader:
71 | d = to_device(d, 'cuda', non_blocking=True)
72 | opt.zero_grad()
73 |
74 | loss_pos_64 = MyJigsawPositionClassifier.infer(cls_64, enc, d['pos_64'])
75 | loss_pos_32 = MyJigsawPositionClassifier.infer(cls_32, enc.enc, d['pos_32'])
76 |
77 | loss_svdd_64 = SVDD_Dataset.infer(enc, d['svdd_64'])
78 | loss_svdd_32 = SVDD_Dataset.infer(enc.enc, d['svdd_32'])
79 |
80 | loss = loss_pos_64 + loss_pos_32 + args.lambda_value * (loss_svdd_64 + loss_svdd_32)
81 |
82 | loss.backward()
83 | opt.step()
84 |
85 | aurocs = eval_encoder_NN_multiK(enc, obj, args.maps_num)
86 |
87 | log_result(obj, aurocs)
88 | enc.save(obj, i_epoch)
89 |
90 |
91 | def log_result(obj, aurocs):
92 | det_64 = aurocs['det_64'] * 100
93 | seg_64 = aurocs['seg_64'] * 100
94 |
95 | det_32 = aurocs['det_32'] * 100
96 | seg_32 = aurocs['seg_32'] * 100
97 |
98 | det_sum = aurocs['det_sum'] * 100
99 | seg_sum = aurocs['seg_sum'] * 100
100 |
101 | det_mult = aurocs['det_mult'] * 100
102 | seg_mult = aurocs['seg_mult'] * 100
103 |
104 | print(
105 | f'|K64| Det: {det_64:4.1f} Seg: {seg_64:4.1f} |K32| Det: {det_32:4.1f} Seg: {seg_32:4.1f} |mult| Det: {det_sum:4.1f} Seg: {seg_sum:4.1f} |mult| Det: {det_mult:4.1f} Seg: {seg_mult:4.1f} ({obj})')
106 |
107 |
108 | if __name__ == '__main__':
109 | train()
110 |
--------------------------------------------------------------------------------