├── .gitignore ├── README.md ├── ZSSGAN ├── __init__.py ├── criteria │ └── clip_loss.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── generate_videos.py ├── legacy.py ├── mapper │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ └── latents_dataset.py │ ├── facial_recognition │ │ ├── __init__.py │ │ ├── helpers.py │ │ └── model_irse.py │ ├── latent_mappers.py │ ├── mapper_criteria │ │ ├── __init__.py │ │ ├── clip_loss.py │ │ └── id_loss.py │ ├── options │ │ ├── __init__.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── scripts │ │ ├── inference.py │ │ └── train.py │ ├── styleclip_mapper.py │ ├── stylegan2 │ │ ├── __init__.py │ │ ├── model.py │ │ └── op │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ └── upfirdn2d.py │ └── training │ │ ├── __init__.py │ │ ├── coach.py │ │ ├── ranger.py │ │ └── train_utils.py ├── model │ ├── ZSSGAN.py │ ├── ZSSGAN_IDE3D.py │ ├── ZSSGAN_eg3d.py │ └── sg2_model.py ├── op │ ├── __init__.py │ ├── conv2d_gradfix.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── options │ └── train_options.py ├── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── filtered_lrelu.cpp │ │ ├── filtered_lrelu.cu │ │ ├── filtered_lrelu.h │ │ ├── filtered_lrelu.py │ │ ├── filtered_lrelu_ns.cu │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py ├── train.py ├── utils │ ├── file_utils.py │ ├── text_templates.py │ └── training_utils.py └── volumetric_rendering.py ├── assets └── pipeline-2.jpg ├── eg3d ├── camera_utils.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── legacy.py ├── shape_utils.py ├── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── filtered_lrelu.cpp │ │ ├── filtered_lrelu.cu │ │ ├── filtered_lrelu.h │ │ ├── filtered_lrelu.py │ │ ├── filtered_lrelu_ns.cu │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py └── training │ ├── __init__.py │ ├── augment.py │ ├── crosssection_utils.py │ ├── dataset.py │ ├── dual_discriminator.py │ ├── networks_stylegan2.py │ ├── networks_stylegan3.py │ ├── superresolution.py │ ├── training_loop.py │ ├── triplane.py │ └── volumetric_rendering │ ├── __init__.py │ ├── math_utils.py │ ├── ray_marcher.py │ ├── ray_sampler.py │ └── renderer.py ├── preprocess ├── extract_3dmm.py ├── extract_camera.py ├── extract_landmark.py ├── extract_mask.py ├── mirror_padding.py ├── process_camera.py ├── run_crop.py ├── run_total.py ├── transform_into_goae_data_format.py └── video2frames.py ├── spi ├── configs │ ├── __init__.py │ ├── global_config.py │ ├── hyperparameters.py │ └── paths_config.py ├── criteria │ ├── __init__.py │ ├── bbox_cx_loss.py │ ├── id_loss │ │ ├── helpers.py │ │ ├── id_loss.py │ │ └── model_irse.py │ ├── l2_loss.py │ ├── lpips │ │ ├── __init__.py │ │ ├── lpips.py │ │ ├── networks.py │ │ └── utils.py │ └── tv_loss.py ├── data │ └── images_dataset.py ├── run_inversion.py ├── training │ ├── coaches │ │ ├── base_coach.py │ │ ├── inference_coach.py │ │ ├── pti_coach.py │ │ └── rot_bbox_cx_coach.py │ └── projectors │ │ ├── mirror_projector.py │ │ ├── w_plus_projector.py │ │ └── w_projector.py └── utils │ ├── camera_utils.py │ ├── load_utils.py │ ├── log_utils.py │ ├── mask_utils.py │ ├── metric_utils.py │ ├── rotate.py │ └── video_utils.py ├── test └── images │ └── anne_hathaway.jpg └── third_part ├── Deep3DFaceRecon_pytorch ├── models │ ├── __init__.py │ ├── arcface_torch │ │ ├── README.md │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── iresnet.py │ │ │ ├── iresnet2060.py │ │ │ └── mobilefacenet.py │ │ ├── configs │ │ │ ├── 3millions.py │ │ │ ├── 3millions_pfc.py │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── glint360k_mbf.py │ │ │ ├── glint360k_r100.py │ │ │ ├── glint360k_r18.py │ │ │ ├── glint360k_r34.py │ │ │ ├── glint360k_r50.py │ │ │ ├── ms1mv3_mbf.py │ │ │ ├── ms1mv3_r18.py │ │ │ ├── ms1mv3_r2060.py │ │ │ ├── ms1mv3_r34.py │ │ │ ├── ms1mv3_r50.py │ │ │ └── speed.py │ │ ├── dataset.py │ │ ├── docs │ │ │ ├── eval.md │ │ │ ├── install.md │ │ │ ├── modelzoo.md │ │ │ └── speed_benchmark.md │ │ ├── eval │ │ │ ├── __init__.py │ │ │ └── verification.py │ │ ├── eval_ijbc.py │ │ ├── inference.py │ │ ├── losses.py │ │ ├── onnx_helper.py │ │ ├── onnx_ijbc.py │ │ ├── partial_fc.py │ │ ├── requirement.txt │ │ ├── run.sh │ │ ├── torch2onnx.py │ │ ├── train.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── plot.py │ │ │ ├── utils_amp.py │ │ │ ├── utils_callbacks.py │ │ │ ├── utils_config.py │ │ │ ├── utils_logging.py │ │ │ └── utils_os.py │ ├── base_model.py │ ├── bfm.py │ ├── facerecon_model.py │ ├── losses.py │ ├── networks.py │ └── template_model.py ├── options │ ├── __init__.py │ ├── base_options.py │ ├── inference_options.py │ └── test_options.py ├── temp │ └── 2.png ├── test.py └── util │ ├── BBRegressorParam_r.mat │ ├── __init__.py │ ├── generate_list.py │ ├── html.py │ ├── load_mats.py │ ├── nvdiffrast.py │ ├── preprocess.py │ ├── skin_mask.py │ ├── test_mean_face.txt │ ├── util.py │ └── visualizer.py └── bisenet ├── bisenet.py └── resnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | .history 2 | __pycache__ 3 | test/dataset 4 | test/output 5 | # ZSSGAN -------------------------------------------------------------------------------- /ZSSGAN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/ZSSGAN/__init__.py -------------------------------------------------------------------------------- /ZSSGAN/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/ZSSGAN/mapper/__init__.py -------------------------------------------------------------------------------- /ZSSGAN/mapper/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/ZSSGAN/mapper/datasets/__init__.py -------------------------------------------------------------------------------- /ZSSGAN/mapper/datasets/latents_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class LatentsDataset(Dataset): 5 | 6 | def __init__(self, latents, opts): 7 | self.latents = latents 8 | self.opts = opts 9 | 10 | def __len__(self): 11 | return self.latents.shape[0] 12 | 13 | def __getitem__(self, index): 14 | 15 | return self.latents[index] 16 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/facial_recognition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/ZSSGAN/mapper/facial_recognition/__init__.py -------------------------------------------------------------------------------- /ZSSGAN/mapper/facial_recognition/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/facial_recognition/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from mapper.facial_recognition.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/latent_mappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module 4 | 5 | from mapper.stylegan2.model import EqualLinear, PixelNorm 6 | 7 | 8 | class Mapper(Module): 9 | 10 | def __init__(self, opts): 11 | super(Mapper, self).__init__() 12 | 13 | self.opts = opts 14 | layers = [PixelNorm()] 15 | 16 | for i in range(4): 17 | layers.append( 18 | EqualLinear( 19 | 512, 512, lr_mul=0.01, activation='fused_lrelu' 20 | ) 21 | ) 22 | 23 | self.mapping = nn.Sequential(*layers) 24 | 25 | 26 | def forward(self, x): 27 | x = self.mapping(x) 28 | return x 29 | 30 | 31 | class SingleMapper(Module): 32 | 33 | def __init__(self, opts): 34 | super(SingleMapper, self).__init__() 35 | 36 | self.opts = opts 37 | 38 | self.mapping = Mapper(opts) 39 | 40 | def forward(self, x): 41 | out = self.mapping(x) 42 | return out 43 | 44 | 45 | class LevelsMapper(Module): 46 | 47 | def __init__(self, opts): 48 | super(LevelsMapper, self).__init__() 49 | 50 | self.opts = opts 51 | 52 | if not opts.no_coarse_mapper: 53 | self.course_mapping = Mapper(opts) 54 | if not opts.no_medium_mapper: 55 | self.medium_mapping = Mapper(opts) 56 | if not opts.no_fine_mapper: 57 | self.fine_mapping = Mapper(opts) 58 | 59 | def forward(self, x): 60 | x_coarse = x[:, :4, :] 61 | x_medium = x[:, 4:8, :] 62 | x_fine = x[:, 8:, :] 63 | 64 | if not self.opts.no_coarse_mapper: 65 | x_coarse = self.course_mapping(x_coarse) 66 | else: 67 | x_coarse = torch.zeros_like(x_coarse) 68 | if not self.opts.no_medium_mapper: 69 | x_medium = self.medium_mapping(x_medium) 70 | else: 71 | x_medium = torch.zeros_like(x_medium) 72 | if not self.opts.no_fine_mapper: 73 | x_fine = self.fine_mapping(x_fine) 74 | else: 75 | x_fine = torch.zeros_like(x_fine) 76 | 77 | 78 | out = torch.cat([x_coarse, x_medium, x_fine], dim=1) 79 | 80 | return out 81 | 82 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/mapper_criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/ZSSGAN/mapper/mapper_criteria/__init__.py -------------------------------------------------------------------------------- /ZSSGAN/mapper/mapper_criteria/clip_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import clip 4 | 5 | 6 | class CLIPLoss(torch.nn.Module): 7 | 8 | def __init__(self, opts): 9 | super(CLIPLoss, self).__init__() 10 | self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") 11 | self.upsample = torch.nn.Upsample(scale_factor=7) 12 | self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) 13 | 14 | self.model_16, _ = clip.load("ViT-B/16", device="cuda") 15 | 16 | self.mse_loss = torch.nn.MSELoss() 17 | 18 | def forward(self, image, text): 19 | image = self.avg_pool(self.upsample(image)) 20 | similarity = 1.0 * (1 - self.model(image, text)[0] / 100) + 1.0 * (1.0 - self.model_16(image, text)[0] / 100) 21 | return similarity 22 | 23 | def norm_loss(self, image_pre, image_post): 24 | norm_pre = self.model.encode_image(self.avg_pool(self.upsample(image_pre))).norm(dim=-1) 25 | norm_post = self.model.encode_image(self.avg_pool(self.upsample(image_post))).norm(dim=-1) 26 | 27 | return self.mse_loss(norm_pre, norm_post) -------------------------------------------------------------------------------- /ZSSGAN/mapper/mapper_criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mapper.facial_recognition.model_irse import Backbone 5 | 6 | class IDLoss(nn.Module): 7 | def __init__(self, opts): 8 | super(IDLoss, self).__init__() 9 | print('Loading ResNet ArcFace') 10 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 11 | self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) 12 | self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | self.opts = opts 16 | 17 | def extract_feats(self, x): 18 | if x.shape[2] != 256: 19 | x = self.pool(x) 20 | x = x[:, :, 35:223, 32:220] # Crop interesting region 21 | x = self.face_pool(x) 22 | x_feats = self.facenet(x) 23 | return x_feats 24 | 25 | def forward(self, y_hat, y): 26 | n_samples = y.shape[0] 27 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 28 | y_hat_feats = self.extract_feats(y_hat) 29 | y_feats = y_feats.detach() 30 | loss = 0 31 | sim_improvement = 0 32 | count = 0 33 | for i in range(n_samples): 34 | diff_target = y_hat_feats[i].dot(y_feats[i]) 35 | loss += 1 - diff_target 36 | count += 1 37 | 38 | return loss / count, sim_improvement / count 39 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/ZSSGAN/mapper/options/__init__.py -------------------------------------------------------------------------------- /ZSSGAN/mapper/options/test_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class TestOptions: 5 | 6 | def __init__(self): 7 | self.parser = ArgumentParser() 8 | self.initialize() 9 | 10 | def initialize(self): 11 | # arguments for inference script 12 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 13 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint') 14 | self.parser.add_argument('--couple_outputs', action='store_true', help='Whether to also save inputs + outputs side-by-side') 15 | 16 | self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') 17 | self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") 18 | self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") 19 | self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") 20 | self.parser.add_argument('--stylegan_size', default=1024, type=int) 21 | 22 | 23 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') 24 | self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation") 25 | self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers') 26 | 27 | self.parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data') 28 | 29 | def parse(self): 30 | opts = self.parser.parse_args() 31 | return opts -------------------------------------------------------------------------------- /ZSSGAN/mapper/options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class TrainOptions: 5 | 6 | def __init__(self): 7 | self.parser = ArgumentParser() 8 | self.initialize() 9 | 10 | def initialize(self): 11 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 12 | self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') 13 | self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") 14 | self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") 15 | self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") 16 | self.parser.add_argument('--latents_train_path', default=None, type=str, help="The latents for the training") 17 | self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation") 18 | self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given") 19 | self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given") 20 | 21 | self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training') 22 | self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference') 23 | self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') 24 | self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers') 25 | 26 | self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate') 27 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') 28 | 29 | self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') 30 | self.parser.add_argument('--clip_lambda', default=1.0, type=float, help='CLIP loss multiplier factor') 31 | self.parser.add_argument('--latent_l2_lambda', default=0.8, type=float, help='Latent L2 loss multiplier factor') 32 | self.parser.add_argument('--norm_lambda', default=0.5, type=float, help='CLIP embedding norm regularization factor') 33 | 34 | self.parser.add_argument('--stylegan_weights', default='../pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights') 35 | self.parser.add_argument('--stylegan_size', default=1024, type=int) 36 | self.parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss") 37 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to StyleCLIPModel model checkpoint') 38 | 39 | self.parser.add_argument('--max_steps', default=50001, type=int, help='Maximum number of training steps') 40 | self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') 41 | self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard') 42 | self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval') 43 | self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval') 44 | 45 | self.parser.add_argument('--description', required=True, type=str, help='Driving text prompt') 46 | 47 | 48 | def parse(self): 49 | opts = self.parser.parse_args() 50 | return opts -------------------------------------------------------------------------------- /ZSSGAN/mapper/scripts/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import sys 9 | import time 10 | 11 | from tqdm import tqdm 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from mapper.datasets.latents_dataset import LatentsDataset 17 | 18 | from mapper.options.test_options import TestOptions 19 | from mapper.styleclip_mapper import StyleCLIPMapper 20 | 21 | 22 | def run(test_opts): 23 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') 24 | os.makedirs(out_path_results, exist_ok=True) 25 | 26 | # update test options with options used during training 27 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 28 | opts = ckpt['opts'] 29 | opts.update(vars(test_opts)) 30 | opts = Namespace(**opts) 31 | 32 | net = StyleCLIPMapper(opts) 33 | net.eval() 34 | net.cuda() 35 | 36 | test_latents = torch.load(opts.latents_test_path) 37 | dataset = LatentsDataset(latents=test_latents.cpu(), 38 | opts=opts) 39 | dataloader = DataLoader(dataset, 40 | batch_size=opts.test_batch_size, 41 | shuffle=False, 42 | num_workers=int(opts.test_workers), 43 | drop_last=True) 44 | 45 | if opts.n_images is None: 46 | opts.n_images = len(dataset) 47 | 48 | global_i = 0 49 | global_time = [] 50 | for input_batch in tqdm(dataloader): 51 | if global_i >= opts.n_images: 52 | break 53 | with torch.no_grad(): 54 | input_cuda = input_batch.cuda().float() 55 | tic = time.time() 56 | result_batch = run_on_batch(input_cuda, net, test_opts.couple_outputs) 57 | toc = time.time() 58 | global_time.append(toc - tic) 59 | 60 | for i in range(opts.test_batch_size): 61 | im_path = str(global_i).zfill(5) 62 | if test_opts.couple_outputs: 63 | couple_output = torch.cat([result_batch[2][i].unsqueeze(0), result_batch[0][i].unsqueeze(0)]) 64 | torchvision.utils.save_image(couple_output, os.path.join(out_path_results, f"{im_path}.jpg"), normalize=True, range=(-1, 1)) 65 | else: 66 | torchvision.utils.save_image(result_batch[0][i], os.path.join(out_path_results, f"{im_path}.jpg"), normalize=True, range=(-1, 1)) 67 | torch.save(result_batch[1][i].detach().cpu(), os.path.join(out_path_results, f"latent_{im_path}.pt")) 68 | 69 | global_i += 1 70 | 71 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 72 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 73 | print(result_str) 74 | 75 | with open(stats_path, 'w') as f: 76 | f.write(result_str) 77 | 78 | 79 | def run_on_batch(inputs, net, couple_outputs=False): 80 | w = inputs 81 | with torch.no_grad(): 82 | w_hat = w + 0.1 * net.mapper(w) 83 | x_hat, w_hat = net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1) 84 | result_batch = (x_hat, w_hat) 85 | if couple_outputs: 86 | x, _ = net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1) 87 | result_batch = (x_hat, w_hat, x) 88 | return result_batch 89 | 90 | 91 | if __name__ == '__main__': 92 | test_opts = TestOptions().parse() 93 | run(test_opts) 94 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import os 5 | import json 6 | import sys 7 | import pprint 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from mapper.options.train_options import TrainOptions 13 | from mapper.training.coach import Coach 14 | 15 | 16 | def main(opts): 17 | # if os.path.exists(opts.exp_dir): 18 | # raise Exception('Oops... {} already exists'.format(opts.exp_dir)) 19 | os.makedirs(opts.exp_dir, exist_ok=True) 20 | 21 | opts_dict = vars(opts) 22 | pprint.pprint(opts_dict) 23 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 24 | json.dump(opts_dict, f, indent=4, sort_keys=True) 25 | 26 | coach = Coach(opts) 27 | coach.train() 28 | 29 | 30 | if __name__ == '__main__': 31 | opts = TrainOptions().parse() 32 | main(opts) 33 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/styleclip_mapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from mapper import latent_mappers 4 | from mapper.stylegan2.model import Generator 5 | 6 | 7 | def get_keys(d, name): 8 | if 'state_dict' in d: 9 | d = d['state_dict'] 10 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 11 | return d_filt 12 | 13 | 14 | class StyleCLIPMapper(nn.Module): 15 | 16 | def __init__(self, opts): 17 | super(StyleCLIPMapper, self).__init__() 18 | self.opts = opts 19 | # Define architecture 20 | self.mapper = self.set_mapper() 21 | self.decoder = Generator(self.opts.stylegan_size, 512, 8) 22 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 23 | # Load weights if needed 24 | self.load_weights() 25 | 26 | self.decoder.cuda() 27 | with torch.no_grad(): 28 | self.latent_avg = self.decoder.mean_latent(4096).cuda() 29 | 30 | def set_mapper(self): 31 | if self.opts.mapper_type == 'SingleMapper': 32 | mapper = latent_mappers.SingleMapper(self.opts) 33 | elif self.opts.mapper_type == 'LevelsMapper': 34 | mapper = latent_mappers.LevelsMapper(self.opts) 35 | else: 36 | raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type)) 37 | return mapper 38 | 39 | def load_weights(self): 40 | if self.opts.checkpoint_path is not None: 41 | print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) 42 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 43 | self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True) 44 | self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) 45 | else: 46 | print('Loading decoder weights from pretrained!') 47 | ckpt = torch.load(self.opts.stylegan_weights) 48 | self.decoder.load_state_dict(ckpt['g_ema'], strict=False) 49 | 50 | def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, 51 | inject_latent=None, return_latents=False, alpha=None): 52 | if input_code: 53 | codes = x 54 | else: 55 | codes = self.mapper(x) 56 | 57 | if latent_mask is not None: 58 | for i in latent_mask: 59 | if inject_latent is not None: 60 | if alpha is not None: 61 | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] 62 | else: 63 | codes[:, i] = inject_latent[:, i] 64 | else: 65 | codes[:, i] = 0 66 | 67 | input_is_latent = not input_code 68 | images, result_latent = self.decoder([codes], 69 | input_is_latent=input_is_latent, 70 | randomize_noise=randomize_noise, 71 | return_latents=return_latents) 72 | 73 | if resize: 74 | images = self.face_pool(images) 75 | 76 | if return_latents: 77 | return images, result_latent 78 | else: 79 | return images 80 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/ZSSGAN/mapper/stylegan2/__init__.py -------------------------------------------------------------------------------- /ZSSGAN/mapper/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | 11 | class FusedLeakyReLU(nn.Module): 12 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 13 | super().__init__() 14 | 15 | self.bias = nn.Parameter(torch.zeros(channel)) 16 | self.negative_slope = negative_slope 17 | self.scale = scale 18 | 19 | def forward(self, input): 20 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 21 | 22 | 23 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 24 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 25 | input = input.cuda() 26 | if input.ndim == 3: 27 | return ( 28 | F.leaky_relu( 29 | input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope 30 | ) 31 | * scale 32 | ) 33 | else: 34 | return ( 35 | F.leaky_relu( 36 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope 37 | ) 38 | * scale 39 | ) 40 | 41 | -------------------------------------------------------------------------------- /ZSSGAN/mapper/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | 11 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 12 | out = upfirdn2d_native( 13 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 14 | ) 15 | 16 | return out 17 | 18 | 19 | def upfirdn2d_native( 20 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 21 | ): 22 | _, channel, in_h, in_w = input.shape 23 | input = input.reshape(-1, in_h, in_w, 1) 24 | 25 | _, in_h, in_w, minor = input.shape 26 | kernel_h, kernel_w = kernel.shape 27 | 28 | out = input.view(-1, in_h, 1, in_w, 1, minor) 29 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 30 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 31 | 32 | out = F.pad( 33 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 34 | ) 35 | out = out[ 36 | :, 37 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 38 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 39 | :, 40 | ] 41 | 42 | out = out.permute(0, 3, 1, 2) 43 | out = out.reshape( 44 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 45 | ) 46 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 47 | out = F.conv2d(out, w) 48 | out = out.reshape( 49 | -1, 50 | minor, 51 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 52 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 53 | ) 54 | out = out.permute(0, 2, 3, 1) 55 | out = out[:, ::down_y, ::down_x, :] 56 | 57 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 58 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 59 | 60 | return out.view(-1, channel, out_h, out_w) -------------------------------------------------------------------------------- /ZSSGAN/mapper/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/ZSSGAN/mapper/training/__init__.py -------------------------------------------------------------------------------- /ZSSGAN/mapper/training/train_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def aggregate_loss_dict(agg_loss_dict): 3 | mean_vals = {} 4 | for output in agg_loss_dict: 5 | for key in output: 6 | mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] 7 | for key in mean_vals: 8 | if len(mean_vals[key]) > 0: 9 | mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) 10 | else: 11 | print('{} has no value'.format(key)) 12 | mean_vals[key] = 0 13 | return mean_vals 14 | -------------------------------------------------------------------------------- /ZSSGAN/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /ZSSGAN/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /ZSSGAN/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /ZSSGAN/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /ZSSGAN/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /ZSSGAN/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /ZSSGAN/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | from torchvision import utils 6 | 7 | import cv2 8 | 9 | def get_dir_img_list(dir_path, valid_exts=[".png", ".jpg", ".jpeg"]): 10 | file_list = [os.path.join(dir_path, file_name) for file_name in os.listdir(dir_path) 11 | if os.path.splitext(file_name)[1].lower() in valid_exts] 12 | 13 | return file_list 14 | 15 | def copytree(src, dst, symlinks=False, ignore=None): 16 | if not os.path.exists(dst): 17 | os.makedirs(dst) 18 | for item in os.listdir(src): 19 | s = os.path.join(src, item) 20 | d = os.path.join(dst, item) 21 | if os.path.isdir(s): 22 | copytree(s, d, symlinks, ignore) 23 | else: 24 | if not os.path.exists(d) or os.stat(s).st_mtime - os.stat(d).st_mtime > 1: 25 | shutil.copy2(s, d) 26 | 27 | def save_images(images: torch.Tensor, output_dir: str, file_prefix: str, nrows: int, iteration: int) -> None: 28 | utils.save_image( 29 | images, 30 | os.path.join(output_dir, f"{file_prefix}_{str(iteration).zfill(6)}.jpg"), 31 | nrow=nrows, 32 | normalize=True, 33 | range=(-1, 1), 34 | ) 35 | 36 | def save_torch_img(img: torch.Tensor, output_dir: str, file_name: str) -> None: 37 | img = img.permute(1, 2, 0).cpu().detach().numpy() 38 | 39 | img = img[:, :, ::-1] # RGB to BGR for cv2 saving 40 | cv2.imwrite(os.path.join(output_dir, file_name), img) 41 | 42 | def resize_img(img: torch.Tensor, size: int) -> torch.Tensor: 43 | return torch.nn.functional.interpolate(img.unsqueeze(0), (size, size))[0] 44 | 45 | def save_paper_image_grid(sampled_images: torch.Tensor, sample_dir: str, file_name: str): 46 | img = (sampled_images + 1.0) * 126 # de-normalize 47 | 48 | half_size = img.size()[-1] // 2 49 | quarter_size = half_size // 2 50 | 51 | base_fig = torch.cat([img[0], img[1]], dim=2) 52 | sub_cols = [torch.cat([resize_img(img[i + j], half_size) for j in range(2)], dim=1) for i in range(2, 8, 2)] 53 | base_fig = torch.cat([base_fig, *sub_cols], dim=2) 54 | 55 | sub_cols = [torch.cat([resize_img(img[i + j], quarter_size) for j in range(4)], dim=1) for i in range(8, 16, 4)] 56 | base_fig = torch.cat([base_fig, *sub_cols], dim=2) 57 | 58 | save_torch_img(base_fig, sample_dir, file_name) 59 | 60 | def save_paper_animal_grid(sampled_images: torch.Tensor, sample_dir: str, file_name: str): 61 | 62 | img = (sampled_images + 1.0) * 126 # de-normalize 63 | 64 | half_size = img.size()[-1] // 2 65 | quarter_size = half_size // 2 66 | 67 | base_fig = torch.cat([img[0]], dim=2) 68 | sub_cols = [torch.cat([resize_img(img[i + j], half_size) for j in range(2)], dim=1) for i in range(1, 5, 2)] 69 | base_fig = torch.cat([base_fig, *sub_cols], dim=2) 70 | 71 | sub_cols = [torch.cat([resize_img(img[i + j], quarter_size) for j in range(4)], dim=1) for i in range(5, 13, 4)] 72 | base_fig = torch.cat([base_fig, *sub_cols], dim=2) 73 | 74 | save_torch_img(base_fig, sample_dir, file_name) -------------------------------------------------------------------------------- /ZSSGAN/utils/text_templates.py: -------------------------------------------------------------------------------- 1 | imagenet_templates = [ 2 | 'a bad photo of a {}.', 3 | 'a sculpture of a {}.', 4 | 'a photo of the hard to see {}.', 5 | 'a low resolution photo of the {}.', 6 | 'a rendering of a {}.', 7 | 'graffiti of a {}.', 8 | 'a bad photo of the {}.', 9 | 'a cropped photo of the {}.', 10 | 'a tattoo of a {}.', 11 | 'the embroidered {}.', 12 | 'a photo of a hard to see {}.', 13 | 'a bright photo of a {}.', 14 | 'a photo of a clean {}.', 15 | 'a photo of a dirty {}.', 16 | 'a dark photo of the {}.', 17 | 'a drawing of a {}.', 18 | 'a photo of my {}.', 19 | 'the plastic {}.', 20 | 'a photo of the cool {}.', 21 | 'a close-up photo of a {}.', 22 | 'a black and white photo of the {}.', 23 | 'a painting of the {}.', 24 | 'a painting of a {}.', 25 | 'a pixelated photo of the {}.', 26 | 'a sculpture of the {}.', 27 | 'a bright photo of the {}.', 28 | 'a cropped photo of a {}.', 29 | 'a plastic {}.', 30 | 'a photo of the dirty {}.', 31 | 'a jpeg corrupted photo of a {}.', 32 | 'a blurry photo of the {}.', 33 | 'a photo of the {}.', 34 | 'a good photo of the {}.', 35 | 'a rendering of the {}.', 36 | 'a {} in a video game.', 37 | 'a photo of one {}.', 38 | 'a doodle of a {}.', 39 | 'a close-up photo of the {}.', 40 | 'a photo of a {}.', 41 | 'the origami {}.', 42 | 'the {} in a video game.', 43 | 'a sketch of a {}.', 44 | 'a doodle of the {}.', 45 | 'a origami {}.', 46 | 'a low resolution photo of a {}.', 47 | 'the toy {}.', 48 | 'a rendition of the {}.', 49 | 'a photo of the clean {}.', 50 | 'a photo of a large {}.', 51 | 'a rendition of a {}.', 52 | 'a photo of a nice {}.', 53 | 'a photo of a weird {}.', 54 | 'a blurry photo of a {}.', 55 | 'a cartoon {}.', 56 | 'art of a {}.', 57 | 'a sketch of the {}.', 58 | 'a embroidered {}.', 59 | 'a pixelated photo of a {}.', 60 | 'itap of the {}.', 61 | 'a jpeg corrupted photo of the {}.', 62 | 'a good photo of a {}.', 63 | 'a plushie {}.', 64 | 'a photo of the nice {}.', 65 | 'a photo of the small {}.', 66 | 'a photo of the weird {}.', 67 | 'the cartoon {}.', 68 | 'art of the {}.', 69 | 'a drawing of the {}.', 70 | 'a photo of the large {}.', 71 | 'a black and white photo of a {}.', 72 | 'the plushie {}.', 73 | 'a dark photo of a {}.', 74 | 'itap of a {}.', 75 | 'graffiti of the {}.', 76 | 'a toy {}.', 77 | 'itap of my {}.', 78 | 'a photo of a cool {}.', 79 | 'a photo of a small {}.', 80 | 'a tattoo of the {}.', 81 | ] 82 | 83 | part_templates = [ 84 | 'the paw of a {}.', 85 | 'the nose of a {}.', 86 | 'the eye of the {}.', 87 | 'the ears of a {}.', 88 | 'an eye of a {}.', 89 | 'the tongue of a {}.', 90 | 'the fur of the {}.', 91 | 'colorful {} fur.', 92 | 'a snout of a {}.', 93 | 'the teeth of the {}.', 94 | 'the {}s fangs.', 95 | 'a claw of the {}.', 96 | 'the face of the {}', 97 | 'a neck of a {}', 98 | 'the head of the {}', 99 | ] 100 | 101 | imagenet_templates_small = [ 102 | 'a photo of a {}.', 103 | 'a rendering of a {}.', 104 | 'a cropped photo of the {}.', 105 | 'the photo of a {}.', 106 | 'a photo of a clean {}.', 107 | 'a photo of a dirty {}.', 108 | 'a dark photo of the {}.', 109 | 'a photo of my {}.', 110 | 'a photo of the cool {}.', 111 | 'a close-up photo of a {}.', 112 | 'a bright photo of the {}.', 113 | 'a cropped photo of a {}.', 114 | 'a photo of the {}.', 115 | 'a good photo of the {}.', 116 | 'a photo of one {}.', 117 | 'a close-up photo of the {}.', 118 | 'a rendition of the {}.', 119 | 'a photo of the clean {}.', 120 | 'a rendition of a {}.', 121 | 'a photo of a nice {}.', 122 | 'a good photo of a {}.', 123 | 'a photo of the nice {}.', 124 | 'a photo of the small {}.', 125 | 'a photo of the weird {}.', 126 | 'a photo of the large {}.', 127 | 'a photo of a cool {}.', 128 | 'a photo of a small {}.', 129 | ] -------------------------------------------------------------------------------- /ZSSGAN/utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import random 4 | 5 | def make_noise(batch, latent_dim, n_noise, device): 6 | if n_noise == 1: 7 | return torch.randn(batch, latent_dim, device=device) 8 | 9 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 10 | 11 | return noises 12 | 13 | 14 | def mixing_noise(batch, latent_dim, prob, device): 15 | if prob > 0 and random.random() < prob: 16 | return make_noise(batch, latent_dim, 2, device) 17 | 18 | else: 19 | return [make_noise(batch, latent_dim, 1, device)] 20 | 21 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 22 | noise = torch.randn_like(fake_img) / math.sqrt( 23 | fake_img.shape[2] * fake_img.shape[3] 24 | ) 25 | grad, = torch.autograd.grad( 26 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 27 | ) 28 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 29 | 30 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 31 | 32 | path_penalty = (path_lengths - path_mean).pow(2).mean() 33 | 34 | return path_penalty, path_mean.detach(), path_lengths -------------------------------------------------------------------------------- /assets/pipeline-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/assets/pipeline-2.jpg -------------------------------------------------------------------------------- /eg3d/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from .util import EasyDict, make_cache_dir_path 12 | -------------------------------------------------------------------------------- /eg3d/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | //------------------------------------------------------------------------ 14 | // CUDA kernel parameters. 15 | 16 | struct bias_act_kernel_params 17 | { 18 | const void* x; // [sizeX] 19 | const void* b; // [sizeB] or NULL 20 | const void* xref; // [sizeX] or NULL 21 | const void* yref; // [sizeX] or NULL 22 | const void* dy; // [sizeX] or NULL 23 | void* y; // [sizeX] 24 | 25 | int grad; 26 | int act; 27 | float alpha; 28 | float gain; 29 | float clamp; 30 | 31 | int sizeX; 32 | int sizeB; 33 | int stepB; 34 | int loopX; 35 | }; 36 | 37 | //------------------------------------------------------------------------ 38 | // CUDA kernel selection. 39 | 40 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 41 | 42 | //------------------------------------------------------------------------ 43 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for no signs mode (no gradients required). 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for sign read mode. 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for sign write mode. 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 12 | 13 | import torch 14 | 15 | #---------------------------------------------------------------------------- 16 | 17 | def fma(a, b, c): # => a * b + c 18 | return _FusedMultiplyAdd.apply(a, b, c) 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 23 | @staticmethod 24 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 25 | out = torch.addcmul(c, a, b) 26 | ctx.save_for_backward(a, b) 27 | ctx.c_shape = c.shape 28 | return out 29 | 30 | @staticmethod 31 | def backward(ctx, dout): # pylint: disable=arguments-differ 32 | a, b = ctx.saved_tensors 33 | c_shape = ctx.c_shape 34 | da = None 35 | db = None 36 | dc = None 37 | 38 | if ctx.needs_input_grad[0]: 39 | da = _unbroadcast(dout * b, a.shape) 40 | 41 | if ctx.needs_input_grad[1]: 42 | db = _unbroadcast(dout * a, b.shape) 43 | 44 | if ctx.needs_input_grad[2]: 45 | dc = _unbroadcast(dout, c_shape) 46 | 47 | return da, db, dc 48 | 49 | #---------------------------------------------------------------------------- 50 | 51 | def _unbroadcast(x, shape): 52 | extra_dims = x.ndim - len(shape) 53 | assert extra_dims >= 0 54 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 55 | if len(dim): 56 | x = x.sum(dim=dim, keepdim=True) 57 | if extra_dims: 58 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 59 | assert x.shape == shape 60 | return x 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Custom replacement for `torch.nn.functional.grid_sample` that 12 | supports arbitrarily high order gradients between the input and output. 13 | Only works on 2D images and assumes 14 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 15 | 16 | import torch 17 | 18 | # pylint: disable=redefined-builtin 19 | # pylint: disable=arguments-differ 20 | # pylint: disable=protected-access 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | enabled = False # Enable the custom op by setting this to true. 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 62 | ctx.save_for_backward(grid) 63 | return grad_input, grad_grid 64 | 65 | @staticmethod 66 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 67 | _ = grad2_grad_grid # unused 68 | grid, = ctx.saved_tensors 69 | grad2_grad_output = None 70 | grad2_input = None 71 | grad2_grid = None 72 | 73 | if ctx.needs_input_grad[0]: 74 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 75 | 76 | assert not ctx.needs_input_grad[2] 77 | return grad2_grad_output, grad2_input, grad2_grid 78 | 79 | #---------------------------------------------------------------------------- 80 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | 15 | //------------------------------------------------------------------------ 16 | // CUDA kernel parameters. 17 | 18 | struct upfirdn2d_kernel_params 19 | { 20 | const void* x; 21 | const float* f; 22 | void* y; 23 | 24 | int2 up; 25 | int2 down; 26 | int2 pad0; 27 | int flip; 28 | float gain; 29 | 30 | int4 inSize; // [width, height, channel, batch] 31 | int4 inStride; 32 | int2 filterSize; // [width, height] 33 | int2 filterStride; 34 | int4 outSize; // [width, height, channel, batch] 35 | int4 outStride; 36 | int sizeMinor; 37 | int sizeMajor; 38 | 39 | int loopMinor; 40 | int loopMajor; 41 | int loopX; 42 | int launchMinor; 43 | int launchMajor; 44 | }; 45 | 46 | //------------------------------------------------------------------------ 47 | // CUDA kernel specialization. 48 | 49 | struct upfirdn2d_kernel_spec 50 | { 51 | void* kernel; 52 | int tileOutW; 53 | int tileOutH; 54 | int loopMinor; 55 | int loopX; 56 | }; 57 | 58 | //------------------------------------------------------------------------ 59 | // CUDA kernel selection. 60 | 61 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 62 | 63 | //------------------------------------------------------------------------ 64 | -------------------------------------------------------------------------------- /eg3d/training/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/training/crosssection_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import torch 12 | 13 | def sample_cross_section(G, ws, resolution=256, w=1.2): 14 | axis=0 15 | A, B = torch.meshgrid(torch.linspace(w/2, -w/2, resolution, device=ws.device), torch.linspace(-w/2, w/2, resolution, device=ws.device), indexing='ij') 16 | A, B = A.reshape(-1, 1), B.reshape(-1, 1) 17 | C = torch.zeros_like(A) 18 | coordinates = [A, B] 19 | coordinates.insert(axis, C) 20 | coordinates = torch.cat(coordinates, dim=-1).expand(ws.shape[0], -1, -1) 21 | 22 | sigma = G.sample_mixed(coordinates, torch.randn_like(coordinates), ws)['sigma'] 23 | return sigma.reshape(-1, 1, resolution, resolution) 24 | 25 | # if __name__ == '__main__': 26 | # sample_crossection(None) -------------------------------------------------------------------------------- /eg3d/training/volumetric_rendering/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty -------------------------------------------------------------------------------- /eg3d/training/volumetric_rendering/ray_marcher.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. 13 | Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) 14 | """ 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | class MipRayMarcher2(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | 25 | def run_forward(self, colors, densities, depths, rendering_options): 26 | deltas = depths[:, :, 1:] - depths[:, :, :-1] 27 | colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 28 | densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 29 | depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 30 | 31 | 32 | if rendering_options['clamp_mode'] == 'softplus': 33 | densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better 34 | else: 35 | assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!" 36 | 37 | density_delta = densities_mid * deltas 38 | 39 | alpha = 1 - torch.exp(-density_delta) 40 | 41 | alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) 42 | weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] 43 | 44 | composite_rgb = torch.sum(weights * colors_mid, -2) 45 | weight_total = weights.sum(2) 46 | composite_depth = torch.sum(weights * depths_mid, -2) / weight_total 47 | 48 | # clip the composite to min/max range of depths 49 | composite_depth = torch.nan_to_num(composite_depth, float('inf')) 50 | composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) 51 | 52 | if rendering_options.get('white_back', False): 53 | composite_rgb = composite_rgb + 1 - weight_total 54 | 55 | composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) 56 | 57 | return composite_rgb, composite_depth, weights 58 | 59 | 60 | def forward(self, colors, densities, depths, rendering_options): 61 | composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) 62 | 63 | return composite_rgb, composite_depth, weights -------------------------------------------------------------------------------- /eg3d/training/volumetric_rendering/ray_sampler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | The ray sampler is a module that takes in camera matrices and resolution and batches of rays. 13 | Expects cam2world matrices that use the OpenCV camera coordinate system conventions. 14 | """ 15 | 16 | import torch 17 | 18 | class RaySampler(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None 22 | 23 | 24 | def forward(self, cam2world_matrix, intrinsics, resolution): 25 | """ 26 | Create batches of rays and return origins and directions. 27 | 28 | cam2world_matrix: (N, 4, 4) 29 | intrinsics: (N, 3, 3) 30 | resolution: int 31 | 32 | ray_origins: (N, M, 3) 33 | ray_dirs: (N, M, 2) 34 | """ 35 | N, M = cam2world_matrix.shape[0], resolution**2 36 | cam_locs_world = cam2world_matrix[:, :3, 3] 37 | fx = intrinsics[:, 0, 0] 38 | fy = intrinsics[:, 1, 1] 39 | cx = intrinsics[:, 0, 2] 40 | cy = intrinsics[:, 1, 2] 41 | sk = intrinsics[:, 0, 1] 42 | 43 | uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1./resolution) + (0.5/resolution) 44 | uv = uv.flip(0).reshape(2, -1).transpose(1, 0) 45 | uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) 46 | 47 | x_cam = uv[:, :, 0].view(N, -1) 48 | y_cam = uv[:, :, 1].view(N, -1) 49 | z_cam = torch.ones((N, M), device=cam2world_matrix.device) 50 | 51 | x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam 52 | y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam 53 | 54 | cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) 55 | 56 | world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] 57 | 58 | ray_dirs = world_rel_points - cam_locs_world[:, None, :] 59 | ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) 60 | 61 | ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) 62 | 63 | return ray_origins, ray_dirs -------------------------------------------------------------------------------- /preprocess/extract_landmark.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import tqdm 4 | import glob 5 | import face_alignment 6 | import numpy as np 7 | import torch 8 | 9 | 10 | detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D) 11 | # detector = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D) # update to TWO_D if using newest face_alignment library 12 | 13 | 14 | def get_landmark(image): 15 | """ 16 | :param images: PIL 17 | :return: numpy (1, 68) 18 | """ 19 | lm = detector.get_landmarks_from_image(np.array(image)) 20 | assert lm is not None, 'No face detect error!' 21 | # lm_np = np.expand_dims(lm[0], axis=0) 22 | return lm[0] 23 | 24 | 25 | def extract_landmark(input_dir, output_dir, mode='png'): 26 | os.makedirs(output_dir, exist_ok=True) 27 | # root = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/cropped/' 28 | # out = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/mask/' 29 | 30 | image_list = sorted(glob.glob(f'{input_dir}/*.{mode}')) 31 | # for image_path in tqdm.tqdm(image_list): 32 | for image_path in image_list: 33 | image_name = os.path.basename(image_path).split('.')[0] 34 | image = Image.open(image_path).convert("RGB").resize((256, 256)) 35 | lm = get_landmark(image) 36 | np.save(os.path.join(output_dir, image_name + '.npy'), lm) 37 | 38 | 39 | # extract_landmark(input_dir='/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/hua_2/cropped', 40 | # output_dir='/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/hua_2/lm/') 41 | -------------------------------------------------------------------------------- /preprocess/extract_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath('.')) 4 | 5 | import numpy as np 6 | from PIL import Image 7 | import torch 8 | import glob 9 | import tqdm 10 | import torch.nn.functional as F 11 | from spi.utils.load_utils import load_bisenet 12 | 13 | 14 | def cal_face_mask(bisenet, image): 15 | # print(image.shape) 16 | target_images = image.clone() 17 | out = bisenet(target_images)[0] # 1, 19, 512, 512 18 | # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 19 | # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] 20 | att_list = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13] 21 | parsing = torch.argmax(out, dim=1, keepdim=True) 22 | mask = torch.zeros(image.shape[0], 1, 512, 512).cuda() 23 | for att in att_list: 24 | mask += (parsing == att) 25 | # log_image(mask, 'predict_mask_512') 26 | mask = F.interpolate(mask, size=(256, 256), mode='nearest') 27 | # print('mask shape', mask.shape) 28 | # log_image(mask, 'predict_mask') 29 | return mask 30 | 31 | 32 | bisenet = load_bisenet() 33 | 34 | 35 | def cal_mask(bisenet, image): 36 | # print(image.shape) 37 | target_images = image.clone() 38 | out = bisenet(target_images)[0] # 1, 19, 512, 512 39 | # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 40 | # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] 41 | # att_list = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13] 42 | parsing = torch.argmax(out, dim=1, keepdim=True) 43 | # mask = torch.zeros(image.shape[0], 1, 512, 512).cuda() 44 | # for att in att_list: 45 | # mask += (parsing == att) 46 | # log_image(mask, 'predict_mask_512') 47 | # mask = F.interpolate(mask, size=(256, 256), mode='nearest') 48 | # print('mask shape', mask.shape) 49 | # log_image(mask, 'predict_mask') 50 | return parsing 51 | 52 | def extract_mask(input_dir, output_dir, mode='png'): 53 | image_list = sorted(glob.glob(f'{input_dir}/*.{mode}')) 54 | # for image_path in tqdm.tqdm(image_list): 55 | for image_path in image_list: 56 | image_name = os.path.basename(image_path).split('.')[0] 57 | image = Image.open(image_path).resize((512, 512)) 58 | image = torch.from_numpy(np.asarray(image)).unsqueeze(0).permute(0, 3, 1, 2) 59 | # print(image.shape) 60 | image = (image.to('cuda').to(torch.float32) / 127.5 - 1) 61 | mask = cal_mask(bisenet, image) 62 | torch.save(mask.cpu(), os.path.join(output_dir, image_name + '.pt')) 63 | 64 | 65 | def main(): 66 | root = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/cropped/' 67 | out = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/mask/' 68 | 69 | 70 | # root = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/personal/origin/barack_obama/' 71 | 72 | image_list = sorted(glob.glob(f'{root}/*.png')) 73 | for image_path in tqdm.tqdm(image_list): 74 | image_name = os.path.basename(image_path).split('.')[0] 75 | image = Image.open(image_path).resize((512, 512)) 76 | image = torch.from_numpy(np.asarray(image)).unsqueeze(0).permute(0, 3, 1, 2) 77 | # print(image.shape) 78 | image = (image.to('cuda').to(torch.float32) / 127.5 - 1) 79 | mask = cal_face_mask(bisenet, image) 80 | torch.save(mask.cpu(), os.path.join(out, image_name + '.pt')) -------------------------------------------------------------------------------- /preprocess/mirror_padding.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import PIL.Image as Image 5 | import tqdm 6 | import scipy.ndimage 7 | 8 | # input_dir = '/home/jachoon/Desktop/Code/pose-free_3dgan_inversion/*.jpg' 9 | # out_dir = './padded' 10 | input_dir = '/apdcephfs_cq2/share_1290939/feiiyin/dataset/CelebAHQ_test/*.jpg' 11 | out_dir = '/apdcephfs_cq2/share_1290939/feiiyin/dataset/CelebAHQ_test_padding/' 12 | if os.path.isdir(out_dir) == 0: 13 | os.mkdir(out_dir) 14 | 15 | image_list = glob.glob(input_dir) 16 | stripstr = input_dir.split('/')[-1].lstrip('*') 17 | for i in tqdm.tqdm(image_list): 18 | image_name = i.split('/')[-1].rstrip(stripstr) 19 | # if image_name not in ['10252', '27065']: 20 | # continue 21 | image = np.array(Image.open(i).convert('RGB')) 22 | pad_num = 250 23 | image_pad= np.pad(image, ((pad_num, pad_num),(pad_num, pad_num),(0, 0)), 'reflect') 24 | h, w, _ = image_pad.shape 25 | y, x, _ = np.mgrid[:h, :w, :1] 26 | # mask = np.ones_like(image_pad)[:, :, :1] 27 | 28 | # mask[pad_num:pad_num+1024, pad_num:pad_num+1024, :] = np.zeros_like(mask[pad_num:pad_num+1024, pad_num:pad_num+1024, :]) 29 | 30 | mask = 1.0 - np.minimum(np.minimum(np.float32(x) / pad_num, np.float32(y) / pad_num), np.minimum(np.float32(w-1-x) / pad_num, np.float32(h-1-y) / pad_num)) 31 | 32 | #for j in range(1, 20): 33 | # import pdb; pdb.set_trace() 34 | image_pad = image_pad.astype(np.float32) 35 | image_pad += (scipy.ndimage.gaussian_filter(image_pad, [5,5,0]) - image_pad) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 36 | 37 | #image_pad += ((np.median(image_pad, axis=(0,1)) - image_pad) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)).astype(np.uint8) 38 | Image.fromarray(image_pad.astype(np.uint8)).save(out_dir + image_name + '.png') -------------------------------------------------------------------------------- /preprocess/process_camera.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from PIL import Image, ImageOps 4 | import os 5 | from tqdm import tqdm 6 | import argparse 7 | 8 | 9 | def fix_intrinsics(intrinsics): 10 | intrinsics = np.array(intrinsics).copy() 11 | assert intrinsics.shape == (3, 3), intrinsics 12 | intrinsics[0,0] = 2985.29/700 13 | intrinsics[1,1] = 2985.29/700 14 | intrinsics[0,2] = 1/2 15 | intrinsics[1,2] = 1/2 16 | assert intrinsics[0,1] == 0 17 | assert intrinsics[2,2] == 1 18 | assert intrinsics[1,0] == 0 19 | assert intrinsics[2,0] == 0 20 | assert intrinsics[2,1] == 0 21 | return intrinsics 22 | 23 | def fix_pose(pose): 24 | COR = np.array([0, 0, 0.175]) 25 | pose = np.array(pose).copy() 26 | location = pose[:3, 3] 27 | direction = (location - COR) / np.linalg.norm(location - COR) 28 | pose[:3, 3] = direction * 2.7 + COR 29 | return pose 30 | 31 | def fix_pose_orig(pose): 32 | pose = np.array(pose).copy() 33 | location = pose[:3, 3] 34 | radius = np.linalg.norm(location) 35 | pose[:3, 3] = pose[:3, 3]/radius * 2.7 36 | return pose 37 | 38 | def flip_yaw(pose_matrix): 39 | flipped = pose_matrix.copy() 40 | flipped[0, 1] *= -1 41 | flipped[0, 2] *= -1 42 | flipped[1, 0] *= -1 43 | flipped[2, 0] *= -1 44 | flipped[0, 3] *= -1 45 | return flipped 46 | 47 | 48 | def process_camera(pose, intrinsics): 49 | 50 | # if args.mode == 'cor': 51 | # pose = fix_pose(pose) 52 | # elif args.mode == 'orig': 53 | pose = fix_pose_orig(pose) 54 | # else: 55 | # assert False, "invalid mode" 56 | intrinsics = fix_intrinsics(intrinsics) 57 | label = np.concatenate([pose.reshape(-1), intrinsics.reshape(-1)]) 58 | return label 59 | 60 | 61 | # if __name__ == '__main__': 62 | # parser = argparse.ArgumentParser() 63 | # parser.add_argument("--source", type=str) 64 | # parser.add_argument("--dest", type=str, default=None) 65 | # parser.add_argument("--max_images", type=int, default=None) 66 | # parser.add_argument("--mode", type=str, default="orig", choices=["orig", "cor"]) 67 | # args = parser.parse_args() 68 | 69 | # camera_dataset_file = os.path.join(args.source, 'cameras.json') 70 | 71 | # with open(camera_dataset_file, "r") as f: 72 | # cameras = json.load(f) 73 | 74 | # dataset = {'labels':[]} 75 | 76 | # max_images = args.max_images if args.max_images is not None else len(cameras) 77 | # for i, filename in tqdm(enumerate(cameras), total=max_images): 78 | # if (max_images is not None and i >= max_images): break 79 | 80 | # pose = cameras[filename]['pose'] 81 | # intrinsics = cameras[filename]['intrinsics'] 82 | 83 | # if args.mode == 'cor': 84 | # pose = fix_pose(pose) 85 | # elif args.mode == 'orig': 86 | # pose = fix_pose_orig(pose) 87 | # else: 88 | # assert False, "invalid mode" 89 | # intrinsics = fix_intrinsics(intrinsics) 90 | # label = np.concatenate([pose.reshape(-1), intrinsics.reshape(-1)]).tolist() 91 | 92 | # image_path = os.path.join(args.source, filename) 93 | # img = Image.open(image_path) 94 | 95 | # dataset["labels"].append([filename, label]) 96 | # os.makedirs(os.path.dirname(os.path.join(args.dest, filename)), exist_ok=True) 97 | # img.save(os.path.join(args.dest, filename)) 98 | 99 | 100 | # flipped_img = ImageOps.mirror(img) 101 | # flipped_pose = flip_yaw(pose) 102 | # label = np.concatenate([flipped_pose.reshape(-1), intrinsics.reshape(-1)]).tolist() 103 | # base, ext = filename.split('.')[0], '.' + filename.split('.')[1] 104 | # flipped_filename = base + '_mirror' + ext 105 | # dataset["labels"].append([flipped_filename, label]) 106 | # flipped_img.save(os.path.join(args.dest, flipped_filename)) -------------------------------------------------------------------------------- /preprocess/run_crop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath('.')) 4 | import glob 5 | import tqdm 6 | 7 | 8 | from preprocess.extract_camera import CameraExtractor 9 | 10 | root = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/hua_2/frames/' 11 | out = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/hua_2/cropped/' 12 | os.makedirs(out, exist_ok=True) 13 | extractor = CameraExtractor(outdir=out) 14 | # root = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/personal/origin/barack_obama/' 15 | 16 | 17 | 18 | image_list = sorted(glob.glob(f'{root}/*.jpg')) 19 | for image_path in tqdm.tqdm(image_list): 20 | print(image_path) 21 | extractor.extract(image_path) -------------------------------------------------------------------------------- /preprocess/run_total.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath('.')) 4 | sys.path.append(os.path.abspath('./eg3d/')) 5 | import argparse 6 | import glob 7 | import tqdm 8 | 9 | from preprocess.extract_camera import CameraExtractor 10 | from preprocess.extract_landmark import extract_landmark 11 | from preprocess.extract_mask import extract_mask 12 | from preprocess.video2frames import video2frames, video2frames_mirror 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Training') 17 | parser.add_argument('--input_root', type=str, default='./test/images/') 18 | parser.add_argument('--output_root', type=str, default='./test/dataset/') 19 | parser.add_argument('--mode', type=str, default='jpg') 20 | args = parser.parse_args() 21 | return args 22 | 23 | def main(): 24 | args = parse_args() 25 | # Real dataset 26 | input_dir = args.input_root 27 | root = args.output_root 28 | 29 | image_input = os.path.join(root, f'input') 30 | c_out = os.path.join(root, f'c') 31 | crop_out = os.path.join(root, f'crop') 32 | lm_out = os.path.join(root, f'lm') 33 | mask_out = os.path.join(root, f'mask') 34 | 35 | for out in [image_input, c_out, crop_out, lm_out, mask_out]: 36 | os.makedirs(out, exist_ok=True) 37 | 38 | extractor = CameraExtractor(crop_outdir=None, c_outdir=None, mode=None) 39 | 40 | mode = args.mode 41 | image_list = sorted(glob.glob(f'{input_dir}/*.{mode}')) 42 | 43 | # filter_name_list = ['taylor_swift_4'] 44 | # image_list = [f'{input_dir}/{name_i}.{mode}' for name_i in filter_name_list] 45 | 46 | # Transform the png into jpg 47 | # for image_path in tqdm.tqdm(image_list): 48 | # Image.open(image_path).convert('RGB').save(image_path.replace(mode, 'png')) 49 | 50 | for image_path in tqdm.tqdm(image_list): 51 | try: 52 | name = os.path.basename(image_path).split('.')[0] 53 | 54 | frame_root = os.path.join(image_input, name) 55 | os.makedirs(frame_root, exist_ok=True) 56 | 57 | frame_crop_path = os.path.join(crop_out, name, f'target.{mode}') 58 | if not os.path.exists(frame_crop_path): 59 | os.system(f'cp {image_path} {frame_root}/target.{mode}') 60 | 61 | # video2frames 62 | # video_path = os.path.join(video_dir, f'{name}.mp4') 63 | # if os.path.exists(video_path): 64 | # os.system(f'cp {video_path} {video_input}/{name}.mp4') 65 | # video2frames(video_path=video_path, output_dir=frame_root, mode=mode) 66 | 67 | # Crop & C 68 | _crop_outdir = os.path.join(crop_out, name) 69 | _c_outdir = os.path.join(c_out, name) 70 | os.makedirs(_crop_outdir, exist_ok=True) 71 | os.makedirs(_c_outdir, exist_ok=True) 72 | extractor.set_path(crop_outdir=_crop_outdir, c_outdir=_c_outdir, mode=mode) 73 | 74 | frame_list = sorted(glob.glob(f'{frame_root}/*.{mode}')) 75 | for f in frame_list: 76 | extractor.extract(f) 77 | 78 | # Lm 79 | _lm_outdir = os.path.join(lm_out, name) 80 | os.makedirs(_lm_outdir, exist_ok=True) 81 | extract_landmark(input_dir=_crop_outdir, output_dir=_lm_outdir, mode=mode) 82 | 83 | # Mask 84 | _mask_outdir = os.path.join(mask_out, name) 85 | os.makedirs(_mask_outdir, exist_ok=True) 86 | extract_mask(input_dir=_crop_outdir, output_dir=_mask_outdir, mode=mode) 87 | except Exception: 88 | print(image_path, name) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() -------------------------------------------------------------------------------- /preprocess/transform_into_goae_data_format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import numpy as np 5 | import argparse 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description='Training') 9 | parser.add_argument('--input_root', type=str, default=None) 10 | parser.add_argument('--output_root', type=str, default=None) 11 | args = parser.parse_args() 12 | return args 13 | 14 | args = parse_args() 15 | spi_output_root = args.input_root # "/weka/home-feiyin/code/GOAE/example/spi_output" 16 | output_root = args.output_root # "example/input" 17 | os.makedirs(output_root, exist_ok=True) 18 | 19 | image_path_list = sorted(glob.glob(os.path.join(spi_output_root, "crop", "*"))) 20 | output_json = {"labels": []} 21 | 22 | for image_root in image_path_list: 23 | image_name = os.path.basename(image_root) 24 | mode = "jpg" 25 | image_path = os.path.join(image_root, f"target.{mode}") 26 | if not os.path.exists(image_path): 27 | mode = "png" 28 | image_path = os.path.join(image_root, f"target.{mode}") 29 | 30 | target_path = os.path.join(output_root, f"{image_name}.{mode}") 31 | cmd = f"cp {image_path} {target_path}" 32 | os.system(cmd) 33 | 34 | camera_path = os.path.join(spi_output_root, "c", image_name, "target.npy") 35 | camera = np.load(camera_path) 36 | 37 | output_json["labels"].append([f"{image_name}.{mode}", camera.tolist()]) 38 | 39 | output_json_path = os.path.join(output_root, "label.json") 40 | with open(output_json_path, 'w', encoding='utf-8') as f: 41 | json.dump(output_json, f, ensure_ascii=False, indent=4) 42 | 43 | -------------------------------------------------------------------------------- /preprocess/video2frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import cv2 4 | 5 | 6 | def read_video(filename, uplimit=1000): 7 | frames = [] 8 | cap = cv2.VideoCapture(filename) 9 | cnt = 0 10 | while cap.isOpened(): 11 | ret, frame = cap.read() 12 | if ret: 13 | frame = cv2.resize(frame, (512, 512)) 14 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 15 | frames.append(frame) 16 | else: 17 | break 18 | cnt += 1 19 | if cnt >= uplimit: 20 | break 21 | cap.release() 22 | assert len(frames) > 0, f'{filename}: video with no frames!' 23 | return frames 24 | 25 | 26 | 27 | def video2frames(video_path, output_dir, mode, prefix=None): 28 | os.makedirs(output_dir, exist_ok=True) 29 | 30 | start_i = 0 31 | frames = read_video(video_path) 32 | for _i in range(len(frames)): 33 | path = os.path.join(output_dir, f'{start_i:03d}.{mode}') 34 | if prefix is not None: 35 | path = os.path.join(output_dir, f'{prefix}#{start_i:03d}.{mode}') 36 | Image.fromarray(frames[_i]).save(path) 37 | start_i += 1 38 | 39 | 40 | def video2frames_mirror(video_path, output_dir, mode): 41 | os.makedirs(output_dir, exist_ok=True) 42 | 43 | start_i = 0 44 | frames = read_video(video_path) 45 | for _i in range(len(frames)): 46 | Image.fromarray(frames[_i]).save(os.path.join(output_dir, f'{start_i:03d}.{mode}')) 47 | start_i += 1 48 | 49 | frames = read_video(video_path) 50 | for _i in range(len(frames)): 51 | Image.fromarray(frames[_i]).transpose(Image.FLIP_LEFT_RIGHT).save(os.path.join(output_dir, f'{start_i:03d}.{mode}')) 52 | start_i += 1 53 | 54 | 55 | 56 | 57 | def main(): 58 | input_path = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/hua_2' 59 | output_dir = '/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D_result/mpti_output/pti_input/hua_2/frames' 60 | 61 | os.makedirs(output_dir, exist_ok=True) 62 | 63 | # videos = [os.path.join(input_path, 'hua_gen.mp4'), os.path.join(input_path, 'hua_flip_shape.mp4')] 64 | videos = [os.path.join(input_path, 'hua_flip_shape.mp4')] 65 | 66 | start_i = 0 67 | for video_path in videos: 68 | print(video_path) 69 | frames = read_video(video_path) 70 | for _i in range(len(frames)): 71 | Image.fromarray(frames[_i]).save(os.path.join(output_dir, f'{start_i:03d}.jpg')) 72 | start_i += 1 73 | -------------------------------------------------------------------------------- /spi/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/spi/configs/__init__.py -------------------------------------------------------------------------------- /spi/configs/global_config.py: -------------------------------------------------------------------------------- 1 | ## Device 2 | cuda_visible_devices = '0' 3 | device = 'cuda:0' 4 | 5 | ## Logs 6 | training_step = 1 7 | log_snapshot = 500 8 | pivotal_training_steps = 0 9 | model_snapshot_interval = 400 10 | 11 | ## Run name to be updated during PTI 12 | run_name = '' 13 | -------------------------------------------------------------------------------- /spi/configs/hyperparameters.py: -------------------------------------------------------------------------------- 1 | ## Architechture 2 | lpips_type = 'vgg' 3 | max_images_to_invert = 3000 4 | 5 | # w stage 6 | use_encoder = False 7 | use_G_avg = False 8 | first_inv_type = 'sg' # 'mir', 'sgw+' 9 | optim_type = 'adam' 10 | first_inv_steps = 500 11 | 12 | # G stage 1 13 | LPIPS_value_threshold = 0.05 14 | G_1_step = 0 15 | G_1_type = None 16 | G_2_step = 0 17 | use_adapt_yaw_range = False 18 | description = None 19 | 20 | 21 | ## Locality regularization 22 | latent_ball_num_of_samples = 1 23 | locality_regularization_interval = 1 24 | use_locality_regularization = False 25 | regulizer_l2_lambda = 0.1 26 | regulizer_lpips_lambda = 0.1 27 | regulizer_alpha = 30 28 | reg_w_loss_weight = 1 29 | 30 | ## Loss 31 | pt_l2_lambda = 1 32 | pt_lpips_lambda = 1 33 | pt_tv_lambda = 0 # 0.25 34 | pt_rot_lambda = 0.1 35 | pt_mirror_rot_lambda = 0.05 36 | pt_depth_lambda = 1 37 | 38 | 39 | ## Optimization 40 | pti_learning_rate = 3e-4 41 | first_inv_lr = 5e-3 42 | train_batch_size = 1 43 | use_last_w_pivots = False 44 | load_embedding_coach_name = None 45 | w_space_index = 14 46 | -------------------------------------------------------------------------------- /spi/configs/paths_config.py: -------------------------------------------------------------------------------- 1 | EG3D_PATH = 'checkpoints/ffhqrebalanced512-128.pkl' 2 | IDLOSS_PATH = 'checkpoints/model_ir_se50.pth' 3 | LPIPS_PATH = '' 4 | BISENET_PATH = 'checkpoints/bisenet.pth' 5 | VGG_PATH = 'checkpoints/vgg16.pt' 6 | 7 | 8 | # Dirs for output files 9 | root = 'test/output/' 10 | checkpoints_dir = root + 'checkpoints/' 11 | embedding_base_dir = root + 'embedding/' 12 | experiments_output_dir = root + 'experiments/' 13 | # Used in the final test 14 | images_output_dir = root + 'image/' 15 | mirror_images_output_dir = root + 'image_m/' 16 | video_output_dir = root + 'video/' 17 | -------------------------------------------------------------------------------- /spi/criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/spi/criteria/__init__.py -------------------------------------------------------------------------------- /spi/criteria/id_loss/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import math 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 5 | 6 | 7 | """ 8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 9 | """ 10 | 11 | 12 | class Flatten(Module): 13 | def forward(self, input): 14 | return input.view(input.size(0), -1) 15 | 16 | 17 | def l2_norm(input, axis=1): 18 | norm = torch.norm(input, 2, axis, True) 19 | output = torch.div(input, norm) 20 | return output 21 | 22 | 23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 24 | """ A named tuple describing a ResNet block. """ 25 | 26 | 27 | def get_block(in_channel, depth, num_units, stride=2): 28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 29 | 30 | 31 | def get_blocks(num_layers): 32 | if num_layers == 50: 33 | blocks = [ 34 | get_block(in_channel=64, depth=64, num_units=3), 35 | get_block(in_channel=64, depth=128, num_units=4), 36 | get_block(in_channel=128, depth=256, num_units=14), 37 | get_block(in_channel=256, depth=512, num_units=3) 38 | ] 39 | elif num_layers == 100: 40 | blocks = [ 41 | get_block(in_channel=64, depth=64, num_units=3), 42 | get_block(in_channel=64, depth=128, num_units=13), 43 | get_block(in_channel=128, depth=256, num_units=30), 44 | get_block(in_channel=256, depth=512, num_units=3) 45 | ] 46 | elif num_layers == 152: 47 | blocks = [ 48 | get_block(in_channel=64, depth=64, num_units=3), 49 | get_block(in_channel=64, depth=128, num_units=8), 50 | get_block(in_channel=128, depth=256, num_units=36), 51 | get_block(in_channel=256, depth=512, num_units=3) 52 | ] 53 | else: 54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 55 | return blocks 56 | 57 | 58 | class SEModule(Module): 59 | def __init__(self, channels, reduction): 60 | super(SEModule, self).__init__() 61 | self.avg_pool = AdaptiveAvgPool2d(1) 62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 63 | self.relu = ReLU(inplace=True) 64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 65 | self.sigmoid = Sigmoid() 66 | 67 | def forward(self, x): 68 | module_input = x 69 | x = self.avg_pool(x) 70 | x = self.fc1(x) 71 | x = self.relu(x) 72 | x = self.fc2(x) 73 | x = self.sigmoid(x) 74 | return module_input * x 75 | 76 | 77 | class bottleneck_IR(Module): 78 | def __init__(self, in_channel, depth, stride): 79 | super(bottleneck_IR, self).__init__() 80 | if in_channel == depth: 81 | self.shortcut_layer = MaxPool2d(1, stride) 82 | else: 83 | self.shortcut_layer = Sequential( 84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 85 | BatchNorm2d(depth) 86 | ) 87 | self.res_layer = Sequential( 88 | BatchNorm2d(in_channel), 89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 91 | ) 92 | 93 | def forward(self, x): 94 | shortcut = self.shortcut_layer(x) 95 | res = self.res_layer(x) 96 | return res + shortcut 97 | 98 | 99 | class bottleneck_IR_SE(Module): 100 | def __init__(self, in_channel, depth, stride): 101 | super(bottleneck_IR_SE, self).__init__() 102 | if in_channel == depth: 103 | self.shortcut_layer = MaxPool2d(1, stride) 104 | else: 105 | self.shortcut_layer = Sequential( 106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 107 | BatchNorm2d(depth) 108 | ) 109 | self.res_layer = Sequential( 110 | BatchNorm2d(in_channel), 111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 112 | PReLU(depth), 113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 114 | BatchNorm2d(depth), 115 | SEModule(depth, 16) 116 | ) 117 | 118 | def forward(self, x): 119 | shortcut = self.shortcut_layer(x) 120 | res = self.res_layer(x) 121 | return res + shortcut 122 | -------------------------------------------------------------------------------- /spi/criteria/id_loss/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .model_irse import Backbone 5 | 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self, path_ir_se50, num_scales=1): 9 | super(IDLoss, self).__init__() 10 | # print(f'Loading ResNet ArcFace from: {path_ir_se50}') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(path_ir_se50)) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | self.num_scales = num_scales 16 | 17 | def extract_feats(self, x): 18 | x = x[:, :, 35:223, 32:220] # Crop interesting region 19 | x = self.face_pool(x) 20 | x_feats = self.facenet(x) 21 | return x_feats 22 | 23 | def calculate_similarity(self, x, y): 24 | assert x.shape[0] == 1 25 | x_feats = self.extract_feats(x) 26 | y_feats = self.extract_feats(y) 27 | simmilar = x_feats[0].dot(y_feats[0]) 28 | return simmilar 29 | 30 | def calculate_batch_similarity(self, x, y): 31 | x_feats = self.extract_feats(x) 32 | y_feats = self.extract_feats(y) # torch.Size([100, 512]) torch.Size([100, 512]) 33 | simmilar = x_feats * y_feats 34 | simmilar = torch.sum(simmilar, dim=-1) 35 | simmilar = torch.mean(simmilar) 36 | return simmilar 37 | 38 | def forward(self, x, y): 39 | n_samples = x.shape[0] 40 | loss = 0.0 41 | for _scale in range(self.num_scales): 42 | x_feats = self.extract_feats(x) 43 | y_feats = self.extract_feats(y) 44 | for i in range(n_samples): 45 | diff_target = y_feats[i].dot(x_feats[i]) 46 | loss += 1 - diff_target 47 | 48 | if _scale != self.num_scales - 1: 49 | x = F.interpolate(x, mode='bilinear', scale_factor=0.5, align_corners=False, recompute_scale_factor=True) 50 | y = F.interpolate(y, mode='bilinear', scale_factor=0.5, align_corners=False, recompute_scale_factor=True) 51 | return loss / n_samples 52 | 53 | def psp_forward(self, y_hat, y, x): 54 | n_samples = x.shape[0] 55 | x_feats = self.extract_feats(x) 56 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 57 | y_hat_feats = self.extract_feats(y_hat) 58 | y_feats = y_feats.detach() 59 | loss = 0 60 | sim_improvement = 0 61 | id_logs = [] 62 | count = 0 63 | for i in range(n_samples): 64 | diff_target = y_hat_feats[i].dot(y_feats[i]) 65 | diff_input = y_hat_feats[i].dot(x_feats[i]) 66 | diff_views = y_feats[i].dot(x_feats[i]) 67 | id_logs.append({'diff_target': float(diff_target), 68 | 'diff_input': float(diff_input), 69 | 'diff_views': float(diff_views)}) 70 | loss += 1 - diff_target 71 | id_diff = float(diff_target) - float(diff_views) 72 | sim_improvement += id_diff 73 | count += 1 74 | 75 | return loss / count, sim_improvement / count, id_logs 76 | -------------------------------------------------------------------------------- /spi/criteria/id_loss/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | from torch_utils import persistence 4 | 5 | """ 6 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Backbone(Module): 11 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 12 | super(Backbone, self).__init__() 13 | assert input_size in [112, 224], "input_size should be 112 or 224" 14 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 15 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 16 | blocks = get_blocks(num_layers) 17 | if mode == 'ir': 18 | unit_module = bottleneck_IR 19 | elif mode == 'ir_se': 20 | unit_module = bottleneck_IR_SE 21 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 22 | BatchNorm2d(64), 23 | PReLU(64)) 24 | if input_size == 112: 25 | self.output_layer = Sequential(BatchNorm2d(512), 26 | Dropout(drop_ratio), 27 | Flatten(), 28 | Linear(512 * 7 * 7, 512), 29 | BatchNorm1d(512, affine=affine)) 30 | else: 31 | self.output_layer = Sequential(BatchNorm2d(512), 32 | Dropout(drop_ratio), 33 | Flatten(), 34 | Linear(512 * 14 * 14, 512), 35 | BatchNorm1d(512, affine=affine)) 36 | 37 | modules = [] 38 | for block in blocks: 39 | for bottleneck in block: 40 | modules.append(unit_module(bottleneck.in_channel, 41 | bottleneck.depth, 42 | bottleneck.stride)) 43 | self.body = Sequential(*modules) 44 | 45 | def forward(self, x): 46 | x = self.input_layer(x) 47 | x = self.body(x) 48 | x = self.output_layer(x) 49 | return l2_norm(x) 50 | 51 | 52 | def IR_50(input_size): 53 | """Constructs a ir-50 model.""" 54 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 55 | return model 56 | 57 | 58 | def IR_101(input_size): 59 | """Constructs a ir-101 model.""" 60 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 61 | return model 62 | 63 | 64 | def IR_152(input_size): 65 | """Constructs a ir-152 model.""" 66 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 67 | return model 68 | 69 | 70 | def IR_SE_50(input_size): 71 | """Constructs a ir_se-50 model.""" 72 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 73 | return model 74 | 75 | 76 | def IR_SE_101(input_size): 77 | """Constructs a ir_se-101 model.""" 78 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 79 | return model 80 | 81 | 82 | def IR_SE_152(input_size): 83 | """Constructs a ir_se-152 model.""" 84 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 85 | return model -------------------------------------------------------------------------------- /spi/criteria/l2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | l2_criterion = torch.nn.MSELoss(reduction='mean') 4 | 5 | 6 | def l2_loss(real_images, generated_images): 7 | loss = l2_criterion(real_images, generated_images) 8 | return loss 9 | -------------------------------------------------------------------------------- /spi/criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/spi/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /spi/criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | from heapq import nsmallest 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .networks import get_network, LinLayers 7 | from .utils import get_state_dict 8 | 9 | 10 | class LPIPS(nn.Module): 11 | r"""Creates a criterion that measures 12 | Learned Perceptual Image Patch Similarity (LPIPS). 13 | Arguments: 14 | net_type (str): the network type to compare the features: 15 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 16 | version (str): the version of LPIPS. Default: 0.1. 17 | """ 18 | def __init__(self, net_type: str='alex', version: str='0.1', num_scales=1): 19 | 20 | assert version in ['0.1'], 'v0.1 is only supported now' 21 | 22 | super(LPIPS, self).__init__() 23 | 24 | # pretrained network 25 | self.net = get_network(net_type).to("cuda") 26 | 27 | # linear layers 28 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 29 | self.lin.load_state_dict(get_state_dict(net_type, version)) 30 | self.num_scales = num_scales 31 | 32 | def forward(self, x: torch.Tensor, y: torch.Tensor, conf_sigma=None, mask=None): 33 | EPS = 1e-7 34 | loss = 0.0 35 | n_samples = x.shape[0] 36 | 37 | if x.shape[-1] > 256: 38 | x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False) 39 | y = F.interpolate(y, size=(256, 256), mode='bilinear', align_corners=False) 40 | 41 | for _scale in range(self.num_scales): 42 | # import pdb; pdb.set_trace() 43 | feat_x, feat_y = self.net(x), self.net(y) 44 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 45 | 46 | if conf_sigma is not None: 47 | ori_diff = diff 48 | diff = [] 49 | for di in ori_diff: 50 | H = di.shape[-1] 51 | _conf_sigma = F.interpolate(conf_sigma, mode='area', size=(H, H)) 52 | di = di / (2 * _conf_sigma**2 + EPS) + (_conf_sigma + EPS).log() 53 | diff.append(di) 54 | 55 | if mask is not None: 56 | ori_diff = diff 57 | diff = [] 58 | for di in ori_diff: 59 | H = di.shape[-1] 60 | _mask = F.interpolate(mask, mode='area', size=(H, H)) 61 | di = di * _mask 62 | diff.append(di) 63 | 64 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 65 | loss += torch.sum(torch.cat(res, 0)) 66 | 67 | if _scale != self.num_scales - 1: 68 | x = F.interpolate(x, mode='bilinear', scale_factor=0.5, align_corners=False, recompute_scale_factor=True) 69 | y = F.interpolate(y, mode='bilinear', scale_factor=0.5, align_corners=False, recompute_scale_factor=True) 70 | 71 | return loss / n_samples 72 | -------------------------------------------------------------------------------- /spi/criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /spi/criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict -------------------------------------------------------------------------------- /spi/criteria/tv_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # density_reg = 0.25 5 | # density_reg_every = 4 6 | density_reg_p_dist = 0.004 7 | box_warp = 1 8 | 9 | def cal_tv_loss(ws, G): 10 | initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 11 | perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * density_reg_p_dist 12 | 13 | all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) 14 | sigma = G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] 15 | sigma_initial = sigma[:, :sigma.shape[1]//2] 16 | sigma_perturbed = sigma[:, sigma.shape[1]//2:] 17 | 18 | TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) 19 | return TVloss 20 | 21 | 22 | def cal_monotonic_loss(ws, G): 23 | initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front 24 | perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * box_warp # Behind 25 | 26 | all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) 27 | sigma = G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] 28 | sigma_initial = sigma[:, :sigma.shape[1]//2] 29 | sigma_perturbed = sigma[:, sigma.shape[1]//2:] 30 | 31 | monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10 32 | return monotonic_loss 33 | -------------------------------------------------------------------------------- /spi/training/coaches/inference_coach.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from spi.configs import paths_config, hyperparameters, global_config 5 | from .base_coach import BaseCoach 6 | from spi.utils.mask_utils import calculate_face_mask 7 | from spi.utils.camera_utils import cal_mirror_c 8 | 9 | 10 | class InferenceCoach(BaseCoach): 11 | 12 | def __init__(self, data_loader, use_wandb): 13 | super().__init__(data_loader, use_wandb) 14 | self.coach_name = 'InferenceCoach' 15 | self.build_name() 16 | 17 | def train(self): 18 | use_ball_holder = True 19 | paths_config.experiments_output_dir += f'{self.coach_name}' 20 | output_dir = paths_config.experiments_output_dir 21 | 22 | for idx, data in tqdm(enumerate(self.data_loader)): 23 | if self.image_counter >= hyperparameters.max_images_to_invert: 24 | break 25 | image_name = data['name'][0] 26 | 27 | image = data['img'].to(global_config.device) 28 | camera = data['c'].to(global_config.device) 29 | mask = data['mask'].to(global_config.device)[:, 0] 30 | fg_mask = 1 - (mask == 0).float() 31 | face_mask = calculate_face_mask(mask).float() 32 | 33 | fg_mask_m = torch.flip(fg_mask, dims=[3]) 34 | face_mask_m = torch.flip(face_mask, dims=[3]) 35 | camera_m = cal_mirror_c(camera=camera) 36 | image_m = torch.flip(image, dims=[3]) 37 | 38 | paths_config.experiments_output_dir = os.path.join(output_dir, image_name) 39 | os.makedirs(paths_config.experiments_output_dir, exist_ok=True) 40 | 41 | ckpt_path = os.path.join(paths_config.checkpoints_dir, hyperparameters.load_embedding_coach_name, f'{image_name}.pt') 42 | 43 | w_pivot, camera, self.G = self.load(ckpt_path) 44 | self.log_video(w_pivot, self.G, os.path.join(paths_config.video_output_dir, f'{image_name}.mp4')) 45 | 46 | 47 | -------------------------------------------------------------------------------- /spi/training/coaches/pti_coach.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from spi.utils.camera_utils import cal_mirror_c 5 | from spi.configs import paths_config, hyperparameters, global_config 6 | from .base_coach import BaseCoach 7 | from spi.utils.log_utils import log_image_from_w, log_image 8 | from spi.criteria.l2_loss import l2_loss 9 | 10 | 11 | 12 | class SingleIDCoach(BaseCoach): 13 | 14 | def __init__(self, data_loader, use_wandb): 15 | super().__init__(data_loader, use_wandb) 16 | self.coach_name = 'PTI_coach' 17 | self.build_name() 18 | 19 | def calc_loss(self, x, gt, use_ball_holder, new_G, w, c): 20 | loss = 0.0 21 | if hyperparameters.pt_l2_lambda > 0: 22 | l2_loss_val = l2_loss(x, gt) 23 | loss += l2_loss_val * hyperparameters.pt_l2_lambda 24 | if hyperparameters.pt_lpips_lambda > 0: 25 | loss_lpips = self.lpips_loss(x, gt) 26 | loss_lpips = torch.squeeze(loss_lpips) 27 | loss += loss_lpips * hyperparameters.pt_lpips_lambda 28 | 29 | if use_ball_holder and hyperparameters.use_locality_regularization: 30 | ball_holder_loss_val = self.space_regulizer.space_regulizer_loss(new_G, w, c) 31 | loss += ball_holder_loss_val 32 | return loss, loss_lpips 33 | 34 | def train(self): 35 | use_ball_holder = True 36 | paths_config.experiments_output_dir += f'{self.coach_name}' 37 | output_dir = paths_config.experiments_output_dir 38 | 39 | for idx, data in tqdm(enumerate(self.data_loader)): 40 | if self.image_counter >= hyperparameters.max_images_to_invert: 41 | break 42 | image_name = data['name'][0] 43 | 44 | image = data['img'].to(global_config.device) 45 | camera = data['c'].to(global_config.device) 46 | mask = data['mask'].to(global_config.device)[:, 0] 47 | fg_mask = 1 - (mask == 0).float() 48 | 49 | paths_config.experiments_output_dir = os.path.join(output_dir, image_name) 50 | os.makedirs(paths_config.experiments_output_dir, exist_ok=True) 51 | 52 | if self.use_wandb: 53 | log_image(image, 'target_image') 54 | 55 | self.restart_training() 56 | 57 | w_pivot = self.get_inversion(image_name, image, camera, fg_mask=fg_mask) 58 | 59 | log_images_counter = 0 60 | real_images_batch = image.to(global_config.device) 61 | 62 | for i in tqdm(range(hyperparameters.G_1_step)): 63 | generated_images = self.G.synthesis(w_pivot, camera, noise_mode='const')['image'] 64 | target_images = real_images_batch 65 | 66 | loss, loss_lpips = self.calc_loss(generated_images, target_images, image_name, self.G, use_ball_holder, w_pivot) 67 | 68 | self.optimizer.zero_grad() 69 | 70 | if loss_lpips <= hyperparameters.LPIPS_value_threshold: 71 | break 72 | 73 | loss.backward() 74 | self.optimizer.step() 75 | 76 | use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0 77 | 78 | if self.use_wandb and log_images_counter % global_config.log_snapshot == 0: 79 | log_image_from_w(w_pivot, camera, self.G, f'{image_name}_G1_inv_{log_images_counter}') 80 | 81 | global_config.training_step += 1 82 | log_images_counter += 1 83 | 84 | self.image_counter += 1 85 | 86 | if self.use_wandb and hyperparameters.G_1_step > 0: 87 | camera_m = cal_mirror_c(camera=camera) 88 | G1_inv = log_image_from_w(w_pivot, camera, self.G, f'{image_name}_G1_inv') 89 | G1_inv_m = log_image_from_w(w_pivot, camera_m, self.G, f'{image_name}_G1_inv_m') 90 | self.log_video(w_pivot, self.G, path=os.path.join(paths_config.experiments_output_dir, f'{image_name}_G1_inv.mp4')) 91 | 92 | self.cal_metric(G1_inv, image, 'G1_inv', fake_m=G1_inv_m) 93 | 94 | self.post_process(w_pivot, camera, self.G, image_name) 95 | 96 | paths_config.experiments_output_dir = output_dir 97 | if self.use_wandb: 98 | self.log_metric() 99 | -------------------------------------------------------------------------------- /spi/training/projectors/w_plus_projector.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | from spi.configs import global_config, hyperparameters 7 | from spi.utils.log_utils import log_image 8 | 9 | 10 | def project( 11 | G, 12 | target: torch.Tensor, 13 | c, 14 | lpips_func, 15 | *, 16 | initial_w=None, 17 | num_steps=1000, 18 | w_avg_samples=10000, 19 | initial_learning_rate=0.01, 20 | initial_noise_factor=0.05, 21 | lr_rampdown_length=0.25, 22 | lr_rampup_length=0.05, 23 | noise_ramp_length=0.75, 24 | regularize_noise_weight=1e5, 25 | verbose=False, 26 | device: torch.device, 27 | image_log_step=global_config.log_snapshot, 28 | w_name: str 29 | ): 30 | assert target.shape[1:] == (G.img_channels, G.img_resolution, G.img_resolution) 31 | G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore 32 | 33 | # Compute w stats. 34 | # logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') 35 | z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) 36 | c_samples = c.repeat(w_avg_samples, 1) 37 | w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C] 38 | w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] 39 | 40 | w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] 41 | w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 42 | 43 | # Setup noise inputs. 44 | noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name} 45 | 46 | if initial_w is not None: 47 | start_w = initial_w 48 | else: 49 | start_w = w_avg 50 | start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1) 51 | 52 | w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, requires_grad=False) # pylint: disable=not-callable 53 | w_opt.requires_grad_(True) 54 | 55 | optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=hyperparameters.first_inv_lr) 56 | 57 | # Init noise. 58 | for buf in noise_bufs.values(): 59 | buf[:] = torch.randn_like(buf) 60 | buf.requires_grad = True 61 | 62 | for step in tqdm(range(num_steps)): 63 | 64 | # Learning rate schedule. 65 | t = step / num_steps 66 | w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 67 | lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) 68 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) 69 | lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) 70 | lr = initial_learning_rate * lr_ramp 71 | for param_group in optimizer.param_groups: 72 | param_group['lr'] = lr 73 | 74 | # Synth images from opt_w. 75 | w_noise = torch.randn_like(w_opt) * w_noise_scale 76 | ws = (w_opt + w_noise) 77 | synth_images = G.synthesis(ws, c, noise_mode='const')['image'] 78 | 79 | # Features for synth images. 80 | dist = lpips_func(synth_images, target) 81 | 82 | # Noise regularization. 83 | reg_loss = 0.0 84 | for v in noise_bufs.values(): 85 | noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() 86 | while True: 87 | reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 88 | reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 89 | if noise.shape[2] <= 8: 90 | break 91 | noise = F.avg_pool2d(noise, kernel_size=2) 92 | 93 | loss = dist + reg_loss * regularize_noise_weight # + reg_avg_loss * 1 94 | 95 | # Step 96 | optimizer.zero_grad(set_to_none=True) 97 | loss.backward() 98 | optimizer.step() 99 | 100 | # Normalize noise. 101 | with torch.no_grad(): 102 | for buf in noise_bufs.values(): 103 | buf -= buf.mean() 104 | buf *= buf.square().mean().rsqrt() 105 | 106 | if verbose and step % image_log_step == 0: 107 | print(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') 108 | with torch.no_grad(): 109 | global_config.training_step += 1 110 | log_image(synth_images, f'w_{w_name}_{step}') 111 | 112 | del G 113 | return w_opt 114 | -------------------------------------------------------------------------------- /spi/utils/load_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath('.')) 4 | 5 | import eg3d.dnnlib as dnnlib 6 | from eg3d.torch_utils import misc 7 | import eg3d.legacy as legacy 8 | import torch 9 | from eg3d.training.triplane import TriPlaneGenerator 10 | from eg3d.training.dual_discriminator import DualDiscriminator 11 | from spi.configs import paths_config 12 | 13 | 14 | 15 | def load_eg3d(reload_modules=True, device='cuda', network_pkl=None): 16 | if network_pkl is None: 17 | network_pkl = paths_config.EG3D_PATH 18 | 19 | with dnnlib.util.open_url(network_pkl) as f: 20 | G = legacy.load_network_pkl(f)['G_ema'].to(device) 21 | 22 | # Specify reload_modules=True if you want code modifications to take effect; otherwise uses pickled code 23 | if reload_modules: 24 | # print("Reloading Modules!") 25 | G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) 26 | misc.copy_params_and_buffers(G, G_new, require_all=True) 27 | G_new.neural_rendering_resolution = G.neural_rendering_resolution 28 | G_new.rendering_kwargs = G.rendering_kwargs 29 | G = G_new 30 | 31 | G.neural_rendering_resolution = 128 32 | G.eval() 33 | return G 34 | 35 | 36 | def load_bisenet(): 37 | from third_part.bisenet.bisenet import BiSeNet 38 | net = BiSeNet(19) 39 | path = paths_config.BISENET_PATH 40 | ckpt = torch.load(path, map_location='cpu') 41 | net.load_state_dict(ckpt) 42 | net.to('cuda') 43 | net.eval() 44 | return net 45 | 46 | 47 | def load_sg_vgg(): 48 | url = paths_config.VGG_PATH 49 | with dnnlib.util.open_url(url) as f: 50 | vgg16 = torch.jit.load(f).eval() 51 | return vgg16 -------------------------------------------------------------------------------- /spi/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import spi.configs.paths_config as paths_config 5 | 6 | 7 | def log_image_from_w(w, c, G, name): 8 | if len(w.size()) <= 2: 9 | w = w.unsqueeze(0) 10 | with torch.no_grad(): 11 | img_tensor = G.synthesis(w, c, noise_mode='const')['image'] 12 | img = img_tensor[0].permute(1, 2, 0) 13 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() 14 | Image.fromarray(img).save(os.path.join(paths_config.experiments_output_dir, name + '.jpg')) 15 | return img_tensor 16 | 17 | 18 | def tensor2im(var, vmin=-1, vmax=1): 19 | # var shape: (3, H, W) 20 | if len(var.shape) == 4: 21 | var = var[0] 22 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 23 | var = ((var - vmin) / (vmax - vmin)) 24 | var[var < 0] = 0 25 | var[var > 1] = 1 26 | var = var * 255 27 | return Image.fromarray(var.astype('uint8')) 28 | 29 | 30 | def tensor2depth(var): 31 | # var shape: (3, H, W) 32 | if len(var.shape) == 4: 33 | var = var[0] 34 | var = var.cpu().detach()[0].numpy() 35 | vmax = var.max() 36 | vmin = var.min() 37 | var = ((var - vmin) / (vmax - vmin)) 38 | var[var < 0] = 0 39 | var[var > 1] = 1 40 | var = var * 255 41 | return Image.fromarray(var.astype('uint8')) 42 | 43 | 44 | def log_image(t, name, vmin=-1, vmax=1, mode='jpg'): 45 | t = t.type(torch.FloatTensor) 46 | if len(t.shape) == 4: 47 | t = t[0] 48 | if t.shape[0] == 1: 49 | img = tensor2depth(t) 50 | else: 51 | img = tensor2im(t, vmin=vmin, vmax=vmax) 52 | img.save(os.path.join(paths_config.experiments_output_dir, name + '.' + mode)) 53 | return 54 | 55 | 56 | def log_mask(t, name): 57 | img = tensor2im(t) 58 | img.save(os.path.join(paths_config.experiments_output_dir, name + '.jpg')) 59 | return -------------------------------------------------------------------------------- /spi/utils/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calculate_face_mask(mask): 5 | att_list = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13] 6 | face_mask = torch.zeros_like(mask) 7 | for att in att_list: 8 | face_mask += (mask == att) 9 | return face_mask 10 | 11 | 12 | 13 | 14 | 15 | 16 | def calculate_face_mask_w_model(bisenet, image): 17 | out = bisenet(image)[0] # 1, 19, 512, 512 18 | att_list = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13] 19 | parsing = torch.argmax(out, dim=1, keepdim=True) 20 | N, _, H, W = image.shape 21 | mask = torch.zeros(N, 1, H, W).float().to(image.device) 22 | for att in att_list: 23 | mask += (parsing == att) 24 | return mask -------------------------------------------------------------------------------- /spi/utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | from spi.criteria.l2_loss import l2_loss 2 | from spi.criteria.lpips.lpips import LPIPS 3 | from spi.criteria.id_loss.id_loss import IDLoss 4 | from spi.configs import paths_config 5 | 6 | class Metric(): 7 | 8 | def __init__(self) -> None: 9 | self.lpips_loss = LPIPS(net_type='vgg').to('cuda').eval() 10 | path_ir_se50 = paths_config.IDLOSS_PATH 11 | self.id_loss = IDLoss(path_ir_se50).to('cuda').eval() 12 | 13 | def run(self, gt, fake): 14 | l2 = l2_loss(gt, fake) 15 | lpips = self.lpips_loss(gt, fake) 16 | id_sim = self.id_loss.calculate_similarity(gt, fake) 17 | return l2.item(), lpips.item(), id_sim.item() 18 | 19 | 20 | 21 | def metric(gt, w, c, G, lpips_func): 22 | fake = G.synthesis(w, c, noise_mode='const')['image'] 23 | 24 | l2 = l2_loss(gt, fake) 25 | lpips = lpips_func(gt, fake) 26 | return l2, lpips 27 | 28 | -------------------------------------------------------------------------------- /test/images/anne_hathaway.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/test/images/anne_hathaway.jpg -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from .base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "third_part.Deep3DFaceRecon_pytorch.models." + model_name + "_model" 33 | # model_filename = "models." + model_name + "_model" 34 | modellib = importlib.import_module(model_filename) 35 | model = None 36 | target_model_name = model_name.replace('_', '') + 'model' 37 | for name, cls in modellib.__dict__.items(): 38 | if name.lower() == target_model_name.lower() \ 39 | and issubclass(cls, BaseModel): 40 | model = cls 41 | 42 | if model is None: 43 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 44 | exit(0) 45 | 46 | return model 47 | 48 | 49 | def get_option_setter(model_name): 50 | """Return the static method of the model class.""" 51 | model_class = find_model_using_name(model_name) 52 | return model_class.modify_commandline_options 53 | 54 | 55 | def create_model(opt): 56 | """Create a model given the option. 57 | 58 | This function warps the class CustomDatasetDataLoader. 59 | This is the main interface between this package and 'train.py'/'test.py' 60 | 61 | Example: 62 | >>> from models import create_model 63 | >>> model = create_model(opt) 64 | """ 65 | model = find_model_using_name(opt.model) 66 | instance = model(opt) 67 | print("model [%s] was created" % type(instance).__name__) 68 | return instance 69 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | from .mobilefacenet import get_mbf 3 | 4 | 5 | def get_model(name, **kwargs): 6 | # resnet 7 | if name == "r18": 8 | return iresnet18(False, **kwargs) 9 | elif name == "r34": 10 | return iresnet34(False, **kwargs) 11 | elif name == "r50": 12 | return iresnet50(False, **kwargs) 13 | elif name == "r100": 14 | return iresnet100(False, **kwargs) 15 | elif name == "r200": 16 | return iresnet200(False, **kwargs) 17 | elif name == "r2060": 18 | from .iresnet2060 import iresnet2060 19 | return iresnet2060(False, **kwargs) 20 | elif name == "mbf": 21 | fp16 = kwargs.get("fp16", False) 22 | num_features = kwargs.get("num_features", 512) 23 | return get_mbf(fp16=fp16, num_features=num_features) 24 | else: 25 | raise ValueError() -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/3millions.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 1.0 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 300 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/3millions_pfc.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 0.1 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 300 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/__init__.py -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/base.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = "ms1mv3_arcface_r50" 12 | 13 | config.dataset = "ms1m-retinaface-t1" 14 | config.embedding_size = 512 15 | config.sample_rate = 1 16 | config.fp16 = False 17 | config.momentum = 0.9 18 | config.weight_decay = 5e-4 19 | config.batch_size = 128 20 | config.lr = 0.1 # batch size is 512 21 | 22 | if config.dataset == "emore": 23 | config.rec = "/train_tmp/faces_emore" 24 | config.num_classes = 85742 25 | config.num_image = 5822653 26 | config.num_epoch = 16 27 | config.warmup_epoch = -1 28 | config.decay_epoch = [8, 14, ] 29 | config.val_targets = ["lfw", ] 30 | 31 | elif config.dataset == "ms1m-retinaface-t1": 32 | config.rec = "/train_tmp/ms1m-retinaface-t1" 33 | config.num_classes = 93431 34 | config.num_image = 5179510 35 | config.num_epoch = 25 36 | config.warmup_epoch = -1 37 | config.decay_epoch = [11, 17, 22] 38 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 39 | 40 | elif config.dataset == "glint360k": 41 | config.rec = "/train_tmp/glint360k" 42 | config.num_classes = 360232 43 | config.num_image = 17091657 44 | config.num_epoch = 20 45 | config.warmup_epoch = -1 46 | config.decay_epoch = [8, 12, 15, 18] 47 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 48 | 49 | elif config.dataset == "webface": 50 | config.rec = "/train_tmp/faces_webface_112x112" 51 | config.num_classes = 10572 52 | config.num_image = "forget" 53 | config.num_epoch = 34 54 | config.warmup_epoch = -1 55 | config.decay_epoch = [20, 28, 32] 56 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 57 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 2e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r34.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r34" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 2e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 30 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 20, 25] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r2060.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r2060" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 64 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r34.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r34" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/speed.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 1.0 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 100 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/dataset.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import os 3 | import queue as Queue 4 | import threading 5 | 6 | import mxnet as mx 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset 10 | from torchvision import transforms 11 | 12 | 13 | class BackgroundGenerator(threading.Thread): 14 | def __init__(self, generator, local_rank, max_prefetch=6): 15 | super(BackgroundGenerator, self).__init__() 16 | self.queue = Queue.Queue(max_prefetch) 17 | self.generator = generator 18 | self.local_rank = local_rank 19 | self.daemon = True 20 | self.start() 21 | 22 | def run(self): 23 | torch.cuda.set_device(self.local_rank) 24 | for item in self.generator: 25 | self.queue.put(item) 26 | self.queue.put(None) 27 | 28 | def next(self): 29 | next_item = self.queue.get() 30 | if next_item is None: 31 | raise StopIteration 32 | return next_item 33 | 34 | def __next__(self): 35 | return self.next() 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | 41 | class DataLoaderX(DataLoader): 42 | 43 | def __init__(self, local_rank, **kwargs): 44 | super(DataLoaderX, self).__init__(**kwargs) 45 | self.stream = torch.cuda.Stream(local_rank) 46 | self.local_rank = local_rank 47 | 48 | def __iter__(self): 49 | self.iter = super(DataLoaderX, self).__iter__() 50 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 51 | self.preload() 52 | return self 53 | 54 | def preload(self): 55 | self.batch = next(self.iter, None) 56 | if self.batch is None: 57 | return None 58 | with torch.cuda.stream(self.stream): 59 | for k in range(len(self.batch)): 60 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) 61 | 62 | def __next__(self): 63 | torch.cuda.current_stream().wait_stream(self.stream) 64 | batch = self.batch 65 | if batch is None: 66 | raise StopIteration 67 | self.preload() 68 | return batch 69 | 70 | 71 | class MXFaceDataset(Dataset): 72 | def __init__(self, root_dir, local_rank): 73 | super(MXFaceDataset, self).__init__() 74 | self.transform = transforms.Compose( 75 | [transforms.ToPILImage(), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 79 | ]) 80 | self.root_dir = root_dir 81 | self.local_rank = local_rank 82 | path_imgrec = os.path.join(root_dir, 'train.rec') 83 | path_imgidx = os.path.join(root_dir, 'train.idx') 84 | self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') 85 | s = self.imgrec.read_idx(0) 86 | header, _ = mx.recordio.unpack(s) 87 | if header.flag > 0: 88 | self.header0 = (int(header.label[0]), int(header.label[1])) 89 | self.imgidx = np.array(range(1, int(header.label[0]))) 90 | else: 91 | self.imgidx = np.array(list(self.imgrec.keys)) 92 | 93 | def __getitem__(self, index): 94 | idx = self.imgidx[index] 95 | s = self.imgrec.read_idx(idx) 96 | header, img = mx.recordio.unpack(s) 97 | label = header.label 98 | if not isinstance(label, numbers.Number): 99 | label = label[0] 100 | label = torch.tensor(label, dtype=torch.long) 101 | sample = mx.image.imdecode(img).asnumpy() 102 | if self.transform is not None: 103 | sample = self.transform(sample) 104 | return sample, label 105 | 106 | def __len__(self): 107 | return len(self.imgidx) 108 | 109 | 110 | class SyntheticDataset(Dataset): 111 | def __init__(self, local_rank): 112 | super(SyntheticDataset, self).__init__() 113 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 114 | img = np.transpose(img, (2, 0, 1)) 115 | img = torch.from_numpy(img).squeeze(0).float() 116 | img = ((img / 255) - 0.5) / 0.5 117 | self.img = img 118 | self.label = 1 119 | 120 | def __getitem__(self, index): 121 | return self.img, self.label 122 | 123 | def __len__(self): 124 | return 1000000 125 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/eval.md: -------------------------------------------------------------------------------- 1 | ## Eval on ICCV2021-MFR 2 | 3 | coming soon. 4 | 5 | 6 | ## Eval IJBC 7 | You can eval ijbc with pytorch or onnx. 8 | 9 | 10 | 1. Eval IJBC With Onnx 11 | ```shell 12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 13 | ``` 14 | 15 | 2. Eval IJBC With Pytorch 16 | ```shell 17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ 18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \ 19 | --image-path IJB_release/IJBC \ 20 | --result-dir ms1mv3_arcface_r50 \ 21 | --batch-size 128 \ 22 | --job ms1mv3_arcface_r50 \ 23 | --target IJBC \ 24 | --network iresnet50 25 | ``` 26 | 27 | ## Inference 28 | 29 | ```shell 30 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 31 | ``` 32 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/install.md: -------------------------------------------------------------------------------- 1 | ## v1.8.0 2 | ### Linux and Windows 3 | ```shell 4 | # CUDA 11.0 5 | pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 6 | 7 | # CUDA 10.2 8 | pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 9 | 10 | # CPU only 11 | pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 12 | 13 | ``` 14 | 15 | 16 | ## v1.7.1 17 | ### Linux and Windows 18 | ```shell 19 | # CUDA 11.0 20 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 21 | 22 | # CUDA 10.2 23 | pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 24 | 25 | # CUDA 10.1 26 | pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 27 | 28 | # CUDA 9.2 29 | pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 30 | 31 | # CPU only 32 | pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 33 | ``` 34 | 35 | 36 | ## v1.6.0 37 | 38 | ### Linux and Windows 39 | ```shell 40 | # CUDA 10.2 41 | pip install torch==1.6.0 torchvision==0.7.0 42 | 43 | # CUDA 10.1 44 | pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 45 | 46 | # CUDA 9.2 47 | pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html 48 | 49 | # CPU only 50 | pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 51 | ``` -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/modelzoo.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/modelzoo.md -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/eval/__init__.py -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from backbones import get_model 8 | 9 | 10 | @torch.no_grad() 11 | def inference(weight, name, img): 12 | if img is None: 13 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) 14 | else: 15 | img = cv2.imread(img) 16 | img = cv2.resize(img, (112, 112)) 17 | 18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 19 | img = np.transpose(img, (2, 0, 1)) 20 | img = torch.from_numpy(img).unsqueeze(0).float() 21 | img.div_(255).sub_(0.5).div_(0.5) 22 | net = get_model(name, fp16=False) 23 | net.load_state_dict(torch.load(weight)) 24 | net.eval() 25 | feat = net(img).numpy() 26 | print(feat) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') 31 | parser.add_argument('--network', type=str, default='r50', help='backbone network') 32 | parser.add_argument('--weight', type=str, default='') 33 | parser.add_argument('--img', type=str, default=None) 34 | args = parser.parse_args() 35 | inference(args.weight, args.network, args.img) 36 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def get_loss(name): 6 | if name == "cosface": 7 | return CosFace() 8 | elif name == "arcface": 9 | return ArcFace() 10 | else: 11 | raise ValueError() 12 | 13 | 14 | class CosFace(nn.Module): 15 | def __init__(self, s=64.0, m=0.40): 16 | super(CosFace, self).__init__() 17 | self.s = s 18 | self.m = m 19 | 20 | def forward(self, cosine, label): 21 | index = torch.where(label != -1)[0] 22 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) 23 | m_hot.scatter_(1, label[index, None], self.m) 24 | cosine[index] -= m_hot 25 | ret = cosine * self.s 26 | return ret 27 | 28 | 29 | class ArcFace(nn.Module): 30 | def __init__(self, s=64.0, m=0.5): 31 | super(ArcFace, self).__init__() 32 | self.s = s 33 | self.m = m 34 | 35 | def forward(self, cosine: torch.Tensor, label): 36 | index = torch.where(label != -1)[0] 37 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) 38 | m_hot.scatter_(1, label[index, None], self.m) 39 | cosine.acos_() 40 | cosine[index] += m_hot 41 | cosine.cos_().mul_(self.s) 42 | return cosine 43 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | mxnet 4 | onnx 5 | sklearn 6 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 2 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh 3 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import torch 4 | 5 | 6 | def convert_onnx(net, path_module, output, opset=11, simplify=False): 7 | assert isinstance(net, torch.nn.Module) 8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 9 | img = img.astype(np.float) 10 | img = (img / 255. - 0.5) / 0.5 # torch style norm 11 | img = img.transpose((2, 0, 1)) 12 | img = torch.from_numpy(img).unsqueeze(0).float() 13 | 14 | weight = torch.load(path_module) 15 | net.load_state_dict(weight) 16 | net.eval() 17 | torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) 18 | model = onnx.load(output) 19 | graph = model.graph 20 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' 21 | if simplify: 22 | from onnxsim import simplify 23 | model, check = simplify(model) 24 | assert check, "Simplified ONNX model could not be validated" 25 | onnx.save(model, output) 26 | 27 | 28 | if __name__ == '__main__': 29 | import os 30 | import argparse 31 | from backbones import get_model 32 | 33 | parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') 34 | parser.add_argument('input', type=str, help='input backbone.pth file or path') 35 | parser.add_argument('--output', type=str, default=None, help='output onnx path') 36 | parser.add_argument('--network', type=str, default=None, help='backbone network') 37 | parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') 38 | args = parser.parse_args() 39 | input_file = args.input 40 | if os.path.isdir(input_file): 41 | input_file = os.path.join(input_file, "backbone.pth") 42 | assert os.path.exists(input_file) 43 | model_name = os.path.basename(os.path.dirname(input_file)).lower() 44 | params = model_name.split("_") 45 | if len(params) >= 3 and params[1] in ('arcface', 'cosface'): 46 | if args.network is None: 47 | args.network = params[2] 48 | assert args.network is not None 49 | print(args) 50 | backbone_onnx = get_model(args.network, dropout=0) 51 | 52 | output_path = args.output 53 | if output_path is None: 54 | output_path = os.path.join(os.path.dirname(__file__), 'onnx') 55 | if not os.path.exists(output_path): 56 | os.makedirs(output_path) 57 | assert os.path.isdir(output_path) 58 | output_file = os.path.join(output_path, "%s.onnx" % model_name) 59 | convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) 60 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/__init__.py -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/plot.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 10 | from prettytable import PrettyTable 11 | from sklearn.metrics import roc_curve, auc 12 | 13 | image_path = "/data/anxiang/IJB_release/IJBC" 14 | files = [ 15 | "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" 16 | ] 17 | 18 | 19 | def read_template_pair_list(path): 20 | pairs = pd.read_csv(path, sep=' ', header=None).values 21 | t1 = pairs[:, 0].astype(np.int) 22 | t2 = pairs[:, 1].astype(np.int) 23 | label = pairs[:, 2].astype(np.int) 24 | return t1, t2, label 25 | 26 | 27 | p1, p2, label = read_template_pair_list( 28 | os.path.join('%s/meta' % image_path, 29 | '%s_template_pair_label.txt' % 'ijbc')) 30 | 31 | methods = [] 32 | scores = [] 33 | for file in files: 34 | methods.append(file.split('/')[-2]) 35 | scores.append(np.load(file)) 36 | 37 | methods = np.array(methods) 38 | scores = dict(zip(methods, scores)) 39 | colours = dict( 40 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) 41 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 42 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) 43 | fig = plt.figure() 44 | for method in methods: 45 | fpr, tpr, _ = roc_curve(label, scores[method]) 46 | roc_auc = auc(fpr, tpr) 47 | fpr = np.flipud(fpr) 48 | tpr = np.flipud(tpr) # select largest tpr at same fpr 49 | plt.plot(fpr, 50 | tpr, 51 | color=colours[method], 52 | lw=1, 53 | label=('[%s (AUC = %0.4f %%)]' % 54 | (method.split('-')[-1], roc_auc * 100))) 55 | tpr_fpr_row = [] 56 | tpr_fpr_row.append("%s-%s" % (method, "IJBC")) 57 | for fpr_iter in np.arange(len(x_labels)): 58 | _, min_index = min( 59 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 60 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 61 | tpr_fpr_table.add_row(tpr_fpr_row) 62 | plt.xlim([10 ** -6, 0.1]) 63 | plt.ylim([0.3, 1.0]) 64 | plt.grid(linestyle='--', linewidth=1) 65 | plt.xticks(x_labels) 66 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 67 | plt.xscale('log') 68 | plt.xlabel('False Positive Rate') 69 | plt.ylabel('True Positive Rate') 70 | plt.title('ROC on IJB') 71 | plt.legend(loc="lower right") 72 | print(tpr_fpr_table) 73 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_amp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | 5 | if torch.__version__ < '1.9': 6 | Iterable = torch._six.container_abcs.Iterable 7 | else: 8 | import collections 9 | 10 | Iterable = collections.abc.Iterable 11 | from torch.cuda.amp import GradScaler 12 | 13 | 14 | class _MultiDeviceReplicator(object): 15 | """ 16 | Lazily serves copies of a tensor to requested devices. Copies are cached per-device. 17 | """ 18 | 19 | def __init__(self, master_tensor: torch.Tensor) -> None: 20 | assert master_tensor.is_cuda 21 | self.master = master_tensor 22 | self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} 23 | 24 | def get(self, device) -> torch.Tensor: 25 | retval = self._per_device_tensors.get(device, None) 26 | if retval is None: 27 | retval = self.master.to(device=device, non_blocking=True, copy=True) 28 | self._per_device_tensors[device] = retval 29 | return retval 30 | 31 | 32 | class MaxClipGradScaler(GradScaler): 33 | def __init__(self, init_scale, max_scale: float, growth_interval=100): 34 | GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) 35 | self.max_scale = max_scale 36 | 37 | def scale_clip(self): 38 | if self.get_scale() == self.max_scale: 39 | self.set_growth_factor(1) 40 | elif self.get_scale() < self.max_scale: 41 | self.set_growth_factor(2) 42 | elif self.get_scale() > self.max_scale: 43 | self._scale.fill_(self.max_scale) 44 | self.set_growth_factor(1) 45 | 46 | def scale(self, outputs): 47 | """ 48 | Multiplies ('scales') a tensor or list of tensors by the scale factor. 49 | 50 | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned 51 | unmodified. 52 | 53 | Arguments: 54 | outputs (Tensor or iterable of Tensors): Outputs to scale. 55 | """ 56 | if not self._enabled: 57 | return outputs 58 | self.scale_clip() 59 | # Short-circuit for the common case. 60 | if isinstance(outputs, torch.Tensor): 61 | assert outputs.is_cuda 62 | if self._scale is None: 63 | self._lazy_init_scale_growth_tracker(outputs.device) 64 | assert self._scale is not None 65 | return outputs * self._scale.to(device=outputs.device, non_blocking=True) 66 | 67 | # Invoke the more complex machinery only if we're treating multiple outputs. 68 | stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale 69 | 70 | def apply_scale(val): 71 | if isinstance(val, torch.Tensor): 72 | assert val.is_cuda 73 | if len(stash) == 0: 74 | if self._scale is None: 75 | self._lazy_init_scale_growth_tracker(val.device) 76 | assert self._scale is not None 77 | stash.append(_MultiDeviceReplicator(self._scale)) 78 | return val * stash[0].get(val.device) 79 | elif isinstance(val, Iterable): 80 | iterable = map(apply_scale, val) 81 | if isinstance(val, list) or isinstance(val, tuple): 82 | return type(val)(iterable) 83 | else: 84 | return iterable 85 | else: 86 | raise ValueError("outputs must be a Tensor or an iterable of Tensors") 87 | 88 | return apply_scale(outputs) 89 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os.path as osp 3 | 4 | 5 | def get_config(config_file): 6 | assert config_file.startswith('configs/'), 'config file setting must start with configs/' 7 | temp_config_name = osp.basename(config_file) 8 | temp_module_name = osp.splitext(temp_config_name)[0] 9 | config = importlib.import_module("configs.base") 10 | cfg = config.config 11 | config = importlib.import_module("configs.%s" % temp_module_name) 12 | job_cfg = config.config 13 | cfg.update(job_cfg) 14 | if cfg.output is None: 15 | cfg.output = osp.join('work_dirs', temp_module_name) 16 | return cfg -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(rank, models_root): 31 | if rank == 0: 32 | log_root = logging.getLogger() 33 | log_root.setLevel(logging.INFO) 34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 36 | handler_stream = logging.StreamHandler(sys.stdout) 37 | handler_file.setFormatter(formatter) 38 | handler_stream.setFormatter(formatter) 39 | log_root.addHandler(handler_file) 40 | log_root.addHandler(handler_stream) 41 | log_root.info('rank_id: %d' % rank) 42 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_os.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/third_part/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_os.py -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/models/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from kornia.geometry import warp_affine 5 | import torch.nn.functional as F 6 | 7 | def resize_n_crop(image, M, dsize=112): 8 | # image: (b, c, h, w) 9 | # M : (b, 2, 3) 10 | return warp_affine(image, M, dsize=(dsize, dsize)) 11 | 12 | ### perceptual level loss 13 | class PerceptualLoss(nn.Module): 14 | def __init__(self, recog_net, input_size=112): 15 | super(PerceptualLoss, self).__init__() 16 | self.recog_net = recog_net 17 | self.preprocess = lambda x: 2 * x - 1 18 | self.input_size=input_size 19 | def forward(imageA, imageB, M): 20 | """ 21 | 1 - cosine distance 22 | Parameters: 23 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order 24 | imageB --same as imageA 25 | """ 26 | 27 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) 28 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) 29 | 30 | # freeze bn 31 | self.recog_net.eval() 32 | 33 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) 34 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) 35 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 36 | # assert torch.sum((cosine_d > 1).float()) == 0 37 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 38 | 39 | def perceptual_loss(id_featureA, id_featureB): 40 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 41 | # assert torch.sum((cosine_d > 1).float()) == 0 42 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 43 | 44 | ### image level loss 45 | def photo_loss(imageA, imageB, mask, eps=1e-6): 46 | """ 47 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) 48 | Parameters: 49 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order 50 | imageB --same as imageA 51 | """ 52 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask 53 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) 54 | return loss 55 | 56 | def landmark_loss(predict_lm, gt_lm, weight=None): 57 | """ 58 | weighted mse loss 59 | Parameters: 60 | predict_lm --torch.tensor (B, 68, 2) 61 | gt_lm --torch.tensor (B, 68, 2) 62 | weight --numpy.array (1, 68) 63 | """ 64 | if not weight: 65 | weight = np.ones([68]) 66 | weight[28:31] = 20 67 | weight[-8:] = 20 68 | weight = np.expand_dims(weight, 0) 69 | weight = torch.tensor(weight).to(predict_lm.device) 70 | loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight 71 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) 72 | return loss 73 | 74 | 75 | ### regulization 76 | def reg_loss(coeffs_dict, opt=None): 77 | """ 78 | l2 norm without the sqrt, from yu's implementation (mse) 79 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss 80 | Parameters: 81 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans 82 | 83 | """ 84 | # coefficient regularization to ensure plausible 3d faces 85 | if opt: 86 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex 87 | else: 88 | w_id, w_exp, w_tex = 1, 1, 1, 1 89 | creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ 90 | w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ 91 | w_tex * torch.sum(coeffs_dict['tex'] ** 2) 92 | creg_loss = creg_loss / coeffs_dict['id'].shape[0] 93 | 94 | # gamma regularization to ensure a nearly-monochromatic light 95 | gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) 96 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True) 97 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2) 98 | 99 | return creg_loss, gamma_loss 100 | 101 | def reflectance_loss(texture, mask): 102 | """ 103 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo 104 | Parameters: 105 | texture --torch.tensor, (B, N, 3) 106 | mask --torch.tensor, (N), 1 or 0 107 | 108 | """ 109 | mask = mask.reshape([1, mask.shape[0], 1]) 110 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) 111 | loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) 112 | return loss 113 | 114 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/options/inference_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class InferenceOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 13 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 14 | 15 | parser.add_argument('--input_dir', type=str, help='the folder of the input files') 16 | parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') 17 | parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') 18 | parser.add_argument('--save_split_files', action='store_true', help='save split files or not') 19 | parser.add_argument('--inference_batch_size', type=int, default=8) 20 | 21 | # Dropout and Batchnorm has different behavior during training and test. 22 | self.isTrain = False 23 | return parser 24 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/options/test_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the test options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | 6 | 7 | class TestOptions(BaseOptions): 8 | """This class includes test options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) # define shared options 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 17 | parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') 18 | 19 | # Dropout and Batchnorm has different behavior during training and test. 20 | self.isTrain = False 21 | return parser 22 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/temp/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/third_part/Deep3DFaceRecon_pytorch/temp/2.png -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/test.py: -------------------------------------------------------------------------------- 1 | """This script is the test script for Deep3DFaceRecon_pytorch 2 | Change most absolute path to relative path 3 | """ 4 | 5 | import os 6 | import sys 7 | from PIL import Image 8 | import numpy as np 9 | import torch 10 | 11 | # print(__package__) 12 | __package__ = 'third_part.Deep3DFaceRecon_pytorch' 13 | sys.path.append(os.path.abspath('../../')) 14 | # sys.path.append(os.path.abspath('./models/')) 15 | # sys.path.append(os.path.abspath('./options/')) 16 | # sys.path.append(os.path.abspath('./util/')) 17 | # print(sys.path) 18 | # sys.path.append('.') 19 | # sys.path.append() 20 | 21 | from .options.test_options import TestOptions 22 | from .models import create_model 23 | from .util.visualizer import MyVisualizer 24 | from .util.preprocess import align_img 25 | from .util.load_mats import load_lm3d 26 | 27 | # from data.flist_dataset import default_flist_reader 28 | # from scipy.io import loadmat, savemat 29 | 30 | 31 | def get_data_path(root='examples'): 32 | 33 | im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')] 34 | lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path] 35 | lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path] 36 | 37 | return im_path, lm_path 38 | 39 | def read_data(im_path, lm_path, lm3d_std, to_tensor=True): 40 | # to RGB 41 | im = Image.open(im_path).convert('RGB') 42 | W,H = im.size 43 | lm = np.loadtxt(lm_path).astype(np.float32) 44 | lm = lm.reshape([-1, 2]) 45 | lm[:, -1] = H - 1 - lm[:, -1] 46 | _, im, lm, _ = align_img(im, lm, lm3d_std) 47 | if to_tensor: 48 | im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) 49 | lm = torch.tensor(lm).unsqueeze(0) 50 | return im, lm 51 | 52 | def main(rank, opt, name='examples'): 53 | device = torch.device(rank) 54 | torch.cuda.set_device(device) 55 | model = create_model(opt) 56 | model.setup(opt) 57 | model.device = device 58 | model.parallelize() 59 | model.eval() 60 | visualizer = MyVisualizer(opt) 61 | 62 | im_path, lm_path = get_data_path(name) 63 | lm3d_std = load_lm3d(opt.bfm_folder) 64 | 65 | for i in range(len(im_path)): 66 | print(i, im_path[i]) 67 | img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','') 68 | if not os.path.isfile(lm_path[i]): 69 | print('no landmark file') 70 | continue 71 | im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std) 72 | data = { 73 | 'imgs': im_tensor, 74 | 'lms': lm_tensor 75 | } 76 | model.set_input(data) # unpack data from data loader 77 | with torch.no_grad(): 78 | # model.test() 79 | model.forward() 80 | pred_coeff = {key: model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict} 81 | import pdb; pdb.set_trace() 82 | # model.test() # run inference 83 | visuals = model.get_current_visuals() # get image results 84 | visualizer.display_current_results(visuals, 0, opt.epoch, dataset=name.split(os.path.sep)[-1], 85 | save_results=True, count=i, name=img_name, add_image=False) 86 | # 87 | # model.save_mesh(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.obj')) # save reconstruction meshes 88 | # model.save_coeff(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.mat')) # save predicted coefficients 89 | 90 | 91 | if __name__ == '__main__': 92 | cmd = '--checkpoints_dir /apdcephfs/share_1290939/feiiyin/TH/PIRender/Deep3DFaceRecon_pytorch/checkpoints \ 93 | --bfm_folder /apdcephfs/share_1290939/feiiyin/TH/PIRender/Deep3DFaceRecon_pytorch/BFM --name=model_name \ 94 | --epoch=20 --img_folder=/apdcephfs_cq2/share_1290939/feiiyin/InvEG3D2/third_part/Deep3DFaceRecon_pytorch/datasets/' 95 | # --epoch=20 --img_folder=temp' 96 | 97 | opt = TestOptions(cmd_line=cmd).parse() # get test options 98 | main(0, opt, opt.img_folder) 99 | 100 | 101 | # python test.py -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/util/BBRegressorParam_r.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiiYin/SPI/fcb1f578f9e7112760a8725c74e4a5a68950c8bd/third_part/Deep3DFaceRecon_pytorch/util/BBRegressorParam_r.mat -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | from .util import * 3 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/util/generate_list.py: -------------------------------------------------------------------------------- 1 | """This script is to generate training list files for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | 6 | # save path to training data 7 | def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): 8 | save_path = os.path.join(save_folder, mode) 9 | if not os.path.isdir(save_path): 10 | os.makedirs(save_path) 11 | with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: 12 | fd.writelines([i + '\n' for i in lms_list]) 13 | 14 | with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: 15 | fd.writelines([i + '\n' for i in imgs_list]) 16 | 17 | with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: 18 | fd.writelines([i + '\n' for i in msks_list]) 19 | 20 | # check if the path is valid 21 | def check_list(rlms_list, rimgs_list, rmsks_list): 22 | lms_list, imgs_list, msks_list = [], [], [] 23 | for i in range(len(rlms_list)): 24 | flag = 'false' 25 | lm_path = rlms_list[i] 26 | im_path = rimgs_list[i] 27 | msk_path = rmsks_list[i] 28 | if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): 29 | flag = 'true' 30 | lms_list.append(rlms_list[i]) 31 | imgs_list.append(rimgs_list[i]) 32 | msks_list.append(rmsks_list[i]) 33 | print(i, rlms_list[i], flag) 34 | return lms_list, imgs_list, msks_list 35 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/util/nvdiffrast.py: -------------------------------------------------------------------------------- 1 | """This script is the differentiable renderer for Deep3DFaceRecon_pytorch 2 | Attention, antialiasing step is missing in current version. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import kornia 8 | from kornia.geometry.camera import pixel2cam 9 | import numpy as np 10 | from typing import List 11 | import nvdiffrast.torch as dr 12 | from scipy.io import loadmat 13 | from torch import nn 14 | 15 | def ndc_projection(x=0.1, n=1.0, f=50.0): 16 | return np.array([[n/x, 0, 0, 0], 17 | [ 0, n/-x, 0, 0], 18 | [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], 19 | [ 0, 0, -1, 0]]).astype(np.float32) 20 | 21 | class MeshRenderer(nn.Module): 22 | def __init__(self, 23 | rasterize_fov, 24 | znear=0.1, 25 | zfar=10, 26 | rasterize_size=224): 27 | super(MeshRenderer, self).__init__() 28 | 29 | x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear 30 | self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( 31 | torch.diag(torch.tensor([1., -1, -1, 1]))) 32 | self.rasterize_size = rasterize_size 33 | self.glctx = None 34 | 35 | def forward(self, vertex, tri, feat=None): 36 | """ 37 | Return: 38 | mask -- torch.tensor, size (B, 1, H, W) 39 | depth -- torch.tensor, size (B, 1, H, W) 40 | features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None 41 | 42 | Parameters: 43 | vertex -- torch.tensor, size (B, N, 3) 44 | tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles 45 | feat(optional) -- torch.tensor, size (B, C), features 46 | """ 47 | device = vertex.device 48 | rsize = int(self.rasterize_size) 49 | ndc_proj = self.ndc_proj.to(device) 50 | # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v 51 | if vertex.shape[-1] == 3: 52 | vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) 53 | vertex[..., 1] = -vertex[..., 1] 54 | 55 | 56 | vertex_ndc = vertex @ ndc_proj.t() 57 | if self.glctx is None: 58 | self.glctx = dr.RasterizeGLContext(device=device) 59 | print("create glctx on device cuda:%d"%device.index) 60 | 61 | ranges = None 62 | if isinstance(tri, List) or len(tri.shape) == 3: 63 | vum = vertex_ndc.shape[1] 64 | fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) 65 | fstartidx = torch.cumsum(fnum, dim=0) - fnum 66 | ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() 67 | for i in range(tri.shape[0]): 68 | tri[i] = tri[i] + i*vum 69 | vertex_ndc = torch.cat(vertex_ndc, dim=0) 70 | tri = torch.cat(tri, dim=0) 71 | 72 | # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] 73 | tri = tri.type(torch.int32).contiguous() 74 | rast_out, _ = dr.rasterize(self.glctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges) 75 | 76 | depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri) 77 | depth = depth.permute(0, 3, 1, 2) 78 | mask = (rast_out[..., 3] > 0).float().unsqueeze(1) 79 | depth = mask * depth 80 | 81 | 82 | image = None 83 | if feat is not None: 84 | image, _ = dr.interpolate(feat, rast_out, tri) 85 | image = image.permute(0, 3, 1, 2) 86 | image = mask * image 87 | 88 | return mask, depth, image 89 | 90 | -------------------------------------------------------------------------------- /third_part/Deep3DFaceRecon_pytorch/util/test_mean_face.txt: -------------------------------------------------------------------------------- 1 | -5.228591537475585938e+01 2 | 2.078247070312500000e-01 3 | -5.064269638061523438e+01 4 | -1.315765380859375000e+01 5 | -4.952939224243164062e+01 6 | -2.592591094970703125e+01 7 | -4.793047332763671875e+01 8 | -3.832135772705078125e+01 9 | -4.512159729003906250e+01 10 | -5.059623336791992188e+01 11 | -3.917720794677734375e+01 12 | -6.043736648559570312e+01 13 | -2.929953765869140625e+01 14 | -6.861183166503906250e+01 15 | -1.719801330566406250e+01 16 | -7.572736358642578125e+01 17 | -1.961936950683593750e+00 18 | -7.862001037597656250e+01 19 | 1.467941284179687500e+01 20 | -7.607844543457031250e+01 21 | 2.744073486328125000e+01 22 | -6.915261840820312500e+01 23 | 3.855677795410156250e+01 24 | -5.950350570678710938e+01 25 | 4.478240966796875000e+01 26 | -4.867547225952148438e+01 27 | 4.714337158203125000e+01 28 | -3.800830078125000000e+01 29 | 4.940315246582031250e+01 30 | -2.496297454833984375e+01 31 | 5.117234802246093750e+01 32 | -1.241538238525390625e+01 33 | 5.190507507324218750e+01 34 | 8.244247436523437500e-01 35 | -4.150688934326171875e+01 36 | 2.386329650878906250e+01 37 | -3.570307159423828125e+01 38 | 3.017010498046875000e+01 39 | -2.790358734130859375e+01 40 | 3.212951660156250000e+01 41 | -1.941773223876953125e+01 42 | 3.156523132324218750e+01 43 | -1.138106536865234375e+01 44 | 2.841992187500000000e+01 45 | 5.993263244628906250e+00 46 | 2.895182800292968750e+01 47 | 1.343590545654296875e+01 48 | 3.189880371093750000e+01 49 | 2.203153991699218750e+01 50 | 3.302221679687500000e+01 51 | 2.992478942871093750e+01 52 | 3.099150085449218750e+01 53 | 3.628388977050781250e+01 54 | 2.765748596191406250e+01 55 | -1.933914184570312500e+00 56 | 1.405374145507812500e+01 57 | -2.153038024902343750e+00 58 | 5.772636413574218750e+00 59 | -2.270050048828125000e+00 60 | -2.121643066406250000e+00 61 | -2.218330383300781250e+00 62 | -1.068978118896484375e+01 63 | -1.187252044677734375e+01 64 | -1.997912597656250000e+01 65 | -6.879402160644531250e+00 66 | -2.143579864501953125e+01 67 | -1.227821350097656250e+00 68 | -2.193494415283203125e+01 69 | 4.623237609863281250e+00 70 | -2.152721405029296875e+01 71 | 9.721397399902343750e+00 72 | -1.953671264648437500e+01 73 | -3.648714447021484375e+01 74 | 9.811126708984375000e+00 75 | -3.130242919921875000e+01 76 | 1.422447967529296875e+01 77 | -2.212834930419921875e+01 78 | 1.493019866943359375e+01 79 | -1.500880432128906250e+01 80 | 1.073588562011718750e+01 81 | -2.095037078857421875e+01 82 | 9.054298400878906250e+00 83 | -3.050099182128906250e+01 84 | 8.704177856445312500e+00 85 | 1.173237609863281250e+01 86 | 1.054329681396484375e+01 87 | 1.856353759765625000e+01 88 | 1.535009765625000000e+01 89 | 2.893331909179687500e+01 90 | 1.451992797851562500e+01 91 | 3.452944946289062500e+01 92 | 1.065280151367187500e+01 93 | 2.875990295410156250e+01 94 | 8.654792785644531250e+00 95 | 1.942100524902343750e+01 96 | 9.422447204589843750e+00 97 | -2.204488372802734375e+01 98 | -3.983994293212890625e+01 99 | -1.324458312988281250e+01 100 | -3.467377471923828125e+01 101 | -6.749649047851562500e+00 102 | -3.092894744873046875e+01 103 | -9.183349609375000000e-01 104 | -3.196458435058593750e+01 105 | 4.220649719238281250e+00 106 | -3.090406036376953125e+01 107 | 1.089889526367187500e+01 108 | -3.497008514404296875e+01 109 | 1.874589538574218750e+01 110 | -4.065438079833984375e+01 111 | 1.124106597900390625e+01 112 | -4.438417816162109375e+01 113 | 5.181709289550781250e+00 114 | -4.649170684814453125e+01 115 | -1.158607482910156250e+00 116 | -4.680406951904296875e+01 117 | -7.918922424316406250e+00 118 | -4.671575164794921875e+01 119 | -1.452505493164062500e+01 120 | -4.416526031494140625e+01 121 | -2.005007171630859375e+01 122 | -3.997841644287109375e+01 123 | -1.054919433593750000e+01 124 | -3.849683380126953125e+01 125 | -1.051826477050781250e+00 126 | -3.794863128662109375e+01 127 | 6.412681579589843750e+00 128 | -3.804645538330078125e+01 129 | 1.627674865722656250e+01 130 | -4.039697265625000000e+01 131 | 6.373878479003906250e+00 132 | -4.087213897705078125e+01 133 | -8.551712036132812500e-01 134 | -4.157129669189453125e+01 135 | -1.014953613281250000e+01 136 | -4.128469085693359375e+01 137 | -------------------------------------------------------------------------------- /third_part/bisenet/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() --------------------------------------------------------------------------------