├── .idea
├── Look_At_Boundary_PyTorch.iml
├── dictionaries
│ ├── jin.xml
│ └── zhijun.xml
├── encodings.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── other.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── dataset.py
├── evaluate.py
├── main.py
├── models
├── __init__.py
├── losses.py
└── models.py
├── train.py
└── utils
├── __init__.py
├── args.py
├── dataload.py
├── dataset_info.py
├── pdb.py
├── train_eval_utils.py
└── visual.py
/.idea/Look_At_Boundary_PyTorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/dictionaries/jin.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | caffe
5 | cfss
6 | cofw
7 | imwrite
8 | mmod
9 | nparts
10 | sapm
11 | tcdcn
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/dictionaries/zhijun.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | accuracys
5 | aflw
6 | allshapes
7 | anno
8 | annos
9 | batchnorm
10 | batchsize
11 | bcsize
12 | bilinear
13 | ckpts
14 | conv
15 | cuda
16 | cudnn
17 | dataload
18 | dataloader
19 | dataloading
20 | datasets
21 | dfake
22 | discrim
23 | downsample
24 | fmfhourglass
25 | frobenius
26 | frontalset
27 | fullset
28 | gthm
29 | heatmap
30 | heatmaps
31 | idxs
32 | imgs
33 | imread
34 | imshow
35 | inchannels
36 | inplanes
37 | ispdb
38 | keypoint
39 | keypoints
40 | largepose
41 | lsll
42 | lsul
43 | maxpool
44 | newidx
45 | numbins
46 | numel
47 | regressor
48 | relu
49 | standarised
50 | strd
51 | testset
52 | tform
53 | trainset
54 | usll
55 | usul
56 | wflw
57 | wingloss
58 | zhijun
59 |
60 |
61 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
234 |
235 |
236 |
237 |
238 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
472 |
473 |
474 |
475 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 | ratote_limit
77 | ratate_limit
78 | ratotion
79 | cat
80 | assert
81 | test68
82 | dataset
83 | img
84 |
85 |
86 |
87 | ratate_limit
88 | rotate_limit
89 | rotation
90 | COFW68
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 | 1553946842473
207 |
208 |
209 | 1553946842473
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 | 1553947592516
229 |
230 |
231 |
232 | 1553947592516
233 |
234 |
235 | 1553948322588
236 |
237 |
238 |
239 | 1553948322588
240 |
241 |
242 | 1553948878831
243 |
244 |
245 |
246 | 1553948878831
247 |
248 |
249 | 1553948924765
250 |
251 |
252 |
253 | 1553948924765
254 |
255 |
256 | 1553949029055
257 |
258 |
259 |
260 | 1553949029055
261 |
262 |
263 | 1553950384851
264 |
265 |
266 |
267 | 1553950384851
268 |
269 |
270 | 1553950416030
271 |
272 |
273 |
274 | 1553950416030
275 |
276 |
277 | 1554362542720
278 |
279 |
280 |
281 | 1554362542720
282 |
283 |
284 | 1554362554487
285 |
286 |
287 |
288 | 1554362554487
289 |
290 |
291 | 1554367348789
292 |
293 |
294 |
295 | 1554367348789
296 |
297 |
298 | 1554386207186
299 |
300 |
301 |
302 | 1554386207186
303 |
304 |
305 | 1554386509068
306 |
307 |
308 |
309 | 1554386509068
310 |
311 |
312 | 1554386734250
313 |
314 |
315 |
316 | 1554386734250
317 |
318 |
319 | 1554441780428
320 |
321 |
322 |
323 | 1554441780428
324 |
325 |
326 | 1554444107718
327 |
328 |
329 |
330 | 1554444107718
331 |
332 |
333 | 1554448010897
334 |
335 |
336 |
337 | 1554448010897
338 |
339 |
340 | 1556388370117
341 |
342 |
343 |
344 | 1556388370117
345 |
346 |
347 | 1556392912461
348 |
349 |
350 |
351 | 1556392912461
352 |
353 |
354 | 1556439992578
355 |
356 |
357 |
358 | 1556439992578
359 |
360 |
361 | 1557031032114
362 |
363 |
364 |
365 | 1557031032114
366 |
367 |
368 | 1557458596830
369 |
370 |
371 |
372 | 1557458596830
373 |
374 |
375 | 1559315473718
376 |
377 |
378 |
379 | 1559315473718
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### Environment Requirement
2 |
3 | + PyTorch >= 1.0.0
4 | + Python >= 3.6 (numpy, scipy, matplotlib, tqdm)
5 | + OpenCV == 3.4.5
6 | + Platform: Linux
7 |
8 | ### Get program
9 |
10 | ```git clone git@github.com:FunkyKoki/Look_At_Boundary_PyTorch.git```
11 |
12 | Program structure is as below:
13 |
14 | ```
15 | .
16 | ├── dataset.py
17 | ├── evaluate.py
18 | ├── models
19 | │ ├── __init__.py
20 | │ ├── losses.py
21 | │ └── models.py
22 | ├── README.md
23 | ├── train.py
24 | ├── utils
25 | │ ├── args.py
26 | │ ├── dataload.py
27 | │ ├── dataset_info.py
28 | │ ├── __init__.py
29 | │ ├── pdb.py
30 | │ ├── train_eval_utils.py
31 | │ └── visual.py
32 | └── weights
33 | └── ckpts
34 | ```
35 |
36 | ### Dataset Prepare
37 |
38 | This program support 4 popular face landmark datasets: 300W, AFLW, COFW, WFLW. The dataset file folder structure is as below:
39 |
40 | ```angular2
41 | .
42 | ├── 300W
43 | │ ├── afw
44 | │ ├── helen
45 | │ │ ├── testset
46 | │ │ └── trainset
47 | │ ├── ibug
48 | │ ├── lfpw
49 | │ │ ├── testset
50 | │ │ └── trainset
51 | │ ├── test_imgs
52 | │ └── testset
53 | │ ├── 01_Indoor
54 | │ └── 02_Outdoor
55 | ├── AFLW
56 | │ ├── 0
57 | │ ├── 2
58 | │ └── 3
59 | ├── COFW
60 | │ ├── test_imgs
61 | │ └── train_imgs
62 | └── WFLW
63 | └── WFLW_images
64 | ├── 0−−Parade
65 | ├── 10−−People_Marching
66 | ├── 11−−Meeting
67 | ├── 12−−Group
68 | ├── 13−−Interview
69 | ├── 14−−Traffic
70 | ├── 15−−Stock_Market
71 | ├── 16−−Award_Ceremony
72 | ├── 17−−Ceremony
73 | ├── 18−−Concerts
74 | ├── 19−−Couple
75 | ├── 1−−Handshaking
76 | ├── 20−−Family_Group
77 | ├── 21−−Festival
78 | ├── 22−−Picnic
79 | ```
80 |
81 | Tips: Pay attention to the ```test_imgs``` folder and ```testset``` folder in 300W dataset, the ```test_imgs``` pics are human faces from COFW which are annotated with 68 landmarks, that's why it is put here. Some other things are written in readme.txt.
82 |
83 | The annotation file can be download from https://pan.baidu.com/s/1hYFcz260IB0pMISbHbxoTg, the code is ```tuz9```, annotation format is \[x1, y1, x2, y2, …, xn, yn, bboxleft, bboxtop, bboxright, bboxbottom, picH, picW, pic_route\], which are coordinates, bounding box position, height and width of inital pic, and route of the pic in order.
84 |
85 | ### Model Evaluation
86 |
87 | WFLW training model can be download from https://pan.baidu.com/s/1tM3oJFUHmP4kJA7enXVLjA, the code is ```tbgi``` and put at ```weights``` folder, this model is trained with 900 epoch.
88 |
89 | When evaluating, you can config the param in utils/args.py or just set the param by terminal, for example, if you want to evaluate at ```Pose Testset``` normalized in the way of ```inter_ocular```:
90 |
91 | ```python evaluate.py −−dataset WFLW −−split pose −−eval_epoch 900 −−norm_way inter_ocular```
92 |
93 | ### Model Training
94 |
95 | Config almost everything in utils/args or set them by terminal:
96 |
97 | ```python train.py −−dataset WFLW −−split train −−loss_type L2```
98 |
99 | Tips: This program integrates the ```Wingloss``` and ```Pose-based Date Balancing```, if you want to use them, just choose it ^_^.
100 |
101 | ### In the end
102 |
103 | Fuck every LICENSE.
104 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from utils import args, get_annotations_list, get_item_from
3 |
4 |
5 | class GeneralDataset(data.Dataset):
6 |
7 | def __init__(self, dataset='WFLW', split='train'):
8 | self.dataset = dataset
9 | self.split = split
10 | self.list = get_annotations_list(dataset, split, ispdb=args.PDB)
11 |
12 | def __len__(self):
13 | return len(self.list)
14 |
15 | def __getitem__(self, item):
16 | return get_item_from(self.dataset, self.split, self.list[item])
17 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import time
3 | import numpy as np
4 | from dataset import GeneralDataset
5 | from models import *
6 | from utils import *
7 |
8 |
9 | def evaluate(arg):
10 | devices = torch.device('cuda:'+arg.gpu_id)
11 | error_rate = []
12 | failure_count = 0
13 | max_threshold = arg.max_threshold
14 |
15 | testset = GeneralDataset(dataset=arg.dataset, split=arg.split)
16 | dataloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, pin_memory=True)
17 |
18 | print('***** Normal Evaluating *****')
19 | print('Evaluating parameters:\n' +
20 | '# Dataset: ' + arg.dataset + '\n' +
21 | '# Dataset split: ' + arg.split + '\n' +
22 | '# Epoch of the model: ' + str(arg.eval_epoch) + '\n' +
23 | '# Normalize way: ' + arg.norm_way + '\n' +
24 | '# Max threshold: ' + str(arg.max_threshold) + '\n')
25 |
26 | print('Loading network ...')
27 | estimator = Estimator(stacks=arg.hour_stack, msg_pass=arg.msg_pass)
28 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2*kp_num[arg.dataset])
29 | estimator = load_weights(estimator, arg.save_folder+'estimator_'+str(arg.eval_epoch)+'.pth', devices)
30 | regressor = load_weights(regressor, arg.save_folder+arg.dataset+'_regressor_'+str(arg.eval_epoch)+'.pth', devices)
31 | if arg.cuda:
32 | estimator = estimator.cuda(device=devices)
33 | regressor = regressor.cuda(device=devices)
34 | estimator.eval()
35 | regressor.eval()
36 | print('Loading network done!\nStart testing ...')
37 |
38 | time_records = []
39 | with torch.no_grad():
40 | for data in tqdm.tqdm(dataloader):
41 | start = time.time()
42 |
43 | input_images, gt_coords_xy, gt_heatmap, coords_xy, bbox, img_name = data
44 | gt_coords_xy = gt_coords_xy.squeeze().numpy()
45 | bbox = bbox.squeeze().numpy()
46 | error_normalize_factor = calc_normalize_factor(arg.dataset, coords_xy.numpy(), arg.norm_way) \
47 | if arg.norm_way in ['inter_pupil', 'inter_ocular'] else (bbox[2] - bbox[0])
48 | input_images = input_images.unsqueeze(1)
49 | input_images = input_images.cuda(device=devices)
50 |
51 | pred_heatmaps = estimator(input_images)
52 | pred_coords = regressor(input_images, pred_heatmaps[-1].detach()).detach().cpu().squeeze().numpy()
53 | pred_coords_map_back = inverse_affine(arg, pred_coords, bbox)
54 |
55 | time_records.append(time.time() - start)
56 |
57 | error_rate_i = calc_error_rate_i(
58 | arg.dataset,
59 | pred_coords_map_back,
60 | coords_xy[0].numpy(),
61 | error_normalize_factor
62 | )
63 |
64 | if arg.eval_visual:
65 | eval_heatmap(arg, pred_heatmaps[-1], img_name, bbox, save_img=arg.save_img)
66 | eval_pred_points(arg, pred_coords, img_name, bbox, save_img=arg.save_img)
67 |
68 | failure_count = failure_count + 1 if error_rate_i > max_threshold else failure_count
69 | error_rate.append(error_rate_i)
70 |
71 | area_under_curve, auc_record = calc_auc(arg.dataset, arg.split, error_rate, max_threshold)
72 | error_rate = sum(error_rate) / dataset_size[arg.dataset][arg.split] * 100
73 | failure_rate = failure_count / dataset_size[arg.dataset][arg.split] * 100
74 |
75 | print('\nEvaluating results:\n# AUC: {:.4f}\n# Error Rate: {:.2f}%\n# Failure Rate: {:.2f}%\n'.format(
76 | area_under_curve, error_rate, failure_rate))
77 | print('Average speed: {:.2f}FPS'.format(1./np.mean(np.array(time_records))))
78 |
79 |
80 | def evaluate_with_gt_heatmap(arg):
81 | devices = torch.device('cuda:' + arg.gpu_id)
82 | error_rate = []
83 | failure_count = 0
84 | max_threshold = arg.max_threshold
85 |
86 | testset = GeneralDataset(dataset=arg.dataset, split=arg.split)
87 | dataloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, pin_memory=True)
88 |
89 | print('***** Evaluating with ground truth heatmap *****')
90 | print('Evaluating parameters:\n' +
91 | '# Dataset: ' + arg.dataset + '\n' +
92 | '# Dataset split: ' + arg.split + '\n' +
93 | '# Epoch of the model: ' + str(arg.eval_epoch) + '\n' +
94 | '# Normalize way: ' + arg.norm_way + '\n' +
95 | '# Max threshold: ' + str(arg.max_threshold) + '\n')
96 |
97 | print('Loading network...')
98 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2 * kp_num[arg.dataset])
99 | regressor = load_weights(regressor, arg.save_folder + arg.dataset + '_regressor_' + str(arg.eval_epoch) + '.pth',
100 | devices)
101 | if arg.cuda:
102 | regressor = regressor.cuda(device=devices)
103 | regressor.eval()
104 | print('Loading network done!\nStart testing...')
105 |
106 | time_records = []
107 | with torch.no_grad():
108 | for data in tqdm.tqdm(dataloader):
109 | start = time.time()
110 |
111 | input_images, gt_coords_xy, gt_heatmap, coords_xy, bbox, img_name = data
112 | bbox = bbox.squeeze().numpy()
113 | error_normalize_factor = calc_normalize_factor(arg.dataset, coords_xy.numpy(), arg.norm_way) \
114 | if arg.norm_way in ['inter_pupil', 'inter_ocular'] else (bbox[2] - bbox[0])
115 | input_images = input_images.unsqueeze(1)
116 | input_images = input_images.cuda(device=devices)
117 | gt_heatmap = gt_heatmap.cuda(device=devices)
118 |
119 | pred_coords = regressor(input_images, gt_heatmap).detach().cpu().squeeze().numpy()
120 | pred_coords_map_back = inverse_affine(arg, pred_coords, bbox)
121 |
122 | time_records.append(time.time() - start)
123 |
124 | error_rate_i = calc_error_rate_i(
125 | arg.dataset,
126 | pred_coords_map_back,
127 | coords_xy[0].numpy(),
128 | error_normalize_factor
129 | )
130 |
131 | if arg.eval_visual:
132 | eval_gt_pred_points(arg, gt_coords_xy, pred_coords, img_name, bbox, save_img=arg.save_img)
133 |
134 | failure_count = failure_count + 1 if error_rate_i > max_threshold else failure_count
135 | error_rate.append(error_rate_i)
136 |
137 | area_under_curve, auc_record = calc_auc(arg.dataset, arg.split, error_rate, max_threshold)
138 | error_rate = sum(error_rate) / dataset_size[arg.dataset][arg.split] * 100
139 | failure_rate = failure_count / dataset_size[arg.dataset][arg.split] * 100
140 |
141 | print('\nEvaluating results:\n# AUC: {:.4f}\n# Error Rate: {:.2f}%\n# Failure Rate: {:.2f}%\n'.format(
142 | area_under_curve, error_rate, failure_rate))
143 | print('Average speed: {:.2f}FPS'.format(1. / np.mean(np.array(time_records))))
144 |
145 |
146 | def evaluate_nparts(arg):
147 | devices = torch.device('cuda:' + arg.gpu_id)
148 | error_rate = []
149 |
150 | testset = GeneralDataset(dataset=arg.dataset, split=arg.split)
151 | dataloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, pin_memory=True)
152 |
153 | print('***** Evaluating Different Parts *****')
154 | print('Evaluating parameters:\n' +
155 | '# Dataset: ' + arg.dataset + '\n' +
156 | '# Dataset split: ' + arg.split + '\n' +
157 | '# Epoch of the model: ' + str(arg.eval_epoch) + '\n' +
158 | '# Normalize way: ' + arg.norm_way + '\n' +
159 | '# Max threshold: ' + str(arg.max_threshold) + '\n')
160 |
161 | print('Loading network ...')
162 | estimator = Estimator(stacks=arg.hour_stack, msg_pass=arg.msg_pass)
163 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2 * kp_num[arg.dataset])
164 | estimator = load_weights(estimator, arg.save_folder + 'estimator_' + str(arg.eval_epoch) + '.pth', devices)
165 | regressor = load_weights(regressor, arg.save_folder + arg.dataset + '_regressor_' + str(arg.eval_epoch) + '.pth',
166 | devices)
167 | if arg.cuda:
168 | estimator = estimator.cuda(device=devices)
169 | regressor = regressor.cuda(device=devices)
170 | estimator.eval()
171 | regressor.eval()
172 | print('Loading network done!\nStart testing ...')
173 |
174 | time_records = []
175 | with torch.no_grad():
176 | for data in tqdm.tqdm(dataloader):
177 | start = time.time()
178 |
179 | input_images, gt_coords_xy, gt_heatmap, coords_xy, bbox, img_name = data
180 | gt_coords_xy = gt_coords_xy.squeeze().numpy()
181 | bbox = bbox.squeeze().numpy()
182 | error_normalize_factor = calc_normalize_factor(arg.dataset, coords_xy.numpy(), arg.norm_way) \
183 | if arg.norm_way in ['inter_pupil', 'inter_ocular'] else (bbox[2] - bbox[0])
184 | input_images = input_images.unsqueeze(1)
185 | input_images = input_images.cuda(device=devices)
186 |
187 | pred_heatmaps = estimator(input_images)
188 | pred_coords = regressor(input_images, pred_heatmaps[-1].detach()).detach().cpu().squeeze().numpy()
189 | pred_coords_map_back = inverse_affine(arg, pred_coords, bbox)
190 |
191 | time_records.append(time.time() - start)
192 |
193 | error_rate_i = calc_error_rate_i_nparts(
194 | arg.dataset,
195 | pred_coords_map_back,
196 | coords_xy[0].numpy(),
197 | error_normalize_factor
198 | )
199 |
200 | if arg.eval_visual:
201 | eval_heatmap(arg, pred_heatmaps[-1], img_name, bbox, save_img=arg.save_img)
202 | eval_pred_points(arg, pred_coords, img_name, bbox, save_img=arg.save_img)
203 |
204 | error_rate.append(error_rate_i)
205 |
206 | error_rate = np.sum(np.array(error_rate), 0) / dataset_size[arg.dataset][arg.split] * 100
207 |
208 | print(f'\nEvaluating results:'
209 | f'\nChin Error Rate: {error_rate[0]}%'
210 | f'\nBrow Error Rate: {error_rate[1]}%'
211 | f'\nNose Error Rate: {error_rate[2]}%'
212 | f'\nEyes Error Rate: {error_rate[3]}%'
213 | f'\nMouth Error Rate: {error_rate[4]}%')
214 | print('Average speed: {:.2f}FPS'.format(1. / np.mean(np.array(time_records))))
215 |
216 |
217 | if __name__ == '__main__':
218 | evaluate_nparts(args)
219 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import dlib
4 | import time
5 | import copy
6 | import numpy as np
7 | from dataset import GeneralDataset
8 | from models import *
9 | from utils import *
10 | import torch.nn.functional as F
11 | import matplotlib.pyplot as plt
12 | from scipy.interpolate import spline
13 |
14 | use_dataset = 'WFLW'
15 | use_epoch = '750'
16 |
17 | # load network
18 | devices = torch.device('cuda:0')
19 | print('***** WFLW trained Model Evaluating *****')
20 | print('Loading network ...')
21 | estimator = Estimator()
22 | regressor = Regressor(output=2*kp_num[use_dataset])
23 | estimator = load_weights(estimator, 'estimator_'+use_epoch+'.pth', devices)
24 | regressor = load_weights(regressor, use_dataset+'_regressor_'+use_epoch+'.pth', devices)
25 | estimator = estimator.cuda(device=devices)
26 | regressor = regressor.cuda(device=devices)
27 | estimator.eval()
28 | regressor.eval()
29 | print('Loading network done!\nStart testing ...')
30 |
31 | # detect face and facial landmark
32 | rescale_ratio = 0.2/2
33 | cv2.namedWindow("Face Detector")
34 | cap = cv2.VideoCapture(0)
35 | face_keypoint_coords = []
36 |
37 | while cap.isOpened(): # isOpened() 检测摄像头是否处于打开状态
38 | ret, img = cap.read() # 把摄像头获取的图像信息保存至img变量
39 | if ret is True: # 如果摄像头读取图像成功
40 | cv2.imshow('Image', img)
41 | k = cv2.waitKey(100)
42 | if k == ord('c') or k == ord('C'):
43 | t_start = str(int(time.time()))
44 |
45 | face_detector = dlib.cnn_face_detection_model_v1('mmod_human_face_detector.dat')
46 | rec = face_detector(img, 1)
47 |
48 | if len(rec) == 0:
49 | print('No Face Detected!')
50 | else:
51 | print('Detect ' + str(len(rec)) + ' face(s).')
52 |
53 | with torch.no_grad():
54 | for face_i in range(len(rec)):
55 | t = str(int(time.time()))
56 |
57 | rec_list = rec.pop().rect
58 | height = rec_list.bottom() - rec_list.top()
59 | width = rec_list.right() - rec_list.left()
60 | bbox = [
61 | int(rec_list.left() - rescale_ratio * width),
62 | int(rec_list.top() - rescale_ratio * height),
63 | int(rec_list.right() + rescale_ratio * width),
64 | int(rec_list.bottom() + rescale_ratio * height)
65 | ]
66 | position_before = np.float32([
67 | [int(bbox[0]), int(bbox[1])],
68 | [int(bbox[0]), int(bbox[3])],
69 | [int(bbox[2]), int(bbox[3])]
70 | ])
71 | position_after = np.float32([
72 | [0, 0],
73 | [0, 255],
74 | [255, 255]
75 | ])
76 | crop_matrix = cv2.getAffineTransform(position_before, position_after)
77 | face_img = cv2.warpAffine(img, crop_matrix, (256, 256))
78 | face_gray = convert_img_to_gray(face_img)
79 | face_norm = pic_normalize(face_gray)
80 |
81 | input_face = torch.Tensor(face_norm)
82 | input_face = input_face.unsqueeze(0)
83 | input_face = input_face.unsqueeze(0).cuda()
84 |
85 | pred_heatmaps = estimator(input_face)
86 | pred_coords = regressor(input_face, pred_heatmaps[-1].detach()).detach().cpu().squeeze().numpy()
87 |
88 | for kp_index in range(kp_num[use_dataset]):
89 | cv2.circle(
90 | face_img,
91 | (int(pred_coords[2 * kp_index]), int(pred_coords[2 * kp_index + 1])),
92 | 2,
93 | (0, 0, 255),
94 | -1
95 | )
96 | show_img(face_img, 'face_small_keypoint'+str(face_i), 500, 650, keep=True)
97 | cv2.imwrite('./pics/face_' + t + '_0.png', face_img)
98 |
99 | heatmaps = F.interpolate(
100 | pred_heatmaps[-1],
101 | scale_factor=4,
102 | mode='bilinear',
103 | align_corners=True
104 | )
105 | heatmaps = heatmaps.squeeze(0).detach().cpu().numpy()
106 | heatmaps_sum = heatmaps[0]
107 | for heatmaps_index in range(boundary_num - 1):
108 | heatmaps_sum += heatmaps[heatmaps_index + 1]
109 | plt.axis('off')
110 | plt.imshow(heatmaps_sum, interpolation='nearest', vmax=1., vmin=0.)
111 | fig = plt.gcf()
112 | fig.set_size_inches(2.56 / 3, 2.56 / 3)
113 | plt.gca().xaxis.set_major_locator(plt.NullLocator())
114 | plt.gca().yaxis.set_major_locator(plt.NullLocator())
115 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
116 | plt.margins(0, 0)
117 | fig.savefig('hm.png', format='png', transparent=True, dpi=300, pad_inches=0)
118 | hm = cv2.imread('hm.png')
119 | syn = cv2.addWeighted(face_img, 0.4, hm, 0.6, 0)
120 | show_img(syn, 'face_small_boundary'+str(face_i), 900, 650)
121 | cv2.imwrite('./pics/face_' + t + '_1.png', syn)
122 |
123 | pred_coords_copy = copy.deepcopy(pred_coords)
124 | for i in range(kp_num[use_dataset]):
125 | pred_coords_copy[2 * i] = \
126 | bbox[0] + pred_coords_copy[2 * i] / 255 * (bbox[2] - bbox[0])
127 | pred_coords_copy[2 * i + 1] = bbox[1] + pred_coords_copy[2 * i + 1] / 255 * (
128 | bbox[3] - bbox[1])
129 | face_keypoint_coords.append(pred_coords_copy)
130 |
131 | if len(face_keypoint_coords) != 0:
132 | for face_id, coords in enumerate(face_keypoint_coords):
133 | for kp_index in range(kp_num[use_dataset]):
134 | cv2.circle(
135 | img,
136 | (int(coords[2 * kp_index]), int(coords[2 * kp_index + 1])),
137 | 2,
138 | (0, 0, 255),
139 | -1
140 | )
141 | show_img(img, 'face_whole', 1400, 650)
142 | cv2.imwrite('./pics/face_' + t_start + '.png', img)
143 | face_keypoint_coords = []
144 |
145 | if k == ord('q') or k == ord('Q'):
146 | break
147 |
148 | print('QUIT.')
149 | if os.path.exists('hm.png'):
150 | os.remove('hm.png')
151 | cap.release() # 关闭摄像头
152 | cv2.destroyAllWindows()
153 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import WingLoss
2 | from .models import Estimator, Regressor, Discrim
3 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class HeatmapLoss(nn.Module):
7 | def __init__(self):
8 | super(HeatmapLoss, self).__init__()
9 |
10 | def forward(self, pred, gt):
11 | assert pred.size() == gt.size()
12 | loss = ((pred - gt)**2)
13 | loss = loss.sum(dim=3).sum(dim=2).sum(dim=1).mean() / 2.
14 | return loss
15 |
16 |
17 | class WingLoss(nn.Module):
18 |
19 | def __init__(self, w=10, epsilon=2, weight=None):
20 | super(WingLoss, self).__init__()
21 | self.w = w
22 | self.epsilon = epsilon
23 | self.C = self.w - self.w * np.log(1 + self.w / self.epsilon)
24 | self.weight = weight
25 |
26 | def forward(self, predictions, targets):
27 | x = predictions - targets
28 | if self.weight is not None:
29 | x = x * self.weight
30 | t = torch.abs(x)
31 |
32 | return torch.mean(torch.where(t < self.w, self.w * torch.log(1 + t / self.epsilon), t - self.C))
33 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .losses import HeatmapLoss
5 |
6 |
7 | class Bottleneck(nn.Module):
8 | expansion = 4
9 |
10 | def __init__(self, inplanes, planes, stride=1, downsample=None):
11 | super(Bottleneck, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(inplanes)
13 | self.relu1 = nn.ReLU(inplace=False)
14 | self.conv1 = nn.Conv2d(inplanes, planes, padding=0,
15 | kernel_size=1, stride=1, bias=False)
16 | self.bn2 = nn.BatchNorm2d(planes)
17 | self.relu2 = nn.ReLU(inplace=False)
18 | self.conv2 = nn.Conv2d(planes, planes, padding=1,
19 | kernel_size=3, stride=stride, bias=False)
20 | self.bn3 = nn.BatchNorm2d(planes)
21 | self.relu3 = nn.ReLU(inplace=False)
22 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, padding=0,
23 | kernel_size=1, stride=1, bias=False)
24 | if stride != 1 or inplanes != planes * self.expansion:
25 | downsample = nn.Conv2d(inplanes, planes * self.expansion, padding=0,
26 | kernel_size=1, stride=stride, bias=False)
27 | self.downsample = downsample
28 |
29 | for m in self.modules():
30 | if m.__class__.__name__ in ['Conv2d']:
31 | nn.init.kaiming_uniform_(m.weight.data)
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | if self.downsample is not None:
37 | out = self.conv1(x)
38 | out = self.bn2(out)
39 | out = self.relu2(out)
40 | out = self.conv2(out)
41 | out = self.bn3(out)
42 | out = self.relu3(out)
43 | out = self.conv3(out)
44 | else:
45 | out = self.bn1(x)
46 | out = self.relu1(out)
47 | out = self.conv1(out)
48 | out = self.bn2(out)
49 | out = self.relu2(out)
50 | out = self.conv2(out)
51 | out = self.bn3(out)
52 | out = self.relu3(out)
53 | out = self.conv3(out)
54 |
55 | if self.downsample is not None:
56 | residual = self.downsample(x)
57 |
58 | out = out + residual
59 |
60 | return out
61 |
62 |
63 | class Hourglass(nn.Module):
64 |
65 | def __init__(self, block=Bottleneck, num_blocks=1, planes=64, depth=4):
66 | super(Hourglass, self).__init__()
67 | self.depth = depth
68 | self.maxpool = nn.MaxPool2d(2, stride=2)
69 | self.hg = self._make_hourglass(block, num_blocks, planes, depth)
70 |
71 | for m in self.modules():
72 | if m.__class__.__name__ in ['Conv2d']:
73 | nn.init.kaiming_uniform_(m.weight.data)
74 |
75 | @staticmethod
76 | def _make_residual(block, num_blocks, planes):
77 | layers = []
78 | for index in range(0, num_blocks):
79 | layers.append(block(planes * block.expansion, planes))
80 | return nn.Sequential(*layers)
81 |
82 | def _make_hourglass(self, block, num_blocks, planes, depth):
83 | hourglass = []
84 | for index in range(depth):
85 | res = []
86 | for j in range(3):
87 | res.append(self._make_residual(block, num_blocks, planes))
88 | if index == 0:
89 | res.append(self._make_residual(block, num_blocks, planes))
90 | hourglass.append(nn.ModuleList(res))
91 | return nn.ModuleList(hourglass)
92 |
93 | def _hourglass_forward(self, n, x):
94 | up1 = self.hg[n - 1][0](x)
95 | low1 = self.maxpool(x)
96 | low1 = self.hg[n - 1][1](low1)
97 |
98 | if n > 1:
99 | low2 = self._hourglass_forward(n - 1, low1)
100 | else:
101 | low2 = self.hg[n - 1][3](low1)
102 | low3 = self.hg[n - 1][2](low2)
103 | up2 = F.interpolate(low3, scale_factor=2, mode='bilinear', align_corners=True)
104 | out = up1 + up2
105 | return out
106 |
107 | def forward(self, x):
108 | return self._hourglass_forward(self.depth, x)
109 |
110 |
111 | class FMFHourglass(nn.Module):
112 |
113 | def __init__(self, planes, depth):
114 | super(FMFHourglass, self).__init__()
115 | self.depth = depth
116 | self.maxpool = nn.MaxPool2d(2, stride=2)
117 | hourglass = []
118 | for index in range(depth):
119 | res = []
120 | for j in range(3):
121 | res.append(Bottleneck(planes * Bottleneck.expansion, planes))
122 | if index == depth - 1:
123 | del(res[-1])
124 | hourglass.append(nn.ModuleList(res))
125 | self.hg = nn.ModuleList(hourglass)
126 |
127 | for m in self.modules():
128 | if m.__class__.__name__ in ['Conv2d']:
129 | nn.init.kaiming_uniform_(m.weight.data)
130 |
131 | def _hourglass_forward(self, n, x):
132 | up1 = self.hg[n - 1][2](x)
133 | low1 = self.maxpool(x)
134 | low1 = self.hg[n - 1][0](low1)
135 |
136 | if n > 1:
137 | low2 = self._hourglass_forward(n - 1, low1)
138 | low2 = self.hg[n - 1][1](low2)
139 | else:
140 | low2 = self.hg[n - 1][1](low1)
141 | up2 = F.interpolate(low2, scale_factor=2, mode='bilinear', align_corners=True)
142 | out = up1 + up2
143 | return out
144 |
145 | def forward(self, x):
146 | out = self.maxpool(x)
147 | out = self.hg[self.depth-1][0](out)
148 | if self.depth > 1:
149 | out = self._hourglass_forward(self.depth-1, out)
150 | out = self.hg[self.depth-1][1](out)
151 | out = F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=True)
152 | return out
153 |
154 |
155 | class MessagePassing(nn.Module):
156 | pass_order = {'A': ['1', '13', '12', '11', '10', '5', '4', '7', '9', '6', '8', '2', '3'],
157 | 'B': ['2', '3', '6', '8', '7', '9', '4', '5', '10', '11', '12', '13', '1']}
158 | boundary_relation = {'A': {'1': ['2', '3', '7', '9', '13'],
159 | '2': [],
160 | '3': [],
161 | '4': ['7', '9'],
162 | '5': ['4'],
163 | '6': ['2'],
164 | '7': ['6'],
165 | '8': ['3'],
166 | '9': ['8'],
167 | '10': ['5'],
168 | '11': ['10'],
169 | '12': ['11'],
170 | '13': ['12']},
171 | 'B': {'1': [],
172 | '2': ['1', '6'],
173 | '3': ['1', '8'],
174 | '4': ['5'],
175 | '5': ['10'],
176 | '6': ['7'],
177 | '7': ['1', '4'],
178 | '8': ['9'],
179 | '9': ['1', '4'],
180 | '10': ['11'],
181 | '11': ['12'],
182 | '12': ['13'],
183 | '13': ['1']}}
184 |
185 | def __init__(self, classes=13, step=2, inchannels=256, channels=16, first=0, last=0):
186 | super(MessagePassing, self).__init__()
187 | self.first = first # 标识当前次message passing是否是第一次message passing, 1表示是第一次
188 | self.last = last # 标识当前次message passing是否是最后一次message passing, 1表示是最后一次
189 | self.classes = classes # boundary number: 13
190 | self.step = step # message passing steps: 2
191 | prepare_conv, prepare_bn, prepare_relu = [], [], []
192 | after_bn, after_relu, after_conv = [], [], []
193 | inner_level_pass, inter_level_pass = [], []
194 | for index in range(2 * classes):
195 | prepare_conv.append(nn.Conv2d(inchannels, channels, padding=0,
196 | kernel_size=1, stride=1, bias=False))
197 | prepare_bn.append(nn.BatchNorm2d(channels))
198 | prepare_relu.append(nn.ReLU())
199 | for index in range(classes):
200 | after_bn.append(nn.BatchNorm2d(2*channels))
201 | after_relu.append(nn.ReLU())
202 | after_conv.append(nn.Conv2d(2*channels, 1, padding=0,
203 | kernel_size=1, stride=1, bias=False))
204 | for item in self.pass_order['A']:
205 | for index in range(len(self.boundary_relation['A'][item])):
206 | inner_level_pass.append(self._make_passing())
207 | for item in self.pass_order['B']:
208 | for index in range(len(self.boundary_relation['B'][item])):
209 | inner_level_pass.append(self._make_passing())
210 | if self.last == 0:
211 | for index in range(2*self.classes):
212 | inter_level_pass.append(self._make_passing())
213 | self.pre_conv = nn.ModuleList(prepare_conv)
214 | self.pre_bn = nn.ModuleList(prepare_bn)
215 | self.pre_relu = nn.ModuleList(prepare_relu)
216 | self.aft_bn = nn.ModuleList(after_bn)
217 | self.aft_relu = nn.ModuleList(after_relu)
218 | self.aft_conv = nn.ModuleList(after_conv)
219 | self.inner_pass = nn.ModuleList(inner_level_pass)
220 | self.inter_pass = nn.ModuleList(inter_level_pass)
221 |
222 | for m in self.modules():
223 | if m.__class__.__name__ in ['Conv2d']:
224 | nn.init.kaiming_uniform_(m.weight.data)
225 |
226 | def _make_passing(self, inplanes=16, planes=8, pad=3, ker_size=7, stride=1, bias=False):
227 | passing = []
228 | for pass_step in range(self.step):
229 | if pass_step == 0:
230 | passing.append(nn.Conv2d(inplanes, planes, padding=pad,
231 | kernel_size=ker_size, stride=stride, bias=bias))
232 | passing.append(nn.BatchNorm2d(planes))
233 | passing.append(nn.ReLU())
234 | elif pass_step == self.step - 1:
235 | passing.append(nn.Conv2d(planes, inplanes, padding=pad,
236 | kernel_size=ker_size, stride=stride, bias=bias))
237 | else:
238 | passing.append(nn.Conv2d(planes, planes, padding=pad,
239 | kernel_size=ker_size, stride=stride, bias=bias))
240 | passing.append(nn.BatchNorm2d(planes))
241 | passing.append(nn.ReLU())
242 | return nn.Sequential(*passing)
243 |
244 | def forward(self, x, ahead_msg):
245 | inner_msg_count = 0
246 | feature_map = []
247 | result = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [],
248 | '10': [], '11': [], '12': [], '13': []}
249 | result_a = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [],
250 | '10': [], '11': [], '12': [], '13': []}
251 | result_b = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [],
252 | '10': [], '11': [], '12': [], '13': []}
253 | msg_box_a = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [],
254 | '10': [], '11': [], '12': [], '13': []}
255 | msg_box_b = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [],
256 | '10': [], '11': [], '12': [], '13': []}
257 | inter_level_msg = {'A': {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [],
258 | '9': [], '10': [], '11': [], '12': [], '13': []},
259 | 'B': {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [],
260 | '9': [], '10': [], '11': [], '12': [], '13': []}}
261 |
262 | for index in range(self.classes): # direction 'A'
263 | out = self.pre_conv[index](x)
264 | for get_msg_index in range(len(msg_box_a[self.pass_order['A'][index]])): # get inner level msg
265 | out = out + msg_box_a[self.pass_order['A'][index]][get_msg_index]
266 | if self.first == 0: # 即不是第一次message passing, get inter level msg
267 | out = out + ahead_msg['A'][self.pass_order['A'][index]][0]
268 | out = self.pre_bn[index](out)
269 | out = self.pre_relu[index](out)
270 | result_a[self.pass_order['A'][index]].append(out) # save to be concatenated
271 | for send_msg_index in range(len(self.boundary_relation['A'][self.pass_order['A'][index]])): # message pass
272 | temp = self.inner_pass[inner_msg_count](out)
273 | inner_msg_count = inner_msg_count + 1
274 | msg_box_a[self.boundary_relation['A'][self.pass_order['A'][index]][send_msg_index]].append(temp)
275 | if self.last == 0: # 即不是最后一次message passing,则向下一个stack传递消息
276 | temp = self.inter_pass[index](out)
277 | inter_level_msg['A'][self.pass_order['A'][index]].append(temp)
278 |
279 | for index in range(self.classes): # direction 'B'
280 | out = self.pre_conv[index + self.classes](x)
281 | for get_msg_index in range(len(msg_box_b[self.pass_order['B'][index]])): # get inner level msg
282 | out = out + msg_box_b[self.pass_order['B'][index]][get_msg_index]
283 | if self.first == 0: # 即不是第一次message passing, get inter level msg
284 | out = out + ahead_msg['B'][self.pass_order['B'][index]][0]
285 | out = self.pre_bn[index + self.classes](out)
286 | out = self.pre_relu[index + self.classes](out)
287 | result_b[self.pass_order['B'][index]].append(out) # save to be concatenated
288 | for send_msg_index in range(len(self.boundary_relation['B'][self.pass_order['B'][index]])): # message pass
289 | temp = self.inner_pass[inner_msg_count](out)
290 | inner_msg_count = inner_msg_count + 1
291 | msg_box_b[self.boundary_relation['B'][self.pass_order['B'][index]][send_msg_index]].append(temp)
292 | if self.last == 0: # 即不是最后一次message passing,则向下一个stack传递消息
293 | temp = self.inter_pass[index + self.classes](out)
294 | inter_level_msg['B'][self.pass_order['B'][index]].append(temp)
295 |
296 | for index in range(self.classes): # concatenation and conv to get feature_map
297 | result[str(index + 1)] = torch.cat((result_a[str(index + 1)][0],
298 | result_b[str(index + 1)][0]), 1) # after concat: 1 32 64 64
299 | result[str(index + 1)] = self.aft_bn[index](result[str(index + 1)])
300 | result[str(index + 1)] = self.aft_relu[index](result[str(index + 1)])
301 | result[str(index + 1)] = self.aft_conv[index](result[str(index + 1)])
302 |
303 | feature_map.append(result['1'])
304 | for index in range(self.classes - 1): # concat all 'classes' feature maps
305 | feature_map[0] = torch.cat((feature_map[0], result[str(index + 2)]), 1)
306 |
307 | if self.last == 0: # 如果不是最后一个stack的message passing,则除了输出feature map外还输出层间消息
308 | return feature_map[0], inter_level_msg
309 | else:
310 | return feature_map[0]
311 |
312 |
313 | class Estimator(nn.Module):
314 |
315 | def __init__(self, stacks=4, msg_pass=1):
316 | super(Estimator, self).__init__()
317 | self.stacks = stacks
318 | self.msg_pass = msg_pass
319 | self.hm_loss = HeatmapLoss()
320 | self.conv1 = nn.Conv2d(1, 64, padding=3, kernel_size=7,
321 | stride=2, bias=False)
322 | self.conv1_bn = nn.BatchNorm2d(64)
323 | self.conv1_relu = nn.ReLU(inplace=False)
324 | self.pre_res_1 = Bottleneck(64, 32)
325 | self.pool1 = nn.MaxPool2d(3, stride=2, padding=1) # problem, need to see the source code of caffe
326 | self.pre_res_2 = Bottleneck(128, 32)
327 | self.pre_res_2_bn = nn.BatchNorm2d(128)
328 | self.pre_res_2_relu = nn.ReLU(inplace=False)
329 | self.hourglass_0 = Bottleneck(128, 64)
330 | hg, mp = [], []
331 | linear_1_res, linear_1_bn, linear_1_relu, linear_1_conv = [], [], [], []
332 | linear_2_bn, linear_2_relu, linear_2_conv = [], [], []
333 | linear_3 = []
334 | linear_mp_bn, linear_mp_relu, linear_mp_conv = [], [], []
335 | for index in range(self.stacks):
336 | hg.append(Hourglass())
337 | linear_1_res.append(Bottleneck(256, 64))
338 | linear_1_bn.append(nn.BatchNorm2d(256))
339 | linear_1_relu.append(nn.ReLU())
340 | linear_1_conv.append(nn.Conv2d(256, 256, padding=0, kernel_size=1,
341 | stride=1, bias=False))
342 | if msg_pass:
343 | if index == 0:
344 | mp.append(MessagePassing(first=1))
345 | elif index == self.stacks - 1:
346 | mp.append(MessagePassing(last=1))
347 | else:
348 | mp.append(MessagePassing())
349 | else:
350 | linear_mp_bn.append(nn.BatchNorm2d(256))
351 | linear_mp_relu.append(nn.ReLU())
352 | linear_mp_conv.append(nn.Conv2d(256, 13, padding=0, kernel_size=1,
353 | stride=1, bias=False))
354 | if index != self.stacks - 1:
355 | linear_2_bn.append(nn.BatchNorm2d(256))
356 | linear_2_relu.append(nn.ReLU())
357 | linear_2_conv.append(nn.Conv2d(256, 256, padding=0, kernel_size=1,
358 | stride=1, bias=False))
359 | linear_3.append(nn.Conv2d(13, 256, padding=0, kernel_size=1,
360 | stride=1, bias=False))
361 | self.hg = nn.ModuleList(hg)
362 | self.linear_1_res = nn.ModuleList(linear_1_res)
363 | self.linear_1_bn = nn.ModuleList(linear_1_bn)
364 | self.linear_1_relu = nn.ModuleList(linear_1_relu)
365 | self.linear_1_conv = nn.ModuleList(linear_1_conv)
366 | self.mp = nn.ModuleList(mp)
367 | self.linear_2_bn = nn.ModuleList(linear_2_bn)
368 | self.linear_2_relu = nn.ModuleList(linear_2_relu)
369 | self.linear_2_conv = nn.ModuleList(linear_2_conv)
370 | self.linear_3 = nn.ModuleList(linear_3)
371 | self.linear_mp_bn = nn.ModuleList(linear_mp_bn)
372 | self.linear_mp_relu = nn.ModuleList(linear_mp_relu)
373 | self.linear_mp_conv = nn.ModuleList(linear_mp_conv)
374 |
375 | for m in self.modules():
376 | if m.__class__.__name__ in ['Conv2d']:
377 | nn.init.kaiming_uniform_(m.weight.data)
378 |
379 | def forward(self, x):
380 | heatmaps = [] # save all the stacks output feature maps
381 | inter_level_msg = []
382 | out = self.conv1(x)
383 | out = self.conv1_bn(out)
384 | out = self.conv1_relu(out)
385 | out = self.pre_res_1(out)
386 | out = self.pool1(out)
387 | out = self.pre_res_2(out)
388 | out = self.pre_res_2_bn(out)
389 | out = self.pre_res_2_relu(out)
390 | out = self.hourglass_0(out)
391 | for index in range(self.stacks):
392 | temp = self.hg[index](out)
393 | temp = self.linear_1_res[index](temp)
394 | temp = self.linear_1_bn[index](temp)
395 | temp = self.linear_1_relu[index](temp)
396 | temp = self.linear_1_conv[index](temp)
397 | if self.msg_pass:
398 | if index != self.stacks - 1:
399 | heatmap, inter_level_msg = self.mp[index](temp, inter_level_msg)
400 | else:
401 | heatmap = self.mp[index](temp, inter_level_msg)
402 | else:
403 | heatmap = self.linear_mp_bn[index](temp)
404 | heatmap = self.linear_mp_relu[index](heatmap)
405 | heatmap = self.linear_mp_conv[index](heatmap)
406 | heatmaps.append(heatmap)
407 | if index != self.stacks - 1:
408 | temp = self.linear_2_bn[index](temp)
409 | temp = self.linear_2_relu[index](temp)
410 | linear2_out = self.linear_2_conv[index](temp)
411 | linear3_out = self.linear_3[index](heatmap)
412 | out = out + linear2_out + linear3_out
413 | return heatmaps # 每一个stack的输出heatmap经过append
414 |
415 | def calc_loss(self, pred_heatmaps, gt_heatmap):
416 | heatmap_loss = []
417 | for stack in range(self.stacks):
418 | heatmap_loss.append(self.hm_loss(pred_heatmaps[stack], gt_heatmap))
419 | heatmap_loss = torch.stack(heatmap_loss, dim=0)
420 | heatmap_loss = torch.sum(heatmap_loss)
421 | return heatmap_loss
422 |
423 |
424 | class Regressor(nn.Module):
425 |
426 | def __init__(self, classes=13, fuse_stages=4, planes=16, output=196):
427 | super(Regressor, self).__init__()
428 | self.classes = classes
429 | self.FMF_stages = 3
430 | self.fuse_stages = fuse_stages
431 | self.planes = planes
432 | self.conv1 = nn.Conv2d(14, self.planes, padding=3, kernel_size=7, stride=2, bias=False) \
433 | if fuse_stages > 0 else nn.Conv2d(1, self.planes, padding=3, kernel_size=7, stride=2, bias=False)
434 | self.bn1 = nn.BatchNorm2d(self.planes)
435 | self.bn2 = nn.BatchNorm2d(256) # regressor最后一个Batchnorm
436 | self.relu1 = nn.ReLU(inplace=False)
437 | self.relu2 = nn.ReLU(inplace=False) # regressor ip之前最后一个relu
438 | self.pool1 = nn.MaxPool2d(3, stride=2, padding=1) # problem, need to see the source code of caffe (solved)
439 | baseline_bn, baseline_relu, baseline_res_1, baseline_res_2 = [], [], [], []
440 | pre_fmf_bn, pre_fmf_relu, pre_fmf_conv = [], [], []
441 | aft_fmf_bn, aft_fmf_relu, aft_fmf_conv = [], [], []
442 | tanh = []
443 | fmfhourglass = []
444 | for index in range(self.FMF_stages + 1):
445 | if index == 0:
446 | baseline_bn.append(nn.BatchNorm2d(self.planes))
447 | baseline_relu.append(nn.ReLU())
448 | baseline_res_1.append(Bottleneck(self.planes, self.planes//2))
449 | baseline_res_2.append(Bottleneck(self.planes * 2, self.planes//2))
450 | else:
451 | baseline_bn.append(nn.BatchNorm2d(self.planes * pow(2, index)))
452 | baseline_relu.append(nn.ReLU())
453 | baseline_res_1.append(Bottleneck(self.planes * pow(2, index), self.planes * pow(2, index-1), stride=2))
454 | baseline_res_2.append(Bottleneck(self.planes * pow(2, index+1), self.planes * pow(2, index-1)))
455 | for index in range(self.FMF_stages):
456 | pre_fmf_bn.append(nn.BatchNorm2d(self.planes * pow(2, index+1) + self.classes))
457 | pre_fmf_relu.append(nn.ReLU())
458 | pre_fmf_conv.append(nn.Conv2d(self.planes*pow(2, index+1) + self.classes, self.planes*pow(2, index+1),
459 | padding=0, kernel_size=1, stride=1, bias=False))
460 | for index in range(self.FMF_stages):
461 | fmfhourglass.append(FMFHourglass(planes=8*pow(2, index), depth=3-index))
462 | for index in range(self.FMF_stages):
463 | aft_fmf_bn.append(nn.BatchNorm2d(self.planes * pow(2, index + 1)))
464 | aft_fmf_bn.append(nn.BatchNorm2d(self.planes * pow(2, index + 1)))
465 | aft_fmf_relu.append(nn.ReLU())
466 | aft_fmf_relu.append(nn.ReLU())
467 | aft_fmf_conv.append(nn.Conv2d(self.planes * pow(2, index + 1), self.planes * pow(2, index + 1),
468 | padding=0, kernel_size=1, stride=1, bias=False))
469 | aft_fmf_conv.append(nn.Conv2d(self.planes * pow(2, index + 1), self.planes * pow(2, index + 1),
470 | padding=0, kernel_size=1, stride=1, bias=False))
471 | tanh.append(nn.Tanh())
472 | self.bl_bn = nn.ModuleList(baseline_bn)
473 | self.bl_relu = nn.ModuleList(baseline_relu)
474 | self.bl_res_1 = nn.ModuleList(baseline_res_1)
475 | self.bl_res_2 = nn.ModuleList(baseline_res_2)
476 | self.pre_fmf_bn = nn.ModuleList(pre_fmf_bn)
477 | self.pre_fmf_relu = nn.ModuleList(pre_fmf_relu)
478 | self.pre_fmf_conv = nn.ModuleList(pre_fmf_conv)
479 | self.FMF_Hourglass = nn.ModuleList(fmfhourglass)
480 | self.aft_fmf_bn = nn.ModuleList(aft_fmf_bn)
481 | self.aft_fmf_relu = nn.ModuleList(aft_fmf_relu)
482 | self.aft_fmf_conv = nn.ModuleList(aft_fmf_conv)
483 | self.tanh = nn.ModuleList(tanh)
484 | self.fc1 = nn.Linear(256 * 8 * 8, 256) # 目前的代码暂时不考虑通用性,很多数字暂时都强硬地固定下来了
485 | self.fc2 = nn.Linear(256, 256)
486 | self.fc3 = nn.Linear(256, output)
487 | self.fc_relu1 = nn.ReLU(inplace=False)
488 | self.fc_relu2 = nn.ReLU(inplace=False)
489 |
490 | for m in self.modules():
491 | if m.__class__.__name__ in ['Conv2d']:
492 | nn.init.kaiming_uniform_(m.weight.data)
493 |
494 | @staticmethod
495 | def num_flat_features(x):
496 | size = x.size()[1:] # all dimensions except the batch dimension
497 | num_features = 1
498 | for s in size:
499 | num_features *= s
500 | return num_features
501 |
502 | def forward(self, input_img, heatmap):
503 | data_concat = []
504 | if self.fuse_stages > 0:
505 | out = F.interpolate(heatmap, scale_factor=4, mode='bilinear', align_corners=True)
506 | data_concat.append(input_img)
507 | for index in range(self.classes - 1):
508 | data_concat[0] = torch.cat((data_concat[0], input_img), 1)
509 | out = data_concat[0]*out
510 | out = torch.cat((out, input_img), 1)
511 | else:
512 | out = input_img
513 | out = self.conv1(out)
514 | out = self.bn1(out)
515 | out = self.relu1(out)
516 | out = self.pool1(out)
517 | out = self.bl_bn[0](out)
518 | out = self.bl_relu[0](out)
519 | out = self.bl_res_1[0](out)
520 | out = self.bl_res_2[0](out)
521 | for index in range(self.FMF_stages):
522 | if index < self.fuse_stages - 1:
523 | temp = F.interpolate(heatmap, scale_factor=pow(2, -1*index), mode='bilinear', align_corners=True)
524 | temp_out = torch.cat((temp, out), 1)
525 | temp_out = self.pre_fmf_bn[index](temp_out)
526 | temp_out = self.pre_fmf_relu[index](temp_out)
527 | temp_out = self.pre_fmf_conv[index](temp_out)
528 | temp_out = self.FMF_Hourglass[index](temp_out)
529 | temp_out = self.aft_fmf_bn[2 * index](temp_out)
530 | temp_out = self.aft_fmf_relu[2 * index](temp_out)
531 | temp_out = self.aft_fmf_conv[2 * index](temp_out)
532 | temp_out = self.aft_fmf_bn[2 * index + 1](temp_out)
533 | temp_out = self.aft_fmf_relu[2 * index + 1](temp_out)
534 | temp_out = self.aft_fmf_conv[2 * index + 1](temp_out)
535 | temp_out = self.tanh[index](temp_out)
536 | temp_out = temp_out * out
537 | out = temp_out + out
538 | out = self.bl_bn[index+1](out)
539 | out = self.bl_relu[index + 1](out)
540 | out = self.bl_res_1[index + 1](out)
541 | out = self.bl_res_2[index + 1](out)
542 | out = self.bn2(out)
543 | out = self.relu2(out)
544 | out = out.view(-1, self.num_flat_features(out))
545 | out = self.fc1(out)
546 | out = self.fc_relu1(out)
547 | out = self.fc2(out)
548 | out = self.fc_relu2(out)
549 | out = self.fc3(out)
550 |
551 | return out
552 |
553 |
554 | class Discrim(nn.Module):
555 | channels, linear_n = [13, 64, 192, 384, 256, 256], [4096, 1024, 256, 13]
556 | ker_size, strd, pad = [2, 5, 3, 3, 3], [2, 1, 1, 1, 1], [0, 2, 1, 1, 1]
557 | maxpool_mask = [1, 1, 0, 0, 1]
558 |
559 | def __init__(self, conv_layers=5, linear_layers=3):
560 | super(Discrim, self).__init__()
561 | conv_features = []
562 | linear_classify = []
563 | for index in range(conv_layers):
564 | conv_features.append(nn.Conv2d(Discrim.channels[index], Discrim.channels[index + 1],
565 | kernel_size=Discrim.ker_size[index],
566 | stride=Discrim.strd[index],
567 | padding=Discrim.pad[index],
568 | bias=False))
569 | conv_features.append(nn.BatchNorm2d(Discrim.channels[index + 1]))
570 | conv_features.append(nn.ReLU(inplace=False))
571 | if Discrim.maxpool_mask[index] == 1:
572 | conv_features.append(nn.MaxPool2d(3, stride=2, padding=1))
573 | else:
574 | conv_features.append(nn.ReLU(inplace=False))
575 | for index in range(linear_layers):
576 | linear_classify.append(nn.Linear(Discrim.linear_n[index], Discrim.linear_n[index+1]))
577 | if index != linear_layers - 1:
578 | linear_classify.append(nn.ReLU(inplace=False))
579 | else:
580 | linear_classify.append(nn.Sigmoid())
581 | self.features = nn.Sequential(*conv_features)
582 | self.classifier = nn.Sequential(*linear_classify)
583 |
584 | @staticmethod
585 | def num_flat_features(x):
586 | size = x.size()[1:] # all dimensions except the batch dimension
587 | num_features = 1
588 | for s in size:
589 | num_features *= s
590 | return num_features
591 |
592 | def forward(self, x):
593 | out = self.features(x)
594 | out = out.view(-1, self.num_flat_features(out))
595 | out = self.classifier(out)
596 | return out
597 |
598 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from models import WingLoss, Estimator, Regressor, Discrim
5 | from dataset import GeneralDataset
6 | from utils import *
7 | import tqdm
8 |
9 | if not os.path.exists(args.save_folder):
10 | os.mkdir(args.save_folder)
11 | if not os.path.exists(args.resume_folder):
12 | os.mkdir(args.resume_folder)
13 |
14 |
15 | def train(arg):
16 | epoch = None
17 | devices = get_devices_list(arg)
18 |
19 | print('***** Normal Training *****')
20 | print('Training parameters:\n' +
21 | '# Dataset: ' + arg.dataset + '\n' +
22 | '# Dataset split: ' + arg.split + '\n' +
23 | '# Batchsize: ' + str(arg.batch_size) + '\n' +
24 | '# Num workers: ' + str(arg.workers) + '\n' +
25 | '# PDB: ' + str(arg.PDB) + '\n' +
26 | '# Use GPU: ' + str(arg.cuda) + '\n' +
27 | '# Start lr: ' + str(arg.lr) + '\n' +
28 | '# Max epoch: ' + str(arg.max_epoch) + '\n' +
29 | '# Loss type: ' + arg.loss_type + '\n' +
30 | '# Resumed model: ' + str(arg.resume_epoch > 0))
31 | if arg.resume_epoch > 0:
32 | print('# Resumed epoch: ' + str(arg.resume_epoch))
33 |
34 | print('Creating networks ...')
35 | estimator, regressor, discrim = create_model(arg, devices)
36 | estimator.train()
37 | regressor.train()
38 | if discrim is not None:
39 | discrim.train()
40 | print('Creating networks done!')
41 |
42 | optimizer_estimator = torch.optim.SGD(estimator.parameters(), lr=arg.lr, momentum=arg.momentum,
43 | weight_decay=arg.weight_decay)
44 | optimizer_regressor = torch.optim.SGD(regressor.parameters(), lr=arg.lr, momentum=arg.momentum,
45 | weight_decay=arg.weight_decay)
46 | optimizer_discrim = torch.optim.SGD(discrim.parameters(), lr=arg.lr, momentum=arg.momentum,
47 | weight_decay=arg.weight_decay) if discrim is not None else None
48 |
49 | if arg.loss_type == 'L2':
50 | criterion = nn.MSELoss()
51 | elif arg.loss_type == 'L1':
52 | criterion = nn.L1Loss()
53 | elif arg.loss_type == 'smoothL1':
54 | criterion = nn.SmoothL1Loss()
55 | else:
56 | criterion = WingLoss(w=arg.wingloss_w, epsilon=arg.wingloss_e)
57 |
58 | print('Loading dataset ...')
59 | trainset = GeneralDataset(dataset=arg.dataset)
60 | print('Loading dataset done!')
61 |
62 | d_fake = (torch.zeros(arg.batch_size, 13)).cuda(device=devices[0]) if arg.GAN \
63 | else torch.zeros(arg.batch_size, 13)
64 |
65 | # evolving training
66 | print('Start training ...')
67 | for epoch in range(arg.resume_epoch, arg.max_epoch):
68 | forward_times_per_epoch, sum_loss_estimator, sum_loss_regressor = 0, 0., 0.
69 | dataloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=arg.shuffle,
70 | num_workers=arg.workers, pin_memory=True)
71 |
72 | if epoch in arg.step_values:
73 | optimizer_estimator.param_groups[0]['lr'] *= arg.gamma
74 | optimizer_regressor.param_groups[0]['lr'] *= arg.gamma
75 | optimizer_discrim.param_groups[0]['lr'] *= arg.gamma
76 |
77 | for data in tqdm.tqdm(dataloader):
78 | forward_times_per_epoch += 1
79 | input_images, gt_coords_xy, gt_heatmap, _, _, _ = data
80 | true_batchsize = input_images.size()[0]
81 | input_images = input_images.unsqueeze(1)
82 | input_images = input_images.cuda(device=devices[0])
83 | gt_coords_xy = gt_coords_xy.cuda(device=devices[0])
84 | gt_heatmap = gt_heatmap.cuda(device=devices[0])
85 |
86 | optimizer_estimator.zero_grad()
87 | heatmaps = estimator(input_images)
88 | loss_G = estimator.calc_loss(heatmaps, gt_heatmap)
89 | loss_A = torch.mean(torch.log2(1. - discrim(heatmaps[-1])))
90 | loss_estimator = loss_G + loss_A
91 | loss_estimator.backward()
92 | optimizer_estimator.step()
93 |
94 | sum_loss_estimator += loss_estimator
95 |
96 | optimizer_discrim.zero_grad()
97 | loss_D_real = -torch.mean(torch.log2(discrim(gt_heatmap)))
98 | loss_D_fake = -torch.mean(torch.log2(1.-torch.abs(discrim(heatmaps[-1].detach()) -
99 | d_fake[:true_batchsize])))
100 | loss_D = loss_D_real + loss_D_fake
101 | loss_D.backward()
102 | optimizer_discrim.step()
103 |
104 | optimizer_regressor.zero_grad()
105 | out = regressor(input_images, heatmaps[-1].detach())
106 | loss_regressor = criterion(out, gt_coords_xy)
107 | loss_regressor.backward()
108 | optimizer_regressor.step()
109 |
110 | d_fake = (calc_d_fake(arg.dataset, out.detach(), gt_coords_xy, true_batchsize,
111 | arg.batch_size)).cuda(device=devices[0])
112 |
113 | sum_loss_regressor += loss_regressor
114 |
115 | if (epoch+1) % arg.save_interval == 0:
116 | torch.save(estimator.state_dict(), arg.save_folder + 'estimator_'+str(epoch+1)+'.pth')
117 | torch.save(discrim.state_dict(), arg.save_folder + 'discrim_'+str(epoch+1)+'.pth')
118 | torch.save(regressor.state_dict(), arg.save_folder + arg.dataset+'_regressor_'+str(epoch+1)+'.pth')
119 |
120 | print('\nepoch: {:0>4d} | loss_estimator: {:.2f} | loss_regressor: {:.2f}'.format(
121 | epoch,
122 | sum_loss_estimator.item()/forward_times_per_epoch,
123 | sum_loss_regressor.item()/forward_times_per_epoch
124 | ))
125 |
126 | torch.save(estimator.state_dict(), arg.save_folder + 'estimator_'+str(epoch+1)+'.pth')
127 | torch.save(discrim.state_dict(), arg.save_folder + 'discrim_'+str(epoch+1)+'.pth')
128 | torch.save(regressor.state_dict(), arg.save_folder + arg.dataset+'_regressor_'+str(epoch+1)+'.pth')
129 | print('Training done!')
130 |
131 |
132 | def train_with_gt_heatmap(arg):
133 | epoch = None
134 | devices = get_devices_list(arg)
135 |
136 | print('***** Training with ground truth heatmap *****')
137 | print('Training parameters:\n' +
138 | '# Dataset: ' + arg.dataset + '\n' +
139 | '# Dataset split: ' + arg.split + '\n' +
140 | '# Batchsize: ' + str(arg.batch_size) + '\n' +
141 | '# Num workers: ' + str(arg.workers) + '\n' +
142 | '# PDB: ' + str(arg.PDB) + '\n' +
143 | '# Use GPU: ' + str(arg.cuda) + '\n' +
144 | '# Start lr: ' + str(arg.lr) + '\n' +
145 | '# Lr step values: ' + str(arg.step_values) + '\n' +
146 | '# Lr step gamma: ' + str(arg.gamma) + '\n' +
147 | '# Max epoch: ' + str(arg.max_epoch) + '\n' +
148 | '# Loss type: ' + arg.loss_type + '\n' +
149 | '# Resumed model: ' + str(arg.resume_epoch > 0))
150 | if arg.resume_epoch > 0:
151 | print('# Resumed epoch: ' + str(arg.resume_epoch))
152 |
153 | print('Creating networks ...')
154 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2 * kp_num[arg.dataset])
155 | regressor = load_weights(regressor, arg.resume_folder + arg.dataset + '_regressor_' +
156 | str(arg.resume_epoch) + '.pth', devices_list[0]) if arg.resume_epoch > 0 else regressor
157 | regressor = regressor.cuda(device=devices[0])
158 | regressor.train()
159 | print('Creating networks done!')
160 |
161 | optimizer_regressor = torch.optim.SGD(regressor.parameters(), lr=arg.lr, momentum=arg.momentum,
162 | weight_decay=arg.weight_decay)
163 |
164 | if arg.loss_type == 'L2':
165 | criterion = nn.MSELoss()
166 | elif arg.loss_type == 'L1':
167 | criterion = nn.L1Loss()
168 | elif arg.loss_type == 'smoothL1':
169 | criterion = nn.SmoothL1Loss()
170 | else:
171 | criterion = WingLoss(w=arg.wingloss_w, epsilon=arg.wingloss_e)
172 |
173 | print('Loading dataset ...')
174 | trainset = GeneralDataset(dataset=arg.dataset)
175 | print('Loading dataset done!')
176 |
177 | print('Start training ...')
178 | for epoch in range(arg.resume_epoch, arg.max_epoch):
179 | forward_times_per_epoch, sum_loss_regressor = 0, 0.
180 | dataloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=arg.shuffle,
181 | num_workers=arg.workers, pin_memory=True)
182 |
183 | if epoch in arg.step_values:
184 | optimizer_regressor.param_groups[0]['lr'] *= arg.gamma
185 |
186 | for data in tqdm.tqdm(dataloader):
187 | forward_times_per_epoch += 1
188 | input_images, gt_coords_xy, gt_heatmap, _, _, _ = data
189 | input_images = input_images.unsqueeze(1)
190 | input_images = input_images.cuda(device=devices[0])
191 | gt_coords_xy = gt_coords_xy.cuda(device=devices[0])
192 | gt_heatmap = gt_heatmap.cuda(device=devices[0])
193 |
194 | optimizer_regressor.zero_grad()
195 | out = regressor(input_images, gt_heatmap)
196 | loss_regressor = criterion(out, gt_coords_xy)
197 | loss_regressor.backward()
198 | optimizer_regressor.step()
199 |
200 | sum_loss_regressor += loss_regressor
201 |
202 | if (epoch + 1) % arg.save_interval == 0:
203 | torch.save(regressor.state_dict(), arg.save_folder + arg.dataset + '_regressor_' + str(epoch + 1) + '.pth')
204 |
205 | print('\nepoch: {:0>4d} | loss_regressor: {:.2f}'.format(
206 | epoch,
207 | sum_loss_regressor.item() / forward_times_per_epoch
208 | ))
209 |
210 | torch.save(regressor.state_dict(), arg.save_folder + arg.dataset + '_regressor_' + str(epoch + 1) + '.pth')
211 | print('Training done!')
212 |
213 |
214 | if __name__ == '__main__':
215 | train(args)
216 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .args import args
2 | from .dataload import *
3 | from .dataset_info import *
4 | from .train_eval_utils import *
5 | from .visual import *
6 |
--------------------------------------------------------------------------------
/utils/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description='LAB')
4 |
5 | # dataset
6 | parser.add_argument('--dataset_route', default='/home/jin/new_datasets/', type=str)
7 | parser.add_argument('--dataset', default='WFLW', type=str)
8 | parser.add_argument('--split', default='test', type=str)
9 |
10 | # dataloader
11 | parser.add_argument('--crop_size', default=256, type=int)
12 | parser.add_argument('--batch_size', default=4, type=int)
13 | parser.add_argument('--workers', default=8, type=int)
14 | parser.add_argument('--shuffle', default=True, type=bool)
15 | parser.add_argument('--PDB', default=False, type=bool)
16 | parser.add_argument('--RGB', default=False, type=bool)
17 | parser.add_argument('--trans_ratio', default=0.1, type=float)
18 | parser.add_argument('--rotate_limit', default=20., type=float)
19 | parser.add_argument('--scale_ratio', default=0.1, type=float)
20 |
21 | # devices
22 | parser.add_argument('--cuda', default=True, type=bool)
23 | parser.add_argument('--gpu_id', default='0', type=str)
24 |
25 | # learning parameters
26 | parser.add_argument('--momentum', default=0.9, type=float)
27 | parser.add_argument('--weight_decay', default=5e-4, type=float)
28 | parser.add_argument('--lr', default=2e-5, type=float)
29 | parser.add_argument('--gamma', default=0.2, type=float)
30 | parser.add_argument('--step_values', default=[1000, 1500], type=list)
31 | parser.add_argument('--max_epoch', default=2000, type=int)
32 |
33 | # losses setting
34 | parser.add_argument('--loss_type', default='smoothL1', type=str,
35 | choices=['L1', 'L2', 'smoothL1', 'wingloss'])
36 | parser.add_argument('--wingloss_w', default=10, type=int)
37 | parser.add_argument('--wingloss_e', default=2, type=int)
38 |
39 | # resume training parameters
40 | parser.add_argument('--resume_epoch', default=0, type=int)
41 | parser.add_argument('--resume_folder', default='./weights/ckpts/', type=str)
42 |
43 | # model saving parameters
44 | parser.add_argument('--save_folder', default='./weights/', type=str)
45 | parser.add_argument('--save_interval', default=100, type=int)
46 |
47 | # model setting
48 | parser.add_argument('--hour_stack', default=4, type=int)
49 | parser.add_argument('--msg_pass', default=True, type=bool)
50 | parser.add_argument('--GAN', default=True, type=bool)
51 | parser.add_argument('--fuse_stage', default=4, type=int)
52 | parser.add_argument('--sigma', default=1.0, type=float)
53 | parser.add_argument('--theta', default=1.5, type=float)
54 | parser.add_argument('--delta', default=0.8, type=float)
55 |
56 | # evaluate parameters
57 | parser.add_argument('--eval_epoch', default=750, type=int)
58 | parser.add_argument('--max_threshold', default=0.1, type=float)
59 | parser.add_argument('--norm_way', default='inter_ocular', type=str,
60 | choices=['inter_pupil', 'inter_ocular', 'face_size'])
61 | parser.add_argument('--eval_visual', default=False, type=bool)
62 | parser.add_argument('--save_img', default=False, type=bool)
63 |
64 | args = parser.parse_args()
65 |
66 | assert args.resume_epoch < args.step_values[0]
67 | assert args.resume_epoch < args.max_epoch
68 | assert args.step_values[-1] < args.max_epoch
69 |
--------------------------------------------------------------------------------
/utils/dataload.py:
--------------------------------------------------------------------------------
1 | from .dataset_info import *
2 | from .args import args
3 | from .pdb import pdb
4 | from .visual import *
5 |
6 | import cv2
7 | import time
8 | import random
9 | import numpy as np
10 | from scipy.interpolate import splprep, splev
11 |
12 |
13 | def get_annotations_list(dataset, split, ispdb=False):
14 | annotations = []
15 | annotation_file = open(dataset_route[dataset] + dataset + '_' + split + '_annos.txt')
16 |
17 | for line in range(dataset_size[dataset][split]):
18 | annotations.append(annotation_file.readline().rstrip().split())
19 | annotation_file.close()
20 |
21 | if ispdb:
22 | annos = []
23 | allshapes = np.zeros((2 * kp_num[dataset], len(annotations)))
24 | for line_index, line in enumerate(annotations):
25 | coord_x = np.array(list(map(float, line[:2*kp_num[dataset]:2])))
26 | coord_y = np.array(list(map(float, line[1:2*kp_num[dataset]:2])))
27 | position_before = np.float32([[int(line[-7]), int(line[-6])],
28 | [int(line[-7]), int(line[-4])],
29 | [int(line[-5]), int(line[-4])]])
30 | position_after = np.float32([[0, 0],
31 | [0, args.crop_size - 1],
32 | [args.crop_size - 1, args.crop_size - 1]])
33 | crop_matrix = cv2.getAffineTransform(position_before, position_after)
34 | coord_x_after_crop = crop_matrix[0][0] * coord_x + crop_matrix[0][1] * coord_y + crop_matrix[0][2]
35 | coord_y_after_crop = crop_matrix[1][0] * coord_x + crop_matrix[1][1] * coord_y + crop_matrix[1][2]
36 | allshapes[0:kp_num[dataset], line_index] = list(coord_x_after_crop)
37 | allshapes[kp_num[dataset]:2*kp_num[dataset], line_index] = list(coord_y_after_crop)
38 | newidx = pdb(dataset, allshapes, dataset_pdb_numbins[dataset])
39 | for id_index in newidx:
40 | annos.append(annotations[int(id_index)])
41 | return annos
42 |
43 | return annotations
44 |
45 |
46 | def convert_img_to_gray(img):
47 | if img.shape[2] == 1:
48 | return img
49 | elif img.shape[2] == 4:
50 | gray = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY)
51 | return gray
52 | elif img.shape[2] == 3:
53 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
54 | return gray
55 | else:
56 | raise Exception("img shape wrong!\n")
57 |
58 |
59 | def get_random_transform_param(split, bbox):
60 | translation, trans_dir, rotation, scaling, flip, gaussian_blur = 0, 0, 0, 1., 0, 0
61 | if split in ['train']:
62 | random.seed(time.time())
63 | translate_param = int(args.trans_ratio * abs(bbox[2] - bbox[0]))
64 | translation = random.randint(-translate_param, translate_param)
65 | trans_dir = random.randint(0, 3) # LU:0 RU:1 LL:2 RL:3
66 | rotation = random.uniform(-args.rotate_limit, args.rotate_limit)
67 | scaling = random.uniform(1-args.scale_ratio, 1+args.scale_ratio)
68 | flip = random.randint(0, 1)
69 | gaussian_blur = random.randint(0, 1)
70 | return translation, trans_dir, rotation, scaling, flip, gaussian_blur
71 |
72 |
73 | def further_transform(pic, bbox, flip, gaussian_blur):
74 | if flip == 1:
75 | pic = cv2.flip(pic, 1)
76 | if abs(bbox[2] - bbox[0]) < 120 or gaussian_blur == 0:
77 | return pic
78 | else:
79 | return cv2.GaussianBlur(pic, (5, 5), 1)
80 |
81 |
82 | def get_affine_matrix(crop_size, rotation, scaling):
83 | center = (crop_size / 2.0, crop_size / 2.0)
84 | return cv2.getRotationMatrix2D(center, rotation, scaling)
85 |
86 |
87 | def pic_normalize(pic): # for accelerate, now support gray pic only
88 | pic = np.float32(pic)
89 | mean, std = cv2.meanStdDev(pic)
90 | pic_channel = 1 if len(pic.shape) == 2 else 3
91 | for channel in range(0, pic_channel):
92 | if std[channel][0] < 1e-6:
93 | std[channel][0] = 1
94 | pic = (pic - mean) / std
95 | return np.float32(pic)
96 |
97 |
98 | def get_cropped_coords(dataset, crop_matrix, coord_x, coord_y, flip=0):
99 | coord_x, coord_y = np.array(coord_x), np.array(coord_y)
100 | temp_x = crop_matrix[0][0] * coord_x + crop_matrix[0][1] * coord_y + crop_matrix[0][2] if flip == 0 else \
101 | float(args.crop_size) - (crop_matrix[0][0] * coord_x + crop_matrix[0][1] * coord_y + crop_matrix[0][2]) - 1
102 | temp_y = crop_matrix[1][0] * coord_x + crop_matrix[1][1] * coord_y + crop_matrix[1][2]
103 | if flip:
104 | temp_x = temp_x[np.array(flip_relation[dataset])[:, 1]]
105 | temp_y = temp_y[np.array(flip_relation[dataset])[:, 1]]
106 | return temp_x, temp_y
107 |
108 |
109 | def get_gt_coords(dataset, affine_matrix, coord_x, coord_y):
110 | out = np.zeros(2*kp_num[dataset])
111 | out[:2*kp_num[dataset]:2] = affine_matrix[0][0] * coord_x + affine_matrix[0][1] * coord_y + affine_matrix[0][2]
112 | out[1:2*kp_num[dataset]:2] = affine_matrix[1][0] * coord_x + affine_matrix[1][1] * coord_y + affine_matrix[1][2]
113 | return np.array(np.float32(out))
114 |
115 |
116 | def get_gt_heatmap(dataset, gt_coords):
117 | coord_x, coord_y, gt_heatmap = [], [], []
118 | for index in range(boundary_num):
119 | gt_heatmap.append(np.ones((64, 64)))
120 | gt_heatmap[index].tolist()
121 | boundary_x = {'chin': [], 'leb': [], 'reb': [], 'bon': [], 'breath': [], 'lue': [], 'lle': [],
122 | 'rue': [], 'rle': [], 'usul': [], 'lsul': [], 'usll': [], 'lsll': []}
123 | boundary_y = {'chin': [], 'leb': [], 'reb': [], 'bon': [], 'breath': [], 'lue': [], 'lle': [],
124 | 'rue': [], 'rle': [], 'usul': [], 'lsul': [], 'usll': [], 'lsll': []}
125 | points = {'chin': [], 'leb': [], 'reb': [], 'bon': [], 'breath': [], 'lue': [], 'lle': [],
126 | 'rue': [], 'rle': [], 'usul': [], 'lsul': [], 'usll': [], 'lsll': []}
127 | resize_matrix = cv2.getAffineTransform(np.float32([[0, 0], [0, args.crop_size-1],
128 | [args.crop_size-1, args.crop_size-1]]),
129 | np.float32([[0, 0], [0, heatmap_size-1],
130 | [heatmap_size-1, heatmap_size-1]]))
131 | for kp_index in range(kp_num[dataset]):
132 | coord_x.append(
133 | resize_matrix[0][0] * gt_coords[2 * kp_index] +
134 | resize_matrix[0][1] * gt_coords[2 * kp_index + 1] +
135 | resize_matrix[0][2] + random.uniform(-0.2, 0.2)
136 | )
137 | coord_y.append(
138 | resize_matrix[1][0] * gt_coords[2 * kp_index] +
139 | resize_matrix[1][1] * gt_coords[2 * kp_index + 1] +
140 | resize_matrix[1][2] + random.uniform(-0.2, 0.2)
141 | )
142 | for boundary_index in range(boundary_num):
143 | for kp_index in range(
144 | point_range[dataset][boundary_index][0],
145 | point_range[dataset][boundary_index][1]
146 | ):
147 | boundary_x[boundary_keys[boundary_index]].append(coord_x[kp_index])
148 | boundary_y[boundary_keys[boundary_index]].append(coord_y[kp_index])
149 | if boundary_keys[boundary_index] in boundary_special.keys() and\
150 | dataset in boundary_special[boundary_keys[boundary_index]]:
151 | boundary_x[boundary_keys[boundary_index]].append(
152 | coord_x[duplicate_point[dataset][boundary_keys[boundary_index]]])
153 | boundary_y[boundary_keys[boundary_index]].append(
154 | coord_y[duplicate_point[dataset][boundary_keys[boundary_index]]])
155 | for k_index, k in enumerate(boundary_keys):
156 | if point_num_per_boundary[dataset][k_index] >= 2.:
157 | if len(boundary_x[k]) == len(set(boundary_x[k])) or len(boundary_y[k]) == len(set(boundary_y[k])):
158 | points[k].append(boundary_x[k])
159 | points[k].append(boundary_y[k])
160 | res = splprep(points[k], s=0.0, k=1)
161 | u_new = np.linspace(res[1].min(), res[1].max(), interp_points_num[k])
162 | boundary_x[k], boundary_y[k] = splev(u_new, res[0], der=0)
163 | for index, k in enumerate(boundary_keys):
164 | if point_num_per_boundary[dataset][index] >= 2.:
165 | for i in range(len(boundary_x[k]) - 1):
166 | cv2.line(gt_heatmap[index], (int(boundary_x[k][i]), int(boundary_y[k][i])),
167 | (int(boundary_x[k][i+1]), int(boundary_y[k][i+1])), 0)
168 | else:
169 | cv2.circle(gt_heatmap[index], (int(boundary_x[k][0]), int(boundary_y[k][0])), 2, 0, -1)
170 | gt_heatmap[index] = np.uint8(gt_heatmap[index])
171 | gt_heatmap[index] = cv2.distanceTransform(gt_heatmap[index], cv2.DIST_L2, 5)
172 | gt_heatmap[index] = np.float32(np.array(gt_heatmap[index]))
173 | gt_heatmap[index] = gt_heatmap[index].reshape(64*64)
174 | (gt_heatmap[index])[(gt_heatmap[index]) < 3. * args.sigma] = \
175 | np.exp(-(gt_heatmap[index])[(gt_heatmap[index]) < 3 * args.sigma] *
176 | (gt_heatmap[index])[(gt_heatmap[index]) < 3 * args.sigma] / 2. * args.sigma * args.sigma)
177 | (gt_heatmap[index])[(gt_heatmap[index]) >= 3. * args.sigma] = 0.
178 | gt_heatmap[index] = gt_heatmap[index].reshape([64, 64])
179 | return np.array(gt_heatmap)
180 |
181 |
182 | def get_item_from(dataset, split, annotation):
183 | pic = cv2.imread(dataset_route[dataset]+annotation[-1])
184 | pic = convert_img_to_gray(pic) if not args.RGB else pic
185 | coord_x = list(map(float, annotation[:2*kp_num[dataset]:2]))
186 | coord_y = list(map(float, annotation[1:2*kp_num[dataset]:2]))
187 | coord_xy = np.array(np.float32(list(map(float, annotation[:2*kp_num[dataset]]))))
188 | bbox = np.array(list(map(int, annotation[-7:-3])))
189 |
190 | translation, trans_dir, rotation, scaling, flip, gaussian_blur = get_random_transform_param(split, bbox)
191 |
192 | position_before = np.float32([[int(bbox[0]) + pow(-1, trans_dir+1)*translation,
193 | int(bbox[1]) + pow(-1, trans_dir//2+1)*translation],
194 | [int(bbox[0]) + pow(-1, trans_dir+1)*translation,
195 | int(bbox[3]) + pow(-1, trans_dir//2+1)*translation],
196 | [int(bbox[2]) + pow(-1, trans_dir+1)*translation,
197 | int(bbox[3]) + pow(-1, trans_dir//2+1)*translation]])
198 | position_after = np.float32([[0, 0],
199 | [0, args.crop_size - 1],
200 | [args.crop_size - 1, args.crop_size - 1]])
201 | crop_matrix = cv2.getAffineTransform(position_before, position_after)
202 | pic_crop = cv2.warpAffine(pic, crop_matrix, (args.crop_size, args.crop_size))
203 | pic_crop = further_transform(pic_crop, bbox, flip, gaussian_blur) if args.split in ['train'] else pic_crop
204 | affine_matrix = get_affine_matrix(args.crop_size, rotation, scaling)
205 | pic_affine = cv2.warpAffine(pic_crop, affine_matrix, (args.crop_size, args.crop_size))
206 | pic_affine = pic_normalize(pic_affine) if not args.RGB else pic_affine
207 |
208 | coord_x_cropped, coord_y_cropped = get_cropped_coords(dataset, crop_matrix, coord_x, coord_y, flip=flip)
209 | gt_coords_xy = get_gt_coords(dataset, affine_matrix, coord_x_cropped, coord_y_cropped)
210 |
211 | gt_heatmap = get_gt_heatmap(dataset, gt_coords_xy)
212 |
213 | return pic_affine, gt_coords_xy, gt_heatmap, coord_xy, bbox, annotation[-1]
214 |
--------------------------------------------------------------------------------
/utils/dataset_info.py:
--------------------------------------------------------------------------------
1 | from .args import args
2 |
3 | heatmap_size = 64
4 | boundary_num = 13
5 |
6 | boundary_keys = ['chin', 'leb', 'reb', 'bon', 'breath', 'lue', 'lle', 'rue', 'rle', 'usul', 'lsul', 'usll', 'lsll']
7 |
8 | interp_points_num = {
9 | 'chin': 120,
10 | 'leb': 32,
11 | 'reb': 32,
12 | 'bon': 32,
13 | 'breath': 25,
14 | 'lue': 25,
15 | 'lle': 25,
16 | 'rue': 25,
17 | 'rle': 25,
18 | 'usul': 32,
19 | 'lsul': 32,
20 | 'usll': 32,
21 | 'lsll': 32
22 | }
23 |
24 | dataset_pdb_numbins = {
25 | '300W': 9,
26 | 'AFLW': 17,
27 | 'COFW': 7,
28 | 'WFLW': 13
29 | }
30 |
31 | dataset_route = {
32 | '300W': args.dataset_route+'/300W/',
33 | 'AFLW': args.dataset_route+'/AFLW/',
34 | 'COFW': args.dataset_route+'/COFW/',
35 | 'WFLW': args.dataset_route+'/WFLW/'
36 | }
37 |
38 | dataset_size = {
39 | '300W': {
40 | 'train': 3148,
41 | 'common_subset': 554,
42 | 'challenge_subset': 135,
43 | 'fullset': 689,
44 | '300W_testset': 600,
45 | 'COFW68': 507 # 该数据集用于300W数据集上训练模型的测试
46 | },
47 | 'AFLW': {
48 | 'train': 20000,
49 | 'test': 24386,
50 | 'frontal': 1314
51 | },
52 | 'COFW': {
53 | 'train': 1345,
54 | 'test': 507
55 | },
56 | 'WFLW': {
57 | 'train': 7500,
58 | 'test': 2500,
59 | 'pose': 326,
60 | 'expression': 314,
61 | 'illumination': 698,
62 | 'makeup': 206,
63 | 'occlusion': 736,
64 | 'blur': 773
65 | }
66 | }
67 |
68 | kp_num = {
69 | '300W': 68,
70 | 'AFLW': 19,
71 | 'COFW': 29,
72 | 'WFLW': 98
73 | }
74 |
75 | point_num_per_boundary = {
76 | '300W': [17., 5., 5., 4., 5., 4., 4., 4., 4., 7., 5., 5., 7.],
77 | 'AFLW': [1., 3., 3., 1., 2., 3., 3., 3., 3., 3., 3., 3., 3.],
78 | 'COFW': [1., 3., 3., 1., 3., 3., 3., 3., 3., 3., 1., 1., 3.],
79 | 'WFLW': [33., 9., 9., 4., 5., 5., 5., 5., 5., 7., 5., 5., 7.]
80 | }
81 |
82 | boundary_special = { # 有些边界线条使用的关键点和其他边界形成不连续交集,特殊处理
83 | 'lle': ['300W', 'COFW', 'WFLW'],
84 | 'rle': ['300W', 'COFW', 'WFLW'],
85 | 'usll': ['300W', 'WFLW'],
86 | 'lsll': ['300W', 'COFW', 'WFLW']
87 | }
88 |
89 | duplicate_point = { # 需要重复使用的关键点的序号,从0开始计数
90 | '300W': {
91 | 'lle': 36,
92 | 'rle': 42,
93 | 'usll': 60,
94 | 'lsll': 48
95 | },
96 | 'COFW': {
97 | 'lle': 13,
98 | 'rle': 17,
99 | 'lsll': 21
100 | },
101 | 'WFLW': {
102 | 'lle': 60,
103 | 'rle': 68,
104 | 'usll': 88,
105 | 'lsll': 76
106 | }
107 | }
108 |
109 | point_range = { # notice: this is 'range', the later number pluses 1; the order is boundary order; index starts from 0
110 | '300W': [
111 | [0, 17], [17, 22], [22, 27], [27, 31], [31, 36],
112 | [36, 40], [39, 42], [42, 46], [45, 48], [48, 55],
113 | [60, 65], [64, 68], [54, 60]
114 | ],
115 | 'AFLW': [
116 | [0, 1], [1, 4], [4, 7], [7, 8], [8, 10],
117 | [10, 13], [10, 13], [13, 16], [13, 16], [16, 19],
118 | [16, 19], [16, 19], [16, 19]
119 | ],
120 | 'COFW': [
121 | [0, 1], [1, 4], [5, 8], [9, 10], [10, 13],
122 | [13, 16], [15, 17], [17, 20], [19, 21], [21, 24],
123 | [25, 26], [26, 27], [23, 25]
124 | ],
125 | 'WFLW': [
126 | [0, 33], [33, 38], [42, 47], [51, 55], [55, 60],
127 | [60, 65], [64, 68], [68, 73], [72, 76], [76, 83],
128 | [88, 93], [92, 96], [82, 88]
129 | ]
130 | }
131 |
132 | flip_relation = {
133 | '300W': [
134 | [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11],
135 | [6, 10], [7, 9], [8, 8], [9, 7], [10, 6], [11, 5],
136 | [12, 4], [13, 3], [14, 2], [15, 1], [16, 0], [17, 26],
137 | [18, 25], [19, 24], [20, 23], [21, 22], [22, 21], [23, 20],
138 | [24, 19], [25, 18], [26, 17], [27, 27], [28, 28], [29, 29],
139 | [30, 30], [31, 35], [32, 34], [33, 33], [34, 32], [35, 31],
140 | [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
141 | [42, 39], [43, 38], [44, 37], [45, 36], [46, 41], [47, 40],
142 | [48, 54], [49, 53], [50, 52], [51, 51], [52, 50], [53, 49],
143 | [54, 48], [55, 59], [56, 58], [57, 57], [58, 56], [59, 55],
144 | [60, 64], [61, 63], [62, 62], [63, 61], [64, 60], [65, 67],
145 | [66, 66], [67, 65]
146 | ],
147 | 'AFLW': [
148 | [0, 0], [1, 6], [2, 5], [3, 4], [4, 3], [5, 2],
149 | [6, 1], [7, 7], [8, 9], [9, 8], [10, 15], [11, 14],
150 | [12, 13], [13, 12], [14, 11], [15, 10], [16, 18], [17, 17],
151 | [18, 16]
152 | ],
153 | 'COFW': [
154 | [0, 0], [1, 7], [2, 6], [3, 5], [4, 8], [5, 3],
155 | [6, 2], [7, 1], [8, 4], [9, 9], [10, 12], [11, 11],
156 | [12, 10], [13, 19], [14, 18], [15, 17], [16, 20], [17, 15],
157 | [18, 14], [19, 13], [20, 16], [21, 23], [22, 22], [23, 21],
158 | [24, 24], [25, 25], [26, 26], [27, 28], [28, 27]
159 | ],
160 | 'WFLW': [
161 | [0, 32], [1, 31], [2, 30], [3, 29], [4, 28], [5, 27],
162 | [6, 26], [7, 25], [8, 24], [9, 23], [10, 22], [11, 21],
163 | [12, 20], [13, 19], [14, 18], [15, 17], [16, 16], [17, 15],
164 | [18, 14], [19, 13], [20, 12], [21, 11], [22, 10], [23, 9],
165 | [24, 8], [25, 7], [26, 6], [27, 5], [28, 4], [29, 3],
166 | [30, 2], [31, 1], [32, 0], [33, 46], [34, 45], [35, 44],
167 | [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47],
168 | [42, 37], [43, 36], [44, 35], [45, 34], [46, 33], [47, 41],
169 | [48, 40], [49, 39], [50, 38], [51, 51], [52, 52], [53, 53],
170 | [54, 54], [55, 59], [56, 58], [57, 57], [58, 56], [59, 55],
171 | [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75],
172 | [66, 74], [67, 73], [68, 64], [69, 63], [70, 62], [71, 61],
173 | [72, 60], [73, 67], [74, 66], [75, 65], [76, 82], [77, 81],
174 | [78, 80], [79, 79], [80, 78], [81, 77], [82, 76], [83, 87],
175 | [84, 86], [85, 85], [86, 84], [87, 83], [88, 92], [89, 91],
176 | [90, 90], [91, 89], [92, 88], [93, 95], [94, 94], [95, 93],
177 | [96, 97], [97, 96]
178 | ]
179 | }
180 |
181 | lo_eye_corner_index_x = {'300W': 72, 'AFLW': 20, 'COFW': 26, 'WFLW': 120}
182 | lo_eye_corner_index_y = {'300W': 73, 'AFLW': 21, 'COFW': 27, 'WFLW': 121}
183 | ro_eye_corner_index_x = {'300W': 90, 'AFLW': 30, 'COFW': 38, 'WFLW': 144}
184 | ro_eye_corner_index_y = {'300W': 91, 'AFLW': 31, 'COFW': 39, 'WFLW': 145}
185 | l_eye_center_index_x = {'300W': [72, 74, 76, 78, 80, 82], 'AFLW': 22, 'COFW': 54, 'WFLW': 192}
186 | l_eye_center_index_y = {'300W': [73, 75, 77, 79, 81, 83], 'AFLW': 23, 'COFW': 55, 'WFLW': 193}
187 | r_eye_center_index_x = {'300W': [84, 86, 88, 90, 92, 94], 'AFLW': 28, 'COFW': 56, 'WFLW': 194}
188 | r_eye_center_index_y = {'300W': [85, 87, 89, 91, 93, 95], 'AFLW': 29, 'COFW': 57, 'WFLW': 195}
189 |
190 | nparts = { # [chin, brow, nose, eyes, mouth], totally 5 parts
191 | '300W': [
192 | [0, 17], [17, 27], [27, 36], [36, 48], [48, 68]
193 | ],
194 | 'WFLW': [
195 | [0, 33], [33, 51], [51, 60], [60, 76], [76, 96]
196 | ]
197 | }
198 |
--------------------------------------------------------------------------------
/utils/pdb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.decomposition import PCA
3 |
4 |
5 | def procrustes(X, Y, scaling=True, reflection='best'):
6 | n, m = X.shape
7 | ny, my = Y.shape
8 | muX = X.mean(0)
9 | muY = Y.mean(0)
10 | X0 = X - muX
11 | Y0 = Y - muY
12 | ssX = (X0**2.).sum()
13 | ssY = (Y0**2.).sum()
14 | # centred Frobenius norm
15 | normX = np.sqrt(ssX)
16 | normY = np.sqrt(ssY)
17 | # scale to equal (unit) norm
18 | X0 = X0/normX if normX > 1e-6 else X0
19 | Y0 = Y0/normY if normY > 1e-6 else Y0
20 | if my < m:
21 | Y0 = np.concatenate((Y0, np.zeros(n, m-my)), 0)
22 | # optimum rotation matrix of Y
23 | A = np.dot(X0.T, Y0)
24 | U, s, Vt = np.linalg.svd(A, full_matrices=False)
25 | V = Vt.T
26 | T = np.dot(V, U.T)
27 | if reflection is not 'best':
28 | # does the current solution use a reflection?
29 | have_reflection = np.linalg.det(T) < 0
30 | # if that's not what was specified, force another reflection
31 | if reflection != have_reflection:
32 | V[:, -1] *= -1
33 | s[-1] *= -1
34 | T = np.dot(V, U.T)
35 | traceTA = s.sum()
36 | if scaling:
37 | # optimum scaling of Y
38 | b = traceTA * normX / normY if normY > 1e-6 else traceTA * normX
39 | # standarised distance between X and b*Y*T + c
40 | d = 1 - traceTA**2
41 | # transformed coords
42 | Z = normX*traceTA*np.dot(Y0, T) + muX
43 | else:
44 | b = 1
45 | d = 1 + ssY/ssX - 2 * traceTA * normY / normX
46 | Z = normY*np.dot(Y0, T) + muX
47 | # transformation matrix
48 | if my < m:
49 | T = T[:my, :]
50 | c = muX - b*np.dot(muY, T)
51 | # transformation values
52 | tform = {'rotation': T, 'scale': b, 'translation': c}
53 |
54 | return d, Z, tform
55 |
56 |
57 | # input as array
58 | def pdb(dataset, allShapes, numBins):
59 | alignedShape = allShapes
60 | meanShape = np.mean(alignedShape, 1)
61 | for i in range(len(alignedShape[0])):
62 | _, tmpS, _ = procrustes(meanShape.reshape((-1, 2), order='F'),
63 | alignedShape[:, i].reshape((-1, 2), order='F'))
64 | alignedShape[:, i] = tmpS.reshape((1, -1), order='F')
65 |
66 | meanShape = np.mean(alignedShape, 1)
67 | meanShape = meanShape.repeat(len(alignedShape[0])).reshape(-1, len(alignedShape[0]))
68 | alignedShape = alignedShape - meanShape
69 | pca = PCA(n_components=2) if dataset in ['AFLW', 'COFW'] else PCA(n_components=1)
70 | posePara = pca.fit_transform(np.transpose(alignedShape))
71 |
72 | absPosePara = np.abs(posePara[:, 1]) if dataset in ['AFLW', 'COFW'] else np.abs(posePara)
73 | maxPosePara = np.max(absPosePara)
74 | maxSampleInBins = np.max(np.histogram(absPosePara, numBins)[0])
75 |
76 | newIdx = np.array([])
77 | for i in range(numBins):
78 | tmp1 = set([index for index in range(len(absPosePara))
79 | if absPosePara[index] >= i*maxPosePara/numBins])
80 | tmp2 = set([index for index in range(len(absPosePara))
81 | if absPosePara[index] <= (i+1)*maxPosePara/numBins])
82 | tmpTrainIdx = np.array(list(tmp1 & tmp2))
83 | ratio = round(maxSampleInBins / len(tmpTrainIdx)) if len(tmpTrainIdx) > 0 else 0
84 | newIdx = np.insert(newIdx, 0, values=tmpTrainIdx.repeat(ratio), axis=0)
85 | return newIdx
86 |
--------------------------------------------------------------------------------
/utils/train_eval_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.backends.cudnn as cudnn
3 | from collections import OrderedDict
4 | import numpy as np
5 | from sklearn.metrics import auc
6 | from utils import *
7 |
8 |
9 | def get_devices_list(arg):
10 | devices_list = [torch.device('cpu')]
11 | if arg.cuda and torch.cuda.is_available():
12 | devices_list = []
13 | for dev in arg.gpu_id.split(','):
14 | devices_list.append(torch.device('cuda:'+dev))
15 | cudnn.benchmark = True
16 | cudnn.enabled = True
17 | return devices_list
18 |
19 |
20 | def load_weights(net, pth_file, device):
21 | state_dict = torch.load(pth_file, map_location=device)
22 | # create new OrderedDict that does not contain `module.`
23 | new_state_dict = OrderedDict()
24 | for k, v in state_dict.items():
25 | head = k[:7]
26 | if head == 'module.':
27 | name = k[7:] # remove `module.`
28 | else:
29 | name = k
30 | new_state_dict[name] = v
31 | net.load_state_dict(new_state_dict)
32 | return net
33 |
34 |
35 | def create_model(arg, devices_list):
36 | from models import Estimator, Discrim, Regressor
37 |
38 | estimator = Estimator(stacks=arg.hour_stack, msg_pass=arg.msg_pass)
39 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2*kp_num[arg.dataset])
40 | discrim = Discrim() if arg.GAN else None
41 |
42 | if arg.resume_epoch > 0:
43 | estimator = load_weights(estimator, arg.resume_folder + 'estimator_' + str(arg.resume_epoch) + '.pth',
44 | devices_list[0])
45 | regressor = load_weights(regressor, arg.resume_folder + arg.dataset+'_regressor_' +
46 | str(arg.resume_epoch) + '.pth', devices_list[0])
47 | discrim = load_weights(discrim, arg.resume_folder + 'discrim_' + str(arg.resume_epoch) + '.pth',
48 | devices_list[0]) if arg.GAN else None
49 |
50 | if arg.cuda:
51 | estimator = estimator.cuda(device=devices_list[0])
52 | regressor = regressor.cuda(device=devices_list[0])
53 | discrim = discrim.cuda(device=devices_list[0]) if arg.GAN else None
54 |
55 | return estimator, regressor, discrim
56 |
57 |
58 | def calc_d_fake(dataset, pred_coords, gt_coords, bcsize, bcsize_set):
59 | error_regressor = (pred_coords - gt_coords) ** 2
60 | dist_regressor = torch.zeros(bcsize, kp_num[dataset])
61 | dfake = torch.zeros(bcsize_set, boundary_num)
62 | for batch in range(bcsize):
63 | dist_regressor[batch, :] = \
64 | (error_regressor[batch][:2*kp_num[dataset]:2] + error_regressor[batch][1:2*kp_num[dataset]:2]) \
65 | < args.theta*args.theta
66 | for batch_index in range(bcsize):
67 | for boundary_index in range(boundary_num):
68 | for kp_index in range(
69 | point_range[dataset][boundary_index][0],
70 | point_range[dataset][boundary_index][1]
71 | ):
72 | if dist_regressor[batch_index][kp_index] == 1:
73 | dfake[batch_index][boundary_index] += 1
74 | if boundary_keys[boundary_index] in boundary_special.keys() and \
75 | dataset in boundary_special[boundary_keys[boundary_index]] and \
76 | dist_regressor[batch_index][duplicate_point[dataset][boundary_keys[boundary_index]]] == 1:
77 | dfake[batch_index][boundary_index] += 1
78 | for boundary_index in range(boundary_num):
79 | if dfake[batch_index][boundary_index] / point_num_per_boundary[dataset][boundary_index] < args.delta:
80 | dfake[batch_index][boundary_index] = 0.
81 | else:
82 | dfake[batch_index][boundary_index] = 1.
83 | if bcsize < bcsize_set:
84 | for batch_index in range(bcsize, bcsize_set):
85 | dfake[batch_index] = dfake[batch_index - bcsize]
86 | return dfake
87 |
88 |
89 | def calc_normalize_factor(dataset, gt_coords_xy, normalize_way='inter_pupil'):
90 | if normalize_way == 'inter_ocular':
91 | error_normalize_factor = np.sqrt(
92 | (gt_coords_xy[0][lo_eye_corner_index_x[dataset]] - gt_coords_xy[0][ro_eye_corner_index_x[dataset]]) *
93 | (gt_coords_xy[0][lo_eye_corner_index_x[dataset]] - gt_coords_xy[0][ro_eye_corner_index_x[dataset]]) +
94 | (gt_coords_xy[0][lo_eye_corner_index_y[dataset]] - gt_coords_xy[0][ro_eye_corner_index_y[dataset]]) *
95 | (gt_coords_xy[0][lo_eye_corner_index_y[dataset]] - gt_coords_xy[0][ro_eye_corner_index_y[dataset]]))
96 | return error_normalize_factor
97 | elif normalize_way == 'inter_pupil':
98 | if l_eye_center_index_x[dataset].__class__ != list:
99 | error_normalize_factor = np.sqrt(
100 | (gt_coords_xy[0][l_eye_center_index_x[dataset]] - gt_coords_xy[0][r_eye_center_index_x[dataset]]) *
101 | (gt_coords_xy[0][l_eye_center_index_x[dataset]] - gt_coords_xy[0][r_eye_center_index_x[dataset]]) +
102 | (gt_coords_xy[0][l_eye_center_index_y[dataset]] - gt_coords_xy[0][r_eye_center_index_y[dataset]]) *
103 | (gt_coords_xy[0][l_eye_center_index_y[dataset]] - gt_coords_xy[0][r_eye_center_index_y[dataset]]))
104 | return error_normalize_factor
105 | else:
106 | length = len(l_eye_center_index_x[dataset])
107 | l_eye_x_avg, l_eye_y_avg, r_eye_x_avg, r_eye_y_avg = 0., 0., 0., 0.
108 | for i in range(length):
109 | l_eye_x_avg += gt_coords_xy[0][l_eye_center_index_x[dataset][i]]
110 | l_eye_y_avg += gt_coords_xy[0][l_eye_center_index_y[dataset][i]]
111 | r_eye_x_avg += gt_coords_xy[0][r_eye_center_index_x[dataset][i]]
112 | r_eye_y_avg += gt_coords_xy[0][r_eye_center_index_y[dataset][i]]
113 | l_eye_x_avg /= length
114 | l_eye_y_avg /= length
115 | r_eye_x_avg /= length
116 | r_eye_y_avg /= length
117 | error_normalize_factor = np.sqrt((l_eye_x_avg - r_eye_x_avg) * (l_eye_x_avg - r_eye_x_avg) +
118 | (l_eye_y_avg - r_eye_y_avg) * (l_eye_y_avg - r_eye_y_avg))
119 | return error_normalize_factor
120 |
121 |
122 | def inverse_affine(arg, pred_coords, bbox):
123 | import copy
124 | pred_coords = copy.deepcopy(pred_coords)
125 | for i in range(kp_num[arg.dataset]):
126 | pred_coords[2 * i] = bbox[0] + pred_coords[2 * i]/(arg.crop_size-1)*(bbox[2] - bbox[0])
127 | pred_coords[2 * i + 1] = bbox[1] + pred_coords[2 * i + 1]/(arg.crop_size-1)*(bbox[3] - bbox[1])
128 | return pred_coords
129 |
130 |
131 | def calc_error_rate_i(dataset, pred_coords, gt_coords_xy, error_normalize_factor):
132 | temp, error = (pred_coords - gt_coords_xy)**2, 0.
133 | for i in range(kp_num[dataset]):
134 | error += np.sqrt(temp[2*i] + temp[2*i+1])
135 | return error/kp_num[dataset]/error_normalize_factor
136 |
137 |
138 | def calc_error_rate_i_nparts(dataset, pred_coords, gt_coords_xy, error_normalize_factor):
139 | assert dataset in nparts.keys()
140 | temp, error = (pred_coords - gt_coords_xy)**2, [0., 0., 0., 0., 0.]
141 | for i in range(len(nparts[dataset])):
142 | for j in range(nparts[dataset][i][0], nparts[dataset][i][1]):
143 | error[i] += np.sqrt(temp[2*j] + temp[2*j+1])
144 | error[i] = error[i]/(nparts[dataset][i][1] - nparts[dataset][i][0])/error_normalize_factor
145 | return error
146 |
147 |
148 | def calc_auc(dataset, split, error_rate, max_threshold):
149 | error_rate = np.array(error_rate)
150 | threshold = np.linspace(0, max_threshold, num=2000)
151 | accuracys = np.zeros(threshold.shape)
152 | for i in range(threshold.size):
153 | accuracys[i] = np.sum(error_rate < threshold[i]) * 1.0 / dataset_size[dataset][split]
154 | return auc(threshold, accuracys) / max_threshold, accuracys
155 |
--------------------------------------------------------------------------------
/utils/visual.py:
--------------------------------------------------------------------------------
1 | from .dataset_info import *
2 |
3 | import cv2
4 | import numpy as np
5 | import torch.nn.functional as F
6 | import matplotlib.pyplot as plt
7 | from scipy.interpolate import spline
8 |
9 |
10 | def show_img(pic, name='pic', x=0, y=0, wait=0, keep=False):
11 | cv2.imshow(name, pic)
12 | cv2.moveWindow(name, x, y)
13 | if keep is False:
14 | cv2.waitKey(wait)
15 | cv2.destroyAllWindows()
16 |
17 |
18 | def watch_gray_heatmap(gt_heatmap):
19 | heatmap_sum = gt_heatmap[0]
20 | for index in range(boundary_num - 1):
21 | heatmap_sum += gt_heatmap[index + 1]
22 | show_img(heatmap_sum, 'heatmap_sum')
23 |
24 |
25 | def watch_pic_kp(dataset, pic, kp):
26 | for kp_index in range(kp_num[dataset]):
27 | cv2.circle(
28 | pic,
29 | (int(kp[2*kp_index]), int(kp[2*kp_index+1])),
30 | 2,
31 | (0, 0, 255),
32 | -1
33 | )
34 | show_img(pic)
35 |
36 |
37 | def watch_pic_kp_xy(dataset, pic, coord_x, coord_y):
38 | for kp_index in range(kp_num[dataset]):
39 | cv2.circle(
40 | pic,
41 | (int(coord_x[kp_index]), int(coord_y[kp_index])),
42 | 1,
43 | (0, 0, 255)
44 | )
45 | show_img(pic)
46 |
47 |
48 | def eval_heatmap(arg, heatmaps, img_name, bbox, save_img=False):
49 | heatmaps = F.interpolate(heatmaps, scale_factor=4, mode='bilinear', align_corners=True)
50 | heatmaps = heatmaps.squeeze(0).detach().cpu().numpy()
51 | heatmaps_sum = heatmaps[0]
52 | for heatmaps_index in range(boundary_num-1):
53 | heatmaps_sum += heatmaps[heatmaps_index+1]
54 | plt.axis('off')
55 | plt.imshow(heatmaps_sum, interpolation='nearest', vmax=1., vmin=0.)
56 | if save_img:
57 | import os
58 | if not os.path.exists('./imgs'):
59 | os.mkdir('./imgs')
60 | fig = plt.gcf()
61 | fig.set_size_inches(2.56 / 3, 2.56 / 3)
62 | plt.gca().xaxis.set_major_locator(plt.NullLocator())
63 | plt.gca().yaxis.set_major_locator(plt.NullLocator())
64 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
65 | plt.margins(0, 0)
66 | name = (img_name[0]).split('/')[-1]
67 | fig.savefig('./imgs/'+name.split('.')[0]+'_hm.png', format='png', transparent=True, dpi=300, pad_inches=0)
68 |
69 | pic = cv2.imread(dataset_route[arg.dataset] + img_name[0])
70 | position_before = np.float32([
71 | [int(bbox[0]), int(bbox[1])],
72 | [int(bbox[0]), int(bbox[3])],
73 | [int(bbox[2]), int(bbox[3])]
74 | ])
75 | position_after = np.float32([
76 | [0, 0],
77 | [0, arg.crop_size - 1],
78 | [arg.crop_size - 1, arg.crop_size - 1]
79 | ])
80 | crop_matrix = cv2.getAffineTransform(position_before, position_after)
81 | pic = cv2.warpAffine(pic, crop_matrix, (arg.crop_size, arg.crop_size))
82 | cv2.imwrite('./imgs/' + name.split('.')[0] + '_init.png', pic)
83 | hm = cv2.imread('./imgs/'+name.split('.')[0]+'_hm.png')
84 | syn = cv2.addWeighted(pic, 0.4, hm, 0.6, 0)
85 | cv2.imwrite('./imgs/'+name.split('.')[0]+'_syn.png', syn)
86 | plt.show()
87 |
88 |
89 | def eval_pred_points(arg, pred_coords, img_name, bbox, save_img=False):
90 | pic = cv2.imread(dataset_route[arg.dataset] + img_name[0])
91 | position_before = np.float32([
92 | [int(bbox[0]), int(bbox[1])],
93 | [int(bbox[0]), int(bbox[3])],
94 | [int(bbox[2]), int(bbox[3])]
95 | ])
96 | position_after = np.float32([
97 | [0, 0],
98 | [0, arg.crop_size - 1],
99 | [arg.crop_size - 1, arg.crop_size - 1]
100 | ])
101 | crop_matrix = cv2.getAffineTransform(position_before, position_after)
102 | pic = cv2.warpAffine(pic, crop_matrix, (arg.crop_size, arg.crop_size))
103 |
104 | for coord_index in range(kp_num[arg.dataset]):
105 | cv2.circle(pic, (int(pred_coords[2 * coord_index]), int(pred_coords[2 * coord_index + 1])), 2, (0, 0, 255))
106 | if save_img:
107 | import os
108 | if not os.path.exists('./imgs'):
109 | os.mkdir('./imgs')
110 | name = (img_name[0]).split('/')[-1]
111 | cv2.imwrite('./imgs/'+name.split('.')[0]+'_lmk.png', pic)
112 | show_img(pic)
113 |
114 |
115 | def eval_gt_pred_points(arg, gt_coords, pred_coords, img_name, bbox, save_img=False):
116 | pic = cv2.imread(dataset_route[arg.dataset] + img_name[0])
117 | position_before = np.float32([
118 | [int(bbox[0]), int(bbox[1])],
119 | [int(bbox[0]), int(bbox[3])],
120 | [int(bbox[2]), int(bbox[3])]
121 | ])
122 | position_after = np.float32([
123 | [0, 0],
124 | [0, arg.crop_size - 1],
125 | [arg.crop_size - 1, arg.crop_size - 1]
126 | ])
127 | crop_matrix = cv2.getAffineTransform(position_before, position_after)
128 | pic = cv2.warpAffine(pic, crop_matrix, (arg.crop_size, arg.crop_size))
129 |
130 | for coord_index in range(kp_num[arg.dataset]):
131 | cv2.circle(pic, (int(pred_coords[2 * coord_index]), int(pred_coords[2 * coord_index + 1])), 2, (0, 0, 255))
132 | cv2.circle(pic, (int(gt_coords[2 * coord_index]), int(gt_coords[2 * coord_index + 1])), 2, (0, 255, 0))
133 | if save_img:
134 | import os
135 | if not os.path.exists('./imgs'):
136 | os.mkdir('./imgs')
137 | name = (img_name[0]).split('/')[-1]
138 | cv2.imwrite('./imgs/'+name.split('.')[0]+'_lmk.png', pic)
139 | show_img(pic)
140 |
141 |
142 | def eval_CED(auc_record):
143 | error = np.linspace(0., 0.1, 21)
144 | error_new = np.linspace(error.min(), error.max(), 300)
145 | auc_value = np.array([auc_record[0], auc_record[99], auc_record[199], auc_record[299],
146 | auc_record[399], auc_record[499], auc_record[599], auc_record[699],
147 | auc_record[799], auc_record[899], auc_record[999], auc_record[1099],
148 | auc_record[1199], auc_record[1299], auc_record[1399], auc_record[1499],
149 | auc_record[1599], auc_record[1699], auc_record[1799], auc_record[1899],
150 | auc_record[1999]])
151 | CFSS_auc_value = np.array([0., 0., 0., 0., 0.,
152 | 0., 0.02, 0.09, 0.18, 0.30,
153 | 0.45, 0.60, 0.70, 0.75, 0.79,
154 | 0.82, 0.85, 0.87, 0.88, 0.89, 0.90])
155 | SAPM_auc_value = np.array([0., 0., 0., 0., 0.,
156 | 0., 0., 0., 0.02, 0.08,
157 | 0.17, 0.28, 0.43, 0.58, 0.71,
158 | 0.78, 0.83, 0.86, 0.89, 0.91, 0.92])
159 | TCDCN_auc_value = np.array([0., 0., 0., 0., 0.,
160 | 0., 0., 0.02, 0.05, 0.10,
161 | 0.19, 0.29, 0.38, 0.47, 0.56,
162 | 0.64, 0.70, 0.75, 0.79, 0.82, 0.826])
163 | auc_smooth = spline(error, auc_value, error_new)
164 | CFSS_auc_smooth = spline(error, CFSS_auc_value, error_new)
165 | SAPM_auc_smooth = spline(error, SAPM_auc_value, error_new)
166 | TCDCN_auc_smooth = spline(error, TCDCN_auc_value, error_new)
167 | plt.plot(error_new, auc_smooth, 'r-')
168 | plt.plot(error_new, CFSS_auc_smooth, 'g-')
169 | plt.plot(error_new, SAPM_auc_smooth, 'y-')
170 | plt.plot(error_new, TCDCN_auc_smooth, 'm-')
171 | plt.legend(['LAB, Error: 5.35%, Failure: 4.73%',
172 | 'CFSS, Error: 6.28%, Failure: 9.07%',
173 | 'SAPM, Error: 6.64%, Failure: 5.72%',
174 | 'TCDCN, Error: 7.66%, Failure: 16.17%'], loc=4)
175 | plt.plot(error, auc_value, 'rs')
176 | plt.plot(error, CFSS_auc_value, 'go')
177 | plt.plot(error, SAPM_auc_value, 'y^')
178 | plt.plot(error, TCDCN_auc_value, 'mx')
179 | plt.axis([0., 0.1, 0., 1.])
180 | plt.show()
181 |
--------------------------------------------------------------------------------