├── .gitignore
├── GPEN
├── README.md
├── __init_paths.py
├── align_faces.py
├── face_enhancement.py
├── face_model
│ ├── face_gan.py
│ ├── model.py
│ └── op
│ │ ├── __init__.py
│ │ ├── fused_act.py
│ │ ├── fused_bias_act.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
├── requirements.txt
├── retinaface
│ ├── .DS_Store
│ ├── data
│ │ ├── FDDB
│ │ │ └── img_list.txt
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── data_augment.py
│ │ └── wider_face.py
│ ├── facemodels
│ │ ├── __init__.py
│ │ ├── net.py
│ │ └── retinaface.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── functions
│ │ │ └── prior_box.py
│ │ └── modules
│ │ │ ├── __init__.py
│ │ │ └── multibox_loss.py
│ ├── retinaface_detection.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── box_utils.py
│ │ ├── nms
│ │ ├── __init__.py
│ │ └── py_cpu_nms.py
│ │ └── timer.py
└── sr_model
│ ├── arch_util.py
│ ├── real_esrnet.py
│ └── rrdbnet_arch.py
├── LICENSE.md
├── README.md
├── assets
├── driving.mp4
└── source.jpg
├── camera_client.py
├── camera_local.py
├── demo_utils.py
├── face-vid2vid
├── LICENSE.md
├── README.md
├── animate.py
├── augmentation.py
├── config
│ └── vox-256-spade.yaml
├── crop-video.py
├── demo.py
├── frames_dataset.py
├── logger.py
├── modules
│ ├── dense_motion.py
│ ├── discriminator.py
│ ├── generator.py
│ ├── hopenet.py
│ ├── keypoint_detector.py
│ ├── model.py
│ └── util.py
├── run.py
├── sync_batchnorm
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
└── train.py
├── remote_server.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 |
30 | *.pth
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/GPEN/README.md:
--------------------------------------------------------------------------------
1 | # GAN Prior Embedded Network for Blind Face Restoration in the Wild
2 |
3 | [Paper](https://arxiv.org/abs/2105.06070) | [Supplementary](https://www4.comp.polyu.edu.hk/~cslzhang/paper/GPEN-cvpr21-supp.pdf) | [Demo](https://vision.aliyun.com/experience/detail?spm=a211p3.14020179.J_7524944390.17.66cd4850wVDkUQ&tagName=facebody&children=EnhanceFace)
4 |
5 | [Tao Yang](https://cg.cs.tsinghua.edu.cn/people/~tyang)1, Peiran Ren1, Xuansong Xie1, [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)1,2
6 | _1[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Hangzhou, China_
7 | _2[Department of Computing, The Hong Kong Polytechnic University](http://www.comp.polyu.edu.hk), Hong Kong, China_
8 |
9 | #### Face Restoration
10 |
11 |
12 |
13 |
14 |
15 |
16 | #### Face Colorization
17 |
18 |
19 |
20 | #### Face Inpainting
21 |
22 |
23 |
24 | #### Conditional Image Synthesis (Seg2Face)
25 |
26 |
27 |
28 | ## News
29 | (2021-07-06) The training code will be released soon. Stay tuned.
30 |
31 | (2021-10-11) The Colab demo for GPEN is available now
.
32 |
33 | (2021-10-22) GPEN can now work with SR methods. A SR model trained by myself is provided. Replace it with your own model if necessary.
34 |
35 | ## Usage
36 |
37 | 
38 | 
39 | 
40 | 
41 | 
42 |
43 | - Clone this repository:
44 | ```bash
45 | git clone https://github.com/yangxy/GPEN.git
46 | cd GPEN
47 | ```
48 | - Download RetinaFace model and our pre-trained model (not our best model due to commercial issues) and put them into ``weights/``.
49 |
50 | [RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [GPEN-BFR-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth) | [GPEN-BFR-512-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512-D.pth) | [GPEN-BFR-256](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256.pth) | [GPEN-Colorization-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth) | [GPEN-Inpainting-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Inpainting-1024.pth) | [GPEN-Seg2face-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Seg2face-512.pth) | [rrdb_realesrnet_psnr](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/rrdb_realesrnet_psnr.pth)
51 |
52 | - Restore face images:
53 | ```bash
54 | python face_enhancement.py --model GPEN-BFR-512 --size 512 --channel_multiplier 2 --narrow 1 --use_sr --indir examples/imgs --outdir examples/outs-BFR
55 | ```
56 |
57 | - Colorize faces:
58 | ```bash
59 | python face_colorization.py
60 | ```
61 |
62 | - Complete faces:
63 | ```bash
64 | python face_inpainting.py
65 | ```
66 |
67 | - Synthesize faces:
68 | ```bash
69 | python segmentation2face.py
70 | ```
71 |
72 | ## Main idea
73 |
74 |
75 | ## Citation
76 | If our work is useful for your research, please consider citing:
77 |
78 | @inproceedings{Yang2021GPEN,
79 | title={GAN Prior Embedded Network for Blind Face Restoration in the Wild},
80 | author={Tao Yang, Peiran Ren, Xuansong Xie, and Lei Zhang},
81 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
82 | year={2021}
83 | }
84 |
85 | ## License
86 | © Alibaba, 2021. For academic and non-commercial use only.
87 |
88 | ## Acknowledgments
89 | We borrow some codes from [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface), [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch), and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN).
90 |
91 | ## Contact
92 | If you have any questions or suggestions about this paper, feel free to reach me at yangtao9009@gmail.com.
93 |
--------------------------------------------------------------------------------
/GPEN/__init_paths.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import os.path as osp
6 | import sys
7 |
8 | def add_path(path):
9 | if path not in sys.path:
10 | sys.path.insert(0, path)
11 |
12 | this_dir = osp.dirname(__file__)
13 |
14 | path = osp.join(this_dir, 'retinaface')
15 | add_path(path)
16 |
17 | path = osp.join(this_dir, 'face_model')
18 | add_path(path)
19 |
20 | path = osp.join(this_dir, 'sr_model')
21 | add_path(path)
--------------------------------------------------------------------------------
/GPEN/align_faces.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mon Apr 24 15:43:29 2017
4 | @author: zhaoy
5 | """
6 | """
7 | @Modified by yangxy (yangtao9009@gmail.com)
8 | """
9 | import cv2
10 | import numpy as np
11 | from skimage import transform as trans
12 |
13 | # reference facial points, a list of coordinates (x,y)
14 | REFERENCE_FACIAL_POINTS = [
15 | [30.29459953, 51.69630051],
16 | [65.53179932, 51.50139999],
17 | [48.02519989, 71.73660278],
18 | [33.54930115, 92.3655014],
19 | [62.72990036, 92.20410156],
20 | ]
21 |
22 | DEFAULT_CROP_SIZE = (96, 112)
23 |
24 |
25 | def _umeyama(src, dst, estimate_scale=True, scale=1.0):
26 | """Estimate N-D similarity transformation with or without scaling.
27 | Parameters
28 | ----------
29 | src : (M, N) array
30 | Source coordinates.
31 | dst : (M, N) array
32 | Destination coordinates.
33 | estimate_scale : bool
34 | Whether to estimate scaling factor.
35 | Returns
36 | -------
37 | T : (N + 1, N + 1)
38 | The homogeneous similarity transformation matrix. The matrix contains
39 | NaN values only if the problem == not well-conditioned.
40 | References
41 | ----------
42 | .. [1] "Least-squares estimation of transformation parameters between two
43 | point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
44 | """
45 |
46 | num = src.shape[0]
47 | dim = src.shape[1]
48 |
49 | # Compute mean of src and dst.
50 | src_mean = src.mean(axis=0)
51 | dst_mean = dst.mean(axis=0)
52 |
53 | # Subtract mean from src and dst.
54 | src_demean = src - src_mean
55 | dst_demean = dst - dst_mean
56 |
57 | # Eq. (38).
58 | A = dst_demean.T @ src_demean / num
59 |
60 | # Eq. (39).
61 | d = np.ones((dim,), dtype=np.double)
62 | if np.linalg.det(A) < 0:
63 | d[dim - 1] = -1
64 |
65 | T = np.eye(dim + 1, dtype=np.double)
66 |
67 | U, S, V = np.linalg.svd(A)
68 |
69 | # Eq. (40) and (43).
70 | rank = np.linalg.matrix_rank(A)
71 | if rank == 0:
72 | return np.nan * T
73 | elif rank == dim - 1:
74 | if np.linalg.det(U) * np.linalg.det(V) > 0:
75 | T[:dim, :dim] = U @ V
76 | else:
77 | s = d[dim - 1]
78 | d[dim - 1] = -1
79 | T[:dim, :dim] = U @ np.diag(d) @ V
80 | d[dim - 1] = s
81 | else:
82 | T[:dim, :dim] = U @ np.diag(d) @ V
83 |
84 | if estimate_scale:
85 | # Eq. (41) and (42).
86 | scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
87 | else:
88 | scale = scale
89 |
90 | T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
91 | T[:dim, :dim] *= scale
92 |
93 | return T, scale
94 |
95 |
96 | class FaceWarpException(Exception):
97 | def __str__(self):
98 | return "In File {}:{}".format(__file__, super.__str__(self))
99 |
100 |
101 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
102 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
103 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
104 |
105 | # 0) make the inner region a square
106 | if default_square:
107 | size_diff = max(tmp_crop_size) - tmp_crop_size
108 | tmp_5pts += size_diff / 2
109 | tmp_crop_size += size_diff
110 |
111 | if output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]:
112 | print("output_size == DEFAULT_CROP_SIZE {}: return default reference points".format(tmp_crop_size))
113 | return tmp_5pts
114 |
115 | if inner_padding_factor == 0 and outer_padding == (0, 0):
116 | if output_size is None:
117 | print("No paddings to do: return default reference points")
118 | return tmp_5pts
119 | else:
120 | raise FaceWarpException("No paddings to do, output_size must be None or {}".format(tmp_crop_size))
121 |
122 | # check output size
123 | if not (0 <= inner_padding_factor <= 1.0):
124 | raise FaceWarpException("Not (0 <= inner_padding_factor <= 1.0)")
125 |
126 | if (inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None:
127 | output_size = tmp_crop_size * (1 + inner_padding_factor * 2).astype(np.int32)
128 | output_size += np.array(outer_padding)
129 | print(" deduced from paddings, output_size = ", output_size)
130 |
131 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
132 | raise FaceWarpException("Not (outer_padding[0] < output_size[0]" "and outer_padding[1] < output_size[1])")
133 |
134 | # 1) pad the inner region according inner_padding_factor
135 | # print('---> STEP1: pad the inner region according inner_padding_factor')
136 | if inner_padding_factor > 0:
137 | size_diff = tmp_crop_size * inner_padding_factor * 2
138 | tmp_5pts += size_diff / 2
139 | tmp_crop_size += np.round(size_diff).astype(np.int32)
140 |
141 | # print(' crop_size = ', tmp_crop_size)
142 | # print(' reference_5pts = ', tmp_5pts)
143 |
144 | # 2) resize the padded inner region
145 | # print('---> STEP2: resize the padded inner region')
146 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
147 | # print(' crop_size = ', tmp_crop_size)
148 | # print(' size_bf_outer_pad = ', size_bf_outer_pad)
149 |
150 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
151 | raise FaceWarpException("Must have (output_size - outer_padding)" "= some_scale * (crop_size * (1.0 + inner_padding_factor)")
152 |
153 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
154 | # print(' resize scale_factor = ', scale_factor)
155 | tmp_5pts = tmp_5pts * scale_factor
156 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
157 | # tmp_5pts = tmp_5pts + size_diff / 2
158 | tmp_crop_size = size_bf_outer_pad
159 | # print(' crop_size = ', tmp_crop_size)
160 | # print(' reference_5pts = ', tmp_5pts)
161 |
162 | # 3) add outer_padding to make output_size
163 | reference_5point = tmp_5pts + np.array(outer_padding)
164 | tmp_crop_size = output_size
165 | # print('---> STEP3: add outer_padding to make output_size')
166 | # print(' crop_size = ', tmp_crop_size)
167 | # print(' reference_5pts = ', tmp_5pts)
168 | #
169 | # print('===> end get_reference_facial_points\n')
170 |
171 | return reference_5point
172 |
173 |
174 | def get_affine_transform_matrix(src_pts, dst_pts):
175 | tfm = np.float32([[1, 0, 0], [0, 1, 0]])
176 | n_pts = src_pts.shape[0]
177 | ones = np.ones((n_pts, 1), src_pts.dtype)
178 | src_pts_ = np.hstack([src_pts, ones])
179 | dst_pts_ = np.hstack([dst_pts, ones])
180 |
181 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
182 |
183 | if rank == 3:
184 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
185 | elif rank == 2:
186 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
187 |
188 | return tfm
189 |
190 |
191 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"): # smilarity cv2_affine affine
192 | if reference_pts is None:
193 | if crop_size[0] == 96 and crop_size[1] == 112:
194 | reference_pts = REFERENCE_FACIAL_POINTS
195 | else:
196 | default_square = False
197 | inner_padding_factor = 0
198 | outer_padding = (0, 0)
199 | output_size = crop_size
200 |
201 | reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, default_square)
202 | ref_pts = np.float32(reference_pts)
203 | ref_pts_shp = ref_pts.shape
204 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
205 | raise FaceWarpException("reference_pts.shape must be (K,2) or (2,K) and K>2")
206 |
207 | if ref_pts_shp[0] == 2:
208 | ref_pts = ref_pts.T
209 |
210 | src_pts = np.float32(facial_pts)
211 | src_pts_shp = src_pts.shape
212 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
213 | raise FaceWarpException("facial_pts.shape must be (K,2) or (2,K) and K>2")
214 |
215 | if src_pts_shp[0] == 2:
216 | src_pts = src_pts.T
217 |
218 | if src_pts.shape != ref_pts.shape:
219 | raise FaceWarpException("facial_pts and reference_pts must have the same shape")
220 |
221 | if align_type == "cv2_affine":
222 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
223 | tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
224 | elif align_type == "affine":
225 | tfm = get_affine_transform_matrix(src_pts, ref_pts)
226 | tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
227 | else:
228 | params, scale = _umeyama(src_pts, ref_pts)
229 | tfm = params[:2, :]
230 |
231 | params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale)
232 | tfm_inv = params[:2, :]
233 |
234 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
235 |
236 | return face_img, tfm_inv
237 |
--------------------------------------------------------------------------------
/GPEN/face_enhancement.py:
--------------------------------------------------------------------------------
1 | """
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | """
5 | import os
6 | import cv2
7 | import glob
8 | import time
9 | import argparse
10 | import numpy as np
11 | from PIL import Image
12 | from skimage import transform as tf
13 |
14 | import GPEN.__init_paths as init_paths
15 | from GPEN.retinaface.retinaface_detection import RetinaFaceDetection
16 | from GPEN.face_model.face_gan import FaceGAN
17 | from GPEN.sr_model.real_esrnet import RealESRNet
18 | from GPEN.align_faces import warp_and_crop_face, get_reference_facial_points
19 |
20 | def check_ckpts(model, sr_model):
21 | # check if checkpoints are downloaded
22 | try:
23 | ckpts_folder = os.path.join(os.path.dirname(__file__), "weights")
24 | if not os.path.exists(ckpts_folder):
25 | print("Downloading checkpoints...")
26 | from gdown import download_folder
27 | file_id = "1epln5c8HW1QXfVz6444Fe0hG-vRNavi6"
28 | download_folder(id=file_id, output=ckpts_folder, quiet=False, use_cookies=False)
29 | else:
30 | print("Checkpoints already downloaded, skipping...")
31 | except Exception as e:
32 | print(e)
33 | raise Exception("Error while downloading checkpoints")
34 |
35 |
36 | class FaceEnhancement(object):
37 | def __init__(self, base_dir=os.path.dirname(__file__), size=512, model=None, use_sr=True, sr_model=None, channel_multiplier=2, narrow=1, use_facegan=True):
38 | check_ckpts(model, sr_model)
39 |
40 | self.facedetector = RetinaFaceDetection(base_dir)
41 | self.facegan = FaceGAN(base_dir, size, model, channel_multiplier, narrow)
42 | self.srmodel = RealESRNet(base_dir, sr_model)
43 | self.use_sr = use_sr
44 | self.size = size
45 | self.threshold = 0.9
46 | self.use_facegan = use_facegan
47 |
48 | # the mask for pasting restored faces back
49 | self.mask = np.zeros((512, 512), np.float32)
50 | cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA)
51 | self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11)
52 | self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11)
53 |
54 | self.kernel = np.array(([0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625]), dtype="float32")
55 |
56 | # get the reference 5 landmarks position in the crop settings
57 | default_square = True
58 | inner_padding_factor = 0.25
59 | outer_padding = (0, 0)
60 | self.reference_5pts = get_reference_facial_points((self.size, self.size), inner_padding_factor, outer_padding, default_square)
61 |
62 | def process(self, img):
63 | if self.use_sr:
64 | img_sr = self.srmodel.process(img)
65 | if img_sr is not None:
66 | img = cv2.resize(img, img_sr.shape[:2][::-1])
67 |
68 | facebs, landms = self.facedetector.detect(img)
69 |
70 | orig_faces, enhanced_faces = [], []
71 | height, width = img.shape[:2]
72 | full_mask = np.zeros((height, width), dtype=np.float32)
73 | full_img = np.zeros(img.shape, dtype=np.uint8)
74 |
75 | for i, (faceb, facial5points) in enumerate(zip(facebs, landms)):
76 | if faceb[4] < self.threshold:
77 | continue
78 | fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0])
79 |
80 | facial5points = np.reshape(facial5points, (2, 5))
81 |
82 | of, tfm_inv = warp_and_crop_face(img, facial5points, reference_pts=self.reference_5pts, crop_size=(self.size, self.size))
83 |
84 | # enhance the face
85 | ef = self.facegan.process(of) if self.use_facegan else of
86 |
87 | orig_faces.append(of)
88 | enhanced_faces.append(ef)
89 |
90 | tmp_mask = self.mask
91 | tmp_mask = cv2.resize(tmp_mask, ef.shape[:2])
92 | tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3)
93 |
94 | if min(fh, fw) < 100: # gaussian filter for small faces
95 | ef = cv2.filter2D(ef, -1, self.kernel)
96 |
97 | tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3)
98 |
99 | mask = tmp_mask - full_mask
100 | full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)]
101 | full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)]
102 |
103 | full_mask = full_mask[:, :, np.newaxis]
104 | if self.use_sr and img_sr is not None:
105 | img = cv2.convertScaleAbs(img_sr * (1 - full_mask) + full_img * full_mask)
106 | else:
107 | img = cv2.convertScaleAbs(img * (1 - full_mask) + full_img * full_mask)
108 |
109 | return img, orig_faces, enhanced_faces
110 |
111 |
112 | if __name__ == "__main__":
113 | parser = argparse.ArgumentParser()
114 | parser.add_argument("--model", type=str, default="GPEN-BFR-512", help="GPEN model")
115 | parser.add_argument("--size", type=int, default=512, help="resolution of GPEN")
116 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of GPEN")
117 | parser.add_argument("--narrow", type=float, default=1, help="channel narrow scale")
118 | parser.add_argument("--use_sr", action="store_true", help="use sr or not")
119 | parser.add_argument("--sr_model", type=str, default="realesrnet_x2", help="SR model")
120 | parser.add_argument("--sr_scale", type=int, default=2, help="SR scale")
121 | parser.add_argument("--indir", type=str, default="examples/imgs", help="input folder")
122 | parser.add_argument("--outdir", type=str, default="results/outs-BFR", help="output folder")
123 | args = parser.parse_args()
124 |
125 | # model = {'name':'GPEN-BFR-512', 'size':512, 'channel_multiplier':2, 'narrow':1}
126 | # model = {'name':'GPEN-BFR-256', 'size':256, 'channel_multiplier':1, 'narrow':0.5}
127 |
128 | os.makedirs(args.outdir, exist_ok=True)
129 |
130 | faceenhancer = FaceEnhancement(
131 | size=args.size,
132 | model=args.model,
133 | use_sr=args.use_sr,
134 | sr_model=args.sr_model,
135 | channel_multiplier=args.channel_multiplier,
136 | narrow=args.narrow,
137 | )
138 |
139 | files = sorted(glob.glob(os.path.join(args.indir, "*.*g")))
140 | for n, file in enumerate(files[:]):
141 | filename = os.path.basename(file)
142 |
143 | im = cv2.imread(file, cv2.IMREAD_COLOR) # BGR
144 | if not isinstance(im, np.ndarray):
145 | print(filename, "error")
146 | continue
147 | # im = cv2.resize(im, (0,0), fx=2, fy=2) # optional
148 |
149 | img, orig_faces, enhanced_faces = faceenhancer.process(im)
150 |
151 | im = cv2.resize(im, img.shape[:2][::-1])
152 | cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_COMP.jpg"), np.hstack((im, img)))
153 | cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_GPEN.jpg"), img)
154 |
155 | for m, (ef, of) in enumerate(zip(enhanced_faces, orig_faces)):
156 | of = cv2.resize(of, ef.shape[:2])
157 | cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_face%02d" % m + ".jpg"), np.hstack((of, ef)))
158 |
159 | if n % 10 == 0:
160 | print(n, filename)
161 |
--------------------------------------------------------------------------------
/GPEN/face_model/face_gan.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import torch
6 | import os
7 | import cv2
8 | import glob
9 | import numpy as np
10 | from torch import nn
11 | import torch.nn.functional as F
12 | from torchvision import transforms, utils
13 | from model import FullGenerator
14 | import torch
15 |
16 | class FaceGAN(object):
17 | def __init__(self, base_dir='./', size=512, model=None, channel_multiplier=2, narrow=1, is_norm=True):
18 | self.mfile = os.path.join(base_dir, 'weights', model+'.pth')
19 | self.n_mlp = 8
20 | self.is_norm = is_norm
21 | self.resolution = size
22 | self.load_model(channel_multiplier, narrow)
23 |
24 | def load_model(self, channel_multiplier=2, narrow=1):
25 | self.model = FullGenerator(self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow).cuda()
26 | pretrained_dict = torch.load(self.mfile)
27 | self.model.load_state_dict(pretrained_dict)
28 | self.model.eval()
29 |
30 | def process(self, img):
31 | img = cv2.resize(img, (self.resolution, self.resolution))
32 | img_t = self.img2tensor(img)
33 |
34 | with torch.no_grad():
35 | out, __ = self.model(img_t)
36 |
37 | out = self.tensor2img(out)
38 |
39 | return out
40 |
41 | def img2tensor(self, img):
42 | img_t = torch.from_numpy(img).cuda()/255.
43 | if self.is_norm:
44 | img_t = (img_t - 0.5) / 0.5
45 | img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
46 | return img_t
47 |
48 | def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8):
49 | if self.is_norm:
50 | img_t = img_t * 0.5 + 0.5
51 | img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
52 | img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax
53 |
54 | return img_np.astype(imtype)
55 |
--------------------------------------------------------------------------------
/GPEN/face_model/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/GPEN/face_model/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load, _import_module_from_library
7 |
8 |
9 | module_path = os.path.dirname(__file__)
10 | fused = load(
11 | 'fused',
12 | sources=[
13 | os.path.join(module_path, 'fused_bias_act.cpp'),
14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'),
15 | ],
16 | )
17 |
18 | #fused = _import_module_from_library('fused', '/tmp/torch_extensions/fused', True)
19 |
20 |
21 | class FusedLeakyReLUFunctionBackward(Function):
22 | @staticmethod
23 | def forward(ctx, grad_output, out, negative_slope, scale):
24 | ctx.save_for_backward(out)
25 | ctx.negative_slope = negative_slope
26 | ctx.scale = scale
27 |
28 | empty = grad_output.new_empty(0)
29 |
30 | grad_input = fused.fused_bias_act(
31 | grad_output, empty, out, 3, 1, negative_slope, scale
32 | )
33 |
34 | dim = [0]
35 |
36 | if grad_input.ndim > 2:
37 | dim += list(range(2, grad_input.ndim))
38 |
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | return grad_input, grad_bias
42 |
43 | @staticmethod
44 | def backward(ctx, gradgrad_input, gradgrad_bias):
45 | out, = ctx.saved_tensors
46 | gradgrad_out = fused.fused_bias_act(
47 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
48 | )
49 |
50 | return gradgrad_out, None, None, None
51 |
52 |
53 | class FusedLeakyReLUFunction(Function):
54 | @staticmethod
55 | def forward(ctx, input, bias, negative_slope, scale):
56 | empty = input.new_empty(0)
57 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
58 | ctx.save_for_backward(out)
59 | ctx.negative_slope = negative_slope
60 | ctx.scale = scale
61 |
62 | return out
63 |
64 | @staticmethod
65 | def backward(ctx, grad_output):
66 | out, = ctx.saved_tensors
67 |
68 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
69 | grad_output, out, ctx.negative_slope, ctx.scale
70 | )
71 |
72 | return grad_input, grad_bias, None, None
73 |
74 |
75 | class FusedLeakyReLU(nn.Module):
76 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
77 | super().__init__()
78 |
79 | self.bias = nn.Parameter(torch.zeros(channel))
80 | self.negative_slope = negative_slope
81 | self.scale = scale
82 |
83 | def forward(self, input):
84 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
85 |
86 |
87 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
88 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
89 |
--------------------------------------------------------------------------------
/GPEN/face_model/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/GPEN/face_model/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/GPEN/face_model/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/GPEN/face_model/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.autograd import Function
5 | from torch.utils.cpp_extension import load, _import_module_from_library
6 |
7 |
8 | module_path = os.path.dirname(__file__)
9 | upfirdn2d_op = load(
10 | 'upfirdn2d',
11 | sources=[
12 | os.path.join(module_path, 'upfirdn2d.cpp'),
13 | os.path.join(module_path, 'upfirdn2d_kernel.cu'),
14 | ],
15 | )
16 |
17 | #upfirdn2d_op = _import_module_from_library('upfirdn2d', '/tmp/torch_extensions/upfirdn2d', True)
18 |
19 | class UpFirDn2dBackward(Function):
20 | @staticmethod
21 | def forward(
22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
23 | ):
24 |
25 | up_x, up_y = up
26 | down_x, down_y = down
27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
28 |
29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
30 |
31 | grad_input = upfirdn2d_op.upfirdn2d(
32 | grad_output,
33 | grad_kernel,
34 | down_x,
35 | down_y,
36 | up_x,
37 | up_y,
38 | g_pad_x0,
39 | g_pad_x1,
40 | g_pad_y0,
41 | g_pad_y1,
42 | )
43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
44 |
45 | ctx.save_for_backward(kernel)
46 |
47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
48 |
49 | ctx.up_x = up_x
50 | ctx.up_y = up_y
51 | ctx.down_x = down_x
52 | ctx.down_y = down_y
53 | ctx.pad_x0 = pad_x0
54 | ctx.pad_x1 = pad_x1
55 | ctx.pad_y0 = pad_y0
56 | ctx.pad_y1 = pad_y1
57 | ctx.in_size = in_size
58 | ctx.out_size = out_size
59 |
60 | return grad_input
61 |
62 | @staticmethod
63 | def backward(ctx, gradgrad_input):
64 | kernel, = ctx.saved_tensors
65 |
66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
67 |
68 | gradgrad_out = upfirdn2d_op.upfirdn2d(
69 | gradgrad_input,
70 | kernel,
71 | ctx.up_x,
72 | ctx.up_y,
73 | ctx.down_x,
74 | ctx.down_y,
75 | ctx.pad_x0,
76 | ctx.pad_x1,
77 | ctx.pad_y0,
78 | ctx.pad_y1,
79 | )
80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
81 | gradgrad_out = gradgrad_out.view(
82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
83 | )
84 |
85 | return gradgrad_out, None, None, None, None, None, None, None, None
86 |
87 |
88 | class UpFirDn2d(Function):
89 | @staticmethod
90 | def forward(ctx, input, kernel, up, down, pad):
91 | up_x, up_y = up
92 | down_x, down_y = down
93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
94 |
95 | kernel_h, kernel_w = kernel.shape
96 | batch, channel, in_h, in_w = input.shape
97 | ctx.in_size = input.shape
98 |
99 | input = input.reshape(-1, in_h, in_w, 1)
100 |
101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
102 |
103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
105 | ctx.out_size = (out_h, out_w)
106 |
107 | ctx.up = (up_x, up_y)
108 | ctx.down = (down_x, down_y)
109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
110 |
111 | g_pad_x0 = kernel_w - pad_x0 - 1
112 | g_pad_y0 = kernel_h - pad_y0 - 1
113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
115 |
116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
117 |
118 | out = upfirdn2d_op.upfirdn2d(
119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
120 | )
121 | # out = out.view(major, out_h, out_w, minor)
122 | out = out.view(-1, channel, out_h, out_w)
123 |
124 | return out
125 |
126 | @staticmethod
127 | def backward(ctx, grad_output):
128 | kernel, grad_kernel = ctx.saved_tensors
129 |
130 | grad_input = UpFirDn2dBackward.apply(
131 | grad_output,
132 | kernel,
133 | grad_kernel,
134 | ctx.up,
135 | ctx.down,
136 | ctx.pad,
137 | ctx.g_pad,
138 | ctx.in_size,
139 | ctx.out_size,
140 | )
141 |
142 | return grad_input, None, None, None, None
143 |
144 |
145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
146 | out = UpFirDn2d.apply(
147 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
148 | )
149 |
150 | return out
151 |
152 |
153 | def upfirdn2d_native(
154 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
155 | ):
156 | _, in_h, in_w, minor = input.shape
157 | kernel_h, kernel_w = kernel.shape
158 |
159 | out = input.view(-1, in_h, 1, in_w, 1, minor)
160 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
161 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
162 |
163 | out = F.pad(
164 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
165 | )
166 | out = out[
167 | :,
168 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
169 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
170 | :,
171 | ]
172 |
173 | out = out.permute(0, 3, 1, 2)
174 | out = out.reshape(
175 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
176 | )
177 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
178 | out = F.conv2d(out, w)
179 | out = out.reshape(
180 | -1,
181 | minor,
182 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
183 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
184 | )
185 | out = out.permute(0, 2, 3, 1)
186 |
187 | return out[:, ::down_y, ::down_x, :]
188 |
189 |
--------------------------------------------------------------------------------
/GPEN/face_model/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19 | int c = a / b;
20 |
21 | if (c * b > a) {
22 | c--;
23 | }
24 |
25 | return c;
26 | }
27 |
28 |
29 | struct UpFirDn2DKernelParams {
30 | int up_x;
31 | int up_y;
32 | int down_x;
33 | int down_y;
34 | int pad_x0;
35 | int pad_x1;
36 | int pad_y0;
37 | int pad_y1;
38 |
39 | int major_dim;
40 | int in_h;
41 | int in_w;
42 | int minor_dim;
43 | int kernel_h;
44 | int kernel_w;
45 | int out_h;
46 | int out_w;
47 | int loop_major;
48 | int loop_x;
49 | };
50 |
51 |
52 | template
53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56 |
57 | __shared__ volatile float sk[kernel_h][kernel_w];
58 | __shared__ volatile float sx[tile_in_h][tile_in_w];
59 |
60 | int minor_idx = blockIdx.x;
61 | int tile_out_y = minor_idx / p.minor_dim;
62 | minor_idx -= tile_out_y * p.minor_dim;
63 | tile_out_y *= tile_out_h;
64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65 | int major_idx_base = blockIdx.z * p.loop_major;
66 |
67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68 | return;
69 | }
70 |
71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72 | int ky = tap_idx / kernel_w;
73 | int kx = tap_idx - ky * kernel_w;
74 | scalar_t v = 0.0;
75 |
76 | if (kx < p.kernel_w & ky < p.kernel_h) {
77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78 | }
79 |
80 | sk[ky][kx] = v;
81 | }
82 |
83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87 | int tile_in_x = floor_div(tile_mid_x, up_x);
88 | int tile_in_y = floor_div(tile_mid_y, up_y);
89 |
90 | __syncthreads();
91 |
92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93 | int rel_in_y = in_idx / tile_in_w;
94 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
95 | int in_x = rel_in_x + tile_in_x;
96 | int in_y = rel_in_y + tile_in_y;
97 |
98 | scalar_t v = 0.0;
99 |
100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102 | }
103 |
104 | sx[rel_in_y][rel_in_x] = v;
105 | }
106 |
107 | __syncthreads();
108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109 | int rel_out_y = out_idx / tile_out_w;
110 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
111 | int out_x = rel_out_x + tile_out_x;
112 | int out_y = rel_out_y + tile_out_y;
113 |
114 | int mid_x = tile_mid_x + rel_out_x * down_x;
115 | int mid_y = tile_mid_y + rel_out_y * down_y;
116 | int in_x = floor_div(mid_x, up_x);
117 | int in_y = floor_div(mid_y, up_y);
118 | int rel_in_x = in_x - tile_in_x;
119 | int rel_in_y = in_y - tile_in_y;
120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122 |
123 | scalar_t v = 0.0;
124 |
125 | #pragma unroll
126 | for (int y = 0; y < kernel_h / up_y; y++)
127 | #pragma unroll
128 | for (int x = 0; x < kernel_w / up_x; x++)
129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130 |
131 | if (out_x < p.out_w & out_y < p.out_h) {
132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133 | }
134 | }
135 | }
136 | }
137 | }
138 |
139 |
140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141 | int up_x, int up_y, int down_x, int down_y,
142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143 | int curDevice = -1;
144 | cudaGetDevice(&curDevice);
145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146 |
147 | UpFirDn2DKernelParams p;
148 |
149 | auto x = input.contiguous();
150 | auto k = kernel.contiguous();
151 |
152 | p.major_dim = x.size(0);
153 | p.in_h = x.size(1);
154 | p.in_w = x.size(2);
155 | p.minor_dim = x.size(3);
156 | p.kernel_h = k.size(0);
157 | p.kernel_w = k.size(1);
158 | p.up_x = up_x;
159 | p.up_y = up_y;
160 | p.down_x = down_x;
161 | p.down_y = down_y;
162 | p.pad_x0 = pad_x0;
163 | p.pad_x1 = pad_x1;
164 | p.pad_y0 = pad_y0;
165 | p.pad_y1 = pad_y1;
166 |
167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169 |
170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171 |
172 | int mode = -1;
173 |
174 | int tile_out_h;
175 | int tile_out_w;
176 |
177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178 | mode = 1;
179 | tile_out_h = 16;
180 | tile_out_w = 64;
181 | }
182 |
183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184 | mode = 2;
185 | tile_out_h = 16;
186 | tile_out_w = 64;
187 | }
188 |
189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190 | mode = 3;
191 | tile_out_h = 16;
192 | tile_out_w = 64;
193 | }
194 |
195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196 | mode = 4;
197 | tile_out_h = 16;
198 | tile_out_w = 64;
199 | }
200 |
201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202 | mode = 5;
203 | tile_out_h = 8;
204 | tile_out_w = 32;
205 | }
206 |
207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208 | mode = 6;
209 | tile_out_h = 8;
210 | tile_out_w = 32;
211 | }
212 |
213 | dim3 block_size;
214 | dim3 grid_size;
215 |
216 | if (tile_out_h > 0 && tile_out_w) {
217 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
218 | p.loop_x = 1;
219 | block_size = dim3(32 * 8, 1, 1);
220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222 | (p.major_dim - 1) / p.loop_major + 1);
223 | }
224 |
225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226 | switch (mode) {
227 | case 1:
228 | upfirdn2d_kernel<<>>(
229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
230 | );
231 |
232 | break;
233 |
234 | case 2:
235 | upfirdn2d_kernel<<>>(
236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
237 | );
238 |
239 | break;
240 |
241 | case 3:
242 | upfirdn2d_kernel<<>>(
243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
244 | );
245 |
246 | break;
247 |
248 | case 4:
249 | upfirdn2d_kernel<<>>(
250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
251 | );
252 |
253 | break;
254 |
255 | case 5:
256 | upfirdn2d_kernel<<>>(
257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
258 | );
259 |
260 | break;
261 |
262 | case 6:
263 | upfirdn2d_kernel<<>>(
264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
265 | );
266 |
267 | break;
268 | }
269 | });
270 |
271 | return out;
272 | }
--------------------------------------------------------------------------------
/GPEN/requirements.txt:
--------------------------------------------------------------------------------
1 | ninja
2 | torch
3 | torchvision
4 | opencv-python
5 | numpy
6 | scikit-image
7 | pillow
8 | pyyaml==5.4.1
--------------------------------------------------------------------------------
/GPEN/retinaface/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/GPEN/retinaface/.DS_Store
--------------------------------------------------------------------------------
/GPEN/retinaface/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .wider_face import WiderFaceDetection, detection_collate
2 | from .data_augment import *
3 | from .config import *
4 |
--------------------------------------------------------------------------------
/GPEN/retinaface/data/config.py:
--------------------------------------------------------------------------------
1 | # config.py
2 |
3 | cfg_mnet = {
4 | 'name': 'mobilenet0.25',
5 | 'min_sizes': [[16, 32], [64, 128], [256, 512]],
6 | 'steps': [8, 16, 32],
7 | 'variance': [0.1, 0.2],
8 | 'clip': False,
9 | 'loc_weight': 2.0,
10 | 'gpu_train': True,
11 | 'batch_size': 32,
12 | 'ngpu': 1,
13 | 'epoch': 250,
14 | 'decay1': 190,
15 | 'decay2': 220,
16 | 'image_size': 640,
17 | 'pretrain': False,
18 | 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
19 | 'in_channel': 32,
20 | 'out_channel': 64
21 | }
22 |
23 | cfg_re50 = {
24 | 'name': 'Resnet50',
25 | 'min_sizes': [[16, 32], [64, 128], [256, 512]],
26 | 'steps': [8, 16, 32],
27 | 'variance': [0.1, 0.2],
28 | 'clip': False,
29 | 'loc_weight': 2.0,
30 | 'gpu_train': True,
31 | 'batch_size': 24,
32 | 'ngpu': 4,
33 | 'epoch': 100,
34 | 'decay1': 70,
35 | 'decay2': 90,
36 | 'image_size': 840,
37 | 'pretrain': False,
38 | 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
39 | 'in_channel': 256,
40 | 'out_channel': 256
41 | }
42 |
43 |
--------------------------------------------------------------------------------
/GPEN/retinaface/data/data_augment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 | from utils.box_utils import matrix_iof
5 |
6 |
7 | def _crop(image, boxes, labels, landm, img_dim):
8 | height, width, _ = image.shape
9 | pad_image_flag = True
10 |
11 | for _ in range(250):
12 | """
13 | if random.uniform(0, 1) <= 0.2:
14 | scale = 1.0
15 | else:
16 | scale = random.uniform(0.3, 1.0)
17 | """
18 | PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
19 | scale = random.choice(PRE_SCALES)
20 | short_side = min(width, height)
21 | w = int(scale * short_side)
22 | h = w
23 |
24 | if width == w:
25 | l = 0
26 | else:
27 | l = random.randrange(width - w)
28 | if height == h:
29 | t = 0
30 | else:
31 | t = random.randrange(height - h)
32 | roi = np.array((l, t, l + w, t + h))
33 |
34 | value = matrix_iof(boxes, roi[np.newaxis])
35 | flag = (value >= 1)
36 | if not flag.any():
37 | continue
38 |
39 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2
40 | mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
41 | boxes_t = boxes[mask_a].copy()
42 | labels_t = labels[mask_a].copy()
43 | landms_t = landm[mask_a].copy()
44 | landms_t = landms_t.reshape([-1, 5, 2])
45 |
46 | if boxes_t.shape[0] == 0:
47 | continue
48 |
49 | image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
50 |
51 | boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
52 | boxes_t[:, :2] -= roi[:2]
53 | boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
54 | boxes_t[:, 2:] -= roi[:2]
55 |
56 | # landm
57 | landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
58 | landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
59 | landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
60 | landms_t = landms_t.reshape([-1, 10])
61 |
62 |
63 | # make sure that the cropped image contains at least one face > 16 pixel at training image scale
64 | b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
65 | b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
66 | mask_b = np.minimum(b_w_t, b_h_t) > 0.0
67 | boxes_t = boxes_t[mask_b]
68 | labels_t = labels_t[mask_b]
69 | landms_t = landms_t[mask_b]
70 |
71 | if boxes_t.shape[0] == 0:
72 | continue
73 |
74 | pad_image_flag = False
75 |
76 | return image_t, boxes_t, labels_t, landms_t, pad_image_flag
77 | return image, boxes, labels, landm, pad_image_flag
78 |
79 |
80 | def _distort(image):
81 |
82 | def _convert(image, alpha=1, beta=0):
83 | tmp = image.astype(float) * alpha + beta
84 | tmp[tmp < 0] = 0
85 | tmp[tmp > 255] = 255
86 | image[:] = tmp
87 |
88 | image = image.copy()
89 |
90 | if random.randrange(2):
91 |
92 | #brightness distortion
93 | if random.randrange(2):
94 | _convert(image, beta=random.uniform(-32, 32))
95 |
96 | #contrast distortion
97 | if random.randrange(2):
98 | _convert(image, alpha=random.uniform(0.5, 1.5))
99 |
100 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
101 |
102 | #saturation distortion
103 | if random.randrange(2):
104 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
105 |
106 | #hue distortion
107 | if random.randrange(2):
108 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
109 | tmp %= 180
110 | image[:, :, 0] = tmp
111 |
112 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
113 |
114 | else:
115 |
116 | #brightness distortion
117 | if random.randrange(2):
118 | _convert(image, beta=random.uniform(-32, 32))
119 |
120 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
121 |
122 | #saturation distortion
123 | if random.randrange(2):
124 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
125 |
126 | #hue distortion
127 | if random.randrange(2):
128 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
129 | tmp %= 180
130 | image[:, :, 0] = tmp
131 |
132 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
133 |
134 | #contrast distortion
135 | if random.randrange(2):
136 | _convert(image, alpha=random.uniform(0.5, 1.5))
137 |
138 | return image
139 |
140 |
141 | def _expand(image, boxes, fill, p):
142 | if random.randrange(2):
143 | return image, boxes
144 |
145 | height, width, depth = image.shape
146 |
147 | scale = random.uniform(1, p)
148 | w = int(scale * width)
149 | h = int(scale * height)
150 |
151 | left = random.randint(0, w - width)
152 | top = random.randint(0, h - height)
153 |
154 | boxes_t = boxes.copy()
155 | boxes_t[:, :2] += (left, top)
156 | boxes_t[:, 2:] += (left, top)
157 | expand_image = np.empty(
158 | (h, w, depth),
159 | dtype=image.dtype)
160 | expand_image[:, :] = fill
161 | expand_image[top:top + height, left:left + width] = image
162 | image = expand_image
163 |
164 | return image, boxes_t
165 |
166 |
167 | def _mirror(image, boxes, landms):
168 | _, width, _ = image.shape
169 | if random.randrange(2):
170 | image = image[:, ::-1]
171 | boxes = boxes.copy()
172 | boxes[:, 0::2] = width - boxes[:, 2::-2]
173 |
174 | # landm
175 | landms = landms.copy()
176 | landms = landms.reshape([-1, 5, 2])
177 | landms[:, :, 0] = width - landms[:, :, 0]
178 | tmp = landms[:, 1, :].copy()
179 | landms[:, 1, :] = landms[:, 0, :]
180 | landms[:, 0, :] = tmp
181 | tmp1 = landms[:, 4, :].copy()
182 | landms[:, 4, :] = landms[:, 3, :]
183 | landms[:, 3, :] = tmp1
184 | landms = landms.reshape([-1, 10])
185 |
186 | return image, boxes, landms
187 |
188 |
189 | def _pad_to_square(image, rgb_mean, pad_image_flag):
190 | if not pad_image_flag:
191 | return image
192 | height, width, _ = image.shape
193 | long_side = max(width, height)
194 | image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
195 | image_t[:, :] = rgb_mean
196 | image_t[0:0 + height, 0:0 + width] = image
197 | return image_t
198 |
199 |
200 | def _resize_subtract_mean(image, insize, rgb_mean):
201 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
202 | interp_method = interp_methods[random.randrange(5)]
203 | image = cv2.resize(image, (insize, insize), interpolation=interp_method)
204 | image = image.astype(np.float32)
205 | image -= rgb_mean
206 | return image.transpose(2, 0, 1)
207 |
208 |
209 | class preproc(object):
210 |
211 | def __init__(self, img_dim, rgb_means):
212 | self.img_dim = img_dim
213 | self.rgb_means = rgb_means
214 |
215 | def __call__(self, image, targets):
216 | assert targets.shape[0] > 0, "this image does not have gt"
217 |
218 | boxes = targets[:, :4].copy()
219 | labels = targets[:, -1].copy()
220 | landm = targets[:, 4:-1].copy()
221 |
222 | image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
223 | image_t = _distort(image_t)
224 | image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag)
225 | image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
226 | height, width, _ = image_t.shape
227 | image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
228 | boxes_t[:, 0::2] /= width
229 | boxes_t[:, 1::2] /= height
230 |
231 | landm_t[:, 0::2] /= width
232 | landm_t[:, 1::2] /= height
233 |
234 | labels_t = np.expand_dims(labels_t, 1)
235 | targets_t = np.hstack((boxes_t, landm_t, labels_t))
236 |
237 | return image_t, targets_t
238 |
--------------------------------------------------------------------------------
/GPEN/retinaface/data/wider_face.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import sys
4 | import torch
5 | import torch.utils.data as data
6 | import cv2
7 | import numpy as np
8 |
9 | class WiderFaceDetection(data.Dataset):
10 | def __init__(self, txt_path, preproc=None):
11 | self.preproc = preproc
12 | self.imgs_path = []
13 | self.words = []
14 | f = open(txt_path,'r')
15 | lines = f.readlines()
16 | isFirst = True
17 | labels = []
18 | for line in lines:
19 | line = line.rstrip()
20 | if line.startswith('#'):
21 | if isFirst is True:
22 | isFirst = False
23 | else:
24 | labels_copy = labels.copy()
25 | self.words.append(labels_copy)
26 | labels.clear()
27 | path = line[2:]
28 | path = txt_path.replace('label.txt','images/') + path
29 | self.imgs_path.append(path)
30 | else:
31 | line = line.split(' ')
32 | label = [float(x) for x in line]
33 | labels.append(label)
34 |
35 | self.words.append(labels)
36 |
37 | def __len__(self):
38 | return len(self.imgs_path)
39 |
40 | def __getitem__(self, index):
41 | img = cv2.imread(self.imgs_path[index])
42 | height, width, _ = img.shape
43 |
44 | labels = self.words[index]
45 | annotations = np.zeros((0, 15))
46 | if len(labels) == 0:
47 | return annotations
48 | for idx, label in enumerate(labels):
49 | annotation = np.zeros((1, 15))
50 | # bbox
51 | annotation[0, 0] = label[0] # x1
52 | annotation[0, 1] = label[1] # y1
53 | annotation[0, 2] = label[0] + label[2] # x2
54 | annotation[0, 3] = label[1] + label[3] # y2
55 |
56 | # landmarks
57 | annotation[0, 4] = label[4] # l0_x
58 | annotation[0, 5] = label[5] # l0_y
59 | annotation[0, 6] = label[7] # l1_x
60 | annotation[0, 7] = label[8] # l1_y
61 | annotation[0, 8] = label[10] # l2_x
62 | annotation[0, 9] = label[11] # l2_y
63 | annotation[0, 10] = label[13] # l3_x
64 | annotation[0, 11] = label[14] # l3_y
65 | annotation[0, 12] = label[16] # l4_x
66 | annotation[0, 13] = label[17] # l4_y
67 | if (annotation[0, 4]<0):
68 | annotation[0, 14] = -1
69 | else:
70 | annotation[0, 14] = 1
71 |
72 | annotations = np.append(annotations, annotation, axis=0)
73 | target = np.array(annotations)
74 | if self.preproc is not None:
75 | img, target = self.preproc(img, target)
76 |
77 | return torch.from_numpy(img), target
78 |
79 | def detection_collate(batch):
80 | """Custom collate fn for dealing with batches of images that have a different
81 | number of associated object annotations (bounding boxes).
82 |
83 | Arguments:
84 | batch: (tuple) A tuple of tensor images and lists of annotations
85 |
86 | Return:
87 | A tuple containing:
88 | 1) (tensor) batch of images stacked on their 0 dim
89 | 2) (list of tensors) annotations for a given image are stacked on 0 dim
90 | """
91 | targets = []
92 | imgs = []
93 | for _, sample in enumerate(batch):
94 | for _, tup in enumerate(sample):
95 | if torch.is_tensor(tup):
96 | imgs.append(tup)
97 | elif isinstance(tup, type(np.empty(0))):
98 | annos = torch.from_numpy(tup).float()
99 | targets.append(annos)
100 |
101 | return (torch.stack(imgs, 0), targets)
102 |
--------------------------------------------------------------------------------
/GPEN/retinaface/facemodels/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/GPEN/retinaface/facemodels/__init__.py
--------------------------------------------------------------------------------
/GPEN/retinaface/facemodels/net.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.models._utils as _utils
5 | import torchvision.models as models
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 | def conv_bn(inp, oup, stride = 1, leaky = 0):
10 | return nn.Sequential(
11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
12 | nn.BatchNorm2d(oup),
13 | nn.LeakyReLU(negative_slope=leaky, inplace=True)
14 | )
15 |
16 | def conv_bn_no_relu(inp, oup, stride):
17 | return nn.Sequential(
18 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
19 | nn.BatchNorm2d(oup),
20 | )
21 |
22 | def conv_bn1X1(inp, oup, stride, leaky=0):
23 | return nn.Sequential(
24 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
25 | nn.BatchNorm2d(oup),
26 | nn.LeakyReLU(negative_slope=leaky, inplace=True)
27 | )
28 |
29 | def conv_dw(inp, oup, stride, leaky=0.1):
30 | return nn.Sequential(
31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
32 | nn.BatchNorm2d(inp),
33 | nn.LeakyReLU(negative_slope= leaky,inplace=True),
34 |
35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36 | nn.BatchNorm2d(oup),
37 | nn.LeakyReLU(negative_slope= leaky,inplace=True),
38 | )
39 |
40 | class SSH(nn.Module):
41 | def __init__(self, in_channel, out_channel):
42 | super(SSH, self).__init__()
43 | assert out_channel % 4 == 0
44 | leaky = 0
45 | if (out_channel <= 64):
46 | leaky = 0.1
47 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
48 |
49 | self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
50 | self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
51 |
52 | self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
53 | self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
54 |
55 | def forward(self, input):
56 | conv3X3 = self.conv3X3(input)
57 |
58 | conv5X5_1 = self.conv5X5_1(input)
59 | conv5X5 = self.conv5X5_2(conv5X5_1)
60 |
61 | conv7X7_2 = self.conv7X7_2(conv5X5_1)
62 | conv7X7 = self.conv7x7_3(conv7X7_2)
63 |
64 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
65 | out = F.relu(out)
66 | return out
67 |
68 | class FPN(nn.Module):
69 | def __init__(self,in_channels_list,out_channels):
70 | super(FPN,self).__init__()
71 | leaky = 0
72 | if (out_channels <= 64):
73 | leaky = 0.1
74 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
75 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
76 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
77 |
78 | self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
79 | self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
80 |
81 | def forward(self, input):
82 | # names = list(input.keys())
83 | input = list(input.values())
84 |
85 | output1 = self.output1(input[0])
86 | output2 = self.output2(input[1])
87 | output3 = self.output3(input[2])
88 |
89 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
90 | output2 = output2 + up3
91 | output2 = self.merge2(output2)
92 |
93 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
94 | output1 = output1 + up2
95 | output1 = self.merge1(output1)
96 |
97 | out = [output1, output2, output3]
98 | return out
99 |
100 |
101 |
102 | class MobileNetV1(nn.Module):
103 | def __init__(self):
104 | super(MobileNetV1, self).__init__()
105 | self.stage1 = nn.Sequential(
106 | conv_bn(3, 8, 2, leaky = 0.1), # 3
107 | conv_dw(8, 16, 1), # 7
108 | conv_dw(16, 32, 2), # 11
109 | conv_dw(32, 32, 1), # 19
110 | conv_dw(32, 64, 2), # 27
111 | conv_dw(64, 64, 1), # 43
112 | )
113 | self.stage2 = nn.Sequential(
114 | conv_dw(64, 128, 2), # 43 + 16 = 59
115 | conv_dw(128, 128, 1), # 59 + 32 = 91
116 | conv_dw(128, 128, 1), # 91 + 32 = 123
117 | conv_dw(128, 128, 1), # 123 + 32 = 155
118 | conv_dw(128, 128, 1), # 155 + 32 = 187
119 | conv_dw(128, 128, 1), # 187 + 32 = 219
120 | )
121 | self.stage3 = nn.Sequential(
122 | conv_dw(128, 256, 2), # 219 +3 2 = 241
123 | conv_dw(256, 256, 1), # 241 + 64 = 301
124 | )
125 | self.avg = nn.AdaptiveAvgPool2d((1,1))
126 | self.fc = nn.Linear(256, 1000)
127 |
128 | def forward(self, x):
129 | x = self.stage1(x)
130 | x = self.stage2(x)
131 | x = self.stage3(x)
132 | x = self.avg(x)
133 | # x = self.model(x)
134 | x = x.view(-1, 256)
135 | x = self.fc(x)
136 | return x
137 |
138 |
--------------------------------------------------------------------------------
/GPEN/retinaface/facemodels/retinaface.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models.detection.backbone_utils as backbone_utils
4 | import torchvision.models._utils as _utils
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 |
8 | from facemodels.net import MobileNetV1 as MobileNetV1
9 | from facemodels.net import FPN as FPN
10 | from facemodels.net import SSH as SSH
11 |
12 |
13 |
14 | class ClassHead(nn.Module):
15 | def __init__(self,inchannels=512,num_anchors=3):
16 | super(ClassHead,self).__init__()
17 | self.num_anchors = num_anchors
18 | self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
19 |
20 | def forward(self,x):
21 | out = self.conv1x1(x)
22 | out = out.permute(0,2,3,1).contiguous()
23 |
24 | return out.view(out.shape[0], -1, 2)
25 |
26 | class BboxHead(nn.Module):
27 | def __init__(self,inchannels=512,num_anchors=3):
28 | super(BboxHead,self).__init__()
29 | self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
30 |
31 | def forward(self,x):
32 | out = self.conv1x1(x)
33 | out = out.permute(0,2,3,1).contiguous()
34 |
35 | return out.view(out.shape[0], -1, 4)
36 |
37 | class LandmarkHead(nn.Module):
38 | def __init__(self,inchannels=512,num_anchors=3):
39 | super(LandmarkHead,self).__init__()
40 | self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
41 |
42 | def forward(self,x):
43 | out = self.conv1x1(x)
44 | out = out.permute(0,2,3,1).contiguous()
45 |
46 | return out.view(out.shape[0], -1, 10)
47 |
48 | class RetinaFace(nn.Module):
49 | def __init__(self, cfg = None, phase = 'train'):
50 | """
51 | :param cfg: Network related settings.
52 | :param phase: train or test.
53 | """
54 | super(RetinaFace,self).__init__()
55 | self.phase = phase
56 | backbone = None
57 | if cfg['name'] == 'mobilenet0.25':
58 | backbone = MobileNetV1()
59 | if cfg['pretrain']:
60 | checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
61 | from collections import OrderedDict
62 | new_state_dict = OrderedDict()
63 | for k, v in checkpoint['state_dict'].items():
64 | name = k[7:] # remove module.
65 | new_state_dict[name] = v
66 | # load params
67 | backbone.load_state_dict(new_state_dict)
68 | elif cfg['name'] == 'Resnet50':
69 | import torchvision.models as models
70 | backbone = models.resnet50(pretrained=cfg['pretrain'])
71 |
72 | self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
73 | in_channels_stage2 = cfg['in_channel']
74 | in_channels_list = [
75 | in_channels_stage2 * 2,
76 | in_channels_stage2 * 4,
77 | in_channels_stage2 * 8,
78 | ]
79 | out_channels = cfg['out_channel']
80 | self.fpn = FPN(in_channels_list,out_channels)
81 | self.ssh1 = SSH(out_channels, out_channels)
82 | self.ssh2 = SSH(out_channels, out_channels)
83 | self.ssh3 = SSH(out_channels, out_channels)
84 |
85 | self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
86 | self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
87 | self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
88 |
89 | def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
90 | classhead = nn.ModuleList()
91 | for i in range(fpn_num):
92 | classhead.append(ClassHead(inchannels,anchor_num))
93 | return classhead
94 |
95 | def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
96 | bboxhead = nn.ModuleList()
97 | for i in range(fpn_num):
98 | bboxhead.append(BboxHead(inchannels,anchor_num))
99 | return bboxhead
100 |
101 | def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
102 | landmarkhead = nn.ModuleList()
103 | for i in range(fpn_num):
104 | landmarkhead.append(LandmarkHead(inchannels,anchor_num))
105 | return landmarkhead
106 |
107 | def forward(self,inputs):
108 | out = self.body(inputs)
109 |
110 | # FPN
111 | fpn = self.fpn(out)
112 |
113 | # SSH
114 | feature1 = self.ssh1(fpn[0])
115 | feature2 = self.ssh2(fpn[1])
116 | feature3 = self.ssh3(fpn[2])
117 | features = [feature1, feature2, feature3]
118 |
119 | bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
120 | classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
121 | ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
122 |
123 | if self.phase == 'train':
124 | output = (bbox_regressions, classifications, ldm_regressions)
125 | else:
126 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
127 | return output
--------------------------------------------------------------------------------
/GPEN/retinaface/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .functions import *
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/GPEN/retinaface/layers/functions/prior_box.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from itertools import product as product
3 | import numpy as np
4 | from math import ceil
5 |
6 |
7 | class PriorBox(object):
8 | def __init__(self, cfg, image_size=None, phase='train'):
9 | super(PriorBox, self).__init__()
10 | self.min_sizes = cfg['min_sizes']
11 | self.steps = cfg['steps']
12 | self.clip = cfg['clip']
13 | self.image_size = image_size
14 | self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
15 | self.name = "s"
16 |
17 | def forward(self):
18 | anchors = []
19 | for k, f in enumerate(self.feature_maps):
20 | min_sizes = self.min_sizes[k]
21 | for i, j in product(range(f[0]), range(f[1])):
22 | for min_size in min_sizes:
23 | s_kx = min_size / self.image_size[1]
24 | s_ky = min_size / self.image_size[0]
25 | dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
26 | dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
27 | for cy, cx in product(dense_cy, dense_cx):
28 | anchors += [cx, cy, s_kx, s_ky]
29 |
30 | # back to torch land
31 | output = torch.Tensor(anchors).view(-1, 4)
32 | if self.clip:
33 | output.clamp_(max=1, min=0)
34 | return output
35 |
--------------------------------------------------------------------------------
/GPEN/retinaface/layers/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .multibox_loss import MultiBoxLoss
2 |
3 | __all__ = ['MultiBoxLoss']
4 |
--------------------------------------------------------------------------------
/GPEN/retinaface/layers/modules/multibox_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from utils.box_utils import match, log_sum_exp
6 | from data import cfg_mnet
7 | GPU = cfg_mnet['gpu_train']
8 |
9 | class MultiBoxLoss(nn.Module):
10 | """SSD Weighted Loss Function
11 | Compute Targets:
12 | 1) Produce Confidence Target Indices by matching ground truth boxes
13 | with (default) 'priorboxes' that have jaccard index > threshold parameter
14 | (default threshold: 0.5).
15 | 2) Produce localization target by 'encoding' variance into offsets of ground
16 | truth boxes and their matched 'priorboxes'.
17 | 3) Hard negative mining to filter the excessive number of negative examples
18 | that comes with using a large number of default bounding boxes.
19 | (default negative:positive ratio 3:1)
20 | Objective Loss:
21 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
22 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
23 | weighted by α which is set to 1 by cross val.
24 | Args:
25 | c: class confidences,
26 | l: predicted boxes,
27 | g: ground truth boxes
28 | N: number of matched default boxes
29 | See: https://arxiv.org/pdf/1512.02325.pdf for more details.
30 | """
31 |
32 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
33 | super(MultiBoxLoss, self).__init__()
34 | self.num_classes = num_classes
35 | self.threshold = overlap_thresh
36 | self.background_label = bkg_label
37 | self.encode_target = encode_target
38 | self.use_prior_for_matching = prior_for_matching
39 | self.do_neg_mining = neg_mining
40 | self.negpos_ratio = neg_pos
41 | self.neg_overlap = neg_overlap
42 | self.variance = [0.1, 0.2]
43 |
44 | def forward(self, predictions, priors, targets):
45 | """Multibox Loss
46 | Args:
47 | predictions (tuple): A tuple containing loc preds, conf preds,
48 | and prior boxes from SSD net.
49 | conf shape: torch.size(batch_size,num_priors,num_classes)
50 | loc shape: torch.size(batch_size,num_priors,4)
51 | priors shape: torch.size(num_priors,4)
52 |
53 | ground_truth (tensor): Ground truth boxes and labels for a batch,
54 | shape: [batch_size,num_objs,5] (last idx is the label).
55 | """
56 |
57 | loc_data, conf_data, landm_data = predictions
58 | priors = priors
59 | num = loc_data.size(0)
60 | num_priors = (priors.size(0))
61 |
62 | # match priors (default boxes) and ground truth boxes
63 | loc_t = torch.Tensor(num, num_priors, 4)
64 | landm_t = torch.Tensor(num, num_priors, 10)
65 | conf_t = torch.LongTensor(num, num_priors)
66 | for idx in range(num):
67 | truths = targets[idx][:, :4].data
68 | labels = targets[idx][:, -1].data
69 | landms = targets[idx][:, 4:14].data
70 | defaults = priors.data
71 | match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
72 | if GPU:
73 | loc_t = loc_t.cuda()
74 | conf_t = conf_t.cuda()
75 | landm_t = landm_t.cuda()
76 |
77 | zeros = torch.tensor(0).cuda()
78 | # landm Loss (Smooth L1)
79 | # Shape: [batch,num_priors,10]
80 | pos1 = conf_t > zeros
81 | num_pos_landm = pos1.long().sum(1, keepdim=True)
82 | N1 = max(num_pos_landm.data.sum().float(), 1)
83 | pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
84 | landm_p = landm_data[pos_idx1].view(-1, 10)
85 | landm_t = landm_t[pos_idx1].view(-1, 10)
86 | loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
87 |
88 |
89 | pos = conf_t != zeros
90 | conf_t[pos] = 1
91 |
92 | # Localization Loss (Smooth L1)
93 | # Shape: [batch,num_priors,4]
94 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
95 | loc_p = loc_data[pos_idx].view(-1, 4)
96 | loc_t = loc_t[pos_idx].view(-1, 4)
97 | loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
98 |
99 | # Compute max conf across batch for hard negative mining
100 | batch_conf = conf_data.view(-1, self.num_classes)
101 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
102 |
103 | # Hard Negative Mining
104 | loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
105 | loss_c = loss_c.view(num, -1)
106 | _, loss_idx = loss_c.sort(1, descending=True)
107 | _, idx_rank = loss_idx.sort(1)
108 | num_pos = pos.long().sum(1, keepdim=True)
109 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
110 | neg = idx_rank < num_neg.expand_as(idx_rank)
111 |
112 | # Confidence Loss Including Positive and Negative Examples
113 | pos_idx = pos.unsqueeze(2).expand_as(conf_data)
114 | neg_idx = neg.unsqueeze(2).expand_as(conf_data)
115 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
116 | targets_weighted = conf_t[(pos+neg).gt(0)]
117 | loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
118 |
119 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
120 | N = max(num_pos.data.sum().float(), 1)
121 | loss_l /= N
122 | loss_c /= N
123 | loss_landm /= N1
124 |
125 | return loss_l, loss_c, loss_landm
126 |
--------------------------------------------------------------------------------
/GPEN/retinaface/retinaface_detection.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import os
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import numpy as np
9 | from data import cfg_re50
10 | from layers.functions.prior_box import PriorBox
11 | from utils.nms.py_cpu_nms import py_cpu_nms
12 | import cv2
13 | from facemodels.retinaface import RetinaFace
14 | from utils.box_utils import decode, decode_landm
15 | import time
16 | import torch
17 |
18 | class RetinaFaceDetection(object):
19 | def __init__(self, base_dir, network='RetinaFace-R50'):
20 | torch.set_grad_enabled(False)
21 | cudnn.benchmark = True
22 | self.pretrained_path = os.path.join(base_dir, 'weights', network+'.pth')
23 | self.device = torch.cuda.current_device()
24 | self.cfg = cfg_re50
25 | self.net = RetinaFace(cfg=self.cfg, phase='test')
26 | self.load_model()
27 | self.net = self.net.cuda()
28 | self.net_trt = None
29 |
30 | def check_keys(self, pretrained_state_dict):
31 | ckpt_keys = set(pretrained_state_dict.keys())
32 | model_keys = set(self.net.state_dict().keys())
33 | used_pretrained_keys = model_keys & ckpt_keys
34 | unused_pretrained_keys = ckpt_keys - model_keys
35 | missing_keys = model_keys - ckpt_keys
36 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
37 | return True
38 |
39 | def remove_prefix(self, state_dict, prefix):
40 | ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
41 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
42 | return {f(key): value for key, value in state_dict.items()}
43 |
44 | def load_model(self, load_to_cpu=False):
45 | if load_to_cpu:
46 | pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage)
47 | else:
48 | pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage.cuda())
49 | if "state_dict" in pretrained_dict.keys():
50 | pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.')
51 | else:
52 | pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')
53 | self.check_keys(pretrained_dict)
54 | self.net.load_state_dict(pretrained_dict, strict=False)
55 | self.net.eval()
56 |
57 | def build_trt(self, img_raw):
58 | img = np.float32(img_raw)
59 |
60 | img -= (104, 117, 123)
61 | img = img.transpose(2, 0, 1)
62 | img = torch.from_numpy(img).unsqueeze(0)
63 | img = img.cuda()
64 |
65 | print('building trt model FaceGAN')
66 | from torch2trt import torch2trt
67 | self.net_trt = torch2trt(self.net, [img], fp16_mode=True)
68 | del self.net
69 | print('sucessfully built')
70 |
71 | def detect_trt(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False):
72 | img = np.float32(img_raw)
73 |
74 | im_height, im_width = img.shape[:2]
75 | scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
76 | img -= (104, 117, 123)
77 | img = img.transpose(2, 0, 1)
78 | img = torch.from_numpy(img).unsqueeze(0)
79 | img = img.cuda()
80 | scale = scale.cuda()
81 |
82 | loc, conf, landms = self.net_trt(img) # forward pass
83 |
84 | priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
85 | priors = priorbox.forward()
86 | priors = priors.cuda()
87 | prior_data = priors.data
88 | boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
89 | boxes = boxes * scale / resize
90 | boxes = boxes.cpu().numpy()
91 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
92 | landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance'])
93 | scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
94 | img.shape[3], img.shape[2], img.shape[3], img.shape[2],
95 | img.shape[3], img.shape[2]])
96 | scale1 = scale1.cuda()
97 | landms = landms * scale1 / resize
98 | landms = landms.cpu().numpy()
99 |
100 | # ignore low scores
101 | inds = np.where(scores > confidence_threshold)[0]
102 | boxes = boxes[inds]
103 | landms = landms[inds]
104 | scores = scores[inds]
105 |
106 | # keep top-K before NMS
107 | order = scores.argsort()[::-1][:top_k]
108 | boxes = boxes[order]
109 | landms = landms[order]
110 | scores = scores[order]
111 |
112 | # do NMS
113 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
114 | keep = py_cpu_nms(dets, nms_threshold)
115 | # keep = nms(dets, nms_threshold,force_cpu=args.cpu)
116 | dets = dets[keep, :]
117 | landms = landms[keep]
118 |
119 | # keep top-K faster NMS
120 | dets = dets[:keep_top_k, :]
121 | landms = landms[:keep_top_k, :]
122 |
123 | # sort faces(delete)
124 | '''
125 | fscores = [det[4] for det in dets]
126 | sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index
127 | tmp = [landms[idx] for idx in sorted_idx]
128 | landms = np.asarray(tmp)
129 | '''
130 |
131 | landms = landms.reshape((-1, 5, 2))
132 | landms = landms.transpose((0, 2, 1))
133 | landms = landms.reshape(-1, 10, )
134 | return dets, landms
135 |
136 |
137 | def detect(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False):
138 | img = np.float32(img_raw)
139 |
140 | im_height, im_width = img.shape[:2]
141 | scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
142 | img -= (104, 117, 123)
143 | img = img.transpose(2, 0, 1)
144 | img = torch.from_numpy(img).unsqueeze(0)
145 | img = img.cuda()
146 | scale = scale.cuda()
147 |
148 | loc, conf, landms = self.net(img) # forward pass
149 |
150 | priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
151 | priors = priorbox.forward()
152 | priors = priors.cuda()
153 | prior_data = priors.data
154 | boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
155 | boxes = boxes * scale / resize
156 | boxes = boxes.cpu().numpy()
157 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
158 | landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance'])
159 | scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
160 | img.shape[3], img.shape[2], img.shape[3], img.shape[2],
161 | img.shape[3], img.shape[2]])
162 | scale1 = scale1.cuda()
163 | landms = landms * scale1 / resize
164 | landms = landms.cpu().numpy()
165 |
166 | # ignore low scores
167 | inds = np.where(scores > confidence_threshold)[0]
168 | boxes = boxes[inds]
169 | landms = landms[inds]
170 | scores = scores[inds]
171 |
172 | # keep top-K before NMS
173 | order = scores.argsort()[::-1][:top_k]
174 | boxes = boxes[order]
175 | landms = landms[order]
176 | scores = scores[order]
177 |
178 | # do NMS
179 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
180 | keep = py_cpu_nms(dets, nms_threshold)
181 | # keep = nms(dets, nms_threshold,force_cpu=args.cpu)
182 | dets = dets[keep, :]
183 | landms = landms[keep]
184 |
185 | # keep top-K faster NMS
186 | dets = dets[:keep_top_k, :]
187 | landms = landms[:keep_top_k, :]
188 |
189 | # sort faces(delete)
190 | '''
191 | fscores = [det[4] for det in dets]
192 | sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index
193 | tmp = [landms[idx] for idx in sorted_idx]
194 | landms = np.asarray(tmp)
195 | '''
196 |
197 | landms = landms.reshape((-1, 5, 2))
198 | landms = landms.transpose((0, 2, 1))
199 | landms = landms.reshape(-1, 10, )
200 | return dets, landms
201 |
--------------------------------------------------------------------------------
/GPEN/retinaface/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/GPEN/retinaface/utils/__init__.py
--------------------------------------------------------------------------------
/GPEN/retinaface/utils/nms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/GPEN/retinaface/utils/nms/__init__.py
--------------------------------------------------------------------------------
/GPEN/retinaface/utils/nms/py_cpu_nms.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | import numpy as np
9 |
10 | def py_cpu_nms(dets, thresh):
11 | """Pure Python NMS baseline."""
12 | x1 = dets[:, 0]
13 | y1 = dets[:, 1]
14 | x2 = dets[:, 2]
15 | y2 = dets[:, 3]
16 | scores = dets[:, 4]
17 |
18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
19 | order = scores.argsort()[::-1]
20 |
21 | keep = []
22 | while order.size > 0:
23 | i = order[0]
24 | keep.append(i)
25 | xx1 = np.maximum(x1[i], x1[order[1:]])
26 | yy1 = np.maximum(y1[i], y1[order[1:]])
27 | xx2 = np.minimum(x2[i], x2[order[1:]])
28 | yy2 = np.minimum(y2[i], y2[order[1:]])
29 |
30 | w = np.maximum(0.0, xx2 - xx1 + 1)
31 | h = np.maximum(0.0, yy2 - yy1 + 1)
32 | inter = w * h
33 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
34 |
35 | inds = np.where(ovr <= thresh)[0]
36 | order = order[inds + 1]
37 |
38 | return keep
39 |
--------------------------------------------------------------------------------
/GPEN/retinaface/utils/timer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | import time
9 |
10 |
11 | class Timer(object):
12 | """A simple timer."""
13 | def __init__(self):
14 | self.total_time = 0.
15 | self.calls = 0
16 | self.start_time = 0.
17 | self.diff = 0.
18 | self.average_time = 0.
19 |
20 | def tic(self):
21 | # using time.time instead of time.clock because time time.clock
22 | # does not normalize for multithreading
23 | self.start_time = time.time()
24 |
25 | def toc(self, average=True):
26 | self.diff = time.time() - self.start_time
27 | self.total_time += self.diff
28 | self.calls += 1
29 | self.average_time = self.total_time / self.calls
30 | if average:
31 | return self.average_time
32 | else:
33 | return self.diff
34 |
35 | def clear(self):
36 | self.total_time = 0.
37 | self.calls = 0
38 | self.start_time = 0.
39 | self.diff = 0.
40 | self.average_time = 0.
41 |
--------------------------------------------------------------------------------
/GPEN/sr_model/arch_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn as nn
4 | from torch.nn import functional as F
5 | from torch.nn import init as init
6 | from torch.nn.modules.batchnorm import _BatchNorm
7 |
8 | @torch.no_grad()
9 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10 | """Initialize network weights.
11 |
12 | Args:
13 | module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14 | scale (float): Scale initialized weights, especially for residual
15 | blocks. Default: 1.
16 | bias_fill (float): The value to fill bias. Default: 0
17 | kwargs (dict): Other arguments for initialization function.
18 | """
19 | if not isinstance(module_list, list):
20 | module_list = [module_list]
21 | for module in module_list:
22 | for m in module.modules():
23 | if isinstance(m, nn.Conv2d):
24 | init.kaiming_normal_(m.weight, **kwargs)
25 | m.weight.data *= scale
26 | if m.bias is not None:
27 | m.bias.data.fill_(bias_fill)
28 | elif isinstance(m, nn.Linear):
29 | init.kaiming_normal_(m.weight, **kwargs)
30 | m.weight.data *= scale
31 | if m.bias is not None:
32 | m.bias.data.fill_(bias_fill)
33 | elif isinstance(m, _BatchNorm):
34 | init.constant_(m.weight, 1)
35 | if m.bias is not None:
36 | m.bias.data.fill_(bias_fill)
37 |
38 |
39 | def make_layer(basic_block, num_basic_block, **kwarg):
40 | """Make layers by stacking the same blocks.
41 |
42 | Args:
43 | basic_block (nn.module): nn.module class for basic block.
44 | num_basic_block (int): number of blocks.
45 |
46 | Returns:
47 | nn.Sequential: Stacked blocks in nn.Sequential.
48 | """
49 | layers = []
50 | for _ in range(num_basic_block):
51 | layers.append(basic_block(**kwarg))
52 | return nn.Sequential(*layers)
53 |
54 |
55 | class ResidualBlockNoBN(nn.Module):
56 | """Residual block without BN.
57 |
58 | It has a style of:
59 | ---Conv-ReLU-Conv-+-
60 | |________________|
61 |
62 | Args:
63 | num_feat (int): Channel number of intermediate features.
64 | Default: 64.
65 | res_scale (float): Residual scale. Default: 1.
66 | pytorch_init (bool): If set to True, use pytorch default init,
67 | otherwise, use default_init_weights. Default: False.
68 | """
69 |
70 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71 | super(ResidualBlockNoBN, self).__init__()
72 | self.res_scale = res_scale
73 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75 | self.relu = nn.ReLU(inplace=True)
76 |
77 | if not pytorch_init:
78 | default_init_weights([self.conv1, self.conv2], 0.1)
79 |
80 | def forward(self, x):
81 | identity = x
82 | out = self.conv2(self.relu(self.conv1(x)))
83 | return identity + out * self.res_scale
84 |
85 |
86 | class Upsample(nn.Sequential):
87 | """Upsample module.
88 |
89 | Args:
90 | scale (int): Scale factor. Supported scales: 2^n and 3.
91 | num_feat (int): Channel number of intermediate features.
92 | """
93 |
94 | def __init__(self, scale, num_feat):
95 | m = []
96 | if (scale & (scale - 1)) == 0: # scale = 2^n
97 | for _ in range(int(math.log(scale, 2))):
98 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99 | m.append(nn.PixelShuffle(2))
100 | elif scale == 3:
101 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102 | m.append(nn.PixelShuffle(3))
103 | else:
104 | raise ValueError(f'scale {scale} is not supported. '
105 | 'Supported scales: 2^n and 3.')
106 | super(Upsample, self).__init__(*m)
107 |
108 | # TODO: may write a cpp file
109 | def pixel_unshuffle(x, scale):
110 | """ Pixel unshuffle.
111 |
112 | Args:
113 | x (Tensor): Input feature with shape (b, c, hh, hw).
114 | scale (int): Downsample ratio.
115 |
116 | Returns:
117 | Tensor: the pixel unshuffled feature.
118 | """
119 | b, c, hh, hw = x.size()
120 | out_channel = c * (scale**2)
121 | assert hh % scale == 0 and hw % scale == 0
122 | h = hh // scale
123 | w = hw // scale
124 | x_view = x.view(b, c, h, scale, w, scale)
125 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
--------------------------------------------------------------------------------
/GPEN/sr_model/real_esrnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from rrdbnet_arch import RRDBNet
5 | from torch.nn import functional as F
6 | import torch
7 |
8 |
9 | class RealESRNet(object):
10 | def __init__(self, base_dir=os.path.dirname(__file__), model=None, scale=2):
11 | self.base_dir = base_dir
12 | self.scale = scale
13 | self.load_srmodel(base_dir, model)
14 | self.srmodel_trt = None
15 |
16 | def load_srmodel(self, base_dir, model):
17 | self.scale = 2 if "x2" in model else 4 if "x4" in model else -1
18 | if self.scale == -1:
19 | raise Exception("Scale not supported")
20 | self.srmodel = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=32, num_block=23, num_grow_ch=32, scale=self.scale)
21 | if model is None:
22 | loadnet = torch.load(os.path.join(self.base_dir, 'weights', 'realesrnet_x2.pth'))
23 | else:
24 | loadnet = torch.load(os.path.join(self.base_dir, 'weights', model+'.pth'))
25 | self.srmodel.load_state_dict(loadnet['params_ema'], strict=True)
26 | self.srmodel.eval()
27 | self.srmodel = self.srmodel.cuda()
28 |
29 | def build_trt(self, img):
30 | img = img.astype(np.float32) / 255.
31 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
32 | img = img.unsqueeze(0).cuda()
33 | print('building trt model srmodel')
34 | from torch2trt import torch2trt
35 | self.srmodel_trt = torch2trt(self.srmodel, [img], fp16_mode=True)
36 | print('sucessfully built')
37 | del self.srmodel
38 |
39 | def process_trt(self, img):
40 | img = img.astype(np.float32) / 255.
41 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
42 | img = img.unsqueeze(0).cuda()
43 |
44 | if self.scale == 2:
45 | mod_scale = 2
46 | elif self.scale == 1:
47 | mod_scale = 4
48 | else:
49 | mod_scale = None
50 | if mod_scale is not None:
51 | h_pad, w_pad = 0, 0
52 | _, _, h, w = img.size()
53 | if (h % mod_scale != 0):
54 | h_pad = (mod_scale - h % mod_scale)
55 | if (w % mod_scale != 0):
56 | w_pad = (mod_scale - w % mod_scale)
57 | img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
58 | try:
59 | with torch.no_grad():
60 | output = self.srmodel_trt(img)
61 | # remove extra pad
62 | if mod_scale is not None:
63 | _, _, h, w = output.size()
64 | output = output[:, :, 0:h - h_pad, 0:w - w_pad]
65 | output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
66 | output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
67 | output = (output * 255.0).round().astype(np.uint8)
68 |
69 | return output
70 | except:
71 | return None
72 |
73 | def process(self, img):
74 | img = img.astype(np.float32) / 255.
75 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
76 | img = img.unsqueeze(0).cuda()
77 | # print(img.shape)
78 |
79 | if self.scale == 2:
80 | mod_scale = 2
81 | elif self.scale == 1:
82 | mod_scale = 4
83 | else:
84 | mod_scale = None
85 | if mod_scale is not None:
86 | h_pad, w_pad = 0, 0
87 | _, _, h, w = img.size()
88 | if (h % mod_scale != 0):
89 | h_pad = (mod_scale - h % mod_scale)
90 | if (w % mod_scale != 0):
91 | w_pad = (mod_scale - w % mod_scale)
92 | img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
93 | try:
94 | with torch.no_grad():
95 | output = self.srmodel(img)
96 | # remove extra pad
97 | if mod_scale is not None:
98 | _, _, h, w = output.size()
99 | output = output[:, :, 0:h - h_pad, 0:w - w_pad]
100 | output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
101 | output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
102 | output = (output * 255.0).round().astype(np.uint8)
103 |
104 | return output
105 | except:
106 | return None
107 |
108 |
--------------------------------------------------------------------------------
/GPEN/sr_model/rrdbnet_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 | from arch_util import default_init_weights, make_layer, pixel_unshuffle
6 |
7 |
8 | class ResidualDenseBlock(nn.Module):
9 | """Residual Dense Block.
10 |
11 | Used in RRDB block in ESRGAN.
12 |
13 | Args:
14 | num_feat (int): Channel number of intermediate features.
15 | num_grow_ch (int): Channels for each growth.
16 | """
17 |
18 | def __init__(self, num_feat=64, num_grow_ch=32):
19 | super(ResidualDenseBlock, self).__init__()
20 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25 |
26 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27 |
28 | # initialization
29 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30 |
31 | def forward(self, x):
32 | x1 = self.lrelu(self.conv1(x))
33 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37 | # Emperically, we use 0.2 to scale the residual for better performance
38 | return x5 * 0.2 + x
39 |
40 |
41 | class RRDB(nn.Module):
42 | """Residual in Residual Dense Block.
43 |
44 | Used in RRDB-Net in ESRGAN.
45 |
46 | Args:
47 | num_feat (int): Channel number of intermediate features.
48 | num_grow_ch (int): Channels for each growth.
49 | """
50 |
51 | def __init__(self, num_feat, num_grow_ch=32):
52 | super(RRDB, self).__init__()
53 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56 |
57 | def forward(self, x):
58 | out = self.rdb1(x)
59 | out = self.rdb2(out)
60 | out = self.rdb3(out)
61 | # Emperically, we use 0.2 to scale the residual for better performance
62 | return out * 0.2 + x
63 |
64 | class RRDBNet(nn.Module):
65 | """Networks consisting of Residual in Residual Dense Block, which is used
66 | in ESRGAN.
67 |
68 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
69 |
70 | We extend ESRGAN for scale x2 and scale x1.
71 | Note: This is one option for scale 1, scale 2 in RRDBNet.
72 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
73 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
74 |
75 | Args:
76 | num_in_ch (int): Channel number of inputs.
77 | num_out_ch (int): Channel number of outputs.
78 | num_feat (int): Channel number of intermediate features.
79 | Default: 64
80 | num_block (int): Block number in the trunk network. Defaults: 23
81 | num_grow_ch (int): Channels for each growth. Default: 32.
82 | """
83 |
84 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
85 | super(RRDBNet, self).__init__()
86 | self.scale = scale
87 | if scale == 2:
88 | num_in_ch = num_in_ch * 4
89 | elif scale == 1:
90 | num_in_ch = num_in_ch * 16
91 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
92 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
93 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
94 | # upsample
95 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
96 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
99 |
100 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
101 |
102 | def forward(self, x):
103 | if self.scale == 2:
104 | feat = pixel_unshuffle(x, scale=2)
105 | elif self.scale == 1:
106 | feat = pixel_unshuffle(x, scale=4)
107 | else:
108 | feat = x
109 | feat = self.conv_first(feat)
110 | body_feat = self.conv_body(self.body(feat))
111 | feat = feat + body_feat
112 | # upsample
113 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
114 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
115 | out = self.conv_last(self.lrelu(self.conv_hr(feat)))
116 | return out
117 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Face Animation in Real Time
2 | One-shot face animation using webcam, capable of running in real time.
3 | ## **Examples Results**
4 | (Driving video | Result video)
5 | - Original Result without Face Restoration
6 |
7 | https://github.com/sky24h/Face_Mapping_Real_Time/assets/26270672/231778e3-0f37-42c3-8cb0-cf849b22c8a8
8 |
9 | - With Face Restoration
10 |
11 | https://github.com/sky24h/Face_Mapping_Real_Time/assets/26270672/323fb958-77b4-444d-9a21-4d0245fb108c
12 |
13 | # How to Use
14 | ### 0. Dependencies
15 | ```
16 | pip install -r requirements.txt
17 | ```
18 | ### **1. For local webcam**
19 | Tested on RTX 3090, got 17 FPS without face restoration, and 10 FPS with face restoration.
20 | ```
21 | python camera_local.py --source_image ./assets/source.jpg --restore_face False
22 | ```
23 | The model output only supports size of 256, but you can change the output size to 512x512 or larger to get a resized output.
24 |
25 | ### **2. For input driving video**
26 | ```
27 | python camera_local.py --source_image ./assets/source.jpg --restore_face False --driving_video ./assets/driving.mp4 --result_video ./result_video.mp4 --output_size 512
28 | ```
29 | The driving video does not require any preprocessing, it is valid to use as long as every frame contains a face.
30 |
31 | ### **3. For remote access (Not recommended)**
32 | First you need to bind the port between server and client, for example, using vscode remote ssh like [this](https://code.visualstudio.com/docs/editor/port-forwarding).
33 | Then run the server side on the remote server, and run the client side on the local machine.
34 |
35 | Notably, due to the network latency, the FPS is low (only 1~2 FPS).
36 |
37 | Server Side:
38 | ```
39 | python remote_server.py --source_image ./assets/source.jpg --restore_face False
40 | ```
41 |
42 | Client Side (Copy only this file to local machine):
43 | ```
44 | python remote_client.py
45 | ```
46 | ---
47 | ### Pre-trained Models
48 | All necessary pre-trained models should be downloaded automatically when running the demo.
49 | If you somehow need to download them manually, please refer to the following links:
50 |
51 | [Motion Transfer Model](https://drive.google.com/file/d/11ZgyjKI5OcB7klcsIdPpCCX38AIX8Soc/view?usp=drive_link)
52 |
53 | [GPEN (Face Restoration Model](https://drive.google.com/drive/folders/1epln5c8HW1QXfVz6444Fe0hG-vRNavi6?usp=drive_link)
54 |
55 |
56 | # Acknowlegement:
57 | Motion transfer is modified from [zhanglonghao1992](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).
58 | Face restoration is modified from [GPEN](https://github.com/yangxy/GPEN).
59 |
60 | Thanks to the authors for their great work!
61 |
--------------------------------------------------------------------------------
/assets/driving.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/assets/driving.mp4
--------------------------------------------------------------------------------
/assets/source.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/assets/source.jpg
--------------------------------------------------------------------------------
/camera_client.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import time
3 | import numpy as np
4 | import asyncio
5 | import websockets
6 | from argparse import ArgumentParser
7 |
8 | websocket_port = 8066
9 |
10 |
11 | class VideoCamera(object):
12 | def __init__(self, CameraSize=(640, 480)):
13 | self.video = cv2.VideoCapture(0)
14 | self.video.set(cv2.CAP_PROP_FRAME_WIDTH, CameraSize[0])
15 | self.video.set(cv2.CAP_PROP_FRAME_HEIGHT, CameraSize[1])
16 | self.video.set(cv2.CAP_PROP_FPS, 24)
17 | # check if camera opened successfully
18 | if not self.video.isOpened():
19 | raise Exception("Camera not found")
20 |
21 | def __del__(self):
22 | self.video.release()
23 |
24 | def get_frame(self):
25 | success, image = self.video.read()
26 | image = cv2.flip(image, 1)
27 | return image
28 |
29 |
30 | async def send_image(image, ScreenSize=512, SendSize=256):
31 | # Encode the image as bytes
32 | _, image_data = cv2.imencode(".jpg", cv2.resize(image, (SendSize, SendSize)), [int(cv2.IMWRITE_JPEG_QUALITY), 90])
33 | image_bytes = image_data.tobytes()
34 | # print size
35 | print("Image size: ", len(image_bytes))
36 |
37 | # Connect to the FastAPI WebSocket server
38 | async with websockets.connect("ws://localhost:{}/ws".format(websocket_port)) as websocket:
39 | # Send the image to the server
40 | await websocket.send(image_bytes)
41 | print("Image sent to the server")
42 | # Receive and process the processed frame
43 | try:
44 | processed_frame_data = await websocket.recv()
45 |
46 | # Decode the processed frame
47 | processed_frame = cv2.imdecode(np.frombuffer(processed_frame_data, dtype=np.uint8), -1)
48 | processed_frame = cv2.resize(processed_frame, (ScreenSize * 2, ScreenSize))
49 | # return processed_frame
50 | except Exception as e:
51 | print(e)
52 | # return image
53 | processed_frame = np.ones((ScreenSize, ScreenSize, 3), dtype=np.uint8) * 255
54 | cv2.putText(processed_frame, "No response from the server", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
55 | cv2.imshow("Frame", processed_frame)
56 | cv2.waitKey(1)
57 |
58 |
59 | if __name__ == "__main__":
60 | parser = ArgumentParser()
61 | parser.add_argument("--output_size", default=512, type=int, help="size of the output video")
62 | args = parser.parse_args()
63 |
64 |
65 | camera = VideoCamera()
66 |
67 | frame_count = 0
68 | times = []
69 | while True:
70 | image = camera.get_frame()
71 | frame_count += 1
72 | time_start = time.time()
73 | asyncio.run(send_image(image, ScreenSize=args.output_size, SendSize=256))
74 | times.append(time.time() - time_start)
75 | if frame_count % 10 == 0:
76 | print("FPS: {:.2f}".format(1 / np.mean(times)))
77 | times = []
78 |
--------------------------------------------------------------------------------
/camera_local.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import time
3 | import numpy as np
4 |
5 | from argparse import ArgumentParser
6 | from demo_utils import FaceAnimationClass
7 |
8 |
9 | class VideoCamera(object):
10 | def __init__(self, video_path=0, CameraSize=(640, 480)):
11 | self.video_path = video_path
12 | self.video = cv2.VideoCapture(video_path) if video_path != 0 else cv2.VideoCapture(0)
13 | self.video.set(cv2.CAP_PROP_FRAME_WIDTH, CameraSize[0])
14 | self.video.set(cv2.CAP_PROP_FRAME_HEIGHT, CameraSize[1])
15 | self.video.set(cv2.CAP_PROP_FPS, 24)
16 | # check if camera opened successfully
17 | if video_path == 0 and not self.video.isOpened():
18 | raise Exception("Camera not found")
19 | elif video_path != 0 and not self.video.isOpened():
20 | raise Exception("Video file not found")
21 |
22 | def __del__(self):
23 | self.video.release()
24 |
25 | def get_frame(self):
26 | success, image = self.video.read()
27 | image = cv2.flip(image, 1) if self.video_path == 0 else image
28 | return image
29 |
30 |
31 | def process_frame(image, ScreenSize=512):
32 | face, result = faceanimation.inference(image)
33 | if face.shape[1] != ScreenSize or face.shape[0] != ScreenSize:
34 | face = cv2.resize(face, (ScreenSize, ScreenSize))
35 | if result.shape[0] != ScreenSize or result.shape[1] != ScreenSize:
36 | result = cv2.resize(result, (ScreenSize, ScreenSize))
37 | return cv2.hconcat([face, result])
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = ArgumentParser()
42 | parser.add_argument("--source_image", default="./assets/source.jpg", help="path to source image")
43 | parser.add_argument("--driving_video", default=None, help="path to driving video")
44 | parser.add_argument("--result_video", default="./result_video.mp4", help="path to output")
45 | parser.add_argument("--output_size", default=512, type=int, help="size of the output video")
46 | parser.add_argument("--restore_face", default=False, type=str, help="restore face")
47 | args = parser.parse_args()
48 | restore_face = True if args.restore_face == 'True' else False if args.restore_face == 'False' else exit('restore_face must be True or False')
49 |
50 | if args.driving_video is None:
51 | video_path = 0
52 | print("Using webcam")
53 | # create window for displaying results
54 | cv2.namedWindow("Face Animation", cv2.WINDOW_NORMAL)
55 | else:
56 | video_path = args.driving_video
57 | print("Using driving video: {}".format(video_path))
58 | camera = VideoCamera(video_path=video_path)
59 | faceanimation = FaceAnimationClass(source_image_path=args.source_image, use_sr=restore_face)
60 |
61 | frames = [] if args.result_video is not None else None
62 | frame_count = 0
63 | times = []
64 | while True:
65 | time_start = time.time()
66 | image = camera.get_frame()
67 | if image is None and frame_count != 0:
68 | print("Video ended")
69 | break
70 | try:
71 | res = process_frame(image, ScreenSize=args.output_size)
72 | frame_count += 1
73 | times.append(time.time() - time_start)
74 | if frame_count % 100 == 0:
75 | print("FPS: {:.2f}".format(1 / np.mean(times)))
76 | times = []
77 | frames.append(res) if args.result_video is not None else None
78 | # display results if using webcam
79 | if args.driving_video is None:
80 | cv2.imshow("Face Animation", res)
81 | if cv2.waitKey(1) & 0xFF == ord("q"):
82 | break
83 | except Exception as e:
84 | print(e)
85 | raise e
86 |
87 | if args.result_video is not None:
88 | import imageio
89 | from tqdm import tqdm
90 |
91 | writer = imageio.get_writer(args.result_video, fps=24, quality=9, macro_block_size=1, codec="libx264", pixelformat="yuv420p")
92 | for frame in tqdm(frames):
93 | writer.append_data(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
94 | writer.close()
95 | print("Video saved to {}".format(args.result_video))
96 |
--------------------------------------------------------------------------------
/face-vid2vid/README.md:
--------------------------------------------------------------------------------
1 | # One-Shot Free-View Neural Talking Head Synthesis
2 | Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing".
3 |
4 | ```Python 3.6``` and ```Pytorch 1.7``` are used.
5 |
6 |
7 | Updates:
8 | --------
9 | ```2021.11.05``` :
10 | * Replace Jacobian with the rotation matrix (Assuming J = R) to avoid estimating Jacobian.
11 | * Correct the rotation matrix.
12 |
13 | ```2021.11.17``` :
14 | * Better Generator, better performance (models and checkpoints have been released).
15 |
16 | Driving | Beta Version | FOMM | New Version:
17 |
18 |
19 | https://user-images.githubusercontent.com/17874285/142828000-db7b324e-c2fd-4fdc-a272-04fb8adbc88a.mp4
20 |
21 |
22 | --------
23 | Driving | FOMM | Ours:
24 | 
25 |
26 | Free-View:
27 | 
28 |
29 | Train:
30 | --------
31 | ```
32 | python run.py --config config/vox-256.yaml --device_ids 0,1,2,3,4,5,6,7
33 | ```
34 |
35 | Demo:
36 | --------
37 | ```
38 | python demo.py --config config/vox-256.yaml --checkpoint path/to/checkpoint --source_image path/to/source --driving_video path/to/driving --relative --adapt_scale --find_best_frame
39 | ```
40 | free-view (e.g. yaw=20, pitch=roll=0):
41 | ```
42 | python demo.py --config config/vox-256.yaml --checkpoint path/to/checkpoint --source_image path/to/source --driving_video path/to/driving --relative --adapt_scale --find_best_frame --free_view --yaw 20 --pitch 0 --roll 0
43 | ```
44 | Note: run ```crop-video.py --inp driving_video.mp4``` first to get the cropping suggestion and crop the raw video.
45 |
46 | Pretrained Model:
47 | --------
48 |
49 | Model | Train Set | Baidu Netdisk | Media Fire |
50 | ------- |------------ |----------- |-------- |
51 | Vox-256-Beta| VoxCeleb-v1 | [Baidu](https://pan.baidu.com/s/1lLS4ArbK2yWelsL-EtwU8g) (PW: c0tc)| [MF](https://www.mediafire.com/folder/rw51an7tk7bh2/TalkingHead) |
52 | Vox-256-New | VoxCeleb-v1 | - | [MF](https://www.mediafire.com/folder/fcvtkn21j57bb/TalkingHead_Update) |
53 | Vox-512 | VoxCeleb-v2 | soon | soon |
54 |
55 | Note:
56 | 1. For now, the Beta Version is not well tuned.
57 | 2. For free-view synthesis, it is recommended that Yaw, Pitch and Roll are within ±45°, ±20° and ±20° respectively.
58 | 3. Face Restoration algorithms ([GPEN](https://github.com/yangxy/GPEN)) can be used for post-processing to significantly improve the resolution.
59 | 
60 |
61 |
62 | Acknowlegement:
63 | --------
64 | Thanks to [NV](https://github.com/NVlabs/face-vid2vid), [AliaksandrSiarohin](https://github.com/AliaksandrSiarohin/first-order-model) and [DeepHeadPose](https://github.com/DriverDistraction/DeepHeadPose).
65 |
--------------------------------------------------------------------------------
/face-vid2vid/animate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | import imageio
8 | from scipy.spatial import ConvexHull
9 | import numpy as np
10 |
11 | from sync_batchnorm import DataParallelWithCallback
12 |
13 | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
14 | use_relative_movement=False, use_relative_jacobian=False):
15 | if adapt_movement_scale:
16 | source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
17 | driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
18 | adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
19 | else:
20 | adapt_movement_scale = 1
21 |
22 | kp_new = {k: v for k, v in kp_driving.items()}
23 |
24 | if use_relative_movement:
25 | kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
26 | kp_value_diff *= adapt_movement_scale
27 | kp_new['value'] = kp_value_diff + kp_source['value']
28 |
29 | if use_relative_jacobian:
30 | jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
31 | kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
32 |
33 | return kp_new
34 |
--------------------------------------------------------------------------------
/face-vid2vid/augmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | Code from https://github.com/hassony2/torch_videovision
3 | """
4 |
5 | import numbers
6 |
7 | import random
8 | import numpy as np
9 | import PIL
10 |
11 | from skimage.transform import resize, rotate
12 | from skimage.util import pad
13 | import torchvision
14 |
15 | import warnings
16 |
17 | from skimage import img_as_ubyte, img_as_float
18 |
19 |
20 | def crop_clip(clip, min_h, min_w, h, w):
21 | if isinstance(clip[0], np.ndarray):
22 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
23 |
24 | elif isinstance(clip[0], PIL.Image.Image):
25 | cropped = [
26 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
27 | ]
28 | else:
29 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
30 | 'but got list of {0}'.format(type(clip[0])))
31 | return cropped
32 |
33 |
34 | def pad_clip(clip, h, w):
35 | im_h, im_w = clip[0].shape[:2]
36 | pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
37 | pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
38 |
39 | return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
40 |
41 |
42 | def resize_clip(clip, size, interpolation='bilinear'):
43 | if isinstance(clip[0], np.ndarray):
44 | if isinstance(size, numbers.Number):
45 | im_h, im_w, im_c = clip[0].shape
46 | # Min spatial dim already matches minimal size
47 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
48 | and im_h == size):
49 | return clip
50 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
51 | size = (new_w, new_h)
52 | else:
53 | size = size[1], size[0]
54 |
55 | scaled = [
56 | resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
57 | mode='constant', anti_aliasing=True) for img in clip
58 | ]
59 | elif isinstance(clip[0], PIL.Image.Image):
60 | if isinstance(size, numbers.Number):
61 | im_w, im_h = clip[0].size
62 | # Min spatial dim already matches minimal size
63 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
64 | and im_h == size):
65 | return clip
66 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
67 | size = (new_w, new_h)
68 | else:
69 | size = size[1], size[0]
70 | if interpolation == 'bilinear':
71 | pil_inter = PIL.Image.NEAREST
72 | else:
73 | pil_inter = PIL.Image.BILINEAR
74 | scaled = [img.resize(size, pil_inter) for img in clip]
75 | else:
76 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
77 | 'but got list of {0}'.format(type(clip[0])))
78 | return scaled
79 |
80 |
81 | def get_resize_sizes(im_h, im_w, size):
82 | if im_w < im_h:
83 | ow = size
84 | oh = int(size * im_h / im_w)
85 | else:
86 | oh = size
87 | ow = int(size * im_w / im_h)
88 | return oh, ow
89 |
90 |
91 | class RandomFlip(object):
92 | def __init__(self, time_flip=False, horizontal_flip=False):
93 | self.time_flip = time_flip
94 | self.horizontal_flip = horizontal_flip
95 |
96 | def __call__(self, clip):
97 | if random.random() < 0.5 and self.time_flip:
98 | return clip[::-1]
99 | if random.random() < 0.5 and self.horizontal_flip:
100 | return [np.fliplr(img) for img in clip]
101 |
102 | return clip
103 |
104 |
105 | class RandomResize(object):
106 | """Resizes a list of (H x W x C) numpy.ndarray to the final size
107 | The larger the original image is, the more times it takes to
108 | interpolate
109 | Args:
110 | interpolation (str): Can be one of 'nearest', 'bilinear'
111 | defaults to nearest
112 | size (tuple): (widht, height)
113 | """
114 |
115 | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
116 | self.ratio = ratio
117 | self.interpolation = interpolation
118 |
119 | def __call__(self, clip):
120 | scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
121 |
122 | if isinstance(clip[0], np.ndarray):
123 | im_h, im_w, im_c = clip[0].shape
124 | elif isinstance(clip[0], PIL.Image.Image):
125 | im_w, im_h = clip[0].size
126 |
127 | new_w = int(im_w * scaling_factor)
128 | new_h = int(im_h * scaling_factor)
129 | new_size = (new_w, new_h)
130 | resized = resize_clip(
131 | clip, new_size, interpolation=self.interpolation)
132 |
133 | return resized
134 |
135 |
136 | class RandomCrop(object):
137 | """Extract random crop at the same location for a list of videos
138 | Args:
139 | size (sequence or int): Desired output size for the
140 | crop in format (h, w)
141 | """
142 |
143 | def __init__(self, size):
144 | if isinstance(size, numbers.Number):
145 | size = (size, size)
146 |
147 | self.size = size
148 |
149 | def __call__(self, clip):
150 | """
151 | Args:
152 | img (PIL.Image or numpy.ndarray): List of videos to be cropped
153 | in format (h, w, c) in numpy.ndarray
154 | Returns:
155 | PIL.Image or numpy.ndarray: Cropped list of videos
156 | """
157 | h, w = self.size
158 | if isinstance(clip[0], np.ndarray):
159 | im_h, im_w, im_c = clip[0].shape
160 | elif isinstance(clip[0], PIL.Image.Image):
161 | im_w, im_h = clip[0].size
162 | else:
163 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
164 | 'but got list of {0}'.format(type(clip[0])))
165 |
166 | clip = pad_clip(clip, h, w)
167 | im_h, im_w = clip.shape[1:3]
168 | x1 = 0 if h == im_h else random.randint(0, im_w - w)
169 | y1 = 0 if w == im_w else random.randint(0, im_h - h)
170 | cropped = crop_clip(clip, y1, x1, h, w)
171 |
172 | return cropped
173 |
174 |
175 | class RandomRotation(object):
176 | """Rotate entire clip randomly by a random angle within
177 | given bounds
178 | Args:
179 | degrees (sequence or int): Range of degrees to select from
180 | If degrees is a number instead of sequence like (min, max),
181 | the range of degrees, will be (-degrees, +degrees).
182 | """
183 |
184 | def __init__(self, degrees):
185 | if isinstance(degrees, numbers.Number):
186 | if degrees < 0:
187 | raise ValueError('If degrees is a single number,'
188 | 'must be positive')
189 | degrees = (-degrees, degrees)
190 | else:
191 | if len(degrees) != 2:
192 | raise ValueError('If degrees is a sequence,'
193 | 'it must be of len 2.')
194 |
195 | self.degrees = degrees
196 |
197 | def __call__(self, clip):
198 | """
199 | Args:
200 | img (PIL.Image or numpy.ndarray): List of videos to be cropped
201 | in format (h, w, c) in numpy.ndarray
202 | Returns:
203 | PIL.Image or numpy.ndarray: Cropped list of videos
204 | """
205 | angle = random.uniform(self.degrees[0], self.degrees[1])
206 | if isinstance(clip[0], np.ndarray):
207 | rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
208 | elif isinstance(clip[0], PIL.Image.Image):
209 | rotated = [img.rotate(angle) for img in clip]
210 | else:
211 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
212 | 'but got list of {0}'.format(type(clip[0])))
213 |
214 | return rotated
215 |
216 |
217 | class ColorJitter(object):
218 | """Randomly change the brightness, contrast and saturation and hue of the clip
219 | Args:
220 | brightness (float): How much to jitter brightness. brightness_factor
221 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
222 | contrast (float): How much to jitter contrast. contrast_factor
223 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
224 | saturation (float): How much to jitter saturation. saturation_factor
225 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
226 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
227 | [-hue, hue]. Should be >=0 and <= 0.5.
228 | """
229 |
230 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
231 | self.brightness = brightness
232 | self.contrast = contrast
233 | self.saturation = saturation
234 | self.hue = hue
235 |
236 | def get_params(self, brightness, contrast, saturation, hue):
237 | if brightness > 0:
238 | brightness_factor = random.uniform(
239 | max(0, 1 - brightness), 1 + brightness)
240 | else:
241 | brightness_factor = None
242 |
243 | if contrast > 0:
244 | contrast_factor = random.uniform(
245 | max(0, 1 - contrast), 1 + contrast)
246 | else:
247 | contrast_factor = None
248 |
249 | if saturation > 0:
250 | saturation_factor = random.uniform(
251 | max(0, 1 - saturation), 1 + saturation)
252 | else:
253 | saturation_factor = None
254 |
255 | if hue > 0:
256 | hue_factor = random.uniform(-hue, hue)
257 | else:
258 | hue_factor = None
259 | return brightness_factor, contrast_factor, saturation_factor, hue_factor
260 |
261 | def __call__(self, clip):
262 | """
263 | Args:
264 | clip (list): list of PIL.Image
265 | Returns:
266 | list PIL.Image : list of transformed PIL.Image
267 | """
268 | if isinstance(clip[0], np.ndarray):
269 | brightness, contrast, saturation, hue = self.get_params(
270 | self.brightness, self.contrast, self.saturation, self.hue)
271 |
272 | # Create img transform function sequence
273 | img_transforms = []
274 | if brightness is not None:
275 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
276 | if saturation is not None:
277 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
278 | if hue is not None:
279 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
280 | if contrast is not None:
281 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
282 | random.shuffle(img_transforms)
283 | img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
284 | img_as_float]
285 |
286 | with warnings.catch_warnings():
287 | warnings.simplefilter("ignore")
288 | jittered_clip = []
289 | for img in clip:
290 | jittered_img = img
291 | for func in img_transforms:
292 | jittered_img = func(jittered_img)
293 | jittered_clip.append(jittered_img.astype('float32'))
294 | elif isinstance(clip[0], PIL.Image.Image):
295 | brightness, contrast, saturation, hue = self.get_params(
296 | self.brightness, self.contrast, self.saturation, self.hue)
297 |
298 | # Create img transform function sequence
299 | img_transforms = []
300 | if brightness is not None:
301 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
302 | if saturation is not None:
303 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
304 | if hue is not None:
305 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
306 | if contrast is not None:
307 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
308 | random.shuffle(img_transforms)
309 |
310 | # Apply to all videos
311 | jittered_clip = []
312 | for img in clip:
313 | for func in img_transforms:
314 | jittered_img = func(img)
315 | jittered_clip.append(jittered_img)
316 |
317 | else:
318 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
319 | 'but got list of {0}'.format(type(clip[0])))
320 | return jittered_clip
321 |
322 |
323 | class AllAugmentationTransform:
324 | def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
325 | self.transforms = []
326 |
327 | if flip_param is not None:
328 | self.transforms.append(RandomFlip(**flip_param))
329 |
330 | if rotation_param is not None:
331 | self.transforms.append(RandomRotation(**rotation_param))
332 |
333 | if resize_param is not None:
334 | self.transforms.append(RandomResize(**resize_param))
335 |
336 | if crop_param is not None:
337 | self.transforms.append(RandomCrop(**crop_param))
338 |
339 | if jitter_param is not None:
340 | self.transforms.append(ColorJitter(**jitter_param))
341 |
342 | def __call__(self, clip):
343 | for t in self.transforms:
344 | clip = t(clip)
345 | return clip
346 |
--------------------------------------------------------------------------------
/face-vid2vid/config/vox-256-spade.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | root_dir:
3 | frame_shape: [256, 256, 3]
4 | id_sampling: True
5 | pairs_list: None
6 | augmentation_params:
7 | flip_param:
8 | horizontal_flip: True
9 | time_flip: True
10 | jitter_param:
11 | brightness: 0.1
12 | contrast: 0.1
13 | saturation: 0.1
14 | hue: 0.1
15 |
16 |
17 | model_params:
18 | common_params:
19 | num_kp: 15
20 | image_channel: 3
21 | feature_channel: 32
22 | estimate_jacobian: False
23 | kp_detector_params:
24 | temperature: 0.1
25 | block_expansion: 32
26 | max_features: 1024
27 | scale_factor: 0.25
28 | num_blocks: 5
29 | reshape_channel: 16384 # 16384 = 1024 * 16
30 | reshape_depth: 16
31 | he_estimator_params:
32 | block_expansion: 64
33 | max_features: 2048
34 | num_bins: 66
35 | generator_params:
36 | block_expansion: 64
37 | max_features: 512
38 | num_down_blocks: 2
39 | reshape_channel: 32
40 | reshape_depth: 16 # 512 = 32 * 16
41 | num_resblocks: 6
42 | estimate_occlusion_map: True
43 | dense_motion_params:
44 | block_expansion: 32
45 | max_features: 1024
46 | num_blocks: 5
47 | # reshape_channel: 32
48 | reshape_depth: 16
49 | compress: 4
50 | discriminator_params:
51 | scales: [1]
52 | block_expansion: 32
53 | max_features: 512
54 | num_blocks: 4
55 | sn: True
56 |
57 | train_params:
58 | num_epochs: 200
59 | num_repeats: 75
60 | epoch_milestones: [180,]
61 | lr_generator: 2.0e-4
62 | lr_discriminator: 2.0e-4
63 | lr_kp_detector: 2.0e-4
64 | lr_he_estimator: 2.0e-4
65 | gan_mode: 'hinge' # hinge or ls
66 | batch_size: 1
67 | scales: [1, 0.5, 0.25, 0.125]
68 | checkpoint_freq: 60
69 | hopenet_snapshot: './checkpoints/hopenet_robust_alpha1.pkl'
70 | transform_params:
71 | sigma_affine: 0.05
72 | sigma_tps: 0.005
73 | points_tps: 5
74 | loss_weights:
75 | generator_gan: 1
76 | discriminator_gan: 1
77 | feature_matching: [10, 10, 10, 10]
78 | perceptual: [10, 10, 10, 10, 10]
79 | equivariance_value: 10
80 | equivariance_jacobian: 0
81 | keypoint: 10
82 | headpose: 20
83 | expression: 5
84 |
85 | visualizer_params:
86 | kp_size: 5
87 | draw_border: True
88 | colormap: 'gist_rainbow'
89 |
--------------------------------------------------------------------------------
/face-vid2vid/crop-video.py:
--------------------------------------------------------------------------------
1 | import face_alignment
2 | import skimage.io
3 | import numpy
4 | from argparse import ArgumentParser
5 | from skimage import img_as_ubyte
6 | from skimage.transform import resize
7 | from tqdm import tqdm
8 | import os
9 | import imageio
10 | import numpy as np
11 | import warnings
12 | warnings.filterwarnings("ignore")
13 |
14 | def extract_bbox(frame, fa):
15 | if max(frame.shape[0], frame.shape[1]) > 640:
16 | scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0
17 | frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor)))
18 | frame = img_as_ubyte(frame)
19 | else:
20 | scale_factor = 1
21 | frame = frame[..., :3]
22 | bboxes = fa.face_detector.detect_from_image(frame[..., ::-1])
23 | if len(bboxes) == 0:
24 | return []
25 | return np.array(bboxes)[:, :-1] * scale_factor
26 |
27 |
28 |
29 | def bb_intersection_over_union(boxA, boxB):
30 | xA = max(boxA[0], boxB[0])
31 | yA = max(boxA[1], boxB[1])
32 | xB = min(boxA[2], boxB[2])
33 | yB = min(boxA[3], boxB[3])
34 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
35 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
36 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
37 | iou = interArea / float(boxAArea + boxBArea - interArea)
38 | return iou
39 |
40 |
41 | def join(tube_bbox, bbox):
42 | xA = min(tube_bbox[0], bbox[0])
43 | yA = min(tube_bbox[1], bbox[1])
44 | xB = max(tube_bbox[2], bbox[2])
45 | yB = max(tube_bbox[3], bbox[3])
46 | return (xA, yA, xB, yB)
47 |
48 |
49 | def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1):
50 | left, top, right, bot = tube_bbox
51 | width = right - left
52 | height = bot - top
53 |
54 | #Computing aspect preserving bbox
55 | width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
56 | height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
57 |
58 | left = int(left - width_increase * width)
59 | top = int(top - height_increase * height)
60 | right = int(right + width_increase * width)
61 | bot = int(bot + height_increase * height)
62 |
63 | top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1])
64 | h, w = bot - top, right - left
65 |
66 | start = start / fps
67 | end = end / fps
68 | time = end - start
69 |
70 | scale = f'{image_shape[0]}:{image_shape[1]}'
71 |
72 | return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4'
73 |
74 |
75 | def compute_bbox_trajectories(trajectories, fps, frame_shape, args):
76 | commands = []
77 | for i, (bbox, tube_bbox, start, end) in enumerate(trajectories):
78 | if (end - start) > args.min_frames:
79 | command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase)
80 | commands.append(command)
81 | return commands
82 |
83 |
84 | def process_video(args):
85 | device = 'cpu' if args.cpu else 'cuda'
86 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
87 | video = imageio.get_reader(args.inp)
88 |
89 | trajectories = []
90 | previous_frame = None
91 | fps = video.get_meta_data()['fps']
92 | commands = []
93 | try:
94 | for i, frame in tqdm(enumerate(video)):
95 | frame_shape = frame.shape
96 | bboxes = extract_bbox(frame, fa)
97 | ## For each trajectory check the criterion
98 | not_valid_trajectories = []
99 | valid_trajectories = []
100 |
101 | for trajectory in trajectories:
102 | tube_bbox = trajectory[0]
103 | intersection = 0
104 | for bbox in bboxes:
105 | intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox))
106 | if intersection > args.iou_with_initial:
107 | valid_trajectories.append(trajectory)
108 | else:
109 | not_valid_trajectories.append(trajectory)
110 |
111 | commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args)
112 | trajectories = valid_trajectories
113 |
114 | ## Assign bbox to trajectories, create new trajectories
115 | for bbox in bboxes:
116 | intersection = 0
117 | current_trajectory = None
118 | for trajectory in trajectories:
119 | tube_bbox = trajectory[0]
120 | current_intersection = bb_intersection_over_union(tube_bbox, bbox)
121 | if intersection < current_intersection and current_intersection > args.iou_with_initial:
122 | intersection = bb_intersection_over_union(tube_bbox, bbox)
123 | current_trajectory = trajectory
124 |
125 | ## Create new trajectory
126 | if current_trajectory is None:
127 | trajectories.append([bbox, bbox, i, i])
128 | else:
129 | current_trajectory[3] = i
130 | current_trajectory[1] = join(current_trajectory[1], bbox)
131 |
132 |
133 | except IndexError as e:
134 | raise (e)
135 |
136 | commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args)
137 | return commands
138 |
139 |
140 | if __name__ == "__main__":
141 | parser = ArgumentParser()
142 |
143 | parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
144 | help="Image shape")
145 | parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount')
146 | parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox")
147 | parser.add_argument("--inp", required=True, help='Input image or video')
148 | parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames')
149 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
150 |
151 |
152 | args = parser.parse_args()
153 |
154 | commands = process_video(args)
155 | for command in commands:
156 | print (command)
157 |
158 |
--------------------------------------------------------------------------------
/face-vid2vid/frames_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from skimage import io, img_as_float32
3 | from skimage.color import gray2rgb
4 | from sklearn.model_selection import train_test_split
5 | from imageio import mimread
6 |
7 | import numpy as np
8 | from torch.utils.data import Dataset
9 | import pandas as pd
10 | from augmentation import AllAugmentationTransform
11 | import glob
12 |
13 |
14 | def read_video(name, frame_shape):
15 | """
16 | Read video which can be:
17 | - an image of concatenated frames
18 | - '.mp4' and'.gif'
19 | - folder with videos
20 | """
21 |
22 | if os.path.isdir(name):
23 | frames = sorted(os.listdir(name))
24 | num_frames = len(frames)
25 | video_array = np.array(
26 | [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
27 | elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
28 | image = io.imread(name)
29 |
30 | if len(image.shape) == 2 or image.shape[2] == 1:
31 | image = gray2rgb(image)
32 |
33 | if image.shape[2] == 4:
34 | image = image[..., :3]
35 |
36 | image = img_as_float32(image)
37 |
38 | video_array = np.moveaxis(image, 1, 0)
39 |
40 | video_array = video_array.reshape((-1,) + frame_shape)
41 | video_array = np.moveaxis(video_array, 1, 2)
42 | elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
43 | video = np.array(mimread(name))
44 | if len(video.shape) == 3:
45 | video = np.array([gray2rgb(frame) for frame in video])
46 | if video.shape[-1] == 4:
47 | video = video[..., :3]
48 | video_array = img_as_float32(video)
49 | else:
50 | raise Exception("Unknown file extensions %s" % name)
51 |
52 | return video_array
53 |
54 |
55 | class FramesDataset(Dataset):
56 | """
57 | Dataset of videos, each video can be represented as:
58 | - an image of concatenated frames
59 | - '.mp4' or '.gif'
60 | - folder with all frames
61 | """
62 |
63 | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
64 | random_seed=0, pairs_list=None, augmentation_params=None):
65 | self.root_dir = root_dir
66 | self.videos = os.listdir(root_dir)
67 | self.frame_shape = tuple(frame_shape)
68 | self.pairs_list = pairs_list
69 | self.id_sampling = id_sampling
70 | if os.path.exists(os.path.join(root_dir, 'train')):
71 | assert os.path.exists(os.path.join(root_dir, 'test'))
72 | print("Use predefined train-test split.")
73 | if id_sampling:
74 | train_videos = {os.path.basename(video).split('#')[0] for video in
75 | os.listdir(os.path.join(root_dir, 'train'))}
76 | train_videos = list(train_videos)
77 | else:
78 | train_videos = os.listdir(os.path.join(root_dir, 'train'))
79 | test_videos = os.listdir(os.path.join(root_dir, 'test'))
80 | self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
81 | else:
82 | print("Use random train-test split.")
83 | train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
84 |
85 | if is_train:
86 | self.videos = train_videos
87 | else:
88 | self.videos = test_videos
89 |
90 | self.is_train = is_train
91 |
92 | if self.is_train:
93 | self.transform = AllAugmentationTransform(**augmentation_params)
94 | else:
95 | self.transform = None
96 |
97 | def __len__(self):
98 | return len(self.videos)
99 |
100 | def __getitem__(self, idx):
101 | if self.is_train and self.id_sampling:
102 | name = self.videos[idx]
103 | path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
104 | else:
105 | name = self.videos[idx]
106 | path = os.path.join(self.root_dir, name)
107 |
108 | video_name = os.path.basename(path)
109 |
110 | if self.is_train and os.path.isdir(path):
111 | frames = os.listdir(path)
112 | num_frames = len(frames)
113 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
114 | video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
115 | else:
116 | video_array = read_video(path, frame_shape=self.frame_shape)
117 | num_frames = len(video_array)
118 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
119 | num_frames)
120 | video_array = video_array[frame_idx]
121 |
122 | if self.transform is not None:
123 | video_array = self.transform(video_array)
124 |
125 | out = {}
126 | if self.is_train:
127 | source = np.array(video_array[0], dtype='float32')
128 | driving = np.array(video_array[1], dtype='float32')
129 |
130 | out['driving'] = driving.transpose((2, 0, 1))
131 | out['source'] = source.transpose((2, 0, 1))
132 | else:
133 | video = np.array(video_array, dtype='float32')
134 | out['video'] = video.transpose((3, 0, 1, 2))
135 |
136 | out['name'] = video_name
137 |
138 | return out
139 |
140 |
141 | class DatasetRepeater(Dataset):
142 | """
143 | Pass several times over the same dataset for better i/o performance
144 | """
145 |
146 | def __init__(self, dataset, num_repeats=100):
147 | self.dataset = dataset
148 | self.num_repeats = num_repeats
149 |
150 | def __len__(self):
151 | return self.num_repeats * self.dataset.__len__()
152 |
153 | def __getitem__(self, idx):
154 | return self.dataset[idx % self.dataset.__len__()]
155 |
--------------------------------------------------------------------------------
/face-vid2vid/logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import imageio
5 |
6 | import os
7 | from skimage.draw import circle_perimeter
8 |
9 | import matplotlib.pyplot as plt
10 | import collections
11 |
12 |
13 | class Logger:
14 | def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name="log.txt"):
15 | self.loss_list = []
16 | self.cpk_dir = log_dir
17 | self.visualizations_dir = os.path.join(log_dir, "train-vis")
18 | if not os.path.exists(self.visualizations_dir):
19 | os.makedirs(self.visualizations_dir)
20 | self.log_file = open(os.path.join(log_dir, log_file_name), "a")
21 | self.zfill_num = zfill_num
22 | self.visualizer = Visualizer(**visualizer_params)
23 | self.checkpoint_freq = checkpoint_freq
24 | self.epoch = 0
25 | self.best_loss = float("inf")
26 | self.names = None
27 |
28 | def log_scores(self, loss_names):
29 | loss_mean = np.array(self.loss_list).mean(axis=0)
30 |
31 | loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
32 | loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
33 |
34 | print(loss_string, file=self.log_file)
35 | self.loss_list = []
36 | self.log_file.flush()
37 |
38 | def visualize_rec(self, inp, out):
39 | image = self.visualizer.visualize(inp["driving"], inp["source"], out)
40 | imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
41 |
42 | def save_cpk(self, emergent=False):
43 | cpk = {k: v.state_dict() for k, v in self.models.items()}
44 | cpk["epoch"] = self.epoch
45 | cpk_path = os.path.join(self.cpk_dir, "%s-checkpoint.pth.tar" % str(self.epoch).zfill(self.zfill_num))
46 | if not (os.path.exists(cpk_path) and emergent):
47 | torch.save(cpk, cpk_path)
48 |
49 | @staticmethod
50 | def load_cpk(
51 | checkpoint_path,
52 | generator=None,
53 | discriminator=None,
54 | kp_detector=None,
55 | he_estimator=None,
56 | optimizer_generator=None,
57 | optimizer_discriminator=None,
58 | optimizer_kp_detector=None,
59 | optimizer_he_estimator=None,
60 | ):
61 | checkpoint = torch.load(checkpoint_path)
62 | if generator is not None:
63 | generator.load_state_dict(checkpoint["generator"])
64 | if kp_detector is not None:
65 | kp_detector.load_state_dict(checkpoint["kp_detector"])
66 | if he_estimator is not None:
67 | he_estimator.load_state_dict(checkpoint["he_estimator"])
68 | if discriminator is not None:
69 | try:
70 | discriminator.load_state_dict(checkpoint["discriminator"])
71 | except:
72 | print("No discriminator in the state-dict. Dicriminator will be randomly initialized")
73 | if optimizer_generator is not None:
74 | optimizer_generator.load_state_dict(checkpoint["optimizer_generator"])
75 | if optimizer_discriminator is not None:
76 | try:
77 | optimizer_discriminator.load_state_dict(checkpoint["optimizer_discriminator"])
78 | except RuntimeError as e:
79 | print("No discriminator optimizer in the state-dict. Optimizer will be not initialized")
80 | if optimizer_kp_detector is not None:
81 | optimizer_kp_detector.load_state_dict(checkpoint["optimizer_kp_detector"])
82 | if optimizer_he_estimator is not None:
83 | optimizer_he_estimator.load_state_dict(checkpoint["optimizer_he_estimator"])
84 |
85 | return checkpoint["epoch"]
86 |
87 | def __enter__(self):
88 | return self
89 |
90 | def __exit__(self, exc_type, exc_val, exc_tb):
91 | if "models" in self.__dict__:
92 | self.save_cpk()
93 | self.log_file.close()
94 |
95 | def log_iter(self, losses):
96 | losses = collections.OrderedDict(losses.items())
97 | if self.names is None:
98 | self.names = list(losses.keys())
99 | self.loss_list.append(list(losses.values()))
100 |
101 | def log_epoch(self, epoch, models, inp, out):
102 | self.epoch = epoch
103 | self.models = models
104 | if (self.epoch + 1) % self.checkpoint_freq == 0:
105 | self.save_cpk()
106 | self.log_scores(self.names)
107 | self.visualize_rec(inp, out)
108 |
109 |
110 | class Visualizer:
111 | def __init__(self, kp_size=5, draw_border=False, colormap="gist_rainbow"):
112 | self.kp_size = kp_size
113 | self.draw_border = draw_border
114 | self.colormap = plt.get_cmap(colormap)
115 |
116 | def draw_image_with_kp(self, image, kp_array):
117 | image = np.copy(image)
118 | spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
119 | kp_array = spatial_size * (kp_array + 1) / 2
120 | num_kp = kp_array.shape[0]
121 | for kp_ind, kp in enumerate(kp_array):
122 | rr, cc = circle_perimeter(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
123 | image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
124 | return image
125 |
126 | def create_image_column_with_kp(self, images, kp):
127 | image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
128 | return self.create_image_column(image_array)
129 |
130 | def create_image_column(self, images):
131 | if self.draw_border:
132 | images = np.copy(images)
133 | images[:, :, [0, -1]] = (1, 1, 1)
134 | images[:, :, [0, -1]] = (1, 1, 1)
135 | return np.concatenate(list(images), axis=0)
136 |
137 | def create_image_grid(self, *args):
138 | out = []
139 | for arg in args:
140 | if type(arg) == tuple:
141 | out.append(self.create_image_column_with_kp(arg[0], arg[1]))
142 | else:
143 | out.append(self.create_image_column(arg))
144 | return np.concatenate(out, axis=1)
145 |
146 | def visualize(self, driving, source, out):
147 | images = []
148 |
149 | # Source image with keypoints
150 | source = source.data.cpu()
151 | kp_source = out["kp_source"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d
152 | source = np.transpose(source, [0, 2, 3, 1])
153 | images.append((source, kp_source))
154 |
155 | # Equivariance visualization
156 | if "transformed_frame" in out:
157 | transformed = out["transformed_frame"].data.cpu().numpy()
158 | transformed = np.transpose(transformed, [0, 2, 3, 1])
159 | transformed_kp = out["transformed_kp"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d
160 | images.append((transformed, transformed_kp))
161 |
162 | # Driving image with keypoints
163 | kp_driving = out["kp_driving"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d
164 | driving = driving.data.cpu().numpy()
165 | driving = np.transpose(driving, [0, 2, 3, 1])
166 | images.append((driving, kp_driving))
167 |
168 | # Result
169 | prediction = out["prediction"].data.cpu().numpy()
170 | prediction = np.transpose(prediction, [0, 2, 3, 1])
171 | images.append(prediction)
172 |
173 | ## Occlusion map
174 | if "occlusion_map" in out:
175 | occlusion_map = out["occlusion_map"].data.cpu().repeat(1, 3, 1, 1)
176 | occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
177 | occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
178 | images.append(occlusion_map)
179 |
180 | ## Mask
181 | if "mask" in out:
182 | for i in range(out["mask"].shape[1]):
183 | mask = out["mask"][:, i : (i + 1)].data.cpu().sum(2).repeat(1, 3, 1, 1) # (n, 3, h, w)
184 | # mask = F.softmax(mask.view(mask.shape[0], mask.shape[1], -1), dim=2).view(mask.shape)
185 | mask = F.interpolate(mask, size=source.shape[1:3]).numpy()
186 | mask = np.transpose(mask, [0, 2, 3, 1])
187 |
188 | if i != 0:
189 | color = np.array(self.colormap((i - 1) / (out["mask"].shape[1] - 1)))[:3]
190 | else:
191 | color = np.array((0, 0, 0))
192 |
193 | color = color.reshape((1, 1, 1, 3))
194 |
195 | if i != 0:
196 | images.append(mask * color)
197 | else:
198 | images.append(mask)
199 |
200 | image = self.create_image_grid(*images)
201 | image = (255 * image).astype(np.uint8)
202 | return image
203 |
--------------------------------------------------------------------------------
/face-vid2vid/modules/dense_motion.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | import torch
4 | from modules.util import Hourglass, make_coordinate_grid, kp2gaussian
5 |
6 | from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
7 |
8 |
9 | class DenseMotionNetwork(nn.Module):
10 | """
11 | Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
12 | """
13 |
14 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress,
15 | estimate_occlusion_map=False):
16 | super(DenseMotionNetwork, self).__init__()
17 | # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks)
18 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks)
19 |
20 | self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3)
21 |
22 | self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1)
23 | self.norm = BatchNorm3d(compress, affine=True)
24 |
25 | if estimate_occlusion_map:
26 | # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3)
27 | self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
28 | else:
29 | self.occlusion = None
30 |
31 | self.num_kp = num_kp
32 |
33 |
34 | def create_sparse_motions(self, feature, kp_driving, kp_source):
35 | bs, _, d, h, w = feature.shape
36 | identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type())
37 | identity_grid = identity_grid.view(1, 1, d, h, w, 3)
38 | coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3)
39 |
40 | k = coordinate_grid.shape[1]
41 |
42 | # if 'jacobian' in kp_driving:
43 | if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None:
44 | jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
45 | jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
46 | jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1)
47 | coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
48 | coordinate_grid = coordinate_grid.squeeze(-1)
49 | '''
50 | if 'rot' in kp_driving:
51 | rot_s = kp_source['rot']
52 | rot_d = kp_driving['rot']
53 | rot = torch.einsum('bij, bjk->bki', rot_s, torch.inverse(rot_d))
54 | rot = rot.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
55 | rot = rot.repeat(1, k, d, h, w, 1, 1)
56 | # print(rot.shape)
57 | coordinate_grid = torch.matmul(rot, coordinate_grid.unsqueeze(-1))
58 | coordinate_grid = coordinate_grid.squeeze(-1)
59 | # print(coordinate_grid.shape)
60 | '''
61 | driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
62 |
63 | #adding background feature
64 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
65 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
66 |
67 | # sparse_motions = driving_to_source
68 |
69 | return sparse_motions
70 |
71 | def create_deformed_feature(self, feature, sparse_motions):
72 | bs, _, d, h, w = feature.shape
73 | feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
74 | feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
75 | sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
76 | sparse_deformed = F.grid_sample(feature_repeat, sparse_motions)
77 | sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
78 | return sparse_deformed
79 |
80 | def create_heatmap_representations(self, feature, kp_driving, kp_source):
81 | spatial_size = feature.shape[3:]
82 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01)
83 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01)
84 | heatmap = gaussian_driving - gaussian_source
85 |
86 | # adding background feature
87 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type())
88 | heatmap = torch.cat([zeros, heatmap], dim=1)
89 | heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
90 | return heatmap
91 |
92 | def forward(self, feature, kp_driving, kp_source):
93 | bs, _, d, h, w = feature.shape
94 |
95 | feature = self.compress(feature)
96 | feature = self.norm(feature)
97 | feature = F.relu(feature)
98 |
99 | out_dict = dict()
100 | sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source)
101 | deformed_feature = self.create_deformed_feature(feature, sparse_motion)
102 |
103 | heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source)
104 |
105 | input = torch.cat([heatmap, deformed_feature], dim=2)
106 | input = input.view(bs, -1, d, h, w)
107 |
108 | # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w)
109 |
110 | prediction = self.hourglass(input)
111 |
112 | mask = self.mask(prediction)
113 | mask = F.softmax(mask, dim=1)
114 | out_dict['mask'] = mask
115 | mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
116 | sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
117 | deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w)
118 | deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
119 |
120 | out_dict['deformation'] = deformation
121 |
122 | if self.occlusion:
123 | bs, c, d, h, w = prediction.shape
124 | prediction = prediction.view(bs, -1, h, w)
125 | occlusion_map = torch.sigmoid(self.occlusion(prediction))
126 | out_dict['occlusion_map'] = occlusion_map
127 |
128 | return out_dict
129 |
--------------------------------------------------------------------------------
/face-vid2vid/modules/discriminator.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from modules.util import kp2gaussian
4 | import torch
5 |
6 |
7 | class DownBlock2d(nn.Module):
8 | """
9 | Simple block for processing video (encoder).
10 | """
11 |
12 | def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
13 | super(DownBlock2d, self).__init__()
14 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
15 |
16 | if sn:
17 | self.conv = nn.utils.spectral_norm(self.conv)
18 |
19 | if norm:
20 | self.norm = nn.InstanceNorm2d(out_features, affine=True)
21 | else:
22 | self.norm = None
23 | self.pool = pool
24 |
25 | def forward(self, x):
26 | out = x
27 | out = self.conv(out)
28 | if self.norm:
29 | out = self.norm(out)
30 | out = F.leaky_relu(out, 0.2)
31 | if self.pool:
32 | out = F.avg_pool2d(out, (2, 2))
33 | return out
34 |
35 |
36 | class Discriminator(nn.Module):
37 | """
38 | Discriminator similar to Pix2Pix
39 | """
40 |
41 | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
42 | sn=False, **kwargs):
43 | super(Discriminator, self).__init__()
44 |
45 | down_blocks = []
46 | for i in range(num_blocks):
47 | down_blocks.append(
48 | DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
49 | min(max_features, block_expansion * (2 ** (i + 1))),
50 | norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
51 |
52 | self.down_blocks = nn.ModuleList(down_blocks)
53 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
54 | if sn:
55 | self.conv = nn.utils.spectral_norm(self.conv)
56 |
57 | def forward(self, x):
58 | feature_maps = []
59 | out = x
60 |
61 | for down_block in self.down_blocks:
62 | feature_maps.append(down_block(out))
63 | out = feature_maps[-1]
64 | prediction_map = self.conv(out)
65 |
66 | return feature_maps, prediction_map
67 |
68 |
69 | class MultiScaleDiscriminator(nn.Module):
70 | """
71 | Multi-scale (scale) discriminator
72 | """
73 |
74 | def __init__(self, scales=(), **kwargs):
75 | super(MultiScaleDiscriminator, self).__init__()
76 | self.scales = scales
77 | discs = {}
78 | for scale in scales:
79 | discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
80 | self.discs = nn.ModuleDict(discs)
81 |
82 | def forward(self, x):
83 | out_dict = {}
84 | for scale, disc in self.discs.items():
85 | scale = str(scale).replace('-', '.')
86 | key = 'prediction_' + scale
87 | feature_maps, prediction_map = disc(x[key])
88 | out_dict['feature_maps_' + scale] = feature_maps
89 | out_dict['prediction_map_' + scale] = prediction_map
90 | return out_dict
91 |
--------------------------------------------------------------------------------
/face-vid2vid/modules/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock
5 | from modules.dense_motion import DenseMotionNetwork
6 |
7 |
8 | class OcclusionAwareGenerator(nn.Module):
9 | """
10 | Generator follows NVIDIA architecture.
11 | """
12 |
13 | def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
14 | num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
15 | super(OcclusionAwareGenerator, self).__init__()
16 |
17 | if dense_motion_params is not None:
18 | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
19 | estimate_occlusion_map=estimate_occlusion_map,
20 | **dense_motion_params)
21 | else:
22 | self.dense_motion_network = None
23 |
24 | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3))
25 |
26 | down_blocks = []
27 | for i in range(num_down_blocks):
28 | in_features = min(max_features, block_expansion * (2 ** i))
29 | out_features = min(max_features, block_expansion * (2 ** (i + 1)))
30 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
31 | self.down_blocks = nn.ModuleList(down_blocks)
32 |
33 | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
34 |
35 | self.reshape_channel = reshape_channel
36 | self.reshape_depth = reshape_depth
37 |
38 | self.resblocks_3d = torch.nn.Sequential()
39 | for i in range(num_resblocks):
40 | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
41 |
42 | out_features = block_expansion * (2 ** (num_down_blocks))
43 | self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
44 | self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
45 |
46 | self.resblocks_2d = torch.nn.Sequential()
47 | for i in range(num_resblocks):
48 | self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1))
49 |
50 | up_blocks = []
51 | for i in range(num_down_blocks):
52 | in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i)))
53 | out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1)))
54 | up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
55 | self.up_blocks = nn.ModuleList(up_blocks)
56 |
57 | self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3))
58 | self.estimate_occlusion_map = estimate_occlusion_map
59 | self.image_channel = image_channel
60 |
61 | def deform_input(self, inp, deformation):
62 | _, d_old, h_old, w_old, _ = deformation.shape
63 | _, _, d, h, w = inp.shape
64 | if d_old != d or h_old != h or w_old != w:
65 | deformation = deformation.permute(0, 4, 1, 2, 3)
66 | deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
67 | deformation = deformation.permute(0, 2, 3, 4, 1)
68 | return F.grid_sample(inp, deformation)
69 |
70 | def forward(self, source_image, kp_driving, kp_source):
71 | # Encoding (downsampling) part
72 | out = self.first(source_image)
73 | for i in range(len(self.down_blocks)):
74 | out = self.down_blocks[i](out)
75 | out = self.second(out)
76 | bs, c, h, w = out.shape
77 | # print(out.shape)
78 | feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
79 | feature_3d = self.resblocks_3d(feature_3d)
80 |
81 | # Transforming feature representation according to deformation and occlusion
82 | output_dict = {}
83 | if self.dense_motion_network is not None:
84 | dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
85 | kp_source=kp_source)
86 | output_dict['mask'] = dense_motion['mask']
87 |
88 | if 'occlusion_map' in dense_motion:
89 | occlusion_map = dense_motion['occlusion_map']
90 | output_dict['occlusion_map'] = occlusion_map
91 | else:
92 | occlusion_map = None
93 | deformation = dense_motion['deformation']
94 | out = self.deform_input(feature_3d, deformation)
95 |
96 | bs, c, d, h, w = out.shape
97 | out = out.view(bs, c*d, h, w)
98 | out = self.third(out)
99 | out = self.fourth(out)
100 |
101 | if occlusion_map is not None:
102 | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
103 | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
104 | out = out * occlusion_map
105 |
106 | # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image
107 |
108 | # Decoding part
109 | out = self.resblocks_2d(out)
110 | for i in range(len(self.up_blocks)):
111 | out = self.up_blocks[i](out)
112 | out = self.final(out)
113 | out = F.sigmoid(out)
114 |
115 | output_dict["prediction"] = out
116 |
117 | return output_dict
118 |
119 |
120 | class SPADEDecoder(nn.Module):
121 | def __init__(self):
122 | super().__init__()
123 | ic = 256
124 | oc = 64
125 | norm_G = 'spadespectralinstance'
126 | label_nc = 256
127 |
128 | self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
129 | self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
130 | self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
131 | self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
132 | self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
133 | self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
134 | self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
135 | self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
136 | self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
137 | self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
138 | self.up = nn.Upsample(scale_factor=2)
139 |
140 | def forward(self, feature):
141 | seg = feature
142 | x = self.fc(feature)
143 | x = self.G_middle_0(x, seg)
144 | x = self.G_middle_1(x, seg)
145 | x = self.G_middle_2(x, seg)
146 | x = self.G_middle_3(x, seg)
147 | x = self.G_middle_4(x, seg)
148 | x = self.G_middle_5(x, seg)
149 | x = self.up(x)
150 | x = self.up_0(x, seg) # 256, 128, 128
151 | x = self.up(x)
152 | x = self.up_1(x, seg) # 64, 256, 256
153 |
154 | x = self.conv_img(F.leaky_relu(x, 2e-1))
155 | # x = torch.tanh(x)
156 | x = F.sigmoid(x)
157 |
158 | return x
159 |
160 |
161 | class OcclusionAwareSPADEGenerator(nn.Module):
162 |
163 | def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
164 | num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
165 | super(OcclusionAwareSPADEGenerator, self).__init__()
166 |
167 | if dense_motion_params is not None:
168 | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
169 | estimate_occlusion_map=estimate_occlusion_map,
170 | **dense_motion_params)
171 | else:
172 | self.dense_motion_network = None
173 |
174 | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
175 |
176 | down_blocks = []
177 | for i in range(num_down_blocks):
178 | in_features = min(max_features, block_expansion * (2 ** i))
179 | out_features = min(max_features, block_expansion * (2 ** (i + 1)))
180 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
181 | self.down_blocks = nn.ModuleList(down_blocks)
182 |
183 | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
184 |
185 | self.reshape_channel = reshape_channel
186 | self.reshape_depth = reshape_depth
187 |
188 | self.resblocks_3d = torch.nn.Sequential()
189 | for i in range(num_resblocks):
190 | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
191 |
192 | out_features = block_expansion * (2 ** (num_down_blocks))
193 | self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
194 | self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
195 |
196 | self.estimate_occlusion_map = estimate_occlusion_map
197 | self.image_channel = image_channel
198 |
199 | self.decoder = SPADEDecoder()
200 |
201 | def deform_input(self, inp, deformation):
202 | _, d_old, h_old, w_old, _ = deformation.shape
203 | _, _, d, h, w = inp.shape
204 | if d_old != d or h_old != h or w_old != w:
205 | deformation = deformation.permute(0, 4, 1, 2, 3)
206 | deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
207 | deformation = deformation.permute(0, 2, 3, 4, 1)
208 | return F.grid_sample(inp, deformation)
209 |
210 | def forward(self, source_image, kp_driving, kp_source, fp16=False):
211 | if fp16:
212 | source_image = source_image.half()
213 | kp_driving['value'] = kp_driving['value'].half()
214 | kp_source['value'] = kp_source['value'].half()
215 | # Encoding (downsampling) part
216 | out = self.first(source_image)
217 | for i in range(len(self.down_blocks)):
218 | out = self.down_blocks[i](out)
219 | out = self.second(out)
220 | bs, c, h, w = out.shape
221 | # print(out.shape)
222 | feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
223 | feature_3d = self.resblocks_3d(feature_3d)
224 |
225 | # Transforming feature representation according to deformation and occlusion
226 | output_dict = {}
227 | if self.dense_motion_network is not None:
228 | dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
229 | kp_source=kp_source)
230 | output_dict['mask'] = dense_motion['mask']
231 |
232 | if 'occlusion_map' in dense_motion:
233 | occlusion_map = dense_motion['occlusion_map']
234 | output_dict['occlusion_map'] = occlusion_map
235 | else:
236 | occlusion_map = None
237 | deformation = dense_motion['deformation']
238 | out = self.deform_input(feature_3d, deformation)
239 |
240 | bs, c, d, h, w = out.shape
241 | out = out.view(bs, c*d, h, w)
242 | out = self.third(out)
243 | out = self.fourth(out)
244 |
245 | if occlusion_map is not None:
246 | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
247 | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
248 | out = out * occlusion_map
249 |
250 | # Decoding part
251 | out = self.decoder(out)
252 |
253 | output_dict["prediction"] = out
254 |
255 | return output_dict
--------------------------------------------------------------------------------
/face-vid2vid/modules/hopenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import math
5 | import torch.nn.functional as F
6 |
7 | class Hopenet(nn.Module):
8 | # Hopenet with 3 output layers for yaw, pitch and roll
9 | # Predicts Euler angles by binning and regression with the expected value
10 | def __init__(self, block, layers, num_bins):
11 | self.inplanes = 64
12 | super(Hopenet, self).__init__()
13 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
14 | bias=False)
15 | self.bn1 = nn.BatchNorm2d(64)
16 | self.relu = nn.ReLU(inplace=True)
17 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
18 | self.layer1 = self._make_layer(block, 64, layers[0])
19 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
20 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
21 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
22 | self.avgpool = nn.AvgPool2d(7)
23 | self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)
24 | self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
25 | self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
26 |
27 | # Vestigial layer from previous experiments
28 | self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
29 |
30 | for m in self.modules():
31 | if isinstance(m, nn.Conv2d):
32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
33 | m.weight.data.normal_(0, math.sqrt(2. / n))
34 | elif isinstance(m, nn.BatchNorm2d):
35 | m.weight.data.fill_(1)
36 | m.bias.data.zero_()
37 |
38 | def _make_layer(self, block, planes, blocks, stride=1):
39 | downsample = None
40 | if stride != 1 or self.inplanes != planes * block.expansion:
41 | downsample = nn.Sequential(
42 | nn.Conv2d(self.inplanes, planes * block.expansion,
43 | kernel_size=1, stride=stride, bias=False),
44 | nn.BatchNorm2d(planes * block.expansion),
45 | )
46 |
47 | layers = []
48 | layers.append(block(self.inplanes, planes, stride, downsample))
49 | self.inplanes = planes * block.expansion
50 | for i in range(1, blocks):
51 | layers.append(block(self.inplanes, planes))
52 |
53 | return nn.Sequential(*layers)
54 |
55 | def forward(self, x):
56 | x = self.conv1(x)
57 | x = self.bn1(x)
58 | x = self.relu(x)
59 | x = self.maxpool(x)
60 |
61 | x = self.layer1(x)
62 | x = self.layer2(x)
63 | x = self.layer3(x)
64 | x = self.layer4(x)
65 |
66 | x = self.avgpool(x)
67 | x = x.view(x.size(0), -1)
68 | pre_yaw = self.fc_yaw(x)
69 | pre_pitch = self.fc_pitch(x)
70 | pre_roll = self.fc_roll(x)
71 |
72 | return pre_yaw, pre_pitch, pre_roll
73 |
74 | class ResNet(nn.Module):
75 | # ResNet for regression of 3 Euler angles.
76 | def __init__(self, block, layers, num_classes=1000):
77 | self.inplanes = 64
78 | super(ResNet, self).__init__()
79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
80 | bias=False)
81 | self.bn1 = nn.BatchNorm2d(64)
82 | self.relu = nn.ReLU(inplace=True)
83 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
84 | self.layer1 = self._make_layer(block, 64, layers[0])
85 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
86 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
87 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
88 | self.avgpool = nn.AvgPool2d(7)
89 | self.fc_angles = nn.Linear(512 * block.expansion, num_classes)
90 |
91 | for m in self.modules():
92 | if isinstance(m, nn.Conv2d):
93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
94 | m.weight.data.normal_(0, math.sqrt(2. / n))
95 | elif isinstance(m, nn.BatchNorm2d):
96 | m.weight.data.fill_(1)
97 | m.bias.data.zero_()
98 |
99 | def _make_layer(self, block, planes, blocks, stride=1):
100 | downsample = None
101 | if stride != 1 or self.inplanes != planes * block.expansion:
102 | downsample = nn.Sequential(
103 | nn.Conv2d(self.inplanes, planes * block.expansion,
104 | kernel_size=1, stride=stride, bias=False),
105 | nn.BatchNorm2d(planes * block.expansion),
106 | )
107 |
108 | layers = []
109 | layers.append(block(self.inplanes, planes, stride, downsample))
110 | self.inplanes = planes * block.expansion
111 | for i in range(1, blocks):
112 | layers.append(block(self.inplanes, planes))
113 |
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x):
117 | x = self.conv1(x)
118 | x = self.bn1(x)
119 | x = self.relu(x)
120 | x = self.maxpool(x)
121 |
122 | x = self.layer1(x)
123 | x = self.layer2(x)
124 | x = self.layer3(x)
125 | x = self.layer4(x)
126 |
127 | x = self.avgpool(x)
128 | x = x.view(x.size(0), -1)
129 | x = self.fc_angles(x)
130 | return x
131 |
132 | class AlexNet(nn.Module):
133 | # AlexNet laid out as a Hopenet - classify Euler angles in bins and
134 | # regress the expected value.
135 | def __init__(self, num_bins):
136 | super(AlexNet, self).__init__()
137 | self.features = nn.Sequential(
138 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
139 | nn.ReLU(inplace=True),
140 | nn.MaxPool2d(kernel_size=3, stride=2),
141 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
142 | nn.ReLU(inplace=True),
143 | nn.MaxPool2d(kernel_size=3, stride=2),
144 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
145 | nn.ReLU(inplace=True),
146 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
147 | nn.ReLU(inplace=True),
148 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
149 | nn.ReLU(inplace=True),
150 | nn.MaxPool2d(kernel_size=3, stride=2),
151 | )
152 | self.classifier = nn.Sequential(
153 | nn.Dropout(),
154 | nn.Linear(256 * 6 * 6, 4096),
155 | nn.ReLU(inplace=True),
156 | nn.Dropout(),
157 | nn.Linear(4096, 4096),
158 | nn.ReLU(inplace=True),
159 | )
160 | self.fc_yaw = nn.Linear(4096, num_bins)
161 | self.fc_pitch = nn.Linear(4096, num_bins)
162 | self.fc_roll = nn.Linear(4096, num_bins)
163 |
164 | def forward(self, x):
165 | x = self.features(x)
166 | x = x.view(x.size(0), 256 * 6 * 6)
167 | x = self.classifier(x)
168 | yaw = self.fc_yaw(x)
169 | pitch = self.fc_pitch(x)
170 | roll = self.fc_roll(x)
171 | return yaw, pitch, roll
172 |
--------------------------------------------------------------------------------
/face-vid2vid/modules/keypoint_detector.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
6 | from modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck
7 |
8 |
9 | class KPDetector(nn.Module):
10 | """
11 | Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint.
12 | """
13 |
14 | def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth,
15 | num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False):
16 | super(KPDetector, self).__init__()
17 |
18 | self.predictor = KPHourglass(block_expansion, in_features=image_channel,
19 | max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks)
20 |
21 | # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3)
22 | self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1)
23 |
24 | if estimate_jacobian:
25 | self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
26 | # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3)
27 | self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1)
28 | '''
29 | initial as:
30 | [[1 0 0]
31 | [0 1 0]
32 | [0 0 1]]
33 | '''
34 | self.jacobian.weight.data.zero_()
35 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
36 | else:
37 | self.jacobian = None
38 |
39 | self.temperature = temperature
40 | self.scale_factor = scale_factor
41 | if self.scale_factor != 1:
42 | self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor)
43 |
44 | def gaussian2kp(self, heatmap):
45 | """
46 | Extract the mean from a heatmap
47 | """
48 | shape = heatmap.shape
49 | heatmap = heatmap.unsqueeze(-1)
50 | grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
51 | value = (heatmap * grid).sum(dim=(2, 3, 4))
52 | kp = {'value': value}
53 |
54 | return kp
55 |
56 | def forward(self, x):
57 | if self.scale_factor != 1:
58 | x = self.down(x)
59 |
60 | feature_map = self.predictor(x)
61 | prediction = self.kp(feature_map)
62 |
63 | final_shape = prediction.shape
64 | heatmap = prediction.view(final_shape[0], final_shape[1], -1)
65 | heatmap = F.softmax(heatmap / self.temperature, dim=2)
66 | heatmap = heatmap.view(*final_shape)
67 |
68 | out = self.gaussian2kp(heatmap)
69 |
70 | if self.jacobian is not None:
71 | jacobian_map = self.jacobian(feature_map)
72 | jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2],
73 | final_shape[3], final_shape[4])
74 | heatmap = heatmap.unsqueeze(2)
75 |
76 | jacobian = heatmap * jacobian_map
77 | jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1)
78 | jacobian = jacobian.sum(dim=-1)
79 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3)
80 | out['jacobian'] = jacobian
81 |
82 | return out
83 |
84 |
85 | class HEEstimator(nn.Module):
86 | """
87 | Estimating head pose and expression.
88 | """
89 |
90 | def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True):
91 | super(HEEstimator, self).__init__()
92 |
93 | self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2)
94 | self.norm1 = BatchNorm2d(block_expansion, affine=True)
95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 |
97 | self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1)
98 | self.norm2 = BatchNorm2d(256, affine=True)
99 |
100 | self.block1 = nn.Sequential()
101 | for i in range(3):
102 | self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1))
103 |
104 | self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)
105 | self.norm3 = BatchNorm2d(512, affine=True)
106 | self.block2 = ResBottleneck(in_features=512, stride=2)
107 |
108 | self.block3 = nn.Sequential()
109 | for i in range(3):
110 | self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1))
111 |
112 | self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)
113 | self.norm4 = BatchNorm2d(1024, affine=True)
114 | self.block4 = ResBottleneck(in_features=1024, stride=2)
115 |
116 | self.block5 = nn.Sequential()
117 | for i in range(5):
118 | self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1))
119 |
120 | self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1)
121 | self.norm5 = BatchNorm2d(2048, affine=True)
122 | self.block6 = ResBottleneck(in_features=2048, stride=2)
123 |
124 | self.block7 = nn.Sequential()
125 | for i in range(2):
126 | self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1))
127 |
128 | self.fc_roll = nn.Linear(2048, num_bins)
129 | self.fc_pitch = nn.Linear(2048, num_bins)
130 | self.fc_yaw = nn.Linear(2048, num_bins)
131 |
132 | self.fc_t = nn.Linear(2048, 3)
133 |
134 | self.fc_exp = nn.Linear(2048, 3*num_kp)
135 |
136 | def forward(self, x):
137 | out = self.conv1(x)
138 | out = self.norm1(out)
139 | out = F.relu(out)
140 | out = self.maxpool(out)
141 |
142 | out = self.conv2(out)
143 | out = self.norm2(out)
144 | out = F.relu(out)
145 |
146 | out = self.block1(out)
147 |
148 | out = self.conv3(out)
149 | out = self.norm3(out)
150 | out = F.relu(out)
151 | out = self.block2(out)
152 |
153 | out = self.block3(out)
154 |
155 | out = self.conv4(out)
156 | out = self.norm4(out)
157 | out = F.relu(out)
158 | out = self.block4(out)
159 |
160 | out = self.block5(out)
161 |
162 | out = self.conv5(out)
163 | out = self.norm5(out)
164 | out = F.relu(out)
165 | out = self.block6(out)
166 |
167 | out = self.block7(out)
168 |
169 | out = F.adaptive_avg_pool2d(out, 1)
170 | out = out.view(out.shape[0], -1)
171 |
172 | yaw = self.fc_roll(out)
173 | pitch = self.fc_pitch(out)
174 | roll = self.fc_yaw(out)
175 | t = self.fc_t(out)
176 | exp = self.fc_exp(out)
177 |
178 | return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
179 |
--------------------------------------------------------------------------------
/face-vid2vid/run.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 |
3 | matplotlib.use("Agg")
4 |
5 | import os, sys
6 | import yaml
7 | from argparse import ArgumentParser
8 | from time import gmtime, strftime
9 | from shutil import copy
10 |
11 | from frames_dataset import FramesDataset
12 |
13 | from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
14 | from modules.discriminator import MultiScaleDiscriminator
15 | from modules.keypoint_detector import KPDetector, HEEstimator
16 |
17 | import torch
18 |
19 | from train import train
20 |
21 | if __name__ == "__main__":
22 | if sys.version_info[0] < 3:
23 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
24 |
25 | parser = ArgumentParser()
26 | parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
27 | parser.add_argument(
28 | "--mode",
29 | default="train",
30 | choices=[
31 | "train",
32 | ],
33 | )
34 | parser.add_argument("--gen", default="original", choices=["original", "spade"])
35 | parser.add_argument("--log_dir", default="log", help="path to log into")
36 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
37 | parser.add_argument(
38 | "--device_ids",
39 | default="0, 1, 2, 3, 4, 5, 6, 7",
40 | type=lambda x: list(map(int, x.split(","))),
41 | help="Names of the devices comma separated.",
42 | )
43 | parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
44 | parser.set_defaults(verbose=False)
45 |
46 | opt = parser.parse_args()
47 | with open(opt.config) as f:
48 | config = yaml.load(f, Loader=yaml.FullLoader)
49 |
50 | if opt.checkpoint is not None:
51 | log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
52 | else:
53 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split(".")[0])
54 | log_dir += " " + strftime("%d_%m_%y_%H.%M.%S", gmtime())
55 |
56 | if opt.gen == "original":
57 | generator = OcclusionAwareGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"])
58 | elif opt.gen == "spade":
59 | generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"])
60 |
61 | if torch.cuda.is_available():
62 | print("cuda is available")
63 | generator.to(opt.device_ids[0])
64 | if opt.verbose:
65 | print(generator)
66 |
67 | discriminator = MultiScaleDiscriminator(**config["model_params"]["discriminator_params"], **config["model_params"]["common_params"])
68 | if torch.cuda.is_available():
69 | discriminator.to(opt.device_ids[0])
70 | if opt.verbose:
71 | print(discriminator)
72 |
73 | kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"])
74 |
75 | if torch.cuda.is_available():
76 | kp_detector.to(opt.device_ids[0])
77 |
78 | if opt.verbose:
79 | print(kp_detector)
80 |
81 | he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"])
82 |
83 | if torch.cuda.is_available():
84 | he_estimator.to(opt.device_ids[0])
85 |
86 | dataset = FramesDataset(is_train=(opt.mode == "train"), **config["dataset_params"])
87 |
88 | if not os.path.exists(log_dir):
89 | os.makedirs(log_dir)
90 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
91 | copy(opt.config, log_dir)
92 |
93 | if opt.mode == "train":
94 | print("Training...")
95 | train(config, generator, discriminator, kp_detector, he_estimator, opt.checkpoint, log_dir, dataset, opt.device_ids)
96 |
--------------------------------------------------------------------------------
/face-vid2vid/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/face-vid2vid/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/face-vid2vid/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/face-vid2vid/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/face-vid2vid/train.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 | import torch
3 |
4 | from torch.utils.data import DataLoader
5 |
6 | from logger import Logger
7 | from modules.model import GeneratorFullModel, DiscriminatorFullModel
8 |
9 | from torch.optim.lr_scheduler import MultiStepLR
10 |
11 | from sync_batchnorm import DataParallelWithCallback
12 |
13 | from frames_dataset import DatasetRepeater
14 |
15 |
16 | def train(config, generator, discriminator, kp_detector, he_estimator, checkpoint, log_dir, dataset, device_ids):
17 | train_params = config["train_params"]
18 |
19 | optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params["lr_generator"], betas=(0.5, 0.999))
20 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params["lr_discriminator"], betas=(0.5, 0.999))
21 | optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params["lr_kp_detector"], betas=(0.5, 0.999))
22 | optimizer_he_estimator = torch.optim.Adam(he_estimator.parameters(), lr=train_params["lr_he_estimator"], betas=(0.5, 0.999))
23 |
24 | if checkpoint is not None:
25 | start_epoch = Logger.load_cpk(
26 | checkpoint,
27 | generator,
28 | discriminator,
29 | kp_detector,
30 | he_estimator,
31 | optimizer_generator,
32 | optimizer_discriminator,
33 | optimizer_kp_detector,
34 | optimizer_he_estimator,
35 | )
36 | else:
37 | start_epoch = 0
38 |
39 | scheduler_generator = MultiStepLR(optimizer_generator, train_params["epoch_milestones"], gamma=0.1, last_epoch=start_epoch - 1)
40 | scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params["epoch_milestones"], gamma=0.1, last_epoch=start_epoch - 1)
41 | scheduler_kp_detector = MultiStepLR(
42 | optimizer_kp_detector, train_params["epoch_milestones"], gamma=0.1, last_epoch=-1 + start_epoch * (train_params["lr_kp_detector"] != 0)
43 | )
44 | scheduler_he_estimator = MultiStepLR(
45 | optimizer_he_estimator, train_params["epoch_milestones"], gamma=0.1, last_epoch=-1 + start_epoch * (train_params["lr_kp_detector"] != 0)
46 | )
47 |
48 | if "num_repeats" in train_params or train_params["num_repeats"] != 1:
49 | dataset = DatasetRepeater(dataset, train_params["num_repeats"])
50 | dataloader = DataLoader(dataset, batch_size=train_params["batch_size"], shuffle=True, num_workers=16, drop_last=True)
51 |
52 | generator_full = GeneratorFullModel(
53 | kp_detector,
54 | he_estimator,
55 | generator,
56 | discriminator,
57 | train_params,
58 | estimate_jacobian=config["model_params"]["common_params"]["estimate_jacobian"],
59 | )
60 | discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)
61 |
62 | if torch.cuda.is_available():
63 | generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
64 | discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)
65 |
66 | with Logger(log_dir=log_dir, visualizer_params=config["visualizer_params"], checkpoint_freq=train_params["checkpoint_freq"]) as logger:
67 | for epoch in trange(start_epoch, train_params["num_epochs"]):
68 | for x in dataloader:
69 | losses_generator, generated = generator_full(x)
70 |
71 | loss_values = [val.mean() for val in losses_generator.values()]
72 | loss = sum(loss_values)
73 |
74 | loss.backward()
75 | optimizer_generator.step()
76 | optimizer_generator.zero_grad()
77 | optimizer_kp_detector.step()
78 | optimizer_kp_detector.zero_grad()
79 | optimizer_he_estimator.step()
80 | optimizer_he_estimator.zero_grad()
81 |
82 | if train_params["loss_weights"]["generator_gan"] != 0:
83 | optimizer_discriminator.zero_grad()
84 | losses_discriminator = discriminator_full(x, generated)
85 | loss_values = [val.mean() for val in losses_discriminator.values()]
86 | loss = sum(loss_values)
87 |
88 | loss.backward()
89 | optimizer_discriminator.step()
90 | optimizer_discriminator.zero_grad()
91 | else:
92 | losses_discriminator = {}
93 |
94 | losses_generator.update(losses_discriminator)
95 | losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
96 | logger.log_iter(losses=losses)
97 |
98 | scheduler_generator.step()
99 | scheduler_discriminator.step()
100 | scheduler_kp_detector.step()
101 | scheduler_he_estimator.step()
102 |
103 | logger.log_epoch(
104 | epoch,
105 | {
106 | "generator": generator,
107 | "discriminator": discriminator,
108 | "kp_detector": kp_detector,
109 | "he_estimator": he_estimator,
110 | "optimizer_generator": optimizer_generator,
111 | "optimizer_discriminator": optimizer_discriminator,
112 | "optimizer_kp_detector": optimizer_kp_detector,
113 | "optimizer_he_estimator": optimizer_he_estimator,
114 | },
115 | inp=x,
116 | out=generated,
117 | )
118 |
--------------------------------------------------------------------------------
/remote_server.py:
--------------------------------------------------------------------------------
1 | import io
2 | import cv2
3 | import numpy as np
4 | from PIL import Image
5 | from argparse import ArgumentParser
6 |
7 |
8 | from fastapi import FastAPI, WebSocket
9 | from fastapi.websockets import WebSocketDisconnect
10 | from demo_utils import FaceAnimationClass
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument("--source_image", default="./assets/source.jpg", help="path to source image")
14 | parser.add_argument("--restore_face", default=False, type=str, help="restore face")
15 | args = parser.parse_args()
16 | restore_face = True if args.restore_face == 'True' else False if args.restore_face == 'False' else exit('restore_face must be True or False')
17 |
18 |
19 | faceanimation = FaceAnimationClass(source_image_path=args.source_image, use_sr=restore_face)
20 | # remote server fps is lower than local camera fps, so we need to increase the frequency of face detection and increase the smooth factor
21 | faceanimation.detect_interval = 2
22 | faceanimation.smooth_factor = 0.8
23 |
24 |
25 | app = FastAPI()
26 | websocket_port = 8066
27 |
28 |
29 | # WebSocket endpoint to receive and process images
30 | @app.websocket("/ws")
31 | async def websocket_endpoint(websocket: WebSocket):
32 | await websocket.accept()
33 | try:
34 | while True:
35 | # Receive the image as a binary stream
36 | image_data = await websocket.receive_bytes()
37 | processed_image = process_image(image_data)
38 | # Send the processed image back to the client
39 | await websocket.send_bytes(processed_image)
40 | except WebSocketDisconnect:
41 | pass
42 |
43 |
44 | def process_image(image_data):
45 | image = Image.open(io.BytesIO(image_data))
46 | image_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
47 | face, result = faceanimation.inference(image_cv2)
48 | # resize to 256x256
49 | if face.shape[1] != 256 or face.shape[0] != 256:
50 | face = cv2.resize(face, (256, 256))
51 | if result.shape[0] != 256 or result.shape[1] != 256:
52 | result = cv2.resize(result, (256, 256))
53 | result = cv2.hconcat([face, result])
54 | _, processed_image_data = cv2.imencode(".jpg", result, [cv2.IMWRITE_JPEG_QUALITY, 95])
55 | return processed_image_data.tobytes()
56 |
57 |
58 | if __name__ == "__main__":
59 | import uvicorn
60 |
61 | uvicorn.run(app, host="0.0.0.0", port=websocket_port)
62 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.11.0
2 | numpy>=1.23.5
3 | PyYAML
4 | imageio[ffmpeg]
5 | batch-face
6 | gdown
7 | scipy
--------------------------------------------------------------------------------