├── .idea
├── face_swap.iml
├── misc.xml
├── modules.xml
└── workspace.xml
├── README.md
├── dataset
├── __pycache__
│ └── reader.cpython-37.pyc
├── reader.py
├── video_clips_path.txt
└── video_landmarks_path.txt
├── face-alignment_projection
├── conda
│ ├── conda_upload.sh
│ └── meta.yaml
├── docs
│ └── images
│ │ ├── 2dlandmarks.png
│ │ └── face-alignment-adrian.gif
├── examples
│ └── detect_landmarks_in_image.py
├── face_alignment
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── api.cpython-36.pyc
│ │ ├── models.cpython-36.pyc
│ │ └── utils.cpython-36.pyc
│ ├── api.py
│ ├── detection
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── core.cpython-36.pyc
│ │ ├── core.py
│ │ ├── dlib
│ │ │ ├── __init__.py
│ │ │ └── dlib_detector.py
│ │ ├── folder
│ │ │ ├── __init__.py
│ │ │ └── folder_detector.py
│ │ └── sfd
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── bbox.cpython-36.pyc
│ │ │ ├── detect.cpython-36.pyc
│ │ │ ├── net_s3fd.cpython-36.pyc
│ │ │ └── sfd_detector.cpython-36.pyc
│ │ │ ├── bbox.py
│ │ │ ├── detect.py
│ │ │ ├── net_s3fd.py
│ │ │ └── sfd_detector.py
│ ├── models.py
│ └── utils.py
├── get_landmarks.py
├── songyiren.jpg
└── test
│ ├── assets
│ └── aflw-test.jpg
│ ├── facealignment_test.py
│ ├── smoke_test.py
│ └── test_utils.py
├── model
├── D.py
├── E.py
├── FaceSwapModel.py
├── G.py
├── __init__.py
├── __pycache__
│ ├── D.cpython-36.pyc
│ ├── D.cpython-37.pyc
│ ├── E.cpython-36.pyc
│ ├── E.cpython-37.pyc
│ ├── FaceSwapModel.cpython-37.pyc
│ ├── G.cpython-36.pyc
│ ├── G.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── loss.cpython-36.pyc
│ ├── loss.cpython-37.pyc
│ ├── resblocks.cpython-36.pyc
│ ├── resblocks.cpython-37.pyc
│ ├── vgg_face.cpython-36.pyc
│ └── vgg_face.cpython-37.pyc
├── loss.py
├── resblocks.py
└── vgg_face.py
├── train.py
├── training_visual
├── temp_fake_gt_landmark_100.jpg
├── temp_fake_gt_landmark_1000.jpg
├── temp_fake_gt_landmark_1100.jpg
├── temp_fake_gt_landmark_1200.jpg
├── temp_fake_gt_landmark_1300.jpg
├── temp_fake_gt_landmark_1400.jpg
├── temp_fake_gt_landmark_1500.jpg
├── temp_fake_gt_landmark_1600.jpg
├── temp_fake_gt_landmark_1700.jpg
├── temp_fake_gt_landmark_1800.jpg
├── temp_fake_gt_landmark_1900.jpg
├── temp_fake_gt_landmark_200.jpg
├── temp_fake_gt_landmark_2000.jpg
├── temp_fake_gt_landmark_2100.jpg
├── temp_fake_gt_landmark_2200.jpg
├── temp_fake_gt_landmark_2300.jpg
├── temp_fake_gt_landmark_300.jpg
├── temp_fake_gt_landmark_400.jpg
├── temp_fake_gt_landmark_500.jpg
├── temp_fake_gt_landmark_600.jpg
├── temp_fake_gt_landmark_700.jpg
├── temp_fake_gt_landmark_800.jpg
└── temp_fake_gt_landmark_900.jpg
└── utils.py
/.idea/face_swap.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
112 |
113 |
114 |
115 | Conv2d
116 | pool
117 | landmarks
118 | fa
119 | v
120 | Adain
121 | mean_y
122 | e
123 | mean_
124 | RandomHorizontalFlip
125 | vis
126 | print
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 | 1559740913359
322 |
323 |
324 | 1559740913359
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Few-Shot-Adversarial-Learning-for-face-swap
2 | This is a unofficial re-implementation of the paper "Few-Shot Adversarial Learning of Realistic Neural Talking Head Models" based on Pytorch.
3 |
4 | # description
5 | The paper from SAMSUNG AI lab presents a new novel and efficient method for face swap, which has amazing performance. I am a student and interested in this paper, so I try to reproduce it.
6 | I have writen the script for getting landmarks dataset, and dataloader pipeline. I also construct the whole network including Embedding, Generator and Discriminator with loss functions. Everything is done according to the paper.
7 | But, due to unspecific network descripted by the paper, I found that there are some mistakes in Generator, especially for the Adaptive instance normalization. The training process gets weird results, which you can look at it from "training_visual" file.
8 | **For getting more understanding about this amazing and great work, I open source this projection and invite more people who are interested in it to become contributor.** If someone reproduces it successfully in the future, please tell me. Thanks!
9 |
10 | # how to use
11 |
12 | 1. get landmarks information.
13 |
14 | you should download the [VoxCeleb, about 36GB](http://www.robots.ox.ac.uk/~vgg/research/CMBiometrics/data/dense-face-frames.tar.gz). And after unzip the dataset, please change the name of root path of the dataset as "voxcelb1", or you can adapt "./face-alignment_projection/get_landmarks.py" for your envirnoment.
15 | > python get_landmarks.py
16 |
17 | For test you shouldn't preprocess all data in the VoxCeleb dataset, because it's time-consuming. I just use about 200 video clips. The two files are generated. "video_clips_path.txt" records all clips which have been preprocessed. "video_landmarks_path.txt" records landmarks information path.
18 |
19 | 2. download VGGFace weights for perceptual loss.
20 |
21 | you can download from [here](http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/vgg_m_face_bn_dag.py) [on this website](http://www.robots.ox.ac.uk/~albanie/pytorch-models.html).
22 | make sure that put the weight file under the "./pretrained" file path.
23 |
24 | 3. start to train the whole network.
25 |
26 | > python ./train.py
27 |
28 | For each 100 iteartion, you can look at temp results including fake image, GT and landmarks under the "./training_visual" file path.
29 |
30 | # Cite
31 | [1] Few-Shot Adversarial Learning of Realistic Neural Talking Head Models
32 |
33 | [2] face-alignment is from [the great work](https://github.com/1adrianb/face-alignment)
34 |
--------------------------------------------------------------------------------
/dataset/__pycache__/reader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/dataset/__pycache__/reader.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/reader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset,DataLoader
2 | import torch
3 | import os
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 | from torchvision.transforms import transforms
8 | from torchvision.transforms import functional as F
9 | import random
10 |
11 | class Reader(Dataset):
12 | def __init__(self,clip_txt,transform):
13 | super(Reader,self).__init__()
14 | f = open(clip_txt, 'r')
15 | clip_list = f.readlines()
16 | f.close()
17 | self.clip_list = {}
18 | self.landmarks_list = []
19 | for c in clip_list:
20 | c = c.strip()
21 | d = c.split('data')
22 | d = d[0] + 'dataset/voxcelb1/' 'landmarks' + d[2]
23 | imgs = os.listdir(c)
24 | self.clip_list[c] = imgs
25 | self.landmarks_list.append(d)
26 | self.trans = transform
27 | self.flip = F.hflip
28 | def __getitem__(self, index):
29 | landmark_path = self.landmarks_list[index]
30 | clip_path = landmark_path.split('landmarks')
31 | clip_path = clip_path[0] + 'data' + clip_path[1]
32 | clip_imgs = self.clip_list[clip_path]
33 | np.random.shuffle(clip_imgs)
34 | imgs_for_e_path = clip_imgs[:8]
35 | imgs_for_training_path = clip_imgs[8:]
36 | imgs_for_e = []
37 | landmarks_for_e = []
38 | imgs_for_training = []
39 | landmarks_for_training = []
40 | for p in imgs_for_e_path:
41 | img = Image.open(os.path.join(clip_path,p)).convert('RGB')
42 | landmark = Image.open(os.path.join(landmark_path,p)).convert('RGB')
43 | is_flip = random.random() > 0.5
44 | if is_flip:
45 | img = self.flip(img)
46 | landmark = self.flip(landmark)
47 | img = self.trans(img)
48 | landmark = self.trans(landmark)
49 | imgs_for_e.append(img)
50 | landmarks_for_e.append(landmark)
51 | for p in imgs_for_training_path:
52 | img = Image.open(os.path.join(clip_path, p)).convert('RGB')
53 | landmark = Image.open(os.path.join(landmark_path, p)).convert('RGB')
54 | is_flip = random.random() > 0.5
55 | if is_flip:
56 | img = self.flip(img)
57 | landmark = self.flip(landmark)
58 | img = self.trans(img)
59 | landmark = self.trans(landmark)
60 | imgs_for_training.append(img)
61 | landmarks_for_training.append(landmark)
62 | return {'imgs_e':imgs_for_e,
63 | 'landmarks_e':landmarks_for_e,
64 | 'imgs_training':imgs_for_training,
65 | 'landmarks_training':landmarks_for_training
66 | }
67 | def __len__(self):
68 | return len(self.landmarks_list)
69 |
70 | def get_loader(clip_txt,batchsize,num_workers):
71 | trans = transforms.Compose([
72 | transforms.Resize((256,256)),
73 | transforms.ToTensor(),
74 | transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
75 | ])
76 | dataset = Reader(clip_txt,trans)
77 | loader = DataLoader(dataset,batchsize,pin_memory=True,drop_last=True,num_workers=num_workers)
78 | return loader, len(dataset)
79 | if __name__=='__main__':
80 | clip_txt = 'video_clips_path.txt'
81 | batchsize = 1
82 | num_workers = 0
83 | loader,_ = get_loader(clip_txt,batchsize,num_workers)
84 | for d in loader:
85 | for k, v in d.items():
86 | print(k)
87 | print(type(v))
88 | print(type(v[0]))
89 | print(v[0].size()) # torch.Size([1, 3, 256, 256])
90 |
91 | print('**********')
92 |
--------------------------------------------------------------------------------
/face-alignment_projection/conda/conda_upload.sh:
--------------------------------------------------------------------------------
1 | PCG_NAME=face_alignment
2 | USER=1adrianb
3 |
4 | mkdir ~/conda-build
5 | conda config --set anaconda_upload no
6 | conda build conda/
7 | anaconda -t $CONDA_UPLOAD_TOKEN upload -u $USER /home/travis/miniconda/envs/test-environment/conda-bld/noarch/face_alignment-1.0.1-py_1.tar.bz2 --force
--------------------------------------------------------------------------------
/face-alignment_projection/conda/meta.yaml:
--------------------------------------------------------------------------------
1 | {% set version = "1.0.1" %}
2 |
3 | package:
4 | name: face_alignment
5 | version: {{ version }}
6 |
7 | source:
8 | path: ..
9 |
10 | build:
11 | number: 1
12 | noarch: python
13 | script: python setup.py install --single-version-externally-managed --record=record.txt
14 |
15 | requirements:
16 | build:
17 | - setuptools
18 | - python
19 | run:
20 | - python
21 | - pytorch
22 | - numpy
23 | - scikit-image
24 | - scipy
25 | - opencv
26 | - tqdm
27 |
28 | about:
29 | home: https://github.com/1adrianb/face-alignment
30 | license: BSD
31 | license_file: LICENSE
32 | summary: A 2D and 3D face alignment libray in python
33 |
34 | extra:
35 | recipe-maintainers:
36 | - 1adrianb
37 |
--------------------------------------------------------------------------------
/face-alignment_projection/docs/images/2dlandmarks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/docs/images/2dlandmarks.png
--------------------------------------------------------------------------------
/face-alignment_projection/docs/images/face-alignment-adrian.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/docs/images/face-alignment-adrian.gif
--------------------------------------------------------------------------------
/face-alignment_projection/examples/detect_landmarks_in_image.py:
--------------------------------------------------------------------------------
1 | import face_alignment
2 | import numpy as np
3 | from mpl_toolkits.mplot3d import Axes3D
4 | import matplotlib.pyplot as plt
5 | from skimage import io
6 |
7 | # Run the 3D face alignment on a test image, without CUDA.
8 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cuda:0', flip_input=True)
9 |
10 | input = io.imread('../test/assets/aflw-test.jpg')
11 | preds = fa.get_landmarks(input)[-1]
12 |
13 | #TODO: Make this nice
14 | fig = plt.figure(figsize=plt.figaspect(.5))
15 | ax = fig.add_subplot(1, 2, 1)
16 | ax.imshow(input)
17 | ax.plot(preds[0:17,0],preds[0:17,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
18 | ax.plot(preds[17:22,0],preds[17:22,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
19 | ax.plot(preds[22:27,0],preds[22:27,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
20 | ax.plot(preds[27:31,0],preds[27:31,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
21 | ax.plot(preds[31:36,0],preds[31:36,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
22 | ax.plot(preds[36:42,0],preds[36:42,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
23 | ax.plot(preds[42:48,0],preds[42:48,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
24 | ax.plot(preds[48:60,0],preds[48:60,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
25 | ax.plot(preds[60:68,0],preds[60:68,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
26 | ax.axis('off')
27 |
28 | ax = fig.add_subplot(1, 2, 2, projection='3d')
29 | surf = ax.scatter(preds[:,0]*1.2,preds[:,1],preds[:,2],c="cyan", alpha=1.0, edgecolor='b')
30 | ax.plot3D(preds[:17,0]*1.2,preds[:17,1], preds[:17,2], color='blue' )
31 | ax.plot3D(preds[17:22,0]*1.2,preds[17:22,1],preds[17:22,2], color='blue')
32 | ax.plot3D(preds[22:27,0]*1.2,preds[22:27,1],preds[22:27,2], color='blue')
33 | ax.plot3D(preds[27:31,0]*1.2,preds[27:31,1],preds[27:31,2], color='blue')
34 | ax.plot3D(preds[31:36,0]*1.2,preds[31:36,1],preds[31:36,2], color='blue')
35 | ax.plot3D(preds[36:42,0]*1.2,preds[36:42,1],preds[36:42,2], color='blue')
36 | ax.plot3D(preds[42:48,0]*1.2,preds[42:48,1],preds[42:48,2], color='blue')
37 | ax.plot3D(preds[48:,0]*1.2,preds[48:,1],preds[48:,2], color='blue' )
38 |
39 | ax.view_init(elev=90., azim=90.)
40 | ax.set_xlim(ax.get_xlim()[::-1])
41 | plt.show()
42 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | __author__ = """Adrian Bulat"""
4 | __email__ = 'adrian.bulat@nottingham.ac.uk'
5 | __version__ = '1.0.1'
6 |
7 | from .api import FaceAlignment, LandmarksType, NetworkSize
8 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/__pycache__/api.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/__pycache__/api.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/api.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | from torch.utils.model_zoo import load_url
5 | from enum import Enum
6 | from skimage import io
7 | from skimage import color
8 | import numpy as np
9 | import cv2
10 | try:
11 | import urllib.request as request_file
12 | except BaseException:
13 | import urllib as request_file
14 |
15 | from .models import FAN, ResNetDepth
16 | from .utils import *
17 |
18 |
19 | class LandmarksType(Enum):
20 | """Enum class defining the type of landmarks to detect.
21 |
22 | ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
23 | ``_2halfD`` - this points represent the projection of the 3D points into 3D
24 | ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
25 |
26 | """
27 | _2D = 1
28 | _2halfD = 2
29 | _3D = 3
30 |
31 |
32 | class NetworkSize(Enum):
33 | # TINY = 1
34 | # SMALL = 2
35 | # MEDIUM = 3
36 | LARGE = 4
37 |
38 | def __new__(cls, value):
39 | member = object.__new__(cls)
40 | member._value_ = value
41 | return member
42 |
43 | def __int__(self):
44 | return self.value
45 |
46 | models_urls = {
47 | '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-11f355bf06.pth.tar',
48 | '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-7835d9f11d.pth.tar',
49 | 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-2a464da4ea.pth.tar',
50 | }
51 |
52 |
53 | class FaceAlignment:
54 | def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
55 | device='cuda', flip_input=False, face_detector='sfd', verbose=False):
56 | self.device = device
57 | self.flip_input = flip_input
58 | self.landmarks_type = landmarks_type
59 | self.verbose = verbose
60 |
61 | network_size = int(network_size)
62 |
63 | if 'cuda' in device:
64 | torch.backends.cudnn.benchmark = True
65 |
66 | # Get the face detector
67 | face_detector_module = __import__('face_alignment.detection.' + face_detector,
68 | globals(), locals(), [face_detector], 0)
69 | self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
70 |
71 | # Initialise the face alignemnt networks
72 | self.face_alignment_net = FAN(network_size)
73 | if landmarks_type == LandmarksType._2D:
74 | network_name = '2DFAN-' + str(network_size)
75 | else:
76 | network_name = '3DFAN-' + str(network_size)
77 |
78 | fan_weights = load_url(models_urls[network_name], map_location=lambda storage, loc: storage)
79 | self.face_alignment_net.load_state_dict(fan_weights)
80 |
81 | self.face_alignment_net.to(device)
82 | self.face_alignment_net.eval()
83 |
84 | # Initialiase the depth prediciton network
85 | if landmarks_type == LandmarksType._3D:
86 | self.depth_prediciton_net = ResNetDepth()
87 |
88 | depth_weights = load_url(models_urls['depth'], map_location=lambda storage, loc: storage)
89 | depth_dict = {
90 | k.replace('module.', ''): v for k,
91 | v in depth_weights['state_dict'].items()}
92 | self.depth_prediciton_net.load_state_dict(depth_dict)
93 |
94 | self.depth_prediciton_net.to(device)
95 | self.depth_prediciton_net.eval()
96 |
97 | def get_landmarks(self, image_or_path, detected_faces=None):
98 | """Deprecated, please use get_landmarks_from_image
99 |
100 | Arguments:
101 | image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it.
102 |
103 | Keyword Arguments:
104 | detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found
105 | in the image (default: {None})
106 | """
107 | return self.get_landmarks_from_image(image_or_path, detected_faces)
108 |
109 | def get_landmarks_from_image(self, image_or_path, detected_faces=None):
110 | """Predict the landmarks for each face present in the image.
111 |
112 | This function predicts a set of 68 2D or 3D images, one for each image present.
113 | If detect_faces is None the method will also run a face detector.
114 |
115 | Arguments:
116 | image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it.
117 |
118 | Keyword Arguments:
119 | detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found
120 | in the image (default: {None})
121 | """
122 | if isinstance(image_or_path, str):
123 | try:
124 | image = io.imread(image_or_path)
125 | except IOError:
126 | print("error opening file :: ", image_or_path)
127 | return None
128 | else:
129 | image = image_or_path
130 |
131 | if image.ndim == 2:
132 | image = color.gray2rgb(image)
133 | elif image.ndim == 4:
134 | image = image[..., :3]
135 |
136 | if detected_faces is None:
137 | detected_faces = self.face_detector.detect_from_image(image[..., ::-1].copy())
138 |
139 | if len(detected_faces) == 0:
140 | print("Warning: No faces were detected.")
141 | return None
142 |
143 | torch.set_grad_enabled(False)
144 | landmarks = []
145 | for i, d in enumerate(detected_faces):
146 | if (d[2]-d[0])*(d[3]-d[1])<75*75:
147 | continue
148 | center = torch.FloatTensor(
149 | [d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0])
150 | center[1] = center[1] - (d[3] - d[1]) * 0.12
151 | scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale
152 |
153 | inp = crop(image, center, scale)
154 | inp = torch.from_numpy(inp.transpose(
155 | (2, 0, 1))).float()
156 |
157 | inp = inp.to(self.device)
158 | inp.div_(255.0).unsqueeze_(0)
159 |
160 | out = self.face_alignment_net(inp)[-1].detach()
161 | if self.flip_input:
162 | out += flip(self.face_alignment_net(flip(inp))
163 | [-1].detach(), is_label=True)
164 | out = out.cpu()
165 |
166 | pts, pts_img = get_preds_fromhm(out, center, scale)
167 | pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2)
168 |
169 | if self.landmarks_type == LandmarksType._3D:
170 | heatmaps = np.zeros((68, 256, 256), dtype=np.float32)
171 | for i in range(68):
172 | if pts[i, 0] > 0:
173 | heatmaps[i] = draw_gaussian(
174 | heatmaps[i], pts[i], 2)
175 | heatmaps = torch.from_numpy(
176 | heatmaps).unsqueeze_(0)
177 |
178 | heatmaps = heatmaps.to(self.device)
179 | depth_pred = self.depth_prediciton_net(
180 | torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
181 | pts_img = torch.cat(
182 | (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)
183 |
184 | landmarks.append(pts_img.numpy())
185 |
186 | return landmarks
187 |
188 | def get_landmarks_from_directory(self, path, extensions=['.jpg', '.png'], recursive=True, show_progress_bar=True):
189 | detected_faces = self.face_detector.detect_from_directory(path, extensions, recursive, show_progress_bar)
190 |
191 | predictions = {}
192 | for image_path, bounding_boxes in detected_faces.items():
193 | image = io.imread(image_path)
194 | preds = self.get_landmarks_from_image(image, bounding_boxes)
195 | predictions[image_path] = preds
196 |
197 | return predictions
198 |
199 | @staticmethod
200 | def remove_models(self):
201 | base_path = os.path.join(appdata_dir('face_alignment'), "data")
202 | for data_model in os.listdir(base_path):
203 | file_path = os.path.join(base_path, data_model)
204 | try:
205 | if os.path.isfile(file_path):
206 | print('Removing ' + data_model + ' ...')
207 | os.unlink(file_path)
208 | except Exception as e:
209 | print(e)
210 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import FaceDetector
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/detection/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/__pycache__/core.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/detection/__pycache__/core.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/core.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import glob
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | import cv2
7 | from skimage import io
8 |
9 |
10 | class FaceDetector(object):
11 | """An abstract class representing a face detector.
12 |
13 | Any other face detection implementation must subclass it. All subclasses
14 | must implement ``detect_from_image``, that return a list of detected
15 | bounding boxes. Optionally, for speed considerations detect from path is
16 | recommended.
17 | """
18 |
19 | def __init__(self, device, verbose):
20 | self.device = device
21 | self.verbose = verbose
22 |
23 | if verbose:
24 | if 'cpu' in device:
25 | logger = logging.getLogger(__name__)
26 | logger.warning("Detection running on CPU, this may be potentially slow.")
27 |
28 | if 'cpu' not in device and 'cuda' not in device:
29 | if verbose:
30 | logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
31 | raise ValueError
32 |
33 | def detect_from_image(self, tensor_or_path):
34 | """Detects faces in a given image.
35 |
36 | This function detects the faces present in a provided BGR(usually)
37 | image. The input can be either the image itself or the path to it.
38 |
39 | Arguments:
40 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
41 | to an image or the image itself.
42 |
43 | Example::
44 |
45 | >>> path_to_image = 'data/image_01.jpg'
46 | ... detected_faces = detect_from_image(path_to_image)
47 | [A list of bounding boxes (x1, y1, x2, y2)]
48 | >>> image = cv2.imread(path_to_image)
49 | ... detected_faces = detect_from_image(image)
50 | [A list of bounding boxes (x1, y1, x2, y2)]
51 |
52 | """
53 | raise NotImplementedError
54 |
55 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
56 | """Detects faces from all the images present in a given directory.
57 |
58 | Arguments:
59 | path {string} -- a string containing a path that points to the folder containing the images
60 |
61 | Keyword Arguments:
62 | extensions {list} -- list of string containing the extensions to be
63 | consider in the following format: ``.extension_name`` (default:
64 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
65 | folder recursively (default: {False}) show_progress_bar {bool} --
66 | display a progressbar (default: {True})
67 |
68 | Example:
69 | >>> directory = 'data'
70 | ... detected_faces = detect_from_directory(directory)
71 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
72 |
73 | """
74 | if self.verbose:
75 | logger = logging.getLogger(__name__)
76 |
77 | if len(extensions) == 0:
78 | if self.verbose:
79 | logger.error("Expected at list one extension, but none was received.")
80 | raise ValueError
81 |
82 | if self.verbose:
83 | logger.info("Constructing the list of images.")
84 | additional_pattern = '/**/*' if recursive else '/*'
85 | files = []
86 | for extension in extensions:
87 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
88 |
89 | if self.verbose:
90 | logger.info("Finished searching for images. %s images found", len(files))
91 | logger.info("Preparing to run the detection.")
92 |
93 | predictions = {}
94 | for image_path in tqdm(files, disable=not show_progress_bar):
95 | if self.verbose:
96 | logger.info("Running the face detector on image: %s", image_path)
97 | predictions[image_path] = self.detect_from_image(image_path)
98 |
99 | if self.verbose:
100 | logger.info("The detector was successfully run on all %s images", len(files))
101 |
102 | return predictions
103 |
104 | @property
105 | def reference_scale(self):
106 | raise NotImplementedError
107 |
108 | @property
109 | def reference_x_shift(self):
110 | raise NotImplementedError
111 |
112 | @property
113 | def reference_y_shift(self):
114 | raise NotImplementedError
115 |
116 | @staticmethod
117 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
118 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
119 |
120 | Arguments:
121 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
122 | """
123 | if isinstance(tensor_or_path, str):
124 | return cv2.imread(tensor_or_path) if not rgb else io.imread(tensor_or_path)
125 | elif torch.is_tensor(tensor_or_path):
126 | # Call cpu in case its coming from cuda
127 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
128 | elif isinstance(tensor_or_path, np.ndarray):
129 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
130 | else:
131 | raise TypeError
132 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/dlib/__init__.py:
--------------------------------------------------------------------------------
1 | from .dlib_detector import DlibDetector as FaceDetector
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/dlib/dlib_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import dlib
4 |
5 | try:
6 | import urllib.request as request_file
7 | except BaseException:
8 | import urllib as request_file
9 |
10 | from ..core import FaceDetector
11 | from ...utils import appdata_dir
12 |
13 |
14 | class DlibDetector(FaceDetector):
15 | def __init__(self, device, path_to_detector=None, verbose=False):
16 | super().__init__(device, verbose)
17 |
18 | print('Warning: this detector is deprecated. Please use a different one, i.e.: S3FD.')
19 | base_path = os.path.join(appdata_dir('face_alignment'), "data")
20 |
21 | # Initialise the face detector
22 | if 'cuda' in device:
23 | if path_to_detector is None:
24 | path_to_detector = os.path.join(
25 | base_path, "mmod_human_face_detector.dat")
26 |
27 | if not os.path.isfile(path_to_detector):
28 | print("Downloading the face detection CNN. Please wait...")
29 |
30 | path_to_temp_detector = os.path.join(
31 | base_path, "mmod_human_face_detector.dat.download")
32 |
33 | if os.path.isfile(path_to_temp_detector):
34 | os.remove(os.path.join(path_to_temp_detector))
35 |
36 | request_file.urlretrieve(
37 | "https://www.adrianbulat.com/downloads/dlib/mmod_human_face_detector.dat",
38 | os.path.join(path_to_temp_detector))
39 |
40 | os.rename(os.path.join(path_to_temp_detector), os.path.join(path_to_detector))
41 |
42 | self.face_detector = dlib.cnn_face_detection_model_v1(path_to_detector)
43 | else:
44 | self.face_detector = dlib.get_frontal_face_detector()
45 |
46 | def detect_from_image(self, tensor_or_path):
47 | image = self.tensor_or_path_to_ndarray(tensor_or_path, rgb=False)
48 |
49 | detected_faces = self.face_detector(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))
50 |
51 | if 'cuda' not in self.device:
52 | detected_faces = [[d.left(), d.top(), d.right(), d.bottom()] for d in detected_faces]
53 | else:
54 | detected_faces = [[d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()] for d in detected_faces]
55 |
56 | return detected_faces
57 |
58 | @property
59 | def reference_scale(self):
60 | return 195
61 |
62 | @property
63 | def reference_x_shift(self):
64 | return 0
65 |
66 | @property
67 | def reference_y_shift(self):
68 | return 0
69 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/folder/__init__.py:
--------------------------------------------------------------------------------
1 | from .folder_detector import FolderDetector as FaceDetector
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/folder/folder_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 | from ..core import FaceDetector
6 |
7 |
8 | class FolderDetector(FaceDetector):
9 | '''This is a simple helper module that assumes the faces were detected already
10 | (either previously or are provided as ground truth).
11 |
12 | The class expects to find the bounding boxes in the same format used by
13 | the rest of face detectors, mainly ``list[(x1,y1,x2,y2),...]``.
14 | For each image the detector will search for a file with the same name and with one of the
15 | following extensions: .npy, .t7 or .pth
16 |
17 | '''
18 |
19 | def __init__(self, device, path_to_detector=None, verbose=False):
20 | super(FolderDetector, self).__init__(device, verbose)
21 |
22 | def detect_from_image(self, tensor_or_path):
23 | # Only strings supported
24 | if not isinstance(tensor_or_path, str):
25 | raise ValueError
26 |
27 | base_name = os.path.splitext(tensor_or_path)[0]
28 |
29 | if os.path.isfile(base_name + '.npy'):
30 | detected_faces = np.load(base_name + '.npy')
31 | elif os.path.isfile(base_name + '.t7'):
32 | detected_faces = torch.load(base_name + '.t7')
33 | elif os.path.isfile(base_name + '.pth'):
34 | detected_faces = torch.load(base_name + '.pth')
35 | else:
36 | raise FileNotFoundError
37 |
38 | if not isinstance(detected_faces, list):
39 | raise TypeError
40 |
41 | return detected_faces
42 |
43 | @property
44 | def reference_scale(self):
45 | return 195
46 |
47 | @property
48 | def reference_x_shift(self):
49 | return 0
50 |
51 | @property
52 | def reference_y_shift(self):
53 | return 0
54 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/__init__.py:
--------------------------------------------------------------------------------
1 | from .sfd_detector import SFDDetector as FaceDetector
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/detection/sfd/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/__pycache__/bbox.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/detection/sfd/__pycache__/bbox.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/__pycache__/detect.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/detection/sfd/__pycache__/detect.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/detection/sfd/__pycache__/net_s3fd.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/face_alignment/detection/sfd/__pycache__/sfd_detector.cpython-36.pyc
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/bbox.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import cv2
5 | import random
6 | import datetime
7 | import time
8 | import math
9 | import argparse
10 | import numpy as np
11 | import torch
12 |
13 | try:
14 | from iou import IOU
15 | except BaseException:
16 | # IOU cython speedup 10x
17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18 | sa = abs((ax2 - ax1) * (ay2 - ay1))
19 | sb = abs((bx2 - bx1) * (by2 - by1))
20 | x1, y1 = max(ax1, bx1), max(ay1, by1)
21 | x2, y2 = min(ax2, bx2), min(ay2, by2)
22 | w = x2 - x1
23 | h = y2 - y1
24 | if w < 0 or h < 0:
25 | return 0.0
26 | else:
27 | return 1.0 * w * h / (sa + sb - w * h)
28 |
29 |
30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33 | dw, dh = math.log(ww / aww), math.log(hh / ahh)
34 | return dx, dy, dw, dh
35 |
36 |
37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38 | xc, yc = dx * aww + axc, dy * ahh + ayc
39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41 | return x1, y1, x2, y2
42 |
43 |
44 | def nms(dets, thresh):
45 | if 0 == len(dets):
46 | return []
47 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49 | order = scores.argsort()[::-1]
50 |
51 | keep = []
52 | while order.size > 0:
53 | i = order[0]
54 | keep.append(i)
55 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57 |
58 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60 |
61 | inds = np.where(ovr <= thresh)[0]
62 | order = order[inds + 1]
63 |
64 | return keep
65 |
66 |
67 | def encode(matched, priors, variances):
68 | """Encode the variances from the priorbox layers into the ground truth boxes
69 | we have matched (based on jaccard overlap) with the prior boxes.
70 | Args:
71 | matched: (tensor) Coords of ground truth for each prior in point-form
72 | Shape: [num_priors, 4].
73 | priors: (tensor) Prior boxes in center-offset form
74 | Shape: [num_priors,4].
75 | variances: (list[float]) Variances of priorboxes
76 | Return:
77 | encoded boxes (tensor), Shape: [num_priors, 4]
78 | """
79 |
80 | # dist b/t match center and prior's center
81 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82 | # encode variance
83 | g_cxcy /= (variances[0] * priors[:, 2:])
84 | # match wh / prior wh
85 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86 | g_wh = torch.log(g_wh) / variances[1]
87 | # return target for smooth_l1_loss
88 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89 |
90 |
91 | def decode(loc, priors, variances):
92 | """Decode locations from predictions using priors to undo
93 | the encoding we did for offset regression at train time.
94 | Args:
95 | loc (tensor): location predictions for loc layers,
96 | Shape: [num_priors,4]
97 | priors (tensor): Prior boxes in center-offset form.
98 | Shape: [num_priors,4].
99 | variances: (list[float]) Variances of priorboxes
100 | Return:
101 | decoded bounding box predictions
102 | """
103 |
104 | boxes = torch.cat((
105 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107 | boxes[:, :2] -= boxes[:, 2:] / 2
108 | boxes[:, 2:] += boxes[:, :2]
109 | return boxes
110 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/detect.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import os
5 | import sys
6 | import cv2
7 | import random
8 | import datetime
9 | import math
10 | import argparse
11 | import numpy as np
12 |
13 | import scipy.io as sio
14 | import zipfile
15 | from .net_s3fd import s3fd
16 | from .bbox import *
17 |
18 |
19 | def detect(net, img, device):
20 | img = img - np.array([104, 117, 123])
21 | img = img.transpose(2, 0, 1)
22 | img = img.reshape((1,) + img.shape)
23 |
24 | if 'cuda' in device:
25 | torch.backends.cudnn.benchmark = True
26 |
27 | img = torch.from_numpy(img).float().to(device)
28 | BB, CC, HH, WW = img.size()
29 | with torch.no_grad():
30 | olist = net(img)
31 |
32 | bboxlist = []
33 | for i in range(len(olist) // 2):
34 | olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35 | olist = [oelem.data.cpu() for oelem in olist]
36 | for i in range(len(olist) // 2):
37 | ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38 | FB, FC, FH, FW = ocls.size() # feature map size
39 | stride = 2**(i + 2) # 4,8,16,32,64,128
40 | anchor = stride * 4
41 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42 | for Iindex, hindex, windex in poss:
43 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44 | score = ocls[0, 1, hindex, windex]
45 | loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47 | variances = [0.1, 0.2]
48 | box = decode(loc, priors, variances)
49 | x1, y1, x2, y2 = box[0] * 1.0
50 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51 | bboxlist.append([x1, y1, x2, y2, score])
52 | bboxlist = np.array(bboxlist)
53 | if 0 == len(bboxlist):
54 | bboxlist = np.zeros((1, 5))
55 |
56 | return bboxlist
57 |
58 |
59 | def flip_detect(net, img, device):
60 | img = cv2.flip(img, 1)
61 | b = detect(net, img, device)
62 |
63 | bboxlist = np.zeros(b.shape)
64 | bboxlist[:, 0] = img.shape[1] - b[:, 2]
65 | bboxlist[:, 1] = b[:, 1]
66 | bboxlist[:, 2] = img.shape[1] - b[:, 0]
67 | bboxlist[:, 3] = b[:, 3]
68 | bboxlist[:, 4] = b[:, 4]
69 | return bboxlist
70 |
71 |
72 | def pts_to_bb(pts):
73 | min_x, min_y = np.min(pts, axis=0)
74 | max_x, max_y = np.max(pts, axis=0)
75 | return np.array([min_x, min_y, max_x, max_y])
76 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/net_s3fd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class L2Norm(nn.Module):
7 | def __init__(self, n_channels, scale=1.0):
8 | super(L2Norm, self).__init__()
9 | self.n_channels = n_channels
10 | self.scale = scale
11 | self.eps = 1e-10
12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13 | self.weight.data *= 0.0
14 | self.weight.data += self.scale
15 |
16 | def forward(self, x):
17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18 | x = x / norm * self.weight.view(1, -1, 1, 1)
19 | return x
20 |
21 |
22 | class s3fd(nn.Module):
23 | def __init__(self):
24 | super(s3fd, self).__init__()
25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27 |
28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30 |
31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34 |
35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38 |
39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42 |
43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45 |
46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48 |
49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51 |
52 | self.conv3_3_norm = L2Norm(256, scale=10)
53 | self.conv4_3_norm = L2Norm(512, scale=8)
54 | self.conv5_3_norm = L2Norm(512, scale=5)
55 |
56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62 |
63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69 |
70 | def forward(self, x):
71 | h = F.relu(self.conv1_1(x))
72 | h = F.relu(self.conv1_2(h))
73 | h = F.max_pool2d(h, 2, 2)
74 |
75 | h = F.relu(self.conv2_1(h))
76 | h = F.relu(self.conv2_2(h))
77 | h = F.max_pool2d(h, 2, 2)
78 |
79 | h = F.relu(self.conv3_1(h))
80 | h = F.relu(self.conv3_2(h))
81 | h = F.relu(self.conv3_3(h))
82 | f3_3 = h
83 | h = F.max_pool2d(h, 2, 2)
84 |
85 | h = F.relu(self.conv4_1(h))
86 | h = F.relu(self.conv4_2(h))
87 | h = F.relu(self.conv4_3(h))
88 | f4_3 = h
89 | h = F.max_pool2d(h, 2, 2)
90 |
91 | h = F.relu(self.conv5_1(h))
92 | h = F.relu(self.conv5_2(h))
93 | h = F.relu(self.conv5_3(h))
94 | f5_3 = h
95 | h = F.max_pool2d(h, 2, 2)
96 |
97 | h = F.relu(self.fc6(h))
98 | h = F.relu(self.fc7(h))
99 | ffc7 = h
100 | h = F.relu(self.conv6_1(h))
101 | h = F.relu(self.conv6_2(h))
102 | f6_2 = h
103 | h = F.relu(self.conv7_1(h))
104 | h = F.relu(self.conv7_2(h))
105 | f7_2 = h
106 |
107 | f3_3 = self.conv3_3_norm(f3_3)
108 | f4_3 = self.conv4_3_norm(f4_3)
109 | f5_3 = self.conv5_3_norm(f5_3)
110 |
111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117 | cls4 = self.fc7_mbox_conf(ffc7)
118 | reg4 = self.fc7_mbox_loc(ffc7)
119 | cls5 = self.conv6_2_mbox_conf(f6_2)
120 | reg5 = self.conv6_2_mbox_loc(f6_2)
121 | cls6 = self.conv7_2_mbox_conf(f7_2)
122 | reg6 = self.conv7_2_mbox_loc(f7_2)
123 |
124 | # max-out background label
125 | chunk = torch.chunk(cls1, 4, 1)
126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127 | cls1 = torch.cat([bmax, chunk[3]], dim=1)
128 |
129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
130 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/detection/sfd/sfd_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from torch.utils.model_zoo import load_url
4 |
5 | from ..core import FaceDetector
6 |
7 | from .net_s3fd import s3fd
8 | from .bbox import *
9 | from .detect import *
10 |
11 | models_urls = {
12 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13 | }
14 |
15 |
16 | class SFDDetector(FaceDetector):
17 | def __init__(self, device, path_to_detector=None, verbose=False):
18 | super(SFDDetector, self).__init__(device, verbose)
19 |
20 | # Initialise the face detector
21 | if path_to_detector is None:
22 | model_weights = load_url(models_urls['s3fd'])
23 | else:
24 | model_weights = torch.load(path_to_detector)
25 |
26 | self.face_detector = s3fd()
27 | self.face_detector.load_state_dict(model_weights)
28 | self.face_detector.to(device)
29 | self.face_detector.eval()
30 |
31 | def detect_from_image(self, tensor_or_path):
32 | image = self.tensor_or_path_to_ndarray(tensor_or_path)
33 |
34 | bboxlist = detect(self.face_detector, image, device=self.device)
35 | keep = nms(bboxlist, 0.3)
36 | bboxlist = bboxlist[keep, :]
37 | bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38 |
39 | return bboxlist
40 |
41 | @property
42 | def reference_scale(self):
43 | return 195
44 |
45 | @property
46 | def reference_x_shift(self):
47 | return 0
48 |
49 | @property
50 | def reference_y_shift(self):
51 | return 0
52 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8 | "3x3 convolution with padding"
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10 | stride=strd, padding=padding, bias=bias)
11 |
12 |
13 | class ConvBlock(nn.Module):
14 | def __init__(self, in_planes, out_planes):
15 | super(ConvBlock, self).__init__()
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22 |
23 | if in_planes != out_planes:
24 | self.downsample = nn.Sequential(
25 | nn.BatchNorm2d(in_planes),
26 | nn.ReLU(True),
27 | nn.Conv2d(in_planes, out_planes,
28 | kernel_size=1, stride=1, bias=False),
29 | )
30 | else:
31 | self.downsample = None
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | out1 = self.bn1(x)
37 | out1 = F.relu(out1, True)
38 | out1 = self.conv1(out1)
39 |
40 | out2 = self.bn2(out1)
41 | out2 = F.relu(out2, True)
42 | out2 = self.conv2(out2)
43 |
44 | out3 = self.bn3(out2)
45 | out3 = F.relu(out3, True)
46 | out3 = self.conv3(out3)
47 |
48 | out3 = torch.cat((out1, out2, out3), 1)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(residual)
52 |
53 | out3 += residual
54 |
55 | return out3
56 |
57 |
58 | class Bottleneck(nn.Module):
59 |
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67 | padding=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70 | self.bn3 = nn.BatchNorm2d(planes * 4)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class HourGlass(nn.Module):
99 | def __init__(self, num_modules, depth, num_features):
100 | super(HourGlass, self).__init__()
101 | self.num_modules = num_modules
102 | self.depth = depth
103 | self.features = num_features
104 |
105 | self._generate_network(self.depth)
106 |
107 | def _generate_network(self, level):
108 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109 |
110 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111 |
112 | if level > 1:
113 | self._generate_network(level - 1)
114 | else:
115 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116 |
117 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118 |
119 | def _forward(self, level, inp):
120 | # Upper branch
121 | up1 = inp
122 | up1 = self._modules['b1_' + str(level)](up1)
123 |
124 | # Lower branch
125 | low1 = F.avg_pool2d(inp, 2, stride=2)
126 | low1 = self._modules['b2_' + str(level)](low1)
127 |
128 | if level > 1:
129 | low2 = self._forward(level - 1, low1)
130 | else:
131 | low2 = low1
132 | low2 = self._modules['b2_plus_' + str(level)](low2)
133 |
134 | low3 = low2
135 | low3 = self._modules['b3_' + str(level)](low3)
136 |
137 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138 |
139 | return up1 + up2
140 |
141 | def forward(self, x):
142 | return self._forward(self.depth, x)
143 |
144 |
145 | class FAN(nn.Module):
146 |
147 | def __init__(self, num_modules=1):
148 | super(FAN, self).__init__()
149 | self.num_modules = num_modules
150 |
151 | # Base part
152 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153 | self.bn1 = nn.BatchNorm2d(64)
154 | self.conv2 = ConvBlock(64, 128)
155 | self.conv3 = ConvBlock(128, 128)
156 | self.conv4 = ConvBlock(128, 256)
157 |
158 | # Stacking part
159 | for hg_module in range(self.num_modules):
160 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162 | self.add_module('conv_last' + str(hg_module),
163 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165 | self.add_module('l' + str(hg_module), nn.Conv2d(256,
166 | 68, kernel_size=1, stride=1, padding=0))
167 |
168 | if hg_module < self.num_modules - 1:
169 | self.add_module(
170 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171 | self.add_module('al' + str(hg_module), nn.Conv2d(68,
172 | 256, kernel_size=1, stride=1, padding=0))
173 |
174 | def forward(self, x):
175 | x = F.relu(self.bn1(self.conv1(x)), True)
176 | x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177 | x = self.conv3(x)
178 | x = self.conv4(x)
179 |
180 | previous = x
181 |
182 | outputs = []
183 | for i in range(self.num_modules):
184 | hg = self._modules['m' + str(i)](previous)
185 |
186 | ll = hg
187 | ll = self._modules['top_m_' + str(i)](ll)
188 |
189 | ll = F.relu(self._modules['bn_end' + str(i)]
190 | (self._modules['conv_last' + str(i)](ll)), True)
191 |
192 | # Predict heatmaps
193 | tmp_out = self._modules['l' + str(i)](ll)
194 | outputs.append(tmp_out)
195 |
196 | if i < self.num_modules - 1:
197 | ll = self._modules['bl' + str(i)](ll)
198 | tmp_out_ = self._modules['al' + str(i)](tmp_out)
199 | previous = previous + ll + tmp_out_
200 |
201 | return outputs
202 |
203 |
204 | class ResNetDepth(nn.Module):
205 |
206 | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207 | self.inplanes = 64
208 | super(ResNetDepth, self).__init__()
209 | self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210 | bias=False)
211 | self.bn1 = nn.BatchNorm2d(64)
212 | self.relu = nn.ReLU(inplace=True)
213 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214 | self.layer1 = self._make_layer(block, 64, layers[0])
215 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218 | self.avgpool = nn.AvgPool2d(7)
219 | self.fc = nn.Linear(512 * block.expansion, num_classes)
220 |
221 | for m in self.modules():
222 | if isinstance(m, nn.Conv2d):
223 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224 | m.weight.data.normal_(0, math.sqrt(2. / n))
225 | elif isinstance(m, nn.BatchNorm2d):
226 | m.weight.data.fill_(1)
227 | m.bias.data.zero_()
228 |
229 | def _make_layer(self, block, planes, blocks, stride=1):
230 | downsample = None
231 | if stride != 1 or self.inplanes != planes * block.expansion:
232 | downsample = nn.Sequential(
233 | nn.Conv2d(self.inplanes, planes * block.expansion,
234 | kernel_size=1, stride=stride, bias=False),
235 | nn.BatchNorm2d(planes * block.expansion),
236 | )
237 |
238 | layers = []
239 | layers.append(block(self.inplanes, planes, stride, downsample))
240 | self.inplanes = planes * block.expansion
241 | for i in range(1, blocks):
242 | layers.append(block(self.inplanes, planes))
243 |
244 | return nn.Sequential(*layers)
245 |
246 | def forward(self, x):
247 | x = self.conv1(x)
248 | x = self.bn1(x)
249 | x = self.relu(x)
250 | x = self.maxpool(x)
251 |
252 | x = self.layer1(x)
253 | x = self.layer2(x)
254 | x = self.layer3(x)
255 | x = self.layer4(x)
256 |
257 | x = self.avgpool(x)
258 | x = x.view(x.size(0), -1)
259 | x = self.fc(x)
260 |
261 | return x
262 |
--------------------------------------------------------------------------------
/face-alignment_projection/face_alignment/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import time
5 | import torch
6 | import math
7 | import numpy as np
8 | import cv2
9 |
10 |
11 | def _gaussian(
12 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14 | mean_vert=0.5):
15 | # handle some defaults
16 | if width is None:
17 | width = size
18 | if height is None:
19 | height = size
20 | if sigma_horz is None:
21 | sigma_horz = sigma
22 | if sigma_vert is None:
23 | sigma_vert = sigma
24 | center_x = mean_horz * width + 0.5
25 | center_y = mean_vert * height + 0.5
26 | gauss = np.empty((height, width), dtype=np.float32)
27 | # generate kernel
28 | for i in range(height):
29 | for j in range(width):
30 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32 | if normalize:
33 | gauss = gauss / np.sum(gauss)
34 | return gauss
35 |
36 |
37 | def draw_gaussian(image, point, sigma):
38 | # Check if the gaussian is inside
39 | ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40 | br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41 | if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42 | return image
43 | size = 6 * sigma + 1
44 | g = _gaussian(size)
45 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49 | assert (g_x[0] > 0 and g_y[1] > 0)
50 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52 | image[image > 1] = 1
53 | return image
54 |
55 |
56 | def transform(point, center, scale, resolution, invert=False):
57 | """Generate and affine transformation matrix.
58 |
59 | Given a set of points, a center, a scale and a targer resolution, the
60 | function generates and affine transformation matrix. If invert is ``True``
61 | it will produce the inverse transformation.
62 |
63 | Arguments:
64 | point {torch.tensor} -- the input 2D point
65 | center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66 | scale {float} -- the scale of the face/object
67 | resolution {float} -- the output resolution
68 |
69 | Keyword Arguments:
70 | invert {bool} -- define wherever the function should produce the direct or the
71 | inverse transformation matrix (default: {False})
72 | """
73 | _pt = torch.ones(3)
74 | _pt[0] = point[0]
75 | _pt[1] = point[1]
76 |
77 | h = 200.0 * scale
78 | t = torch.eye(3)
79 | t[0, 0] = resolution / h
80 | t[1, 1] = resolution / h
81 | t[0, 2] = resolution * (-center[0] / h + 0.5)
82 | t[1, 2] = resolution * (-center[1] / h + 0.5)
83 |
84 | if invert:
85 | t = torch.inverse(t)
86 |
87 | new_point = (torch.matmul(t, _pt))[0:2]
88 |
89 | return new_point.int()
90 |
91 |
92 | def crop(image, center, scale, resolution=256.0):
93 | """Center crops an image or set of heatmaps
94 |
95 | Arguments:
96 | image {numpy.array} -- an rgb image
97 | center {numpy.array} -- the center of the object, usually the same as of the bounding box
98 | scale {float} -- scale of the face
99 |
100 | Keyword Arguments:
101 | resolution {float} -- the size of the output cropped image (default: {256.0})
102 |
103 | Returns:
104 | [type] -- [description]
105 | """ # Crop around the center point
106 | """ Crops the image around the center. Input is expected to be an np.ndarray """
107 | ul = transform([1, 1], center, scale, resolution, True)
108 | br = transform([resolution, resolution], center, scale, resolution, True)
109 | # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110 | if image.ndim > 2:
111 | newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112 | image.shape[2]], dtype=np.int32)
113 | newImg = np.zeros(newDim, dtype=np.uint8)
114 | else:
115 | newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116 | newImg = np.zeros(newDim, dtype=np.uint8)
117 | ht = image.shape[0]
118 | wd = image.shape[1]
119 | newX = np.array(
120 | [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121 | newY = np.array(
122 | [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123 | oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124 | oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125 | newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126 | ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127 | newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128 | interpolation=cv2.INTER_LINEAR)
129 | return newImg
130 |
131 |
132 | def get_preds_fromhm(hm, center=None, scale=None):
133 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134 | and the scale is provided the function will return the points also in
135 | the original coordinate frame.
136 |
137 | Arguments:
138 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139 |
140 | Keyword Arguments:
141 | center {torch.tensor} -- the center of the bounding box (default: {None})
142 | scale {float} -- face scale (default: {None})
143 | """
144 | max, idx = torch.max(
145 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146 | idx += 1
147 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150 |
151 | for i in range(preds.size(0)):
152 | for j in range(preds.size(1)):
153 | hm_ = hm[i, j, :]
154 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156 | diff = torch.FloatTensor(
157 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159 | preds[i, j].add_(diff.sign_().mul_(.25))
160 |
161 | preds.add_(-.5)
162 |
163 | preds_orig = torch.zeros(preds.size())
164 | if center is not None and scale is not None:
165 | for i in range(hm.size(0)):
166 | for j in range(hm.size(1)):
167 | preds_orig[i, j] = transform(
168 | preds[i, j], center, scale, hm.size(2), True)
169 |
170 | return preds, preds_orig
171 |
172 |
173 | def shuffle_lr(parts, pairs=None):
174 | """Shuffle the points left-right according to the axis of symmetry
175 | of the object.
176 |
177 | Arguments:
178 | parts {torch.tensor} -- a 3D or 4D object containing the
179 | heatmaps.
180 |
181 | Keyword Arguments:
182 | pairs {list of integers} -- [order of the flipped points] (default: {None})
183 | """
184 | if pairs is None:
185 | pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
186 | 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
187 | 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
188 | 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
189 | 62, 61, 60, 67, 66, 65]
190 | if parts.ndimension() == 3:
191 | parts = parts[pairs, ...]
192 | else:
193 | parts = parts[:, pairs, ...]
194 |
195 | return parts
196 |
197 |
198 | def flip(tensor, is_label=False):
199 | """Flip an image or a set of heatmaps left-right
200 |
201 | Arguments:
202 | tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
203 |
204 | Keyword Arguments:
205 | is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
206 | """
207 | if not torch.is_tensor(tensor):
208 | tensor = torch.from_numpy(tensor)
209 |
210 | if is_label:
211 | tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
212 | else:
213 | tensor = tensor.flip(tensor.ndimension() - 1)
214 |
215 | return tensor
216 |
217 | # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
218 |
219 |
220 | def appdata_dir(appname=None, roaming=False):
221 | """ appdata_dir(appname=None, roaming=False)
222 |
223 | Get the path to the application directory, where applications are allowed
224 | to write user specific files (e.g. configurations). For non-user specific
225 | data, consider using common_appdata_dir().
226 | If appname is given, a subdir is appended (and created if necessary).
227 | If roaming is True, will prefer a roaming directory (Windows Vista/7).
228 | """
229 |
230 | # Define default user directory
231 | userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
232 | if userDir is None:
233 | userDir = os.path.expanduser('~')
234 | if not os.path.isdir(userDir): # pragma: no cover
235 | userDir = '/var/tmp' # issue #54
236 |
237 | # Get system app data dir
238 | path = None
239 | if sys.platform.startswith('win'):
240 | path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
241 | path = (path2 or path1) if roaming else (path1 or path2)
242 | elif sys.platform.startswith('darwin'):
243 | path = os.path.join(userDir, 'Library', 'Application Support')
244 | # On Linux and as fallback
245 | if not (path and os.path.isdir(path)):
246 | path = userDir
247 |
248 | # Maybe we should store things local to the executable (in case of a
249 | # portable distro or a frozen application that wants to be portable)
250 | prefix = sys.prefix
251 | if getattr(sys, 'frozen', None):
252 | prefix = os.path.abspath(os.path.dirname(sys.executable))
253 | for reldir in ('settings', '../settings'):
254 | localpath = os.path.abspath(os.path.join(prefix, reldir))
255 | if os.path.isdir(localpath): # pragma: no cover
256 | try:
257 | open(os.path.join(localpath, 'test.write'), 'wb').close()
258 | os.remove(os.path.join(localpath, 'test.write'))
259 | except IOError:
260 | pass # We cannot write in this directory
261 | else:
262 | path = localpath
263 | break
264 |
265 | # Get path specific for this app
266 | if appname:
267 | if path == userDir:
268 | appname = '.' + appname.lstrip('.') # Make it a hidden directory
269 | path = os.path.join(path, appname)
270 | if not os.path.isdir(path): # pragma: no cover
271 | os.mkdir(path)
272 |
273 | # Done
274 | return path
275 |
--------------------------------------------------------------------------------
/face-alignment_projection/get_landmarks.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy
3 | import face_alignment
4 | import numpy as np
5 | import skimage.io as io
6 | import os
7 |
8 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True)
9 | predefined_color = [128,64,160,244,35,220,80,200]
10 |
11 | dataset = '../../dataset/voxcelb1/data'
12 | person_name = os.listdir(dataset)
13 | num_classes = len(person_name)
14 | video_path = [os.path.abspath(os.path.join(dataset,i)) for i in person_name]
15 |
16 | # save all video clips
17 | all_video_path = []
18 | f = open('video_clips_path.txt','w')
19 | print('saving video path!!')
20 | for v in video_path:
21 | each_person_clip = os.path.join(v,'1.6')
22 | each_person_clip = [ os.path.join(each_person_clip,i,'1') for i in os.listdir(each_person_clip)]
23 | all_video_path += each_person_clip
24 | for path in each_person_clip:
25 | print(path)
26 | f.write(path+'\n')
27 | f.close()
28 | if not os.path.exists('../../dataset/voxcelb1/landmarks'):
29 | os.makedirs('../../dataset/voxcelb1/landmarks')
30 |
31 | print('saving landmarks!!!')
32 | f = open('video_landmarks_path.txt','w')
33 | for v in all_video_path:
34 | print(v)
35 | for i in os.listdir(v):
36 | img = os.path.join(v,i)
37 | basename = os.path.dirname(img)
38 | filename = os.path.basename(img)
39 | input = io.imread(img)
40 | preds = fa.get_landmarks(input)
41 | if preds is None:
42 | continue
43 | preds = preds[-1]
44 | saved = np.ones_like(input,dtype=np.uint8)*255
45 | for i in range(17-1):
46 | cv2.line(saved,(preds[i,0],preds[i,1]),(preds[i+1,0],preds[i+1,1]),[predefined_color[0],250,10],2)
47 | for i in range(17,22-1):
48 | cv2.line(saved,(preds[i,0],preds[i,1]),(preds[i+1,0],preds[i+1,1]),[predefined_color[1],10,250],2)
49 | for i in range(22,27-1):
50 | cv2.line(saved,(preds[i,0],preds[i,1]),(preds[i+1,0],preds[i+1,1]),[predefined_color[2],150,150],2)
51 | for i in range(27,31-1):
52 | cv2.line(saved, (preds[i, 0], preds[i, 1]), (preds[i + 1, 0], preds[i + 1, 1]), [predefined_color[3],0,0],
53 | 2)
54 | for i in range(31,36-1):
55 | cv2.line(saved,(preds[i,0],preds[i,1]),(preds[i+1,0],preds[i+1,1]),[predefined_color[4],50,0],2)
56 | for i in range(36,42-1):
57 | cv2.line(saved,(preds[i,0],preds[i,1]),(preds[i+1,0],preds[i+1,1]),[predefined_color[5],0,50],2)
58 | for i in range(42,48-1):
59 | cv2.line(saved,(preds[i,0],preds[i,1]),(preds[i+1,0],preds[i+1,1]),[predefined_color[6],180,30],2)
60 | for i in range(48,60-1):
61 | cv2.line(saved, (preds[i, 0], preds[i, 1]), (preds[i + 1, 0], preds[i + 1, 1]), [predefined_color[7],20,180],
62 | 2)
63 | basename = basename.split('data')
64 | basename = basename[0] + 'dataset/voxcelb1/' 'landmarks' + basename[2]
65 | if not os.path.exists(basename):
66 | os.makedirs(basename)
67 | filename = os.path.join(basename,filename)
68 | print(filename)
69 | f.write(filename + '\n')
70 | cv2.imwrite(filename,saved)
71 | f.close()
72 |
73 |
74 |
--------------------------------------------------------------------------------
/face-alignment_projection/songyiren.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/songyiren.jpg
--------------------------------------------------------------------------------
/face-alignment_projection/test/assets/aflw-test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/face-alignment_projection/test/assets/aflw-test.jpg
--------------------------------------------------------------------------------
/face-alignment_projection/test/facealignment_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import face_alignment
3 |
4 |
5 | class Tester(unittest.TestCase):
6 | def test_predict_points(self):
7 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu')
8 | fa.get_landmarks('test/assets/aflw-test.jpg')
9 |
10 | if __name__ == '__main__':
11 | unittest.main()
12 |
--------------------------------------------------------------------------------
/face-alignment_projection/test/smoke_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import face_alignment
3 |
--------------------------------------------------------------------------------
/face-alignment_projection/test/test_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from face_alignment.utils import *
3 | import numpy as np
4 | import torch
5 |
6 |
7 | class Tester(unittest.TestCase):
8 | def test_flip_is_label(self):
9 | # Generate the points
10 | heatmaps = torch.from_numpy(np.random.randint(1, high=250, size=(68, 64, 64)).astype('float32'))
11 |
12 | flipped_heatmaps = flip(flip(heatmaps.clone(), is_label=True), is_label=True)
13 |
14 | assert np.allclose(heatmaps.numpy(), flipped_heatmaps.numpy())
15 |
16 | def test_flip_is_image(self):
17 | fake_image = torch.torch.rand(3, 256, 256)
18 | fliped_fake_image = flip(flip(fake_image.clone()))
19 |
20 | assert np.allclose(fake_image.numpy(), fliped_fake_image.numpy())
21 |
22 | def test_getpreds(self):
23 | pts = torch.from_numpy(np.random.randint(1, high=63, size=(68, 2)).astype('float32'))
24 |
25 | heatmaps = np.zeros((68, 256, 256))
26 | for i in range(68):
27 | if pts[i, 0] > 0:
28 | heatmaps[i] = draw_gaussian(heatmaps[i], pts[i], 2)
29 | heatmaps = torch.from_numpy(np.expand_dims(heatmaps, axis=0))
30 |
31 | preds, _ = get_preds_fromhm(heatmaps)
32 |
33 | assert np.allclose(pts.numpy(), preds.numpy(), atol=5)
34 |
35 | if __name__ == '__main__':
36 | unittest.main()
37 |
--------------------------------------------------------------------------------
/model/D.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 | from torch.nn import utils
6 |
7 | from model.resblocks import Block
8 | from model.resblocks import OptimizedBlock
9 |
10 |
11 | class D(nn.Module):
12 |
13 | def __init__(self,input_nc=6 ,num_features=64, num_classes=0, activation=F.relu):
14 | super(D, self).__init__()
15 | self.num_features = num_features
16 | self.num_classes = num_classes
17 | self.activation = activation
18 | if num_classes == None:
19 | raise ValueError('no given num_classes')
20 | self.block1 = OptimizedBlock(input_nc, num_features)
21 | self.block2 = Block(num_features, num_features * 2,
22 | activation=activation, downsample=True)
23 | self.block3 = Block(num_features * 2, num_features * 4,
24 | activation=activation, downsample=True)
25 | self.block4 = Block(num_features * 4, num_features * 8,
26 | activation=activation, downsample=True)
27 | self.block5 = Block(num_features * 8, num_features * 8,
28 | activation=activation, downsample=True) # 8
29 |
30 |
31 | self.w_0 = nn.Parameter(torch.Tensor(1,num_features * 8))
32 | nn.init.xavier_normal_(self.w_0.data)
33 | self.b = nn.Parameter(torch.zeros(1))
34 |
35 | if num_classes > 0:
36 | self.W = utils.spectral_norm(
37 | nn.Embedding(num_classes, num_features * 8))
38 |
39 | def _initialize(self):
40 | optional_W = getattr(self, 'W', None)
41 | if optional_W is not None:
42 | init.xavier_uniform_(optional_W.weight.data)
43 |
44 |
45 | def forward(self, x, y): # 形如[1,23,45,..] 是视频序列,范围在0 到 num_class-1
46 | # h = x
47 | h1 = self.block1(x)
48 | h2 = self.block2(h1)
49 | h3 = self.block3(h2)
50 | h4 = self.block4(h3)
51 | h5 = self.block5(h4)
52 | h5 = self.activation(h5)
53 | # Global pooling
54 | v_loss = torch.sum(h5, dim=(2, 3)) # B,C
55 | v_loss = torch.sum(v_loss * (self.W(y) + self.w_0.squeeze()),dim=1) + self.b # ([B])
56 | v_loss = torch.mean(v_loss)
57 | return [v_loss, h1, h2, h3, h4, h5,self.W(y)]
58 |
59 |
60 | if __name__=='__main__':
61 | model = D(num_classes=100).cuda()
62 | y = torch.ones([]).long().cuda()
63 | data = torch.randn(1,6,256,256).cuda()
64 | output = model(data,y)
65 | v_loss = output[0]
66 | print(v_loss.size()) # torch.Size([])
67 | w = output[-1]
68 | print(w.size()) # torch.Size([8, 512])
69 |
--------------------------------------------------------------------------------
/model/E.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 | from torch.nn import utils
6 |
7 | from model.resblocks import Block
8 | from model.resblocks import OptimizedBlock
9 |
10 |
11 | class E(nn.Module):
12 |
13 | def __init__(self, num_features=64, activation=F.relu):
14 | super(E, self).__init__()
15 | self.num_features = num_features
16 | self.activation = activation
17 |
18 | self.block1 = OptimizedBlock(6, num_features) # 128
19 | self.block2 = Block(num_features, num_features * 2,
20 | activation=activation, downsample=True) # 64
21 | self.block3 = Block(num_features * 2, num_features * 4,
22 | activation=activation, downsample=True) # 32
23 | self.block4 = Block(num_features * 4, num_features * 8,
24 | activation=activation, downsample=True) # 16
25 | self.block5 = Block(num_features * 8, num_features * 8,
26 | activation=activation, downsample=True) # 8
27 | self.l6 = utils.spectral_norm(nn.Linear(num_features * 8, 256))
28 | self.l7 = utils.spectral_norm(nn.Linear(num_features * 8, 256))
29 | self._initialize()
30 |
31 | def _initialize(self):
32 | init.xavier_uniform_(self.l6.weight.data)
33 | optional_l_y = getattr(self, 'l_y', None)
34 | if optional_l_y is not None:
35 | init.xavier_uniform_(optional_l_y.weight.data)
36 |
37 | def forward(self, x,):
38 | h = self.block1(x)
39 | h = self.block2(h)
40 | h = self.block3(h)
41 | h = self.block4(h)
42 | h = self.block5(h)
43 | h = self.activation(h)
44 | # Global pooling
45 | h = torch.sum(h, dim=(2, 3))
46 | h = self.activation(h) # 512
47 | e = torch.mean(h,dim=0,keepdim=True)
48 | out1 = self.l6(h) # B,256 P矩阵
49 | size = out1.size() # B,256
50 | assert len(size)==2, 'size is not right'
51 | mean = out1.view(size[0],size[1],1,1) # 512维
52 | mean = torch.mean(mean,dim=0,keepdim=True)
53 | std = self.l7(h) # B,256
54 | std = std.view(size[0],size[1],1,1)
55 | std = torch.mean(std,dim=0,keepdim=True)
56 | return e, mean, std
57 |
58 |
59 | if __name__=='__main__':
60 | model = E().cuda()
61 | data = torch.randn(8,6,256,256).cuda()
62 | output = model(data)
63 | e = output[0]
64 | mean = output[1]
65 | print(e.size()) # torch.Size([1, 512])
66 | print(mean.size()) # torch.Size([1, 256, 1, 1])
67 |
--------------------------------------------------------------------------------
/model/FaceSwapModel.py:
--------------------------------------------------------------------------------
1 | from model.D import D
2 | from model.G import G
3 | from model.E import E
4 | import numpy as np
5 | import torch
6 | from torch.nn import Module
7 | from torch import nn
8 | from model.loss import VGGLoss,VGGFaceLoss,CNTLoss,AdvLoss,MCHLoss,DLoss
9 | # divide whole model into E and G&D
10 | class GDModel(Module):
11 | def __init__(self,num_classes):
12 | super(GDModel, self).__init__()
13 | self.g = G(input_nc=3)
14 | self.d = D(num_classes=num_classes)
15 | self.cntloss = CNTLoss()
16 | self.advloss = AdvLoss()
17 | self.mchloss = MCHLoss()
18 | self.dloss = DLoss()
19 |
20 |
21 | def g_forward(self,landmark):
22 | return self.g(landmark)
23 | def d_forward(self,fake_img,real_image,y):
24 | fake_info, real_info = self.d(fake_img,y), self.d(real_image,y)
25 | return fake_info, real_info
26 | def cal_cnt_loss(self,fake_image, real_image):
27 | return self.cntloss(fake_image,real_image)
28 | def cal_adv_loss(self,fake_info,real_info): # include FM loss
29 | fake_v_loss = fake_info[0]
30 | fake_v_features = fake_info[1:6]
31 | real_v_features = real_info[1:6]
32 | return self.advloss(fake_v_features,real_v_features,fake_v_loss)
33 | def cal_mch_loss(self,fake_info):
34 | w = fake_info[6]
35 | return self.mchloss(w,self.e)
36 | def cal_d_loss(self,fake_info, real_info):
37 | fake_v_loss = fake_info[0]
38 | reak_v_loss = real_info[0]
39 | return self.dloss(reak_v_loss,fake_v_loss)
40 | def update_GDModel(self,mean_y,std_y,e):
41 | self.e = e
42 | self.g.update_adain(mean_y,std_y)
43 |
44 | def for_test_inference(self,landmark,y,x):
45 | x_landmark = torch.cat((landmark,x),1)
46 | fake_image = self.g_forward(landmark) # from -1 to 1
47 | fake_landmark = torch.cat((landmark,fake_image),1)
48 | fake_info,real_info = self.d_forward(fake_landmark,x_landmark,y)
49 | g_loss = self.cal_cnt_loss(fake_image,x) + self.cal_adv_loss(fake_info,real_info) +\
50 | self.cal_mch_loss(fake_info)
51 | d_loss = self.cal_d_loss(fake_info,real_info)
52 | print(g_loss.size())
53 | print(d_loss.size())
54 |
55 | def forward(self, landmark,y,x):
56 | x_landmark = torch.cat((landmark,x),1)
57 | fake_image = self.g_forward(landmark) # from -1 to 1
58 | fake_landmark = torch.cat((landmark,fake_image),1)
59 | fake_info,real_info = self.d_forward(fake_landmark,x_landmark,y)
60 | g_loss = self.cal_cnt_loss(fake_image,x) + self.cal_adv_loss(fake_info,real_info) +\
61 | self.cal_mch_loss(fake_info)
62 | d_loss = self.cal_d_loss(fake_info,real_info)
63 | return fake_image,g_loss,d_loss
64 |
65 |
66 | if __name__=='__main__':
67 | landmark = torch.randn(2, 3, 224, 224).cuda()
68 | y = torch.LongTensor(np.random.randint(0, 50, size=[2])).cuda()
69 | x = torch.randn(2, 3, 224, 224).cuda()
70 | x_landmark = torch.cat((landmark, x), 1)
71 | e_net = E().cuda()
72 | e, mean_y, std_y = e_net(x_landmark)
73 | model = GDModel(num_classes=30,).cuda()
74 | model.update_GDModel(mean_y,std_y,e)
75 | model.for_test_inference(landmark,y,x)
76 |
77 |
78 |
--------------------------------------------------------------------------------
/model/G.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.utils import spectral_norm
3 | from torch.nn import Module
4 | from torch import nn
5 | from torch.nn import init
6 | import math
7 |
8 | class G(Module):
9 | def __init__(self,input_nc,):
10 | super(G,self).__init__()
11 | activation = nn.ReLU()
12 |
13 | model = [nn.ReflectionPad2d(4),
14 | spectral_norm(nn.Conv2d(input_nc,128,9,1,bias=False)),
15 | nn.InstanceNorm2d(128,affine=True),
16 | activation]
17 | model = nn.Sequential(*model)
18 |
19 | # downsampling
20 | in_nc = 128
21 | out_nc = 2* in_nc # 256, 256
22 | for i in range(2):
23 | model.add_module('res_down%d'% i,Residual_Block(in_nc,out_nc))
24 | in_nc = out_nc
25 | # residual same resolution
26 | for i in range(4):
27 | model.add_module('res_iden%d'% i, Residual_Ident(in_nc,256,))
28 |
29 | # upsampling
30 | for i in range(2):
31 | model.add_module('up%d'%i, Up(in_nc,in_nc))
32 | # in_nc = in_nc//2
33 | # reduce dimension from 64 to 3
34 | model.add_module('head',nn.Sequential(
35 | nn.ReflectionPad2d(1),
36 | spectral_norm(nn.Conv2d(256,128,3,1)),
37 | activation,
38 |
39 | nn.ReflectionPad2d(1),
40 | spectral_norm(nn.Conv2d(128, 64, 3, 1)),
41 | activation,
42 |
43 | nn.ReflectionPad2d(1),
44 | spectral_norm(nn.Conv2d(64, 3, 3, 1)),
45 | nn.Tanh() # -1 ,1
46 |
47 | ))
48 | self.model = model
49 | def forward(self,x):
50 | x= self.model(x)
51 | # print('g mean : ',x.mean())
52 | return x
53 | def update_adain(self,mean_y,std_y):
54 | for m in self.modules():
55 | if isinstance(m,Adain):
56 | m.update_mean_std(mean_y,std_y)
57 |
58 |
59 | class Residual_Block(Module):
60 | def __init__(self,input_nc,output_nc):
61 | super(Residual_Block, self).__init__()
62 | activation = nn.ReLU(True)
63 | self.left = nn.Sequential(*[spectral_norm(nn.Conv2d(input_nc,output_nc,1,1,padding=0,bias=False)),
64 | nn.InstanceNorm2d(output_nc,affine=True),
65 | activation,
66 | nn.AvgPool2d(3, 2, padding=1)
67 | ])
68 | self.right = nn.Sequential(*[
69 | nn.ReflectionPad2d(1),
70 | spectral_norm(nn.Conv2d(input_nc,output_nc,3,1,padding=0,bias=False)),
71 | nn.InstanceNorm2d(output_nc,affine=True),
72 | activation,
73 | nn.ReflectionPad2d(1),
74 | spectral_norm(nn.Conv2d(output_nc, output_nc, 3, 1, padding=0, bias=False)),
75 | nn.InstanceNorm2d(output_nc,affine=True),
76 | activation,
77 | nn.AvgPool2d(3, 2, padding=1)
78 | ])
79 |
80 | def _initialize(self):
81 | init.xavier_uniform_(self.c1.weight.data, math.sqrt(2))
82 | for m in self.modules():
83 | if isinstance(m, nn.Conv2d):
84 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
85 | m.weight.data.normal_(0, math.sqrt(2. / n))
86 | elif isinstance(m, nn.InstanceNorm2d):
87 | m.weight.data.fill_(1)
88 | m.bias.data.zero_()
89 | else:
90 | raise ValueError('No this layer init way')
91 |
92 | def forward(self, x):
93 | x1 = self.right(x)
94 | x2 = self.left(x)
95 | return x1+x2
96 |
97 | class Residual_Ident(Module):
98 | def __init__(self,in_nc,out_nc,):
99 | super(Residual_Ident,self).__init__()
100 | activation = nn.ReLU(True)
101 | model = [nn.ReflectionPad2d(1),
102 | spectral_norm(nn.Conv2d(in_nc,out_nc,3,stride=1,padding=0,bias=False)),
103 | nn.InstanceNorm2d(out_nc,affine=False),
104 | Adain(),
105 | activation
106 | ]
107 | self.model = nn.Sequential(*model)
108 | def forward(self, x):
109 | return self.model(x)
110 |
111 | class Adain(Module):
112 | def __init__(self,):
113 | super(Adain,self).__init__()
114 |
115 | def forward(self, x):
116 | size = x.size()
117 | x = x*self.std_y.expand(size) + self.mean_y.expand(size)
118 | return x
119 | def update_mean_std(self,mean_y,std_y): # be used before forward
120 | self.mean_y = mean_y
121 | self.std_y = std_y
122 |
123 | class Up(Module):
124 | def __init__(self,in_nc,out_nc):
125 | super(Up, self).__init__()
126 | activation = nn.ReLU(True)
127 | model = [
128 | nn.Upsample(scale_factor=2,mode='bilinear'),
129 | nn.ReflectionPad2d(1),
130 | spectral_norm(nn.Conv2d(in_nc,out_nc,3,1,bias=False)),
131 | nn.InstanceNorm2d(out_nc,affine=False),
132 | Adain(),
133 | activation
134 | ]
135 | self.model = nn.Sequential(*model)
136 | def forward(self, x):
137 | x = self.model(x)
138 | return x
139 |
140 | if __name__=='__main__':
141 | std_y = torch.randn(2,256,1,1).cuda()
142 | x = torch.randn(2,6,224,224).cuda()
143 | mean_y = torch.zeros(2,256,1,1).cuda()
144 |
145 | model = G(6).cuda()
146 | model.update_adain(mean_y,std_y)
147 | out = model(x)
148 | print(out.size()) # torch.Size([2, 3, 224, 224])
149 |
150 |
151 |
152 | # for m in model.modules():
153 | # if isinstance(m,Adain):
154 | # m.update_mean_std(mean_y,std_y)
155 | # for m in model.children():
156 | # for c in m.modules():
157 | # if isinstance(c,Adain):
158 | # # print(c.mean_y==mean_y)
159 | # # print(c.std_y==std_y)
160 | # print('*********') # 6次
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__init__.py
--------------------------------------------------------------------------------
/model/__pycache__/D.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/D.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/D.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/D.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/E.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/E.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/E.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/E.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/FaceSwapModel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/FaceSwapModel.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/G.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/G.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/G.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/G.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/resblocks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/resblocks.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/resblocks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/resblocks.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/vgg_face.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/vgg_face.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/vgg_face.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/model/__pycache__/vgg_face.cpython-37.pyc
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 | import torch
3 | import torch.nn as nn
4 | import functools
5 | from torch.autograd import Variable
6 | import numpy as np
7 | from torchvision.transforms import transforms
8 | from model.vgg_face import vgg_m_face_bn_dag
9 |
10 | class Vgg19(torch.nn.Module):
11 | def __init__(self, requires_grad=False):
12 | super(Vgg19, self).__init__()
13 | vgg_pretrained_features = models.vgg19(pretrained=True).features
14 | self.slice1 = torch.nn.Sequential()
15 | self.slice2 = torch.nn.Sequential()
16 | self.slice3 = torch.nn.Sequential()
17 | self.slice4 = torch.nn.Sequential()
18 | self.slice5 = torch.nn.Sequential()
19 | for x in range(2):
20 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
21 | for x in range(1, 7):
22 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
23 | for x in range(7, 12):
24 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
25 | for x in range(12, 21):
26 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
27 | for x in range(21, 30):
28 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
29 |
30 | if not requires_grad:
31 | for param in self.parameters():
32 | param.requires_grad = False
33 |
34 | def forward(self, X):
35 | # self.zero_grad()
36 | h_relu1 = self.slice1(X)
37 | h_relu2 = self.slice2(h_relu1)
38 | h_relu3 = self.slice3(h_relu2)
39 | h_relu4 = self.slice4(h_relu3)
40 | h_relu5 = self.slice5(h_relu4)
41 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
42 | return out
43 |
44 | class VGGLoss(nn.Module):
45 | def __init__(self):
46 | super(VGGLoss,self).__init__()
47 | self.vgg = Vgg19()
48 | self.criterion = nn.L1Loss(size_average=True)
49 | self.weight = [0.01] *5
50 | def forward(self, x,y):
51 | x_vgg, y_vgg = self.vgg(x),self.vgg(y)
52 | loss = 0
53 | for i in range(len(x_vgg)):
54 | loss = self.weight[i]/float(len(x_vgg)) * self.criterion(x_vgg[i],y_vgg[i].detach()) + loss
55 | return loss
56 | class VGGFaceLoss(nn.Module):
57 | def __init__(self):
58 | super(VGGFaceLoss,self).__init__()
59 | self.vggface = vgg_m_face_bn_dag('./pretrained/vgg_m_face_bn_dag.pth')
60 | # self.vggface = vgg_m_face_bn_dag(None)
61 | self.criterion = nn.L1Loss(size_average=True)
62 | self.weights = [0.002] * 5
63 | def forward(self, x,y):
64 | x_vggface, y_vggface = self.vggface(x), self.vggface(y)
65 | loss = 0
66 | for i in range(len(x_vggface)):
67 | loss += self.weights[i]/5.0 * self.criterion(x_vggface[i],y_vggface[i].detach())
68 | return loss
69 |
70 | class CNTLoss(nn.Module):
71 | def __init__(self):
72 | super(CNTLoss, self).__init__()
73 | self.vggloss = VGGLoss()
74 | self.vggfaceloss = VGGFaceLoss()
75 | self.vgg_preprocessed = transforms.Normalize([0.485,0.456,0.406],
76 | [0.229,0.224,0.225])
77 | def forward(self,x,y):
78 | x = (x+1)*127.5
79 | y = (y+1)*127.5
80 | x1 = self.vgg_preprocessed(x.squeeze(0)/255.0).unsqueeze(0)
81 | y1 = self.vgg_preprocessed(y.squeeze(0)/255.0).unsqueeze(0)
82 | loss = 0.5 * (self.vggfaceloss(x,y) + self.vggloss(x1,y1))
83 | return loss
84 |
85 |
86 | class AdvLoss(nn.Module):
87 | def __init__(self,):
88 | super(AdvLoss,self).__init__()
89 | self.criterion = nn.L1Loss(size_average=True)
90 |
91 | def forward(self, fake_feature,real_feature,v_loss):
92 | fm_loss = 0
93 | feat_weights = 10.0 / len(fake_feature)
94 | for i in range(len(fake_feature)-1):
95 | fm_loss += feat_weights * self.criterion(fake_feature[i],real_feature[i].detach())
96 | return -v_loss + fm_loss # D_loss必须是标量才行
97 |
98 |
99 | class MCHLoss(nn.Module):
100 | def __init__(self):
101 | super(MCHLoss,self).__init__()
102 | self.criterion = nn.L1Loss(size_average=True)
103 |
104 | def forward(self, w,e):
105 | return 80 * self.criterion(w, e.detach())
106 |
107 | class DLoss(nn.Module):
108 | def __init__(self):
109 | super(DLoss,self).__init__()
110 | def forward(self, real_vloss, fake_vloss):
111 | d_loss = torch.mean(torch.relu(1. - real_vloss)) +\
112 | torch.mean(torch.relu(1. + fake_vloss))
113 | return d_loss
114 |
--------------------------------------------------------------------------------
/model/resblocks.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 | from torch.nn import utils
8 |
9 |
10 | class Block(nn.Module):
11 |
12 | def __init__(self, in_ch, out_ch, h_ch=None, ksize=3, pad=1,
13 | activation=F.relu, downsample=False):
14 | super(Block, self).__init__()
15 |
16 | self.activation = activation
17 | self.downsample = downsample
18 |
19 | self.learnable_sc = (in_ch != out_ch) or downsample
20 | if h_ch is None:
21 | h_ch = in_ch
22 | else:
23 | h_ch = out_ch
24 |
25 | self.c1 = utils.spectral_norm(nn.Conv2d(in_ch, h_ch, ksize, 1, pad))
26 | self.c2 = utils.spectral_norm(nn.Conv2d(h_ch, out_ch, ksize, 1, pad))
27 | if self.learnable_sc:
28 | self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0))
29 |
30 | self._initialize()
31 |
32 | def _initialize(self):
33 | init.xavier_uniform_(self.c1.weight.data, math.sqrt(2))
34 | init.xavier_uniform_(self.c2.weight.data, math.sqrt(2))
35 | if self.learnable_sc:
36 | init.xavier_uniform_(self.c_sc.weight.data)
37 |
38 | def forward(self, x):
39 | return self.shortcut(x) + self.residual(x)
40 |
41 | def shortcut(self, x):
42 | if self.learnable_sc:
43 | x = self.c_sc(x)
44 | if self.downsample:
45 | return F.avg_pool2d(x, 2)
46 | return x
47 |
48 | def residual(self, x):
49 | h = self.c1(self.activation(x))
50 | h = self.c2(self.activation(h))
51 | if self.downsample:
52 | h = F.avg_pool2d(h, 2)
53 | return h
54 |
55 |
56 | class OptimizedBlock(nn.Module):
57 |
58 | def __init__(self, in_ch, out_ch, ksize=3, pad=1, activation=F.relu):
59 | super(OptimizedBlock, self).__init__()
60 | self.activation = activation
61 |
62 | self.c1 = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, ksize, 1, pad))
63 | self.c2 = utils.spectral_norm(nn.Conv2d(out_ch, out_ch, ksize, 1, pad))
64 | self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0))
65 |
66 | self._initialize()
67 |
68 | def _initialize(self):
69 | init.xavier_uniform_(self.c1.weight.data, math.sqrt(2))
70 | init.xavier_uniform_(self.c2.weight.data, math.sqrt(2))
71 | init.xavier_uniform_(self.c_sc.weight.data)
72 |
73 | def forward(self, x):
74 | return self.shortcut(x) + self.residual(x)
75 |
76 | def shortcut(self, x):
77 | return self.c_sc(F.avg_pool2d(x, 2))
78 |
79 | def residual(self, x):
80 | h = self.activation(self.c1(x))
81 | return F.avg_pool2d(self.c2(h), 2)
82 |
--------------------------------------------------------------------------------
/model/vgg_face.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class Vgg_m_face_bn_dag(nn.Module):
7 |
8 | def __init__(self):
9 | super(Vgg_m_face_bn_dag, self).__init__()
10 | self.meta = {'mean': [131.45376586914062, 103.98748016357422, 91.46234893798828],
11 | 'std': [1, 1, 1],
12 | 'imageSize': [224, 224, 3]}
13 | self.mean = torch.FloatTensor(self.meta['mean']).unsqueeze(0).unsqueeze(2).unsqueeze(3).cuda()
14 | self.resize = torch.nn.UpsamplingBilinear2d(size=(224,224))
15 | self.conv1 = nn.Conv2d(3, 96, kernel_size=[7, 7], stride=(2, 2))
16 | self.bn49 = nn.BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
17 | self.relu1 = nn.ReLU()
18 | self.pool1 = nn.MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
19 | self.conv2 = nn.Conv2d(96, 256, kernel_size=[5, 5], stride=(2, 2), padding=(1, 1))
20 | self.bn50 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
21 | self.relu2 = nn.ReLU()
22 | self.pool2 = nn.MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=(0, 0), dilation=1, ceil_mode=True)
23 | self.conv3 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
24 | self.bn51 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
25 | self.relu3 = nn.ReLU()
26 | self.conv4 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
27 | self.bn52 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
28 | self.relu4 = nn.ReLU()
29 | self.conv5 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
30 | self.bn53 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
31 | self.relu5 = nn.ReLU()
32 | self.pool5 = nn.MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
33 | self.fc6 = nn.Conv2d(512, 4096, kernel_size=[6, 6], stride=(1, 1))
34 | self.bn54 = nn.BatchNorm2d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
35 | self.relu6 = nn.ReLU()
36 | self.fc7 = nn.Conv2d(4096, 4096, kernel_size=[1, 1], stride=(1, 1))
37 | self.bn55 = nn.BatchNorm2d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
38 | self.relu7 = nn.ReLU()
39 | self.fc8 = nn.Linear(in_features=4096, out_features=2622, bias=True)
40 |
41 | for param in self.parameters():
42 | param.requires_grad = False
43 | # for m in self.modules():
44 | # if isinstance(m,nn.BatchNorm2d):
45 | # m = m.eval()
46 |
47 | def forward(self, x0):
48 | # self.zero_grad()
49 | x0 = self.resize(x0)
50 | x0 = x0 - self.mean.expand(x0.size())
51 | x1 = self.conv1(x0)
52 | x2 = self.bn49(x1)
53 | x3 = self.relu1(x2)
54 | x4 = self.pool1(x3)
55 | x5 = self.conv2(x4)
56 | x6 = self.bn50(x5)
57 | x7 = self.relu2(x6)
58 | x8 = self.pool2(x7)
59 | x9 = self.conv3(x8)
60 | x10 = self.bn51(x9)
61 | x11 = self.relu3(x10)
62 | x12 = self.conv4(x11)
63 | x13 = self.bn52(x12)
64 | x14 = self.relu4(x13)
65 | x15 = self.conv5(x14)
66 | x16 = self.bn53(x15)
67 | x17 = self.relu5(x16)
68 | x18 = self.pool5(x17)
69 | x19 = self.fc6(x18)
70 | x20 = self.bn54(x19)
71 | x21 = self.relu6(x20)
72 | x22 = self.fc7(x21)
73 | x23 = self.bn55(x22)
74 | x24_preflatten = self.relu7(x23)
75 | x24 = x24_preflatten.view(x24_preflatten.size(0), -1)
76 | x25 = self.fc8(x24)
77 | return [x1, x6, x11, x18, x25]
78 |
79 | def vgg_m_face_bn_dag(weights_path=None, **kwargs):
80 | """
81 | load imported model instance
82 |
83 | Args:
84 | weights_path (str): If set, loads model weights from the given path
85 | """
86 | model = Vgg_m_face_bn_dag()
87 | if weights_path:
88 | state_dict = torch.load(weights_path)
89 | model.load_state_dict(state_dict)
90 | return model
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from model.E import E
5 | from model.FaceSwapModel import GDModel
6 | from dataset.reader import get_loader
7 | from torch.optim import Adam
8 | from utils import update_learning_rate
9 | from torch import nn
10 | import visdom
11 | import cv2
12 |
13 | torch.manual_seed(0)
14 | torch.cuda.manual_seed_all(0)
15 |
16 | torch.backends.cudnn.benchmark = True
17 | torch.backends.cudnn.deterministic = True
18 | # vis = visdom.Visdom()
19 |
20 | clip_txt = './dataset/video_clips_path.txt'
21 | batchsize = 1
22 | num_workers = 4
23 | epoches = 60
24 | loader, num_classes = get_loader(clip_txt,batchsize,num_workers)
25 | E = E().cuda()
26 | E.train()
27 | model = GDModel(num_classes).cuda()
28 | model.train()
29 | for m in model.cntloss.vggfaceloss.vggface.modules():
30 | if isinstance(m,nn.BatchNorm2d):
31 | m.eval()
32 |
33 | # 涉及到batchnorm 在training阶段,要计算batch的均值和标准差,使得batchsize不能为1,但是eval的话,又没法计算梯度
34 | # model.cntloss.vggfaceloss.vggface.eval()
35 | # model.cntloss.vggloss.vgg.eval()
36 | #
37 |
38 | # mutilple GPUs
39 | # E = torch.nn.DataParallel(E,)
40 | # model = torch.nn.DataParallel(model)
41 |
42 |
43 | # define optim
44 | g_e_parameters = list(E.parameters())
45 | g_e_parameters += list(model.g.parameters())
46 |
47 | lr = 2e-4
48 | g_e_optim = Adam(g_e_parameters,lr = lr ,betas=(0.5,0.999) )
49 | d_optim = Adam(model.d.parameters(),lr = 5 * lr,betas=(0.5,0.999))
50 |
51 | global_step = 0
52 | for e in range(epoches):
53 | current_clip_number = 0
54 | print('current epoch is %d'%epoches)
55 | for d in loader:
56 | print('current_clip_number is %d' % current_clip_number)
57 | # calculate e, std_y, mean_y for adaptive instance norm
58 | data_for_e = d['imgs_e']
59 | data_for_e = torch.cat(data_for_e,0).cuda()
60 | landmark_for_e = d['landmarks_e']
61 | landmark_for_e = torch.cat(landmark_for_e,0).cuda()
62 | batch_data = d['imgs_training']
63 | batch_data = torch.cat(batch_data,0)
64 | batch_landmark = d['landmarks_training']
65 | batch_landmark = torch.cat(batch_landmark,0)
66 |
67 | # e,mean_y,std_y = E(torch.cat((data_for_e,landmark_for_e),1))
68 | # model.update_GDModel(mean_y,std_y,e)
69 |
70 | # print(data_for_e.size())
71 | # print(landmark_for_e.size())
72 | for b,l in zip(batch_data,batch_landmark): # b and l are 3-dim tensors
73 |
74 | e, mean_y, std_y = E(torch.cat((data_for_e, landmark_for_e), 1))
75 | model.update_GDModel(mean_y, std_y, e)
76 |
77 | global_step += 1
78 | b = b.unsqueeze(0).cuda()
79 | l = l.unsqueeze(0).cuda()
80 | y = torch.tensor(current_clip_number).long().cuda()
81 | fake_img, g_loss, d_loss = model(l,y,b)
82 |
83 | model.cntloss.vggfaceloss.vggface.zero_grad()
84 | model.cntloss.vggloss.vgg.zero_grad()
85 | g_e_optim.zero_grad()
86 | g_loss.backward(retain_graph=True)
87 | g_e_optim.step()
88 |
89 |
90 | d_optim.zero_grad()
91 | d_loss.backward()
92 | d_optim.step()
93 |
94 | if global_step%100==0:
95 | fake_img = np.transpose(np.uint8((fake_img.cpu().data.numpy()[0]/2.0 + 0.5)*255),[1,2,0])
96 | b_ = np.transpose(np.uint8((b.cpu().data.numpy()[0]/2.0 + 0.5)*255),[1,2,0])
97 | l_ = np.transpose(np.uint8((l.cpu().data.numpy()[0]/2.0 + 0.5)*255),[1,2,0])
98 | # temp = np.stack((fake_img,b_,l_)) # 3, 3 ,256,256
99 | temp = np.concatenate((fake_img[:,:,::-1],b_[:,:,::-1],l_[:,:,::-1]),axis=1)
100 | cv2.imwrite('./training_visual/temp_fake_gt_landmark_%d.jpg'%global_step,temp)
101 | # vis.images(temp,nrow=1,win='temp_results')
102 |
103 | print('***************')
104 | current_clip_number += 1
105 | if global_step % 50 == 0:
106 | saved = {'e':E.state_dict(),
107 | 'g_d': model.state_dict()}
108 | torch.save(saved,'./saved_models/e_g_d%d.pth'%global_step)
109 | if (e+1)%10 ==0:
110 | lr = lr/2.0
111 | update_learning_rate(g_e_optim,lr)
112 | update_learning_rate(d_optim,5*lr)
113 |
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_100.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_100.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1000.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1100.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1100.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1200.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1200.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1300.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1300.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1400.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1400.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1500.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1500.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1600.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1600.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1700.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1700.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1800.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1800.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_1900.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_1900.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_200.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_200.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_2000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_2000.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_2100.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_2100.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_2200.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_2200.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_2300.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_2300.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_300.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_300.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_400.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_400.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_500.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_500.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_600.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_600.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_700.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_700.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_800.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_800.jpg
--------------------------------------------------------------------------------
/training_visual/temp_fake_gt_landmark_900.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shoutOutYangJie/Few-Shot-Adversarial-Learning-for-face-swap/927b10bcf066507caba470bdf66bf543fcb2520b/training_visual/temp_fake_gt_landmark_900.jpg
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from torch.optim import Adam
2 | from torchvision.models import vgg16
3 |
4 | def update_learning_rate(optim,lr):
5 | for param_group in optim.param_groups:
6 | param_group['lr'] = lr
7 |
8 |
9 | if __name__=='__main__':
10 | model = vgg16()
11 | optim = Adam(model.parameters(),lr=3)
12 | print(optim.param_groups[0]['lr'])
13 | update_learning_rate(optim,0.001)
14 | print(optim.param_groups[0]['lr'])
--------------------------------------------------------------------------------