├── libs ├── DECA │ ├── decalib │ │ ├── __init__.py │ │ ├── models │ │ │ ├── encoders.py │ │ │ ├── decoders.py │ │ │ └── resnet.py │ │ ├── datasets │ │ │ ├── detectors.py │ │ │ └── datasets.py │ │ └── utils │ │ │ └── config.py │ ├── .gitignore │ ├── README.md │ └── estimate_DECA.py ├── criteria │ ├── lpips │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── lpips.py │ │ └── networks.py │ ├── l2_loss.py │ ├── PTI │ │ ├── global_config.py │ │ ├── hyperparameters.py │ │ ├── base_coach.py │ │ └── localitly_regulizer.py │ ├── id_loss.py │ ├── losses.py │ ├── model_irse.py │ └── helpers.py ├── .gitignore ├── face_models │ ├── sfd │ │ ├── __init__.py │ │ ├── sfd_detector.py │ │ ├── detect.py │ │ ├── bbox.py │ │ ├── net_s3fd.py │ │ └── core.py │ ├── ffhq_cropping.py │ ├── landmarks_estimation.py │ └── fan_model │ │ └── models.py ├── gan │ ├── StyleGAN2 │ │ ├── op │ │ │ ├── __init__.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── upfirdn2d.cpp │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.py │ │ │ ├── conv2d_gradfix.py │ │ │ └── upfirdn2d_kernel.cu │ │ └── convert_weight.py │ └── encoder4editing │ │ ├── helpers.py │ │ └── psp_encoders.py ├── configs │ ├── ranges_FFHQ.npy │ ├── ranges_voxceleb.npy │ ├── config_models.py │ ├── config_arguments.py │ └── config_directions.py ├── utilities │ ├── utils.py │ ├── visualization.py │ ├── utils_inference.py │ ├── image_utils.py │ └── generic.py ├── models │ └── direction_matrix.py ├── optimization.py └── datasets │ ├── dataloader_inversion.py │ ├── dataloader.py │ └── dataloader_paired.py ├── .gitignore ├── images ├── cross.gif ├── self.gif ├── example.png ├── architecture.png └── gif_editing.gif ├── inference_examples ├── 0002775.png └── lWOTF8SdzJw#2614-2801.mp4 ├── requirements.txt ├── extract_statistics.py ├── invert_images.py ├── run_trainer.py └── README.md /libs/DECA/decalib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/criteria/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /libs/DECA/.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.pth 3 | training_attempts/ 4 | results/ 5 | __pycache__ 6 | -------------------------------------------------------------------------------- /libs/face_models/sfd/__init__.py: -------------------------------------------------------------------------------- 1 | from .sfd_detector import SFDDetector as FaceDetector -------------------------------------------------------------------------------- /images/cross.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/images/cross.gif -------------------------------------------------------------------------------- /images/self.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/images/self.gif -------------------------------------------------------------------------------- /images/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/images/example.png -------------------------------------------------------------------------------- /libs/DECA/README.md: -------------------------------------------------------------------------------- 1 | # DECA 3D shape model 2 | 3 | Code taken from [DECA](https://github.com/YadiraF/DECA) 4 | 5 | -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/images/architecture.png -------------------------------------------------------------------------------- /images/gif_editing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/images/gif_editing.gif -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /libs/configs/ranges_FFHQ.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/libs/configs/ranges_FFHQ.npy -------------------------------------------------------------------------------- /inference_examples/0002775.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/inference_examples/0002775.png -------------------------------------------------------------------------------- /libs/configs/ranges_voxceleb.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/libs/configs/ranges_voxceleb.npy -------------------------------------------------------------------------------- /inference_examples/lWOTF8SdzJw#2614-2801.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/stylegan_directions_face_reenactment/HEAD/inference_examples/lWOTF8SdzJw#2614-2801.mp4 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | imageio 5 | Pillow 6 | ninja 7 | opencv-python 8 | tqdm 9 | scikit-image 10 | sklearn 11 | wandb 12 | face-alignment 13 | kornia==0.4.1 14 | chumpy 15 | -------------------------------------------------------------------------------- /libs/criteria/l2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | l2_criterion = torch.nn.MSELoss(reduction='mean') 4 | 5 | 6 | def l2_loss(real_images, generated_images): 7 | loss = l2_criterion(real_images, generated_images) 8 | return loss 9 | -------------------------------------------------------------------------------- /libs/criteria/PTI/global_config.py: -------------------------------------------------------------------------------- 1 | ## Device 2 | cuda_visible_devices = '0' 3 | device = 'cuda:0' 4 | 5 | ## Logs 6 | training_step = 1 7 | image_rec_result_log_snapshot = 100 8 | pivotal_training_steps = 0 9 | model_snapshot_interval = 400 10 | 11 | ## Run name to be updated during PTI 12 | run_name = '' 13 | -------------------------------------------------------------------------------- /libs/configs/config_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | stylegan2_voxceleb_256 = { 5 | 'image_resolution': 256, 6 | 'channel_multiplier': 1, 7 | 'gan_weights': './pretrained_models/stylegan-voxceleb.pt', 8 | } 9 | 10 | stylegan2_ffhq_256 = { 11 | 'image_resolution': 256, 12 | 'channel_multiplier': 2, 13 | 'gan_weights': './pretrained_models/stylegan_rosinality.pt', 14 | } 15 | 16 | stylegan2_ffhq_1024 = { 17 | 'image_resolution': 1024, 18 | 'channel_multiplier': 2, 19 | 'gan_weights': './pretrained_models/stylegan2-ffhq-config-f_1024.pt', 20 | } 21 | 22 | -------------------------------------------------------------------------------- /libs/criteria/PTI/hyperparameters.py: -------------------------------------------------------------------------------- 1 | ## Architechture 2 | lpips_type = 'alex' 3 | first_inv_type = 'w' 4 | optim_type = 'adam' 5 | 6 | ## Locality regularization 7 | latent_ball_num_of_samples = 1 8 | locality_regularization_interval = 1 9 | use_locality_regularization = False 10 | regulizer_l2_lambda = 0.1 11 | regulizer_lpips_lambda = 0.1 12 | regulizer_alpha = 10 13 | 14 | ## Loss 15 | pt_l2_lambda = 1 16 | pt_lpips_lambda = 1 17 | 18 | ## Steps 19 | LPIPS_value_threshold = 0.06 20 | max_pti_steps = 350 21 | first_inv_steps = 450 22 | max_images_to_invert = 30 23 | 24 | ## Optimization 25 | pti_learning_rate = 3e-4 26 | first_inv_lr = 5e-3 27 | train_batch_size = 1 28 | use_last_w_pivots = False 29 | 30 | 31 | ############ NEW ############### 32 | # ## Optimization 33 | # max_pti_steps = 50 34 | # regulizer_alpha = 30 35 | # pti_learning_rate = 3e-4 36 | # first_inv_lr = 5e-3 37 | # stitching_tuning_lr = 3e-4 38 | # pti_adam_beta1 = 0.9 39 | # lr_rampdown_length = 0.25 40 | # lr_rampup_length = 0.05 41 | # use_lr_ramp = False 42 | -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | //#include 2 | 3 | #pragma warning(push, 0) 4 | #include 5 | #pragma warning(pop) 6 | 7 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 15 | int act, int grad, float alpha, float scale) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(bias); 18 | 19 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 24 | } -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #pragma warning(push, 0) 2 | #include 3 | #pragma warning(pop) 4 | 5 | 6 | 7 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 8 | int up_x, int up_y, int down_x, int down_y, 9 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 16 | int up_x, int up_y, int down_x, int down_y, 17 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(kernel); 20 | 21 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 26 | } -------------------------------------------------------------------------------- /libs/criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | # print(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | # if torch.isnan(x).any(): 9 | # # print(gradients_keep) 10 | # pdb.set_trace() 11 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)+1e-9) 12 | return x / (norm_factor + eps) 13 | 14 | 15 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 16 | # build url 17 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 18 | + f'master/lpips/weights/v{version}/{net_type}.pth' 19 | 20 | # download 21 | old_state_dict = torch.hub.load_state_dict_from_url( 22 | url, progress=True, 23 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 24 | ) 25 | 26 | # rename keys 27 | new_state_dict = OrderedDict() 28 | for key, val in old_state_dict.items(): 29 | new_key = key 30 | new_key = new_key.replace('lin', '') 31 | new_key = new_key.replace('model.', '') 32 | new_state_dict[new_key] = val 33 | 34 | return new_state_dict 35 | -------------------------------------------------------------------------------- /libs/criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures https://github.com/eladrich/pixel2style2pixel 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | # pretrained network 22 | self.net = get_network(net_type).to("cuda") 23 | 24 | # linear layers 25 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 26 | self.lin.load_state_dict(get_state_dict(net_type, version)) 27 | 28 | def forward(self, x: torch.Tensor, y: torch.Tensor): 29 | feat_x, feat_y = self.net(x), self.net(y) 30 | 31 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 32 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 33 | 34 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 35 | -------------------------------------------------------------------------------- /libs/criteria/PTI/base_coach.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | import pickle 4 | from argparse import Namespace 5 | import wandb 6 | import os.path 7 | import sys 8 | 9 | sys.path.append( '.' ) 10 | sys.path.append( '..' ) 11 | 12 | import torch 13 | from torchvision import transforms 14 | from libs.criteria.PTI.localitly_regulizer import Space_Regulizer 15 | from libs.criteria.PTI import hyperparameters 16 | 17 | l2_criterion = torch.nn.MSELoss(reduction='mean') 18 | 19 | 20 | def l2_loss(real_images, generated_images): 21 | loss = l2_criterion(real_images, generated_images) 22 | return loss 23 | 24 | def calc_loss(generated_images, real_images, new_G, use_ball_holder, w_batch, space_regulizer, lpips_loss, pt_l2_lambda): 25 | loss = 0.0 26 | 27 | # if hyperparameters.pt_l2_lambda > 0: 28 | if pt_l2_lambda > 0: 29 | l2_loss_val = l2_loss(generated_images, real_images) 30 | loss += l2_loss_val * pt_l2_lambda 31 | if hyperparameters.pt_lpips_lambda > 0: 32 | loss_lpips = lpips_loss(generated_images, real_images) 33 | loss_lpips = torch.squeeze(loss_lpips) 34 | 35 | loss += loss_lpips * hyperparameters.pt_lpips_lambda 36 | 37 | if use_ball_holder: 38 | ball_holder_loss_val = space_regulizer.space_regulizer_loss(new_G, w_batch, use_wandb=False) 39 | loss += ball_holder_loss_val 40 | 41 | return loss, l2_loss_val, loss_lpips, ball_holder_loss_val 42 | else: 43 | return loss, l2_loss_val, loss_lpips 44 | -------------------------------------------------------------------------------- /libs/criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .model_irse import Backbone 4 | import os 5 | import torch.backends.cudnn as cudnn 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self, pretrained_model_path = './pretrained_models/model_ir_se50.pth'): 9 | super(IDLoss, self).__init__() 10 | print('Loading ResNet ArcFace for identity loss') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | if not os.path.exists(pretrained_model_path): 13 | print('ir_se50 model does not exist in {}'.format(pretrained_model_path)) 14 | exit() 15 | self.facenet.load_state_dict(torch.load(pretrained_model_path)) 16 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 17 | self.facenet.eval() 18 | self.criterion = nn.CosineSimilarity(dim=1, eps=1e-6) 19 | 20 | def extract_feats(self, x, crop = True): 21 | if crop: 22 | x = x[:, :, 35:223, 32:220] # Crop interesting region 23 | x = self.face_pool(x) 24 | x_feats = self.facenet(x) 25 | return x_feats 26 | 27 | def forward(self, y_hat, y, crop = True): 28 | n_samples = y.shape[0] 29 | y_feats = self.extract_feats(y, crop) 30 | y_hat_feats = self.extract_feats(y_hat, crop) 31 | cosine_sim = self.criterion(y_hat_feats, y_feats.detach()) 32 | loss = 1 - cosine_sim 33 | loss = torch.mean(loss) 34 | return loss 35 | -------------------------------------------------------------------------------- /libs/DECA/decalib/models/encoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | import torch 19 | import torch.nn.functional as F 20 | from . import resnet 21 | 22 | class ResnetEncoder(nn.Module): 23 | def __init__(self, outsize, last_op=None): 24 | super(ResnetEncoder, self).__init__() 25 | feature_size = 2048 26 | self.encoder = resnet.load_ResNet50Model() #out: 2048 27 | ### regressor 28 | self.layers = nn.Sequential( 29 | nn.Linear(feature_size, 1024), 30 | nn.ReLU(), 31 | nn.Linear(1024, outsize) 32 | ) 33 | self.last_op = last_op 34 | 35 | def forward(self, inputs): 36 | features = self.encoder(inputs) 37 | parameters = self.layers(features) 38 | if self.last_op: 39 | parameters = self.last_op(parameters) 40 | return parameters 41 | -------------------------------------------------------------------------------- /libs/configs/config_arguments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | # Additional arguments for training 5 | 6 | arguments = { 7 | 'shift_scale': 6.0, # set the maximum shift scale 8 | 'min_shift': 0.1, # set the minimum shift 9 | 'learned_directions': 15, # set the number of directions to learn 10 | 'num_layers_shift': 8, # set number of layers to add the shift 11 | 'w_plus': True, # set w_plus True to find directions in the W+ space 12 | 'disentanglement_50': True, # set True to train half images on the batch changing only one direction 13 | 14 | 15 | 'lambda_identity': 10.0, # identity loss weight 16 | 'lambda_perceptual': 10.0, # perceptual loss weight 17 | 'lambda_pixel_wise': 1.0, # pixel wise loss weight, only on paired data 18 | 'lambda_shape': 1.0, # shape loss weight 19 | 'lambda_mouth_shape': 1.0, # mouth shape loss weight 20 | 'lambda_eye_shape': 1.0, # eye shape loss weight 21 | 'lambda_w_reg': 0.0, # w regularizer 22 | 23 | 'steps_per_log': 10, # set number iterations per log 24 | 'steps_per_save': 1000, # set number iterations per saving model 25 | 'steps_per_ev_log': 1000, # set number iterations per evaluation 26 | 'validation_samples': 100, # number of samples for evaluation 27 | 28 | 'reenactment_fig': True, # generate reenactment figure during evaluation 29 | 'num_pairs_log': 4, # how many pairs on the reenactment figure 30 | 'gif': False, # generate gif with directions durion evaluation 31 | 'evaluation': True, # evaluate model during training 32 | 33 | } -------------------------------------------------------------------------------- /libs/face_models/sfd/sfd_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from torch.utils.model_zoo import load_url 4 | import sys 5 | import matplotlib.pyplot as plt 6 | from .core import FaceDetector 7 | 8 | from .net_s3fd import s3fd 9 | from .bbox import * 10 | from .detect import * 11 | import torch.backends.cudnn as cudnn 12 | 13 | 14 | models_urls = { 15 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', 16 | } 17 | 18 | 19 | class SFDDetector(FaceDetector): 20 | def __init__(self, device, path_to_detector=None, verbose=False): 21 | super(SFDDetector, self).__init__(device, verbose) 22 | 23 | self.device = device 24 | model_weights = torch.load(path_to_detector) 25 | 26 | self.face_detector = s3fd() 27 | self.face_detector.load_state_dict(model_weights) 28 | self.face_detector.to(self.device) 29 | self.face_detector.eval() 30 | 31 | def detect_from_batch(self, tensor): 32 | 33 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device) 34 | 35 | new_bboxlists = [] 36 | for i in range(bboxlists.shape[0]): 37 | bboxlist = bboxlists[i] 38 | keep = nms(bboxlist, 0.3) 39 | # print(keep) 40 | if len(keep)>0: 41 | bboxlist = bboxlist[keep, :] 42 | bboxlist = [x for x in bboxlist if x[-1] > 0.5] 43 | new_bboxlists.append(bboxlist) 44 | else: 45 | new_bboxlists.append([]) 46 | 47 | return new_bboxlists 48 | 49 | @property 50 | def reference_scale(self): 51 | return 195 52 | 53 | @property 54 | def reference_x_shift(self): 55 | return 0 56 | 57 | @property 58 | def reference_y_shift(self): 59 | return 0 60 | -------------------------------------------------------------------------------- /libs/utilities/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import json 5 | from datetime import datetime 6 | import glob 7 | 8 | def get_image_files(path): 9 | types = ('*.png', '*.jpg') # the tuple of file types 10 | files_grabbed = [] 11 | for files in types: 12 | files_grabbed.extend(glob.glob(os.path.join(path, files))) 13 | files_grabbed.sort() 14 | return files_grabbed 15 | 16 | def get_files_frompath(path, types): 17 | files_grabbed = [] 18 | for files in types: 19 | files_grabbed.extend(glob.glob(os.path.join(path, files))) 20 | files_grabbed.sort() 21 | return files_grabbed 22 | 23 | def make_path(path): 24 | if not os.path.exists(path): 25 | os.makedirs(path, exist_ok = True) 26 | 27 | def save_arguments_json(args, save_path, filename): 28 | out_json = os.path.join(save_path, filename) 29 | # datetime object containing current date and time 30 | now = datetime.now() 31 | dt_string = now.strftime("%d/%m/%Y %H:%M:%S") 32 | with open(out_json, 'w') as out: 33 | stat_dict = args 34 | json.dump(stat_dict, out) 35 | 36 | def read_arguments_json(filename): 37 | with open(filename) as json_file: 38 | data = json.load(json_file) 39 | arguments_dict = data 40 | 41 | return arguments_dict 42 | 43 | def delete_files(file_list): 44 | """Delete files with filenames in given list. 45 | Args: 46 | file_list (list): list of filenames to be deleted 47 | """ 48 | for file in file_list: 49 | try: 50 | os.remove(file) 51 | except OSError: 52 | pass 53 | 54 | def make_noise(batch, dim, truncation=None): 55 | if isinstance(dim, int): 56 | dim = [dim] 57 | if truncation is None or truncation == 1.0: 58 | return torch.randn([batch] + dim) 59 | else: 60 | return torch.from_numpy(truncated_noise([batch] + dim, truncation)).to(torch.float) 61 | 62 | def one_hot(dims, value, indx): 63 | vec = torch.zeros(dims) 64 | vec[indx] = value 65 | return vec 66 | -------------------------------------------------------------------------------- /libs/DECA/decalib/datasets/detectors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import numpy as np 17 | import torch 18 | # from libs.pose_estimation.fan_model.models import FAN, ResNetDepth 19 | # from libs.pose_estimation.fan_model.utils import * 20 | from enum import Enum 21 | # from libs.pose_estimation.sfd.sfd_detector import SFDDetector as FaceDetector 22 | 23 | class FAN(object): 24 | def __init__(self): 25 | import face_alignment 26 | self.model = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) 27 | 28 | def run(self, image): 29 | ''' 30 | image: 0-255, uint8, rgb, [h, w, 3] 31 | return: detected box list 32 | ''' 33 | 34 | out = self.model.get_landmarks(image) 35 | if out is None: 36 | return [0], 'error' 37 | else: 38 | kpt = out[0].squeeze() 39 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 40 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 41 | bbox = [left,top, right, bottom] 42 | return bbox, 'kpt68' 43 | 44 | 45 | class MTCNN(object): 46 | def __init__(self, device = 'cpu'): 47 | ''' 48 | https://github.com/timesler/facenet-pytorch/blob/master/examples/infer.ipynb 49 | ''' 50 | from facenet_pytorch import MTCNN as mtcnn 51 | self.device = device 52 | self.model = mtcnn(keep_all=True) 53 | def run(self, input): 54 | ''' 55 | image: 0-255, uint8, rgb, [h, w, 3] 56 | return: detected box 57 | ''' 58 | out = self.model.detect(input[None,...]) 59 | if out[0][0] is None: 60 | return [0] 61 | else: 62 | bbox = out[0][0].squeeze() 63 | return bbox, 'bbox' 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /libs/models/direction_matrix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | class DirectionMatrix(nn.Module): 7 | def __init__(self, shift_dim, input_dim=None, out_dim=None, inner_dim=512, 8 | bias=True, w_plus = False, num_layers = 14, initialization = 'normal'): 9 | super(DirectionMatrix, self).__init__() 10 | self.shift_dim = shift_dim 11 | self.input_dim = input_dim if input_dim is not None else np.product(shift_dim) 12 | self.out_dim = out_dim if out_dim is not None else np.product(shift_dim) 13 | self.w_plus = w_plus 14 | self.num_layers = num_layers 15 | 16 | if self.w_plus: 17 | print("Linear Direction matrix-A in w+ space: input dimension {}, output dimension {}, shift dimension {} ".format(self.input_dim, 18 | self.out_dim, self.shift_dim)) 19 | else: 20 | print("Linear Direction matrix-A type : input dimension {}, output dimension {}, shift dimension {} ".format(self.input_dim, 21 | self.out_dim, self.shift_dim)) 22 | 23 | if self.w_plus: 24 | out_dim = self.out_dim * num_layers 25 | else: 26 | out_dim = self.out_dim 27 | 28 | self.linear = nn.Linear(self.input_dim, out_dim, bias=bias) 29 | self.linear.weight.data = torch.zeros_like(self.linear.weight.data) 30 | 31 | if initialization == 'normal': 32 | torch.nn.init.normal_(self.linear.weight, mean=0.0, std=0.03) 33 | if not self.w_plus and initialization == 'eye': 34 | min_dim = int(min(self.input_dim, out_dim)) 35 | self.linear.weight.data[:min_dim, :min_dim] = torch.eye(min_dim) 36 | if self.w_plus and initialization == 'eye': 37 | min_dim = int(min(self.input_dim, out_dim)) 38 | for layer_cnt in range(num_layers): 39 | self.linear.weight.data[layer_cnt*self.out_dim:(layer_cnt*self.out_dim + min_dim), :min_dim] = torch.eye(min_dim) 40 | 41 | def forward(self, input): 42 | input = input.view([-1, self.input_dim]) 43 | 44 | out = self.linear(input) 45 | if self.w_plus: 46 | out = out.view(len(input), self.num_layers, self.shift_dim) 47 | 48 | return out 49 | -------------------------------------------------------------------------------- /libs/criteria/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | """ 5 | Calculate shape losses 6 | """ 7 | 8 | class Losses(): 9 | def __init__(self): 10 | self.criterion_mse = torch.nn.MSELoss() 11 | self.criterion_l1 = torch.nn.L1Loss() 12 | self.image_deca_size = 224 13 | 14 | def calculate_pixel_wise_loss(self, images_shifted, images): 15 | 16 | pixel_wise_loss = self.criterion_l1(images, images_shifted) 17 | 18 | return pixel_wise_loss 19 | 20 | def calculate_shape_loss(self, shape_gt, shape_reenacted, normalize = False): 21 | criterion_l1 = torch.nn.L1Loss() 22 | if normalize: 23 | shape_gt_norm = shape_gt/200 #self.image_deca_size 24 | shape_reenacted_norm = shape_reenacted/200 #self.image_deca_size 25 | loss = criterion_l1(shape_gt_norm, shape_reenacted_norm) 26 | else: 27 | loss = criterion_l1(shape_gt, shape_reenacted) 28 | return loss 29 | 30 | def calculate_eye_loss(self, shape_gt, shape_reenacted): 31 | criterion_l1 = torch.nn.L1Loss() 32 | shape_gt_norm = shape_gt.clone() 33 | shape_reenacted_norm = shape_reenacted.clone() 34 | # shape_gt_norm = shape_gt_norm/self.image_deca_size 35 | # shape_reenacted_norm = shape_reenacted_norm/self.image_deca_size 36 | eye_pairs = [(36, 39), (37, 41), (38, 40), (42, 45), (43, 47), (44, 46)] 37 | loss = 0 38 | for i in range(len(eye_pairs)): 39 | pair = eye_pairs[i] 40 | d_gt = abs(shape_gt[:, pair[0],:] - shape_gt[:, pair[1],:]) 41 | d_e = abs(shape_reenacted[:, pair[0],:] - shape_reenacted[:, pair[1],:]) 42 | loss += criterion_l1(d_gt, d_e) 43 | 44 | loss = loss/len(eye_pairs) 45 | return loss 46 | 47 | def calculate_mouth_loss(self, shape_gt, shape_reenacted): 48 | criterion_l1 = torch.nn.L1Loss() 49 | shape_gt_norm = shape_gt.clone() 50 | shape_reenacted_norm = shape_reenacted.clone() 51 | # shape_gt_norm = shape_gt_norm/self.image_deca_size 52 | # shape_reenacted_norm = shape_reenacted_norm/self.image_deca_size 53 | mouth_pairs = [(48, 54), (49, 59), (50, 58), (51, 57), (52, 56), (53, 55), (60, 64), (61, 67), (62, 66), (63, 65)] 54 | loss = 0 55 | for i in range(len(mouth_pairs)): 56 | pair = mouth_pairs[i] 57 | d_gt = abs(shape_gt[:, pair[0],:] - shape_gt[:, pair[1],:]) 58 | d_e = abs(shape_reenacted[:, pair[0],:] - shape_reenacted[:, pair[1],:]) 59 | loss += criterion_l1(d_gt, d_e) 60 | 61 | loss = loss/len(mouth_pairs) 62 | return loss 63 | 64 | -------------------------------------------------------------------------------- /libs/face_models/ffhq_cropping.py: -------------------------------------------------------------------------------- 1 | """ 2 | Crop images using facial landmarks 3 | """ 4 | import numpy as np 5 | import cv2 6 | import os 7 | import collections 8 | import PIL.Image 9 | import PIL.ImageFile 10 | from PIL import Image 11 | import scipy.ndimage 12 | 13 | def pad_img_to_fit_bbox(img, x1, x2, y1, y2, crop_box): 14 | img_or = img.copy() 15 | img = cv2.copyMakeBorder(img, 16 | -min(0, y1), max(y2 - img.shape[0], 0), 17 | -min(0, x1), max(x2 - img.shape[1], 0), cv2.BORDER_REFLECT) 18 | 19 | y2 += -min(0, y1) 20 | y1 += -min(0, y1) 21 | x2 += -min(0, x1) 22 | x1 += -min(0, x1) 23 | 24 | pad = crop_box 25 | pad = (max(-pad[0], 0), max(-pad[1], 0), max(pad[2] - img_or.shape[1] , 0), max(pad[3] - img_or.shape[0] , 0)) 26 | 27 | h, w, _ = img.shape 28 | y, x, _ = np.ogrid[:h, :w, :1] 29 | pad = np.array(pad, dtype=np.float32) 30 | pad[pad == 0] = 1e-10 31 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 32 | img = np.array(img, dtype=np.float32) 33 | blur = 5.0 34 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 35 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 36 | 37 | return img, x1, x2, y1, y2 38 | 39 | def crop_from_bbox(img, bbox): 40 | """ 41 | bbox: tuple, (x1, y1, x2, y2) 42 | x: horizontal, y: vertical, exclusive 43 | """ 44 | x1, y1, x2, y2 = bbox 45 | if x1 < 0 or y1 < 0 or x2 > img.shape[1] or y2 > img.shape[0]: 46 | img, x1, x2, y1, y2 = pad_img_to_fit_bbox(img, x1, x2, y1, y2, bbox) 47 | return img[y1:y2, x1:x2] 48 | 49 | def crop_using_landmarks(image, landmarks): 50 | image_size = 256 51 | center = ((landmarks.min(0) + landmarks.max(0)) / 2).round().astype(int) 52 | size = int(max(landmarks[:, 0].max() - landmarks[:, 0].min(), landmarks[:, 1].max() - landmarks[:, 1].min())) 53 | try: 54 | center[1] -= size // 6 55 | except: 56 | return None 57 | 58 | # Crop images and poses 59 | h, w, _ = image.shape 60 | img = Image.fromarray(image) 61 | crop_box = (center[0]-size, center[1]-size, center[0]+size, center[1]+size) 62 | image = crop_from_bbox(image, crop_box) 63 | try: 64 | img = Image.fromarray(image.astype(np.uint8)) 65 | img = img.resize((image_size, image_size), Image.BICUBIC) 66 | pix = np.array(img) 67 | return pix 68 | except: 69 | return None 70 | -------------------------------------------------------------------------------- /libs/DECA/estimate_DECA.py: -------------------------------------------------------------------------------- 1 | """ 2 | The pipeline of 3D shape model prediction 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import cv2 8 | import os 9 | 10 | from .decalib.deca import DECA 11 | from .decalib.datasets import datasets 12 | from .decalib.utils import util 13 | from .decalib.utils.config import cfg as deca_cfg 14 | from .decalib.utils.rotation_converter import * 15 | 16 | 17 | class DECA_model(): 18 | 19 | def __init__(self, device): 20 | deca_cfg.model.use_tex = False 21 | dir_path = os.path.dirname(os.path.realpath(__file__)) 22 | models_path = os.path.join(dir_path, 'data') 23 | if not os.path.exists(models_path): 24 | print('Please download the required data for DECA model. See Readme.') 25 | exit() 26 | self.deca = DECA(config = deca_cfg, device=device) 27 | self.data = datasets.TestData() 28 | 29 | 'Batch torch tensor' 30 | def extract_DECA_params(self, images): 31 | 32 | p_tensor = torch.zeros(images.shape[0], 6).cuda() 33 | alpha_shp_tensor = torch.zeros(images.shape[0], 100).cuda() 34 | alpha_exp_tensor = torch.zeros(images.shape[0], 50).cuda() 35 | angles = torch.zeros(images.shape[0], 3).cuda() 36 | cam = torch.zeros(images.shape[0], 3).cuda() 37 | for batch in range(images.shape[0]): 38 | image_prepro, error_flag = self.data.get_image_tensor(images[batch].clone()) 39 | if not error_flag: 40 | codedict = self.deca.encode(image_prepro.unsqueeze(0).cuda()) 41 | pose = codedict['pose'][:,:3] 42 | pose = rad2deg(batch_axis2euler(pose)) 43 | p_tensor[batch] = codedict['pose'][0] 44 | alpha_shp_tensor[batch] = codedict['shape'][0] 45 | alpha_exp_tensor[batch] = codedict['exp'][0] 46 | cam[batch] = codedict['cam'][0] 47 | angles[batch] = pose 48 | else: 49 | angles[batch][0] = -180 50 | angles[batch][1] = -180 51 | angles[batch][2] = -180 52 | 53 | return p_tensor, alpha_shp_tensor, alpha_exp_tensor, angles, cam 54 | 55 | def calculate_shape(self, coefficients, image = None, save_path = None, prefix = None): 56 | landmarks2d, landmarks3d, points = self.deca.decode(coefficients) 57 | return landmarks2d, landmarks3d, points 58 | 59 | -------------------------------------------------------------------------------- /libs/DECA/decalib/models/decoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | class Generator(nn.Module): 20 | def __init__(self, latent_dim=100, out_channels=1, out_scale=0.01, sample_mode = 'bilinear'): 21 | super(Generator, self).__init__() 22 | self.out_scale = out_scale 23 | 24 | self.init_size = 32 // 4 # Initial size before upsampling 25 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2)) 26 | self.conv_blocks = nn.Sequential( 27 | nn.BatchNorm2d(128), 28 | nn.Upsample(scale_factor=2, mode=sample_mode), #16 29 | nn.Conv2d(128, 128, 3, stride=1, padding=1), 30 | nn.BatchNorm2d(128, 0.8), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Upsample(scale_factor=2, mode=sample_mode), #32 33 | nn.Conv2d(128, 64, 3, stride=1, padding=1), 34 | nn.BatchNorm2d(64, 0.8), 35 | nn.LeakyReLU(0.2, inplace=True), 36 | nn.Upsample(scale_factor=2, mode=sample_mode), #64 37 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 38 | nn.BatchNorm2d(64, 0.8), 39 | nn.LeakyReLU(0.2, inplace=True), 40 | nn.Upsample(scale_factor=2, mode=sample_mode), #128 41 | nn.Conv2d(64, 32, 3, stride=1, padding=1), 42 | nn.BatchNorm2d(32, 0.8), 43 | nn.LeakyReLU(0.2, inplace=True), 44 | nn.Upsample(scale_factor=2, mode=sample_mode), #256 45 | nn.Conv2d(32, 16, 3, stride=1, padding=1), 46 | nn.BatchNorm2d(16, 0.8), 47 | nn.LeakyReLU(0.2, inplace=True), 48 | nn.Conv2d(16, out_channels, 3, stride=1, padding=1), 49 | nn.Tanh(), 50 | ) 51 | 52 | def forward(self, noise): 53 | out = self.l1(noise) 54 | out = out.view(out.shape[0], 128, self.init_size, self.init_size) 55 | img = self.conv_blocks(out) 56 | return img*self.out_scale -------------------------------------------------------------------------------- /libs/criteria/PTI/localitly_regulizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from libs.criteria import l2_loss 5 | from libs.criteria.PTI import hyperparameters 6 | from libs.criteria.PTI import global_config 7 | 8 | 9 | class Space_Regulizer: 10 | def __init__(self, original_G, lpips_net): 11 | self.original_G = original_G 12 | self.morphing_regulizer_alpha = hyperparameters.regulizer_alpha 13 | self.lpips_loss = lpips_net 14 | 15 | def get_morphed_w_code(self, new_w_code, fixed_w): 16 | interpolation_direction = new_w_code - fixed_w 17 | interpolation_direction_norm = torch.norm(interpolation_direction, p=2) 18 | direction_to_move = hyperparameters.regulizer_alpha * interpolation_direction / interpolation_direction_norm 19 | result_w = fixed_w + direction_to_move 20 | self.morphing_regulizer_alpha * fixed_w + (1 - self.morphing_regulizer_alpha) * new_w_code 21 | 22 | return result_w 23 | 24 | def get_image_from_ws(self, w_codes, G): 25 | return torch.cat([G.synthesis(w_code, noise_mode='none', force_fp32=True) for w_code in w_codes]) 26 | 27 | def ball_holder_loss_lazy(self, new_G, num_of_sampled_latents, w_batch): 28 | loss = 0.0 29 | from torchvision import utils as torch_utils 30 | trunc = new_G.mean_latent(4096).detach().clone() 31 | z_samples = np.random.randn(num_of_sampled_latents, 512) 32 | 33 | w_samples = self.original_G.get_latent(torch.from_numpy(z_samples).float().to(global_config.device)) 34 | territory_indicator_ws = [self.get_morphed_w_code(w_code.unsqueeze(0), w_batch) for w_code in w_samples] 35 | for w_code in territory_indicator_ws: 36 | 37 | new_img, _ = new_G([w_code], return_latents = False, truncation = 0.7, truncation_latent = trunc, input_is_latent = True) 38 | with torch.no_grad(): 39 | old_img, _ = self.original_G([w_code], return_latents = False, truncation = 0.7, truncation_latent = trunc, input_is_latent = True) 40 | 41 | if hyperparameters.regulizer_l2_lambda > 0: 42 | l2_loss_val = l2_loss.l2_loss(old_img, new_img) 43 | loss += l2_loss_val * hyperparameters.regulizer_l2_lambda 44 | 45 | if hyperparameters.regulizer_lpips_lambda > 0: 46 | loss_lpips = self.lpips_loss(old_img, new_img) 47 | loss_lpips = torch.mean(torch.squeeze(loss_lpips)) 48 | loss += loss_lpips * hyperparameters.regulizer_lpips_lambda 49 | 50 | return loss / len(territory_indicator_ws) 51 | 52 | def space_regulizer_loss(self, new_G, w_batch): 53 | ret_val = self.ball_holder_loss_lazy(new_G, hyperparameters.latent_ball_num_of_samples, w_batch) 54 | return ret_val 55 | -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | out, = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | out, = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /libs/optimization.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import glob 5 | import sys 6 | from tqdm import tqdm 7 | sys.path.append(".") 8 | sys.path.append("..") 9 | 10 | from libs.criteria.PTI.localitly_regulizer import Space_Regulizer 11 | from libs.criteria.PTI.base_coach import calc_loss 12 | from libs.criteria.PTI import hyperparameters 13 | from libs.criteria import id_loss 14 | from libs.criteria import l2_loss 15 | from libs.criteria.lpips.lpips import LPIPS 16 | 17 | 18 | 19 | def print_losses(metrics_dict, step): 20 | out_text = '[step {}]'.format(step) 21 | for key, value in metrics_dict.items(): 22 | out_text += (' | {}: {:.5f}'.format(key, value)) 23 | print(out_text) 24 | 25 | def optimize_g(generator, latent, real_imgs, opt_steps = 200, lr = 3e-3, use_ball_holder = False, optimize_all = False): 26 | trunc = generator.mean_latent(4096).detach().clone() 27 | truncation = 0.7 28 | generator_copy = copy.deepcopy(generator) 29 | generator.train() 30 | 31 | if not optimize_all: # Optimize some of the generator's parameters 32 | parameters = list(generator.convs[11].parameters()) + list(generator.convs[10].parameters()) + list(generator.convs[9].parameters()) \ 33 | + list(generator.convs[8].parameters()) + list(generator.convs[7].parameters()) + list(generator.convs[6].parameters()) \ 34 | + list(generator.convs[5].parameters()) + list(generator.convs[4].parameters()) 35 | optimizer = torch.optim.Adam(parameters, lr=lr) 36 | pt_l2_lambda = 100 37 | else: 38 | optimizer = torch.optim.Adam(generator.parameters(), lr=lr) 39 | pt_l2_lambda = 1 40 | 41 | print('********** Start optimization for {} steps **********'.format(opt_steps)) 42 | 43 | lpips_loss_fun = LPIPS(net_type='alex').cuda().eval() 44 | space_regulizer = Space_Regulizer(generator_copy, lpips_loss_fun) 45 | 46 | # for step in range(opt_steps): 47 | for step in tqdm(range(opt_steps)): 48 | 49 | loss_dict = {} 50 | imgs_gen, _ = generator([latent], input_is_latent=True, return_latents = False, truncation=truncation, truncation_latent=trunc) 51 | 52 | if use_ball_holder: 53 | loss, l2_loss_val, lpips_loss, ball_holder_loss_val = calc_loss(imgs_gen, real_imgs, 54 | generator, use_ball_holder, latent, space_regulizer, lpips_loss_fun, pt_l2_lambda) 55 | loss_dict["ball_holder_loss_val"] = ball_holder_loss_val.item() 56 | else: 57 | loss, l2_loss_val, lpips_loss = calc_loss(imgs_gen, real_imgs, 58 | generator, use_ball_holder, latent, space_regulizer, lpips_loss_fun, pt_l2_lambda) 59 | 60 | loss_dict['loss'] = loss.item() 61 | loss_dict["lpips_loss"] = lpips_loss.item() 62 | loss_dict['l2_loss'] = l2_loss_val.item() 63 | 64 | # print_losses(loss_dict, step) 65 | 66 | optimizer.zero_grad() 67 | loss.backward() 68 | optimizer.step() 69 | 70 | print('********** End optimization **********') 71 | 72 | return generator 73 | -------------------------------------------------------------------------------- /libs/face_models/sfd/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import random 8 | import datetime 9 | import math 10 | import argparse 11 | import numpy as np 12 | 13 | import scipy.io as sio 14 | import zipfile 15 | from .net_s3fd import s3fd 16 | from .bbox import * 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | def detect(net, img, device): 21 | img = img - np.array([104, 117, 123]) 22 | img = img.transpose(2, 0, 1) 23 | # Creates a batch of 1 24 | img = img.reshape((1,) + img.shape) 25 | 26 | 27 | if torch.cuda.current_device() == 0: 28 | torch.backends.cudnn.benchmark = True 29 | 30 | img = torch.from_numpy(img).float().to(device) 31 | 32 | 33 | return batch_detect(net, img, device) 34 | 35 | 36 | def batch_detect(net, img_batch, device): 37 | """ 38 | Inputs: 39 | - img_batch: a torch.Tensor of shape (Batch size, Channels, Height, Width) 40 | """ 41 | 42 | BB, CC, HH, WW = img_batch.size() 43 | 44 | with torch.no_grad(): 45 | olist = net(img_batch.float()) # patched uint8_t overflow error 46 | 47 | 48 | for i in range(len(olist) // 2): 49 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 50 | 51 | bboxlists = [] 52 | 53 | olist = [oelem.data.cpu() for oelem in olist] 54 | for j in range(BB): 55 | bboxlist = [] 56 | for i in range(len(olist) // 2): 57 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 58 | FB, FC, FH, FW = ocls.size() # feature map size 59 | stride = 2**(i + 2) # 4,8,16,32,64,128 60 | anchor = stride * 4 61 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 62 | 63 | for Iindex, hindex, windex in poss: 64 | 65 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 66 | score = ocls[j, 1, hindex, windex] 67 | loc = oreg[j, :, hindex, windex].contiguous().view(1, 4) 68 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) 69 | variances = [0.1, 0.2] 70 | box = decode(loc, priors, variances) 71 | x1, y1, x2, y2 = box[0] * 1.0 72 | bboxlist.append([x1, y1, x2, y2, score]) 73 | bboxlists.append(bboxlist) 74 | 75 | bboxlists = np.array(bboxlists) 76 | 77 | if 0 == len(bboxlists): 78 | bboxlists = np.zeros((1, 1, 5)) 79 | 80 | 81 | return bboxlists 82 | 83 | 84 | def flip_detect(net, img, device): 85 | img = cv2.flip(img, 1) 86 | b = detect(net, img, device) 87 | 88 | bboxlist = np.zeros(b.shape) 89 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 90 | bboxlist[:, 1] = b[:, 1] 91 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 92 | bboxlist[:, 3] = b[:, 3] 93 | bboxlist[:, 4] = b[:, 4] 94 | return bboxlist 95 | 96 | 97 | def pts_to_bb(pts): 98 | min_x, min_y = np.min(pts, axis=0) 99 | max_x, max_y = np.max(pts, axis=0) 100 | return np.array([min_x, min_y, max_x, max_y]) 101 | -------------------------------------------------------------------------------- /libs/configs/config_directions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | voxceleb_dict = { 6 | 7 | 'yaw_direction': 0, 8 | 'pitch_direction': 1, 9 | 'roll_direction': 2, 10 | 'jaw_direction': 3, 11 | 'yaw_scale': 40, 12 | 'pitch_scale': 20, 13 | 'roll_scale': 20, 14 | 'ranges_filepath': './libs/configs/ranges_voxceleb.npy' 15 | } 16 | 17 | 18 | ffhq_dict = { 19 | 20 | 'yaw_direction': 0, 21 | 'pitch_direction': 1, 22 | 'roll_direction': -1, 23 | 'jaw_direction': 3, 24 | 'yaw_scale': 40, 25 | 'pitch_scale': 20, 26 | 'roll_scale': 20, 27 | 'ranges_filepath': './libs/configs/ranges_FFHQ.npy' 28 | } 29 | 30 | def get_direction_ranges(range_filepath): 31 | 32 | if os.path.exists(range_filepath): 33 | exp_ranges = np.load(range_filepath) 34 | exp_ranges = np.asarray(exp_ranges).astype('float64') 35 | else: 36 | print('{} does not exists'.format(range_filepath)) 37 | exit() 38 | 39 | return exp_ranges 40 | 41 | " For inference -> generate interpolation gifs" 42 | def get_direction_info(direction_index, angle_directions, a_jaw, b_jaw, directions_exp, shift_scale, angle_scales, 43 | count_pose, shifts_count, params_source, angles_source): 44 | if direction_index == angle_directions[0]: 45 | type_direction = 'yaw'; pose = True 46 | angle_scale = angle_scales[0] 47 | source_angle = angles_source[:,0][0].detach().cpu().numpy() 48 | max_angle = 30; min_angle = -30 49 | elif direction_index == angle_directions[1]: 50 | type_direction = 'pitch'; pose = True 51 | angle_scale = angle_scales[1] 52 | source_angle = angles_source[:,1][0].detach().cpu().numpy() 53 | max_angle = 15; min_angle = -15 54 | elif direction_index == angle_directions[2]: 55 | type_direction = 'roll'; pose = True 56 | angle_scale = angle_scales[2] 57 | source_angle = angles_source[:,2][0].detach().cpu().numpy() 58 | max_angle = 15; min_angle = -15 59 | else: 60 | if direction_index == count_pose - 1: 61 | type_direction = 'jaw'; pose = False 62 | jaw_exp_source = params_source['pose'][0, 3] 63 | start_pose = a_jaw * jaw_exp_source + b_jaw 64 | start_pose = start_pose.detach().cpu().numpy() 65 | else: 66 | index = next((index for (index, d) in enumerate(directions_exp) if d['A_direction'] == direction_index), None) 67 | if index is not None: 68 | ind_exp = directions_exp[index]['exp_component'] 69 | type_direction = 'exp_{:02d}'.format(ind_exp); pose = False 70 | exp_target = params_source['alpha_exp'][0][ind_exp] 71 | a = directions_exp[index]['a'] 72 | b = directions_exp[index]['b'] 73 | start_pose = a * exp_target + b 74 | start_pose = start_pose.detach().cpu().numpy() 75 | 76 | if pose: 77 | start_pose = source_angle * shift_scale / angle_scale 78 | min_shift = (-shift_scale - start_pose) 79 | max_shift = (shift_scale - start_pose) + 1e-5 80 | else: 81 | min_shift = (-shift_scale - start_pose) 82 | max_shift = (shift_scale - start_pose) + 1e-5 83 | step = shift_scale / shifts_count 84 | 85 | return type_direction, start_pose, min_shift, max_shift, step -------------------------------------------------------------------------------- /libs/criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /libs/DECA/decalib/utils/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Default config for DECA 3 | ''' 4 | from yacs.config import CfgNode as CN 5 | import argparse 6 | import yaml 7 | import os 8 | 9 | cfg = CN() 10 | 11 | abs_deca_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) 12 | cfg.deca_dir = abs_deca_dir 13 | cfg.device = 'cuda' 14 | cfg.device_id = '0' 15 | 16 | cfg.pretrained_modelpath = os.path.join(cfg.deca_dir, 'data', 'deca_model.tar') 17 | 18 | # ---------------------------------------------------------------------------- # 19 | # Options for Face model 20 | # ---------------------------------------------------------------------------- # 21 | cfg.model = CN() 22 | cfg.model.topology_path = os.path.join(cfg.deca_dir, 'data', 'head_template.obj') 23 | # texture data original from http://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_texture_data.zip 24 | cfg.model.dense_template_path = os.path.join(cfg.deca_dir, 'data', 'texture_data_256.npy') 25 | cfg.model.fixed_displacement_path = os.path.join(cfg.deca_dir, 'data', 'fixed_displacement_256.npy') 26 | cfg.model.flame_model_path = os.path.join(cfg.deca_dir, 'data', 'generic_model.pkl') 27 | cfg.model.flame_lmk_embedding_path = os.path.join(cfg.deca_dir, 'data', 'landmark_embedding.npy') 28 | cfg.model.face_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_mask.png') 29 | cfg.model.face_eye_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_eye_mask.png') 30 | cfg.model.mean_tex_path = os.path.join(cfg.deca_dir, 'data', 'mean_texture.jpg') 31 | cfg.model.tex_path = os.path.join(cfg.deca_dir, 'data', 'FLAME_albedo_from_BFM.npz') 32 | cfg.model.tex_type = 'BFM' # BFM, FLAME, albedoMM 33 | cfg.model.uv_size = 256 34 | cfg.model.param_list = ['shape', 'tex', 'exp', 'pose', 'cam', 'light'] 35 | cfg.model.n_shape = 100 36 | cfg.model.n_tex = 50 37 | cfg.model.n_exp = 50 38 | cfg.model.n_cam = 3 39 | cfg.model.n_pose = 6 40 | cfg.model.n_light = 27 41 | cfg.model.use_tex = False 42 | cfg.model.jaw_type = 'aa' # default use axis angle, another option: euler 43 | 44 | ## details 45 | cfg.model.n_detail = 128 46 | cfg.model.max_z = 0.01 47 | 48 | # ---------------------------------------------------------------------------- # 49 | # Options for Dataset 50 | # ---------------------------------------------------------------------------- # 51 | cfg.dataset = CN() 52 | cfg.dataset.batch_size = 24 53 | cfg.dataset.num_workers = 2 54 | cfg.dataset.image_size = 224 55 | 56 | def get_cfg_defaults(): 57 | """Get a yacs CfgNode object with default values for my_project.""" 58 | # Return a clone so that the defaults will not be altered 59 | # This is for the "local variable" use pattern 60 | return cfg.clone() 61 | 62 | def update_cfg(cfg, cfg_file): 63 | cfg.merge_from_file(cfg_file) 64 | return cfg.clone() 65 | 66 | def parse_args(): 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--cfg', type=str, help='cfg file path') 69 | 70 | args = parser.parse_args() 71 | print(args, end='\n\n') 72 | 73 | cfg = get_cfg_defaults() 74 | if args.cfg is not None: 75 | cfg_file = args.cfg 76 | cfg = update_cfg(cfg, args.cfg) 77 | cfg.cfg_file = cfg_file 78 | 79 | return cfg 80 | -------------------------------------------------------------------------------- /libs/utilities/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import io 5 | import os 6 | import cv2 7 | 8 | from libs.utilities.utils import one_hot 9 | from libs.utilities.generic import calculate_shapemodel, generate_image 10 | from libs.configs.config_directions import get_direction_info 11 | from libs.utilities.image_utils import tensor_to_image 12 | 13 | def get_shifted_image(G, A_matrix, source_code, shift, direction_index, truncation, trunc, w_plus, input_is_latent, num_layers_shift): 14 | shift_vector = one_hot(A_matrix.input_dim, shift, direction_index).cuda() 15 | latent_shift = A_matrix(shift_vector) 16 | shifted_image = generate_image( G, source_code, truncation, trunc, w_plus, shift_code=latent_shift, 17 | input_is_latent = input_is_latent, num_layers_shift = num_layers_shift) 18 | return shifted_image 19 | 20 | @torch.no_grad() 21 | def make_interpolation_chart(G, A_matrix, shape_model, z, directions_exp, angle_scales, angle_directions, info, max_directions = None): 22 | 23 | input_is_latent = info['input_is_latent'] 24 | learned_dims = info['learned_dims'] 25 | count_pose = info['count_pose'] 26 | a_jaw = info['a_jaw'] 27 | b_jaw = info['b_jaw'] 28 | shift_scale = info['shift_scale'] 29 | shifts_count = info['shifts_count'] 30 | w_plus = info['w_plus'] 31 | num_layers_shift = info['num_layers_shift'] 32 | 33 | grids = []; types = [] 34 | truncation = 0.7 35 | trunc = G.mean_latent(4096).detach().clone() 36 | 37 | if max_directions is not None: 38 | if learned_dims > max_directions: 39 | learned_dims = max_directions 40 | 41 | # Generate original image and calculate shape model 42 | original_img = generate_image( G, z, truncation, trunc, w_plus, input_is_latent = input_is_latent, 43 | num_layers_shift = num_layers_shift, return_latents = False) 44 | params_original, angles_original = calculate_shapemodel(shape_model, original_img) 45 | 46 | for direction_index in range(learned_dims): 47 | shifted_images = [] 48 | type_direction, start_pose, min_shift, max_shift, step = get_direction_info(direction_index, angle_directions, a_jaw, b_jaw, directions_exp, 49 | shift_scale, angle_scales, count_pose, shifts_count, params_original, angles_original) 50 | min_shift = (-shift_scale - start_pose) 51 | max_shift = (shift_scale - start_pose) + 1e-5 52 | # min_shift --> start pose 53 | for shift in np.arange(min_shift, start_pose, step): 54 | shifted_image = get_shifted_image(G, A_matrix, z, shift, direction_index, truncation, trunc, w_plus, input_is_latent, num_layers_shift) 55 | shifted_images.append(shifted_image[0].detach().cpu()) 56 | 57 | # start pose -> max_shift 58 | for shift in np.arange(start_pose, max_shift, step): 59 | shifted_image = get_shifted_image(G, A_matrix, z, shift, direction_index, truncation, trunc, w_plus, input_is_latent, num_layers_shift) 60 | shifted_images.append(shifted_image[0].detach().cpu()) 61 | 62 | grids.append(shifted_images) 63 | types.append(type_direction) 64 | 65 | 66 | if len(grids) > 0: 67 | for i in range(len(grids)): 68 | for j in range(len(grids[i])): 69 | im = grids[i][j] 70 | im = tensor_to_image(im) 71 | grids[i][j] = im.astype(np.uint8) 72 | 73 | return grids, types -------------------------------------------------------------------------------- /libs/criteria/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /libs/DECA/decalib/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import os, sys 17 | import torch 18 | from torch.utils.data import Dataset, DataLoader 19 | import torchvision.transforms as transforms 20 | import numpy as np 21 | import cv2 22 | import scipy 23 | from skimage.io import imread, imsave 24 | from skimage.transform import estimate_transform, warp, resize, rescale 25 | from glob import glob 26 | import scipy.io 27 | import torch 28 | import kornia 29 | 30 | from . import detectors 31 | 32 | class TestData(Dataset): 33 | def __init__(self, iscrop=True, crop_size=224, scale=1.25): 34 | ''' 35 | testpath: folder, imagepath_list, image path, video path 36 | ''' 37 | self.crop_size = crop_size 38 | self.scale = scale 39 | self.iscrop = iscrop 40 | self.resolution_inp = crop_size 41 | self.face_detector = detectors.FAN() # CHANGE 42 | 43 | 44 | def bbox2point(self, left, right, top, bottom, type='bbox'): 45 | ''' bbox from detector and landmarks are different 46 | ''' 47 | if type=='kpt68': 48 | old_size = (right - left + bottom - top)/2*1.1 49 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ]) 50 | elif type=='bbox': 51 | old_size = (right - left + bottom - top)/2 52 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.12]) 53 | else: 54 | raise NotImplementedError 55 | return old_size, center 56 | 57 | def get_image_tensor(self, image): 58 | " image: tensor 3x256x256" 59 | img_tmp = image.clone() 60 | img_tmp = img_tmp.permute(1,2,0) 61 | bbox, bbox_type = self.face_detector.run(img_tmp) 62 | if bbox_type != 'error': 63 | if len(bbox) < 4: 64 | print('no face detected! run original image') 65 | left = 0; right = h-1; top=0; bottom=w-1 66 | else: 67 | left = bbox[0]; right=bbox[2] 68 | top = bbox[1]; bottom=bbox[3] 69 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type) 70 | size = int(old_size*self.scale) 71 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 72 | 73 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]]) 74 | tform = estimate_transform('similarity', src_pts, DST_PTS) 75 | theta = torch.tensor(tform.params, dtype=torch.float32).unsqueeze(0).cuda() 76 | 77 | image_tensor = image.clone() 78 | image_tensor = image_tensor.unsqueeze(0) 79 | dst_image = kornia.warp_affine(image_tensor, theta[:,:2,:], dsize=(224, 224)) 80 | dst_image = dst_image.div(255.) 81 | 82 | return dst_image.squeeze(0), False 83 | 84 | else: 85 | 86 | return image, True -------------------------------------------------------------------------------- /extract_statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to extract the npy file with the min, max values of facial pose parameters (yaw, pitch, roll, jaw and expressions) 3 | 1. Generate a set of random synthetic images 4 | 2. Use DECA model to extract the facial shape and the corresponding parameters 5 | 3. Calculate min, max values 6 | """ 7 | 8 | import os 9 | import glob 10 | import numpy as np 11 | from PIL import Image 12 | import torch 13 | from torch.nn import functional as F 14 | import matplotlib.pyplot as plt 15 | import json 16 | import cv2 17 | from tqdm import tqdm 18 | import argparse 19 | from torchvision import utils as torch_utils 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | from libs.utilities.utils import make_noise 24 | from libs.DECA.estimate_DECA import DECA_model 25 | from libs.gan.StyleGAN2.model import Generator as StyleGAN2Generator 26 | from libs.utilities.generic import calculate_shapemodel, generate_image, save_image 27 | from libs.configs.config_models import * 28 | 29 | 30 | 31 | def extract_stats(statistics): 32 | 33 | num_stats = statistics.shape[1] 34 | statistics = np.transpose(statistics, (1, 0)) 35 | ranges = [] 36 | for i in range(statistics.shape[0]): 37 | pred = statistics[i, :] 38 | max_ = np.amax(pred) 39 | min_ = np.amin(pred) 40 | if i == 0: 41 | label = 'yaw' 42 | elif i == 1: 43 | label = 'pitch' 44 | elif i == 2: 45 | label = 'roll' 46 | elif i == 3: 47 | label = 'jaw' 48 | else: 49 | label = 'exp_{:02d}'.format(i) 50 | 51 | print('{}/{} Min {:.2f} Max {:.2f}'.format(label, i, min_, max_)) 52 | 53 | ranges.append([min_, max_]) 54 | 55 | return ranges 56 | 57 | 58 | if __name__ == '__main__': 59 | 60 | num_images = 2000 61 | 62 | image_resolution = 256 63 | dataset = 'voxceleb' 64 | 65 | output_path = './{}_stats'.format(dataset) 66 | make_path(output_path) 67 | 68 | gan_weights = stylegan2_voxceleb_256['gan_weights'] 69 | channel_multiplier = stylegan2_voxceleb_256['channel_multiplier'] 70 | 71 | print('----- Load generator from {} -----'.format(gan_weights)) 72 | truncation = 0.7 73 | generator = StyleGAN2Generator(image_resolution, 512, 8, channel_multiplier= channel_multiplier) 74 | generator.load_state_dict(torch.load(gan_weights)['g_ema'], strict = True) 75 | generator.cuda().eval() 76 | trunc = generator.mean_latent(4096).detach().clone() 77 | 78 | shape_model = DECA_model('cuda') 79 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 80 | 81 | statistics = [] 82 | with torch.no_grad(): 83 | for i in tqdm(range(num_images)): 84 | z = make_noise(1, 512).cuda() 85 | source_img = generator([z], return_latents = False, truncation = truncation, truncation_latent = trunc, input_is_latent = False)[0] 86 | source_img = face_pool(source_img) 87 | params_source, angles_source = calculate_shapemodel(shape_model, source_img) 88 | 89 | yaw = angles_source[:,0][0].detach().cpu().numpy() 90 | pitch = angles_source[:,1][0].detach().cpu().numpy() 91 | roll = angles_source[:, 2][0].detach().cpu().numpy() 92 | exp = params_source['alpha_exp'][0].detach().cpu().numpy() 93 | jaw = params_source['pose'][0, 3].detach().cpu().numpy() 94 | 95 | tmp = np.zeros(54) 96 | tmp[0] = yaw 97 | tmp[1] = pitch 98 | tmp[2] = roll 99 | tmp[3] = jaw 100 | tmp[4:] = exp 101 | # np.save(os.path.join(output_path, '{:07d}.npy'.format(i)), tmp) 102 | statistics.append(tmp) 103 | 104 | statistics = np.asarray(statistics) 105 | np.save(os.path.join(output_path, 'stats_all.npy'), statistics) 106 | 107 | ranges = extract_stats(statistics) 108 | 109 | np.save(os.path.join(output_path, 'ranges_{}.npy'.format(dataset)), ranges) 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /libs/utilities/utils_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torchvision import utils as torch_utils 5 | import cv2 6 | 7 | from libs.utilities.image_utils import * 8 | from libs.face_models.ffhq_cropping import crop_using_landmarks 9 | from libs.utilities.utils import make_path 10 | 11 | def generate_video(images, video_path, fps = 25): 12 | dim = (images[0].shape[1], images[0].shape[0]) 13 | com_video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'MP4V') , fps, dim) 14 | 15 | for image in images: 16 | com_video.write(np.uint8(image)) 17 | 18 | com_video.release() 19 | 20 | def generate_grid_image(source, target, reenacted): 21 | num_images = source.shape[0] # batch size 22 | width = 256; height = 256 23 | grid_image = torch.zeros((3, num_images*height, 3*width)) 24 | for i in range(num_images): 25 | s = i*height 26 | e = s + height 27 | grid_image[:, s:e, :width] = source[i, :, :, :] 28 | grid_image[:, s:e, width:2*width] = target[i, :, :, :] 29 | grid_image[:, s:e, 2*width:] = reenacted[i, :, :, :] 30 | 31 | if grid_image.shape[1] > 1000: # height 32 | grid_image = torch_image_resize(grid_image, height = 800) 33 | return grid_image 34 | 35 | def extract_frames(video_path, fps = 25, save_frames = None, get_only_first = False): 36 | if not get_only_first: 37 | print('Extract frames with {} fps'.format(fps)) 38 | if save_frames is not None: 39 | make_path(save_frames) 40 | cap = cv2.VideoCapture(video_path) 41 | counter = 0 42 | frames = [] 43 | while cap.isOpened(): 44 | ret, frame = cap.read() 45 | if not ret: 46 | break 47 | if get_only_first: 48 | return cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 49 | if counter % fps == 0: 50 | if save_frames is not None: 51 | cv2.imwrite(os.path.join(save_frames, '{:06d}.png'.format(counter)), frame) 52 | frames.append( cv2.cvtColor(frame.copy(), cv2.COLOR_RGB2BGR)) 53 | counter += 1 54 | 55 | cap.release() 56 | cv2.destroyAllWindows() 57 | 58 | return np.asarray(frames) 59 | 60 | " Crop images using facial landmarks " 61 | def preprocess_image(image_path, landmarks_est, save_filename = None): 62 | 63 | if os.path.isfile(image_path): 64 | image = read_image_opencv(image_path) 65 | else: 66 | image = image_path 67 | image, scale = image_resize(image, width = 1000) 68 | image_tensor = torch.tensor(np.transpose(image, (2,0,1))).float().cuda() 69 | 70 | with torch.no_grad(): 71 | landmarks = landmarks_est.detect_landmarks(image_tensor.unsqueeze(0)) 72 | landmarks = landmarks[0].detach().cpu().numpy() 73 | landmarks = np.asarray(landmarks) 74 | 75 | img = crop_using_landmarks(image, landmarks) 76 | if img is not None and save_filename is not None: 77 | cv2.imwrite(save_filename, cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR)) 78 | if img is not None: 79 | return img 80 | else: 81 | print('Error with image preprocessing') 82 | exit() 83 | 84 | " Invert real image into the latent space of StyleGAN2 " 85 | def invert_image(image, encoder, generator, truncation, trunc, save_path = None, save_name = None): 86 | with torch.no_grad(): 87 | latent_codes = encoder(image) 88 | inverted_images, _ = generator([latent_codes], input_is_latent=True, return_latents = False, truncation= truncation, truncation_latent=trunc) 89 | 90 | if save_path is not None and save_name is not None: 91 | grid = torch_utils.save_image( 92 | inverted_images, 93 | os.path.join(save_path, '{}.png'.format(save_name)), 94 | normalize=True, 95 | range=(-1, 1), 96 | ) 97 | # Latent code 98 | latent_code = latent_codes[0].detach().cpu().numpy() 99 | save_dir = os.path.join(save_path, '{}.npy'.format(save_name)) 100 | np.save(save_dir, latent_code) 101 | 102 | return inverted_images, latent_codes -------------------------------------------------------------------------------- /libs/datasets/dataloader_inversion.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | import glob 5 | import cv2 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from torchvision import utils as torch_utils 9 | 10 | class DatasetInversion(torch.utils.data.Dataset): 11 | 12 | def __init__(self, root_path, num_images = None): 13 | """ 14 | Args: 15 | root_path: VoxCeleb dataset images are saved as: id_index/video_id/frames_path/*.png 16 | """ 17 | self.root_path = root_path 18 | self.images_files = None 19 | 20 | ids_path = glob.glob(os.path.join(root_path, '*/')) 21 | ids_path.sort() 22 | count_videos = 0 23 | # print('Dataset has {} identities'.format(len(ids_path))) 24 | for i, id_path in enumerate(ids_path): 25 | id_index = id_path.split('/')[-2] 26 | videos_path = glob.glob(os.path.join(id_path, '*/')) 27 | videos_path.sort() 28 | count_videos += len(videos_path) 29 | for j, video_path in enumerate(videos_path): 30 | images_files_tmp = glob.glob(os.path.join(video_path, 'frames_cropped', '*.png')) 31 | images_files_tmp.sort() 32 | if self.images_files is None: 33 | self.images_files = images_files_tmp 34 | else: 35 | self.images_files = np.concatenate((self.images_files, images_files_tmp)) 36 | 37 | if num_images is not None: 38 | self.images_files = self.images_files[:num_images] 39 | self.images_files.sort() 40 | 41 | self.len_images = self.get_length() 42 | self.indices = [ x for x in range(self.len_images) ] 43 | self.indices_temporal = self.indices.copy() 44 | 45 | print('Dataset has {} identities, {} videos and {} frames'.format(len(ids_path), count_videos, len(self.images_files))) 46 | 47 | def __len__(self): 48 | 'Denotes the total number of samples' 49 | return len(self.images_files) 50 | 51 | def __getitem__(self, index): 52 | 53 | img_name = self.images_files[index] 54 | tmp = img_name.split('/') 55 | filenames = tmp[-1] 56 | video_indices = tmp[-3] 57 | id_indices = tmp[-4] 58 | 59 | images = self.image_to_tensor(img_name) 60 | 61 | out_dict = { 62 | 'images': images, 63 | 'filenames': filenames, 64 | 'id_indices': id_indices, 65 | 'video_indices': video_indices 66 | } 67 | 68 | return out_dict 69 | 70 | def get_length(self): 71 | 72 | return len(self.images_files) 73 | 74 | def get_sample(self, idx): 75 | 76 | img_name = self.images_files[idx] 77 | tmp = img_name.split('/') 78 | filename = tmp[-1] 79 | video_index = tmp[-3] 80 | id_index = tmp[-4] 81 | 82 | image = self.image_to_tensor(img_name) 83 | 84 | return image, filename, id_index, video_index 85 | 86 | def get_batch(self, batch_size = 2): 87 | 88 | images = torch.zeros(batch_size, 3, 256, 256) 89 | filenames = [] 90 | id_indices = []; video_indices = [] 91 | for i in range(int(batch_size)): 92 | if len(self.indices_temporal)>0: 93 | index_pick = self.indices_temporal[-1] 94 | image, filename, id_index, video_index = self.get_sample(index_pick) 95 | self.indices_temporal.pop(-1) 96 | filenames.append(filename) 97 | id_indices.append(id_index) 98 | video_indices.append(video_index) 99 | images[i] = image 100 | 101 | out_dict = { 102 | 'images': images, 103 | 'filenames': filenames, 104 | 'id_indices': id_indices, 105 | 'video_indices': video_indices 106 | } 107 | return out_dict 108 | 109 | 'Transform images to tensor [-1,1]. Generators space' 110 | def image_to_tensor(self, image_file): 111 | max_val = 1 112 | min_val = -1 113 | # for image_file in image_files: 114 | image = cv2.imread(image_file, cv2.IMREAD_COLOR) # BGR order!!!! 115 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype('uint8') 116 | 117 | 118 | if image.shape[0]>256: 119 | image, _ = image_resize(image, 256) 120 | 121 | image_tensor = torch.tensor(np.transpose(image,(2,0,1))).float().div(255.0) 122 | image_tensor = image_tensor * (max_val - min_val) + min_val 123 | 124 | return image_tensor -------------------------------------------------------------------------------- /libs/criteria/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /libs/face_models/sfd/bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import cv2 5 | import random 6 | import datetime 7 | import time 8 | import math 9 | import argparse 10 | import numpy as np 11 | import torch 12 | 13 | try: 14 | from iou import IOU 15 | except BaseException: 16 | # IOU cython speedup 10x 17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): 18 | sa = abs((ax2 - ax1) * (ay2 - ay1)) 19 | sb = abs((bx2 - bx1) * (by2 - by1)) 20 | x1, y1 = max(ax1, bx1), max(ay1, by1) 21 | x2, y2 = min(ax2, bx2), min(ay2, by2) 22 | w = x2 - x1 23 | h = y2 - y1 24 | if w < 0 or h < 0: 25 | return 0.0 26 | else: 27 | return 1.0 * w * h / (sa + sb - w * h) 28 | 29 | 30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): 31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh 33 | dw, dh = math.log(ww / aww), math.log(hh / ahh) 34 | return dx, dy, dw, dh 35 | 36 | 37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): 38 | xc, yc = dx * aww + axc, dy * ahh + ayc 39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh 40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 41 | return x1, y1, x2, y2 42 | 43 | 44 | def nms(dets, thresh): 45 | # print(dets) 46 | if 0 == len(dets): 47 | return [] 48 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] 49 | # print(x1,x2,y1,y2) 50 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 51 | order = scores.argsort()[::-1] 52 | 53 | keep = [] 54 | while order.size > 0: 55 | i = order[0] 56 | keep.append(i) 57 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) 58 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) 59 | 60 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) 61 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h) 62 | 63 | inds = np.where(ovr <= thresh)[0] 64 | order = order[inds + 1] 65 | 66 | return keep 67 | 68 | 69 | def encode(matched, priors, variances): 70 | """Encode the variances from the priorbox layers into the ground truth boxes 71 | we have matched (based on jaccard overlap) with the prior boxes. 72 | Args: 73 | matched: (tensor) Coords of ground truth for each prior in point-form 74 | Shape: [num_priors, 4]. 75 | priors: (tensor) Prior boxes in center-offset form 76 | Shape: [num_priors,4]. 77 | variances: (list[float]) Variances of priorboxes 78 | Return: 79 | encoded boxes (tensor), Shape: [num_priors, 4] 80 | """ 81 | 82 | # dist b/t match center and prior's center 83 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] 84 | # encode variance 85 | g_cxcy /= (variances[0] * priors[:, 2:]) 86 | # match wh / prior wh 87 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 88 | g_wh = torch.log(g_wh) / variances[1] 89 | # return target for smooth_l1_loss 90 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 91 | 92 | 93 | def decode(loc, priors, variances): 94 | """Decode locations from predictions using priors to undo 95 | the encoding we did for offset regression at train time. 96 | Args: 97 | loc (tensor): location predictions for loc layers, 98 | Shape: [num_priors,4] 99 | priors (tensor): Prior boxes in center-offset form. 100 | Shape: [num_priors,4]. 101 | variances: (list[float]) Variances of priorboxes 102 | Return: 103 | decoded bounding box predictions 104 | """ 105 | 106 | boxes = torch.cat(( 107 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 108 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 109 | boxes[:, :2] -= boxes[:, 2:] / 2 110 | boxes[:, 2:] += boxes[:, :2] 111 | return boxes 112 | -------------------------------------------------------------------------------- /libs/utilities/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # import scipy.misc 4 | import cv2 5 | import torchvision 6 | import os 7 | 8 | def torch_image_resize(image, width = None, height = None): 9 | dim = None 10 | (h, w) = image.shape[1:] 11 | # if both the width and height are None, then return the 12 | # original image 13 | if width is None and height is None: 14 | return image 15 | 16 | # check to see if the width is None 17 | if width is None: 18 | # calculate the ratio of the height and construct the 19 | # dimensions 20 | r = height / float(h) 21 | dim = (height, int(w * r)) 22 | scale = r 23 | # otherwise, the height is None 24 | else: 25 | # calculate the ratio of the width and construct the 26 | # dimensions 27 | r = width / float(w) 28 | dim = (int(h * r), width) 29 | scale = r 30 | image = image.unsqueeze(0) 31 | image = torch.nn.functional.interpolate(image, size=dim, mode='bilinear') 32 | return image.squeeze(0) 33 | 34 | 35 | " Resize numpy array image" 36 | def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA): 37 | # initialize the dimensions of the image to be resized and 38 | # grab the image size 39 | dim = None 40 | (h, w) = image.shape[:2] 41 | 42 | # if both the width and height are None, then return the 43 | # original image 44 | if width is None and height is None: 45 | return image 46 | 47 | # check to see if the width is None 48 | if width is None: 49 | # calculate the ratio of the height and construct the 50 | # dimensions 51 | r = height / float(h) 52 | dim = (int(w * r), height) 53 | scale = r 54 | 55 | # otherwise, the height is None 56 | else: 57 | # calculate the ratio of the width and construct the 58 | # dimensions 59 | r = width / float(w) 60 | dim = (width, int(h * r)) 61 | scale = r 62 | 63 | # resize the image 64 | resized = cv2.resize(image, dim, interpolation = inter) 65 | 66 | return resized, scale 67 | 68 | " Read image from path" 69 | def read_image_opencv(image_path): 70 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) # BGR order!!!! 71 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 72 | 73 | return img.astype('uint8') 74 | 75 | "Transform image tensor to numpy array" 76 | def im_to_numpy(img): 77 | img = img.detach().cpu().numpy() 78 | img = np.transpose(img, (1, 2, 0)) # H*W*C 79 | return img 80 | 81 | "Transform numpy 255 to image torch div(255) 3 x H x W" 82 | def numpy_255_to_torch_1(image): 83 | image_tensor = torch.tensor(np.transpose(image,(2,0,1))).float().div(255.0) 84 | return image_tensor 85 | 86 | " Trasnform torch tensor images from range [-1,1] to [0,255]" 87 | def torch_range_1_to_255(image): 88 | img_tmp = image.clone() 89 | min_val = -1 90 | max_val = 1 91 | img_tmp.clamp_(min=min_val, max=max_val) 92 | img_tmp.add_(-min_val).div_(max_val - min_val + 1e-5) 93 | img_tmp = img_tmp.mul(255.0) 94 | return img_tmp 95 | 96 | " Trasnform torch tensor to numpy images from range [-1,1] to [0,255]" 97 | def tensor_to_image(image_tensor): 98 | if image_tensor.ndim == 4: 99 | image_tensor = image_tensor.squeeze(0) 100 | 101 | min_val = -1 102 | max_val = 1 103 | image_tensor.clamp_(min=min_val, max=max_val) 104 | image_tensor.add_(-min_val).div_(max_val - min_val + 1e-5) 105 | image_tensor = image_tensor.mul(255.0).add(0.0) 106 | 107 | image_tensor = image_tensor.detach().cpu().numpy() 108 | image_tensor = np.transpose(image_tensor, (1, 2, 0)) 109 | 110 | return image_tensor 111 | 112 | " Load image from file path to tensor [-1,1] range " 113 | def image_to_tensor(image_file): 114 | max_val = 1 115 | min_val = -1 116 | if os.path.isfile(image_file): 117 | image = cv2.imread(image_file, cv2.IMREAD_COLOR) # BGR order 118 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype('uint8') 119 | else: 120 | image = image_file 121 | if image.shape[0]>256: 122 | image, _ = image_resize(image, 256) 123 | image_tensor = torch.tensor(np.transpose(image,(2,0,1))).float().div(255.0) 124 | image_tensor = image_tensor * (max_val - min_val) + min_val 125 | 126 | return image_tensor 127 | 128 | " Draw a red rectangle around image " 129 | def add_border(tensor): 130 | border = 3 131 | for ch in range(tensor.shape[0]): 132 | color = 1.0 if ch == 0 else -1 133 | tensor[ch, :border, :] = color 134 | tensor[ch, -border:,] = color 135 | tensor[ch, :, :border] = color 136 | tensor[ch, :, -border:] = color 137 | return tensor -------------------------------------------------------------------------------- /invert_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Invert VoxCeleb dataset using Encoder4Editing method https://github.com/omertov/encoder4editing 3 | 4 | Inputs: 5 | --input_path: Path to voxceleb dataset. The dataset format should be: id_index/video_id/frames_cropped/*.png 6 | 7 | Inverted images will be saved in input_path as: 8 | id_index/video_id/inversion/frames/*.png 9 | id_index/video_id/inversion/latent_codes/*.npy 10 | 11 | python invert_images.py --input_path /datasets/VoxCeleb1/VoxCeleb1_test 12 | """ 13 | 14 | import os 15 | import numpy as np 16 | import torch 17 | from torchvision import utils as torch_utils 18 | from torch.utils.data import DataLoader 19 | from tqdm import tqdm 20 | from argparse import ArgumentParser 21 | import warnings 22 | warnings.filterwarnings("ignore") 23 | 24 | from libs.gan.StyleGAN2.model import Generator as StyleGAN2Generator 25 | from libs.datasets.dataloader_inversion import DatasetInversion 26 | from libs.gan.encoder4editing.psp_encoders import Encoder4Editing 27 | 28 | 29 | parser = ArgumentParser() 30 | 31 | parser.add_argument("--input_path", required = True, help='Path to VoxCeleb dataset') 32 | 33 | parser.add_argument("--generator_path", default = './pretrained_models/stylegan-voxceleb.pt', help='Path to generator model') 34 | parser.add_argument("--encoder_path", default = './pretrained_models/e4e-voxceleb.pt', help='Path to encoder model') 35 | parser.add_argument("--batch_size", default = 4, help='batch size') 36 | parser.add_argument("--image_resolution", default = 256, help='Path to generator model') 37 | parser.add_argument("--channel_multiplier", default = 1, help='Path to generator model') 38 | 39 | 40 | def make_path(path): 41 | if not os.path.exists(path): 42 | os.makedirs(path) 43 | 44 | class Inversion(object): 45 | 46 | def __init__(self): 47 | 48 | args = parser.parse_args() 49 | self.input_path = args.input_path 50 | 51 | self.output_path = self.input_path # output_path same with the input_path" 52 | self.generator_path = args.generator_path 53 | self.encoder_path = args.encoder_path 54 | self.channel_multiplier = args.channel_multiplier 55 | self.batch_size = args.batch_size 56 | self.image_resolution = args.image_resolution 57 | 58 | 59 | def load_networks(self): 60 | print('----- Load encoder from {} -----'.format(self.encoder_path)) 61 | 62 | self.encoder = Encoder4Editing(50, 'ir_se', self.image_resolution) 63 | ckpt = torch.load(self.encoder_path) 64 | self.encoder.load_state_dict(ckpt['e']) 65 | self.encoder.cuda().eval() 66 | 67 | print('----- Load generator from {} -----'.format(self.generator_path)) 68 | self.generator = StyleGAN2Generator(self.image_resolution, 512, 8, channel_multiplier = self.channel_multiplier) 69 | self.generator.load_state_dict(torch.load(self.generator_path)['g_ema'], strict = False) 70 | self.generator.cuda().eval() 71 | 72 | self.truncation = 0.7 73 | self.trunc = self.generator.mean_latent(4096).detach().clone() 74 | 75 | def configure_dataset(self): 76 | self.dataset = DatasetInversion(self.input_path, num_images=None) 77 | 78 | self.dataloader = DataLoader(self.dataset, 79 | batch_size=self.batch_size, 80 | shuffle=False, 81 | num_workers=1, 82 | drop_last=False) 83 | 84 | def run_inversion_dataset(self): 85 | 86 | if not os.path.exists(self.output_path): 87 | os.makedirs(self.output_path, exist_ok=True) 88 | 89 | self.configure_dataset() 90 | self.load_networks() 91 | step = 0 92 | 93 | for batch_idx, batch in enumerate(tqdm(self.dataloader)): 94 | 95 | sample_dict = batch 96 | filenames = sample_dict['filenames'] 97 | id_indices = sample_dict['id_indices'] 98 | video_indices = sample_dict['video_indices'] 99 | images = sample_dict['images'].cuda() 100 | 101 | with torch.no_grad(): 102 | latent_codes = self.encoder(images) 103 | inverted_images, _ = self.generator([latent_codes], input_is_latent=True, return_latents = False, truncation= self.truncation, truncation_latent=self.trunc) 104 | 105 | 106 | # Save inverted images and latent codes 107 | for i in range(len(id_indices)): 108 | output_path_local = os.path.join(self.output_path, id_indices[i], video_indices[i], 'inversion') 109 | make_path(output_path_local) 110 | make_path(os.path.join(output_path_local, 'frames')) 111 | save_dir = os.path.join(output_path_local, 'frames', filenames[i]) 112 | grid = torch_utils.save_image( 113 | inverted_images[i], 114 | save_dir, 115 | normalize=True, 116 | range=(-1, 1), 117 | ) 118 | 119 | # Latent code 120 | latent_code = latent_codes[i].detach().cpu().numpy() 121 | make_path(os.path.join(output_path_local, 'latent_codes')) 122 | latent_filename = filenames[i].split('.')[0] 123 | latent_filename = latent_filename + '.npy' 124 | save_dir = os.path.join(output_path_local, 'latent_codes', latent_filename) 125 | np.save(save_dir, latent_code) 126 | 127 | step += self.batch_size 128 | 129 | if __name__ == "__main__": 130 | 131 | inversion = Inversion() 132 | inversion.run_inversion_dataset() 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /libs/gan/encoder4editing/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 5 | 6 | """ 7 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | 16 | def l2_norm(input, axis=1): 17 | norm = torch.norm(input, 2, axis, True) 18 | output = torch.div(input, norm) 19 | return output 20 | 21 | 22 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 23 | """ A named tuple describing a ResNet block. """ 24 | 25 | 26 | def get_block(in_channel, depth, num_units, stride=2): 27 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 28 | 29 | 30 | def get_blocks(num_layers): 31 | if num_layers == 50: 32 | blocks = [ 33 | get_block(in_channel=64, depth=64, num_units=3), 34 | get_block(in_channel=64, depth=128, num_units=4), 35 | get_block(in_channel=128, depth=256, num_units=14), 36 | get_block(in_channel=256, depth=512, num_units=3) 37 | ] 38 | elif num_layers == 100: 39 | blocks = [ 40 | get_block(in_channel=64, depth=64, num_units=3), 41 | get_block(in_channel=64, depth=128, num_units=13), 42 | get_block(in_channel=128, depth=256, num_units=30), 43 | get_block(in_channel=256, depth=512, num_units=3) 44 | ] 45 | elif num_layers == 152: 46 | blocks = [ 47 | get_block(in_channel=64, depth=64, num_units=3), 48 | get_block(in_channel=64, depth=128, num_units=8), 49 | get_block(in_channel=128, depth=256, num_units=36), 50 | get_block(in_channel=256, depth=512, num_units=3) 51 | ] 52 | else: 53 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 54 | return blocks 55 | 56 | 57 | class SEModule(Module): 58 | def __init__(self, channels, reduction): 59 | super(SEModule, self).__init__() 60 | self.avg_pool = AdaptiveAvgPool2d(1) 61 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 62 | self.relu = ReLU(inplace=True) 63 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 64 | self.sigmoid = Sigmoid() 65 | 66 | def forward(self, x): 67 | module_input = x 68 | x = self.avg_pool(x) 69 | x = self.fc1(x) 70 | x = self.relu(x) 71 | x = self.fc2(x) 72 | x = self.sigmoid(x) 73 | return module_input * x 74 | 75 | 76 | class bottleneck_IR(Module): 77 | def __init__(self, in_channel, depth, stride): 78 | super(bottleneck_IR, self).__init__() 79 | if in_channel == depth: 80 | self.shortcut_layer = MaxPool2d(1, stride) 81 | else: 82 | self.shortcut_layer = Sequential( 83 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 84 | BatchNorm2d(depth) 85 | ) 86 | self.res_layer = Sequential( 87 | BatchNorm2d(in_channel), 88 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 89 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 90 | ) 91 | 92 | def forward(self, x): 93 | shortcut = self.shortcut_layer(x) 94 | res = self.res_layer(x) 95 | return res + shortcut 96 | 97 | 98 | class bottleneck_IR_SE(Module): 99 | def __init__(self, in_channel, depth, stride): 100 | super(bottleneck_IR_SE, self).__init__() 101 | if in_channel == depth: 102 | self.shortcut_layer = MaxPool2d(1, stride) 103 | else: 104 | self.shortcut_layer = Sequential( 105 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 106 | BatchNorm2d(depth) 107 | ) 108 | self.res_layer = Sequential( 109 | BatchNorm2d(in_channel), 110 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 111 | PReLU(depth), 112 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 113 | BatchNorm2d(depth), 114 | SEModule(depth, 16) 115 | ) 116 | 117 | def forward(self, x): 118 | shortcut = self.shortcut_layer(x) 119 | res = self.res_layer(x) 120 | return res + shortcut 121 | 122 | 123 | def _upsample_add(x, y): 124 | """Upsample and add two feature maps. 125 | Args: 126 | x: (Variable) top feature map to be upsampled. 127 | y: (Variable) lateral feature map. 128 | Returns: 129 | (Variable) added feature map. 130 | Note in PyTorch, when input size is odd, the upsampled feature map 131 | with `F.upsample(..., scale_factor=2, mode='nearest')` 132 | maybe not equal to the lateral feature map size. 133 | e.g. 134 | original input size: [N,_,15,15] -> 135 | conv2d feature map size: [N,_,8,8] -> 136 | upsampled feature map size: [N,_,16,16] 137 | So we choose bilinear upsample which supports arbitrary output sizes. 138 | """ 139 | _, _, H, W = y.size() 140 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 141 | -------------------------------------------------------------------------------- /libs/utilities/generic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from numpy import ones,vstack 4 | from numpy.linalg import lstsq 5 | import os 6 | import json 7 | from torchvision import utils as torch_utils 8 | 9 | from libs.configs.config_directions import get_direction_ranges, voxceleb_dict, ffhq_dict 10 | from libs.utilities.image_utils import * 11 | 12 | 13 | def save_image(image, save_image_path): 14 | 15 | grid = torch_utils.save_image( 16 | image, 17 | save_image_path, 18 | normalize=True, 19 | range=(-1, 1), 20 | ) 21 | 22 | def calculate_shapemodel(deca_model, images, image_space = 'gan'): 23 | img_tmp = images.clone() 24 | if image_space == 'gan': 25 | img_tmp = torch_range_1_to_255(img_tmp) 26 | 27 | p_tensor, alpha_shp_tensor, alpha_exp_tensor, angles, cam = deca_model.extract_DECA_params(img_tmp) # params dictionary 28 | out_dict = {} 29 | out_dict['pose'] = p_tensor 30 | out_dict['alpha_exp'] = alpha_exp_tensor 31 | out_dict['alpha_shp'] = alpha_shp_tensor 32 | out_dict['cam'] = cam 33 | 34 | return out_dict, angles.cuda() 35 | 36 | def initialize_directions(dataset_type, learned_directions, shift_scale): 37 | if dataset_type == 'voxceleb': 38 | ranges = get_direction_ranges(voxceleb_dict['ranges_filepath']) 39 | jaw_range = ranges[3] 40 | max_jaw = jaw_range[1] 41 | min_jaw = jaw_range[0] 42 | exp_ranges = ranges[4:] 43 | 44 | angle_scales = np.zeros(3) 45 | angle_scales[0] = voxceleb_dict['yaw_scale'] 46 | angle_scales[1] = voxceleb_dict['pitch_scale'] 47 | angle_scales[2] = voxceleb_dict['roll_scale'] 48 | 49 | angle_directions = np.zeros(3) 50 | angle_directions[0] = int(voxceleb_dict['yaw_direction']) 51 | angle_directions[1] = int(voxceleb_dict['pitch_direction']) 52 | angle_directions[2] = int(voxceleb_dict['roll_direction']) 53 | 54 | else: 55 | angle_scales = np.zeros(3) 56 | angle_scales[0] = ffhq_dict['yaw_scale'] 57 | angle_scales[1] = ffhq_dict['pitch_scale'] 58 | angle_scales[2] = ffhq_dict['roll_scale'] 59 | 60 | angle_directions = np.zeros(3) 61 | angle_directions[0] = ffhq_dict['yaw_direction'] 62 | angle_directions[1] = ffhq_dict['pitch_direction'] 63 | angle_directions[2] = ffhq_dict['roll_direction'] 64 | exp_ranges = get_direction_ranges(ffhq_dict['ranges_filepath']) 65 | 66 | jaw_range = exp_ranges[3] 67 | jaw_range = jaw_range 68 | max_jaw = jaw_range[1] 69 | min_jaw = jaw_range[0] 70 | exp_ranges = exp_ranges[4:] 71 | 72 | directions_exp = [] 73 | count_pose = 0 74 | if angle_directions[0] != -1: 75 | count_pose += 1 76 | if angle_directions[1] != -1: 77 | count_pose += 1 78 | if angle_directions[2] != -1: 79 | count_pose += 1 80 | count_pose += 1 # Jaw 81 | num_expressions = learned_directions - count_pose 82 | 83 | 84 | for i in range(num_expressions): 85 | dict_3d = {} 86 | dict_3d['exp_component'] = i 87 | dict_3d['A_direction'] = i + count_pose 88 | dict_3d['max_shift'] = exp_ranges[i][1] 89 | dict_3d['min_shift'] = exp_ranges[i][0] 90 | 91 | points = [(dict_3d['min_shift'], - shift_scale),(dict_3d['max_shift'], shift_scale)] 92 | x_coords, y_coords = zip(*points) 93 | A = vstack([x_coords,ones(len(x_coords))]).T 94 | m, c = lstsq(A, y_coords, rcond=None)[0] 95 | dict_3d['a'] = m 96 | dict_3d['b'] = c 97 | directions_exp.append(dict_3d) 98 | 99 | 100 | points = [(min_jaw, -6),(max_jaw, 6)] 101 | x_coords, y_coords = zip(*points) 102 | A = vstack([x_coords,ones(len(x_coords))]).T 103 | m, c = lstsq(A, y_coords, rcond=None)[0] 104 | a_jaw = m 105 | b_jaw = c 106 | 107 | jaw_dict = { 108 | 'a': a_jaw, 109 | 'b': b_jaw, 110 | 'max': max_jaw, 111 | 'min': min_jaw 112 | } 113 | 114 | return count_pose, num_expressions, directions_exp, jaw_dict, angle_scales, angle_directions 115 | 116 | def get_shifted_latent_code(G, z, shift, input_is_latent = False, truncation=1, truncation_latent = None, w_plus = False, num_layers = None): 117 | inject_index = G.n_latent 118 | if not input_is_latent: # Z space 119 | w = G.get_latent(z) 120 | latent = w.unsqueeze(1).repeat(1, inject_index, 1) 121 | else: # W space 122 | latent = z.clone() 123 | if not w_plus: # shift = B x 512 124 | if num_layers is None: # add shift in all layers 125 | shift_rep = shift.unsqueeze(1) 126 | shift_rep = shift_rep.repeat(1, inject_index, 1) 127 | latent += shift_rep 128 | else: 129 | for i in range(num_layers): 130 | latent[:, i,:] += shift 131 | 132 | else:# shift= B x num_layers x 512 133 | latent[:, :shift.shape[1],:] += shift 134 | 135 | return latent 136 | 137 | def generate_image( G, latent_code, truncation, trunc, w_plus = True, num_layers_shift = 8, shift_code = None, input_is_latent = False, return_latents = False): 138 | if shift_code is None: 139 | imgs = G([latent_code], return_latents = return_latents, truncation = truncation, truncation_latent = trunc, input_is_latent = input_is_latent) 140 | else: 141 | shifted_code = get_shifted_latent_code(G, latent_code, shift_code, input_is_latent = input_is_latent, truncation=truncation, 142 | truncation_latent = trunc, w_plus = w_plus, num_layers = num_layers_shift) 143 | imgs = G([shifted_code], return_latents = return_latents, truncation = truncation, truncation_latent = trunc, input_is_latent = True) 144 | image = imgs[0] 145 | latent_w = imgs[1] 146 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 147 | if image.shape[2] > 256: 148 | image = face_pool(image) 149 | if return_latents: 150 | return image, latent_w 151 | else: 152 | return image -------------------------------------------------------------------------------- /libs/face_models/sfd/net_s3fd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L2Norm(nn.Module): 7 | def __init__(self, n_channels, scale=1.0): 8 | super(L2Norm, self).__init__() 9 | self.n_channels = n_channels 10 | self.scale = scale 11 | self.eps = 1e-10 12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 13 | self.weight.data *= 0.0 14 | self.weight.data += self.scale 15 | 16 | def forward(self, x): 17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 18 | x = x / norm * self.weight.view(1, -1, 1, 1) 19 | return x 20 | 21 | 22 | class s3fd(nn.Module): 23 | def __init__(self): 24 | super(s3fd, self).__init__() 25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 27 | 28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 30 | 31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 34 | 35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 42 | 43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) 44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) 45 | 46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 48 | 49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) 50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 51 | 52 | self.conv3_3_norm = L2Norm(256, scale=10) 53 | self.conv4_3_norm = L2Norm(512, scale=8) 54 | self.conv5_3_norm = L2Norm(512, scale=5) 55 | 56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 62 | 63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) 64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) 65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) 68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 69 | 70 | def forward(self, x): 71 | h = F.relu(self.conv1_1(x)) 72 | h = F.relu(self.conv1_2(h)) 73 | h = F.max_pool2d(h, 2, 2) 74 | 75 | h = F.relu(self.conv2_1(h)) 76 | h = F.relu(self.conv2_2(h)) 77 | h = F.max_pool2d(h, 2, 2) 78 | 79 | h = F.relu(self.conv3_1(h)) 80 | h = F.relu(self.conv3_2(h)) 81 | h = F.relu(self.conv3_3(h)) 82 | f3_3 = h 83 | h = F.max_pool2d(h, 2, 2) 84 | 85 | h = F.relu(self.conv4_1(h)) 86 | h = F.relu(self.conv4_2(h)) 87 | h = F.relu(self.conv4_3(h)) 88 | f4_3 = h 89 | h = F.max_pool2d(h, 2, 2) 90 | 91 | h = F.relu(self.conv5_1(h)) 92 | h = F.relu(self.conv5_2(h)) 93 | h = F.relu(self.conv5_3(h)) 94 | f5_3 = h 95 | h = F.max_pool2d(h, 2, 2) 96 | 97 | h = F.relu(self.fc6(h)) 98 | h = F.relu(self.fc7(h)) 99 | ffc7 = h 100 | h = F.relu(self.conv6_1(h)) 101 | h = F.relu(self.conv6_2(h)) 102 | f6_2 = h 103 | h = F.relu(self.conv7_1(h)) 104 | h = F.relu(self.conv7_2(h)) 105 | f7_2 = h 106 | 107 | f3_3 = self.conv3_3_norm(f3_3) 108 | f4_3 = self.conv4_3_norm(f4_3) 109 | f5_3 = self.conv5_3_norm(f5_3) 110 | 111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3) 112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3) 113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3) 114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3) 115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3) 116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3) 117 | cls4 = self.fc7_mbox_conf(ffc7) 118 | reg4 = self.fc7_mbox_loc(ffc7) 119 | cls5 = self.conv6_2_mbox_conf(f6_2) 120 | reg5 = self.conv6_2_mbox_loc(f6_2) 121 | cls6 = self.conv7_2_mbox_conf(f7_2) 122 | reg6 = self.conv7_2_mbox_loc(f7_2) 123 | 124 | # max-out background label 125 | chunk = torch.chunk(cls1, 4, 1) 126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) 127 | cls1 = torch.cat([bmax, chunk[3]], dim=1) 128 | 129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] 130 | -------------------------------------------------------------------------------- /run_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import random 4 | import sys 5 | import json 6 | import argparse 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | sys.dont_write_bytecode = True 10 | 11 | from libs.trainer import Trainer 12 | from libs.configs.config_arguments import arguments 13 | 14 | root_path = os.getcwd() 15 | 16 | def main(): 17 | """ 18 | Training script. 19 | Options: 20 | ######### General ########### 21 | --experiment_path : path to save experiment 22 | --use_wandb : use wandb to log losses and evaluation metrics 23 | --log_images_wandb : if True log images on wandb 24 | --project_wandb : Project name for wandb 25 | --resume_training_model : Path to model to continue training or None 26 | 27 | ######### Generator ######### 28 | --dataset_type : voxceleb or ffhq 29 | --image_resolution : image resolution of pre-trained GAN model. image resolution for voxceleb dataset is 256 30 | 31 | ######### Dataset ######### 32 | --synthetic_dataset_path : set synthetic dataset path for evaluation 33 | --train_dataset_path : set training dataset path 34 | --test_dataset_path : set testing dataset path 35 | 36 | ######### Direction matrix A ######### 37 | --training_method : set training method: 38 | synthetic -> training only with synthetic images 39 | real -> training only with real images 40 | real_synthetic -> training with synthetic and real images 41 | paired -> training with paired images 42 | 43 | --lr : set the learning rate of direction matrix model 44 | 45 | ######### Training ######### 46 | --max_iter : set maximum number of training iterations 47 | --batch_size : set training batch size 48 | 49 | Phase 1: Train with synthetic images only. Evaluation during training on synthetic images. 50 | python run_trainer.py --experiment_path ./training_attempts/exp_v00 --training_method synthetic 51 | 52 | Phase 2: Train with both synthetic and real images. Evaluation during training on real images, source images are real target images are synthetic. 53 | python run_trainer.py --experiment_path ./training_attempts/exp_v00 --training_method real_synthetic \ 54 | --train_dataset_path /datasets/VoxCeleb1/VoxCeleb_videos\ 55 | --test_dataset_path /datasets/VoxCeleb1/VoxCeleb_videos_test 56 | 57 | python run_trainer.py --experiment_path ./training_attempts/test/exp_v00 --training_method paired --batch_size 4 \ 58 | --train_dataset_path /home/stella/Desktop/datasets/VoxCeleb1/VoxCeleb_few_shot \ 59 | --test_dataset_path /home/stella/Desktop/datasets/VoxCeleb1/VoxCeleb_few_shot --use_wandb --log_images_wandb 60 | 61 | Phase 3: Train with paired data. Evaluation during training on paired images. 62 | python run_trainer.py --experiment_path ./training_attempts/exp_v00 --training_method paired --batch_size 4 \ 63 | --train_dataset_path /datasets/VoxCeleb1/VoxCeleb_videos \ 64 | --test_dataset_path /datasets/VoxCeleb1/VoxCeleb_videos_test 65 | 66 | """ 67 | parser = argparse.ArgumentParser(description="training script") 68 | 69 | ######### General ########### 70 | parser.add_argument('--experiment_path', type=str, required = True, help="path to save the experiment") 71 | parser.add_argument('--use_wandb', dest='use_wandb', action='store_true', help="use wandb to log losses and evaluation metrics") 72 | parser.set_defaults(use_wandb=False) 73 | parser.add_argument('--log_images_wandb', dest='log_images_wandb', action='store_true', help="if True log images on wandb") 74 | parser.set_defaults(log_images_wandb=False) 75 | parser.add_argument('--project_wandb', type=str, default = 'face-reenactment', help="Project name for wandb") 76 | 77 | parser.add_argument('--resume_training_model', type=str, default = None, help="Path to model or None") 78 | 79 | ######### Generator ######### 80 | parser.add_argument('--image_resolution', type=int, default=256, choices=(256, 1024), help="image resolution of pre-trained GAN modeln") 81 | parser.add_argument('--dataset_type', type=str, default='voxceleb', choices=('voxceleb', 'ffhq'), help="set dataset name") 82 | 83 | ######### Dataset ######### 84 | parser.add_argument('--synthetic_dataset_path', type=str, default=None, help="set synthetic dataset path for evaluation") 85 | parser.add_argument('--train_dataset_path', type=str, default=None, help="set training dataset path") 86 | parser.add_argument('--test_dataset_path', type=str, default=None, help="set testing dataset path") 87 | parser.add_argument('--training_method', type=str, default='synthetic', choices=('synthetic', 'real', 'real_synthetic', 'paired'), help="set training method") 88 | parser.add_argument('--lr', type=float, default=0.0001, help=" set the learning rate of direction matrix model") 89 | 90 | ######### Training ######### 91 | parser.add_argument('--max_iter', type=int, default=100000, help="set maximum number of training iterations") 92 | parser.add_argument('--batch_size', type=int, default=12, help="set training batch size") 93 | parser.add_argument('--test_batch_size', type=int, default=2, help="set evaluation batch size") 94 | parser.add_argument('--workers', type=int, default=1, help="") 95 | 96 | 97 | # Parse given arguments 98 | args = parser.parse_args() 99 | args = vars(args) # convert to dictionary 100 | 101 | args.update(arguments) # add arguments from libs.configs.config_arguments.py 102 | 103 | # Create output dir and save current arguments 104 | experiment_path = args['experiment_path'] 105 | experiment_path = experiment_path + '_{}_{}'.format(args['dataset_type'], args['training_method']) 106 | args['experiment_path'] = experiment_path 107 | args['root_path'] = root_path 108 | # Set up trainer 109 | print("#. Experiment: {}".format(experiment_path)) 110 | 111 | 112 | trainer = Trainer(args) 113 | 114 | training_method = args['training_method'] 115 | if training_method == 'synthetic': 116 | trainer.train() 117 | elif training_method == 'real' or training_method == 'real_synthetic': 118 | trainer.train_real() 119 | elif training_method == 'paired': 120 | trainer.train_paired() 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /libs/face_models/sfd/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from skimage import io 8 | 9 | 10 | class FaceDetector(object): 11 | """An abstract class representing a face detector. 12 | 13 | Any other face detection implementation must subclass it. All subclasses 14 | must implement ``detect_from_image``, that return a list of detected 15 | bounding boxes. Optionally, for speed considerations detect from path is 16 | recommended. 17 | """ 18 | 19 | def __init__(self, device, verbose): 20 | self.device = device 21 | self.verbose = verbose 22 | 23 | # if verbose: 24 | # if 'cpu' in device: 25 | # logger = logging.getLogger(__name__) 26 | # logger.warning("Detection running on CPU, this may be potentially slow.") 27 | 28 | # if 'cpu' not in device and 'cuda' not in device: 29 | # if verbose: 30 | # logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) 31 | # raise ValueError 32 | 33 | def detect_from_image(self, tensor_or_path): 34 | """Detects faces in a given image. 35 | 36 | This function detects the faces present in a provided BGR(usually) 37 | image. The input can be either the image itself or the path to it. 38 | 39 | Arguments: 40 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path 41 | to an image or the image itself. 42 | 43 | Example:: 44 | 45 | >>> path_to_image = 'data/image_01.jpg' 46 | ... detected_faces = detect_from_image(path_to_image) 47 | [A list of bounding boxes (x1, y1, x2, y2)] 48 | >>> image = cv2.imread(path_to_image) 49 | ... detected_faces = detect_from_image(image) 50 | [A list of bounding boxes (x1, y1, x2, y2)] 51 | 52 | """ 53 | raise NotImplementedError 54 | 55 | def detect_from_batch(self, tensor): 56 | """Detects faces in a given image. 57 | 58 | This function detects the faces present in a provided BGR(usually) 59 | image. The input can be either the image itself or the path to it. 60 | 61 | Arguments: 62 | tensor {torch.tensor} -- image batch tensor. 63 | 64 | Example:: 65 | 66 | >>> path_to_image = 'data/image_01.jpg' 67 | ... detected_faces = detect_from_image(path_to_image) 68 | [A list of bounding boxes (x1, y1, x2, y2)] 69 | >>> image = cv2.imread(path_to_image) 70 | ... detected_faces = detect_from_image(image) 71 | [A list of bounding boxes (x1, y1, x2, y2)] 72 | 73 | """ 74 | raise NotImplementedError 75 | 76 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): 77 | """Detects faces from all the images present in a given directory. 78 | 79 | Arguments: 80 | path {string} -- a string containing a path that points to the folder containing the images 81 | 82 | Keyword Arguments: 83 | extensions {list} -- list of string containing the extensions to be 84 | consider in the following format: ``.extension_name`` (default: 85 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the 86 | folder recursively (default: {False}) show_progress_bar {bool} -- 87 | display a progressbar (default: {True}) 88 | 89 | Example: 90 | >>> directory = 'data' 91 | ... detected_faces = detect_from_directory(directory) 92 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} 93 | 94 | """ 95 | if self.verbose: 96 | logger = logging.getLogger(__name__) 97 | 98 | if len(extensions) == 0: 99 | if self.verbose: 100 | logger.error("Expected at list one extension, but none was received.") 101 | raise ValueError 102 | 103 | if self.verbose: 104 | logger.info("Constructing the list of images.") 105 | additional_pattern = '/**/*' if recursive else '/*' 106 | files = [] 107 | for extension in extensions: 108 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) 109 | 110 | if self.verbose: 111 | logger.info("Finished searching for images. %s images found", len(files)) 112 | logger.info("Preparing to run the detection.") 113 | 114 | predictions = {} 115 | for image_path in tqdm(files, disable=not show_progress_bar): 116 | if self.verbose: 117 | logger.info("Running the face detector on image: %s", image_path) 118 | predictions[image_path] = self.detect_from_image(image_path) 119 | 120 | if self.verbose: 121 | logger.info("The detector was successfully run on all %s images", len(files)) 122 | 123 | return predictions 124 | 125 | @property 126 | def reference_scale(self): 127 | raise NotImplementedError 128 | 129 | @property 130 | def reference_x_shift(self): 131 | raise NotImplementedError 132 | 133 | @property 134 | def reference_y_shift(self): 135 | raise NotImplementedError 136 | 137 | @staticmethod 138 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): 139 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray 140 | 141 | Arguments: 142 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself 143 | """ 144 | if isinstance(tensor_or_path, str): 145 | return cv2.imread(tensor_or_path) if not rgb else io.imread(tensor_or_path) 146 | elif torch.is_tensor(tensor_or_path): 147 | # Call cpu in case its coming from cuda 148 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() 149 | elif isinstance(tensor_or_path, np.ndarray): 150 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path 151 | else: 152 | raise TypeError 153 | -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /libs/face_models/landmarks_estimation.py: -------------------------------------------------------------------------------- 1 | """ 2 | The face detector used is SFD (taken from face-alignment) 3 | https://github.com/1adrianb/face-alignment 4 | """ 5 | import os 6 | import numpy as np 7 | import cv2 8 | from enum import Enum 9 | import torch 10 | from torch.utils.model_zoo import load_url 11 | 12 | 13 | from libs.face_models.sfd.sfd_detector import SFDDetector as FaceDetector 14 | from libs.face_models.fan_model.models import FAN, ResNetDepth 15 | from libs.face_models.fan_model.utils import * 16 | 17 | models_urls = { 18 | '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-11f355bf06.pth.tar', 19 | '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-7835d9f11d.pth.tar', 20 | 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-2a464da4ea.pth.tar', 21 | } 22 | 23 | class LandmarksType(Enum): 24 | """Enum class defining the type of landmarks to detect. 25 | 26 | ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face 27 | ``_2halfD`` - this points represent the projection of the 3D points into 3D 28 | ``_3D`` - detect the points ``(x,y,z)``` in a 3D space 29 | 30 | """ 31 | _2D = 1 32 | _2halfD = 2 33 | _3D = 3 34 | 35 | class NetworkSize(Enum): 36 | # TINY = 1 37 | # SMALL = 2 38 | # MEDIUM = 3 39 | LARGE = 4 40 | 41 | def __new__(cls, value): 42 | member = object.__new__(cls) 43 | member._value_ = value 44 | return member 45 | 46 | def __int__(self): 47 | return self.value 48 | 49 | 50 | def get_preds_fromhm(hm, center=None, scale=None): 51 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center 52 | and the scale is provided the function will return the points also in 53 | the original coordinate frame. 54 | 55 | Arguments: 56 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 57 | 58 | Keyword Arguments: 59 | center {torch.tensor} -- the center of the bounding box (default: {None}) 60 | scale {float} -- face scale (default: {None}) 61 | """ 62 | max, idx = torch.max( 63 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 64 | idx = idx + 1 65 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 66 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 67 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 68 | 69 | for i in range(preds.size(0)): 70 | for j in range(preds.size(1)): 71 | hm_ = hm[i, j, :] 72 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 73 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 74 | diff = torch.FloatTensor( 75 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 76 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 77 | preds[i, j].add_(diff.sign_().mul_(.25)) 78 | 79 | preds.add_(-.5) 80 | 81 | preds_orig = torch.zeros(preds.size()) 82 | if center is not None and scale is not None: 83 | for i in range(hm.size(0)): 84 | for j in range(hm.size(1)): 85 | preds_orig[i, j] = transform( 86 | preds[i, j], center, scale, hm.size(2), True) 87 | 88 | return preds, preds_orig 89 | 90 | def draw_detected_face(img, face): 91 | x_min = int(face[0]) 92 | y_min = int(face[1]) 93 | x_max = int(face[2]) 94 | y_max = int(face[3]) 95 | 96 | cv2.rectangle(img, (int(x_min),int(y_min)), (int(x_max),int(y_max)), (255,0,0), 2) 97 | 98 | return img 99 | 100 | 101 | class LandmarksEstimation(): 102 | def __init__(self, type = '3D', path_to_detector = './pretrained_models/s3fd-619a316812.pth'): 103 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 104 | # Load all needed models - Face detector and Pose detector 105 | network_size = NetworkSize.LARGE 106 | network_size = int(network_size) 107 | if type == '3D': 108 | self.landmarks_type = LandmarksType._3D 109 | else: 110 | self.landmarks_type = LandmarksType._2D 111 | self.flip_input = False 112 | 113 | #################### SFD face detection ################### 114 | if not os.path.exists(path_to_detector): 115 | print('Pretrained model of SFD face detector does not exist in {}'.format(path_to_detector)) 116 | exit() 117 | self.face_detector = FaceDetector(device=self.device, verbose=False, path_to_detector = path_to_detector) 118 | ########################################################### 119 | 120 | ################### Initialise the face alignemnt networks ################### 121 | self.face_alignment_net = FAN(network_size) 122 | if self.landmarks_type == LandmarksType._2D: # 123 | network_name = '2DFAN-' + str(network_size) 124 | else: 125 | network_name = '3DFAN-' + str(network_size) 126 | fan_weights = load_url(models_urls[network_name], map_location=lambda storage, loc: storage) 127 | self.face_alignment_net.load_state_dict(fan_weights) 128 | self.face_alignment_net.to(self.device) 129 | self.face_alignment_net.eval() 130 | ############################################################################## 131 | 132 | # Initialiase the depth prediciton network if 3D landmarks 133 | if self.landmarks_type == LandmarksType._3D: 134 | self.depth_prediciton_net = ResNetDepth() 135 | depth_weights = load_url(models_urls['depth'], map_location=lambda storage, loc: storage) 136 | depth_dict = { 137 | k.replace('module.', ''): v for k, 138 | v in depth_weights['state_dict'].items()} 139 | self.depth_prediciton_net.load_state_dict(depth_dict) 140 | self.depth_prediciton_net.to(self.device) 141 | self.depth_prediciton_net.eval() 142 | 143 | def get_landmarks(self, face, image): 144 | 145 | center = torch.FloatTensor( 146 | [(face[2] + face[0]) / 2.0, 147 | (face[3] + face[1]) / 2.0]) 148 | 149 | center[1] = center[1] - (face[3] - face[1]) * 0.12 150 | scale = (face[2] - face[0] + face[3] - face[1]) / self.face_detector.reference_scale 151 | 152 | inp = crop_torch(image, center, scale).float().cuda() 153 | inp = inp.div(255.0) 154 | 155 | out = self.face_alignment_net(inp)[-1] 156 | 157 | if self.flip_input: 158 | out = out + flip(self.face_alignment_net(flip(inp)) 159 | [-1], is_label=True) 160 | out = out.cpu() 161 | 162 | pts, pts_img = get_preds_fromhm(out, center, scale) 163 | out = out.cuda() 164 | # Added 3D landmark support 165 | if self.landmarks_type == LandmarksType._3D: 166 | pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2) 167 | heatmaps = torch.zeros((68,256,256), dtype=torch.float32) 168 | for i in range(68): 169 | if pts[i, 0] > 0: 170 | heatmaps[i] = draw_gaussian( 171 | heatmaps[i], pts[i], 2) 172 | 173 | heatmaps = heatmaps.unsqueeze(0) 174 | 175 | heatmaps = heatmaps.to(self.device) 176 | depth_pred = self.depth_prediciton_net( 177 | torch.cat((inp, heatmaps), 1)).view(68, 1) 178 | 179 | pts_img = pts_img.cuda() 180 | pts_img = torch.cat( 181 | (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1) 182 | else: 183 | pts, pts_img = pts.view(-1, 68, 2) * 4, pts_img.view(-1, 68, 2) 184 | 185 | return pts_img, out 186 | 187 | def detect_landmarks(self, image): 188 | 189 | if len(image.shape) == 3: 190 | image = image.unsqueeze(0) 191 | 192 | if self.device == 'cuda': 193 | image = image.cuda() 194 | 195 | with torch.no_grad(): 196 | detected_faces = self.face_detector.detect_from_batch(image) 197 | 198 | if self.landmarks_type == LandmarksType._3D: 199 | landmarks = torch.empty((1, 68, 3)) 200 | else: 201 | landmarks = torch.empty((1, 68, 2)) 202 | 203 | for face in detected_faces[0]: 204 | conf = face[4] 205 | if conf > 0.99: 206 | pts_img, heatmaps = self.get_landmarks(face, image) 207 | landmarks[0] = pts_img 208 | 209 | return landmarks -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Finding Directions in GAN's Latent Space for Neural Face Reenactment 2 | 3 | Authors official PyTorch implementation of the **[Finding Directions in GAN's Latent Space for Neural Face Reenactment](https://arxiv.org/abs/2202.00046)**. This paper has been accepted as an oral presentation at British Machine Vision Conference (BMVC), 2022. If you use this code for your research, please [**cite**](#citation) our paper. 4 | 5 |

6 | 7 |

8 | 9 | >**Finding Directions in GAN's Latent Space for Neural Face Reenactment**
10 | > Stella Bounareli, Vasileios Argyriou, Georgios Tzimiropoulos
11 | > 12 | > **Abstract**: This paper is on face/head reenactment where the goal is to transfer the facial pose (3D head orientation and expression) of a target face to a source face. Previous methods focus on learning embedding networks for identity and pose disentanglement which proves to be a rather hard task, degrading the quality of the generated images. We take a different approach, bypassing the training of such networks, by using (fine-tuned) pre-trained GANs which have been shown capable of producing high-quality facial images. Because GANs are characterized by weak controllability, the core of our approach is a method to discover which directions in latent GAN space are responsible for controlling facial pose and expression variations. We present a simple pipeline to learn such directions with the aid of a 3D shape model which, by construction, already captures disentangled directions for facial pose, identity and expression. Moreover, we show that by embedding real images in the GAN latent space, our method can be successfully used for the reenactment of real-world faces. Our method features several favorable properties including using a single source image (one-shot) and enabling cross-person reenactment. Our qualitative and quantitative results show that our approach often produces reenacted faces of significantly higher quality than those produced by state-of-the-art methods for the standard benchmarks of VoxCeleb1 & 2. 13 | 14 | 15 | 16 | 17 | ## Face Reenactment Results on VoxCeleb1 dataset 18 | 19 | > Real image editing of head pose and expression 20 | 21 |

22 | 23 |

24 | 25 | > Self and Cross-subject Reenactment 26 | 27 |

28 | 29 | 30 |

31 | 32 | 33 | # Installation 34 | 35 | * Python 3.5+ 36 | * Linux 37 | * NVIDIA GPU + CUDA CuDNN 38 | * Pytorch (>=1.5) 39 | * [Pytorch3d](https://github.com/facebookresearch/pytorch3d) 40 | * [DECA](https://github.com/YadiraF/DECA) 41 | 42 | We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/). 43 | 44 | ``` 45 | conda create -n python38 python=3.8 46 | conda activate python38 47 | conda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=11.0 -c pytorch 48 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 49 | conda install pytorch3d -c pytorch3d 50 | pip install -r requirements.txt 51 | 52 | ``` 53 | 54 | # Pretrained Models 55 | 56 | We provide a StyleGAN2 model trained using [StyleGAN2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch) and an [e4e](https://github.com/omertov/encoder4editing) inversion model trained on [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) dataset. 57 | 58 | 59 | | Path | Description 60 | | :--- | :---------- 61 | |[StyleGAN2-VoxCeleb1](https://drive.google.com/file/d/1cBwIFwq6cYIA5iR8tEvj6BIL7Ji7azIH/view?usp=sharing) | StyleGAN2 trained on VoxCeleb1 dataset. 62 | |[e4e-VoxCeleb1](https://drive.google.com/file/d/1TRATaREBi4VCMITUZV0ZO2XFU3YZKGlQ/view?usp=share_link) | e4e trained on VoxCeleb1 dataset. 63 | 64 | 65 | # Auxiliary Models 66 | 67 | We provide additional auxiliary models needed during training. 68 | 69 | | Path | Description 70 | | :--- | :---------- 71 | |[face-detector](https://drive.google.com/file/d/1IWqJUTAZCelAZrUzfU38zK_ZM25fK32S/view?usp=share_link) | Pretrained face detector taken from [face-alignment](https://github.com/1adrianb/face-alignment). 72 | |[IR-SE50 Model](https://drive.google.com/file/d/1s5pWag4AwqQyhue6HH-M_f2WDV4IVZEl/view?usp=sharing) | Pretrained IR-SE50 model taken from [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) for use in our identity loss. 73 | |[DECA model](https://drive.google.com/file/d/1BHVJAEXscaXMj_p2rOsHYF_vaRRRHQbA/view?usp=sharing) | Pretrained model taken from [DECA](https://github.com/YadiraF/DECA). Extract data.tar.gz under `./libs/DECA/`. 74 | 75 | 76 | 77 | By default, we assume that all pretrained models are downloaded and saved to the directory `./pretrained_models`. 78 | 79 | 80 | # Preparing your Data 81 | 82 | 1. Download and preprocess the VoxCeleb dataset using [VoxCeleb_preprocessing](https://github.com/StelaBou/voxceleb_preprocessing). 83 | 84 | 85 | 2. Invert real images into the latent space of the pretrained StyleGAN2 using the [Encoder4Editing](https://arxiv.org/abs/2102.02766) method. 86 | ``` 87 | python invert_images.py --input_path path/to/voxdataset 88 | ``` 89 | 90 | The dataset is saved as: 91 | 92 | ``` 93 | .path/to/voxdataset 94 | |-- id10271 # identity index 95 | | |-- 37nktPRUJ58 # video index 96 | | | |-- frames_cropped # preprocessed frames 97 | | | | |-- 00_000025.png 98 | | | | |-- ... 99 | | | |-- inversion 100 | | | | |-- frames # inverted frames 101 | | | | | |-- 00_000025.png 102 | | | | | |-- .. 103 | | | | |-- latent_codes # inverted latent_codes 104 | | | | | |-- 00_000025.npy 105 | | | | | |-- .. 106 | | |-- Zjc7Xy7aT8c 107 | | | | ... 108 | |-- id10273 109 | | | ... 110 | ``` 111 | 112 | The correct preprocessing of the dataset is important to reenact the images. Different preprocessing will lead in poor performance. 113 | Example: 114 |

115 | 116 |

117 | 118 | # Training 119 | 120 | To train our model, make sure to download and save the required models under `./pretrained_models` path and that the training and testing data are configured as described above. Please check `run_trainer.py` and `./libs/configs/config_arguments.py` for the training arguments. 121 | 122 | Example of training using paired data: 123 | ``` 124 | python run_trainer.py \ 125 | --experiment_path ./training_attempts/exp_v00 \ 126 | --train_dataset_path path_to_training_dataset \ 127 | --test_dataset_path path_to_test_dataset \ 128 | --training_method paired 129 | ``` 130 | 131 | 132 | # Inference 133 | Download our pretrained model [A-matrix](https://drive.google.com/file/d/11aNalhBnPREFQT9i9wmQE0-fkQac4SzQ/view?usp=share_link) and save it under `./pretrained_models` path. 134 | 135 | ## Facial image editing: 136 | 137 | Given as input an image or a latent code, change only one facial attribute that corresponds to one of our learned directions. 138 | ``` 139 | python run_facial_editing.py \ 140 | --source_path ./inference_examples/0002775.png \ 141 | --output_path ./results/facial_editing \ 142 | --directions 0 1 2 3 4 \ 143 | --save_gif \ 144 | --optimize_generator 145 | ``` 146 | 147 | ## Face reenactment (self or cross): 148 | 149 | Given as input a source identity and a target video, reenact the source face. The source and target faces could have the same identity or different identity. 150 | ``` 151 | python run_inference.py \ 152 | --source_path ./inference_examples/0002775.png \ 153 | --target_path ./inference_examples/lWOTF8SdzJw#2614-2801.mp4 \ 154 | --output_path ./results/ \ 155 | --save_video 156 | ``` 157 | 158 | 159 | # Citation 160 | 161 | [1] Stella Bounareli, Argyriou Vasileios and Georgios Tzimiropoulos. Finding Directions in GAN's Latent Space for Neural Face Reenactment. 162 | 163 | Bibtex entry: 164 | 165 | ```bibtex 166 | @article{bounareli2022finding, 167 | title={Finding Directions in GAN's Latent Space for Neural Face Reenactment}, 168 | author={Bounareli, Stella and Argyriou, Vasileios and Tzimiropoulos, Georgios}, 169 | journal={British Machine Vision Conference (BMVC)}, 170 | year={2022} 171 | } 172 | 173 | ``` 174 | -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/convert_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import pickle 5 | import math 6 | 7 | import torch 8 | import numpy as np 9 | from torchvision import utils 10 | 11 | from models.StyleGAN2.model import Generator, Discriminator 12 | 13 | 14 | def convert_modconv(vars, source_name, target_name, flip=False): 15 | weight = vars[source_name + '/weight'].value().eval() 16 | mod_weight = vars[source_name + '/mod_weight'].value().eval() 17 | mod_bias = vars[source_name + '/mod_bias'].value().eval() 18 | noise = vars[source_name + '/noise_strength'].value().eval() 19 | bias = vars[source_name + '/bias'].value().eval() 20 | 21 | dic = { 22 | 'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 23 | 'conv.modulation.weight': mod_weight.transpose((1, 0)), 24 | 'conv.modulation.bias': mod_bias + 1, 25 | 'noise.weight': np.array([noise]), 26 | 'activate.bias': bias, 27 | } 28 | 29 | dic_torch = {} 30 | 31 | for k, v in dic.items(): 32 | dic_torch[target_name + '.' + k] = torch.from_numpy(v) 33 | 34 | if flip: 35 | dic_torch[target_name + '.conv.weight'] = torch.flip( 36 | dic_torch[target_name + '.conv.weight'], [3, 4] 37 | ) 38 | 39 | return dic_torch 40 | 41 | 42 | def convert_conv(vars, source_name, target_name, bias=True, start=0): 43 | weight = vars[source_name + '/weight'].value().eval() 44 | 45 | dic = {'weight': weight.transpose((3, 2, 0, 1))} 46 | 47 | if bias: 48 | dic['bias'] = vars[source_name + '/bias'].value().eval() 49 | 50 | dic_torch = {} 51 | dic_torch[target_name + '.{}.weight'.format(start)] = torch.from_numpy(dic['weight']) 52 | 53 | if bias: 54 | dic_torch[target_name + '.{}.bias'.format(start + 1)] = torch.from_numpy(dic['bias']) 55 | 56 | return dic_torch 57 | 58 | 59 | def convert_torgb(vars, source_name, target_name): 60 | weight = vars[source_name + '/weight'].value().eval() 61 | mod_weight = vars[source_name + '/mod_weight'].value().eval() 62 | mod_bias = vars[source_name + '/mod_bias'].value().eval() 63 | bias = vars[source_name + '/bias'].value().eval() 64 | 65 | dic = { 66 | 'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 67 | 'conv.modulation.weight': mod_weight.transpose((1, 0)), 68 | 'conv.modulation.bias': mod_bias + 1, 69 | 'bias': bias.reshape((1, 3, 1, 1)), 70 | } 71 | 72 | dic_torch = {} 73 | 74 | for k, v in dic.items(): 75 | dic_torch[target_name + '.' + k] = torch.from_numpy(v) 76 | 77 | return dic_torch 78 | 79 | 80 | def convert_dense(vars, source_name, target_name): 81 | weight = vars[source_name + '/weight'].value().eval() 82 | bias = vars[source_name + '/bias'].value().eval() 83 | 84 | dic = {'weight': weight.transpose((1, 0)), 'bias': bias} 85 | 86 | dic_torch = {} 87 | 88 | for k, v in dic.items(): 89 | dic_torch[target_name + '.' + k] = torch.from_numpy(v) 90 | 91 | return dic_torch 92 | 93 | 94 | def update(state_dict, new): 95 | for k, v in new.items(): 96 | if k not in state_dict: 97 | raise KeyError(k + ' is not found') 98 | 99 | if v.shape != state_dict[k].shape: 100 | raise ValueError('Shape mismatch: {} vs {}'.format(v.shape, state_dict[k].shape)) 101 | 102 | state_dict[k] = v 103 | 104 | 105 | def discriminator_fill_statedict(statedict, vars, size): 106 | log_size = int(math.log(size, 2)) 107 | 108 | update(statedict, convert_conv(vars, '{}x{}/FromRGB'.format(size, size), 'convs.0')) 109 | 110 | conv_i = 1 111 | 112 | for i in range(log_size - 2, 0, -1): 113 | reso = 4 * 2 ** i 114 | update(statedict, convert_conv(vars, '{}x{}/Conv0'.format(reso, reso), 'convs.{}.conv1'.format(conv_i))) 115 | update(statedict, convert_conv(vars, '{}x{}/Conv1_down'.format(reso, reso), 'convs.{}.conv2'.format(conv_i), start=1)) 116 | update(statedict, convert_conv(vars, '{}x{}/Skip'.format(reso, reso), 'convs.{}.skip'.format(conv_i), start=1, bias=False)) 117 | conv_i += 1 118 | 119 | update(statedict, convert_conv(vars, '4x4/Conv', 'final_conv')) 120 | update(statedict, convert_dense(vars, '4x4/Dense0', 'final_linear.0')) 121 | update(statedict, convert_dense(vars, 'Output', 'final_linear.1')) 122 | 123 | return statedict 124 | 125 | 126 | def fill_statedict(state_dict, vars, size): 127 | log_size = int(math.log(size, 2)) 128 | 129 | for i in range(8): 130 | update(state_dict, convert_dense(vars, 'G_mapping/Dense{}'.format(i), 'style.{}'.format(i + 1))) 131 | 132 | update( 133 | state_dict, 134 | { 135 | 'input.input': torch.from_numpy( 136 | vars['G_synthesis/4x4/Const/const'].value().eval() 137 | ) 138 | }, 139 | ) 140 | 141 | update(state_dict, convert_torgb(vars, 'G_synthesis/4x4/ToRGB', 'to_rgb1')) 142 | 143 | for i in range(log_size - 2): 144 | reso = 4 * 2 ** (i + 1) 145 | update( 146 | state_dict, 147 | convert_torgb(vars, 'G_synthesis/{}x{}/ToRGB'.format(reso, reso), 'to_rgbs.{}'.format(i)), 148 | ) 149 | 150 | update(state_dict, convert_modconv(vars, 'G_synthesis/4x4/Conv', 'conv1')) 151 | 152 | conv_i = 0 153 | 154 | for i in range(log_size - 2): 155 | reso = 4 * 2 ** (i + 1) 156 | update( 157 | state_dict, 158 | convert_modconv( 159 | vars, 160 | 'G_synthesis/{}x{}/Conv0_up'.format(reso, reso), 161 | 'convs.{}'.format(conv_i), 162 | flip=True, 163 | ), 164 | ) 165 | update( 166 | state_dict, 167 | convert_modconv( 168 | vars, 169 | 'G_synthesis/{}x{}/Conv1'.format(reso, reso), 170 | 'convs.{}'.format(conv_i + 1) 171 | ), 172 | ) 173 | conv_i += 2 174 | 175 | for i in range(0, (log_size - 2) * 2 + 1): 176 | update( 177 | state_dict, 178 | { 179 | 'noises.noise_{}'.format(i): torch.from_numpy( 180 | vars['G_synthesis/noise{}'.format(i)].value().eval() 181 | ) 182 | }, 183 | ) 184 | 185 | return state_dict 186 | 187 | 188 | if __name__ == '__main__': 189 | device = 'cuda' 190 | 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument('--repo', type=str, required=True) 193 | parser.add_argument('--gen', action='store_true') 194 | parser.add_argument('--disc', action='store_true') 195 | parser.add_argument('path', metavar='PATH') 196 | 197 | args = parser.parse_args() 198 | 199 | sys.path.append(args.repo) 200 | 201 | import dnnlib 202 | from dnnlib import tflib 203 | 204 | tflib.init_tf() 205 | 206 | with open(args.path, 'rb') as f: 207 | generator, discriminator, g_ema = pickle.load(f) 208 | 209 | size = g_ema.output_shape[2] 210 | 211 | g = Generator(size, 512, 8) 212 | state_dict = g.state_dict() 213 | state_dict = fill_statedict(state_dict, g_ema.vars, size) 214 | 215 | g.load_state_dict(state_dict) 216 | 217 | latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval()) 218 | 219 | ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg} 220 | 221 | if args.gen: 222 | g_train = Generator(size, 512, 8) 223 | g_train_state = g_train.state_dict() 224 | g_train_state = fill_statedict(g_train_state, generator.vars, size) 225 | ckpt['g'] = g_train_state 226 | 227 | if args.disc: 228 | disc = Discriminator(size) 229 | d_state = disc.state_dict() 230 | d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) 231 | ckpt['d'] = d_state 232 | 233 | name = os.path.splitext(os.path.basename(args.path))[0] 234 | torch.save(ckpt, name + '.pt') 235 | 236 | batch_size = {256: 16, 512: 9, 1024: 4} 237 | n_sample = batch_size.get(size, 25) 238 | 239 | g = g.to(device) 240 | 241 | z = np.random.RandomState(0).randn(n_sample, 512).astype('float32') 242 | 243 | with torch.no_grad(): 244 | img_pt, _ = g([torch.from_numpy(z).to(device)], truncation=0.5, truncation_latent=latent_avg.to(device)) 245 | 246 | Gs_kwargs = dnnlib.EasyDict() 247 | Gs_kwargs.randomize_noise = False 248 | img_tf = g_ema.run(z, None, **Gs_kwargs) 249 | img_tf = torch.from_numpy(img_tf).to(device) 250 | 251 | img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(0.0, 1.0) 252 | 253 | img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) 254 | utils.save_image(img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1)) 255 | 256 | -------------------------------------------------------------------------------- /libs/datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | CustomDataset: Custom Dataloader for real images using VoxCeleb dataset 3 | CustomDataset_testset_synthetic: Custom Dataloader for synthetic images for evaluation 4 | CustomDataset_testset_real: Custom Dataloader for real images for evaluation 5 | """ 6 | import torch 7 | import os 8 | import glob 9 | import cv2 10 | import numpy as np 11 | from torchvision import transforms, utils 12 | from PIL import Image 13 | from torch.utils.data import Dataset 14 | 15 | from libs.utilities.utils import make_noise 16 | 17 | np.random.seed(0) 18 | 19 | class CustomDataset(Dataset): 20 | 21 | def __init__(self, dataset_path): 22 | """ 23 | VoxCeleb dataset format: id_index/video_index/frames_inverted/latent_codes 24 | Args: 25 | dataset_path (string): Path to voxceleb dataset 26 | """ 27 | 28 | self.dataset_path = dataset_path 29 | self.get_dataset() 30 | 31 | self.transform = transforms.Compose([ 32 | transforms.Resize((256, 256)), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 35 | 36 | 37 | def get_dataset(self): 38 | 39 | ids_path = glob.glob(os.path.join(self.dataset_path, '*/')) 40 | ids_path.sort() 41 | if len(ids_path) == 0: 42 | print('Dataset has no identities in path {}'.format(self.dataset_path)) 43 | exit() 44 | 45 | real_images = None; inv_images = None; w = None 46 | counter_ids = 0; counter_videos = 0 47 | samples = [] 48 | for i, ids in enumerate(ids_path): 49 | 50 | id_index = ids.split('/')[-2] 51 | videos_path = glob.glob(os.path.join(ids, '*/')) 52 | videos_path.sort() 53 | 54 | for j, video_path in enumerate(videos_path): 55 | video_index = video_path.split('/')[-2] 56 | 57 | if not os.path.exists(os.path.join(video_path, 'inversion')): 58 | print('Path with inverted latent codes does not exist.') 59 | exit() 60 | inv_images_path = glob.glob(os.path.join(video_path, 'inversion', 'frames', '*.png')) # Inverted 61 | inv_images_path.sort() 62 | codes_path = glob.glob(os.path.join(video_path, 'inversion', 'latent_codes', '*.npy')) 63 | codes_path.sort() 64 | real_images_path = glob.glob(os.path.join(video_path, 'frames_cropped', '*.png')) 65 | real_images_path.sort() 66 | 67 | dict_sample = { 68 | 'id_index': id_index, 69 | 'video_index': video_index, 70 | 'real_images': real_images_path, 71 | 'codes': codes_path, 72 | 'inv_images': inv_images_path, 73 | } 74 | 75 | if real_images is None: 76 | real_images = real_images_path 77 | w = codes_path 78 | inv_images = inv_images_path 79 | else: 80 | real_images = np.concatenate((real_images, real_images_path), axis=0) 81 | w = np.concatenate((w, codes_path), axis=0) 82 | inv_images = np.concatenate((inv_images, inv_images_path), axis=0) 83 | counter_videos += 1 84 | samples.append(dict_sample) 85 | 86 | counter_ids += 1 87 | 88 | real_images = np.asarray(real_images) 89 | w = np.asarray(w) 90 | inv_images = np.asarray(inv_images) 91 | 92 | self.real_images = real_images 93 | self.inv_images = inv_images 94 | self.w = w 95 | self.counter_ids = counter_ids 96 | self.counter_videos = counter_videos 97 | 98 | def get_length(self, train = True): 99 | return len(self.real_images), self.counter_ids, self.counter_videos 100 | 101 | def __len__(self): 102 | return len(self.real_images) 103 | 104 | def __getitem__(self, index): 105 | 106 | real_image_path = self.real_images[index] 107 | real_img = Image.open(real_image_path) 108 | real_img = real_img.convert('RGB') 109 | real_img = self.transform(real_img) 110 | 111 | inv_image_path = self.inv_images[index] 112 | inv_img = Image.open(inv_image_path) 113 | inv_img = inv_img.convert('RGB') 114 | inv_img = self.transform(inv_img) 115 | 116 | w_file = self.w[index] 117 | latent_code = np.load(w_file) 118 | latent_code = torch.from_numpy(latent_code) 119 | assert latent_code.ndim == 2, 'latent code dimensions should be inject_index x 512 while now is {}'.format(latent_code.shape) 120 | 121 | sample = { 122 | 'real_img': real_img, 123 | 'inv_img': inv_img, 124 | 'w': latent_code, 125 | } 126 | return sample 127 | 128 | class CustomDataset_testset_synthetic(Dataset): 129 | 130 | def __init__(self, synthetic_dataset_path = None, num_samples = None, shuffle = True): 131 | """ 132 | VoxCeleb dataset format: id_index/video_index/frames_inverted/latent_codes 133 | Args: 134 | synthetic_dataset_path: path to synthetic latent codes. If None generate random 135 | num_samples: how many samples for validation 136 | 137 | """ 138 | self.shuffle = shuffle 139 | self.num_samples = num_samples 140 | self.synthetic_dataset_path = synthetic_dataset_path 141 | 142 | 143 | if self.synthetic_dataset_path is not None: 144 | z_codes = np.load(self.synthetic_dataset_path) 145 | z_codes = torch.from_numpy(z_codes) 146 | self.fixed_source_w = z_codes[:self.num_samples, :] 147 | self.fixed_target_w = z_codes[self.num_samples:2*self.num_samples, :] 148 | else: 149 | self.fixed_source_w = make_noise(self.num_samples, 512, None) 150 | self.fixed_target_w = make_noise(self.num_samples, 512, None) 151 | # Save random generated latent codes 152 | save_path = './libs/configs/random_latent_codes_{}.npy'.format(2*self.num_samples) 153 | z_codes = torch.cat((self.fixed_source_w, self.fixed_target_w), dim = 0) 154 | z_codes = z_codes.detach().cpu().numpy() 155 | np.save(save_path, z_codes) 156 | 157 | self.transform = transforms.Compose([ 158 | transforms.Resize((256, 256)), 159 | transforms.ToTensor(), 160 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 161 | 162 | 163 | def __len__(self): 164 | return self.num_samples 165 | 166 | def __getitem__(self, index): 167 | 168 | source_w = self.fixed_source_w[index] 169 | target_w = self.fixed_target_w[index] 170 | sample = { 171 | 'source_w': source_w, 172 | 'target_w': target_w 173 | } 174 | return sample 175 | 176 | class CustomDataset_testset_real(Dataset): 177 | 178 | def __init__(self, dataset_path, suffle = True, num_samples = None): 179 | """ 180 | VoxCeleb dataset format: id_index/video_index/frames_inverted/latent_codes 181 | Args: 182 | dataset_path (string): Path to voxceleb dataset 183 | num_samples: how many samples for validation 184 | """ 185 | self.num_samples = num_samples 186 | self.dataset_path = dataset_path 187 | self.suffle = suffle 188 | 189 | self.get_dataset() 190 | self.fixed_target_w = make_noise(self.num_samples, 512, None) 191 | 192 | self.transform = transforms.Compose([ 193 | transforms.Resize((256, 256)), 194 | transforms.ToTensor(), 195 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 196 | 197 | def get_dataset(self): 198 | 199 | ids_path = glob.glob(os.path.join(self.dataset_path, '*/')) 200 | ids_path.sort() 201 | if len(ids_path) == 0: 202 | print('Dataset has no identities in path {}'.format(self.dataset_path)) 203 | exit() 204 | 205 | real_images = None; inv_images = None; w = None 206 | counter_ids = 0; counter_videos = 0 207 | samples = [] 208 | for i, ids in enumerate(ids_path): 209 | 210 | id_index = ids.split('/')[-2] 211 | videos_path = glob.glob(os.path.join(ids, '*/')) 212 | videos_path.sort() 213 | 214 | for j, video_path in enumerate(videos_path): 215 | video_index = video_path.split('/')[-2] 216 | if not os.path.exists(os.path.join(video_path, 'inversion')): 217 | print('Path with inverted latent codes does not exist.') 218 | exit() 219 | codes_path = glob.glob(os.path.join(video_path, 'inversion', 'latent_codes', '*.npy')) 220 | codes_path.sort() 221 | if w is None: 222 | w = codes_path 223 | else: 224 | w = np.concatenate((w, codes_path), axis=0) 225 | counter_videos += 1 226 | counter_ids += 1 227 | 228 | self.w = w 229 | self.counter_ids = counter_ids 230 | self.counter_videos = counter_videos 231 | 232 | self.w = np.asarray(self.w) 233 | if self.suffle: 234 | r = np.random.permutation(len(self.w)) 235 | self.w = self.w[r.astype(int)] 236 | 237 | if self.num_samples < len(self.w): 238 | self.w = self.w[:self.num_samples] 239 | 240 | def get_length(self): 241 | return self.num_samples 242 | 243 | def __len__(self): 244 | return self.num_samples 245 | 246 | def __getitem__(self, index): 247 | 248 | w_file = self.w[index] 249 | source_w = np.load(w_file) 250 | source_w = torch.from_numpy(source_w) 251 | assert source_w.ndim == 2, 'latent code dimensions should be inject_index x 512 while now is {}'.format(source_w.shape) 252 | 253 | target_w = self.fixed_target_w[index] 254 | sample = { 255 | 'source_w': source_w, 256 | 'target_w': target_w 257 | } 258 | return sample 259 | -------------------------------------------------------------------------------- /libs/gan/encoder4editing/psp_encoders.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 7 | 8 | from libs.gan.encoder4editing.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add 9 | from libs.gan.StyleGAN2.model import EqualLinear 10 | 11 | class ProgressiveStage(Enum): 12 | WTraining = 0 13 | Delta1Training = 1 14 | Delta2Training = 2 15 | Delta3Training = 3 16 | Delta4Training = 4 17 | Delta5Training = 5 18 | Delta6Training = 6 19 | Delta7Training = 7 20 | Delta8Training = 8 21 | Delta9Training = 9 22 | Delta10Training = 10 23 | Delta11Training = 11 24 | Delta12Training = 12 25 | Delta13Training = 13 26 | Delta14Training = 14 27 | Delta15Training = 15 28 | Delta16Training = 16 29 | Delta17Training = 17 30 | Inference = 18 31 | 32 | 33 | class GradualStyleBlock(Module): 34 | def __init__(self, in_c, out_c, spatial): 35 | super(GradualStyleBlock, self).__init__() 36 | self.out_c = out_c 37 | self.spatial = spatial 38 | num_pools = int(np.log2(spatial)) 39 | modules = [] 40 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 41 | nn.LeakyReLU()] 42 | for i in range(num_pools - 1): 43 | modules += [ 44 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 45 | nn.LeakyReLU() 46 | ] 47 | self.convs = nn.Sequential(*modules) 48 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 49 | 50 | def forward(self, x): 51 | x = self.convs(x) 52 | x = x.view(-1, self.out_c) 53 | x = self.linear(x) 54 | return x 55 | 56 | 57 | class GradualStyleEncoder(Module): 58 | def __init__(self, num_layers, mode='ir', image_resolution = 256 ): 59 | super(GradualStyleEncoder, self).__init__() 60 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 61 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 62 | blocks = get_blocks(num_layers) 63 | if mode == 'ir': 64 | unit_module = bottleneck_IR 65 | elif mode == 'ir_se': 66 | unit_module = bottleneck_IR_SE 67 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 68 | BatchNorm2d(64), 69 | PReLU(64)) 70 | modules = [] 71 | for block in blocks: 72 | for bottleneck in block: 73 | modules.append(unit_module(bottleneck.in_channel, 74 | bottleneck.depth, 75 | bottleneck.stride)) 76 | self.body = Sequential(*modules) 77 | 78 | self.styles = nn.ModuleList() 79 | log_size = int(math.log(image_resolution, 2)) 80 | self.style_count = 2 * log_size - 2 81 | self.coarse_ind = 3 82 | self.middle_ind = 7 83 | for i in range(self.style_count): 84 | if i < self.coarse_ind: 85 | style = GradualStyleBlock(512, 512, 16) 86 | elif i < self.middle_ind: 87 | style = GradualStyleBlock(512, 512, 32) 88 | else: 89 | style = GradualStyleBlock(512, 512, 64) 90 | self.styles.append(style) 91 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 92 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 93 | 94 | def forward(self, x): 95 | x = self.input_layer(x) 96 | 97 | latents = [] 98 | modulelist = list(self.body._modules.values()) 99 | for i, l in enumerate(modulelist): 100 | x = l(x) 101 | if i == 6: 102 | c1 = x 103 | elif i == 20: 104 | c2 = x 105 | elif i == 23: 106 | c3 = x 107 | 108 | for j in range(self.coarse_ind): 109 | latents.append(self.styles[j](c3)) 110 | 111 | p2 = _upsample_add(c3, self.latlayer1(c2)) 112 | for j in range(self.coarse_ind, self.middle_ind): 113 | latents.append(self.styles[j](p2)) 114 | 115 | p1 = _upsample_add(p2, self.latlayer2(c1)) 116 | for j in range(self.middle_ind, self.style_count): 117 | latents.append(self.styles[j](p1)) 118 | 119 | out = torch.stack(latents, dim=1) 120 | return out 121 | 122 | class Encoder4Editing(Module): 123 | def __init__(self, num_layers, mode='ir', image_resolution = 256 ): 124 | super(Encoder4Editing, self).__init__() 125 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 126 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 127 | blocks = get_blocks(num_layers) 128 | if mode == 'ir': 129 | unit_module = bottleneck_IR 130 | elif mode == 'ir_se': 131 | unit_module = bottleneck_IR_SE 132 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 133 | BatchNorm2d(64), 134 | PReLU(64)) 135 | modules = [] 136 | for block in blocks: 137 | for bottleneck in block: 138 | modules.append(unit_module(bottleneck.in_channel, 139 | bottleneck.depth, 140 | bottleneck.stride)) 141 | self.body = Sequential(*modules) 142 | 143 | self.styles = nn.ModuleList() 144 | log_size = int(math.log(image_resolution, 2)) 145 | self.style_count = 2 * log_size - 2 146 | self.coarse_ind = 3 147 | self.middle_ind = 7 148 | 149 | for i in range(self.style_count): 150 | if i < self.coarse_ind: 151 | style = GradualStyleBlock(512, 512, 16) 152 | elif i < self.middle_ind: 153 | style = GradualStyleBlock(512, 512, 32) 154 | else: 155 | style = GradualStyleBlock(512, 512, 64) 156 | self.styles.append(style) 157 | 158 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 159 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 160 | 161 | self.progressive_stage = ProgressiveStage.Inference 162 | 163 | def get_deltas_starting_dimensions(self): 164 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 165 | return list(range(self.style_count)) # Each dimension has a delta applied to it 166 | 167 | def set_progressive_stage(self, new_stage: ProgressiveStage): 168 | self.progressive_stage = new_stage 169 | print('Changed progressive stage to: ', new_stage) 170 | 171 | def forward(self, x): 172 | x = self.input_layer(x) 173 | 174 | modulelist = list(self.body._modules.values()) 175 | for i, l in enumerate(modulelist): 176 | x = l(x) 177 | if i == 6: 178 | c1 = x 179 | elif i == 20: 180 | c2 = x 181 | elif i == 23: 182 | c3 = x 183 | 184 | # Infer main W and duplicate it 185 | w0 = self.styles[0](c3) 186 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 187 | stage = self.progressive_stage.value 188 | 189 | features = c3 190 | for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas 191 | if i == self.coarse_ind: 192 | p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features 193 | features = p2 194 | elif i == self.middle_ind: 195 | p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features 196 | features = p1 197 | delta_i = self.styles[i](features) 198 | w[:, i] += delta_i 199 | return w 200 | 201 | class BackboneEncoderUsingLastLayerIntoW(Module): 202 | def __init__(self, num_layers, mode='ir', image_resolution = 256): 203 | super(BackboneEncoderUsingLastLayerIntoW, self).__init__() 204 | print('Using BackboneEncoderUsingLastLayerIntoW') 205 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 206 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 207 | blocks = get_blocks(num_layers) 208 | if mode == 'ir': 209 | unit_module = bottleneck_IR 210 | elif mode == 'ir_se': 211 | unit_module = bottleneck_IR_SE 212 | input_nc = 3 213 | self.input_layer = Sequential(Conv2d(input_nc, 64, (3, 3), 1, 1, bias=False), 214 | BatchNorm2d(64), 215 | PReLU(64)) 216 | self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) 217 | self.linear = EqualLinear(512, 512, lr_mul=1) 218 | modules = [] 219 | for block in blocks: 220 | for bottleneck in block: 221 | modules.append(unit_module(bottleneck.in_channel, 222 | bottleneck.depth, 223 | bottleneck.stride)) 224 | self.body = Sequential(*modules) 225 | 226 | def forward(self, x): 227 | x = self.input_layer(x) 228 | x = self.body(x) 229 | x = self.output_pool(x) 230 | x = x.view(-1, 512) 231 | x = self.linear(x) 232 | return x -------------------------------------------------------------------------------- /libs/face_models/fan_model/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 10 | stride=strd, padding=padding, bias=bias) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes): 15 | super(ConvBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 22 | 23 | if in_planes != out_planes: 24 | self.downsample = nn.Sequential( 25 | nn.BatchNorm2d(in_planes), 26 | nn.ReLU(True), 27 | nn.Conv2d(in_planes, out_planes, 28 | kernel_size=1, stride=1, bias=False), 29 | ) 30 | else: 31 | self.downsample = None 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out1 = self.bn1(x) 37 | out1 = F.relu(out1, True) 38 | out1 = self.conv1(out1) 39 | 40 | out2 = self.bn2(out1) 41 | out2 = F.relu(out2, True) 42 | out2 = self.conv2(out2) 43 | 44 | out3 = self.bn3(out2) 45 | out3 = F.relu(out3, True) 46 | out3 = self.conv3(out3) 47 | 48 | out3 = torch.cat((out1, out2, out3), 1) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(residual) 52 | 53 | out3 += residual 54 | 55 | return out3 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class HourGlass(nn.Module): 99 | def __init__(self, num_modules, depth, num_features): 100 | super(HourGlass, self).__init__() 101 | self.num_modules = num_modules 102 | self.depth = depth 103 | self.features = num_features 104 | 105 | self._generate_network(self.depth) 106 | 107 | def _generate_network(self, level): 108 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) 109 | 110 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) 111 | 112 | if level > 1: 113 | self._generate_network(level - 1) 114 | else: 115 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) 116 | 117 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) 118 | 119 | def _forward(self, level, inp): 120 | # Upper branch 121 | up1 = inp 122 | up1 = self._modules['b1_' + str(level)](up1) 123 | 124 | # Lower branch 125 | low1 = F.avg_pool2d(inp, 2, stride=2) 126 | low1 = self._modules['b2_' + str(level)](low1) 127 | 128 | if level > 1: 129 | low2 = self._forward(level - 1, low1) 130 | else: 131 | low2 = low1 132 | low2 = self._modules['b2_plus_' + str(level)](low2) 133 | 134 | low3 = low2 135 | low3 = self._modules['b3_' + str(level)](low3) 136 | 137 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest') 138 | 139 | return up1 + up2 140 | 141 | def forward(self, x): 142 | return self._forward(self.depth, x) 143 | 144 | 145 | class FAN(nn.Module): 146 | 147 | def __init__(self, num_modules=1): 148 | super(FAN, self).__init__() 149 | self.num_modules = num_modules 150 | 151 | # Base part 152 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 153 | self.bn1 = nn.BatchNorm2d(64) 154 | self.conv2 = ConvBlock(64, 128) 155 | self.conv3 = ConvBlock(128, 128) 156 | self.conv4 = ConvBlock(128, 256) 157 | 158 | # Stacking part 159 | for hg_module in range(self.num_modules): 160 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) 161 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 162 | self.add_module('conv_last' + str(hg_module), 163 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 164 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 165 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 166 | 68, kernel_size=1, stride=1, padding=0)) 167 | 168 | if hg_module < self.num_modules - 1: 169 | self.add_module( 170 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 171 | self.add_module('al' + str(hg_module), nn.Conv2d(68, 172 | 256, kernel_size=1, stride=1, padding=0)) 173 | 174 | def forward(self, x): 175 | x = F.relu(self.bn1(self.conv1(x)), True) 176 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 177 | x = self.conv3(x) 178 | x = self.conv4(x) 179 | 180 | previous = x 181 | 182 | outputs = [] 183 | for i in range(self.num_modules): 184 | hg = self._modules['m' + str(i)](previous) 185 | 186 | ll = hg 187 | ll = self._modules['top_m_' + str(i)](ll) 188 | 189 | ll = F.relu(self._modules['bn_end' + str(i)] 190 | (self._modules['conv_last' + str(i)](ll)), True) 191 | 192 | # Predict heatmaps 193 | tmp_out = self._modules['l' + str(i)](ll) 194 | outputs.append(tmp_out) 195 | 196 | if i < self.num_modules - 1: 197 | ll = self._modules['bl' + str(i)](ll) 198 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 199 | previous = previous + ll + tmp_out_ 200 | 201 | # x.register_hook(lambda grad: print('images',grad)) 202 | return outputs 203 | 204 | 205 | class ResNetDepth(nn.Module): 206 | 207 | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): 208 | self.inplanes = 64 209 | super(ResNetDepth, self).__init__() 210 | self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, 211 | bias=False) 212 | self.bn1 = nn.BatchNorm2d(64) 213 | self.relu = nn.ReLU(inplace=True) 214 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 215 | self.layer1 = self._make_layer(block, 64, layers[0]) 216 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 217 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 218 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 219 | self.avgpool = nn.AvgPool2d(7) 220 | self.fc = nn.Linear(512 * block.expansion, num_classes) 221 | 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 225 | m.weight.data.normal_(0, math.sqrt(2. / n)) 226 | elif isinstance(m, nn.BatchNorm2d): 227 | m.weight.data.fill_(1) 228 | m.bias.data.zero_() 229 | 230 | def _make_layer(self, block, planes, blocks, stride=1): 231 | downsample = None 232 | if stride != 1 or self.inplanes != planes * block.expansion: 233 | downsample = nn.Sequential( 234 | nn.Conv2d(self.inplanes, planes * block.expansion, 235 | kernel_size=1, stride=stride, bias=False), 236 | nn.BatchNorm2d(planes * block.expansion), 237 | ) 238 | 239 | layers = [] 240 | layers.append(block(self.inplanes, planes, stride, downsample)) 241 | self.inplanes = planes * block.expansion 242 | for i in range(1, blocks): 243 | layers.append(block(self.inplanes, planes)) 244 | 245 | return nn.Sequential(*layers) 246 | 247 | def forward(self, x): 248 | # print(x.shape) 249 | # x.register_hook(lambda grad: print('images',grad)) 250 | x = self.conv1(x) 251 | x = self.bn1(x) 252 | x = self.relu(x) 253 | x = self.maxpool(x) 254 | 255 | x = self.layer1(x) 256 | x = self.layer2(x) 257 | x = self.layer3(x) 258 | x = self.layer4(x) 259 | 260 | x = self.avgpool(x) 261 | x = x.view(x.size(0), -1) 262 | x = self.fc(x) 263 | 264 | 265 | return x 266 | -------------------------------------------------------------------------------- /libs/gan/StyleGAN2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /libs/DECA/decalib/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Soubhik Sanyal 3 | Copyright (c) 2019, Soubhik Sanyal 4 | All rights reserved. 5 | Loads different resnet models 6 | """ 7 | ''' 8 | file: Resnet.py 9 | date: 2018_05_02 10 | author: zhangxiong(1025679612@qq.com) 11 | mark: copied from pytorch source code 12 | ''' 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch 17 | from torch.nn.parameter import Parameter 18 | import torch.optim as optim 19 | import numpy as np 20 | import math 21 | import torchvision 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, block, layers, num_classes=1000): 25 | self.inplanes = 64 26 | super(ResNet, self).__init__() 27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 28 | bias=False) 29 | self.bn1 = nn.BatchNorm2d(64) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 32 | self.layer1 = self._make_layer(block, 64, layers[0]) 33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 36 | self.avgpool = nn.AvgPool2d(7, stride=1) 37 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 38 | 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2. / n)) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | 47 | def _make_layer(self, block, planes, blocks, stride=1): 48 | downsample = None 49 | if stride != 1 or self.inplanes != planes * block.expansion: 50 | downsample = nn.Sequential( 51 | nn.Conv2d(self.inplanes, planes * block.expansion, 52 | kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(planes * block.expansion), 54 | ) 55 | 56 | layers = [] 57 | layers.append(block(self.inplanes, planes, stride, downsample)) 58 | self.inplanes = planes * block.expansion 59 | for i in range(1, blocks): 60 | layers.append(block(self.inplanes, planes)) 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | x = self.conv1(x) 66 | x = self.bn1(x) 67 | x = self.relu(x) 68 | x = self.maxpool(x) 69 | 70 | x = self.layer1(x) 71 | x = self.layer2(x) 72 | x = self.layer3(x) 73 | x1 = self.layer4(x) 74 | 75 | x2 = self.avgpool(x1) 76 | x2 = x2.view(x2.size(0), -1) 77 | # x = self.fc(x) 78 | ## x2: [bz, 2048] for shape 79 | ## x1: [bz, 2048, 7, 7] for texture 80 | return x2 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | 85 | def __init__(self, inplanes, planes, stride=1, downsample=None): 86 | super(Bottleneck, self).__init__() 87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(planes) 89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 90 | padding=1, bias=False) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 93 | self.bn3 = nn.BatchNorm2d(planes * 4) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out += residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | def conv3x3(in_planes, out_planes, stride=1): 121 | """3x3 convolution with padding""" 122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 123 | padding=1, bias=False) 124 | 125 | class BasicBlock(nn.Module): 126 | expansion = 1 127 | 128 | def __init__(self, inplanes, planes, stride=1, downsample=None): 129 | super(BasicBlock, self).__init__() 130 | self.conv1 = conv3x3(inplanes, planes, stride) 131 | self.bn1 = nn.BatchNorm2d(planes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.conv2 = conv3x3(planes, planes) 134 | self.bn2 = nn.BatchNorm2d(planes) 135 | self.downsample = downsample 136 | self.stride = stride 137 | 138 | def forward(self, x): 139 | residual = x 140 | 141 | out = self.conv1(x) 142 | out = self.bn1(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv2(out) 146 | out = self.bn2(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | def copy_parameter_from_resnet(model, resnet_dict): 157 | cur_state_dict = model.state_dict() 158 | # import ipdb; ipdb.set_trace() 159 | for name, param in list(resnet_dict.items())[0:None]: 160 | if name not in cur_state_dict: 161 | # print(name, ' not available in reconstructed resnet') 162 | continue 163 | if isinstance(param, Parameter): 164 | param = param.data 165 | try: 166 | cur_state_dict[name].copy_(param) 167 | except: 168 | print(name, ' is inconsistent!') 169 | continue 170 | # print('copy resnet state dict finished!') 171 | # import ipdb; ipdb.set_trace() 172 | 173 | def load_ResNet50Model(): 174 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 175 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = True).state_dict()) 176 | return model 177 | 178 | def load_ResNet101Model(): 179 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 180 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict()) 181 | return model 182 | 183 | def load_ResNet152Model(): 184 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 185 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict()) 186 | return model 187 | 188 | # model.load_state_dict(checkpoint['model_state_dict']) 189 | 190 | 191 | ######## Unet 192 | 193 | class DoubleConv(nn.Module): 194 | """(convolution => [BN] => ReLU) * 2""" 195 | 196 | def __init__(self, in_channels, out_channels): 197 | super().__init__() 198 | self.double_conv = nn.Sequential( 199 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 200 | nn.BatchNorm2d(out_channels), 201 | nn.ReLU(inplace=True), 202 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 203 | nn.BatchNorm2d(out_channels), 204 | nn.ReLU(inplace=True) 205 | ) 206 | 207 | def forward(self, x): 208 | return self.double_conv(x) 209 | 210 | 211 | class Down(nn.Module): 212 | """Downscaling with maxpool then double conv""" 213 | 214 | def __init__(self, in_channels, out_channels): 215 | super().__init__() 216 | self.maxpool_conv = nn.Sequential( 217 | nn.MaxPool2d(2), 218 | DoubleConv(in_channels, out_channels) 219 | ) 220 | 221 | def forward(self, x): 222 | return self.maxpool_conv(x) 223 | 224 | 225 | class Up(nn.Module): 226 | """Upscaling then double conv""" 227 | 228 | def __init__(self, in_channels, out_channels, bilinear=True): 229 | super().__init__() 230 | 231 | # if bilinear, use the normal convolutions to reduce the number of channels 232 | if bilinear: 233 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 234 | else: 235 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 236 | 237 | self.conv = DoubleConv(in_channels, out_channels) 238 | 239 | def forward(self, x1, x2): 240 | x1 = self.up(x1) 241 | # input is CHW 242 | diffY = x2.size()[2] - x1.size()[2] 243 | diffX = x2.size()[3] - x1.size()[3] 244 | 245 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 246 | diffY // 2, diffY - diffY // 2]) 247 | # if you have padding issues, see 248 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 249 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 250 | x = torch.cat([x2, x1], dim=1) 251 | return self.conv(x) 252 | 253 | 254 | class OutConv(nn.Module): 255 | def __init__(self, in_channels, out_channels): 256 | super(OutConv, self).__init__() 257 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 258 | 259 | def forward(self, x): 260 | return self.conv(x) 261 | 262 | class UNet(nn.Module): 263 | def __init__(self, n_channels, n_classes, bilinear=True): 264 | super(UNet, self).__init__() 265 | self.n_channels = n_channels 266 | self.n_classes = n_classes 267 | self.bilinear = bilinear 268 | 269 | self.inc = DoubleConv(n_channels, 64) 270 | self.down1 = Down(64, 128) 271 | self.down2 = Down(128, 256) 272 | self.down3 = Down(256, 512) 273 | self.down4 = Down(512, 512) 274 | self.up1 = Up(1024, 256, bilinear) 275 | self.up2 = Up(512, 128, bilinear) 276 | self.up3 = Up(256, 64, bilinear) 277 | self.up4 = Up(128, 64, bilinear) 278 | self.outc = OutConv(64, n_classes) 279 | 280 | def forward(self, x): 281 | x1 = self.inc(x) 282 | x2 = self.down1(x1) 283 | x3 = self.down2(x2) 284 | x4 = self.down3(x3) 285 | x5 = self.down4(x4) 286 | x = self.up1(x5, x4) 287 | x = self.up2(x, x3) 288 | x = self.up3(x, x2) 289 | x = self.up4(x, x1) 290 | x = F.normalize(x) 291 | return x -------------------------------------------------------------------------------- /libs/datasets/dataloader_paired.py: -------------------------------------------------------------------------------- 1 | """ 2 | CustomDataset_paired: Custom Dataloader for real paired images training using VoxCeleb dataset 3 | CustomDataset_paired_validation: Custom Dataloader for real paired images evaluation using VoxCeleb dataset 4 | 5 | """ 6 | import torch 7 | import os 8 | import glob 9 | import cv2 10 | import numpy as np 11 | from PIL import Image 12 | import torchvision.transforms as transforms 13 | 14 | class CustomDataset_paired(): 15 | 16 | def __init__(self, dataset_path, num_samples = None, max_pairs = 2): 17 | """ 18 | VoxCeleb dataset format: id_index/video_index/frames_cropped/*.png 19 | id_index/video_index/inversion/frames/*.png 20 | id_index/video_index/inversion/latent_codes/*.npy 21 | Args: 22 | dataset_path (string): Path to dataset with inverted images. 23 | num_samples: how many samples for validation 24 | 25 | """ 26 | self.dataset_path = dataset_path 27 | self.num_samples = num_samples 28 | self.max_pairs = max_pairs 29 | self.transform = transforms.Compose([ 30 | transforms.Resize((256, 256)), 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 33 | 34 | self.get_dataset() 35 | 36 | def get_dataset(self): 37 | 38 | ids_path = glob.glob(os.path.join(self.dataset_path, '*/')) 39 | ids_path.sort() 40 | if len(ids_path) == 0: 41 | print('Dataset has no identities in path {}'.format(self.dataset_path)) 42 | exit() 43 | 44 | sum_images = 0 45 | counter_ids = 0; counter_videos = 0; counter_imgs = 0 46 | self.samples_dict = {} 47 | self.videos_dict = {} 48 | for i, ids in enumerate(ids_path): 49 | id_index = ids.split('/')[-2] 50 | self.videos_dict.update( {id_index: dict()} ) 51 | videos_path = glob.glob(os.path.join(ids, '*/')) 52 | videos_path.sort() 53 | counter_ids += 1 54 | for k, video_path in enumerate(videos_path): 55 | video_id = video_path.split('/')[-2] 56 | images_path = glob.glob(os.path.join(video_path, 'frames_cropped', '*.png')) # real frames 57 | images_path.sort() 58 | 59 | if not os.path.exists(os.path.join(video_path, 'inversion')): 60 | print('Path with inverted latent codes does not exist.') 61 | exit() 62 | 63 | latent_codes_path = glob.glob(os.path.join(video_path, 'inversion', 'latent_codes', '*.npy')) 64 | latent_codes_path.sort() 65 | 66 | if len(images_path) > 0 and len(latent_codes_path) > 0: 67 | indices = np.random.permutation(len(images_path)) 68 | images_path = np.asarray(images_path); latent_codes_path = np.asarray(latent_codes_path) 69 | images_path = images_path[indices.astype(int)] 70 | latent_codes_path = latent_codes_path[indices.astype(int)] 71 | dict_ = { 72 | 'num_frames': len(images_path), 73 | 'frames': images_path, 74 | 'latent_codes': latent_codes_path, 75 | } 76 | self.videos_dict[id_index].update( {video_id: dict_}) 77 | 78 | if len(images_path) >= 2: 79 | images_path_source = images_path[:self.max_pairs] 80 | 81 | for j, image_path in enumerate(images_path_source): 82 | data = [id_index, video_id, j] 83 | self.samples_dict.update( {counter_imgs: data} ) 84 | counter_imgs += 1 85 | counter_videos += 1 86 | 87 | 88 | self.num_samples = counter_imgs 89 | self.counter_ids = counter_ids 90 | self.counter_videos = counter_videos 91 | 92 | def get_length(self): 93 | return self.num_samples, self.counter_ids, self.counter_videos 94 | 95 | def __len__(self): 96 | return self.num_samples 97 | 98 | def __getitem__(self, index): 99 | 100 | source_sample = self.samples_dict[index] 101 | source_id = source_sample[0] 102 | source_video = source_sample[1] 103 | source_index = source_sample[2] 104 | # Get target sample from the same video sequence 105 | video_dict = self.videos_dict[source_id][source_video] 106 | frames_path = video_dict['frames'] 107 | latent_codes_path = video_dict['latent_codes'] 108 | num_frames = video_dict['num_frames'] 109 | 110 | target_index = np.random.randint(num_frames, size = 1)[0] 111 | while target_index == source_index: 112 | target_index = np.random.randint(num_frames, size = 1)[0] 113 | 114 | source_img_path = frames_path[source_index] 115 | source_code_path = latent_codes_path[source_index] 116 | target_img_path = frames_path[target_index] 117 | target_code_path = latent_codes_path[target_index] 118 | 119 | 120 | # Source sample 121 | source_img = Image.open(source_img_path) 122 | source_img = source_img.convert('RGB') 123 | source_img = self.transform(source_img) 124 | source_latent_code = np.load(source_code_path) 125 | source_latent_code = torch.from_numpy(source_latent_code) 126 | # assert source_latent_code.ndim == 2, 'latent code dimensions should be inject_index x 512 while now is {}'.format(source_latent_code.shape) 127 | 128 | # Target sample 129 | target_img = Image.open(target_img_path) 130 | target_img = target_img.convert('RGB') 131 | target_img = self.transform(target_img) 132 | target_latent_code = np.load(target_code_path) 133 | target_latent_code = torch.from_numpy(target_latent_code) 134 | 135 | # assert target_latent_code.ndim == 2, 'latent code dimensions should be inject_index x 512 while now is {}'.format(target_latent_code.shape) 136 | if target_latent_code.ndim == 3: 137 | target_latent_code = target_latent_code.squeeze(0) 138 | if source_latent_code.ndim == 3: 139 | source_latent_code = source_latent_code.squeeze(0) 140 | 141 | sample = { 142 | 'source_img': source_img, 143 | 'source_latent_code': source_latent_code, 144 | 'target_img': target_img, 145 | 'target_latent_code': target_latent_code, 146 | 147 | } 148 | return sample 149 | 150 | 151 | class CustomDataset_paired_validation(): 152 | 153 | def __init__(self, dataset_path, num_samples = None): 154 | """ 155 | VoxCeleb dataset format: id_index/video_index/frames_cropped/*.png 156 | id_index/video_index/inversion/frames/*.png 157 | id_index/video_index/inversion/latent_codes/*.npy 158 | Args: 159 | dataset_path (string): Path to dataset with inverted images. 160 | num_samples: how many samples for validation 161 | 162 | """ 163 | 164 | self.dataset_path = dataset_path 165 | self.num_samples = num_samples 166 | self.transform = transforms.Compose([ 167 | transforms.Resize((256, 256)), 168 | transforms.ToTensor(), 169 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 170 | 171 | self.get_dataset() 172 | 173 | def get_dataset(self): 174 | 175 | ids_path = glob.glob(os.path.join(self.dataset_path, '*/')) 176 | ids_path.sort() 177 | if len(ids_path) == 0: 178 | print('Dataset has no identities in path {}'.format(self.dataset_path)) 179 | exit() 180 | 181 | sum_images = 0 182 | counter_ids = 0; counter_videos = 0; counter_imgs = 0 183 | self.samples_dict = {} 184 | self.videos_dict = {} 185 | for i, ids in enumerate(ids_path): 186 | id_index = ids.split('/')[-2] 187 | self.videos_dict.update( {id_index: dict()} ) 188 | videos_path = glob.glob(os.path.join(ids, '*/')) 189 | videos_path.sort() 190 | counter_ids += 1 191 | for k, video_path in enumerate(videos_path): 192 | video_id = video_path.split('/')[-2] 193 | images_path = glob.glob(os.path.join(video_path, 'frames_cropped', '*.png')) # real frames 194 | images_path.sort() 195 | if not os.path.exists(os.path.join(video_path, 'inversion')): 196 | print('Path with inverted latent codes does not exist.') 197 | exit() 198 | latent_codes_path = glob.glob(os.path.join(video_path, 'inversion', 'latent_codes', '*.npy')) 199 | latent_codes_path.sort() 200 | 201 | dict_ = { 202 | 'num_frames': len(images_path), 203 | 'frames': images_path, 204 | 'latent_codes': latent_codes_path 205 | } 206 | self.videos_dict[id_index].update( {video_id: dict_}) 207 | 208 | if len(images_path) >= 2: 209 | for j, image_path in enumerate(images_path): 210 | target_index = np.random.randint(len(images_path), size = 1)[0] 211 | while target_index == j: 212 | target_index = np.random.randint(len(images_path), size = 1)[0] 213 | data = [id_index, video_id, j, target_index] 214 | self.samples_dict.update( {counter_imgs: data} ) 215 | counter_imgs += 1 216 | counter_videos += 1 217 | 218 | self.num_samples = counter_imgs 219 | self.counter_ids = counter_ids 220 | self.counter_videos = counter_videos 221 | 222 | def get_length(self): 223 | return self.num_samples 224 | 225 | def __len__(self): 226 | return self.num_samples 227 | 228 | def __getitem__(self, index): 229 | 230 | source_sample = self.samples_dict[index] 231 | source_id = source_sample[0] 232 | source_video = source_sample[1] 233 | source_index = source_sample[2] 234 | target_index = source_sample[3] 235 | # Get target sample from the same video sequence 236 | video_dict = self.videos_dict[source_id][source_video] 237 | frames_path = video_dict['frames'] 238 | latent_codes_path = video_dict['latent_codes'] 239 | num_frames = video_dict['num_frames'] 240 | 241 | source_img_path = frames_path[source_index] 242 | source_code_path = latent_codes_path[source_index] 243 | target_img_path = frames_path[target_index] 244 | target_code_path = latent_codes_path[target_index] 245 | 246 | 247 | # Source sample 248 | source_img = Image.open(source_img_path) 249 | source_img = source_img.convert('RGB') 250 | source_img = self.transform(source_img) 251 | source_latent_code = np.load(source_code_path) 252 | source_latent_code = torch.from_numpy(source_latent_code) 253 | # assert source_latent_code.ndim == 2, 'latent code dimensions should be inject_index x 512 while now is {}'.format(source_latent_code.shape) 254 | 255 | # Target sample 256 | target_img = Image.open(target_img_path) 257 | target_img = target_img.convert('RGB') 258 | target_img = self.transform(target_img) 259 | target_latent_code = np.load(target_code_path) 260 | target_latent_code = torch.from_numpy(target_latent_code) 261 | 262 | # assert target_latent_code.ndim == 2, 'latent code dimensions should be inject_index x 512 while now is {}'.format(target_latent_code.shape) 263 | if target_latent_code.ndim == 3: 264 | target_latent_code = target_latent_code.squeeze(0) 265 | if source_latent_code.ndim == 3: 266 | source_latent_code = source_latent_code.squeeze(0) 267 | 268 | sample = { 269 | 'source_img': source_img, 270 | 'source_latent_code': source_latent_code, 271 | 'target_img': target_img, 272 | 'target_latent_code': target_latent_code, 273 | 274 | } 275 | return sample 276 | --------------------------------------------------------------------------------