├── .gitignore
├── .idea
├── D4Net.iml
├── misc.xml
├── modules.xml
└── workspace.xml
├── CAM.py
├── LICENSE
├── README.md
├── ckpt
└── d4net
│ └── ckptPlaceHolder
├── config.py
├── infer.py
├── resnext
├── __init__.py
├── config.py
├── resnext101.py
└── resnext_101_32x4d_.py
├── train.py
└── util
├── __init__.py
├── folder_classify.py
├── folder_mvtec.py
└── misc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.idea/D4Net.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 | root
81 | dir_name
82 | delete_list
83 | ResNeXt101_manus_with_diff
84 |
85 |
86 | data_root
87 | D4Net
88 |
89 |
90 |
91 |
105 |
106 |
107 |
108 |
109 | true
110 | DEFINITION_ORDER
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 | 1588844529704
241 |
242 |
243 | 1588844529704
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
--------------------------------------------------------------------------------
/CAM.py:
--------------------------------------------------------------------------------
1 | # simple implementation of CAM in PyTorch for the networks such as ResNet, DenseNet, SqueezeNet, Inception
2 |
3 | import io
4 | import requests
5 | from PIL import Image
6 | from torchvision import models, transforms
7 | from torch.autograd import Variable
8 | from torch.nn import functional as F
9 | import torch.nn as nn
10 | import numpy as np
11 | import torch
12 | import cv2
13 | import os
14 | import pdb
15 |
16 | from resnext import D4Net
17 |
18 |
19 |
20 | class Net(nn.Module):
21 | def __init__(self, num_cls):
22 | super(Net, self).__init__()
23 | self.net = D4Net(num_cls=2)
24 | self.layer0 = self.net.layer0
25 | self.layer1 = self.net.layer1
26 | self.layer2 = self.net.layer2
27 | self.layer3 = self.net.layer3
28 | self.layer4 = self.net.layer4
29 |
30 | # manus
31 | pretrained_path = os.path.join('./ckpt', 'd4net',
32 | 'd4net.pth')
33 | self.net.load_state_dict(torch.load(pretrained_path))
34 |
35 | def forward(self, x1, x2):
36 | output = self.net(x1, x2)
37 |
38 | layer0_x1 = self.layer0(x1)
39 | layer1_x1 = self.layer1(layer0_x1)
40 | layer2_x1 = self.layer2(layer1_x1)
41 | layer3_x1 = self.layer3(layer2_x1)
42 | layer4_x1 = self.layer4(layer3_x1)
43 |
44 | layer0_x2 = self.layer0(x2)
45 | layer1_x2 = self.layer1(layer0_x2)
46 | layer2_x2 = self.layer2(layer1_x2)
47 | layer3_x2 = self.layer3(layer2_x2)
48 | layer4_x2 = self.layer4(layer3_x2)
49 | difference = layer4_x2 - layer4_x1
50 |
51 | return output, difference
52 |
53 |
54 |
55 | img_root = 'the path to image directory'
56 | img_name = ''
57 | ref_name = ''
58 |
59 | net = Net(num_cls=2)
60 | finalconv_name = 'layer4'
61 |
62 | net.eval()
63 | with torch.no_grad():
64 |
65 | # hook the feature extractor
66 | features_blobs = []
67 | def hook_feature(module, input, output):
68 | features_blobs.append(output.data.cpu().numpy())
69 |
70 | net._modules.get(finalconv_name).register_forward_hook(hook_feature)
71 |
72 | # get the softmax weight
73 | params = list(net.parameters())
74 | weight_softmax = np.squeeze(params[-2].data.numpy())
75 |
76 | def returnCAM(feature_conv, weight_softmax, class_idx):
77 | # generate the class activation maps upsample to 512x512
78 | size_upsample = (512, 512)
79 | bz, nc, h, w = feature_conv.shape
80 | output_cam = []
81 | for idx in class_idx:
82 | cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h*w)))
83 | cam = cam.reshape(h, w)
84 | cam = cam - np.min(cam)
85 | cam_img = cam / np.max(cam)
86 | cam_img = np.uint8(255 * cam_img)
87 | output_cam.append(cv2.resize(cam_img, size_upsample))
88 | return output_cam
89 |
90 |
91 | preprocess = transforms.Compose([
92 | transforms.Resize((224,224)),
93 | transforms.ToTensor(),
94 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
95 | ])
96 |
97 |
98 | img_path = os.path.join(img_root, img_name)
99 | ref_path = os.path.join(img_root, ref_name)
100 | img = Image.open(img_path)
101 | ref = Image.open(ref_path)
102 |
103 | img_tensor = preprocess(img)
104 | ref_tensor = preprocess(ref)
105 | img_variable = Variable(img_tensor.unsqueeze(0))
106 | ref_variable = Variable(ref_tensor.unsqueeze(0))
107 |
108 | logit, features_blobs = net(img_variable, ref_variable)
109 |
110 | h_x = F.softmax(logit, dim=1).data.squeeze()
111 | probs, idx = h_x.sort(0, True)
112 | probs = probs.numpy()
113 | idx = idx.numpy()
114 |
115 | CAMs = returnCAM(features_blobs, weight_softmax, [idx[0]])
116 |
117 |
118 | # render the CAM and output
119 | # print('output CAM.jpg for the top1 prediction: %s'%classes[idx[0]])
120 | img = cv2.imread(img_path)
121 | height, width, _ = img.shape
122 | heatmap = cv2.applyColorMap(cv2.resize(CAMs[0],(width, height)), cv2.COLORMAP_JET)
123 | result = heatmap * 0.3 + img * 0.5
124 | cv2.imwrite( os.path.join('visual_result', 'result.jpg'), result)
125 |
126 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jiaxing Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # D4Net: De-Deformation Defect Detection Network for Non-Rigid Products with Large Patterns
2 | by Xuemiao Xu^, Jiaxing Chen^, Huaidong Zhang\*, and Wing~W. Y. Ng\* (^ joint 1st author, * joint corresponding author)[[paper link](https://www.sciencedirect.com/science/article/abs/pii/S0020025520304710)]
3 |
4 | This implementation is written by Jiaxing Chen at the South China University of Technology.
5 |
6 | ## Citation
7 |
8 | @article{xu2020d4net,
9 | title={D4Net: De-Deformation Defect Detection Network for Non-Rigid Products with Large Patterns},
10 | author={Xuemiao Xu, Jiaxing Chen, Huaidong Zhang, and Wing~W. Y. Ng},
11 | journal={Information Sciences},
12 | volume={547},
13 | pages={763--776},
14 | year={2021},
15 | publisher={Elsevier}
16 | }
17 |
18 | ## LFLP Dataset
19 |
20 | Due to the influence of COVID-19, the LFLP dataset will be released after the author returns to school. [[LFLP dataset link](https://drive.google.com/drive/folders/1t9TSmZiDb5mVElaqS6M9XEUV9w_BuJiF?usp=sharing)]
21 |
22 | ### The testing set of LFLP is updated(2021/04/22), the training set of LFLP is coming soon.
23 |
24 | ## Trained Model
25 |
26 | You can download the trained model which is reported in our paper at [Google Drive](https://drive.google.com/file/d/1knTpVXt3gKGxqHMZQKz-T1q0r3TlsQxf/view?usp=sharing).
27 |
28 | ## Requirement
29 |
30 | - Python 2.7
31 | - PyTorch 0.4.0
32 | - torchvision
33 | - numpy
34 |
35 | ## Training
36 |
37 | 1. Set the path of pretrained ResNeXt model in resnext/config.py
38 | 2. Set the path of LFLP dataset in config.py
39 | 3. Run by `python train.py`
40 |
41 | *Hyper-parameters* of training were gathered at the beginning of *train.py* and you can conveniently change them as you need.
42 |
43 | ## Testing
44 |
45 | 1. Put the trained model in ckpt/d4net
46 | 2. Run by `python infer.py`
47 |
48 | *Settings* of testing were gathered at the beginning of *infer.py* and you can conveniently change them as you need.
49 |
--------------------------------------------------------------------------------
/ckpt/d4net/ckptPlaceHolder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/githubBingoChen/D4Net/a9c07679fb66e3101e804aacd388508e52ab49aa/ckpt/d4net/ckptPlaceHolder
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | data_root = '/media/b3-542/0DFD5CD11721CA55/lace data refine/pair_perceptual'
5 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import torch
3 | import random
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from torch import optim, nn
7 | from torch.autograd import Variable
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | import time
11 |
12 | from config import *
13 | from resnext import D4Net
14 | # from util.folder_mvtec import PairLoader
15 | from util.folder_classify import PairLoader
16 | import json
17 | from util.misc import AvgMeter, check_mkdir, cal_metric
18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
19 |
20 | ckpt_path = './ckpt'
21 |
22 | args = {
23 | 'exp_name': 'd4net',
24 | 'val_batch_size': 157,
25 | 'snapshot': 'd4net.pth',
26 | }
27 |
28 | val_transform = transforms.Compose([
29 | transforms.Resize((224, 224)),
30 | transforms.ToTensor(),
31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
32 | ])
33 | val_set = PairLoader(mode='val', transform=val_transform)
34 | # val_set = PairLoader(transform=val_transform)
35 | val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=12, drop_last=False)
36 |
37 | def main():
38 |
39 | snapshot_path = os.path.join(ckpt_path, args['exp_name'], args['snapshot'])
40 |
41 | net = D4Net(num_cls=2)
42 | net.load_state_dict(torch.load(snapshot_path))
43 | net.cuda().eval()
44 |
45 | with torch.no_grad():
46 |
47 | val_length = len(val_loader)
48 | acc_meter = AvgMeter()
49 | precision_meter = AvgMeter()
50 | recall_meter = AvgMeter()
51 | f1_meter = AvgMeter()
52 |
53 | for vi, data in enumerate(val_loader, 0):
54 | print '%d/%d' % (vi + 1, val_length)
55 | test_images, ref_images, labels = data
56 |
57 | val_batch_size = labels.size(0)
58 |
59 | test_images = Variable(test_images).cuda()
60 | ref_images = Variable(ref_images).cuda()
61 | labels = Variable(labels).cuda()
62 |
63 | outputs, _ = net(test_images, ref_images)
64 |
65 | accuracy, precision, recall, f1_score = cal_metric(outputs.data, labels.data)
66 | acc_meter.update(accuracy, val_batch_size)
67 | precision_meter.update(precision, val_batch_size)
68 | recall_meter.update(recall, val_batch_size)
69 | f1_meter.update(f1_score, val_batch_size)
70 |
71 |
72 | log1 = args['exp_name'] + 'mvtec'
73 | log2 = '\nacc: precision: recall: f1: FPS\n'
74 | log3 = '%.4f \t %.4f \t %.4f \t %.4f\n' % (
75 | acc_meter.avg, precision_meter.avg, recall_meter.avg, f1_meter.avg)
76 |
77 | with open('experiment_result.txt', 'a') as f:
78 | f.write(log1 + log2 + log3 + '\n\n')
79 |
80 | if __name__ == '__main__':
81 | main()
--------------------------------------------------------------------------------
/resnext/__init__.py:
--------------------------------------------------------------------------------
1 | from resnext101 import ResNeXt101, D4Net
2 |
3 |
--------------------------------------------------------------------------------
/resnext/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | pytorch_pretrained_root = '/home/b3-542/Packages/Models/PyTorch Pretrained'
3 | resnext101_32_path = os.path.join(pytorch_pretrained_root, 'ResNeXt', 'resnext_101_32x4d.pth')
4 | pretrained_res18_path = os.path.join(pytorch_pretrained_root, 'ResNet', 'resnet18-5c106cde.pth')
5 |
6 |
7 |
--------------------------------------------------------------------------------
/resnext/resnext101.py:
--------------------------------------------------------------------------------
1 | import resnext_101_32x4d_
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | from config import *
7 |
8 | def initialize_weights(*models):
9 | for model in models:
10 | for module in model.modules():
11 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
12 | nn.init.kaiming_normal(module.weight)
13 | if module.bias is not None:
14 | module.bias.data.zero_()
15 | elif isinstance(module, nn.BatchNorm2d):
16 | module.weight.data.fill_(1)
17 | module.bias.data.zero_()
18 |
19 |
20 | class ResNeXt101(nn.Module):
21 | def __init__(self):
22 | super(ResNeXt101, self).__init__()
23 | net = resnext_101_32x4d_.resnext_101_32x4d
24 | net.load_state_dict(torch.load(resnext101_32_path))
25 |
26 | net = list(net.children())
27 | self.layer0 = nn.Sequential(*net[:4])
28 | self.layer1 = net[4]
29 | self.layer2 = net[5]
30 | self.layer3 = net[6]
31 | self.layer4 = net[7]
32 | net.pop()
33 | self.view = net.pop()
34 |
35 | def forward(self, x):
36 | layer0 = self.layer0(x)
37 | layer1 = self.layer1(layer0)
38 | layer2 = self.layer2(layer1)
39 | layer3 = self.layer3(layer2)
40 | layer4 = self.layer4(layer3)
41 | return layer4
42 |
43 |
44 | class D4Net(nn.Module):
45 | def __init__(self, num_cls):
46 | super(D4Net, self).__init__()
47 | net = resnext_101_32x4d_.resnext_101_32x4d
48 | net.load_state_dict(torch.load(resnext101_32_path))
49 |
50 | net = list(net.children())
51 | self.layer0 = nn.Sequential(*net[:4])
52 | self.layer1 = net[4]
53 | self.layer2 = net[5]
54 | self.layer3 = net[6]
55 | self.layer4 = net[7]
56 | self.max_pool = nn.AdaptiveMaxPool2d(1)
57 |
58 | net.pop()
59 | view = net.pop()
60 | self.classifier = nn.Sequential(view, nn.Linear(2048, num_cls))
61 |
62 | def forward(self, x1, x2):
63 | layer0_x1 = self.layer0(x1)
64 | layer1_x1 = self.layer1(layer0_x1)
65 | layer2_x1 = self.layer2(layer1_x1)
66 | layer3_x1 = self.layer3(layer2_x1)
67 | layer4_x1 = self.layer4(layer3_x1)
68 |
69 | layer0_x2 = self.layer0(x2)
70 | layer1_x2 = self.layer1(layer0_x2)
71 | layer2_x2 = self.layer2(layer1_x2)
72 | layer3_x2 = self.layer3(layer2_x2)
73 | layer4_x2 = self.layer4(layer3_x2)
74 |
75 | difference = layer4_x2 - layer4_x1
76 | difference = self.max_pool(difference)
77 | output = self.classifier(difference)
78 |
79 | return output, -self.max_pool(-F.cosine_similarity(layer4_x2, layer4_x1))
80 |
--------------------------------------------------------------------------------
/resnext/resnext_101_32x4d_.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | from functools import reduce
5 |
6 |
7 | class LambdaBase(nn.Sequential):
8 | def __init__(self, fn, *args):
9 | super(LambdaBase, self).__init__(*args)
10 | self.lambda_func = fn
11 |
12 | def forward_prepare(self, input):
13 | output = []
14 | for module in self._modules.values():
15 | output.append(module(input))
16 | return output if output else input
17 |
18 |
19 | class Lambda(LambdaBase):
20 | def forward(self, input):
21 | return self.lambda_func(self.forward_prepare(input))
22 |
23 |
24 | class LambdaMap(LambdaBase):
25 | def forward(self, input):
26 | return list(map(self.lambda_func, self.forward_prepare(input)))
27 |
28 |
29 | class LambdaReduce(LambdaBase):
30 | def forward(self, input):
31 | return reduce(self.lambda_func, self.forward_prepare(input))
32 |
33 |
34 | resnext_101_32x4d = nn.Sequential( # Sequential,
35 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False),
36 | nn.BatchNorm2d(64),
37 | nn.ReLU(),
38 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
39 | nn.Sequential( # Sequential,
40 | nn.Sequential( # Sequential,
41 | LambdaMap(lambda x: x, # ConcatTable,
42 | nn.Sequential( # Sequential,
43 | nn.Sequential( # Sequential,
44 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
45 | nn.BatchNorm2d(128),
46 | nn.ReLU(),
47 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
48 | nn.BatchNorm2d(128),
49 | nn.ReLU(),
50 | ),
51 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
52 | nn.BatchNorm2d(256),
53 | ),
54 | nn.Sequential( # Sequential,
55 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
56 | nn.BatchNorm2d(256),
57 | ),
58 | ),
59 | LambdaReduce(lambda x, y: x + y), # CAddTable,
60 | nn.ReLU(),
61 | ),
62 | nn.Sequential( # Sequential,
63 | LambdaMap(lambda x: x, # ConcatTable,
64 | nn.Sequential( # Sequential,
65 | nn.Sequential( # Sequential,
66 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
67 | nn.BatchNorm2d(128),
68 | nn.ReLU(),
69 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
70 | nn.BatchNorm2d(128),
71 | nn.ReLU(),
72 | ),
73 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
74 | nn.BatchNorm2d(256),
75 | ),
76 | Lambda(lambda x: x), # Identity,
77 | ),
78 | LambdaReduce(lambda x, y: x + y), # CAddTable,
79 | nn.ReLU(),
80 | ),
81 | nn.Sequential( # Sequential,
82 | LambdaMap(lambda x: x, # ConcatTable,
83 | nn.Sequential( # Sequential,
84 | nn.Sequential( # Sequential,
85 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
86 | nn.BatchNorm2d(128),
87 | nn.ReLU(),
88 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
89 | nn.BatchNorm2d(128),
90 | nn.ReLU(),
91 | ),
92 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
93 | nn.BatchNorm2d(256),
94 | ),
95 | Lambda(lambda x: x), # Identity,
96 | ),
97 | LambdaReduce(lambda x, y: x + y), # CAddTable,
98 | nn.ReLU(),
99 | ),
100 | ),
101 | nn.Sequential( # Sequential,
102 | nn.Sequential( # Sequential,
103 | LambdaMap(lambda x: x, # ConcatTable,
104 | nn.Sequential( # Sequential,
105 | nn.Sequential( # Sequential,
106 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
107 | nn.BatchNorm2d(256),
108 | nn.ReLU(),
109 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False),
110 | nn.BatchNorm2d(256),
111 | nn.ReLU(),
112 | ),
113 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
114 | nn.BatchNorm2d(512),
115 | ),
116 | nn.Sequential( # Sequential,
117 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
118 | nn.BatchNorm2d(512),
119 | ),
120 | ),
121 | LambdaReduce(lambda x, y: x + y), # CAddTable,
122 | nn.ReLU(),
123 | ),
124 | nn.Sequential( # Sequential,
125 | LambdaMap(lambda x: x, # ConcatTable,
126 | nn.Sequential( # Sequential,
127 | nn.Sequential( # Sequential,
128 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
129 | nn.BatchNorm2d(256),
130 | nn.ReLU(),
131 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
132 | nn.BatchNorm2d(256),
133 | nn.ReLU(),
134 | ),
135 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
136 | nn.BatchNorm2d(512),
137 | ),
138 | Lambda(lambda x: x), # Identity,
139 | ),
140 | LambdaReduce(lambda x, y: x + y), # CAddTable,
141 | nn.ReLU(),
142 | ),
143 | nn.Sequential( # Sequential,
144 | LambdaMap(lambda x: x, # ConcatTable,
145 | nn.Sequential( # Sequential,
146 | nn.Sequential( # Sequential,
147 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
148 | nn.BatchNorm2d(256),
149 | nn.ReLU(),
150 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
151 | nn.BatchNorm2d(256),
152 | nn.ReLU(),
153 | ),
154 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
155 | nn.BatchNorm2d(512),
156 | ),
157 | Lambda(lambda x: x), # Identity,
158 | ),
159 | LambdaReduce(lambda x, y: x + y), # CAddTable,
160 | nn.ReLU(),
161 | ),
162 | nn.Sequential( # Sequential,
163 | LambdaMap(lambda x: x, # ConcatTable,
164 | nn.Sequential( # Sequential,
165 | nn.Sequential( # Sequential,
166 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
167 | nn.BatchNorm2d(256),
168 | nn.ReLU(),
169 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
170 | nn.BatchNorm2d(256),
171 | nn.ReLU(),
172 | ),
173 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
174 | nn.BatchNorm2d(512),
175 | ),
176 | Lambda(lambda x: x), # Identity,
177 | ),
178 | LambdaReduce(lambda x, y: x + y), # CAddTable,
179 | nn.ReLU(),
180 | ),
181 | ),
182 | nn.Sequential( # Sequential,
183 | nn.Sequential( # Sequential,
184 | LambdaMap(lambda x: x, # ConcatTable,
185 | nn.Sequential( # Sequential,
186 | nn.Sequential( # Sequential,
187 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
188 | nn.BatchNorm2d(512),
189 | nn.ReLU(),
190 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False),
191 | nn.BatchNorm2d(512),
192 | nn.ReLU(),
193 | ),
194 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
195 | nn.BatchNorm2d(1024),
196 | ),
197 | nn.Sequential( # Sequential,
198 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
199 | nn.BatchNorm2d(1024),
200 | ),
201 | ),
202 | LambdaReduce(lambda x, y: x + y), # CAddTable,
203 | nn.ReLU(),
204 | ),
205 | nn.Sequential( # Sequential,
206 | LambdaMap(lambda x: x, # ConcatTable,
207 | nn.Sequential( # Sequential,
208 | nn.Sequential( # Sequential,
209 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
210 | nn.BatchNorm2d(512),
211 | nn.ReLU(),
212 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
213 | nn.BatchNorm2d(512),
214 | nn.ReLU(),
215 | ),
216 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
217 | nn.BatchNorm2d(1024),
218 | ),
219 | Lambda(lambda x: x), # Identity,
220 | ),
221 | LambdaReduce(lambda x, y: x + y), # CAddTable,
222 | nn.ReLU(),
223 | ),
224 | nn.Sequential( # Sequential,
225 | LambdaMap(lambda x: x, # ConcatTable,
226 | nn.Sequential( # Sequential,
227 | nn.Sequential( # Sequential,
228 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
229 | nn.BatchNorm2d(512),
230 | nn.ReLU(),
231 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
232 | nn.BatchNorm2d(512),
233 | nn.ReLU(),
234 | ),
235 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
236 | nn.BatchNorm2d(1024),
237 | ),
238 | Lambda(lambda x: x), # Identity,
239 | ),
240 | LambdaReduce(lambda x, y: x + y), # CAddTable,
241 | nn.ReLU(),
242 | ),
243 | nn.Sequential( # Sequential,
244 | LambdaMap(lambda x: x, # ConcatTable,
245 | nn.Sequential( # Sequential,
246 | nn.Sequential( # Sequential,
247 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
248 | nn.BatchNorm2d(512),
249 | nn.ReLU(),
250 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
251 | nn.BatchNorm2d(512),
252 | nn.ReLU(),
253 | ),
254 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
255 | nn.BatchNorm2d(1024),
256 | ),
257 | Lambda(lambda x: x), # Identity,
258 | ),
259 | LambdaReduce(lambda x, y: x + y), # CAddTable,
260 | nn.ReLU(),
261 | ),
262 | nn.Sequential( # Sequential,
263 | LambdaMap(lambda x: x, # ConcatTable,
264 | nn.Sequential( # Sequential,
265 | nn.Sequential( # Sequential,
266 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
267 | nn.BatchNorm2d(512),
268 | nn.ReLU(),
269 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
270 | nn.BatchNorm2d(512),
271 | nn.ReLU(),
272 | ),
273 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
274 | nn.BatchNorm2d(1024),
275 | ),
276 | Lambda(lambda x: x), # Identity,
277 | ),
278 | LambdaReduce(lambda x, y: x + y), # CAddTable,
279 | nn.ReLU(),
280 | ),
281 | nn.Sequential( # Sequential,
282 | LambdaMap(lambda x: x, # ConcatTable,
283 | nn.Sequential( # Sequential,
284 | nn.Sequential( # Sequential,
285 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
286 | nn.BatchNorm2d(512),
287 | nn.ReLU(),
288 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
289 | nn.BatchNorm2d(512),
290 | nn.ReLU(),
291 | ),
292 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
293 | nn.BatchNorm2d(1024),
294 | ),
295 | Lambda(lambda x: x), # Identity,
296 | ),
297 | LambdaReduce(lambda x, y: x + y), # CAddTable,
298 | nn.ReLU(),
299 | ),
300 | nn.Sequential( # Sequential,
301 | LambdaMap(lambda x: x, # ConcatTable,
302 | nn.Sequential( # Sequential,
303 | nn.Sequential( # Sequential,
304 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
305 | nn.BatchNorm2d(512),
306 | nn.ReLU(),
307 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
308 | nn.BatchNorm2d(512),
309 | nn.ReLU(),
310 | ),
311 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
312 | nn.BatchNorm2d(1024),
313 | ),
314 | Lambda(lambda x: x), # Identity,
315 | ),
316 | LambdaReduce(lambda x, y: x + y), # CAddTable,
317 | nn.ReLU(),
318 | ),
319 | nn.Sequential( # Sequential,
320 | LambdaMap(lambda x: x, # ConcatTable,
321 | nn.Sequential( # Sequential,
322 | nn.Sequential( # Sequential,
323 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
324 | nn.BatchNorm2d(512),
325 | nn.ReLU(),
326 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
327 | nn.BatchNorm2d(512),
328 | nn.ReLU(),
329 | ),
330 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
331 | nn.BatchNorm2d(1024),
332 | ),
333 | Lambda(lambda x: x), # Identity,
334 | ),
335 | LambdaReduce(lambda x, y: x + y), # CAddTable,
336 | nn.ReLU(),
337 | ),
338 | nn.Sequential( # Sequential,
339 | LambdaMap(lambda x: x, # ConcatTable,
340 | nn.Sequential( # Sequential,
341 | nn.Sequential( # Sequential,
342 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
343 | nn.BatchNorm2d(512),
344 | nn.ReLU(),
345 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
346 | nn.BatchNorm2d(512),
347 | nn.ReLU(),
348 | ),
349 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
350 | nn.BatchNorm2d(1024),
351 | ),
352 | Lambda(lambda x: x), # Identity,
353 | ),
354 | LambdaReduce(lambda x, y: x + y), # CAddTable,
355 | nn.ReLU(),
356 | ),
357 | nn.Sequential( # Sequential,
358 | LambdaMap(lambda x: x, # ConcatTable,
359 | nn.Sequential( # Sequential,
360 | nn.Sequential( # Sequential,
361 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
362 | nn.BatchNorm2d(512),
363 | nn.ReLU(),
364 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
365 | nn.BatchNorm2d(512),
366 | nn.ReLU(),
367 | ),
368 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
369 | nn.BatchNorm2d(1024),
370 | ),
371 | Lambda(lambda x: x), # Identity,
372 | ),
373 | LambdaReduce(lambda x, y: x + y), # CAddTable,
374 | nn.ReLU(),
375 | ),
376 | nn.Sequential( # Sequential,
377 | LambdaMap(lambda x: x, # ConcatTable,
378 | nn.Sequential( # Sequential,
379 | nn.Sequential( # Sequential,
380 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
381 | nn.BatchNorm2d(512),
382 | nn.ReLU(),
383 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
384 | nn.BatchNorm2d(512),
385 | nn.ReLU(),
386 | ),
387 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
388 | nn.BatchNorm2d(1024),
389 | ),
390 | Lambda(lambda x: x), # Identity,
391 | ),
392 | LambdaReduce(lambda x, y: x + y), # CAddTable,
393 | nn.ReLU(),
394 | ),
395 | nn.Sequential( # Sequential,
396 | LambdaMap(lambda x: x, # ConcatTable,
397 | nn.Sequential( # Sequential,
398 | nn.Sequential( # Sequential,
399 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
400 | nn.BatchNorm2d(512),
401 | nn.ReLU(),
402 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
403 | nn.BatchNorm2d(512),
404 | nn.ReLU(),
405 | ),
406 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
407 | nn.BatchNorm2d(1024),
408 | ),
409 | Lambda(lambda x: x), # Identity,
410 | ),
411 | LambdaReduce(lambda x, y: x + y), # CAddTable,
412 | nn.ReLU(),
413 | ),
414 | nn.Sequential( # Sequential,
415 | LambdaMap(lambda x: x, # ConcatTable,
416 | nn.Sequential( # Sequential,
417 | nn.Sequential( # Sequential,
418 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
419 | nn.BatchNorm2d(512),
420 | nn.ReLU(),
421 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
422 | nn.BatchNorm2d(512),
423 | nn.ReLU(),
424 | ),
425 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
426 | nn.BatchNorm2d(1024),
427 | ),
428 | Lambda(lambda x: x), # Identity,
429 | ),
430 | LambdaReduce(lambda x, y: x + y), # CAddTable,
431 | nn.ReLU(),
432 | ),
433 | nn.Sequential( # Sequential,
434 | LambdaMap(lambda x: x, # ConcatTable,
435 | nn.Sequential( # Sequential,
436 | nn.Sequential( # Sequential,
437 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
438 | nn.BatchNorm2d(512),
439 | nn.ReLU(),
440 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
441 | nn.BatchNorm2d(512),
442 | nn.ReLU(),
443 | ),
444 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
445 | nn.BatchNorm2d(1024),
446 | ),
447 | Lambda(lambda x: x), # Identity,
448 | ),
449 | LambdaReduce(lambda x, y: x + y), # CAddTable,
450 | nn.ReLU(),
451 | ),
452 | nn.Sequential( # Sequential,
453 | LambdaMap(lambda x: x, # ConcatTable,
454 | nn.Sequential( # Sequential,
455 | nn.Sequential( # Sequential,
456 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
457 | nn.BatchNorm2d(512),
458 | nn.ReLU(),
459 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
460 | nn.BatchNorm2d(512),
461 | nn.ReLU(),
462 | ),
463 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
464 | nn.BatchNorm2d(1024),
465 | ),
466 | Lambda(lambda x: x), # Identity,
467 | ),
468 | LambdaReduce(lambda x, y: x + y), # CAddTable,
469 | nn.ReLU(),
470 | ),
471 | nn.Sequential( # Sequential,
472 | LambdaMap(lambda x: x, # ConcatTable,
473 | nn.Sequential( # Sequential,
474 | nn.Sequential( # Sequential,
475 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
476 | nn.BatchNorm2d(512),
477 | nn.ReLU(),
478 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
479 | nn.BatchNorm2d(512),
480 | nn.ReLU(),
481 | ),
482 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
483 | nn.BatchNorm2d(1024),
484 | ),
485 | Lambda(lambda x: x), # Identity,
486 | ),
487 | LambdaReduce(lambda x, y: x + y), # CAddTable,
488 | nn.ReLU(),
489 | ),
490 | nn.Sequential( # Sequential,
491 | LambdaMap(lambda x: x, # ConcatTable,
492 | nn.Sequential( # Sequential,
493 | nn.Sequential( # Sequential,
494 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
495 | nn.BatchNorm2d(512),
496 | nn.ReLU(),
497 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
498 | nn.BatchNorm2d(512),
499 | nn.ReLU(),
500 | ),
501 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
502 | nn.BatchNorm2d(1024),
503 | ),
504 | Lambda(lambda x: x), # Identity,
505 | ),
506 | LambdaReduce(lambda x, y: x + y), # CAddTable,
507 | nn.ReLU(),
508 | ),
509 | nn.Sequential( # Sequential,
510 | LambdaMap(lambda x: x, # ConcatTable,
511 | nn.Sequential( # Sequential,
512 | nn.Sequential( # Sequential,
513 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
514 | nn.BatchNorm2d(512),
515 | nn.ReLU(),
516 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
517 | nn.BatchNorm2d(512),
518 | nn.ReLU(),
519 | ),
520 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
521 | nn.BatchNorm2d(1024),
522 | ),
523 | Lambda(lambda x: x), # Identity,
524 | ),
525 | LambdaReduce(lambda x, y: x + y), # CAddTable,
526 | nn.ReLU(),
527 | ),
528 | nn.Sequential( # Sequential,
529 | LambdaMap(lambda x: x, # ConcatTable,
530 | nn.Sequential( # Sequential,
531 | nn.Sequential( # Sequential,
532 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
533 | nn.BatchNorm2d(512),
534 | nn.ReLU(),
535 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
536 | nn.BatchNorm2d(512),
537 | nn.ReLU(),
538 | ),
539 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
540 | nn.BatchNorm2d(1024),
541 | ),
542 | Lambda(lambda x: x), # Identity,
543 | ),
544 | LambdaReduce(lambda x, y: x + y), # CAddTable,
545 | nn.ReLU(),
546 | ),
547 | nn.Sequential( # Sequential,
548 | LambdaMap(lambda x: x, # ConcatTable,
549 | nn.Sequential( # Sequential,
550 | nn.Sequential( # Sequential,
551 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
552 | nn.BatchNorm2d(512),
553 | nn.ReLU(),
554 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
555 | nn.BatchNorm2d(512),
556 | nn.ReLU(),
557 | ),
558 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
559 | nn.BatchNorm2d(1024),
560 | ),
561 | Lambda(lambda x: x), # Identity,
562 | ),
563 | LambdaReduce(lambda x, y: x + y), # CAddTable,
564 | nn.ReLU(),
565 | ),
566 | nn.Sequential( # Sequential,
567 | LambdaMap(lambda x: x, # ConcatTable,
568 | nn.Sequential( # Sequential,
569 | nn.Sequential( # Sequential,
570 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
571 | nn.BatchNorm2d(512),
572 | nn.ReLU(),
573 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
574 | nn.BatchNorm2d(512),
575 | nn.ReLU(),
576 | ),
577 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
578 | nn.BatchNorm2d(1024),
579 | ),
580 | Lambda(lambda x: x), # Identity,
581 | ),
582 | LambdaReduce(lambda x, y: x + y), # CAddTable,
583 | nn.ReLU(),
584 | ),
585 | nn.Sequential( # Sequential,
586 | LambdaMap(lambda x: x, # ConcatTable,
587 | nn.Sequential( # Sequential,
588 | nn.Sequential( # Sequential,
589 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
590 | nn.BatchNorm2d(512),
591 | nn.ReLU(),
592 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
593 | nn.BatchNorm2d(512),
594 | nn.ReLU(),
595 | ),
596 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
597 | nn.BatchNorm2d(1024),
598 | ),
599 | Lambda(lambda x: x), # Identity,
600 | ),
601 | LambdaReduce(lambda x, y: x + y), # CAddTable,
602 | nn.ReLU(),
603 | ),
604 | nn.Sequential( # Sequential,
605 | LambdaMap(lambda x: x, # ConcatTable,
606 | nn.Sequential( # Sequential,
607 | nn.Sequential( # Sequential,
608 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
609 | nn.BatchNorm2d(512),
610 | nn.ReLU(),
611 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
612 | nn.BatchNorm2d(512),
613 | nn.ReLU(),
614 | ),
615 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
616 | nn.BatchNorm2d(1024),
617 | ),
618 | Lambda(lambda x: x), # Identity,
619 | ),
620 | LambdaReduce(lambda x, y: x + y), # CAddTable,
621 | nn.ReLU(),
622 | ),
623 | ),
624 | nn.Sequential( # Sequential,
625 | nn.Sequential( # Sequential,
626 | LambdaMap(lambda x: x, # ConcatTable,
627 | nn.Sequential( # Sequential,
628 | nn.Sequential( # Sequential,
629 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
630 | nn.BatchNorm2d(1024),
631 | nn.ReLU(),
632 | nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 32, bias=False),
633 | nn.BatchNorm2d(1024),
634 | nn.ReLU(),
635 | ),
636 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
637 | nn.BatchNorm2d(2048),
638 | ),
639 | nn.Sequential( # Sequential,
640 | nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
641 | nn.BatchNorm2d(2048),
642 | ),
643 | ),
644 | LambdaReduce(lambda x, y: x + y), # CAddTable,
645 | nn.ReLU(),
646 | ),
647 | nn.Sequential( # Sequential,
648 | LambdaMap(lambda x: x, # ConcatTable,
649 | nn.Sequential( # Sequential,
650 | nn.Sequential( # Sequential,
651 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
652 | nn.BatchNorm2d(1024),
653 | nn.ReLU(),
654 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
655 | nn.BatchNorm2d(1024),
656 | nn.ReLU(),
657 | ),
658 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
659 | nn.BatchNorm2d(2048),
660 | ),
661 | Lambda(lambda x: x), # Identity,
662 | ),
663 | LambdaReduce(lambda x, y: x + y), # CAddTable,
664 | nn.ReLU(),
665 | ),
666 | nn.Sequential( # Sequential,
667 | LambdaMap(lambda x: x, # ConcatTable,
668 | nn.Sequential( # Sequential,
669 | nn.Sequential( # Sequential,
670 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
671 | nn.BatchNorm2d(1024),
672 | nn.ReLU(),
673 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
674 | nn.BatchNorm2d(1024),
675 | nn.ReLU(),
676 | ),
677 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
678 | nn.BatchNorm2d(2048),
679 | ),
680 | Lambda(lambda x: x), # Identity,
681 | ),
682 | LambdaReduce(lambda x, y: x + y), # CAddTable,
683 | nn.ReLU(),
684 | ),
685 | ),
686 | nn.AvgPool2d((7, 7), (1, 1)),
687 | Lambda(lambda x: x.view(x.size(0), -1)), # View,
688 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear,
689 | )
690 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import torch
3 | import random
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from torch import optim, nn
7 | from torch.autograd import Variable
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 |
11 | from config import *
12 | from resnext import D4Net
13 | from util.folder_classify import PairLoader
14 | import json
15 | from util.misc import AvgMeter, check_mkdir, cal_accuracy
16 |
17 |
18 | ckpt_path = './ckpt'
19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
20 |
21 | args = {
22 | 'exp_name': 'd4net',
23 | 'margin1': 0.2,
24 | 'margin2': 0.5,
25 | 'iter_num': 20000,
26 | 'train_batch_size': 16,
27 | 'val_batch_size': 160,
28 | 'last_iter': 0,
29 | 'val_print_freq': 1000,
30 | 'lr': 1e-3,
31 | 'lr_decay': 0.9,
32 | 'weight_decay': 5e-4,
33 | 'momentum': 0.9,
34 | 'resize_size': (224, 224),
35 | 'snapshot': ''
36 | }
37 |
38 | input_size, _ = args['resize_size']
39 |
40 | train_transform = transforms.Compose([
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ColorJitter(0.1),
43 | transforms.Resize(args['resize_size']),
44 | transforms.ToTensor(),
45 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
46 | ])
47 |
48 | val_transform = transforms.Compose([
49 | transforms.Resize(args['resize_size']),
50 | transforms.ToTensor(),
51 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
52 | ])
53 |
54 | # mode, transform
55 |
56 | train_set = PairLoader(mode='train', transform=train_transform)
57 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], shuffle=True, num_workers=8, drop_last=True)
58 |
59 | val_set = PairLoader(mode='val', transform=val_transform)
60 | val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=12)
61 |
62 | criterion = nn.CrossEntropyLoss().cuda()
63 |
64 | def main():
65 | try:
66 | train(args['exp_name'])
67 | except:
68 | print 'exception'
69 |
70 |
71 | def train(exp_name):
72 |
73 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')
74 |
75 | net = D4Net(num_cls=2)
76 | net.cuda().train()
77 |
78 | best_record = {}
79 |
80 | optimizer = optim.SGD([
81 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
82 | 'lr': 2 * args['lr']},
83 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
84 | 'lr': args['lr'], 'weight_decay': args['weight_decay']}
85 | ], momentum=args['momentum'])
86 |
87 | if len(args['snapshot']) == 0:
88 | best_record['loss'] = 1e10
89 | best_record['iter'] = 0
90 | best_record['lr'] = args['lr']
91 | else:
92 | print('training resumes from \'%s\'' % args['snapshot'])
93 | net.load_state_dict(torch.load(os.path.join(ckpt_path, args['exp_name'], args['snapshot'] + '.pth')))
94 | optimizer.load_state_dict(
95 | torch.load(os.path.join(ckpt_path, args['exp_name'], args['snapshot'] + '_optim.pth')))
96 | optimizer.param_groups[0]['lr'] = 2 * args['lr']
97 | optimizer.param_groups[1]['lr'] = args['lr']
98 |
99 | best_record['loss'] = 1e10
100 | best_record['iter'] = 0
101 | best_record['lr'] = args['lr']
102 |
103 | check_mkdir(ckpt_path)
104 | check_mkdir(os.path.join(ckpt_path, args['exp_name']))
105 | with open(log_path, 'w') as f:
106 | f.write(str(args) + '\n\n')
107 | print 'start to train'
108 |
109 | curr_iter = args['last_iter']
110 |
111 | while True:
112 | class_loss_meter = AvgMeter()
113 | margin_loss_meter = AvgMeter()
114 | margin_loss2_meter = AvgMeter()
115 | acc_meter = AvgMeter()
116 |
117 | for i, data in enumerate(train_loader):
118 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
119 | ) ** args['lr_decay']
120 | optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
121 | ) ** args['lr_decay']
122 |
123 | test_images, ref_images, labels = data
124 | batch_size = labels.size(0)
125 |
126 | test_images = Variable(test_images).cuda()
127 | ref_images = Variable(ref_images).cuda()
128 | labels = Variable(labels).cuda()
129 |
130 | optimizer.zero_grad()
131 |
132 | outputs, cos_similarity = net(test_images, ref_images)
133 |
134 | class_loss = criterion(outputs, labels)
135 |
136 | pos_number = cos_similarity[labels == 1].size(0)
137 |
138 | if pos_number != batch_size:
139 |
140 | pos_cos_similarity = torch.min(cos_similarity[labels == 1])
141 | neg_cos_similarity = torch.max(cos_similarity[labels == 0])
142 |
143 | margin_loss = torch.max(Variable(torch.zeros_like(neg_cos_similarity)).cuda(), -torch.abs(neg_cos_similarity - pos_cos_similarity) + args['margin2'])
144 | margin2_loss = torch.max(Variable(torch.zeros_like(neg_cos_similarity)).cuda(), torch.abs(1 - pos_cos_similarity) - args['margin1'])
145 |
146 | loss = margin_loss + class_loss + margin2_loss
147 | margin_loss_meter.update(margin_loss)
148 | margin_loss2_meter.update(margin2_loss)
149 | else:
150 | loss = class_loss
151 |
152 | loss.backward()
153 | optimizer.step()
154 |
155 | accuracy = cal_accuracy(outputs.data, labels.data)
156 | acc_meter.update(accuracy[0], batch_size)
157 | class_loss_meter.update(class_loss.item(), batch_size)
158 |
159 | log = '[iter %d], [class_loss %.4f], [margin_loss %.4f], [pos_1_diff %.4f], [train_acc %.4f], [lr %.8f]' % (
160 | curr_iter+1, class_loss_meter.avg, margin_loss_meter.avg, margin_loss2_meter.avg, acc_meter.avg, optimizer.param_groups[1]['lr'])
161 | print log
162 | with open(log_path, 'a') as f:
163 | f.write(log + '\n')
164 |
165 | if (curr_iter + 1) % args['val_print_freq'] == 0:
166 |
167 | val_loss, val_acc = validate(net)
168 | log = 'iter_%d_loss_%.4f_valacc_%.4f_lr_%.8f' % (
169 | curr_iter + 1, val_loss, val_acc, optimizer.param_groups[1]['lr'])
170 | print '--------------------------------------------------------------------------------'
171 | print log
172 |
173 | if best_record['loss'] > val_loss:
174 | best_record['loss'] = val_loss
175 | best_record['valacc'] = val_acc
176 | best_record['iter'] = curr_iter
177 | best_record['lr'] = optimizer.param_groups[1]['lr']
178 |
179 | torch.save(net.state_dict(), os.path.join(ckpt_path, args['exp_name'], log + '.pth'))
180 | torch.save(optimizer.state_dict(),
181 | os.path.join(ckpt_path, args['exp_name'], log + '_optim.pth'))
182 |
183 | with open(os.path.join(ckpt_path, args['exp_name'] + ".txt"), "a") as f:
184 | f.write(log + '\n')
185 |
186 | print '[best]: [iter %d], [val_loss %.4f], [val_acc %.4f], [lr %.8f]' % (
187 | best_record['iter'] + 1, best_record['loss'], best_record['valacc'], best_record['lr']
188 | )
189 | print '--------------------------------------------------------------------------------'
190 |
191 | curr_iter +=1
192 | if curr_iter > args['iter_num'] :
193 | return
194 |
195 |
196 | def validate(net):
197 | print 'validating...'
198 | net.eval()
199 |
200 | val_length = len(val_loader)
201 | with torch.no_grad():
202 | loss_meter = AvgMeter()
203 | acc_meter = AvgMeter()
204 | for vi, data in enumerate(val_loader, 0):
205 | print '%d/%d' % (vi+1, val_length)
206 | test_images, ref_images, labels = data
207 |
208 | val_batch_size = labels.size(0)
209 |
210 | test_images = Variable(test_images).cuda()
211 | ref_images = Variable(ref_images).cuda()
212 | labels = Variable(labels).cuda()
213 |
214 | outputs, _ = net(test_images, ref_images)
215 | loss = criterion(outputs, labels)
216 | loss_meter.update(loss.item(), val_batch_size)
217 |
218 | accuracy = cal_accuracy(outputs.data, labels.data)
219 | acc_meter.update(accuracy[0], val_batch_size)
220 |
221 | net.train()
222 | return loss_meter.avg, acc_meter.avg
223 |
224 |
225 | if __name__ == '__main__':
226 | main()
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/githubBingoChen/D4Net/a9c07679fb66e3101e804aacd388508e52ab49aa/util/__init__.py
--------------------------------------------------------------------------------
/util/folder_classify.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import scipy.io as sio
7 | from PIL import Image
8 | from torch.utils import data
9 | import json
10 | from config import data_root
11 |
12 |
13 |
14 |
15 | def make_dataset(mode):
16 | assert (mode in ['train', 'val'])
17 |
18 | dataset_path = os.path.join(data_root, mode)
19 |
20 | items = []
21 | for dir_name in os.listdir(dataset_path):
22 |
23 | lace_path = os.path.join(dataset_path, dir_name)
24 | if os.path.isdir(lace_path):
25 |
26 | json_path = os.path.join(lace_path, 'pair_infos.json')
27 | if not os.path.isfile(json_path):
28 | continue
29 | h_json = open(json_path, 'r')
30 | f_json = json.load(h_json)
31 |
32 | img_list = [os.path.splitext(f)[0] for f in os.listdir(lace_path) if 'img' in f]
33 |
34 | for img_name in img_list:
35 | img_path1 = os.path.join(lace_path, img_name + '.JPG')
36 | img_path2 = os.path.join(lace_path, 'ref' + img_name[3:] + '.JPG')
37 |
38 | if (not os.path.isfile(img_path1)) or (not os.path.isfile(img_path2)) or (img_name not in f_json.keys()):
39 | continue
40 |
41 | item = (dir_name, img_path1, img_path2, f_json[img_name]['gt'])
42 | items.append(item)
43 |
44 | h_json.close()
45 |
46 | print(mode, ' dataset filenum: ', len(items))
47 | return items
48 |
49 |
50 | class PairLoader(data.Dataset):
51 | def __init__(self, mode, transform):
52 | self.imgs = make_dataset(mode)
53 | if len(self.imgs) == 0:
54 | raise (RuntimeError('Found 0 images, please check the data set'))
55 | self.mode = mode
56 | self.transform = transform
57 |
58 | def __getitem__(self, index):
59 | dir_name, img_path1, img_path2, gt = self.imgs[index]
60 |
61 | target_ori = Image.open(img_path1).convert('RGB')
62 | ref_ori = Image.open(img_path2).convert('RGB')
63 |
64 | if self.transform is not None:
65 | target_img = self.transform(target_ori)
66 | ref_img = self.transform(ref_ori)
67 |
68 | if gt == 1:
69 | gt = 1
70 | else:
71 | gt = 0
72 | gt = int(gt)
73 |
74 | return target_img, ref_img, gt
75 |
76 | def __len__(self):
77 | return len(self.imgs)
78 |
79 |
--------------------------------------------------------------------------------
/util/folder_mvtec.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import scipy.io as sio
7 | from PIL import Image
8 | from torch.utils import data
9 | import json
10 |
11 |
12 | input_root = '/media/b3-542/0DFD5CD11721CA55/mvtec_anomaly_detection'
13 |
14 |
15 | def make_dataset():
16 |
17 |
18 | items = []
19 | dirs = [f for f in os.listdir(input_root) if os.path.isdir(os.path.join(input_root, f))]
20 | for dir_name in sorted(dirs)[3:4]:
21 | # for dir_name in sorted(os.listdir(input_root))[1:2]:
22 | print(dir_name)
23 | # os.path.join(input_root, dir_name, 'test')
24 | img_root = os.path.join(input_root, dir_name, 'test')
25 |
26 | json_path = os.path.join(input_root, '%s.json' % (dir_name))
27 | if not os.path.isfile(json_path):
28 | continue
29 | h_json = open(json_path, 'r')
30 | f_json = json.load(h_json)
31 |
32 | img_list = []
33 | for img_dir in sorted(os.listdir(img_root)):
34 | # if 'good' not in img_dir:
35 | if 'good' in img_dir:
36 | continue
37 | for img_name in os.listdir(os.path.join(img_root, img_dir)):
38 | img_list.append(os.path.join(img_root, img_dir, img_name))
39 |
40 | img_list = sorted(img_list)
41 | for img_name in img_list:
42 | img_path1 = img_name
43 | # if img_path1 in delete_list:
44 | # continue
45 | img_path2 = f_json[img_name]
46 |
47 | if (not os.path.isfile(img_path1)) or (not os.path.isfile(img_path2)) or (img_name not in f_json.keys()):
48 | continue
49 |
50 | gt = 0
51 | if 'good' in img_path1:
52 | gt = 1
53 |
54 | item = (img_path1, img_path2, gt)
55 | items.append(item)
56 |
57 | h_json.close()
58 |
59 | print('dataset filenum: ', len(items))
60 | return items
61 |
62 |
63 | class PairLoader(data.Dataset):
64 | def __init__(self, transform):
65 | self.imgs = make_dataset()
66 | if len(self.imgs) == 0:
67 | raise (RuntimeError('Found 0 images, please check the data set'))
68 | self.transform = transform
69 |
70 | def __getitem__(self, index):
71 | img_path1, img_path2, gt = self.imgs[index]
72 |
73 | target_ori = Image.open(img_path1).convert('RGB')
74 | ref_ori = Image.open(img_path2).convert('RGB')
75 |
76 | if self.transform is not None:
77 | target_img = self.transform(target_ori)
78 | ref_img = self.transform(ref_ori)
79 |
80 | gt = int(gt)
81 |
82 | return target_img, ref_img, gt
83 |
84 | def __len__(self):
85 | return len(self.imgs)
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import cv2
4 | from skimage.measure import compare_psnr, compare_ssim
5 | import os
6 | from torch import nn
7 | from torch.autograd import Variable
8 | from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score, precision_recall_curve
9 |
10 |
11 |
12 | class AvgMeter(object):
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 |
29 | def check_mkdir(dir_name):
30 | if not os.path.exists(dir_name):
31 | os.mkdir(dir_name)
32 |
33 |
34 |
35 |
36 | # def calc_psnr(im1, im2):
37 | # im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
38 | # im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
39 | # return compare_psnr(im1_y, im2_y)
40 |
41 | def calc_ssim(im1, im2):
42 | # im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
43 | im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
44 | # im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
45 | im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)
46 | return compare_ssim(im1_y, im2_y)
47 |
48 |
49 | def cal_accuracy(output, target, topk=(1,)):
50 | """Computes the precision@k for the specified values of k"""
51 | maxk = max(topk)
52 | batch_size = target.size(0)
53 |
54 | _, pred = output.topk(maxk, 1, True, True)
55 | pred = pred.t()
56 | correct = pred.eq(target.view(1, -1).expand_as(pred))
57 |
58 | res = []
59 | for k in topk:
60 | correct_k = correct[:k].view(-1).float().sum(0)
61 | res.append(correct_k.div_(batch_size).item())
62 | return res
63 |
64 |
65 | def cal_metric(output, target):
66 | maxk = 1
67 | _, pred = output.topk(maxk, 1, True, True)
68 |
69 | accuracy = accuracy_score(target, pred)
70 | precision = precision_score(target, pred, pos_label=0)
71 |
72 | recall = recall_score(target, pred, pos_label=0)
73 | f1_value = f1_score(target, pred, pos_label=0)
74 | return accuracy, precision, recall, f1_value
75 |
76 |
77 | def cal_pr_curve(pred, target):
78 | precision, recall, thresholds = precision_recall_curve(target, pred, pos_label=0)
79 | return precision, recall, thresholds
80 |
81 |
82 | def cal_metric_with_path(output, target, image_path):
83 | maxk = 1
84 | _, pred = output.topk(maxk, 1, True, True)
85 |
86 | for i in range(len(pred)):
87 | if (target[i] == 0 and pred[i] != target[i]):
88 | print image_path[i]
89 | with open('failure_recall_error_record.txt', 'a') as f:
90 | f.write(image_path[i] + '\n')
91 |
92 | accuracy = accuracy_score(target, pred)
93 | precision = precision_score(target, pred, pos_label=0)
94 |
95 | recall = recall_score(target, pred, pos_label=0)
96 | f1_value = f1_score(target, pred, pos_label=0)
97 | return accuracy, precision, recall, f1_value
98 |
99 |
100 |
101 |
102 |
103 | def get_center_loss(centers, features, target, alpha):
104 | batch_size = target.size(0)
105 | features_dim = features.size(1)
106 | target_expand = target.view(batch_size, 1).expand(batch_size, features_dim)
107 | centers_var = Variable(centers)
108 | centers_batch = centers_var.gather(0, target_expand)
109 | criterion = nn.MSELoss()
110 | center_loss = criterion(features, centers_batch)
111 |
112 | diff = centers_batch - features
113 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True,
114 | return_counts=True)
115 | appear_times = torch.from_numpy(unique_count).gather(0, torch.from_numpy(unique_reverse))
116 | appear_times_expand = appear_times.view(-1, 1).expand(batch_size, features_dim).type(torch.FloatTensor)
117 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6)
118 | diff_cpu *= alpha
119 | for i in range(batch_size):
120 | centers[target.data[i]] -= diff_cpu[i].type(centers.type())
121 |
122 | return center_loss, centers
--------------------------------------------------------------------------------