├── libs ├── DECA │ ├── decalib │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── deca.cpython-38.pyc │ │ │ ├── deca.cpython-39.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── __init__.cpython-39.pyc │ │ ├── models │ │ │ ├── __pycache__ │ │ │ │ ├── FLAME.cpython-38.pyc │ │ │ │ ├── lbs.cpython-38.pyc │ │ │ │ ├── resnet.cpython-38.pyc │ │ │ │ ├── decoders.cpython-38.pyc │ │ │ │ └── encoders.cpython-38.pyc │ │ │ ├── encoders.py │ │ │ ├── decoders.py │ │ │ ├── resnet.py │ │ │ ├── FLAME.py │ │ │ └── lbs.py │ │ ├── utils │ │ │ ├── __pycache__ │ │ │ │ ├── config.cpython-38.pyc │ │ │ │ ├── util.cpython-38.pyc │ │ │ │ ├── renderer.cpython-38.pyc │ │ │ │ ├── renderer.cpython-39.pyc │ │ │ │ └── rotation_converter.cpython-38.pyc │ │ │ ├── config.py │ │ │ └── rotation_converter.py │ │ └── datasets │ │ │ ├── __pycache__ │ │ │ ├── datasets.cpython-38.pyc │ │ │ ├── detectors.cpython-38.pyc │ │ │ └── detectors_2.cpython-38.pyc │ │ │ ├── detectors.py │ │ │ └── datasets.py │ ├── estimate_DECA.py │ └── README.md ├── criteria │ ├── lpips │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── lpips.cpython-38.pyc │ │ │ ├── lpips.cpython-39.pyc │ │ │ ├── utils.cpython-38.pyc │ │ │ ├── utils.cpython-39.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── networks.cpython-38.pyc │ │ │ └── networks.cpython-39.pyc │ │ ├── utils.py │ │ ├── lpips.py │ │ └── networks.py │ ├── l2_loss.py │ ├── id_loss.py │ ├── losses.py │ ├── model_irse.py │ └── helpers.py ├── configs │ ├── ranges_FFHQ.npy │ ├── random_latent_codes_100.npy │ ├── __pycache__ │ │ └── config_models.cpython-38.pyc │ └── config_models.py ├── models │ ├── StyleGAN2 │ │ ├── op │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── fused_act.cpython-37.pyc │ │ │ │ ├── fused_act.cpython-38.pyc │ │ │ │ ├── fused_act.cpython-39.pyc │ │ │ │ ├── upfirdn2d.cpython-38.pyc │ │ │ │ └── upfirdn2d.cpython-39.pyc │ │ │ ├── fused_bias_act.cpp │ │ │ ├── upfirdn2d.cpp │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ ├── __pycache__ │ │ │ ├── model.cpython-37.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ └── model.cpython-39.pyc │ │ └── convert_weight.py │ ├── mask_predictor.py │ └── inversion │ │ ├── psp.py │ │ ├── helpers.py │ │ └── psp_encoders.py └── utilities │ ├── __pycache__ │ ├── utils.cpython-38.pyc │ └── stylespace_utils.cpython-38.pyc │ ├── image_utils.py │ ├── dataloader.py │ ├── utils.py │ ├── ffhq_cropping.py │ ├── stylespace_utils.py │ └── utils_inference.py ├── images ├── architecture.png ├── source_target.png └── source_target_gif.gif ├── requirements.txt ├── extract_statistics.py ├── run_trainer.py ├── README.md └── run_inference.py /libs/DECA/decalib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /libs/criteria/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/images/architecture.png -------------------------------------------------------------------------------- /images/source_target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/images/source_target.png -------------------------------------------------------------------------------- /images/source_target_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/images/source_target_gif.gif -------------------------------------------------------------------------------- /libs/configs/ranges_FFHQ.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/configs/ranges_FFHQ.npy -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /libs/configs/random_latent_codes_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/configs/random_latent_codes_100.npy -------------------------------------------------------------------------------- /libs/utilities/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/utilities/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/__pycache__/deca.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/__pycache__/deca.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/__pycache__/deca.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/__pycache__/deca.cpython-39.pyc -------------------------------------------------------------------------------- /libs/criteria/lpips/__pycache__/lpips.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/criteria/lpips/__pycache__/lpips.cpython-38.pyc -------------------------------------------------------------------------------- /libs/criteria/lpips/__pycache__/lpips.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/criteria/lpips/__pycache__/lpips.cpython-39.pyc -------------------------------------------------------------------------------- /libs/criteria/lpips/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/criteria/lpips/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /libs/criteria/lpips/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/criteria/lpips/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /libs/configs/__pycache__/config_models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/configs/__pycache__/config_models.cpython-38.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/models/__pycache__/FLAME.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/models/__pycache__/FLAME.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/models/__pycache__/lbs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/models/__pycache__/lbs.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/utils/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/utils/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /libs/criteria/lpips/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/criteria/lpips/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /libs/criteria/lpips/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/criteria/lpips/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /libs/criteria/lpips/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/criteria/lpips/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /libs/criteria/lpips/__pycache__/networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/criteria/lpips/__pycache__/networks.cpython-39.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/utils/__pycache__/renderer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/utils/__pycache__/renderer.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/utils/__pycache__/renderer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/utils/__pycache__/renderer.cpython-39.pyc -------------------------------------------------------------------------------- /libs/utilities/__pycache__/stylespace_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/utilities/__pycache__/stylespace_utils.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/datasets/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/datasets/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/models/__pycache__/decoders.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/models/__pycache__/decoders.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/models/__pycache__/encoders.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/models/__pycache__/encoders.cpython-38.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/op/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/op/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/op/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__pycache__/fused_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/op/__pycache__/fused_act.cpython-37.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__pycache__/fused_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/op/__pycache__/fused_act.cpython-38.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__pycache__/fused_act.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/op/__pycache__/fused_act.cpython-39.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__pycache__/upfirdn2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/op/__pycache__/upfirdn2d.cpython-38.pyc -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/__pycache__/upfirdn2d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/models/StyleGAN2/op/__pycache__/upfirdn2d.cpython-39.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/datasets/__pycache__/detectors.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/datasets/__pycache__/detectors.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/datasets/__pycache__/detectors_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/datasets/__pycache__/detectors_2.cpython-38.pyc -------------------------------------------------------------------------------- /libs/DECA/decalib/utils/__pycache__/rotation_converter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/StyleMask/HEAD/libs/DECA/decalib/utils/__pycache__/rotation_converter.cpython-38.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | imageio 5 | Pillow 6 | ninja 7 | opencv-python 8 | tqdm 9 | scikit-image 10 | sklearn 11 | wandb 12 | face-alignment 13 | kornia==0.4.1 14 | chumpy 15 | -------------------------------------------------------------------------------- /libs/criteria/l2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | l2_criterion = torch.nn.MSELoss(reduction='mean') 4 | 5 | 6 | def l2_loss(real_images, generated_images): 7 | loss = l2_criterion(real_images, generated_images) 8 | return loss 9 | -------------------------------------------------------------------------------- /libs/configs/config_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | stylegan2_ffhq_1024 = { 5 | 'image_resolution': 1024, 6 | 'channel_multiplier': 2, 7 | 'gan_weights': './pretrained_models/stylegan2-ffhq-config-f_1024.pt', 8 | 9 | 'stylespace_dim': 6048, 10 | 'split_sections': [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 128, 128, 64, 64, 32], 11 | 12 | 'e4e_inversion_model': './pretrained_models/e4e_ffhq_encode_1024.pt', 13 | 'expression_ranges': './libs/configs/ranges_FFHQ.npy' # Used for evaluation 14 | } 15 | 16 | -------------------------------------------------------------------------------- /libs/models/mask_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MaskPredictor(nn.Module): 6 | def __init__(self, input_dim, output_dim, inner_dim=1024): 7 | super(MaskPredictor, self).__init__() 8 | 9 | 10 | self.masknet = nn.Sequential(nn.Linear(input_dim, inner_dim, bias=True), 11 | nn.ReLU(), 12 | nn.Linear(inner_dim, output_dim, bias=True), 13 | ) 14 | 15 | self.initilization() 16 | 17 | def initilization(self): 18 | torch.nn.init.normal_(self.masknet[0].weight, mean=0.0, std=0.01) 19 | torch.nn.init.normal_(self.masknet[2].weight, mean=0.0, std=0.01) 20 | 21 | def forward(self, input): 22 | 23 | out = self.masknet(input) 24 | out = torch.nn.Sigmoid()(out) 25 | 26 | return out 27 | 28 | -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | //#include 2 | 3 | #pragma warning(push, 0) 4 | #include 5 | #pragma warning(pop) 6 | 7 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 15 | int act, int grad, float alpha, float scale) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(bias); 18 | 19 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 24 | } -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #pragma warning(push, 0) 2 | #include 3 | #pragma warning(pop) 4 | 5 | 6 | 7 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 8 | int up_x, int up_y, int down_x, int down_y, 9 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 16 | int up_x, int up_y, int down_x, int down_y, 17 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(kernel); 20 | 21 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 26 | } -------------------------------------------------------------------------------- /libs/criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | # print(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | # if torch.isnan(x).any(): 9 | # # print(gradients_keep) 10 | # pdb.set_trace() 11 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)+1e-9) 12 | return x / (norm_factor + eps) 13 | 14 | 15 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 16 | # build url 17 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 18 | + f'master/lpips/weights/v{version}/{net_type}.pth' 19 | 20 | # download 21 | old_state_dict = torch.hub.load_state_dict_from_url( 22 | url, progress=True, 23 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 24 | ) 25 | 26 | # rename keys 27 | new_state_dict = OrderedDict() 28 | for key, val in old_state_dict.items(): 29 | new_key = key 30 | new_key = new_key.replace('lin', '') 31 | new_key = new_key.replace('model.', '') 32 | new_state_dict[new_key] = val 33 | 34 | return new_state_dict 35 | -------------------------------------------------------------------------------- /libs/criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures https://github.com/eladrich/pixel2style2pixel 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | # pretrained network 22 | self.net = get_network(net_type).to("cuda") 23 | 24 | # linear layers 25 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 26 | self.lin.load_state_dict(get_state_dict(net_type, version)) 27 | 28 | def forward(self, x: torch.Tensor, y: torch.Tensor): 29 | feat_x, feat_y = self.net(x), self.net(y) 30 | 31 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 32 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 33 | 34 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 35 | -------------------------------------------------------------------------------- /libs/criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .model_irse import Backbone 4 | import os 5 | import torch.backends.cudnn as cudnn 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self, pretrained_model_path = './pretrained_models/model_ir_se50.pth'): 9 | super(IDLoss, self).__init__() 10 | print('Loading ResNet ArcFace for identity loss') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | if not os.path.exists(pretrained_model_path): 13 | print('ir_se50 model does not exist in {}'.format(pretrained_model_path)) 14 | exit() 15 | self.facenet.load_state_dict(torch.load(pretrained_model_path)) 16 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 17 | self.facenet.eval() 18 | self.criterion = nn.CosineSimilarity(dim=1, eps=1e-6) 19 | 20 | def extract_feats(self, x, crop = True): 21 | if crop: 22 | x = x[:, :, 35:223, 32:220] # Crop interesting region 23 | x = self.face_pool(x) 24 | x_feats = self.facenet(x) 25 | return x_feats 26 | 27 | def forward(self, y_hat, y, crop = True): 28 | n_samples = y.shape[0] 29 | y_feats = self.extract_feats(y, crop) 30 | y_hat_feats = self.extract_feats(y_hat, crop) 31 | cosine_sim = self.criterion(y_hat_feats, y_feats.detach()) 32 | loss = 1 - cosine_sim 33 | loss = torch.mean(loss) 34 | return loss 35 | -------------------------------------------------------------------------------- /libs/DECA/decalib/models/encoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | import torch 19 | import torch.nn.functional as F 20 | from . import resnet 21 | 22 | class ResnetEncoder(nn.Module): 23 | def __init__(self, outsize, last_op=None): 24 | super(ResnetEncoder, self).__init__() 25 | feature_size = 2048 26 | self.encoder = resnet.load_ResNet50Model() #out: 2048 27 | ### regressor 28 | self.layers = nn.Sequential( 29 | nn.Linear(feature_size, 1024), 30 | nn.ReLU(), 31 | nn.Linear(1024, outsize) 32 | ) 33 | self.last_op = last_op 34 | 35 | def forward(self, inputs): 36 | features = self.encoder(inputs) 37 | parameters = self.layers(features) 38 | if self.last_op: 39 | parameters = self.last_op(parameters) 40 | return parameters 41 | -------------------------------------------------------------------------------- /libs/utilities/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import torchvision 5 | import os 6 | 7 | " Read image from path" 8 | def read_image_opencv(image_path): 9 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) # BGR order!!!! 10 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 11 | 12 | return img.astype('uint8') 13 | 14 | " image numpy array to tensor [-1,1] range " 15 | def image_to_tensor(image): 16 | max_val = 1 17 | min_val = -1 18 | if image.shape[0]>256: 19 | image, _ = image_resize(image, 256) 20 | image_tensor = torch.tensor(np.transpose(image,(2,0,1))).float().div(255.0) 21 | image_tensor = image_tensor * (max_val - min_val) + min_val 22 | 23 | return image_tensor 24 | 25 | def tensor_to_255(image): 26 | img_tmp = image.clone() 27 | min_val = -1 28 | max_val = 1 29 | img_tmp.clamp_(min=min_val, max=max_val) 30 | img_tmp.add_(-min_val).div_(max_val - min_val + 1e-5) 31 | img_tmp = img_tmp.mul(255.0).add(0.0) 32 | return img_tmp 33 | 34 | def torch_image_resize(image, width = None, height = None): 35 | dim = None 36 | (h, w) = image.shape[1:] 37 | # if both the width and height are None, then return the 38 | # original image 39 | if width is None and height is None: 40 | return image 41 | 42 | # check to see if the width is None 43 | if width is None: 44 | # calculate the ratio of the height and construct the 45 | # dimensions 46 | r = height / float(h) 47 | dim = (height, int(w * r)) 48 | scale = r 49 | # otherwise, the height is None 50 | else: 51 | # calculate the ratio of the width and construct the 52 | # dimensions 53 | r = width / float(w) 54 | dim = (int(h * r), width) 55 | scale = r 56 | image = image.unsqueeze(0) 57 | image = torch.nn.functional.interpolate(image, size=dim, mode='bilinear') 58 | return image.squeeze(0) -------------------------------------------------------------------------------- /libs/models/inversion/psp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import math 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import torch 8 | from torch import nn 9 | import torchvision.transforms as transforms 10 | import os 11 | 12 | from libs.models.inversion import psp_encoders 13 | 14 | 15 | 16 | def get_keys(d, name): 17 | if 'state_dict' in d: 18 | d = d['state_dict'] 19 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 20 | return d_filt 21 | 22 | 23 | class pSp(nn.Module): 24 | 25 | def __init__(self, opts): 26 | super(pSp, self).__init__() 27 | 28 | self.opts = opts 29 | # compute number of style inputs based on the output resolution 30 | self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 31 | self.n_styles = self.opts.n_styles 32 | # Define architecture 33 | self.encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts) 34 | # Load weights if needed 35 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 36 | self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) 37 | 38 | self.__load_latent_avg(ckpt) 39 | 40 | def forward(self, real_image, randomize_noise=False, inject_latent=None, return_latents=False, alpha=None, average_code=False, input_is_full=False): 41 | 42 | codes = self.encoder(real_image) 43 | if self.latent_avg is not None: 44 | if codes.ndim == 2: 45 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] 46 | else: 47 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 48 | return codes 49 | 50 | def __load_latent_avg(self, ckpt, repeat=None): 51 | if 'latent_avg' in ckpt: 52 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 53 | if repeat is not None: 54 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 55 | else: 56 | self.latent_avg = None -------------------------------------------------------------------------------- /libs/DECA/decalib/datasets/detectors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import numpy as np 17 | import torch 18 | # from libs.pose_estimation.fan_model.models import FAN, ResNetDepth 19 | # from libs.pose_estimation.fan_model.utils import * 20 | from enum import Enum 21 | # from libs.pose_estimation.sfd.sfd_detector import SFDDetector as FaceDetector 22 | 23 | class FAN(object): 24 | def __init__(self): 25 | import face_alignment 26 | self.model = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) 27 | 28 | def run(self, image): 29 | ''' 30 | image: 0-255, uint8, rgb, [h, w, 3] 31 | return: detected box list 32 | ''' 33 | 34 | out = self.model.get_landmarks(image) 35 | if out is None: 36 | return [0], 'error' 37 | else: 38 | kpt = out[0].squeeze() 39 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 40 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 41 | bbox = [left,top, right, bottom] 42 | return bbox, 'kpt68' 43 | 44 | 45 | class MTCNN(object): 46 | def __init__(self, device = 'cpu'): 47 | ''' 48 | https://github.com/timesler/facenet-pytorch/blob/master/examples/infer.ipynb 49 | ''' 50 | from facenet_pytorch import MTCNN as mtcnn 51 | self.device = device 52 | self.model = mtcnn(keep_all=True) 53 | def run(self, input): 54 | ''' 55 | image: 0-255, uint8, rgb, [h, w, 3] 56 | return: detected box 57 | ''' 58 | out = self.model.detect(input[None,...]) 59 | if out[0][0] is None: 60 | return [0] 61 | else: 62 | bbox = out[0][0].squeeze() 63 | return bbox, 'bbox' 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /libs/DECA/estimate_DECA.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | 4 | import torch 5 | import numpy as np 6 | import cv2 7 | import os 8 | 9 | from .decalib.deca import DECA 10 | from .decalib.datasets import datasets 11 | from .decalib.utils import util 12 | from .decalib.utils.config import cfg as deca_cfg 13 | from .decalib.utils.rotation_converter import * 14 | 15 | 16 | class DECA_model(): 17 | 18 | def __init__(self, device): 19 | deca_cfg.model.use_tex = False 20 | dir_path = os.path.dirname(os.path.realpath(__file__)) 21 | models_path = os.path.join(dir_path, 'data') 22 | if not os.path.exists(models_path): 23 | print('Please download the required data for DECA model. See Readme.') 24 | exit() 25 | self.deca = DECA(config = deca_cfg, device=device) 26 | self.data = datasets.TestData() 27 | 28 | 'Batch torch tensor' 29 | def extract_DECA_params(self, images): 30 | 31 | p_tensor = torch.zeros(images.shape[0], 6).cuda() 32 | alpha_shp_tensor = torch.zeros(images.shape[0], 100).cuda() 33 | alpha_exp_tensor = torch.zeros(images.shape[0], 50).cuda() 34 | angles = torch.zeros(images.shape[0], 3).cuda() 35 | cam = torch.zeros(images.shape[0], 3).cuda() 36 | for batch in range(images.shape[0]): 37 | image_prepro, error_flag = self.data.get_image_tensor(images[batch].clone()) 38 | if not error_flag: 39 | codedict = self.deca.encode(image_prepro.unsqueeze(0).cuda()) 40 | pose = codedict['pose'][:,:3] 41 | pose = rad2deg(batch_axis2euler(pose)) 42 | p_tensor[batch] = codedict['pose'][0] 43 | alpha_shp_tensor[batch] = codedict['shape'][0] 44 | alpha_exp_tensor[batch] = codedict['exp'][0] 45 | cam[batch] = codedict['cam'][0] 46 | angles[batch] = pose 47 | else: 48 | angles[batch][0] = -180 49 | angles[batch][1] = -180 50 | angles[batch][2] = -180 51 | 52 | return p_tensor, alpha_shp_tensor, alpha_exp_tensor, angles, cam 53 | 54 | def calculate_shape(self, coefficients, image = None, save_path = None, prefix = None): 55 | landmarks2d, landmarks3d, points = self.deca.decode(coefficients) 56 | return landmarks2d, landmarks3d, points 57 | 58 | -------------------------------------------------------------------------------- /libs/criteria/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | """ 5 | Calculate shape losses 6 | """ 7 | 8 | class Losses(): 9 | def __init__(self): 10 | self.criterion_mse = torch.nn.MSELoss() 11 | self.criterion_l1 = torch.nn.L1Loss() 12 | self.image_deca_size = 224 13 | 14 | def calculate_pixel_wise_loss(self, images_shifted, images): 15 | 16 | pixel_wise_loss = self.criterion_l1(images, images_shifted) 17 | 18 | return pixel_wise_loss 19 | 20 | def calculate_shape_loss(self, shape_gt, shape_reenacted, normalize = False): 21 | criterion_l1 = torch.nn.L1Loss() 22 | if normalize: 23 | shape_gt_norm = shape_gt/200 #self.image_deca_size 24 | shape_reenacted_norm = shape_reenacted/200 #self.image_deca_size 25 | loss = criterion_l1(shape_gt_norm, shape_reenacted_norm) 26 | else: 27 | loss = criterion_l1(shape_gt, shape_reenacted) 28 | return loss 29 | 30 | def calculate_eye_loss(self, shape_gt, shape_reenacted): 31 | criterion_l1 = torch.nn.L1Loss() 32 | shape_gt_norm = shape_gt.clone() 33 | shape_reenacted_norm = shape_reenacted.clone() 34 | # shape_gt_norm = shape_gt_norm/self.image_deca_size 35 | # shape_reenacted_norm = shape_reenacted_norm/self.image_deca_size 36 | eye_pairs = [(36, 39), (37, 41), (38, 40), (42, 45), (43, 47), (44, 46)] 37 | loss = 0 38 | for i in range(len(eye_pairs)): 39 | pair = eye_pairs[i] 40 | d_gt = abs(shape_gt[:, pair[0],:] - shape_gt[:, pair[1],:]) 41 | d_e = abs(shape_reenacted[:, pair[0],:] - shape_reenacted[:, pair[1],:]) 42 | loss += criterion_l1(d_gt, d_e) 43 | 44 | loss = loss/len(eye_pairs) 45 | return loss 46 | 47 | def calculate_mouth_loss(self, shape_gt, shape_reenacted): 48 | criterion_l1 = torch.nn.L1Loss() 49 | shape_gt_norm = shape_gt.clone() 50 | shape_reenacted_norm = shape_reenacted.clone() 51 | # shape_gt_norm = shape_gt_norm/self.image_deca_size 52 | # shape_reenacted_norm = shape_reenacted_norm/self.image_deca_size 53 | mouth_pairs = [(48, 54), (49, 59), (50, 58), (51, 57), (52, 56), (53, 55), (60, 64), (61, 67), (62, 66), (63, 65)] 54 | loss = 0 55 | for i in range(len(mouth_pairs)): 56 | pair = mouth_pairs[i] 57 | d_gt = abs(shape_gt[:, pair[0],:] - shape_gt[:, pair[1],:]) 58 | d_e = abs(shape_reenacted[:, pair[0],:] - shape_reenacted[:, pair[1],:]) 59 | loss += criterion_l1(d_gt, d_e) 60 | 61 | loss = loss/len(mouth_pairs) 62 | return loss 63 | 64 | -------------------------------------------------------------------------------- /libs/utilities/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | import torch 4 | import os 5 | import glob 6 | import cv2 7 | import numpy as np 8 | from torchvision import transforms, utils 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | 12 | from libs.utilities.utils import make_noise 13 | 14 | np.random.seed(0) 15 | 16 | class CustomDataset_validation(Dataset): 17 | 18 | def __init__(self, synthetic_dataset_path = None, validation_pairs = None, shuffle = True): 19 | """ 20 | Args: 21 | synthetic_dataset_path: path to synthetic latent codes. If None generate random 22 | num_samples: how many samples for validation 23 | 24 | """ 25 | self.shuffle = shuffle 26 | self.validation_pairs = validation_pairs 27 | self.synthetic_dataset_path = synthetic_dataset_path 28 | 29 | if self.synthetic_dataset_path is not None: 30 | z_codes = np.load(self.synthetic_dataset_path) 31 | z_codes = torch.from_numpy(z_codes) 32 | if self.validation_pairs is not None: 33 | self.num_samples = 2 * self.validation_pairs 34 | if z_codes.shape[0] > self.num_samples: 35 | z_codes = z_codes[:self.num_samples] 36 | else: 37 | self.num_samples = z_codes.shape[0] 38 | self.validation_pairs = int(self.num_samples/2) 39 | else: 40 | self.validation_pairs = int(z_codes.shape[0]/2) 41 | self.num_samples = 2 * self.validation_pairs 42 | 43 | self.fixed_source_w = z_codes[:self.validation_pairs, :] 44 | self.fixed_target_w = z_codes[self.validation_pairs:2*self.validation_pairs, :] 45 | else: 46 | self.fixed_source_w = make_noise(self.validation_pairs, 512, None) 47 | self.fixed_target_w = make_noise(self.validation_pairs, 512, None) 48 | # Save random generated latent codes 49 | save_path = './libs/configs/random_latent_codes_{}.npy'.format(self.validation_pairs) 50 | z_codes = torch.cat((self.fixed_source_w, self.fixed_target_w), dim = 0) 51 | np.save(save_path, z_codes.detach().cpu().numpy()) 52 | 53 | self.transform = transforms.Compose([ 54 | transforms.Resize((256, 256)), 55 | transforms.ToTensor(), 56 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 57 | 58 | 59 | def __len__(self): 60 | return self.validation_pairs 61 | 62 | def __getitem__(self, index): 63 | 64 | source_w = self.fixed_source_w[index] 65 | target_w = self.fixed_target_w[index] 66 | sample = { 67 | 'source_w': source_w, 68 | 'target_w': target_w 69 | } 70 | return sample 71 | 72 | -------------------------------------------------------------------------------- /libs/DECA/decalib/models/decoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | class Generator(nn.Module): 20 | def __init__(self, latent_dim=100, out_channels=1, out_scale=0.01, sample_mode = 'bilinear'): 21 | super(Generator, self).__init__() 22 | self.out_scale = out_scale 23 | 24 | self.init_size = 32 // 4 # Initial size before upsampling 25 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2)) 26 | self.conv_blocks = nn.Sequential( 27 | nn.BatchNorm2d(128), 28 | nn.Upsample(scale_factor=2, mode=sample_mode), #16 29 | nn.Conv2d(128, 128, 3, stride=1, padding=1), 30 | nn.BatchNorm2d(128, 0.8), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Upsample(scale_factor=2, mode=sample_mode), #32 33 | nn.Conv2d(128, 64, 3, stride=1, padding=1), 34 | nn.BatchNorm2d(64, 0.8), 35 | nn.LeakyReLU(0.2, inplace=True), 36 | nn.Upsample(scale_factor=2, mode=sample_mode), #64 37 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 38 | nn.BatchNorm2d(64, 0.8), 39 | nn.LeakyReLU(0.2, inplace=True), 40 | nn.Upsample(scale_factor=2, mode=sample_mode), #128 41 | nn.Conv2d(64, 32, 3, stride=1, padding=1), 42 | nn.BatchNorm2d(32, 0.8), 43 | nn.LeakyReLU(0.2, inplace=True), 44 | nn.Upsample(scale_factor=2, mode=sample_mode), #256 45 | nn.Conv2d(32, 16, 3, stride=1, padding=1), 46 | nn.BatchNorm2d(16, 0.8), 47 | nn.LeakyReLU(0.2, inplace=True), 48 | nn.Conv2d(16, out_channels, 3, stride=1, padding=1), 49 | nn.Tanh(), 50 | ) 51 | 52 | def forward(self, noise): 53 | out = self.l1(noise) 54 | out = out.view(out.shape[0], 128, self.init_size, self.init_size) 55 | img = self.conv_blocks(out) 56 | return img*self.out_scale -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | out, = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | out, = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /libs/criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /libs/DECA/decalib/utils/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Default config for DECA 3 | ''' 4 | from yacs.config import CfgNode as CN 5 | import argparse 6 | import yaml 7 | import os 8 | 9 | cfg = CN() 10 | 11 | abs_deca_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) 12 | cfg.deca_dir = abs_deca_dir 13 | cfg.device = 'cuda' 14 | cfg.device_id = '0' 15 | 16 | cfg.pretrained_modelpath = os.path.join(cfg.deca_dir, 'data', 'deca_model.tar') 17 | 18 | # ---------------------------------------------------------------------------- # 19 | # Options for Face model 20 | # ---------------------------------------------------------------------------- # 21 | cfg.model = CN() 22 | cfg.model.topology_path = os.path.join(cfg.deca_dir, 'data', 'head_template.obj') 23 | # texture data original from http://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_texture_data.zip 24 | cfg.model.dense_template_path = os.path.join(cfg.deca_dir, 'data', 'texture_data_256.npy') 25 | cfg.model.fixed_displacement_path = os.path.join(cfg.deca_dir, 'data', 'fixed_displacement_256.npy') 26 | cfg.model.flame_model_path = os.path.join(cfg.deca_dir, 'data', 'generic_model.pkl') 27 | cfg.model.flame_lmk_embedding_path = os.path.join(cfg.deca_dir, 'data', 'landmark_embedding.npy') 28 | cfg.model.face_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_mask.png') 29 | cfg.model.face_eye_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_eye_mask.png') 30 | cfg.model.mean_tex_path = os.path.join(cfg.deca_dir, 'data', 'mean_texture.jpg') 31 | cfg.model.tex_path = os.path.join(cfg.deca_dir, 'data', 'FLAME_albedo_from_BFM.npz') 32 | cfg.model.tex_type = 'BFM' # BFM, FLAME, albedoMM 33 | cfg.model.uv_size = 256 34 | cfg.model.param_list = ['shape', 'tex', 'exp', 'pose', 'cam', 'light'] 35 | cfg.model.n_shape = 100 36 | cfg.model.n_tex = 50 37 | cfg.model.n_exp = 50 38 | cfg.model.n_cam = 3 39 | cfg.model.n_pose = 6 40 | cfg.model.n_light = 27 41 | cfg.model.use_tex = False 42 | cfg.model.jaw_type = 'aa' # default use axis angle, another option: euler 43 | 44 | ## details 45 | cfg.model.n_detail = 128 46 | cfg.model.max_z = 0.01 47 | 48 | # ---------------------------------------------------------------------------- # 49 | # Options for Dataset 50 | # ---------------------------------------------------------------------------- # 51 | cfg.dataset = CN() 52 | cfg.dataset.batch_size = 24 53 | cfg.dataset.num_workers = 2 54 | cfg.dataset.image_size = 224 55 | 56 | def get_cfg_defaults(): 57 | """Get a yacs CfgNode object with default values for my_project.""" 58 | # Return a clone so that the defaults will not be altered 59 | # This is for the "local variable" use pattern 60 | return cfg.clone() 61 | 62 | def update_cfg(cfg, cfg_file): 63 | cfg.merge_from_file(cfg_file) 64 | return cfg.clone() 65 | 66 | def parse_args(): 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--cfg', type=str, help='cfg file path') 69 | 70 | args = parser.parse_args() 71 | print(args, end='\n\n') 72 | 73 | cfg = get_cfg_defaults() 74 | if args.cfg is not None: 75 | cfg_file = args.cfg 76 | cfg = update_cfg(cfg, args.cfg) 77 | cfg.cfg_file = cfg_file 78 | 79 | return cfg 80 | -------------------------------------------------------------------------------- /libs/criteria/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /libs/utilities/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torchvision import utils as torch_utils 5 | import glob 6 | from datetime import datetime 7 | import json 8 | 9 | from libs.utilities.stylespace_utils import encoder, decoder 10 | 11 | def make_path(filepath): 12 | if not os.path.exists(filepath): 13 | os.makedirs(filepath, exist_ok = True) 14 | 15 | def save_arguments_json(args, save_path, filename): 16 | out_json = os.path.join(save_path, filename) 17 | # datetime object containing current date and time 18 | now = datetime.now() 19 | dt_string = now.strftime("%d/%m/%Y %H:%M:%S") 20 | with open(out_json, 'w') as out: 21 | stat_dict = args 22 | json.dump(stat_dict, out) 23 | 24 | def get_files_frompath(path, types): 25 | files_grabbed = [] 26 | for files in types: 27 | files_grabbed.extend(glob.glob(os.path.join(path, files))) 28 | files_grabbed.sort() 29 | return files_grabbed 30 | 31 | def make_noise(batch, dim, truncation=None): 32 | if isinstance(dim, int): 33 | dim = [dim] 34 | if truncation is None or truncation == 1.0: 35 | return torch.randn([batch] + dim) 36 | else: 37 | return torch.from_numpy(truncated_noise([batch] + dim, truncation)).to(torch.float) 38 | 39 | def calculate_shapemodel(deca_model, images, image_space = 'gan'): 40 | img_tmp = images.clone() 41 | if image_space == 'gan': 42 | # invert image from [-1,1] to [0,255] 43 | min_val = -1; max_val = 1 44 | img_tmp.clamp_(min=min_val, max=max_val) 45 | img_tmp.add_(-min_val).div_(max_val - min_val + 1e-5) 46 | img_tmp = img_tmp.mul(255.0) 47 | 48 | p_tensor, alpha_shp_tensor, alpha_exp_tensor, angles, cam = deca_model.extract_DECA_params(img_tmp) # params dictionary 49 | out_dict = {} 50 | out_dict['pose'] = p_tensor 51 | out_dict['alpha_exp'] = alpha_exp_tensor 52 | out_dict['alpha_shp'] = alpha_shp_tensor 53 | out_dict['cam'] = cam 54 | 55 | return out_dict, angles.cuda() 56 | 57 | def generate_image(G, latent_code, truncation, trunc, image_resolution, split_sections, input_is_latent = False, return_latents = False, resize_image = True): 58 | 59 | img, _ = G([latent_code], return_latents = return_latents, truncation = truncation, truncation_latent = trunc, input_is_latent = input_is_latent) 60 | style_space, w, noise = encoder(G, latent_code, truncation, trunc, size = image_resolution, input_is_latent = input_is_latent) 61 | if resize_image: 62 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 63 | img = face_pool(img) 64 | 65 | return img, style_space, w, noise 66 | 67 | def generate_new_stylespace(style_source, style_target, mask, num_layers_control = None): 68 | if num_layers_control is not None: 69 | new_style_space = style_source.clone() 70 | mask_size = mask.shape[1] 71 | new_style_space[:, :mask_size] = mask * style_target[:, :mask_size] + (1-mask) * style_source[:, :mask_size] 72 | else: 73 | new_style_space = mask * style_target + (1-mask) * style_source 74 | return new_style_space 75 | 76 | def save_image(image, save_image_dir): 77 | grid = torch_utils.save_image( 78 | image, 79 | save_image_dir, 80 | normalize=True, 81 | range=(-1, 1), 82 | ) 83 | 84 | def save_grid(source_img, target_img, reenacted_img, save_path): 85 | dim = source_img.shape[2] 86 | grid_image = torch.zeros(3, dim , 3 * dim) 87 | grid_image[:, :, :dim] = source_img.squeeze(0) 88 | grid_image[:, :, dim:dim*2] = target_img.squeeze(0) 89 | grid_image[:, :, dim*2:] = reenacted_img.squeeze(0) 90 | save_image(grid_image, save_path) 91 | 92 | -------------------------------------------------------------------------------- /libs/DECA/decalib/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import os, sys 17 | import torch 18 | from torch.utils.data import Dataset, DataLoader 19 | import torchvision.transforms as transforms 20 | import numpy as np 21 | import cv2 22 | import scipy 23 | from skimage.io import imread, imsave 24 | from skimage.transform import estimate_transform, warp, resize, rescale 25 | from glob import glob 26 | import scipy.io 27 | import torch 28 | import kornia 29 | 30 | from . import detectors 31 | 32 | class TestData(Dataset): 33 | def __init__(self, iscrop=True, crop_size=224, scale=1.25): 34 | ''' 35 | testpath: folder, imagepath_list, image path, video path 36 | ''' 37 | self.crop_size = crop_size 38 | self.scale = scale 39 | self.iscrop = iscrop 40 | self.resolution_inp = crop_size 41 | self.face_detector = detectors.FAN() # CHANGE 42 | 43 | 44 | def bbox2point(self, left, right, top, bottom, type='bbox'): 45 | ''' bbox from detector and landmarks are different 46 | ''' 47 | if type=='kpt68': 48 | old_size = (right - left + bottom - top)/2*1.1 49 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ]) 50 | elif type=='bbox': 51 | old_size = (right - left + bottom - top)/2 52 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.12]) 53 | else: 54 | raise NotImplementedError 55 | return old_size, center 56 | 57 | def get_image_tensor(self, image): 58 | " image: tensor 3x256x256" 59 | img_tmp = image.clone() 60 | img_tmp = img_tmp.permute(1,2,0) 61 | bbox, bbox_type = self.face_detector.run(img_tmp) 62 | if bbox_type != 'error': 63 | if len(bbox) < 4: 64 | print('no face detected! run original image') 65 | left = 0; right = h-1; top=0; bottom=w-1 66 | else: 67 | left = bbox[0]; right=bbox[2] 68 | top = bbox[1]; bottom=bbox[3] 69 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type) 70 | size = int(old_size*self.scale) 71 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 72 | 73 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]]) 74 | tform = estimate_transform('similarity', src_pts, DST_PTS) 75 | theta = torch.tensor(tform.params, dtype=torch.float32).unsqueeze(0).cuda() 76 | 77 | image_tensor = image.clone() 78 | image_tensor = image_tensor.unsqueeze(0) 79 | dst_image = kornia.warp_affine(image_tensor, theta[:,:2,:], dsize=(224, 224)) 80 | dst_image = dst_image.div(255.) 81 | 82 | return dst_image.squeeze(0), False 83 | 84 | else: 85 | 86 | return image, True -------------------------------------------------------------------------------- /extract_statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to extract the npy file with the min, max values of facial pose parameters (yaw, pitch, roll, jaw and expressions) 3 | 1. Generate a set of random synthetic images 4 | 2. Use DECA model to extract the facial shape and the corresponding parameters 5 | 3. Calculate min, max values 6 | """ 7 | 8 | 9 | import os 10 | import glob 11 | import numpy as np 12 | from PIL import Image 13 | import torch 14 | from torch.nn import functional as F 15 | import matplotlib.pyplot as plt 16 | import json 17 | import cv2 18 | from tqdm import tqdm 19 | import argparse 20 | from torchvision import utils as torch_utils 21 | import warnings 22 | warnings.filterwarnings("ignore") 23 | 24 | from libs.configs.config_models import * 25 | from libs.utilities.utils import make_noise, make_path, calculate_shapemodel 26 | from libs.DECA.estimate_DECA import DECA_model 27 | from libs.models.StyleGAN2.model import Generator as StyleGAN2Generator 28 | 29 | 30 | 31 | def extract_stats(statistics): 32 | 33 | num_stats = statistics.shape[1] 34 | statistics = np.transpose(statistics, (1, 0)) 35 | ranges = [] 36 | for i in range(statistics.shape[0]): 37 | pred = statistics[i, :] 38 | max_ = np.amax(pred) 39 | min_ = np.amin(pred) 40 | if i == 0: 41 | label = 'yaw' 42 | elif i == 1: 43 | label = 'pitch' 44 | elif i == 2: 45 | label = 'roll' 46 | elif i == 3: 47 | label = 'jaw' 48 | else: 49 | label = 'exp_{:02d}'.format(i) 50 | 51 | print('{}/{} Min {:.2f} Max {:.2f}'.format(label, i, min_, max_)) 52 | 53 | ranges.append([min_, max_]) 54 | 55 | return ranges 56 | 57 | 58 | 59 | if __name__ == '__main__': 60 | 61 | num_images = 2000 62 | 63 | image_resolution = 1024 64 | dataset = 'FFHQ' 65 | 66 | output_path = './{}_stats'.format(dataset) 67 | make_path(output_path) 68 | 69 | gan_weights = stylegan2_ffhq_1024['gan_weights'] 70 | channel_multiplier = stylegan2_ffhq_1024['channel_multiplier'] 71 | 72 | print('----- Load generator from {} -----'.format(gan_weights)) 73 | truncation = 0.7 74 | generator = StyleGAN2Generator(image_resolution, 512, 8, channel_multiplier= channel_multiplier) 75 | generator.load_state_dict(torch.load(gan_weights)['g_ema'], strict = True) 76 | generator.cuda().eval() 77 | trunc = generator.mean_latent(4096).detach().clone() 78 | 79 | shape_model = DECA_model('cuda') 80 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 81 | 82 | statistics = [] 83 | with torch.no_grad(): 84 | for i in tqdm(range(num_images)): 85 | z = make_noise(1, 512).cuda() 86 | source_img = generator([z], return_latents = False, truncation = truncation, truncation_latent = trunc, input_is_latent = False)[0] 87 | source_img = face_pool(source_img) 88 | params_source, angles_source = calculate_shapemodel(shape_model, source_img) 89 | 90 | yaw = angles_source[:,0][0].detach().cpu().numpy() 91 | pitch = angles_source[:,1][0].detach().cpu().numpy() 92 | roll = angles_source[:, 2][0].detach().cpu().numpy() 93 | exp = params_source['alpha_exp'][0].detach().cpu().numpy() 94 | jaw = params_source['pose'][0, 3].detach().cpu().numpy() 95 | 96 | tmp = np.zeros(54) 97 | tmp[0] = yaw 98 | tmp[1] = pitch 99 | tmp[2] = roll 100 | tmp[3] = jaw 101 | tmp[4:] = exp 102 | # np.save(os.path.join(output_path, '{:07d}.npy'.format(i)), tmp) 103 | statistics.append(tmp) 104 | 105 | statistics = np.asarray(statistics) 106 | np.save(os.path.join(output_path, 'stats_all.npy'), statistics) 107 | 108 | ranges = extract_stats(statistics) 109 | 110 | np.save(os.path.join(output_path, 'ranges_{}.npy'.format(dataset)), ranges) 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /libs/utilities/ffhq_cropping.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Aling and crop images like in FFHQ dataset 3 | Code from https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py 4 | ''' 5 | import numpy as np 6 | import cv2 7 | import os 8 | import glob 9 | import matplotlib.pyplot as plt 10 | import collections 11 | import PIL.Image 12 | import PIL.ImageFile 13 | from PIL import Image 14 | import scipy.ndimage 15 | 16 | 17 | def align_crop_image(image, landmarks, transform_size = 4096, output_size = 256): 18 | lm = landmarks 19 | lm_chin = lm[0 : 17] # left-right 20 | lm_eyebrow_left = lm[17 : 22] # left-right 21 | lm_eyebrow_right = lm[22 : 27] # left-right 22 | lm_nose = lm[27 : 31] # top-down 23 | lm_nostrils = lm[31 : 36] # top-down 24 | lm_eye_left = lm[36 : 42] # left-clockwise 25 | lm_eye_right = lm[42 : 48] # left-clockwise 26 | lm_mouth_outer = lm[48 : 60] # left-clockwise 27 | lm_mouth_inner = lm[60 : 68] # left-clockwise 28 | 29 | # Calculate auxiliary vectors. 30 | eye_left = np.mean(lm_eye_left, axis=0) 31 | eye_right = np.mean(lm_eye_right, axis=0) 32 | eye_avg = (eye_left + eye_right) * 0.5 33 | eye_to_eye = eye_right - eye_left 34 | mouth_left = lm_mouth_outer[0] 35 | mouth_right = lm_mouth_outer[6] 36 | mouth_avg = (mouth_left + mouth_right) * 0.5 37 | eye_to_mouth = mouth_avg - eye_avg 38 | 39 | # Choose oriented crop rectangle. 40 | 41 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 42 | x /= np.hypot(*x) 43 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 44 | y = np.flipud(x) * [-1, 1] 45 | c = eye_avg + eye_to_mouth * 0.1 46 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 47 | qsize = np.hypot(*x) * 2 48 | 49 | 50 | img = Image.fromarray(image) 51 | 52 | shrink = int(np.floor(qsize / output_size * 0.5)) 53 | if shrink > 1: 54 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 55 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 56 | quad /= shrink 57 | qsize /= shrink 58 | 59 | # Crop. 60 | border = max(int(np.rint(qsize * 0.1)), 3) 61 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 62 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 63 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 64 | img = img.crop(crop) 65 | quad -= crop[0:2] 66 | # Pad. 67 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 68 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 69 | enable_padding = True 70 | if enable_padding and max(pad) > border - 4: 71 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 72 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 73 | h, w, _ = img.shape 74 | y, x, _ = np.ogrid[:h, :w, :1] 75 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 76 | blur = qsize * 0.01 77 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 78 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 79 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 80 | 81 | quad += pad[:2] 82 | 83 | 84 | # Transform. 85 | transform_size = 256 86 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 87 | 88 | if output_size < transform_size: 89 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 90 | 91 | pix = np.array(img) 92 | return pix 93 | -------------------------------------------------------------------------------- /libs/criteria/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /libs/utilities/stylespace_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import functional as F 4 | import os 5 | import math 6 | 7 | 8 | 9 | def conv_warper(layer, input, style, noise): 10 | # the conv should change 11 | conv = layer.conv 12 | batch, in_channel, height, width = input.shape 13 | 14 | style = style.view(batch, 1, in_channel, 1, 1) 15 | weight = conv.scale * conv.weight * style 16 | 17 | if conv.demodulate: 18 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 19 | weight = weight * demod.view(batch, conv.out_channel, 1, 1, 1) 20 | 21 | weight = weight.view( 22 | batch * conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size 23 | ) 24 | 25 | if conv.upsample: 26 | input = input.view(1, batch * in_channel, height, width) 27 | weight = weight.view( 28 | batch, conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size 29 | ) 30 | weight = weight.transpose(1, 2).reshape( 31 | batch * in_channel, conv.out_channel, conv.kernel_size, conv.kernel_size 32 | ) 33 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 34 | _, _, height, width = out.shape 35 | out = out.view(batch, conv.out_channel, height, width) 36 | out = conv.blur(out) 37 | 38 | elif conv.downsample: 39 | input = conv.blur(input) 40 | _, _, height, width = input.shape 41 | input = input.view(1, batch * in_channel, height, width) 42 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 43 | _, _, height, width = out.shape 44 | out = out.view(batch, conv.out_channel, height, width) 45 | 46 | else: 47 | input = input.view(1, batch * in_channel, height, width) 48 | out = F.conv2d(input, weight, padding=conv.padding, groups=batch) 49 | _, _, height, width = out.shape 50 | out = out.view(batch, conv.out_channel, height, width) 51 | 52 | out = layer.noise(out, noise=noise) 53 | out = layer.activate(out) 54 | 55 | return out 56 | 57 | def decoder(G, style_space, latent, noise, resize_image = True): 58 | # an decoder warper for G 59 | out = G.input(latent) 60 | out = conv_warper(G.conv1, out, style_space[0], noise[0]) 61 | skip = G.to_rgb1(out, latent[:, 1]) 62 | 63 | 64 | i = 1 65 | for conv1, conv2, noise1, noise2, to_rgb in zip( 66 | G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs 67 | ): 68 | 69 | out = conv_warper(conv1, out, style_space[i], noise=noise1) 70 | out = conv_warper(conv2, out, style_space[i+1], noise=noise2) 71 | skip = to_rgb(out, latent[:, i + 2], skip) 72 | i += 2 73 | 74 | image = skip 75 | 76 | if resize_image: 77 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 78 | image = face_pool(image) 79 | return image 80 | 81 | def encoder(G, noise, truncation, truncation_latent, size = 256, input_is_latent = False): 82 | style_space = [] 83 | # an encoder warper for G 84 | inject_index = None 85 | if not input_is_latent: 86 | inject_index = G.n_latent 87 | styles = [noise] 88 | styles = [G.style(s) for s in styles] 89 | else: 90 | styles = [noise] 91 | 92 | n_latent = int(math.log(size, 2))* 2 - 2 93 | if truncation < 1: 94 | style_t = [] 95 | for style in styles: 96 | style_t.append( 97 | truncation_latent + truncation * (style - truncation_latent) 98 | ) 99 | styles = style_t 100 | 101 | if len(styles) < 2: 102 | inject_index = n_latent 103 | if styles[0].ndim < 3: 104 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 105 | 106 | else: 107 | latent = styles[0] 108 | 109 | else: 110 | if inject_index is None: 111 | inject_index = random.randint(1, n_latent - 1) 112 | 113 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 114 | latent2 = styles[1].unsqueeze(1).repeat(1, n_latent - inject_index, 1) 115 | latent = torch.cat([latent, latent2], 1) 116 | 117 | noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)] 118 | 119 | style_space.append(G.conv1.conv.modulation(latent[:, 0])) 120 | i = 1 121 | for conv1, conv2, noise1, noise2, to_rgb in zip( 122 | G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs 123 | ): 124 | style_space.append(conv1.conv.modulation(latent[:, i])) 125 | style_space.append(conv2.conv.modulation(latent[:, i+1])) 126 | i += 2 127 | 128 | return style_space, latent, noise 129 | 130 | -------------------------------------------------------------------------------- /libs/utilities/utils_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torchvision import utils as torch_utils 5 | import cv2 6 | from skimage import io 7 | 8 | from libs.utilities.image_utils import read_image_opencv, torch_image_resize 9 | from libs.utilities.ffhq_cropping import align_crop_image 10 | 11 | def calculate_evaluation_metrics(params_shifted, params_target, angles_shifted, angles_target, imgs_shifted, imgs_source, id_loss_, exp_ranges): 12 | 13 | 14 | ############ Evaluation ############ 15 | yaw_reenacted = angles_shifted[:,0][0].detach().cpu().numpy() 16 | pitch_reenacted = angles_shifted[:,1][0].detach().cpu().numpy() 17 | roll_reenacted = angles_shifted[:,2][0].detach().cpu().numpy() 18 | exp_reenacted = params_shifted['alpha_exp'][0].detach().cpu().numpy() 19 | jaw_reenacted = params_shifted['pose'][0, 3].detach().cpu().numpy() 20 | 21 | yaw_target = angles_target[:,0][0].detach().cpu().numpy() 22 | pitch_target = angles_target[:,1][0].detach().cpu().numpy() 23 | roll_target = angles_target[:,2][0].detach().cpu().numpy() 24 | exp_target = params_target['alpha_exp'][0].detach().cpu().numpy() 25 | jaw_target = params_target['pose'][0, 3].detach().cpu().numpy() 26 | 27 | exp_error = [] 28 | num_expressions = 20 29 | max_range = exp_ranges[3][1] 30 | min_range = exp_ranges[3][0] 31 | jaw_target = (jaw_target - min_range)/(max_range-min_range) 32 | jaw_reenacted = (jaw_reenacted - min_range)/(max_range-min_range) 33 | exp_error.append(abs(jaw_reenacted - jaw_target)) 34 | 35 | for j in range(num_expressions): 36 | max_range = exp_ranges[j+4][1] 37 | min_range = exp_ranges[j+4][0] 38 | target = (exp_target[j] - min_range)/(max_range-min_range) 39 | reenacted = (exp_reenacted[j] - min_range)/(max_range-min_range) 40 | exp_error.append(abs(reenacted - target) ) 41 | exp_error = np.mean(exp_error) 42 | 43 | ## normalize exp coef in [0,1] 44 | # exp_error = [] 45 | # num_expressions = 12 # len(exp_target) 46 | # for j in range(num_expressions): 47 | # exp_error.append(abs(exp_reenacted[j] - exp_target[j]) ) 48 | # exp_error.append(abs(jaw_reenacted - jaw_target)) 49 | # exp_error = np.mean(exp_error) 50 | 51 | pose = (abs(yaw_reenacted-yaw_target) + abs(pitch_reenacted-pitch_target) + abs(roll_reenacted-roll_target))/3 52 | ################################################ 53 | 54 | ###### CSIM ###### 55 | loss_identity = id_loss_(imgs_shifted, imgs_source) 56 | csim = 1 - loss_identity.data.item() 57 | 58 | return csim, pose, exp_error 59 | 60 | def generate_grid_image(source, target, reenacted): 61 | num_images = source.shape[0] # batch size 62 | width = 256; height = 256 63 | grid_image = torch.zeros((3, num_images*height, 3*width)) 64 | for i in range(num_images): 65 | s = i*height 66 | e = s + height 67 | grid_image[:, s:e, :width] = source[i, :, :, :] 68 | grid_image[:, s:e, width:2*width] = target[i, :, :, :] 69 | grid_image[:, s:e, 2*width:] = reenacted[i, :, :, :] 70 | 71 | if grid_image.shape[1] > 1000: # height 72 | grid_image = torch_image_resize(grid_image, height = 800) 73 | return grid_image 74 | 75 | " Crop images using facial landmarks like FFHQ " 76 | def preprocess_image(image_path, landmarks_est, save_filename = None): 77 | 78 | image = read_image_opencv(image_path) 79 | landmarks = landmarks_est.get_landmarks(image)[0] 80 | landmarks = np.asarray(landmarks) 81 | 82 | img = align_crop_image(image, landmarks) 83 | 84 | if img is not None and save_filename is not None: 85 | cv2.imwrite(save_filename, cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR)) 86 | if img is not None: 87 | return img 88 | else: 89 | print('Error with image preprocessing') 90 | exit() 91 | 92 | " Invert real image into the latent space of StyleGAN2 " 93 | def invert_image(image, encoder, generator, truncation, trunc, save_path = None, save_name = None): 94 | with torch.no_grad(): 95 | latent_codes = encoder(image) 96 | inverted_images, _ = generator([latent_codes], input_is_latent=True, return_latents = False, truncation= truncation, truncation_latent=trunc) 97 | 98 | if save_path is not None and save_name is not None: 99 | grid = torch_utils.save_image( 100 | inverted_images, 101 | os.path.join(save_path, '{}.png'.format(save_name)), 102 | normalize=True, 103 | range=(-1, 1), 104 | ) 105 | # Latent code 106 | latent_code = latent_codes[0].detach().cpu().numpy() 107 | save_dir = os.path.join(save_path, '{}.npy'.format(save_name)) 108 | np.save(save_dir, latent_code) 109 | 110 | return inverted_images, latent_codes -------------------------------------------------------------------------------- /run_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import random 4 | import sys 5 | import json 6 | import argparse 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | sys.dont_write_bytecode = True 10 | 11 | from libs.trainer import Trainer 12 | 13 | 14 | 15 | def main(): 16 | """ 17 | Training script. 18 | Options: 19 | ######### General ########### 20 | --experiment_path : path to save experiment 21 | --use_wandb : use wandb to log losses and evaluation metrics 22 | --log_images_wandb : if True log images on wandb 23 | --project_wandb : Project name for wandb 24 | --resume_training_model : Path to model to continue training or None 25 | 26 | ######### Generator ######### 27 | --dataset_type : voxceleb or ffhq 28 | --image_resolution : image resolution of pre-trained GAN model. image resolution for voxceleb dataset is 256 29 | 30 | ######### Dataset ######### 31 | --synthetic_dataset_path : set synthetic dataset path for evaluation. npy file with random synthetic latent codes. 32 | 33 | ######### Training ######### 34 | --lr : set the learning rate of direction matrix model 35 | --num_layers_control : number of layers to apply the mask 36 | --max_iter : set maximum number of training iterations 37 | --batch_size : set training batch size 38 | 39 | --lambda_identity : identity loss weight 40 | --lambda_perceptual : perceptual loss weight 41 | --lambda_shape : shape loss weight 42 | --use_recurrent_cycle_loss : use recurrent cycle loss 43 | 44 | --steps_per_log : set number iterations per log 45 | --steps_per_save_models : set number iterations per saving model 46 | --steps_per_evaluation : set number iterations per model evaluation during training 47 | --validation_pairs : number of validation pairs for evaluation 48 | --num_pairs_log : number of pairs to visualize during evaluation 49 | 50 | ###################### 51 | 52 | python run_trainer.py --experiment_path ./training_attempts/exp_v00 53 | """ 54 | parser = argparse.ArgumentParser(description="training script") 55 | 56 | ######### General ########### 57 | parser.add_argument('--experiment_path', type=str, required = True, help="path to save the experiment") 58 | parser.add_argument('--use_wandb', dest='use_wandb', action='store_true', help="use wandb to log losses and evaluation metrics") 59 | parser.set_defaults(use_wandb=False) 60 | parser.add_argument('--log_images_wandb', dest='log_images_wandb', action='store_true', help="if True log images on wandb") 61 | parser.set_defaults(log_images_wandb=False) 62 | parser.add_argument('--project_wandb', type=str, default = 'stylespace', help="Project name for wandb") 63 | 64 | parser.add_argument('--resume_training_model', type=str, default = None, help="Path to model or None") 65 | 66 | ######### Generator ######### 67 | parser.add_argument('--image_resolution', type=int, default=1024, choices=(256, 1024), help="image resolution of pre-trained GAN modeln") 68 | parser.add_argument('--dataset_type', type=str, default='ffhq', help="set dataset name") 69 | 70 | ######### Dataset ######### 71 | parser.add_argument('--synthetic_dataset_path', type=str, default=None, help="set synthetic dataset path for evaluation") 72 | 73 | ######### Training ######### 74 | parser.add_argument('--lr', type=float, default=0.0001, help=" set the learning rate of direction matrix model") 75 | parser.add_argument('--num_layers_control', type=int, default=12, help="setnumber of layers to apply the mask") 76 | parser.add_argument('--max_iter', type=int, default=100000, help="set maximum number of training iterations") 77 | parser.add_argument('--batch_size', type=int, default=12, help="set training batch size") 78 | parser.add_argument('--test_batch_size', type=int, default=2, help="set test batch size") 79 | parser.add_argument('--workers', type=int, default=1, help="set workers") 80 | 81 | parser.add_argument('--lambda_identity', type=float, default=10.0, help="") 82 | parser.add_argument('--lambda_perceptual', type=float, default=0.0, help="") 83 | parser.add_argument('--lambda_shape', type=float, default=1.0, help="") 84 | parser.add_argument('--use_recurrent_cycle_loss', dest='use_recurrent_cycle_loss', action='store_false', help="Use recurrent cycle loss. Default is True!") 85 | 86 | parser.add_argument('--steps_per_log', type=int, default=10, help="print log") 87 | parser.add_argument('--steps_per_save_models', type=int, default=1000, help="steps per save model") 88 | parser.add_argument('--steps_per_evaluation', type=int, default=1000, help="steps per evaluation during training") 89 | parser.add_argument('--validation_pairs', type=int, default=100, help="number of pairs for evaluation") 90 | parser.add_argument('--num_pairs_log', type=int, default=4, help="how many pairs on the reenactment figure") 91 | 92 | # Parse given arguments 93 | args = parser.parse_args() 94 | args = vars(args) # convert to dictionary 95 | 96 | # Create output dir and save current arguments 97 | experiment_path = args['experiment_path'] 98 | experiment_path = experiment_path + '_{}_{}'.format(args['dataset_type'], args['image_resolution']) 99 | args['experiment_path'] = experiment_path 100 | # Set up trainer 101 | print("#. Experiment: {}".format(experiment_path)) 102 | 103 | 104 | trainer = Trainer(args) 105 | trainer.train() 106 | 107 | 108 | if __name__ == '__main__': 109 | main() 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | 8 | module_path = os.path.dirname(__file__) 9 | upfirdn2d_op = load( 10 | 'upfirdn2d', 11 | sources=[ 12 | os.path.join(module_path, 'upfirdn2d.cpp'), 13 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class UpFirDn2dBackward(Function): 19 | @staticmethod 20 | def forward( 21 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 22 | ): 23 | 24 | up_x, up_y = up 25 | down_x, down_y = down 26 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 27 | 28 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 29 | 30 | grad_input = upfirdn2d_op.upfirdn2d( 31 | grad_output, 32 | grad_kernel, 33 | down_x, 34 | down_y, 35 | up_x, 36 | up_y, 37 | g_pad_x0, 38 | g_pad_x1, 39 | g_pad_y0, 40 | g_pad_y1, 41 | ) 42 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 43 | 44 | ctx.save_for_backward(kernel) 45 | 46 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 47 | 48 | ctx.up_x = up_x 49 | ctx.up_y = up_y 50 | ctx.down_x = down_x 51 | ctx.down_y = down_y 52 | ctx.pad_x0 = pad_x0 53 | ctx.pad_x1 = pad_x1 54 | ctx.pad_y0 = pad_y0 55 | ctx.pad_y1 = pad_y1 56 | ctx.in_size = in_size 57 | ctx.out_size = out_size 58 | 59 | return grad_input 60 | 61 | @staticmethod 62 | def backward(ctx, gradgrad_input): 63 | kernel, = ctx.saved_tensors 64 | 65 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 66 | 67 | gradgrad_out = upfirdn2d_op.upfirdn2d( 68 | gradgrad_input, 69 | kernel, 70 | ctx.up_x, 71 | ctx.up_y, 72 | ctx.down_x, 73 | ctx.down_y, 74 | ctx.pad_x0, 75 | ctx.pad_x1, 76 | ctx.pad_y0, 77 | ctx.pad_y1, 78 | ) 79 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 80 | gradgrad_out = gradgrad_out.view( 81 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 82 | ) 83 | 84 | return gradgrad_out, None, None, None, None, None, None, None, None 85 | 86 | 87 | class UpFirDn2d(Function): 88 | @staticmethod 89 | def forward(ctx, input, kernel, up, down, pad): 90 | up_x, up_y = up 91 | down_x, down_y = down 92 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 93 | 94 | kernel_h, kernel_w = kernel.shape 95 | batch, channel, in_h, in_w = input.shape 96 | ctx.in_size = input.shape 97 | 98 | input = input.reshape(-1, in_h, in_w, 1) 99 | 100 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 101 | 102 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 103 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 104 | ctx.out_size = (out_h, out_w) 105 | 106 | ctx.up = (up_x, up_y) 107 | ctx.down = (down_x, down_y) 108 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 109 | 110 | g_pad_x0 = kernel_w - pad_x0 - 1 111 | g_pad_y0 = kernel_h - pad_y0 - 1 112 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 113 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 114 | 115 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 116 | 117 | out = upfirdn2d_op.upfirdn2d( 118 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 119 | ) 120 | # out = out.view(major, out_h, out_w, minor) 121 | out = out.view(-1, channel, out_h, out_w) 122 | 123 | return out 124 | 125 | @staticmethod 126 | def backward(ctx, grad_output): 127 | kernel, grad_kernel = ctx.saved_tensors 128 | 129 | grad_input = UpFirDn2dBackward.apply( 130 | grad_output, 131 | kernel, 132 | grad_kernel, 133 | ctx.up, 134 | ctx.down, 135 | ctx.pad, 136 | ctx.g_pad, 137 | ctx.in_size, 138 | ctx.out_size, 139 | ) 140 | 141 | return grad_input, None, None, None, None 142 | 143 | 144 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 145 | out = UpFirDn2d.apply( 146 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 147 | ) 148 | 149 | return out 150 | 151 | 152 | def upfirdn2d_native( 153 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 154 | ): 155 | _, in_h, in_w, minor = input.shape 156 | kernel_h, kernel_w = kernel.shape 157 | 158 | out = input.view(-1, in_h, 1, in_w, 1, minor) 159 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 160 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 161 | 162 | out = F.pad( 163 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 164 | ) 165 | out = out[ 166 | :, 167 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 168 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 169 | :, 170 | ] 171 | 172 | out = out.permute(0, 3, 1, 2) 173 | out = out.reshape( 174 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 175 | ) 176 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 177 | out = F.conv2d(out, w) 178 | out = out.reshape( 179 | -1, 180 | minor, 181 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 182 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 183 | ) 184 | out = out.permute(0, 2, 3, 1) 185 | 186 | return out[:, ::down_y, ::down_x, :] 187 | 188 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleMask: Disentangling the Style Space of StyleGAN2 for Neural Face Reenactment 2 | 3 | Authors official PyTorch implementation of **[StyleMask: Disentangling the Style Space of StyleGAN2 for Neural Face Reenactment](https://arxiv.org/abs/2209.13375)**. This paper has been accepted for publication at IEEE Conference on Automatic Face and Gesture Recognition, 2023. 4 | 5 |

6 | 7 |

8 | 9 | >**StyleMask: Disentangling the Style Space of StyleGAN2 for Neural Face Reenactment**
10 | > Stella Bounareli, Christos Tzelepis, Vasileios Argyriou, Ioannis Patras, Georgios Tzimiropoulos
11 | > 12 | > **Abstract**: In this paper we address the problem of neural face reenactment, where, given a pair of a source and a target facial image, we need to transfer the target's pose (defined as the head pose and its facial expressions) to the source image, by preserving at the same time the source's identity characteristics (e.g., facial shape, hair style, etc), even in the challenging case where the source and the target faces belong to different identities. In doing so, we address some of the limitations of the state-of-the-art works, namely, a) that they depend on paired training data (i.e., source and target faces have the same identity), b) that they rely on labeled data during inference, and c) that they do not preserve identity in large head pose changes. More specifically, we propose a framework that, using unpaired randomly generated facial images, learns to disentangle the identity characteristics of the face from its pose by incorporating the recently introduced style space $\mathcal{S}$ of StyleGAN2, a latent representation space that exhibits remarkable disentanglement properties. By capitalizing on this, we learn to successfully mix a pair of source and target style codes using supervision from a 3D model. The resulting latent code, that is subsequently used for reenactment, consists of latent units corresponding to the facial pose of the target only and of units corresponding to the identity of the source only, leading to notable improvement in the reenactment performance compared to recent state-of-the-art methods. In comparison to state of the art, we quantitatively and qualitatively show that the proposed method produces higher quality results even on extreme pose variations. Finally, we report results on real images by first embedding them on the latent space of the pretrained generator. 13 | 14 | 15 | 16 | ## Face Reenactment Results 17 |
18 | 19 | 20 |

21 | 22 |

23 | 24 |

25 | 26 |

27 | 28 | # Installation 29 | 30 | * Python 3.5+ 31 | * Linux 32 | * NVIDIA GPU + CUDA CuDNN 33 | * Pytorch (>=1.5) 34 | * [Pytorch3d](https://github.com/facebookresearch/pytorch3d) 35 | * [DECA](https://github.com/YadiraF/DECA) 36 | 37 | We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/). 38 | 39 | ``` 40 | conda create -n python38 python=3.8 41 | conda activate python38 42 | conda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=11.0 -c pytorch 43 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 44 | conda install pytorch3d -c pytorch3d 45 | pip install -r requirements.txt 46 | 47 | ``` 48 | 49 | # Pretrained Models 50 | 51 | In order to use our method, make sure to download and save the required models under `./pretrained_models` path. 52 | 53 | | Path | Description 54 | | :--- | :---------- 55 | |[StyleGAN2-FFHQ-1024](https://drive.google.com/file/d/1I01HVu9UUyzAV7rNNnbzyFqPF0msebD7/view?usp=share_link) | Official StyleGAN2 model trained on FFHQ 1024x1024 output resolution converted using [rosinality](https://github.com/rosinality/stylegan2-pytorch)(FFHQ 1024x1024 output resolution). 56 | |[e4e-FFHQ-1024](https://drive.google.com/file/d/1DexTMA3QMRNwQ3Xhdojki8UAYuY3g0uu/view?usp=share_link) | Official e4e inversion model trained on FFHQ dataset taken from [e4e](https://github.com/omertov/encoder4editing). In case of using real images use this model to invert them into the latent space of StyleGAN2. 57 | |[stylemask-model](https://drive.google.com/file/d/1_V_MnFB8rh5qrQ3zJk00fKXb81uHjIU8/view?usp=share_link) | Our pretrained StyleMask model on FFHQ 1024x1024 output resolution. 58 | 59 | # Inference 60 | Given a pair of images or latent codes transfer the target facial pose into the source face. 61 | Source and target paths could be None (generate random latent codes pairs), latent code files, image files or directories with images or latent codes. In case of input paths are real images the script will use the e4e inversion model to get the inverted latent codes. 62 | 63 | ``` 64 | python run_inference.py --output_path ./results --save_grid 65 | ``` 66 | 67 | # Training 68 | 69 | We provide additional models needed during training. 70 | 71 | | Path | Description 72 | | :--- | :---------- 73 | |[IR-SE50 Model](https://drive.google.com/file/d/1s5pWag4AwqQyhue6HH-M_f2WDV4IVZEl/view?usp=sharing) | Pretrained IR-SE50 model taken from [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) for use in our identity loss. 74 | |[DECA models](https://drive.google.com/file/d/1BHVJAEXscaXMj_p2rOsHYF_vaRRRHQbA/view?usp=sharing) | Pretrained models taken from [DECA](https://github.com/YadiraF/DECA). Extract data.tar.gz under `./libs/DECA/`. 75 | 76 | By default, we assume that all pretrained models are downloaded and saved to the directory `./pretrained_models`. 77 | 78 | ``` 79 | python run_trainer.py --experiment_path ./training_attempts/exp_v00 80 | ``` 81 | 82 | ## Citation 83 | 84 | [1] Stella Bounareli, Christos Tzelepis, Argyriou Vasileios, Ioannis Patras, Georgios Tzimiropoulos. StyleMask: Disentangling the Style Space of StyleGAN2 for Neural Face Reenactment. IEEE Conference on Automatic Face and Gesture Recognition (FG), 2023. 85 | 86 | Bibtex entry: 87 | 88 | ```bibtex 89 | @article{bounareli2022StyleMask, 90 | author = {Bounareli, Stella and Tzelepis, Christos and Argyriou, Vasileios and Patras, Ioannis and Tzimiropoulos, Georgios}, 91 | title = {StyleMask: Disentangling the Style Space of StyleGAN2 for Neural Face Reenactment}, 92 | journal = {IEEE Conference on Automatic Face and Gesture Recognition}, 93 | year = {2023}, 94 | } 95 | ``` 96 | 97 | 98 | 99 | ## Acknowledgment 100 | 101 | This work was supported by the EU's Horizon 2020 programme H2020-951911 [AI4Media](https://www.ai4media.eu/) project. 102 | 103 | -------------------------------------------------------------------------------- /libs/models/inversion/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Linear 4 | import torch.nn.functional as F 5 | 6 | """ 7 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | 16 | def l2_norm(input, axis=1): 17 | norm = torch.norm(input, 2, axis, True) 18 | output = torch.div(input, norm) 19 | return output 20 | 21 | 22 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 23 | """ A named tuple describing a ResNet block. """ 24 | 25 | 26 | def get_block(in_channel, depth, num_units, stride=2): 27 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 28 | 29 | 30 | def get_blocks(num_layers): 31 | if num_layers == 50: 32 | blocks = [ 33 | get_block(in_channel=64, depth=64, num_units=3), 34 | get_block(in_channel=64, depth=128, num_units=4), 35 | get_block(in_channel=128, depth=256, num_units=14), 36 | get_block(in_channel=256, depth=512, num_units=3) 37 | ] 38 | elif num_layers == 100: 39 | blocks = [ 40 | get_block(in_channel=64, depth=64, num_units=3), 41 | get_block(in_channel=64, depth=128, num_units=13), 42 | get_block(in_channel=128, depth=256, num_units=30), 43 | get_block(in_channel=256, depth=512, num_units=3) 44 | ] 45 | elif num_layers == 152: 46 | blocks = [ 47 | get_block(in_channel=64, depth=64, num_units=3), 48 | get_block(in_channel=64, depth=128, num_units=8), 49 | get_block(in_channel=128, depth=256, num_units=36), 50 | get_block(in_channel=256, depth=512, num_units=3) 51 | ] 52 | else: 53 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 54 | return blocks 55 | 56 | 57 | class SEModule(Module): 58 | def __init__(self, channels, reduction): 59 | super(SEModule, self).__init__() 60 | self.avg_pool = AdaptiveAvgPool2d(1) 61 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 62 | self.relu = ReLU(inplace=True) 63 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 64 | self.sigmoid = Sigmoid() 65 | 66 | def forward(self, x): 67 | module_input = x 68 | x = self.avg_pool(x) 69 | x = self.fc1(x) 70 | x = self.relu(x) 71 | x = self.fc2(x) 72 | x = self.sigmoid(x) 73 | return module_input * x 74 | 75 | 76 | class bottleneck_IR(Module): 77 | def __init__(self, in_channel, depth, stride): 78 | super(bottleneck_IR, self).__init__() 79 | if in_channel == depth: 80 | self.shortcut_layer = MaxPool2d(1, stride) 81 | else: 82 | self.shortcut_layer = Sequential( 83 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 84 | BatchNorm2d(depth) 85 | ) 86 | self.res_layer = Sequential( 87 | BatchNorm2d(in_channel), 88 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 89 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 90 | ) 91 | 92 | def forward(self, x): 93 | shortcut = self.shortcut_layer(x) 94 | res = self.res_layer(x) 95 | return res + shortcut 96 | 97 | 98 | class bottleneck_IR_SE(Module): 99 | def __init__(self, in_channel, depth, stride): 100 | super(bottleneck_IR_SE, self).__init__() 101 | if in_channel == depth: 102 | self.shortcut_layer = MaxPool2d(1, stride) 103 | else: 104 | self.shortcut_layer = Sequential( 105 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 106 | BatchNorm2d(depth) 107 | ) 108 | 109 | self.res_layer = Sequential( 110 | BatchNorm2d(in_channel), 111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 112 | PReLU(depth), 113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 114 | BatchNorm2d(depth), 115 | SEModule(depth, 16) 116 | ) 117 | 118 | def forward(self, x): 119 | 120 | shortcut = self.shortcut_layer(x) 121 | res = self.res_layer(x) 122 | return res + shortcut 123 | 124 | 125 | class SeparableConv2d(torch.nn.Module): 126 | 127 | def __init__(self, in_channels, out_channels, kernel_size, bias=False): 128 | super(SeparableConv2d, self).__init__() 129 | self.depthwise = Conv2d(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, bias=bias, padding=1) 130 | self.pointwise = Conv2d(in_channels, out_channels, kernel_size=1, bias=bias) 131 | 132 | def forward(self, x): 133 | out = self.depthwise(x) 134 | out = self.pointwise(out) 135 | return out 136 | 137 | 138 | def _upsample_add(x, y): 139 | """Upsample and add two feature maps. 140 | Args: 141 | x: (Variable) top feature map to be upsampled. 142 | y: (Variable) lateral feature map. 143 | Returns: 144 | (Variable) added feature map. 145 | Note in PyTorch, when input size is odd, the upsampled feature map 146 | with `F.upsample(..., scale_factor=2, mode='nearest')` 147 | maybe not equal to the lateral feature map size. 148 | e.g. 149 | original input size: [N,_,15,15] -> 150 | conv2d feature map size: [N,_,8,8] -> 151 | upsampled feature map size: [N,_,16,16] 152 | So we choose bilinear upsample which supports arbitrary output sizes. 153 | """ 154 | _, _, H, W = y.size() 155 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 156 | 157 | 158 | class SeparableBlock(Module): 159 | 160 | def __init__(self, input_size, kernel_channels_in, kernel_channels_out, kernel_size): 161 | super(SeparableBlock, self).__init__() 162 | 163 | self.input_size = input_size 164 | self.kernel_size = kernel_size 165 | self.kernel_channels_in = kernel_channels_in 166 | self.kernel_channels_out = kernel_channels_out 167 | 168 | self.make_kernel_in = Linear(input_size, kernel_size * kernel_size * kernel_channels_in) 169 | self.make_kernel_out = Linear(input_size, kernel_size * kernel_size * kernel_channels_out) 170 | 171 | self.kernel_linear_in = Linear(kernel_channels_in, kernel_channels_in) 172 | self.kernel_linear_out = Linear(kernel_channels_out, kernel_channels_out) 173 | 174 | def forward(self, features): 175 | 176 | features = features.view(-1, self.input_size) 177 | 178 | kernel_in = self.make_kernel_in(features).view(-1, self.kernel_size, self.kernel_size, 1, self.kernel_channels_in) 179 | kernel_out = self.make_kernel_out(features).view(-1, self.kernel_size, self.kernel_size, self.kernel_channels_out, 1) 180 | 181 | kernel = torch.matmul(kernel_out, kernel_in) 182 | 183 | kernel = self.kernel_linear_in(kernel).permute(0, 1, 2, 4, 3) 184 | kernel = self.kernel_linear_out(kernel) 185 | kernel = kernel.permute(0, 4, 3, 1, 2) 186 | 187 | return kernel 188 | -------------------------------------------------------------------------------- /libs/DECA/README.md: -------------------------------------------------------------------------------- 1 | # DECA 3D shape model 2 | 3 | https://github.com/YadiraF/DECA 4 | 5 | Install pytorch 3D 6 | https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md 7 | 8 | 9 | # Installation 10 | 11 | 12 | ## Requirements 13 | 14 | ### Core library 15 | 16 | The core library is written in PyTorch. Several components have underlying implementation in CUDA for improved performance. A subset of these components have CPU implementations in C++/Pytorch. It is advised to use PyTorch3D with GPU support in order to use all the features. 17 | 18 | - Linux or macOS or Windows 19 | - Python 3.6, 3.7 or 3.8 20 | - PyTorch 1.4, 1.5.0, 1.5.1, 1.6.0, 1.7.0, or 1.7.1. 21 | - torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this. 22 | - gcc & g++ ≥ 4.9 23 | - [fvcore](https://github.com/facebookresearch/fvcore) 24 | - [ioPath](https://github.com/facebookresearch/iopath) 25 | - If CUDA is to be used, use a version which is supported by the corresponding pytorch version and at least version 9.2. 26 | - If CUDA is to be used and you are building from source, the CUB library must be available. We recommend version 1.10.0. 27 | 28 | The runtime dependencies can be installed by running: 29 | ``` 30 | conda create -n pytorch3d python=3.8 31 | conda activate pytorch3d 32 | conda install -c pytorch pytorch=1.7.1 torchvision cudatoolkit=10.2 33 | conda install -c conda-forge -c fvcore -c iopath fvcore iopath 34 | ``` 35 | 36 | For the CUB build time dependency, if you are using conda, you can continue with 37 | ``` 38 | conda install -c bottler nvidiacub 39 | ``` 40 | Otherwise download the CUB library from https://github.com/NVIDIA/cub/releases and unpack it to a folder of your choice. 41 | Define the environment variable CUB_HOME before building and point it to the directory that contains `CMakeLists.txt` for CUB. 42 | For example on Linux/Mac, 43 | ``` 44 | curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz 45 | tar xzf 1.10.0.tar.gz 46 | export CUB_HOME=$PWD/cub-1.10.0 47 | ``` 48 | 49 | ### Tests/Linting and Demos 50 | 51 | For developing on top of PyTorch3D or contributing, you will need to run the linter and tests. If you want to run any of the notebook tutorials as `docs/tutorials` or the examples in `docs/examples` you will also need matplotlib and OpenCV. 52 | - scikit-image 53 | - black 54 | - isort 55 | - flake8 56 | - matplotlib 57 | - tdqm 58 | - jupyter 59 | - imageio 60 | - plotly 61 | - opencv-python 62 | 63 | These can be installed by running: 64 | ``` 65 | # Demos and examples 66 | conda install jupyter 67 | pip install scikit-image matplotlib imageio plotly opencv-python 68 | 69 | # Tests/Linting 70 | pip install black 'isort<5' flake8 flake8-bugbear flake8-comprehensions 71 | ``` 72 | 73 | ## Installing prebuilt binaries for PyTorch3D 74 | After installing the above dependencies, run one of the following commands: 75 | 76 | ### 1. Install with CUDA support from Anaconda Cloud, on Linux only 77 | 78 | ``` 79 | # Anaconda Cloud 80 | conda install pytorch3d -c pytorch3d 81 | ``` 82 | 83 | Or, to install a nightly (non-official, alpha) build: 84 | ``` 85 | # Anaconda Cloud 86 | conda install pytorch3d -c pytorch3d-nightly 87 | ``` 88 | ### 2. Install from PyPI, on Mac only. 89 | This works with pytorch 1.7.1 only. The build is CPU only. 90 | ``` 91 | pip install pytorch3d 92 | ``` 93 | 94 | ### 3. Install wheels for Linux 95 | We have prebuilt wheels with CUDA for Linux for PyTorch 1.7.0 and 1.7.1, for each of the CUDA versions that they support. 96 | These are installed in a special way. 97 | For example, to install for Python 3.6, PyTorch 1.7.0 and CUDA 10.1 98 | ``` 99 | pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py36_cu101_pyt170/download.html 100 | ``` 101 | 102 | In general, from inside IPython, or in Google Colab or a jupyter notebook, you can install with 103 | ``` 104 | import sys 105 | import torch 106 | version_str="".join([ 107 | f"py3{sys.version_info.minor}_cu", 108 | torch.version.cuda.replace(".",""), 109 | f"_pyt{torch.__version__[0:5:2]}" 110 | ]) 111 | !pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html 112 | ``` 113 | 114 | ## Building / installing from source. 115 | CUDA support will be included if CUDA is available in pytorch or if the environment variable 116 | `FORCE_CUDA` is set to `1`. 117 | 118 | ### 1. Install from GitHub 119 | ``` 120 | pip install "git+https://github.com/facebookresearch/pytorch3d.git" 121 | ``` 122 | To install using the code of the released version instead of from the main branch, use the following instead. 123 | ``` 124 | pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" 125 | ``` 126 | 127 | For CUDA builds with versions earlier than CUDA 11, set `CUB_HOME` before building as described above. 128 | 129 | **Install from Github on macOS:** 130 | Some environment variables should be provided, like this. 131 | ``` 132 | MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install "git+https://github.com/facebookresearch/pytorch3d.git" 133 | ``` 134 | 135 | ### 2. Install from a local clone 136 | ``` 137 | git clone https://github.com/facebookresearch/pytorch3d.git 138 | cd pytorch3d && pip install -e . 139 | ``` 140 | To rebuild after installing from a local clone run, `rm -rf build/ **/*.so` then `pip install -e .`. You often need to rebuild pytorch3d after reinstalling PyTorch. For CUDA builds with versions earlier than CUDA 11, set `CUB_HOME` before building as described above. 141 | 142 | **Install from local clone on macOS:** 143 | ``` 144 | MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install -e . 145 | ``` 146 | 147 | **Install from local clone on Windows:** 148 | 149 | If you are using pre-compiled pytorch 1.4 and torchvision 0.5, you should make the following changes to the pytorch source code to successfully compile with Visual Studio 2019 (MSVC 19.16.27034) and CUDA 10.1. 150 | 151 | Change python/Lib/site-packages/torch/include/csrc/jit/script/module.h 152 | 153 | L466, 476, 493, 506, 536 154 | ``` 155 | -static constexpr * 156 | +static const * 157 | ``` 158 | Change python/Lib/site-packages/torch/include/csrc/jit/argument_spec.h 159 | 160 | L190 161 | ``` 162 | -static constexpr size_t DEPTH_LIMIT = 128; 163 | +static const size_t DEPTH_LIMIT = 128; 164 | ``` 165 | 166 | Change python/Lib/site-packages/torch/include/pybind11/cast.h 167 | 168 | L1449 169 | ``` 170 | -explicit operator type&() { return *(this->value); } 171 | +explicit operator type& () { return *((type*)(this->value)); } 172 | ``` 173 | 174 | After patching, you can go to "x64 Native Tools Command Prompt for VS 2019" to compile and install 175 | ``` 176 | cd pytorch3d 177 | python3 setup.py install 178 | ``` 179 | After installing, verify whether all unit tests have passed 180 | ``` 181 | cd tests 182 | python3 -m unittest discover -p *.py 183 | ``` 184 | 185 | # FAQ 186 | 187 | ### Can I use Docker? 188 | 189 | We don't provide a docker file but see [#113](https://github.com/facebookresearch/pytorch3d/issues/113) for a docker file shared by a user (NOTE: this has not been tested by the PyTorch3D team). 190 | -------------------------------------------------------------------------------- /libs/models/StyleGAN2/convert_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import pickle 5 | import math 6 | 7 | import torch 8 | import numpy as np 9 | from torchvision import utils 10 | 11 | from models.StyleGAN2.model import Generator, Discriminator 12 | 13 | 14 | def convert_modconv(vars, source_name, target_name, flip=False): 15 | weight = vars[source_name + '/weight'].value().eval() 16 | mod_weight = vars[source_name + '/mod_weight'].value().eval() 17 | mod_bias = vars[source_name + '/mod_bias'].value().eval() 18 | noise = vars[source_name + '/noise_strength'].value().eval() 19 | bias = vars[source_name + '/bias'].value().eval() 20 | 21 | dic = { 22 | 'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 23 | 'conv.modulation.weight': mod_weight.transpose((1, 0)), 24 | 'conv.modulation.bias': mod_bias + 1, 25 | 'noise.weight': np.array([noise]), 26 | 'activate.bias': bias, 27 | } 28 | 29 | dic_torch = {} 30 | 31 | for k, v in dic.items(): 32 | dic_torch[target_name + '.' + k] = torch.from_numpy(v) 33 | 34 | if flip: 35 | dic_torch[target_name + '.conv.weight'] = torch.flip( 36 | dic_torch[target_name + '.conv.weight'], [3, 4] 37 | ) 38 | 39 | return dic_torch 40 | 41 | 42 | def convert_conv(vars, source_name, target_name, bias=True, start=0): 43 | weight = vars[source_name + '/weight'].value().eval() 44 | 45 | dic = {'weight': weight.transpose((3, 2, 0, 1))} 46 | 47 | if bias: 48 | dic['bias'] = vars[source_name + '/bias'].value().eval() 49 | 50 | dic_torch = {} 51 | dic_torch[target_name + '.{}.weight'.format(start)] = torch.from_numpy(dic['weight']) 52 | 53 | if bias: 54 | dic_torch[target_name + '.{}.bias'.format(start + 1)] = torch.from_numpy(dic['bias']) 55 | 56 | return dic_torch 57 | 58 | 59 | def convert_torgb(vars, source_name, target_name): 60 | weight = vars[source_name + '/weight'].value().eval() 61 | mod_weight = vars[source_name + '/mod_weight'].value().eval() 62 | mod_bias = vars[source_name + '/mod_bias'].value().eval() 63 | bias = vars[source_name + '/bias'].value().eval() 64 | 65 | dic = { 66 | 'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 67 | 'conv.modulation.weight': mod_weight.transpose((1, 0)), 68 | 'conv.modulation.bias': mod_bias + 1, 69 | 'bias': bias.reshape((1, 3, 1, 1)), 70 | } 71 | 72 | dic_torch = {} 73 | 74 | for k, v in dic.items(): 75 | dic_torch[target_name + '.' + k] = torch.from_numpy(v) 76 | 77 | return dic_torch 78 | 79 | 80 | def convert_dense(vars, source_name, target_name): 81 | weight = vars[source_name + '/weight'].value().eval() 82 | bias = vars[source_name + '/bias'].value().eval() 83 | 84 | dic = {'weight': weight.transpose((1, 0)), 'bias': bias} 85 | 86 | dic_torch = {} 87 | 88 | for k, v in dic.items(): 89 | dic_torch[target_name + '.' + k] = torch.from_numpy(v) 90 | 91 | return dic_torch 92 | 93 | 94 | def update(state_dict, new): 95 | for k, v in new.items(): 96 | if k not in state_dict: 97 | raise KeyError(k + ' is not found') 98 | 99 | if v.shape != state_dict[k].shape: 100 | raise ValueError('Shape mismatch: {} vs {}'.format(v.shape, state_dict[k].shape)) 101 | 102 | state_dict[k] = v 103 | 104 | 105 | def discriminator_fill_statedict(statedict, vars, size): 106 | log_size = int(math.log(size, 2)) 107 | 108 | update(statedict, convert_conv(vars, '{}x{}/FromRGB'.format(size, size), 'convs.0')) 109 | 110 | conv_i = 1 111 | 112 | for i in range(log_size - 2, 0, -1): 113 | reso = 4 * 2 ** i 114 | update(statedict, convert_conv(vars, '{}x{}/Conv0'.format(reso, reso), 'convs.{}.conv1'.format(conv_i))) 115 | update(statedict, convert_conv(vars, '{}x{}/Conv1_down'.format(reso, reso), 'convs.{}.conv2'.format(conv_i), start=1)) 116 | update(statedict, convert_conv(vars, '{}x{}/Skip'.format(reso, reso), 'convs.{}.skip'.format(conv_i), start=1, bias=False)) 117 | conv_i += 1 118 | 119 | update(statedict, convert_conv(vars, '4x4/Conv', 'final_conv')) 120 | update(statedict, convert_dense(vars, '4x4/Dense0', 'final_linear.0')) 121 | update(statedict, convert_dense(vars, 'Output', 'final_linear.1')) 122 | 123 | return statedict 124 | 125 | 126 | def fill_statedict(state_dict, vars, size): 127 | log_size = int(math.log(size, 2)) 128 | 129 | for i in range(8): 130 | update(state_dict, convert_dense(vars, 'G_mapping/Dense{}'.format(i), 'style.{}'.format(i + 1))) 131 | 132 | update( 133 | state_dict, 134 | { 135 | 'input.input': torch.from_numpy( 136 | vars['G_synthesis/4x4/Const/const'].value().eval() 137 | ) 138 | }, 139 | ) 140 | 141 | update(state_dict, convert_torgb(vars, 'G_synthesis/4x4/ToRGB', 'to_rgb1')) 142 | 143 | for i in range(log_size - 2): 144 | reso = 4 * 2 ** (i + 1) 145 | update( 146 | state_dict, 147 | convert_torgb(vars, 'G_synthesis/{}x{}/ToRGB'.format(reso, reso), 'to_rgbs.{}'.format(i)), 148 | ) 149 | 150 | update(state_dict, convert_modconv(vars, 'G_synthesis/4x4/Conv', 'conv1')) 151 | 152 | conv_i = 0 153 | 154 | for i in range(log_size - 2): 155 | reso = 4 * 2 ** (i + 1) 156 | update( 157 | state_dict, 158 | convert_modconv( 159 | vars, 160 | 'G_synthesis/{}x{}/Conv0_up'.format(reso, reso), 161 | 'convs.{}'.format(conv_i), 162 | flip=True, 163 | ), 164 | ) 165 | update( 166 | state_dict, 167 | convert_modconv( 168 | vars, 169 | 'G_synthesis/{}x{}/Conv1'.format(reso, reso), 170 | 'convs.{}'.format(conv_i + 1) 171 | ), 172 | ) 173 | conv_i += 2 174 | 175 | for i in range(0, (log_size - 2) * 2 + 1): 176 | update( 177 | state_dict, 178 | { 179 | 'noises.noise_{}'.format(i): torch.from_numpy( 180 | vars['G_synthesis/noise{}'.format(i)].value().eval() 181 | ) 182 | }, 183 | ) 184 | 185 | return state_dict 186 | 187 | 188 | if __name__ == '__main__': 189 | device = 'cuda' 190 | 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument('--repo', type=str, required=True) 193 | parser.add_argument('--gen', action='store_true') 194 | parser.add_argument('--disc', action='store_true') 195 | parser.add_argument('path', metavar='PATH') 196 | 197 | args = parser.parse_args() 198 | 199 | sys.path.append(args.repo) 200 | 201 | import dnnlib 202 | from dnnlib import tflib 203 | 204 | tflib.init_tf() 205 | 206 | with open(args.path, 'rb') as f: 207 | generator, discriminator, g_ema = pickle.load(f) 208 | 209 | size = g_ema.output_shape[2] 210 | 211 | g = Generator(size, 512, 8) 212 | state_dict = g.state_dict() 213 | state_dict = fill_statedict(state_dict, g_ema.vars, size) 214 | 215 | g.load_state_dict(state_dict) 216 | 217 | latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval()) 218 | 219 | ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg} 220 | 221 | if args.gen: 222 | g_train = Generator(size, 512, 8) 223 | g_train_state = g_train.state_dict() 224 | g_train_state = fill_statedict(g_train_state, generator.vars, size) 225 | ckpt['g'] = g_train_state 226 | 227 | if args.disc: 228 | disc = Discriminator(size) 229 | d_state = disc.state_dict() 230 | d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) 231 | ckpt['d'] = d_state 232 | 233 | name = os.path.splitext(os.path.basename(args.path))[0] 234 | torch.save(ckpt, name + '.pt') 235 | 236 | batch_size = {256: 16, 512: 9, 1024: 4} 237 | n_sample = batch_size.get(size, 25) 238 | 239 | g = g.to(device) 240 | 241 | z = np.random.RandomState(0).randn(n_sample, 512).astype('float32') 242 | 243 | with torch.no_grad(): 244 | img_pt, _ = g([torch.from_numpy(z).to(device)], truncation=0.5, truncation_latent=latent_avg.to(device)) 245 | 246 | Gs_kwargs = dnnlib.EasyDict() 247 | Gs_kwargs.randomize_noise = False 248 | img_tf = g_ema.run(z, None, **Gs_kwargs) 249 | img_tf = torch.from_numpy(img_tf).to(device) 250 | 251 | img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(0.0, 1.0) 252 | 253 | img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) 254 | utils.save_image(img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1)) 255 | 256 | -------------------------------------------------------------------------------- /libs/models/StyleGAN2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /libs/DECA/decalib/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Soubhik Sanyal 3 | Copyright (c) 2019, Soubhik Sanyal 4 | All rights reserved. 5 | Loads different resnet models 6 | """ 7 | ''' 8 | file: Resnet.py 9 | date: 2018_05_02 10 | author: zhangxiong(1025679612@qq.com) 11 | mark: copied from pytorch source code 12 | ''' 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch 17 | from torch.nn.parameter import Parameter 18 | import torch.optim as optim 19 | import numpy as np 20 | import math 21 | import torchvision 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, block, layers, num_classes=1000): 25 | self.inplanes = 64 26 | super(ResNet, self).__init__() 27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 28 | bias=False) 29 | self.bn1 = nn.BatchNorm2d(64) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 32 | self.layer1 = self._make_layer(block, 64, layers[0]) 33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 36 | self.avgpool = nn.AvgPool2d(7, stride=1) 37 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 38 | 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2. / n)) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | 47 | def _make_layer(self, block, planes, blocks, stride=1): 48 | downsample = None 49 | if stride != 1 or self.inplanes != planes * block.expansion: 50 | downsample = nn.Sequential( 51 | nn.Conv2d(self.inplanes, planes * block.expansion, 52 | kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(planes * block.expansion), 54 | ) 55 | 56 | layers = [] 57 | layers.append(block(self.inplanes, planes, stride, downsample)) 58 | self.inplanes = planes * block.expansion 59 | for i in range(1, blocks): 60 | layers.append(block(self.inplanes, planes)) 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | x = self.conv1(x) 66 | x = self.bn1(x) 67 | x = self.relu(x) 68 | x = self.maxpool(x) 69 | 70 | x = self.layer1(x) 71 | x = self.layer2(x) 72 | x = self.layer3(x) 73 | x1 = self.layer4(x) 74 | 75 | x2 = self.avgpool(x1) 76 | x2 = x2.view(x2.size(0), -1) 77 | # x = self.fc(x) 78 | ## x2: [bz, 2048] for shape 79 | ## x1: [bz, 2048, 7, 7] for texture 80 | return x2 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | 85 | def __init__(self, inplanes, planes, stride=1, downsample=None): 86 | super(Bottleneck, self).__init__() 87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(planes) 89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 90 | padding=1, bias=False) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 93 | self.bn3 = nn.BatchNorm2d(planes * 4) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out += residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | def conv3x3(in_planes, out_planes, stride=1): 121 | """3x3 convolution with padding""" 122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 123 | padding=1, bias=False) 124 | 125 | class BasicBlock(nn.Module): 126 | expansion = 1 127 | 128 | def __init__(self, inplanes, planes, stride=1, downsample=None): 129 | super(BasicBlock, self).__init__() 130 | self.conv1 = conv3x3(inplanes, planes, stride) 131 | self.bn1 = nn.BatchNorm2d(planes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.conv2 = conv3x3(planes, planes) 134 | self.bn2 = nn.BatchNorm2d(planes) 135 | self.downsample = downsample 136 | self.stride = stride 137 | 138 | def forward(self, x): 139 | residual = x 140 | 141 | out = self.conv1(x) 142 | out = self.bn1(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv2(out) 146 | out = self.bn2(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | def copy_parameter_from_resnet(model, resnet_dict): 157 | cur_state_dict = model.state_dict() 158 | # import ipdb; ipdb.set_trace() 159 | for name, param in list(resnet_dict.items())[0:None]: 160 | if name not in cur_state_dict: 161 | # print(name, ' not available in reconstructed resnet') 162 | continue 163 | if isinstance(param, Parameter): 164 | param = param.data 165 | try: 166 | cur_state_dict[name].copy_(param) 167 | except: 168 | print(name, ' is inconsistent!') 169 | continue 170 | # print('copy resnet state dict finished!') 171 | # import ipdb; ipdb.set_trace() 172 | 173 | def load_ResNet50Model(): 174 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 175 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = True).state_dict()) 176 | return model 177 | 178 | def load_ResNet101Model(): 179 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 180 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict()) 181 | return model 182 | 183 | def load_ResNet152Model(): 184 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 185 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict()) 186 | return model 187 | 188 | # model.load_state_dict(checkpoint['model_state_dict']) 189 | 190 | 191 | ######## Unet 192 | 193 | class DoubleConv(nn.Module): 194 | """(convolution => [BN] => ReLU) * 2""" 195 | 196 | def __init__(self, in_channels, out_channels): 197 | super().__init__() 198 | self.double_conv = nn.Sequential( 199 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 200 | nn.BatchNorm2d(out_channels), 201 | nn.ReLU(inplace=True), 202 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 203 | nn.BatchNorm2d(out_channels), 204 | nn.ReLU(inplace=True) 205 | ) 206 | 207 | def forward(self, x): 208 | return self.double_conv(x) 209 | 210 | 211 | class Down(nn.Module): 212 | """Downscaling with maxpool then double conv""" 213 | 214 | def __init__(self, in_channels, out_channels): 215 | super().__init__() 216 | self.maxpool_conv = nn.Sequential( 217 | nn.MaxPool2d(2), 218 | DoubleConv(in_channels, out_channels) 219 | ) 220 | 221 | def forward(self, x): 222 | return self.maxpool_conv(x) 223 | 224 | 225 | class Up(nn.Module): 226 | """Upscaling then double conv""" 227 | 228 | def __init__(self, in_channels, out_channels, bilinear=True): 229 | super().__init__() 230 | 231 | # if bilinear, use the normal convolutions to reduce the number of channels 232 | if bilinear: 233 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 234 | else: 235 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 236 | 237 | self.conv = DoubleConv(in_channels, out_channels) 238 | 239 | def forward(self, x1, x2): 240 | x1 = self.up(x1) 241 | # input is CHW 242 | diffY = x2.size()[2] - x1.size()[2] 243 | diffX = x2.size()[3] - x1.size()[3] 244 | 245 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 246 | diffY // 2, diffY - diffY // 2]) 247 | # if you have padding issues, see 248 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 249 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 250 | x = torch.cat([x2, x1], dim=1) 251 | return self.conv(x) 252 | 253 | 254 | class OutConv(nn.Module): 255 | def __init__(self, in_channels, out_channels): 256 | super(OutConv, self).__init__() 257 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 258 | 259 | def forward(self, x): 260 | return self.conv(x) 261 | 262 | class UNet(nn.Module): 263 | def __init__(self, n_channels, n_classes, bilinear=True): 264 | super(UNet, self).__init__() 265 | self.n_channels = n_channels 266 | self.n_classes = n_classes 267 | self.bilinear = bilinear 268 | 269 | self.inc = DoubleConv(n_channels, 64) 270 | self.down1 = Down(64, 128) 271 | self.down2 = Down(128, 256) 272 | self.down3 = Down(256, 512) 273 | self.down4 = Down(512, 512) 274 | self.up1 = Up(1024, 256, bilinear) 275 | self.up2 = Up(512, 128, bilinear) 276 | self.up3 = Up(256, 64, bilinear) 277 | self.up4 = Up(128, 64, bilinear) 278 | self.outc = OutConv(64, n_classes) 279 | 280 | def forward(self, x): 281 | x1 = self.inc(x) 282 | x2 = self.down1(x1) 283 | x3 = self.down2(x2) 284 | x4 = self.down3(x3) 285 | x5 = self.down4(x4) 286 | x = self.up1(x5, x4) 287 | x = self.up2(x, x3) 288 | x = self.up3(x, x2) 289 | x = self.up4(x, x1) 290 | x = F.normalize(x) 291 | return x -------------------------------------------------------------------------------- /libs/models/inversion/psp_encoders.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 7 | 8 | from libs.models.inversion.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add 9 | from libs.models.StyleGAN2.model import EqualLinear, ScaledLeakyReLU, EqualConv2d 10 | 11 | 12 | class ProgressiveStage(Enum): 13 | WTraining = 0 14 | Delta1Training = 1 15 | Delta2Training = 2 16 | Delta3Training = 3 17 | Delta4Training = 4 18 | Delta5Training = 5 19 | Delta6Training = 6 20 | Delta7Training = 7 21 | Delta8Training = 8 22 | Delta9Training = 9 23 | Delta10Training = 10 24 | Delta11Training = 11 25 | Delta12Training = 12 26 | Delta13Training = 13 27 | Delta14Training = 14 28 | Delta15Training = 15 29 | Delta16Training = 16 30 | Delta17Training = 17 31 | Inference = 18 32 | 33 | 34 | class GradualStyleBlock(Module): 35 | def __init__(self, in_c, out_c, spatial): 36 | super(GradualStyleBlock, self).__init__() 37 | self.out_c = out_c 38 | self.spatial = spatial 39 | num_pools = int(np.log2(spatial)) 40 | modules = [] 41 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 42 | nn.LeakyReLU()] 43 | for i in range(num_pools - 1): 44 | modules += [ 45 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 46 | nn.LeakyReLU() 47 | ] 48 | self.convs = nn.Sequential(*modules) 49 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 50 | 51 | def forward(self, x): 52 | x = self.convs(x) 53 | x = x.view(-1, self.out_c) 54 | x = self.linear(x) 55 | return x 56 | 57 | 58 | class GradualStyleEncoder(Module): 59 | def __init__(self, num_layers, mode='ir', opts=None): 60 | super(GradualStyleEncoder, self).__init__() 61 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 62 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 63 | blocks = get_blocks(num_layers) 64 | if mode == 'ir': 65 | unit_module = bottleneck_IR 66 | elif mode == 'ir_se': 67 | unit_module = bottleneck_IR_SE 68 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 69 | BatchNorm2d(64), 70 | PReLU(64)) 71 | modules = [] 72 | for block in blocks: 73 | for bottleneck in block: 74 | modules.append(unit_module(bottleneck.in_channel, 75 | bottleneck.depth, 76 | bottleneck.stride)) 77 | self.body = Sequential(*modules) 78 | 79 | self.styles = nn.ModuleList() 80 | log_size = int(math.log(opts.output_size, 2)) 81 | self.style_count = 2 * log_size - 2 82 | self.coarse_ind = 3 83 | self.middle_ind = 7 84 | for i in range(self.style_count): 85 | if i < self.coarse_ind: 86 | style = GradualStyleBlock(512, 512, 16) 87 | elif i < self.middle_ind: 88 | style = GradualStyleBlock(512, 512, 32) 89 | else: 90 | style = GradualStyleBlock(512, 512, 64) 91 | self.styles.append(style) 92 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 93 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 94 | 95 | def forward(self, x): 96 | x = self.input_layer(x) 97 | 98 | latents = [] 99 | modulelist = list(self.body._modules.values()) 100 | for i, l in enumerate(modulelist): 101 | x = l(x) 102 | if i == 6: 103 | c1 = x 104 | elif i == 20: 105 | c2 = x 106 | elif i == 23: 107 | c3 = x 108 | 109 | for j in range(self.coarse_ind): 110 | latents.append(self.styles[j](c3)) 111 | 112 | p2 = _upsample_add(c3, self.latlayer1(c2)) 113 | for j in range(self.coarse_ind, self.middle_ind): 114 | latents.append(self.styles[j](p2)) 115 | 116 | p1 = _upsample_add(p2, self.latlayer2(c1)) 117 | for j in range(self.middle_ind, self.style_count): 118 | latents.append(self.styles[j](p1)) 119 | 120 | out = torch.stack(latents, dim=1) 121 | return out 122 | 123 | 124 | class Encoder4Editing(Module): 125 | def __init__(self, num_layers, mode='ir', opts=None): 126 | super(Encoder4Editing, self).__init__() 127 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 128 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 129 | blocks = get_blocks(num_layers) 130 | if mode == 'ir': 131 | unit_module = bottleneck_IR 132 | elif mode == 'ir_se': 133 | unit_module = bottleneck_IR_SE 134 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 135 | BatchNorm2d(64), 136 | PReLU(64)) 137 | modules = [] 138 | for block in blocks: 139 | for bottleneck in block: 140 | modules.append(unit_module(bottleneck.in_channel, 141 | bottleneck.depth, 142 | bottleneck.stride)) 143 | self.body = Sequential(*modules) 144 | 145 | self.styles = nn.ModuleList() 146 | log_size = int(math.log(opts.output_size, 2)) 147 | self.style_count = 2 * log_size - 2 148 | self.coarse_ind = 3 149 | self.middle_ind = 7 150 | 151 | for i in range(self.style_count): 152 | if i < self.coarse_ind: 153 | style = GradualStyleBlock(512, 512, 16) 154 | elif i < self.middle_ind: 155 | style = GradualStyleBlock(512, 512, 32) 156 | else: 157 | style = GradualStyleBlock(512, 512, 64) 158 | self.styles.append(style) 159 | 160 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 161 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 162 | 163 | self.progressive_stage = ProgressiveStage.Inference 164 | 165 | def get_deltas_starting_dimensions(self): 166 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 167 | return list(range(self.style_count)) # Each dimension has a delta applied to it 168 | 169 | def set_progressive_stage(self, new_stage: ProgressiveStage): 170 | self.progressive_stage = new_stage 171 | print('Changed progressive stage to: ', new_stage) 172 | 173 | def forward(self, x): 174 | x = self.input_layer(x) 175 | 176 | modulelist = list(self.body._modules.values()) 177 | for i, l in enumerate(modulelist): 178 | x = l(x) 179 | if i == 6: 180 | c1 = x 181 | elif i == 20: 182 | c2 = x 183 | elif i == 23: 184 | c3 = x 185 | 186 | # Infer main W and duplicate it 187 | w0 = self.styles[0](c3) 188 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 189 | stage = self.progressive_stage.value 190 | features = c3 191 | for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas 192 | if i == self.coarse_ind: 193 | p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features 194 | features = p2 195 | elif i == self.middle_ind: 196 | p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features 197 | features = p1 198 | delta_i = self.styles[i](features) 199 | w[:, i] += delta_i 200 | 201 | return w 202 | 203 | 204 | class BackboneEncoderUsingLastLayerIntoW(Module): 205 | def __init__(self, num_layers, mode='ir', opts=None): 206 | super(BackboneEncoderUsingLastLayerIntoW, self).__init__() 207 | print('Using BackboneEncoderUsingLastLayerIntoW') 208 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 209 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 210 | blocks = get_blocks(num_layers) 211 | if mode == 'ir': 212 | unit_module = bottleneck_IR 213 | elif mode == 'ir_se': 214 | unit_module = bottleneck_IR_SE 215 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 216 | BatchNorm2d(64), 217 | PReLU(64)) 218 | self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) 219 | self.linear = EqualLinear(512, 512, lr_mul=1) 220 | modules = [] 221 | for block in blocks: 222 | for bottleneck in block: 223 | modules.append(unit_module(bottleneck.in_channel, 224 | bottleneck.depth, 225 | bottleneck.stride)) 226 | self.body = Sequential(*modules) 227 | log_size = int(math.log(opts.output_size, 2)) 228 | self.style_count = 2 * log_size - 2 229 | 230 | def forward(self, x): 231 | x = self.input_layer(x) 232 | x = self.body(x) 233 | x = self.output_pool(x) 234 | x = x.view(-1, 512) 235 | x = self.linear(x) 236 | return x.repeat(self.style_count, 1, 1).permute(1, 0, 2) 237 | 238 | # Consultation encoder 239 | class ResidualEncoder(Module): 240 | def __init__(self, opts=None): 241 | super(ResidualEncoder, self).__init__() 242 | self.conv_layer1 = Sequential(Conv2d(3, 32, (3, 3), 1, 1, bias=False), 243 | BatchNorm2d(32), 244 | PReLU(32)) 245 | 246 | self.conv_layer2 = Sequential(*[bottleneck_IR(32,48,2), bottleneck_IR(48,48,1), bottleneck_IR(48,48,1)]) 247 | 248 | self.conv_layer3 = Sequential(*[bottleneck_IR(48,64,2), bottleneck_IR(64,64,1), bottleneck_IR(64,64,1)]) 249 | 250 | self.condition_scale3 = nn.Sequential( 251 | EqualConv2d(64, 512, 3, stride=1, padding=1, bias=True ), 252 | ScaledLeakyReLU(0.2), 253 | EqualConv2d(512, 512, 3, stride=1, padding=1, bias=True )) 254 | 255 | self.condition_shift3 = nn.Sequential( 256 | EqualConv2d(64, 512, 3, stride=1, padding=1, bias=True ), 257 | ScaledLeakyReLU(0.2), 258 | EqualConv2d(512, 512, 3, stride=1, padding=1, bias=True )) 259 | 260 | 261 | 262 | def get_deltas_starting_dimensions(self): 263 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 264 | return list(range(self.style_count)) # Each dimension has a delta applied to it 265 | 266 | 267 | 268 | def forward(self, x): 269 | conditions = [] 270 | feat1 = self.conv_layer1(x) 271 | feat2 = self.conv_layer2(feat1) 272 | feat3 = self.conv_layer3(feat2) 273 | 274 | scale = self.condition_scale3(feat3) 275 | scale = torch.nn.functional.interpolate(scale, size=(64,64) , mode='bilinear') 276 | conditions.append(scale.clone()) 277 | shift = self.condition_shift3(feat3) 278 | shift = torch.nn.functional.interpolate(shift, size=(64,64) , mode='bilinear') 279 | conditions.append(shift.clone()) 280 | return conditions 281 | 282 | 283 | # ADA 284 | class ResidualAligner(Module): 285 | def __init__(self, opts=None): 286 | super(ResidualAligner, self).__init__() 287 | self.conv_layer1 = Sequential(Conv2d(6, 16, (3, 3), 1, 1, bias=False), 288 | BatchNorm2d(16), 289 | PReLU(16)) 290 | 291 | self.conv_layer2 = Sequential(*[bottleneck_IR(16,32,2), bottleneck_IR(32,32,1), bottleneck_IR(32,32,1)]) 292 | self.conv_layer3 = Sequential(*[bottleneck_IR(32,48,2), bottleneck_IR(48,48,1), bottleneck_IR(48,48,1)]) 293 | self.conv_layer4 = Sequential(*[bottleneck_IR(48,64,2), bottleneck_IR(64,64,1), bottleneck_IR(64,64,1)]) 294 | 295 | self.dconv_layer1 = Sequential(*[bottleneck_IR(112,64,1), bottleneck_IR(64,32,1), bottleneck_IR(32,32,1)]) 296 | self.dconv_layer2 = Sequential(*[bottleneck_IR(64,32,1), bottleneck_IR(32,16,1), bottleneck_IR(16,16,1)]) 297 | self.dconv_layer3 = Sequential(*[bottleneck_IR(32,16,1), bottleneck_IR(16,3,1), bottleneck_IR(3,3,1)]) 298 | 299 | def forward(self, x): 300 | feat1 = self.conv_layer1(x) 301 | feat2 = self.conv_layer2(feat1) 302 | feat3 = self.conv_layer3(feat2) 303 | feat4 = self.conv_layer4(feat3) 304 | 305 | feat4 = torch.nn.functional.interpolate(feat4, size=(64,64) , mode='bilinear') 306 | dfea1 = self.dconv_layer1(torch.cat((feat4, feat3),1)) 307 | dfea1 = torch.nn.functional.interpolate(dfea1, size=(128,128) , mode='bilinear') 308 | dfea2 = self.dconv_layer2(torch.cat( (dfea1, feat2),1)) 309 | dfea2 = torch.nn.functional.interpolate(dfea2, size=(256,256) , mode='bilinear') 310 | dfea3 = self.dconv_layer3(torch.cat( (dfea2, feat1),1)) 311 | 312 | res_aligned = dfea3 313 | 314 | return res_aligned 315 | 316 | -------------------------------------------------------------------------------- /libs/DECA/decalib/models/FLAME.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import pickle 20 | import torch.nn.functional as F 21 | 22 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler 23 | 24 | def to_tensor(array, dtype=torch.float32): 25 | if 'torch.tensor' not in str(type(array)): 26 | return torch.tensor(array, dtype=dtype) 27 | def to_np(array, dtype=np.float32): 28 | if 'scipy.sparse' in str(type(array)): 29 | array = array.todense() 30 | return np.array(array, dtype=dtype) 31 | 32 | class Struct(object): 33 | def __init__(self, **kwargs): 34 | for key, val in kwargs.items(): 35 | setattr(self, key, val) 36 | 37 | class FLAME(nn.Module): 38 | """ 39 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py 40 | Given flame parameters this class generates a differentiable FLAME function 41 | which outputs the a mesh and 2D/3D facial landmarks 42 | """ 43 | def __init__(self, config): 44 | super(FLAME, self).__init__() 45 | print("creating the FLAME Decoder") 46 | with open(config.flame_model_path, 'rb') as f: 47 | ss = pickle.load(f, encoding='latin1') 48 | flame_model = Struct(**ss) 49 | 50 | self.dtype = torch.float32 51 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) 52 | # The vertices of the template model 53 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) 54 | # The shape components and expression 55 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) 56 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2) 57 | self.register_buffer('shapedirs', shapedirs) 58 | # The pose components 59 | num_pose_basis = flame_model.posedirs.shape[-1] 60 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T 61 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) 62 | # 63 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) 64 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1 65 | self.register_buffer('parents', parents) 66 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) 67 | 68 | # Fixing Eyeball and neck rotation 69 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False) 70 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose, 71 | requires_grad=False)) 72 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) 73 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose, 74 | requires_grad=False)) 75 | 76 | # Static and Dynamic Landmark embeddings for FLAME 77 | lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1') 78 | lmk_embeddings = lmk_embeddings[()] 79 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long()) 80 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype)) 81 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long()) 82 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype)) 83 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long()) 84 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype)) 85 | 86 | neck_kin_chain = []; NECK_IDX=1 87 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) 88 | while curr_idx != -1: 89 | neck_kin_chain.append(curr_idx) 90 | curr_idx = self.parents[curr_idx] 91 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) 92 | 93 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx, 94 | dynamic_lmk_b_coords, 95 | neck_kin_chain, dtype=torch.float32): 96 | """ 97 | Selects the face contour depending on the reletive position of the head 98 | Input: 99 | vertices: N X num_of_vertices X 3 100 | pose: N X full pose 101 | dynamic_lmk_faces_idx: The list of contour face indexes 102 | dynamic_lmk_b_coords: The list of contour barycentric weights 103 | neck_kin_chain: The tree to consider for the relative rotation 104 | dtype: Data type 105 | return: 106 | The contour face indexes and the corresponding barycentric weights 107 | """ 108 | 109 | batch_size = pose.shape[0] 110 | 111 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 112 | neck_kin_chain) 113 | rot_mats = batch_rodrigues( 114 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 115 | 116 | rel_rot_mat = torch.eye(3, device=pose.device, 117 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) 118 | for idx in range(len(neck_kin_chain)): 119 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 120 | 121 | y_rot_angle = torch.round( 122 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 123 | max=39)).to(dtype=torch.long) 124 | 125 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 126 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 127 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 128 | y_rot_angle = (neg_mask * neg_vals + 129 | (1 - neg_mask) * y_rot_angle) 130 | 131 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 132 | 0, y_rot_angle) 133 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 134 | 0, y_rot_angle) 135 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 136 | 137 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): 138 | """ 139 | Calculates landmarks by barycentric interpolation 140 | Input: 141 | vertices: torch.tensor NxVx3, dtype = torch.float32 142 | The tensor of input vertices 143 | faces: torch.tensor (N*F)x3, dtype = torch.long 144 | The faces of the mesh 145 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long 146 | The tensor with the indices of the faces used to calculate the 147 | landmarks. 148 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 149 | The tensor of barycentric coordinates that are used to interpolate 150 | the landmarks 151 | 152 | Returns: 153 | landmarks: torch.tensor NxLx3, dtype = torch.float32 154 | The coordinates of the landmarks for each mesh in the batch 155 | """ 156 | # Extract the indices of the vertices for each face 157 | # NxLx3 158 | batch_size, num_verts = vertices.shape[:dd2] 159 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 160 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1) 161 | 162 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to( 163 | device=vertices.device) * num_verts 164 | 165 | lmk_vertices = vertices.view(-1, 3)[lmk_faces] 166 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 167 | return landmarks 168 | 169 | def seletec_3d68(self, vertices): 170 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 171 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1), 172 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) 173 | return landmarks3d 174 | 175 | def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None): 176 | """ 177 | Input: 178 | shape_params: N X number of shape parameters 179 | expression_params: N X number of expression parameters 180 | pose_params: N X number of pose parameters (6) 181 | return:d 182 | vertices: N X V X 3 183 | landmarks: N X number of landmarks X 3 184 | """ 185 | batch_size = shape_params.shape[0] 186 | if eye_pose_params is None: 187 | eye_pose_params = self.eye_pose.expand(batch_size, -1) 188 | betas = torch.cat([shape_params, expression_params], dim=1) 189 | full_pose = torch.cat([pose_params[:, :3], self.neck_pose.expand(batch_size, -1), pose_params[:, 3:], eye_pose_params], dim=1) 190 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) 191 | 192 | vertices, _ = lbs(betas, full_pose, template_vertices, 193 | self.shapedirs, self.posedirs, 194 | self.J_regressor, self.parents, 195 | self.lbs_weights, dtype=self.dtype) 196 | 197 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) 198 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) 199 | 200 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( 201 | full_pose, self.dynamic_lmk_faces_idx, 202 | self.dynamic_lmk_bary_coords, 203 | self.neck_kin_chain, dtype=self.dtype) 204 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) 205 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) 206 | 207 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor, 208 | lmk_faces_idx, 209 | lmk_bary_coords) 210 | bz = vertices.shape[0] 211 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 212 | self.full_lmk_faces_idx.repeat(bz, 1), 213 | self.full_lmk_bary_coords.repeat(bz, 1, 1)) 214 | return vertices, landmarks2d, landmarks3d 215 | 216 | class FLAMETex(nn.Module): 217 | """ 218 | FLAME texture: 219 | https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64 220 | FLAME texture converted from BFM: 221 | https://github.com/TimoBolkart/BFM_to_FLAME 222 | """ 223 | def __init__(self, config): 224 | super(FLAMETex, self).__init__() 225 | if config.tex_type == 'BFM': 226 | mu_key = 'MU' 227 | pc_key = 'PC' 228 | n_pc = 199 229 | tex_path = config.tex_path 230 | tex_space = np.load(tex_path) 231 | texture_mean = tex_space[mu_key].reshape(1, -1) 232 | texture_basis = tex_space[pc_key].reshape(-1, n_pc) 233 | 234 | elif config.tex_type == 'FLAME': 235 | mu_key = 'mean' 236 | pc_key = 'tex_dir' 237 | n_pc = 200 238 | tex_path = config.flame_tex_path 239 | tex_space = np.load(tex_path) 240 | texture_mean = tex_space[mu_key].reshape(1, -1)/255. 241 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)/255. 242 | else: 243 | print('texture type ', config.tex_type, 'not exist!') 244 | raise NotImplementedError 245 | 246 | n_tex = config.n_tex 247 | num_components = texture_basis.shape[1] 248 | texture_mean = torch.from_numpy(texture_mean).float()[None,...] 249 | texture_basis = torch.from_numpy(texture_basis[:,:n_tex]).float()[None,...] 250 | self.register_buffer('texture_mean', texture_mean) 251 | self.register_buffer('texture_basis', texture_basis) 252 | 253 | def forward(self, texcode): 254 | ''' 255 | texcode: [batchsize, n_tex] 256 | texture: [bz, 3, 256, 256], range: 0-1 257 | ''' 258 | texture = self.texture_mean + (self.texture_basis*texcode[:,None,:]).sum(-1) 259 | texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0,3,1,2) 260 | texture = F.interpolate(texture, [256, 256]) 261 | texture = texture[:,[2,1,0], :,:] 262 | return texture -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import random 4 | import sys 5 | import argparse 6 | from argparse import Namespace 7 | import torch 8 | from torch import nn 9 | import numpy as np 10 | import warnings 11 | from tqdm import tqdm 12 | warnings.filterwarnings("ignore") 13 | sys.dont_write_bytecode = True 14 | 15 | seed = 0 16 | random.seed(seed) 17 | import face_alignment 18 | 19 | from libs.models.StyleGAN2.model import Generator as StyleGAN2Generator 20 | from libs.models.mask_predictor import MaskPredictor 21 | from libs.utilities.utils import make_noise, generate_image, generate_new_stylespace, save_image, save_grid, get_files_frompath 22 | from libs.utilities.stylespace_utils import decoder 23 | from libs.configs.config_models import stylegan2_ffhq_1024 24 | from libs.utilities.utils_inference import preprocess_image, invert_image 25 | from libs.utilities.image_utils import image_to_tensor 26 | from libs.models.inversion.psp import pSp 27 | 28 | class Inference_demo(): 29 | 30 | def __init__(self, args): 31 | self.args = args 32 | 33 | self.device = 'cuda' 34 | self.output_path = args['output_path'] 35 | arguments_json = os.path.join(self.output_path, 'arguments.json') 36 | self.masknet_path = args['masknet_path'] 37 | self.image_resolution = args['image_resolution'] 38 | self.dataset = args['dataset'] 39 | 40 | self.source_path = args['source_path'] 41 | self.target_path = args['target_path'] 42 | self.num_pairs = args['num_pairs'] 43 | 44 | self.save_grid = args['save_grid'] 45 | self.save_image = args['save_image'] 46 | self.resize_image = args['resize_image'] 47 | 48 | if not os.path.exists(self.output_path): 49 | os.makedirs(self.output_path, exist_ok=True) 50 | 51 | def load_models(self, inversion): 52 | 53 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 54 | 55 | if self.dataset == 'ffhq' and self.image_resolution == 1024: 56 | self.generator_path = stylegan2_ffhq_1024['gan_weights'] 57 | self.channel_multiplier = stylegan2_ffhq_1024['channel_multiplier'] 58 | self.split_sections = stylegan2_ffhq_1024['split_sections'] 59 | self.stylespace_dim = stylegan2_ffhq_1024['stylespace_dim'] 60 | else: 61 | print('Incorect dataset type {} and image resolution {}'.format(self.dataset, self.image_resolution)) 62 | 63 | if os.path.exists(self.generator_path): 64 | print('----- Load generator from {} -----'.format(self.generator_path)) 65 | 66 | self.G = StyleGAN2Generator(self.image_resolution, 512, 8, channel_multiplier = self.channel_multiplier) 67 | self.G.load_state_dict(torch.load(self.generator_path)['g_ema'], strict = True) 68 | self.G.cuda().eval() 69 | # use truncation 70 | self.truncation = 0.7 71 | self.trunc =self.G.mean_latent(4096).detach().clone() 72 | 73 | else: 74 | print('Please download the pretrained model for StyleGAN2 generator and save it into ./pretrained_models path') 75 | exit() 76 | 77 | if os.path.exists(self.masknet_path): 78 | print('----- Load mask network from {} -----'.format(self.masknet_path)) 79 | ckpt = torch.load(self.masknet_path, map_location=torch.device('cpu')) 80 | self.num_layers_control = ckpt['num_layers_control'] 81 | self.mask_net = nn.ModuleDict({}) 82 | for layer_idx in range(self.num_layers_control): 83 | network_name_str = 'network_{:02d}'.format(layer_idx) 84 | 85 | # Net info 86 | stylespace_dim_layer = self.split_sections[layer_idx] 87 | input_dim = stylespace_dim_layer 88 | output_dim = stylespace_dim_layer 89 | inner_dim = stylespace_dim_layer 90 | 91 | network_module = MaskPredictor(input_dim, output_dim, inner_dim = inner_dim) 92 | self.mask_net.update({network_name_str: network_module}) 93 | self.mask_net.load_state_dict(ckpt['mask_net']) 94 | self.mask_net.cuda().eval() 95 | else: 96 | print('Please download the pretrained model for Mask network and save it into ./pretrained_models path') 97 | exit() 98 | 99 | if inversion: 100 | self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cuda') 101 | ### Load inversion model only when the input is image. ### 102 | self.encoder_path = stylegan2_ffhq_1024['e4e_inversion_model'] 103 | print('----- Load e4e encoder from {} -----'.format(self.encoder_path)) 104 | ckpt = torch.load(self.encoder_path, map_location='cpu') 105 | opts = ckpt['opts'] 106 | opts['output_size'] = self.image_resolution 107 | opts['checkpoint_path'] = self.encoder_path 108 | opts['device'] = 'cuda' 109 | opts['channel_multiplier'] = self.channel_multiplier 110 | opts['dataset'] = self.dataset 111 | opts = Namespace(**opts) 112 | self.encoder = pSp(opts) 113 | self.encoder.cuda().eval() 114 | 115 | def load_samples(self, filepath): 116 | inversion = False 117 | if filepath is None: 118 | # Generate random latent code 119 | files_grabbed = [] 120 | for i in range(self.num_pairs): 121 | files_grabbed.append(make_noise(1, 512)) 122 | else: 123 | if os.path.isdir(filepath): 124 | ## Check if files inside directory are images. Else check if latent codes 125 | files_grabbed = get_files_frompath(filepath, ['*.png', '*.jpg']) 126 | if len(files_grabbed) == 0: 127 | files_grabbed = get_files_frompath(filepath, ['*.npy']) 128 | if len(files_grabbed) == 0: 129 | print('Please specify correct path: folder with images (.png, .jpg) or latent codes (.npy)') 130 | exit() 131 | z_codes = [] 132 | for file_ in files_grabbed: 133 | z_codes.append(torch.from_numpy(np.load(file_)).cuda()) 134 | z_codes = torch.cat(z_codes).unsqueeze(0) 135 | files_grabbed = z_codes 136 | else: 137 | inversion = True # invert real images 138 | 139 | elif os.path.isfile(filepath): 140 | 141 | head, tail = os.path.split(filepath) 142 | ext = tail.split('.')[-1] 143 | # Check if file is image 144 | if ext == 'png' or ext == 'jpg': 145 | files_grabbed = [filepath] 146 | inversion = True 147 | elif ext == 'npy': 148 | z_codes = torch.from_numpy(np.load(filepath)).unsqueeze(1) 149 | files_grabbed = z_codes 150 | else: 151 | print('Wrong path. Expected file image (.png, .jpg) or latent code (.npy)') 152 | exit() 153 | else: 154 | print('Wrong path. Expected file image (.png, .jpg) or latent code (.npy)') 155 | exit() 156 | 157 | return files_grabbed, inversion 158 | 159 | def reenact_pair(self, source_code, target_code): 160 | 161 | with torch.no_grad(): 162 | # Get source style space 163 | source_img, style_source, w_source, noise_source = generate_image(self.G, source_code, self.truncation, self.trunc, self.image_resolution, self.split_sections, 164 | input_is_latent = self.input_is_latent, return_latents= True, resize_image = self.resize_image) 165 | 166 | # Get target style space 167 | target_img, style_target, w_target, noise_target = generate_image(self.G, target_code, self.truncation, self.trunc, self.image_resolution, self.split_sections, 168 | input_is_latent = self.input_is_latent, return_latents= True, resize_image = self.resize_image) 169 | 170 | # Get reenacted image 171 | masks_per_layer = [] 172 | for layer_idx in range(self.num_layers_control): 173 | network_name_str = 'network_{:02d}'.format(layer_idx) 174 | style_source_idx = style_source[layer_idx] 175 | style_target_idx = style_target[layer_idx] 176 | styles = style_source_idx - style_target_idx 177 | mask_idx = self.mask_net[network_name_str](styles) 178 | masks_per_layer.append(mask_idx) 179 | 180 | mask = torch.cat(masks_per_layer, dim=1) 181 | style_source = torch.cat(style_source, dim=1) 182 | style_target = torch.cat(style_target, dim=1) 183 | 184 | new_style_space = generate_new_stylespace(style_source, style_target, mask, num_layers_control = self.num_layers_control) 185 | new_style_space = list(torch.split(tensor=new_style_space, split_size_or_sections=self.split_sections, dim=1)) 186 | reenacted_img = decoder(self.G, new_style_space, w_source, noise_source, resize_image = self.resize_image) 187 | 188 | return source_img, target_img, reenacted_img 189 | 190 | def check_paths(self): 191 | assert type(self.target_path) == type(self.source_path), \ 192 | "Source path and target path should have the same type, None, files (.png, .jpg or .npy) or directories with files of type .png, .jpg or .npy" 193 | 194 | if self.source_path is not None and self.target_path is not None: 195 | if os.path.isdir(self.source_path): 196 | assert os.path.isdir(self.target_path), \ 197 | "Source path and target path should have the same type, None, files (.png, .jpg or .npy) or directories with files of type .png, .jpg or .npy" 198 | 199 | if os.path.isfile(self.source_path): 200 | assert os.path.isfile(self.target_path), \ 201 | "Source path and target path should have the same type, None, files (.png, .jpg or .npy) or directories with files of type .png, .jpg or .npy" 202 | 203 | def run(self): 204 | 205 | self.check_paths() 206 | source_samples, inversion = self.load_samples(self.source_path) 207 | target_samples, inversion = self.load_samples(self.target_path) 208 | 209 | assert len(source_samples) == len(target_samples), "Number of source samples should be the same with target samples" 210 | 211 | 212 | self.load_models(inversion) 213 | self.num_pairs = len(source_samples) 214 | 215 | print('Reenact {} pairs'.format(self.num_pairs)) 216 | 217 | for i in tqdm(range(self.num_pairs)): 218 | if inversion: # Real image 219 | # Preprocess and invert real images into the W+ latent space using Encoder4Editing method 220 | cropped_image = preprocess_image(source_samples[i], self.fa, save_filename = None) 221 | source_img = image_to_tensor(cropped_image).unsqueeze(0).cuda() 222 | inv_image, source_code = invert_image(source_img, self.encoder, self.G, self.truncation, self.trunc) 223 | 224 | cropped_image = preprocess_image(target_samples[i], self.fa) 225 | target_img = image_to_tensor(cropped_image).unsqueeze(0).cuda() 226 | inv_image, target_code = invert_image(target_img, self.encoder, self.G, self.truncation, self.trunc) 227 | self.input_is_latent = True 228 | else: # synthetic latent code 229 | if self.source_path is not None: 230 | source_code = source_samples[i].cuda() 231 | target_code = target_samples[i].cuda() 232 | if source_code.ndim == 2: 233 | self.input_is_latent = False # Z space 1 X 512 234 | elif source_code.ndim == 3: 235 | self.truncation = 1.0 236 | self.trunc = None 237 | self.input_is_latent = True # W sapce 1 X 18 X 512 238 | else: 239 | source_code = source_samples[i].cuda() 240 | target_code = target_samples[i].cuda() 241 | self.input_is_latent = False # Z space 242 | 243 | source_img, target_img, reenacted_img = self.reenact_pair(source_code, target_code) 244 | 245 | if self.save_grid: 246 | save_grid(source_img, target_img, reenacted_img, os.path.join(self.output_path, 'grid_{:04d}.png').format(i)) 247 | if self.save_image: 248 | save_image(reenacted_img, os.path.join(self.output_path, '{:04d}.png').format(i)) 249 | 250 | def main(): 251 | """ 252 | Inference script. 253 | 254 | Options: 255 | ######### General ########### 256 | --output_path : path to save output images 257 | 258 | --source_path : It can be either an image file, or a latent code or a directory with images or latent codes or None. 259 | If source path is None then it will generate a random latent code. 260 | --target_path : It can be either an image file, or a latent code or a directory with images or latent codes or None. 261 | If target path is None then it will generate a random latent code. 262 | 263 | --masknet_path : path to pretrained model for mask network 264 | --dataset : dataset (ffhq) 265 | --image_resolution : image resolution (1024) 266 | --num_pairs : number of pairs to reenact 267 | 268 | ########## Visualization ########## 269 | --save_grid : Generate figure with source, target and reenacted image 270 | --save_image : Save only the reenacted image 271 | --resize_image : Resize image from 1024 to 256 272 | 273 | 274 | python run_inference.py --output_path ./results --save_grid 275 | 276 | """ 277 | parser = argparse.ArgumentParser(description="training script") 278 | 279 | ######### General ######### 280 | parser.add_argument('--output_path', type=str, required = True, help="path to save output images") 281 | parser.add_argument('--source_path', type=str, default = None, help='path to source samples (latent codes or images)') 282 | parser.add_argument('--target_path', type=str, default = None, help='path to target samples (latent codes or images)') 283 | 284 | parser.add_argument('--masknet_path', type=str, default = './pretrained_models/mask_network_1024.pt', help="path to pretrained model for mask network") 285 | parser.add_argument('--dataset', type=str, default = 'ffhq', help="dataset") 286 | parser.add_argument('--image_resolution', type=int, default = 1024, help="image resolution") 287 | 288 | parser.add_argument('--num_pairs', type=int, default = 4, help="number of random pairs to reenact") 289 | 290 | parser.add_argument('--save_grid', dest='save_grid', action='store_true', help="Generate figure with source, target and reenacted image") 291 | parser.set_defaults(save_grid=False) 292 | parser.add_argument('--save_image', dest='save_image', action='store_true', help="Save only the reenacted image") 293 | parser.set_defaults(save_image=False) 294 | parser.add_argument('--resize_image', dest='resize_image', action='store_true', help="Resize image from 1024 to 256") 295 | parser.set_defaults(resize_image=False) 296 | 297 | 298 | # Parse given arguments 299 | args = parser.parse_args() 300 | args = vars(args) # convert to dictionary 301 | 302 | inf = Inference_demo(args) 303 | inf.run() 304 | 305 | 306 | 307 | 308 | if __name__ == '__main__': 309 | main() 310 | 311 | 312 | 313 | -------------------------------------------------------------------------------- /libs/DECA/decalib/utils/rotation_converter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | 18 | ''' Rotation Converter 19 | Representations: 20 | euler angle(3), angle axis(3), rotation matrix(3x3), quaternion(4), continous repre 21 | Ref: 22 | https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/core/conversions.html# 23 | smplx/lbs 24 | ''' 25 | 26 | pi = torch.Tensor([3.14159265358979323846]) 27 | def rad2deg(tensor): 28 | """Function that converts angles from radians to degrees. 29 | 30 | See :class:`~torchgeometry.RadToDeg` for details. 31 | 32 | Args: 33 | tensor (Tensor): Tensor of arbitrary shape. 34 | 35 | Returns: 36 | Tensor: Tensor with same shape as input. 37 | 38 | Example: 39 | >>> input = tgm.pi * torch.rand(1, 3, 3) 40 | >>> output = tgm.rad2deg(input) 41 | """ 42 | if not torch.is_tensor(tensor): 43 | raise TypeError("Input type is not a torch.Tensor. Got {}" 44 | .format(type(tensor))) 45 | 46 | return 180. * tensor / pi.to(tensor.device).type(tensor.dtype) 47 | 48 | def deg2rad(tensor): 49 | """Function that converts angles from degrees to radians. 50 | 51 | See :class:`~torchgeometry.DegToRad` for details. 52 | 53 | Args: 54 | tensor (Tensor): Tensor of arbitrary shape. 55 | 56 | Returns: 57 | Tensor: Tensor with same shape as input. 58 | 59 | Examples:: 60 | 61 | >>> input = 360. * torch.rand(1, 3, 3) 62 | >>> output = tgm.deg2rad(input) 63 | """ 64 | if not torch.is_tensor(tensor): 65 | raise TypeError("Input type is not a torch.Tensor. Got {}" 66 | .format(type(tensor))) 67 | 68 | return tensor * pi.to(tensor.device).type(tensor.dtype) / 180. 69 | 70 | ######### to quaternion 71 | def euler_to_quaternion(r): 72 | x = r[..., 0] 73 | y = r[..., 1] 74 | z = r[..., 2] 75 | 76 | z = z/2.0 77 | y = y/2.0 78 | x = x/2.0 79 | cz = torch.cos(z) 80 | sz = torch.sin(z) 81 | cy = torch.cos(y) 82 | sy = torch.sin(y) 83 | cx = torch.cos(x) 84 | sx = torch.sin(x) 85 | quaternion = torch.zeros_like(r.repeat(1,2))[..., :4].to(r.device) 86 | quaternion[..., 0] += cx*cy*cz - sx*sy*sz 87 | quaternion[..., 1] += cx*sy*sz + cy*cz*sx 88 | quaternion[..., 2] += cx*cz*sy - sx*cy*sz 89 | quaternion[..., 3] += cx*cy*sz + sx*cz*sy 90 | return quaternion 91 | 92 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): 93 | """Convert 3x4 rotation matrix to 4d quaternion vector 94 | 95 | This algorithm is based on algorithm described in 96 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 97 | 98 | Args: 99 | rotation_matrix (Tensor): the rotation matrix to convert. 100 | 101 | Return: 102 | Tensor: the rotation in quaternion 103 | 104 | Shape: 105 | - Input: :math:`(N, 3, 4)` 106 | - Output: :math:`(N, 4)` 107 | 108 | Example: 109 | >>> input = torch.rand(4, 3, 4) # Nx3x4 110 | >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 111 | """ 112 | if not torch.is_tensor(rotation_matrix): 113 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 114 | type(rotation_matrix))) 115 | 116 | if len(rotation_matrix.shape) > 3: 117 | raise ValueError( 118 | "Input size must be a three dimensional tensor. Got {}".format( 119 | rotation_matrix.shape)) 120 | # if not rotation_matrix.shape[-2:] == (3, 4): 121 | # raise ValueError( 122 | # "Input size must be a N x 3 x 4 tensor. Got {}".format( 123 | # rotation_matrix.shape)) 124 | 125 | rmat_t = torch.transpose(rotation_matrix, 1, 2) 126 | 127 | mask_d2 = rmat_t[:, 2, 2] < eps 128 | 129 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] 130 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] 131 | 132 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 133 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 134 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 135 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) 136 | t0_rep = t0.repeat(4, 1).t() 137 | 138 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 139 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 140 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 141 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) 142 | t1_rep = t1.repeat(4, 1).t() 143 | 144 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 145 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], 146 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2], 147 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) 148 | t2_rep = t2.repeat(4, 1).t() 149 | 150 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 151 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 152 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 153 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) 154 | t3_rep = t3.repeat(4, 1).t() 155 | 156 | mask_c0 = mask_d2 * mask_d0_d1.float() 157 | mask_c1 = mask_d2 * (1 - mask_d0_d1.float()) 158 | mask_c2 = (1 - mask_d2.float()) * mask_d0_nd1 159 | mask_c3 = (1 - mask_d2.float()) * (1 - mask_d0_nd1.float()) 160 | mask_c0 = mask_c0.view(-1, 1).type_as(q0) 161 | mask_c1 = mask_c1.view(-1, 1).type_as(q1) 162 | mask_c2 = mask_c2.view(-1, 1).type_as(q2) 163 | mask_c3 = mask_c3.view(-1, 1).type_as(q3) 164 | 165 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 166 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa 167 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa 168 | q *= 0.5 169 | return q 170 | 171 | # def angle_axis_to_quaternion(theta): 172 | # batch_size = theta.shape[0] 173 | # l1norm = torch.norm(theta + 1e-8, p=2, dim=1) 174 | # angle = torch.unsqueeze(l1norm, -1) 175 | # normalized = torch.div(theta, angle) 176 | # angle = angle * 0.5 177 | # v_cos = torch.cos(angle) 178 | # v_sin = torch.sin(angle) 179 | # quat = torch.cat([v_cos, v_sin * normalized], dim=1) 180 | # return quat 181 | 182 | def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor: 183 | """Convert an angle axis to a quaternion. 184 | 185 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 186 | 187 | Args: 188 | angle_axis (torch.Tensor): tensor with angle axis. 189 | 190 | Return: 191 | torch.Tensor: tensor with quaternion. 192 | 193 | Shape: 194 | - Input: :math:`(*, 3)` where `*` means, any number of dimensions 195 | - Output: :math:`(*, 4)` 196 | 197 | Example: 198 | >>> angle_axis = torch.rand(2, 4) # Nx4 199 | >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3 200 | """ 201 | if not torch.is_tensor(angle_axis): 202 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 203 | type(angle_axis))) 204 | 205 | if not angle_axis.shape[-1] == 3: 206 | raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}" 207 | .format(angle_axis.shape)) 208 | # unpack input and compute conversion 209 | a0: torch.Tensor = angle_axis[..., 0:1] 210 | a1: torch.Tensor = angle_axis[..., 1:2] 211 | a2: torch.Tensor = angle_axis[..., 2:3] 212 | theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2 213 | 214 | theta: torch.Tensor = torch.sqrt(theta_squared) 215 | half_theta: torch.Tensor = theta * 0.5 216 | 217 | mask: torch.Tensor = theta_squared > 0.0 218 | ones: torch.Tensor = torch.ones_like(half_theta) 219 | 220 | k_neg: torch.Tensor = 0.5 * ones 221 | k_pos: torch.Tensor = torch.sin(half_theta) / theta 222 | k: torch.Tensor = torch.where(mask, k_pos, k_neg) 223 | w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones) 224 | 225 | quaternion: torch.Tensor = torch.zeros_like(angle_axis) 226 | quaternion[..., 0:1] += a0 * k 227 | quaternion[..., 1:2] += a1 * k 228 | quaternion[..., 2:3] += a2 * k 229 | 230 | # print(quaternion) 231 | return torch.cat([w, quaternion], dim=-1) 232 | 233 | #### quaternion to 234 | def quaternion_to_rotation_matrix(quat): 235 | """Convert quaternion coefficients to rotation matrix. 236 | Args: 237 | quat: size = [B, 4] 4 <===>(w, x, y, z) 238 | Returns: 239 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 240 | """ 241 | norm_quat = quat 242 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) 243 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] 244 | 245 | B = quat.size(0) 246 | 247 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 248 | wx, wy, wz = w * x, w * y, w * z 249 | xy, xz, yz = x * y, x * z, y * z 250 | 251 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 252 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 253 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) 254 | return rotMat 255 | 256 | def quaternion_to_angle_axis(quaternion: torch.Tensor): 257 | """Convert quaternion vector to angle axis of rotation. TODO: CORRECT 258 | 259 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 260 | 261 | Args: 262 | quaternion (torch.Tensor): tensor with quaternions. 263 | 264 | Return: 265 | torch.Tensor: tensor with angle axis of rotation. 266 | 267 | Shape: 268 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions 269 | - Output: :math:`(*, 3)` 270 | 271 | Example: 272 | >>> quaternion = torch.rand(2, 4) # Nx4 273 | >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 274 | """ 275 | if not torch.is_tensor(quaternion): 276 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 277 | type(quaternion))) 278 | 279 | if not quaternion.shape[-1] == 4: 280 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" 281 | .format(quaternion.shape)) 282 | # unpack input and compute conversion 283 | q1: torch.Tensor = quaternion[..., 1] 284 | q2: torch.Tensor = quaternion[..., 2] 285 | q3: torch.Tensor = quaternion[..., 3] 286 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 287 | 288 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) 289 | cos_theta: torch.Tensor = quaternion[..., 0] 290 | two_theta: torch.Tensor = 2.0 * torch.where( 291 | cos_theta < 0.0, 292 | torch.atan2(-sin_theta, -cos_theta), 293 | torch.atan2(sin_theta, cos_theta)) 294 | 295 | k_pos: torch.Tensor = two_theta / sin_theta 296 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(quaternion.device) 297 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) 298 | 299 | angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(quaternion.device)[..., :3] 300 | angle_axis[..., 0] += q1 * k 301 | angle_axis[..., 1] += q2 * k 302 | angle_axis[..., 2] += q3 * k 303 | return angle_axis 304 | 305 | #### batch converter 306 | def batch_euler2axis(r): 307 | return quaternion_to_angle_axis(euler_to_quaternion(r)) 308 | 309 | def batch_euler2matrix(r): 310 | return quaternion_to_rotation_matrix(euler_to_quaternion(r)) 311 | 312 | def batch_matrix2euler(rot_mats): 313 | # Calculates rotation matrix to euler angles 314 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 315 | ### only y? 316 | # TODO: 317 | # sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 318 | # rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 319 | # return torch.atan2(-rot_mats[:, 2, 0], sy) 320 | batch_index = 0 321 | yaw = torch.zeros(rot_mats.shape[0],1) 322 | pitch = torch.zeros(rot_mats.shape[0],1) 323 | roll = torch.zeros(rot_mats.shape[0],1) 324 | for R in rot_mats: 325 | 326 | if R[2, 0] > 0.998: 327 | z = 0 328 | x = np.pi / 2 329 | y = z + atan2(-R[0, 1], -R[0, 2]) 330 | elif R[2, 0] < -0.998: 331 | z = 0 332 | x = -np.pi / 2 333 | y = -z + torch.atan2(R[0, 1], R[0, 2]) 334 | else: 335 | x = torch.asin(R[2, 0]) 336 | y = torch.atan2(R[2, 1] / torch.cos(x), R[2, 2] / torch.cos(x)) 337 | z = torch.atan2(R[1, 0] / torch.cos(x), R[0, 0] / torch.cos(x)) 338 | 339 | yaw[batch_index] = x 340 | pitch[batch_index] = y 341 | roll[batch_index] = z 342 | batch_index = batch_index + 1 343 | angles = torch.zeros(1, 3) 344 | angles[:,0] = x 345 | angles[:,1] = y 346 | angles[:,2] = z 347 | return angles 348 | 349 | def batch_matrix2axis(rot_mats): 350 | return quaternion_to_angle_axis(rotation_matrix_to_quaternion(rot_mats)) 351 | 352 | def batch_axis2matrix(theta): 353 | # angle axis to rotation matrix 354 | # theta N x 3 355 | # return quat2mat(quat) 356 | # batch_rodrigues 357 | return quaternion_to_rotation_matrix(angle_axis_to_quaternion(theta)) 358 | 359 | def batch_axis2euler(theta): 360 | return batch_matrix2euler(batch_axis2matrix(theta)) 361 | 362 | 363 | 364 | def batch_orth_proj(X, camera): 365 | ''' 366 | X is N x num_pquaternion_to_angle_axisoints x 3 367 | ''' 368 | camera = camera.clone().view(-1, 1, 3) 369 | X_trans = X[:, :, :2] + camera[:, :, 1:] 370 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2) 371 | Xn = (camera[:, :, 0:1] * X_trans) 372 | return Xn 373 | 374 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 375 | ''' same as batch_matrix2axis 376 | Calculates the rotation matrices for a batch of rotation vectors 377 | Parameters 378 | ---------- 379 | rot_vecs: torch.tensor Nx3 380 | array of N axis-angle vectors 381 | Returns 382 | ------- 383 | R: torch.tensor Nx3x3 384 | The rotation matrices for the given axis-angle parameters 385 | ''' 386 | 387 | batch_size = rot_vecs.shape[0] 388 | device = rot_vecs.device 389 | 390 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 391 | rot_dir = rot_vecs / angle 392 | 393 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 394 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 395 | 396 | # Bx1 arrays 397 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 398 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 399 | 400 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 401 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 402 | .view((batch_size, 3, 3)) 403 | 404 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 405 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 406 | return rot_mat 407 | -------------------------------------------------------------------------------- /libs/DECA/decalib/models/lbs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | 26 | def rot_mat_to_euler(rot_mats): 27 | # Calculates rotation matrix to euler angles 28 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 29 | 30 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 31 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 32 | return torch.atan2(-rot_mats[:, 2, 0], sy) 33 | 34 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx, 35 | dynamic_lmk_b_coords, 36 | neck_kin_chain, dtype=torch.float32): 37 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks 38 | 39 | 40 | To do so, we first compute the rotation of the neck around the y-axis 41 | and then use a pre-computed look-up table to find the faces and the 42 | barycentric coordinates that will be used. 43 | 44 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) 45 | for providing the original TensorFlow implementation and for the LUT. 46 | 47 | Parameters 48 | ---------- 49 | vertices: torch.tensor BxVx3, dtype = torch.float32 50 | The tensor of input vertices 51 | pose: torch.tensor Bx(Jx3), dtype = torch.float32 52 | The current pose of the body model 53 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long 54 | The look-up table from neck rotation to faces 55 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 56 | The look-up table from neck rotation to barycentric coordinates 57 | neck_kin_chain: list 58 | A python list that contains the indices of the joints that form the 59 | kinematic chain of the neck. 60 | dtype: torch.dtype, optional 61 | 62 | Returns 63 | ------- 64 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long 65 | A tensor of size BxL that contains the indices of the faces that 66 | will be used to compute the current dynamic landmarks. 67 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 68 | A tensor of size BxL that contains the indices of the faces that 69 | will be used to compute the current dynamic landmarks. 70 | ''' 71 | 72 | batch_size = vertices.shape[0] 73 | 74 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 75 | neck_kin_chain) 76 | rot_mats = batch_rodrigues( 77 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 78 | 79 | rel_rot_mat = torch.eye(3, device=vertices.device, 80 | dtype=dtype).unsqueeze_(dim=0) 81 | for idx in range(len(neck_kin_chain)): 82 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 83 | 84 | y_rot_angle = torch.round( 85 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 86 | max=39)).to(dtype=torch.long) 87 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 88 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 89 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 90 | y_rot_angle = (neg_mask * neg_vals + 91 | (1 - neg_mask) * y_rot_angle) 92 | 93 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 94 | 0, y_rot_angle) 95 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 96 | 0, y_rot_angle) 97 | 98 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 99 | 100 | 101 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): 102 | ''' Calculates landmarks by barycentric interpolation 103 | 104 | Parameters 105 | ---------- 106 | vertices: torch.tensor BxVx3, dtype = torch.float32 107 | The tensor of input vertices 108 | faces: torch.tensor Fx3, dtype = torch.long 109 | The faces of the mesh 110 | lmk_faces_idx: torch.tensor L, dtype = torch.long 111 | The tensor with the indices of the faces used to calculate the 112 | landmarks. 113 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 114 | The tensor of barycentric coordinates that are used to interpolate 115 | the landmarks 116 | 117 | Returns 118 | ------- 119 | landmarks: torch.tensor BxLx3, dtype = torch.float32 120 | The coordinates of the landmarks for each mesh in the batch 121 | ''' 122 | # Extract the indices of the vertices for each face 123 | # BxLx3 124 | batch_size, num_verts = vertices.shape[:2] 125 | device = vertices.device 126 | 127 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 128 | batch_size, -1, 3) 129 | 130 | lmk_faces += torch.arange( 131 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts 132 | 133 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( 134 | batch_size, -1, 3, 3) 135 | 136 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 137 | return landmarks 138 | 139 | 140 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, 141 | lbs_weights, pose2rot=True, dtype=torch.float32): 142 | ''' Performs Linear Blend Skinning with the given shape and pose parameters 143 | 144 | Parameters 145 | ---------- 146 | betas : torch.tensor BxNB 147 | The tensor of shape parameters 148 | pose : torch.tensor Bx(J + 1) * 3 149 | The pose parameters in axis-angle format 150 | v_template torch.tensor BxVx3 151 | The template mesh that will be deformed 152 | shapedirs : torch.tensor 1xNB 153 | The tensor of PCA shape displacements 154 | posedirs : torch.tensor Px(V * 3) 155 | The pose PCA coefficients 156 | J_regressor : torch.tensor JxV 157 | The regressor array that is used to calculate the joints from 158 | the position of the vertices 159 | parents: torch.tensor J 160 | The array that describes the kinematic tree for the model 161 | lbs_weights: torch.tensor N x V x (J + 1) 162 | The linear blend skinning weights that represent how much the 163 | rotation matrix of each part affects each vertex 164 | pose2rot: bool, optional 165 | Flag on whether to convert the input pose tensor to rotation 166 | matrices. The default value is True. If False, then the pose tensor 167 | should already contain rotation matrices and have a size of 168 | Bx(J + 1)x9 169 | dtype: torch.dtype, optional 170 | 171 | Returns 172 | ------- 173 | verts: torch.tensor BxVx3 174 | The vertices of the mesh after applying the shape and pose 175 | displacements. 176 | joints: torch.tensor BxJx3 177 | The joints of the model 178 | ''' 179 | 180 | batch_size = max(betas.shape[0], pose.shape[0]) 181 | device = betas.device 182 | 183 | # Add shape contribution 184 | v_shaped = v_template + blend_shapes(betas, shapedirs) 185 | 186 | # Get the joints 187 | # NxJx3 array 188 | J = vertices2joints(J_regressor, v_shaped) 189 | 190 | # 3. Add pose blend shapes 191 | # N x J x 3 x 3 192 | ident = torch.eye(3, dtype=dtype, device=device) 193 | if pose2rot: 194 | rot_mats = batch_rodrigues( 195 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) 196 | 197 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 198 | # (N x P) x (P, V * 3) -> N x V x 3 199 | pose_offsets = torch.matmul(pose_feature, posedirs) \ 200 | .view(batch_size, -1, 3) 201 | else: 202 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident 203 | rot_mats = pose.view(batch_size, -1, 3, 3) 204 | 205 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), 206 | posedirs).view(batch_size, -1, 3) 207 | 208 | v_posed = pose_offsets + v_shaped 209 | # 4. Get the global joint location 210 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 211 | 212 | # 5. Do skinning: 213 | # W is N x V x (J + 1) 214 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) 215 | # (N x V x (J + 1)) x (N x (J + 1) x 16) 216 | num_joints = J_regressor.shape[0] 217 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ 218 | .view(batch_size, -1, 4, 4) 219 | 220 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], 221 | dtype=dtype, device=device) 222 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) 223 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) 224 | 225 | verts = v_homo[:, :, :3, 0] 226 | 227 | return verts, J_transformed 228 | 229 | 230 | def vertices2joints(J_regressor, vertices): 231 | ''' Calculates the 3D joint locations from the vertices 232 | 233 | Parameters 234 | ---------- 235 | J_regressor : torch.tensor JxV 236 | The regressor array that is used to calculate the joints from the 237 | position of the vertices 238 | vertices : torch.tensor BxVx3 239 | The tensor of mesh vertices 240 | 241 | Returns 242 | ------- 243 | torch.tensor BxJx3 244 | The location of the joints 245 | ''' 246 | 247 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) 248 | 249 | 250 | def blend_shapes(betas, shape_disps): 251 | ''' Calculates the per vertex displacement due to the blend shapes 252 | 253 | 254 | Parameters 255 | ---------- 256 | betas : torch.tensor Bx(num_betas) 257 | Blend shape coefficients 258 | shape_disps: torch.tensor Vx3x(num_betas) 259 | Blend shapes 260 | 261 | Returns 262 | ------- 263 | torch.tensor BxVx3 264 | The per-vertex displacement due to shape deformation 265 | ''' 266 | 267 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] 268 | # i.e. Multiply each shape displacement by its corresponding beta and 269 | # then sum them. 270 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) 271 | return blend_shape 272 | 273 | 274 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 275 | ''' Calculates the rotation matrices for a batch of rotation vectors 276 | Parameters 277 | ---------- 278 | rot_vecs: torch.tensor Nx3 279 | array of N axis-angle vectors 280 | Returns 281 | ------- 282 | R: torch.tensor Nx3x3 283 | The rotation matrices for the given axis-angle parameters 284 | ''' 285 | 286 | batch_size = rot_vecs.shape[0] 287 | device = rot_vecs.device 288 | 289 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 290 | rot_dir = rot_vecs / angle 291 | 292 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 293 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 294 | 295 | # Bx1 arrays 296 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 297 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 298 | 299 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 300 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 301 | .view((batch_size, 3, 3)) 302 | 303 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 304 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 305 | return rot_mat 306 | 307 | 308 | def transform_mat(R, t): 309 | ''' Creates a batch of transformation matrices 310 | Args: 311 | - R: Bx3x3 array of a batch of rotation matrices 312 | - t: Bx3x1 array of a batch of translation vectors 313 | Returns: 314 | - T: Bx4x4 Transformation matrix 315 | ''' 316 | # No padding left or right, only add an extra row 317 | return torch.cat([F.pad(R, [0, 0, 0, 1]), 318 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2) 319 | 320 | 321 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): 322 | """ 323 | Applies a batch of rigid transformations to the joints 324 | 325 | Parameters 326 | ---------- 327 | rot_mats : torch.tensor BxNx3x3 328 | Tensor of rotation matrices 329 | joints : torch.tensor BxNx3 330 | Locations of joints 331 | parents : torch.tensor BxN 332 | The kinematic tree of each object 333 | dtype : torch.dtype, optional: 334 | The data type of the created tensors, the default is torch.float32 335 | 336 | Returns 337 | ------- 338 | posed_joints : torch.tensor BxNx3 339 | The locations of the joints after applying the pose rotations 340 | rel_transforms : torch.tensor BxNx4x4 341 | The relative (with respect to the root joint) rigid transformations 342 | for all the joints 343 | """ 344 | 345 | joints = torch.unsqueeze(joints, dim=-1) 346 | 347 | rel_joints = joints.clone() 348 | rel_joints[:, 1:] -= joints[:, parents[1:]] 349 | 350 | # transforms_mat = transform_mat( 351 | # rot_mats.view(-1, 3, 3), 352 | # rel_joints.view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4) 353 | transforms_mat = transform_mat( 354 | rot_mats.view(-1, 3, 3), 355 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) 356 | 357 | transform_chain = [transforms_mat[:, 0]] 358 | for i in range(1, parents.shape[0]): 359 | # Subtract the joint location at the rest pose 360 | # No need for rotation, since it's identity when at rest 361 | curr_res = torch.matmul(transform_chain[parents[i]], 362 | transforms_mat[:, i]) 363 | transform_chain.append(curr_res) 364 | 365 | transforms = torch.stack(transform_chain, dim=1) 366 | 367 | # The last column of the transformations contains the posed joints 368 | posed_joints = transforms[:, :, :3, 3] 369 | 370 | # The last column of the transformations contains the posed joints 371 | posed_joints = transforms[:, :, :3, 3] 372 | 373 | joints_homogen = F.pad(joints, [0, 0, 0, 1]) 374 | 375 | rel_transforms = transforms - F.pad( 376 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) 377 | 378 | return posed_joints, rel_transforms --------------------------------------------------------------------------------