├── gpeno
├── face_detect
│ ├── utils
│ │ ├── __init__.py
│ │ ├── nms
│ │ │ ├── __init__.py
│ │ │ └── py_cpu_nms.py
│ │ ├── timer.py
│ │ └── box_utils.py
│ ├── facemodels
│ │ ├── __init__.py
│ │ ├── retinaface.py
│ │ └── net.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ └── multibox_loss.py
│ │ └── functions
│ │ │ └── prior_box.py
│ ├── .DS_Store
│ ├── data
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── wider_face.py
│ │ └── data_augment.py
│ └── retinaface_detection.py
├── face_parse
│ ├── mask.png
│ ├── test.png
│ ├── face_parsing.py
│ ├── face_parsing_broken.py
│ ├── parse_model.py
│ └── blocks.py
├── face_model
│ ├── op
│ │ ├── __init__.py
│ │ ├── fused_bias_act.cpp
│ │ ├── upfirdn2d.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── fused_act.py
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ └── face_gan.py
├── requirements.txt
├── training
│ ├── lpips
│ │ ├── weights
│ │ │ ├── v0.0
│ │ │ │ ├── alex.pth
│ │ │ │ ├── vgg.pth
│ │ │ │ └── squeeze.pth
│ │ │ └── v0.1
│ │ │ │ ├── alex.pth
│ │ │ │ ├── vgg.pth
│ │ │ │ └── squeeze.pth
│ │ ├── __init__.py
│ │ ├── pretrained_networks.py
│ │ ├── lpips.py
│ │ └── trainer.py
│ ├── loss
│ │ ├── id_loss.py
│ │ ├── model_irse.py
│ │ └── helpers.py
│ └── data_loader
│ │ └── dataset_face.py
├── misc
│ ├── cog.yaml
│ ├── predict.py
│ └── onnx_export.py
├── __init_paths.py
├── face_inpainting.py
├── face_colorization.py
├── .gitignore
├── distributed.py
├── face_enhancement.py
├── demo.py
├── align_faces.py
└── README.md
├── .gitattributes
├── workflows
└── workflow_gpeno.png
├── CHANGELOG.md
├── .github
└── workflows
│ └── publish_action.yml
├── pyproject.toml
├── README.md
└── __init__.py
/gpeno/face_detect/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gpeno/face_detect/facemodels/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gpeno/face_detect/utils/nms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/gpeno/face_detect/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .functions import *
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/gpeno/face_parse/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/face_parse/mask.png
--------------------------------------------------------------------------------
/gpeno/face_parse/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/face_parse/test.png
--------------------------------------------------------------------------------
/gpeno/face_detect/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/face_detect/.DS_Store
--------------------------------------------------------------------------------
/workflows/workflow_gpeno.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/workflows/workflow_gpeno.png
--------------------------------------------------------------------------------
/gpeno/face_detect/layers/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .multibox_loss import MultiBoxLoss
2 |
3 | __all__ = ['MultiBoxLoss']
4 |
--------------------------------------------------------------------------------
/gpeno/face_model/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/gpeno/requirements.txt:
--------------------------------------------------------------------------------
1 | ninja
2 | torch
3 | torchvision
4 | opencv-python
5 | numpy
6 | scikit-image
7 | scipy
8 | pillow
9 | tqdm
--------------------------------------------------------------------------------
/gpeno/training/lpips/weights/v0.0/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.0/alex.pth
--------------------------------------------------------------------------------
/gpeno/training/lpips/weights/v0.0/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.0/vgg.pth
--------------------------------------------------------------------------------
/gpeno/training/lpips/weights/v0.1/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.1/alex.pth
--------------------------------------------------------------------------------
/gpeno/training/lpips/weights/v0.1/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.1/vgg.pth
--------------------------------------------------------------------------------
/gpeno/training/lpips/weights/v0.0/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.0/squeeze.pth
--------------------------------------------------------------------------------
/gpeno/training/lpips/weights/v0.1/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SparknightLLC/ComfyUI-GPENO/HEAD/gpeno/training/lpips/weights/v0.1/squeeze.pth
--------------------------------------------------------------------------------
/gpeno/face_detect/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .wider_face import WiderFaceDetection, detection_collate
2 | from .data_augment import *
3 | from .config import *
4 |
--------------------------------------------------------------------------------
/gpeno/misc/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | gpu: true
3 | python_version: "3.8"
4 | system_packages:
5 | - "libgl1-mesa-glx"
6 | - "libglib2.0-0"
7 | - "ninja-build"
8 | python_packages:
9 | - "torch==1.7.1"
10 | - "torchvision==0.8.2"
11 | - "numpy==1.20.1"
12 | - "ipython==7.21.0"
13 | - "Pillow==8.3.1"
14 | - "scikit-image==0.18.3"
15 | - "opencv-python==4.5.3.56"
16 |
17 | predict: "predict.py:Predictor"
18 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | All notable changes to this project will be documented in this file.
2 |
3 | 0.1.0 - 15 April 2025
4 |
5 | **Added**
6 | - New `colorize` setting: Bring old photos to life
7 | - Returns the original face and enhanced face in addition to the resulting masked image
8 |
9 |
10 |
11 | 0.0.1 - 24 December 2024
12 |
13 | **Added**
14 | - Initial release
15 |
16 |
--------------------------------------------------------------------------------
/gpeno/__init_paths.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import os.path as osp
6 | import sys
7 |
8 |
9 | def add_path(path):
10 | if path not in sys.path:
11 | sys.path.insert(0, path)
12 |
13 |
14 | this_dir = osp.dirname(__file__)
15 |
16 | path = osp.join(this_dir, 'face_detect')
17 | add_path(path)
18 |
19 | path = osp.join(this_dir, 'face_parse')
20 | add_path(path)
21 |
22 | path = osp.join(this_dir, 'face_model')
23 | add_path(path)
24 |
--------------------------------------------------------------------------------
/.github/workflows/publish_action.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "pyproject.toml"
9 |
10 | permissions:
11 | issues: write
12 |
13 | jobs:
14 | publish-node:
15 | name: Publish Custom Node to registry
16 | runs-on: ubuntu-latest
17 | if: ${{ github.repository_owner == 'SparknightLLC' }}
18 | steps:
19 | - name: Check out code
20 | uses: actions/checkout@v4
21 | - name: Publish Custom Node
22 | uses: Comfy-Org/publish-node-action@v1
23 | with:
24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here.
25 |
--------------------------------------------------------------------------------
/gpeno/face_inpainting.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | from face_model.face_gan import FaceGAN
6 |
7 | class FaceInpainting(object):
8 | def __init__(self, base_dir='./', in_size=1024, out_size=1024, model=None, channel_multiplier=2, narrow=1, key=None, device='cuda'):
9 | self.facegan = FaceGAN(base_dir, in_size, out_size, model, channel_multiplier, narrow, key, device=device)
10 |
11 | # make sure the face image is well aligned. Please refer to face_enhancement.py
12 | def process(self, brokenf, aligned=True):
13 | # complete the face
14 | out = self.facegan.process(brokenf)
15 |
16 | return out, [brokenf], [out]
17 |
18 |
19 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # pyproject.toml
2 | [project]
3 | name = "comfyui-gpeno" # Unique identifier for your node. Immutable after creation.
4 | description = "A node for ComfyUI that performs GPEN face restoration on the input image(s). Significantly faster than other implementations of GPEN."
5 | version = "0.1.0" # Custom Node version. Must be semantically versioned.
6 | license = { file = "LICENSE.txt" }
7 | dependencies = [] # Filled in from requirements.txt
8 |
9 | [project.urls]
10 | Repository = "https://github.com/SparknightLLC/ComfyUI-GPENO"
11 |
12 | [tool.comfy]
13 | PublisherId = "sparknight" # TODO (fill in Publisher ID from Comfy Registry Website).
14 | DisplayName = "ComfyUI-GPENO" # Display name for the Custom Node. Can be changed later.
15 | Icon = "https://example.com/icon.png" # SVG, PNG, JPG or GIF (MAX. 800x400px)
--------------------------------------------------------------------------------
/gpeno/face_model/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/gpeno/face_model/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/gpeno/face_detect/utils/nms/py_cpu_nms.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | import numpy as np
9 |
10 | def py_cpu_nms(dets, thresh):
11 | """Pure Python NMS baseline."""
12 | x1 = dets[:, 0]
13 | y1 = dets[:, 1]
14 | x2 = dets[:, 2]
15 | y2 = dets[:, 3]
16 | scores = dets[:, 4]
17 |
18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
19 | order = scores.argsort()[::-1]
20 |
21 | keep = []
22 | while order.size > 0:
23 | i = order[0]
24 | keep.append(i)
25 | xx1 = np.maximum(x1[i], x1[order[1:]])
26 | yy1 = np.maximum(y1[i], y1[order[1:]])
27 | xx2 = np.minimum(x2[i], x2[order[1:]])
28 | yy2 = np.minimum(y2[i], y2[order[1:]])
29 |
30 | w = np.maximum(0.0, xx2 - xx1 + 1)
31 | h = np.maximum(0.0, yy2 - yy1 + 1)
32 | inter = w * h
33 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
34 |
35 | inds = np.where(ovr <= thresh)[0]
36 | order = order[inds + 1]
37 |
38 | return keep
39 |
--------------------------------------------------------------------------------
/gpeno/face_detect/data/config.py:
--------------------------------------------------------------------------------
1 | # config.py
2 |
3 | cfg_mnet = {
4 | 'name': 'detection_mobilenet0.25_Final',
5 | 'min_sizes': [[16, 32], [64, 128], [256, 512]],
6 | 'steps': [8, 16, 32],
7 | 'variance': [0.1, 0.2],
8 | 'clip': False,
9 | 'loc_weight': 2.0,
10 | 'gpu_train': True,
11 | 'batch_size': 32,
12 | 'ngpu': 1,
13 | 'epoch': 250,
14 | 'decay1': 190,
15 | 'decay2': 220,
16 | 'image_size': 640,
17 | 'pretrain': False,
18 | 'return_layers': {
19 | 'stage1': 1,
20 | 'stage2': 2,
21 | 'stage3': 3
22 | },
23 | 'in_channel': 32,
24 | 'out_channel': 64,
25 | }
26 |
27 | cfg_re50 = {
28 | 'name': 'detection_Resnet50_Final', #'Resnet50',
29 | 'min_sizes': [[16, 32], [64, 128], [256, 512]],
30 | 'steps': [8, 16, 32],
31 | 'variance': [0.1, 0.2],
32 | 'clip': False,
33 | 'loc_weight': 2.0,
34 | 'gpu_train': True,
35 | 'batch_size': 24,
36 | 'ngpu': 4,
37 | 'epoch': 100,
38 | 'decay1': 70,
39 | 'decay2': 90,
40 | 'image_size': 840,
41 | 'pretrain': False,
42 | 'return_layers': {
43 | 'layer2': 1,
44 | 'layer3': 2,
45 | 'layer4': 3
46 | },
47 | 'in_channel': 256,
48 | 'out_channel': 256,
49 | }
50 |
--------------------------------------------------------------------------------
/gpeno/face_colorization.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import cv2
6 | from face_model.face_gan import FaceGAN
7 |
8 | class FaceColorization(object):
9 | def __init__(self, base_dir='./', in_size=1024, out_size=1024, model=None, channel_multiplier=2, narrow=1, key=None, device='cuda'):
10 | self.facegan = FaceGAN(base_dir, in_size, out_size, model, channel_multiplier, narrow, key, device=device)
11 |
12 | def post_process(self, gray, out):
13 | out_rs = cv2.resize(out, gray.shape[:2][::-1])
14 | gray_yuv = cv2.cvtColor(gray, cv2.COLOR_BGR2YUV)
15 | out_yuv = cv2.cvtColor(out_rs, cv2.COLOR_BGR2YUV)
16 |
17 | out_yuv[:, :, 0] = gray_yuv[:, :, 0]
18 | final = cv2.cvtColor(out_yuv, cv2.COLOR_YUV2BGR)
19 |
20 | return final
21 |
22 | # make sure the face image is well aligned. Please refer to face_enhancement.py
23 | def process(self, gray, aligned=True):
24 | # colorize the face
25 | out = self.facegan.process(gray)
26 |
27 | if gray.shape[:2] != out.shape[:2]:
28 | out = self.post_process(gray, out)
29 |
30 | return out, [gray], [out]
31 |
32 |
33 |
--------------------------------------------------------------------------------
/gpeno/face_detect/utils/timer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | import time
9 |
10 |
11 | class Timer(object):
12 | """A simple timer."""
13 | def __init__(self):
14 | self.total_time = 0.
15 | self.calls = 0
16 | self.start_time = 0.
17 | self.diff = 0.
18 | self.average_time = 0.
19 |
20 | def tic(self):
21 | # using time.time instead of time.clock because time time.clock
22 | # does not normalize for multithreading
23 | self.start_time = time.time()
24 |
25 | def toc(self, average=True):
26 | self.diff = time.time() - self.start_time
27 | self.total_time += self.diff
28 | self.calls += 1
29 | self.average_time = self.total_time / self.calls
30 | if average:
31 | return self.average_time
32 | else:
33 | return self.diff
34 |
35 | def clear(self):
36 | self.total_time = 0.
37 | self.calls = 0
38 | self.start_time = 0.
39 | self.diff = 0.
40 | self.average_time = 0.
41 |
--------------------------------------------------------------------------------
/gpeno/face_detect/layers/functions/prior_box.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from itertools import product as product
3 | import numpy as np
4 | from math import ceil
5 |
6 |
7 | class PriorBox(object):
8 | def __init__(self, cfg, image_size=None, phase='train'):
9 | super(PriorBox, self).__init__()
10 | self.min_sizes = cfg['min_sizes']
11 | self.steps = cfg['steps']
12 | self.clip = cfg['clip']
13 | self.image_size = image_size
14 | self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
15 | self.name = "s"
16 |
17 | def forward(self):
18 | anchors = []
19 | for k, f in enumerate(self.feature_maps):
20 | min_sizes = self.min_sizes[k]
21 | for i, j in product(range(f[0]), range(f[1])):
22 | for min_size in min_sizes:
23 | s_kx = min_size / self.image_size[1]
24 | s_ky = min_size / self.image_size[0]
25 | dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
26 | dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
27 | for cy, cx in product(dense_cy, dense_cx):
28 | anchors += [cx, cy, s_kx, s_ky]
29 |
30 | # back to torch land
31 | output = torch.Tensor(anchors).view(-1, 4)
32 | if self.clip:
33 | output.clamp_(max=1, min=0)
34 | return output
35 |
--------------------------------------------------------------------------------
/gpeno/training/loss/id_loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 | from model_irse import Backbone
5 |
6 | class IDLoss(nn.Module):
7 | def __init__(self, base_dir='./', device='cuda', ckpt_dict=None):
8 | super(IDLoss, self).__init__()
9 | print('Loading ResNet ArcFace')
10 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se').to(device)
11 | if ckpt_dict is None:
12 | self.facenet.load_state_dict(torch.load(os.path.join(base_dir, 'weights', 'model_ir_se50.pth'), map_location=torch.device('cpu')))
13 | else:
14 | self.facenet.load_state_dict(ckpt_dict)
15 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
16 | self.facenet.eval()
17 |
18 | def extract_feats(self, x):
19 | _, _, h, w = x.shape
20 | assert h==w
21 | ss = h//256
22 | x = x[:, :, 35*ss:-33*ss, 32*ss:-36*ss] # Crop interesting region
23 | x = self.face_pool(x)
24 | x_feats = self.facenet(x)
25 | return x_feats
26 |
27 | def forward(self, y_hat, y, x):
28 | n_samples = x.shape[0]
29 | x_feats = self.extract_feats(x)
30 | y_feats = self.extract_feats(y) # Otherwise use the feature from there
31 | y_hat_feats = self.extract_feats(y_hat)
32 | y_feats = y_feats.detach()
33 | loss = 0
34 | sim_improvement = 0
35 | id_logs = []
36 | count = 0
37 | for i in range(n_samples):
38 | diff_target = y_hat_feats[i].dot(y_feats[i])
39 | diff_input = y_hat_feats[i].dot(x_feats[i])
40 | diff_views = y_feats[i].dot(x_feats[i])
41 | id_logs.append({'diff_target': float(diff_target),
42 | 'diff_input': float(diff_input),
43 | 'diff_views': float(diff_views)})
44 | loss += 1 - diff_target
45 | id_diff = float(diff_target) - float(diff_views)
46 | sim_improvement += id_diff
47 | count += 1
48 |
49 | return loss / count, sim_improvement / count, id_logs
50 |
51 |
--------------------------------------------------------------------------------
/gpeno/face_parse/face_parsing.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import os
6 | import cv2
7 | import torch
8 | import numpy as np
9 | from parse_model import ParseNet
10 | import torch.nn.functional as F
11 |
12 |
13 | class FaceParse:
14 |
15 | def __init__(self, base_dir='./', model='ParseNet-latest', device='cuda'):
16 | self.mfile = os.path.join(base_dir, 'facerestore_models', model + '.pth')
17 | self.size = 512
18 | self.device = device
19 | self.MASK_COLORMAP = torch.tensor([0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 255, 255, 0], device=self.device)
20 | self.load_model()
21 | print("FaceParse initialized")
22 |
23 | def load_model(self):
24 | self.faceparse = ParseNet(self.size, self.size, 32, 64, 19, norm_type='bn', relu_type='LeakyReLU', ch_range=[32, 256])
25 | self.faceparse.load_state_dict(torch.load(self.mfile, map_location=self.device))
26 | self.faceparse.to(self.device)
27 | self.faceparse.eval()
28 |
29 | def process(self, im):
30 | im = cv2.resize(im, (self.size, self.size))
31 | imt = self.img2tensor(im)
32 | with torch.no_grad():
33 | pred_mask, _ = self.faceparse(imt)
34 | mask = self.tensor2mask(pred_mask)
35 | return mask
36 |
37 | def process_tensor(self, imt):
38 | imt = F.interpolate(imt.flip(1) * 2 - 1, (self.size, self.size))
39 | with torch.no_grad():
40 | pred_mask, _ = self.faceparse(imt)
41 | mask = pred_mask.argmax(dim=1)
42 | mask = self.MASK_COLORMAP[mask].unsqueeze(0)
43 | return mask
44 |
45 | def img2tensor(self, img):
46 | img = img[..., ::-1] # Convert BGR to RGB
47 | img = (img / 255.0) * 2 - 1 # Scale image to [-1, 1]
48 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device)
49 | return img_tensor.float()
50 |
51 | def tensor2mask(self, tensor):
52 | if len(tensor.shape) < 4:
53 | tensor = tensor.unsqueeze(0)
54 | if tensor.shape[1] > 1:
55 | tensor = tensor.argmax(dim=1)
56 | tensor = tensor.squeeze(1).cpu().numpy()
57 | color_maps = [self.MASK_COLORMAP[t].cpu().numpy().astype(np.uint8) for t in tensor]
58 | return color_maps
59 |
--------------------------------------------------------------------------------
/gpeno/face_parse/face_parsing_broken.py:
--------------------------------------------------------------------------------
1 | '''Modified for use with Unprompted.'''
2 | '''
3 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
4 | @author: yangxy (yangtao9009@gmail.com)
5 | '''
6 |
7 | import os
8 | import cv2
9 | import torch
10 | import numpy as np
11 | from parse_model import ParseNet
12 | import torch.nn.functional as F
13 |
14 |
15 | class FaceParse(object):
16 | def __init__(self, base_dir='./', model='ParseNet-latest', device='cuda'):
17 | self.mfile = os.path.join(base_dir, 'gpen', model + '.pth')
18 | self.size = 512
19 | self.device = device
20 | self.MASK_COLORMAP = np.array([0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 255, 255, 0])
21 | self.load_model()
22 |
23 | def load_model(self):
24 | self.faceparse = ParseNet(self.size, self.size, 32, 64, 19, norm_type='bn', relu_type='LeakyReLU', ch_range=[32, 256])
25 | self.faceparse.load_state_dict(torch.load(self.mfile, map_location=torch.device('cpu')))
26 | self.faceparse.to(self.device)
27 | self.faceparse.eval()
28 |
29 | def process(self, im):
30 | im = cv2.resize(im, (self.size, self.size))
31 | imt = self.img2tensor(im)
32 | with torch.no_grad():
33 | pred_mask, _ = self.faceparse(imt)
34 | mask = self.tensor2mask(pred_mask)
35 | return mask
36 |
37 | def img2tensor(self, img):
38 | img = img[..., ::-1] # BGR to RGB
39 | img = img / 255.0 * 2 - 1
40 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device).float()
41 | return img_tensor
42 |
43 | def tensor2mask(self, tensor):
44 | if tensor.shape[1] > 1:
45 | tensor = tensor.argmax(dim=1)
46 | tensor = tensor.squeeze(0).data.cpu().numpy()
47 | mask = np.take(self.MASK_COLORMAP, tensor)
48 | if mask.ndim == 3:
49 | mask = mask[:, :, 0] # Ensure the mask is 2D
50 | return mask.astype(np.float32)
51 |
52 | def process_tensor(self, imt):
53 | imt = F.interpolate(imt.flip(1) * 2 - 1, (self.size, self.size))
54 | with torch.no_grad():
55 | pred_mask, _ = self.faceparse(imt)
56 | mask = pred_mask.argmax(dim=1).squeeze(0)
57 | mask = torch.where(mask < len(self.MASK_COLORMAP), torch.tensor(self.MASK_COLORMAP, device=mask.device)[mask], mask)
58 | return mask.unsqueeze(0)
59 |
--------------------------------------------------------------------------------
/gpeno/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/gpeno/face_model/face_gan.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import torch
6 | import os
7 | import cv2
8 | import glob
9 | import numpy as np
10 | from torch import nn
11 | import torch.nn.functional as F
12 |
13 | from torchvision import transforms, utils
14 | from .gpen_model import FullGenerator, FullGenerator_SR
15 |
16 |
17 | class FaceGAN(object):
18 |
19 | def __init__(self, base_dir='./', in_size=512, out_size=None, model=None, channel_multiplier=2, narrow=1, key=None, is_norm=True, device='cuda'):
20 | print(f"Initializing FaceGAN on {device} device...")
21 | self.mfile = os.path.join(base_dir, 'facerestore_models', model + '.pth')
22 | self.n_mlp = 8
23 | self.device = device
24 | self.is_norm = is_norm
25 | self.in_resolution = in_size
26 | self.out_resolution = in_size if out_size is None else out_size
27 | self.key = key
28 | self.load_model(channel_multiplier, narrow)
29 | print(f"FaceGAN initialized")
30 |
31 | def load_model(self, channel_multiplier=2, narrow=1):
32 | if self.in_resolution == self.out_resolution:
33 | self.model = FullGenerator(self.in_resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow, device=self.device)
34 | else:
35 | self.model = FullGenerator_SR(self.in_resolution, self.out_resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow, device=self.device)
36 |
37 | pretrained_dict = torch.load(self.mfile, map_location=self.device)
38 |
39 | #if self.key is not None:
40 | # pretrained_dict = pretrained_dict[self.key]
41 |
42 | self.model.load_state_dict(pretrained_dict)
43 | self.model.to(self.device)
44 | self.model.eval()
45 |
46 | def process(self, img):
47 | torch.backends.cudnn.deterministic = True
48 | torch.backends.cudnn.benchmark = False
49 | img = cv2.resize(img, (self.in_resolution, self.in_resolution))
50 | img_t = self.img2tensor(img)
51 |
52 | with torch.no_grad():
53 | out, __ = self.model(img_t)
54 | # del img_t
55 |
56 | out = self.tensor2img(out)
57 |
58 | return out
59 |
60 | def img2tensor(self, img):
61 | img_t = torch.from_numpy(img).to(self.device) / 255.
62 | if self.is_norm:
63 | img_t = (img_t - 0.5) / 0.5
64 | img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
65 | return img_t
66 |
67 | def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8):
68 | if self.is_norm:
69 | img_t = img_t * 0.5 + 0.5
70 | img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
71 | img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax
72 |
73 | return img_np.astype(imtype)
74 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-GPENO
2 |
3 | A node for [ComfyUI](https://github.com/comfyanonymous/ComfyUI) that performs [GPEN face restoration](https://github.com/yangxy/GPEN) on the input image(s). The "O" in "GPENO" stands for "Optimized," as I have implemented various performance improvements that significantly boost speed.
4 |
5 | 
6 |
7 | ### Installation
8 |
9 | Simply drag the image above into ComfyUI and use [ComfyUI Manager » Install Missing Custom Nodes](https://github.com/ltdrdata/ComfyUI-Manager).
10 |
11 | > [!NOTE]
12 | > ComfyUI-GPENO will download GPEN dependencies upon first use. The destinations are `comfyui/models/facerestore_models` and `comfyui/models/facedetection`.
13 |
14 | ### Performance Evaluation
15 |
16 | Face restoration is commonly applied as a post-processing step in faceswap workflows. I ran a few tests against the popular [ReActor node](https://github.com/Gourieff/ComfyUI-ReActor), which includes both face restoration and GPEN features:
17 |
18 | - ReActor end-to-end time using the `GPEN-BFR-512` model for face restoration is about **1.4 seconds** on my GeForce 3090.
19 | - ReActor followed by this GPENO node with the same model takes about **0.7 seconds** - almost exactly 2x speedup for no loss of quality.
20 | - Applying the GPENO node directly to my test images took about **0.5 seconds**.
21 |
22 | Note that your inference time will depend on a number of factors, including input resolution, the number of faces in the image and so on. But you can probably expect 2-3x speedup compared to other implementations of GPEN.
23 |
24 | ### Advantages
25 |
26 | Apart from the speed, here are some other reasons why you might want to use GPENO in your projects:
27 |
28 | 1. It is not coupled with other functions such as faceswap, making it easy to apply to an image without additional processing
29 | 2. It uses a global cache, which helps you save VRAM if you have multiple instances of GPENO in your workflow
30 | 3. It exposes more controls for interfacing with GPEN than you would typically find in an all-in-one node
31 |
32 | ### Inputs
33 |
34 | - `image`: Image or a list of batch images for processing with GPEN
35 | - `use_global_cache`: If enabled, the model will be loaded once and shared across all instances of this node. This saves VRAM if you are using multiple instances of GPENO in your flow, but the settings must remain the same for all instances.
36 | - `unload`: If enabled, the model will be freed from the cache at the start of this node's execution (if applicable), and it will not be cached again.
37 | - Please refer to the GPEN repository for more information on remaining controls
38 |
39 | ---
40 |
41 | This node was adapted from the `[restore_faces]` shortcode of [Unprompted](https://github.com/ThereforeGames/unprompted), my Automatic1111 extension.
--------------------------------------------------------------------------------
/gpeno/face_parse/parse_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Created by chaofengc (chaofenghust@gmail.com)
3 |
4 | @Modified by yangxy (yangtao9009@gmail.com)
5 | '''
6 |
7 | from blocks import *
8 | import torch
9 | from torch import nn
10 | import numpy as np
11 |
12 | def define_P(in_size=512, out_size=512, min_feat_size=32, relu_type='LeakyReLU', isTrain=False, weight_path=None):
13 | net = ParseNet(in_size, out_size, min_feat_size, 64, 19, norm_type='bn', relu_type=relu_type, ch_range=[32, 256])
14 | if not isTrain:
15 | net.eval()
16 | if weight_path is not None:
17 | net.load_state_dict(torch.load(weight_path))
18 | return net
19 |
20 |
21 | class ParseNet(nn.Module):
22 | def __init__(self,
23 | in_size=128,
24 | out_size=128,
25 | min_feat_size=32,
26 | base_ch=64,
27 | parsing_ch=19,
28 | res_depth=10,
29 | relu_type='prelu',
30 | norm_type='bn',
31 | ch_range=[32, 512],
32 | ):
33 | super().__init__()
34 | self.res_depth = res_depth
35 | act_args = {'norm_type': norm_type, 'relu_type': relu_type}
36 | min_ch, max_ch = ch_range
37 |
38 | ch_clip = lambda x: max(min_ch, min(x, max_ch))
39 | min_feat_size = min(in_size, min_feat_size)
40 |
41 | down_steps = int(np.log2(in_size//min_feat_size))
42 | up_steps = int(np.log2(out_size//min_feat_size))
43 |
44 | # =============== define encoder-body-decoder ====================
45 | self.encoder = []
46 | self.encoder.append(ConvLayer(3, base_ch, 3, 1))
47 | head_ch = base_ch
48 | for i in range(down_steps):
49 | cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
50 | self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
51 | head_ch = head_ch * 2
52 |
53 | self.body = []
54 | for i in range(res_depth):
55 | self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
56 |
57 | self.decoder = []
58 | for i in range(up_steps):
59 | cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
60 | self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
61 | head_ch = head_ch // 2
62 |
63 | self.encoder = nn.Sequential(*self.encoder)
64 | self.body = nn.Sequential(*self.body)
65 | self.decoder = nn.Sequential(*self.decoder)
66 | self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
67 | self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
68 |
69 | def forward(self, x):
70 | feat = self.encoder(x)
71 | x = feat + self.body(feat)
72 | x = self.decoder(x)
73 | out_img = self.out_img_conv(x)
74 | out_mask = self.out_mask_conv(x)
75 | return out_mask, out_img
76 |
77 |
78 |
--------------------------------------------------------------------------------
/gpeno/face_model/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/gpeno/training/loss/model_irse.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2 | #from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3 | from helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
4 |
5 | """
6 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7 | """
8 |
9 |
10 | class Backbone(Module):
11 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
12 | super(Backbone, self).__init__()
13 | assert input_size in [112, 224], "input_size should be 112 or 224"
14 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
15 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
16 | blocks = get_blocks(num_layers)
17 | if mode == 'ir':
18 | unit_module = bottleneck_IR
19 | elif mode == 'ir_se':
20 | unit_module = bottleneck_IR_SE
21 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
22 | BatchNorm2d(64),
23 | PReLU(64))
24 | if input_size == 112:
25 | self.output_layer = Sequential(BatchNorm2d(512),
26 | Dropout(drop_ratio),
27 | Flatten(),
28 | Linear(512 * 7 * 7, 512),
29 | BatchNorm1d(512, affine=affine))
30 | else:
31 | self.output_layer = Sequential(BatchNorm2d(512),
32 | Dropout(drop_ratio),
33 | Flatten(),
34 | Linear(512 * 14 * 14, 512),
35 | BatchNorm1d(512, affine=affine))
36 |
37 | modules = []
38 | for block in blocks:
39 | for bottleneck in block:
40 | modules.append(unit_module(bottleneck.in_channel,
41 | bottleneck.depth,
42 | bottleneck.stride))
43 | self.body = Sequential(*modules)
44 |
45 | def forward(self, x):
46 | x = self.input_layer(x)
47 | x = self.body(x)
48 | x = self.output_layer(x)
49 | return l2_norm(x)
50 |
51 |
52 | def IR_50(input_size):
53 | """Constructs a ir-50 model."""
54 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
55 | return model
56 |
57 |
58 | def IR_101(input_size):
59 | """Constructs a ir-101 model."""
60 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
61 | return model
62 |
63 |
64 | def IR_152(input_size):
65 | """Constructs a ir-152 model."""
66 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
67 | return model
68 |
69 |
70 | def IR_SE_50(input_size):
71 | """Constructs a ir_se-50 model."""
72 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
73 | return model
74 |
75 |
76 | def IR_SE_101(input_size):
77 | """Constructs a ir_se-101 model."""
78 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
79 | return model
80 |
81 |
82 | def IR_SE_152(input_size):
83 | """Constructs a ir_se-152 model."""
84 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
85 | return model
86 |
--------------------------------------------------------------------------------
/gpeno/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 |
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils.data.sampler import Sampler
7 |
8 |
9 | def get_rank():
10 | if not dist.is_available():
11 | return 0
12 |
13 | if not dist.is_initialized():
14 | return 0
15 |
16 | return dist.get_rank()
17 |
18 |
19 | def synchronize():
20 | if not dist.is_available():
21 | return
22 |
23 | if not dist.is_initialized():
24 | return
25 |
26 | world_size = dist.get_world_size()
27 |
28 | if world_size == 1:
29 | return
30 |
31 | dist.barrier()
32 |
33 |
34 | def get_world_size():
35 | if not dist.is_available():
36 | return 1
37 |
38 | if not dist.is_initialized():
39 | return 1
40 |
41 | return dist.get_world_size()
42 |
43 |
44 | def reduce_sum(tensor):
45 | if not dist.is_available():
46 | return tensor
47 |
48 | if not dist.is_initialized():
49 | return tensor
50 |
51 | tensor = tensor.clone()
52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53 |
54 | return tensor
55 |
56 |
57 | def gather_grad(params):
58 | world_size = get_world_size()
59 |
60 | if world_size == 1:
61 | return
62 |
63 | for param in params:
64 | if param.grad is not None:
65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66 | param.grad.data.div_(world_size)
67 |
68 |
69 | def all_gather(data):
70 | world_size = get_world_size()
71 |
72 | if world_size == 1:
73 | return [data]
74 |
75 | buffer = pickle.dumps(data)
76 | storage = torch.ByteStorage.from_buffer(buffer)
77 | tensor = torch.ByteTensor(storage).to('cuda')
78 |
79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81 | dist.all_gather(size_list, local_size)
82 | size_list = [int(size.item()) for size in size_list]
83 | max_size = max(size_list)
84 |
85 | tensor_list = []
86 | for _ in size_list:
87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88 |
89 | if local_size != max_size:
90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91 | tensor = torch.cat((tensor, padding), 0)
92 |
93 | dist.all_gather(tensor_list, tensor)
94 |
95 | data_list = []
96 |
97 | for size, tensor in zip(size_list, tensor_list):
98 | buffer = tensor.cpu().numpy().tobytes()[:size]
99 | data_list.append(pickle.loads(buffer))
100 |
101 | return data_list
102 |
103 |
104 | def reduce_loss_dict(loss_dict):
105 | world_size = get_world_size()
106 |
107 | if world_size < 2:
108 | return loss_dict
109 |
110 | with torch.no_grad():
111 | keys = []
112 | losses = []
113 |
114 | for k in sorted(loss_dict.keys()):
115 | keys.append(k)
116 | losses.append(loss_dict[k])
117 |
118 | losses = torch.stack(losses, 0)
119 | dist.reduce(losses, dst=0)
120 |
121 | if dist.get_rank() == 0:
122 | losses /= world_size
123 |
124 | reduced_losses = {k: v for k, v in zip(keys, losses)}
125 |
126 | return reduced_losses
127 |
--------------------------------------------------------------------------------
/gpeno/face_model/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Function
8 | from torch.utils.cpp_extension import load, _import_module_from_library
9 |
10 | # if running GPEN without cuda, please comment line 11-19
11 | if platform.system() == 'Linux' and torch.cuda.is_available():
12 | module_path = os.path.dirname(__file__)
13 | fused = load(
14 | 'fused',
15 | sources=[
16 | os.path.join(module_path, 'fused_bias_act.cpp'),
17 | os.path.join(module_path, 'fused_bias_act_kernel.cu'),
18 | ],
19 | )
20 |
21 |
22 | #fused = _import_module_from_library('fused', '/tmp/torch_extensions/fused', True)
23 |
24 |
25 | class FusedLeakyReLUFunctionBackward(Function):
26 | @staticmethod
27 | def forward(ctx, grad_output, out, negative_slope, scale):
28 | ctx.save_for_backward(out)
29 | ctx.negative_slope = negative_slope
30 | ctx.scale = scale
31 |
32 | empty = grad_output.new_empty(0)
33 |
34 | grad_input = fused.fused_bias_act(
35 | grad_output, empty, out, 3, 1, negative_slope, scale
36 | )
37 |
38 | dim = [0]
39 |
40 | if grad_input.ndim > 2:
41 | dim += list(range(2, grad_input.ndim))
42 |
43 | grad_bias = grad_input.sum(dim).detach()
44 |
45 | return grad_input, grad_bias
46 |
47 | @staticmethod
48 | def backward(ctx, gradgrad_input, gradgrad_bias):
49 | out, = ctx.saved_tensors
50 | gradgrad_out = fused.fused_bias_act(
51 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
52 | )
53 |
54 | return gradgrad_out, None, None, None
55 |
56 |
57 | class FusedLeakyReLUFunction(Function):
58 | @staticmethod
59 | def forward(ctx, input, bias, negative_slope, scale):
60 | empty = input.new_empty(0)
61 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
62 | ctx.save_for_backward(out)
63 | ctx.negative_slope = negative_slope
64 | ctx.scale = scale
65 |
66 | return out
67 |
68 | @staticmethod
69 | def backward(ctx, grad_output):
70 | out, = ctx.saved_tensors
71 |
72 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
73 | grad_output, out, ctx.negative_slope, ctx.scale
74 | )
75 |
76 | return grad_input, grad_bias, None, None
77 |
78 |
79 | class FusedLeakyReLU(nn.Module):
80 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, device='cpu'):
81 | super().__init__()
82 |
83 | self.bias = nn.Parameter(torch.zeros(channel))
84 | self.negative_slope = negative_slope
85 | self.scale = scale
86 | self.device = device
87 |
88 | def forward(self, input):
89 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale, self.device)
90 |
91 |
92 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5, device='cpu'):
93 | if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu':
94 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
95 | else:
96 | return scale * F.leaky_relu(input + bias.view((1, -1)+(1,)*(len(input.shape)-2)), negative_slope=negative_slope)
97 |
--------------------------------------------------------------------------------
/gpeno/face_detect/data/wider_face.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import sys
4 | import torch
5 | import torch.utils.data as data
6 | import cv2
7 | import numpy as np
8 |
9 | class WiderFaceDetection(data.Dataset):
10 | def __init__(self, txt_path, preproc=None):
11 | self.preproc = preproc
12 | self.imgs_path = []
13 | self.words = []
14 | f = open(txt_path,'r')
15 | lines = f.readlines()
16 | isFirst = True
17 | labels = []
18 | for line in lines:
19 | line = line.rstrip()
20 | if line.startswith('#'):
21 | if isFirst is True:
22 | isFirst = False
23 | else:
24 | labels_copy = labels.copy()
25 | self.words.append(labels_copy)
26 | labels.clear()
27 | path = line[2:]
28 | path = txt_path.replace('label.txt','images/') + path
29 | self.imgs_path.append(path)
30 | else:
31 | line = line.split(' ')
32 | label = [float(x) for x in line]
33 | labels.append(label)
34 |
35 | self.words.append(labels)
36 |
37 | def __len__(self):
38 | return len(self.imgs_path)
39 |
40 | def __getitem__(self, index):
41 | img = cv2.imread(self.imgs_path[index])
42 | height, width, _ = img.shape
43 |
44 | labels = self.words[index]
45 | annotations = np.zeros((0, 15))
46 | if len(labels) == 0:
47 | return annotations
48 | for idx, label in enumerate(labels):
49 | annotation = np.zeros((1, 15))
50 | # bbox
51 | annotation[0, 0] = label[0] # x1
52 | annotation[0, 1] = label[1] # y1
53 | annotation[0, 2] = label[0] + label[2] # x2
54 | annotation[0, 3] = label[1] + label[3] # y2
55 |
56 | # landmarks
57 | annotation[0, 4] = label[4] # l0_x
58 | annotation[0, 5] = label[5] # l0_y
59 | annotation[0, 6] = label[7] # l1_x
60 | annotation[0, 7] = label[8] # l1_y
61 | annotation[0, 8] = label[10] # l2_x
62 | annotation[0, 9] = label[11] # l2_y
63 | annotation[0, 10] = label[13] # l3_x
64 | annotation[0, 11] = label[14] # l3_y
65 | annotation[0, 12] = label[16] # l4_x
66 | annotation[0, 13] = label[17] # l4_y
67 | if (annotation[0, 4]<0):
68 | annotation[0, 14] = -1
69 | else:
70 | annotation[0, 14] = 1
71 |
72 | annotations = np.append(annotations, annotation, axis=0)
73 | target = np.array(annotations)
74 | if self.preproc is not None:
75 | img, target = self.preproc(img, target)
76 |
77 | return torch.from_numpy(img), target
78 |
79 | def detection_collate(batch):
80 | """Custom collate fn for dealing with batches of images that have a different
81 | number of associated object annotations (bounding boxes).
82 |
83 | Arguments:
84 | batch: (tuple) A tuple of tensor images and lists of annotations
85 |
86 | Return:
87 | A tuple containing:
88 | 1) (tensor) batch of images stacked on their 0 dim
89 | 2) (list of tensors) annotations for a given image are stacked on 0 dim
90 | """
91 | targets = []
92 | imgs = []
93 | for _, sample in enumerate(batch):
94 | for _, tup in enumerate(sample):
95 | if torch.is_tensor(tup):
96 | imgs.append(tup)
97 | elif isinstance(tup, type(np.empty(0))):
98 | annos = torch.from_numpy(tup).float()
99 | targets.append(annos)
100 |
101 | return (torch.stack(imgs, 0), targets)
102 |
--------------------------------------------------------------------------------
/gpeno/training/data_loader/dataset_face.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import glob
5 | import math
6 | import random
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset
10 |
11 | import degradations
12 |
13 |
14 | class GFPGAN_degradation(object):
15 | def __init__(self):
16 | self.kernel_list = ['iso', 'aniso']
17 | self.kernel_prob = [0.5, 0.5]
18 | self.blur_kernel_size = 41
19 | self.blur_sigma = [0.1, 10]
20 | self.downsample_range = [0.8, 8]
21 | self.noise_range = [0, 20]
22 | self.jpeg_range = [60, 100]
23 | self.gray_prob = 0.2
24 | self.color_jitter_prob = 0.0
25 | self.color_jitter_pt_prob = 0.0
26 | self.shift = 20/255.
27 |
28 | def degrade_process(self, img_gt):
29 | if random.random() > 0.5:
30 | img_gt = cv2.flip(img_gt, 1)
31 |
32 | h, w = img_gt.shape[:2]
33 |
34 | # random color jitter
35 | if np.random.uniform() < self.color_jitter_prob:
36 | jitter_val = np.random.uniform(-self.shift, self.shift, 3).astype(np.float32)
37 | img_gt = img_gt + jitter_val
38 | img_gt = np.clip(img_gt, 0, 1)
39 |
40 | # random grayscale
41 | if np.random.uniform() < self.gray_prob:
42 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
43 | img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
44 |
45 | # ------------------------ generate lq image ------------------------ #
46 | # blur
47 | kernel = degradations.random_mixed_kernels(
48 | self.kernel_list,
49 | self.kernel_prob,
50 | self.blur_kernel_size,
51 | self.blur_sigma,
52 | self.blur_sigma, [-math.pi, math.pi],
53 | noise_range=None)
54 | img_lq = cv2.filter2D(img_gt, -1, kernel)
55 | # downsample
56 | scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
57 | img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
58 |
59 | # noise
60 | if self.noise_range is not None:
61 | img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
62 | # jpeg compression
63 | if self.jpeg_range is not None:
64 | img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
65 |
66 | # round and clip
67 | img_lq = np.clip((img_lq * 255.0).round(), 0, 255) / 255.
68 |
69 | # resize to original size
70 | img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
71 |
72 | return img_gt, img_lq
73 |
74 | class FaceDataset(Dataset):
75 | def __init__(self, path, resolution=512):
76 | self.resolution = resolution
77 |
78 | self.HQ_imgs = glob.glob(os.path.join(path, '*.*'))
79 | self.length = len(self.HQ_imgs)
80 |
81 | self.degrader = GFPGAN_degradation()
82 |
83 | def __len__(self):
84 | return self.length
85 |
86 | def __getitem__(self, index):
87 | img_gt = cv2.imread(self.HQ_imgs[index], cv2.IMREAD_COLOR)
88 | img_gt = cv2.resize(img_gt, (self.resolution, self.resolution), interpolation=cv2.INTER_AREA)
89 |
90 | # BFR degradation
91 | # We adopt the degradation of GFPGAN for simplicity, which however differs from our implementation in the paper.
92 | # Data degradation plays a key role in BFR. Please replace it with your own methods.
93 | img_gt = img_gt.astype(np.float32)/255.
94 | img_gt, img_lq = self.degrader.degrade_process(img_gt)
95 |
96 | img_gt = (torch.from_numpy(img_gt) - 0.5) / 0.5
97 | img_lq = (torch.from_numpy(img_lq) - 0.5) / 0.5
98 |
99 | img_gt = img_gt.permute(2, 0, 1).flip(0) # BGR->RGB
100 | img_lq = img_lq.permute(2, 0, 1).flip(0) # BGR->RGB
101 |
102 | return img_lq, img_gt
103 |
104 |
--------------------------------------------------------------------------------
/gpeno/training/loss/helpers.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4 |
5 | """
6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7 | """
8 |
9 |
10 | class Flatten(Module):
11 | def forward(self, input):
12 | return input.view(input.size(0), -1)
13 |
14 |
15 | def l2_norm(input, axis=1):
16 | norm = torch.norm(input, 2, axis, True)
17 | output = torch.div(input, norm)
18 | return output
19 |
20 |
21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
22 | """ A named tuple describing a ResNet block. """
23 |
24 |
25 | def get_block(in_channel, depth, num_units, stride=2):
26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
27 |
28 |
29 | def get_blocks(num_layers):
30 | if num_layers == 50:
31 | blocks = [
32 | get_block(in_channel=64, depth=64, num_units=3),
33 | get_block(in_channel=64, depth=128, num_units=4),
34 | get_block(in_channel=128, depth=256, num_units=14),
35 | get_block(in_channel=256, depth=512, num_units=3)
36 | ]
37 | elif num_layers == 100:
38 | blocks = [
39 | get_block(in_channel=64, depth=64, num_units=3),
40 | get_block(in_channel=64, depth=128, num_units=13),
41 | get_block(in_channel=128, depth=256, num_units=30),
42 | get_block(in_channel=256, depth=512, num_units=3)
43 | ]
44 | elif num_layers == 152:
45 | blocks = [
46 | get_block(in_channel=64, depth=64, num_units=3),
47 | get_block(in_channel=64, depth=128, num_units=8),
48 | get_block(in_channel=128, depth=256, num_units=36),
49 | get_block(in_channel=256, depth=512, num_units=3)
50 | ]
51 | else:
52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
53 | return blocks
54 |
55 |
56 | class SEModule(Module):
57 | def __init__(self, channels, reduction):
58 | super(SEModule, self).__init__()
59 | self.avg_pool = AdaptiveAvgPool2d(1)
60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
61 | self.relu = ReLU(inplace=True)
62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
63 | self.sigmoid = Sigmoid()
64 |
65 | def forward(self, x):
66 | module_input = x
67 | x = self.avg_pool(x)
68 | x = self.fc1(x)
69 | x = self.relu(x)
70 | x = self.fc2(x)
71 | x = self.sigmoid(x)
72 | return module_input * x
73 |
74 |
75 | class bottleneck_IR(Module):
76 | def __init__(self, in_channel, depth, stride):
77 | super(bottleneck_IR, self).__init__()
78 | if in_channel == depth:
79 | self.shortcut_layer = MaxPool2d(1, stride)
80 | else:
81 | self.shortcut_layer = Sequential(
82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
83 | BatchNorm2d(depth)
84 | )
85 | self.res_layer = Sequential(
86 | BatchNorm2d(in_channel),
87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
89 | )
90 |
91 | def forward(self, x):
92 | shortcut = self.shortcut_layer(x)
93 | res = self.res_layer(x)
94 | return res + shortcut
95 |
96 |
97 | class bottleneck_IR_SE(Module):
98 | def __init__(self, in_channel, depth, stride):
99 | super(bottleneck_IR_SE, self).__init__()
100 | if in_channel == depth:
101 | self.shortcut_layer = MaxPool2d(1, stride)
102 | else:
103 | self.shortcut_layer = Sequential(
104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
105 | BatchNorm2d(depth)
106 | )
107 | self.res_layer = Sequential(
108 | BatchNorm2d(in_channel),
109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
110 | PReLU(depth),
111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
112 | BatchNorm2d(depth),
113 | SEModule(depth, 16)
114 | )
115 |
116 | def forward(self, x):
117 | shortcut = self.shortcut_layer(x)
118 | res = self.res_layer(x)
119 | return res + shortcut
120 |
--------------------------------------------------------------------------------
/gpeno/face_parse/blocks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn.parameter import Parameter
5 | from torch.nn import functional as F
6 | import numpy as np
7 |
8 | class NormLayer(nn.Module):
9 | """Normalization Layers.
10 | ------------
11 | # Arguments
12 | - channels: input channels, for batch norm and instance norm.
13 | - input_size: input shape without batch size, for layer norm.
14 | """
15 | def __init__(self, channels, normalize_shape=None, norm_type='bn', ref_channels=None):
16 | super(NormLayer, self).__init__()
17 | norm_type = norm_type.lower()
18 | self.norm_type = norm_type
19 | if norm_type == 'bn':
20 | self.norm = nn.BatchNorm2d(channels, affine=True)
21 | elif norm_type == 'in':
22 | self.norm = nn.InstanceNorm2d(channels, affine=False)
23 | elif norm_type == 'gn':
24 | self.norm = nn.GroupNorm(32, channels, affine=True)
25 | elif norm_type == 'pixel':
26 | self.norm = lambda x: F.normalize(x, p=2, dim=1)
27 | elif norm_type == 'layer':
28 | self.norm = nn.LayerNorm(normalize_shape)
29 | elif norm_type == 'none':
30 | self.norm = lambda x: x*1.0
31 | else:
32 | assert 1==0, 'Norm type {} not support.'.format(norm_type)
33 |
34 | def forward(self, x, ref=None):
35 | if self.norm_type == 'spade':
36 | return self.norm(x, ref)
37 | else:
38 | return self.norm(x)
39 |
40 |
41 | class ReluLayer(nn.Module):
42 | """Relu Layer.
43 | ------------
44 | # Arguments
45 | - relu type: type of relu layer, candidates are
46 | - ReLU
47 | - LeakyReLU: default relu slope 0.2
48 | - PRelu
49 | - SELU
50 | - none: direct pass
51 | """
52 | def __init__(self, channels, relu_type='relu'):
53 | super(ReluLayer, self).__init__()
54 | relu_type = relu_type.lower()
55 | if relu_type == 'relu':
56 | self.func = nn.ReLU(True)
57 | elif relu_type == 'leakyrelu':
58 | self.func = nn.LeakyReLU(0.2, inplace=True)
59 | elif relu_type == 'prelu':
60 | self.func = nn.PReLU(channels)
61 | elif relu_type == 'selu':
62 | self.func = nn.SELU(True)
63 | elif relu_type == 'none':
64 | self.func = lambda x: x*1.0
65 | else:
66 | assert 1==0, 'Relu type {} not support.'.format(relu_type)
67 |
68 | def forward(self, x):
69 | return self.func(x)
70 |
71 |
72 | class ConvLayer(nn.Module):
73 | def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True, bias=True):
74 | super(ConvLayer, self).__init__()
75 | self.use_pad = use_pad
76 | self.norm_type = norm_type
77 | if norm_type in ['bn']:
78 | bias = False
79 |
80 | stride = 2 if scale == 'down' else 1
81 |
82 | self.scale_func = lambda x: x
83 | if scale == 'up':
84 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
85 |
86 | self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.)/2)))
87 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
88 |
89 | self.relu = ReluLayer(out_channels, relu_type)
90 | self.norm = NormLayer(out_channels, norm_type=norm_type)
91 |
92 | def forward(self, x):
93 | out = self.scale_func(x)
94 | if self.use_pad:
95 | out = self.reflection_pad(out)
96 | out = self.conv2d(out)
97 | out = self.norm(out)
98 | out = self.relu(out)
99 | return out
100 |
101 |
102 | class ResidualBlock(nn.Module):
103 | """
104 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
105 | """
106 | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
107 | super(ResidualBlock, self).__init__()
108 |
109 | if scale == 'none' and c_in == c_out:
110 | self.shortcut_func = lambda x: x
111 | else:
112 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
113 |
114 | scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
115 | scale_conf = scale_config_dict[scale]
116 |
117 | self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
118 | self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
119 |
120 | def forward(self, x):
121 | identity = self.shortcut_func(x)
122 |
123 | res = self.conv1(x)
124 | res = self.conv2(res)
125 | return identity + res
126 |
127 |
128 |
--------------------------------------------------------------------------------
/gpeno/face_detect/facemodels/retinaface.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models.detection.backbone_utils as backbone_utils
4 | import torchvision.models._utils as _utils
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 |
8 | from facemodels.net import MobileNetV1 as MobileNetV1
9 | from facemodels.net import FPN as FPN
10 | from facemodels.net import SSH as SSH
11 |
12 |
13 | class ClassHead(nn.Module):
14 |
15 | def __init__(self, inchannels=512, num_anchors=3):
16 | super(ClassHead, self).__init__()
17 | self.num_anchors = num_anchors
18 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
19 |
20 | def forward(self, x):
21 | out = self.conv1x1(x)
22 | out = out.permute(0, 2, 3, 1).contiguous()
23 |
24 | return out.view(out.shape[0], -1, 2)
25 |
26 |
27 | class BboxHead(nn.Module):
28 |
29 | def __init__(self, inchannels=512, num_anchors=3):
30 | super(BboxHead, self).__init__()
31 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
32 |
33 | def forward(self, x):
34 | out = self.conv1x1(x)
35 | out = out.permute(0, 2, 3, 1).contiguous()
36 |
37 | return out.view(out.shape[0], -1, 4)
38 |
39 |
40 | class LandmarkHead(nn.Module):
41 |
42 | def __init__(self, inchannels=512, num_anchors=3):
43 | super(LandmarkHead, self).__init__()
44 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
45 |
46 | def forward(self, x):
47 | out = self.conv1x1(x)
48 | out = out.permute(0, 2, 3, 1).contiguous()
49 |
50 | return out.view(out.shape[0], -1, 10)
51 |
52 |
53 | class RetinaFace(nn.Module):
54 |
55 | def __init__(self, cfg=None, phase='train'):
56 | """
57 | :param cfg: Network related settings.
58 | :param phase: train or test.
59 | """
60 | super(RetinaFace, self).__init__()
61 | self.phase = phase
62 | backbone = None
63 |
64 | if cfg['name'] == 'detection_mobilenet0.25_Final':
65 | backbone = MobileNetV1()
66 | if cfg['pretrain']:
67 | checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
68 | from collections import OrderedDict
69 | new_state_dict = OrderedDict()
70 | for k, v in checkpoint['state_dict'].items():
71 | name = k[7:] # remove module.
72 | new_state_dict[name] = v
73 | # load params
74 | backbone.load_state_dict(new_state_dict)
75 | elif cfg['name'] == 'detection_Resnet50_Final':
76 | import torchvision.models as models
77 | backbone = models.resnet50(pretrained=cfg['pretrain'])
78 |
79 | self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
80 | in_channels_stage2 = cfg['in_channel']
81 | in_channels_list = [
82 | in_channels_stage2 * 2,
83 | in_channels_stage2 * 4,
84 | in_channels_stage2 * 8,
85 | ]
86 | out_channels = cfg['out_channel']
87 | self.fpn = FPN(in_channels_list, out_channels)
88 | self.ssh1 = SSH(out_channels, out_channels)
89 | self.ssh2 = SSH(out_channels, out_channels)
90 | self.ssh3 = SSH(out_channels, out_channels)
91 |
92 | self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
93 | self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
94 | self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
95 |
96 | def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2):
97 | classhead = nn.ModuleList()
98 | for i in range(fpn_num):
99 | classhead.append(ClassHead(inchannels, anchor_num))
100 | return classhead
101 |
102 | def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2):
103 | bboxhead = nn.ModuleList()
104 | for i in range(fpn_num):
105 | bboxhead.append(BboxHead(inchannels, anchor_num))
106 | return bboxhead
107 |
108 | def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2):
109 | landmarkhead = nn.ModuleList()
110 | for i in range(fpn_num):
111 | landmarkhead.append(LandmarkHead(inchannels, anchor_num))
112 | return landmarkhead
113 |
114 | def forward(self, inputs):
115 | out = self.body(inputs)
116 |
117 | # FPN
118 | fpn = self.fpn(out)
119 |
120 | # SSH
121 | feature1 = self.ssh1(fpn[0])
122 | feature2 = self.ssh2(fpn[1])
123 | feature3 = self.ssh3(fpn[2])
124 | features = [feature1, feature2, feature3]
125 |
126 | bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
127 | classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
128 | ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
129 |
130 | if self.phase == 'train':
131 | output = (bbox_regressions, classifications, ldm_regressions)
132 | else:
133 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
134 | return output
135 |
--------------------------------------------------------------------------------
/gpeno/misc/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from pathlib import Path
4 | import glob
5 | import numpy as np
6 | import cv2
7 | from zipfile import ZipFile
8 | from PIL import Image
9 | import shutil
10 | import cog
11 | from face_enhancement import FaceEnhancement
12 | from face_colorization import FaceColorization
13 | from face_inpainting import FaceInpainting, brush_stroke_mask
14 |
15 |
16 | class Predictor(cog.Predictor):
17 | def setup(self):
18 | faceenhancer_model = {'name': 'GPEN-BFR-256', 'size': 256, 'channel_multiplier': 1, 'narrow': 0.5}
19 | self.faceenhancer = FaceEnhancement(size=faceenhancer_model['size'], model=faceenhancer_model['name'],
20 | channel_multiplier=faceenhancer_model['channel_multiplier'],
21 | narrow=faceenhancer_model['narrow'])
22 | faceinpainter_model = {'name': 'GPEN-Inpainting-1024', 'size': 1024}
23 | self.faceinpainter = FaceInpainting(size=faceinpainter_model['size'], model=faceinpainter_model['name'],
24 | channel_multiplier=2)
25 | facecolorizer_model = {'name': 'GPEN-Colorization-1024', 'size': 1024}
26 | self.facecolorizer = FaceColorization(size=facecolorizer_model['size'], model=facecolorizer_model['name'],
27 | channel_multiplier=2)
28 |
29 | @cog.input(
30 | "image",
31 | type=Path,
32 | help="input image",
33 | )
34 | @cog.input(
35 | "task",
36 | type=str,
37 | options=['Face Restoration', 'Face Colorization', 'Face Inpainting'],
38 | default='Face Restoration',
39 | help="choose task type"
40 | )
41 | @cog.input(
42 | "output_individual",
43 | type=bool,
44 | default=False,
45 | help="whether outputs individual enhanced faces, valid for Face Restoration. When set to true, a zip folder of "
46 | "all the enhanced faces in the input will be generated for download."
47 | )
48 | @cog.input(
49 | "broken_image",
50 | type=bool,
51 | default=True,
52 | help="whether the input image is broken, valid for Face Inpainting. When set to True, the output will be the "
53 | "'fixed' image. When set to False, the image will randomly add brush strokes to simulate a broken image, "
54 | "and the output will be broken + fixed image"
55 | )
56 | def predict(self, image, task='Face Restoration', output_individual=False, broken_image=True):
57 | out_path = Path(tempfile.mkdtemp()) / "out.png"
58 | if task == 'Face Restoration':
59 | im = cv2.imread(str(image), cv2.IMREAD_COLOR) # BGR
60 | assert isinstance(im, np.ndarray), 'input filename error'
61 | im = cv2.resize(im, (0, 0), fx=2, fy=2)
62 | img, orig_faces, enhanced_faces = self.faceenhancer.process(im)
63 | cv2.imwrite(str(out_path), img)
64 | if output_individual:
65 | zip_folder = 'out_zip'
66 | os.makedirs(zip_folder, exist_ok=True)
67 | out_path = Path(tempfile.mkdtemp()) / "out.zip"
68 | try:
69 | cv2.imwrite(os.path.join(zip_folder, 'whole_image.jpg'), img)
70 | for m, ef in enumerate(enhanced_faces):
71 | cv2.imwrite(os.path.join(zip_folder, f'face_{m}.jpg'), ef)
72 | img_list = sorted(glob.glob(os.path.join(zip_folder, '*')))
73 | with ZipFile(str(out_path), 'w') as zipfile:
74 | for img in img_list:
75 | zipfile.write(img)
76 | finally:
77 | clean_folder(zip_folder)
78 | elif task == 'Face Colorization':
79 | grayf = cv2.imread(str(image), cv2.IMREAD_GRAYSCALE)
80 | grayf = cv2.cvtColor(grayf, cv2.COLOR_GRAY2BGR) # channel: 1->3
81 | colorf = self.facecolorizer.process(grayf)
82 | cv2.imwrite(str(out_path), colorf)
83 | else:
84 | originf = cv2.imread(str(image), cv2.IMREAD_COLOR)
85 | brokenf = originf
86 | if not broken_image:
87 | brokenf = np.asarray(brush_stroke_mask(Image.fromarray(originf)))
88 | completef = self.faceinpainter.process(brokenf)
89 | brokenf = cv2.resize(brokenf, completef.shape[:2])
90 | out_img = completef if broken_image else np.hstack((brokenf, completef))
91 | cv2.imwrite(str(out_path), out_img)
92 |
93 | return out_path
94 |
95 |
96 | def clean_folder(folder):
97 | for filename in os.listdir(folder):
98 | file_path = os.path.join(folder, filename)
99 | try:
100 | if os.path.isfile(file_path) or os.path.islink(file_path):
101 | os.unlink(file_path)
102 | elif os.path.isdir(file_path):
103 | shutil.rmtree(file_path)
104 | except Exception as e:
105 | print('Failed to delete %s. Reason: %s' % (file_path, e))
--------------------------------------------------------------------------------
/gpeno/face_detect/retinaface_detection.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import os
6 | import torch
7 | import numpy as np
8 | from data import cfg_re50, cfg_mnet
9 | from layers.functions.prior_box import PriorBox
10 | from gpeno.face_detect.utils.nms.py_cpu_nms import py_cpu_nms
11 | import cv2
12 | from facemodels.retinaface import RetinaFace
13 | from gpeno.face_detect.utils.box_utils import decode, decode_landm
14 | import torch.nn.functional as F
15 |
16 |
17 | class RetinaFaceDetection(object):
18 |
19 | def __init__(self, base_dir, device='cuda', network='RetinaFace-R50'):
20 | torch.set_grad_enabled(False)
21 | print(f"Initializing RetinaFaceDetection on device {device}...")
22 | self.pretrained_path = os.path.join(base_dir, 'facedetection', network + '.pth')
23 | self.device = device
24 | if network == "detection_Resnet50_Final":
25 | self.cfg = cfg_re50
26 | else:
27 | self.cfg = cfg_mnet
28 | self.net = RetinaFace(cfg=self.cfg, phase='test')
29 | self.net = self.net.to(self.device)
30 |
31 | self.load_model()
32 |
33 | self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device)
34 |
35 | def check_keys(self, pretrained_state_dict):
36 | ckpt_keys = set(pretrained_state_dict.keys())
37 | model_keys = set(self.net.state_dict().keys())
38 | used_pretrained_keys = model_keys & ckpt_keys
39 | unused_pretrained_keys = ckpt_keys - model_keys
40 | missing_keys = model_keys - ckpt_keys
41 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
42 | return True
43 |
44 | def remove_prefix(self, state_dict, prefix):
45 | ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
46 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
47 | return {f(key): value for key, value in state_dict.items()}
48 |
49 | def load_model(self):
50 | pretrained_dict = torch.load(self.pretrained_path, map_location=self.device)
51 |
52 | if "state_dict" in pretrained_dict.keys():
53 | pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.')
54 | else:
55 | pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')
56 | # self.check_keys(pretrained_dict)
57 | self.net.load_state_dict(pretrained_dict, strict=False)
58 | self.net.eval()
59 |
60 | print("Finished loading model")
61 |
62 | def detect(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False):
63 | img = np.float32(img_raw)
64 |
65 | im_height, im_width = img.shape[:2]
66 | ss = 1.0
67 | # tricky
68 | if max(im_height, im_width) > 1500:
69 | ss = 1000.0 / max(im_height, im_width)
70 | img = cv2.resize(img, (0, 0), fx=ss, fy=ss)
71 | im_height, im_width = img.shape[:2]
72 |
73 | scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
74 | img -= (104, 117, 123)
75 | img = img.transpose(2, 0, 1)
76 | img = torch.from_numpy(img).unsqueeze(0)
77 | img = img.to(self.device)
78 | scale = scale.to(self.device)
79 |
80 | loc, conf, landms = self.net(img) # forward pass
81 |
82 | del img
83 |
84 | priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
85 | priors = priorbox.forward()
86 | priors = priors.to(self.device)
87 | prior_data = priors.data
88 | boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
89 | boxes = boxes * scale / resize
90 | boxes = boxes.cpu().numpy()
91 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
92 | landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance'])
93 | scale1 = torch.Tensor([im_width, im_height, im_width, im_height, im_width, im_height, im_width, im_height, im_width, im_height])
94 | scale1 = scale1.to(self.device)
95 | landms = landms * scale1 / resize
96 | landms = landms.cpu().numpy()
97 |
98 | # ignore low scores
99 | inds = np.where(scores > confidence_threshold)[0]
100 | boxes = boxes[inds]
101 | landms = landms[inds]
102 | scores = scores[inds]
103 |
104 | # keep top-K before NMS
105 | order = scores.argsort()[::-1][:top_k]
106 | boxes = boxes[order]
107 | landms = landms[order]
108 | scores = scores[order]
109 |
110 | # do NMS
111 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
112 | keep = py_cpu_nms(dets, nms_threshold)
113 | # keep = nms(dets, nms_threshold,force_cpu=args.cpu)
114 | dets = dets[keep, :]
115 | landms = landms[keep]
116 |
117 | # keep top-K faster NMS
118 | dets = dets[:keep_top_k, :]
119 | landms = landms[:keep_top_k, :]
120 |
121 | # sort faces(delete)
122 | '''
123 | fscores = [det[4] for det in dets]
124 | sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index
125 | tmp = [landms[idx] for idx in sorted_idx]
126 | landms = np.asarray(tmp)
127 | '''
128 |
129 | landms = landms.reshape((-1, 5, 2))
130 | landms = landms.transpose((0, 2, 1))
131 | landms = landms.reshape(
132 | -1,
133 | 10,
134 | )
135 | return dets / ss, landms / ss
136 |
--------------------------------------------------------------------------------
/gpeno/face_detect/facemodels/net.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.models._utils as _utils
5 | import torchvision.models as models
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 | def conv_bn(inp, oup, stride = 1, leaky = 0):
10 | return nn.Sequential(
11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
12 | nn.BatchNorm2d(oup),
13 | nn.LeakyReLU(negative_slope=leaky, inplace=True)
14 | )
15 |
16 | def conv_bn_no_relu(inp, oup, stride):
17 | return nn.Sequential(
18 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
19 | nn.BatchNorm2d(oup),
20 | )
21 |
22 | def conv_bn1X1(inp, oup, stride, leaky=0):
23 | return nn.Sequential(
24 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
25 | nn.BatchNorm2d(oup),
26 | nn.LeakyReLU(negative_slope=leaky, inplace=True)
27 | )
28 |
29 | def conv_dw(inp, oup, stride, leaky=0.1):
30 | return nn.Sequential(
31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
32 | nn.BatchNorm2d(inp),
33 | nn.LeakyReLU(negative_slope= leaky,inplace=True),
34 |
35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36 | nn.BatchNorm2d(oup),
37 | nn.LeakyReLU(negative_slope= leaky,inplace=True),
38 | )
39 |
40 | class SSH(nn.Module):
41 | def __init__(self, in_channel, out_channel):
42 | super(SSH, self).__init__()
43 | assert out_channel % 4 == 0
44 | leaky = 0
45 | if (out_channel <= 64):
46 | leaky = 0.1
47 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
48 |
49 | self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
50 | self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
51 |
52 | self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
53 | self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
54 |
55 | def forward(self, input):
56 | conv3X3 = self.conv3X3(input)
57 |
58 | conv5X5_1 = self.conv5X5_1(input)
59 | conv5X5 = self.conv5X5_2(conv5X5_1)
60 |
61 | conv7X7_2 = self.conv7X7_2(conv5X5_1)
62 | conv7X7 = self.conv7x7_3(conv7X7_2)
63 |
64 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
65 | out = F.relu(out)
66 | return out
67 |
68 | class FPN(nn.Module):
69 | def __init__(self,in_channels_list,out_channels):
70 | super(FPN,self).__init__()
71 | leaky = 0
72 | if (out_channels <= 64):
73 | leaky = 0.1
74 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
75 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
76 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
77 |
78 | self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
79 | self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
80 |
81 | def forward(self, input):
82 | # names = list(input.keys())
83 | input = list(input.values())
84 |
85 | output1 = self.output1(input[0])
86 | output2 = self.output2(input[1])
87 | output3 = self.output3(input[2])
88 |
89 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
90 | output2 = output2 + up3
91 | output2 = self.merge2(output2)
92 |
93 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
94 | output1 = output1 + up2
95 | output1 = self.merge1(output1)
96 |
97 | out = [output1, output2, output3]
98 | return out
99 |
100 |
101 |
102 | class MobileNetV1(nn.Module):
103 | def __init__(self):
104 | super(MobileNetV1, self).__init__()
105 | self.stage1 = nn.Sequential(
106 | conv_bn(3, 8, 2, leaky = 0.1), # 3
107 | conv_dw(8, 16, 1), # 7
108 | conv_dw(16, 32, 2), # 11
109 | conv_dw(32, 32, 1), # 19
110 | conv_dw(32, 64, 2), # 27
111 | conv_dw(64, 64, 1), # 43
112 | )
113 | self.stage2 = nn.Sequential(
114 | conv_dw(64, 128, 2), # 43 + 16 = 59
115 | conv_dw(128, 128, 1), # 59 + 32 = 91
116 | conv_dw(128, 128, 1), # 91 + 32 = 123
117 | conv_dw(128, 128, 1), # 123 + 32 = 155
118 | conv_dw(128, 128, 1), # 155 + 32 = 187
119 | conv_dw(128, 128, 1), # 187 + 32 = 219
120 | )
121 | self.stage3 = nn.Sequential(
122 | conv_dw(128, 256, 2), # 219 +3 2 = 241
123 | conv_dw(256, 256, 1), # 241 + 64 = 301
124 | )
125 | self.avg = nn.AdaptiveAvgPool2d((1,1))
126 | self.fc = nn.Linear(256, 1000)
127 |
128 | def forward(self, x):
129 | x = self.stage1(x)
130 | x = self.stage2(x)
131 | x = self.stage3(x)
132 | x = self.avg(x)
133 | # x = self.model(x)
134 | x = x.view(-1, 256)
135 | x = self.fc(x)
136 | return x
137 |
138 |
--------------------------------------------------------------------------------
/gpeno/face_detect/layers/modules/multibox_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from gpeno.face_detect.utils.box_utils import match, log_sum_exp
6 | from data import cfg_mnet
7 |
8 | GPU = cfg_mnet['gpu_train']
9 |
10 |
11 | class MultiBoxLoss(nn.Module):
12 | """SSD Weighted Loss Function
13 | Compute Targets:
14 | 1) Produce Confidence Target Indices by matching ground truth boxes
15 | with (default) 'priorboxes' that have jaccard index > threshold parameter
16 | (default threshold: 0.5).
17 | 2) Produce localization target by 'encoding' variance into offsets of ground
18 | truth boxes and their matched 'priorboxes'.
19 | 3) Hard negative mining to filter the excessive number of negative examples
20 | that comes with using a large number of default bounding boxes.
21 | (default negative:positive ratio 3:1)
22 | Objective Loss:
23 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
24 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
25 | weighted by α which is set to 1 by cross val.
26 | Args:
27 | c: class confidences,
28 | l: predicted boxes,
29 | g: ground truth boxes
30 | N: number of matched default boxes
31 | See: https://arxiv.org/pdf/1512.02325.pdf for more details.
32 | """
33 |
34 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
35 | super(MultiBoxLoss, self).__init__()
36 | self.num_classes = num_classes
37 | self.threshold = overlap_thresh
38 | self.background_label = bkg_label
39 | self.encode_target = encode_target
40 | self.use_prior_for_matching = prior_for_matching
41 | self.do_neg_mining = neg_mining
42 | self.negpos_ratio = neg_pos
43 | self.neg_overlap = neg_overlap
44 | self.variance = [0.1, 0.2]
45 |
46 | def forward(self, predictions, priors, targets):
47 | """Multibox Loss
48 | Args:
49 | predictions (tuple): A tuple containing loc preds, conf preds,
50 | and prior boxes from SSD net.
51 | conf shape: torch.size(batch_size,num_priors,num_classes)
52 | loc shape: torch.size(batch_size,num_priors,4)
53 | priors shape: torch.size(num_priors,4)
54 |
55 | ground_truth (tensor): Ground truth boxes and labels for a batch,
56 | shape: [batch_size,num_objs,5] (last idx is the label).
57 | """
58 |
59 | loc_data, conf_data, landm_data = predictions
60 | priors = priors
61 | num = loc_data.size(0)
62 | num_priors = (priors.size(0))
63 |
64 | # match priors (default boxes) and ground truth boxes
65 | loc_t = torch.Tensor(num, num_priors, 4)
66 | landm_t = torch.Tensor(num, num_priors, 10)
67 | conf_t = torch.LongTensor(num, num_priors)
68 | for idx in range(num):
69 | truths = targets[idx][:, :4].data
70 | labels = targets[idx][:, -1].data
71 | landms = targets[idx][:, 4:14].data
72 | defaults = priors.data
73 | match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
74 | if GPU:
75 | loc_t = loc_t.cuda()
76 | conf_t = conf_t.cuda()
77 | landm_t = landm_t.cuda()
78 |
79 | zeros = torch.tensor(0).cuda()
80 | # landm Loss (Smooth L1)
81 | # Shape: [batch,num_priors,10]
82 | pos1 = conf_t > zeros
83 | num_pos_landm = pos1.long().sum(1, keepdim=True)
84 | N1 = max(num_pos_landm.data.sum().float(), 1)
85 | pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
86 | landm_p = landm_data[pos_idx1].view(-1, 10)
87 | landm_t = landm_t[pos_idx1].view(-1, 10)
88 | loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
89 |
90 | pos = conf_t != zeros
91 | conf_t[pos] = 1
92 |
93 | # Localization Loss (Smooth L1)
94 | # Shape: [batch,num_priors,4]
95 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
96 | loc_p = loc_data[pos_idx].view(-1, 4)
97 | loc_t = loc_t[pos_idx].view(-1, 4)
98 | loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
99 |
100 | # Compute max conf across batch for hard negative mining
101 | batch_conf = conf_data.view(-1, self.num_classes)
102 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
103 |
104 | # Hard Negative Mining
105 | loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
106 | loss_c = loss_c.view(num, -1)
107 | _, loss_idx = loss_c.sort(1, descending=True)
108 | _, idx_rank = loss_idx.sort(1)
109 | num_pos = pos.long().sum(1, keepdim=True)
110 | num_neg = torch.clamp(self.negpos_ratio * num_pos, max=pos.size(1) - 1)
111 | neg = idx_rank < num_neg.expand_as(idx_rank)
112 |
113 | # Confidence Loss Including Positive and Negative Examples
114 | pos_idx = pos.unsqueeze(2).expand_as(conf_data)
115 | neg_idx = neg.unsqueeze(2).expand_as(conf_data)
116 | conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
117 | targets_weighted = conf_t[(pos + neg).gt(0)]
118 | loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
119 |
120 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
121 | N = max(num_pos.data.sum().float(), 1)
122 | loss_l /= N
123 | loss_c /= N
124 | loss_landm /= N1
125 |
126 | return loss_l, loss_c, loss_landm
127 |
--------------------------------------------------------------------------------
/gpeno/face_enhancement.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from gpeno.face_detect.retinaface_detection import RetinaFaceDetection
5 | # from gpeno.face_detect.batch_face import RetinaFace
6 | from gpeno.face_parse.face_parsing import FaceParse
7 | from gpeno.face_model.face_gan import FaceGAN
8 | # from gpen.sr_model.real_esrnet import RealESRNet
9 | from gpeno.align_faces import warp_and_crop_face, get_reference_facial_points
10 |
11 |
12 | class FaceEnhancement(object):
13 |
14 | def __init__(self, args, base_dir='./', in_size=512, out_size=None, model=None, use_sr=True, device='cuda', interp=3, backbone="RetinaFace-R50", log=None, colorize=False):
15 | self.log = log
16 | # self.log.debug("Initializing FaceEnhancement...")
17 | self.facedetector = RetinaFaceDetection(base_dir, device, network=backbone)
18 | # self.facedetector = RetinaFace()
19 | self.facegan = FaceGAN(base_dir, in_size, out_size, model, args.channel_multiplier, args.narrow, args.key, device=device)
20 | # self.srmodel = RealESRNet(base_dir, args.sr_model, args.sr_scale, args.tile_size, device=device)
21 | self.faceparser = FaceParse(base_dir, device=device)
22 | self.use_sr = use_sr
23 | self.in_size = in_size
24 | self.out_size = in_size if out_size is None else out_size
25 | self.threshold = 0.9
26 | self.alpha = args.alpha
27 | self.interp = interp
28 | self.colorize = colorize
29 |
30 | if self.colorize:
31 | self.colorizer = FaceGAN(base_dir, 1024, 1024, "GPEN-Colorization-1024", args.channel_multiplier, args.narrow, None, device)
32 |
33 | self.mask = np.zeros((512, 512), np.float32)
34 | cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA)
35 | self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4)
36 |
37 | self.kernel = np.array(([0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625]), dtype="float32")
38 | self.reference_5pts = get_reference_facial_points((self.in_size, self.in_size), 0.25, (0, 0), True)
39 |
40 | def mask_postprocess(self, mask, thres=26):
41 | mask[:thres, :] = 0
42 | mask[-thres:, :] = 0
43 | mask[:, :thres] = 0
44 | mask[:, -thres:] = 0
45 | mask = cv2.GaussianBlur(mask, (101, 101), 4)
46 | return mask.astype(np.float32)
47 |
48 | def colorize_face(self, face):
49 | # Convert BGR (OpenCV) to RGB before colorizing
50 | rgb_input = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
51 | color_face = self.colorizer.process(rgb_input)
52 | color_face = cv2.cvtColor(color_face, cv2.COLOR_RGB2BGR)
53 |
54 | if face.shape[:2] != color_face.shape[:2]:
55 | out_rs = cv2.resize(color_face, face.shape[:2][::-1])
56 | gray_yuv = cv2.cvtColor(face, cv2.COLOR_BGR2YUV)
57 | out_yuv = cv2.cvtColor(out_rs, cv2.COLOR_BGR2YUV)
58 |
59 | out_yuv[:, :, 0] = gray_yuv[:, :, 0]
60 | color_face = cv2.cvtColor(out_yuv, cv2.COLOR_YUV2BGR)
61 |
62 | return color_face
63 |
64 | def process(self, img, aligned=False):
65 | orig_faces, enhanced_faces = [], []
66 | if aligned:
67 | print("Aligned is true")
68 | ef = self.facegan.process(img)
69 | if self.colorize:
70 | ef = self.colorize_face(ef)
71 | orig_faces.append(img)
72 | enhanced_faces.append(ef)
73 |
74 | # if self.use_sr:
75 | # ef = self.srmodel.process(ef)
76 |
77 | return ef, orig_faces, enhanced_faces
78 |
79 | # if self.use_sr:
80 | # img_sr = self.srmodel.process(img)
81 | # if img_sr is not None:
82 | # img = cv2.resize(img, img_sr.shape[:2][::-1])
83 |
84 | with torch.no_grad():
85 | print("Starting face detection")
86 | facebs, landms = self.facedetector.detect(img)
87 | # faces = self.facedetector.detect(img)
88 |
89 | # self.log.debug("Face detection complete")
90 | print("Face detection complete")
91 |
92 | height, width = img.shape[:2]
93 | full_mask = np.zeros((height, width), dtype=np.float32)
94 | full_img = np.zeros(img.shape, dtype=np.uint8)
95 |
96 | # for i, (faceb, facial5points, score) in enumerate(faces):
97 | for i, (faceb, facial5points) in enumerate(zip(facebs, landms)):
98 | if faceb[4] < self.threshold:
99 | continue
100 | fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0])
101 |
102 | facial5points = np.reshape(facial5points, (2, 5))
103 |
104 | print("Starting face alignment...")
105 | of, tfm_inv = warp_and_crop_face(img, facial5points, reference_pts=self.reference_5pts, crop_size=(self.in_size, self.in_size))
106 |
107 | print("Face alignment complete")
108 |
109 | # enhance the face
110 | ef = self.facegan.process(of)
111 |
112 | if self.colorize:
113 | ef = self.colorize_face(ef)
114 |
115 | # self.log.debug("Face GAN complete")
116 | print("Face GAN complete")
117 |
118 | orig_faces.append(of)
119 | enhanced_faces.append(ef)
120 |
121 | tmp_mask = self.mask
122 | tmp_mask = self.mask_postprocess(self.faceparser.process(ef)[0] / 255.)
123 | # self.log.debug("Mask postprocessing complete")
124 | tmp_mask = cv2.resize(tmp_mask, (self.in_size, self.in_size), interpolation=self.interp)
125 |
126 | tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=self.interp)
127 |
128 | # from PIL import Image
129 |
130 | # tmp_mask_pil = Image.fromarray(tmp_mask)
131 | # Apply the inverse affine transformation using PIL
132 | # tmp_mask_pil = tmp_mask_pil.transform((width, height), Image.AFFINE, tuple(tfm_inv.flatten()), resample=self.interp)
133 | # tmp_mask = np.array(tmp_mask_pil)
134 |
135 | if min(fh, fw) < 100: # gaussian filter for small faces
136 | ef = cv2.filter2D(ef, -1, self.kernel)
137 |
138 | ef = cv2.addWeighted(ef, self.alpha, of, 1. - self.alpha, 0.0)
139 |
140 | if self.in_size != self.out_size:
141 | ef = cv2.resize(ef, (self.in_size, self.in_size), interpolation=self.interp)
142 | tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=self.interp)
143 |
144 | mask = tmp_mask - full_mask
145 | full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)]
146 | full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)]
147 |
148 | full_mask = full_mask[:, :, np.newaxis]
149 | # if self.use_sr and img_sr is not None:
150 | # img = cv2.convertScaleAbs(img_sr * (1 - full_mask) + full_img * full_mask)
151 | # else:
152 | img = cv2.convertScaleAbs(img * (1 - full_mask) + full_img * full_mask)
153 |
154 | print("Postprocessing complete")
155 |
156 | return img, orig_faces, enhanced_faces
157 |
--------------------------------------------------------------------------------
/gpeno/misc/onnx_export.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import sys
4 |
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 | from torch.autograd import Variable
9 |
10 | import __init_paths
11 | #from face_model import model
12 | from face_model.gpen_model import FullGenerator
13 |
14 | def model_load(model, path):
15 | """Load model."""
16 |
17 | if not os.path.exists(path):
18 | print("Model '{}' does not exist.".format(path))
19 | return
20 |
21 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
22 |
23 | model.load_state_dict(state_dict)
24 |
25 |
26 | def export_onnx(model, path, force_cpu):
27 | """Export onnx model."""
28 |
29 | import onnx
30 | import onnxruntime
31 | #from onnx import optimizer
32 | import numpy as np
33 |
34 | onnx_file_name = os.path.join(path, model+".onnx")
35 | model_weight_file = os.path.join(path, model+".pth")
36 | dummy_input = Variable(torch.randn(1, 3, 1024, 1024))
37 |
38 | # 1. Create and load model.
39 | model_setenv(force_cpu)
40 | torch_model = get_model(model_weight_file)
41 | torch_model.eval()
42 |
43 | # 2. Model export
44 | print("Export model ...")
45 |
46 | input_names = ["input"]
47 | output_names = ["output"]
48 | device = model_device()
49 | # torch.onnx.export(torch_model, dummy_input.to(device), onnx_file_name,
50 | # input_names=input_names,
51 | # output_names=output_names,
52 | # verbose=False,
53 | # opset_version=12,
54 | # keep_initializers_as_inputs=False,
55 | # export_params=True,
56 | # operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
57 | torch.onnx.export(torch_model, dummy_input.to(device), onnx_file_name,
58 | input_names=input_names,
59 | output_names=output_names,
60 | verbose=False,
61 | opset_version=10,
62 | operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
63 |
64 | # 3. Optimize model
65 | print('Checking model ...')
66 | onnx_model = onnx.load(onnx_file_name)
67 | onnx.checker.check_model(onnx_model)
68 | # https://github.com/onnx/optimizer
69 | print('Done checking model ...')
70 | # 4. Visual model
71 | # python -c "import netron; netron.start('output/image_zoom.onnx')"
72 |
73 | def verify_onnx(model, path, force_cpu):
74 | """Verify onnx model."""
75 |
76 | import onnxruntime
77 | import numpy as np
78 |
79 |
80 | model_weight_file = os.path.join(path, model+".pth")
81 |
82 | model_weight_file = "./weights/GPEN-512.pth"
83 |
84 | model_setenv(force_cpu)
85 | torch_model = get_model(model_weight_file)
86 | torch_model.eval()
87 |
88 | onnx_file_name = os.path.join(path, model+".onnx")
89 | onnxruntime_engine = onnxruntime.InferenceSession(onnx_file_name)
90 |
91 | def to_numpy(tensor):
92 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
93 |
94 | dummy_input = Variable(torch.randn(1, 3, 512, 512))
95 | with torch.no_grad():
96 | torch_output, _ = torch_model(dummy_input)
97 | onnxruntime_inputs = {onnxruntime_engine.get_inputs()[0].name: to_numpy(dummy_input)}
98 | onnxruntime_outputs = onnxruntime_engine.run(None, onnxruntime_inputs)
99 | np.testing.assert_allclose(to_numpy(torch_output), onnxruntime_outputs[0], rtol=1e-02, atol=1e-02)
100 | print("Example: Onnx model has been tested with ONNXRuntime, the result looks good !")
101 |
102 | def get_model(checkpoint):
103 | """Create encoder model."""
104 |
105 | #model_setenv()
106 | model = FullGenerator(1024, 512, 8, 2, narrow=1) #TODO
107 | model_load(model, checkpoint)
108 | device = model_device()
109 | model.to(device)
110 | return model
111 |
112 |
113 | def model_device():
114 | """Please call after model_setenv. """
115 |
116 | return torch.device(os.environ["DEVICE"])
117 |
118 |
119 | def model_setenv(cpu_only):
120 | """Setup environ ..."""
121 |
122 | # random init ...
123 | import random
124 | random.seed(42)
125 | torch.manual_seed(42)
126 |
127 | # Set default device to avoid exceptions
128 | if cpu_only:
129 | os.environ["DEVICE"] = 'cpu'
130 | else:
131 | if os.environ.get("DEVICE") != "cuda" and os.environ.get("DEVICE") != "cpu":
132 | os.environ["DEVICE"] = 'cuda' if torch.cuda.is_available() else 'cpu'
133 | if os.environ["DEVICE"] == 'cuda':
134 | torch.backends.cudnn.enabled = True
135 | torch.backends.cudnn.benchmark = True
136 |
137 | print("Running Environment:")
138 | print("----------------------------------------------")
139 | #print(" PWD: ", os.environ["PWD"])
140 | print(" DEVICE: ", os.environ["DEVICE"])
141 |
142 | # def export_torch(model, path):
143 | # """Export torch model."""
144 |
145 | # script_file = os.path.join(path, model+".pt")
146 | # weight_file = os.path.join(path, model+".onnx")
147 |
148 | # # 1. Load model
149 | # print("Loading model ...")
150 | # model = get_model(weight_file)
151 | # model.eval()
152 |
153 | # # 2. Model export
154 | # print("Export model ...")
155 | # dummy_input = Variable(torch.randn(1, 3, 512, 512))
156 | # device = model_device()
157 | # traced_script_module = torch.jit.trace(model, dummy_input.to(device), _force_outplace=True)
158 | # traced_script_module.save(script_file)
159 |
160 |
161 | if __name__ == '__main__':
162 | """Test model ..."""
163 | import argparse
164 |
165 | parser = argparse.ArgumentParser()
166 | parser.add_argument('--model', type=str, required=True)
167 | parser.add_argument('--path', type=str, default='./')
168 | parser.add_argument('--export', help="Export onnx model", action='store_true')
169 | parser.add_argument('--verify', help="Verify onnx model", action='store_true')
170 | parser.add_argument('--force-cpu', dest='force_cpu', help="Verify onnx model", action='store_true')
171 |
172 | args = parser.parse_args()
173 |
174 | # export_torch()
175 |
176 |
177 |
178 | if args.export:
179 | export_onnx(model = args.model, path = args.path, force_cpu=args.force_cpu)
180 |
181 | if args.verify:
182 | verify_onnx(model = args.model, path = args.path, force_cpu=args.force_cpu)
183 |
--------------------------------------------------------------------------------
/gpeno/face_model/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load, _import_module_from_library
8 |
9 | # if running GPEN without cuda, please comment line 10-18
10 | if platform.system() == 'Linux' and torch.cuda.is_available():
11 | module_path = os.path.dirname(__file__)
12 | upfirdn2d_op = load(
13 | 'upfirdn2d',
14 | sources=[
15 | os.path.join(module_path, 'upfirdn2d.cpp'),
16 | os.path.join(module_path, 'upfirdn2d_kernel.cu'),
17 | ],
18 | )
19 |
20 |
21 | #upfirdn2d_op = _import_module_from_library('upfirdn2d', '/tmp/torch_extensions/upfirdn2d', True)
22 |
23 | class UpFirDn2dBackward(Function):
24 | @staticmethod
25 | def forward(
26 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
27 | ):
28 |
29 | up_x, up_y = up
30 | down_x, down_y = down
31 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
32 |
33 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
34 |
35 | grad_input = upfirdn2d_op.upfirdn2d(
36 | grad_output,
37 | grad_kernel,
38 | down_x,
39 | down_y,
40 | up_x,
41 | up_y,
42 | g_pad_x0,
43 | g_pad_x1,
44 | g_pad_y0,
45 | g_pad_y1,
46 | )
47 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
48 |
49 | ctx.save_for_backward(kernel)
50 |
51 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
52 |
53 | ctx.up_x = up_x
54 | ctx.up_y = up_y
55 | ctx.down_x = down_x
56 | ctx.down_y = down_y
57 | ctx.pad_x0 = pad_x0
58 | ctx.pad_x1 = pad_x1
59 | ctx.pad_y0 = pad_y0
60 | ctx.pad_y1 = pad_y1
61 | ctx.in_size = in_size
62 | ctx.out_size = out_size
63 |
64 | return grad_input
65 |
66 | @staticmethod
67 | def backward(ctx, gradgrad_input):
68 | kernel, = ctx.saved_tensors
69 |
70 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
71 |
72 | gradgrad_out = upfirdn2d_op.upfirdn2d(
73 | gradgrad_input,
74 | kernel,
75 | ctx.up_x,
76 | ctx.up_y,
77 | ctx.down_x,
78 | ctx.down_y,
79 | ctx.pad_x0,
80 | ctx.pad_x1,
81 | ctx.pad_y0,
82 | ctx.pad_y1,
83 | )
84 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
85 | gradgrad_out = gradgrad_out.view(
86 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
87 | )
88 |
89 | return gradgrad_out, None, None, None, None, None, None, None, None
90 |
91 |
92 | class UpFirDn2d(Function):
93 | @staticmethod
94 | def forward(ctx, input, kernel, up, down, pad):
95 | up_x, up_y = up
96 | down_x, down_y = down
97 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
98 |
99 | kernel_h, kernel_w = kernel.shape
100 | batch, channel, in_h, in_w = input.shape
101 | ctx.in_size = input.shape
102 |
103 | input = input.reshape(-1, in_h, in_w, 1)
104 |
105 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
106 |
107 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
108 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
109 | ctx.out_size = (out_h, out_w)
110 |
111 | ctx.up = (up_x, up_y)
112 | ctx.down = (down_x, down_y)
113 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
114 |
115 | g_pad_x0 = kernel_w - pad_x0 - 1
116 | g_pad_y0 = kernel_h - pad_y0 - 1
117 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
118 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
119 |
120 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
121 |
122 | out = upfirdn2d_op.upfirdn2d(
123 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
124 | )
125 | # out = out.view(major, out_h, out_w, minor)
126 | out = out.view(-1, channel, out_h, out_w)
127 |
128 | return out
129 |
130 | @staticmethod
131 | def backward(ctx, grad_output):
132 | kernel, grad_kernel = ctx.saved_tensors
133 |
134 | grad_input = UpFirDn2dBackward.apply(
135 | grad_output,
136 | kernel,
137 | grad_kernel,
138 | ctx.up,
139 | ctx.down,
140 | ctx.pad,
141 | ctx.g_pad,
142 | ctx.in_size,
143 | ctx.out_size,
144 | )
145 |
146 | return grad_input, None, None, None, None
147 |
148 |
149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0), device='cpu'):
150 | if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu':
151 | out = UpFirDn2d.apply(
152 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
153 | )
154 | else:
155 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
156 |
157 | return out
158 |
159 |
160 | def upfirdn2d_native(
161 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
162 | ):
163 | input = input.permute(0, 2, 3, 1)
164 | _, in_h, in_w, minor = input.shape
165 | kernel_h, kernel_w = kernel.shape
166 | out = input.view(-1, in_h, 1, in_w, 1, minor)
167 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
168 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
169 |
170 | out = F.pad(
171 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
172 | )
173 | out = out[
174 | :,
175 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
176 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
177 | :,
178 | ]
179 |
180 | out = out.permute(0, 3, 1, 2)
181 | out = out.reshape(
182 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
183 | )
184 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
185 | out = F.conv2d(out, w)
186 | out = out.reshape(
187 | -1,
188 | minor,
189 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
190 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
191 | )
192 | # out = out.permute(0, 2, 3, 1)
193 | return out[:, :, ::down_y, ::down_x]
194 |
195 |
--------------------------------------------------------------------------------
/gpeno/demo.py:
--------------------------------------------------------------------------------
1 | '''
2 | @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3 | @author: yangxy (yangtao9009@gmail.com)
4 | '''
5 | import os
6 | import cv2
7 | import glob
8 | import time
9 | import math
10 | import argparse
11 | import numpy as np
12 | from PIL import Image, ImageDraw
13 | import __init_paths
14 | from face_enhancement import FaceEnhancement
15 | from face_colorization import FaceColorization
16 | from face_inpainting import FaceInpainting
17 | from segmentation2face import Segmentation2Face
18 |
19 | def brush_stroke_mask(img, color=(255,255,255)):
20 | min_num_vertex = 8
21 | max_num_vertex = 28
22 | mean_angle = 2*math.pi / 5
23 | angle_range = 2*math.pi / 15
24 | min_width = 12
25 | max_width = 80
26 | def generate_mask(H, W, img=None):
27 | average_radius = math.sqrt(H*H+W*W) / 8
28 | mask = Image.new('RGB', (W, H), 0)
29 | if img is not None: mask = img #Image.fromarray(img)
30 |
31 | for _ in range(np.random.randint(1, 4)):
32 | num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
33 | angle_min = mean_angle - np.random.uniform(0, angle_range)
34 | angle_max = mean_angle + np.random.uniform(0, angle_range)
35 | angles = []
36 | vertex = []
37 | for i in range(num_vertex):
38 | if i % 2 == 0:
39 | angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
40 | else:
41 | angles.append(np.random.uniform(angle_min, angle_max))
42 |
43 | h, w = mask.size
44 | vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
45 | for i in range(num_vertex):
46 | r = np.clip(
47 | np.random.normal(loc=average_radius, scale=average_radius//2),
48 | 0, 2*average_radius)
49 | new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
50 | new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
51 | vertex.append((int(new_x), int(new_y)))
52 |
53 | draw = ImageDraw.Draw(mask)
54 | width = int(np.random.uniform(min_width, max_width))
55 | draw.line(vertex, fill=color, width=width)
56 | for v in vertex:
57 | draw.ellipse((v[0] - width//2,
58 | v[1] - width//2,
59 | v[0] + width//2,
60 | v[1] + width//2),
61 | fill=color)
62 |
63 | return mask
64 |
65 | width, height = img.size
66 | mask = generate_mask(height, width, img)
67 | return mask
68 |
69 | if __name__=='__main__':
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('--model', type=str, default='GPEN-BFR-512', help='GPEN model')
72 | parser.add_argument('--task', type=str, default='FaceEnhancement', help='task of GPEN model')
73 | parser.add_argument('--key', type=str, default=None, help='key of GPEN model')
74 | parser.add_argument('--in_size', type=int, default=512, help='in resolution of GPEN')
75 | parser.add_argument('--out_size', type=int, default=None, help='out resolution of GPEN')
76 | parser.add_argument('--channel_multiplier', type=int, default=2, help='channel multiplier of GPEN')
77 | parser.add_argument('--narrow', type=float, default=1, help='channel narrow scale')
78 | parser.add_argument('--alpha', type=float, default=1, help='blending the results')
79 | parser.add_argument('--use_sr', action='store_true', help='use sr or not')
80 | parser.add_argument('--use_cuda', action='store_true', help='use cuda or not')
81 | parser.add_argument('--save_face', action='store_true', help='save face or not')
82 | parser.add_argument('--aligned', action='store_true', help='input are aligned faces or not')
83 | parser.add_argument('--sr_model', type=str, default='realesrnet', help='SR model')
84 | parser.add_argument('--sr_scale', type=int, default=2, help='SR scale')
85 | parser.add_argument('--tile_size', type=int, default=0, help='tile size for SR to avoid OOM')
86 | parser.add_argument('--indir', type=str, default='examples/imgs', help='input folder')
87 | parser.add_argument('--outdir', type=str, default='results/outs-BFR', help='output folder')
88 | parser.add_argument('--ext', type=str, default='.jpg', help='extension of output')
89 | args = parser.parse_args()
90 |
91 | #model = {'name':'GPEN-BFR-512', 'size':512, 'channel_multiplier':2, 'narrow':1}
92 | #model = {'name':'GPEN-BFR-256', 'size':256, 'channel_multiplier':1, 'narrow':0.5}
93 |
94 | os.makedirs(args.outdir, exist_ok=True)
95 |
96 | if args.task == 'FaceEnhancement':
97 | processer = FaceEnhancement(args, in_size=args.in_size, model=args.model, use_sr=args.use_sr, device='cuda' if args.use_cuda else 'cpu')
98 | elif args.task == 'FaceColorization':
99 | processer = FaceColorization(in_size=args.in_size, model=args.model, device='cuda' if args.use_cuda else 'cpu')
100 | elif args.task == 'FaceInpainting':
101 | processer = FaceInpainting(in_size=args.in_size, model=args.model, device='cuda' if args.use_cuda else 'cpu')
102 | elif args.task == 'Segmentation2Face':
103 | processer = Segmentation2Face(in_size=args.in_size, model=args.model, is_norm=False, device='cuda' if args.use_cuda else 'cpu')
104 |
105 |
106 | files = sorted(glob.glob(os.path.join(args.indir, '*.*g')))
107 | for n, file in enumerate(files[:]):
108 | filename, ext = os.path.splitext(os.path.basename(file))
109 |
110 | img = cv2.imread(file, cv2.IMREAD_COLOR) # BGR
111 | if not isinstance(img, np.ndarray): print(filename, 'error'); continue
112 | #img = cv2.resize(img, (0,0), fx=2, fy=2) # optional
113 |
114 | if args.task == 'FaceInpainting':
115 | img = np.asarray(brush_stroke_mask(Image.fromarray(img)))
116 |
117 | img_out, orig_faces, enhanced_faces = processer.process(img, aligned=args.aligned)
118 |
119 | img = cv2.resize(img, img_out.shape[:2][::-1])
120 | cv2.imwrite(f'{args.outdir}/{filename}_COMP{args.ext}', np.hstack((img, img_out)))
121 | cv2.imwrite(f'{args.outdir}/{filename}_GPEN{args.ext}', img_out)
122 |
123 | if args.save_face:
124 | for m, (ef, of) in enumerate(zip(enhanced_faces, orig_faces)):
125 | of = cv2.resize(of, ef.shape[:2])
126 | cv2.imwrite(f'{args.outdir}/{filename}_face{m:02d}{args.ext}', np.hstack((of, ef)))
127 |
128 | if n%10==0: print(n, filename)
129 |
--------------------------------------------------------------------------------
/gpeno/training/lpips/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | import torch
8 | # from torch.autograd import Variable
9 |
10 | from lpips.trainer import *
11 | from lpips.lpips import *
12 |
13 | # class PerceptualLoss(torch.nn.Module):
14 | # def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric)
15 | # # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16 | # super(PerceptualLoss, self).__init__()
17 | # print('Setting up Perceptual loss...')
18 | # self.use_gpu = use_gpu
19 | # self.spatial = spatial
20 | # self.gpu_ids = gpu_ids
21 | # self.model = dist_model.DistModel()
22 | # self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version)
23 | # print('...[%s] initialized'%self.model.name())
24 | # print('...Done')
25 |
26 | # def forward(self, pred, target, normalize=False):
27 | # """
28 | # Pred and target are Variables.
29 | # If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30 | # If normalize is False, assumes the images are already between [-1,+1]
31 |
32 | # Inputs pred and target are Nx3xHxW
33 | # Output pytorch Variable N long
34 | # """
35 |
36 | # if normalize:
37 | # target = 2 * target - 1
38 | # pred = 2 * pred - 1
39 |
40 | # return self.model.forward(target, pred)
41 |
42 | def normalize_tensor(in_feat,eps=1e-10):
43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44 | return in_feat/(norm_factor+eps)
45 |
46 | def l2(p0, p1, range=255.):
47 | return .5*np.mean((p0 / range - p1 / range)**2)
48 |
49 | def psnr(p0, p1, peak=255.):
50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51 |
52 | def dssim(p0, p1, range=255.):
53 | from skimage.measure import compare_ssim
54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
55 |
56 | def rgb2lab(in_img,mean_cent=False):
57 | from skimage import color
58 | img_lab = color.rgb2lab(in_img)
59 | if(mean_cent):
60 | img_lab[:,:,0] = img_lab[:,:,0]-50
61 | return img_lab
62 |
63 | def tensor2np(tensor_obj):
64 | # change dimension of a tensor object into a numpy array
65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
66 |
67 | def np2tensor(np_obj):
68 | # change dimenion of np array into tensor array
69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
70 |
71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
72 | # image tensor to lab tensor
73 | from skimage import color
74 |
75 | img = tensor2im(image_tensor)
76 | img_lab = color.rgb2lab(img)
77 | if(mc_only):
78 | img_lab[:,:,0] = img_lab[:,:,0]-50
79 | if(to_norm and not mc_only):
80 | img_lab[:,:,0] = img_lab[:,:,0]-50
81 | img_lab = img_lab/100.
82 |
83 | return np2tensor(img_lab)
84 |
85 | def tensorlab2tensor(lab_tensor,return_inbnd=False):
86 | from skimage import color
87 | import warnings
88 | warnings.filterwarnings("ignore")
89 |
90 | lab = tensor2np(lab_tensor)*100.
91 | lab[:,:,0] = lab[:,:,0]+50
92 |
93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
94 | if(return_inbnd):
95 | # convert back to lab, see if we match
96 | lab_back = color.rgb2lab(rgb_back.astype('uint8'))
97 | mask = 1.*np.isclose(lab_back,lab,atol=2.)
98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
99 | return (im2tensor(rgb_back),mask)
100 | else:
101 | return im2tensor(rgb_back)
102 |
103 | def load_image(path):
104 | if(path[-3:] == 'dng'):
105 | import rawpy
106 | with rawpy.imread(path) as raw:
107 | img = raw.postprocess()
108 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'):
109 | import cv2
110 | return cv2.imread(path)[:,:,::-1]
111 | else:
112 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8')
113 |
114 | return img
115 |
116 | def rgb2lab(input):
117 | from skimage import color
118 | return color.rgb2lab(input / 255.)
119 |
120 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
121 | image_numpy = image_tensor[0].cpu().float().numpy()
122 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
123 | return image_numpy.astype(imtype)
124 |
125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
126 | return torch.Tensor((image / factor - cent)
127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
128 |
129 | def tensor2vec(vector_tensor):
130 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
131 |
132 |
133 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
134 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
135 | image_numpy = image_tensor[0].cpu().float().numpy()
136 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
137 | return image_numpy.astype(imtype)
138 |
139 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
140 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
141 | return torch.Tensor((image / factor - cent)
142 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
143 |
144 |
145 |
146 | def voc_ap(rec, prec, use_07_metric=False):
147 | """ ap = voc_ap(rec, prec, [use_07_metric])
148 | Compute VOC AP given precision and recall.
149 | If use_07_metric is true, uses the
150 | VOC 07 11 point method (default:False).
151 | """
152 | if use_07_metric:
153 | # 11 point metric
154 | ap = 0.
155 | for t in np.arange(0., 1.1, 0.1):
156 | if np.sum(rec >= t) == 0:
157 | p = 0
158 | else:
159 | p = np.max(prec[rec >= t])
160 | ap = ap + p / 11.
161 | else:
162 | # correct AP calculation
163 | # first append sentinel values at the end
164 | mrec = np.concatenate(([0.], rec, [1.]))
165 | mpre = np.concatenate(([0.], prec, [0.]))
166 |
167 | # compute the precision envelope
168 | for i in range(mpre.size - 1, 0, -1):
169 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
170 |
171 | # to calculate area under PR curve, look for points
172 | # where X axis (recall) changes value
173 | i = np.where(mrec[1:] != mrec[:-1])[0]
174 |
175 | # and sum (\Delta recall) * prec
176 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
177 | return ap
178 |
179 |
--------------------------------------------------------------------------------
/gpeno/face_detect/data/data_augment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 | from gpeno.face_detect.utils.box_utils import matrix_iof
5 |
6 |
7 | def _crop(image, boxes, labels, landm, img_dim):
8 | height, width, _ = image.shape
9 | pad_image_flag = True
10 |
11 | for _ in range(250):
12 | """
13 | if random.uniform(0, 1) <= 0.2:
14 | scale = 1.0
15 | else:
16 | scale = random.uniform(0.3, 1.0)
17 | """
18 | PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
19 | scale = random.choice(PRE_SCALES)
20 | short_side = min(width, height)
21 | w = int(scale * short_side)
22 | h = w
23 |
24 | if width == w:
25 | l = 0
26 | else:
27 | l = random.randrange(width - w)
28 | if height == h:
29 | t = 0
30 | else:
31 | t = random.randrange(height - h)
32 | roi = np.array((l, t, l + w, t + h))
33 |
34 | value = matrix_iof(boxes, roi[np.newaxis])
35 | flag = (value >= 1)
36 | if not flag.any():
37 | continue
38 |
39 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2
40 | mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
41 | boxes_t = boxes[mask_a].copy()
42 | labels_t = labels[mask_a].copy()
43 | landms_t = landm[mask_a].copy()
44 | landms_t = landms_t.reshape([-1, 5, 2])
45 |
46 | if boxes_t.shape[0] == 0:
47 | continue
48 |
49 | image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
50 |
51 | boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
52 | boxes_t[:, :2] -= roi[:2]
53 | boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
54 | boxes_t[:, 2:] -= roi[:2]
55 |
56 | # landm
57 | landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
58 | landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
59 | landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
60 | landms_t = landms_t.reshape([-1, 10])
61 |
62 | # make sure that the cropped image contains at least one face > 16 pixel at training image scale
63 | b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
64 | b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
65 | mask_b = np.minimum(b_w_t, b_h_t) > 0.0
66 | boxes_t = boxes_t[mask_b]
67 | labels_t = labels_t[mask_b]
68 | landms_t = landms_t[mask_b]
69 |
70 | if boxes_t.shape[0] == 0:
71 | continue
72 |
73 | pad_image_flag = False
74 |
75 | return image_t, boxes_t, labels_t, landms_t, pad_image_flag
76 | return image, boxes, labels, landm, pad_image_flag
77 |
78 |
79 | def _distort(image):
80 |
81 | def _convert(image, alpha=1, beta=0):
82 | tmp = image.astype(float) * alpha + beta
83 | tmp[tmp < 0] = 0
84 | tmp[tmp > 255] = 255
85 | image[:] = tmp
86 |
87 | image = image.copy()
88 |
89 | if random.randrange(2):
90 |
91 | #brightness distortion
92 | if random.randrange(2):
93 | _convert(image, beta=random.uniform(-32, 32))
94 |
95 | #contrast distortion
96 | if random.randrange(2):
97 | _convert(image, alpha=random.uniform(0.5, 1.5))
98 |
99 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
100 |
101 | #saturation distortion
102 | if random.randrange(2):
103 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
104 |
105 | #hue distortion
106 | if random.randrange(2):
107 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
108 | tmp %= 180
109 | image[:, :, 0] = tmp
110 |
111 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
112 |
113 | else:
114 |
115 | #brightness distortion
116 | if random.randrange(2):
117 | _convert(image, beta=random.uniform(-32, 32))
118 |
119 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
120 |
121 | #saturation distortion
122 | if random.randrange(2):
123 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
124 |
125 | #hue distortion
126 | if random.randrange(2):
127 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
128 | tmp %= 180
129 | image[:, :, 0] = tmp
130 |
131 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
132 |
133 | #contrast distortion
134 | if random.randrange(2):
135 | _convert(image, alpha=random.uniform(0.5, 1.5))
136 |
137 | return image
138 |
139 |
140 | def _expand(image, boxes, fill, p):
141 | if random.randrange(2):
142 | return image, boxes
143 |
144 | height, width, depth = image.shape
145 |
146 | scale = random.uniform(1, p)
147 | w = int(scale * width)
148 | h = int(scale * height)
149 |
150 | left = random.randint(0, w - width)
151 | top = random.randint(0, h - height)
152 |
153 | boxes_t = boxes.copy()
154 | boxes_t[:, :2] += (left, top)
155 | boxes_t[:, 2:] += (left, top)
156 | expand_image = np.empty((h, w, depth), dtype=image.dtype)
157 | expand_image[:, :] = fill
158 | expand_image[top:top + height, left:left + width] = image
159 | image = expand_image
160 |
161 | return image, boxes_t
162 |
163 |
164 | def _mirror(image, boxes, landms):
165 | _, width, _ = image.shape
166 | if random.randrange(2):
167 | image = image[:, ::-1]
168 | boxes = boxes.copy()
169 | boxes[:, 0::2] = width - boxes[:, 2::-2]
170 |
171 | # landm
172 | landms = landms.copy()
173 | landms = landms.reshape([-1, 5, 2])
174 | landms[:, :, 0] = width - landms[:, :, 0]
175 | tmp = landms[:, 1, :].copy()
176 | landms[:, 1, :] = landms[:, 0, :]
177 | landms[:, 0, :] = tmp
178 | tmp1 = landms[:, 4, :].copy()
179 | landms[:, 4, :] = landms[:, 3, :]
180 | landms[:, 3, :] = tmp1
181 | landms = landms.reshape([-1, 10])
182 |
183 | return image, boxes, landms
184 |
185 |
186 | def _pad_to_square(image, rgb_mean, pad_image_flag):
187 | if not pad_image_flag:
188 | return image
189 | height, width, _ = image.shape
190 | long_side = max(width, height)
191 | image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
192 | image_t[:, :] = rgb_mean
193 | image_t[0:0 + height, 0:0 + width] = image
194 | return image_t
195 |
196 |
197 | def _resize_subtract_mean(image, insize, rgb_mean):
198 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
199 | interp_method = interp_methods[random.randrange(5)]
200 | image = cv2.resize(image, (insize, insize), interpolation=interp_method)
201 | image = image.astype(np.float32)
202 | image -= rgb_mean
203 | return image.transpose(2, 0, 1)
204 |
205 |
206 | class preproc(object):
207 |
208 | def __init__(self, img_dim, rgb_means):
209 | self.img_dim = img_dim
210 | self.rgb_means = rgb_means
211 |
212 | def __call__(self, image, targets):
213 | assert targets.shape[0] > 0, "this image does not have gt"
214 |
215 | boxes = targets[:, :4].copy()
216 | labels = targets[:, -1].copy()
217 | landm = targets[:, 4:-1].copy()
218 |
219 | image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
220 | image_t = _distort(image_t)
221 | image_t = _pad_to_square(image_t, self.rgb_means, pad_image_flag)
222 | image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
223 | height, width, _ = image_t.shape
224 | image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
225 | boxes_t[:, 0::2] /= width
226 | boxes_t[:, 1::2] /= height
227 |
228 | landm_t[:, 0::2] /= width
229 | landm_t[:, 1::2] /= height
230 |
231 | labels_t = np.expand_dims(labels_t, 1)
232 | targets_t = np.hstack((boxes_t, landm_t, labels_t))
233 |
234 | return image_t, targets_t
235 |
--------------------------------------------------------------------------------
/gpeno/training/lpips/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torchvision import models as tv
4 |
5 | class squeezenet(torch.nn.Module):
6 | def __init__(self, requires_grad=False, pretrained=True):
7 | super(squeezenet, self).__init__()
8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
9 | self.slice1 = torch.nn.Sequential()
10 | self.slice2 = torch.nn.Sequential()
11 | self.slice3 = torch.nn.Sequential()
12 | self.slice4 = torch.nn.Sequential()
13 | self.slice5 = torch.nn.Sequential()
14 | self.slice6 = torch.nn.Sequential()
15 | self.slice7 = torch.nn.Sequential()
16 | self.N_slices = 7
17 | for x in range(2):
18 | self.slice1.add_module(str(x), pretrained_features[x])
19 | for x in range(2,5):
20 | self.slice2.add_module(str(x), pretrained_features[x])
21 | for x in range(5, 8):
22 | self.slice3.add_module(str(x), pretrained_features[x])
23 | for x in range(8, 10):
24 | self.slice4.add_module(str(x), pretrained_features[x])
25 | for x in range(10, 11):
26 | self.slice5.add_module(str(x), pretrained_features[x])
27 | for x in range(11, 12):
28 | self.slice6.add_module(str(x), pretrained_features[x])
29 | for x in range(12, 13):
30 | self.slice7.add_module(str(x), pretrained_features[x])
31 | if not requires_grad:
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 | def forward(self, X):
36 | h = self.slice1(X)
37 | h_relu1 = h
38 | h = self.slice2(h)
39 | h_relu2 = h
40 | h = self.slice3(h)
41 | h_relu3 = h
42 | h = self.slice4(h)
43 | h_relu4 = h
44 | h = self.slice5(h)
45 | h_relu5 = h
46 | h = self.slice6(h)
47 | h_relu6 = h
48 | h = self.slice7(h)
49 | h_relu7 = h
50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
52 |
53 | return out
54 |
55 |
56 | class alexnet(torch.nn.Module):
57 | def __init__(self, requires_grad=False, pretrained=True):
58 | super(alexnet, self).__init__()
59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
60 | self.slice1 = torch.nn.Sequential()
61 | self.slice2 = torch.nn.Sequential()
62 | self.slice3 = torch.nn.Sequential()
63 | self.slice4 = torch.nn.Sequential()
64 | self.slice5 = torch.nn.Sequential()
65 | self.N_slices = 5
66 | for x in range(2):
67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x])
68 | for x in range(2, 5):
69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x])
70 | for x in range(5, 8):
71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x])
72 | for x in range(8, 10):
73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x])
74 | for x in range(10, 12):
75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x])
76 | if not requires_grad:
77 | for param in self.parameters():
78 | param.requires_grad = False
79 |
80 | def forward(self, X):
81 | h = self.slice1(X)
82 | h_relu1 = h
83 | h = self.slice2(h)
84 | h_relu2 = h
85 | h = self.slice3(h)
86 | h_relu3 = h
87 | h = self.slice4(h)
88 | h_relu4 = h
89 | h = self.slice5(h)
90 | h_relu5 = h
91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
93 |
94 | return out
95 |
96 | class vgg16(torch.nn.Module):
97 | def __init__(self, requires_grad=False, pretrained=True):
98 | super(vgg16, self).__init__()
99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
100 | self.slice1 = torch.nn.Sequential()
101 | self.slice2 = torch.nn.Sequential()
102 | self.slice3 = torch.nn.Sequential()
103 | self.slice4 = torch.nn.Sequential()
104 | self.slice5 = torch.nn.Sequential()
105 | self.N_slices = 5
106 | for x in range(4):
107 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
108 | for x in range(4, 9):
109 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
110 | for x in range(9, 16):
111 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
112 | for x in range(16, 23):
113 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
114 | for x in range(23, 30):
115 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
116 | if not requires_grad:
117 | for param in self.parameters():
118 | param.requires_grad = False
119 |
120 | def forward(self, X):
121 | h = self.slice1(X)
122 | h_relu1_2 = h
123 | h = self.slice2(h)
124 | h_relu2_2 = h
125 | h = self.slice3(h)
126 | h_relu3_3 = h
127 | h = self.slice4(h)
128 | h_relu4_3 = h
129 | h = self.slice5(h)
130 | h_relu5_3 = h
131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
133 |
134 | return out
135 |
136 |
137 |
138 | class resnet(torch.nn.Module):
139 | def __init__(self, requires_grad=False, pretrained=True, num=18):
140 | super(resnet, self).__init__()
141 | if(num==18):
142 | self.net = tv.resnet18(pretrained=pretrained)
143 | elif(num==34):
144 | self.net = tv.resnet34(pretrained=pretrained)
145 | elif(num==50):
146 | self.net = tv.resnet50(pretrained=pretrained)
147 | elif(num==101):
148 | self.net = tv.resnet101(pretrained=pretrained)
149 | elif(num==152):
150 | self.net = tv.resnet152(pretrained=pretrained)
151 | self.N_slices = 5
152 |
153 | self.conv1 = self.net.conv1
154 | self.bn1 = self.net.bn1
155 | self.relu = self.net.relu
156 | self.maxpool = self.net.maxpool
157 | self.layer1 = self.net.layer1
158 | self.layer2 = self.net.layer2
159 | self.layer3 = self.net.layer3
160 | self.layer4 = self.net.layer4
161 |
162 | def forward(self, X):
163 | h = self.conv1(X)
164 | h = self.bn1(h)
165 | h = self.relu(h)
166 | h_relu1 = h
167 | h = self.maxpool(h)
168 | h = self.layer1(h)
169 | h_conv2 = h
170 | h = self.layer2(h)
171 | h_conv3 = h
172 | h = self.layer3(h)
173 | h_conv4 = h
174 | h = self.layer4(h)
175 | h_conv5 = h
176 |
177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
179 |
180 | return out
181 |
--------------------------------------------------------------------------------
/gpeno/align_faces.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mon Apr 24 15:43:29 2017
4 | @author: zhaoy
5 | """
6 | """
7 | @Modified by yangxy (yangtao9009@gmail.com)
8 | """
9 | import cv2
10 | import numpy as np
11 | from skimage import transform as trans
12 |
13 | # reference facial points, a list of coordinates (x,y)
14 | REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], [33.54930115, 92.3655014], [62.72990036, 92.20410156]]
15 |
16 | DEFAULT_CROP_SIZE = (96, 112)
17 |
18 |
19 | def _umeyama(src, dst, estimate_scale=True, scale=1.0):
20 | """Estimate N-D similarity transformation with or without scaling.
21 | Parameters
22 | ----------
23 | src : (M, N) array
24 | Source coordinates.
25 | dst : (M, N) array
26 | Destination coordinates.
27 | estimate_scale : bool
28 | Whether to estimate scaling factor.
29 | Returns
30 | -------
31 | T : (N + 1, N + 1)
32 | The homogeneous similarity transformation matrix. The matrix contains
33 | NaN values only if the problem is not well-conditioned.
34 | References
35 | ----------
36 | .. [1] "Least-squares estimation of transformation parameters between two
37 | point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
38 | """
39 |
40 | num = src.shape[0]
41 | dim = src.shape[1]
42 |
43 | # Compute mean of src and dst.
44 | src_mean = src.mean(axis=0)
45 | dst_mean = dst.mean(axis=0)
46 |
47 | # Subtract mean from src and dst.
48 | src_demean = src - src_mean
49 | dst_demean = dst - dst_mean
50 |
51 | # Eq. (38).
52 | A = dst_demean.T @ src_demean / num
53 |
54 | # Eq. (39).
55 | d = np.ones((dim, ), dtype=np.double)
56 | if np.linalg.det(A) < 0:
57 | d[dim - 1] = -1
58 |
59 | T = np.eye(dim + 1, dtype=np.double)
60 |
61 | U, S, V = np.linalg.svd(A)
62 |
63 | # Eq. (40) and (43).
64 | rank = np.linalg.matrix_rank(A)
65 | if rank == 0:
66 | return np.nan * T
67 | elif rank == dim - 1:
68 | if np.linalg.det(U) * np.linalg.det(V) > 0:
69 | T[:dim, :dim] = U @ V
70 | else:
71 | s = d[dim - 1]
72 | d[dim - 1] = -1
73 | T[:dim, :dim] = U @ np.diag(d) @ V
74 | d[dim - 1] = s
75 | else:
76 | T[:dim, :dim] = U @ np.diag(d) @ V
77 |
78 | if estimate_scale:
79 | # Eq. (41) and (42).
80 | scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
81 | else:
82 | scale = scale
83 |
84 | T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
85 | T[:dim, :dim] *= scale
86 |
87 | return T, scale
88 |
89 |
90 | class FaceWarpException(Exception):
91 |
92 | def __str__(self):
93 | return 'In File {}:{}'.format(__file__, super.__str__(self))
94 |
95 |
96 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
97 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
98 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
99 |
100 | # 0) make the inner region a square
101 | if default_square:
102 | size_diff = max(tmp_crop_size) - tmp_crop_size
103 | tmp_5pts += size_diff / 2
104 | tmp_crop_size += size_diff
105 |
106 | if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
107 | print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
108 | return tmp_5pts
109 |
110 | if (inner_padding_factor == 0 and outer_padding == (0, 0)):
111 | if output_size is None:
112 | print('No paddings to do: return default reference points')
113 | return tmp_5pts
114 | else:
115 | raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
116 |
117 | # check output size
118 | if not (0 <= inner_padding_factor <= 1.0):
119 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
120 |
121 | if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
122 | output_size = tmp_crop_size * \
123 | (1 + inner_padding_factor * 2).astype(np.int32)
124 | output_size += np.array(outer_padding)
125 | print(' deduced from paddings, output_size = ', output_size)
126 |
127 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
128 | raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
129 | 'and outer_padding[1] < output_size[1])')
130 |
131 | # 1) pad the inner region according inner_padding_factor
132 | # print('---> STEP1: pad the inner region according inner_padding_factor')
133 | if inner_padding_factor > 0:
134 | size_diff = tmp_crop_size * inner_padding_factor * 2
135 | tmp_5pts += size_diff / 2
136 | tmp_crop_size += np.round(size_diff).astype(np.int32)
137 |
138 | # print(' crop_size = ', tmp_crop_size)
139 | # print(' reference_5pts = ', tmp_5pts)
140 |
141 | # 2) resize the padded inner region
142 | # print('---> STEP2: resize the padded inner region')
143 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
144 | # print(' crop_size = ', tmp_crop_size)
145 | # print(' size_bf_outer_pad = ', size_bf_outer_pad)
146 |
147 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
148 | raise FaceWarpException('Must have (output_size - outer_padding)'
149 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
150 |
151 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
152 | # print(' resize scale_factor = ', scale_factor)
153 | tmp_5pts = tmp_5pts * scale_factor
154 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
155 | # tmp_5pts = tmp_5pts + size_diff / 2
156 | tmp_crop_size = size_bf_outer_pad
157 | # print(' crop_size = ', tmp_crop_size)
158 | # print(' reference_5pts = ', tmp_5pts)
159 |
160 | # 3) add outer_padding to make output_size
161 | reference_5point = tmp_5pts + np.array(outer_padding)
162 | tmp_crop_size = output_size
163 | # print('---> STEP3: add outer_padding to make output_size')
164 | # print(' crop_size = ', tmp_crop_size)
165 | # print(' reference_5pts = ', tmp_5pts)
166 | #
167 | # print('===> end get_reference_facial_points\n')
168 |
169 | return reference_5point
170 |
171 |
172 | def get_affine_transform_matrix(src_pts, dst_pts):
173 | tfm = np.float32([[1, 0, 0], [0, 1, 0]])
174 | n_pts = src_pts.shape[0]
175 | ones = np.ones((n_pts, 1), src_pts.dtype)
176 | src_pts_ = np.hstack([src_pts, ones])
177 | dst_pts_ = np.hstack([dst_pts, ones])
178 |
179 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
180 |
181 | if rank == 3:
182 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
183 | elif rank == 2:
184 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
185 |
186 | return tfm
187 |
188 |
189 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"): #smilarity cv2_affine affine
190 | if reference_pts is None:
191 | if crop_size[0] == 96 and crop_size[1] == 112:
192 | reference_pts = REFERENCE_FACIAL_POINTS
193 | else:
194 | default_square = False
195 | inner_padding_factor = 0
196 | outer_padding = (0, 0)
197 | output_size = crop_size
198 |
199 | reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, default_square)
200 |
201 | ref_pts = np.float32(reference_pts)
202 | ref_pts_shp = ref_pts.shape
203 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
204 | raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
205 |
206 | if ref_pts_shp[0] == 2:
207 | ref_pts = ref_pts.T
208 |
209 | src_pts = np.float32(facial_pts)
210 | src_pts_shp = src_pts.shape
211 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
212 | raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
213 |
214 | if src_pts_shp[0] == 2:
215 | src_pts = src_pts.T
216 |
217 | if src_pts.shape != ref_pts.shape:
218 | raise FaceWarpException('facial_pts and reference_pts must have the same shape')
219 |
220 | if align_type == 'cv2_affine':
221 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
222 | tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
223 | elif align_type == 'affine':
224 | tfm = get_affine_transform_matrix(src_pts, ref_pts)
225 | tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
226 | else:
227 | params, scale = _umeyama(src_pts, ref_pts)
228 | tfm = params[:2, :]
229 |
230 | params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale)
231 | tfm_inv = params[:2, :]
232 |
233 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
234 |
235 | return face_img, tfm_inv
236 |
--------------------------------------------------------------------------------
/gpeno/training/lpips/lpips.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torch.autograd import Variable
7 | import numpy as np
8 | from . import pretrained_networks as pn
9 | import torch.nn
10 |
11 | import lpips
12 |
13 |
14 | def spatial_average(in_tens, keepdim=True):
15 | return in_tens.mean([2, 3], keepdim=keepdim)
16 |
17 |
18 | def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
19 | in_H, in_W = in_tens.shape[2], in_tens.shape[3]
20 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
21 |
22 |
23 | # Learned perceptual metric
24 | class LPIPS(nn.Module):
25 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
26 | # lpips - [True] means with linear calibration on top of base network
27 | # pretrained - [True] means load linear weights
28 |
29 | super(LPIPS, self).__init__()
30 | if (verbose):
31 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
32 |
33 | self.pnet_type = net
34 | self.pnet_tune = pnet_tune
35 | self.pnet_rand = pnet_rand
36 | self.spatial = spatial
37 | self.lpips = lpips # false means baseline of just averaging all layers
38 | self.version = version
39 | self.scaling_layer = ScalingLayer()
40 |
41 | if (self.pnet_type in ['vgg', 'vgg16']):
42 | net_type = pn.vgg16
43 | self.chns = [64, 128, 256, 512, 512]
44 | elif (self.pnet_type == 'alex'):
45 | net_type = pn.alexnet
46 | self.chns = [64, 192, 384, 256, 256]
47 | elif (self.pnet_type == 'squeeze'):
48 | net_type = pn.squeezenet
49 | self.chns = [64, 128, 256, 384, 384, 512, 512]
50 | self.L = len(self.chns)
51 |
52 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
53 |
54 | if (lpips):
55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
60 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
61 | if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet
62 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
63 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
64 | self.lins += [self.lin5, self.lin6]
65 | self.lins = nn.ModuleList(self.lins)
66 |
67 | if (pretrained):
68 | if (model_path is None):
69 | import inspect
70 | import os
71 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'gpen/v%s/%s.pth' % (version, net)))
72 |
73 | if (verbose):
74 | print('Loading model from: %s' % model_path)
75 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
76 |
77 | if (eval_mode):
78 | self.eval()
79 |
80 | def forward(self, in0, in1, retPerLayer=False, normalize=False):
81 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
82 | in0 = 2 * in0 - 1
83 | in1 = 2 * in1 - 1
84 |
85 | # v0.0 - original release had a bug, where input was not scaled
86 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (in0, in1)
87 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
88 | feats0, feats1, diffs = {}, {}, {}
89 |
90 | for kk in range(self.L):
91 | feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk])
92 | diffs[kk] = (feats0[kk] - feats1[kk])**2
93 |
94 | if (self.lpips):
95 | if (self.spatial):
96 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
97 | else:
98 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
99 | else:
100 | if (self.spatial):
101 | res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
102 | else:
103 | res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
104 |
105 | val = res[0]
106 | for l in range(1, self.L):
107 | val += res[l]
108 |
109 | # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
110 | # b = torch.max(self.lins[kk](feats0[kk]**2))
111 | # for kk in range(self.L):
112 | # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
113 | # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
114 | # a = a/self.L
115 | # from IPython import embed
116 | # embed()
117 | # return 10*torch.log10(b/a)
118 |
119 | if (retPerLayer):
120 | return (val, res)
121 | else:
122 | return val
123 |
124 |
125 | class ScalingLayer(nn.Module):
126 | def __init__(self):
127 | super(ScalingLayer, self).__init__()
128 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
129 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
130 |
131 | def forward(self, inp):
132 | return (inp - self.shift) / self.scale
133 |
134 |
135 | class NetLinLayer(nn.Module):
136 | ''' A single linear layer which does a 1x1 conv '''
137 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
138 | super(NetLinLayer, self).__init__()
139 |
140 | layers = [
141 | nn.Dropout(),
142 | ] if (use_dropout) else []
143 | layers += [
144 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
145 | ]
146 | self.model = nn.Sequential(*layers)
147 |
148 | def forward(self, x):
149 | return self.model(x)
150 |
151 |
152 | class Dist2LogitLayer(nn.Module):
153 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
154 | def __init__(self, chn_mid=32, use_sigmoid=True):
155 | super(Dist2LogitLayer, self).__init__()
156 |
157 | layers = [
158 | nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),
159 | ]
160 | layers += [
161 | nn.LeakyReLU(0.2, True),
162 | ]
163 | layers += [
164 | nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),
165 | ]
166 | layers += [
167 | nn.LeakyReLU(0.2, True),
168 | ]
169 | layers += [
170 | nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),
171 | ]
172 | if (use_sigmoid):
173 | layers += [
174 | nn.Sigmoid(),
175 | ]
176 | self.model = nn.Sequential(*layers)
177 |
178 | def forward(self, d0, d1, eps=0.1):
179 | return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))
180 |
181 |
182 | class BCERankingLoss(nn.Module):
183 | def __init__(self, chn_mid=32):
184 | super(BCERankingLoss, self).__init__()
185 | self.net = Dist2LogitLayer(chn_mid=chn_mid)
186 | # self.parameters = list(self.net.parameters())
187 | self.loss = torch.nn.BCELoss()
188 |
189 | def forward(self, d0, d1, judge):
190 | per = (judge + 1.) / 2.
191 | self.logit = self.net.forward(d0, d1)
192 | return self.loss(self.logit, per)
193 |
194 |
195 | # L2, DSSIM metrics
196 | class FakeNet(nn.Module):
197 | def __init__(self, use_gpu=True, colorspace='Lab'):
198 | super(FakeNet, self).__init__()
199 | self.use_gpu = use_gpu
200 | self.colorspace = colorspace
201 |
202 |
203 | class L2(FakeNet):
204 | def forward(self, in0, in1, retPerLayer=None):
205 | assert (in0.size()[0] == 1) # currently only supports batchSize 1
206 |
207 | if (self.colorspace == 'RGB'):
208 | (N, C, X, Y) = in0.size()
209 | value = torch.mean(torch.mean(torch.mean((in0 - in1)**2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), dim=3).view(N)
210 | return value
211 | elif (self.colorspace == 'Lab'):
212 | value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data, to_norm=False)), lpips.tensor2np(lpips.tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
213 | ret_var = Variable(torch.Tensor((value, )))
214 | if (self.use_gpu):
215 | ret_var = ret_var.cuda()
216 | return ret_var
217 |
218 |
219 | class DSSIM(FakeNet):
220 | def forward(self, in0, in1, retPerLayer=None):
221 | assert (in0.size()[0] == 1) # currently only supports batchSize 1
222 |
223 | if (self.colorspace == 'RGB'):
224 | value = lpips.dssim(1. * lpips.tensor2im(in0.data), 1. * lpips.tensor2im(in1.data), range=255.).astype('float')
225 | elif (self.colorspace == 'Lab'):
226 | value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data, to_norm=False)), lpips.tensor2np(lpips.tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
227 | ret_var = Variable(torch.Tensor((value, )))
228 | if (self.use_gpu):
229 | ret_var = ret_var.cuda()
230 | return ret_var
231 |
232 |
233 | def print_network(net):
234 | num_params = 0
235 | for param in net.parameters():
236 | num_params += param.numel()
237 | print('Network', net)
238 | print('Total number of parameters: %d' % num_params)
239 |
--------------------------------------------------------------------------------
/gpeno/README.md:
--------------------------------------------------------------------------------
1 | # GPENO has various optimizations for Unprompted by Therefore Games.
2 |
3 | # GAN Prior Embedded Network for Blind Face Restoration in the Wild
4 |
5 | [Paper](https://arxiv.org/abs/2105.06070) | [Supplementary](https://www4.comp.polyu.edu.hk/~cslzhang/paper/GPEN-cvpr21-supp.pdf) | [Demo](https://vision.aliyun.com/experience/detail?spm=a211p3.14020179.J_7524944390.17.66cd4850wVDkUQ&tagName=facebody&children=EnhanceFace) | [ModelScope](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement/summary)
6 |
7 |
[](https://huggingface.co/spaces/akhaliq/GPEN)
8 |
9 | [Tao Yang](https://cg.cs.tsinghua.edu.cn/people/~tyang)1, Peiran Ren1, Xuansong Xie1, [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)1,2
10 | _1[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Hangzhou, China_
11 | _2[Department of Computing, The Hong Kong Polytechnic University](http://www.comp.polyu.edu.hk), Hong Kong, China_
12 |
13 | #### Face Restoration
14 |
15 |
16 |
17 |
18 |
19 |
20 | #### Selfie Restoration
21 |
22 |
23 |
24 | #### Face Colorization
25 |
26 |
27 |
28 | #### Face Inpainting
29 |
30 |
31 |
32 | #### Conditional Image Synthesis (Seg2Face)
33 |
34 |
35 |
36 | ## News
37 | (2023-02-15) **GPEN-BFR-1024** and **GPEN-BFR-2048** are now publicly available. Please download them via \[[ModelScope2](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement-hires/summary)\].
38 |
39 | (2023-02-15) We provide online demos via \[[ModelScope1](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement/summary)\] and \[[ModelScope2](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement-hires/summary)\].
40 |
41 | (2022-05-16) Add x1 sr model. Add ``--tile_size`` to avoid OOM.
42 |
43 | (2022-03-15) Add x4 sr model. Try ``--sr_scale``.
44 |
45 | (2022-03-09) Add GPEN-BFR-2048 for selfies. I have to take it down due to commercial issues. Sorry about that.
46 |
47 | (2021-12-29) Add online demos
[](https://huggingface.co/spaces/akhaliq/GPEN). Many thanks to [CJWBW](https://github.com/CJWBW) and [AK391](https://github.com/AK391).
48 |
49 | (2021-12-16) Release a simplified training code of GPEN. It differs from our implementation in the paper, but could achieve comparable performance. We strongly recommend to change the degradation model.
50 |
51 | (2021-12-09) Add face parsing to better paste restored faces back.
52 |
53 | (2021-12-09) GPEN can run on CPU now by simply discarding ``--use_cuda``.
54 |
55 | (2021-12-01) GPEN can now work on a Windows machine without compiling cuda codes. Please check it out. Thanks to [Animadversio](https://github.com/rosinality/stylegan2-pytorch/issues/81). Alternatively, you can try [GPEN-Windows](https://drive.google.com/file/d/1YJJVnPGq90e_mWZxSGGTptNQilZNfOEO/view?usp=drivesdk). Many thanks to [Cioscos](https://github.com/yangxy/GPEN/issues/74).
56 |
57 | (2021-10-22) GPEN can now work with SR methods. A SR model trained by myself is provided. Replace it with your own model if necessary.
58 |
59 | (2021-10-11) The Colab demo for GPEN is available now
.
60 |
61 | ## Download models from Modelscope
62 |
63 | - Install modelscope:
64 | ```bash
65 | pip install "modelscope[cv]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
66 | ```
67 |
68 | - Run the following codes:
69 | ```python
70 | import cv2
71 | from modelscope.pipelines import pipeline
72 | from modelscope.utils.constant import Tasks
73 | from modelscope.outputs import OutputKeys
74 |
75 | portrait_enhancement = pipeline(Tasks.image_portrait_enhancement, model='damo/cv_gpen_image-portrait-enhancement-hires')
76 | result = portrait_enhancement('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/marilyn_monroe_4.jpg')
77 | cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
78 | ```
79 |
80 | It will automatically download the GPEN models. You can find the model in the local path ``~/.cache/modelscope/hub/damo``. Please note pytorch_model.pt, pytorch_model-2048.pt are respectively the 1024 and 2048 versions.
81 |
82 | ## Usage
83 |
84 | 
85 | 
86 | 
87 | 
88 | 
89 |
90 | - Clone this repository:
91 | ```bash
92 | git clone https://github.com/yangxy/GPEN.git
93 | cd GPEN
94 | ```
95 | - Download RetinaFace model and our pre-trained model (not our best model due to commercial issues) and put them into ``weights/``.
96 |
97 | [RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [ParseNet-latest](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/ParseNet-latest.pth) | [model_ir_se50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/model_ir_se50.pth) | [GPEN-BFR-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth) | [GPEN-BFR-512-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512-D.pth) | [GPEN-BFR-256](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256.pth) | [GPEN-BFR-256-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256-D.pth) | [GPEN-Colorization-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth) | [GPEN-Inpainting-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Inpainting-1024.pth) | [GPEN-Seg2face-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Seg2face-512.pth) | [realesrnet_x1](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x1.pth) | [realesrnet_x2](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x2.pth) | [realesrnet_x4](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x4.pth)
98 |
99 | - Restore face images:
100 | ```bash
101 | python demo.py --task FaceEnhancement --model GPEN-BFR-512 --in_size 512 --channel_multiplier 2 --narrow 1 --use_sr --sr_scale 4 --use_cuda --save_face --indir examples/imgs --outdir examples/outs-bfr
102 | ```
103 |
104 | - Colorize faces:
105 | ```bash
106 | python demo.py --task FaceColorization --model GPEN-Colorization-1024 --in_size 1024 --use_cuda --indir examples/grays --outdir examples/outs-colorization
107 | ```
108 |
109 | - Complete faces:
110 | ```bash
111 | python demo.py --task FaceInpainting --model GPEN-Inpainting-1024 --in_size 1024 --use_cuda --indir examples/ffhq-10 --outdir examples/outs-inpainting
112 | ```
113 |
114 | - Synthesize faces:
115 | ```bash
116 | python demo.py --task Segmentation2Face --model GPEN-Seg2face-512 --in_size 512 --use_cuda --indir examples/segs --outdir examples/outs-seg2face
117 | ```
118 |
119 | - Train GPEN for BFR with 4 GPUs:
120 | ```bash
121 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train_simple.py --size 1024 --channel_multiplier 2 --narrow 1 --ckpt weights --sample results --batch 2 --path your_path_of_croped+aligned_hq_faces (e.g., FFHQ)
122 |
123 | ```
124 | When testing your own model, set ``--key g_ema``.
125 |
126 | Please check out ``run.sh`` for more details.
127 |
128 | ## Main idea
129 |
130 |
131 | ## Citation
132 | If our work is useful for your research, please consider citing:
133 |
134 | @inproceedings{Yang2021GPEN,
135 | title={GAN Prior Embedded Network for Blind Face Restoration in the Wild},
136 | author={Tao Yang, Peiran Ren, Xuansong Xie, and Lei Zhang},
137 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
138 | year={2021}
139 | }
140 |
141 | ## License
142 | © Alibaba, 2021. For academic and non-commercial use only.
143 |
144 | ## Acknowledgments
145 | We borrow some codes from [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface), [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch), [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN), and [GFPGAN](https://github.com/TencentARC/GFPGAN).
146 |
147 | ## Contact
148 | If you have any questions or suggestions about this paper, feel free to reach me at yangtao9009@gmail.com.
149 |
--------------------------------------------------------------------------------
/gpeno/face_model/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19 | int c = a / b;
20 |
21 | if (c * b > a) {
22 | c--;
23 | }
24 |
25 | return c;
26 | }
27 |
28 |
29 | struct UpFirDn2DKernelParams {
30 | int up_x;
31 | int up_y;
32 | int down_x;
33 | int down_y;
34 | int pad_x0;
35 | int pad_x1;
36 | int pad_y0;
37 | int pad_y1;
38 |
39 | int major_dim;
40 | int in_h;
41 | int in_w;
42 | int minor_dim;
43 | int kernel_h;
44 | int kernel_w;
45 | int out_h;
46 | int out_w;
47 | int loop_major;
48 | int loop_x;
49 | };
50 |
51 |
52 | template
53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56 |
57 | __shared__ volatile float sk[kernel_h][kernel_w];
58 | __shared__ volatile float sx[tile_in_h][tile_in_w];
59 |
60 | int minor_idx = blockIdx.x;
61 | int tile_out_y = minor_idx / p.minor_dim;
62 | minor_idx -= tile_out_y * p.minor_dim;
63 | tile_out_y *= tile_out_h;
64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65 | int major_idx_base = blockIdx.z * p.loop_major;
66 |
67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68 | return;
69 | }
70 |
71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72 | int ky = tap_idx / kernel_w;
73 | int kx = tap_idx - ky * kernel_w;
74 | scalar_t v = 0.0;
75 |
76 | if (kx < p.kernel_w & ky < p.kernel_h) {
77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78 | }
79 |
80 | sk[ky][kx] = v;
81 | }
82 |
83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87 | int tile_in_x = floor_div(tile_mid_x, up_x);
88 | int tile_in_y = floor_div(tile_mid_y, up_y);
89 |
90 | __syncthreads();
91 |
92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93 | int rel_in_y = in_idx / tile_in_w;
94 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
95 | int in_x = rel_in_x + tile_in_x;
96 | int in_y = rel_in_y + tile_in_y;
97 |
98 | scalar_t v = 0.0;
99 |
100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102 | }
103 |
104 | sx[rel_in_y][rel_in_x] = v;
105 | }
106 |
107 | __syncthreads();
108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109 | int rel_out_y = out_idx / tile_out_w;
110 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
111 | int out_x = rel_out_x + tile_out_x;
112 | int out_y = rel_out_y + tile_out_y;
113 |
114 | int mid_x = tile_mid_x + rel_out_x * down_x;
115 | int mid_y = tile_mid_y + rel_out_y * down_y;
116 | int in_x = floor_div(mid_x, up_x);
117 | int in_y = floor_div(mid_y, up_y);
118 | int rel_in_x = in_x - tile_in_x;
119 | int rel_in_y = in_y - tile_in_y;
120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122 |
123 | scalar_t v = 0.0;
124 |
125 | #pragma unroll
126 | for (int y = 0; y < kernel_h / up_y; y++)
127 | #pragma unroll
128 | for (int x = 0; x < kernel_w / up_x; x++)
129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130 |
131 | if (out_x < p.out_w & out_y < p.out_h) {
132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133 | }
134 | }
135 | }
136 | }
137 | }
138 |
139 |
140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141 | int up_x, int up_y, int down_x, int down_y,
142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143 | int curDevice = -1;
144 | cudaGetDevice(&curDevice);
145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146 |
147 | UpFirDn2DKernelParams p;
148 |
149 | auto x = input.contiguous();
150 | auto k = kernel.contiguous();
151 |
152 | p.major_dim = x.size(0);
153 | p.in_h = x.size(1);
154 | p.in_w = x.size(2);
155 | p.minor_dim = x.size(3);
156 | p.kernel_h = k.size(0);
157 | p.kernel_w = k.size(1);
158 | p.up_x = up_x;
159 | p.up_y = up_y;
160 | p.down_x = down_x;
161 | p.down_y = down_y;
162 | p.pad_x0 = pad_x0;
163 | p.pad_x1 = pad_x1;
164 | p.pad_y0 = pad_y0;
165 | p.pad_y1 = pad_y1;
166 |
167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169 |
170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171 |
172 | int mode = -1;
173 |
174 | int tile_out_h;
175 | int tile_out_w;
176 |
177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178 | mode = 1;
179 | tile_out_h = 16;
180 | tile_out_w = 64;
181 | }
182 |
183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184 | mode = 2;
185 | tile_out_h = 16;
186 | tile_out_w = 64;
187 | }
188 |
189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190 | mode = 3;
191 | tile_out_h = 16;
192 | tile_out_w = 64;
193 | }
194 |
195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196 | mode = 4;
197 | tile_out_h = 16;
198 | tile_out_w = 64;
199 | }
200 |
201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202 | mode = 5;
203 | tile_out_h = 8;
204 | tile_out_w = 32;
205 | }
206 |
207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208 | mode = 6;
209 | tile_out_h = 8;
210 | tile_out_w = 32;
211 | }
212 |
213 | dim3 block_size;
214 | dim3 grid_size;
215 |
216 | if (tile_out_h > 0 && tile_out_w) {
217 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
218 | p.loop_x = 1;
219 | block_size = dim3(32 * 8, 1, 1);
220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222 | (p.major_dim - 1) / p.loop_major + 1);
223 | }
224 |
225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226 | switch (mode) {
227 | case 1:
228 | upfirdn2d_kernel<<>>(
229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
230 | );
231 |
232 | break;
233 |
234 | case 2:
235 | upfirdn2d_kernel<<>>(
236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
237 | );
238 |
239 | break;
240 |
241 | case 3:
242 | upfirdn2d_kernel<<>>(
243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
244 | );
245 |
246 | break;
247 |
248 | case 4:
249 | upfirdn2d_kernel<<>>(
250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
251 | );
252 |
253 | break;
254 |
255 | case 5:
256 | upfirdn2d_kernel<<>>(
257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
258 | );
259 |
260 | break;
261 |
262 | case 6:
263 | upfirdn2d_kernel<<>>(
264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
265 | );
266 |
267 | break;
268 | }
269 | });
270 |
271 | return out;
272 | }
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import folder_paths
2 | import torch
3 | import os
4 | import sys
5 |
6 | import cv2
7 | import glob
8 | import time
9 | import math
10 | import argparse
11 | import numpy as np
12 | from PIL import Image, ImageDraw
13 |
14 | this_path = os.path.dirname(os.path.abspath(__file__))
15 | sys.path.append(this_path)
16 |
17 | import gpeno.__init_paths
18 | from gpeno.face_enhancement import FaceEnhancement
19 |
20 | models_dir = folder_paths.models_dir
21 |
22 | global_gpen_processor = None
23 | global_gpen_cache_model = ""
24 |
25 |
26 | class GPENO:
27 | """Restores faces in the image a special version of the GPEN technique that has been optimized for speed."""
28 |
29 | @classmethod
30 | def INPUT_TYPES(cls):
31 | return {
32 | "required": {
33 | "image": ("IMAGE", ),
34 | "use_global_cache": ("BOOLEAN", {
35 | "default": True,
36 | "tooltip": "If enabled, the model will be loaded once and shared across all instances of this node. This saves VRAM if you are using multiple instances of GPENO in your flow, but the settings must remain the same for all instances."
37 | }),
38 | "unload": ("BOOLEAN", {
39 | "default": False,
40 | "tooltip": "If enabled, the model will be freed from the cache at the start of this node's execution (if applicable), and it will not be cached again."
41 | }),
42 | "backbone": (["RetinaFace-R50", "mobilenet0.25_Final"], {
43 | "default": "RetinaFace-R50",
44 | "tooltip": "Backbone files are downloaded to `comfyui/models/facedetection`."
45 | }),
46 | "resolution_preset": (["512", "1024", "2048"], {
47 | "default": "512"
48 | }),
49 | "downscale_method": (["Bilinear", "Nearest", "Bicubic", "Area", "Lanczos"], {
50 | "default": "Bilinear"
51 | }),
52 | "channel_multiplier": ("FLOAT", {
53 | "default": 2
54 | }),
55 | "narrow": ("FLOAT", {
56 | "default": 1.0
57 | }),
58 | "alpha": ("FLOAT", {
59 | "default": 1.0
60 | }),
61 | "device": (["cpu", "cuda"], {
62 | "default": "cuda" if torch.cuda.is_available() else "cpu"
63 | }),
64 | "aligned": ("BOOLEAN", {
65 | "default": False
66 | }),
67 | "colorize": ("BOOLEAN", {
68 | "default": False,
69 | }),
70 | },
71 | }
72 |
73 | RETURN_TYPES = (
74 | "IMAGE",
75 | "IMAGE",
76 | "IMAGE",
77 | )
78 | FUNCTION = "op"
79 | CATEGORY = "image"
80 | DESCRIPTION = """
81 | Performs GPEN face restoration on the input image(s). This implementation has been optimized for speed.
82 | """
83 |
84 | def __init__(self):
85 | self.gpen_processor = None
86 | self.gpen_cache_model = ""
87 |
88 | def op(self, image, use_global_cache, unload, backbone, resolution_preset, downscale_method, channel_multiplier, narrow, alpha, device, aligned, colorize):
89 | global global_gpen_processor
90 | global global_gpen_cache_model
91 |
92 | # Package arguments into attribute notation for use with argparse
93 | args = argparse.Namespace()
94 | args.model = f"GPEN-BFR-{resolution_preset}"
95 | args.channel_multiplier = channel_multiplier
96 | args.narrow = narrow
97 | args.alpha = alpha
98 | args.use_cuda = device
99 | args.aligned = aligned
100 |
101 | # Hardcoded arguments irrelevant to the user
102 | args.use_sr = False
103 | args.in_size = int(resolution_preset)
104 | args.out_size = 0
105 | args.sr_model = "realesrnet"
106 | args.sr_scale = 2
107 | args.key = None
108 | args.indir = "example/imgs"
109 | args.outdir = "results/outs-BFR"
110 | args.ext = ".jpg"
111 | args.save_face = False
112 | args.tile_size = 0
113 |
114 | if downscale_method == "Nearest":
115 | downscale_method = cv2.INTER_NEAREST
116 | elif downscale_method == "Bilinear":
117 | downscale_method = cv2.INTER_LINEAR
118 | elif downscale_method == "Area":
119 | downscale_method = cv2.INTER_AREA
120 | elif downscale_method == "Cubic":
121 | downscale_method = cv2.INTER_CUBIC
122 | elif downscale_method == "Lanczos":
123 | downscale_method = cv2.INTER_LANCZOS4
124 |
125 | def download_file(filename, url, logger=None, overwrite=False, headers=None):
126 | import os, requests
127 |
128 | # log = get_logger(logger)
129 |
130 | if overwrite or not os.path.exists(filename):
131 | # Make sure directory structure exists
132 | os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
133 |
134 | # log.info(f"Downloading file into: {filename}...")
135 |
136 | response = requests.get(url, stream=True, headers=headers)
137 | if response.status_code != 200:
138 | # log.error(f"Error when trying to download `{url}` to `{filename}`. Dtatus code received: {response.status_code}")
139 | return False
140 | try:
141 | with open(filename, "wb") as fout:
142 | for block in response.iter_content(4096):
143 | fout.write(block)
144 | except:
145 | # log.exception(f"Error when writing download to `{filename}`.")
146 | return False
147 |
148 | return True
149 |
150 | gpen_dir = os.path.join(models_dir, "facerestore_models")
151 | facedetect_dir = os.path.join(models_dir, "facedetection")
152 |
153 | if args.model == "GPEN-BFR-512":
154 | download_file(f"{gpen_dir}/{args.model}.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth")
155 | elif args.model == "GPEN-BFR-1024":
156 | if not download_file(f"{gpen_dir}/{args.model}.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-1024.pth"):
157 | pass
158 | # self.log.error("The download link for the 1024 model doesn't appear to work. Try installing it manually into your unprompted/models/gpen folder: https://cyberfile.me/644d")
159 | elif args.model == "GPEN-BFR-2048":
160 | download_file(f"{gpen_dir}/{args.model}.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-2048.pth")
161 |
162 | if colorize:
163 | download_file(f"{gpen_dir}/GPEN-Colorization-1024.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth")
164 |
165 | # Additional dependencies
166 | download_file(f"{facedetect_dir}/parsing_parsenet.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/ParseNet-latest.pth")
167 | # for sr
168 | # download_file(f"{facedetect_dir}/realesrnet_x2.pth", "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x2.pth")
169 |
170 | if backbone == "RetinaFace-R50":
171 | backbone = "detection_Resnet50_Final"
172 | download_file(f"{facedetect_dir}/detection_Resnet50_Final.pth", "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth")
173 | elif backbone == "mobilenet0.25_Final":
174 | backbone = "detection_mobilenet0.25_Final"
175 | download_file(f"{facedetect_dir}/detection_mobilenet0.25_Final.pth", "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth")
176 |
177 | if use_global_cache:
178 | this_gpen_processor = global_gpen_processor
179 | this_gpen_cache_model = global_gpen_cache_model
180 | else:
181 | this_gpen_processor = self.gpen_processor
182 | this_gpen_cache_model = self.gpen_cache_model
183 |
184 | if unload or not this_gpen_processor or this_gpen_cache_model != resolution_preset:
185 | print("Loading FaceEnhancement object...")
186 | face_enhancement_obj = FaceEnhancement(args, base_dir=f"{models_dir}/", in_size=args.in_size, model=args.model, use_sr=args.use_sr, device=args.use_cuda, interp=downscale_method, backbone=backbone, colorize=colorize)
187 |
188 | if use_global_cache:
189 | global_gpen_processor = face_enhancement_obj
190 | global_gpen_cache_model = resolution_preset
191 | this_gpen_processor = global_gpen_processor
192 | this_gpen_cache_model = global_gpen_cache_model
193 | else:
194 | self.gpen_processor = face_enhancement_obj
195 | self.gpen_cache_model = resolution_preset
196 | this_gpen_processor = self.gpen_processor
197 | this_gpen_cache_model = self.gpen_cache_model
198 | else:
199 | print("Using cached FaceEnhancement object.")
200 | # self.log.info("Using cached FaceEnhancement object.")
201 |
202 | print("Starting GPENO processing...")
203 |
204 | total_images = image.shape[0]
205 | out_images = []
206 | out_original_faces = []
207 | out_enhanced_faces = []
208 |
209 | for i in range(total_images):
210 | # image is a 4d tensor array in the format of [B, H, W, C]
211 | this_img = 255. * image[i].cpu().numpy()
212 | img = np.clip(this_img, 0, 255).astype(np.uint8)
213 |
214 | result, orig_faces, enhanced_faces = this_gpen_processor.process(img, aligned=args.aligned)
215 |
216 | out_images.append(result)
217 | # add each of the orig_faces list to the out_original_faces list
218 | for orig_face in orig_faces:
219 | out_original_faces.append(orig_face)
220 | for enhanced_face in enhanced_faces:
221 | out_enhanced_faces.append(enhanced_face)
222 |
223 | restored_img_np = np.array(out_images).astype(np.float32) / 255.0
224 | restored_img_tensor = torch.from_numpy(restored_img_np)
225 |
226 | restored_original_faces_np = np.array(out_original_faces).astype(np.float32) / 255.0
227 | restored_original_faces_tensor = torch.from_numpy(restored_original_faces_np)
228 |
229 | restored_enhanced_faces_np = np.array(out_enhanced_faces).astype(np.float32) / 255.0
230 | restored_enhanced_faces_tensor = torch.from_numpy(restored_enhanced_faces_np)
231 |
232 | if unload:
233 | print("Unloading GPEN from cache.")
234 |
235 | if use_global_cache:
236 | global_gpen_processor = None
237 | global_gpen_cache_model = ""
238 | else:
239 | self.gpen_cache_model = ""
240 | self.gpen_processor = None
241 |
242 | print("GPENO processing done.")
243 | return (
244 | restored_img_tensor,
245 | restored_original_faces_tensor,
246 | restored_enhanced_faces_tensor,
247 | )
248 |
249 |
250 | # A dictionary that contains all nodes you want to export with their names
251 | # NOTE: names should be globally unique
252 | NODE_CLASS_MAPPINGS = {
253 | "GPENO Face Restoration": GPENO,
254 | }
255 |
256 | # A dictionary that contains the friendly/humanly readable titles for the nodes
257 | NODE_DISPLAY_NAME_MAPPINGS = {
258 | "GPENO Face Restoration": "GPENO Face Restoration",
259 | }
260 |
--------------------------------------------------------------------------------
/gpeno/training/lpips/trainer.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | from collections import OrderedDict
8 | from torch.autograd import Variable
9 | from scipy.ndimage import zoom
10 | from tqdm import tqdm
11 | import lpips
12 | import os
13 |
14 |
15 | class Trainer():
16 | def name(self):
17 | return self.model_name
18 |
19 | def initialize(self, model='lpips', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
20 | use_gpu=True, printNet=False, spatial=False,
21 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
22 | '''
23 | INPUTS
24 | model - ['lpips'] for linearly calibrated network
25 | ['baseline'] for off-the-shelf network
26 | ['L2'] for L2 distance in Lab colorspace
27 | ['SSIM'] for ssim in RGB colorspace
28 | net - ['squeeze','alex','vgg']
29 | model_path - if None, will look in weights/[NET_NAME].pth
30 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
31 | use_gpu - bool - whether or not to use a GPU
32 | printNet - bool - whether or not to print network architecture out
33 | spatial - bool - whether to output an array containing varying distances across spatial dimensions
34 | is_train - bool - [True] for training mode
35 | lr - float - initial learning rate
36 | beta1 - float - initial momentum term for adam
37 | version - 0.1 for latest, 0.0 was original (with a bug)
38 | gpu_ids - int array - [0] by default, gpus to use
39 | '''
40 | self.use_gpu = use_gpu
41 | self.gpu_ids = gpu_ids
42 | self.model = model
43 | self.net = net
44 | self.is_train = is_train
45 | self.spatial = spatial
46 | self.model_name = '%s [%s]'%(model,net)
47 |
48 | if(self.model == 'lpips'): # pretrained net + linear layer
49 | self.net = lpips.LPIPS(pretrained=not is_train, net=net, version=version, lpips=True, spatial=spatial,
50 | pnet_rand=pnet_rand, pnet_tune=pnet_tune,
51 | use_dropout=True, model_path=model_path, eval_mode=False)
52 | elif(self.model=='baseline'): # pretrained network
53 | self.net = lpips.LPIPS(pnet_rand=pnet_rand, net=net, lpips=False)
54 | elif(self.model in ['L2','l2']):
55 | self.net = lpips.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
56 | self.model_name = 'L2'
57 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
58 | self.net = lpips.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
59 | self.model_name = 'SSIM'
60 | else:
61 | raise ValueError("Model [%s] not recognized." % self.model)
62 |
63 | self.parameters = list(self.net.parameters())
64 |
65 | if self.is_train: # training mode
66 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
67 | self.rankLoss = lpips.BCERankingLoss()
68 | self.parameters += list(self.rankLoss.net.parameters())
69 | self.lr = lr
70 | self.old_lr = lr
71 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
72 | else: # test mode
73 | self.net.eval()
74 |
75 | if(use_gpu):
76 | self.net.to(gpu_ids[0])
77 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
78 | if(self.is_train):
79 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
80 |
81 | if(printNet):
82 | print('---------- Networks initialized -------------')
83 | networks.print_network(self.net)
84 | print('-----------------------------------------------')
85 |
86 | def forward(self, in0, in1, retPerLayer=False):
87 | ''' Function computes the distance between image patches in0 and in1
88 | INPUTS
89 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
90 | OUTPUT
91 | computed distances between in0 and in1
92 | '''
93 |
94 | return self.net.forward(in0, in1, retPerLayer=retPerLayer)
95 |
96 | # ***** TRAINING FUNCTIONS *****
97 | def optimize_parameters(self):
98 | self.forward_train()
99 | self.optimizer_net.zero_grad()
100 | self.backward_train()
101 | self.optimizer_net.step()
102 | self.clamp_weights()
103 |
104 | def clamp_weights(self):
105 | for module in self.net.modules():
106 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
107 | module.weight.data = torch.clamp(module.weight.data,min=0)
108 |
109 | def set_input(self, data):
110 | self.input_ref = data['ref']
111 | self.input_p0 = data['p0']
112 | self.input_p1 = data['p1']
113 | self.input_judge = data['judge']
114 |
115 | if(self.use_gpu):
116 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
117 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
118 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
119 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
120 |
121 | self.var_ref = Variable(self.input_ref,requires_grad=True)
122 | self.var_p0 = Variable(self.input_p0,requires_grad=True)
123 | self.var_p1 = Variable(self.input_p1,requires_grad=True)
124 |
125 | def forward_train(self): # run forward pass
126 | self.d0 = self.forward(self.var_ref, self.var_p0)
127 | self.d1 = self.forward(self.var_ref, self.var_p1)
128 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
129 |
130 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
131 |
132 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
133 |
134 | return self.loss_total
135 |
136 | def backward_train(self):
137 | torch.mean(self.loss_total).backward()
138 |
139 | def compute_accuracy(self,d0,d1,judge):
140 | ''' d0, d1 are Variables, judge is a Tensor '''
141 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr))
197 | self.old_lr = lr
198 |
199 |
200 | def get_image_paths(self):
201 | return self.image_paths
202 |
203 | def save_done(self, flag=False):
204 | np.save(os.path.join(self.save_dir, 'done_flag'),flag)
205 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
206 |
207 |
208 | def score_2afc_dataset(data_loader, func, name=''):
209 | ''' Function computes Two Alternative Forced Choice (2AFC) score using
210 | distance function 'func' in dataset 'data_loader'
211 | INPUTS
212 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
213 | func - callable distance function - calling d=func(in0,in1) should take 2
214 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N
215 | OUTPUTS
216 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
217 | [1] - dictionary with following elements
218 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches
219 | gts - N array in [0,1], preferred patch selected by human evaluators
220 | (closer to "0" for left patch p0, "1" for right patch p1,
221 | "0.6" means 60pct people preferred right patch, 40pct preferred left)
222 | scores - N array in [0,1], corresponding to what percentage function agreed with humans
223 | CONSTS
224 | N - number of test triplets in data_loader
225 | '''
226 |
227 | d0s = []
228 | d1s = []
229 | gts = []
230 |
231 | for data in tqdm(data_loader.load_data(), desc=name):
232 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
233 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
234 | gts+=data['judge'].cpu().numpy().flatten().tolist()
235 |
236 | d0s = np.array(d0s)
237 | d1s = np.array(d1s)
238 | gts = np.array(gts)
239 | scores = (d0s [A,1,2] -> [A,B,2]
32 | [B,2] -> [1,B,2] -> [A,B,2]
33 | Then we compute the area of intersect between box_a and box_b.
34 | Args:
35 | box_a: (tensor) bounding boxes, Shape: [A,4].
36 | box_b: (tensor) bounding boxes, Shape: [B,4].
37 | Return:
38 | (tensor) intersection area, Shape: [A,B].
39 | """
40 | A = box_a.size(0)
41 | B = box_b.size(0)
42 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
43 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
44 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
45 | box_b[:, :2].unsqueeze(0).expand(A, B, 2))
46 | inter = torch.clamp((max_xy - min_xy), min=0)
47 | return inter[:, :, 0] * inter[:, :, 1]
48 |
49 |
50 | def jaccard(box_a, box_b):
51 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
52 | is simply the intersection over union of two boxes. Here we operate on
53 | ground truth boxes and default boxes.
54 | E.g.:
55 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
56 | Args:
57 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
58 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
59 | Return:
60 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
61 | """
62 | inter = intersect(box_a, box_b)
63 | area_a = ((box_a[:, 2]-box_a[:, 0]) *
64 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
65 | area_b = ((box_b[:, 2]-box_b[:, 0]) *
66 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
67 | union = area_a + area_b - inter
68 | return inter / union # [A,B]
69 |
70 |
71 | def matrix_iou(a, b):
72 | """
73 | return iou of a and b, numpy version for data augenmentation
74 | """
75 | lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
76 | rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
77 |
78 | area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
79 | area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
80 | area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
81 | return area_i / (area_a[:, np.newaxis] + area_b - area_i)
82 |
83 |
84 | def matrix_iof(a, b):
85 | """
86 | return iof of a and b, numpy version for data augenmentation
87 | """
88 | lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
89 | rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
90 |
91 | area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
92 | area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
93 | return area_i / np.maximum(area_a[:, np.newaxis], 1)
94 |
95 |
96 | def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
97 | """Match each prior box with the ground truth box of the highest jaccard
98 | overlap, encode the bounding boxes, then return the matched indices
99 | corresponding to both confidence and location preds.
100 | Args:
101 | threshold: (float) The overlap threshold used when mathing boxes.
102 | truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
103 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
104 | variances: (tensor) Variances corresponding to each prior coord,
105 | Shape: [num_priors, 4].
106 | labels: (tensor) All the class labels for the image, Shape: [num_obj].
107 | landms: (tensor) Ground truth landms, Shape [num_obj, 10].
108 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
109 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
110 | landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
111 | idx: (int) current batch index
112 | Return:
113 | The matched indices corresponding to 1)location 2)confidence 3)landm preds.
114 | """
115 | # jaccard index
116 | overlaps = jaccard(
117 | truths,
118 | point_form(priors)
119 | )
120 | # (Bipartite Matching)
121 | # [1,num_objects] best prior for each ground truth
122 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
123 |
124 | # ignore hard gt
125 | valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
126 | best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
127 | if best_prior_idx_filter.shape[0] <= 0:
128 | loc_t[idx] = 0
129 | conf_t[idx] = 0
130 | return
131 |
132 | # [1,num_priors] best ground truth for each prior
133 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
134 | best_truth_idx.squeeze_(0)
135 | best_truth_overlap.squeeze_(0)
136 | best_prior_idx.squeeze_(1)
137 | best_prior_idx_filter.squeeze_(1)
138 | best_prior_overlap.squeeze_(1)
139 | best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
140 | # TODO refactor: index best_prior_idx with long tensor
141 | # ensure every gt matches with its prior of max overlap
142 | for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
143 | best_truth_idx[best_prior_idx[j]] = j
144 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
145 | conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
146 | conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
147 | loc = encode(matches, priors, variances)
148 |
149 | matches_landm = landms[best_truth_idx]
150 | landm = encode_landm(matches_landm, priors, variances)
151 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
152 | conf_t[idx] = conf # [num_priors] top class label for each prior
153 | landm_t[idx] = landm
154 |
155 |
156 | def encode(matched, priors, variances):
157 | """Encode the variances from the priorbox layers into the ground truth boxes
158 | we have matched (based on jaccard overlap) with the prior boxes.
159 | Args:
160 | matched: (tensor) Coords of ground truth for each prior in point-form
161 | Shape: [num_priors, 4].
162 | priors: (tensor) Prior boxes in center-offset form
163 | Shape: [num_priors,4].
164 | variances: (list[float]) Variances of priorboxes
165 | Return:
166 | encoded boxes (tensor), Shape: [num_priors, 4]
167 | """
168 |
169 | # dist b/t match center and prior's center
170 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
171 | # encode variance
172 | g_cxcy /= (variances[0] * priors[:, 2:])
173 | # match wh / prior wh
174 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
175 | g_wh = torch.log(g_wh) / variances[1]
176 | # return target for smooth_l1_loss
177 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
178 |
179 | def encode_landm(matched, priors, variances):
180 | """Encode the variances from the priorbox layers into the ground truth boxes
181 | we have matched (based on jaccard overlap) with the prior boxes.
182 | Args:
183 | matched: (tensor) Coords of ground truth for each prior in point-form
184 | Shape: [num_priors, 10].
185 | priors: (tensor) Prior boxes in center-offset form
186 | Shape: [num_priors,4].
187 | variances: (list[float]) Variances of priorboxes
188 | Return:
189 | encoded landm (tensor), Shape: [num_priors, 10]
190 | """
191 |
192 | # dist b/t match center and prior's center
193 | matched = torch.reshape(matched, (matched.size(0), 5, 2))
194 | priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
195 | priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
196 | priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
197 | priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
198 | priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
199 | g_cxcy = matched[:, :, :2] - priors[:, :, :2]
200 | # encode variance
201 | g_cxcy /= (variances[0] * priors[:, :, 2:])
202 | # g_cxcy /= priors[:, :, 2:]
203 | g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
204 | # return target for smooth_l1_loss
205 | return g_cxcy
206 |
207 |
208 | # Adapted from https://github.com/Hakuyume/chainer-ssd
209 | def decode(loc, priors, variances):
210 | """Decode locations from predictions using priors to undo
211 | the encoding we did for offset regression at train time.
212 | Args:
213 | loc (tensor): location predictions for loc layers,
214 | Shape: [num_priors,4]
215 | priors (tensor): Prior boxes in center-offset form.
216 | Shape: [num_priors,4].
217 | variances: (list[float]) Variances of priorboxes
218 | Return:
219 | decoded bounding box predictions
220 | """
221 |
222 | boxes = torch.cat((
223 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
224 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
225 | boxes[:, :2] -= boxes[:, 2:] / 2
226 | boxes[:, 2:] += boxes[:, :2]
227 | return boxes
228 |
229 | def decode_landm(pre, priors, variances):
230 | """Decode landm from predictions using priors to undo
231 | the encoding we did for offset regression at train time.
232 | Args:
233 | pre (tensor): landm predictions for loc layers,
234 | Shape: [num_priors,10]
235 | priors (tensor): Prior boxes in center-offset form.
236 | Shape: [num_priors,4].
237 | variances: (list[float]) Variances of priorboxes
238 | Return:
239 | decoded landm predictions
240 | """
241 | landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
242 | priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
243 | priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
244 | priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
245 | priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
246 | ), dim=1)
247 | return landms
248 |
249 |
250 | def log_sum_exp(x):
251 | """Utility function for computing log_sum_exp while determining
252 | This will be used to determine unaveraged confidence loss across
253 | all examples in a batch.
254 | Args:
255 | x (Variable(tensor)): conf_preds from conf layers
256 | """
257 | x_max = x.data.max()
258 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
259 |
260 |
261 | # Original author: Francisco Massa:
262 | # https://github.com/fmassa/object-detection.torch
263 | # Ported to PyTorch by Max deGroot (02/01/2017)
264 | def nms(boxes, scores, overlap=0.5, top_k=200):
265 | """Apply non-maximum suppression at test time to avoid detecting too many
266 | overlapping bounding boxes for a given object.
267 | Args:
268 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
269 | scores: (tensor) The class predscores for the img, Shape:[num_priors].
270 | overlap: (float) The overlap thresh for suppressing unnecessary boxes.
271 | top_k: (int) The Maximum number of box preds to consider.
272 | Return:
273 | The indices of the kept boxes with respect to num_priors.
274 | """
275 |
276 | keep = torch.Tensor(scores.size(0)).fill_(0).long()
277 | if boxes.numel() == 0:
278 | return keep
279 | x1 = boxes[:, 0]
280 | y1 = boxes[:, 1]
281 | x2 = boxes[:, 2]
282 | y2 = boxes[:, 3]
283 | area = torch.mul(x2 - x1, y2 - y1)
284 | v, idx = scores.sort(0) # sort in ascending order
285 | # I = I[v >= 0.01]
286 | idx = idx[-top_k:] # indices of the top-k largest vals
287 | xx1 = boxes.new()
288 | yy1 = boxes.new()
289 | xx2 = boxes.new()
290 | yy2 = boxes.new()
291 | w = boxes.new()
292 | h = boxes.new()
293 |
294 | # keep = torch.Tensor()
295 | count = 0
296 | while idx.numel() > 0:
297 | i = idx[-1] # index of current largest val
298 | # keep.append(i)
299 | keep[count] = i
300 | count += 1
301 | if idx.size(0) == 1:
302 | break
303 | idx = idx[:-1] # remove kept element from view
304 | # load bboxes of next highest vals
305 | torch.index_select(x1, 0, idx, out=xx1)
306 | torch.index_select(y1, 0, idx, out=yy1)
307 | torch.index_select(x2, 0, idx, out=xx2)
308 | torch.index_select(y2, 0, idx, out=yy2)
309 | # store element-wise max with next highest score
310 | xx1 = torch.clamp(xx1, min=x1[i])
311 | yy1 = torch.clamp(yy1, min=y1[i])
312 | xx2 = torch.clamp(xx2, max=x2[i])
313 | yy2 = torch.clamp(yy2, max=y2[i])
314 | w.resize_as_(xx2)
315 | h.resize_as_(yy2)
316 | w = xx2 - xx1
317 | h = yy2 - yy1
318 | # check sizes of xx1 and xx2.. after each iteration
319 | w = torch.clamp(w, min=0.0)
320 | h = torch.clamp(h, min=0.0)
321 | inter = w*h
322 | # IoU = i / (area(a) + area(b) - i)
323 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
324 | union = (rem_areas - inter) + area[i]
325 | IoU = inter/union # store result in iou
326 | # keep only elements with an IoU <= overlap
327 | idx = idx[IoU.le(overlap)]
328 | return keep, count
329 |
330 |
331 |
--------------------------------------------------------------------------------