├── LICENSE ├── README.md ├── StyleCLIP ├── .gitignore ├── criteria │ ├── __init__.py │ ├── clip_loss.py │ └── id_loss.py ├── models │ ├── .DS_Store │ ├── __init__.py │ ├── facial_recognition │ │ ├── __init__.py │ │ ├── helpers.py │ │ └── model_irse.py │ └── stylegan2 │ │ ├── __init__.py │ │ ├── model.py │ │ └── op │ │ ├── __init__.py │ │ ├── fused_act.py │ │ └── upfirdn2d.py ├── optimization │ ├── .DS_Store │ ├── __init__.py │ ├── mixin.py │ ├── pca.py │ └── text.py └── utils.py ├── assets ├── pipeline.png └── teaser.png ├── environment.yml ├── examples ├── 01.jpg ├── aerith.png ├── david.jpg ├── mona_lisa.jpg ├── pope.jpg └── van_gouh.jpg ├── licenses ├── LICENSE-StyleCLIP └── LICENSE-restyle_encoder ├── main.py ├── pregen_latents ├── aging │ ├── d_latent │ ├── d_latent_baseline │ └── d_latent_oldGrayHair ├── angry │ └── d_latent ├── eyesClose │ └── d_latent ├── headsTurn │ └── d_latent └── smile │ └── d_latent └── restyle_encoder ├── .gitignore ├── LICENSE ├── cog.yaml ├── configs ├── __init__.py ├── data_configs.py ├── paths_config.py └── transforms_config.py ├── criteria ├── __init__.py ├── id_loss.py ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── moco_loss.py └── w_norm.py ├── datasets ├── __init__.py ├── gt_res_dataset.py ├── images_dataset.py └── inference_dataset.py ├── docs ├── 02530.jpg ├── 2441.jpg ├── 2598.jpg ├── 346.jpg ├── ardern.jpg ├── macron.jpg ├── merkel.jpg └── teaser.jpg ├── editing ├── __init__.py ├── inference_editing.py ├── interfacegan_directions │ ├── age.pt │ ├── pose.pt │ └── smile.pt └── latent_editor.py ├── environment └── restyle_env.yaml ├── licenses ├── LICENSE_S-aiueo32 ├── LICENSE_TreB1eN ├── LICENSE_eladrich ├── LICENSE_lessw2020 ├── LICENSE_omertov └── LICENSE_rosinality ├── main.py ├── models ├── __init__.py ├── e4e.py ├── e4e_modules │ ├── __init__.py │ ├── discriminator.py │ └── latent_codes_pool.py ├── encoders │ ├── __init__.py │ ├── fpn_encoders.py │ ├── helpers.py │ ├── map2style.py │ ├── model_irse.py │ ├── restyle_e4e_encoders.py │ └── restyle_psp_encoders.py ├── mtcnn │ ├── __init__.py │ ├── mtcnn.py │ └── mtcnn_pytorch │ │ ├── __init__.py │ │ └── src │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── box_utils.py │ │ ├── detector.py │ │ ├── first_stage.py │ │ ├── get_nets.py │ │ ├── matlab_cp2tform.py │ │ ├── visualization_utils.py │ │ └── weights │ │ ├── onet.npy │ │ ├── pnet.npy │ │ └── rnet.npy ├── psp.py └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── options ├── __init__.py ├── e4e_train_options.py ├── test_options.py └── train_options.py ├── scriptsLocal ├── __init__.py ├── align_faces_parallel.py ├── calc_id_loss_parallel.py ├── calc_losses_on_images.py ├── encoder_bootstrapping_inference.py ├── inference_iterative.py ├── inference_iterative_save_coupled.py ├── train_restyle_e4e.py └── train_restyle_psp.py ├── training ├── __init__.py ├── coach_restyle_e4e.py ├── coach_restyle_psp.py └── ranger.py └── utils ├── __init__.py ├── common.py ├── data_utils.py ├── inference_utils.py ├── model_utils.py └── train_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Qiucheng Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /StyleCLIP/.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /latents/ 3 | /results/ 4 | /mine/ 5 | *.pt 6 | *.pth -------------------------------------------------------------------------------- /StyleCLIP/criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/StyleCLIP/criteria/__init__.py -------------------------------------------------------------------------------- /StyleCLIP/criteria/clip_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import clip 4 | 5 | 6 | class CLIPLoss(torch.nn.Module): 7 | 8 | def __init__(self, opts): 9 | super(CLIPLoss, self).__init__() 10 | self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") 11 | self.upsample = torch.nn.Upsample(scale_factor=7) 12 | self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) 13 | 14 | def forward(self, image, text): 15 | image = self.avg_pool(self.upsample(image)) 16 | similarity = 1 - self.model(image, text)[0] / 100 17 | return similarity -------------------------------------------------------------------------------- /StyleCLIP/criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from StyleCLIP.models.facial_recognition.model_irse import Backbone 5 | 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self, opts): 9 | super(IDLoss, self).__init__() 10 | # print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) 13 | self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 14 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 15 | self.facenet.eval() 16 | self.facenet.cuda() 17 | self.opts = opts 18 | 19 | def extract_feats(self, x): 20 | if x.shape[2] != 256: 21 | x = self.pool(x) 22 | x = x[:, :, 35:223, 32:220] # Crop interesting region 23 | x = self.face_pool(x) 24 | x_feats = self.facenet(x) 25 | return x_feats 26 | 27 | def forward(self, y_hat, y): 28 | n_samples = y.shape[0] 29 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 30 | y_hat_feats = self.extract_feats(y_hat) 31 | y_feats = y_feats.detach() 32 | loss = 0 33 | sim_improvement = 0 34 | count = 0 35 | for i in range(n_samples): 36 | diff_target = y_hat_feats[i].dot(y_feats[i]) 37 | loss += 1 - diff_target 38 | count += 1 39 | 40 | return loss / count, sim_improvement / count 41 | -------------------------------------------------------------------------------- /StyleCLIP/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/StyleCLIP/models/.DS_Store -------------------------------------------------------------------------------- /StyleCLIP/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/StyleCLIP/models/__init__.py -------------------------------------------------------------------------------- /StyleCLIP/models/facial_recognition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/StyleCLIP/models/facial_recognition/__init__.py -------------------------------------------------------------------------------- /StyleCLIP/models/facial_recognition/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /StyleCLIP/models/facial_recognition/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from StyleCLIP.models.facial_recognition.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /StyleCLIP/models/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/StyleCLIP/models/stylegan2/__init__.py -------------------------------------------------------------------------------- /StyleCLIP/models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /StyleCLIP/models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | 11 | class FusedLeakyReLU(nn.Module): 12 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 13 | super().__init__() 14 | 15 | self.bias = nn.Parameter(torch.zeros(channel)) 16 | self.negative_slope = negative_slope 17 | self.scale = scale 18 | 19 | def forward(self, input): 20 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 21 | 22 | 23 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 24 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 25 | input = input.cuda() 26 | if input.ndim == 3: 27 | return ( 28 | F.leaky_relu( 29 | input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope 30 | ) 31 | * scale 32 | ) 33 | else: 34 | return ( 35 | F.leaky_relu( 36 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope 37 | ) 38 | * scale 39 | ) 40 | 41 | -------------------------------------------------------------------------------- /StyleCLIP/models/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | module_path = os.path.dirname(__file__) 8 | 9 | 10 | 11 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 12 | out = upfirdn2d_native( 13 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 14 | ) 15 | 16 | return out 17 | 18 | 19 | def upfirdn2d_native( 20 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 21 | ): 22 | _, channel, in_h, in_w = input.shape 23 | input = input.reshape(-1, in_h, in_w, 1) 24 | 25 | _, in_h, in_w, minor = input.shape 26 | kernel_h, kernel_w = kernel.shape 27 | 28 | out = input.view(-1, in_h, 1, in_w, 1, minor) 29 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 30 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 31 | 32 | out = F.pad( 33 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 34 | ) 35 | out = out[ 36 | :, 37 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 38 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 39 | :, 40 | ] 41 | 42 | out = out.permute(0, 3, 1, 2) 43 | out = out.reshape( 44 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 45 | ) 46 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 47 | out = F.conv2d(out, w) 48 | out = out.reshape( 49 | -1, 50 | minor, 51 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 52 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 53 | ) 54 | out = out.permute(0, 2, 3, 1) 55 | out = out[:, ::down_y, ::down_x, :] 56 | 57 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 58 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 59 | 60 | return out.view(-1, channel, out_h, out_w) -------------------------------------------------------------------------------- /StyleCLIP/optimization/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/StyleCLIP/optimization/.DS_Store -------------------------------------------------------------------------------- /StyleCLIP/optimization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/StyleCLIP/optimization/__init__.py -------------------------------------------------------------------------------- /StyleCLIP/optimization/mixin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | from skimage import img_as_ubyte 5 | 6 | from StyleCLIP.models.stylegan2.model import Generator 7 | import imageio 8 | 9 | 10 | def mixin(args, 11 | latents_text_pca, 12 | latent_img, 13 | scale, 14 | num_frames=100, 15 | fps=60, 16 | exp_name=''): 17 | video_dir = 'results/' 18 | if not os.path.exists(video_dir): 19 | os.makedirs(video_dir, exist_ok=True) 20 | 21 | with torch.no_grad(): 22 | latent_none = latents_text_pca[0] 23 | latent_full = latents_text_pca[-1] 24 | latent_zero = latent_img.unsqueeze(0) 25 | diff = latent_full - latent_none 26 | latents = [ 27 | latent_zero + diff * (i / 100) * scale for i in range(0, 100) 28 | ] 29 | g_ema = Generator(args.stylegan_size, 512, 8) 30 | g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) 31 | g_ema.eval() 32 | g_ema = g_ema.cuda() 33 | 34 | images = [] 35 | for frame in range(num_frames): 36 | img_gen_full, _ = g_ema( 37 | [latents[frame]], 38 | input_is_latent=True, 39 | randomize_noise=False, 40 | input_is_stylespace=args.work_in_stylespace) 41 | img_gen_full = (img_gen_full - torch.min(img_gen_full)) / ( 42 | torch.max(img_gen_full) - torch.min(img_gen_full)) 43 | img_gen_full = img_as_ubyte(img_gen_full.squeeze(0).detach().cpu()) 44 | images.append(img_gen_full.swapaxes(0, 2).swapaxes(0, 1)) 45 | imageio.mimsave(os.path.join(video_dir, f"{exp_name}.gif"), 46 | images, 47 | fps=fps) 48 | -------------------------------------------------------------------------------- /StyleCLIP/optimization/pca.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | from StyleCLIP.models.stylegan2.model import Generator 8 | 9 | 10 | class RPCA_gpu: 11 | """ low-rank and sparse matrix decomposition via RPCA [1] with CUDA capabilities """ 12 | def __init__(self, D, mu=None, lmbda=None): 13 | self.D = D 14 | self.S = torch.zeros_like(self.D) 15 | self.Y = torch.zeros_like(self.D) 16 | self.mu = mu or (np.prod(self.D.shape) / 17 | (4 * self.norm_p(self.D, 2))).item() 18 | self.mu_inv = 1 / self.mu 19 | self.lmbda = lmbda or 1 / np.sqrt(np.max(self.D.shape)) 20 | 21 | @staticmethod 22 | def norm_p(M, p): 23 | return torch.sum(torch.pow(M, p)) 24 | 25 | @staticmethod 26 | def shrink(M, tau): 27 | return torch.sign(M) * F.relu( 28 | torch.abs(M) - tau) # hack to save memory 29 | 30 | def svd_threshold(self, M, tau): 31 | U, s, V = torch.svd(M, some=True) 32 | return torch.mm(U, torch.mm(torch.diag(self.shrink(s, tau)), V.t())) 33 | 34 | def fit(self, tol=None, max_iter=1000, iter_print=100): 35 | i, err = 0, np.inf 36 | Sk, Yk, Lk = self.S, self.Y, torch.zeros_like(self.D) 37 | _tol = tol or 1e-7 * self.norm_p(torch.abs(self.D), 2) 38 | while err > _tol and i < max_iter: 39 | Lk = self.svd_threshold(self.D - Sk + self.mu_inv * Yk, 40 | self.mu_inv) 41 | Sk = self.shrink(self.D - Lk + (self.mu_inv * Yk), 42 | self.mu_inv * self.lmbda) 43 | Yk = Yk + self.mu * (self.D - Lk - Sk) 44 | err = self.norm_p(torch.abs(self.D - Lk - Sk), 2) / self.norm_p( 45 | self.D, 2) 46 | i += 1 47 | # if (i % iter_print) == 0 or i == 1 or i > max_iter or err <= _tol: 48 | # print(f'Iteration: {i}; Error: {err:0.4e}') 49 | self.L, self.S = Lk, Sk 50 | return Lk, Sk 51 | 52 | 53 | def get_pca_latent(args, latents, text, degrees, exp_name): 54 | save_dir = 'text_pca/' 55 | if not os.path.exists(save_dir): 56 | os.makedirs(save_dir, exist_ok=True) 57 | text_latents = [] 58 | new_latents = [torch.zeros_like(l) for l in latents] 59 | for i in range(latents[0].shape[0]): 60 | new_tensor = torch.zeros(0, 512).to("cuda") 61 | for j in range(len(new_latents)): 62 | new_tensor = torch.cat((new_tensor, latents[j][i].reshape(1, -1)), 63 | dim=0) 64 | # results = torch.pca_lowrank(new_tensor, q=4, center=False) 65 | solver = RPCA_gpu(new_tensor) 66 | new_tensor_lowrank, _ = solver.fit() 67 | results = torch.pca_lowrank(new_tensor_lowrank, q=4, center=False) 68 | 69 | tmp = torch.matmul(results[0], torch.diag(results[1])) 70 | tmp = torch.matmul(tmp, torch.transpose(results[2], 0, 1)) 71 | for j in range(len(new_latents)): 72 | new_latents[j][i] = tmp[j] 73 | 74 | g_ema = Generator(args.stylegan_size, 512, 8) 75 | g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) 76 | g_ema.eval() 77 | g_ema = g_ema.cuda() 78 | 79 | for i, degree in enumerate(degrees): 80 | text_latents.append(torch.unsqueeze(new_latents[i], 0)) 81 | img_gen, _ = g_ema([torch.unsqueeze(new_latents[i], 0)], 82 | input_is_latent=True, 83 | randomize_noise=False, 84 | input_is_stylespace=args.work_in_stylespace) 85 | torchvision.utils.save_image(img_gen, 86 | f"{save_dir}/{exp_name}_{degree}.png", 87 | normalize=True, 88 | range=(-1, 1)) 89 | return text_latents 90 | -------------------------------------------------------------------------------- /StyleCLIP/optimization/text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import numpy as np 5 | import random 6 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 7 | import torch 8 | import torchvision 9 | from torch import optim 10 | from tqdm import tqdm 11 | import clip 12 | 13 | from StyleCLIP.criteria.clip_loss import CLIPLoss 14 | from StyleCLIP.criteria.id_loss import IDLoss 15 | from StyleCLIP.models.stylegan2.model import Generator 16 | from StyleCLIP.models.stylegan2.model import Generator 17 | 18 | 19 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): 20 | lr_ramp = min(1, (1 - t) / rampdown) 21 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) 22 | lr_ramp = lr_ramp * min(1, t / rampup) 23 | 24 | return initial_lr * lr_ramp 25 | 26 | 27 | def get_text_latent(args, 28 | text, 29 | degrees=[ 30 | "10%", "20%", "30%", "40%", "50%", "60%", "70%", "80%", 31 | "90%", "100%" 32 | ], 33 | seed=0, 34 | exp_name=''): 35 | torch.backends.cudnn.benchmark = False 36 | torch.backends.cudnn.deterministic = True 37 | random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | np.random.seed(seed) 42 | 43 | latents = [] 44 | temp_dir = 'text/' 45 | if not os.path.exists(temp_dir): 46 | os.makedirs(temp_dir, exist_ok=True) 47 | with torch.no_grad(): 48 | g_ema = Generator(1024, 512, 8) 49 | g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) 50 | g_ema.eval() 51 | g_ema = g_ema.cuda() 52 | model, _ = clip.load("ViT-B/32", device="cuda") 53 | mean_latent = g_ema.mean_latent(4096) 54 | upsample_func = torch.nn.Upsample(1024) 55 | 56 | # set image from a random latent 57 | latent_code_init_not_trunc = torch.randn(1, 512).cuda() 58 | _, latent_code_init, _ = g_ema([latent_code_init_not_trunc], 59 | return_latents=True, 60 | truncation=args.truncation, 61 | truncation_latent=mean_latent) 62 | 63 | latents.append(latent_code_init.detach().clone().squeeze(0)) 64 | 65 | # start editing 66 | img_orig, _ = g_ema([latent_code_init], 67 | input_is_latent=True, 68 | randomize_noise=False) 69 | img_orig = upsample_func(img_orig) 70 | 71 | clip_loss = CLIPLoss(args).cuda() 72 | id_loss = IDLoss(args).cuda() 73 | torchvision.utils.save_image(img_orig, 74 | f"{temp_dir}/{exp_name}_org.png", 75 | normalize=True, 76 | range=(-1, 1)) 77 | 78 | for degree in tqdm(degrees): 79 | full_text = eval("f'{}'".format(text)) 80 | text_inputs = torch.cat([clip.tokenize(full_text)]).cuda() 81 | random.seed(seed) 82 | torch.manual_seed(seed) 83 | torch.cuda.manual_seed(seed) 84 | torch.cuda.manual_seed_all(seed) 85 | np.random.seed(seed) 86 | 87 | latent = latent_code_init.detach().clone() 88 | latent.requires_grad = True 89 | 90 | optimizer = optim.Adam([latent], lr=args.lr) 91 | # pbar = tqdm(range(args.step)) 92 | 93 | for i in range(args.step): 94 | random.seed(seed+i) 95 | torch.manual_seed(seed+i) 96 | torch.cuda.manual_seed(seed+i) 97 | torch.cuda.manual_seed_all(seed+i) 98 | np.random.seed(seed+i) 99 | t = i / args.step 100 | lr = get_lr(t, args.lr) 101 | optimizer.param_groups[0]["lr"] = lr 102 | 103 | img_gen, _ = g_ema([latent], 104 | input_is_latent=True, 105 | randomize_noise=False, 106 | input_is_stylespace=args.work_in_stylespace) 107 | 108 | c_loss = clip_loss(img_gen, text_inputs) 109 | 110 | if args.id_lambda > 0: 111 | i_loss = id_loss(img_gen, img_orig)[0] 112 | else: 113 | i_loss = 0 114 | 115 | if args.mode == "edit": 116 | if args.work_in_stylespace: 117 | l2_loss = sum([ 118 | ((latent_code_init[c] - latent[c])**2).sum() 119 | for c in range(len(latent_code_init)) 120 | ]) 121 | else: 122 | l2_loss = ((latent_code_init - latent)**2).sum() 123 | loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss 124 | else: 125 | loss = c_loss 126 | 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | 131 | with torch.no_grad(): 132 | img_gen, _ = g_ema([latent.detach()], 133 | input_is_latent=True, 134 | randomize_noise=False, 135 | input_is_stylespace=args.work_in_stylespace) 136 | torchvision.utils.save_image(img_gen, 137 | f"{temp_dir}/{exp_name}_{degree}.png", 138 | normalize=True, 139 | range=(-1, 1)) 140 | latents.append(latent.detach().clone().squeeze(0)) 141 | torch.cuda.empty_cache() 142 | 143 | return latents 144 | -------------------------------------------------------------------------------- /StyleCLIP/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | google_drive_paths = { 5 | "stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT", 6 | 7 | "mapper/pretrained/afro.pt": "https://drive.google.com/uc?id=1i5vAqo4z0I-Yon3FNft_YZOq7ClWayQJ", 8 | "mapper/pretrained/angry.pt": "https://drive.google.com/uc?id=1g82HEH0jFDrcbCtn3M22gesWKfzWV_ma", 9 | "mapper/pretrained/beyonce.pt": "https://drive.google.com/uc?id=1KJTc-h02LXs4zqCyo7pzCp0iWeO6T9fz", 10 | "mapper/pretrained/bobcut.pt": "https://drive.google.com/uc?id=1IvyqjZzKS-vNdq_OhwapAcwrxgLAY8UF", 11 | "mapper/pretrained/bowlcut.pt": "https://drive.google.com/uc?id=1xwdxI2YCewSt05dEHgkpmmzoauPjEnnZ", 12 | "mapper/pretrained/curly_hair.pt": "https://drive.google.com/uc?id=1xZ7fFB12Ci6rUbUfaHPpo44xUFzpWQ6M", 13 | "mapper/pretrained/depp.pt": "https://drive.google.com/uc?id=1FPiJkvFPG_y-bFanxLLP91wUKuy-l3IV", 14 | "mapper/pretrained/hilary_clinton.pt": "https://drive.google.com/uc?id=1X7U2zj2lt0KFifIsTfOOzVZXqYyCWVll", 15 | "mapper/pretrained/mohawk.pt": "https://drive.google.com/uc?id=1oMMPc8iQZ7dhyWavZ7VNWLwzf9aX4C09", 16 | "mapper/pretrained/purple_hair.pt": "https://drive.google.com/uc?id=14H0CGXWxePrrKIYmZnDD2Ccs65EEww75", 17 | "mapper/pretrained/surprised.pt": "https://drive.google.com/uc?id=1F-mPrhO-UeWrV1QYMZck63R43aLtPChI", 18 | "mapper/pretrained/taylor_swift.pt": "https://drive.google.com/uc?id=10jHuHsKKJxuf3N0vgQbX_SMEQgFHDrZa", 19 | "mapper/pretrained/trump.pt": "https://drive.google.com/uc?id=14v8D0uzy4tOyfBU3ca9T0AzTt3v-dNyh", 20 | "mapper/pretrained/zuckerberg.pt": "https://drive.google.com/uc?id=1NjDcMUL8G-pO3i_9N6EPpQNXeMc3Ar1r", 21 | 22 | "example_celebs.pt": "https://drive.google.com/uc?id=1VL3lP4avRhz75LxSza6jgDe-pHd2veQG" 23 | } 24 | 25 | 26 | def ensure_checkpoint_exists(model_weights_filename): 27 | if not os.path.isfile(model_weights_filename) and ( 28 | model_weights_filename in google_drive_paths 29 | ): 30 | gdrive_url = google_drive_paths[model_weights_filename] 31 | try: 32 | from gdown import download as drive_download 33 | 34 | drive_download(gdrive_url, model_weights_filename, quiet=False) 35 | except ModuleNotFoundError: 36 | print( 37 | "gdown module not found.", 38 | "pip3 install gdown or, manually download the checkpoint file:", 39 | gdrive_url 40 | ) 41 | 42 | if not os.path.isfile(model_weights_filename) and ( 43 | model_weights_filename not in google_drive_paths 44 | ): 45 | print( 46 | model_weights_filename, 47 | " not found, you may need to manually download the model weights." 48 | ) 49 | 50 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/assets/pipeline.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/assets/teaser.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: test-grasping 2 | dependencies: 3 | - python=3.9.7 4 | - pip: 5 | - dlib==19.23.0 6 | - imageio==2.9.0 7 | - imageio-ffmpeg==0.4.7 8 | - numpy==1.22.3 9 | - pillow==9.1.0 10 | - scipy==1.7.1 11 | - tqdm==4.64.0 12 | - typing-extensions==4.2.0 13 | - clip-by-openai 14 | -------------------------------------------------------------------------------- /examples/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/examples/01.jpg -------------------------------------------------------------------------------- /examples/aerith.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/examples/aerith.png -------------------------------------------------------------------------------- /examples/david.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/examples/david.jpg -------------------------------------------------------------------------------- /examples/mona_lisa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/examples/mona_lisa.jpg -------------------------------------------------------------------------------- /examples/pope.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/examples/pope.jpg -------------------------------------------------------------------------------- /examples/van_gouh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/examples/van_gouh.jpg -------------------------------------------------------------------------------- /licenses/LICENSE-StyleCLIP: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Or Patashnik, Zongze Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/LICENSE-restyle_encoder: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pregen_latents/aging/d_latent: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/pregen_latents/aging/d_latent -------------------------------------------------------------------------------- /pregen_latents/aging/d_latent_baseline: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/pregen_latents/aging/d_latent_baseline -------------------------------------------------------------------------------- /pregen_latents/aging/d_latent_oldGrayHair: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/pregen_latents/aging/d_latent_oldGrayHair -------------------------------------------------------------------------------- /pregen_latents/angry/d_latent: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/pregen_latents/angry/d_latent -------------------------------------------------------------------------------- /pregen_latents/eyesClose/d_latent: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/pregen_latents/eyesClose/d_latent -------------------------------------------------------------------------------- /pregen_latents/headsTurn/d_latent: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/pregen_latents/headsTurn/d_latent -------------------------------------------------------------------------------- /pregen_latents/smile/d_latent: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/pregen_latents/smile/d_latent -------------------------------------------------------------------------------- /restyle_encoder/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | pretrained_models/ 131 | -------------------------------------------------------------------------------- /restyle_encoder/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /restyle_encoder/cog.yaml: -------------------------------------------------------------------------------- 1 | image: "r8.im/yuval-alaluf/restyle_encoder" 2 | build: 3 | gpu: true 4 | python_version: "3.8" 5 | system_packages: 6 | - "cmake" 7 | - "libgl1-mesa-glx" 8 | - "libglib2.0-0" 9 | - "zip" 10 | - "ninja-build" 11 | python_packages: 12 | - "Pillow==8.3.1" 13 | - "cmake==3.21.1" 14 | - "dlib==19.22.1" 15 | - "imageio==2.9.0" 16 | - "ipython==7.21.0" 17 | - "matplotlib==3.1.3" 18 | - "numpy==1.21.1" 19 | - "opencv-python==4.5.3.56" 20 | - "scipy==1.4.1" 21 | - "tensorboard==2.2.1" 22 | - "torch==1.8.0" 23 | - "torchvision==0.9.0" 24 | - "tqdm==4.42.1" 25 | pre_install: 26 | - "mkdir /content" 27 | - "wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip" 28 | - "unzip ninja-linux.zip -d /usr/local/bin/" 29 | - "update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force" 30 | - "wget -O /content/shape_predictor_68_face_landmarks.dat.bz2 http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2" 31 | - "cd /content && bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2" 32 | predict: predict.py:Predictor -------------------------------------------------------------------------------- /restyle_encoder/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/configs/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/configs/data_configs.py: -------------------------------------------------------------------------------- 1 | from configs import transforms_config 2 | from configs.paths_config import dataset_paths 3 | 4 | 5 | DATASETS = { 6 | 'ffhq_encode': { 7 | 'transforms': transforms_config.EncodeTransforms, 8 | 'train_source_root': dataset_paths['ffhq'], 9 | 'train_target_root': dataset_paths['ffhq'], 10 | 'test_source_root': dataset_paths['celeba_test'], 11 | 'test_target_root': dataset_paths['celeba_test'] 12 | }, 13 | "cars_encode": { 14 | 'transforms': transforms_config.CarsEncodeTransforms, 15 | 'train_source_root': dataset_paths['cars_train'], 16 | 'train_target_root': dataset_paths['cars_train'], 17 | 'test_source_root': dataset_paths['cars_test'], 18 | 'test_target_root': dataset_paths['cars_test'] 19 | }, 20 | "church_encode": { 21 | 'transforms': transforms_config.EncodeTransforms, 22 | 'train_source_root': dataset_paths['church_train'], 23 | 'train_target_root': dataset_paths['church_train'], 24 | 'test_source_root': dataset_paths['church_test'], 25 | 'test_target_root': dataset_paths['church_test'] 26 | }, 27 | "horse_encode": { 28 | 'transforms': transforms_config.EncodeTransforms, 29 | 'train_source_root': dataset_paths['horse_train'], 30 | 'train_target_root': dataset_paths['horse_train'], 31 | 'test_source_root': dataset_paths['horse_test'], 32 | 'test_target_root': dataset_paths['horse_test'] 33 | }, 34 | "afhq_wild_encode": { 35 | 'transforms': transforms_config.EncodeTransforms, 36 | 'train_source_root': dataset_paths['afhq_wild_train'], 37 | 'train_target_root': dataset_paths['afhq_wild_train'], 38 | 'test_source_root': dataset_paths['afhq_wild_test'], 39 | 'test_target_root': dataset_paths['afhq_wild_test'] 40 | }, 41 | "toonify": { 42 | 'transforms': transforms_config.EncodeTransforms, 43 | 'train_source_root': dataset_paths['ffhq'], 44 | 'train_target_root': dataset_paths['ffhq'], 45 | 'test_source_root': dataset_paths['celeba_test'], 46 | 'test_target_root': dataset_paths['celeba_test'] 47 | } 48 | } -------------------------------------------------------------------------------- /restyle_encoder/configs/paths_config.py: -------------------------------------------------------------------------------- 1 | dataset_paths = { 2 | 'ffhq': '', 3 | 'celeba_test': '', 4 | 5 | 'cars_train': '', 6 | 'cars_test': '', 7 | 8 | 'church_train': '', 9 | 'church_test': '', 10 | 11 | 'horse_train': '', 12 | 'horse_test': '', 13 | 14 | 'afhq_wild_train': '', 15 | 'afhq_wild_test': '' 16 | } 17 | 18 | model_paths = { 19 | 'ir_se50': 'pretrained_models/model_ir_se50.pth', 20 | 'resnet34': 'pretrained_models/resnet34-333f7ec4.pth', 21 | 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt', 22 | 'stylegan_cars': 'pretrained_models/stylegan2-car-config-f.pt', 23 | 'stylegan_church': 'pretrained_models/stylegan2-church-config-f.pt', 24 | 'stylegan_horse': 'pretrained_models/stylegan2-horse-config-f.pt', 25 | 'stylegan_ada_wild': 'pretrained_models/afhqwild.pt', 26 | 'stylegan_toonify': 'pretrained_models/ffhq_cartoon_blended.pt', 27 | 'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat', 28 | 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth', 29 | 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy', 30 | 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy', 31 | 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy', 32 | 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pt' 33 | } 34 | -------------------------------------------------------------------------------- /restyle_encoder/configs/transforms_config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torchvision.transforms as transforms 3 | 4 | 5 | class TransformsConfig(object): 6 | 7 | def __init__(self, opts): 8 | self.opts = opts 9 | 10 | @abstractmethod 11 | def get_transforms(self): 12 | pass 13 | 14 | 15 | class EncodeTransforms(TransformsConfig): 16 | 17 | def __init__(self, opts): 18 | super(EncodeTransforms, self).__init__(opts) 19 | 20 | def get_transforms(self): 21 | transforms_dict = { 22 | 'transform_gt_train': transforms.Compose([ 23 | transforms.Resize((256, 256)), 24 | transforms.RandomHorizontalFlip(0.5), 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 27 | 'transform_source': None, 28 | 'transform_test': transforms.Compose([ 29 | transforms.Resize((256, 256)), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 32 | 'transform_inference': transforms.Compose([ 33 | transforms.Resize((256, 256)), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 36 | } 37 | return transforms_dict 38 | 39 | 40 | class CarsEncodeTransforms(TransformsConfig): 41 | 42 | def __init__(self, opts): 43 | super(CarsEncodeTransforms, self).__init__(opts) 44 | 45 | def get_transforms(self): 46 | transforms_dict = { 47 | 'transform_gt_train': transforms.Compose([ 48 | transforms.Resize((192, 256)), 49 | transforms.RandomHorizontalFlip(0.5), 50 | transforms.ToTensor(), 51 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 52 | 'transform_source': None, 53 | 'transform_test': transforms.Compose([ 54 | transforms.Resize((192, 256)), 55 | transforms.ToTensor(), 56 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 57 | 'transform_inference': transforms.Compose([ 58 | transforms.Resize((192, 256)), 59 | transforms.ToTensor(), 60 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 61 | } 62 | return transforms_dict 63 | -------------------------------------------------------------------------------- /restyle_encoder/criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/criteria/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from configs.paths_config import model_paths 4 | from models.encoders.model_irse import Backbone 5 | 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self): 9 | super(IDLoss, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def extract_feats(self, x): 17 | x = x[:, :, 35:223, 32:220] # Crop interesting region 18 | x = self.face_pool(x) 19 | x_feats = self.facenet(x) 20 | return x_feats 21 | 22 | def forward(self, y_hat, y, x): 23 | n_samples = x.shape[0] 24 | x_feats = self.extract_feats(x) 25 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 26 | y_hat_feats = self.extract_feats(y_hat) 27 | y_feats = y_feats.detach() 28 | loss = 0 29 | sim_improvement = 0 30 | id_logs = [] 31 | count = 0 32 | for i in range(n_samples): 33 | diff_target = y_hat_feats[i].dot(y_feats[i]) 34 | diff_input = y_hat_feats[i].dot(x_feats[i]) 35 | diff_views = y_feats[i].dot(x_feats[i]) 36 | id_logs.append({'diff_target': float(diff_target), 37 | 'diff_input': float(diff_input), 38 | 'diff_views': float(diff_views)}) 39 | loss += 1 - diff_target 40 | id_diff = float(diff_target) - float(diff_views) 41 | sim_improvement += id_diff 42 | count += 1 43 | 44 | return loss / count, sim_improvement / count, id_logs 45 | -------------------------------------------------------------------------------- /restyle_encoder/criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /restyle_encoder/criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /restyle_encoder/criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /restyle_encoder/criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from configs.paths_config import model_paths 5 | 6 | 7 | class MocoLoss(nn.Module): 8 | 9 | def __init__(self): 10 | super(MocoLoss, self).__init__() 11 | print("Loading MOCO model from path: {}".format(model_paths["moco"])) 12 | self.model = self.__load_model() 13 | self.model.cuda() 14 | self.model.eval() 15 | 16 | @staticmethod 17 | def __load_model(): 18 | import torchvision.models as models 19 | model = models.__dict__["resnet50"]() 20 | # freeze all layers but the last fc 21 | for name, param in model.named_parameters(): 22 | if name not in ['fc.weight', 'fc.bias']: 23 | param.requires_grad = False 24 | checkpoint = torch.load(model_paths['moco'], map_location="cpu") 25 | state_dict = checkpoint['state_dict'] 26 | # rename moco pre-trained keys 27 | for k in list(state_dict.keys()): 28 | # retain only encoder_q up to before the embedding layer 29 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 30 | # remove prefix 31 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 32 | # delete renamed or unused k 33 | del state_dict[k] 34 | msg = model.load_state_dict(state_dict, strict=False) 35 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 36 | # remove output layer 37 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 38 | return model 39 | 40 | def extract_feats(self, x): 41 | x = F.interpolate(x, size=224) 42 | x_feats = self.model(x) 43 | x_feats = nn.functional.normalize(x_feats, dim=1) 44 | x_feats = x_feats.squeeze() 45 | return x_feats 46 | 47 | def forward(self, y_hat, y, x): 48 | n_samples = x.shape[0] 49 | x_feats = self.extract_feats(x) 50 | y_feats = self.extract_feats(y) 51 | y_hat_feats = self.extract_feats(y_hat) 52 | y_feats = y_feats.detach() 53 | loss = 0 54 | sim_improvement = 0 55 | sim_logs = [] 56 | count = 0 57 | for i in range(n_samples): 58 | diff_target = y_hat_feats[i].dot(y_feats[i]) 59 | diff_input = y_hat_feats[i].dot(x_feats[i]) 60 | diff_views = y_feats[i].dot(x_feats[i]) 61 | sim_logs.append({'diff_target': float(diff_target), 62 | 'diff_input': float(diff_input), 63 | 'diff_views': float(diff_views)}) 64 | loss += 1 - diff_target 65 | sim_diff = float(diff_target) - float(diff_views) 66 | sim_improvement += sim_diff 67 | count += 1 68 | 69 | return loss / count, sim_improvement / count, sim_logs 70 | -------------------------------------------------------------------------------- /restyle_encoder/criteria/w_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class WNormLoss(nn.Module): 6 | 7 | def __init__(self, start_from_latent_avg=True): 8 | super(WNormLoss, self).__init__() 9 | self.start_from_latent_avg = start_from_latent_avg 10 | 11 | def forward(self, latent, latent_avg=None): 12 | if self.start_from_latent_avg: 13 | latent = latent - latent_avg 14 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] 15 | -------------------------------------------------------------------------------- /restyle_encoder/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/datasets/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/datasets/gt_res_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | 5 | 6 | class GTResDataset(Dataset): 7 | 8 | def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): 9 | self.pairs = [] 10 | for f in os.listdir(root_path): 11 | image_path = os.path.join(root_path, f) 12 | gt_path = os.path.join(gt_dir, f) 13 | if f.endswith(".jpg") or f.endswith(".png") or f.endswith(".jpeg"): 14 | self.pairs.append([image_path, gt_path, None]) 15 | self.transform = transform 16 | self.transform_train = transform_train 17 | 18 | def __len__(self): 19 | return len(self.pairs) 20 | 21 | def __getitem__(self, index): 22 | from_path, to_path, _ = self.pairs[index] 23 | from_im = Image.open(from_path).convert('RGB') 24 | to_im = Image.open(to_path).convert('RGB') 25 | if self.transform: 26 | to_im = self.transform(to_im) 27 | from_im = self.transform(from_im) 28 | return from_im, to_im 29 | -------------------------------------------------------------------------------- /restyle_encoder/datasets/images_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class ImagesDataset(Dataset): 7 | 8 | def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): 9 | self.source_paths = sorted(data_utils.make_dataset(source_root)) 10 | self.target_paths = sorted(data_utils.make_dataset(target_root)) 11 | self.source_transform = source_transform 12 | self.target_transform = target_transform 13 | self.opts = opts 14 | 15 | def __len__(self): 16 | return len(self.source_paths) 17 | 18 | def __getitem__(self, index): 19 | from_path = self.source_paths[index] 20 | to_path = self.target_paths[index] 21 | 22 | from_im = Image.open(from_path).convert('RGB') 23 | to_im = Image.open(to_path).convert('RGB') 24 | 25 | if self.target_transform: 26 | to_im = self.target_transform(to_im) 27 | 28 | if self.source_transform: 29 | from_im = self.source_transform(from_im) 30 | else: 31 | from_im = to_im 32 | 33 | return from_im, to_im 34 | -------------------------------------------------------------------------------- /restyle_encoder/datasets/inference_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class InferenceDataset(Dataset): 7 | 8 | def __init__(self, root, opts, transform=None): 9 | self.paths = sorted(data_utils.make_dataset(root)) 10 | self.transform = transform 11 | self.opts = opts 12 | 13 | def __len__(self): 14 | return len(self.paths) 15 | 16 | def __getitem__(self, index): 17 | from_path = self.paths[index] 18 | from_im = Image.open(from_path).convert('RGB') 19 | if self.transform: 20 | from_im = self.transform(from_im) 21 | return from_im 22 | -------------------------------------------------------------------------------- /restyle_encoder/docs/02530.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/docs/02530.jpg -------------------------------------------------------------------------------- /restyle_encoder/docs/2441.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/docs/2441.jpg -------------------------------------------------------------------------------- /restyle_encoder/docs/2598.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/docs/2598.jpg -------------------------------------------------------------------------------- /restyle_encoder/docs/346.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/docs/346.jpg -------------------------------------------------------------------------------- /restyle_encoder/docs/ardern.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/docs/ardern.jpg -------------------------------------------------------------------------------- /restyle_encoder/docs/macron.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/docs/macron.jpg -------------------------------------------------------------------------------- /restyle_encoder/docs/merkel.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/docs/merkel.jpg -------------------------------------------------------------------------------- /restyle_encoder/docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/docs/teaser.jpg -------------------------------------------------------------------------------- /restyle_encoder/editing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/editing/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/editing/inference_editing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | import sys 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from configs import data_configs 15 | from datasets.inference_dataset import InferenceDataset 16 | from editing.latent_editor import LatentEditor 17 | from models.e4e import e4e 18 | from options.test_options import TestOptions 19 | from utils.common import tensor2im 20 | from utils.inference_utils import get_average_image, run_on_batch 21 | 22 | 23 | def run(): 24 | """ 25 | This script can be used to perform inversion and editing. Please note that this script supports editing using 26 | only the ReStyle-e4e model and currently supports editing using three edit directions found using InterFaceGAN 27 | (age, smile, and pose) on the faces domain. 28 | For performing the edits please provide the arguments `--edit_directions` and `--factor_ranges`. For example, 29 | setting these values to be `--edit_directions=age,smile,pose` and `--factor_ranges=5,5,5` will use a lambda range 30 | between -5 and 5 for each of the attributes. These should be comma-separated lists of the same length. You may 31 | get better results by playing around with the factor ranges for each edit. 32 | """ 33 | test_opts = TestOptions().parse() 34 | 35 | out_path_results = os.path.join(test_opts.exp_dir, 'editing_results') 36 | out_path_coupled = os.path.join(test_opts.exp_dir, 'editing_coupled') 37 | 38 | os.makedirs(out_path_results, exist_ok=True) 39 | os.makedirs(out_path_coupled, exist_ok=True) 40 | 41 | # update test options with options used during training 42 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 43 | opts = ckpt['opts'] 44 | opts.update(vars(test_opts)) 45 | opts = Namespace(**opts) 46 | net = e4e(opts) 47 | net.eval() 48 | net.cuda() 49 | 50 | print('Loading dataset for {}'.format(opts.dataset_type)) 51 | if opts.dataset_type != "ffhq_encode": 52 | raise ValueError("Editing script only supports edits on the faces domain!") 53 | dataset_args = data_configs.DATASETS[opts.dataset_type] 54 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 55 | dataset = InferenceDataset(root=opts.data_path, 56 | transform=transforms_dict['transform_inference'], 57 | opts=opts) 58 | dataloader = DataLoader(dataset, 59 | batch_size=opts.test_batch_size, 60 | shuffle=False, 61 | num_workers=int(opts.test_workers), 62 | drop_last=False) 63 | 64 | if opts.n_images is None: 65 | opts.n_images = len(dataset) 66 | 67 | latent_editor = LatentEditor(net.decoder) 68 | opts.edit_directions = opts.edit_directions.split(',') 69 | opts.factor_ranges = [int(factor) for factor in opts.factor_ranges.split(',')] 70 | if len(opts.edit_directions) != len(opts.factor_ranges): 71 | raise ValueError("Invalid edit directions and factor ranges. Please provide a single factor range for each" 72 | f"edit direction. Given: {opts.edit_directions} and {opts.factor_ranges}") 73 | 74 | avg_image = get_average_image(net, opts) 75 | 76 | global_i = 0 77 | global_time = [] 78 | for input_batch in tqdm(dataloader): 79 | if global_i >= opts.n_images: 80 | break 81 | with torch.no_grad(): 82 | input_cuda = input_batch.cuda().float() 83 | tic = time.time() 84 | result_batch = edit_batch(input_cuda, net, avg_image, latent_editor, opts) 85 | toc = time.time() 86 | global_time.append(toc - tic) 87 | 88 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 89 | for i in range(input_batch.shape[0]): 90 | 91 | im_path = dataset.paths[global_i] 92 | results = result_batch[i] 93 | 94 | inversion = results.pop('inversion') 95 | input_im = tensor2im(input_batch[i]) 96 | 97 | all_edit_results = [] 98 | for edit_name, edit_res in results.items(): 99 | res = np.array(input_im.resize(resize_amount)) # set the input image 100 | res = np.concatenate([res, np.array(inversion.resize(resize_amount))], axis=1) # set the inversion 101 | for result in edit_res: 102 | res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1) 103 | res_im = Image.fromarray(res) 104 | all_edit_results.append(res_im) 105 | 106 | edit_save_dir = os.path.join(out_path_results, edit_name) 107 | os.makedirs(edit_save_dir, exist_ok=True) 108 | res_im.save(os.path.join(edit_save_dir, os.path.basename(im_path))) 109 | 110 | # save final concatenated result if all factor ranges are equal 111 | if opts.factor_ranges.count(opts.factor_ranges[0]) == len(opts.factor_ranges): 112 | coupled_res = np.concatenate(all_edit_results, axis=0) 113 | im_save_path = os.path.join(out_path_coupled, os.path.basename(im_path)) 114 | Image.fromarray(coupled_res).save(im_save_path) 115 | 116 | global_i += 1 117 | 118 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 119 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 120 | print(result_str) 121 | 122 | with open(stats_path, 'w') as f: 123 | f.write(result_str) 124 | 125 | 126 | def edit_batch(inputs, net, avg_image, latent_editor, opts): 127 | y_hat, latents = get_inversions_on_batch(inputs, net, avg_image, opts) 128 | # store all results for each sample, split by the edit direction 129 | results = {idx: {'inversion': tensor2im(y_hat[idx])} for idx in range(len(inputs))} 130 | for edit_direction, factor_range in zip(opts.edit_directions, opts.factor_ranges): 131 | edit_res = latent_editor.apply_interfacegan(latents=latents, 132 | direction=edit_direction, 133 | factor_range=(-1 * factor_range, factor_range)) 134 | # store the results for each sample 135 | for idx, sample_res in edit_res.items(): 136 | results[idx][edit_direction] = sample_res 137 | return results 138 | 139 | 140 | def get_inversions_on_batch(inputs, net, avg_image, opts): 141 | result_batch, result_latents = run_on_batch(inputs, net, opts, avg_image) 142 | # we'll take the final inversion as the inversion to edit 143 | y_hat = [result_batch[idx][-1] for idx in range(len(result_batch))] 144 | latents = [torch.from_numpy(result_latents[idx][-1]).cuda() for idx in range(len(result_batch))] 145 | return y_hat, torch.stack(latents) 146 | 147 | 148 | if __name__ == '__main__': 149 | run() -------------------------------------------------------------------------------- /restyle_encoder/editing/interfacegan_directions/age.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/editing/interfacegan_directions/age.pt -------------------------------------------------------------------------------- /restyle_encoder/editing/interfacegan_directions/pose.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/editing/interfacegan_directions/pose.pt -------------------------------------------------------------------------------- /restyle_encoder/editing/interfacegan_directions/smile.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/editing/interfacegan_directions/smile.pt -------------------------------------------------------------------------------- /restyle_encoder/editing/latent_editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.common import tensor2im 3 | 4 | 5 | class LatentEditor(object): 6 | 7 | def __init__(self, stylegan_generator): 8 | self.generator = stylegan_generator 9 | self.interfacegan_directions = { 10 | 'age': torch.load('editing/interfacegan_directions/age.pt').cuda(), 11 | 'smile': torch.load('editing/interfacegan_directions/smile.pt').cuda(), 12 | 'pose': torch.load('editing/interfacegan_directions/pose.pt').cuda() 13 | } 14 | 15 | def apply_interfacegan(self, latents, direction, factor=1, factor_range=None): 16 | edit_latents = [] 17 | direction = self.interfacegan_directions[direction] 18 | if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5) 19 | for f in range(*factor_range): 20 | edit_latent = latents + f * direction 21 | edit_latents.append(edit_latent) 22 | edit_latents = torch.stack(edit_latents).transpose(0, 1) 23 | else: 24 | edit_latents = latents + factor * direction 25 | return self._latents_to_image(edit_latents) 26 | 27 | def _latents_to_image(self, all_latents): 28 | sample_results = {} 29 | with torch.no_grad(): 30 | for idx, sample_latents in enumerate(all_latents): 31 | images, _ = self.generator([sample_latents], randomize_noise=False, input_is_latent=True) 32 | sample_results[idx] = [tensor2im(image) for image in images] 33 | return sample_results 34 | -------------------------------------------------------------------------------- /restyle_encoder/environment/restyle_env.yaml: -------------------------------------------------------------------------------- 1 | name: restyle_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - ca-certificates=2020.4.5.1=hecc5488_0 8 | - certifi=2020.4.5.1=py36h9f0ad1d_0 9 | - libedit=3.1.20181209=hc058e9b_0 10 | - libffi=3.2.1=hd88cf55_4 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - ncurses=6.2=he6710b0_1 14 | - ninja=1.10.0=hc9558a2_0 15 | - openssl=1.1.1g=h516909a_0 16 | - pip=20.0.2=py36_3 17 | - python=3.6.7=h0371630_0 18 | - python_abi=3.6=1_cp36m 19 | - readline=7.0=h7b6447c_5 20 | - setuptools=46.4.0=py36_0 21 | - sqlite=3.31.1=h62c20be_1 22 | - tk=8.6.8=hbc83047_0 23 | - wheel=0.34.2=py36_0 24 | - xz=5.2.5=h7b6447c_0 25 | - zlib=1.2.11=h7b6447c_3 26 | - pip: 27 | - scipy==1.4.1 28 | - matplotlib==3.2.1 29 | - tqdm==4.46.0 30 | - numpy==1.18.4 31 | - opencv-python==4.2.0.34 32 | - pillow==7.1.2 33 | - tensorboard==2.2.1 34 | - torch==1.6.0 35 | - torchvision==0.4.2 36 | prefix: ~/anaconda3/envs/restyle_env 37 | 38 | -------------------------------------------------------------------------------- /restyle_encoder/licenses/LICENSE_S-aiueo32: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Sou Uchida 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /restyle_encoder/licenses/LICENSE_TreB1eN: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 TreB1eN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /restyle_encoder/licenses/LICENSE_eladrich: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elad Richardson, Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /restyle_encoder/licenses/LICENSE_omertov: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 omertov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /restyle_encoder/licenses/LICENSE_rosinality: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /restyle_encoder/main.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import os 3 | import sys 4 | import torch 5 | import torchvision.transforms as transforms 6 | import dlib 7 | 8 | from models.psp import pSp 9 | from models.psp import pSp 10 | from scriptsLocal.align_faces_parallel import align_face 11 | from utils.inference_utils import run_batch_latent 12 | 13 | sys.path.append("./restyle_encoder") 14 | 15 | 16 | def run_alignment(image_path): 17 | if not os.path.exists("shape_predictor_68_face_landmarks.dat"): 18 | print('Downloading files for aligning face image...') 19 | os.system( 20 | 'wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2' 21 | ) 22 | os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2') 23 | print('Done.') 24 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 25 | aligned_image = align_face(filepath=image_path, predictor=predictor) 26 | print("Aligned image has shape: {}".format(aligned_image.size)) 27 | return aligned_image 28 | 29 | 30 | def get_avg_image(net): 31 | avg_image = net(net.latent_avg.unsqueeze(0), 32 | input_code=True, 33 | randomize_noise=False, 34 | return_latents=False, 35 | average_code=True)[0] 36 | avg_image = avg_image.to('cuda').float().detach() 37 | return avg_image 38 | 39 | 40 | def get_org_latent(image_path): 41 | model_path = "restyle_encoder/pretrained_models/restyle_psp_ffhq_encode.pt" 42 | transform = transforms.Compose([ 43 | transforms.Resize((256, 256)), 44 | transforms.ToTensor(), 45 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 46 | ]) 47 | with torch.no_grad(): 48 | ckpt = torch.load(model_path, map_location='cpu') 49 | opts = ckpt['opts'] 50 | opts['checkpoint_path'] = model_path 51 | opts = Namespace(**opts) 52 | net = pSp(opts).cuda().eval() 53 | input_image = run_alignment(image_path).convert('RGB') 54 | input_image.resize((256, 256)) 55 | transformed_image = transform(input_image) 56 | opts.n_iters_per_batch = 5 57 | opts.resize_outputs = False 58 | avg_image = get_avg_image(net) 59 | latents = run_batch_latent( 60 | transformed_image.unsqueeze(0).cuda(), net, opts, avg_image) 61 | return latents[0][4] # the inverted results from the last iteration 62 | -------------------------------------------------------------------------------- /restyle_encoder/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/models/e4e.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import math 5 | import torch 6 | from torch import nn 7 | 8 | from restyle_encoder.models.stylegan2.model import Generator 9 | from restyle_encoder.configs.paths_config import model_paths 10 | from restyle_encoder.models.encoders import restyle_e4e_encoders 11 | from restyle_encoder.utils.model_utils import RESNET_MAPPING 12 | 13 | 14 | class e4e(nn.Module): 15 | 16 | def __init__(self, opts): 17 | super(e4e, self).__init__() 18 | self.set_opts(opts) 19 | self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 20 | # Define architecture 21 | self.encoder = self.set_encoder() 22 | self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2) 23 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 24 | # Load weights if needed 25 | self.load_weights() 26 | 27 | def set_encoder(self): 28 | if self.opts.encoder_type == 'ProgressiveBackboneEncoder': 29 | encoder = restyle_e4e_encoders.ProgressiveBackboneEncoder(50, 'ir_se', self.n_styles, self.opts) 30 | elif self.opts.encoder_type == 'ResNetProgressiveBackboneEncoder': 31 | encoder = restyle_e4e_encoders.ResNetProgressiveBackboneEncoder(self.n_styles, self.opts) 32 | else: 33 | raise Exception(f'{self.opts.encoder_type} is not a valid encoders') 34 | return encoder 35 | 36 | def load_weights(self): 37 | if self.opts.checkpoint_path is not None: 38 | print(f'Loading ReStyle e4e from checkpoint: {self.opts.checkpoint_path}') 39 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 40 | self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=False) 41 | self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True) 42 | self.__load_latent_avg(ckpt) 43 | else: 44 | encoder_ckpt = self.__get_encoder_checkpoint() 45 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 46 | print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}') 47 | ckpt = torch.load(self.opts.stylegan_weights) 48 | self.decoder.load_state_dict(ckpt['g_ema'], strict=True) 49 | self.__load_latent_avg(ckpt, repeat=self.n_styles) 50 | 51 | def forward(self, x, latent=None, resize=True, latent_mask=None, input_code=False, randomize_noise=True, 52 | inject_latent=None, return_latents=False, alpha=None, average_code=False, input_is_full=False): 53 | if input_code: 54 | codes = x 55 | else: 56 | codes = self.encoder(x) 57 | # residual step 58 | if x.shape[1] == 6 and latent is not None: 59 | # learn error with respect to previous iteration 60 | codes = codes + latent 61 | else: 62 | # first iteration is with respect to the avg latent code 63 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 64 | 65 | if latent_mask is not None: 66 | for i in latent_mask: 67 | if inject_latent is not None: 68 | if alpha is not None: 69 | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] 70 | else: 71 | codes[:, i] = inject_latent[:, i] 72 | else: 73 | codes[:, i] = 0 74 | 75 | if average_code: 76 | input_is_latent = True 77 | else: 78 | input_is_latent = (not input_code) or (input_is_full) 79 | 80 | images, result_latent = self.decoder([codes], 81 | input_is_latent=input_is_latent, 82 | randomize_noise=randomize_noise, 83 | return_latents=return_latents) 84 | 85 | if resize: 86 | images = self.face_pool(images) 87 | 88 | if return_latents: 89 | return images, result_latent 90 | else: 91 | return images 92 | 93 | def set_opts(self, opts): 94 | self.opts = opts 95 | 96 | def __load_latent_avg(self, ckpt, repeat=None): 97 | if 'latent_avg' in ckpt: 98 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 99 | if repeat is not None: 100 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 101 | else: 102 | self.latent_avg = None 103 | 104 | def __get_encoder_checkpoint(self): 105 | if "ffhq" in self.opts.dataset_type: 106 | print('Loading encoders weights from irse50!') 107 | encoder_ckpt = torch.load(model_paths['ir_se50']) 108 | # Transfer the RGB input of the irse50 network to the first 3 input channels of pSp's encoder 109 | if self.opts.input_nc != 3: 110 | shape = encoder_ckpt['input_layer.0.weight'].shape 111 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 112 | altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight'] 113 | encoder_ckpt['input_layer.0.weight'] = altered_input_layer 114 | return encoder_ckpt 115 | else: 116 | print('Loading encoders weights from resnet34!') 117 | encoder_ckpt = torch.load(model_paths['resnet34']) 118 | # Transfer the RGB input of the resnet34 network to the first 3 input channels of pSp's encoder 119 | if self.opts.input_nc != 3: 120 | shape = encoder_ckpt['conv1.weight'].shape 121 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 122 | altered_input_layer[:, :3, :, :] = encoder_ckpt['conv1.weight'] 123 | encoder_ckpt['conv1.weight'] = altered_input_layer 124 | mapped_encoder_ckpt = dict(encoder_ckpt) 125 | for p, v in encoder_ckpt.items(): 126 | for original_name, psp_name in RESNET_MAPPING.items(): 127 | if original_name in p: 128 | mapped_encoder_ckpt[p.replace(original_name, psp_name)] = v 129 | mapped_encoder_ckpt.pop(p) 130 | return encoder_ckpt 131 | 132 | @staticmethod 133 | def __get_keys(d, name): 134 | if 'state_dict' in d: 135 | d = d['state_dict'] 136 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 137 | return d_filt 138 | -------------------------------------------------------------------------------- /restyle_encoder/models/e4e_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/e4e_modules/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/models/e4e_modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LatentCodesDiscriminator(nn.Module): 5 | def __init__(self, style_dim, n_mlp): 6 | super().__init__() 7 | 8 | self.style_dim = style_dim 9 | 10 | layers = [] 11 | for i in range(n_mlp-1): 12 | layers.append( 13 | nn.Linear(style_dim, style_dim) 14 | ) 15 | layers.append(nn.LeakyReLU(0.2)) 16 | layers.append(nn.Linear(512, 1)) 17 | self.mlp = nn.Sequential(*layers) 18 | 19 | def forward(self, w): 20 | return self.mlp(w) 21 | -------------------------------------------------------------------------------- /restyle_encoder/models/e4e_modules/latent_codes_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class LatentCodesPool: 6 | """This class implements latent codes buffer that stores previously generated w latent codes. 7 | This buffer enables us to update discriminators using a history of generated w's 8 | rather than the ones produced by the latest encoder. 9 | """ 10 | 11 | def __init__(self, pool_size): 12 | """Initialize the ImagePool class 13 | Parameters: 14 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 15 | """ 16 | self.pool_size = pool_size 17 | if self.pool_size > 0: # create an empty pool 18 | self.num_ws = 0 19 | self.ws = [] 20 | 21 | def query(self, ws): 22 | """Return w's from the pool. 23 | Parameters: 24 | ws: the latest generated w's from the generator 25 | Returns w's from the buffer. 26 | By 50/100, the buffer will return input w's. 27 | By 50/100, the buffer will return w's previously stored in the buffer, 28 | and insert the current w's to the buffer. 29 | """ 30 | if self.pool_size == 0: # if the buffer size is 0, do nothing 31 | return ws 32 | return_ws = [] 33 | for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) 34 | # w = torch.unsqueeze(image.data, 0) 35 | if w.ndim == 2: 36 | i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate 37 | w = w[i] 38 | self.handle_w(w, return_ws) 39 | return_ws = torch.stack(return_ws, 0) # collect all the images and return 40 | return return_ws 41 | 42 | def handle_w(self, w, return_ws): 43 | if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer 44 | self.num_ws = self.num_ws + 1 45 | self.ws.append(w) 46 | return_ws.append(w) 47 | else: 48 | p = random.uniform(0, 1) 49 | if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer 50 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 51 | tmp = self.ws[random_id].clone() 52 | self.ws[random_id] = w 53 | return_ws.append(tmp) 54 | else: # by another 50% chance, the buffer will return the current image 55 | return_ws.append(w) 56 | -------------------------------------------------------------------------------- /restyle_encoder/models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/encoders/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/models/encoders/fpn_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 5 | from torchvision.models.resnet import resnet34 6 | 7 | from restyle_encoder.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 8 | from restyle_encoder.models.encoders.map2style import GradualStyleBlock 9 | 10 | 11 | class GradualStyleEncoder(Module): 12 | """ 13 | Original encoder architecture from pixel2style2pixel. This classes uses an FPN-based architecture applied over 14 | an ResNet IRSE-50 backbone. 15 | Note this class is designed to be used for the human facial domain. 16 | """ 17 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 18 | super(GradualStyleEncoder, self).__init__() 19 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 20 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 21 | blocks = get_blocks(num_layers) 22 | if mode == 'ir': 23 | unit_module = bottleneck_IR 24 | elif mode == 'ir_se': 25 | unit_module = bottleneck_IR_SE 26 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 27 | BatchNorm2d(64), 28 | PReLU(64)) 29 | modules = [] 30 | for block in blocks: 31 | for bottleneck in block: 32 | modules.append(unit_module(bottleneck.in_channel, 33 | bottleneck.depth, 34 | bottleneck.stride)) 35 | self.body = Sequential(*modules) 36 | 37 | self.styles = nn.ModuleList() 38 | self.style_count = n_styles 39 | self.coarse_ind = 3 40 | self.middle_ind = 7 41 | for i in range(self.style_count): 42 | if i < self.coarse_ind: 43 | style = GradualStyleBlock(512, 512, 16) 44 | elif i < self.middle_ind: 45 | style = GradualStyleBlock(512, 512, 32) 46 | else: 47 | style = GradualStyleBlock(512, 512, 64) 48 | self.styles.append(style) 49 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 50 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 51 | 52 | def _upsample_add(self, x, y): 53 | _, _, H, W = y.size() 54 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 55 | 56 | def forward(self, x): 57 | x = self.input_layer(x) 58 | 59 | latents = [] 60 | modulelist = list(self.body._modules.values()) 61 | for i, l in enumerate(modulelist): 62 | x = l(x) 63 | if i == 6: 64 | c1 = x 65 | elif i == 20: 66 | c2 = x 67 | elif i == 23: 68 | c3 = x 69 | 70 | for j in range(self.coarse_ind): 71 | latents.append(self.styles[j](c3)) 72 | 73 | p2 = self._upsample_add(c3, self.latlayer1(c2)) 74 | for j in range(self.coarse_ind, self.middle_ind): 75 | latents.append(self.styles[j](p2)) 76 | 77 | p1 = self._upsample_add(p2, self.latlayer2(c1)) 78 | for j in range(self.middle_ind, self.style_count): 79 | latents.append(self.styles[j](p1)) 80 | 81 | out = torch.stack(latents, dim=1) 82 | return out 83 | 84 | 85 | class ResNetGradualStyleEncoder(Module): 86 | """ 87 | Original encoder architecture from pixel2style2pixel. This classes uses an FPN-based architecture applied over 88 | an ResNet34 backbone. 89 | """ 90 | def __init__(self, n_styles=18, opts=None): 91 | super(ResNetGradualStyleEncoder, self).__init__() 92 | 93 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 94 | self.bn1 = BatchNorm2d(64) 95 | self.relu = PReLU(64) 96 | 97 | resnet_basenet = resnet34(pretrained=True) 98 | blocks = [ 99 | resnet_basenet.layer1, 100 | resnet_basenet.layer2, 101 | resnet_basenet.layer3, 102 | resnet_basenet.layer4 103 | ] 104 | 105 | modules = [] 106 | for block in blocks: 107 | for bottleneck in block: 108 | modules.append(bottleneck) 109 | 110 | self.body = Sequential(*modules) 111 | 112 | self.styles = nn.ModuleList() 113 | self.style_count = n_styles 114 | self.coarse_ind = 3 115 | self.middle_ind = 7 116 | for i in range(self.style_count): 117 | if i < self.coarse_ind: 118 | style = GradualStyleBlock(512, 512, 16) 119 | elif i < self.middle_ind: 120 | style = GradualStyleBlock(512, 512, 32) 121 | else: 122 | style = GradualStyleBlock(512, 512, 64) 123 | self.styles.append(style) 124 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 125 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 126 | 127 | def _upsample_add(self, x, y): 128 | _, _, H, W = y.size() 129 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 130 | 131 | def forward(self, x): 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | 136 | latents = [] 137 | modulelist = list(self.body._modules.values()) 138 | for i, l in enumerate(modulelist): 139 | x = l(x) 140 | if i == 6: 141 | c1 = x 142 | elif i == 12: 143 | c2 = x 144 | elif i == 15: 145 | c3 = x 146 | 147 | for j in range(self.coarse_ind): 148 | latents.append(self.styles[j](c3)) 149 | 150 | p2 = self._upsample_add(c3, self.latlayer1(c2)) 151 | for j in range(self.coarse_ind, self.middle_ind): 152 | latents.append(self.styles[j](p2)) 153 | 154 | p1 = self._upsample_add(p2, self.latlayer2(c1)) 155 | for j in range(self.middle_ind, self.style_count): 156 | latents.append(self.styles[j](p1)) 157 | 158 | out = torch.stack(latents, dim=1) 159 | return out 160 | -------------------------------------------------------------------------------- /restyle_encoder/models/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /restyle_encoder/models/encoders/map2style.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from torch.nn import Conv2d, Module 4 | 5 | from restyle_encoder.models.stylegan2.model import EqualLinear 6 | 7 | 8 | class GradualStyleBlock(Module): 9 | def __init__(self, in_c, out_c, spatial): 10 | super(GradualStyleBlock, self).__init__() 11 | self.out_c = out_c 12 | self.spatial = spatial 13 | num_pools = int(np.log2(spatial)) 14 | modules = [] 15 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 16 | nn.LeakyReLU()] 17 | for i in range(num_pools - 1): 18 | modules += [ 19 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 20 | nn.LeakyReLU() 21 | ] 22 | self.convs = nn.Sequential(*modules) 23 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 24 | 25 | def forward(self, x): 26 | x = self.convs(x) 27 | x = x.view(-1, self.out_c) 28 | x = self.linear(x) 29 | return x 30 | -------------------------------------------------------------------------------- /restyle_encoder/models/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /restyle_encoder/models/encoders/restyle_e4e_encoders.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from torch import nn 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 4 | from torchvision.models import resnet34 5 | 6 | from restyle_encoder.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 7 | from restyle_encoder.models.encoders.map2style import GradualStyleBlock 8 | 9 | 10 | class ProgressiveStage(Enum): 11 | WTraining = 0 12 | Delta1Training = 1 13 | Delta2Training = 2 14 | Delta3Training = 3 15 | Delta4Training = 4 16 | Delta5Training = 5 17 | Delta6Training = 6 18 | Delta7Training = 7 19 | Delta8Training = 8 20 | Delta9Training = 9 21 | Delta10Training = 10 22 | Delta11Training = 11 23 | Delta12Training = 12 24 | Delta13Training = 13 25 | Delta14Training = 14 26 | Delta15Training = 15 27 | Delta16Training = 16 28 | Delta17Training = 17 29 | Inference = 18 30 | 31 | 32 | class ProgressiveBackboneEncoder(Module): 33 | """ 34 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 35 | map of the encoder. This classes uses the simplified architecture applied over an ResNet IRSE50 backbone with the 36 | progressive training scheme from e4e_modules. 37 | Note this class is designed to be used for the human facial domain. 38 | """ 39 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 40 | super(ProgressiveBackboneEncoder, self).__init__() 41 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 42 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 43 | blocks = get_blocks(num_layers) 44 | if mode == 'ir': 45 | unit_module = bottleneck_IR 46 | elif mode == 'ir_se': 47 | unit_module = bottleneck_IR_SE 48 | 49 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 50 | BatchNorm2d(64), 51 | PReLU(64)) 52 | modules = [] 53 | for block in blocks: 54 | for bottleneck in block: 55 | modules.append(unit_module(bottleneck.in_channel, 56 | bottleneck.depth, 57 | bottleneck.stride)) 58 | self.body = Sequential(*modules) 59 | 60 | self.styles = nn.ModuleList() 61 | self.style_count = n_styles 62 | for i in range(self.style_count): 63 | style = GradualStyleBlock(512, 512, 16) 64 | self.styles.append(style) 65 | self.progressive_stage = ProgressiveStage.Inference 66 | 67 | def get_deltas_starting_dimensions(self): 68 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 69 | return list(range(self.style_count)) # Each dimension has a delta applied to 70 | 71 | def set_progressive_stage(self, new_stage: ProgressiveStage): 72 | # In this encoder we train all the pyramid (At least as a first stage experiment 73 | self.progressive_stage = new_stage 74 | print('Changed progressive stage to: ', new_stage) 75 | 76 | def forward(self, x): 77 | x = self.input_layer(x) 78 | x = self.body(x) 79 | 80 | # get initial w0 from first map2style layer 81 | w0 = self.styles[0](x) 82 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 83 | 84 | # learn the deltas up to the current stage 85 | stage = self.progressive_stage.value 86 | for i in range(1, min(stage + 1, self.style_count)): 87 | delta_i = self.styles[i](x) 88 | w[:, i] += delta_i 89 | return w 90 | 91 | 92 | class ResNetProgressiveBackboneEncoder(Module): 93 | """ 94 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 95 | map of the encoder. This classes uses the simplified architecture applied over an ResNet34 backbone with the 96 | progressive training scheme from e4e_modules. 97 | """ 98 | def __init__(self, n_styles=18, opts=None): 99 | super(ResNetProgressiveBackboneEncoder, self).__init__() 100 | 101 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 102 | self.bn1 = BatchNorm2d(64) 103 | self.relu = PReLU(64) 104 | 105 | resnet_basenet = resnet34(pretrained=True) 106 | blocks = [ 107 | resnet_basenet.layer1, 108 | resnet_basenet.layer2, 109 | resnet_basenet.layer3, 110 | resnet_basenet.layer4 111 | ] 112 | modules = [] 113 | for block in blocks: 114 | for bottleneck in block: 115 | modules.append(bottleneck) 116 | self.body = Sequential(*modules) 117 | 118 | self.styles = nn.ModuleList() 119 | self.style_count = n_styles 120 | for i in range(self.style_count): 121 | style = GradualStyleBlock(512, 512, 16) 122 | self.styles.append(style) 123 | self.progressive_stage = ProgressiveStage.Inference 124 | 125 | def get_deltas_starting_dimensions(self): 126 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 127 | return list(range(self.style_count)) # Each dimension has a delta applied to 128 | 129 | def set_progressive_stage(self, new_stage: ProgressiveStage): 130 | # In this encoder we train all the pyramid (At least as a first stage experiment 131 | self.progressive_stage = new_stage 132 | print('Changed progressive stage to: ', new_stage) 133 | 134 | def forward(self, x): 135 | x = self.conv1(x) 136 | x = self.bn1(x) 137 | x = self.relu(x) 138 | x = self.body(x) 139 | 140 | # get initial w0 from first map2style layer 141 | w0 = self.styles[0](x) 142 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 143 | 144 | # learn the deltas up to the current stage 145 | stage = self.progressive_stage.value 146 | for i in range(1, min(stage + 1, self.style_count)): 147 | delta_i = self.styles[i](x) 148 | w[:, i] += delta_i 149 | return w 150 | -------------------------------------------------------------------------------- /restyle_encoder/models/encoders/restyle_psp_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 4 | from torchvision.models.resnet import resnet34 5 | 6 | from restyle_encoder.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 7 | from restyle_encoder.models.encoders.map2style import GradualStyleBlock 8 | 9 | 10 | class BackboneEncoder(Module): 11 | """ 12 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 13 | map of the encoder. This classes uses the simplified architecture applied over an ResNet IRSE-50 backbone. 14 | Note this class is designed to be used for the human facial domain. 15 | """ 16 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 17 | super(BackboneEncoder, self).__init__() 18 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 19 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 20 | blocks = get_blocks(num_layers) 21 | if mode == 'ir': 22 | unit_module = bottleneck_IR 23 | elif mode == 'ir_se': 24 | unit_module = bottleneck_IR_SE 25 | 26 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 27 | BatchNorm2d(64), 28 | PReLU(64)) 29 | modules = [] 30 | for block in blocks: 31 | for bottleneck in block: 32 | modules.append(unit_module(bottleneck.in_channel, 33 | bottleneck.depth, 34 | bottleneck.stride)) 35 | self.body = Sequential(*modules) 36 | 37 | self.styles = nn.ModuleList() 38 | self.style_count = n_styles 39 | for i in range(self.style_count): 40 | style = GradualStyleBlock(512, 512, 16) 41 | self.styles.append(style) 42 | 43 | def forward(self, x): 44 | x = self.input_layer(x) 45 | x = self.body(x) 46 | latents = [] 47 | for j in range(self.style_count): 48 | latents.append(self.styles[j](x)) 49 | out = torch.stack(latents, dim=1) 50 | return out 51 | 52 | 53 | class ResNetBackboneEncoder(Module): 54 | """ 55 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 56 | map of the encoder. This classes uses the simplified architecture applied over an ResNet34 backbone. 57 | """ 58 | def __init__(self, n_styles=18, opts=None): 59 | super(ResNetBackboneEncoder, self).__init__() 60 | 61 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 62 | self.bn1 = BatchNorm2d(64) 63 | self.relu = PReLU(64) 64 | 65 | resnet_basenet = resnet34(pretrained=True) 66 | blocks = [ 67 | resnet_basenet.layer1, 68 | resnet_basenet.layer2, 69 | resnet_basenet.layer3, 70 | resnet_basenet.layer4 71 | ] 72 | modules = [] 73 | for block in blocks: 74 | for bottleneck in block: 75 | modules.append(bottleneck) 76 | self.body = Sequential(*modules) 77 | 78 | self.styles = nn.ModuleList() 79 | self.style_count = n_styles 80 | for i in range(self.style_count): 81 | style = GradualStyleBlock(512, 512, 16) 82 | self.styles.append(style) 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = self.bn1(x) 87 | x = self.relu(x) 88 | x = self.body(x) 89 | latents = [] 90 | for j in range(self.style_count): 91 | latents.append(self.styles[j](x)) 92 | out = torch.stack(latents, dim=1) 93 | return out 94 | -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/mtcnn/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet 5 | from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage 7 | from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face 8 | 9 | device = 'cuda:0' 10 | 11 | 12 | class MTCNN(): 13 | def __init__(self): 14 | print(device) 15 | self.pnet = PNet().to(device) 16 | self.rnet = RNet().to(device) 17 | self.onet = ONet().to(device) 18 | self.pnet.eval() 19 | self.rnet.eval() 20 | self.onet.eval() 21 | self.refrence = get_reference_facial_points(default_square=True) 22 | 23 | def align(self, img): 24 | _, landmarks = self.detect_faces(img) 25 | if len(landmarks) == 0: 26 | return None, None 27 | facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] 28 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 29 | return Image.fromarray(warped_face), tfm 30 | 31 | def align_multi(self, img, limit=None, min_face_size=30.0): 32 | boxes, landmarks = self.detect_faces(img, min_face_size) 33 | if limit: 34 | boxes = boxes[:limit] 35 | landmarks = landmarks[:limit] 36 | faces = [] 37 | tfms = [] 38 | for landmark in landmarks: 39 | facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)] 40 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 41 | faces.append(Image.fromarray(warped_face)) 42 | tfms.append(tfm) 43 | return boxes, faces, tfms 44 | 45 | def detect_faces(self, image, min_face_size=20.0, 46 | thresholds=[0.15, 0.25, 0.35], 47 | nms_thresholds=[0.7, 0.7, 0.7]): 48 | """ 49 | Arguments: 50 | image: an instance of PIL.Image. 51 | min_face_size: a float number. 52 | thresholds: a list of length 3. 53 | nms_thresholds: a list of length 3. 54 | 55 | Returns: 56 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 57 | bounding boxes and facial landmarks. 58 | """ 59 | 60 | # BUILD AN IMAGE PYRAMID 61 | width, height = image.size 62 | min_length = min(height, width) 63 | 64 | min_detection_size = 12 65 | factor = 0.707 # sqrt(0.5) 66 | 67 | # scales for scaling the image 68 | scales = [] 69 | 70 | # scales the image so that 71 | # minimum size that we can detect equals to 72 | # minimum face size that we want to detect 73 | m = min_detection_size / min_face_size 74 | min_length *= m 75 | 76 | factor_count = 0 77 | while min_length > min_detection_size: 78 | scales.append(m * factor ** factor_count) 79 | min_length *= factor 80 | factor_count += 1 81 | 82 | # STAGE 1 83 | 84 | # it will be returned 85 | bounding_boxes = [] 86 | 87 | with torch.no_grad(): 88 | # run P-Net on different scales 89 | for s in scales: 90 | boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0]) 91 | bounding_boxes.append(boxes) 92 | 93 | # collect boxes (and offsets, and scores) from different scales 94 | bounding_boxes = [i for i in bounding_boxes if i is not None] 95 | bounding_boxes = np.vstack(bounding_boxes) 96 | 97 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 98 | bounding_boxes = bounding_boxes[keep] 99 | 100 | # use offsets predicted by pnet to transform bounding boxes 101 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 102 | # shape [n_boxes, 5] 103 | 104 | bounding_boxes = convert_to_square(bounding_boxes) 105 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 106 | 107 | # STAGE 2 108 | 109 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 110 | img_boxes = torch.FloatTensor(img_boxes).to(device) 111 | 112 | output = self.rnet(img_boxes) 113 | offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] 114 | probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] 115 | 116 | keep = np.where(probs[:, 1] > thresholds[1])[0] 117 | bounding_boxes = bounding_boxes[keep] 118 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 119 | offsets = offsets[keep] 120 | 121 | keep = nms(bounding_boxes, nms_thresholds[1]) 122 | bounding_boxes = bounding_boxes[keep] 123 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 124 | bounding_boxes = convert_to_square(bounding_boxes) 125 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 126 | 127 | # STAGE 3 128 | 129 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 130 | if len(img_boxes) == 0: 131 | return [], [] 132 | img_boxes = torch.FloatTensor(img_boxes).to(device) 133 | output = self.onet(img_boxes) 134 | landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] 135 | offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] 136 | probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] 137 | 138 | keep = np.where(probs[:, 1] > thresholds[2])[0] 139 | bounding_boxes = bounding_boxes[keep] 140 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 141 | offsets = offsets[keep] 142 | landmarks = landmarks[keep] 143 | 144 | # compute landmark points 145 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 146 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 147 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 148 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 149 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 150 | 151 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 152 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 153 | bounding_boxes = bounding_boxes[keep] 154 | landmarks = landmarks[keep] 155 | 156 | return bounding_boxes, landmarks 157 | -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/mtcnn/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .get_nets import PNet, RNet, ONet 4 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 5 | from .first_stage import run_first_stage 6 | 7 | 8 | def detect_faces(image, min_face_size=20.0, 9 | thresholds=[0.6, 0.7, 0.8], 10 | nms_thresholds=[0.7, 0.7, 0.7]): 11 | """ 12 | Arguments: 13 | image: an instance of PIL.Image. 14 | min_face_size: a float number. 15 | thresholds: a list of length 3. 16 | nms_thresholds: a list of length 3. 17 | 18 | Returns: 19 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 20 | bounding boxes and facial landmarks. 21 | """ 22 | 23 | # LOAD MODELS 24 | pnet = PNet() 25 | rnet = RNet() 26 | onet = ONet() 27 | onet.eval() 28 | 29 | # BUILD AN IMAGE PYRAMID 30 | width, height = image.size 31 | min_length = min(height, width) 32 | 33 | min_detection_size = 12 34 | factor = 0.707 # sqrt(0.5) 35 | 36 | # scales for scaling the image 37 | scales = [] 38 | 39 | # scales the image so that 40 | # minimum size that we can detect equals to 41 | # minimum face size that we want to detect 42 | m = min_detection_size / min_face_size 43 | min_length *= m 44 | 45 | factor_count = 0 46 | while min_length > min_detection_size: 47 | scales.append(m * factor ** factor_count) 48 | min_length *= factor 49 | factor_count += 1 50 | 51 | # STAGE 1 52 | 53 | # it will be returned 54 | bounding_boxes = [] 55 | 56 | with torch.no_grad(): 57 | # run P-Net on different scales 58 | for s in scales: 59 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 60 | bounding_boxes.append(boxes) 61 | 62 | # collect boxes (and offsets, and scores) from different scales 63 | bounding_boxes = [i for i in bounding_boxes if i is not None] 64 | bounding_boxes = np.vstack(bounding_boxes) 65 | 66 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 67 | bounding_boxes = bounding_boxes[keep] 68 | 69 | # use offsets predicted by pnet to transform bounding boxes 70 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 71 | # shape [n_boxes, 5] 72 | 73 | bounding_boxes = convert_to_square(bounding_boxes) 74 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 75 | 76 | # STAGE 2 77 | 78 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 79 | img_boxes = torch.FloatTensor(img_boxes) 80 | 81 | output = rnet(img_boxes) 82 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 83 | probs = output[1].data.numpy() # shape [n_boxes, 2] 84 | 85 | keep = np.where(probs[:, 1] > thresholds[1])[0] 86 | bounding_boxes = bounding_boxes[keep] 87 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 88 | offsets = offsets[keep] 89 | 90 | keep = nms(bounding_boxes, nms_thresholds[1]) 91 | bounding_boxes = bounding_boxes[keep] 92 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 93 | bounding_boxes = convert_to_square(bounding_boxes) 94 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 95 | 96 | # STAGE 3 97 | 98 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 99 | if len(img_boxes) == 0: 100 | return [], [] 101 | img_boxes = torch.FloatTensor(img_boxes) 102 | output = onet(img_boxes) 103 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 104 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 105 | probs = output[2].data.numpy() # shape [n_boxes, 2] 106 | 107 | keep = np.where(probs[:, 1] > thresholds[2])[0] 108 | bounding_boxes = bounding_boxes[keep] 109 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 110 | offsets = offsets[keep] 111 | landmarks = landmarks[keep] 112 | 113 | # compute landmark points 114 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 115 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 116 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 117 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 118 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 119 | 120 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 121 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 122 | bounding_boxes = bounding_boxes[keep] 123 | landmarks = landmarks[keep] 124 | 125 | return bounding_boxes, landmarks 126 | -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from PIL import Image 4 | import numpy as np 5 | from .box_utils import nms, _preprocess 6 | 7 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | device = 'cuda:0' 9 | 10 | 11 | def run_first_stage(image, net, scale, threshold): 12 | """Run P-Net, generate bounding boxes, and do NMS. 13 | 14 | Arguments: 15 | image: an instance of PIL.Image. 16 | net: an instance of pytorch's nn.Module, P-Net. 17 | scale: a float number, 18 | scale width and height of the image by this number. 19 | threshold: a float number, 20 | threshold on the probability of a face when generating 21 | bounding boxes from predictions of the net. 22 | 23 | Returns: 24 | a float numpy array of shape [n_boxes, 9], 25 | bounding boxes with scores and offsets (4 + 1 + 4). 26 | """ 27 | 28 | # scale the image and convert it to a float array 29 | width, height = image.size 30 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 31 | img = image.resize((sw, sh), Image.BILINEAR) 32 | img = np.asarray(img, 'float32') 33 | 34 | img = torch.FloatTensor(_preprocess(img)).to(device) 35 | with torch.no_grad(): 36 | output = net(img) 37 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 38 | offsets = output[0].cpu().data.numpy() 39 | # probs: probability of a face at each sliding window 40 | # offsets: transformations to true bounding boxes 41 | 42 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 43 | if len(boxes) == 0: 44 | return None 45 | 46 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 47 | return boxes[keep] 48 | 49 | 50 | def _generate_bboxes(probs, offsets, scale, threshold): 51 | """Generate bounding boxes at places 52 | where there is probably a face. 53 | 54 | Arguments: 55 | probs: a float numpy array of shape [n, m]. 56 | offsets: a float numpy array of shape [1, 4, n, m]. 57 | scale: a float number, 58 | width and height of the image were scaled by this number. 59 | threshold: a float number. 60 | 61 | Returns: 62 | a float numpy array of shape [n_boxes, 9] 63 | """ 64 | 65 | # applying P-Net is equivalent, in some sense, to 66 | # moving 12x12 window with stride 2 67 | stride = 2 68 | cell_size = 12 69 | 70 | # indices of boxes where there is probably a face 71 | inds = np.where(probs > threshold) 72 | 73 | if inds[0].size == 0: 74 | return np.array([]) 75 | 76 | # transformations of bounding boxes 77 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 78 | # they are defined as: 79 | # w = x2 - x1 + 1 80 | # h = y2 - y1 + 1 81 | # x1_true = x1 + tx1*w 82 | # x2_true = x2 + tx2*w 83 | # y1_true = y1 + ty1*h 84 | # y2_true = y2 + ty2*h 85 | 86 | offsets = np.array([tx1, ty1, tx2, ty2]) 87 | score = probs[inds[0], inds[1]] 88 | 89 | # P-Net is applied to scaled images 90 | # so we need to rescale bounding boxes back 91 | bounding_boxes = np.vstack([ 92 | np.round((stride * inds[1] + 1.0) / scale), 93 | np.round((stride * inds[0] + 1.0) / scale), 94 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 95 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 96 | score, offsets 97 | ]) 98 | # why one is added? 99 | 100 | return bounding_boxes.T 101 | -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | from configs.paths_config import model_paths 8 | PNET_PATH = model_paths["mtcnn_pnet"] 9 | ONET_PATH = model_paths["mtcnn_onet"] 10 | RNET_PATH = model_paths["mtcnn_rnet"] 11 | 12 | 13 | class Flatten(nn.Module): 14 | 15 | def __init__(self): 16 | super(Flatten, self).__init__() 17 | 18 | def forward(self, x): 19 | """ 20 | Arguments: 21 | x: a float tensor with shape [batch_size, c, h, w]. 22 | Returns: 23 | a float tensor with shape [batch_size, c*h*w]. 24 | """ 25 | 26 | # without this pretrained model isn't working 27 | x = x.transpose(3, 2).contiguous() 28 | 29 | return x.view(x.size(0), -1) 30 | 31 | 32 | class PNet(nn.Module): 33 | 34 | def __init__(self): 35 | super().__init__() 36 | 37 | # suppose we have input with size HxW, then 38 | # after first layer: H - 2, 39 | # after pool: ceil((H - 2)/2), 40 | # after second conv: ceil((H - 2)/2) - 2, 41 | # after last conv: ceil((H - 2)/2) - 4, 42 | # and the same for W 43 | 44 | self.features = nn.Sequential(OrderedDict([ 45 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 46 | ('prelu1', nn.PReLU(10)), 47 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), 48 | 49 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 50 | ('prelu2', nn.PReLU(16)), 51 | 52 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 53 | ('prelu3', nn.PReLU(32)) 54 | ])) 55 | 56 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 57 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 58 | 59 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 60 | for n, p in self.named_parameters(): 61 | p.data = torch.FloatTensor(weights[n]) 62 | 63 | def forward(self, x): 64 | """ 65 | Arguments: 66 | x: a float tensor with shape [batch_size, 3, h, w]. 67 | Returns: 68 | b: a float tensor with shape [batch_size, 4, h', w']. 69 | a: a float tensor with shape [batch_size, 2, h', w']. 70 | """ 71 | x = self.features(x) 72 | a = self.conv4_1(x) 73 | b = self.conv4_2(x) 74 | a = F.softmax(a, dim=-1) 75 | return b, a 76 | 77 | 78 | class RNet(nn.Module): 79 | 80 | def __init__(self): 81 | super().__init__() 82 | 83 | self.features = nn.Sequential(OrderedDict([ 84 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 85 | ('prelu1', nn.PReLU(28)), 86 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 87 | 88 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 89 | ('prelu2', nn.PReLU(48)), 90 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 91 | 92 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 93 | ('prelu3', nn.PReLU(64)), 94 | 95 | ('flatten', Flatten()), 96 | ('conv4', nn.Linear(576, 128)), 97 | ('prelu4', nn.PReLU(128)) 98 | ])) 99 | 100 | self.conv5_1 = nn.Linear(128, 2) 101 | self.conv5_2 = nn.Linear(128, 4) 102 | 103 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 104 | for n, p in self.named_parameters(): 105 | p.data = torch.FloatTensor(weights[n]) 106 | 107 | def forward(self, x): 108 | """ 109 | Arguments: 110 | x: a float tensor with shape [batch_size, 3, h, w]. 111 | Returns: 112 | b: a float tensor with shape [batch_size, 4]. 113 | a: a float tensor with shape [batch_size, 2]. 114 | """ 115 | x = self.features(x) 116 | a = self.conv5_1(x) 117 | b = self.conv5_2(x) 118 | a = F.softmax(a, dim=-1) 119 | return b, a 120 | 121 | 122 | class ONet(nn.Module): 123 | 124 | def __init__(self): 125 | super().__init__() 126 | 127 | self.features = nn.Sequential(OrderedDict([ 128 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 129 | ('prelu1', nn.PReLU(32)), 130 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 131 | 132 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 133 | ('prelu2', nn.PReLU(64)), 134 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 135 | 136 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 137 | ('prelu3', nn.PReLU(64)), 138 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), 139 | 140 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 141 | ('prelu4', nn.PReLU(128)), 142 | 143 | ('flatten', Flatten()), 144 | ('conv5', nn.Linear(1152, 256)), 145 | ('drop5', nn.Dropout(0.25)), 146 | ('prelu5', nn.PReLU(256)), 147 | ])) 148 | 149 | self.conv6_1 = nn.Linear(256, 2) 150 | self.conv6_2 = nn.Linear(256, 4) 151 | self.conv6_3 = nn.Linear(256, 10) 152 | 153 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 154 | for n, p in self.named_parameters(): 155 | p.data = torch.FloatTensor(weights[n]) 156 | 157 | def forward(self, x): 158 | """ 159 | Arguments: 160 | x: a float tensor with shape [batch_size, 3, h, w]. 161 | Returns: 162 | c: a float tensor with shape [batch_size, 10]. 163 | b: a float tensor with shape [batch_size, 4]. 164 | a: a float tensor with shape [batch_size, 2]. 165 | """ 166 | x = self.features(x) 167 | a = self.conv6_1(x) 168 | b = self.conv6_2(x) 169 | c = self.conv6_3(x) 170 | a = F.softmax(a, dim=-1) 171 | return c, b, a 172 | -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /restyle_encoder/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /restyle_encoder/models/psp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import math 5 | import torch 6 | from torch import nn 7 | 8 | from restyle_encoder.models.stylegan2.model import Generator 9 | from restyle_encoder.configs.paths_config import model_paths 10 | from restyle_encoder.models.encoders import fpn_encoders, restyle_psp_encoders 11 | from restyle_encoder.utils.model_utils import RESNET_MAPPING 12 | 13 | 14 | class pSp(nn.Module): 15 | 16 | def __init__(self, opts): 17 | super(pSp, self).__init__() 18 | self.set_opts(opts) 19 | self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 20 | # Define architecture 21 | self.encoder = self.set_encoder() 22 | self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2) 23 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 24 | # Load weights if needed 25 | self.load_weights() 26 | 27 | def set_encoder(self): 28 | if self.opts.encoder_type == 'GradualStyleEncoder': 29 | encoder = fpn_encoders.GradualStyleEncoder(50, 'ir_se', self.n_styles, self.opts) 30 | elif self.opts.encoder_type == 'ResNetGradualStyleEncoder': 31 | encoder = fpn_encoders.ResNetGradualStyleEncoder(self.n_styles, self.opts) 32 | elif self.opts.encoder_type == 'BackboneEncoder': 33 | encoder = restyle_psp_encoders.BackboneEncoder(50, 'ir_se', self.n_styles, self.opts) 34 | elif self.opts.encoder_type == 'ResNetBackboneEncoder': 35 | encoder = restyle_psp_encoders.ResNetBackboneEncoder(self.n_styles, self.opts) 36 | else: 37 | raise Exception(f'{self.opts.encoder_type} is not a valid encoders') 38 | return encoder 39 | 40 | def load_weights(self): 41 | if self.opts.checkpoint_path is not None: 42 | print(f'Loading ReStyle pSp from checkpoint: {self.opts.checkpoint_path}') 43 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 44 | self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=False) 45 | self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True) 46 | self.__load_latent_avg(ckpt) 47 | else: 48 | encoder_ckpt = self.__get_encoder_checkpoint() 49 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 50 | print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}') 51 | ckpt = torch.load(self.opts.stylegan_weights) 52 | self.decoder.load_state_dict(ckpt['g_ema'], strict=True) 53 | self.__load_latent_avg(ckpt, repeat=self.n_styles) 54 | 55 | def forward(self, x, latent=None, resize=True, latent_mask=None, input_code=False, randomize_noise=True, 56 | inject_latent=None, return_latents=False, alpha=None, average_code=False, input_is_full=False): 57 | if input_code: 58 | codes = x 59 | else: 60 | codes = self.encoder(x) 61 | # residual step 62 | if x.shape[1] == 6 and latent is not None: 63 | # learn error with respect to previous iteration 64 | codes = codes + latent 65 | else: 66 | # first iteration is with respect to the avg latent code 67 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 68 | 69 | if latent_mask is not None: 70 | for i in latent_mask: 71 | if inject_latent is not None: 72 | if alpha is not None: 73 | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] 74 | else: 75 | codes[:, i] = inject_latent[:, i] 76 | else: 77 | codes[:, i] = 0 78 | 79 | if average_code: 80 | input_is_latent = True 81 | else: 82 | input_is_latent = (not input_code) or (input_is_full) 83 | 84 | images, result_latent = self.decoder([codes], 85 | input_is_latent=input_is_latent, 86 | randomize_noise=randomize_noise, 87 | return_latents=return_latents) 88 | 89 | if resize: 90 | images = self.face_pool(images) 91 | 92 | if return_latents: 93 | return images, result_latent 94 | else: 95 | return images 96 | 97 | def set_opts(self, opts): 98 | self.opts = opts 99 | 100 | def __load_latent_avg(self, ckpt, repeat=None): 101 | if 'latent_avg' in ckpt: 102 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 103 | if repeat is not None: 104 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 105 | else: 106 | self.latent_avg = None 107 | 108 | def __get_encoder_checkpoint(self): 109 | if "ffhq" in self.opts.dataset_type: 110 | print('Loading encoders weights from irse50!') 111 | encoder_ckpt = torch.load(model_paths['ir_se50']) 112 | # Transfer the RGB input of the irse50 network to the first 3 input channels of pSp's encoder 113 | if self.opts.input_nc != 3: 114 | shape = encoder_ckpt['input_layer.0.weight'].shape 115 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 116 | altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight'] 117 | encoder_ckpt['input_layer.0.weight'] = altered_input_layer 118 | return encoder_ckpt 119 | else: 120 | print('Loading encoders weights from resnet34!') 121 | encoder_ckpt = torch.load(model_paths['resnet34']) 122 | # Transfer the RGB input of the resnet34 network to the first 3 input channels of pSp's encoder 123 | if self.opts.input_nc != 3: 124 | shape = encoder_ckpt['conv1.weight'].shape 125 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 126 | altered_input_layer[:, :3, :, :] = encoder_ckpt['conv1.weight'] 127 | encoder_ckpt['conv1.weight'] = altered_input_layer 128 | mapped_encoder_ckpt = dict(encoder_ckpt) 129 | for p, v in encoder_ckpt.items(): 130 | for original_name, psp_name in RESNET_MAPPING.items(): 131 | if original_name in p: 132 | mapped_encoder_ckpt[p.replace(original_name, psp_name)] = v 133 | mapped_encoder_ckpt.pop(p) 134 | return encoder_ckpt 135 | 136 | @staticmethod 137 | def __get_keys(d, name): 138 | if 'state_dict' in d: 139 | d = d['state_dict'] 140 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 141 | return d_filt 142 | -------------------------------------------------------------------------------- /restyle_encoder/models/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/models/stylegan2/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /restyle_encoder/models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /restyle_encoder/models/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /restyle_encoder/models/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /restyle_encoder/models/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /restyle_encoder/models/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | up_x, up_y = up 23 | down_x, down_y = down 24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 25 | 26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 27 | 28 | grad_input = upfirdn2d_op.upfirdn2d( 29 | grad_output, 30 | grad_kernel, 31 | down_x, 32 | down_y, 33 | up_x, 34 | up_y, 35 | g_pad_x0, 36 | g_pad_x1, 37 | g_pad_y0, 38 | g_pad_y1, 39 | ) 40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 41 | 42 | ctx.save_for_backward(kernel) 43 | 44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 45 | 46 | ctx.up_x = up_x 47 | ctx.up_y = up_y 48 | ctx.down_x = down_x 49 | ctx.down_y = down_y 50 | ctx.pad_x0 = pad_x0 51 | ctx.pad_x1 = pad_x1 52 | ctx.pad_y0 = pad_y0 53 | ctx.pad_y1 = pad_y1 54 | ctx.in_size = in_size 55 | ctx.out_size = out_size 56 | 57 | return grad_input 58 | 59 | @staticmethod 60 | def backward(ctx, gradgrad_input): 61 | kernel, = ctx.saved_tensors 62 | 63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 64 | 65 | gradgrad_out = upfirdn2d_op.upfirdn2d( 66 | gradgrad_input, 67 | kernel, 68 | ctx.up_x, 69 | ctx.up_y, 70 | ctx.down_x, 71 | ctx.down_y, 72 | ctx.pad_x0, 73 | ctx.pad_x1, 74 | ctx.pad_y0, 75 | ctx.pad_y1, 76 | ) 77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 78 | gradgrad_out = gradgrad_out.view( 79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 80 | ) 81 | 82 | return gradgrad_out, None, None, None, None, None, None, None, None 83 | 84 | 85 | class UpFirDn2d(Function): 86 | @staticmethod 87 | def forward(ctx, input, kernel, up, down, pad): 88 | up_x, up_y = up 89 | down_x, down_y = down 90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 91 | 92 | kernel_h, kernel_w = kernel.shape 93 | batch, channel, in_h, in_w = input.shape 94 | ctx.in_size = input.shape 95 | 96 | input = input.reshape(-1, in_h, in_w, 1) 97 | 98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 99 | 100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 102 | ctx.out_size = (out_h, out_w) 103 | 104 | ctx.up = (up_x, up_y) 105 | ctx.down = (down_x, down_y) 106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 107 | 108 | g_pad_x0 = kernel_w - pad_x0 - 1 109 | g_pad_y0 = kernel_h - pad_y0 - 1 110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 112 | 113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 114 | 115 | out = upfirdn2d_op.upfirdn2d( 116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 117 | ) 118 | # out = out.view(major, out_h, out_w, minor) 119 | out = out.view(-1, channel, out_h, out_w) 120 | 121 | return out 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | kernel, grad_kernel = ctx.saved_tensors 126 | 127 | grad_input = UpFirDn2dBackward.apply( 128 | grad_output, 129 | kernel, 130 | grad_kernel, 131 | ctx.up, 132 | ctx.down, 133 | ctx.pad, 134 | ctx.g_pad, 135 | ctx.in_size, 136 | ctx.out_size, 137 | ) 138 | 139 | return grad_input, None, None, None, None 140 | 141 | 142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 143 | out = UpFirDn2d.apply( 144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 145 | ) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native( 151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 152 | ): 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad( 161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 162 | ) 163 | out = out[ 164 | :, 165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), 166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), 167 | :, 168 | ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape( 172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 173 | ) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | 184 | return out[:, ::down_y, ::down_x, :] 185 | -------------------------------------------------------------------------------- /restyle_encoder/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/options/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/options/e4e_train_options.py: -------------------------------------------------------------------------------- 1 | from options.train_options import TrainOptions 2 | 3 | 4 | class e4eTrainOptions(TrainOptions): 5 | 6 | def __init__(self): 7 | super(e4eTrainOptions, self).__init__() 8 | 9 | def initialize(self): 10 | super(e4eTrainOptions, self).initialize() 11 | self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, 12 | help='Dw loss multiplier') 13 | self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, 14 | help='Dw learning rate') 15 | self.parser.add_argument("--r1", type=float, default=10, 16 | help="weight of the r1 regularization") 17 | self.parser.add_argument("--d_reg_every", type=int, default=16, 18 | help="interval for applying r1 regularization") 19 | self.parser.add_argument('--use_w_pool', action='store_true', 20 | help='Whether to store a latnet codes pool for the discriminator\'s training') 21 | self.parser.add_argument("--w_pool_size", type=int, default=50, 22 | help="W\'s pool size, depends on --use_w_pool") 23 | 24 | # e4e_modules specific 25 | self.parser.add_argument('--delta_norm', type=int, default=2, 26 | help="norm type of the deltas") 27 | self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, 28 | help="lambda for delta norm loss") 29 | 30 | # Progressive training 31 | self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None, 32 | help="The training steps of training new deltas. steps[i] starts the delta_i training") 33 | self.parser.add_argument('--progressive_start', type=int, default=None, 34 | help="The training step to start training the deltas, overrides progressive_steps") 35 | self.parser.add_argument('--progressive_step_every', type=int, default=2_000, 36 | help="Amount of training steps for each progressive step") 37 | 38 | # Save additional training info to enable future training continuation from produced checkpoints 39 | self.parser.add_argument('--save_training_data', action='store_true', 40 | help='Save intermediate training data to resume training from the checkpoint') 41 | self.parser.add_argument('--sub_exp_dir', default=None, type=str, 42 | help='Name of sub experiment directory') 43 | self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str, 44 | help='Path to training checkpoint, works when --save_training_data was set to True') 45 | self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None, 46 | help="Name of training parameters to update the loaded training checkpoint") 47 | 48 | def parse(self): 49 | opts = self.parser.parse_args() 50 | return opts 51 | -------------------------------------------------------------------------------- /restyle_encoder/options/test_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class TestOptions: 5 | 6 | def __init__(self): 7 | self.parser = ArgumentParser() 8 | self.initialize() 9 | 10 | def initialize(self): 11 | # arguments for inference script 12 | self.parser.add_argument('--exp_dir', type=str, 13 | help='Path to experiment output directory') 14 | self.parser.add_argument('--checkpoint_path', default=None, type=str, 15 | help='Path to ReStyle model checkpoint') 16 | self.parser.add_argument('--data_path', type=str, default='gt_images', 17 | help='Path to directory of images to evaluate') 18 | self.parser.add_argument('--resize_outputs', action='store_true', 19 | help='Whether to resize outputs to 256x256 or keep at original output resolution') 20 | self.parser.add_argument('--test_batch_size', default=2, type=int, 21 | help='Batch size for testing and inference') 22 | self.parser.add_argument('--test_workers', default=2, type=int, 23 | help='Number of test/inference dataloader workers') 24 | self.parser.add_argument('--n_images', type=int, default=None, 25 | help='Number of images to output. If None, run on all data') 26 | 27 | # arguments for iterative inference 28 | self.parser.add_argument('--n_iters_per_batch', default=5, type=int, 29 | help='Number of forward passes per batch during training.') 30 | 31 | # arguments for encoder bootstrapping 32 | self.parser.add_argument('--model_1_checkpoint_path', default=None, type=str, 33 | help='Path to encoder used to initialize encoder bootstrapping inference.') 34 | self.parser.add_argument('--model_2_checkpoint_path', default=None, type=str, 35 | help='Path to encoder used to iteratively translate images following ' 36 | 'model 1\'s initialization.') 37 | 38 | # arguments for editing 39 | self.parser.add_argument('--edit_directions', type=str, default='age,smile,pose', 40 | help='comma-separated list of which edit directions top perform.') 41 | self.parser.add_argument('--factor_ranges', type=str, default='5,5,5', 42 | help='comma-separated list of max ranges for each corresponding edit.') 43 | 44 | 45 | def parse(self): 46 | opts = self.parser.parse_args() 47 | return opts 48 | -------------------------------------------------------------------------------- /restyle_encoder/options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | class TrainOptions: 4 | 5 | def __init__(self): 6 | self.parser = ArgumentParser() 7 | self.initialize() 8 | 9 | def initialize(self): 10 | # general setup 11 | self.parser.add_argument('--exp_dir', type=str, 12 | help='Path to experiment output directory') 13 | self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, 14 | help='Type of dataset/experiment to run') 15 | self.parser.add_argument('--encoder_type', default='BackboneEncoder', type=str, 16 | help='Which encoder to use') 17 | self.parser.add_argument('--input_nc', default=6, type=int, 18 | help='Number of input image channels to the ReStyle encoder. Should be set to 6.') 19 | self.parser.add_argument('--output_size', default=1024, type=int, 20 | help='Output size of generator') 21 | 22 | # batch size and dataloader works 23 | self.parser.add_argument('--batch_size', default=4, type=int, 24 | help='Batch size for training') 25 | self.parser.add_argument('--test_batch_size', default=2, type=int, 26 | help='Batch size for testing and inference') 27 | self.parser.add_argument('--workers', default=4, type=int, 28 | help='Number of train dataloader workers') 29 | self.parser.add_argument('--test_workers', default=2, type=int, 30 | help='Number of test/inference dataloader workers') 31 | 32 | # optimizers 33 | self.parser.add_argument('--learning_rate', default=0.0001, type=float, 34 | help='Optimizer learning rate') 35 | self.parser.add_argument('--optim_name', default='ranger', type=str, 36 | help='Which optimizer to use') 37 | self.parser.add_argument('--train_decoder', default=False, type=bool, 38 | help='Whether to train the decoder model') 39 | self.parser.add_argument('--start_from_latent_avg', action='store_true', 40 | help='Whether to add average latent vector to generate codes from encoder.') 41 | 42 | # loss lambdas 43 | self.parser.add_argument('--lpips_lambda', default=0, type=float, 44 | help='LPIPS loss multiplier factor') 45 | self.parser.add_argument('--id_lambda', default=0, type=float, 46 | help='ID loss multiplier factor') 47 | self.parser.add_argument('--l2_lambda', default=0, type=float, 48 | help='L2 loss multiplier factor') 49 | self.parser.add_argument('--w_norm_lambda', default=0, type=float, 50 | help='W-norm loss multiplier factor') 51 | self.parser.add_argument('--moco_lambda', default=0, type=float, 52 | help='Moco feature loss multiplier factor') 53 | 54 | # weights and checkpoint paths 55 | self.parser.add_argument('--stylegan_weights', default=None, type=str, 56 | help='Path to StyleGAN model weights') 57 | self.parser.add_argument('--checkpoint_path', default=None, type=str, 58 | help='Path to ReStyle model checkpoint') 59 | 60 | # intervals for logging, validation, and saving 61 | self.parser.add_argument('--max_steps', default=500000, type=int, 62 | help='Maximum number of training steps') 63 | self.parser.add_argument('--image_interval', default=100, type=int, 64 | help='Interval for logging train images during training') 65 | self.parser.add_argument('--board_interval', default=50, type=int, 66 | help='Interval for logging metrics to tensorboard') 67 | self.parser.add_argument('--val_interval', default=1000, type=int, 68 | help='Validation interval') 69 | self.parser.add_argument('--save_interval', default=None, type=int, 70 | help='Model checkpoint interval') 71 | 72 | # arguments for iterative encoding 73 | self.parser.add_argument('--n_iters_per_batch', default=5, type=int, 74 | help='Number of forward passes per batch during training') 75 | 76 | def parse(self): 77 | opts = self.parser.parse_args() 78 | return opts 79 | -------------------------------------------------------------------------------- /restyle_encoder/scriptsLocal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/scriptsLocal/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/scriptsLocal/calc_id_loss_parallel.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import time 3 | import numpy as np 4 | import os 5 | import json 6 | import sys 7 | from PIL import Image 8 | import multiprocessing as mp 9 | import math 10 | import torch 11 | import torchvision.transforms as trans 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from models.mtcnn.mtcnn import MTCNN 17 | from models.encoders.model_irse import IR_101 18 | from configs.paths_config import model_paths 19 | CIRCULAR_FACE_PATH = model_paths['circular_face'] 20 | 21 | 22 | def chunks(lst, n): 23 | """Yield successive n-sized chunks from lst.""" 24 | for i in range(0, len(lst), n): 25 | yield lst[i:i + n] 26 | 27 | 28 | def extract_on_paths(file_paths): 29 | facenet = IR_101(input_size=112) 30 | facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) 31 | facenet.cuda() 32 | facenet.eval() 33 | mtcnn = MTCNN() 34 | id_transform = trans.Compose([ 35 | trans.ToTensor(), 36 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 37 | ]) 38 | 39 | pid = mp.current_process().name 40 | print('\t{} is starting to extract on {} images'.format(pid, len(file_paths))) 41 | tot_count = len(file_paths) 42 | count = 0 43 | 44 | scores_dict = {} 45 | for res_path, gt_path in file_paths: 46 | count += 1 47 | if count % 100 == 0: 48 | print('{} done with {}/{}'.format(pid, count, tot_count)) 49 | if True: 50 | input_im = Image.open(res_path) 51 | input_im, _ = mtcnn.align(input_im) 52 | if input_im is None: 53 | print('{} skipping {}'.format(pid, res_path)) 54 | continue 55 | 56 | input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] 57 | 58 | result_im = Image.open(gt_path) 59 | result_im, _ = mtcnn.align(result_im) 60 | if result_im is None: 61 | print('{} skipping {}'.format(pid, gt_path)) 62 | continue 63 | 64 | result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] 65 | score = float(input_id.dot(result_id)) 66 | scores_dict[os.path.basename(gt_path)] = score 67 | 68 | return scores_dict 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser(add_help=False) 73 | parser.add_argument('--num_threads', type=int, default=4) 74 | parser.add_argument('--output_path', type=str, default='inference_results', help='path to inference outputs') 75 | parser.add_argument('--gt_path', type=str, default='gt_images', help='path to gt images') 76 | args = parser.parse_args() 77 | return args 78 | 79 | 80 | def run(args): 81 | for step in sorted(os.listdir(args.output_path)): 82 | if not step.isdigit(): 83 | continue 84 | step_outputs_path = os.path.join(args.output_path, step) 85 | if os.path.isdir(step_outputs_path): 86 | print('#' * 80) 87 | print(f'Running on step: {step}') 88 | print('#' * 80) 89 | run_on_step_output(step=step, args=args) 90 | 91 | 92 | def run_on_step_output(step, args): 93 | file_paths = [] 94 | step_outputs_path = os.path.join(args.output_path, step) 95 | for f in os.listdir(step_outputs_path): 96 | image_path = os.path.join(step_outputs_path, f) 97 | gt_path = os.path.join(args.gt_path, f) 98 | if f.endswith(".jpg") or f.endswith('.png') or f.endswith('.jpeg'): 99 | file_paths.append([image_path, gt_path.replace('.png', '.jpg')]) 100 | 101 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 102 | pool = mp.Pool(args.num_threads) 103 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 104 | 105 | tic = time.time() 106 | results = pool.map(extract_on_paths, file_chunks) 107 | scores_dict = {} 108 | for d in results: 109 | scores_dict.update(d) 110 | 111 | all_scores = list(scores_dict.values()) 112 | mean = np.mean(all_scores) 113 | std = np.std(all_scores) 114 | result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std) 115 | print(result_str) 116 | 117 | out_path = os.path.join(os.path.dirname(args.output_path), 'inference_metrics') 118 | if not os.path.exists(out_path): 119 | os.makedirs(out_path) 120 | 121 | with open(os.path.join(out_path, f'stat_id_step_{step}.txt'), 'w') as f: 122 | f.write(result_str) 123 | with open(os.path.join(out_path, f'scores_id_step_{step}.json'), 'w') as f: 124 | json.dump(scores_dict, f) 125 | 126 | toc = time.time() 127 | print('Mischief managed in {}s'.format(toc - tic)) 128 | 129 | 130 | if __name__ == '__main__': 131 | args = parse_args() 132 | run(args) 133 | -------------------------------------------------------------------------------- /restyle_encoder/scriptsLocal/calc_losses_on_images.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import json 4 | import sys 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torchvision.transforms as transforms 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from criteria.lpips.lpips import LPIPS 15 | from datasets.gt_res_dataset import GTResDataset 16 | 17 | 18 | def parse_args(): 19 | parser = ArgumentParser(add_help=False) 20 | parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) 21 | parser.add_argument('--output_path', type=str, default='results') 22 | parser.add_argument('--gt_path', type=str, default='gt_images') 23 | parser.add_argument('--workers', type=int, default=4) 24 | parser.add_argument('--batch_size', type=int, default=4) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def run(args): 30 | for step in sorted(os.listdir(args.output_path)): 31 | if not step.isdigit(): 32 | continue 33 | step_outputs_path = os.path.join(args.output_path, step) 34 | if os.path.isdir(step_outputs_path): 35 | print('#' * 80) 36 | print(f'Running on step: {step}') 37 | print('#' * 80) 38 | run_on_step_output(step=step, args=args) 39 | 40 | 41 | def run_on_step_output(step, args): 42 | 43 | transform = transforms.Compose([transforms.Resize((256, 256)), 44 | transforms.ToTensor(), 45 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 46 | 47 | step_outputs_path = os.path.join(args.output_path, step) 48 | 49 | print('Loading dataset') 50 | dataset = GTResDataset(root_path=step_outputs_path, 51 | gt_dir=args.gt_path, 52 | transform=transform) 53 | 54 | dataloader = DataLoader(dataset, 55 | batch_size=args.batch_size, 56 | shuffle=False, 57 | num_workers=int(args.workers), 58 | drop_last=True) 59 | 60 | if args.mode == 'lpips': 61 | loss_func = LPIPS(net_type='alex') 62 | elif args.mode == 'l2': 63 | loss_func = torch.nn.MSELoss() 64 | else: 65 | raise Exception('Not a valid mode!') 66 | loss_func.cuda() 67 | 68 | global_i = 0 69 | scores_dict = {} 70 | all_scores = [] 71 | for result_batch, gt_batch in tqdm(dataloader): 72 | for i in range(args.batch_size): 73 | loss = float(loss_func(result_batch[i:i+1].cuda(), gt_batch[i:i+1].cuda())) 74 | all_scores.append(loss) 75 | im_path = dataset.pairs[global_i][0] 76 | scores_dict[os.path.basename(im_path)] = loss 77 | global_i += 1 78 | 79 | all_scores = list(scores_dict.values()) 80 | mean = np.mean(all_scores) 81 | std = np.std(all_scores) 82 | result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) 83 | print('Finished with ', step_outputs_path) 84 | print(result_str) 85 | 86 | out_path = os.path.join(os.path.dirname(args.output_path), 'inference_metrics') 87 | if not os.path.exists(out_path): 88 | os.makedirs(out_path) 89 | 90 | with open(os.path.join(out_path, f'stat_{args.mode}_step_{step}.txt'), 'w') as f: 91 | f.write(result_str) 92 | with open(os.path.join(out_path, f'scores_{args.mode}_step_{step}.json'), 'w') as f: 93 | json.dump(scores_dict, f) 94 | 95 | 96 | if __name__ == '__main__': 97 | args = parse_args() 98 | run(args) 99 | -------------------------------------------------------------------------------- /restyle_encoder/scriptsLocal/encoder_bootstrapping_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | 4 | from tqdm import tqdm 5 | import time 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | import sys 11 | 12 | from utils.inference_utils import get_average_image 13 | 14 | sys.path.append(".") 15 | sys.path.append("..") 16 | 17 | from configs import data_configs 18 | from datasets.inference_dataset import InferenceDataset 19 | from options.test_options import TestOptions 20 | from models.psp import pSp 21 | from models.e4e import e4e 22 | from utils.model_utils import ENCODER_TYPES 23 | from utils.common import tensor2im 24 | 25 | 26 | def run(): 27 | test_opts = TestOptions().parse() 28 | 29 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') 30 | os.makedirs(out_path_results, exist_ok=True) 31 | 32 | # load model used for initializing encoder bootstrapping 33 | ckpt = torch.load(test_opts.model_1_checkpoint_path, map_location='cpu') 34 | opts = ckpt['opts'] 35 | opts.update(vars(test_opts)) 36 | opts['checkpoint_path'] = test_opts.model_1_checkpoint_path 37 | opts = Namespace(**opts) 38 | if opts.encoder_type in ENCODER_TYPES['pSp']: 39 | net1 = pSp(opts) 40 | else: 41 | net1 = e4e(opts) 42 | net1.eval() 43 | net1.cuda() 44 | 45 | # load model used for translating input image after initialization 46 | ckpt = torch.load(test_opts.model_2_checkpoint_path, map_location='cpu') 47 | opts = ckpt['opts'] 48 | opts.update(vars(test_opts)) 49 | opts['checkpoint_path'] = test_opts.model_2_checkpoint_path 50 | opts = Namespace(**opts) 51 | if opts.encoder_type in ENCODER_TYPES['pSp']: 52 | net2 = pSp(opts) 53 | else: 54 | net2 = e4e(opts) 55 | net2.eval() 56 | net2.cuda() 57 | 58 | print('Loading dataset for {}'.format(opts.dataset_type)) 59 | dataset_args = data_configs.DATASETS[opts.dataset_type] 60 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 61 | dataset = InferenceDataset(root=opts.data_path, 62 | transform=transforms_dict['transform_inference'], 63 | opts=opts) 64 | dataloader = DataLoader(dataset, 65 | batch_size=opts.test_batch_size, 66 | shuffle=False, 67 | num_workers=int(opts.test_workers), 68 | drop_last=False) 69 | 70 | if opts.n_images is None: 71 | opts.n_images = len(dataset) 72 | 73 | # get the image corresponding to the latent average 74 | avg_image = get_average_image(net1, opts) 75 | 76 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 77 | 78 | global_i = 0 79 | global_time = [] 80 | for input_batch in tqdm(dataloader): 81 | if global_i >= opts.n_images: 82 | break 83 | with torch.no_grad(): 84 | input_cuda = input_batch.cuda().float() 85 | tic = time.time() 86 | result_batch = run_on_batch(input_cuda, net1, net2, opts, avg_image) 87 | toc = time.time() 88 | global_time.append(toc - tic) 89 | 90 | for i in range(input_batch.shape[0]): 91 | results = [tensor2im(result_batch[i][iter_idx]) for iter_idx in range(opts.n_iters_per_batch + 1)] 92 | im_path = dataset.paths[global_i] 93 | 94 | input_im = tensor2im(input_batch[i]) 95 | 96 | # save step-by-step results side-by-side 97 | res = np.array(results[0].resize(resize_amount)) 98 | for idx, result in enumerate(results[1:]): 99 | res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1) 100 | res = np.concatenate([res, input_im.resize(resize_amount)], axis=1) 101 | Image.fromarray(res).save(os.path.join(out_path_results, os.path.basename(im_path))) 102 | 103 | global_i += 1 104 | 105 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 106 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 107 | print(result_str) 108 | 109 | with open(stats_path, 'w') as f: 110 | f.write(result_str) 111 | 112 | 113 | def run_on_batch(inputs, net1, net2, opts, avg_image): 114 | y_hat, latent = None, None 115 | results_batch = {idx: [] for idx in range(inputs.shape[0])} 116 | 117 | # initialize using the first net 118 | avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1) 119 | x_input = torch.cat([inputs, avg_image_for_batch], dim=1) 120 | y_hat, latent = net1.forward(x_input, 121 | latent=latent, 122 | randomize_noise=False, 123 | return_latents=True, 124 | resize=opts.resize_outputs) 125 | for idx in range(inputs.shape[0]): 126 | results_batch[idx].append(y_hat[idx]) 127 | y_hat = net1.face_pool(y_hat) 128 | 129 | # iteratively translate using the resulting latent and generated image 130 | for iter in range(opts.n_iters_per_batch): 131 | x_input = torch.cat([inputs, y_hat], dim=1) 132 | y_hat, latent = net2.forward(x_input, 133 | latent=latent, 134 | randomize_noise=False, 135 | return_latents=True, 136 | resize=opts.resize_outputs) 137 | for idx in range(inputs.shape[0]): 138 | results_batch[idx].append(y_hat[idx]) 139 | y_hat = net1.face_pool(y_hat) 140 | 141 | return results_batch 142 | 143 | 144 | if __name__ == '__main__': 145 | run() 146 | -------------------------------------------------------------------------------- /restyle_encoder/scriptsLocal/inference_iterative.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import sys 9 | 10 | sys.path.append(".") 11 | sys.path.append("..") 12 | 13 | from configs import data_configs 14 | from datasets.inference_dataset import InferenceDataset 15 | from options.test_options import TestOptions 16 | from models.psp import pSp 17 | from models.e4e import e4e 18 | from utils.model_utils import ENCODER_TYPES 19 | from utils.common import tensor2im 20 | from utils.inference_utils import run_on_batch, get_average_image 21 | 22 | 23 | def run(): 24 | test_opts = TestOptions().parse() 25 | 26 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') 27 | os.makedirs(out_path_results, exist_ok=True) 28 | 29 | # update test options with options used during training 30 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 31 | opts = ckpt['opts'] 32 | opts.update(vars(test_opts)) 33 | opts = Namespace(**opts) 34 | 35 | if opts.encoder_type in ENCODER_TYPES['pSp']: 36 | net = pSp(opts) 37 | else: 38 | net = e4e(opts) 39 | 40 | net.eval() 41 | net.cuda() 42 | 43 | print('Loading dataset for {}'.format(opts.dataset_type)) 44 | dataset_args = data_configs.DATASETS[opts.dataset_type] 45 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 46 | dataset = InferenceDataset(root=opts.data_path, 47 | transform=transforms_dict['transform_inference'], 48 | opts=opts) 49 | dataloader = DataLoader(dataset, 50 | batch_size=opts.test_batch_size, 51 | shuffle=False, 52 | num_workers=int(opts.test_workers), 53 | drop_last=False) 54 | 55 | if opts.n_images is None: 56 | opts.n_images = len(dataset) 57 | 58 | # get the image corresponding to the latent average 59 | avg_image = get_average_image(net, opts) 60 | 61 | if opts.dataset_type == "cars_encode": 62 | resize_amount = (256, 192) if opts.resize_outputs else (512, 384) 63 | else: 64 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 65 | 66 | global_i = 0 67 | global_time = [] 68 | all_latents = {} 69 | for input_batch in tqdm(dataloader): 70 | if global_i >= opts.n_images: 71 | break 72 | 73 | with torch.no_grad(): 74 | input_cuda = input_batch.cuda().float() 75 | tic = time.time() 76 | result_batch, result_latents = run_on_batch(input_cuda, net, opts, avg_image) 77 | toc = time.time() 78 | global_time.append(toc - tic) 79 | 80 | for i in range(input_batch.shape[0]): 81 | results = [tensor2im(result_batch[i][iter_idx]) for iter_idx in range(opts.n_iters_per_batch)] 82 | im_path = dataset.paths[global_i] 83 | 84 | # save step-by-step results side-by-side 85 | for idx, result in enumerate(results): 86 | save_dir = os.path.join(out_path_results, str(idx)) 87 | os.makedirs(save_dir, exist_ok=True) 88 | result.resize(resize_amount).save(os.path.join(save_dir, os.path.basename(im_path))) 89 | 90 | # store all latents with dict pairs (image_name, latents) 91 | all_latents[os.path.basename(im_path)] = result_latents[i] 92 | 93 | global_i += 1 94 | 95 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 96 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 97 | print(result_str) 98 | 99 | with open(stats_path, 'w') as f: 100 | f.write(result_str) 101 | 102 | # save all latents as npy file 103 | np.save(os.path.join(test_opts.exp_dir, 'latents.npy'), all_latents) 104 | 105 | 106 | if __name__ == '__main__': 107 | run() 108 | -------------------------------------------------------------------------------- /restyle_encoder/scriptsLocal/inference_iterative_save_coupled.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | import sys 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from configs import data_configs 15 | from datasets.inference_dataset import InferenceDataset 16 | from options.test_options import TestOptions 17 | from models.psp import pSp 18 | from models.e4e import e4e 19 | from utils.model_utils import ENCODER_TYPES 20 | from utils.common import tensor2im 21 | from utils.inference_utils import run_on_batch, get_average_image 22 | 23 | 24 | def run(): 25 | test_opts = TestOptions().parse() 26 | 27 | out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled') 28 | os.makedirs(out_path_coupled, exist_ok=True) 29 | 30 | # update test options with options used during training 31 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 32 | opts = ckpt['opts'] 33 | opts.update(vars(test_opts)) 34 | opts = Namespace(**opts) 35 | 36 | if opts.encoder_type in ENCODER_TYPES['pSp']: 37 | net = pSp(opts) 38 | else: 39 | net = e4e(opts) 40 | 41 | net.eval() 42 | net.cuda() 43 | 44 | print('Loading dataset for {}'.format(opts.dataset_type)) 45 | dataset_args = data_configs.DATASETS[opts.dataset_type] 46 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 47 | dataset = InferenceDataset(root=opts.data_path, 48 | transform=transforms_dict['transform_inference'], 49 | opts=opts) 50 | dataloader = DataLoader(dataset, 51 | batch_size=opts.test_batch_size, 52 | shuffle=False, 53 | num_workers=int(opts.test_workers), 54 | drop_last=False) 55 | 56 | if opts.n_images is None: 57 | opts.n_images = len(dataset) 58 | 59 | # get the image corresponding to the latent average 60 | avg_image = get_average_image(net, opts) 61 | 62 | if opts.dataset_type == "cars_encode": 63 | resize_amount = (256, 192) if opts.resize_outputs else (512, 384) 64 | else: 65 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 66 | 67 | global_i = 0 68 | global_time = [] 69 | for input_batch in tqdm(dataloader): 70 | if global_i >= opts.n_images: 71 | break 72 | 73 | with torch.no_grad(): 74 | input_cuda = input_batch.cuda().float() 75 | tic = time.time() 76 | result_batch, result_latents = run_on_batch(input_cuda, net, opts, avg_image) 77 | toc = time.time() 78 | global_time.append(toc - tic) 79 | 80 | for i in range(input_batch.shape[0]): 81 | results = [tensor2im(result_batch[i][iter_idx]) for iter_idx in range(opts.n_iters_per_batch)] 82 | im_path = dataset.paths[global_i] 83 | 84 | # save step-by-step results side-by-side 85 | input_im = tensor2im(input_batch[i]) 86 | res = np.array(results[0].resize(resize_amount)) 87 | for idx, result in enumerate(results[1:]): 88 | res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1) 89 | res = np.concatenate([res, input_im.resize(resize_amount)], axis=1) 90 | 91 | Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path))) 92 | 93 | global_i += 1 94 | 95 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 96 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 97 | print(result_str) 98 | 99 | with open(stats_path, 'w') as f: 100 | f.write(result_str) 101 | 102 | 103 | if __name__ == '__main__': 104 | run() 105 | -------------------------------------------------------------------------------- /restyle_encoder/scriptsLocal/train_restyle_e4e.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import os 5 | import json 6 | import math 7 | import sys 8 | import pprint 9 | import torch 10 | from argparse import Namespace 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | 15 | from options.e4e_train_options import e4eTrainOptions 16 | from training.coach_restyle_e4e import Coach 17 | 18 | 19 | def main(): 20 | opts = e4eTrainOptions().parse() 21 | previous_train_ckpt = None 22 | if opts.resume_training_from_ckpt: 23 | opts, previous_train_ckpt = load_train_checkpoint(opts) 24 | else: 25 | setup_progressive_steps(opts) 26 | create_initial_experiment_dir(opts) 27 | 28 | coach = Coach(opts, previous_train_ckpt) 29 | coach.train() 30 | 31 | 32 | def load_train_checkpoint(opts): 33 | train_ckpt_path = opts.resume_training_from_ckpt 34 | previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu') 35 | new_opts_dict = vars(opts) 36 | opts = previous_train_ckpt['opts'] 37 | opts['resume_training_from_ckpt'] = train_ckpt_path 38 | update_new_configs(opts, new_opts_dict) 39 | pprint.pprint(opts) 40 | opts = Namespace(**opts) 41 | if opts.sub_exp_dir is not None: 42 | sub_exp_dir = opts.sub_exp_dir 43 | opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir) 44 | create_initial_experiment_dir(opts) 45 | return opts, previous_train_ckpt 46 | 47 | 48 | def setup_progressive_steps(opts): 49 | log_size = int(math.log(opts.output_size, 2)) 50 | num_style_layers = 2 * log_size - 2 51 | num_deltas = num_style_layers - 1 52 | if opts.progressive_start is not None: # If progressive delta training 53 | opts.progressive_steps = [0] 54 | next_progressive_step = opts.progressive_start 55 | for i in range(num_deltas): 56 | opts.progressive_steps.append(next_progressive_step) 57 | next_progressive_step += opts.progressive_step_every 58 | 59 | assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \ 60 | "Invalid progressive training input" 61 | 62 | 63 | def is_valid_progressive_steps(opts, num_style_layers): 64 | return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0 65 | 66 | 67 | def create_initial_experiment_dir(opts): 68 | os.makedirs(opts.exp_dir, exist_ok=True) 69 | opts_dict = vars(opts) 70 | pprint.pprint(opts_dict) 71 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 72 | json.dump(opts_dict, f, indent=4, sort_keys=True) 73 | 74 | 75 | def update_new_configs(ckpt_opts, new_opts): 76 | for k, v in new_opts.items(): 77 | if k not in ckpt_opts: 78 | ckpt_opts[k] = v 79 | if new_opts['update_param_list']: 80 | for param in new_opts['update_param_list']: 81 | ckpt_opts[param] = new_opts[param] 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /restyle_encoder/scriptsLocal/train_restyle_psp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import os 5 | import json 6 | import sys 7 | import pprint 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from options.train_options import TrainOptions 13 | from training.coach_restyle_psp import Coach 14 | 15 | 16 | def main(): 17 | opts = TrainOptions().parse() 18 | os.makedirs(opts.exp_dir, exist_ok=True) 19 | 20 | opts_dict = vars(opts) 21 | pprint.pprint(opts_dict) 22 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 23 | json.dump(opts_dict, f, indent=4, sort_keys=True) 24 | 25 | coach = Coach(opts) 26 | coach.train() 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /restyle_encoder/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/training/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/training/ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | use_gc=True, gc_conv_only=False 35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, 55 | eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | def __setstate__(self, state): 76 | super(Ranger, self).__setstate__(state) 77 | 78 | def step(self, closure=None): 79 | loss = None 80 | 81 | # Evaluate averages and grad, update param tensors 82 | for group in self.param_groups: 83 | 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | grad = p.grad.data.float() 88 | 89 | if grad.is_sparse: 90 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 91 | 92 | p_data_fp32 = p.data.float() 93 | 94 | state = self.state[p] # get state dict for this param 95 | 96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 97 | # if self.first_run_check==0: 98 | # self.first_run_check=1 99 | # print("Initializing slow buffer...should not see this at load from saved model!") 100 | state['step'] = 0 101 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 103 | 104 | # look ahead weight storage now in state dict 105 | state['slow_buffer'] = torch.empty_like(p.data) 106 | state['slow_buffer'].copy_(p.data) 107 | 108 | else: 109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 111 | 112 | # begin computations 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | # GC operation for Conv layers and FC layers 117 | if grad.dim() > self.gc_gradient_threshold: 118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 119 | 120 | state['step'] += 1 121 | 122 | # compute variance mov avg 123 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 124 | # compute mean moving avg 125 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 126 | 127 | buffered = self.radam_buffer[int(state['step'] % 10)] 128 | 129 | if state['step'] == buffered[0]: 130 | N_sma, step_size = buffered[1], buffered[2] 131 | else: 132 | buffered[0] = state['step'] 133 | beta2_t = beta2 ** state['step'] 134 | N_sma_max = 2 / (1 - beta2) - 1 135 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 136 | buffered[1] = N_sma 137 | if N_sma > self.N_sma_threshhold: 138 | step_size = math.sqrt( 139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 140 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 141 | else: 142 | step_size = 1.0 / (1 - beta1 ** state['step']) 143 | buffered[2] = step_size 144 | 145 | if group['weight_decay'] != 0: 146 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 147 | 148 | # apply lr 149 | if N_sma > self.N_sma_threshhold: 150 | denom = exp_avg_sq.sqrt().add_(group['eps']) 151 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 152 | else: 153 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 154 | 155 | p.data.copy_(p_data_fp32) 156 | 157 | # integrated look ahead... 158 | # we do it at the param level instead of group level 159 | if state['step'] % group['k'] == 0: 160 | slow_p = state['slow_buffer'] # get access to slow param tensor 161 | slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha 162 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor 163 | 164 | return loss -------------------------------------------------------------------------------- /restyle_encoder/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuqiuche/micromotion-styleGAN/d4ff949b0d08814f49603850bb50a98346905a7b/restyle_encoder/utils/__init__.py -------------------------------------------------------------------------------- /restyle_encoder/utils/common.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def tensor2im(var): 6 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 7 | var = ((var + 1) / 2) 8 | var[var < 0] = 0 9 | var[var > 1] = 1 10 | var = var * 255 11 | return Image.fromarray(var.astype('uint8')) 12 | 13 | 14 | def vis_faces(log_hooks): 15 | display_count = len(log_hooks) 16 | n_outputs = len(log_hooks[0]['output_face']) if type(log_hooks[0]['output_face']) == list else 1 17 | fig = plt.figure(figsize=(6 + (n_outputs * 2), 4 * display_count)) 18 | gs = fig.add_gridspec(display_count, (2 + n_outputs)) 19 | for i in range(display_count): 20 | hooks_dict = log_hooks[i] 21 | fig.add_subplot(gs[i, 0]) 22 | vis_faces_iterative(hooks_dict, fig, gs, i) 23 | plt.tight_layout() 24 | return fig 25 | 26 | 27 | def vis_faces_iterative(hooks_dict, fig, gs, i): 28 | plt.imshow(hooks_dict['input_face']) 29 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 30 | fig.add_subplot(gs[i, 1]) 31 | plt.imshow(hooks_dict['target_face']) 32 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), float(hooks_dict['diff_target']))) 33 | for idx, output_idx in enumerate(range(len(hooks_dict['output_face']) - 1, -1, -1)): 34 | output_image, similarity = hooks_dict['output_face'][output_idx] 35 | fig.add_subplot(gs[i, 2 + idx]) 36 | plt.imshow(output_image) 37 | plt.title('Output {}\n Target Sim={:.2f}'.format(output_idx, float(similarity))) 38 | -------------------------------------------------------------------------------- /restyle_encoder/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD: 3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 4 | """ 5 | import os 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 20 | for root, _, fnames in sorted(os.walk(dir)): 21 | for fname in fnames: 22 | if is_image_file(fname): 23 | path = os.path.join(root, fname) 24 | images.append(path) 25 | return images 26 | -------------------------------------------------------------------------------- /restyle_encoder/utils/inference_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_average_image(net, opts): 5 | avg_image = net(net.latent_avg.unsqueeze(0), 6 | input_code=True, 7 | randomize_noise=False, 8 | return_latents=False, 9 | average_code=True)[0] 10 | avg_image = avg_image.to('cuda').float().detach() 11 | if opts.dataset_type == "cars_encode": 12 | avg_image = avg_image[:, 32:224, :] 13 | return avg_image 14 | 15 | 16 | def run_on_batch(inputs, net, opts, avg_image): 17 | y_hat, latent = None, None 18 | results_batch = {idx: [] for idx in range(inputs.shape[0])} 19 | results_latent = {idx: [] for idx in range(inputs.shape[0])} 20 | for iter in range(opts.n_iters_per_batch): 21 | if iter == 0: 22 | avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1) 23 | x_input = torch.cat([inputs, avg_image_for_batch], dim=1) 24 | else: 25 | x_input = torch.cat([inputs, y_hat], dim=1) 26 | 27 | y_hat, latent = net.forward(x_input, 28 | latent=latent, 29 | randomize_noise=False, 30 | return_latents=True, 31 | resize=opts.resize_outputs) 32 | 33 | if opts.dataset_type == "cars_encode": 34 | if opts.resize_outputs: 35 | y_hat = y_hat[:, :, 32:224, :] 36 | else: 37 | y_hat = y_hat[:, :, 64:448, :] 38 | 39 | # store intermediate outputs 40 | for idx in range(inputs.shape[0]): 41 | results_batch[idx].append(y_hat[idx]) 42 | results_latent[idx].append(latent[idx].cpu().numpy()) 43 | 44 | # resize input to 256 before feeding into next iteration 45 | if opts.dataset_type == "cars_encode": 46 | y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat) 47 | else: 48 | y_hat = net.face_pool(y_hat) 49 | 50 | return results_batch, results_latent 51 | 52 | 53 | def run_batch_latent(inputs, net, opts, avg_image): 54 | y_hat, latent = None, None 55 | results_latent = {idx: [] for idx in range(inputs.shape[0])} 56 | for iter in range(opts.n_iters_per_batch): 57 | if iter == 0: 58 | avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1) 59 | x_input = torch.cat([inputs, avg_image_for_batch], dim=1) 60 | else: 61 | x_input = torch.cat([inputs, y_hat], dim=1) 62 | 63 | y_hat, latent = net.forward(x_input, 64 | latent=latent, 65 | randomize_noise=False, 66 | return_latents=True, 67 | resize=opts.resize_outputs) 68 | 69 | if opts.dataset_type == "cars_encode": 70 | if opts.resize_outputs: 71 | y_hat = y_hat[:, :, 32:224, :] 72 | else: 73 | y_hat = y_hat[:, :, 64:448, :] 74 | 75 | for idx in range(inputs.shape[0]): 76 | results_latent[idx].append(latent[idx]) 77 | 78 | # resize input to 256 before feeding into next iteration 79 | if opts.dataset_type == "cars_encode": 80 | y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat) 81 | else: 82 | y_hat = net.face_pool(y_hat) 83 | 84 | return results_latent 85 | 86 | -------------------------------------------------------------------------------- /restyle_encoder/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # specify the encoder types for pSp and e4e - this is mainly used for the inference scripts 2 | ENCODER_TYPES = { 3 | 'pSp': ['GradualStyleEncoder', 'ResNetGradualStyleEncoder', 'BackboneEncoder', 'ResNetBackboneEncoder'], 4 | 'e4e': ['ProgressiveBackboneEncoder', 'ResNetProgressiveBackboneEncoder'] 5 | } 6 | 7 | RESNET_MAPPING = { 8 | 'layer1.0': 'body.0', 9 | 'layer1.1': 'body.1', 10 | 'layer1.2': 'body.2', 11 | 'layer2.0': 'body.3', 12 | 'layer2.1': 'body.4', 13 | 'layer2.2': 'body.5', 14 | 'layer2.3': 'body.6', 15 | 'layer3.0': 'body.7', 16 | 'layer3.1': 'body.8', 17 | 'layer3.2': 'body.9', 18 | 'layer3.3': 'body.10', 19 | 'layer3.4': 'body.11', 20 | 'layer3.5': 'body.12', 21 | 'layer4.0': 'body.13', 22 | 'layer4.1': 'body.14', 23 | 'layer4.2': 'body.15', 24 | } 25 | -------------------------------------------------------------------------------- /restyle_encoder/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def aggregate_loss_dict(agg_loss_dict): 3 | mean_vals = {} 4 | for output in agg_loss_dict: 5 | for key in output: 6 | mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] 7 | for key in mean_vals: 8 | if len(mean_vals[key]) > 0: 9 | mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) 10 | else: 11 | print('{} has no value'.format(key)) 12 | mean_vals[key] = 0 13 | return mean_vals 14 | --------------------------------------------------------------------------------