├── gpeno ├── face_detect │ ├── utils │ │ ├── __init__.py │ │ ├── nms │ │ │ ├── __init__.py │ │ │ └── py_cpu_nms.py │ │ ├── timer.py │ │ └── box_utils.py │ ├── facemodels │ │ ├── __init__.py │ │ ├── retinaface.py │ │ └── net.py │ ├── layers │ │ ├── __init__.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── multibox_loss.py │ │ └── functions │ │ │ └── prior_box.py │ ├── .DS_Store │ ├── data │ │ ├── __init__.py │ │ ├── config.py │ │ ├── wider_face.py │ │ └── data_augment.py │ └── retinaface_detection.py ├── face_parse │ ├── mask.png │ ├── test.png │ ├── face_parsing.py │ ├── face_parsing_broken.py │ ├── parse_model.py │ └── blocks.py ├── face_model │ ├── op │ │ ├── __init__.py │ │ ├── fused_bias_act.cpp │ │ ├── upfirdn2d.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── fused_act.py │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu │ └── face_gan.py ├── requirements.txt ├── training │ ├── lpips │ │ ├── weights │ │ │ ├── v0.0 │ │ │ │ ├── alex.pth │ │ │ │ ├── vgg.pth │ │ │ │ └── squeeze.pth │ │ │ └── v0.1 │ │ │ │ ├── alex.pth │ │ │ │ ├── vgg.pth │ │ │ │ └── squeeze.pth │ │ ├── __init__.py │ │ ├── pretrained_networks.py │ │ ├── lpips.py │ │ └── trainer.py │ ├── loss │ │ ├── id_loss.py │ │ ├── model_irse.py │ │ └── helpers.py │ └── data_loader │ │ └── dataset_face.py ├── misc │ ├── cog.yaml │ ├── predict.py │ └── onnx_export.py ├── __init_paths.py ├── face_inpainting.py ├── face_colorization.py ├── .gitignore ├── distributed.py ├── face_enhancement.py ├── demo.py ├── align_faces.py └── README.md ├── .gitattributes ├── workflows └── workflow_gpeno.png ├── CHANGELOG.md ├── .github └── workflows │ └── publish_action.yml ├── pyproject.toml ├── README.md └── __init__.py /gpeno/face_detect/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gpeno/face_detect/facemodels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gpeno/face_detect/utils/nms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /gpeno/face_detect/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /gpeno/face_parse/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/face_parse/mask.png -------------------------------------------------------------------------------- /gpeno/face_parse/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/face_parse/test.png -------------------------------------------------------------------------------- /gpeno/face_detect/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/face_detect/.DS_Store -------------------------------------------------------------------------------- /workflows/workflow_gpeno.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/workflows/workflow_gpeno.png -------------------------------------------------------------------------------- /gpeno/face_detect/layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multibox_loss import MultiBoxLoss 2 | 3 | __all__ = ['MultiBoxLoss'] 4 | -------------------------------------------------------------------------------- /gpeno/face_model/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /gpeno/requirements.txt: -------------------------------------------------------------------------------- 1 | ninja 2 | torch 3 | torchvision 4 | opencv-python 5 | numpy 6 | scikit-image 7 | scipy 8 | pillow 9 | tqdm -------------------------------------------------------------------------------- /gpeno/training/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /gpeno/training/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /gpeno/training/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /gpeno/training/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /gpeno/training/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /gpeno/training/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /gpeno/face_detect/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .wider_face import WiderFaceDetection, detection_collate 2 | from .data_augment import * 3 | from .config import * 4 | -------------------------------------------------------------------------------- /gpeno/misc/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | python_version: "3.8" 4 | system_packages: 5 | - "libgl1-mesa-glx" 6 | - "libglib2.0-0" 7 | - "ninja-build" 8 | python_packages: 9 | - "torch==1.7.1" 10 | - "torchvision==0.8.2" 11 | - "numpy==1.20.1" 12 | - "ipython==7.21.0" 13 | - "Pillow==8.3.1" 14 | - "scikit-image==0.18.3" 15 | - "opencv-python==4.5.3.56" 16 | 17 | predict: "predict.py:Predictor" 18 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | All notable changes to this project will be documented in this file. 2 | 3 |
0.1.0 - 15 April 2025 4 | 5 | **Added** 6 | - New `colorize` setting: Bring old photos to life 7 | - Returns the original face and enhanced face in addition to the resulting masked image 8 | 9 |
10 | 11 |
0.0.1 - 24 December 2024 12 | 13 | **Added** 14 | - Initial release 15 | 16 |
-------------------------------------------------------------------------------- /gpeno/__init_paths.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | path = osp.join(this_dir, 'face_detect') 17 | add_path(path) 18 | 19 | path = osp.join(this_dir, 'face_parse') 20 | add_path(path) 21 | 22 | path = osp.join(this_dir, 'face_model') 23 | add_path(path) 24 | -------------------------------------------------------------------------------- /.github/workflows/publish_action.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'SparknightLLC' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | -------------------------------------------------------------------------------- /gpeno/face_inpainting.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | from face_model.face_gan import FaceGAN 6 | 7 | class FaceInpainting(object): 8 | def __init__(self, base_dir='./', in_size=1024, out_size=1024, model=None, channel_multiplier=2, narrow=1, key=None, device='cuda'): 9 | self.facegan = FaceGAN(base_dir, in_size, out_size, model, channel_multiplier, narrow, key, device=device) 10 | 11 | # make sure the face image is well aligned. Please refer to face_enhancement.py 12 | def process(self, brokenf, aligned=True): 13 | # complete the face 14 | out = self.facegan.process(brokenf) 15 | 16 | return out, [brokenf], [out] 17 | 18 | 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # pyproject.toml 2 | [project] 3 | name = "comfyui-gpeno" # Unique identifier for your node. Immutable after creation. 4 | description = "A node for ComfyUI that performs GPEN face restoration on the input image(s). Significantly faster than other implementations of GPEN." 5 | version = "0.1.0" # Custom Node version. Must be semantically versioned. 6 | license = { file = "LICENSE.txt" } 7 | dependencies = [] # Filled in from requirements.txt 8 | 9 | [project.urls] 10 | Repository = "https://github.com/SparknightLLC/ComfyUI-GPENO" 11 | 12 | [tool.comfy] 13 | PublisherId = "sparknight" # TODO (fill in Publisher ID from Comfy Registry Website). 14 | DisplayName = "ComfyUI-GPENO" # Display name for the Custom Node. Can be changed later. 15 | Icon = "https://example.com/icon.png" # SVG, PNG, JPG or GIF (MAX. 800x400px) -------------------------------------------------------------------------------- /gpeno/face_model/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /gpeno/face_model/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /gpeno/face_detect/utils/nms/py_cpu_nms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | 10 | def py_cpu_nms(dets, thresh): 11 | """Pure Python NMS baseline.""" 12 | x1 = dets[:, 0] 13 | y1 = dets[:, 1] 14 | x2 = dets[:, 2] 15 | y2 = dets[:, 3] 16 | scores = dets[:, 4] 17 | 18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 19 | order = scores.argsort()[::-1] 20 | 21 | keep = [] 22 | while order.size > 0: 23 | i = order[0] 24 | keep.append(i) 25 | xx1 = np.maximum(x1[i], x1[order[1:]]) 26 | yy1 = np.maximum(y1[i], y1[order[1:]]) 27 | xx2 = np.minimum(x2[i], x2[order[1:]]) 28 | yy2 = np.minimum(y2[i], y2[order[1:]]) 29 | 30 | w = np.maximum(0.0, xx2 - xx1 + 1) 31 | h = np.maximum(0.0, yy2 - yy1 + 1) 32 | inter = w * h 33 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 34 | 35 | inds = np.where(ovr <= thresh)[0] 36 | order = order[inds + 1] 37 | 38 | return keep 39 | -------------------------------------------------------------------------------- /gpeno/face_detect/data/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | 3 | cfg_mnet = { 4 | 'name': 'detection_mobilenet0.25_Final', 5 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 6 | 'steps': [8, 16, 32], 7 | 'variance': [0.1, 0.2], 8 | 'clip': False, 9 | 'loc_weight': 2.0, 10 | 'gpu_train': True, 11 | 'batch_size': 32, 12 | 'ngpu': 1, 13 | 'epoch': 250, 14 | 'decay1': 190, 15 | 'decay2': 220, 16 | 'image_size': 640, 17 | 'pretrain': False, 18 | 'return_layers': { 19 | 'stage1': 1, 20 | 'stage2': 2, 21 | 'stage3': 3 22 | }, 23 | 'in_channel': 32, 24 | 'out_channel': 64, 25 | } 26 | 27 | cfg_re50 = { 28 | 'name': 'detection_Resnet50_Final', #'Resnet50', 29 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 30 | 'steps': [8, 16, 32], 31 | 'variance': [0.1, 0.2], 32 | 'clip': False, 33 | 'loc_weight': 2.0, 34 | 'gpu_train': True, 35 | 'batch_size': 24, 36 | 'ngpu': 4, 37 | 'epoch': 100, 38 | 'decay1': 70, 39 | 'decay2': 90, 40 | 'image_size': 840, 41 | 'pretrain': False, 42 | 'return_layers': { 43 | 'layer2': 1, 44 | 'layer3': 2, 45 | 'layer4': 3 46 | }, 47 | 'in_channel': 256, 48 | 'out_channel': 256, 49 | } 50 | -------------------------------------------------------------------------------- /gpeno/face_colorization.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import cv2 6 | from face_model.face_gan import FaceGAN 7 | 8 | class FaceColorization(object): 9 | def __init__(self, base_dir='./', in_size=1024, out_size=1024, model=None, channel_multiplier=2, narrow=1, key=None, device='cuda'): 10 | self.facegan = FaceGAN(base_dir, in_size, out_size, model, channel_multiplier, narrow, key, device=device) 11 | 12 | def post_process(self, gray, out): 13 | out_rs = cv2.resize(out, gray.shape[:2][::-1]) 14 | gray_yuv = cv2.cvtColor(gray, cv2.COLOR_BGR2YUV) 15 | out_yuv = cv2.cvtColor(out_rs, cv2.COLOR_BGR2YUV) 16 | 17 | out_yuv[:, :, 0] = gray_yuv[:, :, 0] 18 | final = cv2.cvtColor(out_yuv, cv2.COLOR_YUV2BGR) 19 | 20 | return final 21 | 22 | # make sure the face image is well aligned. Please refer to face_enhancement.py 23 | def process(self, gray, aligned=True): 24 | # colorize the face 25 | out = self.facegan.process(gray) 26 | 27 | if gray.shape[:2] != out.shape[:2]: 28 | out = self.post_process(gray, out) 29 | 30 | return out, [gray], [out] 31 | 32 | 33 | -------------------------------------------------------------------------------- /gpeno/face_detect/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | 10 | 11 | class Timer(object): 12 | """A simple timer.""" 13 | def __init__(self): 14 | self.total_time = 0. 15 | self.calls = 0 16 | self.start_time = 0. 17 | self.diff = 0. 18 | self.average_time = 0. 19 | 20 | def tic(self): 21 | # using time.time instead of time.clock because time time.clock 22 | # does not normalize for multithreading 23 | self.start_time = time.time() 24 | 25 | def toc(self, average=True): 26 | self.diff = time.time() - self.start_time 27 | self.total_time += self.diff 28 | self.calls += 1 29 | self.average_time = self.total_time / self.calls 30 | if average: 31 | return self.average_time 32 | else: 33 | return self.diff 34 | 35 | def clear(self): 36 | self.total_time = 0. 37 | self.calls = 0 38 | self.start_time = 0. 39 | self.diff = 0. 40 | self.average_time = 0. 41 | -------------------------------------------------------------------------------- /gpeno/face_detect/layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import product as product 3 | import numpy as np 4 | from math import ceil 5 | 6 | 7 | class PriorBox(object): 8 | def __init__(self, cfg, image_size=None, phase='train'): 9 | super(PriorBox, self).__init__() 10 | self.min_sizes = cfg['min_sizes'] 11 | self.steps = cfg['steps'] 12 | self.clip = cfg['clip'] 13 | self.image_size = image_size 14 | self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] 15 | self.name = "s" 16 | 17 | def forward(self): 18 | anchors = [] 19 | for k, f in enumerate(self.feature_maps): 20 | min_sizes = self.min_sizes[k] 21 | for i, j in product(range(f[0]), range(f[1])): 22 | for min_size in min_sizes: 23 | s_kx = min_size / self.image_size[1] 24 | s_ky = min_size / self.image_size[0] 25 | dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] 26 | dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] 27 | for cy, cx in product(dense_cy, dense_cx): 28 | anchors += [cx, cy, s_kx, s_ky] 29 | 30 | # back to torch land 31 | output = torch.Tensor(anchors).view(-1, 4) 32 | if self.clip: 33 | output.clamp_(max=1, min=0) 34 | return output 35 | -------------------------------------------------------------------------------- /gpeno/training/loss/id_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from model_irse import Backbone 5 | 6 | class IDLoss(nn.Module): 7 | def __init__(self, base_dir='./', device='cuda', ckpt_dict=None): 8 | super(IDLoss, self).__init__() 9 | print('Loading ResNet ArcFace') 10 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se').to(device) 11 | if ckpt_dict is None: 12 | self.facenet.load_state_dict(torch.load(os.path.join(base_dir, 'weights', 'model_ir_se50.pth'), map_location=torch.device('cpu'))) 13 | else: 14 | self.facenet.load_state_dict(ckpt_dict) 15 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 16 | self.facenet.eval() 17 | 18 | def extract_feats(self, x): 19 | _, _, h, w = x.shape 20 | assert h==w 21 | ss = h//256 22 | x = x[:, :, 35*ss:-33*ss, 32*ss:-36*ss] # Crop interesting region 23 | x = self.face_pool(x) 24 | x_feats = self.facenet(x) 25 | return x_feats 26 | 27 | def forward(self, y_hat, y, x): 28 | n_samples = x.shape[0] 29 | x_feats = self.extract_feats(x) 30 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 31 | y_hat_feats = self.extract_feats(y_hat) 32 | y_feats = y_feats.detach() 33 | loss = 0 34 | sim_improvement = 0 35 | id_logs = [] 36 | count = 0 37 | for i in range(n_samples): 38 | diff_target = y_hat_feats[i].dot(y_feats[i]) 39 | diff_input = y_hat_feats[i].dot(x_feats[i]) 40 | diff_views = y_feats[i].dot(x_feats[i]) 41 | id_logs.append({'diff_target': float(diff_target), 42 | 'diff_input': float(diff_input), 43 | 'diff_views': float(diff_views)}) 44 | loss += 1 - diff_target 45 | id_diff = float(diff_target) - float(diff_views) 46 | sim_improvement += id_diff 47 | count += 1 48 | 49 | return loss / count, sim_improvement / count, id_logs 50 | 51 | -------------------------------------------------------------------------------- /gpeno/face_parse/face_parsing.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import os 6 | import cv2 7 | import torch 8 | import numpy as np 9 | from parse_model import ParseNet 10 | import torch.nn.functional as F 11 | 12 | 13 | class FaceParse: 14 | 15 | def __init__(self, base_dir='./', model='ParseNet-latest', device='cuda'): 16 | self.mfile = os.path.join(base_dir, 'facerestore_models', model + '.pth') 17 | self.size = 512 18 | self.device = device 19 | self.MASK_COLORMAP = torch.tensor([0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 255, 255, 0], device=self.device) 20 | self.load_model() 21 | print("FaceParse initialized") 22 | 23 | def load_model(self): 24 | self.faceparse = ParseNet(self.size, self.size, 32, 64, 19, norm_type='bn', relu_type='LeakyReLU', ch_range=[32, 256]) 25 | self.faceparse.load_state_dict(torch.load(self.mfile, map_location=self.device)) 26 | self.faceparse.to(self.device) 27 | self.faceparse.eval() 28 | 29 | def process(self, im): 30 | im = cv2.resize(im, (self.size, self.size)) 31 | imt = self.img2tensor(im) 32 | with torch.no_grad(): 33 | pred_mask, _ = self.faceparse(imt) 34 | mask = self.tensor2mask(pred_mask) 35 | return mask 36 | 37 | def process_tensor(self, imt): 38 | imt = F.interpolate(imt.flip(1) * 2 - 1, (self.size, self.size)) 39 | with torch.no_grad(): 40 | pred_mask, _ = self.faceparse(imt) 41 | mask = pred_mask.argmax(dim=1) 42 | mask = self.MASK_COLORMAP[mask].unsqueeze(0) 43 | return mask 44 | 45 | def img2tensor(self, img): 46 | img = img[..., ::-1] # Convert BGR to RGB 47 | img = (img / 255.0) * 2 - 1 # Scale image to [-1, 1] 48 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device) 49 | return img_tensor.float() 50 | 51 | def tensor2mask(self, tensor): 52 | if len(tensor.shape) < 4: 53 | tensor = tensor.unsqueeze(0) 54 | if tensor.shape[1] > 1: 55 | tensor = tensor.argmax(dim=1) 56 | tensor = tensor.squeeze(1).cpu().numpy() 57 | color_maps = [self.MASK_COLORMAP[t].cpu().numpy().astype(np.uint8) for t in tensor] 58 | return color_maps 59 | -------------------------------------------------------------------------------- /gpeno/face_parse/face_parsing_broken.py: -------------------------------------------------------------------------------- 1 | '''Modified for use with Unprompted.''' 2 | ''' 3 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 4 | @author: yangxy (yangtao9009@gmail.com) 5 | ''' 6 | 7 | import os 8 | import cv2 9 | import torch 10 | import numpy as np 11 | from parse_model import ParseNet 12 | import torch.nn.functional as F 13 | 14 | 15 | class FaceParse(object): 16 | def __init__(self, base_dir='./', model='ParseNet-latest', device='cuda'): 17 | self.mfile = os.path.join(base_dir, 'gpen', model + '.pth') 18 | self.size = 512 19 | self.device = device 20 | self.MASK_COLORMAP = np.array([0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 255, 255, 0]) 21 | self.load_model() 22 | 23 | def load_model(self): 24 | self.faceparse = ParseNet(self.size, self.size, 32, 64, 19, norm_type='bn', relu_type='LeakyReLU', ch_range=[32, 256]) 25 | self.faceparse.load_state_dict(torch.load(self.mfile, map_location=torch.device('cpu'))) 26 | self.faceparse.to(self.device) 27 | self.faceparse.eval() 28 | 29 | def process(self, im): 30 | im = cv2.resize(im, (self.size, self.size)) 31 | imt = self.img2tensor(im) 32 | with torch.no_grad(): 33 | pred_mask, _ = self.faceparse(imt) 34 | mask = self.tensor2mask(pred_mask) 35 | return mask 36 | 37 | def img2tensor(self, img): 38 | img = img[..., ::-1] # BGR to RGB 39 | img = img / 255.0 * 2 - 1 40 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device).float() 41 | return img_tensor 42 | 43 | def tensor2mask(self, tensor): 44 | if tensor.shape[1] > 1: 45 | tensor = tensor.argmax(dim=1) 46 | tensor = tensor.squeeze(0).data.cpu().numpy() 47 | mask = np.take(self.MASK_COLORMAP, tensor) 48 | if mask.ndim == 3: 49 | mask = mask[:, :, 0] # Ensure the mask is 2D 50 | return mask.astype(np.float32) 51 | 52 | def process_tensor(self, imt): 53 | imt = F.interpolate(imt.flip(1) * 2 - 1, (self.size, self.size)) 54 | with torch.no_grad(): 55 | pred_mask, _ = self.faceparse(imt) 56 | mask = pred_mask.argmax(dim=1).squeeze(0) 57 | mask = torch.where(mask < len(self.MASK_COLORMAP), torch.tensor(self.MASK_COLORMAP, device=mask.device)[mask], mask) 58 | return mask.unsqueeze(0) 59 | -------------------------------------------------------------------------------- /gpeno/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /gpeno/face_model/face_gan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import torch 6 | import os 7 | import cv2 8 | import glob 9 | import numpy as np 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | from torchvision import transforms, utils 14 | from .gpen_model import FullGenerator, FullGenerator_SR 15 | 16 | 17 | class FaceGAN(object): 18 | 19 | def __init__(self, base_dir='./', in_size=512, out_size=None, model=None, channel_multiplier=2, narrow=1, key=None, is_norm=True, device='cuda'): 20 | print(f"Initializing FaceGAN on {device} device...") 21 | self.mfile = os.path.join(base_dir, 'facerestore_models', model + '.pth') 22 | self.n_mlp = 8 23 | self.device = device 24 | self.is_norm = is_norm 25 | self.in_resolution = in_size 26 | self.out_resolution = in_size if out_size is None else out_size 27 | self.key = key 28 | self.load_model(channel_multiplier, narrow) 29 | print(f"FaceGAN initialized") 30 | 31 | def load_model(self, channel_multiplier=2, narrow=1): 32 | if self.in_resolution == self.out_resolution: 33 | self.model = FullGenerator(self.in_resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow, device=self.device) 34 | else: 35 | self.model = FullGenerator_SR(self.in_resolution, self.out_resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow, device=self.device) 36 | 37 | pretrained_dict = torch.load(self.mfile, map_location=self.device) 38 | 39 | #if self.key is not None: 40 | # pretrained_dict = pretrained_dict[self.key] 41 | 42 | self.model.load_state_dict(pretrained_dict) 43 | self.model.to(self.device) 44 | self.model.eval() 45 | 46 | def process(self, img): 47 | torch.backends.cudnn.deterministic = True 48 | torch.backends.cudnn.benchmark = False 49 | img = cv2.resize(img, (self.in_resolution, self.in_resolution)) 50 | img_t = self.img2tensor(img) 51 | 52 | with torch.no_grad(): 53 | out, __ = self.model(img_t) 54 | # del img_t 55 | 56 | out = self.tensor2img(out) 57 | 58 | return out 59 | 60 | def img2tensor(self, img): 61 | img_t = torch.from_numpy(img).to(self.device) / 255. 62 | if self.is_norm: 63 | img_t = (img_t - 0.5) / 0.5 64 | img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB 65 | return img_t 66 | 67 | def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8): 68 | if self.is_norm: 69 | img_t = img_t * 0.5 + 0.5 70 | img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR 71 | img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax 72 | 73 | return img_np.astype(imtype) 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-GPENO 2 | 3 | A node for [ComfyUI](https://github.com/comfyanonymous/ComfyUI) that performs [GPEN face restoration](https://github.com/yangxy/GPEN) on the input image(s). The "O" in "GPENO" stands for "Optimized," as I have implemented various performance improvements that significantly boost speed. 4 | 5 | ![workflow_gpeno](workflows/workflow_gpeno.png) 6 | 7 | ### Installation 8 | 9 | Simply drag the image above into ComfyUI and use [ComfyUI Manager » Install Missing Custom Nodes](https://github.com/ltdrdata/ComfyUI-Manager). 10 | 11 | > [!NOTE] 12 | > ComfyUI-GPENO will download GPEN dependencies upon first use. The destinations are `comfyui/models/facerestore_models` and `comfyui/models/facedetection`. 13 | 14 | ### Performance Evaluation 15 | 16 | Face restoration is commonly applied as a post-processing step in faceswap workflows. I ran a few tests against the popular [ReActor node](https://github.com/Gourieff/ComfyUI-ReActor), which includes both face restoration and GPEN features: 17 | 18 | - ReActor end-to-end time using the `GPEN-BFR-512` model for face restoration is about **1.4 seconds** on my GeForce 3090. 19 | - ReActor followed by this GPENO node with the same model takes about **0.7 seconds** - almost exactly 2x speedup for no loss of quality. 20 | - Applying the GPENO node directly to my test images took about **0.5 seconds**. 21 | 22 | Note that your inference time will depend on a number of factors, including input resolution, the number of faces in the image and so on. But you can probably expect 2-3x speedup compared to other implementations of GPEN. 23 | 24 | ### Advantages 25 | 26 | Apart from the speed, here are some other reasons why you might want to use GPENO in your projects: 27 | 28 | 1. It is not coupled with other functions such as faceswap, making it easy to apply to an image without additional processing 29 | 2. It uses a global cache, which helps you save VRAM if you have multiple instances of GPENO in your workflow 30 | 3. It exposes more controls for interfacing with GPEN than you would typically find in an all-in-one node 31 | 32 | ### Inputs 33 | 34 | - `image`: Image or a list of batch images for processing with GPEN 35 | - `use_global_cache`: If enabled, the model will be loaded once and shared across all instances of this node. This saves VRAM if you are using multiple instances of GPENO in your flow, but the settings must remain the same for all instances. 36 | - `unload`: If enabled, the model will be freed from the cache at the start of this node's execution (if applicable), and it will not be cached again. 37 | - Please refer to the GPEN repository for more information on remaining controls 38 | 39 | --- 40 | 41 | This node was adapted from the `[restore_faces]` shortcode of [Unprompted](https://github.com/ThereforeGames/unprompted), my Automatic1111 extension. -------------------------------------------------------------------------------- /gpeno/face_parse/parse_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Created by chaofengc (chaofenghust@gmail.com) 3 | 4 | @Modified by yangxy (yangtao9009@gmail.com) 5 | ''' 6 | 7 | from blocks import * 8 | import torch 9 | from torch import nn 10 | import numpy as np 11 | 12 | def define_P(in_size=512, out_size=512, min_feat_size=32, relu_type='LeakyReLU', isTrain=False, weight_path=None): 13 | net = ParseNet(in_size, out_size, min_feat_size, 64, 19, norm_type='bn', relu_type=relu_type, ch_range=[32, 256]) 14 | if not isTrain: 15 | net.eval() 16 | if weight_path is not None: 17 | net.load_state_dict(torch.load(weight_path)) 18 | return net 19 | 20 | 21 | class ParseNet(nn.Module): 22 | def __init__(self, 23 | in_size=128, 24 | out_size=128, 25 | min_feat_size=32, 26 | base_ch=64, 27 | parsing_ch=19, 28 | res_depth=10, 29 | relu_type='prelu', 30 | norm_type='bn', 31 | ch_range=[32, 512], 32 | ): 33 | super().__init__() 34 | self.res_depth = res_depth 35 | act_args = {'norm_type': norm_type, 'relu_type': relu_type} 36 | min_ch, max_ch = ch_range 37 | 38 | ch_clip = lambda x: max(min_ch, min(x, max_ch)) 39 | min_feat_size = min(in_size, min_feat_size) 40 | 41 | down_steps = int(np.log2(in_size//min_feat_size)) 42 | up_steps = int(np.log2(out_size//min_feat_size)) 43 | 44 | # =============== define encoder-body-decoder ==================== 45 | self.encoder = [] 46 | self.encoder.append(ConvLayer(3, base_ch, 3, 1)) 47 | head_ch = base_ch 48 | for i in range(down_steps): 49 | cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) 50 | self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) 51 | head_ch = head_ch * 2 52 | 53 | self.body = [] 54 | for i in range(res_depth): 55 | self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) 56 | 57 | self.decoder = [] 58 | for i in range(up_steps): 59 | cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) 60 | self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) 61 | head_ch = head_ch // 2 62 | 63 | self.encoder = nn.Sequential(*self.encoder) 64 | self.body = nn.Sequential(*self.body) 65 | self.decoder = nn.Sequential(*self.decoder) 66 | self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) 67 | self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) 68 | 69 | def forward(self, x): 70 | feat = self.encoder(x) 71 | x = feat + self.body(feat) 72 | x = self.decoder(x) 73 | out_img = self.out_img_conv(x) 74 | out_mask = self.out_mask_conv(x) 75 | return out_mask, out_img 76 | 77 | 78 | -------------------------------------------------------------------------------- /gpeno/face_model/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /gpeno/training/loss/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | #from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | from helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 4 | 5 | """ 6 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Backbone(Module): 11 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 12 | super(Backbone, self).__init__() 13 | assert input_size in [112, 224], "input_size should be 112 or 224" 14 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 15 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 16 | blocks = get_blocks(num_layers) 17 | if mode == 'ir': 18 | unit_module = bottleneck_IR 19 | elif mode == 'ir_se': 20 | unit_module = bottleneck_IR_SE 21 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 22 | BatchNorm2d(64), 23 | PReLU(64)) 24 | if input_size == 112: 25 | self.output_layer = Sequential(BatchNorm2d(512), 26 | Dropout(drop_ratio), 27 | Flatten(), 28 | Linear(512 * 7 * 7, 512), 29 | BatchNorm1d(512, affine=affine)) 30 | else: 31 | self.output_layer = Sequential(BatchNorm2d(512), 32 | Dropout(drop_ratio), 33 | Flatten(), 34 | Linear(512 * 14 * 14, 512), 35 | BatchNorm1d(512, affine=affine)) 36 | 37 | modules = [] 38 | for block in blocks: 39 | for bottleneck in block: 40 | modules.append(unit_module(bottleneck.in_channel, 41 | bottleneck.depth, 42 | bottleneck.stride)) 43 | self.body = Sequential(*modules) 44 | 45 | def forward(self, x): 46 | x = self.input_layer(x) 47 | x = self.body(x) 48 | x = self.output_layer(x) 49 | return l2_norm(x) 50 | 51 | 52 | def IR_50(input_size): 53 | """Constructs a ir-50 model.""" 54 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 55 | return model 56 | 57 | 58 | def IR_101(input_size): 59 | """Constructs a ir-101 model.""" 60 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 61 | return model 62 | 63 | 64 | def IR_152(input_size): 65 | """Constructs a ir-152 model.""" 66 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 67 | return model 68 | 69 | 70 | def IR_SE_50(input_size): 71 | """Constructs a ir_se-50 model.""" 72 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 73 | return model 74 | 75 | 76 | def IR_SE_101(input_size): 77 | """Constructs a ir_se-101 model.""" 78 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 79 | return model 80 | 81 | 82 | def IR_SE_152(input_size): 83 | """Constructs a ir_se-152 model.""" 84 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 85 | return model 86 | -------------------------------------------------------------------------------- /gpeno/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /gpeno/face_model/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Function 8 | from torch.utils.cpp_extension import load, _import_module_from_library 9 | 10 | # if running GPEN without cuda, please comment line 11-19 11 | if platform.system() == 'Linux' and torch.cuda.is_available(): 12 | module_path = os.path.dirname(__file__) 13 | fused = load( 14 | 'fused', 15 | sources=[ 16 | os.path.join(module_path, 'fused_bias_act.cpp'), 17 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 18 | ], 19 | ) 20 | 21 | 22 | #fused = _import_module_from_library('fused', '/tmp/torch_extensions/fused', True) 23 | 24 | 25 | class FusedLeakyReLUFunctionBackward(Function): 26 | @staticmethod 27 | def forward(ctx, grad_output, out, negative_slope, scale): 28 | ctx.save_for_backward(out) 29 | ctx.negative_slope = negative_slope 30 | ctx.scale = scale 31 | 32 | empty = grad_output.new_empty(0) 33 | 34 | grad_input = fused.fused_bias_act( 35 | grad_output, empty, out, 3, 1, negative_slope, scale 36 | ) 37 | 38 | dim = [0] 39 | 40 | if grad_input.ndim > 2: 41 | dim += list(range(2, grad_input.ndim)) 42 | 43 | grad_bias = grad_input.sum(dim).detach() 44 | 45 | return grad_input, grad_bias 46 | 47 | @staticmethod 48 | def backward(ctx, gradgrad_input, gradgrad_bias): 49 | out, = ctx.saved_tensors 50 | gradgrad_out = fused.fused_bias_act( 51 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 52 | ) 53 | 54 | return gradgrad_out, None, None, None 55 | 56 | 57 | class FusedLeakyReLUFunction(Function): 58 | @staticmethod 59 | def forward(ctx, input, bias, negative_slope, scale): 60 | empty = input.new_empty(0) 61 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 62 | ctx.save_for_backward(out) 63 | ctx.negative_slope = negative_slope 64 | ctx.scale = scale 65 | 66 | return out 67 | 68 | @staticmethod 69 | def backward(ctx, grad_output): 70 | out, = ctx.saved_tensors 71 | 72 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 73 | grad_output, out, ctx.negative_slope, ctx.scale 74 | ) 75 | 76 | return grad_input, grad_bias, None, None 77 | 78 | 79 | class FusedLeakyReLU(nn.Module): 80 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, device='cpu'): 81 | super().__init__() 82 | 83 | self.bias = nn.Parameter(torch.zeros(channel)) 84 | self.negative_slope = negative_slope 85 | self.scale = scale 86 | self.device = device 87 | 88 | def forward(self, input): 89 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale, self.device) 90 | 91 | 92 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5, device='cpu'): 93 | if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu': 94 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 95 | else: 96 | return scale * F.leaky_relu(input + bias.view((1, -1)+(1,)*(len(input.shape)-2)), negative_slope=negative_slope) 97 | -------------------------------------------------------------------------------- /gpeno/face_detect/data/wider_face.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import sys 4 | import torch 5 | import torch.utils.data as data 6 | import cv2 7 | import numpy as np 8 | 9 | class WiderFaceDetection(data.Dataset): 10 | def __init__(self, txt_path, preproc=None): 11 | self.preproc = preproc 12 | self.imgs_path = [] 13 | self.words = [] 14 | f = open(txt_path,'r') 15 | lines = f.readlines() 16 | isFirst = True 17 | labels = [] 18 | for line in lines: 19 | line = line.rstrip() 20 | if line.startswith('#'): 21 | if isFirst is True: 22 | isFirst = False 23 | else: 24 | labels_copy = labels.copy() 25 | self.words.append(labels_copy) 26 | labels.clear() 27 | path = line[2:] 28 | path = txt_path.replace('label.txt','images/') + path 29 | self.imgs_path.append(path) 30 | else: 31 | line = line.split(' ') 32 | label = [float(x) for x in line] 33 | labels.append(label) 34 | 35 | self.words.append(labels) 36 | 37 | def __len__(self): 38 | return len(self.imgs_path) 39 | 40 | def __getitem__(self, index): 41 | img = cv2.imread(self.imgs_path[index]) 42 | height, width, _ = img.shape 43 | 44 | labels = self.words[index] 45 | annotations = np.zeros((0, 15)) 46 | if len(labels) == 0: 47 | return annotations 48 | for idx, label in enumerate(labels): 49 | annotation = np.zeros((1, 15)) 50 | # bbox 51 | annotation[0, 0] = label[0] # x1 52 | annotation[0, 1] = label[1] # y1 53 | annotation[0, 2] = label[0] + label[2] # x2 54 | annotation[0, 3] = label[1] + label[3] # y2 55 | 56 | # landmarks 57 | annotation[0, 4] = label[4] # l0_x 58 | annotation[0, 5] = label[5] # l0_y 59 | annotation[0, 6] = label[7] # l1_x 60 | annotation[0, 7] = label[8] # l1_y 61 | annotation[0, 8] = label[10] # l2_x 62 | annotation[0, 9] = label[11] # l2_y 63 | annotation[0, 10] = label[13] # l3_x 64 | annotation[0, 11] = label[14] # l3_y 65 | annotation[0, 12] = label[16] # l4_x 66 | annotation[0, 13] = label[17] # l4_y 67 | if (annotation[0, 4]<0): 68 | annotation[0, 14] = -1 69 | else: 70 | annotation[0, 14] = 1 71 | 72 | annotations = np.append(annotations, annotation, axis=0) 73 | target = np.array(annotations) 74 | if self.preproc is not None: 75 | img, target = self.preproc(img, target) 76 | 77 | return torch.from_numpy(img), target 78 | 79 | def detection_collate(batch): 80 | """Custom collate fn for dealing with batches of images that have a different 81 | number of associated object annotations (bounding boxes). 82 | 83 | Arguments: 84 | batch: (tuple) A tuple of tensor images and lists of annotations 85 | 86 | Return: 87 | A tuple containing: 88 | 1) (tensor) batch of images stacked on their 0 dim 89 | 2) (list of tensors) annotations for a given image are stacked on 0 dim 90 | """ 91 | targets = [] 92 | imgs = [] 93 | for _, sample in enumerate(batch): 94 | for _, tup in enumerate(sample): 95 | if torch.is_tensor(tup): 96 | imgs.append(tup) 97 | elif isinstance(tup, type(np.empty(0))): 98 | annos = torch.from_numpy(tup).float() 99 | targets.append(annos) 100 | 101 | return (torch.stack(imgs, 0), targets) 102 | -------------------------------------------------------------------------------- /gpeno/training/data_loader/dataset_face.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import glob 5 | import math 6 | import random 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset 10 | 11 | import degradations 12 | 13 | 14 | class GFPGAN_degradation(object): 15 | def __init__(self): 16 | self.kernel_list = ['iso', 'aniso'] 17 | self.kernel_prob = [0.5, 0.5] 18 | self.blur_kernel_size = 41 19 | self.blur_sigma = [0.1, 10] 20 | self.downsample_range = [0.8, 8] 21 | self.noise_range = [0, 20] 22 | self.jpeg_range = [60, 100] 23 | self.gray_prob = 0.2 24 | self.color_jitter_prob = 0.0 25 | self.color_jitter_pt_prob = 0.0 26 | self.shift = 20/255. 27 | 28 | def degrade_process(self, img_gt): 29 | if random.random() > 0.5: 30 | img_gt = cv2.flip(img_gt, 1) 31 | 32 | h, w = img_gt.shape[:2] 33 | 34 | # random color jitter 35 | if np.random.uniform() < self.color_jitter_prob: 36 | jitter_val = np.random.uniform(-self.shift, self.shift, 3).astype(np.float32) 37 | img_gt = img_gt + jitter_val 38 | img_gt = np.clip(img_gt, 0, 1) 39 | 40 | # random grayscale 41 | if np.random.uniform() < self.gray_prob: 42 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) 43 | img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) 44 | 45 | # ------------------------ generate lq image ------------------------ # 46 | # blur 47 | kernel = degradations.random_mixed_kernels( 48 | self.kernel_list, 49 | self.kernel_prob, 50 | self.blur_kernel_size, 51 | self.blur_sigma, 52 | self.blur_sigma, [-math.pi, math.pi], 53 | noise_range=None) 54 | img_lq = cv2.filter2D(img_gt, -1, kernel) 55 | # downsample 56 | scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) 57 | img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) 58 | 59 | # noise 60 | if self.noise_range is not None: 61 | img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) 62 | # jpeg compression 63 | if self.jpeg_range is not None: 64 | img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) 65 | 66 | # round and clip 67 | img_lq = np.clip((img_lq * 255.0).round(), 0, 255) / 255. 68 | 69 | # resize to original size 70 | img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) 71 | 72 | return img_gt, img_lq 73 | 74 | class FaceDataset(Dataset): 75 | def __init__(self, path, resolution=512): 76 | self.resolution = resolution 77 | 78 | self.HQ_imgs = glob.glob(os.path.join(path, '*.*')) 79 | self.length = len(self.HQ_imgs) 80 | 81 | self.degrader = GFPGAN_degradation() 82 | 83 | def __len__(self): 84 | return self.length 85 | 86 | def __getitem__(self, index): 87 | img_gt = cv2.imread(self.HQ_imgs[index], cv2.IMREAD_COLOR) 88 | img_gt = cv2.resize(img_gt, (self.resolution, self.resolution), interpolation=cv2.INTER_AREA) 89 | 90 | # BFR degradation 91 | # We adopt the degradation of GFPGAN for simplicity, which however differs from our implementation in the paper. 92 | # Data degradation plays a key role in BFR. Please replace it with your own methods. 93 | img_gt = img_gt.astype(np.float32)/255. 94 | img_gt, img_lq = self.degrader.degrade_process(img_gt) 95 | 96 | img_gt = (torch.from_numpy(img_gt) - 0.5) / 0.5 97 | img_lq = (torch.from_numpy(img_lq) - 0.5) / 0.5 98 | 99 | img_gt = img_gt.permute(2, 0, 1).flip(0) # BGR->RGB 100 | img_lq = img_lq.permute(2, 0, 1).flip(0) # BGR->RGB 101 | 102 | return img_lq, img_gt 103 | 104 | -------------------------------------------------------------------------------- /gpeno/training/loss/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /gpeno/face_parse/blocks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | from torch.nn import functional as F 6 | import numpy as np 7 | 8 | class NormLayer(nn.Module): 9 | """Normalization Layers. 10 | ------------ 11 | # Arguments 12 | - channels: input channels, for batch norm and instance norm. 13 | - input_size: input shape without batch size, for layer norm. 14 | """ 15 | def __init__(self, channels, normalize_shape=None, norm_type='bn', ref_channels=None): 16 | super(NormLayer, self).__init__() 17 | norm_type = norm_type.lower() 18 | self.norm_type = norm_type 19 | if norm_type == 'bn': 20 | self.norm = nn.BatchNorm2d(channels, affine=True) 21 | elif norm_type == 'in': 22 | self.norm = nn.InstanceNorm2d(channels, affine=False) 23 | elif norm_type == 'gn': 24 | self.norm = nn.GroupNorm(32, channels, affine=True) 25 | elif norm_type == 'pixel': 26 | self.norm = lambda x: F.normalize(x, p=2, dim=1) 27 | elif norm_type == 'layer': 28 | self.norm = nn.LayerNorm(normalize_shape) 29 | elif norm_type == 'none': 30 | self.norm = lambda x: x*1.0 31 | else: 32 | assert 1==0, 'Norm type {} not support.'.format(norm_type) 33 | 34 | def forward(self, x, ref=None): 35 | if self.norm_type == 'spade': 36 | return self.norm(x, ref) 37 | else: 38 | return self.norm(x) 39 | 40 | 41 | class ReluLayer(nn.Module): 42 | """Relu Layer. 43 | ------------ 44 | # Arguments 45 | - relu type: type of relu layer, candidates are 46 | - ReLU 47 | - LeakyReLU: default relu slope 0.2 48 | - PRelu 49 | - SELU 50 | - none: direct pass 51 | """ 52 | def __init__(self, channels, relu_type='relu'): 53 | super(ReluLayer, self).__init__() 54 | relu_type = relu_type.lower() 55 | if relu_type == 'relu': 56 | self.func = nn.ReLU(True) 57 | elif relu_type == 'leakyrelu': 58 | self.func = nn.LeakyReLU(0.2, inplace=True) 59 | elif relu_type == 'prelu': 60 | self.func = nn.PReLU(channels) 61 | elif relu_type == 'selu': 62 | self.func = nn.SELU(True) 63 | elif relu_type == 'none': 64 | self.func = lambda x: x*1.0 65 | else: 66 | assert 1==0, 'Relu type {} not support.'.format(relu_type) 67 | 68 | def forward(self, x): 69 | return self.func(x) 70 | 71 | 72 | class ConvLayer(nn.Module): 73 | def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True, bias=True): 74 | super(ConvLayer, self).__init__() 75 | self.use_pad = use_pad 76 | self.norm_type = norm_type 77 | if norm_type in ['bn']: 78 | bias = False 79 | 80 | stride = 2 if scale == 'down' else 1 81 | 82 | self.scale_func = lambda x: x 83 | if scale == 'up': 84 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') 85 | 86 | self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.)/2))) 87 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) 88 | 89 | self.relu = ReluLayer(out_channels, relu_type) 90 | self.norm = NormLayer(out_channels, norm_type=norm_type) 91 | 92 | def forward(self, x): 93 | out = self.scale_func(x) 94 | if self.use_pad: 95 | out = self.reflection_pad(out) 96 | out = self.conv2d(out) 97 | out = self.norm(out) 98 | out = self.relu(out) 99 | return out 100 | 101 | 102 | class ResidualBlock(nn.Module): 103 | """ 104 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html 105 | """ 106 | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): 107 | super(ResidualBlock, self).__init__() 108 | 109 | if scale == 'none' and c_in == c_out: 110 | self.shortcut_func = lambda x: x 111 | else: 112 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) 113 | 114 | scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} 115 | scale_conf = scale_config_dict[scale] 116 | 117 | self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) 118 | self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') 119 | 120 | def forward(self, x): 121 | identity = self.shortcut_func(x) 122 | 123 | res = self.conv1(x) 124 | res = self.conv2(res) 125 | return identity + res 126 | 127 | 128 | -------------------------------------------------------------------------------- /gpeno/face_detect/facemodels/retinaface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models.detection.backbone_utils as backbone_utils 4 | import torchvision.models._utils as _utils 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | from facemodels.net import MobileNetV1 as MobileNetV1 9 | from facemodels.net import FPN as FPN 10 | from facemodels.net import SSH as SSH 11 | 12 | 13 | class ClassHead(nn.Module): 14 | 15 | def __init__(self, inchannels=512, num_anchors=3): 16 | super(ClassHead, self).__init__() 17 | self.num_anchors = num_anchors 18 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) 19 | 20 | def forward(self, x): 21 | out = self.conv1x1(x) 22 | out = out.permute(0, 2, 3, 1).contiguous() 23 | 24 | return out.view(out.shape[0], -1, 2) 25 | 26 | 27 | class BboxHead(nn.Module): 28 | 29 | def __init__(self, inchannels=512, num_anchors=3): 30 | super(BboxHead, self).__init__() 31 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) 32 | 33 | def forward(self, x): 34 | out = self.conv1x1(x) 35 | out = out.permute(0, 2, 3, 1).contiguous() 36 | 37 | return out.view(out.shape[0], -1, 4) 38 | 39 | 40 | class LandmarkHead(nn.Module): 41 | 42 | def __init__(self, inchannels=512, num_anchors=3): 43 | super(LandmarkHead, self).__init__() 44 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) 45 | 46 | def forward(self, x): 47 | out = self.conv1x1(x) 48 | out = out.permute(0, 2, 3, 1).contiguous() 49 | 50 | return out.view(out.shape[0], -1, 10) 51 | 52 | 53 | class RetinaFace(nn.Module): 54 | 55 | def __init__(self, cfg=None, phase='train'): 56 | """ 57 | :param cfg: Network related settings. 58 | :param phase: train or test. 59 | """ 60 | super(RetinaFace, self).__init__() 61 | self.phase = phase 62 | backbone = None 63 | 64 | if cfg['name'] == 'detection_mobilenet0.25_Final': 65 | backbone = MobileNetV1() 66 | if cfg['pretrain']: 67 | checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu')) 68 | from collections import OrderedDict 69 | new_state_dict = OrderedDict() 70 | for k, v in checkpoint['state_dict'].items(): 71 | name = k[7:] # remove module. 72 | new_state_dict[name] = v 73 | # load params 74 | backbone.load_state_dict(new_state_dict) 75 | elif cfg['name'] == 'detection_Resnet50_Final': 76 | import torchvision.models as models 77 | backbone = models.resnet50(pretrained=cfg['pretrain']) 78 | 79 | self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers']) 80 | in_channels_stage2 = cfg['in_channel'] 81 | in_channels_list = [ 82 | in_channels_stage2 * 2, 83 | in_channels_stage2 * 4, 84 | in_channels_stage2 * 8, 85 | ] 86 | out_channels = cfg['out_channel'] 87 | self.fpn = FPN(in_channels_list, out_channels) 88 | self.ssh1 = SSH(out_channels, out_channels) 89 | self.ssh2 = SSH(out_channels, out_channels) 90 | self.ssh3 = SSH(out_channels, out_channels) 91 | 92 | self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) 93 | self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) 94 | self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) 95 | 96 | def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): 97 | classhead = nn.ModuleList() 98 | for i in range(fpn_num): 99 | classhead.append(ClassHead(inchannels, anchor_num)) 100 | return classhead 101 | 102 | def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): 103 | bboxhead = nn.ModuleList() 104 | for i in range(fpn_num): 105 | bboxhead.append(BboxHead(inchannels, anchor_num)) 106 | return bboxhead 107 | 108 | def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): 109 | landmarkhead = nn.ModuleList() 110 | for i in range(fpn_num): 111 | landmarkhead.append(LandmarkHead(inchannels, anchor_num)) 112 | return landmarkhead 113 | 114 | def forward(self, inputs): 115 | out = self.body(inputs) 116 | 117 | # FPN 118 | fpn = self.fpn(out) 119 | 120 | # SSH 121 | feature1 = self.ssh1(fpn[0]) 122 | feature2 = self.ssh2(fpn[1]) 123 | feature3 = self.ssh3(fpn[2]) 124 | features = [feature1, feature2, feature3] 125 | 126 | bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) 127 | classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) 128 | ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1) 129 | 130 | if self.phase == 'train': 131 | output = (bbox_regressions, classifications, ldm_regressions) 132 | else: 133 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) 134 | return output 135 | -------------------------------------------------------------------------------- /gpeno/misc/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from pathlib import Path 4 | import glob 5 | import numpy as np 6 | import cv2 7 | from zipfile import ZipFile 8 | from PIL import Image 9 | import shutil 10 | import cog 11 | from face_enhancement import FaceEnhancement 12 | from face_colorization import FaceColorization 13 | from face_inpainting import FaceInpainting, brush_stroke_mask 14 | 15 | 16 | class Predictor(cog.Predictor): 17 | def setup(self): 18 | faceenhancer_model = {'name': 'GPEN-BFR-256', 'size': 256, 'channel_multiplier': 1, 'narrow': 0.5} 19 | self.faceenhancer = FaceEnhancement(size=faceenhancer_model['size'], model=faceenhancer_model['name'], 20 | channel_multiplier=faceenhancer_model['channel_multiplier'], 21 | narrow=faceenhancer_model['narrow']) 22 | faceinpainter_model = {'name': 'GPEN-Inpainting-1024', 'size': 1024} 23 | self.faceinpainter = FaceInpainting(size=faceinpainter_model['size'], model=faceinpainter_model['name'], 24 | channel_multiplier=2) 25 | facecolorizer_model = {'name': 'GPEN-Colorization-1024', 'size': 1024} 26 | self.facecolorizer = FaceColorization(size=facecolorizer_model['size'], model=facecolorizer_model['name'], 27 | channel_multiplier=2) 28 | 29 | @cog.input( 30 | "image", 31 | type=Path, 32 | help="input image", 33 | ) 34 | @cog.input( 35 | "task", 36 | type=str, 37 | options=['Face Restoration', 'Face Colorization', 'Face Inpainting'], 38 | default='Face Restoration', 39 | help="choose task type" 40 | ) 41 | @cog.input( 42 | "output_individual", 43 | type=bool, 44 | default=False, 45 | help="whether outputs individual enhanced faces, valid for Face Restoration. When set to true, a zip folder of " 46 | "all the enhanced faces in the input will be generated for download." 47 | ) 48 | @cog.input( 49 | "broken_image", 50 | type=bool, 51 | default=True, 52 | help="whether the input image is broken, valid for Face Inpainting. When set to True, the output will be the " 53 | "'fixed' image. When set to False, the image will randomly add brush strokes to simulate a broken image, " 54 | "and the output will be broken + fixed image" 55 | ) 56 | def predict(self, image, task='Face Restoration', output_individual=False, broken_image=True): 57 | out_path = Path(tempfile.mkdtemp()) / "out.png" 58 | if task == 'Face Restoration': 59 | im = cv2.imread(str(image), cv2.IMREAD_COLOR) # BGR 60 | assert isinstance(im, np.ndarray), 'input filename error' 61 | im = cv2.resize(im, (0, 0), fx=2, fy=2) 62 | img, orig_faces, enhanced_faces = self.faceenhancer.process(im) 63 | cv2.imwrite(str(out_path), img) 64 | if output_individual: 65 | zip_folder = 'out_zip' 66 | os.makedirs(zip_folder, exist_ok=True) 67 | out_path = Path(tempfile.mkdtemp()) / "out.zip" 68 | try: 69 | cv2.imwrite(os.path.join(zip_folder, 'whole_image.jpg'), img) 70 | for m, ef in enumerate(enhanced_faces): 71 | cv2.imwrite(os.path.join(zip_folder, f'face_{m}.jpg'), ef) 72 | img_list = sorted(glob.glob(os.path.join(zip_folder, '*'))) 73 | with ZipFile(str(out_path), 'w') as zipfile: 74 | for img in img_list: 75 | zipfile.write(img) 76 | finally: 77 | clean_folder(zip_folder) 78 | elif task == 'Face Colorization': 79 | grayf = cv2.imread(str(image), cv2.IMREAD_GRAYSCALE) 80 | grayf = cv2.cvtColor(grayf, cv2.COLOR_GRAY2BGR) # channel: 1->3 81 | colorf = self.facecolorizer.process(grayf) 82 | cv2.imwrite(str(out_path), colorf) 83 | else: 84 | originf = cv2.imread(str(image), cv2.IMREAD_COLOR) 85 | brokenf = originf 86 | if not broken_image: 87 | brokenf = np.asarray(brush_stroke_mask(Image.fromarray(originf))) 88 | completef = self.faceinpainter.process(brokenf) 89 | brokenf = cv2.resize(brokenf, completef.shape[:2]) 90 | out_img = completef if broken_image else np.hstack((brokenf, completef)) 91 | cv2.imwrite(str(out_path), out_img) 92 | 93 | return out_path 94 | 95 | 96 | def clean_folder(folder): 97 | for filename in os.listdir(folder): 98 | file_path = os.path.join(folder, filename) 99 | try: 100 | if os.path.isfile(file_path) or os.path.islink(file_path): 101 | os.unlink(file_path) 102 | elif os.path.isdir(file_path): 103 | shutil.rmtree(file_path) 104 | except Exception as e: 105 | print('Failed to delete %s. Reason: %s' % (file_path, e)) -------------------------------------------------------------------------------- /gpeno/face_detect/retinaface_detection.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import os 6 | import torch 7 | import numpy as np 8 | from data import cfg_re50, cfg_mnet 9 | from layers.functions.prior_box import PriorBox 10 | from gpeno.face_detect.utils.nms.py_cpu_nms import py_cpu_nms 11 | import cv2 12 | from facemodels.retinaface import RetinaFace 13 | from gpeno.face_detect.utils.box_utils import decode, decode_landm 14 | import torch.nn.functional as F 15 | 16 | 17 | class RetinaFaceDetection(object): 18 | 19 | def __init__(self, base_dir, device='cuda', network='RetinaFace-R50'): 20 | torch.set_grad_enabled(False) 21 | print(f"Initializing RetinaFaceDetection on device {device}...") 22 | self.pretrained_path = os.path.join(base_dir, 'facedetection', network + '.pth') 23 | self.device = device 24 | if network == "detection_Resnet50_Final": 25 | self.cfg = cfg_re50 26 | else: 27 | self.cfg = cfg_mnet 28 | self.net = RetinaFace(cfg=self.cfg, phase='test') 29 | self.net = self.net.to(self.device) 30 | 31 | self.load_model() 32 | 33 | self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) 34 | 35 | def check_keys(self, pretrained_state_dict): 36 | ckpt_keys = set(pretrained_state_dict.keys()) 37 | model_keys = set(self.net.state_dict().keys()) 38 | used_pretrained_keys = model_keys & ckpt_keys 39 | unused_pretrained_keys = ckpt_keys - model_keys 40 | missing_keys = model_keys - ckpt_keys 41 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' 42 | return True 43 | 44 | def remove_prefix(self, state_dict, prefix): 45 | ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' 46 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 47 | return {f(key): value for key, value in state_dict.items()} 48 | 49 | def load_model(self): 50 | pretrained_dict = torch.load(self.pretrained_path, map_location=self.device) 51 | 52 | if "state_dict" in pretrained_dict.keys(): 53 | pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.') 54 | else: 55 | pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') 56 | # self.check_keys(pretrained_dict) 57 | self.net.load_state_dict(pretrained_dict, strict=False) 58 | self.net.eval() 59 | 60 | print("Finished loading model") 61 | 62 | def detect(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False): 63 | img = np.float32(img_raw) 64 | 65 | im_height, im_width = img.shape[:2] 66 | ss = 1.0 67 | # tricky 68 | if max(im_height, im_width) > 1500: 69 | ss = 1000.0 / max(im_height, im_width) 70 | img = cv2.resize(img, (0, 0), fx=ss, fy=ss) 71 | im_height, im_width = img.shape[:2] 72 | 73 | scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) 74 | img -= (104, 117, 123) 75 | img = img.transpose(2, 0, 1) 76 | img = torch.from_numpy(img).unsqueeze(0) 77 | img = img.to(self.device) 78 | scale = scale.to(self.device) 79 | 80 | loc, conf, landms = self.net(img) # forward pass 81 | 82 | del img 83 | 84 | priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) 85 | priors = priorbox.forward() 86 | priors = priors.to(self.device) 87 | prior_data = priors.data 88 | boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) 89 | boxes = boxes * scale / resize 90 | boxes = boxes.cpu().numpy() 91 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1] 92 | landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance']) 93 | scale1 = torch.Tensor([im_width, im_height, im_width, im_height, im_width, im_height, im_width, im_height, im_width, im_height]) 94 | scale1 = scale1.to(self.device) 95 | landms = landms * scale1 / resize 96 | landms = landms.cpu().numpy() 97 | 98 | # ignore low scores 99 | inds = np.where(scores > confidence_threshold)[0] 100 | boxes = boxes[inds] 101 | landms = landms[inds] 102 | scores = scores[inds] 103 | 104 | # keep top-K before NMS 105 | order = scores.argsort()[::-1][:top_k] 106 | boxes = boxes[order] 107 | landms = landms[order] 108 | scores = scores[order] 109 | 110 | # do NMS 111 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) 112 | keep = py_cpu_nms(dets, nms_threshold) 113 | # keep = nms(dets, nms_threshold,force_cpu=args.cpu) 114 | dets = dets[keep, :] 115 | landms = landms[keep] 116 | 117 | # keep top-K faster NMS 118 | dets = dets[:keep_top_k, :] 119 | landms = landms[:keep_top_k, :] 120 | 121 | # sort faces(delete) 122 | ''' 123 | fscores = [det[4] for det in dets] 124 | sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index 125 | tmp = [landms[idx] for idx in sorted_idx] 126 | landms = np.asarray(tmp) 127 | ''' 128 | 129 | landms = landms.reshape((-1, 5, 2)) 130 | landms = landms.transpose((0, 2, 1)) 131 | landms = landms.reshape( 132 | -1, 133 | 10, 134 | ) 135 | return dets / ss, landms / ss 136 | -------------------------------------------------------------------------------- /gpeno/face_detect/facemodels/net.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models._utils as _utils 5 | import torchvision.models as models 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | def conv_bn(inp, oup, stride = 1, leaky = 0): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 12 | nn.BatchNorm2d(oup), 13 | nn.LeakyReLU(negative_slope=leaky, inplace=True) 14 | ) 15 | 16 | def conv_bn_no_relu(inp, oup, stride): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 19 | nn.BatchNorm2d(oup), 20 | ) 21 | 22 | def conv_bn1X1(inp, oup, stride, leaky=0): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), 25 | nn.BatchNorm2d(oup), 26 | nn.LeakyReLU(negative_slope=leaky, inplace=True) 27 | ) 28 | 29 | def conv_dw(inp, oup, stride, leaky=0.1): 30 | return nn.Sequential( 31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 32 | nn.BatchNorm2d(inp), 33 | nn.LeakyReLU(negative_slope= leaky,inplace=True), 34 | 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | nn.LeakyReLU(negative_slope= leaky,inplace=True), 38 | ) 39 | 40 | class SSH(nn.Module): 41 | def __init__(self, in_channel, out_channel): 42 | super(SSH, self).__init__() 43 | assert out_channel % 4 == 0 44 | leaky = 0 45 | if (out_channel <= 64): 46 | leaky = 0.1 47 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1) 48 | 49 | self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky) 50 | self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) 51 | 52 | self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky) 53 | self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) 54 | 55 | def forward(self, input): 56 | conv3X3 = self.conv3X3(input) 57 | 58 | conv5X5_1 = self.conv5X5_1(input) 59 | conv5X5 = self.conv5X5_2(conv5X5_1) 60 | 61 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 62 | conv7X7 = self.conv7x7_3(conv7X7_2) 63 | 64 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 65 | out = F.relu(out) 66 | return out 67 | 68 | class FPN(nn.Module): 69 | def __init__(self,in_channels_list,out_channels): 70 | super(FPN,self).__init__() 71 | leaky = 0 72 | if (out_channels <= 64): 73 | leaky = 0.1 74 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky) 75 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky) 76 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky) 77 | 78 | self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky) 79 | self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky) 80 | 81 | def forward(self, input): 82 | # names = list(input.keys()) 83 | input = list(input.values()) 84 | 85 | output1 = self.output1(input[0]) 86 | output2 = self.output2(input[1]) 87 | output3 = self.output3(input[2]) 88 | 89 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") 90 | output2 = output2 + up3 91 | output2 = self.merge2(output2) 92 | 93 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") 94 | output1 = output1 + up2 95 | output1 = self.merge1(output1) 96 | 97 | out = [output1, output2, output3] 98 | return out 99 | 100 | 101 | 102 | class MobileNetV1(nn.Module): 103 | def __init__(self): 104 | super(MobileNetV1, self).__init__() 105 | self.stage1 = nn.Sequential( 106 | conv_bn(3, 8, 2, leaky = 0.1), # 3 107 | conv_dw(8, 16, 1), # 7 108 | conv_dw(16, 32, 2), # 11 109 | conv_dw(32, 32, 1), # 19 110 | conv_dw(32, 64, 2), # 27 111 | conv_dw(64, 64, 1), # 43 112 | ) 113 | self.stage2 = nn.Sequential( 114 | conv_dw(64, 128, 2), # 43 + 16 = 59 115 | conv_dw(128, 128, 1), # 59 + 32 = 91 116 | conv_dw(128, 128, 1), # 91 + 32 = 123 117 | conv_dw(128, 128, 1), # 123 + 32 = 155 118 | conv_dw(128, 128, 1), # 155 + 32 = 187 119 | conv_dw(128, 128, 1), # 187 + 32 = 219 120 | ) 121 | self.stage3 = nn.Sequential( 122 | conv_dw(128, 256, 2), # 219 +3 2 = 241 123 | conv_dw(256, 256, 1), # 241 + 64 = 301 124 | ) 125 | self.avg = nn.AdaptiveAvgPool2d((1,1)) 126 | self.fc = nn.Linear(256, 1000) 127 | 128 | def forward(self, x): 129 | x = self.stage1(x) 130 | x = self.stage2(x) 131 | x = self.stage3(x) 132 | x = self.avg(x) 133 | # x = self.model(x) 134 | x = x.view(-1, 256) 135 | x = self.fc(x) 136 | return x 137 | 138 | -------------------------------------------------------------------------------- /gpeno/face_detect/layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from gpeno.face_detect.utils.box_utils import match, log_sum_exp 6 | from data import cfg_mnet 7 | 8 | GPU = cfg_mnet['gpu_train'] 9 | 10 | 11 | class MultiBoxLoss(nn.Module): 12 | """SSD Weighted Loss Function 13 | Compute Targets: 14 | 1) Produce Confidence Target Indices by matching ground truth boxes 15 | with (default) 'priorboxes' that have jaccard index > threshold parameter 16 | (default threshold: 0.5). 17 | 2) Produce localization target by 'encoding' variance into offsets of ground 18 | truth boxes and their matched 'priorboxes'. 19 | 3) Hard negative mining to filter the excessive number of negative examples 20 | that comes with using a large number of default bounding boxes. 21 | (default negative:positive ratio 3:1) 22 | Objective Loss: 23 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 24 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 25 | weighted by α which is set to 1 by cross val. 26 | Args: 27 | c: class confidences, 28 | l: predicted boxes, 29 | g: ground truth boxes 30 | N: number of matched default boxes 31 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 32 | """ 33 | 34 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): 35 | super(MultiBoxLoss, self).__init__() 36 | self.num_classes = num_classes 37 | self.threshold = overlap_thresh 38 | self.background_label = bkg_label 39 | self.encode_target = encode_target 40 | self.use_prior_for_matching = prior_for_matching 41 | self.do_neg_mining = neg_mining 42 | self.negpos_ratio = neg_pos 43 | self.neg_overlap = neg_overlap 44 | self.variance = [0.1, 0.2] 45 | 46 | def forward(self, predictions, priors, targets): 47 | """Multibox Loss 48 | Args: 49 | predictions (tuple): A tuple containing loc preds, conf preds, 50 | and prior boxes from SSD net. 51 | conf shape: torch.size(batch_size,num_priors,num_classes) 52 | loc shape: torch.size(batch_size,num_priors,4) 53 | priors shape: torch.size(num_priors,4) 54 | 55 | ground_truth (tensor): Ground truth boxes and labels for a batch, 56 | shape: [batch_size,num_objs,5] (last idx is the label). 57 | """ 58 | 59 | loc_data, conf_data, landm_data = predictions 60 | priors = priors 61 | num = loc_data.size(0) 62 | num_priors = (priors.size(0)) 63 | 64 | # match priors (default boxes) and ground truth boxes 65 | loc_t = torch.Tensor(num, num_priors, 4) 66 | landm_t = torch.Tensor(num, num_priors, 10) 67 | conf_t = torch.LongTensor(num, num_priors) 68 | for idx in range(num): 69 | truths = targets[idx][:, :4].data 70 | labels = targets[idx][:, -1].data 71 | landms = targets[idx][:, 4:14].data 72 | defaults = priors.data 73 | match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) 74 | if GPU: 75 | loc_t = loc_t.cuda() 76 | conf_t = conf_t.cuda() 77 | landm_t = landm_t.cuda() 78 | 79 | zeros = torch.tensor(0).cuda() 80 | # landm Loss (Smooth L1) 81 | # Shape: [batch,num_priors,10] 82 | pos1 = conf_t > zeros 83 | num_pos_landm = pos1.long().sum(1, keepdim=True) 84 | N1 = max(num_pos_landm.data.sum().float(), 1) 85 | pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) 86 | landm_p = landm_data[pos_idx1].view(-1, 10) 87 | landm_t = landm_t[pos_idx1].view(-1, 10) 88 | loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') 89 | 90 | pos = conf_t != zeros 91 | conf_t[pos] = 1 92 | 93 | # Localization Loss (Smooth L1) 94 | # Shape: [batch,num_priors,4] 95 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 96 | loc_p = loc_data[pos_idx].view(-1, 4) 97 | loc_t = loc_t[pos_idx].view(-1, 4) 98 | loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') 99 | 100 | # Compute max conf across batch for hard negative mining 101 | batch_conf = conf_data.view(-1, self.num_classes) 102 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) 103 | 104 | # Hard Negative Mining 105 | loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now 106 | loss_c = loss_c.view(num, -1) 107 | _, loss_idx = loss_c.sort(1, descending=True) 108 | _, idx_rank = loss_idx.sort(1) 109 | num_pos = pos.long().sum(1, keepdim=True) 110 | num_neg = torch.clamp(self.negpos_ratio * num_pos, max=pos.size(1) - 1) 111 | neg = idx_rank < num_neg.expand_as(idx_rank) 112 | 113 | # Confidence Loss Including Positive and Negative Examples 114 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 115 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 116 | conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes) 117 | targets_weighted = conf_t[(pos + neg).gt(0)] 118 | loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') 119 | 120 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 121 | N = max(num_pos.data.sum().float(), 1) 122 | loss_l /= N 123 | loss_c /= N 124 | loss_landm /= N1 125 | 126 | return loss_l, loss_c, loss_landm 127 | -------------------------------------------------------------------------------- /gpeno/face_enhancement.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from gpeno.face_detect.retinaface_detection import RetinaFaceDetection 5 | # from gpeno.face_detect.batch_face import RetinaFace 6 | from gpeno.face_parse.face_parsing import FaceParse 7 | from gpeno.face_model.face_gan import FaceGAN 8 | # from gpen.sr_model.real_esrnet import RealESRNet 9 | from gpeno.align_faces import warp_and_crop_face, get_reference_facial_points 10 | 11 | 12 | class FaceEnhancement(object): 13 | 14 | def __init__(self, args, base_dir='./', in_size=512, out_size=None, model=None, use_sr=True, device='cuda', interp=3, backbone="RetinaFace-R50", log=None, colorize=False): 15 | self.log = log 16 | # self.log.debug("Initializing FaceEnhancement...") 17 | self.facedetector = RetinaFaceDetection(base_dir, device, network=backbone) 18 | # self.facedetector = RetinaFace() 19 | self.facegan = FaceGAN(base_dir, in_size, out_size, model, args.channel_multiplier, args.narrow, args.key, device=device) 20 | # self.srmodel = RealESRNet(base_dir, args.sr_model, args.sr_scale, args.tile_size, device=device) 21 | self.faceparser = FaceParse(base_dir, device=device) 22 | self.use_sr = use_sr 23 | self.in_size = in_size 24 | self.out_size = in_size if out_size is None else out_size 25 | self.threshold = 0.9 26 | self.alpha = args.alpha 27 | self.interp = interp 28 | self.colorize = colorize 29 | 30 | if self.colorize: 31 | self.colorizer = FaceGAN(base_dir, 1024, 1024, "GPEN-Colorization-1024", args.channel_multiplier, args.narrow, None, device) 32 | 33 | self.mask = np.zeros((512, 512), np.float32) 34 | cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA) 35 | self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) 36 | 37 | self.kernel = np.array(([0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625]), dtype="float32") 38 | self.reference_5pts = get_reference_facial_points((self.in_size, self.in_size), 0.25, (0, 0), True) 39 | 40 | def mask_postprocess(self, mask, thres=26): 41 | mask[:thres, :] = 0 42 | mask[-thres:, :] = 0 43 | mask[:, :thres] = 0 44 | mask[:, -thres:] = 0 45 | mask = cv2.GaussianBlur(mask, (101, 101), 4) 46 | return mask.astype(np.float32) 47 | 48 | def colorize_face(self, face): 49 | # Convert BGR (OpenCV) to RGB before colorizing 50 | rgb_input = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) 51 | color_face = self.colorizer.process(rgb_input) 52 | color_face = cv2.cvtColor(color_face, cv2.COLOR_RGB2BGR) 53 | 54 | if face.shape[:2] != color_face.shape[:2]: 55 | out_rs = cv2.resize(color_face, face.shape[:2][::-1]) 56 | gray_yuv = cv2.cvtColor(face, cv2.COLOR_BGR2YUV) 57 | out_yuv = cv2.cvtColor(out_rs, cv2.COLOR_BGR2YUV) 58 | 59 | out_yuv[:, :, 0] = gray_yuv[:, :, 0] 60 | color_face = cv2.cvtColor(out_yuv, cv2.COLOR_YUV2BGR) 61 | 62 | return color_face 63 | 64 | def process(self, img, aligned=False): 65 | orig_faces, enhanced_faces = [], [] 66 | if aligned: 67 | print("Aligned is true") 68 | ef = self.facegan.process(img) 69 | if self.colorize: 70 | ef = self.colorize_face(ef) 71 | orig_faces.append(img) 72 | enhanced_faces.append(ef) 73 | 74 | # if self.use_sr: 75 | # ef = self.srmodel.process(ef) 76 | 77 | return ef, orig_faces, enhanced_faces 78 | 79 | # if self.use_sr: 80 | # img_sr = self.srmodel.process(img) 81 | # if img_sr is not None: 82 | # img = cv2.resize(img, img_sr.shape[:2][::-1]) 83 | 84 | with torch.no_grad(): 85 | print("Starting face detection") 86 | facebs, landms = self.facedetector.detect(img) 87 | # faces = self.facedetector.detect(img) 88 | 89 | # self.log.debug("Face detection complete") 90 | print("Face detection complete") 91 | 92 | height, width = img.shape[:2] 93 | full_mask = np.zeros((height, width), dtype=np.float32) 94 | full_img = np.zeros(img.shape, dtype=np.uint8) 95 | 96 | # for i, (faceb, facial5points, score) in enumerate(faces): 97 | for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): 98 | if faceb[4] < self.threshold: 99 | continue 100 | fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0]) 101 | 102 | facial5points = np.reshape(facial5points, (2, 5)) 103 | 104 | print("Starting face alignment...") 105 | of, tfm_inv = warp_and_crop_face(img, facial5points, reference_pts=self.reference_5pts, crop_size=(self.in_size, self.in_size)) 106 | 107 | print("Face alignment complete") 108 | 109 | # enhance the face 110 | ef = self.facegan.process(of) 111 | 112 | if self.colorize: 113 | ef = self.colorize_face(ef) 114 | 115 | # self.log.debug("Face GAN complete") 116 | print("Face GAN complete") 117 | 118 | orig_faces.append(of) 119 | enhanced_faces.append(ef) 120 | 121 | tmp_mask = self.mask 122 | tmp_mask = self.mask_postprocess(self.faceparser.process(ef)[0] / 255.) 123 | # self.log.debug("Mask postprocessing complete") 124 | tmp_mask = cv2.resize(tmp_mask, (self.in_size, self.in_size), interpolation=self.interp) 125 | 126 | tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=self.interp) 127 | 128 | # from PIL import Image 129 | 130 | # tmp_mask_pil = Image.fromarray(tmp_mask) 131 | # Apply the inverse affine transformation using PIL 132 | # tmp_mask_pil = tmp_mask_pil.transform((width, height), Image.AFFINE, tuple(tfm_inv.flatten()), resample=self.interp) 133 | # tmp_mask = np.array(tmp_mask_pil) 134 | 135 | if min(fh, fw) < 100: # gaussian filter for small faces 136 | ef = cv2.filter2D(ef, -1, self.kernel) 137 | 138 | ef = cv2.addWeighted(ef, self.alpha, of, 1. - self.alpha, 0.0) 139 | 140 | if self.in_size != self.out_size: 141 | ef = cv2.resize(ef, (self.in_size, self.in_size), interpolation=self.interp) 142 | tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=self.interp) 143 | 144 | mask = tmp_mask - full_mask 145 | full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] 146 | full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] 147 | 148 | full_mask = full_mask[:, :, np.newaxis] 149 | # if self.use_sr and img_sr is not None: 150 | # img = cv2.convertScaleAbs(img_sr * (1 - full_mask) + full_img * full_mask) 151 | # else: 152 | img = cv2.convertScaleAbs(img * (1 - full_mask) + full_img * full_mask) 153 | 154 | print("Postprocessing complete") 155 | 156 | return img, orig_faces, enhanced_faces 157 | -------------------------------------------------------------------------------- /gpeno/misc/onnx_export.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | from torch.autograd import Variable 9 | 10 | import __init_paths 11 | #from face_model import model 12 | from face_model.gpen_model import FullGenerator 13 | 14 | def model_load(model, path): 15 | """Load model.""" 16 | 17 | if not os.path.exists(path): 18 | print("Model '{}' does not exist.".format(path)) 19 | return 20 | 21 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 22 | 23 | model.load_state_dict(state_dict) 24 | 25 | 26 | def export_onnx(model, path, force_cpu): 27 | """Export onnx model.""" 28 | 29 | import onnx 30 | import onnxruntime 31 | #from onnx import optimizer 32 | import numpy as np 33 | 34 | onnx_file_name = os.path.join(path, model+".onnx") 35 | model_weight_file = os.path.join(path, model+".pth") 36 | dummy_input = Variable(torch.randn(1, 3, 1024, 1024)) 37 | 38 | # 1. Create and load model. 39 | model_setenv(force_cpu) 40 | torch_model = get_model(model_weight_file) 41 | torch_model.eval() 42 | 43 | # 2. Model export 44 | print("Export model ...") 45 | 46 | input_names = ["input"] 47 | output_names = ["output"] 48 | device = model_device() 49 | # torch.onnx.export(torch_model, dummy_input.to(device), onnx_file_name, 50 | # input_names=input_names, 51 | # output_names=output_names, 52 | # verbose=False, 53 | # opset_version=12, 54 | # keep_initializers_as_inputs=False, 55 | # export_params=True, 56 | # operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK) 57 | torch.onnx.export(torch_model, dummy_input.to(device), onnx_file_name, 58 | input_names=input_names, 59 | output_names=output_names, 60 | verbose=False, 61 | opset_version=10, 62 | operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK) 63 | 64 | # 3. Optimize model 65 | print('Checking model ...') 66 | onnx_model = onnx.load(onnx_file_name) 67 | onnx.checker.check_model(onnx_model) 68 | # https://github.com/onnx/optimizer 69 | print('Done checking model ...') 70 | # 4. Visual model 71 | # python -c "import netron; netron.start('output/image_zoom.onnx')" 72 | 73 | def verify_onnx(model, path, force_cpu): 74 | """Verify onnx model.""" 75 | 76 | import onnxruntime 77 | import numpy as np 78 | 79 | 80 | model_weight_file = os.path.join(path, model+".pth") 81 | 82 | model_weight_file = "./weights/GPEN-512.pth" 83 | 84 | model_setenv(force_cpu) 85 | torch_model = get_model(model_weight_file) 86 | torch_model.eval() 87 | 88 | onnx_file_name = os.path.join(path, model+".onnx") 89 | onnxruntime_engine = onnxruntime.InferenceSession(onnx_file_name) 90 | 91 | def to_numpy(tensor): 92 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() 93 | 94 | dummy_input = Variable(torch.randn(1, 3, 512, 512)) 95 | with torch.no_grad(): 96 | torch_output, _ = torch_model(dummy_input) 97 | onnxruntime_inputs = {onnxruntime_engine.get_inputs()[0].name: to_numpy(dummy_input)} 98 | onnxruntime_outputs = onnxruntime_engine.run(None, onnxruntime_inputs) 99 | np.testing.assert_allclose(to_numpy(torch_output), onnxruntime_outputs[0], rtol=1e-02, atol=1e-02) 100 | print("Example: Onnx model has been tested with ONNXRuntime, the result looks good !") 101 | 102 | def get_model(checkpoint): 103 | """Create encoder model.""" 104 | 105 | #model_setenv() 106 | model = FullGenerator(1024, 512, 8, 2, narrow=1) #TODO 107 | model_load(model, checkpoint) 108 | device = model_device() 109 | model.to(device) 110 | return model 111 | 112 | 113 | def model_device(): 114 | """Please call after model_setenv. """ 115 | 116 | return torch.device(os.environ["DEVICE"]) 117 | 118 | 119 | def model_setenv(cpu_only): 120 | """Setup environ ...""" 121 | 122 | # random init ... 123 | import random 124 | random.seed(42) 125 | torch.manual_seed(42) 126 | 127 | # Set default device to avoid exceptions 128 | if cpu_only: 129 | os.environ["DEVICE"] = 'cpu' 130 | else: 131 | if os.environ.get("DEVICE") != "cuda" and os.environ.get("DEVICE") != "cpu": 132 | os.environ["DEVICE"] = 'cuda' if torch.cuda.is_available() else 'cpu' 133 | if os.environ["DEVICE"] == 'cuda': 134 | torch.backends.cudnn.enabled = True 135 | torch.backends.cudnn.benchmark = True 136 | 137 | print("Running Environment:") 138 | print("----------------------------------------------") 139 | #print(" PWD: ", os.environ["PWD"]) 140 | print(" DEVICE: ", os.environ["DEVICE"]) 141 | 142 | # def export_torch(model, path): 143 | # """Export torch model.""" 144 | 145 | # script_file = os.path.join(path, model+".pt") 146 | # weight_file = os.path.join(path, model+".onnx") 147 | 148 | # # 1. Load model 149 | # print("Loading model ...") 150 | # model = get_model(weight_file) 151 | # model.eval() 152 | 153 | # # 2. Model export 154 | # print("Export model ...") 155 | # dummy_input = Variable(torch.randn(1, 3, 512, 512)) 156 | # device = model_device() 157 | # traced_script_module = torch.jit.trace(model, dummy_input.to(device), _force_outplace=True) 158 | # traced_script_module.save(script_file) 159 | 160 | 161 | if __name__ == '__main__': 162 | """Test model ...""" 163 | import argparse 164 | 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument('--model', type=str, required=True) 167 | parser.add_argument('--path', type=str, default='./') 168 | parser.add_argument('--export', help="Export onnx model", action='store_true') 169 | parser.add_argument('--verify', help="Verify onnx model", action='store_true') 170 | parser.add_argument('--force-cpu', dest='force_cpu', help="Verify onnx model", action='store_true') 171 | 172 | args = parser.parse_args() 173 | 174 | # export_torch() 175 | 176 | 177 | 178 | if args.export: 179 | export_onnx(model = args.model, path = args.path, force_cpu=args.force_cpu) 180 | 181 | if args.verify: 182 | verify_onnx(model = args.model, path = args.path, force_cpu=args.force_cpu) 183 | -------------------------------------------------------------------------------- /gpeno/face_model/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load, _import_module_from_library 8 | 9 | # if running GPEN without cuda, please comment line 10-18 10 | if platform.system() == 'Linux' and torch.cuda.is_available(): 11 | module_path = os.path.dirname(__file__) 12 | upfirdn2d_op = load( 13 | 'upfirdn2d', 14 | sources=[ 15 | os.path.join(module_path, 'upfirdn2d.cpp'), 16 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 17 | ], 18 | ) 19 | 20 | 21 | #upfirdn2d_op = _import_module_from_library('upfirdn2d', '/tmp/torch_extensions/upfirdn2d', True) 22 | 23 | class UpFirDn2dBackward(Function): 24 | @staticmethod 25 | def forward( 26 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 27 | ): 28 | 29 | up_x, up_y = up 30 | down_x, down_y = down 31 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 32 | 33 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 34 | 35 | grad_input = upfirdn2d_op.upfirdn2d( 36 | grad_output, 37 | grad_kernel, 38 | down_x, 39 | down_y, 40 | up_x, 41 | up_y, 42 | g_pad_x0, 43 | g_pad_x1, 44 | g_pad_y0, 45 | g_pad_y1, 46 | ) 47 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 48 | 49 | ctx.save_for_backward(kernel) 50 | 51 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 52 | 53 | ctx.up_x = up_x 54 | ctx.up_y = up_y 55 | ctx.down_x = down_x 56 | ctx.down_y = down_y 57 | ctx.pad_x0 = pad_x0 58 | ctx.pad_x1 = pad_x1 59 | ctx.pad_y0 = pad_y0 60 | ctx.pad_y1 = pad_y1 61 | ctx.in_size = in_size 62 | ctx.out_size = out_size 63 | 64 | return grad_input 65 | 66 | @staticmethod 67 | def backward(ctx, gradgrad_input): 68 | kernel, = ctx.saved_tensors 69 | 70 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 71 | 72 | gradgrad_out = upfirdn2d_op.upfirdn2d( 73 | gradgrad_input, 74 | kernel, 75 | ctx.up_x, 76 | ctx.up_y, 77 | ctx.down_x, 78 | ctx.down_y, 79 | ctx.pad_x0, 80 | ctx.pad_x1, 81 | ctx.pad_y0, 82 | ctx.pad_y1, 83 | ) 84 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 85 | gradgrad_out = gradgrad_out.view( 86 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 87 | ) 88 | 89 | return gradgrad_out, None, None, None, None, None, None, None, None 90 | 91 | 92 | class UpFirDn2d(Function): 93 | @staticmethod 94 | def forward(ctx, input, kernel, up, down, pad): 95 | up_x, up_y = up 96 | down_x, down_y = down 97 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 98 | 99 | kernel_h, kernel_w = kernel.shape 100 | batch, channel, in_h, in_w = input.shape 101 | ctx.in_size = input.shape 102 | 103 | input = input.reshape(-1, in_h, in_w, 1) 104 | 105 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 106 | 107 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 108 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 109 | ctx.out_size = (out_h, out_w) 110 | 111 | ctx.up = (up_x, up_y) 112 | ctx.down = (down_x, down_y) 113 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 114 | 115 | g_pad_x0 = kernel_w - pad_x0 - 1 116 | g_pad_y0 = kernel_h - pad_y0 - 1 117 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 118 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 119 | 120 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 121 | 122 | out = upfirdn2d_op.upfirdn2d( 123 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 124 | ) 125 | # out = out.view(major, out_h, out_w, minor) 126 | out = out.view(-1, channel, out_h, out_w) 127 | 128 | return out 129 | 130 | @staticmethod 131 | def backward(ctx, grad_output): 132 | kernel, grad_kernel = ctx.saved_tensors 133 | 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0), device='cpu'): 150 | if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu': 151 | out = UpFirDn2d.apply( 152 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 153 | ) 154 | else: 155 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 156 | 157 | return out 158 | 159 | 160 | def upfirdn2d_native( 161 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 162 | ): 163 | input = input.permute(0, 2, 3, 1) 164 | _, in_h, in_w, minor = input.shape 165 | kernel_h, kernel_w = kernel.shape 166 | out = input.view(-1, in_h, 1, in_w, 1, minor) 167 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 168 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 169 | 170 | out = F.pad( 171 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 172 | ) 173 | out = out[ 174 | :, 175 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 176 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 177 | :, 178 | ] 179 | 180 | out = out.permute(0, 3, 1, 2) 181 | out = out.reshape( 182 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 183 | ) 184 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 185 | out = F.conv2d(out, w) 186 | out = out.reshape( 187 | -1, 188 | minor, 189 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 190 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 191 | ) 192 | # out = out.permute(0, 2, 3, 1) 193 | return out[:, :, ::down_y, ::down_x] 194 | 195 | -------------------------------------------------------------------------------- /gpeno/demo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) 3 | @author: yangxy (yangtao9009@gmail.com) 4 | ''' 5 | import os 6 | import cv2 7 | import glob 8 | import time 9 | import math 10 | import argparse 11 | import numpy as np 12 | from PIL import Image, ImageDraw 13 | import __init_paths 14 | from face_enhancement import FaceEnhancement 15 | from face_colorization import FaceColorization 16 | from face_inpainting import FaceInpainting 17 | from segmentation2face import Segmentation2Face 18 | 19 | def brush_stroke_mask(img, color=(255,255,255)): 20 | min_num_vertex = 8 21 | max_num_vertex = 28 22 | mean_angle = 2*math.pi / 5 23 | angle_range = 2*math.pi / 15 24 | min_width = 12 25 | max_width = 80 26 | def generate_mask(H, W, img=None): 27 | average_radius = math.sqrt(H*H+W*W) / 8 28 | mask = Image.new('RGB', (W, H), 0) 29 | if img is not None: mask = img #Image.fromarray(img) 30 | 31 | for _ in range(np.random.randint(1, 4)): 32 | num_vertex = np.random.randint(min_num_vertex, max_num_vertex) 33 | angle_min = mean_angle - np.random.uniform(0, angle_range) 34 | angle_max = mean_angle + np.random.uniform(0, angle_range) 35 | angles = [] 36 | vertex = [] 37 | for i in range(num_vertex): 38 | if i % 2 == 0: 39 | angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) 40 | else: 41 | angles.append(np.random.uniform(angle_min, angle_max)) 42 | 43 | h, w = mask.size 44 | vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) 45 | for i in range(num_vertex): 46 | r = np.clip( 47 | np.random.normal(loc=average_radius, scale=average_radius//2), 48 | 0, 2*average_radius) 49 | new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) 50 | new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) 51 | vertex.append((int(new_x), int(new_y))) 52 | 53 | draw = ImageDraw.Draw(mask) 54 | width = int(np.random.uniform(min_width, max_width)) 55 | draw.line(vertex, fill=color, width=width) 56 | for v in vertex: 57 | draw.ellipse((v[0] - width//2, 58 | v[1] - width//2, 59 | v[0] + width//2, 60 | v[1] + width//2), 61 | fill=color) 62 | 63 | return mask 64 | 65 | width, height = img.size 66 | mask = generate_mask(height, width, img) 67 | return mask 68 | 69 | if __name__=='__main__': 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--model', type=str, default='GPEN-BFR-512', help='GPEN model') 72 | parser.add_argument('--task', type=str, default='FaceEnhancement', help='task of GPEN model') 73 | parser.add_argument('--key', type=str, default=None, help='key of GPEN model') 74 | parser.add_argument('--in_size', type=int, default=512, help='in resolution of GPEN') 75 | parser.add_argument('--out_size', type=int, default=None, help='out resolution of GPEN') 76 | parser.add_argument('--channel_multiplier', type=int, default=2, help='channel multiplier of GPEN') 77 | parser.add_argument('--narrow', type=float, default=1, help='channel narrow scale') 78 | parser.add_argument('--alpha', type=float, default=1, help='blending the results') 79 | parser.add_argument('--use_sr', action='store_true', help='use sr or not') 80 | parser.add_argument('--use_cuda', action='store_true', help='use cuda or not') 81 | parser.add_argument('--save_face', action='store_true', help='save face or not') 82 | parser.add_argument('--aligned', action='store_true', help='input are aligned faces or not') 83 | parser.add_argument('--sr_model', type=str, default='realesrnet', help='SR model') 84 | parser.add_argument('--sr_scale', type=int, default=2, help='SR scale') 85 | parser.add_argument('--tile_size', type=int, default=0, help='tile size for SR to avoid OOM') 86 | parser.add_argument('--indir', type=str, default='examples/imgs', help='input folder') 87 | parser.add_argument('--outdir', type=str, default='results/outs-BFR', help='output folder') 88 | parser.add_argument('--ext', type=str, default='.jpg', help='extension of output') 89 | args = parser.parse_args() 90 | 91 | #model = {'name':'GPEN-BFR-512', 'size':512, 'channel_multiplier':2, 'narrow':1} 92 | #model = {'name':'GPEN-BFR-256', 'size':256, 'channel_multiplier':1, 'narrow':0.5} 93 | 94 | os.makedirs(args.outdir, exist_ok=True) 95 | 96 | if args.task == 'FaceEnhancement': 97 | processer = FaceEnhancement(args, in_size=args.in_size, model=args.model, use_sr=args.use_sr, device='cuda' if args.use_cuda else 'cpu') 98 | elif args.task == 'FaceColorization': 99 | processer = FaceColorization(in_size=args.in_size, model=args.model, device='cuda' if args.use_cuda else 'cpu') 100 | elif args.task == 'FaceInpainting': 101 | processer = FaceInpainting(in_size=args.in_size, model=args.model, device='cuda' if args.use_cuda else 'cpu') 102 | elif args.task == 'Segmentation2Face': 103 | processer = Segmentation2Face(in_size=args.in_size, model=args.model, is_norm=False, device='cuda' if args.use_cuda else 'cpu') 104 | 105 | 106 | files = sorted(glob.glob(os.path.join(args.indir, '*.*g'))) 107 | for n, file in enumerate(files[:]): 108 | filename, ext = os.path.splitext(os.path.basename(file)) 109 | 110 | img = cv2.imread(file, cv2.IMREAD_COLOR) # BGR 111 | if not isinstance(img, np.ndarray): print(filename, 'error'); continue 112 | #img = cv2.resize(img, (0,0), fx=2, fy=2) # optional 113 | 114 | if args.task == 'FaceInpainting': 115 | img = np.asarray(brush_stroke_mask(Image.fromarray(img))) 116 | 117 | img_out, orig_faces, enhanced_faces = processer.process(img, aligned=args.aligned) 118 | 119 | img = cv2.resize(img, img_out.shape[:2][::-1]) 120 | cv2.imwrite(f'{args.outdir}/{filename}_COMP{args.ext}', np.hstack((img, img_out))) 121 | cv2.imwrite(f'{args.outdir}/{filename}_GPEN{args.ext}', img_out) 122 | 123 | if args.save_face: 124 | for m, (ef, of) in enumerate(zip(enhanced_faces, orig_faces)): 125 | of = cv2.resize(of, ef.shape[:2]) 126 | cv2.imwrite(f'{args.outdir}/{filename}_face{m:02d}{args.ext}', np.hstack((of, ef))) 127 | 128 | if n%10==0: print(n, filename) 129 | -------------------------------------------------------------------------------- /gpeno/training/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | # from torch.autograd import Variable 9 | 10 | from lpips.trainer import * 11 | from lpips.lpips import * 12 | 13 | # class PerceptualLoss(torch.nn.Module): 14 | # def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | # super(PerceptualLoss, self).__init__() 17 | # print('Setting up Perceptual loss...') 18 | # self.use_gpu = use_gpu 19 | # self.spatial = spatial 20 | # self.gpu_ids = gpu_ids 21 | # self.model = dist_model.DistModel() 22 | # self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 23 | # print('...[%s] initialized'%self.model.name()) 24 | # print('...Done') 25 | 26 | # def forward(self, pred, target, normalize=False): 27 | # """ 28 | # Pred and target are Variables. 29 | # If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | # If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | # Inputs pred and target are Nx3xHxW 33 | # Output pytorch Variable N long 34 | # """ 35 | 36 | # if normalize: 37 | # target = 2 * target - 1 38 | # pred = 2 * pred - 1 39 | 40 | # return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | from skimage.measure import compare_ssim 54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 55 | 56 | def rgb2lab(in_img,mean_cent=False): 57 | from skimage import color 58 | img_lab = color.rgb2lab(in_img) 59 | if(mean_cent): 60 | img_lab[:,:,0] = img_lab[:,:,0]-50 61 | return img_lab 62 | 63 | def tensor2np(tensor_obj): 64 | # change dimension of a tensor object into a numpy array 65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 66 | 67 | def np2tensor(np_obj): 68 | # change dimenion of np array into tensor array 69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 70 | 71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 72 | # image tensor to lab tensor 73 | from skimage import color 74 | 75 | img = tensor2im(image_tensor) 76 | img_lab = color.rgb2lab(img) 77 | if(mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | if(to_norm and not mc_only): 80 | img_lab[:,:,0] = img_lab[:,:,0]-50 81 | img_lab = img_lab/100. 82 | 83 | return np2tensor(img_lab) 84 | 85 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 86 | from skimage import color 87 | import warnings 88 | warnings.filterwarnings("ignore") 89 | 90 | lab = tensor2np(lab_tensor)*100. 91 | lab[:,:,0] = lab[:,:,0]+50 92 | 93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 94 | if(return_inbnd): 95 | # convert back to lab, see if we match 96 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 97 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 99 | return (im2tensor(rgb_back),mask) 100 | else: 101 | return im2tensor(rgb_back) 102 | 103 | def load_image(path): 104 | if(path[-3:] == 'dng'): 105 | import rawpy 106 | with rawpy.imread(path) as raw: 107 | img = raw.postprocess() 108 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'): 109 | import cv2 110 | return cv2.imread(path)[:,:,::-1] 111 | else: 112 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 113 | 114 | return img 115 | 116 | def rgb2lab(input): 117 | from skimage import color 118 | return color.rgb2lab(input / 255.) 119 | 120 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 121 | image_numpy = image_tensor[0].cpu().float().numpy() 122 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 123 | return image_numpy.astype(imtype) 124 | 125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 126 | return torch.Tensor((image / factor - cent) 127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 128 | 129 | def tensor2vec(vector_tensor): 130 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 131 | 132 | 133 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 134 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 135 | image_numpy = image_tensor[0].cpu().float().numpy() 136 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 137 | return image_numpy.astype(imtype) 138 | 139 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 140 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 141 | return torch.Tensor((image / factor - cent) 142 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 143 | 144 | 145 | 146 | def voc_ap(rec, prec, use_07_metric=False): 147 | """ ap = voc_ap(rec, prec, [use_07_metric]) 148 | Compute VOC AP given precision and recall. 149 | If use_07_metric is true, uses the 150 | VOC 07 11 point method (default:False). 151 | """ 152 | if use_07_metric: 153 | # 11 point metric 154 | ap = 0. 155 | for t in np.arange(0., 1.1, 0.1): 156 | if np.sum(rec >= t) == 0: 157 | p = 0 158 | else: 159 | p = np.max(prec[rec >= t]) 160 | ap = ap + p / 11. 161 | else: 162 | # correct AP calculation 163 | # first append sentinel values at the end 164 | mrec = np.concatenate(([0.], rec, [1.])) 165 | mpre = np.concatenate(([0.], prec, [0.])) 166 | 167 | # compute the precision envelope 168 | for i in range(mpre.size - 1, 0, -1): 169 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 170 | 171 | # to calculate area under PR curve, look for points 172 | # where X axis (recall) changes value 173 | i = np.where(mrec[1:] != mrec[:-1])[0] 174 | 175 | # and sum (\Delta recall) * prec 176 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 177 | return ap 178 | 179 | -------------------------------------------------------------------------------- /gpeno/face_detect/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | from gpeno.face_detect.utils.box_utils import matrix_iof 5 | 6 | 7 | def _crop(image, boxes, labels, landm, img_dim): 8 | height, width, _ = image.shape 9 | pad_image_flag = True 10 | 11 | for _ in range(250): 12 | """ 13 | if random.uniform(0, 1) <= 0.2: 14 | scale = 1.0 15 | else: 16 | scale = random.uniform(0.3, 1.0) 17 | """ 18 | PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0] 19 | scale = random.choice(PRE_SCALES) 20 | short_side = min(width, height) 21 | w = int(scale * short_side) 22 | h = w 23 | 24 | if width == w: 25 | l = 0 26 | else: 27 | l = random.randrange(width - w) 28 | if height == h: 29 | t = 0 30 | else: 31 | t = random.randrange(height - h) 32 | roi = np.array((l, t, l + w, t + h)) 33 | 34 | value = matrix_iof(boxes, roi[np.newaxis]) 35 | flag = (value >= 1) 36 | if not flag.any(): 37 | continue 38 | 39 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2 40 | mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) 41 | boxes_t = boxes[mask_a].copy() 42 | labels_t = labels[mask_a].copy() 43 | landms_t = landm[mask_a].copy() 44 | landms_t = landms_t.reshape([-1, 5, 2]) 45 | 46 | if boxes_t.shape[0] == 0: 47 | continue 48 | 49 | image_t = image[roi[1]:roi[3], roi[0]:roi[2]] 50 | 51 | boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) 52 | boxes_t[:, :2] -= roi[:2] 53 | boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) 54 | boxes_t[:, 2:] -= roi[:2] 55 | 56 | # landm 57 | landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2] 58 | landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0])) 59 | landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2]) 60 | landms_t = landms_t.reshape([-1, 10]) 61 | 62 | # make sure that the cropped image contains at least one face > 16 pixel at training image scale 63 | b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim 64 | b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim 65 | mask_b = np.minimum(b_w_t, b_h_t) > 0.0 66 | boxes_t = boxes_t[mask_b] 67 | labels_t = labels_t[mask_b] 68 | landms_t = landms_t[mask_b] 69 | 70 | if boxes_t.shape[0] == 0: 71 | continue 72 | 73 | pad_image_flag = False 74 | 75 | return image_t, boxes_t, labels_t, landms_t, pad_image_flag 76 | return image, boxes, labels, landm, pad_image_flag 77 | 78 | 79 | def _distort(image): 80 | 81 | def _convert(image, alpha=1, beta=0): 82 | tmp = image.astype(float) * alpha + beta 83 | tmp[tmp < 0] = 0 84 | tmp[tmp > 255] = 255 85 | image[:] = tmp 86 | 87 | image = image.copy() 88 | 89 | if random.randrange(2): 90 | 91 | #brightness distortion 92 | if random.randrange(2): 93 | _convert(image, beta=random.uniform(-32, 32)) 94 | 95 | #contrast distortion 96 | if random.randrange(2): 97 | _convert(image, alpha=random.uniform(0.5, 1.5)) 98 | 99 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 100 | 101 | #saturation distortion 102 | if random.randrange(2): 103 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 104 | 105 | #hue distortion 106 | if random.randrange(2): 107 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 108 | tmp %= 180 109 | image[:, :, 0] = tmp 110 | 111 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 112 | 113 | else: 114 | 115 | #brightness distortion 116 | if random.randrange(2): 117 | _convert(image, beta=random.uniform(-32, 32)) 118 | 119 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 120 | 121 | #saturation distortion 122 | if random.randrange(2): 123 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 124 | 125 | #hue distortion 126 | if random.randrange(2): 127 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 128 | tmp %= 180 129 | image[:, :, 0] = tmp 130 | 131 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 132 | 133 | #contrast distortion 134 | if random.randrange(2): 135 | _convert(image, alpha=random.uniform(0.5, 1.5)) 136 | 137 | return image 138 | 139 | 140 | def _expand(image, boxes, fill, p): 141 | if random.randrange(2): 142 | return image, boxes 143 | 144 | height, width, depth = image.shape 145 | 146 | scale = random.uniform(1, p) 147 | w = int(scale * width) 148 | h = int(scale * height) 149 | 150 | left = random.randint(0, w - width) 151 | top = random.randint(0, h - height) 152 | 153 | boxes_t = boxes.copy() 154 | boxes_t[:, :2] += (left, top) 155 | boxes_t[:, 2:] += (left, top) 156 | expand_image = np.empty((h, w, depth), dtype=image.dtype) 157 | expand_image[:, :] = fill 158 | expand_image[top:top + height, left:left + width] = image 159 | image = expand_image 160 | 161 | return image, boxes_t 162 | 163 | 164 | def _mirror(image, boxes, landms): 165 | _, width, _ = image.shape 166 | if random.randrange(2): 167 | image = image[:, ::-1] 168 | boxes = boxes.copy() 169 | boxes[:, 0::2] = width - boxes[:, 2::-2] 170 | 171 | # landm 172 | landms = landms.copy() 173 | landms = landms.reshape([-1, 5, 2]) 174 | landms[:, :, 0] = width - landms[:, :, 0] 175 | tmp = landms[:, 1, :].copy() 176 | landms[:, 1, :] = landms[:, 0, :] 177 | landms[:, 0, :] = tmp 178 | tmp1 = landms[:, 4, :].copy() 179 | landms[:, 4, :] = landms[:, 3, :] 180 | landms[:, 3, :] = tmp1 181 | landms = landms.reshape([-1, 10]) 182 | 183 | return image, boxes, landms 184 | 185 | 186 | def _pad_to_square(image, rgb_mean, pad_image_flag): 187 | if not pad_image_flag: 188 | return image 189 | height, width, _ = image.shape 190 | long_side = max(width, height) 191 | image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) 192 | image_t[:, :] = rgb_mean 193 | image_t[0:0 + height, 0:0 + width] = image 194 | return image_t 195 | 196 | 197 | def _resize_subtract_mean(image, insize, rgb_mean): 198 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] 199 | interp_method = interp_methods[random.randrange(5)] 200 | image = cv2.resize(image, (insize, insize), interpolation=interp_method) 201 | image = image.astype(np.float32) 202 | image -= rgb_mean 203 | return image.transpose(2, 0, 1) 204 | 205 | 206 | class preproc(object): 207 | 208 | def __init__(self, img_dim, rgb_means): 209 | self.img_dim = img_dim 210 | self.rgb_means = rgb_means 211 | 212 | def __call__(self, image, targets): 213 | assert targets.shape[0] > 0, "this image does not have gt" 214 | 215 | boxes = targets[:, :4].copy() 216 | labels = targets[:, -1].copy() 217 | landm = targets[:, 4:-1].copy() 218 | 219 | image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) 220 | image_t = _distort(image_t) 221 | image_t = _pad_to_square(image_t, self.rgb_means, pad_image_flag) 222 | image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t) 223 | height, width, _ = image_t.shape 224 | image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) 225 | boxes_t[:, 0::2] /= width 226 | boxes_t[:, 1::2] /= height 227 | 228 | landm_t[:, 0::2] /= width 229 | landm_t[:, 1::2] /= height 230 | 231 | labels_t = np.expand_dims(labels_t, 1) 232 | targets_t = np.hstack((boxes_t, landm_t, labels_t)) 233 | 234 | return image_t, targets_t 235 | -------------------------------------------------------------------------------- /gpeno/training/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | class squeezenet(torch.nn.Module): 6 | def __init__(self, requires_grad=False, pretrained=True): 7 | super(squeezenet, self).__init__() 8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 9 | self.slice1 = torch.nn.Sequential() 10 | self.slice2 = torch.nn.Sequential() 11 | self.slice3 = torch.nn.Sequential() 12 | self.slice4 = torch.nn.Sequential() 13 | self.slice5 = torch.nn.Sequential() 14 | self.slice6 = torch.nn.Sequential() 15 | self.slice7 = torch.nn.Sequential() 16 | self.N_slices = 7 17 | for x in range(2): 18 | self.slice1.add_module(str(x), pretrained_features[x]) 19 | for x in range(2,5): 20 | self.slice2.add_module(str(x), pretrained_features[x]) 21 | for x in range(5, 8): 22 | self.slice3.add_module(str(x), pretrained_features[x]) 23 | for x in range(8, 10): 24 | self.slice4.add_module(str(x), pretrained_features[x]) 25 | for x in range(10, 11): 26 | self.slice5.add_module(str(x), pretrained_features[x]) 27 | for x in range(11, 12): 28 | self.slice6.add_module(str(x), pretrained_features[x]) 29 | for x in range(12, 13): 30 | self.slice7.add_module(str(x), pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h = self.slice1(X) 37 | h_relu1 = h 38 | h = self.slice2(h) 39 | h_relu2 = h 40 | h = self.slice3(h) 41 | h_relu3 = h 42 | h = self.slice4(h) 43 | h_relu4 = h 44 | h = self.slice5(h) 45 | h_relu5 = h 46 | h = self.slice6(h) 47 | h_relu6 = h 48 | h = self.slice7(h) 49 | h_relu7 = h 50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 52 | 53 | return out 54 | 55 | 56 | class alexnet(torch.nn.Module): 57 | def __init__(self, requires_grad=False, pretrained=True): 58 | super(alexnet, self).__init__() 59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 60 | self.slice1 = torch.nn.Sequential() 61 | self.slice2 = torch.nn.Sequential() 62 | self.slice3 = torch.nn.Sequential() 63 | self.slice4 = torch.nn.Sequential() 64 | self.slice5 = torch.nn.Sequential() 65 | self.N_slices = 5 66 | for x in range(2): 67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 68 | for x in range(2, 5): 69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 70 | for x in range(5, 8): 71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(8, 10): 73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(10, 12): 75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 76 | if not requires_grad: 77 | for param in self.parameters(): 78 | param.requires_grad = False 79 | 80 | def forward(self, X): 81 | h = self.slice1(X) 82 | h_relu1 = h 83 | h = self.slice2(h) 84 | h_relu2 = h 85 | h = self.slice3(h) 86 | h_relu3 = h 87 | h = self.slice4(h) 88 | h_relu4 = h 89 | h = self.slice5(h) 90 | h_relu5 = h 91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 93 | 94 | return out 95 | 96 | class vgg16(torch.nn.Module): 97 | def __init__(self, requires_grad=False, pretrained=True): 98 | super(vgg16, self).__init__() 99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 100 | self.slice1 = torch.nn.Sequential() 101 | self.slice2 = torch.nn.Sequential() 102 | self.slice3 = torch.nn.Sequential() 103 | self.slice4 = torch.nn.Sequential() 104 | self.slice5 = torch.nn.Sequential() 105 | self.N_slices = 5 106 | for x in range(4): 107 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(4, 9): 109 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(9, 16): 111 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(16, 23): 113 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(23, 30): 115 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 116 | if not requires_grad: 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | def forward(self, X): 121 | h = self.slice1(X) 122 | h_relu1_2 = h 123 | h = self.slice2(h) 124 | h_relu2_2 = h 125 | h = self.slice3(h) 126 | h_relu3_3 = h 127 | h = self.slice4(h) 128 | h_relu4_3 = h 129 | h = self.slice5(h) 130 | h_relu5_3 = h 131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 133 | 134 | return out 135 | 136 | 137 | 138 | class resnet(torch.nn.Module): 139 | def __init__(self, requires_grad=False, pretrained=True, num=18): 140 | super(resnet, self).__init__() 141 | if(num==18): 142 | self.net = tv.resnet18(pretrained=pretrained) 143 | elif(num==34): 144 | self.net = tv.resnet34(pretrained=pretrained) 145 | elif(num==50): 146 | self.net = tv.resnet50(pretrained=pretrained) 147 | elif(num==101): 148 | self.net = tv.resnet101(pretrained=pretrained) 149 | elif(num==152): 150 | self.net = tv.resnet152(pretrained=pretrained) 151 | self.N_slices = 5 152 | 153 | self.conv1 = self.net.conv1 154 | self.bn1 = self.net.bn1 155 | self.relu = self.net.relu 156 | self.maxpool = self.net.maxpool 157 | self.layer1 = self.net.layer1 158 | self.layer2 = self.net.layer2 159 | self.layer3 = self.net.layer3 160 | self.layer4 = self.net.layer4 161 | 162 | def forward(self, X): 163 | h = self.conv1(X) 164 | h = self.bn1(h) 165 | h = self.relu(h) 166 | h_relu1 = h 167 | h = self.maxpool(h) 168 | h = self.layer1(h) 169 | h_conv2 = h 170 | h = self.layer2(h) 171 | h_conv3 = h 172 | h = self.layer3(h) 173 | h_conv4 = h 174 | h = self.layer4(h) 175 | h_conv5 = h 176 | 177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 179 | 180 | return out 181 | -------------------------------------------------------------------------------- /gpeno/align_faces.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 24 15:43:29 2017 4 | @author: zhaoy 5 | """ 6 | """ 7 | @Modified by yangxy (yangtao9009@gmail.com) 8 | """ 9 | import cv2 10 | import numpy as np 11 | from skimage import transform as trans 12 | 13 | # reference facial points, a list of coordinates (x,y) 14 | REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], [33.54930115, 92.3655014], [62.72990036, 92.20410156]] 15 | 16 | DEFAULT_CROP_SIZE = (96, 112) 17 | 18 | 19 | def _umeyama(src, dst, estimate_scale=True, scale=1.0): 20 | """Estimate N-D similarity transformation with or without scaling. 21 | Parameters 22 | ---------- 23 | src : (M, N) array 24 | Source coordinates. 25 | dst : (M, N) array 26 | Destination coordinates. 27 | estimate_scale : bool 28 | Whether to estimate scaling factor. 29 | Returns 30 | ------- 31 | T : (N + 1, N + 1) 32 | The homogeneous similarity transformation matrix. The matrix contains 33 | NaN values only if the problem is not well-conditioned. 34 | References 35 | ---------- 36 | .. [1] "Least-squares estimation of transformation parameters between two 37 | point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` 38 | """ 39 | 40 | num = src.shape[0] 41 | dim = src.shape[1] 42 | 43 | # Compute mean of src and dst. 44 | src_mean = src.mean(axis=0) 45 | dst_mean = dst.mean(axis=0) 46 | 47 | # Subtract mean from src and dst. 48 | src_demean = src - src_mean 49 | dst_demean = dst - dst_mean 50 | 51 | # Eq. (38). 52 | A = dst_demean.T @ src_demean / num 53 | 54 | # Eq. (39). 55 | d = np.ones((dim, ), dtype=np.double) 56 | if np.linalg.det(A) < 0: 57 | d[dim - 1] = -1 58 | 59 | T = np.eye(dim + 1, dtype=np.double) 60 | 61 | U, S, V = np.linalg.svd(A) 62 | 63 | # Eq. (40) and (43). 64 | rank = np.linalg.matrix_rank(A) 65 | if rank == 0: 66 | return np.nan * T 67 | elif rank == dim - 1: 68 | if np.linalg.det(U) * np.linalg.det(V) > 0: 69 | T[:dim, :dim] = U @ V 70 | else: 71 | s = d[dim - 1] 72 | d[dim - 1] = -1 73 | T[:dim, :dim] = U @ np.diag(d) @ V 74 | d[dim - 1] = s 75 | else: 76 | T[:dim, :dim] = U @ np.diag(d) @ V 77 | 78 | if estimate_scale: 79 | # Eq. (41) and (42). 80 | scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) 81 | else: 82 | scale = scale 83 | 84 | T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) 85 | T[:dim, :dim] *= scale 86 | 87 | return T, scale 88 | 89 | 90 | class FaceWarpException(Exception): 91 | 92 | def __str__(self): 93 | return 'In File {}:{}'.format(__file__, super.__str__(self)) 94 | 95 | 96 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): 97 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 98 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 99 | 100 | # 0) make the inner region a square 101 | if default_square: 102 | size_diff = max(tmp_crop_size) - tmp_crop_size 103 | tmp_5pts += size_diff / 2 104 | tmp_crop_size += size_diff 105 | 106 | if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): 107 | print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size)) 108 | return tmp_5pts 109 | 110 | if (inner_padding_factor == 0 and outer_padding == (0, 0)): 111 | if output_size is None: 112 | print('No paddings to do: return default reference points') 113 | return tmp_5pts 114 | else: 115 | raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) 116 | 117 | # check output size 118 | if not (0 <= inner_padding_factor <= 1.0): 119 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') 120 | 121 | if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): 122 | output_size = tmp_crop_size * \ 123 | (1 + inner_padding_factor * 2).astype(np.int32) 124 | output_size += np.array(outer_padding) 125 | print(' deduced from paddings, output_size = ', output_size) 126 | 127 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): 128 | raise FaceWarpException('Not (outer_padding[0] < output_size[0]' 129 | 'and outer_padding[1] < output_size[1])') 130 | 131 | # 1) pad the inner region according inner_padding_factor 132 | # print('---> STEP1: pad the inner region according inner_padding_factor') 133 | if inner_padding_factor > 0: 134 | size_diff = tmp_crop_size * inner_padding_factor * 2 135 | tmp_5pts += size_diff / 2 136 | tmp_crop_size += np.round(size_diff).astype(np.int32) 137 | 138 | # print(' crop_size = ', tmp_crop_size) 139 | # print(' reference_5pts = ', tmp_5pts) 140 | 141 | # 2) resize the padded inner region 142 | # print('---> STEP2: resize the padded inner region') 143 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 144 | # print(' crop_size = ', tmp_crop_size) 145 | # print(' size_bf_outer_pad = ', size_bf_outer_pad) 146 | 147 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: 148 | raise FaceWarpException('Must have (output_size - outer_padding)' 149 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)') 150 | 151 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 152 | # print(' resize scale_factor = ', scale_factor) 153 | tmp_5pts = tmp_5pts * scale_factor 154 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) 155 | # tmp_5pts = tmp_5pts + size_diff / 2 156 | tmp_crop_size = size_bf_outer_pad 157 | # print(' crop_size = ', tmp_crop_size) 158 | # print(' reference_5pts = ', tmp_5pts) 159 | 160 | # 3) add outer_padding to make output_size 161 | reference_5point = tmp_5pts + np.array(outer_padding) 162 | tmp_crop_size = output_size 163 | # print('---> STEP3: add outer_padding to make output_size') 164 | # print(' crop_size = ', tmp_crop_size) 165 | # print(' reference_5pts = ', tmp_5pts) 166 | # 167 | # print('===> end get_reference_facial_points\n') 168 | 169 | return reference_5point 170 | 171 | 172 | def get_affine_transform_matrix(src_pts, dst_pts): 173 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 174 | n_pts = src_pts.shape[0] 175 | ones = np.ones((n_pts, 1), src_pts.dtype) 176 | src_pts_ = np.hstack([src_pts, ones]) 177 | dst_pts_ = np.hstack([dst_pts, ones]) 178 | 179 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 180 | 181 | if rank == 3: 182 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) 183 | elif rank == 2: 184 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) 185 | 186 | return tfm 187 | 188 | 189 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"): #smilarity cv2_affine affine 190 | if reference_pts is None: 191 | if crop_size[0] == 96 and crop_size[1] == 112: 192 | reference_pts = REFERENCE_FACIAL_POINTS 193 | else: 194 | default_square = False 195 | inner_padding_factor = 0 196 | outer_padding = (0, 0) 197 | output_size = crop_size 198 | 199 | reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, default_square) 200 | 201 | ref_pts = np.float32(reference_pts) 202 | ref_pts_shp = ref_pts.shape 203 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 204 | raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') 205 | 206 | if ref_pts_shp[0] == 2: 207 | ref_pts = ref_pts.T 208 | 209 | src_pts = np.float32(facial_pts) 210 | src_pts_shp = src_pts.shape 211 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 212 | raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') 213 | 214 | if src_pts_shp[0] == 2: 215 | src_pts = src_pts.T 216 | 217 | if src_pts.shape != ref_pts.shape: 218 | raise FaceWarpException('facial_pts and reference_pts must have the same shape') 219 | 220 | if align_type == 'cv2_affine': 221 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) 222 | tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) 223 | elif align_type == 'affine': 224 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 225 | tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) 226 | else: 227 | params, scale = _umeyama(src_pts, ref_pts) 228 | tfm = params[:2, :] 229 | 230 | params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) 231 | tfm_inv = params[:2, :] 232 | 233 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3) 234 | 235 | return face_img, tfm_inv 236 | -------------------------------------------------------------------------------- /gpeno/training/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from . import pretrained_networks as pn 9 | import torch.nn 10 | 11 | import lpips 12 | 13 | 14 | def spatial_average(in_tens, keepdim=True): 15 | return in_tens.mean([2, 3], keepdim=keepdim) 16 | 17 | 18 | def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W 19 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 20 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) 21 | 22 | 23 | # Learned perceptual metric 24 | class LPIPS(nn.Module): 25 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): 26 | # lpips - [True] means with linear calibration on top of base network 27 | # pretrained - [True] means load linear weights 28 | 29 | super(LPIPS, self).__init__() 30 | if (verbose): 31 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) 32 | 33 | self.pnet_type = net 34 | self.pnet_tune = pnet_tune 35 | self.pnet_rand = pnet_rand 36 | self.spatial = spatial 37 | self.lpips = lpips # false means baseline of just averaging all layers 38 | self.version = version 39 | self.scaling_layer = ScalingLayer() 40 | 41 | if (self.pnet_type in ['vgg', 'vgg16']): 42 | net_type = pn.vgg16 43 | self.chns = [64, 128, 256, 512, 512] 44 | elif (self.pnet_type == 'alex'): 45 | net_type = pn.alexnet 46 | self.chns = [64, 192, 384, 256, 256] 47 | elif (self.pnet_type == 'squeeze'): 48 | net_type = pn.squeezenet 49 | self.chns = [64, 128, 256, 384, 384, 512, 512] 50 | self.L = len(self.chns) 51 | 52 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 53 | 54 | if (lpips): 55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 60 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 61 | if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet 62 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 63 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 64 | self.lins += [self.lin5, self.lin6] 65 | self.lins = nn.ModuleList(self.lins) 66 | 67 | if (pretrained): 68 | if (model_path is None): 69 | import inspect 70 | import os 71 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'gpen/v%s/%s.pth' % (version, net))) 72 | 73 | if (verbose): 74 | print('Loading model from: %s' % model_path) 75 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) 76 | 77 | if (eval_mode): 78 | self.eval() 79 | 80 | def forward(self, in0, in1, retPerLayer=False, normalize=False): 81 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 82 | in0 = 2 * in0 - 1 83 | in1 = 2 * in1 - 1 84 | 85 | # v0.0 - original release had a bug, where input was not scaled 86 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (in0, in1) 87 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 88 | feats0, feats1, diffs = {}, {}, {} 89 | 90 | for kk in range(self.L): 91 | feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) 92 | diffs[kk] = (feats0[kk] - feats1[kk])**2 93 | 94 | if (self.lpips): 95 | if (self.spatial): 96 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 97 | else: 98 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 99 | else: 100 | if (self.spatial): 101 | res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 102 | else: 103 | res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] 104 | 105 | val = res[0] 106 | for l in range(1, self.L): 107 | val += res[l] 108 | 109 | # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 110 | # b = torch.max(self.lins[kk](feats0[kk]**2)) 111 | # for kk in range(self.L): 112 | # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 113 | # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) 114 | # a = a/self.L 115 | # from IPython import embed 116 | # embed() 117 | # return 10*torch.log10(b/a) 118 | 119 | if (retPerLayer): 120 | return (val, res) 121 | else: 122 | return val 123 | 124 | 125 | class ScalingLayer(nn.Module): 126 | def __init__(self): 127 | super(ScalingLayer, self).__init__() 128 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 129 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 130 | 131 | def forward(self, inp): 132 | return (inp - self.shift) / self.scale 133 | 134 | 135 | class NetLinLayer(nn.Module): 136 | ''' A single linear layer which does a 1x1 conv ''' 137 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 138 | super(NetLinLayer, self).__init__() 139 | 140 | layers = [ 141 | nn.Dropout(), 142 | ] if (use_dropout) else [] 143 | layers += [ 144 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 145 | ] 146 | self.model = nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | return self.model(x) 150 | 151 | 152 | class Dist2LogitLayer(nn.Module): 153 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 154 | def __init__(self, chn_mid=32, use_sigmoid=True): 155 | super(Dist2LogitLayer, self).__init__() 156 | 157 | layers = [ 158 | nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), 159 | ] 160 | layers += [ 161 | nn.LeakyReLU(0.2, True), 162 | ] 163 | layers += [ 164 | nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), 165 | ] 166 | layers += [ 167 | nn.LeakyReLU(0.2, True), 168 | ] 169 | layers += [ 170 | nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), 171 | ] 172 | if (use_sigmoid): 173 | layers += [ 174 | nn.Sigmoid(), 175 | ] 176 | self.model = nn.Sequential(*layers) 177 | 178 | def forward(self, d0, d1, eps=0.1): 179 | return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) 180 | 181 | 182 | class BCERankingLoss(nn.Module): 183 | def __init__(self, chn_mid=32): 184 | super(BCERankingLoss, self).__init__() 185 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 186 | # self.parameters = list(self.net.parameters()) 187 | self.loss = torch.nn.BCELoss() 188 | 189 | def forward(self, d0, d1, judge): 190 | per = (judge + 1.) / 2. 191 | self.logit = self.net.forward(d0, d1) 192 | return self.loss(self.logit, per) 193 | 194 | 195 | # L2, DSSIM metrics 196 | class FakeNet(nn.Module): 197 | def __init__(self, use_gpu=True, colorspace='Lab'): 198 | super(FakeNet, self).__init__() 199 | self.use_gpu = use_gpu 200 | self.colorspace = colorspace 201 | 202 | 203 | class L2(FakeNet): 204 | def forward(self, in0, in1, retPerLayer=None): 205 | assert (in0.size()[0] == 1) # currently only supports batchSize 1 206 | 207 | if (self.colorspace == 'RGB'): 208 | (N, C, X, Y) = in0.size() 209 | value = torch.mean(torch.mean(torch.mean((in0 - in1)**2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), dim=3).view(N) 210 | return value 211 | elif (self.colorspace == 'Lab'): 212 | value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data, to_norm=False)), lpips.tensor2np(lpips.tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') 213 | ret_var = Variable(torch.Tensor((value, ))) 214 | if (self.use_gpu): 215 | ret_var = ret_var.cuda() 216 | return ret_var 217 | 218 | 219 | class DSSIM(FakeNet): 220 | def forward(self, in0, in1, retPerLayer=None): 221 | assert (in0.size()[0] == 1) # currently only supports batchSize 1 222 | 223 | if (self.colorspace == 'RGB'): 224 | value = lpips.dssim(1. * lpips.tensor2im(in0.data), 1. * lpips.tensor2im(in1.data), range=255.).astype('float') 225 | elif (self.colorspace == 'Lab'): 226 | value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data, to_norm=False)), lpips.tensor2np(lpips.tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') 227 | ret_var = Variable(torch.Tensor((value, ))) 228 | if (self.use_gpu): 229 | ret_var = ret_var.cuda() 230 | return ret_var 231 | 232 | 233 | def print_network(net): 234 | num_params = 0 235 | for param in net.parameters(): 236 | num_params += param.numel() 237 | print('Network', net) 238 | print('Total number of parameters: %d' % num_params) 239 | -------------------------------------------------------------------------------- /gpeno/README.md: -------------------------------------------------------------------------------- 1 | # GPENO has various optimizations for Unprompted by Therefore Games. 2 | 3 | # GAN Prior Embedded Network for Blind Face Restoration in the Wild 4 | 5 | [Paper](https://arxiv.org/abs/2105.06070) | [Supplementary](https://www4.comp.polyu.edu.hk/~cslzhang/paper/GPEN-cvpr21-supp.pdf) | [Demo](https://vision.aliyun.com/experience/detail?spm=a211p3.14020179.J_7524944390.17.66cd4850wVDkUQ&tagName=facebody&children=EnhanceFace) | [ModelScope](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement/summary) 6 | 7 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/GPEN) 8 | 9 | [Tao Yang](https://cg.cs.tsinghua.edu.cn/people/~tyang)1, Peiran Ren1, Xuansong Xie1, [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)1,2 10 | _1[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Hangzhou, China_ 11 | _2[Department of Computing, The Hong Kong Polytechnic University](http://www.comp.polyu.edu.hk), Hong Kong, China_ 12 | 13 | #### Face Restoration 14 | 15 | 16 | 17 | 18 | 19 | 20 | #### Selfie Restoration 21 | 22 | 23 | 24 | #### Face Colorization 25 | 26 | 27 | 28 | #### Face Inpainting 29 | 30 | 31 | 32 | #### Conditional Image Synthesis (Seg2Face) 33 | 34 | 35 | 36 | ## News 37 | (2023-02-15) **GPEN-BFR-1024** and **GPEN-BFR-2048** are now publicly available. Please download them via \[[ModelScope2](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement-hires/summary)\]. 38 | 39 | (2023-02-15) We provide online demos via \[[ModelScope1](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement/summary)\] and \[[ModelScope2](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement-hires/summary)\]. 40 | 41 | (2022-05-16) Add x1 sr model. Add ``--tile_size`` to avoid OOM. 42 | 43 | (2022-03-15) Add x4 sr model. Try ``--sr_scale``. 44 | 45 | (2022-03-09) Add GPEN-BFR-2048 for selfies. I have to take it down due to commercial issues. Sorry about that. 46 | 47 | (2021-12-29) Add online demos [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/GPEN). Many thanks to [CJWBW](https://github.com/CJWBW) and [AK391](https://github.com/AK391). 48 | 49 | (2021-12-16) Release a simplified training code of GPEN. It differs from our implementation in the paper, but could achieve comparable performance. We strongly recommend to change the degradation model. 50 | 51 | (2021-12-09) Add face parsing to better paste restored faces back. 52 | 53 | (2021-12-09) GPEN can run on CPU now by simply discarding ``--use_cuda``. 54 | 55 | (2021-12-01) GPEN can now work on a Windows machine without compiling cuda codes. Please check it out. Thanks to [Animadversio](https://github.com/rosinality/stylegan2-pytorch/issues/81). Alternatively, you can try [GPEN-Windows](https://drive.google.com/file/d/1YJJVnPGq90e_mWZxSGGTptNQilZNfOEO/view?usp=drivesdk). Many thanks to [Cioscos](https://github.com/yangxy/GPEN/issues/74). 56 | 57 | (2021-10-22) GPEN can now work with SR methods. A SR model trained by myself is provided. Replace it with your own model if necessary. 58 | 59 | (2021-10-11) The Colab demo for GPEN is available now google colab logo. 60 | 61 | ## Download models from Modelscope 62 | 63 | - Install modelscope: 64 | ```bash 65 | pip install "modelscope[cv]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html 66 | ``` 67 | 68 | - Run the following codes: 69 | ```python 70 | import cv2 71 | from modelscope.pipelines import pipeline 72 | from modelscope.utils.constant import Tasks 73 | from modelscope.outputs import OutputKeys 74 | 75 | portrait_enhancement = pipeline(Tasks.image_portrait_enhancement, model='damo/cv_gpen_image-portrait-enhancement-hires') 76 | result = portrait_enhancement('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/marilyn_monroe_4.jpg') 77 | cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) 78 | ``` 79 | 80 | It will automatically download the GPEN models. You can find the model in the local path ``~/.cache/modelscope/hub/damo``. Please note pytorch_model.pt, pytorch_model-2048.pt are respectively the 1024 and 2048 versions. 81 | 82 | ## Usage 83 | 84 | ![python](https://img.shields.io/badge/python-v3.7.4-green.svg?style=plastic) 85 | ![pytorch](https://img.shields.io/badge/pytorch-v1.7.0-green.svg?style=plastic) 86 | ![cuda](https://img.shields.io/badge/cuda-v10.2.89-green.svg?style=plastic) 87 | ![driver](https://img.shields.io/badge/driver-v460.73.01-green.svg?style=plastic) 88 | ![gcc](https://img.shields.io/badge/gcc-v7.5.0-green.svg?style=plastic) 89 | 90 | - Clone this repository: 91 | ```bash 92 | git clone https://github.com/yangxy/GPEN.git 93 | cd GPEN 94 | ``` 95 | - Download RetinaFace model and our pre-trained model (not our best model due to commercial issues) and put them into ``weights/``. 96 | 97 | [RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [ParseNet-latest](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/ParseNet-latest.pth) | [model_ir_se50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/model_ir_se50.pth) | [GPEN-BFR-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth) | [GPEN-BFR-512-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512-D.pth) | [GPEN-BFR-256](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256.pth) | [GPEN-BFR-256-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256-D.pth) | [GPEN-Colorization-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth) | [GPEN-Inpainting-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Inpainting-1024.pth) | [GPEN-Seg2face-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Seg2face-512.pth) | [realesrnet_x1](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x1.pth) | [realesrnet_x2](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x2.pth) | [realesrnet_x4](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x4.pth) 98 | 99 | - Restore face images: 100 | ```bash 101 | python demo.py --task FaceEnhancement --model GPEN-BFR-512 --in_size 512 --channel_multiplier 2 --narrow 1 --use_sr --sr_scale 4 --use_cuda --save_face --indir examples/imgs --outdir examples/outs-bfr 102 | ``` 103 | 104 | - Colorize faces: 105 | ```bash 106 | python demo.py --task FaceColorization --model GPEN-Colorization-1024 --in_size 1024 --use_cuda --indir examples/grays --outdir examples/outs-colorization 107 | ``` 108 | 109 | - Complete faces: 110 | ```bash 111 | python demo.py --task FaceInpainting --model GPEN-Inpainting-1024 --in_size 1024 --use_cuda --indir examples/ffhq-10 --outdir examples/outs-inpainting 112 | ``` 113 | 114 | - Synthesize faces: 115 | ```bash 116 | python demo.py --task Segmentation2Face --model GPEN-Seg2face-512 --in_size 512 --use_cuda --indir examples/segs --outdir examples/outs-seg2face 117 | ``` 118 | 119 | - Train GPEN for BFR with 4 GPUs: 120 | ```bash 121 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train_simple.py --size 1024 --channel_multiplier 2 --narrow 1 --ckpt weights --sample results --batch 2 --path your_path_of_croped+aligned_hq_faces (e.g., FFHQ) 122 | 123 | ``` 124 | When testing your own model, set ``--key g_ema``. 125 | 126 | Please check out ``run.sh`` for more details. 127 | 128 | ## Main idea 129 | 130 | 131 | ## Citation 132 | If our work is useful for your research, please consider citing: 133 | 134 | @inproceedings{Yang2021GPEN, 135 | title={GAN Prior Embedded Network for Blind Face Restoration in the Wild}, 136 | author={Tao Yang, Peiran Ren, Xuansong Xie, and Lei Zhang}, 137 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 138 | year={2021} 139 | } 140 | 141 | ## License 142 | © Alibaba, 2021. For academic and non-commercial use only. 143 | 144 | ## Acknowledgments 145 | We borrow some codes from [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface), [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch), [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN), and [GFPGAN](https://github.com/TencentARC/GFPGAN). 146 | 147 | ## Contact 148 | If you have any questions or suggestions about this paper, feel free to reach me at yangtao9009@gmail.com. 149 | -------------------------------------------------------------------------------- /gpeno/face_model/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import folder_paths 2 | import torch 3 | import os 4 | import sys 5 | 6 | import cv2 7 | import glob 8 | import time 9 | import math 10 | import argparse 11 | import numpy as np 12 | from PIL import Image, ImageDraw 13 | 14 | this_path = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(this_path) 16 | 17 | import gpeno.__init_paths 18 | from gpeno.face_enhancement import FaceEnhancement 19 | 20 | models_dir = folder_paths.models_dir 21 | 22 | global_gpen_processor = None 23 | global_gpen_cache_model = "" 24 | 25 | 26 | class GPENO: 27 | """Restores faces in the image a special version of the GPEN technique that has been optimized for speed.""" 28 | 29 | @classmethod 30 | def INPUT_TYPES(cls): 31 | return { 32 | "required": { 33 | "image": ("IMAGE", ), 34 | "use_global_cache": ("BOOLEAN", { 35 | "default": True, 36 | "tooltip": "If enabled, the model will be loaded once and shared across all instances of this node. This saves VRAM if you are using multiple instances of GPENO in your flow, but the settings must remain the same for all instances." 37 | }), 38 | "unload": ("BOOLEAN", { 39 | "default": False, 40 | "tooltip": "If enabled, the model will be freed from the cache at the start of this node's execution (if applicable), and it will not be cached again." 41 | }), 42 | "backbone": (["RetinaFace-R50", "mobilenet0.25_Final"], { 43 | "default": "RetinaFace-R50", 44 | "tooltip": "Backbone files are downloaded to `comfyui/models/facedetection`." 45 | }), 46 | "resolution_preset": (["512", "1024", "2048"], { 47 | "default": "512" 48 | }), 49 | "downscale_method": (["Bilinear", "Nearest", "Bicubic", "Area", "Lanczos"], { 50 | "default": "Bilinear" 51 | }), 52 | "channel_multiplier": ("FLOAT", { 53 | "default": 2 54 | }), 55 | "narrow": ("FLOAT", { 56 | "default": 1.0 57 | }), 58 | "alpha": ("FLOAT", { 59 | "default": 1.0 60 | }), 61 | "device": (["cpu", "cuda"], { 62 | "default": "cuda" if torch.cuda.is_available() else "cpu" 63 | }), 64 | "aligned": ("BOOLEAN", { 65 | "default": False 66 | }), 67 | "colorize": ("BOOLEAN", { 68 | "default": False, 69 | }), 70 | }, 71 | } 72 | 73 | RETURN_TYPES = ( 74 | "IMAGE", 75 | "IMAGE", 76 | "IMAGE", 77 | ) 78 | FUNCTION = "op" 79 | CATEGORY = "image" 80 | DESCRIPTION = """ 81 | Performs GPEN face restoration on the input image(s). This implementation has been optimized for speed. 82 | """ 83 | 84 | def __init__(self): 85 | self.gpen_processor = None 86 | self.gpen_cache_model = "" 87 | 88 | def op(self, image, use_global_cache, unload, backbone, resolution_preset, downscale_method, channel_multiplier, narrow, alpha, device, aligned, colorize): 89 | global global_gpen_processor 90 | global global_gpen_cache_model 91 | 92 | # Package arguments into attribute notation for use with argparse 93 | args = argparse.Namespace() 94 | args.model = f"GPEN-BFR-{resolution_preset}" 95 | args.channel_multiplier = channel_multiplier 96 | args.narrow = narrow 97 | args.alpha = alpha 98 | args.use_cuda = device 99 | args.aligned = aligned 100 | 101 | # Hardcoded arguments irrelevant to the user 102 | args.use_sr = False 103 | args.in_size = int(resolution_preset) 104 | args.out_size = 0 105 | args.sr_model = "realesrnet" 106 | args.sr_scale = 2 107 | args.key = None 108 | args.indir = "example/imgs" 109 | args.outdir = "results/outs-BFR" 110 | args.ext = ".jpg" 111 | args.save_face = False 112 | args.tile_size = 0 113 | 114 | if downscale_method == "Nearest": 115 | downscale_method = cv2.INTER_NEAREST 116 | elif downscale_method == "Bilinear": 117 | downscale_method = cv2.INTER_LINEAR 118 | elif downscale_method == "Area": 119 | downscale_method = cv2.INTER_AREA 120 | elif downscale_method == "Cubic": 121 | downscale_method = cv2.INTER_CUBIC 122 | elif downscale_method == "Lanczos": 123 | downscale_method = cv2.INTER_LANCZOS4 124 | 125 | def download_file(filename, url, logger=None, overwrite=False, headers=None): 126 | import os, requests 127 | 128 | # log = get_logger(logger) 129 | 130 | if overwrite or not os.path.exists(filename): 131 | # Make sure directory structure exists 132 | os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True) 133 | 134 | # log.info(f"Downloading file into: {filename}...") 135 | 136 | response = requests.get(url, stream=True, headers=headers) 137 | if response.status_code != 200: 138 | # log.error(f"Error when trying to download `{url}` to `{filename}`. Dtatus code received: {response.status_code}") 139 | return False 140 | try: 141 | with open(filename, "wb") as fout: 142 | for block in response.iter_content(4096): 143 | fout.write(block) 144 | except: 145 | # log.exception(f"Error when writing download to `{filename}`.") 146 | return False 147 | 148 | return True 149 | 150 | gpen_dir = os.path.join(models_dir, "facerestore_models") 151 | facedetect_dir = os.path.join(models_dir, "facedetection") 152 | 153 | if args.model == "GPEN-BFR-512": 154 | download_file(f"{gpen_dir}/{args.model}.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth") 155 | elif args.model == "GPEN-BFR-1024": 156 | if not download_file(f"{gpen_dir}/{args.model}.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-1024.pth"): 157 | pass 158 | # self.log.error("The download link for the 1024 model doesn't appear to work. Try installing it manually into your unprompted/models/gpen folder: https://cyberfile.me/644d") 159 | elif args.model == "GPEN-BFR-2048": 160 | download_file(f"{gpen_dir}/{args.model}.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-2048.pth") 161 | 162 | if colorize: 163 | download_file(f"{gpen_dir}/GPEN-Colorization-1024.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth") 164 | 165 | # Additional dependencies 166 | download_file(f"{facedetect_dir}/parsing_parsenet.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/ParseNet-latest.pth") 167 | # for sr 168 | # download_file(f"{facedetect_dir}/realesrnet_x2.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x2.pth") 169 | 170 | if backbone == "RetinaFace-R50": 171 | backbone = "detection_Resnet50_Final" 172 | download_file(f"{facedetect_dir}/detection_Resnet50_Final.pth", "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth") 173 | elif backbone == "mobilenet0.25_Final": 174 | backbone = "detection_mobilenet0.25_Final" 175 | download_file(f"{facedetect_dir}/detection_mobilenet0.25_Final.pth", "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth") 176 | 177 | if use_global_cache: 178 | this_gpen_processor = global_gpen_processor 179 | this_gpen_cache_model = global_gpen_cache_model 180 | else: 181 | this_gpen_processor = self.gpen_processor 182 | this_gpen_cache_model = self.gpen_cache_model 183 | 184 | if unload or not this_gpen_processor or this_gpen_cache_model != resolution_preset: 185 | print("Loading FaceEnhancement object...") 186 | face_enhancement_obj = FaceEnhancement(args, base_dir=f"{models_dir}/", in_size=args.in_size, model=args.model, use_sr=args.use_sr, device=args.use_cuda, interp=downscale_method, backbone=backbone, colorize=colorize) 187 | 188 | if use_global_cache: 189 | global_gpen_processor = face_enhancement_obj 190 | global_gpen_cache_model = resolution_preset 191 | this_gpen_processor = global_gpen_processor 192 | this_gpen_cache_model = global_gpen_cache_model 193 | else: 194 | self.gpen_processor = face_enhancement_obj 195 | self.gpen_cache_model = resolution_preset 196 | this_gpen_processor = self.gpen_processor 197 | this_gpen_cache_model = self.gpen_cache_model 198 | else: 199 | print("Using cached FaceEnhancement object.") 200 | # self.log.info("Using cached FaceEnhancement object.") 201 | 202 | print("Starting GPENO processing...") 203 | 204 | total_images = image.shape[0] 205 | out_images = [] 206 | out_original_faces = [] 207 | out_enhanced_faces = [] 208 | 209 | for i in range(total_images): 210 | # image is a 4d tensor array in the format of [B, H, W, C] 211 | this_img = 255. * image[i].cpu().numpy() 212 | img = np.clip(this_img, 0, 255).astype(np.uint8) 213 | 214 | result, orig_faces, enhanced_faces = this_gpen_processor.process(img, aligned=args.aligned) 215 | 216 | out_images.append(result) 217 | # add each of the orig_faces list to the out_original_faces list 218 | for orig_face in orig_faces: 219 | out_original_faces.append(orig_face) 220 | for enhanced_face in enhanced_faces: 221 | out_enhanced_faces.append(enhanced_face) 222 | 223 | restored_img_np = np.array(out_images).astype(np.float32) / 255.0 224 | restored_img_tensor = torch.from_numpy(restored_img_np) 225 | 226 | restored_original_faces_np = np.array(out_original_faces).astype(np.float32) / 255.0 227 | restored_original_faces_tensor = torch.from_numpy(restored_original_faces_np) 228 | 229 | restored_enhanced_faces_np = np.array(out_enhanced_faces).astype(np.float32) / 255.0 230 | restored_enhanced_faces_tensor = torch.from_numpy(restored_enhanced_faces_np) 231 | 232 | if unload: 233 | print("Unloading GPEN from cache.") 234 | 235 | if use_global_cache: 236 | global_gpen_processor = None 237 | global_gpen_cache_model = "" 238 | else: 239 | self.gpen_cache_model = "" 240 | self.gpen_processor = None 241 | 242 | print("GPENO processing done.") 243 | return ( 244 | restored_img_tensor, 245 | restored_original_faces_tensor, 246 | restored_enhanced_faces_tensor, 247 | ) 248 | 249 | 250 | # A dictionary that contains all nodes you want to export with their names 251 | # NOTE: names should be globally unique 252 | NODE_CLASS_MAPPINGS = { 253 | "GPENO Face Restoration": GPENO, 254 | } 255 | 256 | # A dictionary that contains the friendly/humanly readable titles for the nodes 257 | NODE_DISPLAY_NAME_MAPPINGS = { 258 | "GPENO Face Restoration": "GPENO Face Restoration", 259 | } 260 | -------------------------------------------------------------------------------- /gpeno/training/lpips/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from collections import OrderedDict 8 | from torch.autograd import Variable 9 | from scipy.ndimage import zoom 10 | from tqdm import tqdm 11 | import lpips 12 | import os 13 | 14 | 15 | class Trainer(): 16 | def name(self): 17 | return self.model_name 18 | 19 | def initialize(self, model='lpips', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 20 | use_gpu=True, printNet=False, spatial=False, 21 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 22 | ''' 23 | INPUTS 24 | model - ['lpips'] for linearly calibrated network 25 | ['baseline'] for off-the-shelf network 26 | ['L2'] for L2 distance in Lab colorspace 27 | ['SSIM'] for ssim in RGB colorspace 28 | net - ['squeeze','alex','vgg'] 29 | model_path - if None, will look in weights/[NET_NAME].pth 30 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 31 | use_gpu - bool - whether or not to use a GPU 32 | printNet - bool - whether or not to print network architecture out 33 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 34 | is_train - bool - [True] for training mode 35 | lr - float - initial learning rate 36 | beta1 - float - initial momentum term for adam 37 | version - 0.1 for latest, 0.0 was original (with a bug) 38 | gpu_ids - int array - [0] by default, gpus to use 39 | ''' 40 | self.use_gpu = use_gpu 41 | self.gpu_ids = gpu_ids 42 | self.model = model 43 | self.net = net 44 | self.is_train = is_train 45 | self.spatial = spatial 46 | self.model_name = '%s [%s]'%(model,net) 47 | 48 | if(self.model == 'lpips'): # pretrained net + linear layer 49 | self.net = lpips.LPIPS(pretrained=not is_train, net=net, version=version, lpips=True, spatial=spatial, 50 | pnet_rand=pnet_rand, pnet_tune=pnet_tune, 51 | use_dropout=True, model_path=model_path, eval_mode=False) 52 | elif(self.model=='baseline'): # pretrained network 53 | self.net = lpips.LPIPS(pnet_rand=pnet_rand, net=net, lpips=False) 54 | elif(self.model in ['L2','l2']): 55 | self.net = lpips.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 56 | self.model_name = 'L2' 57 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 58 | self.net = lpips.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 59 | self.model_name = 'SSIM' 60 | else: 61 | raise ValueError("Model [%s] not recognized." % self.model) 62 | 63 | self.parameters = list(self.net.parameters()) 64 | 65 | if self.is_train: # training mode 66 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 67 | self.rankLoss = lpips.BCERankingLoss() 68 | self.parameters += list(self.rankLoss.net.parameters()) 69 | self.lr = lr 70 | self.old_lr = lr 71 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 72 | else: # test mode 73 | self.net.eval() 74 | 75 | if(use_gpu): 76 | self.net.to(gpu_ids[0]) 77 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 78 | if(self.is_train): 79 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 80 | 81 | if(printNet): 82 | print('---------- Networks initialized -------------') 83 | networks.print_network(self.net) 84 | print('-----------------------------------------------') 85 | 86 | def forward(self, in0, in1, retPerLayer=False): 87 | ''' Function computes the distance between image patches in0 and in1 88 | INPUTS 89 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 90 | OUTPUT 91 | computed distances between in0 and in1 92 | ''' 93 | 94 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 95 | 96 | # ***** TRAINING FUNCTIONS ***** 97 | def optimize_parameters(self): 98 | self.forward_train() 99 | self.optimizer_net.zero_grad() 100 | self.backward_train() 101 | self.optimizer_net.step() 102 | self.clamp_weights() 103 | 104 | def clamp_weights(self): 105 | for module in self.net.modules(): 106 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 107 | module.weight.data = torch.clamp(module.weight.data,min=0) 108 | 109 | def set_input(self, data): 110 | self.input_ref = data['ref'] 111 | self.input_p0 = data['p0'] 112 | self.input_p1 = data['p1'] 113 | self.input_judge = data['judge'] 114 | 115 | if(self.use_gpu): 116 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 117 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 118 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 119 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 120 | 121 | self.var_ref = Variable(self.input_ref,requires_grad=True) 122 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 123 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 124 | 125 | def forward_train(self): # run forward pass 126 | self.d0 = self.forward(self.var_ref, self.var_p0) 127 | self.d1 = self.forward(self.var_ref, self.var_p1) 128 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 129 | 130 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 131 | 132 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 133 | 134 | return self.loss_total 135 | 136 | def backward_train(self): 137 | torch.mean(self.loss_total).backward() 138 | 139 | def compute_accuracy(self,d0,d1,judge): 140 | ''' d0, d1 are Variables, judge is a Tensor ''' 141 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 197 | self.old_lr = lr 198 | 199 | 200 | def get_image_paths(self): 201 | return self.image_paths 202 | 203 | def save_done(self, flag=False): 204 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 205 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 206 | 207 | 208 | def score_2afc_dataset(data_loader, func, name=''): 209 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 210 | distance function 'func' in dataset 'data_loader' 211 | INPUTS 212 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 213 | func - callable distance function - calling d=func(in0,in1) should take 2 214 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 215 | OUTPUTS 216 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 217 | [1] - dictionary with following elements 218 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 219 | gts - N array in [0,1], preferred patch selected by human evaluators 220 | (closer to "0" for left patch p0, "1" for right patch p1, 221 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 222 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 223 | CONSTS 224 | N - number of test triplets in data_loader 225 | ''' 226 | 227 | d0s = [] 228 | d1s = [] 229 | gts = [] 230 | 231 | for data in tqdm(data_loader.load_data(), desc=name): 232 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 233 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 234 | gts+=data['judge'].cpu().numpy().flatten().tolist() 235 | 236 | d0s = np.array(d0s) 237 | d1s = np.array(d1s) 238 | gts = np.array(gts) 239 | scores = (d0s [A,1,2] -> [A,B,2] 32 | [B,2] -> [1,B,2] -> [A,B,2] 33 | Then we compute the area of intersect between box_a and box_b. 34 | Args: 35 | box_a: (tensor) bounding boxes, Shape: [A,4]. 36 | box_b: (tensor) bounding boxes, Shape: [B,4]. 37 | Return: 38 | (tensor) intersection area, Shape: [A,B]. 39 | """ 40 | A = box_a.size(0) 41 | B = box_b.size(0) 42 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 43 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 44 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 45 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 46 | inter = torch.clamp((max_xy - min_xy), min=0) 47 | return inter[:, :, 0] * inter[:, :, 1] 48 | 49 | 50 | def jaccard(box_a, box_b): 51 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 52 | is simply the intersection over union of two boxes. Here we operate on 53 | ground truth boxes and default boxes. 54 | E.g.: 55 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 56 | Args: 57 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 58 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 59 | Return: 60 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 61 | """ 62 | inter = intersect(box_a, box_b) 63 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 64 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 65 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 66 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 67 | union = area_a + area_b - inter 68 | return inter / union # [A,B] 69 | 70 | 71 | def matrix_iou(a, b): 72 | """ 73 | return iou of a and b, numpy version for data augenmentation 74 | """ 75 | lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) 76 | rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) 77 | 78 | area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) 79 | area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) 80 | area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) 81 | return area_i / (area_a[:, np.newaxis] + area_b - area_i) 82 | 83 | 84 | def matrix_iof(a, b): 85 | """ 86 | return iof of a and b, numpy version for data augenmentation 87 | """ 88 | lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) 89 | rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) 90 | 91 | area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) 92 | area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) 93 | return area_i / np.maximum(area_a[:, np.newaxis], 1) 94 | 95 | 96 | def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): 97 | """Match each prior box with the ground truth box of the highest jaccard 98 | overlap, encode the bounding boxes, then return the matched indices 99 | corresponding to both confidence and location preds. 100 | Args: 101 | threshold: (float) The overlap threshold used when mathing boxes. 102 | truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. 103 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 104 | variances: (tensor) Variances corresponding to each prior coord, 105 | Shape: [num_priors, 4]. 106 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 107 | landms: (tensor) Ground truth landms, Shape [num_obj, 10]. 108 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets. 109 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. 110 | landm_t: (tensor) Tensor to be filled w/ endcoded landm targets. 111 | idx: (int) current batch index 112 | Return: 113 | The matched indices corresponding to 1)location 2)confidence 3)landm preds. 114 | """ 115 | # jaccard index 116 | overlaps = jaccard( 117 | truths, 118 | point_form(priors) 119 | ) 120 | # (Bipartite Matching) 121 | # [1,num_objects] best prior for each ground truth 122 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 123 | 124 | # ignore hard gt 125 | valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 126 | best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] 127 | if best_prior_idx_filter.shape[0] <= 0: 128 | loc_t[idx] = 0 129 | conf_t[idx] = 0 130 | return 131 | 132 | # [1,num_priors] best ground truth for each prior 133 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 134 | best_truth_idx.squeeze_(0) 135 | best_truth_overlap.squeeze_(0) 136 | best_prior_idx.squeeze_(1) 137 | best_prior_idx_filter.squeeze_(1) 138 | best_prior_overlap.squeeze_(1) 139 | best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior 140 | # TODO refactor: index best_prior_idx with long tensor 141 | # ensure every gt matches with its prior of max overlap 142 | for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes 143 | best_truth_idx[best_prior_idx[j]] = j 144 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 145 | conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 146 | conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 147 | loc = encode(matches, priors, variances) 148 | 149 | matches_landm = landms[best_truth_idx] 150 | landm = encode_landm(matches_landm, priors, variances) 151 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn 152 | conf_t[idx] = conf # [num_priors] top class label for each prior 153 | landm_t[idx] = landm 154 | 155 | 156 | def encode(matched, priors, variances): 157 | """Encode the variances from the priorbox layers into the ground truth boxes 158 | we have matched (based on jaccard overlap) with the prior boxes. 159 | Args: 160 | matched: (tensor) Coords of ground truth for each prior in point-form 161 | Shape: [num_priors, 4]. 162 | priors: (tensor) Prior boxes in center-offset form 163 | Shape: [num_priors,4]. 164 | variances: (list[float]) Variances of priorboxes 165 | Return: 166 | encoded boxes (tensor), Shape: [num_priors, 4] 167 | """ 168 | 169 | # dist b/t match center and prior's center 170 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] 171 | # encode variance 172 | g_cxcy /= (variances[0] * priors[:, 2:]) 173 | # match wh / prior wh 174 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 175 | g_wh = torch.log(g_wh) / variances[1] 176 | # return target for smooth_l1_loss 177 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 178 | 179 | def encode_landm(matched, priors, variances): 180 | """Encode the variances from the priorbox layers into the ground truth boxes 181 | we have matched (based on jaccard overlap) with the prior boxes. 182 | Args: 183 | matched: (tensor) Coords of ground truth for each prior in point-form 184 | Shape: [num_priors, 10]. 185 | priors: (tensor) Prior boxes in center-offset form 186 | Shape: [num_priors,4]. 187 | variances: (list[float]) Variances of priorboxes 188 | Return: 189 | encoded landm (tensor), Shape: [num_priors, 10] 190 | """ 191 | 192 | # dist b/t match center and prior's center 193 | matched = torch.reshape(matched, (matched.size(0), 5, 2)) 194 | priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) 195 | priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) 196 | priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) 197 | priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) 198 | priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) 199 | g_cxcy = matched[:, :, :2] - priors[:, :, :2] 200 | # encode variance 201 | g_cxcy /= (variances[0] * priors[:, :, 2:]) 202 | # g_cxcy /= priors[:, :, 2:] 203 | g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) 204 | # return target for smooth_l1_loss 205 | return g_cxcy 206 | 207 | 208 | # Adapted from https://github.com/Hakuyume/chainer-ssd 209 | def decode(loc, priors, variances): 210 | """Decode locations from predictions using priors to undo 211 | the encoding we did for offset regression at train time. 212 | Args: 213 | loc (tensor): location predictions for loc layers, 214 | Shape: [num_priors,4] 215 | priors (tensor): Prior boxes in center-offset form. 216 | Shape: [num_priors,4]. 217 | variances: (list[float]) Variances of priorboxes 218 | Return: 219 | decoded bounding box predictions 220 | """ 221 | 222 | boxes = torch.cat(( 223 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 224 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 225 | boxes[:, :2] -= boxes[:, 2:] / 2 226 | boxes[:, 2:] += boxes[:, :2] 227 | return boxes 228 | 229 | def decode_landm(pre, priors, variances): 230 | """Decode landm from predictions using priors to undo 231 | the encoding we did for offset regression at train time. 232 | Args: 233 | pre (tensor): landm predictions for loc layers, 234 | Shape: [num_priors,10] 235 | priors (tensor): Prior boxes in center-offset form. 236 | Shape: [num_priors,4]. 237 | variances: (list[float]) Variances of priorboxes 238 | Return: 239 | decoded landm predictions 240 | """ 241 | landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], 242 | priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], 243 | priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], 244 | priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], 245 | priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], 246 | ), dim=1) 247 | return landms 248 | 249 | 250 | def log_sum_exp(x): 251 | """Utility function for computing log_sum_exp while determining 252 | This will be used to determine unaveraged confidence loss across 253 | all examples in a batch. 254 | Args: 255 | x (Variable(tensor)): conf_preds from conf layers 256 | """ 257 | x_max = x.data.max() 258 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max 259 | 260 | 261 | # Original author: Francisco Massa: 262 | # https://github.com/fmassa/object-detection.torch 263 | # Ported to PyTorch by Max deGroot (02/01/2017) 264 | def nms(boxes, scores, overlap=0.5, top_k=200): 265 | """Apply non-maximum suppression at test time to avoid detecting too many 266 | overlapping bounding boxes for a given object. 267 | Args: 268 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 269 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 270 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 271 | top_k: (int) The Maximum number of box preds to consider. 272 | Return: 273 | The indices of the kept boxes with respect to num_priors. 274 | """ 275 | 276 | keep = torch.Tensor(scores.size(0)).fill_(0).long() 277 | if boxes.numel() == 0: 278 | return keep 279 | x1 = boxes[:, 0] 280 | y1 = boxes[:, 1] 281 | x2 = boxes[:, 2] 282 | y2 = boxes[:, 3] 283 | area = torch.mul(x2 - x1, y2 - y1) 284 | v, idx = scores.sort(0) # sort in ascending order 285 | # I = I[v >= 0.01] 286 | idx = idx[-top_k:] # indices of the top-k largest vals 287 | xx1 = boxes.new() 288 | yy1 = boxes.new() 289 | xx2 = boxes.new() 290 | yy2 = boxes.new() 291 | w = boxes.new() 292 | h = boxes.new() 293 | 294 | # keep = torch.Tensor() 295 | count = 0 296 | while idx.numel() > 0: 297 | i = idx[-1] # index of current largest val 298 | # keep.append(i) 299 | keep[count] = i 300 | count += 1 301 | if idx.size(0) == 1: 302 | break 303 | idx = idx[:-1] # remove kept element from view 304 | # load bboxes of next highest vals 305 | torch.index_select(x1, 0, idx, out=xx1) 306 | torch.index_select(y1, 0, idx, out=yy1) 307 | torch.index_select(x2, 0, idx, out=xx2) 308 | torch.index_select(y2, 0, idx, out=yy2) 309 | # store element-wise max with next highest score 310 | xx1 = torch.clamp(xx1, min=x1[i]) 311 | yy1 = torch.clamp(yy1, min=y1[i]) 312 | xx2 = torch.clamp(xx2, max=x2[i]) 313 | yy2 = torch.clamp(yy2, max=y2[i]) 314 | w.resize_as_(xx2) 315 | h.resize_as_(yy2) 316 | w = xx2 - xx1 317 | h = yy2 - yy1 318 | # check sizes of xx1 and xx2.. after each iteration 319 | w = torch.clamp(w, min=0.0) 320 | h = torch.clamp(h, min=0.0) 321 | inter = w*h 322 | # IoU = i / (area(a) + area(b) - i) 323 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 324 | union = (rem_areas - inter) + area[i] 325 | IoU = inter/union # store result in iou 326 | # keep only elements with an IoU <= overlap 327 | idx = idx[IoU.le(overlap)] 328 | return keep, count 329 | 330 | 331 | --------------------------------------------------------------------------------