├── .gitignore ├── GPEN ├── README.md ├── __init_paths.py ├── align_faces.py ├── face_enhancement.py ├── face_model │ ├── face_gan.py │ ├── model.py │ └── op │ │ ├── __init__.py │ │ ├── fused_act.py │ │ ├── fused_bias_act.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu ├── requirements.txt ├── retinaface │ ├── .DS_Store │ ├── data │ │ ├── FDDB │ │ │ └── img_list.txt │ │ ├── __init__.py │ │ ├── config.py │ │ ├── data_augment.py │ │ └── wider_face.py │ ├── facemodels │ │ ├── __init__.py │ │ ├── net.py │ │ └── retinaface.py │ ├── layers │ │ ├── __init__.py │ │ ├── functions │ │ │ └── prior_box.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ └── multibox_loss.py │ ├── retinaface_detection.py │ └── utils │ │ ├── __init__.py │ │ ├── box_utils.py │ │ ├── nms │ │ ├── __init__.py │ │ └── py_cpu_nms.py │ │ └── timer.py └── sr_model │ ├── arch_util.py │ ├── real_esrnet.py │ └── rrdbnet_arch.py ├── LICENSE.md ├── README.md ├── assets ├── driving.mp4 └── source.jpg ├── camera_client.py ├── camera_local.py ├── demo_utils.py ├── face-vid2vid ├── LICENSE.md ├── README.md ├── animate.py ├── augmentation.py ├── config │ └── vox-256-spade.yaml ├── crop-video.py ├── demo.py ├── frames_dataset.py ├── logger.py ├── modules │ ├── dense_motion.py │ ├── discriminator.py │ ├── generator.py │ ├── hopenet.py │ ├── keypoint_detector.py │ ├── model.py │ └── util.py ├── run.py ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py └── train.py ├── remote_server.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | 30 | *.pth 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /GPEN/README.md: -------------------------------------------------------------------------------- 1 | # GAN Prior Embedded Network for Blind Face Restoration in the Wild 2 | 3 | [Paper](https://arxiv.org/abs/2105.06070) | [Supplementary](https://www4.comp.polyu.edu.hk/~cslzhang/paper/GPEN-cvpr21-supp.pdf) | [Demo](https://vision.aliyun.com/experience/detail?spm=a211p3.14020179.J_7524944390.17.66cd4850wVDkUQ&tagName=facebody&children=EnhanceFace) 4 | 5 | [Tao Yang](https://cg.cs.tsinghua.edu.cn/people/~tyang)1, Peiran Ren1, Xuansong Xie1, [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)1,2 6 | _1[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Hangzhou, China_ 7 | _2[Department of Computing, The Hong Kong Polytechnic University](http://www.comp.polyu.edu.hk), Hong Kong, China_ 8 | 9 | #### Face Restoration 10 | 11 | 12 | 13 | 14 | 15 | 16 | #### Face Colorization 17 | 18 | 19 | 20 | #### Face Inpainting 21 | 22 | 23 | 24 | #### Conditional Image Synthesis (Seg2Face) 25 | 26 | 27 | 28 | ## News 29 | (2021-07-06) The training code will be released soon. Stay tuned. 30 | 31 | (2021-10-11) The Colab demo for GPEN is available now google colab logo. 32 | 33 | (2021-10-22) GPEN can now work with SR methods. A SR model trained by myself is provided. Replace it with your own model if necessary. 34 | 35 | ## Usage 36 | 37 | ![python](https://img.shields.io/badge/python-v3.7.4-green.svg?style=plastic) 38 | ![pytorch](https://img.shields.io/badge/pytorch-v1.7.0-green.svg?style=plastic) 39 | ![cuda](https://img.shields.io/badge/cuda-v10.2.89-green.svg?style=plastic) 40 | ![driver](https://img.shields.io/badge/driver-v460.73.01-green.svg?style=plastic) 41 | ![gcc](https://img.shields.io/badge/gcc-v7.5.0-green.svg?style=plastic) 42 | 43 | - Clone this repository: 44 | ```bash 45 | git clone https://github.com/yangxy/GPEN.git 46 | cd GPEN 47 | ``` 48 | - Download RetinaFace model and our pre-trained model (not our best model due to commercial issues) and put them into ``weights/``. 49 | 50 | [RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [GPEN-BFR-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth) | [GPEN-BFR-512-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512-D.pth) | [GPEN-BFR-256](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256.pth) | [GPEN-Colorization-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth) | [GPEN-Inpainting-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Inpainting-1024.pth) | [GPEN-Seg2face-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Seg2face-512.pth) | [rrdb_realesrnet_psnr](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/rrdb_realesrnet_psnr.pth) 51 | 52 | - Restore face images: 53 | ```bash 54 | python face_enhancement.py --model GPEN-BFR-512 --size 512 --channel_multiplier 2 --narrow 1 --use_sr --indir examples/imgs --outdir examples/outs-BFR 55 | ``` 56 | 57 | - Colorize faces: 58 | ```bash 59 | python face_colorization.py 60 | ``` 61 | 62 | - Complete faces: 63 | ```bash 64 | python face_inpainting.py 65 | ``` 66 | 67 | - Synthesize faces: 68 | ```bash 69 | python segmentation2face.py 70 | ``` 71 | 72 | ## Main idea 73 | 74 | 75 | ## Citation 76 | If our work is useful for your research, please consider citing: 77 | 78 | @inproceedings{Yang2021GPEN, 79 | title={GAN Prior Embedded Network for Blind Face Restoration in the Wild}, 80 | author={Tao Yang, Peiran Ren, Xuansong Xie, and Lei Zhang}, 81 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 82 | year={2021} 83 | } 84 | 85 | ## License 86 | © Alibaba, 2021. For academic and non-commercial use only. 87 | 88 | ## Acknowledgments 89 | We borrow some codes from [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface), [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch), and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN). 90 | 91 | ## Contact 92 | If you have any questions or suggestions about this paper, feel free to reach me at yangtao9009@gmail.com. 93 | -------------------------------------------------------------------------------- /GPEN/__init_paths.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import os.path as osp 6 | import sys 7 | 8 | def add_path(path): 9 | if path not in sys.path: 10 | sys.path.insert(0, path) 11 | 12 | this_dir = osp.dirname(__file__) 13 | 14 | path = osp.join(this_dir, 'retinaface') 15 | add_path(path) 16 | 17 | path = osp.join(this_dir, 'face_model') 18 | add_path(path) 19 | 20 | path = osp.join(this_dir, 'sr_model') 21 | add_path(path) -------------------------------------------------------------------------------- /GPEN/align_faces.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 24 15:43:29 2017 4 | @author: zhaoy 5 | """ 6 | """ 7 | @Modified by yangxy (yangtao9009@gmail.com) 8 | """ 9 | import cv2 10 | import numpy as np 11 | from skimage import transform as trans 12 | 13 | # reference facial points, a list of coordinates (x,y) 14 | REFERENCE_FACIAL_POINTS = [ 15 | [30.29459953, 51.69630051], 16 | [65.53179932, 51.50139999], 17 | [48.02519989, 71.73660278], 18 | [33.54930115, 92.3655014], 19 | [62.72990036, 92.20410156], 20 | ] 21 | 22 | DEFAULT_CROP_SIZE = (96, 112) 23 | 24 | 25 | def _umeyama(src, dst, estimate_scale=True, scale=1.0): 26 | """Estimate N-D similarity transformation with or without scaling. 27 | Parameters 28 | ---------- 29 | src : (M, N) array 30 | Source coordinates. 31 | dst : (M, N) array 32 | Destination coordinates. 33 | estimate_scale : bool 34 | Whether to estimate scaling factor. 35 | Returns 36 | ------- 37 | T : (N + 1, N + 1) 38 | The homogeneous similarity transformation matrix. The matrix contains 39 | NaN values only if the problem == not well-conditioned. 40 | References 41 | ---------- 42 | .. [1] "Least-squares estimation of transformation parameters between two 43 | point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` 44 | """ 45 | 46 | num = src.shape[0] 47 | dim = src.shape[1] 48 | 49 | # Compute mean of src and dst. 50 | src_mean = src.mean(axis=0) 51 | dst_mean = dst.mean(axis=0) 52 | 53 | # Subtract mean from src and dst. 54 | src_demean = src - src_mean 55 | dst_demean = dst - dst_mean 56 | 57 | # Eq. (38). 58 | A = dst_demean.T @ src_demean / num 59 | 60 | # Eq. (39). 61 | d = np.ones((dim,), dtype=np.double) 62 | if np.linalg.det(A) < 0: 63 | d[dim - 1] = -1 64 | 65 | T = np.eye(dim + 1, dtype=np.double) 66 | 67 | U, S, V = np.linalg.svd(A) 68 | 69 | # Eq. (40) and (43). 70 | rank = np.linalg.matrix_rank(A) 71 | if rank == 0: 72 | return np.nan * T 73 | elif rank == dim - 1: 74 | if np.linalg.det(U) * np.linalg.det(V) > 0: 75 | T[:dim, :dim] = U @ V 76 | else: 77 | s = d[dim - 1] 78 | d[dim - 1] = -1 79 | T[:dim, :dim] = U @ np.diag(d) @ V 80 | d[dim - 1] = s 81 | else: 82 | T[:dim, :dim] = U @ np.diag(d) @ V 83 | 84 | if estimate_scale: 85 | # Eq. (41) and (42). 86 | scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) 87 | else: 88 | scale = scale 89 | 90 | T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) 91 | T[:dim, :dim] *= scale 92 | 93 | return T, scale 94 | 95 | 96 | class FaceWarpException(Exception): 97 | def __str__(self): 98 | return "In File {}:{}".format(__file__, super.__str__(self)) 99 | 100 | 101 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): 102 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 103 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 104 | 105 | # 0) make the inner region a square 106 | if default_square: 107 | size_diff = max(tmp_crop_size) - tmp_crop_size 108 | tmp_5pts += size_diff / 2 109 | tmp_crop_size += size_diff 110 | 111 | if output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]: 112 | print("output_size == DEFAULT_CROP_SIZE {}: return default reference points".format(tmp_crop_size)) 113 | return tmp_5pts 114 | 115 | if inner_padding_factor == 0 and outer_padding == (0, 0): 116 | if output_size is None: 117 | print("No paddings to do: return default reference points") 118 | return tmp_5pts 119 | else: 120 | raise FaceWarpException("No paddings to do, output_size must be None or {}".format(tmp_crop_size)) 121 | 122 | # check output size 123 | if not (0 <= inner_padding_factor <= 1.0): 124 | raise FaceWarpException("Not (0 <= inner_padding_factor <= 1.0)") 125 | 126 | if (inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None: 127 | output_size = tmp_crop_size * (1 + inner_padding_factor * 2).astype(np.int32) 128 | output_size += np.array(outer_padding) 129 | print(" deduced from paddings, output_size = ", output_size) 130 | 131 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): 132 | raise FaceWarpException("Not (outer_padding[0] < output_size[0]" "and outer_padding[1] < output_size[1])") 133 | 134 | # 1) pad the inner region according inner_padding_factor 135 | # print('---> STEP1: pad the inner region according inner_padding_factor') 136 | if inner_padding_factor > 0: 137 | size_diff = tmp_crop_size * inner_padding_factor * 2 138 | tmp_5pts += size_diff / 2 139 | tmp_crop_size += np.round(size_diff).astype(np.int32) 140 | 141 | # print(' crop_size = ', tmp_crop_size) 142 | # print(' reference_5pts = ', tmp_5pts) 143 | 144 | # 2) resize the padded inner region 145 | # print('---> STEP2: resize the padded inner region') 146 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 147 | # print(' crop_size = ', tmp_crop_size) 148 | # print(' size_bf_outer_pad = ', size_bf_outer_pad) 149 | 150 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: 151 | raise FaceWarpException("Must have (output_size - outer_padding)" "= some_scale * (crop_size * (1.0 + inner_padding_factor)") 152 | 153 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 154 | # print(' resize scale_factor = ', scale_factor) 155 | tmp_5pts = tmp_5pts * scale_factor 156 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) 157 | # tmp_5pts = tmp_5pts + size_diff / 2 158 | tmp_crop_size = size_bf_outer_pad 159 | # print(' crop_size = ', tmp_crop_size) 160 | # print(' reference_5pts = ', tmp_5pts) 161 | 162 | # 3) add outer_padding to make output_size 163 | reference_5point = tmp_5pts + np.array(outer_padding) 164 | tmp_crop_size = output_size 165 | # print('---> STEP3: add outer_padding to make output_size') 166 | # print(' crop_size = ', tmp_crop_size) 167 | # print(' reference_5pts = ', tmp_5pts) 168 | # 169 | # print('===> end get_reference_facial_points\n') 170 | 171 | return reference_5point 172 | 173 | 174 | def get_affine_transform_matrix(src_pts, dst_pts): 175 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 176 | n_pts = src_pts.shape[0] 177 | ones = np.ones((n_pts, 1), src_pts.dtype) 178 | src_pts_ = np.hstack([src_pts, ones]) 179 | dst_pts_ = np.hstack([dst_pts, ones]) 180 | 181 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 182 | 183 | if rank == 3: 184 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) 185 | elif rank == 2: 186 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) 187 | 188 | return tfm 189 | 190 | 191 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"): # smilarity cv2_affine affine 192 | if reference_pts is None: 193 | if crop_size[0] == 96 and crop_size[1] == 112: 194 | reference_pts = REFERENCE_FACIAL_POINTS 195 | else: 196 | default_square = False 197 | inner_padding_factor = 0 198 | outer_padding = (0, 0) 199 | output_size = crop_size 200 | 201 | reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, default_square) 202 | ref_pts = np.float32(reference_pts) 203 | ref_pts_shp = ref_pts.shape 204 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 205 | raise FaceWarpException("reference_pts.shape must be (K,2) or (2,K) and K>2") 206 | 207 | if ref_pts_shp[0] == 2: 208 | ref_pts = ref_pts.T 209 | 210 | src_pts = np.float32(facial_pts) 211 | src_pts_shp = src_pts.shape 212 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 213 | raise FaceWarpException("facial_pts.shape must be (K,2) or (2,K) and K>2") 214 | 215 | if src_pts_shp[0] == 2: 216 | src_pts = src_pts.T 217 | 218 | if src_pts.shape != ref_pts.shape: 219 | raise FaceWarpException("facial_pts and reference_pts must have the same shape") 220 | 221 | if align_type == "cv2_affine": 222 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) 223 | tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) 224 | elif align_type == "affine": 225 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 226 | tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) 227 | else: 228 | params, scale = _umeyama(src_pts, ref_pts) 229 | tfm = params[:2, :] 230 | 231 | params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) 232 | tfm_inv = params[:2, :] 233 | 234 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3) 235 | 236 | return face_img, tfm_inv 237 | -------------------------------------------------------------------------------- /GPEN/face_enhancement.py: -------------------------------------------------------------------------------- 1 | """ 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | """ 5 | import os 6 | import cv2 7 | import glob 8 | import time 9 | import argparse 10 | import numpy as np 11 | from PIL import Image 12 | from skimage import transform as tf 13 | 14 | import GPEN.__init_paths as init_paths 15 | from GPEN.retinaface.retinaface_detection import RetinaFaceDetection 16 | from GPEN.face_model.face_gan import FaceGAN 17 | from GPEN.sr_model.real_esrnet import RealESRNet 18 | from GPEN.align_faces import warp_and_crop_face, get_reference_facial_points 19 | 20 | def check_ckpts(model, sr_model): 21 | # check if checkpoints are downloaded 22 | try: 23 | ckpts_folder = os.path.join(os.path.dirname(__file__), "weights") 24 | if not os.path.exists(ckpts_folder): 25 | print("Downloading checkpoints...") 26 | from gdown import download_folder 27 | file_id = "1epln5c8HW1QXfVz6444Fe0hG-vRNavi6" 28 | download_folder(id=file_id, output=ckpts_folder, quiet=False, use_cookies=False) 29 | else: 30 | print("Checkpoints already downloaded, skipping...") 31 | except Exception as e: 32 | print(e) 33 | raise Exception("Error while downloading checkpoints") 34 | 35 | 36 | class FaceEnhancement(object): 37 | def __init__(self, base_dir=os.path.dirname(__file__), size=512, model=None, use_sr=True, sr_model=None, channel_multiplier=2, narrow=1, use_facegan=True): 38 | check_ckpts(model, sr_model) 39 | 40 | self.facedetector = RetinaFaceDetection(base_dir) 41 | self.facegan = FaceGAN(base_dir, size, model, channel_multiplier, narrow) 42 | self.srmodel = RealESRNet(base_dir, sr_model) 43 | self.use_sr = use_sr 44 | self.size = size 45 | self.threshold = 0.9 46 | self.use_facegan = use_facegan 47 | 48 | # the mask for pasting restored faces back 49 | self.mask = np.zeros((512, 512), np.float32) 50 | cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA) 51 | self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) 52 | self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) 53 | 54 | self.kernel = np.array(([0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625]), dtype="float32") 55 | 56 | # get the reference 5 landmarks position in the crop settings 57 | default_square = True 58 | inner_padding_factor = 0.25 59 | outer_padding = (0, 0) 60 | self.reference_5pts = get_reference_facial_points((self.size, self.size), inner_padding_factor, outer_padding, default_square) 61 | 62 | def process(self, img): 63 | if self.use_sr: 64 | img_sr = self.srmodel.process(img) 65 | if img_sr is not None: 66 | img = cv2.resize(img, img_sr.shape[:2][::-1]) 67 | 68 | facebs, landms = self.facedetector.detect(img) 69 | 70 | orig_faces, enhanced_faces = [], [] 71 | height, width = img.shape[:2] 72 | full_mask = np.zeros((height, width), dtype=np.float32) 73 | full_img = np.zeros(img.shape, dtype=np.uint8) 74 | 75 | for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): 76 | if faceb[4] < self.threshold: 77 | continue 78 | fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0]) 79 | 80 | facial5points = np.reshape(facial5points, (2, 5)) 81 | 82 | of, tfm_inv = warp_and_crop_face(img, facial5points, reference_pts=self.reference_5pts, crop_size=(self.size, self.size)) 83 | 84 | # enhance the face 85 | ef = self.facegan.process(of) if self.use_facegan else of 86 | 87 | orig_faces.append(of) 88 | enhanced_faces.append(ef) 89 | 90 | tmp_mask = self.mask 91 | tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) 92 | tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3) 93 | 94 | if min(fh, fw) < 100: # gaussian filter for small faces 95 | ef = cv2.filter2D(ef, -1, self.kernel) 96 | 97 | tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) 98 | 99 | mask = tmp_mask - full_mask 100 | full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] 101 | full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] 102 | 103 | full_mask = full_mask[:, :, np.newaxis] 104 | if self.use_sr and img_sr is not None: 105 | img = cv2.convertScaleAbs(img_sr * (1 - full_mask) + full_img * full_mask) 106 | else: 107 | img = cv2.convertScaleAbs(img * (1 - full_mask) + full_img * full_mask) 108 | 109 | return img, orig_faces, enhanced_faces 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument("--model", type=str, default="GPEN-BFR-512", help="GPEN model") 115 | parser.add_argument("--size", type=int, default=512, help="resolution of GPEN") 116 | parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of GPEN") 117 | parser.add_argument("--narrow", type=float, default=1, help="channel narrow scale") 118 | parser.add_argument("--use_sr", action="store_true", help="use sr or not") 119 | parser.add_argument("--sr_model", type=str, default="realesrnet_x2", help="SR model") 120 | parser.add_argument("--sr_scale", type=int, default=2, help="SR scale") 121 | parser.add_argument("--indir", type=str, default="examples/imgs", help="input folder") 122 | parser.add_argument("--outdir", type=str, default="results/outs-BFR", help="output folder") 123 | args = parser.parse_args() 124 | 125 | # model = {'name':'GPEN-BFR-512', 'size':512, 'channel_multiplier':2, 'narrow':1} 126 | # model = {'name':'GPEN-BFR-256', 'size':256, 'channel_multiplier':1, 'narrow':0.5} 127 | 128 | os.makedirs(args.outdir, exist_ok=True) 129 | 130 | faceenhancer = FaceEnhancement( 131 | size=args.size, 132 | model=args.model, 133 | use_sr=args.use_sr, 134 | sr_model=args.sr_model, 135 | channel_multiplier=args.channel_multiplier, 136 | narrow=args.narrow, 137 | ) 138 | 139 | files = sorted(glob.glob(os.path.join(args.indir, "*.*g"))) 140 | for n, file in enumerate(files[:]): 141 | filename = os.path.basename(file) 142 | 143 | im = cv2.imread(file, cv2.IMREAD_COLOR) # BGR 144 | if not isinstance(im, np.ndarray): 145 | print(filename, "error") 146 | continue 147 | # im = cv2.resize(im, (0,0), fx=2, fy=2) # optional 148 | 149 | img, orig_faces, enhanced_faces = faceenhancer.process(im) 150 | 151 | im = cv2.resize(im, img.shape[:2][::-1]) 152 | cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_COMP.jpg"), np.hstack((im, img))) 153 | cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_GPEN.jpg"), img) 154 | 155 | for m, (ef, of) in enumerate(zip(enhanced_faces, orig_faces)): 156 | of = cv2.resize(of, ef.shape[:2]) 157 | cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_face%02d" % m + ".jpg"), np.hstack((of, ef))) 158 | 159 | if n % 10 == 0: 160 | print(n, filename) 161 | -------------------------------------------------------------------------------- /GPEN/face_model/face_gan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import torch 6 | import os 7 | import cv2 8 | import glob 9 | import numpy as np 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from torchvision import transforms, utils 13 | from model import FullGenerator 14 | import torch 15 | 16 | class FaceGAN(object): 17 | def __init__(self, base_dir='./', size=512, model=None, channel_multiplier=2, narrow=1, is_norm=True): 18 | self.mfile = os.path.join(base_dir, 'weights', model+'.pth') 19 | self.n_mlp = 8 20 | self.is_norm = is_norm 21 | self.resolution = size 22 | self.load_model(channel_multiplier, narrow) 23 | 24 | def load_model(self, channel_multiplier=2, narrow=1): 25 | self.model = FullGenerator(self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow).cuda() 26 | pretrained_dict = torch.load(self.mfile) 27 | self.model.load_state_dict(pretrained_dict) 28 | self.model.eval() 29 | 30 | def process(self, img): 31 | img = cv2.resize(img, (self.resolution, self.resolution)) 32 | img_t = self.img2tensor(img) 33 | 34 | with torch.no_grad(): 35 | out, __ = self.model(img_t) 36 | 37 | out = self.tensor2img(out) 38 | 39 | return out 40 | 41 | def img2tensor(self, img): 42 | img_t = torch.from_numpy(img).cuda()/255. 43 | if self.is_norm: 44 | img_t = (img_t - 0.5) / 0.5 45 | img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB 46 | return img_t 47 | 48 | def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8): 49 | if self.is_norm: 50 | img_t = img_t * 0.5 + 0.5 51 | img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR 52 | img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax 53 | 54 | return img_np.astype(imtype) 55 | -------------------------------------------------------------------------------- /GPEN/face_model/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /GPEN/face_model/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load, _import_module_from_library 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | #fused = _import_module_from_library('fused', '/tmp/torch_extensions/fused', True) 19 | 20 | 21 | class FusedLeakyReLUFunctionBackward(Function): 22 | @staticmethod 23 | def forward(ctx, grad_output, out, negative_slope, scale): 24 | ctx.save_for_backward(out) 25 | ctx.negative_slope = negative_slope 26 | ctx.scale = scale 27 | 28 | empty = grad_output.new_empty(0) 29 | 30 | grad_input = fused.fused_bias_act( 31 | grad_output, empty, out, 3, 1, negative_slope, scale 32 | ) 33 | 34 | dim = [0] 35 | 36 | if grad_input.ndim > 2: 37 | dim += list(range(2, grad_input.ndim)) 38 | 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | return grad_input, grad_bias 42 | 43 | @staticmethod 44 | def backward(ctx, gradgrad_input, gradgrad_bias): 45 | out, = ctx.saved_tensors 46 | gradgrad_out = fused.fused_bias_act( 47 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 48 | ) 49 | 50 | return gradgrad_out, None, None, None 51 | 52 | 53 | class FusedLeakyReLUFunction(Function): 54 | @staticmethod 55 | def forward(ctx, input, bias, negative_slope, scale): 56 | empty = input.new_empty(0) 57 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 58 | ctx.save_for_backward(out) 59 | ctx.negative_slope = negative_slope 60 | ctx.scale = scale 61 | 62 | return out 63 | 64 | @staticmethod 65 | def backward(ctx, grad_output): 66 | out, = ctx.saved_tensors 67 | 68 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 69 | grad_output, out, ctx.negative_slope, ctx.scale 70 | ) 71 | 72 | return grad_input, grad_bias, None, None 73 | 74 | 75 | class FusedLeakyReLU(nn.Module): 76 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 77 | super().__init__() 78 | 79 | self.bias = nn.Parameter(torch.zeros(channel)) 80 | self.negative_slope = negative_slope 81 | self.scale = scale 82 | 83 | def forward(self, input): 84 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 85 | 86 | 87 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 88 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 89 | -------------------------------------------------------------------------------- /GPEN/face_model/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /GPEN/face_model/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /GPEN/face_model/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /GPEN/face_model/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load, _import_module_from_library 6 | 7 | 8 | module_path = os.path.dirname(__file__) 9 | upfirdn2d_op = load( 10 | 'upfirdn2d', 11 | sources=[ 12 | os.path.join(module_path, 'upfirdn2d.cpp'), 13 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 14 | ], 15 | ) 16 | 17 | #upfirdn2d_op = _import_module_from_library('upfirdn2d', '/tmp/torch_extensions/upfirdn2d', True) 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | out = UpFirDn2d.apply( 147 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 148 | ) 149 | 150 | return out 151 | 152 | 153 | def upfirdn2d_native( 154 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 155 | ): 156 | _, in_h, in_w, minor = input.shape 157 | kernel_h, kernel_w = kernel.shape 158 | 159 | out = input.view(-1, in_h, 1, in_w, 1, minor) 160 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 161 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 162 | 163 | out = F.pad( 164 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 165 | ) 166 | out = out[ 167 | :, 168 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 169 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 170 | :, 171 | ] 172 | 173 | out = out.permute(0, 3, 1, 2) 174 | out = out.reshape( 175 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 176 | ) 177 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 178 | out = F.conv2d(out, w) 179 | out = out.reshape( 180 | -1, 181 | minor, 182 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 183 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 184 | ) 185 | out = out.permute(0, 2, 3, 1) 186 | 187 | return out[:, ::down_y, ::down_x, :] 188 | 189 | -------------------------------------------------------------------------------- /GPEN/face_model/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /GPEN/requirements.txt: -------------------------------------------------------------------------------- 1 | ninja 2 | torch 3 | torchvision 4 | opencv-python 5 | numpy 6 | scikit-image 7 | pillow 8 | pyyaml==5.4.1 -------------------------------------------------------------------------------- /GPEN/retinaface/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/GPEN/retinaface/.DS_Store -------------------------------------------------------------------------------- /GPEN/retinaface/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .wider_face import WiderFaceDetection, detection_collate 2 | from .data_augment import * 3 | from .config import * 4 | -------------------------------------------------------------------------------- /GPEN/retinaface/data/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | 3 | cfg_mnet = { 4 | 'name': 'mobilenet0.25', 5 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 6 | 'steps': [8, 16, 32], 7 | 'variance': [0.1, 0.2], 8 | 'clip': False, 9 | 'loc_weight': 2.0, 10 | 'gpu_train': True, 11 | 'batch_size': 32, 12 | 'ngpu': 1, 13 | 'epoch': 250, 14 | 'decay1': 190, 15 | 'decay2': 220, 16 | 'image_size': 640, 17 | 'pretrain': False, 18 | 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, 19 | 'in_channel': 32, 20 | 'out_channel': 64 21 | } 22 | 23 | cfg_re50 = { 24 | 'name': 'Resnet50', 25 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 26 | 'steps': [8, 16, 32], 27 | 'variance': [0.1, 0.2], 28 | 'clip': False, 29 | 'loc_weight': 2.0, 30 | 'gpu_train': True, 31 | 'batch_size': 24, 32 | 'ngpu': 4, 33 | 'epoch': 100, 34 | 'decay1': 70, 35 | 'decay2': 90, 36 | 'image_size': 840, 37 | 'pretrain': False, 38 | 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, 39 | 'in_channel': 256, 40 | 'out_channel': 256 41 | } 42 | 43 | -------------------------------------------------------------------------------- /GPEN/retinaface/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | from utils.box_utils import matrix_iof 5 | 6 | 7 | def _crop(image, boxes, labels, landm, img_dim): 8 | height, width, _ = image.shape 9 | pad_image_flag = True 10 | 11 | for _ in range(250): 12 | """ 13 | if random.uniform(0, 1) <= 0.2: 14 | scale = 1.0 15 | else: 16 | scale = random.uniform(0.3, 1.0) 17 | """ 18 | PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0] 19 | scale = random.choice(PRE_SCALES) 20 | short_side = min(width, height) 21 | w = int(scale * short_side) 22 | h = w 23 | 24 | if width == w: 25 | l = 0 26 | else: 27 | l = random.randrange(width - w) 28 | if height == h: 29 | t = 0 30 | else: 31 | t = random.randrange(height - h) 32 | roi = np.array((l, t, l + w, t + h)) 33 | 34 | value = matrix_iof(boxes, roi[np.newaxis]) 35 | flag = (value >= 1) 36 | if not flag.any(): 37 | continue 38 | 39 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2 40 | mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) 41 | boxes_t = boxes[mask_a].copy() 42 | labels_t = labels[mask_a].copy() 43 | landms_t = landm[mask_a].copy() 44 | landms_t = landms_t.reshape([-1, 5, 2]) 45 | 46 | if boxes_t.shape[0] == 0: 47 | continue 48 | 49 | image_t = image[roi[1]:roi[3], roi[0]:roi[2]] 50 | 51 | boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) 52 | boxes_t[:, :2] -= roi[:2] 53 | boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) 54 | boxes_t[:, 2:] -= roi[:2] 55 | 56 | # landm 57 | landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2] 58 | landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0])) 59 | landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2]) 60 | landms_t = landms_t.reshape([-1, 10]) 61 | 62 | 63 | # make sure that the cropped image contains at least one face > 16 pixel at training image scale 64 | b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim 65 | b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim 66 | mask_b = np.minimum(b_w_t, b_h_t) > 0.0 67 | boxes_t = boxes_t[mask_b] 68 | labels_t = labels_t[mask_b] 69 | landms_t = landms_t[mask_b] 70 | 71 | if boxes_t.shape[0] == 0: 72 | continue 73 | 74 | pad_image_flag = False 75 | 76 | return image_t, boxes_t, labels_t, landms_t, pad_image_flag 77 | return image, boxes, labels, landm, pad_image_flag 78 | 79 | 80 | def _distort(image): 81 | 82 | def _convert(image, alpha=1, beta=0): 83 | tmp = image.astype(float) * alpha + beta 84 | tmp[tmp < 0] = 0 85 | tmp[tmp > 255] = 255 86 | image[:] = tmp 87 | 88 | image = image.copy() 89 | 90 | if random.randrange(2): 91 | 92 | #brightness distortion 93 | if random.randrange(2): 94 | _convert(image, beta=random.uniform(-32, 32)) 95 | 96 | #contrast distortion 97 | if random.randrange(2): 98 | _convert(image, alpha=random.uniform(0.5, 1.5)) 99 | 100 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 101 | 102 | #saturation distortion 103 | if random.randrange(2): 104 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 105 | 106 | #hue distortion 107 | if random.randrange(2): 108 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 109 | tmp %= 180 110 | image[:, :, 0] = tmp 111 | 112 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 113 | 114 | else: 115 | 116 | #brightness distortion 117 | if random.randrange(2): 118 | _convert(image, beta=random.uniform(-32, 32)) 119 | 120 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 121 | 122 | #saturation distortion 123 | if random.randrange(2): 124 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 125 | 126 | #hue distortion 127 | if random.randrange(2): 128 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 129 | tmp %= 180 130 | image[:, :, 0] = tmp 131 | 132 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 133 | 134 | #contrast distortion 135 | if random.randrange(2): 136 | _convert(image, alpha=random.uniform(0.5, 1.5)) 137 | 138 | return image 139 | 140 | 141 | def _expand(image, boxes, fill, p): 142 | if random.randrange(2): 143 | return image, boxes 144 | 145 | height, width, depth = image.shape 146 | 147 | scale = random.uniform(1, p) 148 | w = int(scale * width) 149 | h = int(scale * height) 150 | 151 | left = random.randint(0, w - width) 152 | top = random.randint(0, h - height) 153 | 154 | boxes_t = boxes.copy() 155 | boxes_t[:, :2] += (left, top) 156 | boxes_t[:, 2:] += (left, top) 157 | expand_image = np.empty( 158 | (h, w, depth), 159 | dtype=image.dtype) 160 | expand_image[:, :] = fill 161 | expand_image[top:top + height, left:left + width] = image 162 | image = expand_image 163 | 164 | return image, boxes_t 165 | 166 | 167 | def _mirror(image, boxes, landms): 168 | _, width, _ = image.shape 169 | if random.randrange(2): 170 | image = image[:, ::-1] 171 | boxes = boxes.copy() 172 | boxes[:, 0::2] = width - boxes[:, 2::-2] 173 | 174 | # landm 175 | landms = landms.copy() 176 | landms = landms.reshape([-1, 5, 2]) 177 | landms[:, :, 0] = width - landms[:, :, 0] 178 | tmp = landms[:, 1, :].copy() 179 | landms[:, 1, :] = landms[:, 0, :] 180 | landms[:, 0, :] = tmp 181 | tmp1 = landms[:, 4, :].copy() 182 | landms[:, 4, :] = landms[:, 3, :] 183 | landms[:, 3, :] = tmp1 184 | landms = landms.reshape([-1, 10]) 185 | 186 | return image, boxes, landms 187 | 188 | 189 | def _pad_to_square(image, rgb_mean, pad_image_flag): 190 | if not pad_image_flag: 191 | return image 192 | height, width, _ = image.shape 193 | long_side = max(width, height) 194 | image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) 195 | image_t[:, :] = rgb_mean 196 | image_t[0:0 + height, 0:0 + width] = image 197 | return image_t 198 | 199 | 200 | def _resize_subtract_mean(image, insize, rgb_mean): 201 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] 202 | interp_method = interp_methods[random.randrange(5)] 203 | image = cv2.resize(image, (insize, insize), interpolation=interp_method) 204 | image = image.astype(np.float32) 205 | image -= rgb_mean 206 | return image.transpose(2, 0, 1) 207 | 208 | 209 | class preproc(object): 210 | 211 | def __init__(self, img_dim, rgb_means): 212 | self.img_dim = img_dim 213 | self.rgb_means = rgb_means 214 | 215 | def __call__(self, image, targets): 216 | assert targets.shape[0] > 0, "this image does not have gt" 217 | 218 | boxes = targets[:, :4].copy() 219 | labels = targets[:, -1].copy() 220 | landm = targets[:, 4:-1].copy() 221 | 222 | image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) 223 | image_t = _distort(image_t) 224 | image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag) 225 | image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t) 226 | height, width, _ = image_t.shape 227 | image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) 228 | boxes_t[:, 0::2] /= width 229 | boxes_t[:, 1::2] /= height 230 | 231 | landm_t[:, 0::2] /= width 232 | landm_t[:, 1::2] /= height 233 | 234 | labels_t = np.expand_dims(labels_t, 1) 235 | targets_t = np.hstack((boxes_t, landm_t, labels_t)) 236 | 237 | return image_t, targets_t 238 | -------------------------------------------------------------------------------- /GPEN/retinaface/data/wider_face.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import sys 4 | import torch 5 | import torch.utils.data as data 6 | import cv2 7 | import numpy as np 8 | 9 | class WiderFaceDetection(data.Dataset): 10 | def __init__(self, txt_path, preproc=None): 11 | self.preproc = preproc 12 | self.imgs_path = [] 13 | self.words = [] 14 | f = open(txt_path,'r') 15 | lines = f.readlines() 16 | isFirst = True 17 | labels = [] 18 | for line in lines: 19 | line = line.rstrip() 20 | if line.startswith('#'): 21 | if isFirst is True: 22 | isFirst = False 23 | else: 24 | labels_copy = labels.copy() 25 | self.words.append(labels_copy) 26 | labels.clear() 27 | path = line[2:] 28 | path = txt_path.replace('label.txt','images/') + path 29 | self.imgs_path.append(path) 30 | else: 31 | line = line.split(' ') 32 | label = [float(x) for x in line] 33 | labels.append(label) 34 | 35 | self.words.append(labels) 36 | 37 | def __len__(self): 38 | return len(self.imgs_path) 39 | 40 | def __getitem__(self, index): 41 | img = cv2.imread(self.imgs_path[index]) 42 | height, width, _ = img.shape 43 | 44 | labels = self.words[index] 45 | annotations = np.zeros((0, 15)) 46 | if len(labels) == 0: 47 | return annotations 48 | for idx, label in enumerate(labels): 49 | annotation = np.zeros((1, 15)) 50 | # bbox 51 | annotation[0, 0] = label[0] # x1 52 | annotation[0, 1] = label[1] # y1 53 | annotation[0, 2] = label[0] + label[2] # x2 54 | annotation[0, 3] = label[1] + label[3] # y2 55 | 56 | # landmarks 57 | annotation[0, 4] = label[4] # l0_x 58 | annotation[0, 5] = label[5] # l0_y 59 | annotation[0, 6] = label[7] # l1_x 60 | annotation[0, 7] = label[8] # l1_y 61 | annotation[0, 8] = label[10] # l2_x 62 | annotation[0, 9] = label[11] # l2_y 63 | annotation[0, 10] = label[13] # l3_x 64 | annotation[0, 11] = label[14] # l3_y 65 | annotation[0, 12] = label[16] # l4_x 66 | annotation[0, 13] = label[17] # l4_y 67 | if (annotation[0, 4]<0): 68 | annotation[0, 14] = -1 69 | else: 70 | annotation[0, 14] = 1 71 | 72 | annotations = np.append(annotations, annotation, axis=0) 73 | target = np.array(annotations) 74 | if self.preproc is not None: 75 | img, target = self.preproc(img, target) 76 | 77 | return torch.from_numpy(img), target 78 | 79 | def detection_collate(batch): 80 | """Custom collate fn for dealing with batches of images that have a different 81 | number of associated object annotations (bounding boxes). 82 | 83 | Arguments: 84 | batch: (tuple) A tuple of tensor images and lists of annotations 85 | 86 | Return: 87 | A tuple containing: 88 | 1) (tensor) batch of images stacked on their 0 dim 89 | 2) (list of tensors) annotations for a given image are stacked on 0 dim 90 | """ 91 | targets = [] 92 | imgs = [] 93 | for _, sample in enumerate(batch): 94 | for _, tup in enumerate(sample): 95 | if torch.is_tensor(tup): 96 | imgs.append(tup) 97 | elif isinstance(tup, type(np.empty(0))): 98 | annos = torch.from_numpy(tup).float() 99 | targets.append(annos) 100 | 101 | return (torch.stack(imgs, 0), targets) 102 | -------------------------------------------------------------------------------- /GPEN/retinaface/facemodels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/GPEN/retinaface/facemodels/__init__.py -------------------------------------------------------------------------------- /GPEN/retinaface/facemodels/net.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models._utils as _utils 5 | import torchvision.models as models 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | def conv_bn(inp, oup, stride = 1, leaky = 0): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 12 | nn.BatchNorm2d(oup), 13 | nn.LeakyReLU(negative_slope=leaky, inplace=True) 14 | ) 15 | 16 | def conv_bn_no_relu(inp, oup, stride): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 19 | nn.BatchNorm2d(oup), 20 | ) 21 | 22 | def conv_bn1X1(inp, oup, stride, leaky=0): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), 25 | nn.BatchNorm2d(oup), 26 | nn.LeakyReLU(negative_slope=leaky, inplace=True) 27 | ) 28 | 29 | def conv_dw(inp, oup, stride, leaky=0.1): 30 | return nn.Sequential( 31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 32 | nn.BatchNorm2d(inp), 33 | nn.LeakyReLU(negative_slope= leaky,inplace=True), 34 | 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | nn.LeakyReLU(negative_slope= leaky,inplace=True), 38 | ) 39 | 40 | class SSH(nn.Module): 41 | def __init__(self, in_channel, out_channel): 42 | super(SSH, self).__init__() 43 | assert out_channel % 4 == 0 44 | leaky = 0 45 | if (out_channel <= 64): 46 | leaky = 0.1 47 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1) 48 | 49 | self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky) 50 | self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) 51 | 52 | self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky) 53 | self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) 54 | 55 | def forward(self, input): 56 | conv3X3 = self.conv3X3(input) 57 | 58 | conv5X5_1 = self.conv5X5_1(input) 59 | conv5X5 = self.conv5X5_2(conv5X5_1) 60 | 61 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 62 | conv7X7 = self.conv7x7_3(conv7X7_2) 63 | 64 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 65 | out = F.relu(out) 66 | return out 67 | 68 | class FPN(nn.Module): 69 | def __init__(self,in_channels_list,out_channels): 70 | super(FPN,self).__init__() 71 | leaky = 0 72 | if (out_channels <= 64): 73 | leaky = 0.1 74 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky) 75 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky) 76 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky) 77 | 78 | self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky) 79 | self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky) 80 | 81 | def forward(self, input): 82 | # names = list(input.keys()) 83 | input = list(input.values()) 84 | 85 | output1 = self.output1(input[0]) 86 | output2 = self.output2(input[1]) 87 | output3 = self.output3(input[2]) 88 | 89 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") 90 | output2 = output2 + up3 91 | output2 = self.merge2(output2) 92 | 93 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") 94 | output1 = output1 + up2 95 | output1 = self.merge1(output1) 96 | 97 | out = [output1, output2, output3] 98 | return out 99 | 100 | 101 | 102 | class MobileNetV1(nn.Module): 103 | def __init__(self): 104 | super(MobileNetV1, self).__init__() 105 | self.stage1 = nn.Sequential( 106 | conv_bn(3, 8, 2, leaky = 0.1), # 3 107 | conv_dw(8, 16, 1), # 7 108 | conv_dw(16, 32, 2), # 11 109 | conv_dw(32, 32, 1), # 19 110 | conv_dw(32, 64, 2), # 27 111 | conv_dw(64, 64, 1), # 43 112 | ) 113 | self.stage2 = nn.Sequential( 114 | conv_dw(64, 128, 2), # 43 + 16 = 59 115 | conv_dw(128, 128, 1), # 59 + 32 = 91 116 | conv_dw(128, 128, 1), # 91 + 32 = 123 117 | conv_dw(128, 128, 1), # 123 + 32 = 155 118 | conv_dw(128, 128, 1), # 155 + 32 = 187 119 | conv_dw(128, 128, 1), # 187 + 32 = 219 120 | ) 121 | self.stage3 = nn.Sequential( 122 | conv_dw(128, 256, 2), # 219 +3 2 = 241 123 | conv_dw(256, 256, 1), # 241 + 64 = 301 124 | ) 125 | self.avg = nn.AdaptiveAvgPool2d((1,1)) 126 | self.fc = nn.Linear(256, 1000) 127 | 128 | def forward(self, x): 129 | x = self.stage1(x) 130 | x = self.stage2(x) 131 | x = self.stage3(x) 132 | x = self.avg(x) 133 | # x = self.model(x) 134 | x = x.view(-1, 256) 135 | x = self.fc(x) 136 | return x 137 | 138 | -------------------------------------------------------------------------------- /GPEN/retinaface/facemodels/retinaface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models.detection.backbone_utils as backbone_utils 4 | import torchvision.models._utils as _utils 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | from facemodels.net import MobileNetV1 as MobileNetV1 9 | from facemodels.net import FPN as FPN 10 | from facemodels.net import SSH as SSH 11 | 12 | 13 | 14 | class ClassHead(nn.Module): 15 | def __init__(self,inchannels=512,num_anchors=3): 16 | super(ClassHead,self).__init__() 17 | self.num_anchors = num_anchors 18 | self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0) 19 | 20 | def forward(self,x): 21 | out = self.conv1x1(x) 22 | out = out.permute(0,2,3,1).contiguous() 23 | 24 | return out.view(out.shape[0], -1, 2) 25 | 26 | class BboxHead(nn.Module): 27 | def __init__(self,inchannels=512,num_anchors=3): 28 | super(BboxHead,self).__init__() 29 | self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0) 30 | 31 | def forward(self,x): 32 | out = self.conv1x1(x) 33 | out = out.permute(0,2,3,1).contiguous() 34 | 35 | return out.view(out.shape[0], -1, 4) 36 | 37 | class LandmarkHead(nn.Module): 38 | def __init__(self,inchannels=512,num_anchors=3): 39 | super(LandmarkHead,self).__init__() 40 | self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0) 41 | 42 | def forward(self,x): 43 | out = self.conv1x1(x) 44 | out = out.permute(0,2,3,1).contiguous() 45 | 46 | return out.view(out.shape[0], -1, 10) 47 | 48 | class RetinaFace(nn.Module): 49 | def __init__(self, cfg = None, phase = 'train'): 50 | """ 51 | :param cfg: Network related settings. 52 | :param phase: train or test. 53 | """ 54 | super(RetinaFace,self).__init__() 55 | self.phase = phase 56 | backbone = None 57 | if cfg['name'] == 'mobilenet0.25': 58 | backbone = MobileNetV1() 59 | if cfg['pretrain']: 60 | checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu')) 61 | from collections import OrderedDict 62 | new_state_dict = OrderedDict() 63 | for k, v in checkpoint['state_dict'].items(): 64 | name = k[7:] # remove module. 65 | new_state_dict[name] = v 66 | # load params 67 | backbone.load_state_dict(new_state_dict) 68 | elif cfg['name'] == 'Resnet50': 69 | import torchvision.models as models 70 | backbone = models.resnet50(pretrained=cfg['pretrain']) 71 | 72 | self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers']) 73 | in_channels_stage2 = cfg['in_channel'] 74 | in_channels_list = [ 75 | in_channels_stage2 * 2, 76 | in_channels_stage2 * 4, 77 | in_channels_stage2 * 8, 78 | ] 79 | out_channels = cfg['out_channel'] 80 | self.fpn = FPN(in_channels_list,out_channels) 81 | self.ssh1 = SSH(out_channels, out_channels) 82 | self.ssh2 = SSH(out_channels, out_channels) 83 | self.ssh3 = SSH(out_channels, out_channels) 84 | 85 | self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) 86 | self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) 87 | self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) 88 | 89 | def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2): 90 | classhead = nn.ModuleList() 91 | for i in range(fpn_num): 92 | classhead.append(ClassHead(inchannels,anchor_num)) 93 | return classhead 94 | 95 | def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2): 96 | bboxhead = nn.ModuleList() 97 | for i in range(fpn_num): 98 | bboxhead.append(BboxHead(inchannels,anchor_num)) 99 | return bboxhead 100 | 101 | def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2): 102 | landmarkhead = nn.ModuleList() 103 | for i in range(fpn_num): 104 | landmarkhead.append(LandmarkHead(inchannels,anchor_num)) 105 | return landmarkhead 106 | 107 | def forward(self,inputs): 108 | out = self.body(inputs) 109 | 110 | # FPN 111 | fpn = self.fpn(out) 112 | 113 | # SSH 114 | feature1 = self.ssh1(fpn[0]) 115 | feature2 = self.ssh2(fpn[1]) 116 | feature3 = self.ssh3(fpn[2]) 117 | features = [feature1, feature2, feature3] 118 | 119 | bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) 120 | classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1) 121 | ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1) 122 | 123 | if self.phase == 'train': 124 | output = (bbox_regressions, classifications, ldm_regressions) 125 | else: 126 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) 127 | return output -------------------------------------------------------------------------------- /GPEN/retinaface/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /GPEN/retinaface/layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import product as product 3 | import numpy as np 4 | from math import ceil 5 | 6 | 7 | class PriorBox(object): 8 | def __init__(self, cfg, image_size=None, phase='train'): 9 | super(PriorBox, self).__init__() 10 | self.min_sizes = cfg['min_sizes'] 11 | self.steps = cfg['steps'] 12 | self.clip = cfg['clip'] 13 | self.image_size = image_size 14 | self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] 15 | self.name = "s" 16 | 17 | def forward(self): 18 | anchors = [] 19 | for k, f in enumerate(self.feature_maps): 20 | min_sizes = self.min_sizes[k] 21 | for i, j in product(range(f[0]), range(f[1])): 22 | for min_size in min_sizes: 23 | s_kx = min_size / self.image_size[1] 24 | s_ky = min_size / self.image_size[0] 25 | dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] 26 | dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] 27 | for cy, cx in product(dense_cy, dense_cx): 28 | anchors += [cx, cy, s_kx, s_ky] 29 | 30 | # back to torch land 31 | output = torch.Tensor(anchors).view(-1, 4) 32 | if self.clip: 33 | output.clamp_(max=1, min=0) 34 | return output 35 | -------------------------------------------------------------------------------- /GPEN/retinaface/layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multibox_loss import MultiBoxLoss 2 | 3 | __all__ = ['MultiBoxLoss'] 4 | -------------------------------------------------------------------------------- /GPEN/retinaface/layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from utils.box_utils import match, log_sum_exp 6 | from data import cfg_mnet 7 | GPU = cfg_mnet['gpu_train'] 8 | 9 | class MultiBoxLoss(nn.Module): 10 | """SSD Weighted Loss Function 11 | Compute Targets: 12 | 1) Produce Confidence Target Indices by matching ground truth boxes 13 | with (default) 'priorboxes' that have jaccard index > threshold parameter 14 | (default threshold: 0.5). 15 | 2) Produce localization target by 'encoding' variance into offsets of ground 16 | truth boxes and their matched 'priorboxes'. 17 | 3) Hard negative mining to filter the excessive number of negative examples 18 | that comes with using a large number of default bounding boxes. 19 | (default negative:positive ratio 3:1) 20 | Objective Loss: 21 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 22 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 23 | weighted by α which is set to 1 by cross val. 24 | Args: 25 | c: class confidences, 26 | l: predicted boxes, 27 | g: ground truth boxes 28 | N: number of matched default boxes 29 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 30 | """ 31 | 32 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): 33 | super(MultiBoxLoss, self).__init__() 34 | self.num_classes = num_classes 35 | self.threshold = overlap_thresh 36 | self.background_label = bkg_label 37 | self.encode_target = encode_target 38 | self.use_prior_for_matching = prior_for_matching 39 | self.do_neg_mining = neg_mining 40 | self.negpos_ratio = neg_pos 41 | self.neg_overlap = neg_overlap 42 | self.variance = [0.1, 0.2] 43 | 44 | def forward(self, predictions, priors, targets): 45 | """Multibox Loss 46 | Args: 47 | predictions (tuple): A tuple containing loc preds, conf preds, 48 | and prior boxes from SSD net. 49 | conf shape: torch.size(batch_size,num_priors,num_classes) 50 | loc shape: torch.size(batch_size,num_priors,4) 51 | priors shape: torch.size(num_priors,4) 52 | 53 | ground_truth (tensor): Ground truth boxes and labels for a batch, 54 | shape: [batch_size,num_objs,5] (last idx is the label). 55 | """ 56 | 57 | loc_data, conf_data, landm_data = predictions 58 | priors = priors 59 | num = loc_data.size(0) 60 | num_priors = (priors.size(0)) 61 | 62 | # match priors (default boxes) and ground truth boxes 63 | loc_t = torch.Tensor(num, num_priors, 4) 64 | landm_t = torch.Tensor(num, num_priors, 10) 65 | conf_t = torch.LongTensor(num, num_priors) 66 | for idx in range(num): 67 | truths = targets[idx][:, :4].data 68 | labels = targets[idx][:, -1].data 69 | landms = targets[idx][:, 4:14].data 70 | defaults = priors.data 71 | match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) 72 | if GPU: 73 | loc_t = loc_t.cuda() 74 | conf_t = conf_t.cuda() 75 | landm_t = landm_t.cuda() 76 | 77 | zeros = torch.tensor(0).cuda() 78 | # landm Loss (Smooth L1) 79 | # Shape: [batch,num_priors,10] 80 | pos1 = conf_t > zeros 81 | num_pos_landm = pos1.long().sum(1, keepdim=True) 82 | N1 = max(num_pos_landm.data.sum().float(), 1) 83 | pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) 84 | landm_p = landm_data[pos_idx1].view(-1, 10) 85 | landm_t = landm_t[pos_idx1].view(-1, 10) 86 | loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') 87 | 88 | 89 | pos = conf_t != zeros 90 | conf_t[pos] = 1 91 | 92 | # Localization Loss (Smooth L1) 93 | # Shape: [batch,num_priors,4] 94 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 95 | loc_p = loc_data[pos_idx].view(-1, 4) 96 | loc_t = loc_t[pos_idx].view(-1, 4) 97 | loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') 98 | 99 | # Compute max conf across batch for hard negative mining 100 | batch_conf = conf_data.view(-1, self.num_classes) 101 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) 102 | 103 | # Hard Negative Mining 104 | loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now 105 | loss_c = loss_c.view(num, -1) 106 | _, loss_idx = loss_c.sort(1, descending=True) 107 | _, idx_rank = loss_idx.sort(1) 108 | num_pos = pos.long().sum(1, keepdim=True) 109 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 110 | neg = idx_rank < num_neg.expand_as(idx_rank) 111 | 112 | # Confidence Loss Including Positive and Negative Examples 113 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 114 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 115 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) 116 | targets_weighted = conf_t[(pos+neg).gt(0)] 117 | loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') 118 | 119 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 120 | N = max(num_pos.data.sum().float(), 1) 121 | loss_l /= N 122 | loss_c /= N 123 | loss_landm /= N1 124 | 125 | return loss_l, loss_c, loss_landm 126 | -------------------------------------------------------------------------------- /GPEN/retinaface/retinaface_detection.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import os 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import numpy as np 9 | from data import cfg_re50 10 | from layers.functions.prior_box import PriorBox 11 | from utils.nms.py_cpu_nms import py_cpu_nms 12 | import cv2 13 | from facemodels.retinaface import RetinaFace 14 | from utils.box_utils import decode, decode_landm 15 | import time 16 | import torch 17 | 18 | class RetinaFaceDetection(object): 19 | def __init__(self, base_dir, network='RetinaFace-R50'): 20 | torch.set_grad_enabled(False) 21 | cudnn.benchmark = True 22 | self.pretrained_path = os.path.join(base_dir, 'weights', network+'.pth') 23 | self.device = torch.cuda.current_device() 24 | self.cfg = cfg_re50 25 | self.net = RetinaFace(cfg=self.cfg, phase='test') 26 | self.load_model() 27 | self.net = self.net.cuda() 28 | self.net_trt = None 29 | 30 | def check_keys(self, pretrained_state_dict): 31 | ckpt_keys = set(pretrained_state_dict.keys()) 32 | model_keys = set(self.net.state_dict().keys()) 33 | used_pretrained_keys = model_keys & ckpt_keys 34 | unused_pretrained_keys = ckpt_keys - model_keys 35 | missing_keys = model_keys - ckpt_keys 36 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' 37 | return True 38 | 39 | def remove_prefix(self, state_dict, prefix): 40 | ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' 41 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 42 | return {f(key): value for key, value in state_dict.items()} 43 | 44 | def load_model(self, load_to_cpu=False): 45 | if load_to_cpu: 46 | pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage) 47 | else: 48 | pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage.cuda()) 49 | if "state_dict" in pretrained_dict.keys(): 50 | pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.') 51 | else: 52 | pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') 53 | self.check_keys(pretrained_dict) 54 | self.net.load_state_dict(pretrained_dict, strict=False) 55 | self.net.eval() 56 | 57 | def build_trt(self, img_raw): 58 | img = np.float32(img_raw) 59 | 60 | img -= (104, 117, 123) 61 | img = img.transpose(2, 0, 1) 62 | img = torch.from_numpy(img).unsqueeze(0) 63 | img = img.cuda() 64 | 65 | print('building trt model FaceGAN') 66 | from torch2trt import torch2trt 67 | self.net_trt = torch2trt(self.net, [img], fp16_mode=True) 68 | del self.net 69 | print('sucessfully built') 70 | 71 | def detect_trt(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False): 72 | img = np.float32(img_raw) 73 | 74 | im_height, im_width = img.shape[:2] 75 | scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) 76 | img -= (104, 117, 123) 77 | img = img.transpose(2, 0, 1) 78 | img = torch.from_numpy(img).unsqueeze(0) 79 | img = img.cuda() 80 | scale = scale.cuda() 81 | 82 | loc, conf, landms = self.net_trt(img) # forward pass 83 | 84 | priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) 85 | priors = priorbox.forward() 86 | priors = priors.cuda() 87 | prior_data = priors.data 88 | boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) 89 | boxes = boxes * scale / resize 90 | boxes = boxes.cpu().numpy() 91 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1] 92 | landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance']) 93 | scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2], 94 | img.shape[3], img.shape[2], img.shape[3], img.shape[2], 95 | img.shape[3], img.shape[2]]) 96 | scale1 = scale1.cuda() 97 | landms = landms * scale1 / resize 98 | landms = landms.cpu().numpy() 99 | 100 | # ignore low scores 101 | inds = np.where(scores > confidence_threshold)[0] 102 | boxes = boxes[inds] 103 | landms = landms[inds] 104 | scores = scores[inds] 105 | 106 | # keep top-K before NMS 107 | order = scores.argsort()[::-1][:top_k] 108 | boxes = boxes[order] 109 | landms = landms[order] 110 | scores = scores[order] 111 | 112 | # do NMS 113 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) 114 | keep = py_cpu_nms(dets, nms_threshold) 115 | # keep = nms(dets, nms_threshold,force_cpu=args.cpu) 116 | dets = dets[keep, :] 117 | landms = landms[keep] 118 | 119 | # keep top-K faster NMS 120 | dets = dets[:keep_top_k, :] 121 | landms = landms[:keep_top_k, :] 122 | 123 | # sort faces(delete) 124 | ''' 125 | fscores = [det[4] for det in dets] 126 | sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index 127 | tmp = [landms[idx] for idx in sorted_idx] 128 | landms = np.asarray(tmp) 129 | ''' 130 | 131 | landms = landms.reshape((-1, 5, 2)) 132 | landms = landms.transpose((0, 2, 1)) 133 | landms = landms.reshape(-1, 10, ) 134 | return dets, landms 135 | 136 | 137 | def detect(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False): 138 | img = np.float32(img_raw) 139 | 140 | im_height, im_width = img.shape[:2] 141 | scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) 142 | img -= (104, 117, 123) 143 | img = img.transpose(2, 0, 1) 144 | img = torch.from_numpy(img).unsqueeze(0) 145 | img = img.cuda() 146 | scale = scale.cuda() 147 | 148 | loc, conf, landms = self.net(img) # forward pass 149 | 150 | priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) 151 | priors = priorbox.forward() 152 | priors = priors.cuda() 153 | prior_data = priors.data 154 | boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) 155 | boxes = boxes * scale / resize 156 | boxes = boxes.cpu().numpy() 157 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1] 158 | landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance']) 159 | scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2], 160 | img.shape[3], img.shape[2], img.shape[3], img.shape[2], 161 | img.shape[3], img.shape[2]]) 162 | scale1 = scale1.cuda() 163 | landms = landms * scale1 / resize 164 | landms = landms.cpu().numpy() 165 | 166 | # ignore low scores 167 | inds = np.where(scores > confidence_threshold)[0] 168 | boxes = boxes[inds] 169 | landms = landms[inds] 170 | scores = scores[inds] 171 | 172 | # keep top-K before NMS 173 | order = scores.argsort()[::-1][:top_k] 174 | boxes = boxes[order] 175 | landms = landms[order] 176 | scores = scores[order] 177 | 178 | # do NMS 179 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) 180 | keep = py_cpu_nms(dets, nms_threshold) 181 | # keep = nms(dets, nms_threshold,force_cpu=args.cpu) 182 | dets = dets[keep, :] 183 | landms = landms[keep] 184 | 185 | # keep top-K faster NMS 186 | dets = dets[:keep_top_k, :] 187 | landms = landms[:keep_top_k, :] 188 | 189 | # sort faces(delete) 190 | ''' 191 | fscores = [det[4] for det in dets] 192 | sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index 193 | tmp = [landms[idx] for idx in sorted_idx] 194 | landms = np.asarray(tmp) 195 | ''' 196 | 197 | landms = landms.reshape((-1, 5, 2)) 198 | landms = landms.transpose((0, 2, 1)) 199 | landms = landms.reshape(-1, 10, ) 200 | return dets, landms 201 | -------------------------------------------------------------------------------- /GPEN/retinaface/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/GPEN/retinaface/utils/__init__.py -------------------------------------------------------------------------------- /GPEN/retinaface/utils/nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/GPEN/retinaface/utils/nms/__init__.py -------------------------------------------------------------------------------- /GPEN/retinaface/utils/nms/py_cpu_nms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | 10 | def py_cpu_nms(dets, thresh): 11 | """Pure Python NMS baseline.""" 12 | x1 = dets[:, 0] 13 | y1 = dets[:, 1] 14 | x2 = dets[:, 2] 15 | y2 = dets[:, 3] 16 | scores = dets[:, 4] 17 | 18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 19 | order = scores.argsort()[::-1] 20 | 21 | keep = [] 22 | while order.size > 0: 23 | i = order[0] 24 | keep.append(i) 25 | xx1 = np.maximum(x1[i], x1[order[1:]]) 26 | yy1 = np.maximum(y1[i], y1[order[1:]]) 27 | xx2 = np.minimum(x2[i], x2[order[1:]]) 28 | yy2 = np.minimum(y2[i], y2[order[1:]]) 29 | 30 | w = np.maximum(0.0, xx2 - xx1 + 1) 31 | h = np.maximum(0.0, yy2 - yy1 + 1) 32 | inter = w * h 33 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 34 | 35 | inds = np.where(ovr <= thresh)[0] 36 | order = order[inds + 1] 37 | 38 | return keep 39 | -------------------------------------------------------------------------------- /GPEN/retinaface/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | 10 | 11 | class Timer(object): 12 | """A simple timer.""" 13 | def __init__(self): 14 | self.total_time = 0. 15 | self.calls = 0 16 | self.start_time = 0. 17 | self.diff = 0. 18 | self.average_time = 0. 19 | 20 | def tic(self): 21 | # using time.time instead of time.clock because time time.clock 22 | # does not normalize for multithreading 23 | self.start_time = time.time() 24 | 25 | def toc(self, average=True): 26 | self.diff = time.time() - self.start_time 27 | self.total_time += self.diff 28 | self.calls += 1 29 | self.average_time = self.total_time / self.calls 30 | if average: 31 | return self.average_time 32 | else: 33 | return self.diff 34 | 35 | def clear(self): 36 | self.total_time = 0. 37 | self.calls = 0 38 | self.start_time = 0. 39 | self.diff = 0. 40 | self.average_time = 0. 41 | -------------------------------------------------------------------------------- /GPEN/sr_model/arch_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | from torch.nn import init as init 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | 8 | @torch.no_grad() 9 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 10 | """Initialize network weights. 11 | 12 | Args: 13 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 14 | scale (float): Scale initialized weights, especially for residual 15 | blocks. Default: 1. 16 | bias_fill (float): The value to fill bias. Default: 0 17 | kwargs (dict): Other arguments for initialization function. 18 | """ 19 | if not isinstance(module_list, list): 20 | module_list = [module_list] 21 | for module in module_list: 22 | for m in module.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | init.kaiming_normal_(m.weight, **kwargs) 25 | m.weight.data *= scale 26 | if m.bias is not None: 27 | m.bias.data.fill_(bias_fill) 28 | elif isinstance(m, nn.Linear): 29 | init.kaiming_normal_(m.weight, **kwargs) 30 | m.weight.data *= scale 31 | if m.bias is not None: 32 | m.bias.data.fill_(bias_fill) 33 | elif isinstance(m, _BatchNorm): 34 | init.constant_(m.weight, 1) 35 | if m.bias is not None: 36 | m.bias.data.fill_(bias_fill) 37 | 38 | 39 | def make_layer(basic_block, num_basic_block, **kwarg): 40 | """Make layers by stacking the same blocks. 41 | 42 | Args: 43 | basic_block (nn.module): nn.module class for basic block. 44 | num_basic_block (int): number of blocks. 45 | 46 | Returns: 47 | nn.Sequential: Stacked blocks in nn.Sequential. 48 | """ 49 | layers = [] 50 | for _ in range(num_basic_block): 51 | layers.append(basic_block(**kwarg)) 52 | return nn.Sequential(*layers) 53 | 54 | 55 | class ResidualBlockNoBN(nn.Module): 56 | """Residual block without BN. 57 | 58 | It has a style of: 59 | ---Conv-ReLU-Conv-+- 60 | |________________| 61 | 62 | Args: 63 | num_feat (int): Channel number of intermediate features. 64 | Default: 64. 65 | res_scale (float): Residual scale. Default: 1. 66 | pytorch_init (bool): If set to True, use pytorch default init, 67 | otherwise, use default_init_weights. Default: False. 68 | """ 69 | 70 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 71 | super(ResidualBlockNoBN, self).__init__() 72 | self.res_scale = res_scale 73 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 74 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 75 | self.relu = nn.ReLU(inplace=True) 76 | 77 | if not pytorch_init: 78 | default_init_weights([self.conv1, self.conv2], 0.1) 79 | 80 | def forward(self, x): 81 | identity = x 82 | out = self.conv2(self.relu(self.conv1(x))) 83 | return identity + out * self.res_scale 84 | 85 | 86 | class Upsample(nn.Sequential): 87 | """Upsample module. 88 | 89 | Args: 90 | scale (int): Scale factor. Supported scales: 2^n and 3. 91 | num_feat (int): Channel number of intermediate features. 92 | """ 93 | 94 | def __init__(self, scale, num_feat): 95 | m = [] 96 | if (scale & (scale - 1)) == 0: # scale = 2^n 97 | for _ in range(int(math.log(scale, 2))): 98 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 99 | m.append(nn.PixelShuffle(2)) 100 | elif scale == 3: 101 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 102 | m.append(nn.PixelShuffle(3)) 103 | else: 104 | raise ValueError(f'scale {scale} is not supported. ' 105 | 'Supported scales: 2^n and 3.') 106 | super(Upsample, self).__init__(*m) 107 | 108 | # TODO: may write a cpp file 109 | def pixel_unshuffle(x, scale): 110 | """ Pixel unshuffle. 111 | 112 | Args: 113 | x (Tensor): Input feature with shape (b, c, hh, hw). 114 | scale (int): Downsample ratio. 115 | 116 | Returns: 117 | Tensor: the pixel unshuffled feature. 118 | """ 119 | b, c, hh, hw = x.size() 120 | out_channel = c * (scale**2) 121 | assert hh % scale == 0 and hw % scale == 0 122 | h = hh // scale 123 | w = hw // scale 124 | x_view = x.view(b, c, h, scale, w, scale) 125 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) -------------------------------------------------------------------------------- /GPEN/sr_model/real_esrnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from rrdbnet_arch import RRDBNet 5 | from torch.nn import functional as F 6 | import torch 7 | 8 | 9 | class RealESRNet(object): 10 | def __init__(self, base_dir=os.path.dirname(__file__), model=None, scale=2): 11 | self.base_dir = base_dir 12 | self.scale = scale 13 | self.load_srmodel(base_dir, model) 14 | self.srmodel_trt = None 15 | 16 | def load_srmodel(self, base_dir, model): 17 | self.scale = 2 if "x2" in model else 4 if "x4" in model else -1 18 | if self.scale == -1: 19 | raise Exception("Scale not supported") 20 | self.srmodel = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=32, num_block=23, num_grow_ch=32, scale=self.scale) 21 | if model is None: 22 | loadnet = torch.load(os.path.join(self.base_dir, 'weights', 'realesrnet_x2.pth')) 23 | else: 24 | loadnet = torch.load(os.path.join(self.base_dir, 'weights', model+'.pth')) 25 | self.srmodel.load_state_dict(loadnet['params_ema'], strict=True) 26 | self.srmodel.eval() 27 | self.srmodel = self.srmodel.cuda() 28 | 29 | def build_trt(self, img): 30 | img = img.astype(np.float32) / 255. 31 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() 32 | img = img.unsqueeze(0).cuda() 33 | print('building trt model srmodel') 34 | from torch2trt import torch2trt 35 | self.srmodel_trt = torch2trt(self.srmodel, [img], fp16_mode=True) 36 | print('sucessfully built') 37 | del self.srmodel 38 | 39 | def process_trt(self, img): 40 | img = img.astype(np.float32) / 255. 41 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() 42 | img = img.unsqueeze(0).cuda() 43 | 44 | if self.scale == 2: 45 | mod_scale = 2 46 | elif self.scale == 1: 47 | mod_scale = 4 48 | else: 49 | mod_scale = None 50 | if mod_scale is not None: 51 | h_pad, w_pad = 0, 0 52 | _, _, h, w = img.size() 53 | if (h % mod_scale != 0): 54 | h_pad = (mod_scale - h % mod_scale) 55 | if (w % mod_scale != 0): 56 | w_pad = (mod_scale - w % mod_scale) 57 | img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect') 58 | try: 59 | with torch.no_grad(): 60 | output = self.srmodel_trt(img) 61 | # remove extra pad 62 | if mod_scale is not None: 63 | _, _, h, w = output.size() 64 | output = output[:, :, 0:h - h_pad, 0:w - w_pad] 65 | output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() 66 | output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) 67 | output = (output * 255.0).round().astype(np.uint8) 68 | 69 | return output 70 | except: 71 | return None 72 | 73 | def process(self, img): 74 | img = img.astype(np.float32) / 255. 75 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() 76 | img = img.unsqueeze(0).cuda() 77 | # print(img.shape) 78 | 79 | if self.scale == 2: 80 | mod_scale = 2 81 | elif self.scale == 1: 82 | mod_scale = 4 83 | else: 84 | mod_scale = None 85 | if mod_scale is not None: 86 | h_pad, w_pad = 0, 0 87 | _, _, h, w = img.size() 88 | if (h % mod_scale != 0): 89 | h_pad = (mod_scale - h % mod_scale) 90 | if (w % mod_scale != 0): 91 | w_pad = (mod_scale - w % mod_scale) 92 | img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect') 93 | try: 94 | with torch.no_grad(): 95 | output = self.srmodel(img) 96 | # remove extra pad 97 | if mod_scale is not None: 98 | _, _, h, w = output.size() 99 | output = output[:, :, 0:h - h_pad, 0:w - w_pad] 100 | output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() 101 | output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) 102 | output = (output * 255.0).round().astype(np.uint8) 103 | 104 | return output 105 | except: 106 | return None 107 | 108 | -------------------------------------------------------------------------------- /GPEN/sr_model/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from arch_util import default_init_weights, make_layer, pixel_unshuffle 6 | 7 | 8 | class ResidualDenseBlock(nn.Module): 9 | """Residual Dense Block. 10 | 11 | Used in RRDB block in ESRGAN. 12 | 13 | Args: 14 | num_feat (int): Channel number of intermediate features. 15 | num_grow_ch (int): Channels for each growth. 16 | """ 17 | 18 | def __init__(self, num_feat=64, num_grow_ch=32): 19 | super(ResidualDenseBlock, self).__init__() 20 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 21 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 22 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 25 | 26 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 27 | 28 | # initialization 29 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 30 | 31 | def forward(self, x): 32 | x1 = self.lrelu(self.conv1(x)) 33 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 34 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 35 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 36 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 37 | # Emperically, we use 0.2 to scale the residual for better performance 38 | return x5 * 0.2 + x 39 | 40 | 41 | class RRDB(nn.Module): 42 | """Residual in Residual Dense Block. 43 | 44 | Used in RRDB-Net in ESRGAN. 45 | 46 | Args: 47 | num_feat (int): Channel number of intermediate features. 48 | num_grow_ch (int): Channels for each growth. 49 | """ 50 | 51 | def __init__(self, num_feat, num_grow_ch=32): 52 | super(RRDB, self).__init__() 53 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 54 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | 57 | def forward(self, x): 58 | out = self.rdb1(x) 59 | out = self.rdb2(out) 60 | out = self.rdb3(out) 61 | # Emperically, we use 0.2 to scale the residual for better performance 62 | return out * 0.2 + x 63 | 64 | class RRDBNet(nn.Module): 65 | """Networks consisting of Residual in Residual Dense Block, which is used 66 | in ESRGAN. 67 | 68 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 69 | 70 | We extend ESRGAN for scale x2 and scale x1. 71 | Note: This is one option for scale 1, scale 2 in RRDBNet. 72 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 73 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 74 | 75 | Args: 76 | num_in_ch (int): Channel number of inputs. 77 | num_out_ch (int): Channel number of outputs. 78 | num_feat (int): Channel number of intermediate features. 79 | Default: 64 80 | num_block (int): Block number in the trunk network. Defaults: 23 81 | num_grow_ch (int): Channels for each growth. Default: 32. 82 | """ 83 | 84 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 85 | super(RRDBNet, self).__init__() 86 | self.scale = scale 87 | if scale == 2: 88 | num_in_ch = num_in_ch * 4 89 | elif scale == 1: 90 | num_in_ch = num_in_ch * 16 91 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 92 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 93 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 94 | # upsample 95 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 96 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 98 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 99 | 100 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 101 | 102 | def forward(self, x): 103 | if self.scale == 2: 104 | feat = pixel_unshuffle(x, scale=2) 105 | elif self.scale == 1: 106 | feat = pixel_unshuffle(x, scale=4) 107 | else: 108 | feat = x 109 | feat = self.conv_first(feat) 110 | body_feat = self.conv_body(self.body(feat)) 111 | feat = feat + body_feat 112 | # upsample 113 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 114 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 115 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 116 | return out 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Face Animation in Real Time 2 | One-shot face animation using webcam, capable of running in real time. 3 | ## **Examples Results** 4 | (Driving video | Result video) 5 | - Original Result without Face Restoration 6 | 7 | https://github.com/sky24h/Face_Mapping_Real_Time/assets/26270672/231778e3-0f37-42c3-8cb0-cf849b22c8a8 8 | 9 | - With Face Restoration 10 | 11 | https://github.com/sky24h/Face_Mapping_Real_Time/assets/26270672/323fb958-77b4-444d-9a21-4d0245fb108c 12 | 13 | # How to Use 14 | ### 0. Dependencies 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | ### **1. For local webcam** 19 | Tested on RTX 3090, got 17 FPS without face restoration, and 10 FPS with face restoration. 20 | ``` 21 | python camera_local.py --source_image ./assets/source.jpg --restore_face False 22 | ``` 23 | The model output only supports size of 256, but you can change the output size to 512x512 or larger to get a resized output. 24 | 25 | ### **2. For input driving video** 26 | ``` 27 | python camera_local.py --source_image ./assets/source.jpg --restore_face False --driving_video ./assets/driving.mp4 --result_video ./result_video.mp4 --output_size 512 28 | ``` 29 | The driving video does not require any preprocessing, it is valid to use as long as every frame contains a face. 30 | 31 | ### **3. For remote access (Not recommended)** 32 | First you need to bind the port between server and client, for example, using vscode remote ssh like [this](https://code.visualstudio.com/docs/editor/port-forwarding). 33 | Then run the server side on the remote server, and run the client side on the local machine. 34 | 35 | Notably, due to the network latency, the FPS is low (only 1~2 FPS). 36 | 37 | Server Side: 38 | ``` 39 | python remote_server.py --source_image ./assets/source.jpg --restore_face False 40 | ``` 41 | 42 | Client Side (Copy only this file to local machine): 43 | ``` 44 | python remote_client.py 45 | ``` 46 | --- 47 | ### Pre-trained Models 48 | All necessary pre-trained models should be downloaded automatically when running the demo. 49 | If you somehow need to download them manually, please refer to the following links: 50 | 51 | [Motion Transfer Model](https://drive.google.com/file/d/11ZgyjKI5OcB7klcsIdPpCCX38AIX8Soc/view?usp=drive_link) 52 | 53 | [GPEN (Face Restoration Model](https://drive.google.com/drive/folders/1epln5c8HW1QXfVz6444Fe0hG-vRNavi6?usp=drive_link) 54 | 55 | 56 | # Acknowlegement: 57 | Motion transfer is modified from [zhanglonghao1992](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis). 58 | Face restoration is modified from [GPEN](https://github.com/yangxy/GPEN). 59 | 60 | Thanks to the authors for their great work! 61 | -------------------------------------------------------------------------------- /assets/driving.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/assets/driving.mp4 -------------------------------------------------------------------------------- /assets/source.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky24h/Face_Animation_Real_Time/7796148b94c21d929c80b842ef7157426e5c2ca5/assets/source.jpg -------------------------------------------------------------------------------- /camera_client.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import numpy as np 4 | import asyncio 5 | import websockets 6 | from argparse import ArgumentParser 7 | 8 | websocket_port = 8066 9 | 10 | 11 | class VideoCamera(object): 12 | def __init__(self, CameraSize=(640, 480)): 13 | self.video = cv2.VideoCapture(0) 14 | self.video.set(cv2.CAP_PROP_FRAME_WIDTH, CameraSize[0]) 15 | self.video.set(cv2.CAP_PROP_FRAME_HEIGHT, CameraSize[1]) 16 | self.video.set(cv2.CAP_PROP_FPS, 24) 17 | # check if camera opened successfully 18 | if not self.video.isOpened(): 19 | raise Exception("Camera not found") 20 | 21 | def __del__(self): 22 | self.video.release() 23 | 24 | def get_frame(self): 25 | success, image = self.video.read() 26 | image = cv2.flip(image, 1) 27 | return image 28 | 29 | 30 | async def send_image(image, ScreenSize=512, SendSize=256): 31 | # Encode the image as bytes 32 | _, image_data = cv2.imencode(".jpg", cv2.resize(image, (SendSize, SendSize)), [int(cv2.IMWRITE_JPEG_QUALITY), 90]) 33 | image_bytes = image_data.tobytes() 34 | # print size 35 | print("Image size: ", len(image_bytes)) 36 | 37 | # Connect to the FastAPI WebSocket server 38 | async with websockets.connect("ws://localhost:{}/ws".format(websocket_port)) as websocket: 39 | # Send the image to the server 40 | await websocket.send(image_bytes) 41 | print("Image sent to the server") 42 | # Receive and process the processed frame 43 | try: 44 | processed_frame_data = await websocket.recv() 45 | 46 | # Decode the processed frame 47 | processed_frame = cv2.imdecode(np.frombuffer(processed_frame_data, dtype=np.uint8), -1) 48 | processed_frame = cv2.resize(processed_frame, (ScreenSize * 2, ScreenSize)) 49 | # return processed_frame 50 | except Exception as e: 51 | print(e) 52 | # return image 53 | processed_frame = np.ones((ScreenSize, ScreenSize, 3), dtype=np.uint8) * 255 54 | cv2.putText(processed_frame, "No response from the server", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) 55 | cv2.imshow("Frame", processed_frame) 56 | cv2.waitKey(1) 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = ArgumentParser() 61 | parser.add_argument("--output_size", default=512, type=int, help="size of the output video") 62 | args = parser.parse_args() 63 | 64 | 65 | camera = VideoCamera() 66 | 67 | frame_count = 0 68 | times = [] 69 | while True: 70 | image = camera.get_frame() 71 | frame_count += 1 72 | time_start = time.time() 73 | asyncio.run(send_image(image, ScreenSize=args.output_size, SendSize=256)) 74 | times.append(time.time() - time_start) 75 | if frame_count % 10 == 0: 76 | print("FPS: {:.2f}".format(1 / np.mean(times))) 77 | times = [] 78 | -------------------------------------------------------------------------------- /camera_local.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import numpy as np 4 | 5 | from argparse import ArgumentParser 6 | from demo_utils import FaceAnimationClass 7 | 8 | 9 | class VideoCamera(object): 10 | def __init__(self, video_path=0, CameraSize=(640, 480)): 11 | self.video_path = video_path 12 | self.video = cv2.VideoCapture(video_path) if video_path != 0 else cv2.VideoCapture(0) 13 | self.video.set(cv2.CAP_PROP_FRAME_WIDTH, CameraSize[0]) 14 | self.video.set(cv2.CAP_PROP_FRAME_HEIGHT, CameraSize[1]) 15 | self.video.set(cv2.CAP_PROP_FPS, 24) 16 | # check if camera opened successfully 17 | if video_path == 0 and not self.video.isOpened(): 18 | raise Exception("Camera not found") 19 | elif video_path != 0 and not self.video.isOpened(): 20 | raise Exception("Video file not found") 21 | 22 | def __del__(self): 23 | self.video.release() 24 | 25 | def get_frame(self): 26 | success, image = self.video.read() 27 | image = cv2.flip(image, 1) if self.video_path == 0 else image 28 | return image 29 | 30 | 31 | def process_frame(image, ScreenSize=512): 32 | face, result = faceanimation.inference(image) 33 | if face.shape[1] != ScreenSize or face.shape[0] != ScreenSize: 34 | face = cv2.resize(face, (ScreenSize, ScreenSize)) 35 | if result.shape[0] != ScreenSize or result.shape[1] != ScreenSize: 36 | result = cv2.resize(result, (ScreenSize, ScreenSize)) 37 | return cv2.hconcat([face, result]) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = ArgumentParser() 42 | parser.add_argument("--source_image", default="./assets/source.jpg", help="path to source image") 43 | parser.add_argument("--driving_video", default=None, help="path to driving video") 44 | parser.add_argument("--result_video", default="./result_video.mp4", help="path to output") 45 | parser.add_argument("--output_size", default=512, type=int, help="size of the output video") 46 | parser.add_argument("--restore_face", default=False, type=str, help="restore face") 47 | args = parser.parse_args() 48 | restore_face = True if args.restore_face == 'True' else False if args.restore_face == 'False' else exit('restore_face must be True or False') 49 | 50 | if args.driving_video is None: 51 | video_path = 0 52 | print("Using webcam") 53 | # create window for displaying results 54 | cv2.namedWindow("Face Animation", cv2.WINDOW_NORMAL) 55 | else: 56 | video_path = args.driving_video 57 | print("Using driving video: {}".format(video_path)) 58 | camera = VideoCamera(video_path=video_path) 59 | faceanimation = FaceAnimationClass(source_image_path=args.source_image, use_sr=restore_face) 60 | 61 | frames = [] if args.result_video is not None else None 62 | frame_count = 0 63 | times = [] 64 | while True: 65 | time_start = time.time() 66 | image = camera.get_frame() 67 | if image is None and frame_count != 0: 68 | print("Video ended") 69 | break 70 | try: 71 | res = process_frame(image, ScreenSize=args.output_size) 72 | frame_count += 1 73 | times.append(time.time() - time_start) 74 | if frame_count % 100 == 0: 75 | print("FPS: {:.2f}".format(1 / np.mean(times))) 76 | times = [] 77 | frames.append(res) if args.result_video is not None else None 78 | # display results if using webcam 79 | if args.driving_video is None: 80 | cv2.imshow("Face Animation", res) 81 | if cv2.waitKey(1) & 0xFF == ord("q"): 82 | break 83 | except Exception as e: 84 | print(e) 85 | raise e 86 | 87 | if args.result_video is not None: 88 | import imageio 89 | from tqdm import tqdm 90 | 91 | writer = imageio.get_writer(args.result_video, fps=24, quality=9, macro_block_size=1, codec="libx264", pixelformat="yuv420p") 92 | for frame in tqdm(frames): 93 | writer.append_data(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 94 | writer.close() 95 | print("Video saved to {}".format(args.result_video)) 96 | -------------------------------------------------------------------------------- /face-vid2vid/README.md: -------------------------------------------------------------------------------- 1 | # One-Shot Free-View Neural Talking Head Synthesis 2 | Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing". 3 | 4 | ```Python 3.6``` and ```Pytorch 1.7``` are used. 5 | 6 | 7 | Updates: 8 | -------- 9 | ```2021.11.05``` : 10 | * Replace Jacobian with the rotation matrix (Assuming J = R) to avoid estimating Jacobian. 11 | * Correct the rotation matrix. 12 | 13 | ```2021.11.17``` : 14 | * Better Generator, better performance (models and checkpoints have been released). 15 | 16 | Driving | Beta Version | FOMM | New Version: 17 | 18 | 19 | https://user-images.githubusercontent.com/17874285/142828000-db7b324e-c2fd-4fdc-a272-04fb8adbc88a.mp4 20 | 21 | 22 | -------- 23 | Driving | FOMM | Ours: 24 | ![show](https://github.com/zhanglonghao1992/ReadmeImages/blob/master/images/081.gif) 25 | 26 | Free-View: 27 | ![show](https://github.com/zhanglonghao1992/ReadmeImages/blob/master/images/concat.gif) 28 | 29 | Train: 30 | -------- 31 | ``` 32 | python run.py --config config/vox-256.yaml --device_ids 0,1,2,3,4,5,6,7 33 | ``` 34 | 35 | Demo: 36 | -------- 37 | ``` 38 | python demo.py --config config/vox-256.yaml --checkpoint path/to/checkpoint --source_image path/to/source --driving_video path/to/driving --relative --adapt_scale --find_best_frame 39 | ``` 40 | free-view (e.g. yaw=20, pitch=roll=0): 41 | ``` 42 | python demo.py --config config/vox-256.yaml --checkpoint path/to/checkpoint --source_image path/to/source --driving_video path/to/driving --relative --adapt_scale --find_best_frame --free_view --yaw 20 --pitch 0 --roll 0 43 | ``` 44 | Note: run ```crop-video.py --inp driving_video.mp4``` first to get the cropping suggestion and crop the raw video. 45 | 46 | Pretrained Model: 47 | -------- 48 | 49 | Model | Train Set | Baidu Netdisk | Media Fire | 50 | ------- |------------ |----------- |-------- | 51 | Vox-256-Beta| VoxCeleb-v1 | [Baidu](https://pan.baidu.com/s/1lLS4ArbK2yWelsL-EtwU8g) (PW: c0tc)| [MF](https://www.mediafire.com/folder/rw51an7tk7bh2/TalkingHead) | 52 | Vox-256-New | VoxCeleb-v1 | - | [MF](https://www.mediafire.com/folder/fcvtkn21j57bb/TalkingHead_Update) | 53 | Vox-512 | VoxCeleb-v2 | soon | soon | 54 | 55 | Note: 56 | 1. For now, the Beta Version is not well tuned. 57 | 2. For free-view synthesis, it is recommended that Yaw, Pitch and Roll are within ±45°, ±20° and ±20° respectively. 58 | 3. Face Restoration algorithms ([GPEN](https://github.com/yangxy/GPEN)) can be used for post-processing to significantly improve the resolution. 59 | ![show](https://github.com/zhanglonghao1992/ReadmeImages/blob/master/images/s%20r.gif) 60 | 61 | 62 | Acknowlegement: 63 | -------- 64 | Thanks to [NV](https://github.com/NVlabs/face-vid2vid), [AliaksandrSiarohin](https://github.com/AliaksandrSiarohin/first-order-model) and [DeepHeadPose](https://github.com/DriverDistraction/DeepHeadPose). 65 | -------------------------------------------------------------------------------- /face-vid2vid/animate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | import imageio 8 | from scipy.spatial import ConvexHull 9 | import numpy as np 10 | 11 | from sync_batchnorm import DataParallelWithCallback 12 | 13 | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, 14 | use_relative_movement=False, use_relative_jacobian=False): 15 | if adapt_movement_scale: 16 | source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume 17 | driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume 18 | adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) 19 | else: 20 | adapt_movement_scale = 1 21 | 22 | kp_new = {k: v for k, v in kp_driving.items()} 23 | 24 | if use_relative_movement: 25 | kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) 26 | kp_value_diff *= adapt_movement_scale 27 | kp_new['value'] = kp_value_diff + kp_source['value'] 28 | 29 | if use_relative_jacobian: 30 | jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) 31 | kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) 32 | 33 | return kp_new 34 | -------------------------------------------------------------------------------- /face-vid2vid/augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from https://github.com/hassony2/torch_videovision 3 | """ 4 | 5 | import numbers 6 | 7 | import random 8 | import numpy as np 9 | import PIL 10 | 11 | from skimage.transform import resize, rotate 12 | from skimage.util import pad 13 | import torchvision 14 | 15 | import warnings 16 | 17 | from skimage import img_as_ubyte, img_as_float 18 | 19 | 20 | def crop_clip(clip, min_h, min_w, h, w): 21 | if isinstance(clip[0], np.ndarray): 22 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 23 | 24 | elif isinstance(clip[0], PIL.Image.Image): 25 | cropped = [ 26 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 27 | ] 28 | else: 29 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 30 | 'but got list of {0}'.format(type(clip[0]))) 31 | return cropped 32 | 33 | 34 | def pad_clip(clip, h, w): 35 | im_h, im_w = clip[0].shape[:2] 36 | pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2) 37 | pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2) 38 | 39 | return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge') 40 | 41 | 42 | def resize_clip(clip, size, interpolation='bilinear'): 43 | if isinstance(clip[0], np.ndarray): 44 | if isinstance(size, numbers.Number): 45 | im_h, im_w, im_c = clip[0].shape 46 | # Min spatial dim already matches minimal size 47 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 48 | and im_h == size): 49 | return clip 50 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 51 | size = (new_w, new_h) 52 | else: 53 | size = size[1], size[0] 54 | 55 | scaled = [ 56 | resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True, 57 | mode='constant', anti_aliasing=True) for img in clip 58 | ] 59 | elif isinstance(clip[0], PIL.Image.Image): 60 | if isinstance(size, numbers.Number): 61 | im_w, im_h = clip[0].size 62 | # Min spatial dim already matches minimal size 63 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 64 | and im_h == size): 65 | return clip 66 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 67 | size = (new_w, new_h) 68 | else: 69 | size = size[1], size[0] 70 | if interpolation == 'bilinear': 71 | pil_inter = PIL.Image.NEAREST 72 | else: 73 | pil_inter = PIL.Image.BILINEAR 74 | scaled = [img.resize(size, pil_inter) for img in clip] 75 | else: 76 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 77 | 'but got list of {0}'.format(type(clip[0]))) 78 | return scaled 79 | 80 | 81 | def get_resize_sizes(im_h, im_w, size): 82 | if im_w < im_h: 83 | ow = size 84 | oh = int(size * im_h / im_w) 85 | else: 86 | oh = size 87 | ow = int(size * im_w / im_h) 88 | return oh, ow 89 | 90 | 91 | class RandomFlip(object): 92 | def __init__(self, time_flip=False, horizontal_flip=False): 93 | self.time_flip = time_flip 94 | self.horizontal_flip = horizontal_flip 95 | 96 | def __call__(self, clip): 97 | if random.random() < 0.5 and self.time_flip: 98 | return clip[::-1] 99 | if random.random() < 0.5 and self.horizontal_flip: 100 | return [np.fliplr(img) for img in clip] 101 | 102 | return clip 103 | 104 | 105 | class RandomResize(object): 106 | """Resizes a list of (H x W x C) numpy.ndarray to the final size 107 | The larger the original image is, the more times it takes to 108 | interpolate 109 | Args: 110 | interpolation (str): Can be one of 'nearest', 'bilinear' 111 | defaults to nearest 112 | size (tuple): (widht, height) 113 | """ 114 | 115 | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): 116 | self.ratio = ratio 117 | self.interpolation = interpolation 118 | 119 | def __call__(self, clip): 120 | scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) 121 | 122 | if isinstance(clip[0], np.ndarray): 123 | im_h, im_w, im_c = clip[0].shape 124 | elif isinstance(clip[0], PIL.Image.Image): 125 | im_w, im_h = clip[0].size 126 | 127 | new_w = int(im_w * scaling_factor) 128 | new_h = int(im_h * scaling_factor) 129 | new_size = (new_w, new_h) 130 | resized = resize_clip( 131 | clip, new_size, interpolation=self.interpolation) 132 | 133 | return resized 134 | 135 | 136 | class RandomCrop(object): 137 | """Extract random crop at the same location for a list of videos 138 | Args: 139 | size (sequence or int): Desired output size for the 140 | crop in format (h, w) 141 | """ 142 | 143 | def __init__(self, size): 144 | if isinstance(size, numbers.Number): 145 | size = (size, size) 146 | 147 | self.size = size 148 | 149 | def __call__(self, clip): 150 | """ 151 | Args: 152 | img (PIL.Image or numpy.ndarray): List of videos to be cropped 153 | in format (h, w, c) in numpy.ndarray 154 | Returns: 155 | PIL.Image or numpy.ndarray: Cropped list of videos 156 | """ 157 | h, w = self.size 158 | if isinstance(clip[0], np.ndarray): 159 | im_h, im_w, im_c = clip[0].shape 160 | elif isinstance(clip[0], PIL.Image.Image): 161 | im_w, im_h = clip[0].size 162 | else: 163 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 164 | 'but got list of {0}'.format(type(clip[0]))) 165 | 166 | clip = pad_clip(clip, h, w) 167 | im_h, im_w = clip.shape[1:3] 168 | x1 = 0 if h == im_h else random.randint(0, im_w - w) 169 | y1 = 0 if w == im_w else random.randint(0, im_h - h) 170 | cropped = crop_clip(clip, y1, x1, h, w) 171 | 172 | return cropped 173 | 174 | 175 | class RandomRotation(object): 176 | """Rotate entire clip randomly by a random angle within 177 | given bounds 178 | Args: 179 | degrees (sequence or int): Range of degrees to select from 180 | If degrees is a number instead of sequence like (min, max), 181 | the range of degrees, will be (-degrees, +degrees). 182 | """ 183 | 184 | def __init__(self, degrees): 185 | if isinstance(degrees, numbers.Number): 186 | if degrees < 0: 187 | raise ValueError('If degrees is a single number,' 188 | 'must be positive') 189 | degrees = (-degrees, degrees) 190 | else: 191 | if len(degrees) != 2: 192 | raise ValueError('If degrees is a sequence,' 193 | 'it must be of len 2.') 194 | 195 | self.degrees = degrees 196 | 197 | def __call__(self, clip): 198 | """ 199 | Args: 200 | img (PIL.Image or numpy.ndarray): List of videos to be cropped 201 | in format (h, w, c) in numpy.ndarray 202 | Returns: 203 | PIL.Image or numpy.ndarray: Cropped list of videos 204 | """ 205 | angle = random.uniform(self.degrees[0], self.degrees[1]) 206 | if isinstance(clip[0], np.ndarray): 207 | rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip] 208 | elif isinstance(clip[0], PIL.Image.Image): 209 | rotated = [img.rotate(angle) for img in clip] 210 | else: 211 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 212 | 'but got list of {0}'.format(type(clip[0]))) 213 | 214 | return rotated 215 | 216 | 217 | class ColorJitter(object): 218 | """Randomly change the brightness, contrast and saturation and hue of the clip 219 | Args: 220 | brightness (float): How much to jitter brightness. brightness_factor 221 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 222 | contrast (float): How much to jitter contrast. contrast_factor 223 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 224 | saturation (float): How much to jitter saturation. saturation_factor 225 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 226 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 227 | [-hue, hue]. Should be >=0 and <= 0.5. 228 | """ 229 | 230 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 231 | self.brightness = brightness 232 | self.contrast = contrast 233 | self.saturation = saturation 234 | self.hue = hue 235 | 236 | def get_params(self, brightness, contrast, saturation, hue): 237 | if brightness > 0: 238 | brightness_factor = random.uniform( 239 | max(0, 1 - brightness), 1 + brightness) 240 | else: 241 | brightness_factor = None 242 | 243 | if contrast > 0: 244 | contrast_factor = random.uniform( 245 | max(0, 1 - contrast), 1 + contrast) 246 | else: 247 | contrast_factor = None 248 | 249 | if saturation > 0: 250 | saturation_factor = random.uniform( 251 | max(0, 1 - saturation), 1 + saturation) 252 | else: 253 | saturation_factor = None 254 | 255 | if hue > 0: 256 | hue_factor = random.uniform(-hue, hue) 257 | else: 258 | hue_factor = None 259 | return brightness_factor, contrast_factor, saturation_factor, hue_factor 260 | 261 | def __call__(self, clip): 262 | """ 263 | Args: 264 | clip (list): list of PIL.Image 265 | Returns: 266 | list PIL.Image : list of transformed PIL.Image 267 | """ 268 | if isinstance(clip[0], np.ndarray): 269 | brightness, contrast, saturation, hue = self.get_params( 270 | self.brightness, self.contrast, self.saturation, self.hue) 271 | 272 | # Create img transform function sequence 273 | img_transforms = [] 274 | if brightness is not None: 275 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 276 | if saturation is not None: 277 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 278 | if hue is not None: 279 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 280 | if contrast is not None: 281 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 282 | random.shuffle(img_transforms) 283 | img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array, 284 | img_as_float] 285 | 286 | with warnings.catch_warnings(): 287 | warnings.simplefilter("ignore") 288 | jittered_clip = [] 289 | for img in clip: 290 | jittered_img = img 291 | for func in img_transforms: 292 | jittered_img = func(jittered_img) 293 | jittered_clip.append(jittered_img.astype('float32')) 294 | elif isinstance(clip[0], PIL.Image.Image): 295 | brightness, contrast, saturation, hue = self.get_params( 296 | self.brightness, self.contrast, self.saturation, self.hue) 297 | 298 | # Create img transform function sequence 299 | img_transforms = [] 300 | if brightness is not None: 301 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 302 | if saturation is not None: 303 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 304 | if hue is not None: 305 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 306 | if contrast is not None: 307 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 308 | random.shuffle(img_transforms) 309 | 310 | # Apply to all videos 311 | jittered_clip = [] 312 | for img in clip: 313 | for func in img_transforms: 314 | jittered_img = func(img) 315 | jittered_clip.append(jittered_img) 316 | 317 | else: 318 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 319 | 'but got list of {0}'.format(type(clip[0]))) 320 | return jittered_clip 321 | 322 | 323 | class AllAugmentationTransform: 324 | def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None): 325 | self.transforms = [] 326 | 327 | if flip_param is not None: 328 | self.transforms.append(RandomFlip(**flip_param)) 329 | 330 | if rotation_param is not None: 331 | self.transforms.append(RandomRotation(**rotation_param)) 332 | 333 | if resize_param is not None: 334 | self.transforms.append(RandomResize(**resize_param)) 335 | 336 | if crop_param is not None: 337 | self.transforms.append(RandomCrop(**crop_param)) 338 | 339 | if jitter_param is not None: 340 | self.transforms.append(ColorJitter(**jitter_param)) 341 | 342 | def __call__(self, clip): 343 | for t in self.transforms: 344 | clip = t(clip) 345 | return clip 346 | -------------------------------------------------------------------------------- /face-vid2vid/config/vox-256-spade.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: 3 | frame_shape: [256, 256, 3] 4 | id_sampling: True 5 | pairs_list: None 6 | augmentation_params: 7 | flip_param: 8 | horizontal_flip: True 9 | time_flip: True 10 | jitter_param: 11 | brightness: 0.1 12 | contrast: 0.1 13 | saturation: 0.1 14 | hue: 0.1 15 | 16 | 17 | model_params: 18 | common_params: 19 | num_kp: 15 20 | image_channel: 3 21 | feature_channel: 32 22 | estimate_jacobian: False 23 | kp_detector_params: 24 | temperature: 0.1 25 | block_expansion: 32 26 | max_features: 1024 27 | scale_factor: 0.25 28 | num_blocks: 5 29 | reshape_channel: 16384 # 16384 = 1024 * 16 30 | reshape_depth: 16 31 | he_estimator_params: 32 | block_expansion: 64 33 | max_features: 2048 34 | num_bins: 66 35 | generator_params: 36 | block_expansion: 64 37 | max_features: 512 38 | num_down_blocks: 2 39 | reshape_channel: 32 40 | reshape_depth: 16 # 512 = 32 * 16 41 | num_resblocks: 6 42 | estimate_occlusion_map: True 43 | dense_motion_params: 44 | block_expansion: 32 45 | max_features: 1024 46 | num_blocks: 5 47 | # reshape_channel: 32 48 | reshape_depth: 16 49 | compress: 4 50 | discriminator_params: 51 | scales: [1] 52 | block_expansion: 32 53 | max_features: 512 54 | num_blocks: 4 55 | sn: True 56 | 57 | train_params: 58 | num_epochs: 200 59 | num_repeats: 75 60 | epoch_milestones: [180,] 61 | lr_generator: 2.0e-4 62 | lr_discriminator: 2.0e-4 63 | lr_kp_detector: 2.0e-4 64 | lr_he_estimator: 2.0e-4 65 | gan_mode: 'hinge' # hinge or ls 66 | batch_size: 1 67 | scales: [1, 0.5, 0.25, 0.125] 68 | checkpoint_freq: 60 69 | hopenet_snapshot: './checkpoints/hopenet_robust_alpha1.pkl' 70 | transform_params: 71 | sigma_affine: 0.05 72 | sigma_tps: 0.005 73 | points_tps: 5 74 | loss_weights: 75 | generator_gan: 1 76 | discriminator_gan: 1 77 | feature_matching: [10, 10, 10, 10] 78 | perceptual: [10, 10, 10, 10, 10] 79 | equivariance_value: 10 80 | equivariance_jacobian: 0 81 | keypoint: 10 82 | headpose: 20 83 | expression: 5 84 | 85 | visualizer_params: 86 | kp_size: 5 87 | draw_border: True 88 | colormap: 'gist_rainbow' 89 | -------------------------------------------------------------------------------- /face-vid2vid/crop-video.py: -------------------------------------------------------------------------------- 1 | import face_alignment 2 | import skimage.io 3 | import numpy 4 | from argparse import ArgumentParser 5 | from skimage import img_as_ubyte 6 | from skimage.transform import resize 7 | from tqdm import tqdm 8 | import os 9 | import imageio 10 | import numpy as np 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | 14 | def extract_bbox(frame, fa): 15 | if max(frame.shape[0], frame.shape[1]) > 640: 16 | scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0 17 | frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor))) 18 | frame = img_as_ubyte(frame) 19 | else: 20 | scale_factor = 1 21 | frame = frame[..., :3] 22 | bboxes = fa.face_detector.detect_from_image(frame[..., ::-1]) 23 | if len(bboxes) == 0: 24 | return [] 25 | return np.array(bboxes)[:, :-1] * scale_factor 26 | 27 | 28 | 29 | def bb_intersection_over_union(boxA, boxB): 30 | xA = max(boxA[0], boxB[0]) 31 | yA = max(boxA[1], boxB[1]) 32 | xB = min(boxA[2], boxB[2]) 33 | yB = min(boxA[3], boxB[3]) 34 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 35 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 36 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 37 | iou = interArea / float(boxAArea + boxBArea - interArea) 38 | return iou 39 | 40 | 41 | def join(tube_bbox, bbox): 42 | xA = min(tube_bbox[0], bbox[0]) 43 | yA = min(tube_bbox[1], bbox[1]) 44 | xB = max(tube_bbox[2], bbox[2]) 45 | yB = max(tube_bbox[3], bbox[3]) 46 | return (xA, yA, xB, yB) 47 | 48 | 49 | def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1): 50 | left, top, right, bot = tube_bbox 51 | width = right - left 52 | height = bot - top 53 | 54 | #Computing aspect preserving bbox 55 | width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) 56 | height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) 57 | 58 | left = int(left - width_increase * width) 59 | top = int(top - height_increase * height) 60 | right = int(right + width_increase * width) 61 | bot = int(bot + height_increase * height) 62 | 63 | top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1]) 64 | h, w = bot - top, right - left 65 | 66 | start = start / fps 67 | end = end / fps 68 | time = end - start 69 | 70 | scale = f'{image_shape[0]}:{image_shape[1]}' 71 | 72 | return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4' 73 | 74 | 75 | def compute_bbox_trajectories(trajectories, fps, frame_shape, args): 76 | commands = [] 77 | for i, (bbox, tube_bbox, start, end) in enumerate(trajectories): 78 | if (end - start) > args.min_frames: 79 | command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase) 80 | commands.append(command) 81 | return commands 82 | 83 | 84 | def process_video(args): 85 | device = 'cpu' if args.cpu else 'cuda' 86 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device) 87 | video = imageio.get_reader(args.inp) 88 | 89 | trajectories = [] 90 | previous_frame = None 91 | fps = video.get_meta_data()['fps'] 92 | commands = [] 93 | try: 94 | for i, frame in tqdm(enumerate(video)): 95 | frame_shape = frame.shape 96 | bboxes = extract_bbox(frame, fa) 97 | ## For each trajectory check the criterion 98 | not_valid_trajectories = [] 99 | valid_trajectories = [] 100 | 101 | for trajectory in trajectories: 102 | tube_bbox = trajectory[0] 103 | intersection = 0 104 | for bbox in bboxes: 105 | intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox)) 106 | if intersection > args.iou_with_initial: 107 | valid_trajectories.append(trajectory) 108 | else: 109 | not_valid_trajectories.append(trajectory) 110 | 111 | commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args) 112 | trajectories = valid_trajectories 113 | 114 | ## Assign bbox to trajectories, create new trajectories 115 | for bbox in bboxes: 116 | intersection = 0 117 | current_trajectory = None 118 | for trajectory in trajectories: 119 | tube_bbox = trajectory[0] 120 | current_intersection = bb_intersection_over_union(tube_bbox, bbox) 121 | if intersection < current_intersection and current_intersection > args.iou_with_initial: 122 | intersection = bb_intersection_over_union(tube_bbox, bbox) 123 | current_trajectory = trajectory 124 | 125 | ## Create new trajectory 126 | if current_trajectory is None: 127 | trajectories.append([bbox, bbox, i, i]) 128 | else: 129 | current_trajectory[3] = i 130 | current_trajectory[1] = join(current_trajectory[1], bbox) 131 | 132 | 133 | except IndexError as e: 134 | raise (e) 135 | 136 | commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args) 137 | return commands 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = ArgumentParser() 142 | 143 | parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))), 144 | help="Image shape") 145 | parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount') 146 | parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox") 147 | parser.add_argument("--inp", required=True, help='Input image or video') 148 | parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames') 149 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") 150 | 151 | 152 | args = parser.parse_args() 153 | 154 | commands = process_video(args) 155 | for command in commands: 156 | print (command) 157 | 158 | -------------------------------------------------------------------------------- /face-vid2vid/frames_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io, img_as_float32 3 | from skimage.color import gray2rgb 4 | from sklearn.model_selection import train_test_split 5 | from imageio import mimread 6 | 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | import pandas as pd 10 | from augmentation import AllAugmentationTransform 11 | import glob 12 | 13 | 14 | def read_video(name, frame_shape): 15 | """ 16 | Read video which can be: 17 | - an image of concatenated frames 18 | - '.mp4' and'.gif' 19 | - folder with videos 20 | """ 21 | 22 | if os.path.isdir(name): 23 | frames = sorted(os.listdir(name)) 24 | num_frames = len(frames) 25 | video_array = np.array( 26 | [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)]) 27 | elif name.lower().endswith('.png') or name.lower().endswith('.jpg'): 28 | image = io.imread(name) 29 | 30 | if len(image.shape) == 2 or image.shape[2] == 1: 31 | image = gray2rgb(image) 32 | 33 | if image.shape[2] == 4: 34 | image = image[..., :3] 35 | 36 | image = img_as_float32(image) 37 | 38 | video_array = np.moveaxis(image, 1, 0) 39 | 40 | video_array = video_array.reshape((-1,) + frame_shape) 41 | video_array = np.moveaxis(video_array, 1, 2) 42 | elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'): 43 | video = np.array(mimread(name)) 44 | if len(video.shape) == 3: 45 | video = np.array([gray2rgb(frame) for frame in video]) 46 | if video.shape[-1] == 4: 47 | video = video[..., :3] 48 | video_array = img_as_float32(video) 49 | else: 50 | raise Exception("Unknown file extensions %s" % name) 51 | 52 | return video_array 53 | 54 | 55 | class FramesDataset(Dataset): 56 | """ 57 | Dataset of videos, each video can be represented as: 58 | - an image of concatenated frames 59 | - '.mp4' or '.gif' 60 | - folder with all frames 61 | """ 62 | 63 | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, 64 | random_seed=0, pairs_list=None, augmentation_params=None): 65 | self.root_dir = root_dir 66 | self.videos = os.listdir(root_dir) 67 | self.frame_shape = tuple(frame_shape) 68 | self.pairs_list = pairs_list 69 | self.id_sampling = id_sampling 70 | if os.path.exists(os.path.join(root_dir, 'train')): 71 | assert os.path.exists(os.path.join(root_dir, 'test')) 72 | print("Use predefined train-test split.") 73 | if id_sampling: 74 | train_videos = {os.path.basename(video).split('#')[0] for video in 75 | os.listdir(os.path.join(root_dir, 'train'))} 76 | train_videos = list(train_videos) 77 | else: 78 | train_videos = os.listdir(os.path.join(root_dir, 'train')) 79 | test_videos = os.listdir(os.path.join(root_dir, 'test')) 80 | self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test') 81 | else: 82 | print("Use random train-test split.") 83 | train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) 84 | 85 | if is_train: 86 | self.videos = train_videos 87 | else: 88 | self.videos = test_videos 89 | 90 | self.is_train = is_train 91 | 92 | if self.is_train: 93 | self.transform = AllAugmentationTransform(**augmentation_params) 94 | else: 95 | self.transform = None 96 | 97 | def __len__(self): 98 | return len(self.videos) 99 | 100 | def __getitem__(self, idx): 101 | if self.is_train and self.id_sampling: 102 | name = self.videos[idx] 103 | path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) 104 | else: 105 | name = self.videos[idx] 106 | path = os.path.join(self.root_dir, name) 107 | 108 | video_name = os.path.basename(path) 109 | 110 | if self.is_train and os.path.isdir(path): 111 | frames = os.listdir(path) 112 | num_frames = len(frames) 113 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) 114 | video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx] 115 | else: 116 | video_array = read_video(path, frame_shape=self.frame_shape) 117 | num_frames = len(video_array) 118 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range( 119 | num_frames) 120 | video_array = video_array[frame_idx] 121 | 122 | if self.transform is not None: 123 | video_array = self.transform(video_array) 124 | 125 | out = {} 126 | if self.is_train: 127 | source = np.array(video_array[0], dtype='float32') 128 | driving = np.array(video_array[1], dtype='float32') 129 | 130 | out['driving'] = driving.transpose((2, 0, 1)) 131 | out['source'] = source.transpose((2, 0, 1)) 132 | else: 133 | video = np.array(video_array, dtype='float32') 134 | out['video'] = video.transpose((3, 0, 1, 2)) 135 | 136 | out['name'] = video_name 137 | 138 | return out 139 | 140 | 141 | class DatasetRepeater(Dataset): 142 | """ 143 | Pass several times over the same dataset for better i/o performance 144 | """ 145 | 146 | def __init__(self, dataset, num_repeats=100): 147 | self.dataset = dataset 148 | self.num_repeats = num_repeats 149 | 150 | def __len__(self): 151 | return self.num_repeats * self.dataset.__len__() 152 | 153 | def __getitem__(self, idx): 154 | return self.dataset[idx % self.dataset.__len__()] 155 | -------------------------------------------------------------------------------- /face-vid2vid/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import imageio 5 | 6 | import os 7 | from skimage.draw import circle_perimeter 8 | 9 | import matplotlib.pyplot as plt 10 | import collections 11 | 12 | 13 | class Logger: 14 | def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name="log.txt"): 15 | self.loss_list = [] 16 | self.cpk_dir = log_dir 17 | self.visualizations_dir = os.path.join(log_dir, "train-vis") 18 | if not os.path.exists(self.visualizations_dir): 19 | os.makedirs(self.visualizations_dir) 20 | self.log_file = open(os.path.join(log_dir, log_file_name), "a") 21 | self.zfill_num = zfill_num 22 | self.visualizer = Visualizer(**visualizer_params) 23 | self.checkpoint_freq = checkpoint_freq 24 | self.epoch = 0 25 | self.best_loss = float("inf") 26 | self.names = None 27 | 28 | def log_scores(self, loss_names): 29 | loss_mean = np.array(self.loss_list).mean(axis=0) 30 | 31 | loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)]) 32 | loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string 33 | 34 | print(loss_string, file=self.log_file) 35 | self.loss_list = [] 36 | self.log_file.flush() 37 | 38 | def visualize_rec(self, inp, out): 39 | image = self.visualizer.visualize(inp["driving"], inp["source"], out) 40 | imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image) 41 | 42 | def save_cpk(self, emergent=False): 43 | cpk = {k: v.state_dict() for k, v in self.models.items()} 44 | cpk["epoch"] = self.epoch 45 | cpk_path = os.path.join(self.cpk_dir, "%s-checkpoint.pth.tar" % str(self.epoch).zfill(self.zfill_num)) 46 | if not (os.path.exists(cpk_path) and emergent): 47 | torch.save(cpk, cpk_path) 48 | 49 | @staticmethod 50 | def load_cpk( 51 | checkpoint_path, 52 | generator=None, 53 | discriminator=None, 54 | kp_detector=None, 55 | he_estimator=None, 56 | optimizer_generator=None, 57 | optimizer_discriminator=None, 58 | optimizer_kp_detector=None, 59 | optimizer_he_estimator=None, 60 | ): 61 | checkpoint = torch.load(checkpoint_path) 62 | if generator is not None: 63 | generator.load_state_dict(checkpoint["generator"]) 64 | if kp_detector is not None: 65 | kp_detector.load_state_dict(checkpoint["kp_detector"]) 66 | if he_estimator is not None: 67 | he_estimator.load_state_dict(checkpoint["he_estimator"]) 68 | if discriminator is not None: 69 | try: 70 | discriminator.load_state_dict(checkpoint["discriminator"]) 71 | except: 72 | print("No discriminator in the state-dict. Dicriminator will be randomly initialized") 73 | if optimizer_generator is not None: 74 | optimizer_generator.load_state_dict(checkpoint["optimizer_generator"]) 75 | if optimizer_discriminator is not None: 76 | try: 77 | optimizer_discriminator.load_state_dict(checkpoint["optimizer_discriminator"]) 78 | except RuntimeError as e: 79 | print("No discriminator optimizer in the state-dict. Optimizer will be not initialized") 80 | if optimizer_kp_detector is not None: 81 | optimizer_kp_detector.load_state_dict(checkpoint["optimizer_kp_detector"]) 82 | if optimizer_he_estimator is not None: 83 | optimizer_he_estimator.load_state_dict(checkpoint["optimizer_he_estimator"]) 84 | 85 | return checkpoint["epoch"] 86 | 87 | def __enter__(self): 88 | return self 89 | 90 | def __exit__(self, exc_type, exc_val, exc_tb): 91 | if "models" in self.__dict__: 92 | self.save_cpk() 93 | self.log_file.close() 94 | 95 | def log_iter(self, losses): 96 | losses = collections.OrderedDict(losses.items()) 97 | if self.names is None: 98 | self.names = list(losses.keys()) 99 | self.loss_list.append(list(losses.values())) 100 | 101 | def log_epoch(self, epoch, models, inp, out): 102 | self.epoch = epoch 103 | self.models = models 104 | if (self.epoch + 1) % self.checkpoint_freq == 0: 105 | self.save_cpk() 106 | self.log_scores(self.names) 107 | self.visualize_rec(inp, out) 108 | 109 | 110 | class Visualizer: 111 | def __init__(self, kp_size=5, draw_border=False, colormap="gist_rainbow"): 112 | self.kp_size = kp_size 113 | self.draw_border = draw_border 114 | self.colormap = plt.get_cmap(colormap) 115 | 116 | def draw_image_with_kp(self, image, kp_array): 117 | image = np.copy(image) 118 | spatial_size = np.array(image.shape[:2][::-1])[np.newaxis] 119 | kp_array = spatial_size * (kp_array + 1) / 2 120 | num_kp = kp_array.shape[0] 121 | for kp_ind, kp in enumerate(kp_array): 122 | rr, cc = circle_perimeter(kp[1], kp[0], self.kp_size, shape=image.shape[:2]) 123 | image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3] 124 | return image 125 | 126 | def create_image_column_with_kp(self, images, kp): 127 | image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)]) 128 | return self.create_image_column(image_array) 129 | 130 | def create_image_column(self, images): 131 | if self.draw_border: 132 | images = np.copy(images) 133 | images[:, :, [0, -1]] = (1, 1, 1) 134 | images[:, :, [0, -1]] = (1, 1, 1) 135 | return np.concatenate(list(images), axis=0) 136 | 137 | def create_image_grid(self, *args): 138 | out = [] 139 | for arg in args: 140 | if type(arg) == tuple: 141 | out.append(self.create_image_column_with_kp(arg[0], arg[1])) 142 | else: 143 | out.append(self.create_image_column(arg)) 144 | return np.concatenate(out, axis=1) 145 | 146 | def visualize(self, driving, source, out): 147 | images = [] 148 | 149 | # Source image with keypoints 150 | source = source.data.cpu() 151 | kp_source = out["kp_source"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d 152 | source = np.transpose(source, [0, 2, 3, 1]) 153 | images.append((source, kp_source)) 154 | 155 | # Equivariance visualization 156 | if "transformed_frame" in out: 157 | transformed = out["transformed_frame"].data.cpu().numpy() 158 | transformed = np.transpose(transformed, [0, 2, 3, 1]) 159 | transformed_kp = out["transformed_kp"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d 160 | images.append((transformed, transformed_kp)) 161 | 162 | # Driving image with keypoints 163 | kp_driving = out["kp_driving"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d 164 | driving = driving.data.cpu().numpy() 165 | driving = np.transpose(driving, [0, 2, 3, 1]) 166 | images.append((driving, kp_driving)) 167 | 168 | # Result 169 | prediction = out["prediction"].data.cpu().numpy() 170 | prediction = np.transpose(prediction, [0, 2, 3, 1]) 171 | images.append(prediction) 172 | 173 | ## Occlusion map 174 | if "occlusion_map" in out: 175 | occlusion_map = out["occlusion_map"].data.cpu().repeat(1, 3, 1, 1) 176 | occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy() 177 | occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1]) 178 | images.append(occlusion_map) 179 | 180 | ## Mask 181 | if "mask" in out: 182 | for i in range(out["mask"].shape[1]): 183 | mask = out["mask"][:, i : (i + 1)].data.cpu().sum(2).repeat(1, 3, 1, 1) # (n, 3, h, w) 184 | # mask = F.softmax(mask.view(mask.shape[0], mask.shape[1], -1), dim=2).view(mask.shape) 185 | mask = F.interpolate(mask, size=source.shape[1:3]).numpy() 186 | mask = np.transpose(mask, [0, 2, 3, 1]) 187 | 188 | if i != 0: 189 | color = np.array(self.colormap((i - 1) / (out["mask"].shape[1] - 1)))[:3] 190 | else: 191 | color = np.array((0, 0, 0)) 192 | 193 | color = color.reshape((1, 1, 1, 3)) 194 | 195 | if i != 0: 196 | images.append(mask * color) 197 | else: 198 | images.append(mask) 199 | 200 | image = self.create_image_grid(*images) 201 | image = (255 * image).astype(np.uint8) 202 | return image 203 | -------------------------------------------------------------------------------- /face-vid2vid/modules/dense_motion.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | from modules.util import Hourglass, make_coordinate_grid, kp2gaussian 5 | 6 | from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d 7 | 8 | 9 | class DenseMotionNetwork(nn.Module): 10 | """ 11 | Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving 12 | """ 13 | 14 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, 15 | estimate_occlusion_map=False): 16 | super(DenseMotionNetwork, self).__init__() 17 | # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks) 18 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) 19 | 20 | self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) 21 | 22 | self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) 23 | self.norm = BatchNorm3d(compress, affine=True) 24 | 25 | if estimate_occlusion_map: 26 | # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3) 27 | self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) 28 | else: 29 | self.occlusion = None 30 | 31 | self.num_kp = num_kp 32 | 33 | 34 | def create_sparse_motions(self, feature, kp_driving, kp_source): 35 | bs, _, d, h, w = feature.shape 36 | identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type()) 37 | identity_grid = identity_grid.view(1, 1, d, h, w, 3) 38 | coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3) 39 | 40 | k = coordinate_grid.shape[1] 41 | 42 | # if 'jacobian' in kp_driving: 43 | if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None: 44 | jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) 45 | jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) 46 | jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1) 47 | coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) 48 | coordinate_grid = coordinate_grid.squeeze(-1) 49 | ''' 50 | if 'rot' in kp_driving: 51 | rot_s = kp_source['rot'] 52 | rot_d = kp_driving['rot'] 53 | rot = torch.einsum('bij, bjk->bki', rot_s, torch.inverse(rot_d)) 54 | rot = rot.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) 55 | rot = rot.repeat(1, k, d, h, w, 1, 1) 56 | # print(rot.shape) 57 | coordinate_grid = torch.matmul(rot, coordinate_grid.unsqueeze(-1)) 58 | coordinate_grid = coordinate_grid.squeeze(-1) 59 | # print(coordinate_grid.shape) 60 | ''' 61 | driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) 62 | 63 | #adding background feature 64 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) 65 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) 66 | 67 | # sparse_motions = driving_to_source 68 | 69 | return sparse_motions 70 | 71 | def create_deformed_feature(self, feature, sparse_motions): 72 | bs, _, d, h, w = feature.shape 73 | feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) 74 | feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) 75 | sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) 76 | sparse_deformed = F.grid_sample(feature_repeat, sparse_motions) 77 | sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) 78 | return sparse_deformed 79 | 80 | def create_heatmap_representations(self, feature, kp_driving, kp_source): 81 | spatial_size = feature.shape[3:] 82 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) 83 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) 84 | heatmap = gaussian_driving - gaussian_source 85 | 86 | # adding background feature 87 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()) 88 | heatmap = torch.cat([zeros, heatmap], dim=1) 89 | heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) 90 | return heatmap 91 | 92 | def forward(self, feature, kp_driving, kp_source): 93 | bs, _, d, h, w = feature.shape 94 | 95 | feature = self.compress(feature) 96 | feature = self.norm(feature) 97 | feature = F.relu(feature) 98 | 99 | out_dict = dict() 100 | sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) 101 | deformed_feature = self.create_deformed_feature(feature, sparse_motion) 102 | 103 | heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) 104 | 105 | input = torch.cat([heatmap, deformed_feature], dim=2) 106 | input = input.view(bs, -1, d, h, w) 107 | 108 | # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w) 109 | 110 | prediction = self.hourglass(input) 111 | 112 | mask = self.mask(prediction) 113 | mask = F.softmax(mask, dim=1) 114 | out_dict['mask'] = mask 115 | mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) 116 | sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) 117 | deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) 118 | deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) 119 | 120 | out_dict['deformation'] = deformation 121 | 122 | if self.occlusion: 123 | bs, c, d, h, w = prediction.shape 124 | prediction = prediction.view(bs, -1, h, w) 125 | occlusion_map = torch.sigmoid(self.occlusion(prediction)) 126 | out_dict['occlusion_map'] = occlusion_map 127 | 128 | return out_dict 129 | -------------------------------------------------------------------------------- /face-vid2vid/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from modules.util import kp2gaussian 4 | import torch 5 | 6 | 7 | class DownBlock2d(nn.Module): 8 | """ 9 | Simple block for processing video (encoder). 10 | """ 11 | 12 | def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): 13 | super(DownBlock2d, self).__init__() 14 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) 15 | 16 | if sn: 17 | self.conv = nn.utils.spectral_norm(self.conv) 18 | 19 | if norm: 20 | self.norm = nn.InstanceNorm2d(out_features, affine=True) 21 | else: 22 | self.norm = None 23 | self.pool = pool 24 | 25 | def forward(self, x): 26 | out = x 27 | out = self.conv(out) 28 | if self.norm: 29 | out = self.norm(out) 30 | out = F.leaky_relu(out, 0.2) 31 | if self.pool: 32 | out = F.avg_pool2d(out, (2, 2)) 33 | return out 34 | 35 | 36 | class Discriminator(nn.Module): 37 | """ 38 | Discriminator similar to Pix2Pix 39 | """ 40 | 41 | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, 42 | sn=False, **kwargs): 43 | super(Discriminator, self).__init__() 44 | 45 | down_blocks = [] 46 | for i in range(num_blocks): 47 | down_blocks.append( 48 | DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), 49 | min(max_features, block_expansion * (2 ** (i + 1))), 50 | norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) 51 | 52 | self.down_blocks = nn.ModuleList(down_blocks) 53 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) 54 | if sn: 55 | self.conv = nn.utils.spectral_norm(self.conv) 56 | 57 | def forward(self, x): 58 | feature_maps = [] 59 | out = x 60 | 61 | for down_block in self.down_blocks: 62 | feature_maps.append(down_block(out)) 63 | out = feature_maps[-1] 64 | prediction_map = self.conv(out) 65 | 66 | return feature_maps, prediction_map 67 | 68 | 69 | class MultiScaleDiscriminator(nn.Module): 70 | """ 71 | Multi-scale (scale) discriminator 72 | """ 73 | 74 | def __init__(self, scales=(), **kwargs): 75 | super(MultiScaleDiscriminator, self).__init__() 76 | self.scales = scales 77 | discs = {} 78 | for scale in scales: 79 | discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) 80 | self.discs = nn.ModuleDict(discs) 81 | 82 | def forward(self, x): 83 | out_dict = {} 84 | for scale, disc in self.discs.items(): 85 | scale = str(scale).replace('-', '.') 86 | key = 'prediction_' + scale 87 | feature_maps, prediction_map = disc(x[key]) 88 | out_dict['feature_maps_' + scale] = feature_maps 89 | out_dict['prediction_map_' + scale] = prediction_map 90 | return out_dict 91 | -------------------------------------------------------------------------------- /face-vid2vid/modules/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock 5 | from modules.dense_motion import DenseMotionNetwork 6 | 7 | 8 | class OcclusionAwareGenerator(nn.Module): 9 | """ 10 | Generator follows NVIDIA architecture. 11 | """ 12 | 13 | def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, 14 | num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): 15 | super(OcclusionAwareGenerator, self).__init__() 16 | 17 | if dense_motion_params is not None: 18 | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, 19 | estimate_occlusion_map=estimate_occlusion_map, 20 | **dense_motion_params) 21 | else: 22 | self.dense_motion_network = None 23 | 24 | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) 25 | 26 | down_blocks = [] 27 | for i in range(num_down_blocks): 28 | in_features = min(max_features, block_expansion * (2 ** i)) 29 | out_features = min(max_features, block_expansion * (2 ** (i + 1))) 30 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 31 | self.down_blocks = nn.ModuleList(down_blocks) 32 | 33 | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) 34 | 35 | self.reshape_channel = reshape_channel 36 | self.reshape_depth = reshape_depth 37 | 38 | self.resblocks_3d = torch.nn.Sequential() 39 | for i in range(num_resblocks): 40 | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) 41 | 42 | out_features = block_expansion * (2 ** (num_down_blocks)) 43 | self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) 44 | self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) 45 | 46 | self.resblocks_2d = torch.nn.Sequential() 47 | for i in range(num_resblocks): 48 | self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) 49 | 50 | up_blocks = [] 51 | for i in range(num_down_blocks): 52 | in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) 53 | out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) 54 | up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 55 | self.up_blocks = nn.ModuleList(up_blocks) 56 | 57 | self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) 58 | self.estimate_occlusion_map = estimate_occlusion_map 59 | self.image_channel = image_channel 60 | 61 | def deform_input(self, inp, deformation): 62 | _, d_old, h_old, w_old, _ = deformation.shape 63 | _, _, d, h, w = inp.shape 64 | if d_old != d or h_old != h or w_old != w: 65 | deformation = deformation.permute(0, 4, 1, 2, 3) 66 | deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') 67 | deformation = deformation.permute(0, 2, 3, 4, 1) 68 | return F.grid_sample(inp, deformation) 69 | 70 | def forward(self, source_image, kp_driving, kp_source): 71 | # Encoding (downsampling) part 72 | out = self.first(source_image) 73 | for i in range(len(self.down_blocks)): 74 | out = self.down_blocks[i](out) 75 | out = self.second(out) 76 | bs, c, h, w = out.shape 77 | # print(out.shape) 78 | feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) 79 | feature_3d = self.resblocks_3d(feature_3d) 80 | 81 | # Transforming feature representation according to deformation and occlusion 82 | output_dict = {} 83 | if self.dense_motion_network is not None: 84 | dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, 85 | kp_source=kp_source) 86 | output_dict['mask'] = dense_motion['mask'] 87 | 88 | if 'occlusion_map' in dense_motion: 89 | occlusion_map = dense_motion['occlusion_map'] 90 | output_dict['occlusion_map'] = occlusion_map 91 | else: 92 | occlusion_map = None 93 | deformation = dense_motion['deformation'] 94 | out = self.deform_input(feature_3d, deformation) 95 | 96 | bs, c, d, h, w = out.shape 97 | out = out.view(bs, c*d, h, w) 98 | out = self.third(out) 99 | out = self.fourth(out) 100 | 101 | if occlusion_map is not None: 102 | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: 103 | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') 104 | out = out * occlusion_map 105 | 106 | # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image 107 | 108 | # Decoding part 109 | out = self.resblocks_2d(out) 110 | for i in range(len(self.up_blocks)): 111 | out = self.up_blocks[i](out) 112 | out = self.final(out) 113 | out = F.sigmoid(out) 114 | 115 | output_dict["prediction"] = out 116 | 117 | return output_dict 118 | 119 | 120 | class SPADEDecoder(nn.Module): 121 | def __init__(self): 122 | super().__init__() 123 | ic = 256 124 | oc = 64 125 | norm_G = 'spadespectralinstance' 126 | label_nc = 256 127 | 128 | self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) 129 | self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) 130 | self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) 131 | self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) 132 | self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) 133 | self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) 134 | self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) 135 | self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) 136 | self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) 137 | self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) 138 | self.up = nn.Upsample(scale_factor=2) 139 | 140 | def forward(self, feature): 141 | seg = feature 142 | x = self.fc(feature) 143 | x = self.G_middle_0(x, seg) 144 | x = self.G_middle_1(x, seg) 145 | x = self.G_middle_2(x, seg) 146 | x = self.G_middle_3(x, seg) 147 | x = self.G_middle_4(x, seg) 148 | x = self.G_middle_5(x, seg) 149 | x = self.up(x) 150 | x = self.up_0(x, seg) # 256, 128, 128 151 | x = self.up(x) 152 | x = self.up_1(x, seg) # 64, 256, 256 153 | 154 | x = self.conv_img(F.leaky_relu(x, 2e-1)) 155 | # x = torch.tanh(x) 156 | x = F.sigmoid(x) 157 | 158 | return x 159 | 160 | 161 | class OcclusionAwareSPADEGenerator(nn.Module): 162 | 163 | def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, 164 | num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): 165 | super(OcclusionAwareSPADEGenerator, self).__init__() 166 | 167 | if dense_motion_params is not None: 168 | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, 169 | estimate_occlusion_map=estimate_occlusion_map, 170 | **dense_motion_params) 171 | else: 172 | self.dense_motion_network = None 173 | 174 | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) 175 | 176 | down_blocks = [] 177 | for i in range(num_down_blocks): 178 | in_features = min(max_features, block_expansion * (2 ** i)) 179 | out_features = min(max_features, block_expansion * (2 ** (i + 1))) 180 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 181 | self.down_blocks = nn.ModuleList(down_blocks) 182 | 183 | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) 184 | 185 | self.reshape_channel = reshape_channel 186 | self.reshape_depth = reshape_depth 187 | 188 | self.resblocks_3d = torch.nn.Sequential() 189 | for i in range(num_resblocks): 190 | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) 191 | 192 | out_features = block_expansion * (2 ** (num_down_blocks)) 193 | self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) 194 | self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) 195 | 196 | self.estimate_occlusion_map = estimate_occlusion_map 197 | self.image_channel = image_channel 198 | 199 | self.decoder = SPADEDecoder() 200 | 201 | def deform_input(self, inp, deformation): 202 | _, d_old, h_old, w_old, _ = deformation.shape 203 | _, _, d, h, w = inp.shape 204 | if d_old != d or h_old != h or w_old != w: 205 | deformation = deformation.permute(0, 4, 1, 2, 3) 206 | deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') 207 | deformation = deformation.permute(0, 2, 3, 4, 1) 208 | return F.grid_sample(inp, deformation) 209 | 210 | def forward(self, source_image, kp_driving, kp_source, fp16=False): 211 | if fp16: 212 | source_image = source_image.half() 213 | kp_driving['value'] = kp_driving['value'].half() 214 | kp_source['value'] = kp_source['value'].half() 215 | # Encoding (downsampling) part 216 | out = self.first(source_image) 217 | for i in range(len(self.down_blocks)): 218 | out = self.down_blocks[i](out) 219 | out = self.second(out) 220 | bs, c, h, w = out.shape 221 | # print(out.shape) 222 | feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) 223 | feature_3d = self.resblocks_3d(feature_3d) 224 | 225 | # Transforming feature representation according to deformation and occlusion 226 | output_dict = {} 227 | if self.dense_motion_network is not None: 228 | dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, 229 | kp_source=kp_source) 230 | output_dict['mask'] = dense_motion['mask'] 231 | 232 | if 'occlusion_map' in dense_motion: 233 | occlusion_map = dense_motion['occlusion_map'] 234 | output_dict['occlusion_map'] = occlusion_map 235 | else: 236 | occlusion_map = None 237 | deformation = dense_motion['deformation'] 238 | out = self.deform_input(feature_3d, deformation) 239 | 240 | bs, c, d, h, w = out.shape 241 | out = out.view(bs, c*d, h, w) 242 | out = self.third(out) 243 | out = self.fourth(out) 244 | 245 | if occlusion_map is not None: 246 | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: 247 | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') 248 | out = out * occlusion_map 249 | 250 | # Decoding part 251 | out = self.decoder(out) 252 | 253 | output_dict["prediction"] = out 254 | 255 | return output_dict -------------------------------------------------------------------------------- /face-vid2vid/modules/hopenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import math 5 | import torch.nn.functional as F 6 | 7 | class Hopenet(nn.Module): 8 | # Hopenet with 3 output layers for yaw, pitch and roll 9 | # Predicts Euler angles by binning and regression with the expected value 10 | def __init__(self, block, layers, num_bins): 11 | self.inplanes = 64 12 | super(Hopenet, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 14 | bias=False) 15 | self.bn1 = nn.BatchNorm2d(64) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 18 | self.layer1 = self._make_layer(block, 64, layers[0]) 19 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 20 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 21 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 22 | self.avgpool = nn.AvgPool2d(7) 23 | self.fc_yaw = nn.Linear(512 * block.expansion, num_bins) 24 | self.fc_pitch = nn.Linear(512 * block.expansion, num_bins) 25 | self.fc_roll = nn.Linear(512 * block.expansion, num_bins) 26 | 27 | # Vestigial layer from previous experiments 28 | self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3) 29 | 30 | for m in self.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 33 | m.weight.data.normal_(0, math.sqrt(2. / n)) 34 | elif isinstance(m, nn.BatchNorm2d): 35 | m.weight.data.fill_(1) 36 | m.bias.data.zero_() 37 | 38 | def _make_layer(self, block, planes, blocks, stride=1): 39 | downsample = None 40 | if stride != 1 or self.inplanes != planes * block.expansion: 41 | downsample = nn.Sequential( 42 | nn.Conv2d(self.inplanes, planes * block.expansion, 43 | kernel_size=1, stride=stride, bias=False), 44 | nn.BatchNorm2d(planes * block.expansion), 45 | ) 46 | 47 | layers = [] 48 | layers.append(block(self.inplanes, planes, stride, downsample)) 49 | self.inplanes = planes * block.expansion 50 | for i in range(1, blocks): 51 | layers.append(block(self.inplanes, planes)) 52 | 53 | return nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | x = self.conv1(x) 57 | x = self.bn1(x) 58 | x = self.relu(x) 59 | x = self.maxpool(x) 60 | 61 | x = self.layer1(x) 62 | x = self.layer2(x) 63 | x = self.layer3(x) 64 | x = self.layer4(x) 65 | 66 | x = self.avgpool(x) 67 | x = x.view(x.size(0), -1) 68 | pre_yaw = self.fc_yaw(x) 69 | pre_pitch = self.fc_pitch(x) 70 | pre_roll = self.fc_roll(x) 71 | 72 | return pre_yaw, pre_pitch, pre_roll 73 | 74 | class ResNet(nn.Module): 75 | # ResNet for regression of 3 Euler angles. 76 | def __init__(self, block, layers, num_classes=1000): 77 | self.inplanes = 64 78 | super(ResNet, self).__init__() 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 80 | bias=False) 81 | self.bn1 = nn.BatchNorm2d(64) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 84 | self.layer1 = self._make_layer(block, 64, layers[0]) 85 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 86 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 87 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 88 | self.avgpool = nn.AvgPool2d(7) 89 | self.fc_angles = nn.Linear(512 * block.expansion, num_classes) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | elif isinstance(m, nn.BatchNorm2d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | 99 | def _make_layer(self, block, planes, blocks, stride=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | downsample = nn.Sequential( 103 | nn.Conv2d(self.inplanes, planes * block.expansion, 104 | kernel_size=1, stride=stride, bias=False), 105 | nn.BatchNorm2d(planes * block.expansion), 106 | ) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, downsample)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, blocks): 112 | layers.append(block(self.inplanes, planes)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | x = self.conv1(x) 118 | x = self.bn1(x) 119 | x = self.relu(x) 120 | x = self.maxpool(x) 121 | 122 | x = self.layer1(x) 123 | x = self.layer2(x) 124 | x = self.layer3(x) 125 | x = self.layer4(x) 126 | 127 | x = self.avgpool(x) 128 | x = x.view(x.size(0), -1) 129 | x = self.fc_angles(x) 130 | return x 131 | 132 | class AlexNet(nn.Module): 133 | # AlexNet laid out as a Hopenet - classify Euler angles in bins and 134 | # regress the expected value. 135 | def __init__(self, num_bins): 136 | super(AlexNet, self).__init__() 137 | self.features = nn.Sequential( 138 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 139 | nn.ReLU(inplace=True), 140 | nn.MaxPool2d(kernel_size=3, stride=2), 141 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 142 | nn.ReLU(inplace=True), 143 | nn.MaxPool2d(kernel_size=3, stride=2), 144 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 147 | nn.ReLU(inplace=True), 148 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 149 | nn.ReLU(inplace=True), 150 | nn.MaxPool2d(kernel_size=3, stride=2), 151 | ) 152 | self.classifier = nn.Sequential( 153 | nn.Dropout(), 154 | nn.Linear(256 * 6 * 6, 4096), 155 | nn.ReLU(inplace=True), 156 | nn.Dropout(), 157 | nn.Linear(4096, 4096), 158 | nn.ReLU(inplace=True), 159 | ) 160 | self.fc_yaw = nn.Linear(4096, num_bins) 161 | self.fc_pitch = nn.Linear(4096, num_bins) 162 | self.fc_roll = nn.Linear(4096, num_bins) 163 | 164 | def forward(self, x): 165 | x = self.features(x) 166 | x = x.view(x.size(0), 256 * 6 * 6) 167 | x = self.classifier(x) 168 | yaw = self.fc_yaw(x) 169 | pitch = self.fc_pitch(x) 170 | roll = self.fc_roll(x) 171 | return yaw, pitch, roll 172 | -------------------------------------------------------------------------------- /face-vid2vid/modules/keypoint_detector.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d 6 | from modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck 7 | 8 | 9 | class KPDetector(nn.Module): 10 | """ 11 | Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint. 12 | """ 13 | 14 | def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth, 15 | num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False): 16 | super(KPDetector, self).__init__() 17 | 18 | self.predictor = KPHourglass(block_expansion, in_features=image_channel, 19 | max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks) 20 | 21 | # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3) 22 | self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1) 23 | 24 | if estimate_jacobian: 25 | self.num_jacobian_maps = 1 if single_jacobian_map else num_kp 26 | # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3) 27 | self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1) 28 | ''' 29 | initial as: 30 | [[1 0 0] 31 | [0 1 0] 32 | [0 0 1]] 33 | ''' 34 | self.jacobian.weight.data.zero_() 35 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) 36 | else: 37 | self.jacobian = None 38 | 39 | self.temperature = temperature 40 | self.scale_factor = scale_factor 41 | if self.scale_factor != 1: 42 | self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor) 43 | 44 | def gaussian2kp(self, heatmap): 45 | """ 46 | Extract the mean from a heatmap 47 | """ 48 | shape = heatmap.shape 49 | heatmap = heatmap.unsqueeze(-1) 50 | grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) 51 | value = (heatmap * grid).sum(dim=(2, 3, 4)) 52 | kp = {'value': value} 53 | 54 | return kp 55 | 56 | def forward(self, x): 57 | if self.scale_factor != 1: 58 | x = self.down(x) 59 | 60 | feature_map = self.predictor(x) 61 | prediction = self.kp(feature_map) 62 | 63 | final_shape = prediction.shape 64 | heatmap = prediction.view(final_shape[0], final_shape[1], -1) 65 | heatmap = F.softmax(heatmap / self.temperature, dim=2) 66 | heatmap = heatmap.view(*final_shape) 67 | 68 | out = self.gaussian2kp(heatmap) 69 | 70 | if self.jacobian is not None: 71 | jacobian_map = self.jacobian(feature_map) 72 | jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2], 73 | final_shape[3], final_shape[4]) 74 | heatmap = heatmap.unsqueeze(2) 75 | 76 | jacobian = heatmap * jacobian_map 77 | jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1) 78 | jacobian = jacobian.sum(dim=-1) 79 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3) 80 | out['jacobian'] = jacobian 81 | 82 | return out 83 | 84 | 85 | class HEEstimator(nn.Module): 86 | """ 87 | Estimating head pose and expression. 88 | """ 89 | 90 | def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True): 91 | super(HEEstimator, self).__init__() 92 | 93 | self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2) 94 | self.norm1 = BatchNorm2d(block_expansion, affine=True) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | 97 | self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1) 98 | self.norm2 = BatchNorm2d(256, affine=True) 99 | 100 | self.block1 = nn.Sequential() 101 | for i in range(3): 102 | self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1)) 103 | 104 | self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1) 105 | self.norm3 = BatchNorm2d(512, affine=True) 106 | self.block2 = ResBottleneck(in_features=512, stride=2) 107 | 108 | self.block3 = nn.Sequential() 109 | for i in range(3): 110 | self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1)) 111 | 112 | self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1) 113 | self.norm4 = BatchNorm2d(1024, affine=True) 114 | self.block4 = ResBottleneck(in_features=1024, stride=2) 115 | 116 | self.block5 = nn.Sequential() 117 | for i in range(5): 118 | self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1)) 119 | 120 | self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1) 121 | self.norm5 = BatchNorm2d(2048, affine=True) 122 | self.block6 = ResBottleneck(in_features=2048, stride=2) 123 | 124 | self.block7 = nn.Sequential() 125 | for i in range(2): 126 | self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1)) 127 | 128 | self.fc_roll = nn.Linear(2048, num_bins) 129 | self.fc_pitch = nn.Linear(2048, num_bins) 130 | self.fc_yaw = nn.Linear(2048, num_bins) 131 | 132 | self.fc_t = nn.Linear(2048, 3) 133 | 134 | self.fc_exp = nn.Linear(2048, 3*num_kp) 135 | 136 | def forward(self, x): 137 | out = self.conv1(x) 138 | out = self.norm1(out) 139 | out = F.relu(out) 140 | out = self.maxpool(out) 141 | 142 | out = self.conv2(out) 143 | out = self.norm2(out) 144 | out = F.relu(out) 145 | 146 | out = self.block1(out) 147 | 148 | out = self.conv3(out) 149 | out = self.norm3(out) 150 | out = F.relu(out) 151 | out = self.block2(out) 152 | 153 | out = self.block3(out) 154 | 155 | out = self.conv4(out) 156 | out = self.norm4(out) 157 | out = F.relu(out) 158 | out = self.block4(out) 159 | 160 | out = self.block5(out) 161 | 162 | out = self.conv5(out) 163 | out = self.norm5(out) 164 | out = F.relu(out) 165 | out = self.block6(out) 166 | 167 | out = self.block7(out) 168 | 169 | out = F.adaptive_avg_pool2d(out, 1) 170 | out = out.view(out.shape[0], -1) 171 | 172 | yaw = self.fc_roll(out) 173 | pitch = self.fc_pitch(out) 174 | roll = self.fc_yaw(out) 175 | t = self.fc_t(out) 176 | exp = self.fc_exp(out) 177 | 178 | return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} 179 | -------------------------------------------------------------------------------- /face-vid2vid/run.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use("Agg") 4 | 5 | import os, sys 6 | import yaml 7 | from argparse import ArgumentParser 8 | from time import gmtime, strftime 9 | from shutil import copy 10 | 11 | from frames_dataset import FramesDataset 12 | 13 | from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator 14 | from modules.discriminator import MultiScaleDiscriminator 15 | from modules.keypoint_detector import KPDetector, HEEstimator 16 | 17 | import torch 18 | 19 | from train import train 20 | 21 | if __name__ == "__main__": 22 | if sys.version_info[0] < 3: 23 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") 24 | 25 | parser = ArgumentParser() 26 | parser.add_argument("--config", default="config/vox-256.yaml", help="path to config") 27 | parser.add_argument( 28 | "--mode", 29 | default="train", 30 | choices=[ 31 | "train", 32 | ], 33 | ) 34 | parser.add_argument("--gen", default="original", choices=["original", "spade"]) 35 | parser.add_argument("--log_dir", default="log", help="path to log into") 36 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") 37 | parser.add_argument( 38 | "--device_ids", 39 | default="0, 1, 2, 3, 4, 5, 6, 7", 40 | type=lambda x: list(map(int, x.split(","))), 41 | help="Names of the devices comma separated.", 42 | ) 43 | parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") 44 | parser.set_defaults(verbose=False) 45 | 46 | opt = parser.parse_args() 47 | with open(opt.config) as f: 48 | config = yaml.load(f, Loader=yaml.FullLoader) 49 | 50 | if opt.checkpoint is not None: 51 | log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) 52 | else: 53 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split(".")[0]) 54 | log_dir += " " + strftime("%d_%m_%y_%H.%M.%S", gmtime()) 55 | 56 | if opt.gen == "original": 57 | generator = OcclusionAwareGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) 58 | elif opt.gen == "spade": 59 | generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) 60 | 61 | if torch.cuda.is_available(): 62 | print("cuda is available") 63 | generator.to(opt.device_ids[0]) 64 | if opt.verbose: 65 | print(generator) 66 | 67 | discriminator = MultiScaleDiscriminator(**config["model_params"]["discriminator_params"], **config["model_params"]["common_params"]) 68 | if torch.cuda.is_available(): 69 | discriminator.to(opt.device_ids[0]) 70 | if opt.verbose: 71 | print(discriminator) 72 | 73 | kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) 74 | 75 | if torch.cuda.is_available(): 76 | kp_detector.to(opt.device_ids[0]) 77 | 78 | if opt.verbose: 79 | print(kp_detector) 80 | 81 | he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"]) 82 | 83 | if torch.cuda.is_available(): 84 | he_estimator.to(opt.device_ids[0]) 85 | 86 | dataset = FramesDataset(is_train=(opt.mode == "train"), **config["dataset_params"]) 87 | 88 | if not os.path.exists(log_dir): 89 | os.makedirs(log_dir) 90 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): 91 | copy(opt.config, log_dir) 92 | 93 | if opt.mode == "train": 94 | print("Training...") 95 | train(config, generator, discriminator, kp_detector, he_estimator, opt.checkpoint, log_dir, dataset, opt.device_ids) 96 | -------------------------------------------------------------------------------- /face-vid2vid/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /face-vid2vid/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /face-vid2vid/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /face-vid2vid/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /face-vid2vid/train.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from logger import Logger 7 | from modules.model import GeneratorFullModel, DiscriminatorFullModel 8 | 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | 11 | from sync_batchnorm import DataParallelWithCallback 12 | 13 | from frames_dataset import DatasetRepeater 14 | 15 | 16 | def train(config, generator, discriminator, kp_detector, he_estimator, checkpoint, log_dir, dataset, device_ids): 17 | train_params = config["train_params"] 18 | 19 | optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params["lr_generator"], betas=(0.5, 0.999)) 20 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params["lr_discriminator"], betas=(0.5, 0.999)) 21 | optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params["lr_kp_detector"], betas=(0.5, 0.999)) 22 | optimizer_he_estimator = torch.optim.Adam(he_estimator.parameters(), lr=train_params["lr_he_estimator"], betas=(0.5, 0.999)) 23 | 24 | if checkpoint is not None: 25 | start_epoch = Logger.load_cpk( 26 | checkpoint, 27 | generator, 28 | discriminator, 29 | kp_detector, 30 | he_estimator, 31 | optimizer_generator, 32 | optimizer_discriminator, 33 | optimizer_kp_detector, 34 | optimizer_he_estimator, 35 | ) 36 | else: 37 | start_epoch = 0 38 | 39 | scheduler_generator = MultiStepLR(optimizer_generator, train_params["epoch_milestones"], gamma=0.1, last_epoch=start_epoch - 1) 40 | scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params["epoch_milestones"], gamma=0.1, last_epoch=start_epoch - 1) 41 | scheduler_kp_detector = MultiStepLR( 42 | optimizer_kp_detector, train_params["epoch_milestones"], gamma=0.1, last_epoch=-1 + start_epoch * (train_params["lr_kp_detector"] != 0) 43 | ) 44 | scheduler_he_estimator = MultiStepLR( 45 | optimizer_he_estimator, train_params["epoch_milestones"], gamma=0.1, last_epoch=-1 + start_epoch * (train_params["lr_kp_detector"] != 0) 46 | ) 47 | 48 | if "num_repeats" in train_params or train_params["num_repeats"] != 1: 49 | dataset = DatasetRepeater(dataset, train_params["num_repeats"]) 50 | dataloader = DataLoader(dataset, batch_size=train_params["batch_size"], shuffle=True, num_workers=16, drop_last=True) 51 | 52 | generator_full = GeneratorFullModel( 53 | kp_detector, 54 | he_estimator, 55 | generator, 56 | discriminator, 57 | train_params, 58 | estimate_jacobian=config["model_params"]["common_params"]["estimate_jacobian"], 59 | ) 60 | discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) 61 | 62 | if torch.cuda.is_available(): 63 | generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) 64 | discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) 65 | 66 | with Logger(log_dir=log_dir, visualizer_params=config["visualizer_params"], checkpoint_freq=train_params["checkpoint_freq"]) as logger: 67 | for epoch in trange(start_epoch, train_params["num_epochs"]): 68 | for x in dataloader: 69 | losses_generator, generated = generator_full(x) 70 | 71 | loss_values = [val.mean() for val in losses_generator.values()] 72 | loss = sum(loss_values) 73 | 74 | loss.backward() 75 | optimizer_generator.step() 76 | optimizer_generator.zero_grad() 77 | optimizer_kp_detector.step() 78 | optimizer_kp_detector.zero_grad() 79 | optimizer_he_estimator.step() 80 | optimizer_he_estimator.zero_grad() 81 | 82 | if train_params["loss_weights"]["generator_gan"] != 0: 83 | optimizer_discriminator.zero_grad() 84 | losses_discriminator = discriminator_full(x, generated) 85 | loss_values = [val.mean() for val in losses_discriminator.values()] 86 | loss = sum(loss_values) 87 | 88 | loss.backward() 89 | optimizer_discriminator.step() 90 | optimizer_discriminator.zero_grad() 91 | else: 92 | losses_discriminator = {} 93 | 94 | losses_generator.update(losses_discriminator) 95 | losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} 96 | logger.log_iter(losses=losses) 97 | 98 | scheduler_generator.step() 99 | scheduler_discriminator.step() 100 | scheduler_kp_detector.step() 101 | scheduler_he_estimator.step() 102 | 103 | logger.log_epoch( 104 | epoch, 105 | { 106 | "generator": generator, 107 | "discriminator": discriminator, 108 | "kp_detector": kp_detector, 109 | "he_estimator": he_estimator, 110 | "optimizer_generator": optimizer_generator, 111 | "optimizer_discriminator": optimizer_discriminator, 112 | "optimizer_kp_detector": optimizer_kp_detector, 113 | "optimizer_he_estimator": optimizer_he_estimator, 114 | }, 115 | inp=x, 116 | out=generated, 117 | ) 118 | -------------------------------------------------------------------------------- /remote_server.py: -------------------------------------------------------------------------------- 1 | import io 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | from argparse import ArgumentParser 6 | 7 | 8 | from fastapi import FastAPI, WebSocket 9 | from fastapi.websockets import WebSocketDisconnect 10 | from demo_utils import FaceAnimationClass 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument("--source_image", default="./assets/source.jpg", help="path to source image") 14 | parser.add_argument("--restore_face", default=False, type=str, help="restore face") 15 | args = parser.parse_args() 16 | restore_face = True if args.restore_face == 'True' else False if args.restore_face == 'False' else exit('restore_face must be True or False') 17 | 18 | 19 | faceanimation = FaceAnimationClass(source_image_path=args.source_image, use_sr=restore_face) 20 | # remote server fps is lower than local camera fps, so we need to increase the frequency of face detection and increase the smooth factor 21 | faceanimation.detect_interval = 2 22 | faceanimation.smooth_factor = 0.8 23 | 24 | 25 | app = FastAPI() 26 | websocket_port = 8066 27 | 28 | 29 | # WebSocket endpoint to receive and process images 30 | @app.websocket("/ws") 31 | async def websocket_endpoint(websocket: WebSocket): 32 | await websocket.accept() 33 | try: 34 | while True: 35 | # Receive the image as a binary stream 36 | image_data = await websocket.receive_bytes() 37 | processed_image = process_image(image_data) 38 | # Send the processed image back to the client 39 | await websocket.send_bytes(processed_image) 40 | except WebSocketDisconnect: 41 | pass 42 | 43 | 44 | def process_image(image_data): 45 | image = Image.open(io.BytesIO(image_data)) 46 | image_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 47 | face, result = faceanimation.inference(image_cv2) 48 | # resize to 256x256 49 | if face.shape[1] != 256 or face.shape[0] != 256: 50 | face = cv2.resize(face, (256, 256)) 51 | if result.shape[0] != 256 or result.shape[1] != 256: 52 | result = cv2.resize(result, (256, 256)) 53 | result = cv2.hconcat([face, result]) 54 | _, processed_image_data = cv2.imencode(".jpg", result, [cv2.IMWRITE_JPEG_QUALITY, 95]) 55 | return processed_image_data.tobytes() 56 | 57 | 58 | if __name__ == "__main__": 59 | import uvicorn 60 | 61 | uvicorn.run(app, host="0.0.0.0", port=websocket_port) 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.11.0 2 | numpy>=1.23.5 3 | PyYAML 4 | imageio[ffmpeg] 5 | batch-face 6 | gdown 7 | scipy --------------------------------------------------------------------------------